1use core::fmt;
16use std::{borrow::Borrow, cmp::min, iter, ops::Deref};
17
18use async_trait::async_trait;
19use deadpool_sqlite::Object as SqliteAsyncConn;
20use itertools::Itertools;
21use matrix_sdk_store_encryption::StoreCipher;
22use ruma::time::SystemTime;
23use rusqlite::{limits::Limit, OptionalExtension, Params, Row, Statement, Transaction};
24use serde::{de::DeserializeOwned, Serialize};
25
26use crate::{
27 error::{Error, Result},
28 OpenStoreError,
29};
30
31#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
32pub(crate) enum Key {
33 Plain(Vec<u8>),
34 Hashed([u8; 32]),
35}
36
37impl Deref for Key {
38 type Target = [u8];
39
40 fn deref(&self) -> &Self::Target {
41 match self {
42 Key::Plain(slice) => slice,
43 Key::Hashed(bytes) => bytes,
44 }
45 }
46}
47
48impl Borrow<[u8]> for Key {
49 fn borrow(&self) -> &[u8] {
50 self.deref()
51 }
52}
53
54impl rusqlite::ToSql for Key {
55 fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
56 self.deref().to_sql()
57 }
58}
59
60#[async_trait]
61pub(crate) trait SqliteAsyncConnExt {
62 async fn execute<P>(
63 &self,
64 sql: impl AsRef<str> + Send + 'static,
65 params: P,
66 ) -> rusqlite::Result<usize>
67 where
68 P: Params + Send + 'static;
69
70 async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()>;
71
72 async fn prepare<T, F>(
73 &self,
74 sql: impl AsRef<str> + Send + 'static,
75 f: F,
76 ) -> rusqlite::Result<T>
77 where
78 T: Send + 'static,
79 F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static;
80
81 async fn query_row<T, P, F>(
82 &self,
83 sql: impl AsRef<str> + Send + 'static,
84 params: P,
85 f: F,
86 ) -> rusqlite::Result<T>
87 where
88 T: Send + 'static,
89 P: Params + Send + 'static,
90 F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static;
91
92 async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
93 where
94 T: Send + 'static,
95 E: From<rusqlite::Error> + Send + 'static,
96 F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static;
97
98 async fn chunk_large_query_over<Query, Res>(
99 &self,
100 mut keys_to_chunk: Vec<Key>,
101 result_capacity: Option<usize>,
102 do_query: Query,
103 ) -> Result<Vec<Res>>
104 where
105 Res: Send + 'static,
106 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
107
108 async fn optimize(&self) -> Result<()> {
116 self.execute_batch("PRAGMA optimize=0x10002;").await?;
117 Ok(())
118 }
119
120 async fn set_journal_size_limit(&self) -> Result<()> {
136 self.execute_batch("PRAGMA journal_size_limit = 10000000;").await.map_err(Error::from)?;
137 Ok(())
138 }
139
140 async fn vacuum(&self) -> Result<()> {
144 if let Err(error) = self.execute_batch("VACUUM").await {
145 #[cfg(not(any(test, debug_assertions)))]
148 tracing::warn!("Failed to vacuum database: {error}");
149
150 #[cfg(any(test, debug_assertions))]
152 return Err(error.into());
153 }
154
155 Ok(())
156 }
157}
158
159#[async_trait]
160impl SqliteAsyncConnExt for SqliteAsyncConn {
161 async fn execute<P>(
162 &self,
163 sql: impl AsRef<str> + Send + 'static,
164 params: P,
165 ) -> rusqlite::Result<usize>
166 where
167 P: Params + Send + 'static,
168 {
169 self.interact(move |conn| conn.execute(sql.as_ref(), params)).await.unwrap()
170 }
171
172 async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()> {
173 self.interact(move |conn| conn.execute_batch(sql.as_ref())).await.unwrap()
174 }
175
176 async fn prepare<T, F>(
177 &self,
178 sql: impl AsRef<str> + Send + 'static,
179 f: F,
180 ) -> rusqlite::Result<T>
181 where
182 T: Send + 'static,
183 F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static,
184 {
185 self.interact(move |conn| f(conn.prepare(sql.as_ref())?)).await.unwrap()
186 }
187
188 async fn query_row<T, P, F>(
189 &self,
190 sql: impl AsRef<str> + Send + 'static,
191 params: P,
192 f: F,
193 ) -> rusqlite::Result<T>
194 where
195 T: Send + 'static,
196 P: Params + Send + 'static,
197 F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static,
198 {
199 self.interact(move |conn| conn.query_row(sql.as_ref(), params, f)).await.unwrap()
200 }
201
202 async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
203 where
204 T: Send + 'static,
205 E: From<rusqlite::Error> + Send + 'static,
206 F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static,
207 {
208 self.interact(move |conn| {
209 let txn = conn.transaction()?;
210 let result = f(&txn)?;
211 txn.commit()?;
212 Ok(result)
213 })
214 .await
215 .unwrap()
216 }
217
218 async fn chunk_large_query_over<Query, Res>(
225 &self,
226 keys_to_chunk: Vec<Key>,
227 result_capacity: Option<usize>,
228 do_query: Query,
229 ) -> Result<Vec<Res>>
230 where
231 Res: Send + 'static,
232 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
233 {
234 self.with_transaction(move |txn| {
235 txn.chunk_large_query_over(keys_to_chunk, result_capacity, do_query)
236 })
237 .await
238 }
239}
240
241pub(crate) trait SqliteTransactionExt {
242 fn chunk_large_query_over<Key, Query, Res>(
243 &self,
244 keys_to_chunk: Vec<Key>,
245 result_capacity: Option<usize>,
246 do_query: Query,
247 ) -> Result<Vec<Res>>
248 where
249 Res: Send + 'static,
250 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
251}
252
253impl SqliteTransactionExt for Transaction<'_> {
254 fn chunk_large_query_over<Key, Query, Res>(
255 &self,
256 mut keys_to_chunk: Vec<Key>,
257 result_capacity: Option<usize>,
258 do_query: Query,
259 ) -> Result<Vec<Res>>
260 where
261 Res: Send + 'static,
262 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
263 {
264 let maximum_chunk_size = self.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER) / 2;
267 let maximum_chunk_size: usize = maximum_chunk_size
268 .try_into()
269 .map_err(|_| Error::SqliteMaximumVariableNumber(maximum_chunk_size))?;
270
271 if keys_to_chunk.len() < maximum_chunk_size {
272 let chunk = keys_to_chunk;
274
275 Ok(do_query(self, chunk)?)
276 } else {
277 let capacity = result_capacity.unwrap_or_default();
281 let mut all_results = Vec::with_capacity(capacity);
282
283 while !keys_to_chunk.is_empty() {
284 let tail = keys_to_chunk.split_off(min(keys_to_chunk.len(), maximum_chunk_size));
286 let chunk = keys_to_chunk;
287 keys_to_chunk = tail;
288
289 all_results.extend(do_query(self, chunk)?);
290 }
291
292 Ok(all_results)
293 }
294 }
295}
296
297pub(crate) trait SqliteKeyValueStoreConnExt {
309 fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()>;
311
312 fn set_serialized_kv<T: Serialize + Send>(&self, key: &str, value: T) -> Result<()> {
314 let serialized_value = rmp_serde::to_vec_named(&value)?;
315 self.set_kv(key, &serialized_value)?;
316
317 Ok(())
318 }
319
320 fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
322
323 fn set_db_version(&self, version: u8) -> rusqlite::Result<()> {
325 self.set_kv("version", &[version])
326 }
327}
328
329impl SqliteKeyValueStoreConnExt for rusqlite::Connection {
330 fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()> {
331 self.execute(
332 "INSERT INTO kv VALUES (?1, ?2) ON CONFLICT (key) DO UPDATE SET value = ?2",
333 (key, value),
334 )?;
335 Ok(())
336 }
337
338 fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
339 self.execute("DELETE FROM kv WHERE key = ?1", (key,))?;
340 Ok(())
341 }
342}
343
344#[async_trait]
356pub(crate) trait SqliteKeyValueStoreAsyncConnExt: SqliteAsyncConnExt {
357 async fn kv_table_exists(&self) -> rusqlite::Result<bool> {
359 self.query_row(
360 "SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'kv')",
361 (),
362 |row| row.get(0),
363 )
364 .await
365 }
366
367 async fn get_kv(&self, key: &str) -> rusqlite::Result<Option<Vec<u8>>> {
369 let key = key.to_owned();
370 self.query_row("SELECT value FROM kv WHERE key = ?", (key,), |row| row.get(0))
371 .await
372 .optional()
373 }
374
375 async fn get_serialized_kv<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
377 let Some(bytes) = self.get_kv(key).await? else {
378 return Ok(None);
379 };
380
381 Ok(Some(rmp_serde::from_slice(&bytes)?))
382 }
383
384 async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()>;
386
387 async fn set_serialized_kv<T: Serialize + Send + 'static>(
389 &self,
390 key: &str,
391 value: T,
392 ) -> Result<()>;
393
394 async fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
396
397 async fn db_version(&self) -> Result<u8, OpenStoreError> {
399 let kv_exists = self.kv_table_exists().await.map_err(OpenStoreError::LoadVersion)?;
400
401 if kv_exists {
402 match self.get_kv("version").await.map_err(OpenStoreError::LoadVersion)?.as_deref() {
403 Some([v]) => Ok(*v),
404 Some(_) => Err(OpenStoreError::InvalidVersion),
405 None => Err(OpenStoreError::MissingVersion),
406 }
407 } else {
408 Ok(0)
409 }
410 }
411
412 async fn get_or_create_store_cipher(
414 &self,
415 passphrase: &str,
416 ) -> Result<StoreCipher, OpenStoreError> {
417 let encrypted_cipher = self.get_kv("cipher").await.map_err(OpenStoreError::LoadCipher)?;
418
419 let cipher = if let Some(encrypted) = encrypted_cipher {
420 StoreCipher::import(passphrase, &encrypted)?
421 } else {
422 let cipher = StoreCipher::new()?;
423 #[cfg(not(test))]
424 let export = cipher.export(passphrase);
425 #[cfg(test)]
426 let export = cipher._insecure_export_fast_for_testing(passphrase);
427 self.set_kv("cipher", export?).await.map_err(OpenStoreError::SaveCipher)?;
428 cipher
429 };
430
431 Ok(cipher)
432 }
433}
434
435#[async_trait]
436impl SqliteKeyValueStoreAsyncConnExt for SqliteAsyncConn {
437 async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()> {
438 let key = key.to_owned();
439 self.interact(move |conn| conn.set_kv(&key, &value)).await.unwrap()?;
440
441 Ok(())
442 }
443
444 async fn set_serialized_kv<T: Serialize + Send + 'static>(
445 &self,
446 key: &str,
447 value: T,
448 ) -> Result<()> {
449 let key = key.to_owned();
450 self.interact(move |conn| conn.set_serialized_kv(&key, value)).await.unwrap()?;
451
452 Ok(())
453 }
454
455 async fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
456 let key = key.to_owned();
457 self.interact(move |conn| conn.clear_kv(&key)).await.unwrap()?;
458
459 Ok(())
460 }
461}
462
463pub(crate) fn repeat_vars(count: usize) -> impl fmt::Display {
465 assert_ne!(count, 0, "Can't generate zero repeated vars");
466
467 iter::repeat_n("?", count).format(",")
468}
469
470pub(crate) fn time_to_timestamp(time: SystemTime) -> i64 {
475 time.duration_since(SystemTime::UNIX_EPOCH)
476 .ok()
477 .and_then(|d| d.as_secs().try_into().ok())
478 .unwrap_or(0)
481}
482
483#[cfg(test)]
484mod unit_tests {
485 use std::time::Duration;
486
487 use super::*;
488
489 #[test]
490 fn can_generate_repeated_vars() {
491 assert_eq!(repeat_vars(1).to_string(), "?");
492 assert_eq!(repeat_vars(2).to_string(), "?,?");
493 assert_eq!(repeat_vars(5).to_string(), "?,?,?,?,?");
494 }
495
496 #[test]
497 #[should_panic(expected = "Can't generate zero repeated vars")]
498 fn generating_zero_vars_panics() {
499 repeat_vars(0);
500 }
501
502 #[test]
503 fn test_time_to_timestamp() {
504 assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH), 0);
505 assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH + Duration::from_secs(60)), 60);
506
507 assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH - Duration::from_secs(60)), 0);
509 }
510}