matrix_sdk_sqlite/
utils.rs

1// Copyright 2022 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::fmt;
16use std::{borrow::Borrow, cmp::min, iter, ops::Deref};
17
18use async_trait::async_trait;
19use deadpool_sqlite::Object as SqliteAsyncConn;
20use itertools::Itertools;
21use matrix_sdk_store_encryption::StoreCipher;
22use ruma::time::SystemTime;
23use rusqlite::{limits::Limit, OptionalExtension, Params, Row, Statement, Transaction};
24use serde::{de::DeserializeOwned, Serialize};
25
26use crate::{
27    error::{Error, Result},
28    OpenStoreError,
29};
30
31#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
32pub(crate) enum Key {
33    Plain(Vec<u8>),
34    Hashed([u8; 32]),
35}
36
37impl Deref for Key {
38    type Target = [u8];
39
40    fn deref(&self) -> &Self::Target {
41        match self {
42            Key::Plain(slice) => slice,
43            Key::Hashed(bytes) => bytes,
44        }
45    }
46}
47
48impl Borrow<[u8]> for Key {
49    fn borrow(&self) -> &[u8] {
50        self.deref()
51    }
52}
53
54impl rusqlite::ToSql for Key {
55    fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
56        self.deref().to_sql()
57    }
58}
59
60#[async_trait]
61pub(crate) trait SqliteAsyncConnExt {
62    async fn execute<P>(
63        &self,
64        sql: impl AsRef<str> + Send + 'static,
65        params: P,
66    ) -> rusqlite::Result<usize>
67    where
68        P: Params + Send + 'static;
69
70    async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()>;
71
72    async fn prepare<T, F>(
73        &self,
74        sql: impl AsRef<str> + Send + 'static,
75        f: F,
76    ) -> rusqlite::Result<T>
77    where
78        T: Send + 'static,
79        F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static;
80
81    async fn query_row<T, P, F>(
82        &self,
83        sql: impl AsRef<str> + Send + 'static,
84        params: P,
85        f: F,
86    ) -> rusqlite::Result<T>
87    where
88        T: Send + 'static,
89        P: Params + Send + 'static,
90        F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static;
91
92    async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
93    where
94        T: Send + 'static,
95        E: From<rusqlite::Error> + Send + 'static,
96        F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static;
97
98    async fn chunk_large_query_over<Query, Res>(
99        &self,
100        mut keys_to_chunk: Vec<Key>,
101        result_capacity: Option<usize>,
102        do_query: Query,
103    ) -> Result<Vec<Res>>
104    where
105        Res: Send + 'static,
106        Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
107
108    /// Optimize the database.
109    ///
110    /// [The SQLite docs] recommend to run this regularly and after any schema
111    /// change. The easiest is to do it consistently when the state store is
112    /// constructed, after eventual migrations.
113    ///
114    /// [The SQLite docs]: https://www.sqlite.org/pragma.html#pragma_optimize
115    async fn optimize(&self) -> Result<()> {
116        self.execute_batch("PRAGMA optimize=0x10002;").await?;
117        Ok(())
118    }
119
120    /// Limit the size of the WAL file.
121    ///
122    /// By default, while the DB connections of the databases are open, [the
123    /// size of the WAL file can keep increasing] depending on the size
124    /// needed for the transactions. A critical case is VACUUM which
125    /// basically writes the content of the DB file to the WAL file before
126    /// writing it back to the DB file, so we end up taking twice the size
127    /// of the database.
128    ///
129    /// By setting this limit, the WAL file is truncated after its content is
130    /// written to the database, if it is bigger than the limit.
131    ///
132    /// The limit is set to 10MB.
133    ///
134    /// [the size of the WAL file can keep increasing]: https://www.sqlite.org/wal.html#avoiding_excessively_large_wal_files
135    async fn set_journal_size_limit(&self) -> Result<()> {
136        self.execute_batch("PRAGMA journal_size_limit = 10000000;").await.map_err(Error::from)?;
137        Ok(())
138    }
139
140    /// Defragment the database and free space on the filesystem.
141    ///
142    /// Only returns an error in tests, otherwise the error is only logged.
143    async fn vacuum(&self) -> Result<()> {
144        if let Err(error) = self.execute_batch("VACUUM").await {
145            // Since this is an optimisation step, do not propagate the error
146            // but log it.
147            #[cfg(not(any(test, debug_assertions)))]
148            tracing::warn!("Failed to vacuum database: {error}");
149
150            // We want to know if there is an error with this step during tests.
151            #[cfg(any(test, debug_assertions))]
152            return Err(error.into());
153        }
154
155        Ok(())
156    }
157}
158
159#[async_trait]
160impl SqliteAsyncConnExt for SqliteAsyncConn {
161    async fn execute<P>(
162        &self,
163        sql: impl AsRef<str> + Send + 'static,
164        params: P,
165    ) -> rusqlite::Result<usize>
166    where
167        P: Params + Send + 'static,
168    {
169        self.interact(move |conn| conn.execute(sql.as_ref(), params)).await.unwrap()
170    }
171
172    async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()> {
173        self.interact(move |conn| conn.execute_batch(sql.as_ref())).await.unwrap()
174    }
175
176    async fn prepare<T, F>(
177        &self,
178        sql: impl AsRef<str> + Send + 'static,
179        f: F,
180    ) -> rusqlite::Result<T>
181    where
182        T: Send + 'static,
183        F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static,
184    {
185        self.interact(move |conn| f(conn.prepare(sql.as_ref())?)).await.unwrap()
186    }
187
188    async fn query_row<T, P, F>(
189        &self,
190        sql: impl AsRef<str> + Send + 'static,
191        params: P,
192        f: F,
193    ) -> rusqlite::Result<T>
194    where
195        T: Send + 'static,
196        P: Params + Send + 'static,
197        F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static,
198    {
199        self.interact(move |conn| conn.query_row(sql.as_ref(), params, f)).await.unwrap()
200    }
201
202    async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
203    where
204        T: Send + 'static,
205        E: From<rusqlite::Error> + Send + 'static,
206        F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static,
207    {
208        self.interact(move |conn| {
209            let txn = conn.transaction()?;
210            let result = f(&txn)?;
211            txn.commit()?;
212            Ok(result)
213        })
214        .await
215        .unwrap()
216    }
217
218    /// Chunk a large query over some keys.
219    ///
220    /// Imagine there is a _dynamic_ query that runs potentially large number of
221    /// parameters, so much that the maximum number of parameters can be hit.
222    /// Then, this helper is for you. It will execute the query on chunks of
223    /// parameters.
224    async fn chunk_large_query_over<Query, Res>(
225        &self,
226        keys_to_chunk: Vec<Key>,
227        result_capacity: Option<usize>,
228        do_query: Query,
229    ) -> Result<Vec<Res>>
230    where
231        Res: Send + 'static,
232        Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
233    {
234        self.with_transaction(move |txn| {
235            txn.chunk_large_query_over(keys_to_chunk, result_capacity, do_query)
236        })
237        .await
238    }
239}
240
241pub(crate) trait SqliteTransactionExt {
242    fn chunk_large_query_over<Key, Query, Res>(
243        &self,
244        keys_to_chunk: Vec<Key>,
245        result_capacity: Option<usize>,
246        do_query: Query,
247    ) -> Result<Vec<Res>>
248    where
249        Res: Send + 'static,
250        Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
251}
252
253impl SqliteTransactionExt for Transaction<'_> {
254    fn chunk_large_query_over<Key, Query, Res>(
255        &self,
256        mut keys_to_chunk: Vec<Key>,
257        result_capacity: Option<usize>,
258        do_query: Query,
259    ) -> Result<Vec<Res>>
260    where
261        Res: Send + 'static,
262        Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
263    {
264        // Divide by 2 to allow space for more static parameters (not part of
265        // `keys_to_chunk`).
266        let maximum_chunk_size = self.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER) / 2;
267        let maximum_chunk_size: usize = maximum_chunk_size
268            .try_into()
269            .map_err(|_| Error::SqliteMaximumVariableNumber(maximum_chunk_size))?;
270
271        if keys_to_chunk.len() < maximum_chunk_size {
272            // Chunking isn't necessary.
273            let chunk = keys_to_chunk;
274
275            Ok(do_query(self, chunk)?)
276        } else {
277            // Chunking _is_ necessary.
278
279            // Define the accumulator.
280            let capacity = result_capacity.unwrap_or_default();
281            let mut all_results = Vec::with_capacity(capacity);
282
283            while !keys_to_chunk.is_empty() {
284                // Chunk and run the query.
285                let tail = keys_to_chunk.split_off(min(keys_to_chunk.len(), maximum_chunk_size));
286                let chunk = keys_to_chunk;
287                keys_to_chunk = tail;
288
289                all_results.extend(do_query(self, chunk)?);
290            }
291
292            Ok(all_results)
293        }
294    }
295}
296
297/// Extension trait for a [`rusqlite::Connection`] that contains a key-value
298/// table named `kv`.
299///
300/// The table should be created like this:
301///
302/// ```sql
303/// CREATE TABLE "kv" (
304///     "key" TEXT PRIMARY KEY NOT NULL,
305///     "value" BLOB NOT NULL
306/// );
307/// ```
308pub(crate) trait SqliteKeyValueStoreConnExt {
309    /// Store the given value for the given key.
310    fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()>;
311
312    /// Store the given value for the given key by serializing it.
313    fn set_serialized_kv<T: Serialize + Send>(&self, key: &str, value: T) -> Result<()> {
314        let serialized_value = rmp_serde::to_vec_named(&value)?;
315        self.set_kv(key, &serialized_value)?;
316
317        Ok(())
318    }
319
320    /// Removes the current key and value if exists.
321    fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
322
323    /// Set the version of the database.
324    fn set_db_version(&self, version: u8) -> rusqlite::Result<()> {
325        self.set_kv("version", &[version])
326    }
327}
328
329impl SqliteKeyValueStoreConnExt for rusqlite::Connection {
330    fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()> {
331        self.execute(
332            "INSERT INTO kv VALUES (?1, ?2) ON CONFLICT (key) DO UPDATE SET value = ?2",
333            (key, value),
334        )?;
335        Ok(())
336    }
337
338    fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
339        self.execute("DELETE FROM kv WHERE key = ?1", (key,))?;
340        Ok(())
341    }
342}
343
344/// Extension trait for an [`SqliteAsyncConn`] that contains a key-value
345/// table named `kv`.
346///
347/// The table should be created like this:
348///
349/// ```sql
350/// CREATE TABLE "kv" (
351///     "key" TEXT PRIMARY KEY NOT NULL,
352///     "value" BLOB NOT NULL
353/// );
354/// ```
355#[async_trait]
356pub(crate) trait SqliteKeyValueStoreAsyncConnExt: SqliteAsyncConnExt {
357    /// Whether the `kv` table exists in this database.
358    async fn kv_table_exists(&self) -> rusqlite::Result<bool> {
359        self.query_row(
360            "SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'kv')",
361            (),
362            |row| row.get(0),
363        )
364        .await
365    }
366
367    /// Get the stored value for the given key.
368    async fn get_kv(&self, key: &str) -> rusqlite::Result<Option<Vec<u8>>> {
369        let key = key.to_owned();
370        self.query_row("SELECT value FROM kv WHERE key = ?", (key,), |row| row.get(0))
371            .await
372            .optional()
373    }
374
375    /// Get the stored serialized value for the given key.
376    async fn get_serialized_kv<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
377        let Some(bytes) = self.get_kv(key).await? else {
378            return Ok(None);
379        };
380
381        Ok(Some(rmp_serde::from_slice(&bytes)?))
382    }
383
384    /// Store the given value for the given key.
385    async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()>;
386
387    /// Store the given value for the given key by serializing it.
388    async fn set_serialized_kv<T: Serialize + Send + 'static>(
389        &self,
390        key: &str,
391        value: T,
392    ) -> Result<()>;
393
394    /// Clears the given value for the given key.
395    async fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
396
397    /// Get the version of the database.
398    async fn db_version(&self) -> Result<u8, OpenStoreError> {
399        let kv_exists = self.kv_table_exists().await.map_err(OpenStoreError::LoadVersion)?;
400
401        if kv_exists {
402            match self.get_kv("version").await.map_err(OpenStoreError::LoadVersion)?.as_deref() {
403                Some([v]) => Ok(*v),
404                Some(_) => Err(OpenStoreError::InvalidVersion),
405                None => Err(OpenStoreError::MissingVersion),
406            }
407        } else {
408            Ok(0)
409        }
410    }
411
412    /// Get the [`StoreCipher`] of the database or create it.
413    async fn get_or_create_store_cipher(
414        &self,
415        passphrase: &str,
416    ) -> Result<StoreCipher, OpenStoreError> {
417        let encrypted_cipher = self.get_kv("cipher").await.map_err(OpenStoreError::LoadCipher)?;
418
419        let cipher = if let Some(encrypted) = encrypted_cipher {
420            StoreCipher::import(passphrase, &encrypted)?
421        } else {
422            let cipher = StoreCipher::new()?;
423            #[cfg(not(test))]
424            let export = cipher.export(passphrase);
425            #[cfg(test)]
426            let export = cipher._insecure_export_fast_for_testing(passphrase);
427            self.set_kv("cipher", export?).await.map_err(OpenStoreError::SaveCipher)?;
428            cipher
429        };
430
431        Ok(cipher)
432    }
433}
434
435#[async_trait]
436impl SqliteKeyValueStoreAsyncConnExt for SqliteAsyncConn {
437    async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()> {
438        let key = key.to_owned();
439        self.interact(move |conn| conn.set_kv(&key, &value)).await.unwrap()?;
440
441        Ok(())
442    }
443
444    async fn set_serialized_kv<T: Serialize + Send + 'static>(
445        &self,
446        key: &str,
447        value: T,
448    ) -> Result<()> {
449        let key = key.to_owned();
450        self.interact(move |conn| conn.set_serialized_kv(&key, value)).await.unwrap()?;
451
452        Ok(())
453    }
454
455    async fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
456        let key = key.to_owned();
457        self.interact(move |conn| conn.clear_kv(&key)).await.unwrap()?;
458
459        Ok(())
460    }
461}
462
463/// Repeat `?` n times, where n is defined by `count`. `?` are comma-separated.
464pub(crate) fn repeat_vars(count: usize) -> impl fmt::Display {
465    assert_ne!(count, 0, "Can't generate zero repeated vars");
466
467    iter::repeat_n("?", count).format(",")
468}
469
470/// Convert the given `SystemTime` to a timestamp, as the number of seconds
471/// since Unix Epoch.
472///
473/// Returns an `i64` as it is the numeric type used by SQLite.
474pub(crate) fn time_to_timestamp(time: SystemTime) -> i64 {
475    time.duration_since(SystemTime::UNIX_EPOCH)
476        .ok()
477        .and_then(|d| d.as_secs().try_into().ok())
478        // It is unlikely to happen unless the time on the system is seriously wrong, but we always
479        // need a value.
480        .unwrap_or(0)
481}
482
483#[cfg(test)]
484mod unit_tests {
485    use std::time::Duration;
486
487    use super::*;
488
489    #[test]
490    fn can_generate_repeated_vars() {
491        assert_eq!(repeat_vars(1).to_string(), "?");
492        assert_eq!(repeat_vars(2).to_string(), "?,?");
493        assert_eq!(repeat_vars(5).to_string(), "?,?,?,?,?");
494    }
495
496    #[test]
497    #[should_panic(expected = "Can't generate zero repeated vars")]
498    fn generating_zero_vars_panics() {
499        repeat_vars(0);
500    }
501
502    #[test]
503    fn test_time_to_timestamp() {
504        assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH), 0);
505        assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH + Duration::from_secs(60)), 60);
506
507        // Fallback value on overflow.
508        assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH - Duration::from_secs(60)), 0);
509    }
510}