1use core::fmt;
16use std::{borrow::Borrow, cmp::min, iter, ops::Deref};
17
18use async_trait::async_trait;
19use deadpool_sqlite::Object as SqliteAsyncConn;
20use itertools::Itertools;
21use matrix_sdk_store_encryption::StoreCipher;
22use ruma::time::SystemTime;
23use rusqlite::{limits::Limit, OptionalExtension, Params, Row, Statement, Transaction};
24use serde::{de::DeserializeOwned, Serialize};
25
26use crate::{
27 error::{Error, Result},
28 OpenStoreError, RuntimeConfig,
29};
30
31#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
32pub(crate) enum Key {
33 Plain(Vec<u8>),
34 Hashed([u8; 32]),
35}
36
37impl Deref for Key {
38 type Target = [u8];
39
40 fn deref(&self) -> &Self::Target {
41 match self {
42 Key::Plain(slice) => slice,
43 Key::Hashed(bytes) => bytes,
44 }
45 }
46}
47
48impl Borrow<[u8]> for Key {
49 fn borrow(&self) -> &[u8] {
50 self.deref()
51 }
52}
53
54impl rusqlite::ToSql for Key {
55 fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
56 self.deref().to_sql()
57 }
58}
59
60#[async_trait]
61pub(crate) trait SqliteAsyncConnExt {
62 async fn execute<P>(
63 &self,
64 sql: impl AsRef<str> + Send + 'static,
65 params: P,
66 ) -> rusqlite::Result<usize>
67 where
68 P: Params + Send + 'static;
69
70 async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()>;
71
72 async fn prepare<T, F>(
73 &self,
74 sql: impl AsRef<str> + Send + 'static,
75 f: F,
76 ) -> rusqlite::Result<T>
77 where
78 T: Send + 'static,
79 F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static;
80
81 async fn query_row<T, P, F>(
82 &self,
83 sql: impl AsRef<str> + Send + 'static,
84 params: P,
85 f: F,
86 ) -> rusqlite::Result<T>
87 where
88 T: Send + 'static,
89 P: Params + Send + 'static,
90 F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static;
91
92 async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
93 where
94 T: Send + 'static,
95 E: From<rusqlite::Error> + Send + 'static,
96 F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static;
97
98 async fn chunk_large_query_over<Query, Res>(
99 &self,
100 mut keys_to_chunk: Vec<Key>,
101 result_capacity: Option<usize>,
102 do_query: Query,
103 ) -> Result<Vec<Res>>
104 where
105 Res: Send + 'static,
106 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
107
108 async fn apply_runtime_config(&self, runtime_config: RuntimeConfig) -> Result<()> {
117 let RuntimeConfig { optimize, cache_size, journal_size_limit } = runtime_config;
118
119 if optimize {
120 self.optimize().await?;
121 }
122
123 self.cache_size(cache_size).await?;
124 self.journal_size_limit(journal_size_limit).await?;
125
126 Ok(())
127 }
128
129 async fn optimize(&self) -> Result<()> {
139 self.execute_batch("PRAGMA optimize = 0x10002;").await?;
140 Ok(())
141 }
142
143 async fn cache_size(&self, cache_size: u32) -> Result<()> {
149 let n = cache_size / 1024;
152
153 self.execute_batch(format!("PRAGMA cache_size = -{n};")).await?;
154 Ok(())
155 }
156
157 async fn journal_size_limit(&self, limit: u32) -> Result<()> {
175 self.execute_batch(format!("PRAGMA journal_size_limit = {limit};")).await?;
176 Ok(())
177 }
178
179 async fn vacuum(&self) -> Result<()> {
183 if let Err(error) = self.execute_batch("VACUUM").await {
184 #[cfg(not(any(test, debug_assertions)))]
187 tracing::warn!("Failed to vacuum database: {error}");
188
189 #[cfg(any(test, debug_assertions))]
191 return Err(error.into());
192 }
193
194 Ok(())
195 }
196}
197
198#[async_trait]
199impl SqliteAsyncConnExt for SqliteAsyncConn {
200 async fn execute<P>(
201 &self,
202 sql: impl AsRef<str> + Send + 'static,
203 params: P,
204 ) -> rusqlite::Result<usize>
205 where
206 P: Params + Send + 'static,
207 {
208 self.interact(move |conn| conn.execute(sql.as_ref(), params)).await.unwrap()
209 }
210
211 async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()> {
212 self.interact(move |conn| conn.execute_batch(sql.as_ref())).await.unwrap()
213 }
214
215 async fn prepare<T, F>(
216 &self,
217 sql: impl AsRef<str> + Send + 'static,
218 f: F,
219 ) -> rusqlite::Result<T>
220 where
221 T: Send + 'static,
222 F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static,
223 {
224 self.interact(move |conn| f(conn.prepare(sql.as_ref())?)).await.unwrap()
225 }
226
227 async fn query_row<T, P, F>(
228 &self,
229 sql: impl AsRef<str> + Send + 'static,
230 params: P,
231 f: F,
232 ) -> rusqlite::Result<T>
233 where
234 T: Send + 'static,
235 P: Params + Send + 'static,
236 F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static,
237 {
238 self.interact(move |conn| conn.query_row(sql.as_ref(), params, f)).await.unwrap()
239 }
240
241 async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
242 where
243 T: Send + 'static,
244 E: From<rusqlite::Error> + Send + 'static,
245 F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static,
246 {
247 self.interact(move |conn| {
248 let txn = conn.transaction()?;
249 let result = f(&txn)?;
250 txn.commit()?;
251 Ok(result)
252 })
253 .await
254 .unwrap()
255 }
256
257 async fn chunk_large_query_over<Query, Res>(
264 &self,
265 keys_to_chunk: Vec<Key>,
266 result_capacity: Option<usize>,
267 do_query: Query,
268 ) -> Result<Vec<Res>>
269 where
270 Res: Send + 'static,
271 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
272 {
273 self.with_transaction(move |txn| {
274 txn.chunk_large_query_over(keys_to_chunk, result_capacity, do_query)
275 })
276 .await
277 }
278}
279
280pub(crate) trait SqliteTransactionExt {
281 fn chunk_large_query_over<Key, Query, Res>(
282 &self,
283 keys_to_chunk: Vec<Key>,
284 result_capacity: Option<usize>,
285 do_query: Query,
286 ) -> Result<Vec<Res>>
287 where
288 Res: Send + 'static,
289 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
290}
291
292impl SqliteTransactionExt for Transaction<'_> {
293 fn chunk_large_query_over<Key, Query, Res>(
294 &self,
295 mut keys_to_chunk: Vec<Key>,
296 result_capacity: Option<usize>,
297 do_query: Query,
298 ) -> Result<Vec<Res>>
299 where
300 Res: Send + 'static,
301 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
302 {
303 let maximum_chunk_size = self.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER) / 2;
306 let maximum_chunk_size: usize = maximum_chunk_size
307 .try_into()
308 .map_err(|_| Error::SqliteMaximumVariableNumber(maximum_chunk_size))?;
309
310 if keys_to_chunk.len() < maximum_chunk_size {
311 let chunk = keys_to_chunk;
313
314 Ok(do_query(self, chunk)?)
315 } else {
316 let capacity = result_capacity.unwrap_or_default();
320 let mut all_results = Vec::with_capacity(capacity);
321
322 while !keys_to_chunk.is_empty() {
323 let tail = keys_to_chunk.split_off(min(keys_to_chunk.len(), maximum_chunk_size));
325 let chunk = keys_to_chunk;
326 keys_to_chunk = tail;
327
328 all_results.extend(do_query(self, chunk)?);
329 }
330
331 Ok(all_results)
332 }
333 }
334}
335
336pub(crate) trait SqliteKeyValueStoreConnExt {
348 fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()>;
350
351 fn set_serialized_kv<T: Serialize + Send>(&self, key: &str, value: T) -> Result<()> {
353 let serialized_value = rmp_serde::to_vec_named(&value)?;
354 self.set_kv(key, &serialized_value)?;
355
356 Ok(())
357 }
358
359 fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
361
362 fn set_db_version(&self, version: u8) -> rusqlite::Result<()> {
364 self.set_kv("version", &[version])
365 }
366}
367
368impl SqliteKeyValueStoreConnExt for rusqlite::Connection {
369 fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()> {
370 self.execute(
371 "INSERT INTO kv VALUES (?1, ?2) ON CONFLICT (key) DO UPDATE SET value = ?2",
372 (key, value),
373 )?;
374 Ok(())
375 }
376
377 fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
378 self.execute("DELETE FROM kv WHERE key = ?1", (key,))?;
379 Ok(())
380 }
381}
382
383#[async_trait]
395pub(crate) trait SqliteKeyValueStoreAsyncConnExt: SqliteAsyncConnExt {
396 async fn kv_table_exists(&self) -> rusqlite::Result<bool> {
398 self.query_row(
399 "SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'kv')",
400 (),
401 |row| row.get(0),
402 )
403 .await
404 }
405
406 async fn get_kv(&self, key: &str) -> rusqlite::Result<Option<Vec<u8>>> {
408 let key = key.to_owned();
409 self.query_row("SELECT value FROM kv WHERE key = ?", (key,), |row| row.get(0))
410 .await
411 .optional()
412 }
413
414 async fn get_serialized_kv<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
416 let Some(bytes) = self.get_kv(key).await? else {
417 return Ok(None);
418 };
419
420 Ok(Some(rmp_serde::from_slice(&bytes)?))
421 }
422
423 async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()>;
425
426 async fn set_serialized_kv<T: Serialize + Send + 'static>(
428 &self,
429 key: &str,
430 value: T,
431 ) -> Result<()>;
432
433 async fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
435
436 async fn db_version(&self) -> Result<u8, OpenStoreError> {
438 let kv_exists = self.kv_table_exists().await.map_err(OpenStoreError::LoadVersion)?;
439
440 if kv_exists {
441 match self.get_kv("version").await.map_err(OpenStoreError::LoadVersion)?.as_deref() {
442 Some([v]) => Ok(*v),
443 Some(_) => Err(OpenStoreError::InvalidVersion),
444 None => Err(OpenStoreError::MissingVersion),
445 }
446 } else {
447 Ok(0)
448 }
449 }
450
451 async fn get_or_create_store_cipher(
453 &self,
454 passphrase: &str,
455 ) -> Result<StoreCipher, OpenStoreError> {
456 let encrypted_cipher = self.get_kv("cipher").await.map_err(OpenStoreError::LoadCipher)?;
457
458 let cipher = if let Some(encrypted) = encrypted_cipher {
459 StoreCipher::import(passphrase, &encrypted)?
460 } else {
461 let cipher = StoreCipher::new()?;
462 #[cfg(not(test))]
463 let export = cipher.export(passphrase);
464 #[cfg(test)]
465 let export = cipher._insecure_export_fast_for_testing(passphrase);
466 self.set_kv("cipher", export?).await.map_err(OpenStoreError::SaveCipher)?;
467 cipher
468 };
469
470 Ok(cipher)
471 }
472}
473
474#[async_trait]
475impl SqliteKeyValueStoreAsyncConnExt for SqliteAsyncConn {
476 async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()> {
477 let key = key.to_owned();
478 self.interact(move |conn| conn.set_kv(&key, &value)).await.unwrap()?;
479
480 Ok(())
481 }
482
483 async fn set_serialized_kv<T: Serialize + Send + 'static>(
484 &self,
485 key: &str,
486 value: T,
487 ) -> Result<()> {
488 let key = key.to_owned();
489 self.interact(move |conn| conn.set_serialized_kv(&key, value)).await.unwrap()?;
490
491 Ok(())
492 }
493
494 async fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
495 let key = key.to_owned();
496 self.interact(move |conn| conn.clear_kv(&key)).await.unwrap()?;
497
498 Ok(())
499 }
500}
501
502pub(crate) fn repeat_vars(count: usize) -> impl fmt::Display {
504 assert_ne!(count, 0, "Can't generate zero repeated vars");
505
506 iter::repeat_n("?", count).format(",")
507}
508
509pub(crate) fn time_to_timestamp(time: SystemTime) -> i64 {
514 time.duration_since(SystemTime::UNIX_EPOCH)
515 .ok()
516 .and_then(|d| d.as_secs().try_into().ok())
517 .unwrap_or(0)
520}
521
522#[cfg(test)]
523mod unit_tests {
524 use std::time::Duration;
525
526 use super::*;
527
528 #[test]
529 fn can_generate_repeated_vars() {
530 assert_eq!(repeat_vars(1).to_string(), "?");
531 assert_eq!(repeat_vars(2).to_string(), "?,?");
532 assert_eq!(repeat_vars(5).to_string(), "?,?,?,?,?");
533 }
534
535 #[test]
536 #[should_panic(expected = "Can't generate zero repeated vars")]
537 fn generating_zero_vars_panics() {
538 repeat_vars(0);
539 }
540
541 #[test]
542 fn test_time_to_timestamp() {
543 assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH), 0);
544 assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH + Duration::from_secs(60)), 60);
545
546 assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH - Duration::from_secs(60)), 0);
548 }
549}