Skip to main content

matrix_sdk_sqlite/
crypto_store.rs

1// Copyright 2022, 2026 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::HashMap,
17    fmt,
18    ops::Deref,
19    path::{Path, PathBuf},
20    sync::{Arc, RwLock},
21};
22
23use async_trait::async_trait;
24use deadpool::managed::PoolConfig;
25use matrix_sdk_base::cross_process_lock::CrossProcessLockGeneration;
26use matrix_sdk_crypto::{
27    Account, DeviceData, GossipRequest, GossippedSecret, SecretInfo, TrackedUser, UserIdentityData,
28    olm::{
29        InboundGroupSession, OutboundGroupSession, PickledInboundGroupSession,
30        PrivateCrossSigningIdentity, SenderDataType, Session, StaticAccountData,
31    },
32    store::{
33        CryptoStore,
34        types::{
35            BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts,
36            RoomKeyWithheldEntry, RoomPendingKeyBundleDetails, RoomSettings,
37            StoredRoomKeyBundleData,
38        },
39    },
40};
41use matrix_sdk_store_encryption::StoreCipher;
42use ruma::{
43    DeviceId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, RoomId, TransactionId, UserId,
44    events::secret::request::SecretName,
45};
46use rusqlite::{OptionalExtension, named_params, params_from_iter};
47use tokio::{
48    fs,
49    sync::{Mutex, OwnedMutexGuard},
50};
51use tracing::{debug, instrument, warn};
52use vodozemac::Curve25519PublicKey;
53use zeroize::Zeroizing;
54
55use crate::{
56    OpenStoreError, RuntimeConfig, Secret, SqliteStoreConfig,
57    connection::{self, Connection as SqliteAsyncConn, Pool as SqlitePool, SqliteConnections},
58    error::{Error, Result},
59    utils::{
60        EncryptableStore, Key, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt,
61        SqliteKeyValueStoreConnExt, repeat_vars,
62    },
63};
64
65/// The database name.
66const DATABASE_NAME: &str = "matrix-sdk-crypto.sqlite3";
67
68/// An SQLite-based crypto store.
69#[derive(Clone)]
70pub struct SqliteCryptoStore {
71    store_cipher: Option<Arc<StoreCipher>>,
72
73    /// `Some` when active, `None` when closed.
74    /// The outer `Mutex` serialises close/reopen with connection access.
75    connections: Arc<Mutex<Option<SqliteConnections>>>,
76
77    /// Retained so we can rebuild the pool on reopen.
78    db_path: PathBuf,
79
80    /// Retained so we can rebuild the pool on reopen.
81    pool_config: PoolConfig,
82
83    /// Retained so we can re-apply runtime config on reopen.
84    runtime_config: RuntimeConfig,
85
86    // DB values cached in memory
87    static_account: Arc<RwLock<Option<StaticAccountData>>>,
88    save_changes_lock: Arc<Mutex<()>>,
89}
90
91#[cfg(not(tarpaulin_include))]
92impl fmt::Debug for SqliteCryptoStore {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        f.debug_struct("SqliteCryptoStore").finish_non_exhaustive()
95    }
96}
97
98impl EncryptableStore for SqliteCryptoStore {
99    fn get_cypher(&self) -> Option<&StoreCipher> {
100        self.store_cipher.as_deref()
101    }
102}
103
104impl SqliteCryptoStore {
105    /// Create an `SqliteCryptoStore` struct without trying to create the
106    /// database or migrate to a newer version.  This is only for use
107    /// internally, and for testing.
108    ///
109    /// # Arguments
110    ///
111    /// * `secret` - The secret used to encrypt the data.
112    ///
113    /// * `pool` - A connection pool to use for reading from the store.
114    ///
115    /// * `conn` - The connection to use for writing to the store.
116    pub(crate) async fn create_raw(
117        secret: Option<Secret>,
118        pool: SqlitePool,
119        conn: SqliteAsyncConn,
120        pool_config: PoolConfig,
121        runtime_config: RuntimeConfig,
122    ) -> Result<Self, OpenStoreError> {
123        let store_cipher = match secret {
124            Some(s) => Some(Arc::new(conn.get_or_create_store_cipher(s).await?)),
125            None => None,
126        };
127
128        let db_path = pool.manager().database_path.clone();
129
130        Ok(Self {
131            store_cipher,
132            connections: Arc::new(Mutex::new(Some(SqliteConnections {
133                pool,
134                write_connection: Arc::new(Mutex::new(conn)),
135            }))),
136            db_path,
137            pool_config,
138            runtime_config,
139            static_account: Arc::new(RwLock::new(None)),
140            save_changes_lock: Default::default(),
141        })
142    }
143
144    /// Open the SQLite-based crypto store at the given path using the given
145    /// passphrase to encrypt private data.
146    pub async fn open(
147        path: impl AsRef<Path>,
148        passphrase: Option<&str>,
149    ) -> Result<Self, OpenStoreError> {
150        Self::open_with_config(&SqliteStoreConfig::new(path).passphrase(passphrase)).await
151    }
152
153    /// Open the SQLite-based crypto store at the given path using the given
154    /// key to encrypt private data.
155    pub async fn open_with_key(
156        path: impl AsRef<Path>,
157        key: Option<&[u8; 32]>,
158    ) -> Result<Self, OpenStoreError> {
159        Self::open_with_config(&SqliteStoreConfig::new(path).key(key)).await
160    }
161
162    /// Open the SQLite-based crypto store with the config open config.
163    pub async fn open_with_config(config: &SqliteStoreConfig) -> Result<Self, OpenStoreError> {
164        fs::create_dir_all(&config.path).await.map_err(OpenStoreError::CreateDir)?;
165
166        let pool = config.build_pool_of_connections(DATABASE_NAME)?;
167        let pool_config = config.pool_config();
168        let runtime_config = config.runtime_config();
169
170        let this =
171            Self::open_with_pool(pool, config.secret.clone(), pool_config, runtime_config).await?;
172        this.read().await?.apply_runtime_config(runtime_config).await?;
173
174        Ok(this)
175    }
176
177    /// Create an SQLite-based crypto store using the given SQLite database
178    /// pool. The given secret will be used to encrypt private data.
179    async fn open_with_pool(
180        pool: SqlitePool,
181        secret: Option<Secret>,
182        pool_config: PoolConfig,
183        runtime_config: RuntimeConfig,
184    ) -> Result<Self, OpenStoreError> {
185        let conn = pool.get().await?;
186
187        let version = conn.db_version().await?;
188        debug!("Opened sqlite store with version {}", version);
189
190        let version = initialize_store(&conn, version).await?;
191
192        let store = Self::create_raw(secret, pool, conn, pool_config, runtime_config).await?;
193
194        run_migrations(&store, version, None).await?;
195
196        store.write().await?.wal_checkpoint().await;
197
198        Ok(store)
199    }
200
201    fn deserialize_and_unpickle_inbound_group_session(
202        &self,
203        value: Vec<u8>,
204        backed_up: bool,
205    ) -> Result<InboundGroupSession> {
206        let mut pickle: PickledInboundGroupSession = self.deserialize_value(&value)?;
207
208        // The `backed_up` SQL column is the source of truth, because we update it
209        // inside `mark_inbound_group_sessions_as_backed_up` and don't update
210        // the pickled value inside the `data` column (until now, when we are puling it
211        // out of the DB).
212        pickle.backed_up = backed_up;
213
214        Ok(InboundGroupSession::from_pickle(pickle)?)
215    }
216
217    fn deserialize_key_request(&self, value: &[u8], sent_out: bool) -> Result<GossipRequest> {
218        let mut request: GossipRequest = self.deserialize_value(value)?;
219        // sent_out SQL column is source of truth, sent_out field in serialized value
220        // needed for other stores though
221        request.sent_out = sent_out;
222        Ok(request)
223    }
224
225    fn get_static_account(&self) -> Option<StaticAccountData> {
226        self.static_account.read().unwrap().clone()
227    }
228
229    /// Acquire a connection for executing read operations.
230    #[instrument(skip_all)]
231    async fn read(&self) -> Result<SqliteAsyncConn> {
232        let pool = {
233            let guard = self.connections.lock().await;
234            let conns = guard.as_ref().ok_or(Error::StoreClosed)?;
235            conns.pool.clone()
236        };
237        Ok(pool.get().await?)
238    }
239
240    /// Acquire a connection for executing write operations.
241    #[instrument(skip_all)]
242    pub(crate) async fn write(&self) -> Result<OwnedMutexGuard<SqliteAsyncConn>> {
243        let write_connection = {
244            let guard = self.connections.lock().await;
245            let conns = guard.as_ref().ok_or(Error::StoreClosed)?;
246            conns.write_connection.clone()
247        };
248        Ok(write_connection.lock_owned().await)
249    }
250}
251
252const DATABASE_VERSION: u8 = 15;
253
254/// key for the dehydrated device pickle key in the key/value table.
255const DEHYDRATED_DEVICE_PICKLE_KEY: &str = "dehydrated_device_pickle_key";
256
257/// Initialize the database to version 1
258///
259/// This must be done before creating the store cipher, because the store cipher
260/// requires the `kv` table.
261///
262/// # Arguments
263///
264/// * `conn` - The connection to use.
265///
266/// * `version` - the current version of the database.
267pub(crate) async fn initialize_store(conn: &SqliteAsyncConn, version: u8) -> Result<u8> {
268    if version == 0 {
269        debug!("Creating database");
270    } else if version < DATABASE_VERSION {
271        debug!(version, new_version = DATABASE_VERSION, "Upgrading database");
272    } else {
273        return Ok(version);
274    }
275
276    if version < 1 {
277        debug!("Creating database");
278        // First turn on WAL mode, this can't be done in the transaction, it fails with
279        // the error message: "cannot change into wal mode from within a transaction".
280        conn.execute_batch("PRAGMA journal_mode = wal;").await?;
281        conn.with_transaction(|txn| {
282            txn.execute_batch(include_str!("../migrations/crypto_store/001_init.sql"))?;
283            txn.set_db_version(1)
284        })
285        .await?;
286        return Ok(1);
287    }
288
289    Ok(version)
290}
291
292/// Run migrations for the given version of the database.
293///
294/// # Arguments
295///
296/// * `store` - The store to run the migrations on
297///
298/// * `version` - The current version of the database.
299///
300/// * `max_version` - The maximum version that the database will be migrated to.
301///   Only used for testing, so will only be checked for the versions that are
302///   needed for tests.
303pub(crate) async fn run_migrations(
304    store: &SqliteCryptoStore,
305    version: u8,
306    max_version: Option<u8>,
307) -> Result<()> {
308    let conn = store.write().await?;
309
310    if version < 2 {
311        debug!("Upgrading database to version 2");
312        conn.with_transaction(|txn| {
313            txn.execute_batch(include_str!("../migrations/crypto_store/002_reset_olm_hash.sql"))?;
314            txn.set_db_version(2)
315        })
316        .await?;
317    }
318
319    if version < 3 {
320        debug!("Upgrading database to version 3");
321        conn.with_transaction(|txn| {
322            txn.execute_batch(include_str!("../migrations/crypto_store/003_room_settings.sql"))?;
323            txn.set_db_version(3)
324        })
325        .await?;
326    }
327
328    if version < 4 {
329        debug!("Upgrading database to version 4");
330        conn.with_transaction(|txn| {
331            txn.execute_batch(include_str!(
332                "../migrations/crypto_store/004_drop_outbound_group_sessions.sql"
333            ))?;
334            txn.set_db_version(4)
335        })
336        .await?;
337    }
338
339    if version < 5 {
340        debug!("Upgrading database to version 5");
341        conn.with_transaction(|txn| {
342            txn.execute_batch(include_str!("../migrations/crypto_store/005_withheld_code.sql"))?;
343            txn.set_db_version(5)
344        })
345        .await?;
346    }
347
348    if version < 6 {
349        debug!("Upgrading database to version 6");
350        conn.with_transaction(|txn| {
351            txn.execute_batch(include_str!(
352                "../migrations/crypto_store/006_drop_outbound_group_sessions.sql"
353            ))?;
354            txn.set_db_version(6)
355        })
356        .await?;
357    }
358
359    if version < 7 {
360        debug!("Upgrading database to version 7");
361        conn.with_transaction(|txn| {
362            txn.execute_batch(include_str!("../migrations/crypto_store/007_lock_leases.sql"))?;
363            txn.set_db_version(7)
364        })
365        .await?;
366    }
367
368    if version < 8 {
369        debug!("Upgrading database to version 8");
370        conn.with_transaction(|txn| {
371            txn.execute_batch(include_str!("../migrations/crypto_store/008_secret_inbox.sql"))?;
372            txn.set_db_version(8)
373        })
374        .await?;
375    }
376
377    if version < 9 {
378        debug!("Upgrading database to version 9");
379        conn.with_transaction(|txn| {
380            txn.execute_batch(include_str!(
381                "../migrations/crypto_store/009_inbound_group_session_sender_key_sender_data_type.sql"
382            ))?;
383            txn.set_db_version(9)
384        })
385        .await?;
386    }
387
388    if version < 10 {
389        debug!("Upgrading database to version 10");
390        conn.with_transaction(|txn| {
391            txn.execute_batch(include_str!(
392                "../migrations/crypto_store/010_received_room_key_bundles.sql"
393            ))?;
394            txn.set_db_version(10)
395        })
396        .await?;
397    }
398
399    if version < 11 {
400        debug!("Upgrading database to version 11");
401        conn.with_transaction(|txn| {
402            txn.execute_batch(include_str!(
403                "../migrations/crypto_store/011_received_room_key_bundles_with_curve_key.sql"
404            ))?;
405            txn.set_db_version(11)
406        })
407        .await?;
408    }
409
410    if version < 12 {
411        debug!("Upgrading database to version 12");
412        conn.with_transaction(|txn| {
413            txn.execute_batch(include_str!(
414                "../migrations/crypto_store/012_withheld_code_by_room.sql"
415            ))?;
416            txn.set_db_version(12)
417        })
418        .await?;
419    }
420
421    if version < 13 {
422        debug!("Upgrading database to version 13");
423        conn.with_transaction(|txn| {
424            txn.execute_batch(include_str!(
425                "../migrations/crypto_store/013_lease_locks_with_generation.sql"
426            ))?;
427            txn.set_db_version(13)
428        })
429        .await?;
430    }
431
432    if version < 14 {
433        debug!("Upgrading database to version 14");
434        conn.with_transaction(|txn| {
435            txn.execute_batch(include_str!(
436                "../migrations/crypto_store/014_room_key_backups_fully_downloaded.sql"
437            ))?;
438            txn.set_db_version(14)
439        })
440        .await?;
441    }
442
443    if version < 15 {
444        debug!("Upgrading database to version 15");
445        conn.with_transaction(|txn| {
446            txn.execute_batch(include_str!(
447                "../migrations/crypto_store/015_rooms_pending_key_bundle.sql"
448            ))?;
449            txn.set_db_version(15)
450        })
451        .await?;
452    }
453
454    if version < 16 {
455        debug!("Upgrading database to version 16");
456        conn.with_transaction(|txn| {
457            txn.execute_batch(include_str!(
458                "../migrations/crypto_store/016_remove_old_generation_counter.sql"
459            ))?;
460            txn.set_db_version(16)
461        })
462        .await?;
463    }
464
465    if max_version.is_some_and(|max_version| max_version < 17) {
466        return Ok(());
467    }
468
469    if version < 17 {
470        let store = store.clone();
471        conn.with_transaction(move |txn| {
472            txn.execute_batch(include_str!(
473                "../migrations/crypto_store/017_add_new_secrets_inbox.sql"
474            ))?;
475            let mut select_query = txn.prepare("SELECT data FROM secrets")?;
476            let mut secrets = select_query.query([])?;
477            let mut insert_query = txn.prepare(
478                "INSERT OR IGNORE INTO secrets_inbox (secret_name, secret)
479            VALUES (?1, ?2)",
480            )?;
481            while let Some(row) = secrets.next()? {
482                let Ok(secret) =
483                    store.deserialize_json::<GossippedSecret>(row.get::<_, Vec<u8>>(0)?.as_ref())
484                else {
485                    continue;
486                };
487                let Ok(encoded_secret) = store.serialize_json(&secret.event.content.secret) else {
488                    continue;
489                };
490                insert_query.execute((
491                    store.encode_key("secrets_inbox", secret.secret_name.to_string()),
492                    &encoded_secret,
493                ))?;
494            }
495            txn.execute_batch(include_str!(
496                "../migrations/crypto_store/017_drop_old_secrets_inbox.sql"
497            ))?;
498            txn.set_db_version(17)
499        })
500        .await?;
501    }
502
503    Ok(())
504}
505
506trait SqliteConnectionExt {
507    fn set_session(
508        &self,
509        session_id: &[u8],
510        sender_key: &[u8],
511        data: &[u8],
512    ) -> rusqlite::Result<()>;
513
514    fn set_inbound_group_session(
515        &self,
516        room_id: &[u8],
517        session_id: &[u8],
518        data: &[u8],
519        backed_up: bool,
520        sender_key: Option<&[u8]>,
521        sender_data_type: Option<u8>,
522    ) -> rusqlite::Result<()>;
523
524    fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
525
526    fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
527    fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()>;
528
529    fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
530
531    fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()>;
532
533    fn set_key_request(
534        &self,
535        request_id: &[u8],
536        sent_out: bool,
537        data: &[u8],
538    ) -> rusqlite::Result<()>;
539
540    fn set_direct_withheld(
541        &self,
542        session_id: &[u8],
543        room_id: &[u8],
544        data: &[u8],
545    ) -> rusqlite::Result<()>;
546
547    fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
548
549    fn set_secret(&self, request_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
550
551    fn set_received_room_key_bundle(
552        &self,
553        room_id: &[u8],
554        user_id: &[u8],
555        data: &[u8],
556    ) -> rusqlite::Result<()>;
557
558    fn set_has_downloaded_all_room_keys(&self, room_id: &[u8]) -> rusqlite::Result<()>;
559
560    fn set_room_pending_key_bundle(
561        &self,
562        room_id: &[u8],
563        details: Option<&[u8]>,
564    ) -> rusqlite::Result<()>;
565}
566
567impl SqliteConnectionExt for rusqlite::Connection {
568    fn set_session(
569        &self,
570        session_id: &[u8],
571        sender_key: &[u8],
572        data: &[u8],
573    ) -> rusqlite::Result<()> {
574        self.execute(
575            "INSERT INTO session (session_id, sender_key, data)
576             VALUES (?1, ?2, ?3)
577             ON CONFLICT (session_id) DO UPDATE SET data = ?3",
578            (session_id, sender_key, data),
579        )?;
580        Ok(())
581    }
582
583    fn set_inbound_group_session(
584        &self,
585        room_id: &[u8],
586        session_id: &[u8],
587        data: &[u8],
588        backed_up: bool,
589        sender_key: Option<&[u8]>,
590        sender_data_type: Option<u8>,
591    ) -> rusqlite::Result<()> {
592        self.execute(
593            "INSERT INTO inbound_group_session (session_id, room_id, data, backed_up, sender_key, sender_data_type) \
594             VALUES (?1, ?2, ?3, ?4, ?5, ?6)
595             ON CONFLICT (session_id) DO UPDATE SET data = ?3, backed_up = ?4, sender_key = ?5, sender_data_type = ?6",
596            (session_id, room_id, data, backed_up, sender_key, sender_data_type),
597        )?;
598        Ok(())
599    }
600
601    fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
602        self.execute(
603            "INSERT INTO outbound_group_session (room_id, data) \
604             VALUES (?1, ?2)
605             ON CONFLICT (room_id) DO UPDATE SET data = ?2",
606            (room_id, data),
607        )?;
608        Ok(())
609    }
610
611    fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
612        self.execute(
613            "INSERT INTO device (user_id, device_id, data) \
614             VALUES (?1, ?2, ?3)
615             ON CONFLICT (user_id, device_id) DO UPDATE SET data = ?3",
616            (user_id, device_id, data),
617        )?;
618        Ok(())
619    }
620
621    fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()> {
622        self.execute(
623            "DELETE FROM device WHERE user_id = ? AND device_id = ?",
624            (user_id, device_id),
625        )?;
626        Ok(())
627    }
628
629    fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
630        self.execute(
631            "INSERT INTO identity (user_id, data) \
632             VALUES (?1, ?2)
633             ON CONFLICT (user_id) DO UPDATE SET data = ?2",
634            (user_id, data),
635        )?;
636        Ok(())
637    }
638
639    fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()> {
640        self.execute("INSERT INTO olm_hash (data) VALUES (?) ON CONFLICT DO NOTHING", (data,))?;
641        Ok(())
642    }
643
644    fn set_key_request(
645        &self,
646        request_id: &[u8],
647        sent_out: bool,
648        data: &[u8],
649    ) -> rusqlite::Result<()> {
650        self.execute(
651            "INSERT INTO key_requests (request_id, sent_out, data)
652            VALUES (?1, ?2, ?3)
653            ON CONFLICT (request_id) DO UPDATE SET sent_out = ?2, data = ?3",
654            (request_id, sent_out, data),
655        )?;
656        Ok(())
657    }
658
659    fn set_direct_withheld(
660        &self,
661        session_id: &[u8],
662        room_id: &[u8],
663        data: &[u8],
664    ) -> rusqlite::Result<()> {
665        self.execute(
666            "INSERT INTO direct_withheld_info (session_id, room_id, data)
667            VALUES (?1, ?2, ?3)
668            ON CONFLICT (session_id) DO UPDATE SET room_id = ?2, data = ?3",
669            (session_id, room_id, data),
670        )?;
671        Ok(())
672    }
673
674    fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
675        self.execute(
676            "INSERT INTO room_settings (room_id, data)
677            VALUES (?1, ?2)
678            ON CONFLICT (room_id) DO UPDATE SET data = ?2",
679            (room_id, data),
680        )?;
681        Ok(())
682    }
683
684    fn set_secret(&self, secret_name: &[u8], secret: &[u8]) -> rusqlite::Result<()> {
685        // Ignore duplicate values, since we may get set the same secret
686        // multiple times.
687        self.execute(
688            "INSERT OR IGNORE INTO secrets_inbox (secret_name, secret)
689            VALUES (?1, ?2)",
690            (secret_name, secret),
691        )?;
692
693        Ok(())
694    }
695
696    fn set_received_room_key_bundle(
697        &self,
698        room_id: &[u8],
699        sender_user_id: &[u8],
700        data: &[u8],
701    ) -> rusqlite::Result<()> {
702        self.execute(
703            "INSERT INTO received_room_key_bundle(room_id, sender_user_id, bundle_data)
704            VALUES (?1, ?2, ?3)
705            ON CONFLICT (room_id, sender_user_id) DO UPDATE SET bundle_data = ?3",
706            (room_id, sender_user_id, data),
707        )?;
708        Ok(())
709    }
710
711    fn set_room_pending_key_bundle(
712        &self,
713        room_id: &[u8],
714        data: Option<&[u8]>,
715    ) -> rusqlite::Result<()> {
716        if let Some(data) = data {
717            self.execute(
718                "INSERT INTO rooms_pending_key_bundle (room_id, data)
719                 VALUES (?1, ?2)
720                 ON CONFLICT (room_id) DO UPDATE SET data = ?2",
721                (room_id, data),
722            )?;
723        } else {
724            self.execute("DELETE FROM rooms_pending_key_bundle WHERE room_id = ?1", (room_id,))?;
725        }
726        Ok(())
727    }
728
729    fn set_has_downloaded_all_room_keys(&self, room_id: &[u8]) -> rusqlite::Result<()> {
730        self.execute(
731            "INSERT INTO room_key_backups_fully_downloaded(room_id)
732             VALUES (?1)
733             ON CONFLICT(room_id) DO NOTHING",
734            (room_id,),
735        )?;
736        Ok(())
737    }
738}
739
740#[async_trait]
741trait SqliteObjectCryptoStoreExt: SqliteAsyncConnExt {
742    async fn get_sessions_for_sender_key(&self, sender_key: Key) -> Result<Vec<Vec<u8>>> {
743        Ok(self
744            .prepare("SELECT data FROM session WHERE sender_key = ?", |mut stmt| {
745                stmt.query((sender_key,))?.mapped(|row| row.get(0)).collect()
746            })
747            .await?)
748    }
749
750    async fn get_inbound_group_session(
751        &self,
752        session_id: Key,
753    ) -> Result<Option<(Vec<u8>, Vec<u8>, bool)>> {
754        Ok(self
755            .query_row(
756                "SELECT room_id, data, backed_up FROM inbound_group_session WHERE session_id = ?",
757                (session_id,),
758                |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
759            )
760            .await
761            .optional()?)
762    }
763
764    async fn get_inbound_group_sessions(&self) -> Result<Vec<(Vec<u8>, bool)>> {
765        Ok(self
766            .prepare("SELECT data, backed_up FROM inbound_group_session", |mut stmt| {
767                stmt.query(())?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
768            })
769            .await?)
770    }
771
772    async fn get_inbound_group_session_counts(
773        &self,
774        _backup_version: Option<&str>,
775    ) -> Result<RoomKeyCounts> {
776        let total = self
777            .query_row("SELECT count(*) FROM inbound_group_session", (), |row| row.get(0))
778            .await?;
779        let backed_up = self
780            .query_row(
781                "SELECT count(*) FROM inbound_group_session WHERE backed_up = TRUE",
782                (),
783                |row| row.get(0),
784            )
785            .await?;
786        Ok(RoomKeyCounts { total, backed_up })
787    }
788
789    async fn get_inbound_group_sessions_by_room_id(
790        &self,
791        room_id: Key,
792    ) -> Result<Vec<(Vec<u8>, bool)>> {
793        Ok(self
794            .prepare(
795                "SELECT data, backed_up FROM inbound_group_session WHERE room_id = :room_id",
796                move |mut stmt| {
797                    stmt.query(named_params! {
798                        ":room_id": room_id,
799                    })?
800                    .mapped(|row| Ok((row.get(0)?, row.get(1)?)))
801                    .collect()
802                },
803            )
804            .await?)
805    }
806
807    async fn get_inbound_group_sessions_for_device_batch(
808        &self,
809        sender_key: Key,
810        sender_data_type: SenderDataType,
811        after_session_id: Option<Key>,
812        limit: usize,
813    ) -> Result<Vec<(Vec<u8>, bool)>> {
814        Ok(self
815            .prepare(
816                "
817                SELECT data, backed_up
818                FROM inbound_group_session
819                WHERE sender_key = :sender_key
820                    AND sender_data_type = :sender_data_type
821                    AND session_id > :after_session_id
822                ORDER BY session_id
823                LIMIT :limit
824                ",
825                move |mut stmt| {
826                    let sender_data_type = sender_data_type as u8;
827
828                    // If we are not provided with an `after_session_id`, use a key which will sort
829                    // before all real keys: the empty string.
830                    let after_session_id = after_session_id.unwrap_or(Key::Plain(Vec::new()));
831
832                    stmt.query(named_params! {
833                        ":sender_key": sender_key,
834                        ":sender_data_type": sender_data_type,
835                        ":after_session_id": after_session_id,
836                        ":limit": limit,
837                    })?
838                    .mapped(|row| Ok((row.get(0)?, row.get(1)?)))
839                    .collect()
840                },
841            )
842            .await?)
843    }
844
845    async fn get_inbound_group_sessions_for_backup(&self, limit: usize) -> Result<Vec<Vec<u8>>> {
846        Ok(self
847            .prepare(
848                "SELECT data FROM inbound_group_session WHERE backed_up = FALSE LIMIT ?",
849                move |mut stmt| stmt.query((limit,))?.mapped(|row| row.get(0)).collect(),
850            )
851            .await?)
852    }
853
854    async fn mark_inbound_group_sessions_as_backed_up(&self, session_ids: Vec<Key>) -> Result<()> {
855        if session_ids.is_empty() {
856            // We are not expecting to be called with an empty list of sessions
857            warn!("No sessions to mark as backed up!");
858            return Ok(());
859        }
860
861        let session_ids_len = session_ids.len();
862
863        self.chunk_large_query_over(session_ids, None, move |txn, session_ids| {
864            // Safety: placeholders is not generated using any user input except the number
865            // of session IDs, so it is safe from injection.
866            let sql_params = repeat_vars(session_ids_len);
867            let query = format!("UPDATE inbound_group_session SET backed_up = TRUE where session_id IN ({sql_params})");
868            txn.prepare(&query)?.execute(params_from_iter(session_ids.iter()))?;
869            Ok(Vec::<()>::new())
870        }).await?;
871
872        Ok(())
873    }
874
875    async fn reset_inbound_group_session_backup_state(&self) -> Result<()> {
876        self.execute("UPDATE inbound_group_session SET backed_up = FALSE", ()).await?;
877        Ok(())
878    }
879
880    async fn get_outbound_group_session(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
881        Ok(self
882            .query_row(
883                "SELECT data FROM outbound_group_session WHERE room_id = ?",
884                (room_id,),
885                |row| row.get(0),
886            )
887            .await
888            .optional()?)
889    }
890
891    async fn get_device(&self, user_id: Key, device_id: Key) -> Result<Option<Vec<u8>>> {
892        Ok(self
893            .query_row(
894                "SELECT data FROM device WHERE user_id = ? AND device_id = ?",
895                (user_id, device_id),
896                |row| row.get(0),
897            )
898            .await
899            .optional()?)
900    }
901
902    async fn get_user_devices(&self, user_id: Key) -> Result<Vec<Vec<u8>>> {
903        Ok(self
904            .prepare("SELECT data FROM device WHERE user_id = ?", |mut stmt| {
905                stmt.query((user_id,))?.mapped(|row| row.get(0)).collect()
906            })
907            .await?)
908    }
909
910    async fn get_user_identity(&self, user_id: Key) -> Result<Option<Vec<u8>>> {
911        Ok(self
912            .query_row("SELECT data FROM identity WHERE user_id = ?", (user_id,), |row| row.get(0))
913            .await
914            .optional()?)
915    }
916
917    async fn has_olm_hash(&self, data: Vec<u8>) -> Result<bool> {
918        Ok(self
919            .query_row("SELECT count(*) FROM olm_hash WHERE data = ?", (data,), |row| {
920                row.get::<_, i32>(0)
921            })
922            .await?
923            > 0)
924    }
925
926    async fn get_tracked_users(&self) -> Result<Vec<Vec<u8>>> {
927        Ok(self
928            .prepare("SELECT data FROM tracked_user", |mut stmt| {
929                stmt.query(())?.mapped(|row| row.get(0)).collect()
930            })
931            .await?)
932    }
933
934    async fn add_tracked_users(&self, users: Vec<(Key, Vec<u8>)>) -> Result<()> {
935        Ok(self
936            .prepare(
937                "INSERT INTO tracked_user (user_id, data) \
938                 VALUES (?1, ?2) \
939                 ON CONFLICT (user_id) DO UPDATE SET data = ?2",
940                |mut stmt| {
941                    for (user_id, data) in users {
942                        stmt.execute((user_id, data))?;
943                    }
944
945                    Ok(())
946                },
947            )
948            .await?)
949    }
950
951    async fn get_outgoing_secret_request(
952        &self,
953        request_id: Key,
954    ) -> Result<Option<(Vec<u8>, bool)>> {
955        Ok(self
956            .query_row(
957                "SELECT data, sent_out FROM key_requests WHERE request_id = ?",
958                (request_id,),
959                |row| Ok((row.get(0)?, row.get(1)?)),
960            )
961            .await
962            .optional()?)
963    }
964
965    async fn get_outgoing_secret_requests(&self) -> Result<Vec<(Vec<u8>, bool)>> {
966        Ok(self
967            .prepare("SELECT data, sent_out FROM key_requests", |mut stmt| {
968                stmt.query(())?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
969            })
970            .await?)
971    }
972
973    async fn get_unsent_secret_requests(&self) -> Result<Vec<Vec<u8>>> {
974        Ok(self
975            .prepare("SELECT data FROM key_requests WHERE sent_out = FALSE", |mut stmt| {
976                stmt.query(())?.mapped(|row| row.get(0)).collect()
977            })
978            .await?)
979    }
980
981    async fn delete_key_request(&self, request_id: Key) -> Result<()> {
982        self.execute("DELETE FROM key_requests WHERE request_id = ?", (request_id,)).await?;
983        Ok(())
984    }
985
986    async fn get_secrets_from_inbox(&self, secret_name: Key) -> Result<Vec<Vec<u8>>> {
987        Ok(self
988            .prepare("SELECT secret FROM secrets_inbox WHERE secret_name = ?", |mut stmt| {
989                stmt.query((secret_name,))?.mapped(|row| row.get(0)).collect()
990            })
991            .await?)
992    }
993
994    async fn delete_secrets_from_inbox(&self, secret_name: Key) -> Result<()> {
995        self.execute("DELETE FROM secrets_inbox WHERE secret_name = ?", (secret_name,)).await?;
996        Ok(())
997    }
998
999    async fn get_direct_withheld_info(
1000        &self,
1001        session_id: Key,
1002        room_id: Key,
1003    ) -> Result<Option<Vec<u8>>> {
1004        Ok(self
1005            .query_row(
1006                "SELECT data FROM direct_withheld_info WHERE session_id = ?1 AND room_id = ?2",
1007                (session_id, room_id),
1008                |row| row.get(0),
1009            )
1010            .await
1011            .optional()?)
1012    }
1013
1014    async fn get_withheld_sessions_by_room_id(&self, room_id: Key) -> Result<Vec<Vec<u8>>> {
1015        Ok(self
1016            .prepare("SELECT data FROM direct_withheld_info WHERE room_id = ?1", |mut stmt| {
1017                stmt.query((room_id,))?.mapped(|row| row.get(0)).collect()
1018            })
1019            .await?)
1020    }
1021
1022    async fn get_room_settings(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
1023        Ok(self
1024            .query_row("SELECT data FROM room_settings WHERE room_id = ?", (room_id,), |row| {
1025                row.get(0)
1026            })
1027            .await
1028            .optional()?)
1029    }
1030
1031    async fn get_received_room_key_bundle(
1032        &self,
1033        room_id: Key,
1034        sender_user: Key,
1035    ) -> Result<Option<Vec<u8>>> {
1036        Ok(self
1037            .query_row(
1038                "SELECT bundle_data FROM received_room_key_bundle WHERE room_id = ? AND sender_user_id = ?",
1039                (room_id, sender_user),
1040                |row| { row.get(0) },
1041            )
1042            .await
1043            .optional()?)
1044    }
1045
1046    async fn get_room_pending_key_bundle(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
1047        Ok(self
1048            .query_row(
1049                "SELECT data FROM rooms_pending_key_bundle WHERE room_id = ?",
1050                (room_id,),
1051                |row| row.get(0),
1052            )
1053            .await
1054            .optional()?)
1055    }
1056
1057    async fn get_all_rooms_pending_key_bundle(&self) -> Result<Vec<Vec<u8>>> {
1058        Ok(self
1059            .query_many("SELECT data FROM rooms_pending_key_bundle", (), |row| row.get(0))
1060            .await?)
1061    }
1062
1063    async fn has_downloaded_all_room_keys(&self, room_id: Key) -> Result<bool> {
1064        Ok(self
1065            .query_row(
1066                "SELECT EXISTS (SELECT 1 FROM room_key_backups_fully_downloaded WHERE room_id = ?)",
1067                (room_id,),
1068                |row| row.get(0),
1069            )
1070            .await?)
1071    }
1072}
1073
1074#[async_trait]
1075impl SqliteObjectCryptoStoreExt for SqliteAsyncConn {}
1076
1077#[async_trait]
1078impl CryptoStore for SqliteCryptoStore {
1079    type Error = Error;
1080
1081    async fn load_account(&self) -> Result<Option<Account>> {
1082        let conn = self.read().await?;
1083        if let Some(pickle) = conn.get_kv("account").await? {
1084            let pickle = self.deserialize_value(&pickle)?;
1085
1086            let account = Account::from_pickle(pickle).map_err(|_| Error::Unpickle)?;
1087
1088            *self.static_account.write().unwrap() = Some(account.static_data().clone());
1089
1090            Ok(Some(account))
1091        } else {
1092            Ok(None)
1093        }
1094    }
1095
1096    async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
1097        let conn = self.read().await?;
1098        if let Some(i) = conn.get_kv("identity").await? {
1099            let pickle = self.deserialize_value(&i)?;
1100            Ok(Some(PrivateCrossSigningIdentity::from_pickle(pickle).map_err(|_| Error::Unpickle)?))
1101        } else {
1102            Ok(None)
1103        }
1104    }
1105
1106    async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
1107        // Serialize calls to `save_pending_changes`; there are multiple await points
1108        // below, and we're pickling data as we go, so we don't want to
1109        // invalidate data we've previously read and overwrite it in the store.
1110        // TODO: #2000 should make this lock go away, or change its shape.
1111        let _guard = self.save_changes_lock.lock().await;
1112
1113        let pickled_account = if let Some(account) = changes.account {
1114            *self.static_account.write().unwrap() = Some(account.static_data().clone());
1115            Some(account.pickle())
1116        } else {
1117            None
1118        };
1119
1120        let this = self.clone();
1121        self.write()
1122            .await?
1123            .with_transaction(move |txn| {
1124                if let Some(pickled_account) = pickled_account {
1125                    let serialized_account = this.serialize_value(&pickled_account)?;
1126                    txn.set_kv("account", &serialized_account)?;
1127                }
1128
1129                Ok::<_, Error>(())
1130            })
1131            .await?;
1132
1133        Ok(())
1134    }
1135
1136    async fn save_changes(&self, changes: Changes) -> Result<()> {
1137        // Serialize calls to `save_changes`; there are multiple await points below, and
1138        // we're pickling data as we go, so we don't want to invalidate data
1139        // we've previously read and overwrite it in the store.
1140        // TODO: #2000 should make this lock go away, or change its shape.
1141        let _guard = self.save_changes_lock.lock().await;
1142
1143        let pickled_private_identity =
1144            if let Some(i) = changes.private_identity { Some(i.pickle().await) } else { None };
1145
1146        let mut session_changes = Vec::new();
1147
1148        for session in changes.sessions {
1149            let session_id = self.encode_key("session", session.session_id());
1150            let sender_key = self.encode_key("session", session.sender_key().to_base64());
1151            let pickle = session.pickle().await;
1152            session_changes.push((session_id, sender_key, pickle));
1153        }
1154
1155        let mut inbound_session_changes = Vec::new();
1156        for session in changes.inbound_group_sessions {
1157            let room_id = self.encode_key("inbound_group_session", session.room_id().as_bytes());
1158            let session_id = self.encode_key("inbound_group_session", session.session_id());
1159            let pickle = session.pickle().await;
1160            let sender_key =
1161                self.encode_key("inbound_group_session", session.sender_key().to_base64());
1162            inbound_session_changes.push((room_id, session_id, pickle, sender_key));
1163        }
1164
1165        let mut outbound_session_changes = Vec::new();
1166        for session in changes.outbound_group_sessions {
1167            let room_id = self.encode_key("outbound_group_session", session.room_id().as_bytes());
1168            let pickle = session.pickle().await;
1169            outbound_session_changes.push((room_id, pickle));
1170        }
1171
1172        let this = self.clone();
1173        self.write()
1174            .await?
1175            .with_transaction(move |txn| {
1176                if let Some(pickled_private_identity) = &pickled_private_identity {
1177                    let serialized_private_identity =
1178                        this.serialize_value(pickled_private_identity)?;
1179                    txn.set_kv("identity", &serialized_private_identity)?;
1180                }
1181
1182                if let Some(token) = &changes.next_batch_token {
1183                    let serialized_token = this.serialize_value(token)?;
1184                    txn.set_kv("next_batch_token", &serialized_token)?;
1185                }
1186
1187                if let Some(decryption_key) = &changes.backup_decryption_key {
1188                    let serialized_decryption_key = this.serialize_value(decryption_key)?;
1189                    txn.set_kv("recovery_key_v1", &serialized_decryption_key)?;
1190                }
1191
1192                if let Some(backup_version) = &changes.backup_version {
1193                    let serialized_backup_version = this.serialize_value(backup_version)?;
1194                    txn.set_kv("backup_version_v1", &serialized_backup_version)?;
1195                }
1196
1197                if let Some(pickle_key) = &changes.dehydrated_device_pickle_key {
1198                    let serialized_pickle_key = this.serialize_value(pickle_key)?;
1199                    txn.set_kv(DEHYDRATED_DEVICE_PICKLE_KEY, &serialized_pickle_key)?;
1200                }
1201
1202                for device in changes.devices.new.iter().chain(&changes.devices.changed) {
1203                    let user_id = this.encode_key("device", device.user_id().as_bytes());
1204                    let device_id = this.encode_key("device", device.device_id().as_bytes());
1205                    let data = this.serialize_value(&device)?;
1206                    txn.set_device(&user_id, &device_id, &data)?;
1207                }
1208
1209                for device in &changes.devices.deleted {
1210                    let user_id = this.encode_key("device", device.user_id().as_bytes());
1211                    let device_id = this.encode_key("device", device.device_id().as_bytes());
1212                    txn.delete_device(&user_id, &device_id)?;
1213                }
1214
1215                for identity in changes.identities.changed.iter().chain(&changes.identities.new) {
1216                    let user_id = this.encode_key("identity", identity.user_id().as_bytes());
1217                    let data = this.serialize_value(&identity)?;
1218                    txn.set_identity(&user_id, &data)?;
1219                }
1220
1221                for (session_id, sender_key, pickle) in &session_changes {
1222                    let serialized_session = this.serialize_value(&pickle)?;
1223                    txn.set_session(session_id, sender_key, &serialized_session)?;
1224                }
1225
1226                for (room_id, session_id, pickle, sender_key) in &inbound_session_changes {
1227                    let serialized_session = this.serialize_value(&pickle)?;
1228                    txn.set_inbound_group_session(
1229                        room_id,
1230                        session_id,
1231                        &serialized_session,
1232                        pickle.backed_up,
1233                        Some(sender_key),
1234                        Some(pickle.sender_data.to_type() as u8),
1235                    )?;
1236                }
1237
1238                for (room_id, pickle) in &outbound_session_changes {
1239                    let serialized_session = this.serialize_json(&pickle)?;
1240                    txn.set_outbound_group_session(room_id, &serialized_session)?;
1241                }
1242
1243                for hash in &changes.message_hashes {
1244                    let hash = rmp_serde::to_vec(hash)?;
1245                    txn.add_olm_hash(&hash)?;
1246                }
1247
1248                for request in changes.key_requests {
1249                    let request_id = this.encode_key("key_requests", request.request_id.as_bytes());
1250                    let serialized_request = this.serialize_value(&request)?;
1251                    txn.set_key_request(&request_id, request.sent_out, &serialized_request)?;
1252                }
1253
1254                for (room_id, data) in changes.withheld_session_info {
1255                    for (session_id, event) in data {
1256                        let session_id = this.encode_key("direct_withheld_info", session_id);
1257                        let room_id = this.encode_key("direct_withheld_info", &room_id);
1258                        let serialized_info = this.serialize_json(&event)?;
1259                        txn.set_direct_withheld(&session_id, &room_id, &serialized_info)?;
1260                    }
1261                }
1262
1263                for (room_id, settings) in changes.room_settings {
1264                    let room_id = this.encode_key("room_settings", room_id.as_bytes());
1265                    let value = this.serialize_value(&settings)?;
1266                    txn.set_room_settings(&room_id, &value)?;
1267                }
1268
1269                for secret in changes.secrets {
1270                    let secret_name =
1271                        this.encode_key("secrets_inbox", secret.secret_name.to_string());
1272                    let value = this.serialize_json(secret.secret.deref())?;
1273                    txn.set_secret(&secret_name, &value)?;
1274                }
1275
1276                for bundle in changes.received_room_key_bundles {
1277                    let room_id =
1278                        this.encode_key("received_room_key_bundle", &bundle.bundle_data.room_id);
1279                    let user_id = this.encode_key("received_room_key_bundle", &bundle.sender_user);
1280                    let value = this.serialize_value(&bundle)?;
1281                    txn.set_received_room_key_bundle(&room_id, &user_id, &value)?;
1282                }
1283
1284                for room in changes.room_key_backups_fully_downloaded {
1285                    let room_id = this.encode_key("room_key_backups_fully_downloaded", &room);
1286                    txn.set_has_downloaded_all_room_keys(&room_id)?;
1287                }
1288
1289                for (room, details) in changes.rooms_pending_key_bundle {
1290                    let room_id = this.encode_key("rooms_pending_key_bundle", &room);
1291                    let value = details.as_ref().map(|d| this.serialize_value(d)).transpose()?;
1292                    txn.set_room_pending_key_bundle(&room_id, value.as_deref())?;
1293                }
1294
1295                Ok::<_, Error>(())
1296            })
1297            .await?;
1298
1299        Ok(())
1300    }
1301
1302    async fn save_inbound_group_sessions(
1303        &self,
1304        sessions: Vec<InboundGroupSession>,
1305        backed_up_to_version: Option<&str>,
1306    ) -> matrix_sdk_crypto::store::Result<(), Self::Error> {
1307        // Sanity-check that the data in the sessions corresponds to backed_up_version
1308        sessions.iter().for_each(|s| {
1309            let backed_up = s.backed_up();
1310            if backed_up != backed_up_to_version.is_some() {
1311                warn!(
1312                    backed_up,
1313                    backed_up_to_version,
1314                    "Session backed-up flag does not correspond to backup version setting",
1315                );
1316            }
1317        });
1318
1319        // Currently, this store doesn't save the backup version separately, so this
1320        // just delegates to save_changes.
1321        self.save_changes(Changes { inbound_group_sessions: sessions, ..Changes::default() }).await
1322    }
1323
1324    async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
1325        let device_keys = self.get_own_device().await?.as_device_keys().clone();
1326
1327        let sessions: Vec<_> = self
1328            .read()
1329            .await?
1330            .get_sessions_for_sender_key(self.encode_key("session", sender_key.as_bytes()))
1331            .await?
1332            .into_iter()
1333            .map(|bytes| {
1334                let pickle = self.deserialize_value(&bytes)?;
1335                Session::from_pickle(device_keys.clone(), pickle).map_err(|_| Error::AccountUnset)
1336            })
1337            .collect::<Result<_>>()?;
1338
1339        if sessions.is_empty() { Ok(None) } else { Ok(Some(sessions)) }
1340    }
1341
1342    #[instrument(skip(self))]
1343    async fn get_inbound_group_session(
1344        &self,
1345        room_id: &RoomId,
1346        session_id: &str,
1347    ) -> Result<Option<InboundGroupSession>> {
1348        let session_id = self.encode_key("inbound_group_session", session_id);
1349        let Some((room_id_from_db, value, backed_up)) =
1350            self.read().await?.get_inbound_group_session(session_id).await?
1351        else {
1352            return Ok(None);
1353        };
1354
1355        let room_id = self.encode_key("inbound_group_session", room_id.as_bytes());
1356        if *room_id != room_id_from_db {
1357            warn!("expected room_id for session_id doesn't match what's in the DB");
1358            return Ok(None);
1359        }
1360
1361        Ok(Some(self.deserialize_and_unpickle_inbound_group_session(value, backed_up)?))
1362    }
1363
1364    async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
1365        self.read()
1366            .await?
1367            .get_inbound_group_sessions()
1368            .await?
1369            .into_iter()
1370            .map(|(value, backed_up)| {
1371                self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1372            })
1373            .collect()
1374    }
1375
1376    async fn get_inbound_group_sessions_by_room_id(
1377        &self,
1378        room_id: &RoomId,
1379    ) -> Result<Vec<InboundGroupSession>> {
1380        let room_id = self.encode_key("inbound_group_session", room_id.as_bytes());
1381        self.read()
1382            .await?
1383            .get_inbound_group_sessions_by_room_id(room_id)
1384            .await?
1385            .into_iter()
1386            .map(|(value, backed_up)| {
1387                self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1388            })
1389            .collect()
1390    }
1391
1392    async fn get_inbound_group_sessions_for_device_batch(
1393        &self,
1394        sender_key: Curve25519PublicKey,
1395        sender_data_type: SenderDataType,
1396        after_session_id: Option<String>,
1397        limit: usize,
1398    ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1399        let after_session_id =
1400            after_session_id.map(|session_id| self.encode_key("inbound_group_session", session_id));
1401        let sender_key = self.encode_key("inbound_group_session", sender_key.to_base64());
1402
1403        self.read()
1404            .await?
1405            .get_inbound_group_sessions_for_device_batch(
1406                sender_key,
1407                sender_data_type,
1408                after_session_id,
1409                limit,
1410            )
1411            .await?
1412            .into_iter()
1413            .map(|(value, backed_up)| {
1414                self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1415            })
1416            .collect()
1417    }
1418
1419    async fn inbound_group_session_counts(
1420        &self,
1421        backup_version: Option<&str>,
1422    ) -> Result<RoomKeyCounts> {
1423        Ok(self.read().await?.get_inbound_group_session_counts(backup_version).await?)
1424    }
1425
1426    async fn inbound_group_sessions_for_backup(
1427        &self,
1428        _backup_version: &str,
1429        limit: usize,
1430    ) -> Result<Vec<InboundGroupSession>> {
1431        self.read()
1432            .await?
1433            .get_inbound_group_sessions_for_backup(limit)
1434            .await?
1435            .into_iter()
1436            .map(|value| self.deserialize_and_unpickle_inbound_group_session(value, false))
1437            .collect()
1438    }
1439
1440    async fn mark_inbound_group_sessions_as_backed_up(
1441        &self,
1442        _backup_version: &str,
1443        session_ids: &[(&RoomId, &str)],
1444    ) -> Result<()> {
1445        Ok(self
1446            .write()
1447            .await?
1448            .mark_inbound_group_sessions_as_backed_up(
1449                session_ids
1450                    .iter()
1451                    .map(|(_, s)| self.encode_key("inbound_group_session", s))
1452                    .collect(),
1453            )
1454            .await?)
1455    }
1456
1457    async fn reset_backup_state(&self) -> Result<()> {
1458        Ok(self.write().await?.reset_inbound_group_session_backup_state().await?)
1459    }
1460
1461    async fn load_backup_keys(&self) -> Result<BackupKeys> {
1462        let conn = self.read().await?;
1463
1464        let backup_version = conn
1465            .get_kv("backup_version_v1")
1466            .await?
1467            .map(|value| self.deserialize_value(&value))
1468            .transpose()?;
1469
1470        let decryption_key = conn
1471            .get_kv("recovery_key_v1")
1472            .await?
1473            .map(|value| self.deserialize_value(&value))
1474            .transpose()?;
1475
1476        Ok(BackupKeys { backup_version, decryption_key })
1477    }
1478
1479    async fn load_dehydrated_device_pickle_key(&self) -> Result<Option<DehydratedDeviceKey>> {
1480        let conn = self.read().await?;
1481
1482        conn.get_kv(DEHYDRATED_DEVICE_PICKLE_KEY)
1483            .await?
1484            .map(|value| self.deserialize_value(&value))
1485            .transpose()
1486    }
1487
1488    async fn delete_dehydrated_device_pickle_key(&self) -> Result<(), Self::Error> {
1489        Ok(self.write().await?.clear_kv(DEHYDRATED_DEVICE_PICKLE_KEY).await?)
1490    }
1491    async fn get_outbound_group_session(
1492        &self,
1493        room_id: &RoomId,
1494    ) -> Result<Option<OutboundGroupSession>> {
1495        let room_id = self.encode_key("outbound_group_session", room_id.as_bytes());
1496        let Some(value) = self.read().await?.get_outbound_group_session(room_id).await? else {
1497            return Ok(None);
1498        };
1499
1500        let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1501
1502        let pickle = self.deserialize_json(&value)?;
1503        let session = OutboundGroupSession::from_pickle(
1504            account_info.device_id,
1505            account_info.identity_keys,
1506            pickle,
1507        )
1508        .map_err(|_| Error::Unpickle)?;
1509
1510        return Ok(Some(session));
1511    }
1512
1513    async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
1514        self.read()
1515            .await?
1516            .get_tracked_users()
1517            .await?
1518            .iter()
1519            .map(|value| self.deserialize_value(value))
1520            .collect()
1521    }
1522
1523    async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
1524        let users: Vec<(Key, Vec<u8>)> = tracked_users
1525            .iter()
1526            .map(|(u, d)| {
1527                let user_id = self.encode_key("tracked_users", u.as_bytes());
1528                let data =
1529                    self.serialize_value(&TrackedUser { user_id: (*u).into(), dirty: *d })?;
1530                Ok((user_id, data))
1531            })
1532            .collect::<Result<_>>()?;
1533
1534        Ok(self.write().await?.add_tracked_users(users).await?)
1535    }
1536
1537    async fn get_device(
1538        &self,
1539        user_id: &UserId,
1540        device_id: &DeviceId,
1541    ) -> Result<Option<DeviceData>> {
1542        let user_id = self.encode_key("device", user_id.as_bytes());
1543        let device_id = self.encode_key("device", device_id.as_bytes());
1544        Ok(self
1545            .read()
1546            .await?
1547            .get_device(user_id, device_id)
1548            .await?
1549            .map(|value| self.deserialize_value(&value))
1550            .transpose()?)
1551    }
1552
1553    async fn get_user_devices(
1554        &self,
1555        user_id: &UserId,
1556    ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1557        let user_id = self.encode_key("device", user_id.as_bytes());
1558        self.read()
1559            .await?
1560            .get_user_devices(user_id)
1561            .await?
1562            .into_iter()
1563            .map(|value| {
1564                let device: DeviceData = self.deserialize_value(&value)?;
1565                Ok((device.device_id().to_owned(), device))
1566            })
1567            .collect()
1568    }
1569
1570    async fn get_own_device(&self) -> Result<DeviceData> {
1571        let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1572
1573        Ok(self
1574            .get_device(&account_info.user_id, &account_info.device_id)
1575            .await?
1576            .expect("We should be able to find our own device."))
1577    }
1578
1579    async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
1580        let user_id = self.encode_key("identity", user_id.as_bytes());
1581        Ok(self
1582            .read()
1583            .await?
1584            .get_user_identity(user_id)
1585            .await?
1586            .map(|value| self.deserialize_value(&value))
1587            .transpose()?)
1588    }
1589
1590    async fn is_message_known(
1591        &self,
1592        message_hash: &matrix_sdk_crypto::olm::OlmMessageHash,
1593    ) -> Result<bool> {
1594        let value = rmp_serde::to_vec(message_hash)?;
1595        Ok(self.read().await?.has_olm_hash(value).await?)
1596    }
1597
1598    async fn get_outgoing_secret_requests(
1599        &self,
1600        request_id: &TransactionId,
1601    ) -> Result<Option<GossipRequest>> {
1602        let request_id = self.encode_key("key_requests", request_id.as_bytes());
1603        Ok(self
1604            .read()
1605            .await?
1606            .get_outgoing_secret_request(request_id)
1607            .await?
1608            .map(|(value, sent_out)| self.deserialize_key_request(&value, sent_out))
1609            .transpose()?)
1610    }
1611
1612    async fn get_secret_request_by_info(
1613        &self,
1614        key_info: &SecretInfo,
1615    ) -> Result<Option<GossipRequest>> {
1616        let requests = self.read().await?.get_outgoing_secret_requests().await?;
1617        for (request, sent_out) in requests {
1618            let request = self.deserialize_key_request(&request, sent_out)?;
1619            if request.info == *key_info {
1620                return Ok(Some(request));
1621            }
1622        }
1623        Ok(None)
1624    }
1625
1626    async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
1627        self.read()
1628            .await?
1629            .get_unsent_secret_requests()
1630            .await?
1631            .iter()
1632            .map(|value| {
1633                let request = self.deserialize_key_request(value, false)?;
1634                Ok(request)
1635            })
1636            .collect()
1637    }
1638
1639    async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
1640        let request_id = self.encode_key("key_requests", request_id.as_bytes());
1641        Ok(self.write().await?.delete_key_request(request_id).await?)
1642    }
1643
1644    async fn get_secrets_from_inbox(
1645        &self,
1646        secret_name: &SecretName,
1647    ) -> Result<Vec<Zeroizing<String>>> {
1648        let secret_name = self.encode_key("secrets_inbox", secret_name.to_string());
1649
1650        self.read()
1651            .await?
1652            .get_secrets_from_inbox(secret_name)
1653            .await?
1654            .into_iter()
1655            .map(|value| self.deserialize_json(value.as_ref()).map(|value: String| value.into()))
1656            .collect()
1657    }
1658
1659    async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
1660        let secret_name = self.encode_key("secrets_inbox", secret_name.to_string());
1661        self.write().await?.delete_secrets_from_inbox(secret_name).await
1662    }
1663
1664    async fn get_withheld_info(
1665        &self,
1666        room_id: &RoomId,
1667        session_id: &str,
1668    ) -> Result<Option<RoomKeyWithheldEntry>> {
1669        let room_id = self.encode_key("direct_withheld_info", room_id);
1670        let session_id = self.encode_key("direct_withheld_info", session_id);
1671
1672        self.read()
1673            .await?
1674            .get_direct_withheld_info(session_id, room_id)
1675            .await?
1676            .map(|value| {
1677                let info = self.deserialize_json::<RoomKeyWithheldEntry>(&value)?;
1678                Ok(info)
1679            })
1680            .transpose()
1681    }
1682
1683    async fn get_withheld_sessions_by_room_id(
1684        &self,
1685        room_id: &RoomId,
1686    ) -> matrix_sdk_crypto::store::Result<Vec<RoomKeyWithheldEntry>, Self::Error> {
1687        let room_id = self.encode_key("direct_withheld_info", room_id);
1688
1689        self.read()
1690            .await?
1691            .get_withheld_sessions_by_room_id(room_id)
1692            .await?
1693            .into_iter()
1694            .map(|value| self.deserialize_json(&value))
1695            .collect()
1696    }
1697
1698    async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
1699        let room_id = self.encode_key("room_settings", room_id.as_bytes());
1700        let Some(value) = self.read().await?.get_room_settings(room_id).await? else {
1701            return Ok(None);
1702        };
1703
1704        let settings = self.deserialize_value(&value)?;
1705
1706        return Ok(Some(settings));
1707    }
1708
1709    async fn get_received_room_key_bundle_data(
1710        &self,
1711        room_id: &RoomId,
1712        user_id: &UserId,
1713    ) -> Result<Option<StoredRoomKeyBundleData>> {
1714        let room_id = self.encode_key("received_room_key_bundle", room_id);
1715        let user_id = self.encode_key("received_room_key_bundle", user_id);
1716        self.read()
1717            .await?
1718            .get_received_room_key_bundle(room_id, user_id)
1719            .await?
1720            .map(|value| self.deserialize_value(&value))
1721            .transpose()
1722    }
1723
1724    async fn has_downloaded_all_room_keys(&self, room_id: &RoomId) -> Result<bool> {
1725        let room_id = self.encode_key("room_key_backups_fully_downloaded", room_id);
1726        self.read().await?.has_downloaded_all_room_keys(room_id).await
1727    }
1728
1729    async fn get_pending_key_bundle_details_for_room(
1730        &self,
1731        room_id: &RoomId,
1732    ) -> Result<Option<RoomPendingKeyBundleDetails>> {
1733        let room_id = self.encode_key("rooms_pending_key_bundle", room_id.as_bytes());
1734        let Some(value) = self.read().await?.get_room_pending_key_bundle(room_id).await? else {
1735            return Ok(None);
1736        };
1737
1738        let details = self.deserialize_value(&value)?;
1739        Ok(Some(details))
1740    }
1741
1742    async fn get_all_rooms_pending_key_bundles(&self) -> Result<Vec<RoomPendingKeyBundleDetails>> {
1743        let details = self.read().await?.get_all_rooms_pending_key_bundle().await?;
1744        let room_ids = details
1745            .into_iter()
1746            .map(|value| self.deserialize_value(&value))
1747            .collect::<Result<_, _>>()?;
1748        Ok(room_ids)
1749    }
1750
1751    async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
1752        let Some(serialized) = self.read().await?.get_kv(key).await? else {
1753            return Ok(None);
1754        };
1755        let value = if let Some(cipher) = &self.store_cipher {
1756            let encrypted = rmp_serde::from_slice(&serialized)?;
1757            cipher.decrypt_value_data(encrypted)?
1758        } else {
1759            serialized
1760        };
1761
1762        Ok(Some(value))
1763    }
1764
1765    async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
1766        let serialized = if let Some(cipher) = &self.store_cipher {
1767            let encrypted = cipher.encrypt_value_data(value)?;
1768            rmp_serde::to_vec_named(&encrypted)?
1769        } else {
1770            value
1771        };
1772
1773        self.write().await?.set_kv(key, serialized).await?;
1774        Ok(())
1775    }
1776
1777    async fn remove_custom_value(&self, key: &str) -> Result<()> {
1778        let key = key.to_owned();
1779        self.write()
1780            .await?
1781            .interact(move |conn| conn.execute("DELETE FROM kv WHERE key = ?1", (&key,)))
1782            .await
1783            .unwrap()?;
1784        Ok(())
1785    }
1786
1787    #[instrument(skip(self))]
1788    async fn try_take_leased_lock(
1789        &self,
1790        lease_duration_ms: u32,
1791        key: &str,
1792        holder: &str,
1793    ) -> Result<Option<CrossProcessLockGeneration>> {
1794        let key = key.to_owned();
1795        let holder = holder.to_owned();
1796
1797        let now: u64 = MilliSecondsSinceUnixEpoch::now().get().into();
1798        let expiration = now + lease_duration_ms as u64;
1799
1800        // Learn about the `excluded` keyword in https://sqlite.org/lang_upsert.html.
1801        let generation = self
1802            .write()
1803            .await?
1804            .with_transaction(move |txn| {
1805                txn.query_row(
1806                    "INSERT INTO lease_locks (key, holder, expiration)
1807                    VALUES (?1, ?2, ?3)
1808                    ON CONFLICT (key)
1809                    DO
1810                        UPDATE SET
1811                            holder = excluded.holder,
1812                            expiration = excluded.expiration,
1813                            generation =
1814                                CASE holder
1815                                    WHEN excluded.holder THEN generation
1816                                    ELSE generation + 1
1817                                END
1818                        WHERE
1819                            holder = excluded.holder
1820                            OR expiration < ?4
1821                    RETURNING generation
1822                    ",
1823                    (key, holder, expiration, now),
1824                    |row| row.get(0),
1825                )
1826                .optional()
1827            })
1828            .await?;
1829
1830        Ok(generation)
1831    }
1832
1833    async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
1834        let conn = self.read().await?;
1835        if let Some(token) = conn.get_kv("next_batch_token").await? {
1836            let maybe_token: Option<String> = self.deserialize_value(&token)?;
1837            Ok(maybe_token)
1838        } else {
1839            Ok(None)
1840        }
1841    }
1842
1843    async fn close(&self) -> Result<()> {
1844        connection::close_connections(&self.connections, "Crypto store").await;
1845        Ok(())
1846    }
1847
1848    async fn reopen(&self) -> Result<()> {
1849        connection::reopen_connections(
1850            &self.connections,
1851            self.db_path.clone(),
1852            self.pool_config,
1853            self.runtime_config,
1854        )
1855        .await?;
1856        Ok(())
1857    }
1858
1859    async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1860        Ok(Some(self.read().await?.get_db_size().await?))
1861    }
1862}
1863
1864#[cfg(test)]
1865mod tests {
1866    use std::{path::Path, sync::LazyLock};
1867
1868    use matrix_sdk_common::deserialized_responses::WithheldCode;
1869    use matrix_sdk_crypto::{
1870        cryptostore_integration_tests, cryptostore_integration_tests_time, olm::SenderDataType,
1871        store::CryptoStore,
1872    };
1873    use matrix_sdk_test::async_test;
1874    use ruma::{device_id, room_id, user_id};
1875    use similar_asserts::assert_eq;
1876    use tempfile::{TempDir, tempdir};
1877    use tokio::fs;
1878
1879    use super::SqliteCryptoStore;
1880    use crate::SqliteStoreConfig;
1881
1882    static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
1883
1884    struct TestDb {
1885        // Needs to be kept alive because the Drop implementation for TempDir deletes the
1886        // directory.
1887        _dir: TempDir,
1888        database: SqliteCryptoStore,
1889    }
1890
1891    fn copy_db(data_path: &str) -> TempDir {
1892        let db_name = super::DATABASE_NAME;
1893
1894        let manifest_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../..");
1895        let database_path = manifest_path.join(data_path).join(db_name);
1896
1897        let tmpdir = tempdir().unwrap();
1898        let destination = tmpdir.path().join(db_name);
1899
1900        // Copy the test database to the tempdir so our test runs are idempotent.
1901        std::fs::copy(&database_path, destination).unwrap();
1902
1903        tmpdir
1904    }
1905
1906    async fn get_test_db(data_path: &str, passphrase: Option<&str>) -> TestDb {
1907        let tmpdir = copy_db(data_path);
1908
1909        let database = SqliteCryptoStore::open(tmpdir.path(), passphrase)
1910            .await
1911            .expect("Can't open the test store");
1912
1913        TestDb { _dir: tmpdir, database }
1914    }
1915
1916    #[async_test]
1917    async fn test_pool_size() {
1918        let store_open_config =
1919            SqliteStoreConfig::new(TMP_DIR.path().join("test_pool_size")).pool_max_size(42);
1920
1921        let store = SqliteCryptoStore::open_with_config(&store_open_config).await.unwrap();
1922
1923        let guard = store.connections.lock().await;
1924        let conns = guard.as_ref().unwrap();
1925        assert_eq!(conns.pool.status().max_size, 42);
1926    }
1927
1928    /// Test that we didn't regress in our storage layer by loading data from a
1929    /// pre-filled database, or in other words use a test vector for this.
1930    #[async_test]
1931    async fn test_open_test_vector_store() {
1932        let TestDb { _dir: _, database } = get_test_db("testing/data/storage", None).await;
1933
1934        let account = database
1935            .load_account()
1936            .await
1937            .unwrap()
1938            .expect("The test database is prefilled with data, we should find an account");
1939
1940        let user_id = account.user_id();
1941        let device_id = account.device_id();
1942
1943        assert_eq!(
1944            user_id.as_str(),
1945            "@pjtest:synapse-oidc.element.dev",
1946            "The user ID should match to the one we expect."
1947        );
1948
1949        assert_eq!(
1950            device_id.as_str(),
1951            "v4TqgcuIH6",
1952            "The device ID should match to the one we expect."
1953        );
1954
1955        let device = database
1956            .get_device(user_id, device_id)
1957            .await
1958            .unwrap()
1959            .expect("Our own device should be found in the store.");
1960
1961        assert_eq!(device.device_id(), device_id);
1962        assert_eq!(device.user_id(), user_id);
1963
1964        assert_eq!(
1965            device.ed25519_key().expect("The device should have a Ed25519 key.").to_base64(),
1966            "+cxl1Gl3du5i7UJwfWnoRDdnafFF+xYdAiTYYhYLr8s"
1967        );
1968
1969        assert_eq!(
1970            device.curve25519_key().expect("The device should have a Curve25519 key.").to_base64(),
1971            "4SL9eEUlpyWSUvjljC5oMjknHQQJY7WZKo5S1KL/5VU"
1972        );
1973
1974        let identity = database
1975            .get_user_identity(user_id)
1976            .await
1977            .unwrap()
1978            .expect("The store should contain an identity.");
1979
1980        assert_eq!(identity.user_id(), user_id);
1981
1982        let identity = identity
1983            .own()
1984            .expect("The identity should be of the correct type, it should be our own identity.");
1985
1986        let master_key = identity
1987            .master_key()
1988            .get_first_key()
1989            .expect("Our own identity should have a master key");
1990
1991        assert_eq!(master_key.to_base64(), "iCUEtB1RwANeqRa5epDrblLk4mer/36sylwQ5hYY3oE");
1992    }
1993
1994    /// Test that we didn't regress in our storage layer by loading data from a
1995    /// pre-filled database, or in other words use a test vector for this.
1996    #[async_test]
1997    async fn test_open_test_vector_encrypted_store() {
1998        let TestDb { _dir: _, database } = get_test_db(
1999            "testing/data/storage/alice",
2000            Some(concat!(
2001                "/rCia2fYAJ+twCZ1Xm2mxFCYcmJdyzkdJjwtgXsziWpYS/UeNxnixuSieuwZXm+x1VsJHmWpl",
2002                "H+QIQBZpEGZtC9/S/l8xK+WOCesmET0o6yJ/KP73ofDtjBlnNpPwuHLKFpyTbyicpCgQ4UT+5E",
2003                "UBuJ08TY9Ujdf1D13k5kr5tSZUefDKKCuG1fCRqlU8ByRas1PMQsZxT2W8t7QgBrQiiGmhpo/O",
2004                "Ti4hfx97GOxncKcxTzppiYQNoHs/f15+XXQD7/oiCcqRIuUlXNsU6hRpFGmbYx2Pi1eyQViQCt",
2005                "B5dAEiSD0N8U81wXYnpynuTPtnL+hfnOJIn7Sy7mkERQeKg"
2006            )),
2007        )
2008        .await;
2009
2010        let account = database
2011            .load_account()
2012            .await
2013            .unwrap()
2014            .expect("The test database is prefilled with data, we should find an account");
2015
2016        let user_id = account.user_id();
2017        let device_id = account.device_id();
2018
2019        assert_eq!(
2020            user_id.as_str(),
2021            "@alice:localhost",
2022            "The user ID should match to the one we expect."
2023        );
2024
2025        assert_eq!(
2026            device_id.as_str(),
2027            "JVVORTHFXY",
2028            "The device ID should match to the one we expect."
2029        );
2030
2031        let tracked_users =
2032            database.load_tracked_users().await.expect("Should be tracking some users");
2033
2034        assert_eq!(tracked_users.len(), 6);
2035
2036        let known_users = vec![
2037            user_id!("@alice:localhost"),
2038            user_id!("@dehydration3:localhost"),
2039            user_id!("@eve:localhost"),
2040            user_id!("@bob:localhost"),
2041            user_id!("@malo:localhost"),
2042            user_id!("@carl:localhost"),
2043        ];
2044
2045        // load the identities
2046        for user_id in known_users {
2047            database.get_user_identity(user_id).await.expect("Should load this identity").unwrap();
2048        }
2049
2050        let carl_identity =
2051            database.get_user_identity(user_id!("@carl:localhost")).await.unwrap().unwrap();
2052
2053        assert_eq!(
2054            carl_identity.master_key().get_first_key().unwrap().to_base64(),
2055            "CdhKYYDeBDQveOioXEGWhTPCyzc63Irpar3CNyfun2Q"
2056        );
2057        assert!(!carl_identity.was_previously_verified());
2058
2059        let bob_identity =
2060            database.get_user_identity(user_id!("@bob:localhost")).await.unwrap().unwrap();
2061
2062        assert_eq!(
2063            bob_identity.master_key().get_first_key().unwrap().to_base64(),
2064            "COh2GYOJWSjem5QPRCaGp9iWV83IELG1IzLKW2S3pFY"
2065        );
2066        // Bob is verified so this flag should be set
2067        assert!(bob_identity.was_previously_verified());
2068
2069        let known_devices = vec![
2070            (device_id!("OPXQHCZSKW"), user_id!("@alice:localhost")),
2071            // a dehydrated one
2072            (
2073                device_id!("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho"),
2074                user_id!("@alice:localhost"),
2075            ),
2076            (device_id!("HEEFRFQENV"), user_id!("@alice:localhost")),
2077            (device_id!("JVVORTHFXY"), user_id!("@alice:localhost")),
2078            (device_id!("NQUWWSKKHS"), user_id!("@alice:localhost")),
2079            (device_id!("ORBLPFYCPG"), user_id!("@alice:localhost")),
2080            (device_id!("YXOWENSEGM"), user_id!("@dehydration3:localhost")),
2081            (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
2082            (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
2083            (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
2084            (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
2085            (device_id!("QKUKWJTTQC"), user_id!("@malo:localhost")),
2086            (device_id!("LOUXJECTFG"), user_id!("@malo:localhost")),
2087            (device_id!("MKKMAEVLPB"), user_id!("@carl:localhost")),
2088        ];
2089
2090        for (device_id, user_id) in known_devices {
2091            database.get_device(user_id, device_id).await.expect("Should load the device").unwrap();
2092        }
2093
2094        let known_sender_key_to_session_count = vec![
2095            ("FfYcYfDF4nWy+LHdK6CEpIMlFAQDORc30WUkghL06kM", 1),
2096            ("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho", 1),
2097            ("hAGsoA4a9M6wwEUX5Q1jux1i+tUngLi01n5AmhDoHTY", 1),
2098            ("aKqtSJymLzuoglWFwPGk1r/Vm2LE2hFESzXxn4RNjRM", 0),
2099            ("zHK1psCrgeMn0kaz8hcdvA3INyar9jg1yfrSp0p1pHo", 1),
2100            ("1QmBA316Wj5jIFRwNOti6N6Xh/vW0bsYCcR4uPfy8VQ", 1),
2101            ("g5ef2vZF3VXgSPyODIeXpyHIRkuthvLhGvd6uwYggWU", 1),
2102            ("o7hfupPd1VsNkRIvdlH6ujrEJFSKjFCGbxhAd31XxjI", 1),
2103            ("Z3RxKQLxY7xpP+ZdOGR2SiNE37SrvmRhW7GPu1UGdm8", 1),
2104            ("GDomaav8NiY3J+dNEeApJm+O0FooJ3IpVaIyJzCN4w4", 1),
2105            ("7m7fqkHyEr47V5s/KjaxtJMOr3pSHrrns2q2lWpAQi8", 0),
2106            ("9psAkPUIF8vNbWbnviX3PlwRcaeO53EHJdNtKpTY1X0", 0),
2107            ("mqanh+ztw5oRtpqYQgLGW864i6NY2zpoKMIlrcyC+Aw", 0),
2108            ("fJU/TJdbsv7tVbbpHw1Ke73ziElnM32cNhP2WIg4T10", 0),
2109            ("sUIeFeFcCZoa5IC6nJ6Vrbvztcyx09m8BBg57XKRClg", 1),
2110        ];
2111
2112        for (id, count) in known_sender_key_to_session_count {
2113            let olm_sessions =
2114                database.get_sessions(id).await.expect("Should have some olm sessions");
2115
2116            println!("### Session id: {id:?}");
2117            assert_eq!(olm_sessions.map_or(0, |v| v.len()), count);
2118        }
2119
2120        let inbound_group_sessions = database.get_inbound_group_sessions().await.unwrap();
2121        assert_eq!(inbound_group_sessions.len(), 15);
2122        let known_inbound_group_sessions = vec![
2123            (
2124                "5hNAxrLai3VI0LKBwfh3wLfksfBFWds0W1a5X5/vSXA",
2125                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2126            ),
2127            (
2128                "M6d2eU3y54gaYTbvGSlqa/xc1Az35l56Cp9sxzHWO4g",
2129                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2130            ),
2131            (
2132                "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
2133                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2134            ),
2135            (
2136                "Y74+l9jTo7N5UF+GQwdpgJGe4sn1+QtWITq7BxulHIE",
2137                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2138            ),
2139            (
2140                "HpJxQR57WbQGdY6w2Q+C16znVvbXGa+JvQdRoMpWbXg",
2141                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2142            ),
2143            (
2144                "Xetvi+ydFkZt8dpONGFbEusQb/Chc2V0XlLByZhsbgE",
2145                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2146            ),
2147            (
2148                "wv/WN/39akyerIXczTaIpjAuLnwgXKRtbXFSEHiJqxo",
2149                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2150            ),
2151            (
2152                "nA4gQwL//Cm8OdlyjABl/jChbPT/cP5V4Sd8iuE6H0s",
2153                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2154            ),
2155            (
2156                "bAAgqFeRDTjfEqL6Qf/c9mk55zoNDCSlboAIRd6b0hw",
2157                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2158            ),
2159            (
2160                "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
2161                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2162            ),
2163            (
2164                "h+om7oSw/ZV94fcKaoe8FGXJwQXWOfKQfzbGgNWQILI",
2165                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2166            ),
2167            (
2168                "ul3VXonpgk4lO2L3fEWubP/nxsTmLHqu5v8ZM9vHEcw",
2169                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2170            ),
2171            (
2172                "JXY15UxC3az2mwg8uX4qwgxfvCM4aygiIWMcdNiVQoc",
2173                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2174            ),
2175            (
2176                "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
2177                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
2178            ),
2179            (
2180                "SFkHcbxjUOYF7mUAYI/oEMDZFaXszQbCN6Jza7iemj0",
2181                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
2182            ),
2183        ];
2184
2185        // ensure we can load them all
2186        for (session_id, room_id) in &known_inbound_group_sessions {
2187            database
2188                .get_inbound_group_session(room_id, session_id)
2189                .await
2190                .expect("Should be able to load inbound group session")
2191                .unwrap();
2192        }
2193
2194        let bob_sender_verified = database
2195            .get_inbound_group_session(
2196                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2197                "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
2198            )
2199            .await
2200            .unwrap()
2201            .unwrap();
2202
2203        assert_eq!(bob_sender_verified.sender_data.to_type(), SenderDataType::SenderVerified);
2204        assert!(bob_sender_verified.backed_up());
2205        assert!(!bob_sender_verified.has_been_imported());
2206
2207        let alice_unknown_device = database
2208            .get_inbound_group_session(
2209                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2210                "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
2211            )
2212            .await
2213            .unwrap()
2214            .unwrap();
2215
2216        assert_eq!(alice_unknown_device.sender_data.to_type(), SenderDataType::UnknownDevice);
2217        assert!(alice_unknown_device.backed_up());
2218        assert!(alice_unknown_device.has_been_imported());
2219
2220        let carl_tofu_session = database
2221            .get_inbound_group_session(
2222                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
2223                "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
2224            )
2225            .await
2226            .unwrap()
2227            .unwrap();
2228
2229        assert_eq!(carl_tofu_session.sender_data.to_type(), SenderDataType::SenderUnverified);
2230        assert!(carl_tofu_session.backed_up());
2231        assert!(!carl_tofu_session.has_been_imported());
2232
2233        // Load outbound sessions
2234        database
2235            .get_outbound_group_session(room_id!("!OgRiTRMaUzLdpCeDBM:localhost"))
2236            .await
2237            .unwrap()
2238            .unwrap();
2239        database
2240            .get_outbound_group_session(room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"))
2241            .await
2242            .unwrap()
2243            .unwrap();
2244        database
2245            .get_outbound_group_session(room_id!("!SRstFdydzrGwJYtVfm:localhost"))
2246            .await
2247            .unwrap()
2248            .unwrap();
2249
2250        let withheld_info = database
2251            .get_withheld_info(
2252                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
2253                "SASgZ+EklvAF4QxJclMlDRlmL0fAMjAJJIKFMdb4Ht0",
2254            )
2255            .await
2256            .expect("This session should be withheld")
2257            .unwrap();
2258
2259        assert_eq!(withheld_info.content.withheld_code(), WithheldCode::Unverified);
2260
2261        let backup_keys = database.load_backup_keys().await.expect("backup key should be cached");
2262        assert_eq!(backup_keys.backup_version.unwrap(), "6");
2263        assert!(backup_keys.decryption_key.is_some());
2264    }
2265
2266    /// Test that we migrate the secrets inbox properly.
2267    ///
2268    /// The format for the secrets inbox changed in version 15.  Previously, the
2269    /// secrets inbox stored a full `GossippedSecrets` struct.  In version 15,
2270    /// the secrets inbox now stores only the secret.
2271    #[async_test]
2272    async fn test_secrets_inbox_migration() {
2273        use std::ops::Deref;
2274
2275        use matrix_sdk_crypto::{
2276            GossipRequest, GossippedSecret, SecretInfo,
2277            types::events::{
2278                olm_v1::{DecryptedSecretSendEvent, OlmV1Keys},
2279                secret_send::SecretSendContent,
2280            },
2281            vodozemac::Ed25519SecretKey,
2282        };
2283        use ruma::{TransactionId, events::secret::request::SecretName, owned_user_id};
2284
2285        use crate::utils::{EncryptableStore, SqliteAsyncConnExt};
2286
2287        // Create a database with version 16
2288        let tmpdir = tempdir().unwrap();
2289        let config = SqliteStoreConfig::new(tmpdir.path());
2290        let pool = config.build_pool_of_connections(super::DATABASE_NAME).unwrap();
2291        let conn = pool.get().await.unwrap();
2292        let version = super::initialize_store(&conn, 0).await.unwrap();
2293        let old_data_store = SqliteCryptoStore::create_raw(
2294            config.secret.clone(),
2295            pool,
2296            conn,
2297            config.pool_config(),
2298            config.runtime_config(),
2299        )
2300        .await
2301        .unwrap();
2302        super::run_migrations(&old_data_store, version, Some(16)).await.unwrap();
2303        old_data_store.write().await.unwrap().wal_checkpoint().await;
2304
2305        // Store a secret using the old format
2306        let secret = GossippedSecret {
2307            secret_name: SecretName::CrossSigningMasterKey,
2308            gossip_request: GossipRequest {
2309                request_recipient: owned_user_id!("@alice:example.com"),
2310                request_id: TransactionId::new(),
2311                info: SecretInfo::SecretRequest(SecretName::CrossSigningMasterKey),
2312                sent_out: true,
2313            },
2314            event: DecryptedSecretSendEvent {
2315                sender: owned_user_id!("@alice:example.com"),
2316                recipient: owned_user_id!("@alice:example.com"),
2317                keys: OlmV1Keys { ed25519: Ed25519SecretKey::new().public_key() },
2318                recipient_keys: OlmV1Keys { ed25519: Ed25519SecretKey::new().public_key() },
2319                sender_device_keys: None,
2320                content: SecretSendContent::new(
2321                    "abc".into(),
2322                    "It is a secret to everybody".to_owned(),
2323                ),
2324            },
2325        };
2326        let value = old_data_store.serialize_json(&secret).unwrap();
2327        old_data_store
2328            .write()
2329            .await
2330            .unwrap()
2331            .prepare("INSERT INTO secrets (secret_name, data) VALUES (?1, ?2)", |mut stmt| {
2332                stmt.execute((SecretName::CrossSigningMasterKey.to_string(), value))
2333            })
2334            .await
2335            .unwrap();
2336
2337        // After we open the store, the data will be migrated
2338        let store = SqliteCryptoStore::open_with_config(&config).await.unwrap();
2339
2340        // and we should be able to read the secrets from the inbox
2341        let secrets =
2342            store.get_secrets_from_inbox(&SecretName::CrossSigningMasterKey).await.unwrap();
2343        assert_eq!(secrets.len(), 1);
2344        assert_eq!(secrets[0].deref(), "It is a secret to everybody");
2345    }
2346
2347    async fn get_store(
2348        name: &str,
2349        passphrase: Option<&str>,
2350        clear_data: bool,
2351    ) -> SqliteCryptoStore {
2352        let tmpdir_path = TMP_DIR.path().join(name);
2353
2354        if clear_data {
2355            let _ = fs::remove_dir_all(&tmpdir_path).await;
2356        }
2357
2358        SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), passphrase)
2359            .await
2360            .expect("Can't create a secret protected store")
2361    }
2362
2363    cryptostore_integration_tests!();
2364    cryptostore_integration_tests_time!();
2365}
2366
2367#[cfg(test)]
2368mod encrypted_tests {
2369    use std::sync::LazyLock;
2370
2371    use matrix_sdk_crypto::{cryptostore_integration_tests, cryptostore_integration_tests_time};
2372    use tempfile::{TempDir, tempdir};
2373    use tokio::fs;
2374
2375    use super::SqliteCryptoStore;
2376
2377    static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
2378
2379    async fn get_store(
2380        name: &str,
2381        passphrase: Option<&str>,
2382        clear_data: bool,
2383    ) -> SqliteCryptoStore {
2384        let tmpdir_path = TMP_DIR.path().join(name);
2385        let pass = passphrase.unwrap_or("default_test_password");
2386
2387        if clear_data {
2388            let _ = fs::remove_dir_all(&tmpdir_path).await;
2389        }
2390
2391        SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), Some(pass))
2392            .await
2393            .expect("Can't create a secret protected store")
2394    }
2395
2396    cryptostore_integration_tests!();
2397    cryptostore_integration_tests_time!();
2398}
2399
2400#[cfg(test)]
2401mod close_reopen_tests {
2402    use std::sync::LazyLock;
2403
2404    use matrix_sdk_crypto::store::CryptoStore;
2405    use matrix_sdk_test::async_test;
2406    use tempfile::{TempDir, tempdir};
2407
2408    use super::SqliteCryptoStore;
2409
2410    static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
2411
2412    async fn new_store(name: &str) -> SqliteCryptoStore {
2413        let tmpdir_path = TMP_DIR.path().join(name);
2414        SqliteCryptoStore::open(tmpdir_path, None).await.unwrap()
2415    }
2416
2417    #[async_test]
2418    async fn test_close_completes_without_timeout() {
2419        let store = new_store("close_no_timeout").await;
2420
2421        // Close should complete quickly without hitting the 5s timeout.
2422        let start = std::time::Instant::now();
2423        store.close().await.unwrap();
2424        let elapsed = start.elapsed();
2425
2426        assert!(
2427            elapsed < std::time::Duration::from_secs(2),
2428            "close() took {elapsed:?}, expected < 2s (no timeout)"
2429        );
2430
2431        // Connections should be None after close.
2432        let guard = store.connections.lock().await;
2433        assert!(guard.is_none(), "connections should be None after close");
2434    }
2435
2436    #[async_test]
2437    async fn test_reopen_restores_connections() {
2438        let store = new_store("reopen_restores").await;
2439
2440        store.close().await.unwrap();
2441
2442        {
2443            let guard = store.connections.lock().await;
2444            assert!(guard.is_none());
2445        }
2446
2447        store.reopen().await.unwrap();
2448
2449        {
2450            let guard = store.connections.lock().await;
2451            assert!(guard.is_some(), "connections should be Some after reopen");
2452        }
2453    }
2454
2455    #[async_test]
2456    async fn test_close_is_idempotent() {
2457        let store = new_store("close_idempotent").await;
2458
2459        store.close().await.unwrap();
2460        // Second close should be a no-op.
2461        store.close().await.unwrap();
2462
2463        let guard = store.connections.lock().await;
2464        assert!(guard.is_none());
2465    }
2466
2467    #[async_test]
2468    async fn test_reopen_is_idempotent() {
2469        let store = new_store("reopen_idempotent").await;
2470
2471        // Reopen on an active store should be a no-op.
2472        store.reopen().await.unwrap();
2473
2474        let guard = store.connections.lock().await;
2475        assert!(guard.is_some());
2476    }
2477
2478    #[async_test]
2479    async fn test_read_fails_when_closed() {
2480        let store = new_store("read_fails_closed").await;
2481        store.close().await.unwrap();
2482
2483        let err = store.load_account().await;
2484        assert!(err.is_err(), "read should fail when closed");
2485
2486        let err_msg = err.unwrap_err().to_string();
2487        assert!(err_msg.contains("closed"), "error should mention 'closed', got: {err_msg}");
2488    }
2489
2490    #[async_test]
2491    async fn test_operations_work_after_reopen() {
2492        let store = new_store("ops_after_reopen").await;
2493
2494        store.close().await.unwrap();
2495        store.reopen().await.unwrap();
2496
2497        // A read operation should work immediately after reopen.
2498        let account = store.load_account().await;
2499        assert!(account.is_ok(), "load_account should succeed after reopen");
2500        // No account was saved, so this should be None.
2501        assert!(account.unwrap().is_none());
2502    }
2503
2504    #[async_test]
2505    async fn test_multiple_close_reopen_cycles() {
2506        let store = new_store("multi_cycles").await;
2507
2508        for _ in 0..5 {
2509            store.close().await.unwrap();
2510            store.reopen().await.unwrap();
2511
2512            // After each cycle, the store should be fully operational.
2513            let account = store.load_account().await;
2514            assert!(account.is_ok(), "store should work after close/reopen cycle");
2515        }
2516    }
2517
2518    #[async_test]
2519    async fn test_pool_is_fully_drained_after_close() {
2520        let store = new_store("pool_drained").await;
2521
2522        // Do a few reads to exercise the pool.
2523        let _ = store.load_account().await;
2524        let _ = store.load_account().await;
2525
2526        store.close().await.unwrap();
2527
2528        // After close, the connections field should be None.
2529        let guard = store.connections.lock().await;
2530        assert!(guard.is_none(), "all connections should be released after close");
2531    }
2532
2533    #[async_test]
2534    async fn test_close_waits_for_held_read_connection_to_drain() {
2535        let store = new_store("held_read_drain").await;
2536
2537        // Acquire a read connection and hold it, simulating an in-flight read.
2538        let held_conn = store.read().await.unwrap();
2539
2540        // Spawn close in a background task — it will close the pool and then
2541        // poll-wait for pool.status().size == 0 in the drain loop.
2542        let store_clone = store.clone();
2543        let close_handle = tokio::spawn(async move {
2544            store_clone.close().await.unwrap();
2545        });
2546
2547        // Give close() a moment to close the pool and enter the drain loop.
2548        tokio::time::sleep(std::time::Duration::from_millis(200)).await;
2549
2550        // The close task should still be running because we hold a connection.
2551        assert!(!close_handle.is_finished(), "close should be waiting for the held connection");
2552
2553        // Release the held connection — this lets pool.status().size drop to 0.
2554        drop(held_conn);
2555
2556        // Now close should complete promptly (well within the 5s timeout).
2557        let timeout = tokio::time::timeout(std::time::Duration::from_secs(3), close_handle).await;
2558        assert!(timeout.is_ok(), "close should complete after the held connection is released");
2559        timeout.unwrap().unwrap();
2560
2561        // Verify the store is fully closed.
2562        let guard = store.connections.lock().await;
2563        assert!(guard.is_none(), "connections should be None after close");
2564    }
2565}