matrix_sdk_sqlite/
utils.rs

1// Copyright 2022 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::fmt;
16use std::{borrow::Borrow, cmp::min, iter, ops::Deref};
17
18use async_trait::async_trait;
19use deadpool_sqlite::Object as SqliteAsyncConn;
20use itertools::Itertools;
21use matrix_sdk_store_encryption::StoreCipher;
22use ruma::time::SystemTime;
23use rusqlite::{limits::Limit, OptionalExtension, Params, Row, Statement, Transaction};
24use serde::{de::DeserializeOwned, Serialize};
25
26use crate::{
27    error::{Error, Result},
28    OpenStoreError, RuntimeConfig,
29};
30
31#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
32pub(crate) enum Key {
33    Plain(Vec<u8>),
34    Hashed([u8; 32]),
35}
36
37impl Deref for Key {
38    type Target = [u8];
39
40    fn deref(&self) -> &Self::Target {
41        match self {
42            Key::Plain(slice) => slice,
43            Key::Hashed(bytes) => bytes,
44        }
45    }
46}
47
48impl Borrow<[u8]> for Key {
49    fn borrow(&self) -> &[u8] {
50        self.deref()
51    }
52}
53
54impl rusqlite::ToSql for Key {
55    fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
56        self.deref().to_sql()
57    }
58}
59
60#[async_trait]
61pub(crate) trait SqliteAsyncConnExt {
62    async fn execute<P>(
63        &self,
64        sql: impl AsRef<str> + Send + 'static,
65        params: P,
66    ) -> rusqlite::Result<usize>
67    where
68        P: Params + Send + 'static;
69
70    async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()>;
71
72    async fn prepare<T, F>(
73        &self,
74        sql: impl AsRef<str> + Send + 'static,
75        f: F,
76    ) -> rusqlite::Result<T>
77    where
78        T: Send + 'static,
79        F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static;
80
81    async fn query_row<T, P, F>(
82        &self,
83        sql: impl AsRef<str> + Send + 'static,
84        params: P,
85        f: F,
86    ) -> rusqlite::Result<T>
87    where
88        T: Send + 'static,
89        P: Params + Send + 'static,
90        F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static;
91
92    async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
93    where
94        T: Send + 'static,
95        E: From<rusqlite::Error> + Send + 'static,
96        F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static;
97
98    async fn chunk_large_query_over<Query, Res>(
99        &self,
100        mut keys_to_chunk: Vec<Key>,
101        result_capacity: Option<usize>,
102        do_query: Query,
103    ) -> Result<Vec<Res>>
104    where
105        Res: Send + 'static,
106        Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
107
108    /// Apply the [`RuntimeConfig`].
109    ///
110    /// It will call the `Self::optimize`, `Self::cache_size` or
111    /// `Self::journal_size_limit` methods automatically based on the
112    /// `RuntimeConfig` values.
113    ///
114    /// It is possible to call these methods individually though. This
115    /// `apply_runtime_config` method allows to automate this process.
116    async fn apply_runtime_config(&self, runtime_config: RuntimeConfig) -> Result<()> {
117        let RuntimeConfig { optimize, cache_size, journal_size_limit } = runtime_config;
118
119        if optimize {
120            self.optimize().await?;
121        }
122
123        self.cache_size(cache_size).await?;
124        self.journal_size_limit(journal_size_limit).await?;
125
126        Ok(())
127    }
128
129    /// Optimize the database.
130    ///
131    /// The SQLite documentation recommends to run this regularly and after any
132    /// schema change. The easiest is to do it consistently when the store is
133    /// constructed, after eventual migrations.
134    ///
135    /// See [`PRAGMA optimize`] to learn more.
136    ///
137    /// [`PRAGMA cache_size`]: https://www.sqlite.org/pragma.html#pragma_optimize
138    async fn optimize(&self) -> Result<()> {
139        self.execute_batch("PRAGMA optimize = 0x10002;").await?;
140        Ok(())
141    }
142
143    /// Define the maximum size in **bytes** the SQLite cache can use.
144    ///
145    /// See [`PRAGMA cache_size`] to learn more.
146    ///
147    /// [`PRAGMA cache_size`]: https://www.sqlite.org/pragma.html#pragma_cache_size
148    async fn cache_size(&self, cache_size: u32) -> Result<()> {
149        // `N` in `PRAGMA cache_size = -N` is expressed in kibibytes.
150        // `cache_size` is expressed in bytes. Let's convert.
151        let n = cache_size / 1024;
152
153        self.execute_batch(format!("PRAGMA cache_size = -{n};")).await?;
154        Ok(())
155    }
156
157    /// Limit the size of the WAL file, in **bytes**.
158    ///
159    /// By default, while the DB connections of the databases are open, [the
160    /// size of the WAL file can keep increasing][size_wal_file] depending on
161    /// the size needed for the transactions. A critical case is `VACUUM`
162    /// which basically writes the content of the DB file to the WAL file
163    /// before writing it back to the DB file, so we end up taking twice the
164    /// size of the database.
165    ///
166    /// By setting this limit, the WAL file is truncated after its content is
167    /// written to the database, if it is bigger than the limit.
168    ///
169    /// See [`PRAGMA journal_size_limit`] to learn more. The value `limit`
170    /// corresponds to `N` in `PRAGMA journal_size_limit = N`.
171    ///
172    /// [size_wal_file]: https://www.sqlite.org/wal.html#avoiding_excessively_large_wal_files
173    /// [`PRAGMA journal_size_limit`]: https://www.sqlite.org/pragma.html#pragma_journal_size_limit
174    async fn journal_size_limit(&self, limit: u32) -> Result<()> {
175        self.execute_batch(format!("PRAGMA journal_size_limit = {limit};")).await?;
176        Ok(())
177    }
178
179    /// Defragment the database and free space on the filesystem.
180    ///
181    /// Only returns an error in tests, otherwise the error is only logged.
182    async fn vacuum(&self) -> Result<()> {
183        if let Err(error) = self.execute_batch("VACUUM").await {
184            // Since this is an optimisation step, do not propagate the error
185            // but log it.
186            #[cfg(not(any(test, debug_assertions)))]
187            tracing::warn!("Failed to vacuum database: {error}");
188
189            // We want to know if there is an error with this step during tests.
190            #[cfg(any(test, debug_assertions))]
191            return Err(error.into());
192        }
193
194        Ok(())
195    }
196}
197
198#[async_trait]
199impl SqliteAsyncConnExt for SqliteAsyncConn {
200    async fn execute<P>(
201        &self,
202        sql: impl AsRef<str> + Send + 'static,
203        params: P,
204    ) -> rusqlite::Result<usize>
205    where
206        P: Params + Send + 'static,
207    {
208        self.interact(move |conn| conn.execute(sql.as_ref(), params)).await.unwrap()
209    }
210
211    async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()> {
212        self.interact(move |conn| conn.execute_batch(sql.as_ref())).await.unwrap()
213    }
214
215    async fn prepare<T, F>(
216        &self,
217        sql: impl AsRef<str> + Send + 'static,
218        f: F,
219    ) -> rusqlite::Result<T>
220    where
221        T: Send + 'static,
222        F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static,
223    {
224        self.interact(move |conn| f(conn.prepare(sql.as_ref())?)).await.unwrap()
225    }
226
227    async fn query_row<T, P, F>(
228        &self,
229        sql: impl AsRef<str> + Send + 'static,
230        params: P,
231        f: F,
232    ) -> rusqlite::Result<T>
233    where
234        T: Send + 'static,
235        P: Params + Send + 'static,
236        F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static,
237    {
238        self.interact(move |conn| conn.query_row(sql.as_ref(), params, f)).await.unwrap()
239    }
240
241    async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
242    where
243        T: Send + 'static,
244        E: From<rusqlite::Error> + Send + 'static,
245        F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static,
246    {
247        self.interact(move |conn| {
248            let txn = conn.transaction()?;
249            let result = f(&txn)?;
250            txn.commit()?;
251            Ok(result)
252        })
253        .await
254        .unwrap()
255    }
256
257    /// Chunk a large query over some keys.
258    ///
259    /// Imagine there is a _dynamic_ query that runs potentially large number of
260    /// parameters, so much that the maximum number of parameters can be hit.
261    /// Then, this helper is for you. It will execute the query on chunks of
262    /// parameters.
263    async fn chunk_large_query_over<Query, Res>(
264        &self,
265        keys_to_chunk: Vec<Key>,
266        result_capacity: Option<usize>,
267        do_query: Query,
268    ) -> Result<Vec<Res>>
269    where
270        Res: Send + 'static,
271        Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
272    {
273        self.with_transaction(move |txn| {
274            txn.chunk_large_query_over(keys_to_chunk, result_capacity, do_query)
275        })
276        .await
277    }
278}
279
280pub(crate) trait SqliteTransactionExt {
281    fn chunk_large_query_over<Key, Query, Res>(
282        &self,
283        keys_to_chunk: Vec<Key>,
284        result_capacity: Option<usize>,
285        do_query: Query,
286    ) -> Result<Vec<Res>>
287    where
288        Res: Send + 'static,
289        Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
290}
291
292impl SqliteTransactionExt for Transaction<'_> {
293    fn chunk_large_query_over<Key, Query, Res>(
294        &self,
295        mut keys_to_chunk: Vec<Key>,
296        result_capacity: Option<usize>,
297        do_query: Query,
298    ) -> Result<Vec<Res>>
299    where
300        Res: Send + 'static,
301        Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
302    {
303        // Divide by 2 to allow space for more static parameters (not part of
304        // `keys_to_chunk`).
305        let maximum_chunk_size = self.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER) / 2;
306        let maximum_chunk_size: usize = maximum_chunk_size
307            .try_into()
308            .map_err(|_| Error::SqliteMaximumVariableNumber(maximum_chunk_size))?;
309
310        if keys_to_chunk.len() < maximum_chunk_size {
311            // Chunking isn't necessary.
312            let chunk = keys_to_chunk;
313
314            Ok(do_query(self, chunk)?)
315        } else {
316            // Chunking _is_ necessary.
317
318            // Define the accumulator.
319            let capacity = result_capacity.unwrap_or_default();
320            let mut all_results = Vec::with_capacity(capacity);
321
322            while !keys_to_chunk.is_empty() {
323                // Chunk and run the query.
324                let tail = keys_to_chunk.split_off(min(keys_to_chunk.len(), maximum_chunk_size));
325                let chunk = keys_to_chunk;
326                keys_to_chunk = tail;
327
328                all_results.extend(do_query(self, chunk)?);
329            }
330
331            Ok(all_results)
332        }
333    }
334}
335
336/// Extension trait for a [`rusqlite::Connection`] that contains a key-value
337/// table named `kv`.
338///
339/// The table should be created like this:
340///
341/// ```sql
342/// CREATE TABLE "kv" (
343///     "key" TEXT PRIMARY KEY NOT NULL,
344///     "value" BLOB NOT NULL
345/// );
346/// ```
347pub(crate) trait SqliteKeyValueStoreConnExt {
348    /// Store the given value for the given key.
349    fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()>;
350
351    /// Store the given value for the given key by serializing it.
352    fn set_serialized_kv<T: Serialize + Send>(&self, key: &str, value: T) -> Result<()> {
353        let serialized_value = rmp_serde::to_vec_named(&value)?;
354        self.set_kv(key, &serialized_value)?;
355
356        Ok(())
357    }
358
359    /// Removes the current key and value if exists.
360    fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
361
362    /// Set the version of the database.
363    fn set_db_version(&self, version: u8) -> rusqlite::Result<()> {
364        self.set_kv("version", &[version])
365    }
366}
367
368impl SqliteKeyValueStoreConnExt for rusqlite::Connection {
369    fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()> {
370        self.execute(
371            "INSERT INTO kv VALUES (?1, ?2) ON CONFLICT (key) DO UPDATE SET value = ?2",
372            (key, value),
373        )?;
374        Ok(())
375    }
376
377    fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
378        self.execute("DELETE FROM kv WHERE key = ?1", (key,))?;
379        Ok(())
380    }
381}
382
383/// Extension trait for an [`SqliteAsyncConn`] that contains a key-value
384/// table named `kv`.
385///
386/// The table should be created like this:
387///
388/// ```sql
389/// CREATE TABLE "kv" (
390///     "key" TEXT PRIMARY KEY NOT NULL,
391///     "value" BLOB NOT NULL
392/// );
393/// ```
394#[async_trait]
395pub(crate) trait SqliteKeyValueStoreAsyncConnExt: SqliteAsyncConnExt {
396    /// Whether the `kv` table exists in this database.
397    async fn kv_table_exists(&self) -> rusqlite::Result<bool> {
398        self.query_row(
399            "SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'kv')",
400            (),
401            |row| row.get(0),
402        )
403        .await
404    }
405
406    /// Get the stored value for the given key.
407    async fn get_kv(&self, key: &str) -> rusqlite::Result<Option<Vec<u8>>> {
408        let key = key.to_owned();
409        self.query_row("SELECT value FROM kv WHERE key = ?", (key,), |row| row.get(0))
410            .await
411            .optional()
412    }
413
414    /// Get the stored serialized value for the given key.
415    async fn get_serialized_kv<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
416        let Some(bytes) = self.get_kv(key).await? else {
417            return Ok(None);
418        };
419
420        Ok(Some(rmp_serde::from_slice(&bytes)?))
421    }
422
423    /// Store the given value for the given key.
424    async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()>;
425
426    /// Store the given value for the given key by serializing it.
427    async fn set_serialized_kv<T: Serialize + Send + 'static>(
428        &self,
429        key: &str,
430        value: T,
431    ) -> Result<()>;
432
433    /// Clears the given value for the given key.
434    async fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
435
436    /// Get the version of the database.
437    async fn db_version(&self) -> Result<u8, OpenStoreError> {
438        let kv_exists = self.kv_table_exists().await.map_err(OpenStoreError::LoadVersion)?;
439
440        if kv_exists {
441            match self.get_kv("version").await.map_err(OpenStoreError::LoadVersion)?.as_deref() {
442                Some([v]) => Ok(*v),
443                Some(_) => Err(OpenStoreError::InvalidVersion),
444                None => Err(OpenStoreError::MissingVersion),
445            }
446        } else {
447            Ok(0)
448        }
449    }
450
451    /// Get the [`StoreCipher`] of the database or create it.
452    async fn get_or_create_store_cipher(
453        &self,
454        passphrase: &str,
455    ) -> Result<StoreCipher, OpenStoreError> {
456        let encrypted_cipher = self.get_kv("cipher").await.map_err(OpenStoreError::LoadCipher)?;
457
458        let cipher = if let Some(encrypted) = encrypted_cipher {
459            StoreCipher::import(passphrase, &encrypted)?
460        } else {
461            let cipher = StoreCipher::new()?;
462            #[cfg(not(test))]
463            let export = cipher.export(passphrase);
464            #[cfg(test)]
465            let export = cipher._insecure_export_fast_for_testing(passphrase);
466            self.set_kv("cipher", export?).await.map_err(OpenStoreError::SaveCipher)?;
467            cipher
468        };
469
470        Ok(cipher)
471    }
472}
473
474#[async_trait]
475impl SqliteKeyValueStoreAsyncConnExt for SqliteAsyncConn {
476    async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()> {
477        let key = key.to_owned();
478        self.interact(move |conn| conn.set_kv(&key, &value)).await.unwrap()?;
479
480        Ok(())
481    }
482
483    async fn set_serialized_kv<T: Serialize + Send + 'static>(
484        &self,
485        key: &str,
486        value: T,
487    ) -> Result<()> {
488        let key = key.to_owned();
489        self.interact(move |conn| conn.set_serialized_kv(&key, value)).await.unwrap()?;
490
491        Ok(())
492    }
493
494    async fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
495        let key = key.to_owned();
496        self.interact(move |conn| conn.clear_kv(&key)).await.unwrap()?;
497
498        Ok(())
499    }
500}
501
502/// Repeat `?` n times, where n is defined by `count`. `?` are comma-separated.
503pub(crate) fn repeat_vars(count: usize) -> impl fmt::Display {
504    assert_ne!(count, 0, "Can't generate zero repeated vars");
505
506    iter::repeat_n("?", count).format(",")
507}
508
509/// Convert the given `SystemTime` to a timestamp, as the number of seconds
510/// since Unix Epoch.
511///
512/// Returns an `i64` as it is the numeric type used by SQLite.
513pub(crate) fn time_to_timestamp(time: SystemTime) -> i64 {
514    time.duration_since(SystemTime::UNIX_EPOCH)
515        .ok()
516        .and_then(|d| d.as_secs().try_into().ok())
517        // It is unlikely to happen unless the time on the system is seriously wrong, but we always
518        // need a value.
519        .unwrap_or(0)
520}
521
522#[cfg(test)]
523mod unit_tests {
524    use std::time::Duration;
525
526    use super::*;
527
528    #[test]
529    fn can_generate_repeated_vars() {
530        assert_eq!(repeat_vars(1).to_string(), "?");
531        assert_eq!(repeat_vars(2).to_string(), "?,?");
532        assert_eq!(repeat_vars(5).to_string(), "?,?,?,?,?");
533    }
534
535    #[test]
536    #[should_panic(expected = "Can't generate zero repeated vars")]
537    fn generating_zero_vars_panics() {
538        repeat_vars(0);
539    }
540
541    #[test]
542    fn test_time_to_timestamp() {
543        assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH), 0);
544        assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH + Duration::from_secs(60)), 60);
545
546        // Fallback value on overflow.
547        assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH - Duration::from_secs(60)), 0);
548    }
549}