matrix_sdk_sqlite/
crypto_store.rs

1// Copyright 2022 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    borrow::Cow,
17    collections::HashMap,
18    fmt,
19    path::Path,
20    sync::{Arc, RwLock},
21};
22
23use async_trait::async_trait;
24use deadpool_sqlite::{Object as SqliteAsyncConn, Pool as SqlitePool, Runtime};
25use matrix_sdk_crypto::{
26    olm::{
27        InboundGroupSession, OutboundGroupSession, PickledInboundGroupSession,
28        PrivateCrossSigningIdentity, SenderDataType, Session, StaticAccountData,
29    },
30    store::{
31        BackupKeys, Changes, CryptoStore, DehydratedDeviceKey, PendingChanges, RoomKeyCounts,
32        RoomSettings,
33    },
34    types::events::room_key_withheld::RoomKeyWithheldEvent,
35    Account, DeviceData, GossipRequest, GossippedSecret, SecretInfo, TrackedUser, UserIdentityData,
36};
37use matrix_sdk_store_encryption::StoreCipher;
38use ruma::{
39    events::secret::request::SecretName, DeviceId, MilliSecondsSinceUnixEpoch, OwnedDeviceId,
40    RoomId, TransactionId, UserId,
41};
42use rusqlite::{named_params, params_from_iter, OptionalExtension};
43use serde::{de::DeserializeOwned, Serialize};
44use tokio::{fs, sync::Mutex};
45use tracing::{debug, instrument, warn};
46use vodozemac::Curve25519PublicKey;
47
48use crate::{
49    error::{Error, Result},
50    utils::{
51        repeat_vars, Key, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt,
52        SqliteKeyValueStoreConnExt,
53    },
54    OpenStoreError, SqliteStoreConfig,
55};
56
57/// The database name.
58const DATABASE_NAME: &str = "matrix-sdk-crypto.sqlite3";
59
60/// A sqlite based cryptostore.
61#[derive(Clone)]
62pub struct SqliteCryptoStore {
63    store_cipher: Option<Arc<StoreCipher>>,
64    pool: SqlitePool,
65
66    // DB values cached in memory
67    static_account: Arc<RwLock<Option<StaticAccountData>>>,
68    save_changes_lock: Arc<Mutex<()>>,
69}
70
71#[cfg(not(tarpaulin_include))]
72impl fmt::Debug for SqliteCryptoStore {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        f.debug_struct("SqliteCryptoStore").finish_non_exhaustive()
75    }
76}
77
78impl SqliteCryptoStore {
79    /// Open the sqlite-based crypto store at the given path using the given
80    /// passphrase to encrypt private data.
81    pub async fn open(
82        path: impl AsRef<Path>,
83        passphrase: Option<&str>,
84    ) -> Result<Self, OpenStoreError> {
85        Self::open_with_config(SqliteStoreConfig::new(path).passphrase(passphrase)).await
86    }
87
88    /// Open the sqlite-based crypto store with the config open config.
89    pub async fn open_with_config(config: SqliteStoreConfig) -> Result<Self, OpenStoreError> {
90        let SqliteStoreConfig { path, passphrase, pool_config, runtime_config } = config;
91
92        fs::create_dir_all(&path).await.map_err(OpenStoreError::CreateDir)?;
93
94        let mut config = deadpool_sqlite::Config::new(path.join(DATABASE_NAME));
95        config.pool = Some(pool_config);
96
97        let pool = config.create_pool(Runtime::Tokio1)?;
98
99        let this = Self::open_with_pool(pool, passphrase.as_deref()).await?;
100        this.pool.get().await?.apply_runtime_config(runtime_config).await?;
101
102        Ok(this)
103    }
104
105    /// Create a sqlite-based crypto store using the given sqlite database pool.
106    /// The given passphrase will be used to encrypt private data.
107    async fn open_with_pool(
108        pool: SqlitePool,
109        passphrase: Option<&str>,
110    ) -> Result<Self, OpenStoreError> {
111        let conn = pool.get().await?;
112
113        let version = conn.db_version().await?;
114        run_migrations(&conn, version).await?;
115
116        let store_cipher = match passphrase {
117            Some(p) => Some(Arc::new(conn.get_or_create_store_cipher(p).await?)),
118            None => None,
119        };
120
121        Ok(SqliteCryptoStore {
122            store_cipher,
123            pool,
124            static_account: Arc::new(RwLock::new(None)),
125            save_changes_lock: Default::default(),
126        })
127    }
128
129    fn encode_value(&self, value: Vec<u8>) -> Result<Vec<u8>> {
130        if let Some(key) = &self.store_cipher {
131            let encrypted = key.encrypt_value_data(value)?;
132            Ok(rmp_serde::to_vec_named(&encrypted)?)
133        } else {
134            Ok(value)
135        }
136    }
137
138    fn decode_value<'a>(&self, value: &'a [u8]) -> Result<Cow<'a, [u8]>> {
139        if let Some(key) = &self.store_cipher {
140            let encrypted = rmp_serde::from_slice(value)?;
141            let decrypted = key.decrypt_value_data(encrypted)?;
142            Ok(Cow::Owned(decrypted))
143        } else {
144            Ok(Cow::Borrowed(value))
145        }
146    }
147
148    fn serialize_json(&self, value: &impl Serialize) -> Result<Vec<u8>> {
149        let serialized = serde_json::to_vec(value)?;
150        self.encode_value(serialized)
151    }
152
153    fn deserialize_json<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T> {
154        let decoded = self.decode_value(data)?;
155        Ok(serde_json::from_slice(&decoded)?)
156    }
157
158    fn serialize_value(&self, value: &impl Serialize) -> Result<Vec<u8>> {
159        let serialized = rmp_serde::to_vec_named(value)?;
160        self.encode_value(serialized)
161    }
162
163    fn deserialize_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T> {
164        let decoded = self.decode_value(value)?;
165        Ok(rmp_serde::from_slice(&decoded)?)
166    }
167
168    fn deserialize_and_unpickle_inbound_group_session(
169        &self,
170        value: Vec<u8>,
171        backed_up: bool,
172    ) -> Result<InboundGroupSession> {
173        let mut pickle: PickledInboundGroupSession = self.deserialize_value(&value)?;
174
175        // The `backed_up` SQL column is the source of truth, because we update it
176        // inside `mark_inbound_group_sessions_as_backed_up` and don't update
177        // the pickled value inside the `data` column (until now, when we are puling it
178        // out of the DB).
179        pickle.backed_up = backed_up;
180
181        Ok(InboundGroupSession::from_pickle(pickle)?)
182    }
183
184    fn deserialize_key_request(&self, value: &[u8], sent_out: bool) -> Result<GossipRequest> {
185        let mut request: GossipRequest = self.deserialize_value(value)?;
186        // sent_out SQL column is source of truth, sent_out field in serialized value
187        // needed for other stores though
188        request.sent_out = sent_out;
189        Ok(request)
190    }
191
192    fn encode_key(&self, table_name: &str, key: impl AsRef<[u8]>) -> Key {
193        let bytes = key.as_ref();
194        if let Some(store_cipher) = &self.store_cipher {
195            Key::Hashed(store_cipher.hash_key(table_name, bytes))
196        } else {
197            Key::Plain(bytes.to_owned())
198        }
199    }
200
201    fn get_static_account(&self) -> Option<StaticAccountData> {
202        self.static_account.read().unwrap().clone()
203    }
204
205    async fn acquire(&self) -> Result<SqliteAsyncConn> {
206        Ok(self.pool.get().await?)
207    }
208}
209
210const DATABASE_VERSION: u8 = 9;
211
212/// key for the dehydrated device pickle key in the key/value table.
213const DEHYDRATED_DEVICE_PICKLE_KEY: &str = "dehydrated_device_pickle_key";
214
215/// Run migrations for the given version of the database.
216async fn run_migrations(conn: &SqliteAsyncConn, version: u8) -> Result<()> {
217    if version == 0 {
218        debug!("Creating database");
219    } else if version < DATABASE_VERSION {
220        debug!(version, new_version = DATABASE_VERSION, "Upgrading database");
221    } else {
222        return Ok(());
223    }
224
225    if version < 1 {
226        // First turn on WAL mode, this can't be done in the transaction, it fails with
227        // the error message: "cannot change into wal mode from within a transaction".
228        conn.execute_batch("PRAGMA journal_mode = wal;").await?;
229        conn.with_transaction(|txn| {
230            txn.execute_batch(include_str!("../migrations/crypto_store/001_init.sql"))?;
231            txn.set_db_version(1)
232        })
233        .await?;
234    }
235
236    if version < 2 {
237        conn.with_transaction(|txn| {
238            txn.execute_batch(include_str!("../migrations/crypto_store/002_reset_olm_hash.sql"))?;
239            txn.set_db_version(2)
240        })
241        .await?;
242    }
243
244    if version < 3 {
245        conn.with_transaction(|txn| {
246            txn.execute_batch(include_str!("../migrations/crypto_store/003_room_settings.sql"))?;
247            txn.set_db_version(3)
248        })
249        .await?;
250    }
251
252    if version < 4 {
253        conn.with_transaction(|txn| {
254            txn.execute_batch(include_str!(
255                "../migrations/crypto_store/004_drop_outbound_group_sessions.sql"
256            ))?;
257            txn.set_db_version(4)
258        })
259        .await?;
260    }
261
262    if version < 5 {
263        conn.with_transaction(|txn| {
264            txn.execute_batch(include_str!("../migrations/crypto_store/005_withheld_code.sql"))?;
265            txn.set_db_version(5)
266        })
267        .await?;
268    }
269
270    if version < 6 {
271        conn.with_transaction(|txn| {
272            txn.execute_batch(include_str!(
273                "../migrations/crypto_store/006_drop_outbound_group_sessions.sql"
274            ))?;
275            txn.set_db_version(6)
276        })
277        .await?;
278    }
279
280    if version < 7 {
281        conn.with_transaction(|txn| {
282            txn.execute_batch(include_str!("../migrations/crypto_store/007_lock_leases.sql"))?;
283            txn.set_db_version(7)
284        })
285        .await?;
286    }
287
288    if version < 8 {
289        conn.with_transaction(|txn| {
290            txn.execute_batch(include_str!("../migrations/crypto_store/008_secret_inbox.sql"))?;
291            txn.set_db_version(8)
292        })
293        .await?;
294    }
295
296    if version < 9 {
297        conn.with_transaction(|txn| {
298            txn.execute_batch(include_str!(
299                "../migrations/crypto_store/009_inbound_group_session_sender_key_sender_data_type.sql"
300            ))?;
301            txn.set_db_version(9)
302        })
303        .await?;
304    }
305
306    Ok(())
307}
308
309trait SqliteConnectionExt {
310    fn set_session(
311        &self,
312        session_id: &[u8],
313        sender_key: &[u8],
314        data: &[u8],
315    ) -> rusqlite::Result<()>;
316
317    fn set_inbound_group_session(
318        &self,
319        room_id: &[u8],
320        session_id: &[u8],
321        data: &[u8],
322        backed_up: bool,
323        sender_key: Option<&[u8]>,
324        sender_data_type: Option<u8>,
325    ) -> rusqlite::Result<()>;
326
327    fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
328
329    fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
330    fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()>;
331
332    fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
333
334    fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()>;
335
336    fn set_key_request(
337        &self,
338        request_id: &[u8],
339        sent_out: bool,
340        data: &[u8],
341    ) -> rusqlite::Result<()>;
342
343    fn set_direct_withheld(
344        &self,
345        session_id: &[u8],
346        room_id: &[u8],
347        data: &[u8],
348    ) -> rusqlite::Result<()>;
349
350    fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
351
352    fn set_secret(&self, request_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
353}
354
355impl SqliteConnectionExt for rusqlite::Connection {
356    fn set_session(
357        &self,
358        session_id: &[u8],
359        sender_key: &[u8],
360        data: &[u8],
361    ) -> rusqlite::Result<()> {
362        self.execute(
363            "INSERT INTO session (session_id, sender_key, data)
364             VALUES (?1, ?2, ?3)
365             ON CONFLICT (session_id) DO UPDATE SET data = ?3",
366            (session_id, sender_key, data),
367        )?;
368        Ok(())
369    }
370
371    fn set_inbound_group_session(
372        &self,
373        room_id: &[u8],
374        session_id: &[u8],
375        data: &[u8],
376        backed_up: bool,
377        sender_key: Option<&[u8]>,
378        sender_data_type: Option<u8>,
379    ) -> rusqlite::Result<()> {
380        self.execute(
381            "INSERT INTO inbound_group_session (session_id, room_id, data, backed_up, sender_key, sender_data_type) \
382             VALUES (?1, ?2, ?3, ?4, ?5, ?6)
383             ON CONFLICT (session_id) DO UPDATE SET data = ?3, backed_up = ?4, sender_key = ?5, sender_data_type = ?6",
384            (session_id, room_id, data, backed_up, sender_key, sender_data_type),
385        )?;
386        Ok(())
387    }
388
389    fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
390        self.execute(
391            "INSERT INTO outbound_group_session (room_id, data) \
392             VALUES (?1, ?2)
393             ON CONFLICT (room_id) DO UPDATE SET data = ?2",
394            (room_id, data),
395        )?;
396        Ok(())
397    }
398
399    fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
400        self.execute(
401            "INSERT INTO device (user_id, device_id, data) \
402             VALUES (?1, ?2, ?3)
403             ON CONFLICT (user_id, device_id) DO UPDATE SET data = ?3",
404            (user_id, device_id, data),
405        )?;
406        Ok(())
407    }
408
409    fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()> {
410        self.execute(
411            "DELETE FROM device WHERE user_id = ? AND device_id = ?",
412            (user_id, device_id),
413        )?;
414        Ok(())
415    }
416
417    fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
418        self.execute(
419            "INSERT INTO identity (user_id, data) \
420             VALUES (?1, ?2)
421             ON CONFLICT (user_id) DO UPDATE SET data = ?2",
422            (user_id, data),
423        )?;
424        Ok(())
425    }
426
427    fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()> {
428        self.execute("INSERT INTO olm_hash (data) VALUES (?) ON CONFLICT DO NOTHING", (data,))?;
429        Ok(())
430    }
431
432    fn set_key_request(
433        &self,
434        request_id: &[u8],
435        sent_out: bool,
436        data: &[u8],
437    ) -> rusqlite::Result<()> {
438        self.execute(
439            "INSERT INTO key_requests (request_id, sent_out, data)
440            VALUES (?1, ?2, ?3)
441            ON CONFLICT (request_id) DO UPDATE SET sent_out = ?2, data = ?3",
442            (request_id, sent_out, data),
443        )?;
444        Ok(())
445    }
446
447    fn set_direct_withheld(
448        &self,
449        session_id: &[u8],
450        room_id: &[u8],
451        data: &[u8],
452    ) -> rusqlite::Result<()> {
453        self.execute(
454            "INSERT INTO direct_withheld_info (session_id, room_id, data)
455            VALUES (?1, ?2, ?3)
456            ON CONFLICT (session_id) DO UPDATE SET room_id = ?2, data = ?3",
457            (session_id, room_id, data),
458        )?;
459        Ok(())
460    }
461
462    fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
463        self.execute(
464            "INSERT INTO room_settings (room_id, data)
465            VALUES (?1, ?2)
466            ON CONFLICT (room_id) DO UPDATE SET data = ?2",
467            (room_id, data),
468        )?;
469        Ok(())
470    }
471
472    fn set_secret(&self, secret_name: &[u8], data: &[u8]) -> rusqlite::Result<()> {
473        self.execute(
474            "INSERT INTO secrets (secret_name, data)
475            VALUES (?1, ?2)",
476            (secret_name, data),
477        )?;
478
479        Ok(())
480    }
481}
482
483#[async_trait]
484trait SqliteObjectCryptoStoreExt: SqliteAsyncConnExt {
485    async fn get_sessions_for_sender_key(&self, sender_key: Key) -> Result<Vec<Vec<u8>>> {
486        Ok(self
487            .prepare("SELECT data FROM session WHERE sender_key = ?", |mut stmt| {
488                stmt.query((sender_key,))?.mapped(|row| row.get(0)).collect()
489            })
490            .await?)
491    }
492
493    async fn get_inbound_group_session(
494        &self,
495        session_id: Key,
496    ) -> Result<Option<(Vec<u8>, Vec<u8>, bool)>> {
497        Ok(self
498            .query_row(
499                "SELECT room_id, data, backed_up FROM inbound_group_session WHERE session_id = ?",
500                (session_id,),
501                |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
502            )
503            .await
504            .optional()?)
505    }
506
507    async fn get_inbound_group_sessions(&self) -> Result<Vec<(Vec<u8>, bool)>> {
508        Ok(self
509            .prepare("SELECT data, backed_up FROM inbound_group_session", |mut stmt| {
510                stmt.query(())?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
511            })
512            .await?)
513    }
514
515    async fn get_inbound_group_session_counts(
516        &self,
517        _backup_version: Option<&str>,
518    ) -> Result<RoomKeyCounts> {
519        let total = self
520            .query_row("SELECT count(*) FROM inbound_group_session", (), |row| row.get(0))
521            .await?;
522        let backed_up = self
523            .query_row(
524                "SELECT count(*) FROM inbound_group_session WHERE backed_up = TRUE",
525                (),
526                |row| row.get(0),
527            )
528            .await?;
529        Ok(RoomKeyCounts { total, backed_up })
530    }
531
532    async fn get_inbound_group_sessions_for_device_batch(
533        &self,
534        sender_key: Key,
535        sender_data_type: SenderDataType,
536        after_session_id: Option<Key>,
537        limit: usize,
538    ) -> Result<Vec<(Vec<u8>, bool)>> {
539        Ok(self
540            .prepare(
541                "
542                SELECT data, backed_up
543                FROM inbound_group_session
544                WHERE sender_key = :sender_key
545                    AND sender_data_type = :sender_data_type
546                    AND session_id > :after_session_id
547                ORDER BY session_id
548                LIMIT :limit
549                ",
550                move |mut stmt| {
551                    let sender_data_type = sender_data_type as u8;
552
553                    // If we are not provided with an `after_session_id`, use a key which will sort
554                    // before all real keys: the empty string.
555                    let after_session_id = after_session_id.unwrap_or(Key::Plain(Vec::new()));
556
557                    stmt.query(named_params! {
558                        ":sender_key": sender_key,
559                        ":sender_data_type": sender_data_type,
560                        ":after_session_id": after_session_id,
561                        ":limit": limit,
562                    })?
563                    .mapped(|row| Ok((row.get(0)?, row.get(1)?)))
564                    .collect()
565                },
566            )
567            .await?)
568    }
569
570    async fn get_inbound_group_sessions_for_backup(&self, limit: usize) -> Result<Vec<Vec<u8>>> {
571        Ok(self
572            .prepare(
573                "SELECT data FROM inbound_group_session WHERE backed_up = FALSE LIMIT ?",
574                move |mut stmt| stmt.query((limit,))?.mapped(|row| row.get(0)).collect(),
575            )
576            .await?)
577    }
578
579    async fn mark_inbound_group_sessions_as_backed_up(&self, session_ids: Vec<Key>) -> Result<()> {
580        if session_ids.is_empty() {
581            // We are not expecting to be called with an empty list of sessions
582            warn!("No sessions to mark as backed up!");
583            return Ok(());
584        }
585
586        let session_ids_len = session_ids.len();
587
588        self.chunk_large_query_over(session_ids, None, move |txn, session_ids| {
589            // Safety: placeholders is not generated using any user input except the number
590            // of session IDs, so it is safe from injection.
591            let sql_params = repeat_vars(session_ids_len);
592            let query = format!("UPDATE inbound_group_session SET backed_up = TRUE where session_id IN ({sql_params})");
593            txn.prepare(&query)?.execute(params_from_iter(session_ids.iter()))?;
594            Ok(Vec::<()>::new())
595        }).await?;
596
597        Ok(())
598    }
599
600    async fn reset_inbound_group_session_backup_state(&self) -> Result<()> {
601        self.execute("UPDATE inbound_group_session SET backed_up = FALSE", ()).await?;
602        Ok(())
603    }
604
605    async fn get_outbound_group_session(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
606        Ok(self
607            .query_row(
608                "SELECT data FROM outbound_group_session WHERE room_id = ?",
609                (room_id,),
610                |row| row.get(0),
611            )
612            .await
613            .optional()?)
614    }
615
616    async fn get_device(&self, user_id: Key, device_id: Key) -> Result<Option<Vec<u8>>> {
617        Ok(self
618            .query_row(
619                "SELECT data FROM device WHERE user_id = ? AND device_id = ?",
620                (user_id, device_id),
621                |row| row.get(0),
622            )
623            .await
624            .optional()?)
625    }
626
627    async fn get_user_devices(&self, user_id: Key) -> Result<Vec<Vec<u8>>> {
628        Ok(self
629            .prepare("SELECT data FROM device WHERE user_id = ?", |mut stmt| {
630                stmt.query((user_id,))?.mapped(|row| row.get(0)).collect()
631            })
632            .await?)
633    }
634
635    async fn get_user_identity(&self, user_id: Key) -> Result<Option<Vec<u8>>> {
636        Ok(self
637            .query_row("SELECT data FROM identity WHERE user_id = ?", (user_id,), |row| row.get(0))
638            .await
639            .optional()?)
640    }
641
642    async fn has_olm_hash(&self, data: Vec<u8>) -> Result<bool> {
643        Ok(self
644            .query_row("SELECT count(*) FROM olm_hash WHERE data = ?", (data,), |row| {
645                row.get::<_, i32>(0)
646            })
647            .await?
648            > 0)
649    }
650
651    async fn get_tracked_users(&self) -> Result<Vec<Vec<u8>>> {
652        Ok(self
653            .prepare("SELECT data FROM tracked_user", |mut stmt| {
654                stmt.query(())?.mapped(|row| row.get(0)).collect()
655            })
656            .await?)
657    }
658
659    async fn add_tracked_users(&self, users: Vec<(Key, Vec<u8>)>) -> Result<()> {
660        Ok(self
661            .prepare(
662                "INSERT INTO tracked_user (user_id, data) \
663                 VALUES (?1, ?2) \
664                 ON CONFLICT (user_id) DO UPDATE SET data = ?2",
665                |mut stmt| {
666                    for (user_id, data) in users {
667                        stmt.execute((user_id, data))?;
668                    }
669
670                    Ok(())
671                },
672            )
673            .await?)
674    }
675
676    async fn get_outgoing_secret_request(
677        &self,
678        request_id: Key,
679    ) -> Result<Option<(Vec<u8>, bool)>> {
680        Ok(self
681            .query_row(
682                "SELECT data, sent_out FROM key_requests WHERE request_id = ?",
683                (request_id,),
684                |row| Ok((row.get(0)?, row.get(1)?)),
685            )
686            .await
687            .optional()?)
688    }
689
690    async fn get_outgoing_secret_requests(&self) -> Result<Vec<(Vec<u8>, bool)>> {
691        Ok(self
692            .prepare("SELECT data, sent_out FROM key_requests", |mut stmt| {
693                stmt.query(())?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
694            })
695            .await?)
696    }
697
698    async fn get_unsent_secret_requests(&self) -> Result<Vec<Vec<u8>>> {
699        Ok(self
700            .prepare("SELECT data FROM key_requests WHERE sent_out = FALSE", |mut stmt| {
701                stmt.query(())?.mapped(|row| row.get(0)).collect()
702            })
703            .await?)
704    }
705
706    async fn delete_key_request(&self, request_id: Key) -> Result<()> {
707        self.execute("DELETE FROM key_requests WHERE request_id = ?", (request_id,)).await?;
708        Ok(())
709    }
710
711    async fn get_secrets_from_inbox(&self, secret_name: Key) -> Result<Vec<Vec<u8>>> {
712        Ok(self
713            .prepare("SELECT data FROM secrets WHERE secret_name = ?", |mut stmt| {
714                stmt.query((secret_name,))?.mapped(|row| row.get(0)).collect()
715            })
716            .await?)
717    }
718
719    async fn delete_secrets_from_inbox(&self, secret_name: Key) -> Result<()> {
720        self.execute("DELETE FROM secrets WHERE secret_name = ?", (secret_name,)).await?;
721        Ok(())
722    }
723
724    async fn get_direct_withheld_info(
725        &self,
726        session_id: Key,
727        room_id: Key,
728    ) -> Result<Option<Vec<u8>>> {
729        Ok(self
730            .query_row(
731                "SELECT data FROM direct_withheld_info WHERE session_id = ?1 AND room_id = ?2",
732                (session_id, room_id),
733                |row| row.get(0),
734            )
735            .await
736            .optional()?)
737    }
738
739    async fn get_room_settings(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
740        Ok(self
741            .query_row("SELECT data FROM room_settings WHERE room_id = ?", (room_id,), |row| {
742                row.get(0)
743            })
744            .await
745            .optional()?)
746    }
747}
748
749#[async_trait]
750impl SqliteObjectCryptoStoreExt for SqliteAsyncConn {}
751
752#[async_trait]
753impl CryptoStore for SqliteCryptoStore {
754    type Error = Error;
755
756    async fn load_account(&self) -> Result<Option<Account>> {
757        let conn = self.acquire().await?;
758        if let Some(pickle) = conn.get_kv("account").await? {
759            let pickle = self.deserialize_value(&pickle)?;
760
761            let account = Account::from_pickle(pickle).map_err(|_| Error::Unpickle)?;
762
763            *self.static_account.write().unwrap() = Some(account.static_data().clone());
764
765            Ok(Some(account))
766        } else {
767            Ok(None)
768        }
769    }
770
771    async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
772        let conn = self.acquire().await?;
773        if let Some(i) = conn.get_kv("identity").await? {
774            let pickle = self.deserialize_value(&i)?;
775            Ok(Some(PrivateCrossSigningIdentity::from_pickle(pickle).map_err(|_| Error::Unpickle)?))
776        } else {
777            Ok(None)
778        }
779    }
780
781    async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
782        // Serialize calls to `save_pending_changes`; there are multiple await points
783        // below, and we're pickling data as we go, so we don't want to
784        // invalidate data we've previously read and overwrite it in the store.
785        // TODO: #2000 should make this lock go away, or change its shape.
786        let _guard = self.save_changes_lock.lock().await;
787
788        let pickled_account = if let Some(account) = changes.account {
789            *self.static_account.write().unwrap() = Some(account.static_data().clone());
790            Some(account.pickle())
791        } else {
792            None
793        };
794
795        let this = self.clone();
796        self.acquire()
797            .await?
798            .with_transaction(move |txn| {
799                if let Some(pickled_account) = pickled_account {
800                    let serialized_account = this.serialize_value(&pickled_account)?;
801                    txn.set_kv("account", &serialized_account)?;
802                }
803
804                Ok::<_, Error>(())
805            })
806            .await?;
807
808        Ok(())
809    }
810
811    async fn save_changes(&self, changes: Changes) -> Result<()> {
812        // Serialize calls to `save_changes`; there are multiple await points below, and
813        // we're pickling data as we go, so we don't want to invalidate data
814        // we've previously read and overwrite it in the store.
815        // TODO: #2000 should make this lock go away, or change its shape.
816        let _guard = self.save_changes_lock.lock().await;
817
818        let pickled_private_identity =
819            if let Some(i) = changes.private_identity { Some(i.pickle().await) } else { None };
820
821        let mut session_changes = Vec::new();
822
823        for session in changes.sessions {
824            let session_id = self.encode_key("session", session.session_id());
825            let sender_key = self.encode_key("session", session.sender_key().to_base64());
826            let pickle = session.pickle().await;
827            session_changes.push((session_id, sender_key, pickle));
828        }
829
830        let mut inbound_session_changes = Vec::new();
831        for session in changes.inbound_group_sessions {
832            let room_id = self.encode_key("inbound_group_session", session.room_id().as_bytes());
833            let session_id = self.encode_key("inbound_group_session", session.session_id());
834            let pickle = session.pickle().await;
835            let sender_key =
836                self.encode_key("inbound_group_session", session.sender_key().to_base64());
837            inbound_session_changes.push((room_id, session_id, pickle, sender_key));
838        }
839
840        let mut outbound_session_changes = Vec::new();
841        for session in changes.outbound_group_sessions {
842            let room_id = self.encode_key("outbound_group_session", session.room_id().as_bytes());
843            let pickle = session.pickle().await;
844            outbound_session_changes.push((room_id, pickle));
845        }
846
847        let this = self.clone();
848        self.acquire()
849            .await?
850            .with_transaction(move |txn| {
851                if let Some(pickled_private_identity) = &pickled_private_identity {
852                    let serialized_private_identity =
853                        this.serialize_value(pickled_private_identity)?;
854                    txn.set_kv("identity", &serialized_private_identity)?;
855                }
856
857                if let Some(token) = &changes.next_batch_token {
858                    let serialized_token = this.serialize_value(token)?;
859                    txn.set_kv("next_batch_token", &serialized_token)?;
860                }
861
862                if let Some(decryption_key) = &changes.backup_decryption_key {
863                    let serialized_decryption_key = this.serialize_value(decryption_key)?;
864                    txn.set_kv("recovery_key_v1", &serialized_decryption_key)?;
865                }
866
867                if let Some(backup_version) = &changes.backup_version {
868                    let serialized_backup_version = this.serialize_value(backup_version)?;
869                    txn.set_kv("backup_version_v1", &serialized_backup_version)?;
870                }
871
872                if let Some(pickle_key) = &changes.dehydrated_device_pickle_key {
873                    let serialized_pickle_key = this.serialize_value(pickle_key)?;
874                    txn.set_kv(DEHYDRATED_DEVICE_PICKLE_KEY, &serialized_pickle_key)?;
875                }
876
877                for device in changes.devices.new.iter().chain(&changes.devices.changed) {
878                    let user_id = this.encode_key("device", device.user_id().as_bytes());
879                    let device_id = this.encode_key("device", device.device_id().as_bytes());
880                    let data = this.serialize_value(&device)?;
881                    txn.set_device(&user_id, &device_id, &data)?;
882                }
883
884                for device in &changes.devices.deleted {
885                    let user_id = this.encode_key("device", device.user_id().as_bytes());
886                    let device_id = this.encode_key("device", device.device_id().as_bytes());
887                    txn.delete_device(&user_id, &device_id)?;
888                }
889
890                for identity in changes.identities.changed.iter().chain(&changes.identities.new) {
891                    let user_id = this.encode_key("identity", identity.user_id().as_bytes());
892                    let data = this.serialize_value(&identity)?;
893                    txn.set_identity(&user_id, &data)?;
894                }
895
896                for (session_id, sender_key, pickle) in &session_changes {
897                    let serialized_session = this.serialize_value(&pickle)?;
898                    txn.set_session(session_id, sender_key, &serialized_session)?;
899                }
900
901                for (room_id, session_id, pickle, sender_key) in &inbound_session_changes {
902                    let serialized_session = this.serialize_value(&pickle)?;
903                    txn.set_inbound_group_session(
904                        room_id,
905                        session_id,
906                        &serialized_session,
907                        pickle.backed_up,
908                        Some(sender_key),
909                        Some(pickle.sender_data.to_type() as u8),
910                    )?;
911                }
912
913                for (room_id, pickle) in &outbound_session_changes {
914                    let serialized_session = this.serialize_json(&pickle)?;
915                    txn.set_outbound_group_session(room_id, &serialized_session)?;
916                }
917
918                for hash in &changes.message_hashes {
919                    let hash = rmp_serde::to_vec(hash)?;
920                    txn.add_olm_hash(&hash)?;
921                }
922
923                for request in changes.key_requests {
924                    let request_id = this.encode_key("key_requests", request.request_id.as_bytes());
925                    let serialized_request = this.serialize_value(&request)?;
926                    txn.set_key_request(&request_id, request.sent_out, &serialized_request)?;
927                }
928
929                for (room_id, data) in changes.withheld_session_info {
930                    for (session_id, event) in data {
931                        let session_id = this.encode_key("direct_withheld_info", session_id);
932                        let room_id = this.encode_key("direct_withheld_info", &room_id);
933                        let serialized_info = this.serialize_json(&event)?;
934                        txn.set_direct_withheld(&session_id, &room_id, &serialized_info)?;
935                    }
936                }
937
938                for (room_id, settings) in changes.room_settings {
939                    let room_id = this.encode_key("room_settings", room_id.as_bytes());
940                    let value = this.serialize_value(&settings)?;
941                    txn.set_room_settings(&room_id, &value)?;
942                }
943
944                for secret in changes.secrets {
945                    let secret_name = this.encode_key("secrets", secret.secret_name.to_string());
946                    let value = this.serialize_json(&secret)?;
947                    txn.set_secret(&secret_name, &value)?;
948                }
949
950                Ok::<_, Error>(())
951            })
952            .await?;
953
954        Ok(())
955    }
956
957    async fn save_inbound_group_sessions(
958        &self,
959        sessions: Vec<InboundGroupSession>,
960        backed_up_to_version: Option<&str>,
961    ) -> matrix_sdk_crypto::store::Result<(), Self::Error> {
962        // Sanity-check that the data in the sessions corresponds to backed_up_version
963        sessions.iter().for_each(|s| {
964            let backed_up = s.backed_up();
965            if backed_up != backed_up_to_version.is_some() {
966                warn!(
967                    backed_up,
968                    backed_up_to_version,
969                    "Session backed-up flag does not correspond to backup version setting",
970                );
971            }
972        });
973
974        // Currently, this store doesn't save the backup version separately, so this
975        // just delegates to save_changes.
976        self.save_changes(Changes { inbound_group_sessions: sessions, ..Changes::default() }).await
977    }
978
979    async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
980        let device_keys = self.get_own_device().await?.as_device_keys().clone();
981
982        let sessions: Vec<_> = self
983            .acquire()
984            .await?
985            .get_sessions_for_sender_key(self.encode_key("session", sender_key.as_bytes()))
986            .await?
987            .into_iter()
988            .map(|bytes| {
989                let pickle = self.deserialize_value(&bytes)?;
990                Session::from_pickle(device_keys.clone(), pickle).map_err(|_| Error::AccountUnset)
991            })
992            .collect::<Result<_>>()?;
993
994        if sessions.is_empty() {
995            Ok(None)
996        } else {
997            Ok(Some(sessions))
998        }
999    }
1000
1001    #[instrument(skip(self))]
1002    async fn get_inbound_group_session(
1003        &self,
1004        room_id: &RoomId,
1005        session_id: &str,
1006    ) -> Result<Option<InboundGroupSession>> {
1007        let session_id = self.encode_key("inbound_group_session", session_id);
1008        let Some((room_id_from_db, value, backed_up)) =
1009            self.acquire().await?.get_inbound_group_session(session_id).await?
1010        else {
1011            return Ok(None);
1012        };
1013
1014        let room_id = self.encode_key("inbound_group_session", room_id.as_bytes());
1015        if *room_id != room_id_from_db {
1016            warn!("expected room_id for session_id doesn't match what's in the DB");
1017            return Ok(None);
1018        }
1019
1020        Ok(Some(self.deserialize_and_unpickle_inbound_group_session(value, backed_up)?))
1021    }
1022
1023    async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
1024        self.acquire()
1025            .await?
1026            .get_inbound_group_sessions()
1027            .await?
1028            .into_iter()
1029            .map(|(value, backed_up)| {
1030                self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1031            })
1032            .collect()
1033    }
1034
1035    async fn get_inbound_group_sessions_for_device_batch(
1036        &self,
1037        sender_key: Curve25519PublicKey,
1038        sender_data_type: SenderDataType,
1039        after_session_id: Option<String>,
1040        limit: usize,
1041    ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1042        let after_session_id =
1043            after_session_id.map(|session_id| self.encode_key("inbound_group_session", session_id));
1044        let sender_key = self.encode_key("inbound_group_session", sender_key.to_base64());
1045
1046        self.acquire()
1047            .await?
1048            .get_inbound_group_sessions_for_device_batch(
1049                sender_key,
1050                sender_data_type,
1051                after_session_id,
1052                limit,
1053            )
1054            .await?
1055            .into_iter()
1056            .map(|(value, backed_up)| {
1057                self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1058            })
1059            .collect()
1060    }
1061
1062    async fn inbound_group_session_counts(
1063        &self,
1064        backup_version: Option<&str>,
1065    ) -> Result<RoomKeyCounts> {
1066        Ok(self.acquire().await?.get_inbound_group_session_counts(backup_version).await?)
1067    }
1068
1069    async fn inbound_group_sessions_for_backup(
1070        &self,
1071        _backup_version: &str,
1072        limit: usize,
1073    ) -> Result<Vec<InboundGroupSession>> {
1074        self.acquire()
1075            .await?
1076            .get_inbound_group_sessions_for_backup(limit)
1077            .await?
1078            .into_iter()
1079            .map(|value| self.deserialize_and_unpickle_inbound_group_session(value, false))
1080            .collect()
1081    }
1082
1083    async fn mark_inbound_group_sessions_as_backed_up(
1084        &self,
1085        _backup_version: &str,
1086        session_ids: &[(&RoomId, &str)],
1087    ) -> Result<()> {
1088        Ok(self
1089            .acquire()
1090            .await?
1091            .mark_inbound_group_sessions_as_backed_up(
1092                session_ids
1093                    .iter()
1094                    .map(|(_, s)| self.encode_key("inbound_group_session", s))
1095                    .collect(),
1096            )
1097            .await?)
1098    }
1099
1100    async fn reset_backup_state(&self) -> Result<()> {
1101        Ok(self.acquire().await?.reset_inbound_group_session_backup_state().await?)
1102    }
1103
1104    async fn load_backup_keys(&self) -> Result<BackupKeys> {
1105        let conn = self.acquire().await?;
1106
1107        let backup_version = conn
1108            .get_kv("backup_version_v1")
1109            .await?
1110            .map(|value| self.deserialize_value(&value))
1111            .transpose()?;
1112
1113        let decryption_key = conn
1114            .get_kv("recovery_key_v1")
1115            .await?
1116            .map(|value| self.deserialize_value(&value))
1117            .transpose()?;
1118
1119        Ok(BackupKeys { backup_version, decryption_key })
1120    }
1121
1122    async fn load_dehydrated_device_pickle_key(&self) -> Result<Option<DehydratedDeviceKey>> {
1123        let conn = self.acquire().await?;
1124
1125        conn.get_kv(DEHYDRATED_DEVICE_PICKLE_KEY)
1126            .await?
1127            .map(|value| self.deserialize_value(&value))
1128            .transpose()
1129    }
1130
1131    async fn delete_dehydrated_device_pickle_key(&self) -> Result<(), Self::Error> {
1132        let conn = self.acquire().await?;
1133        conn.clear_kv(DEHYDRATED_DEVICE_PICKLE_KEY).await?;
1134
1135        Ok(())
1136    }
1137    async fn get_outbound_group_session(
1138        &self,
1139        room_id: &RoomId,
1140    ) -> Result<Option<OutboundGroupSession>> {
1141        let room_id = self.encode_key("outbound_group_session", room_id.as_bytes());
1142        let Some(value) = self.acquire().await?.get_outbound_group_session(room_id).await? else {
1143            return Ok(None);
1144        };
1145
1146        let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1147
1148        let pickle = self.deserialize_json(&value)?;
1149        let session = OutboundGroupSession::from_pickle(
1150            account_info.device_id,
1151            account_info.identity_keys,
1152            pickle,
1153        )
1154        .map_err(|_| Error::Unpickle)?;
1155
1156        return Ok(Some(session));
1157    }
1158
1159    async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
1160        self.acquire()
1161            .await?
1162            .get_tracked_users()
1163            .await?
1164            .iter()
1165            .map(|value| self.deserialize_value(value))
1166            .collect()
1167    }
1168
1169    async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
1170        let users: Vec<(Key, Vec<u8>)> = tracked_users
1171            .iter()
1172            .map(|(u, d)| {
1173                let user_id = self.encode_key("tracked_users", u.as_bytes());
1174                let data =
1175                    self.serialize_value(&TrackedUser { user_id: (*u).into(), dirty: *d })?;
1176                Ok((user_id, data))
1177            })
1178            .collect::<Result<_>>()?;
1179
1180        Ok(self.acquire().await?.add_tracked_users(users).await?)
1181    }
1182
1183    async fn get_device(
1184        &self,
1185        user_id: &UserId,
1186        device_id: &DeviceId,
1187    ) -> Result<Option<DeviceData>> {
1188        let user_id = self.encode_key("device", user_id.as_bytes());
1189        let device_id = self.encode_key("device", device_id.as_bytes());
1190        Ok(self
1191            .acquire()
1192            .await?
1193            .get_device(user_id, device_id)
1194            .await?
1195            .map(|value| self.deserialize_value(&value))
1196            .transpose()?)
1197    }
1198
1199    async fn get_user_devices(
1200        &self,
1201        user_id: &UserId,
1202    ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1203        let user_id = self.encode_key("device", user_id.as_bytes());
1204        self.acquire()
1205            .await?
1206            .get_user_devices(user_id)
1207            .await?
1208            .into_iter()
1209            .map(|value| {
1210                let device: DeviceData = self.deserialize_value(&value)?;
1211                Ok((device.device_id().to_owned(), device))
1212            })
1213            .collect()
1214    }
1215
1216    async fn get_own_device(&self) -> Result<DeviceData> {
1217        let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1218
1219        Ok(self
1220            .get_device(&account_info.user_id, &account_info.device_id)
1221            .await?
1222            .expect("We should be able to find our own device."))
1223    }
1224
1225    async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
1226        let user_id = self.encode_key("identity", user_id.as_bytes());
1227        Ok(self
1228            .acquire()
1229            .await?
1230            .get_user_identity(user_id)
1231            .await?
1232            .map(|value| self.deserialize_value(&value))
1233            .transpose()?)
1234    }
1235
1236    async fn is_message_known(
1237        &self,
1238        message_hash: &matrix_sdk_crypto::olm::OlmMessageHash,
1239    ) -> Result<bool> {
1240        let value = rmp_serde::to_vec(message_hash)?;
1241        Ok(self.acquire().await?.has_olm_hash(value).await?)
1242    }
1243
1244    async fn get_outgoing_secret_requests(
1245        &self,
1246        request_id: &TransactionId,
1247    ) -> Result<Option<GossipRequest>> {
1248        let request_id = self.encode_key("key_requests", request_id.as_bytes());
1249        Ok(self
1250            .acquire()
1251            .await?
1252            .get_outgoing_secret_request(request_id)
1253            .await?
1254            .map(|(value, sent_out)| self.deserialize_key_request(&value, sent_out))
1255            .transpose()?)
1256    }
1257
1258    async fn get_secret_request_by_info(
1259        &self,
1260        key_info: &SecretInfo,
1261    ) -> Result<Option<GossipRequest>> {
1262        let requests = self.acquire().await?.get_outgoing_secret_requests().await?;
1263        for (request, sent_out) in requests {
1264            let request = self.deserialize_key_request(&request, sent_out)?;
1265            if request.info == *key_info {
1266                return Ok(Some(request));
1267            }
1268        }
1269        Ok(None)
1270    }
1271
1272    async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
1273        self.acquire()
1274            .await?
1275            .get_unsent_secret_requests()
1276            .await?
1277            .iter()
1278            .map(|value| {
1279                let request = self.deserialize_key_request(value, false)?;
1280                Ok(request)
1281            })
1282            .collect()
1283    }
1284
1285    async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
1286        let request_id = self.encode_key("key_requests", request_id.as_bytes());
1287        Ok(self.acquire().await?.delete_key_request(request_id).await?)
1288    }
1289
1290    async fn get_secrets_from_inbox(
1291        &self,
1292        secret_name: &SecretName,
1293    ) -> Result<Vec<GossippedSecret>> {
1294        let secret_name = self.encode_key("secrets", secret_name.to_string());
1295
1296        self.acquire()
1297            .await?
1298            .get_secrets_from_inbox(secret_name)
1299            .await?
1300            .into_iter()
1301            .map(|value| self.deserialize_json(value.as_ref()))
1302            .collect()
1303    }
1304
1305    async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
1306        let secret_name = self.encode_key("secrets", secret_name.to_string());
1307        self.acquire().await?.delete_secrets_from_inbox(secret_name).await
1308    }
1309
1310    async fn get_withheld_info(
1311        &self,
1312        room_id: &RoomId,
1313        session_id: &str,
1314    ) -> Result<Option<RoomKeyWithheldEvent>> {
1315        let room_id = self.encode_key("direct_withheld_info", room_id);
1316        let session_id = self.encode_key("direct_withheld_info", session_id);
1317
1318        self.acquire()
1319            .await?
1320            .get_direct_withheld_info(session_id, room_id)
1321            .await?
1322            .map(|value| {
1323                let info = self.deserialize_json::<RoomKeyWithheldEvent>(&value)?;
1324                Ok(info)
1325            })
1326            .transpose()
1327    }
1328
1329    async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
1330        let room_id = self.encode_key("room_settings", room_id.as_bytes());
1331        let Some(value) = self.acquire().await?.get_room_settings(room_id).await? else {
1332            return Ok(None);
1333        };
1334
1335        let settings = self.deserialize_value(&value)?;
1336
1337        return Ok(Some(settings));
1338    }
1339
1340    async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
1341        let Some(serialized) = self.acquire().await?.get_kv(key).await? else {
1342            return Ok(None);
1343        };
1344        let value = if let Some(cipher) = &self.store_cipher {
1345            let encrypted = rmp_serde::from_slice(&serialized)?;
1346            cipher.decrypt_value_data(encrypted)?
1347        } else {
1348            serialized
1349        };
1350
1351        Ok(Some(value))
1352    }
1353
1354    async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
1355        let serialized = if let Some(cipher) = &self.store_cipher {
1356            let encrypted = cipher.encrypt_value_data(value)?;
1357            rmp_serde::to_vec_named(&encrypted)?
1358        } else {
1359            value
1360        };
1361
1362        self.acquire().await?.set_kv(key, serialized).await?;
1363        Ok(())
1364    }
1365
1366    async fn remove_custom_value(&self, key: &str) -> Result<()> {
1367        let key = key.to_owned();
1368        self.acquire()
1369            .await?
1370            .interact(move |conn| conn.execute("DELETE FROM kv WHERE key = ?1", (&key,)))
1371            .await
1372            .unwrap()?;
1373        Ok(())
1374    }
1375
1376    async fn try_take_leased_lock(
1377        &self,
1378        lease_duration_ms: u32,
1379        key: &str,
1380        holder: &str,
1381    ) -> Result<bool> {
1382        let key = key.to_owned();
1383        let holder = holder.to_owned();
1384
1385        let now_ts: u64 = MilliSecondsSinceUnixEpoch::now().get().into();
1386        let expiration_ts = now_ts + lease_duration_ms as u64;
1387
1388        let num_touched = self
1389            .acquire()
1390            .await?
1391            .with_transaction(move |txn| {
1392                txn.execute(
1393                    "INSERT INTO lease_locks (key, holder, expiration_ts)
1394                    VALUES (?1, ?2, ?3)
1395                    ON CONFLICT (key)
1396                    DO
1397                        UPDATE SET holder = ?2, expiration_ts = ?3
1398                        WHERE holder = ?2
1399                        OR expiration_ts < ?4
1400                ",
1401                    (key, holder, expiration_ts, now_ts),
1402                )
1403            })
1404            .await?;
1405
1406        Ok(num_touched == 1)
1407    }
1408
1409    async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
1410        let conn = self.acquire().await?;
1411        if let Some(token) = conn.get_kv("next_batch_token").await? {
1412            let maybe_token: Option<String> = self.deserialize_value(&token)?;
1413            Ok(maybe_token)
1414        } else {
1415            Ok(None)
1416        }
1417    }
1418}
1419
1420#[cfg(test)]
1421mod tests {
1422    use std::path::Path;
1423
1424    use matrix_sdk_common::deserialized_responses::WithheldCode;
1425    use matrix_sdk_crypto::{
1426        cryptostore_integration_tests, cryptostore_integration_tests_time, olm::SenderDataType,
1427        store::CryptoStore,
1428    };
1429    use matrix_sdk_test::async_test;
1430    use once_cell::sync::Lazy;
1431    use ruma::{device_id, room_id, user_id};
1432    use similar_asserts::assert_eq;
1433    use tempfile::{tempdir, TempDir};
1434    use tokio::fs;
1435
1436    use super::SqliteCryptoStore;
1437    use crate::SqliteStoreConfig;
1438
1439    static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
1440
1441    struct TestDb {
1442        // Needs to be kept alive because the Drop implementation for TempDir deletes the
1443        // directory.
1444        _dir: TempDir,
1445        database: SqliteCryptoStore,
1446    }
1447
1448    fn copy_db(data_path: &str) -> TempDir {
1449        let db_name = super::DATABASE_NAME;
1450
1451        let manifest_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../..");
1452        let database_path = manifest_path.join(data_path).join(db_name);
1453
1454        let tmpdir = tempdir().unwrap();
1455        let destination = tmpdir.path().join(db_name);
1456
1457        // Copy the test database to the tempdir so our test runs are idempotent.
1458        std::fs::copy(&database_path, destination).unwrap();
1459
1460        tmpdir
1461    }
1462
1463    async fn get_test_db(data_path: &str, passphrase: Option<&str>) -> TestDb {
1464        let tmpdir = copy_db(data_path);
1465
1466        let database = SqliteCryptoStore::open(tmpdir.path(), passphrase)
1467            .await
1468            .expect("Can't open the test store");
1469
1470        TestDb { _dir: tmpdir, database }
1471    }
1472
1473    #[async_test]
1474    async fn test_pool_size() {
1475        let store_open_config =
1476            SqliteStoreConfig::new(TMP_DIR.path().join("test_pool_size")).pool_max_size(42);
1477
1478        let store = SqliteCryptoStore::open_with_config(store_open_config).await.unwrap();
1479
1480        assert_eq!(store.pool.status().max_size, 42);
1481    }
1482
1483    /// Test that we didn't regress in our storage layer by loading data from a
1484    /// pre-filled database, or in other words use a test vector for this.
1485    #[async_test]
1486    async fn test_open_test_vector_store() {
1487        let TestDb { _dir: _, database } = get_test_db("testing/data/storage", None).await;
1488
1489        let account = database
1490            .load_account()
1491            .await
1492            .unwrap()
1493            .expect("The test database is prefilled with data, we should find an account");
1494
1495        let user_id = account.user_id();
1496        let device_id = account.device_id();
1497
1498        assert_eq!(
1499            user_id.as_str(),
1500            "@pjtest:synapse-oidc.element.dev",
1501            "The user ID should match to the one we expect."
1502        );
1503
1504        assert_eq!(
1505            device_id.as_str(),
1506            "v4TqgcuIH6",
1507            "The device ID should match to the one we expect."
1508        );
1509
1510        let device = database
1511            .get_device(user_id, device_id)
1512            .await
1513            .unwrap()
1514            .expect("Our own device should be found in the store.");
1515
1516        assert_eq!(device.device_id(), device_id);
1517        assert_eq!(device.user_id(), user_id);
1518
1519        assert_eq!(
1520            device.ed25519_key().expect("The device should have a Ed25519 key.").to_base64(),
1521            "+cxl1Gl3du5i7UJwfWnoRDdnafFF+xYdAiTYYhYLr8s"
1522        );
1523
1524        assert_eq!(
1525            device.curve25519_key().expect("The device should have a Curve25519 key.").to_base64(),
1526            "4SL9eEUlpyWSUvjljC5oMjknHQQJY7WZKo5S1KL/5VU"
1527        );
1528
1529        let identity = database
1530            .get_user_identity(user_id)
1531            .await
1532            .unwrap()
1533            .expect("The store should contain an identity.");
1534
1535        assert_eq!(identity.user_id(), user_id);
1536
1537        let identity = identity
1538            .own()
1539            .expect("The identity should be of the correct type, it should be our own identity.");
1540
1541        let master_key = identity
1542            .master_key()
1543            .get_first_key()
1544            .expect("Our own identity should have a master key");
1545
1546        assert_eq!(master_key.to_base64(), "iCUEtB1RwANeqRa5epDrblLk4mer/36sylwQ5hYY3oE");
1547    }
1548
1549    /// Test that we didn't regress in our storage layer by loading data from a
1550    /// pre-filled database, or in other words use a test vector for this.
1551    #[async_test]
1552    async fn test_open_test_vector_encrypted_store() {
1553        let TestDb { _dir: _, database } = get_test_db(
1554            "testing/data/storage/alice",
1555            Some(concat!(
1556                "/rCia2fYAJ+twCZ1Xm2mxFCYcmJdyzkdJjwtgXsziWpYS/UeNxnixuSieuwZXm+x1VsJHmWpl",
1557                "H+QIQBZpEGZtC9/S/l8xK+WOCesmET0o6yJ/KP73ofDtjBlnNpPwuHLKFpyTbyicpCgQ4UT+5E",
1558                "UBuJ08TY9Ujdf1D13k5kr5tSZUefDKKCuG1fCRqlU8ByRas1PMQsZxT2W8t7QgBrQiiGmhpo/O",
1559                "Ti4hfx97GOxncKcxTzppiYQNoHs/f15+XXQD7/oiCcqRIuUlXNsU6hRpFGmbYx2Pi1eyQViQCt",
1560                "B5dAEiSD0N8U81wXYnpynuTPtnL+hfnOJIn7Sy7mkERQeKg"
1561            )),
1562        )
1563        .await;
1564
1565        let account = database
1566            .load_account()
1567            .await
1568            .unwrap()
1569            .expect("The test database is prefilled with data, we should find an account");
1570
1571        let user_id = account.user_id();
1572        let device_id = account.device_id();
1573
1574        assert_eq!(
1575            user_id.as_str(),
1576            "@alice:localhost",
1577            "The user ID should match to the one we expect."
1578        );
1579
1580        assert_eq!(
1581            device_id.as_str(),
1582            "JVVORTHFXY",
1583            "The device ID should match to the one we expect."
1584        );
1585
1586        let tracked_users =
1587            database.load_tracked_users().await.expect("Should be tracking some users");
1588
1589        assert_eq!(tracked_users.len(), 6);
1590
1591        let known_users = vec![
1592            user_id!("@alice:localhost"),
1593            user_id!("@dehydration3:localhost"),
1594            user_id!("@eve:localhost"),
1595            user_id!("@bob:localhost"),
1596            user_id!("@malo:localhost"),
1597            user_id!("@carl:localhost"),
1598        ];
1599
1600        // load the identities
1601        for user_id in known_users {
1602            database.get_user_identity(user_id).await.expect("Should load this identity").unwrap();
1603        }
1604
1605        let carl_identity =
1606            database.get_user_identity(user_id!("@carl:localhost")).await.unwrap().unwrap();
1607
1608        assert_eq!(
1609            carl_identity.master_key().get_first_key().unwrap().to_base64(),
1610            "CdhKYYDeBDQveOioXEGWhTPCyzc63Irpar3CNyfun2Q"
1611        );
1612        assert!(!carl_identity.was_previously_verified());
1613
1614        let bob_identity =
1615            database.get_user_identity(user_id!("@bob:localhost")).await.unwrap().unwrap();
1616
1617        assert_eq!(
1618            bob_identity.master_key().get_first_key().unwrap().to_base64(),
1619            "COh2GYOJWSjem5QPRCaGp9iWV83IELG1IzLKW2S3pFY"
1620        );
1621        // Bob is verified so this flag should be set
1622        assert!(bob_identity.was_previously_verified());
1623
1624        let known_devices = vec![
1625            (device_id!("OPXQHCZSKW"), user_id!("@alice:localhost")),
1626            // a dehydrated one
1627            (
1628                device_id!("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho"),
1629                user_id!("@alice:localhost"),
1630            ),
1631            (device_id!("HEEFRFQENV"), user_id!("@alice:localhost")),
1632            (device_id!("JVVORTHFXY"), user_id!("@alice:localhost")),
1633            (device_id!("NQUWWSKKHS"), user_id!("@alice:localhost")),
1634            (device_id!("ORBLPFYCPG"), user_id!("@alice:localhost")),
1635            (device_id!("YXOWENSEGM"), user_id!("@dehydration3:localhost")),
1636            (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
1637            (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
1638            (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
1639            (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
1640            (device_id!("QKUKWJTTQC"), user_id!("@malo:localhost")),
1641            (device_id!("LOUXJECTFG"), user_id!("@malo:localhost")),
1642            (device_id!("MKKMAEVLPB"), user_id!("@carl:localhost")),
1643        ];
1644
1645        for (device_id, user_id) in known_devices {
1646            database.get_device(user_id, device_id).await.expect("Should load the device").unwrap();
1647        }
1648
1649        let known_sender_key_to_session_count = vec![
1650            ("FfYcYfDF4nWy+LHdK6CEpIMlFAQDORc30WUkghL06kM", 1),
1651            ("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho", 1),
1652            ("hAGsoA4a9M6wwEUX5Q1jux1i+tUngLi01n5AmhDoHTY", 1),
1653            ("aKqtSJymLzuoglWFwPGk1r/Vm2LE2hFESzXxn4RNjRM", 0),
1654            ("zHK1psCrgeMn0kaz8hcdvA3INyar9jg1yfrSp0p1pHo", 1),
1655            ("1QmBA316Wj5jIFRwNOti6N6Xh/vW0bsYCcR4uPfy8VQ", 1),
1656            ("g5ef2vZF3VXgSPyODIeXpyHIRkuthvLhGvd6uwYggWU", 1),
1657            ("o7hfupPd1VsNkRIvdlH6ujrEJFSKjFCGbxhAd31XxjI", 1),
1658            ("Z3RxKQLxY7xpP+ZdOGR2SiNE37SrvmRhW7GPu1UGdm8", 1),
1659            ("GDomaav8NiY3J+dNEeApJm+O0FooJ3IpVaIyJzCN4w4", 1),
1660            ("7m7fqkHyEr47V5s/KjaxtJMOr3pSHrrns2q2lWpAQi8", 0),
1661            ("9psAkPUIF8vNbWbnviX3PlwRcaeO53EHJdNtKpTY1X0", 0),
1662            ("mqanh+ztw5oRtpqYQgLGW864i6NY2zpoKMIlrcyC+Aw", 0),
1663            ("fJU/TJdbsv7tVbbpHw1Ke73ziElnM32cNhP2WIg4T10", 0),
1664            ("sUIeFeFcCZoa5IC6nJ6Vrbvztcyx09m8BBg57XKRClg", 1),
1665        ];
1666
1667        for (id, count) in known_sender_key_to_session_count {
1668            let olm_sessions =
1669                database.get_sessions(id).await.expect("Should have some olm sessions");
1670
1671            println!("### Session id: {:?}", id);
1672            assert_eq!(olm_sessions.map_or(0, |v| v.len()), count);
1673        }
1674
1675        let inbound_group_sessions = database.get_inbound_group_sessions().await.unwrap();
1676        assert_eq!(inbound_group_sessions.len(), 15);
1677        let known_inbound_group_sessions = vec![
1678            (
1679                "5hNAxrLai3VI0LKBwfh3wLfksfBFWds0W1a5X5/vSXA",
1680                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1681            ),
1682            (
1683                "M6d2eU3y54gaYTbvGSlqa/xc1Az35l56Cp9sxzHWO4g",
1684                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1685            ),
1686            (
1687                "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
1688                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1689            ),
1690            (
1691                "Y74+l9jTo7N5UF+GQwdpgJGe4sn1+QtWITq7BxulHIE",
1692                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1693            ),
1694            (
1695                "HpJxQR57WbQGdY6w2Q+C16znVvbXGa+JvQdRoMpWbXg",
1696                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1697            ),
1698            (
1699                "Xetvi+ydFkZt8dpONGFbEusQb/Chc2V0XlLByZhsbgE",
1700                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1701            ),
1702            (
1703                "wv/WN/39akyerIXczTaIpjAuLnwgXKRtbXFSEHiJqxo",
1704                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1705            ),
1706            (
1707                "nA4gQwL//Cm8OdlyjABl/jChbPT/cP5V4Sd8iuE6H0s",
1708                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1709            ),
1710            (
1711                "bAAgqFeRDTjfEqL6Qf/c9mk55zoNDCSlboAIRd6b0hw",
1712                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1713            ),
1714            (
1715                "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
1716                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1717            ),
1718            (
1719                "h+om7oSw/ZV94fcKaoe8FGXJwQXWOfKQfzbGgNWQILI",
1720                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1721            ),
1722            (
1723                "ul3VXonpgk4lO2L3fEWubP/nxsTmLHqu5v8ZM9vHEcw",
1724                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1725            ),
1726            (
1727                "JXY15UxC3az2mwg8uX4qwgxfvCM4aygiIWMcdNiVQoc",
1728                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1729            ),
1730            (
1731                "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
1732                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1733            ),
1734            (
1735                "SFkHcbxjUOYF7mUAYI/oEMDZFaXszQbCN6Jza7iemj0",
1736                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1737            ),
1738        ];
1739
1740        // ensure we can load them all
1741        for (session_id, room_id) in &known_inbound_group_sessions {
1742            database
1743                .get_inbound_group_session(room_id, session_id)
1744                .await
1745                .expect("Should be able to load inbound group session")
1746                .unwrap();
1747        }
1748
1749        let bob_sender_verified = database
1750            .get_inbound_group_session(
1751                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1752                "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
1753            )
1754            .await
1755            .unwrap()
1756            .unwrap();
1757
1758        assert_eq!(bob_sender_verified.sender_data.to_type(), SenderDataType::SenderVerified);
1759        assert!(bob_sender_verified.backed_up());
1760        assert!(!bob_sender_verified.has_been_imported());
1761
1762        let alice_unknown_device = database
1763            .get_inbound_group_session(
1764                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1765                "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
1766            )
1767            .await
1768            .unwrap()
1769            .unwrap();
1770
1771        assert_eq!(alice_unknown_device.sender_data.to_type(), SenderDataType::UnknownDevice);
1772        assert!(alice_unknown_device.backed_up());
1773        assert!(alice_unknown_device.has_been_imported());
1774
1775        let carl_tofu_session = database
1776            .get_inbound_group_session(
1777                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1778                "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
1779            )
1780            .await
1781            .unwrap()
1782            .unwrap();
1783
1784        assert_eq!(carl_tofu_session.sender_data.to_type(), SenderDataType::SenderUnverified);
1785        assert!(carl_tofu_session.backed_up());
1786        assert!(!carl_tofu_session.has_been_imported());
1787
1788        // Load outbound sessions
1789        database
1790            .get_outbound_group_session(room_id!("!OgRiTRMaUzLdpCeDBM:localhost"))
1791            .await
1792            .unwrap()
1793            .unwrap();
1794        database
1795            .get_outbound_group_session(room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"))
1796            .await
1797            .unwrap()
1798            .unwrap();
1799        database
1800            .get_outbound_group_session(room_id!("!SRstFdydzrGwJYtVfm:localhost"))
1801            .await
1802            .unwrap()
1803            .unwrap();
1804
1805        let withheld_info = database
1806            .get_withheld_info(
1807                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1808                "SASgZ+EklvAF4QxJclMlDRlmL0fAMjAJJIKFMdb4Ht0",
1809            )
1810            .await
1811            .expect("This session should be withheld")
1812            .unwrap();
1813
1814        assert_eq!(withheld_info.content.withheld_code(), WithheldCode::Unverified);
1815
1816        let backup_keys = database.load_backup_keys().await.expect("backup key should be cached");
1817        assert_eq!(backup_keys.backup_version.unwrap(), "6");
1818        assert!(backup_keys.decryption_key.is_some());
1819    }
1820
1821    async fn get_store(
1822        name: &str,
1823        passphrase: Option<&str>,
1824        clear_data: bool,
1825    ) -> SqliteCryptoStore {
1826        let tmpdir_path = TMP_DIR.path().join(name);
1827
1828        if clear_data {
1829            let _ = fs::remove_dir_all(&tmpdir_path).await;
1830        }
1831
1832        SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), passphrase)
1833            .await
1834            .expect("Can't create a passphrase protected store")
1835    }
1836
1837    cryptostore_integration_tests!();
1838    cryptostore_integration_tests_time!();
1839}
1840
1841#[cfg(test)]
1842mod encrypted_tests {
1843    use matrix_sdk_crypto::{cryptostore_integration_tests, cryptostore_integration_tests_time};
1844    use once_cell::sync::Lazy;
1845    use tempfile::{tempdir, TempDir};
1846    use tokio::fs;
1847
1848    use super::SqliteCryptoStore;
1849
1850    static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
1851
1852    async fn get_store(
1853        name: &str,
1854        passphrase: Option<&str>,
1855        clear_data: bool,
1856    ) -> SqliteCryptoStore {
1857        let tmpdir_path = TMP_DIR.path().join(name);
1858        let pass = passphrase.unwrap_or("default_test_password");
1859
1860        if clear_data {
1861            let _ = fs::remove_dir_all(&tmpdir_path).await;
1862        }
1863
1864        SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), Some(pass))
1865            .await
1866            .expect("Can't create a passphrase protected store")
1867    }
1868
1869    cryptostore_integration_tests!();
1870    cryptostore_integration_tests_time!();
1871}