Skip to main content

matrix_sdk_sqlite/
crypto_store.rs

1// Copyright 2022, 2026 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::HashMap,
17    fmt,
18    ops::Deref,
19    path::{Path, PathBuf},
20    sync::{Arc, RwLock},
21};
22
23use async_trait::async_trait;
24use deadpool::managed::PoolConfig;
25use matrix_sdk_base::cross_process_lock::CrossProcessLockGeneration;
26use matrix_sdk_crypto::{
27    Account, DeviceData, GossipRequest, GossippedSecret, SecretInfo, TrackedUser, UserIdentityData,
28    olm::{
29        InboundGroupSession, OutboundGroupSession, PickledInboundGroupSession,
30        PrivateCrossSigningIdentity, SenderDataType, Session, StaticAccountData,
31    },
32    store::{
33        CryptoStore,
34        types::{
35            BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts,
36            RoomKeyWithheldEntry, RoomPendingKeyBundleDetails, RoomSettings,
37            StoredRoomKeyBundleData,
38        },
39    },
40};
41use matrix_sdk_store_encryption::StoreCipher;
42use ruma::{
43    DeviceId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, RoomId, TransactionId, UserId,
44    events::secret::request::SecretName,
45};
46use rusqlite::{OptionalExtension, named_params, params_from_iter};
47use tokio::{
48    fs,
49    sync::{Mutex, OwnedMutexGuard},
50};
51use tracing::{debug, instrument, warn};
52use vodozemac::Curve25519PublicKey;
53use zeroize::Zeroizing;
54
55use crate::{
56    OpenStoreError, RuntimeConfig, Secret, SqliteStoreConfig,
57    connection::{self, Connection as SqliteAsyncConn, Pool as SqlitePool, SqliteConnections},
58    error::{Error, Result},
59    utils::{
60        EncryptableStore, Key, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt,
61        SqliteKeyValueStoreConnExt, repeat_vars,
62    },
63};
64
65/// The database name.
66const DATABASE_NAME: &str = "matrix-sdk-crypto.sqlite3";
67
68/// An SQLite-based crypto store.
69#[derive(Clone)]
70pub struct SqliteCryptoStore {
71    store_cipher: Option<Arc<StoreCipher>>,
72
73    /// `Some` when active, `None` when closed.
74    /// The outer `Mutex` serialises close/reopen with connection access.
75    connections: Arc<Mutex<Option<SqliteConnections>>>,
76
77    /// Retained so we can rebuild the pool on reopen.
78    db_path: PathBuf,
79
80    /// Retained so we can rebuild the pool on reopen.
81    pool_config: PoolConfig,
82
83    /// Retained so we can re-apply runtime config on reopen.
84    runtime_config: RuntimeConfig,
85
86    // DB values cached in memory
87    static_account: Arc<RwLock<Option<StaticAccountData>>>,
88    save_changes_lock: Arc<Mutex<()>>,
89}
90
91#[cfg(not(tarpaulin_include))]
92impl fmt::Debug for SqliteCryptoStore {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        f.debug_struct("SqliteCryptoStore").finish_non_exhaustive()
95    }
96}
97
98impl EncryptableStore for SqliteCryptoStore {
99    fn get_cypher(&self) -> Option<&StoreCipher> {
100        self.store_cipher.as_deref()
101    }
102}
103
104impl SqliteCryptoStore {
105    /// Create an `SqliteCryptoStore` struct without trying to create the
106    /// database or migrate to a newer version.  This is only for use
107    /// internally, and for testing.
108    ///
109    /// # Arguments
110    ///
111    /// * `secret` - The secret used to encrypt the data.
112    ///
113    /// * `pool` - A connection pool to use for reading from the store.
114    ///
115    /// * `conn` - The connection to use for writing to the store.
116    pub(crate) async fn create_raw(
117        secret: Option<Secret>,
118        pool: SqlitePool,
119        conn: SqliteAsyncConn,
120        pool_config: PoolConfig,
121        runtime_config: RuntimeConfig,
122    ) -> Result<Self, OpenStoreError> {
123        let store_cipher = match secret {
124            Some(s) => Some(Arc::new(conn.get_or_create_store_cipher(s).await?)),
125            None => None,
126        };
127
128        let db_path = pool.manager().database_path.clone();
129
130        Ok(Self {
131            store_cipher,
132            connections: Arc::new(Mutex::new(Some(SqliteConnections {
133                pool,
134                write_connection: Arc::new(Mutex::new(conn)),
135            }))),
136            db_path,
137            pool_config,
138            runtime_config,
139            static_account: Arc::new(RwLock::new(None)),
140            save_changes_lock: Default::default(),
141        })
142    }
143
144    /// Open the SQLite-based crypto store at the given path using the given
145    /// passphrase to encrypt private data.
146    pub async fn open(
147        path: impl AsRef<Path>,
148        passphrase: Option<&str>,
149    ) -> Result<Self, OpenStoreError> {
150        Self::open_with_config(&SqliteStoreConfig::new(path).passphrase(passphrase)).await
151    }
152
153    /// Open the SQLite-based crypto store at the given path using the given
154    /// key to encrypt private data.
155    pub async fn open_with_key(
156        path: impl AsRef<Path>,
157        key: Option<&[u8; 32]>,
158    ) -> Result<Self, OpenStoreError> {
159        Self::open_with_config(&SqliteStoreConfig::new(path).key(key)).await
160    }
161
162    /// Open the SQLite-based crypto store with the config open config.
163    pub async fn open_with_config(config: &SqliteStoreConfig) -> Result<Self, OpenStoreError> {
164        fs::create_dir_all(&config.path).await.map_err(OpenStoreError::CreateDir)?;
165
166        let pool = config.build_pool_of_connections(DATABASE_NAME)?;
167        let pool_config = config.pool_config();
168        let runtime_config = config.runtime_config();
169
170        let this =
171            Self::open_with_pool(pool, config.secret.clone(), pool_config, runtime_config).await?;
172        this.read().await?.apply_runtime_config(runtime_config).await?;
173
174        Ok(this)
175    }
176
177    /// Create an SQLite-based crypto store using the given SQLite database
178    /// pool. The given secret will be used to encrypt private data.
179    async fn open_with_pool(
180        pool: SqlitePool,
181        secret: Option<Secret>,
182        pool_config: PoolConfig,
183        runtime_config: RuntimeConfig,
184    ) -> Result<Self, OpenStoreError> {
185        let conn = pool.get().await?;
186
187        let version = conn.db_version().await?;
188        debug!("Opened sqlite store with version {}", version);
189
190        let version = initialize_store(&conn, version).await?;
191
192        let store = Self::create_raw(secret, pool, conn, pool_config, runtime_config).await?;
193
194        run_migrations(&store, version, None).await?;
195
196        store.write().await?.wal_checkpoint().await;
197
198        Ok(store)
199    }
200
201    fn deserialize_and_unpickle_inbound_group_session(
202        &self,
203        value: Vec<u8>,
204        backed_up: bool,
205    ) -> Result<InboundGroupSession> {
206        let mut pickle: PickledInboundGroupSession = self.deserialize_value(&value)?;
207
208        // The `backed_up` SQL column is the source of truth, because we update it
209        // inside `mark_inbound_group_sessions_as_backed_up` and don't update
210        // the pickled value inside the `data` column (until now, when we are puling it
211        // out of the DB).
212        pickle.backed_up = backed_up;
213
214        Ok(InboundGroupSession::from_pickle(pickle)?)
215    }
216
217    fn deserialize_key_request(&self, value: &[u8], sent_out: bool) -> Result<GossipRequest> {
218        let mut request: GossipRequest = self.deserialize_value(value)?;
219        // sent_out SQL column is source of truth, sent_out field in serialized value
220        // needed for other stores though
221        request.sent_out = sent_out;
222        Ok(request)
223    }
224
225    fn get_static_account(&self) -> Option<StaticAccountData> {
226        self.static_account.read().unwrap().clone()
227    }
228
229    /// Acquire a connection for executing read operations.
230    #[instrument(skip_all)]
231    async fn read(&self) -> Result<SqliteAsyncConn> {
232        let pool = {
233            let guard = self.connections.lock().await;
234            let conns = guard.as_ref().ok_or(Error::StoreClosed)?;
235            conns.pool.clone()
236        };
237        Ok(pool.get().await?)
238    }
239
240    /// Acquire a connection for executing write operations.
241    #[instrument(skip_all)]
242    pub(crate) async fn write(&self) -> Result<OwnedMutexGuard<SqliteAsyncConn>> {
243        let write_connection = {
244            let guard = self.connections.lock().await;
245            let conns = guard.as_ref().ok_or(Error::StoreClosed)?;
246            conns.write_connection.clone()
247        };
248        Ok(write_connection.lock_owned().await)
249    }
250}
251
252const DATABASE_VERSION: u8 = 15;
253
254/// key for the dehydrated device pickle key in the key/value table.
255const DEHYDRATED_DEVICE_PICKLE_KEY: &str = "dehydrated_device_pickle_key";
256
257/// Initialize the database to version 1
258///
259/// This must be done before creating the store cipher, because the store cipher
260/// requires the `kv` table.
261///
262/// # Arguments
263///
264/// * `conn` - The connection to use.
265///
266/// * `version` - the current version of the database.
267pub(crate) async fn initialize_store(conn: &SqliteAsyncConn, version: u8) -> Result<u8> {
268    if version == 0 {
269        debug!("Creating database");
270    } else if version < DATABASE_VERSION {
271        debug!(version, new_version = DATABASE_VERSION, "Upgrading database");
272    } else {
273        return Ok(version);
274    }
275
276    if version < 1 {
277        debug!("Creating database");
278        // First turn on WAL mode, this can't be done in the transaction, it fails with
279        // the error message: "cannot change into wal mode from within a transaction".
280        conn.execute_batch("PRAGMA journal_mode = wal;").await?;
281        conn.with_transaction(|txn| {
282            txn.execute_batch(include_str!("../migrations/crypto_store/001_init.sql"))?;
283            txn.set_db_version(1)
284        })
285        .await?;
286        return Ok(1);
287    }
288
289    Ok(version)
290}
291
292/// Run migrations for the given version of the database.
293///
294/// # Arguments
295///
296/// * `store` - The store to run the migrations on
297///
298/// * `version` - The current version of the database.
299///
300/// * `max_version` - The maximum version that the database will be migrated to.
301///   Only used for testing, so will only be checked for the versions that are
302///   needed for tests.
303pub(crate) async fn run_migrations(
304    store: &SqliteCryptoStore,
305    version: u8,
306    max_version: Option<u8>,
307) -> Result<()> {
308    let conn = store.write().await?;
309
310    if version < 2 {
311        debug!("Upgrading database to version 2");
312        conn.with_transaction(|txn| {
313            txn.execute_batch(include_str!("../migrations/crypto_store/002_reset_olm_hash.sql"))?;
314            txn.set_db_version(2)
315        })
316        .await?;
317    }
318
319    if version < 3 {
320        debug!("Upgrading database to version 3");
321        conn.with_transaction(|txn| {
322            txn.execute_batch(include_str!("../migrations/crypto_store/003_room_settings.sql"))?;
323            txn.set_db_version(3)
324        })
325        .await?;
326    }
327
328    if version < 4 {
329        debug!("Upgrading database to version 4");
330        conn.with_transaction(|txn| {
331            txn.execute_batch(include_str!(
332                "../migrations/crypto_store/004_drop_outbound_group_sessions.sql"
333            ))?;
334            txn.set_db_version(4)
335        })
336        .await?;
337    }
338
339    if version < 5 {
340        debug!("Upgrading database to version 5");
341        conn.with_transaction(|txn| {
342            txn.execute_batch(include_str!("../migrations/crypto_store/005_withheld_code.sql"))?;
343            txn.set_db_version(5)
344        })
345        .await?;
346    }
347
348    if version < 6 {
349        debug!("Upgrading database to version 6");
350        conn.with_transaction(|txn| {
351            txn.execute_batch(include_str!(
352                "../migrations/crypto_store/006_drop_outbound_group_sessions.sql"
353            ))?;
354            txn.set_db_version(6)
355        })
356        .await?;
357    }
358
359    if version < 7 {
360        debug!("Upgrading database to version 7");
361        conn.with_transaction(|txn| {
362            txn.execute_batch(include_str!("../migrations/crypto_store/007_lock_leases.sql"))?;
363            txn.set_db_version(7)
364        })
365        .await?;
366    }
367
368    if version < 8 {
369        debug!("Upgrading database to version 8");
370        conn.with_transaction(|txn| {
371            txn.execute_batch(include_str!("../migrations/crypto_store/008_secret_inbox.sql"))?;
372            txn.set_db_version(8)
373        })
374        .await?;
375    }
376
377    if version < 9 {
378        debug!("Upgrading database to version 9");
379        conn.with_transaction(|txn| {
380            txn.execute_batch(include_str!(
381                "../migrations/crypto_store/009_inbound_group_session_sender_key_sender_data_type.sql"
382            ))?;
383            txn.set_db_version(9)
384        })
385        .await?;
386    }
387
388    if version < 10 {
389        debug!("Upgrading database to version 10");
390        conn.with_transaction(|txn| {
391            txn.execute_batch(include_str!(
392                "../migrations/crypto_store/010_received_room_key_bundles.sql"
393            ))?;
394            txn.set_db_version(10)
395        })
396        .await?;
397    }
398
399    if version < 11 {
400        debug!("Upgrading database to version 11");
401        conn.with_transaction(|txn| {
402            txn.execute_batch(include_str!(
403                "../migrations/crypto_store/011_received_room_key_bundles_with_curve_key.sql"
404            ))?;
405            txn.set_db_version(11)
406        })
407        .await?;
408    }
409
410    if version < 12 {
411        debug!("Upgrading database to version 12");
412        conn.with_transaction(|txn| {
413            txn.execute_batch(include_str!(
414                "../migrations/crypto_store/012_withheld_code_by_room.sql"
415            ))?;
416            txn.set_db_version(12)
417        })
418        .await?;
419    }
420
421    if version < 13 {
422        debug!("Upgrading database to version 13");
423        conn.with_transaction(|txn| {
424            txn.execute_batch(include_str!(
425                "../migrations/crypto_store/013_lease_locks_with_generation.sql"
426            ))?;
427            txn.set_db_version(13)
428        })
429        .await?;
430    }
431
432    if version < 14 {
433        debug!("Upgrading database to version 14");
434        conn.with_transaction(|txn| {
435            txn.execute_batch(include_str!(
436                "../migrations/crypto_store/014_room_key_backups_fully_downloaded.sql"
437            ))?;
438            txn.set_db_version(14)
439        })
440        .await?;
441    }
442
443    if version < 15 {
444        debug!("Upgrading database to version 15");
445        conn.with_transaction(|txn| {
446            txn.execute_batch(include_str!(
447                "../migrations/crypto_store/015_rooms_pending_key_bundle.sql"
448            ))?;
449            txn.set_db_version(15)
450        })
451        .await?;
452    }
453
454    if version < 16 {
455        debug!("Upgrading database to version 16");
456        conn.with_transaction(|txn| {
457            txn.execute_batch(include_str!(
458                "../migrations/crypto_store/016_remove_old_generation_counter.sql"
459            ))?;
460            txn.set_db_version(16)
461        })
462        .await?;
463    }
464
465    if max_version.is_some_and(|max_version| max_version < 17) {
466        return Ok(());
467    }
468
469    if version < 17 {
470        debug!("Upgrading database to version 17");
471        let store = store.clone();
472        conn.with_transaction(move |txn| {
473            txn.execute_batch(include_str!(
474                "../migrations/crypto_store/017_add_new_secrets_inbox.sql"
475            ))?;
476            let mut select_query = txn.prepare("SELECT data FROM secrets")?;
477            let mut secrets = select_query.query([])?;
478            let mut insert_query = txn.prepare(
479                "INSERT OR IGNORE INTO secrets_inbox (secret_name, secret)
480            VALUES (?1, ?2)",
481            )?;
482            while let Some(row) = secrets.next()? {
483                let Ok(secret) =
484                    store.deserialize_json::<GossippedSecret>(row.get::<_, Vec<u8>>(0)?.as_ref())
485                else {
486                    continue;
487                };
488                let Ok(encoded_secret) = store.serialize_json(&secret.event.content.secret) else {
489                    continue;
490                };
491                insert_query.execute((
492                    store.encode_key("secrets_inbox", secret.secret_name.to_string()),
493                    &encoded_secret,
494                ))?;
495            }
496            txn.execute_batch(include_str!(
497                "../migrations/crypto_store/017_drop_old_secrets_inbox.sql"
498            ))?;
499            txn.set_db_version(17)
500        })
501        .await?;
502    }
503
504    if version < 18 {
505        debug!("Upgrading database to version 18");
506        let store = store.clone();
507        conn.with_transaction(move |txn| {
508            txn.execute_batch(include_str!(
509                "../migrations/crypto_store/018_add_gossip_request_info.sql"
510            ))?;
511            let mut select_query =
512                txn.prepare("SELECT request_id, sent_out, data FROM key_requests")?;
513            let mut requests = select_query.query([])?;
514            let mut update_query =
515                txn.prepare("UPDATE OR REPLACE key_requests SET info = ?1 WHERE request_id = ?2")?;
516            while let Some(row) = requests.next()? {
517                let Ok(request) = store.deserialize_key_request(
518                    row.get::<_, Vec<u8>>(2)?.as_ref(),
519                    row.get::<_, bool>(1)?,
520                ) else {
521                    continue;
522                };
523                let info = store.encode_key("key_requests", request.info.as_key());
524                update_query.execute((info, row.get::<_, Vec<u8>>(0)?))?;
525            }
526            txn.set_db_version(18)
527        })
528        .await?;
529    }
530
531    Ok(())
532}
533
534trait SqliteConnectionExt {
535    fn set_session(
536        &self,
537        session_id: &[u8],
538        sender_key: &[u8],
539        data: &[u8],
540    ) -> rusqlite::Result<()>;
541
542    fn set_inbound_group_session(
543        &self,
544        room_id: &[u8],
545        session_id: &[u8],
546        data: &[u8],
547        backed_up: bool,
548        sender_key: Option<&[u8]>,
549        sender_data_type: Option<u8>,
550    ) -> rusqlite::Result<()>;
551
552    fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
553
554    fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
555    fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()>;
556
557    fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
558
559    fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()>;
560
561    fn set_key_request(
562        &self,
563        request_id: &[u8],
564        sent_out: bool,
565        data: &[u8],
566        info: &[u8],
567    ) -> rusqlite::Result<()>;
568
569    fn set_direct_withheld(
570        &self,
571        session_id: &[u8],
572        room_id: &[u8],
573        data: &[u8],
574    ) -> rusqlite::Result<()>;
575
576    fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
577
578    fn set_secret(&self, request_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
579
580    fn set_received_room_key_bundle(
581        &self,
582        room_id: &[u8],
583        user_id: &[u8],
584        data: &[u8],
585    ) -> rusqlite::Result<()>;
586
587    fn set_has_downloaded_all_room_keys(&self, room_id: &[u8]) -> rusqlite::Result<()>;
588
589    fn set_room_pending_key_bundle(
590        &self,
591        room_id: &[u8],
592        details: Option<&[u8]>,
593    ) -> rusqlite::Result<()>;
594}
595
596impl SqliteConnectionExt for rusqlite::Connection {
597    fn set_session(
598        &self,
599        session_id: &[u8],
600        sender_key: &[u8],
601        data: &[u8],
602    ) -> rusqlite::Result<()> {
603        self.execute(
604            "INSERT INTO session (session_id, sender_key, data)
605             VALUES (?1, ?2, ?3)
606             ON CONFLICT (session_id) DO UPDATE SET data = ?3",
607            (session_id, sender_key, data),
608        )?;
609        Ok(())
610    }
611
612    fn set_inbound_group_session(
613        &self,
614        room_id: &[u8],
615        session_id: &[u8],
616        data: &[u8],
617        backed_up: bool,
618        sender_key: Option<&[u8]>,
619        sender_data_type: Option<u8>,
620    ) -> rusqlite::Result<()> {
621        self.execute(
622            "INSERT INTO inbound_group_session (session_id, room_id, data, backed_up, sender_key, sender_data_type) \
623             VALUES (?1, ?2, ?3, ?4, ?5, ?6)
624             ON CONFLICT (session_id) DO UPDATE SET data = ?3, backed_up = ?4, sender_key = ?5, sender_data_type = ?6",
625            (session_id, room_id, data, backed_up, sender_key, sender_data_type),
626        )?;
627        Ok(())
628    }
629
630    fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
631        self.execute(
632            "INSERT INTO outbound_group_session (room_id, data) \
633             VALUES (?1, ?2)
634             ON CONFLICT (room_id) DO UPDATE SET data = ?2",
635            (room_id, data),
636        )?;
637        Ok(())
638    }
639
640    fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
641        self.execute(
642            "INSERT INTO device (user_id, device_id, data) \
643             VALUES (?1, ?2, ?3)
644             ON CONFLICT (user_id, device_id) DO UPDATE SET data = ?3",
645            (user_id, device_id, data),
646        )?;
647        Ok(())
648    }
649
650    fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()> {
651        self.execute(
652            "DELETE FROM device WHERE user_id = ? AND device_id = ?",
653            (user_id, device_id),
654        )?;
655        Ok(())
656    }
657
658    fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
659        self.execute(
660            "INSERT INTO identity (user_id, data) \
661             VALUES (?1, ?2)
662             ON CONFLICT (user_id) DO UPDATE SET data = ?2",
663            (user_id, data),
664        )?;
665        Ok(())
666    }
667
668    fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()> {
669        self.execute("INSERT INTO olm_hash (data) VALUES (?) ON CONFLICT DO NOTHING", (data,))?;
670        Ok(())
671    }
672
673    fn set_key_request(
674        &self,
675        request_id: &[u8],
676        sent_out: bool,
677        data: &[u8],
678        info: &[u8],
679    ) -> rusqlite::Result<()> {
680        // The first `ON CONFLICT` cause will update a request if we try to save
681        // it again.  The second `ON CONFLICT` will replace an old request for
682        // the same key/secret with the new request.
683        self.execute(
684            "INSERT INTO key_requests (request_id, sent_out, data, info)
685            VALUES (?1, ?2, ?3, ?4)
686            ON CONFLICT (request_id) DO UPDATE SET sent_out = ?2, data = ?3, info = ?4
687            ON CONFLICT (info) DO UPDATE SET request_id = ?1, sent_out = ?2, data = ?3",
688            (request_id, sent_out, data, info),
689        )?;
690        Ok(())
691    }
692
693    fn set_direct_withheld(
694        &self,
695        session_id: &[u8],
696        room_id: &[u8],
697        data: &[u8],
698    ) -> rusqlite::Result<()> {
699        self.execute(
700            "INSERT INTO direct_withheld_info (session_id, room_id, data)
701            VALUES (?1, ?2, ?3)
702            ON CONFLICT (session_id) DO UPDATE SET room_id = ?2, data = ?3",
703            (session_id, room_id, data),
704        )?;
705        Ok(())
706    }
707
708    fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
709        self.execute(
710            "INSERT INTO room_settings (room_id, data)
711            VALUES (?1, ?2)
712            ON CONFLICT (room_id) DO UPDATE SET data = ?2",
713            (room_id, data),
714        )?;
715        Ok(())
716    }
717
718    fn set_secret(&self, secret_name: &[u8], secret: &[u8]) -> rusqlite::Result<()> {
719        // Ignore duplicate values, since we may get set the same secret
720        // multiple times.
721        self.execute(
722            "INSERT OR IGNORE INTO secrets_inbox (secret_name, secret)
723            VALUES (?1, ?2)",
724            (secret_name, secret),
725        )?;
726
727        Ok(())
728    }
729
730    fn set_received_room_key_bundle(
731        &self,
732        room_id: &[u8],
733        sender_user_id: &[u8],
734        data: &[u8],
735    ) -> rusqlite::Result<()> {
736        self.execute(
737            "INSERT INTO received_room_key_bundle(room_id, sender_user_id, bundle_data)
738            VALUES (?1, ?2, ?3)
739            ON CONFLICT (room_id, sender_user_id) DO UPDATE SET bundle_data = ?3",
740            (room_id, sender_user_id, data),
741        )?;
742        Ok(())
743    }
744
745    fn set_room_pending_key_bundle(
746        &self,
747        room_id: &[u8],
748        data: Option<&[u8]>,
749    ) -> rusqlite::Result<()> {
750        if let Some(data) = data {
751            self.execute(
752                "INSERT INTO rooms_pending_key_bundle (room_id, data)
753                 VALUES (?1, ?2)
754                 ON CONFLICT (room_id) DO UPDATE SET data = ?2",
755                (room_id, data),
756            )?;
757        } else {
758            self.execute("DELETE FROM rooms_pending_key_bundle WHERE room_id = ?1", (room_id,))?;
759        }
760        Ok(())
761    }
762
763    fn set_has_downloaded_all_room_keys(&self, room_id: &[u8]) -> rusqlite::Result<()> {
764        self.execute(
765            "INSERT INTO room_key_backups_fully_downloaded(room_id)
766             VALUES (?1)
767             ON CONFLICT(room_id) DO NOTHING",
768            (room_id,),
769        )?;
770        Ok(())
771    }
772}
773
774#[async_trait]
775trait SqliteObjectCryptoStoreExt: SqliteAsyncConnExt {
776    async fn get_sessions_for_sender_key(&self, sender_key: Key) -> Result<Vec<Vec<u8>>> {
777        Ok(self
778            .prepare("SELECT data FROM session WHERE sender_key = ?", |mut stmt| {
779                stmt.query((sender_key,))?.mapped(|row| row.get(0)).collect()
780            })
781            .await?)
782    }
783
784    async fn get_inbound_group_session(
785        &self,
786        session_id: Key,
787    ) -> Result<Option<(Vec<u8>, Vec<u8>, bool)>> {
788        Ok(self
789            .query_row(
790                "SELECT room_id, data, backed_up FROM inbound_group_session WHERE session_id = ?",
791                (session_id,),
792                |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
793            )
794            .await
795            .optional()?)
796    }
797
798    async fn get_inbound_group_sessions(&self) -> Result<Vec<(Vec<u8>, bool)>> {
799        Ok(self
800            .prepare("SELECT data, backed_up FROM inbound_group_session", |mut stmt| {
801                stmt.query(())?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
802            })
803            .await?)
804    }
805
806    async fn get_inbound_group_session_counts(
807        &self,
808        _backup_version: Option<&str>,
809    ) -> Result<RoomKeyCounts> {
810        let total = self
811            .query_row("SELECT count(*) FROM inbound_group_session", (), |row| row.get(0))
812            .await?;
813        let backed_up = self
814            .query_row(
815                "SELECT count(*) FROM inbound_group_session WHERE backed_up = TRUE",
816                (),
817                |row| row.get(0),
818            )
819            .await?;
820        Ok(RoomKeyCounts { total, backed_up })
821    }
822
823    async fn get_inbound_group_sessions_by_room_id(
824        &self,
825        room_id: Key,
826    ) -> Result<Vec<(Vec<u8>, bool)>> {
827        Ok(self
828            .prepare(
829                "SELECT data, backed_up FROM inbound_group_session WHERE room_id = :room_id",
830                move |mut stmt| {
831                    stmt.query(named_params! {
832                        ":room_id": room_id,
833                    })?
834                    .mapped(|row| Ok((row.get(0)?, row.get(1)?)))
835                    .collect()
836                },
837            )
838            .await?)
839    }
840
841    async fn get_inbound_group_sessions_for_device_batch(
842        &self,
843        sender_key: Key,
844        sender_data_type: SenderDataType,
845        after_session_id: Option<Key>,
846        limit: usize,
847    ) -> Result<Vec<(Vec<u8>, bool)>> {
848        Ok(self
849            .prepare(
850                "
851                SELECT data, backed_up
852                FROM inbound_group_session
853                WHERE sender_key = :sender_key
854                    AND sender_data_type = :sender_data_type
855                    AND session_id > :after_session_id
856                ORDER BY session_id
857                LIMIT :limit
858                ",
859                move |mut stmt| {
860                    let sender_data_type = sender_data_type as u8;
861
862                    // If we are not provided with an `after_session_id`, use a key which will sort
863                    // before all real keys: the empty string.
864                    let after_session_id = after_session_id.unwrap_or(Key::Plain(Vec::new()));
865
866                    stmt.query(named_params! {
867                        ":sender_key": sender_key,
868                        ":sender_data_type": sender_data_type,
869                        ":after_session_id": after_session_id,
870                        ":limit": limit,
871                    })?
872                    .mapped(|row| Ok((row.get(0)?, row.get(1)?)))
873                    .collect()
874                },
875            )
876            .await?)
877    }
878
879    async fn get_inbound_group_sessions_for_backup(&self, limit: usize) -> Result<Vec<Vec<u8>>> {
880        Ok(self
881            .prepare(
882                "SELECT data FROM inbound_group_session WHERE backed_up = FALSE LIMIT ?",
883                move |mut stmt| stmt.query((limit,))?.mapped(|row| row.get(0)).collect(),
884            )
885            .await?)
886    }
887
888    async fn mark_inbound_group_sessions_as_backed_up(&self, session_ids: Vec<Key>) -> Result<()> {
889        if session_ids.is_empty() {
890            // We are not expecting to be called with an empty list of sessions
891            warn!("No sessions to mark as backed up!");
892            return Ok(());
893        }
894
895        let session_ids_len = session_ids.len();
896
897        self.chunk_large_query_over(session_ids, None, move |txn, session_ids| {
898            // Safety: placeholders is not generated using any user input except the number
899            // of session IDs, so it is safe from injection.
900            let sql_params = repeat_vars(session_ids_len);
901            let query = format!("UPDATE inbound_group_session SET backed_up = TRUE where session_id IN ({sql_params})");
902            txn.prepare(&query)?.execute(params_from_iter(session_ids.iter()))?;
903            Ok(Vec::<()>::new())
904        }).await?;
905
906        Ok(())
907    }
908
909    async fn reset_inbound_group_session_backup_state(&self) -> Result<()> {
910        self.execute("UPDATE inbound_group_session SET backed_up = FALSE", ()).await?;
911        Ok(())
912    }
913
914    async fn get_outbound_group_session(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
915        Ok(self
916            .query_row(
917                "SELECT data FROM outbound_group_session WHERE room_id = ?",
918                (room_id,),
919                |row| row.get(0),
920            )
921            .await
922            .optional()?)
923    }
924
925    async fn get_device(&self, user_id: Key, device_id: Key) -> Result<Option<Vec<u8>>> {
926        Ok(self
927            .query_row(
928                "SELECT data FROM device WHERE user_id = ? AND device_id = ?",
929                (user_id, device_id),
930                |row| row.get(0),
931            )
932            .await
933            .optional()?)
934    }
935
936    async fn get_user_devices(&self, user_id: Key) -> Result<Vec<Vec<u8>>> {
937        Ok(self
938            .prepare("SELECT data FROM device WHERE user_id = ?", |mut stmt| {
939                stmt.query((user_id,))?.mapped(|row| row.get(0)).collect()
940            })
941            .await?)
942    }
943
944    async fn get_user_identity(&self, user_id: Key) -> Result<Option<Vec<u8>>> {
945        Ok(self
946            .query_row("SELECT data FROM identity WHERE user_id = ?", (user_id,), |row| row.get(0))
947            .await
948            .optional()?)
949    }
950
951    async fn has_olm_hash(&self, data: Vec<u8>) -> Result<bool> {
952        Ok(self
953            .query_row("SELECT count(*) FROM olm_hash WHERE data = ?", (data,), |row| {
954                row.get::<_, i32>(0)
955            })
956            .await?
957            > 0)
958    }
959
960    async fn get_tracked_users(&self) -> Result<Vec<Vec<u8>>> {
961        Ok(self
962            .prepare("SELECT data FROM tracked_user", |mut stmt| {
963                stmt.query(())?.mapped(|row| row.get(0)).collect()
964            })
965            .await?)
966    }
967
968    async fn add_tracked_users(&self, users: Vec<(Key, Vec<u8>)>) -> Result<()> {
969        Ok(self
970            .prepare(
971                "INSERT INTO tracked_user (user_id, data) \
972                 VALUES (?1, ?2) \
973                 ON CONFLICT (user_id) DO UPDATE SET data = ?2",
974                |mut stmt| {
975                    for (user_id, data) in users {
976                        stmt.execute((user_id, data))?;
977                    }
978
979                    Ok(())
980                },
981            )
982            .await?)
983    }
984
985    async fn get_outgoing_secret_request(
986        &self,
987        request_id: Key,
988    ) -> Result<Option<(Vec<u8>, bool)>> {
989        Ok(self
990            .query_row(
991                "SELECT data, sent_out FROM key_requests WHERE request_id = ?",
992                (request_id,),
993                |row| Ok((row.get(0)?, row.get(1)?)),
994            )
995            .await
996            .optional()?)
997    }
998
999    async fn get_secret_request_by_info(&self, info: Key) -> Result<Option<(Vec<u8>, bool)>> {
1000        Ok(self
1001            .query_row("SELECT data, sent_out FROM key_requests WHERE info = ?", (info,), |row| {
1002                Ok((row.get(0)?, row.get(1)?))
1003            })
1004            .await
1005            .optional()?)
1006    }
1007
1008    async fn get_unsent_secret_requests(&self) -> Result<Vec<Vec<u8>>> {
1009        Ok(self
1010            .prepare("SELECT data FROM key_requests WHERE sent_out = FALSE", |mut stmt| {
1011                stmt.query(())?.mapped(|row| row.get(0)).collect()
1012            })
1013            .await?)
1014    }
1015
1016    async fn delete_key_request(&self, request_id: Key) -> Result<()> {
1017        self.execute("DELETE FROM key_requests WHERE request_id = ?", (request_id,)).await?;
1018        Ok(())
1019    }
1020
1021    async fn get_secrets_from_inbox(&self, secret_name: Key) -> Result<Vec<Vec<u8>>> {
1022        Ok(self
1023            .prepare("SELECT secret FROM secrets_inbox WHERE secret_name = ?", |mut stmt| {
1024                stmt.query((secret_name,))?.mapped(|row| row.get(0)).collect()
1025            })
1026            .await?)
1027    }
1028
1029    async fn delete_secrets_from_inbox(&self, secret_name: Key) -> Result<()> {
1030        self.execute("DELETE FROM secrets_inbox WHERE secret_name = ?", (secret_name,)).await?;
1031        Ok(())
1032    }
1033
1034    async fn get_direct_withheld_info(
1035        &self,
1036        session_id: Key,
1037        room_id: Key,
1038    ) -> Result<Option<Vec<u8>>> {
1039        Ok(self
1040            .query_row(
1041                "SELECT data FROM direct_withheld_info WHERE session_id = ?1 AND room_id = ?2",
1042                (session_id, room_id),
1043                |row| row.get(0),
1044            )
1045            .await
1046            .optional()?)
1047    }
1048
1049    async fn get_withheld_sessions_by_room_id(&self, room_id: Key) -> Result<Vec<Vec<u8>>> {
1050        Ok(self
1051            .prepare("SELECT data FROM direct_withheld_info WHERE room_id = ?1", |mut stmt| {
1052                stmt.query((room_id,))?.mapped(|row| row.get(0)).collect()
1053            })
1054            .await?)
1055    }
1056
1057    async fn get_room_settings(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
1058        Ok(self
1059            .query_row("SELECT data FROM room_settings WHERE room_id = ?", (room_id,), |row| {
1060                row.get(0)
1061            })
1062            .await
1063            .optional()?)
1064    }
1065
1066    async fn get_received_room_key_bundle(
1067        &self,
1068        room_id: Key,
1069        sender_user: Key,
1070    ) -> Result<Option<Vec<u8>>> {
1071        Ok(self
1072            .query_row(
1073                "SELECT bundle_data FROM received_room_key_bundle WHERE room_id = ? AND sender_user_id = ?",
1074                (room_id, sender_user),
1075                |row| { row.get(0) },
1076            )
1077            .await
1078            .optional()?)
1079    }
1080
1081    async fn get_room_pending_key_bundle(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
1082        Ok(self
1083            .query_row(
1084                "SELECT data FROM rooms_pending_key_bundle WHERE room_id = ?",
1085                (room_id,),
1086                |row| row.get(0),
1087            )
1088            .await
1089            .optional()?)
1090    }
1091
1092    async fn get_all_rooms_pending_key_bundle(&self) -> Result<Vec<Vec<u8>>> {
1093        Ok(self
1094            .query_many("SELECT data FROM rooms_pending_key_bundle", (), |row| row.get(0))
1095            .await?)
1096    }
1097
1098    async fn has_downloaded_all_room_keys(&self, room_id: Key) -> Result<bool> {
1099        Ok(self
1100            .query_row(
1101                "SELECT EXISTS (SELECT 1 FROM room_key_backups_fully_downloaded WHERE room_id = ?)",
1102                (room_id,),
1103                |row| row.get(0),
1104            )
1105            .await?)
1106    }
1107}
1108
1109#[async_trait]
1110impl SqliteObjectCryptoStoreExt for SqliteAsyncConn {}
1111
1112#[async_trait]
1113impl CryptoStore for SqliteCryptoStore {
1114    type Error = Error;
1115
1116    async fn load_account(&self) -> Result<Option<Account>> {
1117        let conn = self.read().await?;
1118        if let Some(pickle) = conn.get_kv("account").await? {
1119            let pickle = self.deserialize_value(&pickle)?;
1120
1121            let account = Account::from_pickle(pickle).map_err(|_| Error::Unpickle)?;
1122
1123            *self.static_account.write().unwrap() = Some(account.static_data().clone());
1124
1125            Ok(Some(account))
1126        } else {
1127            Ok(None)
1128        }
1129    }
1130
1131    async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
1132        let conn = self.read().await?;
1133        if let Some(i) = conn.get_kv("identity").await? {
1134            let pickle = self.deserialize_value(&i)?;
1135            Ok(Some(PrivateCrossSigningIdentity::from_pickle(pickle).map_err(|_| Error::Unpickle)?))
1136        } else {
1137            Ok(None)
1138        }
1139    }
1140
1141    async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
1142        // Serialize calls to `save_pending_changes`; there are multiple await points
1143        // below, and we're pickling data as we go, so we don't want to
1144        // invalidate data we've previously read and overwrite it in the store.
1145        // TODO: #2000 should make this lock go away, or change its shape.
1146        let _guard = self.save_changes_lock.lock().await;
1147
1148        let pickled_account = if let Some(account) = changes.account {
1149            *self.static_account.write().unwrap() = Some(account.static_data().clone());
1150            Some(account.pickle())
1151        } else {
1152            None
1153        };
1154
1155        let this = self.clone();
1156        self.write()
1157            .await?
1158            .with_transaction(move |txn| {
1159                if let Some(pickled_account) = pickled_account {
1160                    let serialized_account = this.serialize_value(&pickled_account)?;
1161                    txn.set_kv("account", &serialized_account)?;
1162                }
1163
1164                Ok::<_, Error>(())
1165            })
1166            .await?;
1167
1168        Ok(())
1169    }
1170
1171    async fn save_changes(&self, changes: Changes) -> Result<()> {
1172        // Serialize calls to `save_changes`; there are multiple await points below, and
1173        // we're pickling data as we go, so we don't want to invalidate data
1174        // we've previously read and overwrite it in the store.
1175        // TODO: #2000 should make this lock go away, or change its shape.
1176        let _guard = self.save_changes_lock.lock().await;
1177
1178        let pickled_private_identity =
1179            if let Some(i) = changes.private_identity { Some(i.pickle().await) } else { None };
1180
1181        let mut session_changes = Vec::new();
1182
1183        for session in changes.sessions {
1184            let session_id = self.encode_key("session", session.session_id());
1185            let sender_key = self.encode_key("session", session.sender_key().to_base64());
1186            let pickle = session.pickle().await;
1187            session_changes.push((session_id, sender_key, pickle));
1188        }
1189
1190        let mut inbound_session_changes = Vec::new();
1191        for session in changes.inbound_group_sessions {
1192            let room_id = self.encode_key("inbound_group_session", session.room_id().as_bytes());
1193            let session_id = self.encode_key("inbound_group_session", session.session_id());
1194            let pickle = session.pickle().await;
1195            let sender_key =
1196                self.encode_key("inbound_group_session", session.sender_key().to_base64());
1197            inbound_session_changes.push((room_id, session_id, pickle, sender_key));
1198        }
1199
1200        let mut outbound_session_changes = Vec::new();
1201        for session in changes.outbound_group_sessions {
1202            let room_id = self.encode_key("outbound_group_session", session.room_id().as_bytes());
1203            let pickle = session.pickle().await;
1204            outbound_session_changes.push((room_id, pickle));
1205        }
1206
1207        let this = self.clone();
1208        self.write()
1209            .await?
1210            .with_transaction(move |txn| {
1211                if let Some(pickled_private_identity) = &pickled_private_identity {
1212                    let serialized_private_identity =
1213                        this.serialize_value(pickled_private_identity)?;
1214                    txn.set_kv("identity", &serialized_private_identity)?;
1215                }
1216
1217                if let Some(token) = &changes.next_batch_token {
1218                    let serialized_token = this.serialize_value(token)?;
1219                    txn.set_kv("next_batch_token", &serialized_token)?;
1220                }
1221
1222                if let Some(decryption_key) = &changes.backup_decryption_key {
1223                    let serialized_decryption_key = this.serialize_value(decryption_key)?;
1224                    txn.set_kv("recovery_key_v1", &serialized_decryption_key)?;
1225                }
1226
1227                if let Some(backup_version) = &changes.backup_version {
1228                    let serialized_backup_version = this.serialize_value(backup_version)?;
1229                    txn.set_kv("backup_version_v1", &serialized_backup_version)?;
1230                }
1231
1232                if let Some(pickle_key) = &changes.dehydrated_device_pickle_key {
1233                    let serialized_pickle_key = this.serialize_value(pickle_key)?;
1234                    txn.set_kv(DEHYDRATED_DEVICE_PICKLE_KEY, &serialized_pickle_key)?;
1235                }
1236
1237                for device in changes.devices.new.iter().chain(&changes.devices.changed) {
1238                    let user_id = this.encode_key("device", device.user_id().as_bytes());
1239                    let device_id = this.encode_key("device", device.device_id().as_bytes());
1240                    let data = this.serialize_value(&device)?;
1241                    txn.set_device(&user_id, &device_id, &data)?;
1242                }
1243
1244                for device in &changes.devices.deleted {
1245                    let user_id = this.encode_key("device", device.user_id().as_bytes());
1246                    let device_id = this.encode_key("device", device.device_id().as_bytes());
1247                    txn.delete_device(&user_id, &device_id)?;
1248                }
1249
1250                for identity in changes.identities.changed.iter().chain(&changes.identities.new) {
1251                    let user_id = this.encode_key("identity", identity.user_id().as_bytes());
1252                    let data = this.serialize_value(&identity)?;
1253                    txn.set_identity(&user_id, &data)?;
1254                }
1255
1256                for (session_id, sender_key, pickle) in &session_changes {
1257                    let serialized_session = this.serialize_value(&pickle)?;
1258                    txn.set_session(session_id, sender_key, &serialized_session)?;
1259                }
1260
1261                for (room_id, session_id, pickle, sender_key) in &inbound_session_changes {
1262                    let serialized_session = this.serialize_value(&pickle)?;
1263                    txn.set_inbound_group_session(
1264                        room_id,
1265                        session_id,
1266                        &serialized_session,
1267                        pickle.backed_up,
1268                        Some(sender_key),
1269                        Some(pickle.sender_data.to_type() as u8),
1270                    )?;
1271                }
1272
1273                for (room_id, pickle) in &outbound_session_changes {
1274                    let serialized_session = this.serialize_json(&pickle)?;
1275                    txn.set_outbound_group_session(room_id, &serialized_session)?;
1276                }
1277
1278                for hash in &changes.message_hashes {
1279                    let hash = rmp_serde::to_vec(hash)?;
1280                    txn.add_olm_hash(&hash)?;
1281                }
1282
1283                for request in changes.key_requests {
1284                    let request_id = this.encode_key("key_requests", request.request_id.as_bytes());
1285                    let serialized_request = this.serialize_value(&request)?;
1286                    let serialized_info = this.encode_key("key_requests", request.info.as_key());
1287                    txn.set_key_request(
1288                        &request_id,
1289                        request.sent_out,
1290                        &serialized_request,
1291                        &serialized_info,
1292                    )?;
1293                }
1294
1295                for (room_id, data) in changes.withheld_session_info {
1296                    for (session_id, event) in data {
1297                        let session_id = this.encode_key("direct_withheld_info", session_id);
1298                        let room_id = this.encode_key("direct_withheld_info", &room_id);
1299                        let serialized_info = this.serialize_json(&event)?;
1300                        txn.set_direct_withheld(&session_id, &room_id, &serialized_info)?;
1301                    }
1302                }
1303
1304                for (room_id, settings) in changes.room_settings {
1305                    let room_id = this.encode_key("room_settings", room_id.as_bytes());
1306                    let value = this.serialize_value(&settings)?;
1307                    txn.set_room_settings(&room_id, &value)?;
1308                }
1309
1310                for secret in changes.secrets {
1311                    let secret_name =
1312                        this.encode_key("secrets_inbox", secret.secret_name.to_string());
1313                    let value = this.serialize_json(secret.secret.deref())?;
1314                    txn.set_secret(&secret_name, &value)?;
1315                }
1316
1317                for bundle in changes.received_room_key_bundles {
1318                    let room_id =
1319                        this.encode_key("received_room_key_bundle", &bundle.bundle_data.room_id);
1320                    let user_id = this.encode_key("received_room_key_bundle", &bundle.sender_user);
1321                    let value = this.serialize_value(&bundle)?;
1322                    txn.set_received_room_key_bundle(&room_id, &user_id, &value)?;
1323                }
1324
1325                for room in changes.room_key_backups_fully_downloaded {
1326                    let room_id = this.encode_key("room_key_backups_fully_downloaded", &room);
1327                    txn.set_has_downloaded_all_room_keys(&room_id)?;
1328                }
1329
1330                for (room, details) in changes.rooms_pending_key_bundle {
1331                    let room_id = this.encode_key("rooms_pending_key_bundle", &room);
1332                    let value = details.as_ref().map(|d| this.serialize_value(d)).transpose()?;
1333                    txn.set_room_pending_key_bundle(&room_id, value.as_deref())?;
1334                }
1335
1336                Ok::<_, Error>(())
1337            })
1338            .await?;
1339
1340        Ok(())
1341    }
1342
1343    async fn save_inbound_group_sessions(
1344        &self,
1345        sessions: Vec<InboundGroupSession>,
1346        backed_up_to_version: Option<&str>,
1347    ) -> matrix_sdk_crypto::store::Result<(), Self::Error> {
1348        // Sanity-check that the data in the sessions corresponds to backed_up_version
1349        sessions.iter().for_each(|s| {
1350            let backed_up = s.backed_up();
1351            if backed_up != backed_up_to_version.is_some() {
1352                warn!(
1353                    backed_up,
1354                    backed_up_to_version,
1355                    "Session backed-up flag does not correspond to backup version setting",
1356                );
1357            }
1358        });
1359
1360        // Currently, this store doesn't save the backup version separately, so this
1361        // just delegates to save_changes.
1362        self.save_changes(Changes { inbound_group_sessions: sessions, ..Changes::default() }).await
1363    }
1364
1365    async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
1366        let device_keys = self.get_own_device().await?.as_device_keys().clone();
1367
1368        let sessions: Vec<_> = self
1369            .read()
1370            .await?
1371            .get_sessions_for_sender_key(self.encode_key("session", sender_key.as_bytes()))
1372            .await?
1373            .into_iter()
1374            .map(|bytes| {
1375                let pickle = self.deserialize_value(&bytes)?;
1376                Session::from_pickle(device_keys.clone(), pickle).map_err(|_| Error::AccountUnset)
1377            })
1378            .collect::<Result<_>>()?;
1379
1380        if sessions.is_empty() { Ok(None) } else { Ok(Some(sessions)) }
1381    }
1382
1383    #[instrument(skip(self))]
1384    async fn get_inbound_group_session(
1385        &self,
1386        room_id: &RoomId,
1387        session_id: &str,
1388    ) -> Result<Option<InboundGroupSession>> {
1389        let session_id = self.encode_key("inbound_group_session", session_id);
1390        let Some((room_id_from_db, value, backed_up)) =
1391            self.read().await?.get_inbound_group_session(session_id).await?
1392        else {
1393            return Ok(None);
1394        };
1395
1396        let room_id = self.encode_key("inbound_group_session", room_id.as_bytes());
1397        if *room_id != room_id_from_db {
1398            warn!("expected room_id for session_id doesn't match what's in the DB");
1399            return Ok(None);
1400        }
1401
1402        Ok(Some(self.deserialize_and_unpickle_inbound_group_session(value, backed_up)?))
1403    }
1404
1405    async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
1406        self.read()
1407            .await?
1408            .get_inbound_group_sessions()
1409            .await?
1410            .into_iter()
1411            .map(|(value, backed_up)| {
1412                self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1413            })
1414            .collect()
1415    }
1416
1417    async fn get_inbound_group_sessions_by_room_id(
1418        &self,
1419        room_id: &RoomId,
1420    ) -> Result<Vec<InboundGroupSession>> {
1421        let room_id = self.encode_key("inbound_group_session", room_id.as_bytes());
1422        self.read()
1423            .await?
1424            .get_inbound_group_sessions_by_room_id(room_id)
1425            .await?
1426            .into_iter()
1427            .map(|(value, backed_up)| {
1428                self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1429            })
1430            .collect()
1431    }
1432
1433    async fn get_inbound_group_sessions_for_device_batch(
1434        &self,
1435        sender_key: Curve25519PublicKey,
1436        sender_data_type: SenderDataType,
1437        after_session_id: Option<String>,
1438        limit: usize,
1439    ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1440        let after_session_id =
1441            after_session_id.map(|session_id| self.encode_key("inbound_group_session", session_id));
1442        let sender_key = self.encode_key("inbound_group_session", sender_key.to_base64());
1443
1444        self.read()
1445            .await?
1446            .get_inbound_group_sessions_for_device_batch(
1447                sender_key,
1448                sender_data_type,
1449                after_session_id,
1450                limit,
1451            )
1452            .await?
1453            .into_iter()
1454            .map(|(value, backed_up)| {
1455                self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1456            })
1457            .collect()
1458    }
1459
1460    async fn inbound_group_session_counts(
1461        &self,
1462        backup_version: Option<&str>,
1463    ) -> Result<RoomKeyCounts> {
1464        Ok(self.read().await?.get_inbound_group_session_counts(backup_version).await?)
1465    }
1466
1467    async fn inbound_group_sessions_for_backup(
1468        &self,
1469        _backup_version: &str,
1470        limit: usize,
1471    ) -> Result<Vec<InboundGroupSession>> {
1472        self.read()
1473            .await?
1474            .get_inbound_group_sessions_for_backup(limit)
1475            .await?
1476            .into_iter()
1477            .map(|value| self.deserialize_and_unpickle_inbound_group_session(value, false))
1478            .collect()
1479    }
1480
1481    async fn mark_inbound_group_sessions_as_backed_up(
1482        &self,
1483        _backup_version: &str,
1484        session_ids: &[(&RoomId, &str)],
1485    ) -> Result<()> {
1486        Ok(self
1487            .write()
1488            .await?
1489            .mark_inbound_group_sessions_as_backed_up(
1490                session_ids
1491                    .iter()
1492                    .map(|(_, s)| self.encode_key("inbound_group_session", s))
1493                    .collect(),
1494            )
1495            .await?)
1496    }
1497
1498    async fn reset_backup_state(&self) -> Result<()> {
1499        Ok(self.write().await?.reset_inbound_group_session_backup_state().await?)
1500    }
1501
1502    async fn load_backup_keys(&self) -> Result<BackupKeys> {
1503        let conn = self.read().await?;
1504
1505        let backup_version = conn
1506            .get_kv("backup_version_v1")
1507            .await?
1508            .map(|value| self.deserialize_value(&value))
1509            .transpose()?;
1510
1511        let decryption_key = conn
1512            .get_kv("recovery_key_v1")
1513            .await?
1514            .map(|value| self.deserialize_value(&value))
1515            .transpose()?;
1516
1517        Ok(BackupKeys { backup_version, decryption_key })
1518    }
1519
1520    async fn load_dehydrated_device_pickle_key(&self) -> Result<Option<DehydratedDeviceKey>> {
1521        let conn = self.read().await?;
1522
1523        conn.get_kv(DEHYDRATED_DEVICE_PICKLE_KEY)
1524            .await?
1525            .map(|value| self.deserialize_value(&value))
1526            .transpose()
1527    }
1528
1529    async fn delete_dehydrated_device_pickle_key(&self) -> Result<(), Self::Error> {
1530        Ok(self.write().await?.clear_kv(DEHYDRATED_DEVICE_PICKLE_KEY).await?)
1531    }
1532    async fn get_outbound_group_session(
1533        &self,
1534        room_id: &RoomId,
1535    ) -> Result<Option<OutboundGroupSession>> {
1536        let room_id = self.encode_key("outbound_group_session", room_id.as_bytes());
1537        let Some(value) = self.read().await?.get_outbound_group_session(room_id).await? else {
1538            return Ok(None);
1539        };
1540
1541        let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1542
1543        let pickle = self.deserialize_json(&value)?;
1544        let session = OutboundGroupSession::from_pickle(
1545            account_info.device_id,
1546            account_info.identity_keys,
1547            pickle,
1548        )
1549        .map_err(|_| Error::Unpickle)?;
1550
1551        return Ok(Some(session));
1552    }
1553
1554    async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
1555        self.read()
1556            .await?
1557            .get_tracked_users()
1558            .await?
1559            .iter()
1560            .map(|value| self.deserialize_value(value))
1561            .collect()
1562    }
1563
1564    async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
1565        let users: Vec<(Key, Vec<u8>)> = tracked_users
1566            .iter()
1567            .map(|(u, d)| {
1568                let user_id = self.encode_key("tracked_users", u.as_bytes());
1569                let data =
1570                    self.serialize_value(&TrackedUser { user_id: (*u).into(), dirty: *d })?;
1571                Ok((user_id, data))
1572            })
1573            .collect::<Result<_>>()?;
1574
1575        Ok(self.write().await?.add_tracked_users(users).await?)
1576    }
1577
1578    async fn get_device(
1579        &self,
1580        user_id: &UserId,
1581        device_id: &DeviceId,
1582    ) -> Result<Option<DeviceData>> {
1583        let user_id = self.encode_key("device", user_id.as_bytes());
1584        let device_id = self.encode_key("device", device_id.as_bytes());
1585        Ok(self
1586            .read()
1587            .await?
1588            .get_device(user_id, device_id)
1589            .await?
1590            .map(|value| self.deserialize_value(&value))
1591            .transpose()?)
1592    }
1593
1594    async fn get_user_devices(
1595        &self,
1596        user_id: &UserId,
1597    ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1598        let user_id = self.encode_key("device", user_id.as_bytes());
1599        self.read()
1600            .await?
1601            .get_user_devices(user_id)
1602            .await?
1603            .into_iter()
1604            .map(|value| {
1605                let device: DeviceData = self.deserialize_value(&value)?;
1606                Ok((device.device_id().to_owned(), device))
1607            })
1608            .collect()
1609    }
1610
1611    async fn get_own_device(&self) -> Result<DeviceData> {
1612        let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1613
1614        Ok(self
1615            .get_device(&account_info.user_id, &account_info.device_id)
1616            .await?
1617            .expect("We should be able to find our own device."))
1618    }
1619
1620    async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
1621        let user_id = self.encode_key("identity", user_id.as_bytes());
1622        Ok(self
1623            .read()
1624            .await?
1625            .get_user_identity(user_id)
1626            .await?
1627            .map(|value| self.deserialize_value(&value))
1628            .transpose()?)
1629    }
1630
1631    async fn is_message_known(
1632        &self,
1633        message_hash: &matrix_sdk_crypto::olm::OlmMessageHash,
1634    ) -> Result<bool> {
1635        let value = rmp_serde::to_vec(message_hash)?;
1636        Ok(self.read().await?.has_olm_hash(value).await?)
1637    }
1638
1639    async fn get_outgoing_secret_requests(
1640        &self,
1641        request_id: &TransactionId,
1642    ) -> Result<Option<GossipRequest>> {
1643        let request_id = self.encode_key("key_requests", request_id.as_bytes());
1644        Ok(self
1645            .read()
1646            .await?
1647            .get_outgoing_secret_request(request_id)
1648            .await?
1649            .map(|(value, sent_out)| self.deserialize_key_request(&value, sent_out))
1650            .transpose()?)
1651    }
1652
1653    async fn get_secret_request_by_info(
1654        &self,
1655        key_info: &SecretInfo,
1656    ) -> Result<Option<GossipRequest>> {
1657        let key_info = self.encode_key("key_requests", key_info.as_key());
1658        Ok(self
1659            .read()
1660            .await?
1661            .get_secret_request_by_info(key_info)
1662            .await?
1663            .map(|(value, sent_out)| self.deserialize_key_request(&value, sent_out))
1664            .transpose()?)
1665    }
1666
1667    async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
1668        self.read()
1669            .await?
1670            .get_unsent_secret_requests()
1671            .await?
1672            .iter()
1673            .map(|value| {
1674                let request = self.deserialize_key_request(value, false)?;
1675                Ok(request)
1676            })
1677            .collect()
1678    }
1679
1680    async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
1681        let request_id = self.encode_key("key_requests", request_id.as_bytes());
1682        Ok(self.write().await?.delete_key_request(request_id).await?)
1683    }
1684
1685    async fn get_secrets_from_inbox(
1686        &self,
1687        secret_name: &SecretName,
1688    ) -> Result<Vec<Zeroizing<String>>> {
1689        let secret_name = self.encode_key("secrets_inbox", secret_name.to_string());
1690
1691        self.read()
1692            .await?
1693            .get_secrets_from_inbox(secret_name)
1694            .await?
1695            .into_iter()
1696            .map(|value| self.deserialize_json(value.as_ref()).map(|value: String| value.into()))
1697            .collect()
1698    }
1699
1700    async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
1701        let secret_name = self.encode_key("secrets_inbox", secret_name.to_string());
1702        self.write().await?.delete_secrets_from_inbox(secret_name).await
1703    }
1704
1705    async fn get_withheld_info(
1706        &self,
1707        room_id: &RoomId,
1708        session_id: &str,
1709    ) -> Result<Option<RoomKeyWithheldEntry>> {
1710        let room_id = self.encode_key("direct_withheld_info", room_id);
1711        let session_id = self.encode_key("direct_withheld_info", session_id);
1712
1713        self.read()
1714            .await?
1715            .get_direct_withheld_info(session_id, room_id)
1716            .await?
1717            .map(|value| {
1718                let info = self.deserialize_json::<RoomKeyWithheldEntry>(&value)?;
1719                Ok(info)
1720            })
1721            .transpose()
1722    }
1723
1724    async fn get_withheld_sessions_by_room_id(
1725        &self,
1726        room_id: &RoomId,
1727    ) -> matrix_sdk_crypto::store::Result<Vec<RoomKeyWithheldEntry>, Self::Error> {
1728        let room_id = self.encode_key("direct_withheld_info", room_id);
1729
1730        self.read()
1731            .await?
1732            .get_withheld_sessions_by_room_id(room_id)
1733            .await?
1734            .into_iter()
1735            .map(|value| self.deserialize_json(&value))
1736            .collect()
1737    }
1738
1739    async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
1740        let room_id = self.encode_key("room_settings", room_id.as_bytes());
1741        let Some(value) = self.read().await?.get_room_settings(room_id).await? else {
1742            return Ok(None);
1743        };
1744
1745        let settings = self.deserialize_value(&value)?;
1746
1747        return Ok(Some(settings));
1748    }
1749
1750    async fn get_received_room_key_bundle_data(
1751        &self,
1752        room_id: &RoomId,
1753        user_id: &UserId,
1754    ) -> Result<Option<StoredRoomKeyBundleData>> {
1755        let room_id = self.encode_key("received_room_key_bundle", room_id);
1756        let user_id = self.encode_key("received_room_key_bundle", user_id);
1757        self.read()
1758            .await?
1759            .get_received_room_key_bundle(room_id, user_id)
1760            .await?
1761            .map(|value| self.deserialize_value(&value))
1762            .transpose()
1763    }
1764
1765    async fn has_downloaded_all_room_keys(&self, room_id: &RoomId) -> Result<bool> {
1766        let room_id = self.encode_key("room_key_backups_fully_downloaded", room_id);
1767        self.read().await?.has_downloaded_all_room_keys(room_id).await
1768    }
1769
1770    async fn get_pending_key_bundle_details_for_room(
1771        &self,
1772        room_id: &RoomId,
1773    ) -> Result<Option<RoomPendingKeyBundleDetails>> {
1774        let room_id = self.encode_key("rooms_pending_key_bundle", room_id.as_bytes());
1775        let Some(value) = self.read().await?.get_room_pending_key_bundle(room_id).await? else {
1776            return Ok(None);
1777        };
1778
1779        let details = self.deserialize_value(&value)?;
1780        Ok(Some(details))
1781    }
1782
1783    async fn get_all_rooms_pending_key_bundles(&self) -> Result<Vec<RoomPendingKeyBundleDetails>> {
1784        let details = self.read().await?.get_all_rooms_pending_key_bundle().await?;
1785        let room_ids = details
1786            .into_iter()
1787            .map(|value| self.deserialize_value(&value))
1788            .collect::<Result<_, _>>()?;
1789        Ok(room_ids)
1790    }
1791
1792    async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
1793        let Some(serialized) = self.read().await?.get_kv(key).await? else {
1794            return Ok(None);
1795        };
1796        let value = if let Some(cipher) = &self.store_cipher {
1797            let encrypted = rmp_serde::from_slice(&serialized)?;
1798            cipher.decrypt_value_data(encrypted)?
1799        } else {
1800            serialized
1801        };
1802
1803        Ok(Some(value))
1804    }
1805
1806    async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
1807        let serialized = if let Some(cipher) = &self.store_cipher {
1808            let encrypted = cipher.encrypt_value_data(value)?;
1809            rmp_serde::to_vec_named(&encrypted)?
1810        } else {
1811            value
1812        };
1813
1814        self.write().await?.set_kv(key, serialized).await?;
1815        Ok(())
1816    }
1817
1818    async fn remove_custom_value(&self, key: &str) -> Result<()> {
1819        let key = key.to_owned();
1820        self.write()
1821            .await?
1822            .interact(move |conn| conn.execute("DELETE FROM kv WHERE key = ?1", (&key,)))
1823            .await
1824            .unwrap()?;
1825        Ok(())
1826    }
1827
1828    #[instrument(skip(self))]
1829    async fn try_take_leased_lock(
1830        &self,
1831        lease_duration_ms: u32,
1832        key: &str,
1833        holder: &str,
1834    ) -> Result<Option<CrossProcessLockGeneration>> {
1835        let key = key.to_owned();
1836        let holder = holder.to_owned();
1837
1838        let now: u64 = MilliSecondsSinceUnixEpoch::now().get().into();
1839        let expiration = now + lease_duration_ms as u64;
1840
1841        // Learn about the `excluded` keyword in https://sqlite.org/lang_upsert.html.
1842        let generation = self
1843            .write()
1844            .await?
1845            .with_transaction(move |txn| {
1846                txn.query_row(
1847                    "INSERT INTO lease_locks (key, holder, expiration)
1848                    VALUES (?1, ?2, ?3)
1849                    ON CONFLICT (key)
1850                    DO
1851                        UPDATE SET
1852                            holder = excluded.holder,
1853                            expiration = excluded.expiration,
1854                            generation =
1855                                CASE holder
1856                                    WHEN excluded.holder THEN generation
1857                                    ELSE generation + 1
1858                                END
1859                        WHERE
1860                            holder = excluded.holder
1861                            OR expiration < ?4
1862                    RETURNING generation
1863                    ",
1864                    (key, holder, expiration, now),
1865                    |row| row.get(0),
1866                )
1867                .optional()
1868            })
1869            .await?;
1870
1871        Ok(generation)
1872    }
1873
1874    async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
1875        let conn = self.read().await?;
1876        if let Some(token) = conn.get_kv("next_batch_token").await? {
1877            let maybe_token: Option<String> = self.deserialize_value(&token)?;
1878            Ok(maybe_token)
1879        } else {
1880            Ok(None)
1881        }
1882    }
1883
1884    async fn close(&self) -> Result<()> {
1885        connection::close_connections(&self.connections, "Crypto store").await;
1886        Ok(())
1887    }
1888
1889    async fn reopen(&self) -> Result<()> {
1890        connection::reopen_connections(
1891            &self.connections,
1892            self.db_path.clone(),
1893            self.pool_config,
1894            self.runtime_config,
1895        )
1896        .await?;
1897        Ok(())
1898    }
1899
1900    async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1901        Ok(Some(self.read().await?.get_db_size().await?))
1902    }
1903}
1904
1905#[cfg(test)]
1906mod tests {
1907    use std::{path::Path, sync::LazyLock};
1908
1909    use matrix_sdk_common::deserialized_responses::WithheldCode;
1910    use matrix_sdk_crypto::{
1911        cryptostore_integration_tests, cryptostore_integration_tests_time, olm::SenderDataType,
1912        store::CryptoStore,
1913    };
1914    use matrix_sdk_test::async_test;
1915    use ruma::{device_id, room_id, user_id};
1916    use similar_asserts::assert_eq;
1917    use tempfile::{TempDir, tempdir};
1918    use tokio::fs;
1919
1920    use super::SqliteCryptoStore;
1921    use crate::SqliteStoreConfig;
1922
1923    static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
1924
1925    struct TestDb {
1926        // Needs to be kept alive because the Drop implementation for TempDir deletes the
1927        // directory.
1928        _dir: TempDir,
1929        database: SqliteCryptoStore,
1930    }
1931
1932    fn copy_db(data_path: &str) -> TempDir {
1933        let db_name = super::DATABASE_NAME;
1934
1935        let manifest_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../..");
1936        let database_path = manifest_path.join(data_path).join(db_name);
1937
1938        let tmpdir = tempdir().unwrap();
1939        let destination = tmpdir.path().join(db_name);
1940
1941        // Copy the test database to the tempdir so our test runs are idempotent.
1942        std::fs::copy(&database_path, destination).unwrap();
1943
1944        tmpdir
1945    }
1946
1947    async fn get_test_db(data_path: &str, passphrase: Option<&str>) -> TestDb {
1948        let tmpdir = copy_db(data_path);
1949
1950        let database = SqliteCryptoStore::open(tmpdir.path(), passphrase)
1951            .await
1952            .expect("Can't open the test store");
1953
1954        TestDb { _dir: tmpdir, database }
1955    }
1956
1957    #[async_test]
1958    async fn test_pool_size() {
1959        let store_open_config =
1960            SqliteStoreConfig::new(TMP_DIR.path().join("test_pool_size")).pool_max_size(42);
1961
1962        let store = SqliteCryptoStore::open_with_config(&store_open_config).await.unwrap();
1963
1964        let guard = store.connections.lock().await;
1965        let conns = guard.as_ref().unwrap();
1966        assert_eq!(conns.pool.status().max_size, 42);
1967    }
1968
1969    /// Test that we didn't regress in our storage layer by loading data from a
1970    /// pre-filled database, or in other words use a test vector for this.
1971    #[async_test]
1972    async fn test_open_test_vector_store() {
1973        let TestDb { _dir: _, database } = get_test_db("testing/data/storage", None).await;
1974
1975        let account = database
1976            .load_account()
1977            .await
1978            .unwrap()
1979            .expect("The test database is prefilled with data, we should find an account");
1980
1981        let user_id = account.user_id();
1982        let device_id = account.device_id();
1983
1984        assert_eq!(
1985            user_id.as_str(),
1986            "@pjtest:synapse-oidc.element.dev",
1987            "The user ID should match to the one we expect."
1988        );
1989
1990        assert_eq!(
1991            device_id.as_str(),
1992            "v4TqgcuIH6",
1993            "The device ID should match to the one we expect."
1994        );
1995
1996        let device = database
1997            .get_device(user_id, device_id)
1998            .await
1999            .unwrap()
2000            .expect("Our own device should be found in the store.");
2001
2002        assert_eq!(device.device_id(), device_id);
2003        assert_eq!(device.user_id(), user_id);
2004
2005        assert_eq!(
2006            device.ed25519_key().expect("The device should have a Ed25519 key.").to_base64(),
2007            "+cxl1Gl3du5i7UJwfWnoRDdnafFF+xYdAiTYYhYLr8s"
2008        );
2009
2010        assert_eq!(
2011            device.curve25519_key().expect("The device should have a Curve25519 key.").to_base64(),
2012            "4SL9eEUlpyWSUvjljC5oMjknHQQJY7WZKo5S1KL/5VU"
2013        );
2014
2015        let identity = database
2016            .get_user_identity(user_id)
2017            .await
2018            .unwrap()
2019            .expect("The store should contain an identity.");
2020
2021        assert_eq!(identity.user_id(), user_id);
2022
2023        let identity = identity
2024            .own()
2025            .expect("The identity should be of the correct type, it should be our own identity.");
2026
2027        let master_key = identity
2028            .master_key()
2029            .get_first_key()
2030            .expect("Our own identity should have a master key");
2031
2032        assert_eq!(master_key.to_base64(), "iCUEtB1RwANeqRa5epDrblLk4mer/36sylwQ5hYY3oE");
2033    }
2034
2035    /// Test that we didn't regress in our storage layer by loading data from a
2036    /// pre-filled database, or in other words use a test vector for this.
2037    #[async_test]
2038    async fn test_open_test_vector_encrypted_store() {
2039        let TestDb { _dir: _, database } = get_test_db(
2040            "testing/data/storage/alice",
2041            Some(concat!(
2042                "/rCia2fYAJ+twCZ1Xm2mxFCYcmJdyzkdJjwtgXsziWpYS/UeNxnixuSieuwZXm+x1VsJHmWpl",
2043                "H+QIQBZpEGZtC9/S/l8xK+WOCesmET0o6yJ/KP73ofDtjBlnNpPwuHLKFpyTbyicpCgQ4UT+5E",
2044                "UBuJ08TY9Ujdf1D13k5kr5tSZUefDKKCuG1fCRqlU8ByRas1PMQsZxT2W8t7QgBrQiiGmhpo/O",
2045                "Ti4hfx97GOxncKcxTzppiYQNoHs/f15+XXQD7/oiCcqRIuUlXNsU6hRpFGmbYx2Pi1eyQViQCt",
2046                "B5dAEiSD0N8U81wXYnpynuTPtnL+hfnOJIn7Sy7mkERQeKg"
2047            )),
2048        )
2049        .await;
2050
2051        let account = database
2052            .load_account()
2053            .await
2054            .unwrap()
2055            .expect("The test database is prefilled with data, we should find an account");
2056
2057        let user_id = account.user_id();
2058        let device_id = account.device_id();
2059
2060        assert_eq!(
2061            user_id.as_str(),
2062            "@alice:localhost",
2063            "The user ID should match to the one we expect."
2064        );
2065
2066        assert_eq!(
2067            device_id.as_str(),
2068            "JVVORTHFXY",
2069            "The device ID should match to the one we expect."
2070        );
2071
2072        let tracked_users =
2073            database.load_tracked_users().await.expect("Should be tracking some users");
2074
2075        assert_eq!(tracked_users.len(), 6);
2076
2077        let known_users = vec![
2078            user_id!("@alice:localhost"),
2079            user_id!("@dehydration3:localhost"),
2080            user_id!("@eve:localhost"),
2081            user_id!("@bob:localhost"),
2082            user_id!("@malo:localhost"),
2083            user_id!("@carl:localhost"),
2084        ];
2085
2086        // load the identities
2087        for user_id in known_users {
2088            database.get_user_identity(user_id).await.expect("Should load this identity").unwrap();
2089        }
2090
2091        let carl_identity =
2092            database.get_user_identity(user_id!("@carl:localhost")).await.unwrap().unwrap();
2093
2094        assert_eq!(
2095            carl_identity.master_key().get_first_key().unwrap().to_base64(),
2096            "CdhKYYDeBDQveOioXEGWhTPCyzc63Irpar3CNyfun2Q"
2097        );
2098        assert!(!carl_identity.was_previously_verified());
2099
2100        let bob_identity =
2101            database.get_user_identity(user_id!("@bob:localhost")).await.unwrap().unwrap();
2102
2103        assert_eq!(
2104            bob_identity.master_key().get_first_key().unwrap().to_base64(),
2105            "COh2GYOJWSjem5QPRCaGp9iWV83IELG1IzLKW2S3pFY"
2106        );
2107        // Bob is verified so this flag should be set
2108        assert!(bob_identity.was_previously_verified());
2109
2110        let known_devices = vec![
2111            (device_id!("OPXQHCZSKW"), user_id!("@alice:localhost")),
2112            // a dehydrated one
2113            (
2114                device_id!("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho"),
2115                user_id!("@alice:localhost"),
2116            ),
2117            (device_id!("HEEFRFQENV"), user_id!("@alice:localhost")),
2118            (device_id!("JVVORTHFXY"), user_id!("@alice:localhost")),
2119            (device_id!("NQUWWSKKHS"), user_id!("@alice:localhost")),
2120            (device_id!("ORBLPFYCPG"), user_id!("@alice:localhost")),
2121            (device_id!("YXOWENSEGM"), user_id!("@dehydration3:localhost")),
2122            (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
2123            (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
2124            (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
2125            (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
2126            (device_id!("QKUKWJTTQC"), user_id!("@malo:localhost")),
2127            (device_id!("LOUXJECTFG"), user_id!("@malo:localhost")),
2128            (device_id!("MKKMAEVLPB"), user_id!("@carl:localhost")),
2129        ];
2130
2131        for (device_id, user_id) in known_devices {
2132            database.get_device(user_id, device_id).await.expect("Should load the device").unwrap();
2133        }
2134
2135        let known_sender_key_to_session_count = vec![
2136            ("FfYcYfDF4nWy+LHdK6CEpIMlFAQDORc30WUkghL06kM", 1),
2137            ("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho", 1),
2138            ("hAGsoA4a9M6wwEUX5Q1jux1i+tUngLi01n5AmhDoHTY", 1),
2139            ("aKqtSJymLzuoglWFwPGk1r/Vm2LE2hFESzXxn4RNjRM", 0),
2140            ("zHK1psCrgeMn0kaz8hcdvA3INyar9jg1yfrSp0p1pHo", 1),
2141            ("1QmBA316Wj5jIFRwNOti6N6Xh/vW0bsYCcR4uPfy8VQ", 1),
2142            ("g5ef2vZF3VXgSPyODIeXpyHIRkuthvLhGvd6uwYggWU", 1),
2143            ("o7hfupPd1VsNkRIvdlH6ujrEJFSKjFCGbxhAd31XxjI", 1),
2144            ("Z3RxKQLxY7xpP+ZdOGR2SiNE37SrvmRhW7GPu1UGdm8", 1),
2145            ("GDomaav8NiY3J+dNEeApJm+O0FooJ3IpVaIyJzCN4w4", 1),
2146            ("7m7fqkHyEr47V5s/KjaxtJMOr3pSHrrns2q2lWpAQi8", 0),
2147            ("9psAkPUIF8vNbWbnviX3PlwRcaeO53EHJdNtKpTY1X0", 0),
2148            ("mqanh+ztw5oRtpqYQgLGW864i6NY2zpoKMIlrcyC+Aw", 0),
2149            ("fJU/TJdbsv7tVbbpHw1Ke73ziElnM32cNhP2WIg4T10", 0),
2150            ("sUIeFeFcCZoa5IC6nJ6Vrbvztcyx09m8BBg57XKRClg", 1),
2151        ];
2152
2153        for (id, count) in known_sender_key_to_session_count {
2154            let olm_sessions =
2155                database.get_sessions(id).await.expect("Should have some olm sessions");
2156
2157            println!("### Session id: {id:?}");
2158            assert_eq!(olm_sessions.map_or(0, |v| v.len()), count);
2159        }
2160
2161        let inbound_group_sessions = database.get_inbound_group_sessions().await.unwrap();
2162        assert_eq!(inbound_group_sessions.len(), 15);
2163        let known_inbound_group_sessions = vec![
2164            (
2165                "5hNAxrLai3VI0LKBwfh3wLfksfBFWds0W1a5X5/vSXA",
2166                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2167            ),
2168            (
2169                "M6d2eU3y54gaYTbvGSlqa/xc1Az35l56Cp9sxzHWO4g",
2170                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2171            ),
2172            (
2173                "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
2174                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2175            ),
2176            (
2177                "Y74+l9jTo7N5UF+GQwdpgJGe4sn1+QtWITq7BxulHIE",
2178                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2179            ),
2180            (
2181                "HpJxQR57WbQGdY6w2Q+C16znVvbXGa+JvQdRoMpWbXg",
2182                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2183            ),
2184            (
2185                "Xetvi+ydFkZt8dpONGFbEusQb/Chc2V0XlLByZhsbgE",
2186                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2187            ),
2188            (
2189                "wv/WN/39akyerIXczTaIpjAuLnwgXKRtbXFSEHiJqxo",
2190                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2191            ),
2192            (
2193                "nA4gQwL//Cm8OdlyjABl/jChbPT/cP5V4Sd8iuE6H0s",
2194                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2195            ),
2196            (
2197                "bAAgqFeRDTjfEqL6Qf/c9mk55zoNDCSlboAIRd6b0hw",
2198                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2199            ),
2200            (
2201                "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
2202                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2203            ),
2204            (
2205                "h+om7oSw/ZV94fcKaoe8FGXJwQXWOfKQfzbGgNWQILI",
2206                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2207            ),
2208            (
2209                "ul3VXonpgk4lO2L3fEWubP/nxsTmLHqu5v8ZM9vHEcw",
2210                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2211            ),
2212            (
2213                "JXY15UxC3az2mwg8uX4qwgxfvCM4aygiIWMcdNiVQoc",
2214                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2215            ),
2216            (
2217                "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
2218                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
2219            ),
2220            (
2221                "SFkHcbxjUOYF7mUAYI/oEMDZFaXszQbCN6Jza7iemj0",
2222                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
2223            ),
2224        ];
2225
2226        // ensure we can load them all
2227        for (session_id, room_id) in &known_inbound_group_sessions {
2228            database
2229                .get_inbound_group_session(room_id, session_id)
2230                .await
2231                .expect("Should be able to load inbound group session")
2232                .unwrap();
2233        }
2234
2235        let bob_sender_verified = database
2236            .get_inbound_group_session(
2237                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2238                "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
2239            )
2240            .await
2241            .unwrap()
2242            .unwrap();
2243
2244        assert_eq!(bob_sender_verified.sender_data.to_type(), SenderDataType::SenderVerified);
2245        assert!(bob_sender_verified.backed_up());
2246        assert!(!bob_sender_verified.has_been_imported());
2247
2248        let alice_unknown_device = database
2249            .get_inbound_group_session(
2250                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2251                "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
2252            )
2253            .await
2254            .unwrap()
2255            .unwrap();
2256
2257        assert_eq!(alice_unknown_device.sender_data.to_type(), SenderDataType::UnknownDevice);
2258        assert!(alice_unknown_device.backed_up());
2259        assert!(alice_unknown_device.has_been_imported());
2260
2261        let carl_tofu_session = database
2262            .get_inbound_group_session(
2263                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
2264                "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
2265            )
2266            .await
2267            .unwrap()
2268            .unwrap();
2269
2270        assert_eq!(carl_tofu_session.sender_data.to_type(), SenderDataType::SenderUnverified);
2271        assert!(carl_tofu_session.backed_up());
2272        assert!(!carl_tofu_session.has_been_imported());
2273
2274        // Load outbound sessions
2275        database
2276            .get_outbound_group_session(room_id!("!OgRiTRMaUzLdpCeDBM:localhost"))
2277            .await
2278            .unwrap()
2279            .unwrap();
2280        database
2281            .get_outbound_group_session(room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"))
2282            .await
2283            .unwrap()
2284            .unwrap();
2285        database
2286            .get_outbound_group_session(room_id!("!SRstFdydzrGwJYtVfm:localhost"))
2287            .await
2288            .unwrap()
2289            .unwrap();
2290
2291        let withheld_info = database
2292            .get_withheld_info(
2293                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
2294                "SASgZ+EklvAF4QxJclMlDRlmL0fAMjAJJIKFMdb4Ht0",
2295            )
2296            .await
2297            .expect("This session should be withheld")
2298            .unwrap();
2299
2300        assert_eq!(withheld_info.content.withheld_code(), WithheldCode::Unverified);
2301
2302        let backup_keys = database.load_backup_keys().await.expect("backup key should be cached");
2303        assert_eq!(backup_keys.backup_version.unwrap(), "6");
2304        assert!(backup_keys.decryption_key.is_some());
2305    }
2306
2307    /// Test that we migrate the secrets inbox properly.
2308    ///
2309    /// The format for the secrets inbox changed in version 17.  Previously, the
2310    /// secrets inbox stored a full `GossippedSecrets` struct.  In version 17,
2311    /// the secrets inbox now stores only the secret.
2312    #[async_test]
2313    async fn test_secrets_inbox_migration() {
2314        use std::ops::Deref;
2315
2316        use matrix_sdk_crypto::{
2317            GossipRequest, GossippedSecret, SecretInfo,
2318            types::events::{
2319                olm_v1::{DecryptedSecretSendEvent, OlmV1Keys},
2320                secret_send::SecretSendContent,
2321            },
2322            vodozemac::Ed25519SecretKey,
2323        };
2324        use ruma::{TransactionId, events::secret::request::SecretName, owned_user_id};
2325
2326        use crate::utils::{EncryptableStore, SqliteAsyncConnExt};
2327
2328        // Create a database with version 16
2329        let tmpdir = tempdir().unwrap();
2330        let config = SqliteStoreConfig::new(tmpdir.path());
2331        let pool = config.build_pool_of_connections(super::DATABASE_NAME).unwrap();
2332        let conn = pool.get().await.unwrap();
2333        let version = super::initialize_store(&conn, 0).await.unwrap();
2334        let old_data_store = SqliteCryptoStore::create_raw(
2335            config.secret.clone(),
2336            pool,
2337            conn,
2338            config.pool_config(),
2339            config.runtime_config(),
2340        )
2341        .await
2342        .unwrap();
2343        super::run_migrations(&old_data_store, version, Some(16)).await.unwrap();
2344        old_data_store.write().await.unwrap().wal_checkpoint().await;
2345
2346        // Store a secret using the old format
2347        let secret = GossippedSecret {
2348            secret_name: SecretName::CrossSigningMasterKey,
2349            gossip_request: GossipRequest {
2350                request_recipient: owned_user_id!("@alice:example.com"),
2351                request_id: TransactionId::new(),
2352                info: SecretInfo::SecretRequest(SecretName::CrossSigningMasterKey),
2353                sent_out: true,
2354            },
2355            event: DecryptedSecretSendEvent {
2356                sender: owned_user_id!("@alice:example.com"),
2357                recipient: owned_user_id!("@alice:example.com"),
2358                keys: OlmV1Keys { ed25519: Ed25519SecretKey::new().public_key() },
2359                recipient_keys: OlmV1Keys { ed25519: Ed25519SecretKey::new().public_key() },
2360                sender_device_keys: None,
2361                content: SecretSendContent::new(
2362                    "abc".into(),
2363                    "It is a secret to everybody".to_owned(),
2364                ),
2365            },
2366        };
2367        let value = old_data_store.serialize_json(&secret).unwrap();
2368        old_data_store
2369            .write()
2370            .await
2371            .unwrap()
2372            .prepare("INSERT INTO secrets (secret_name, data) VALUES (?1, ?2)", |mut stmt| {
2373                stmt.execute((SecretName::CrossSigningMasterKey.to_string(), value))
2374            })
2375            .await
2376            .unwrap();
2377
2378        // After we open the store, the data will be migrated
2379        let store = SqliteCryptoStore::open_with_config(&config).await.unwrap();
2380
2381        // and we should be able to read the secrets from the inbox
2382        let secrets =
2383            store.get_secrets_from_inbox(&SecretName::CrossSigningMasterKey).await.unwrap();
2384        assert_eq!(secrets.len(), 1);
2385        assert_eq!(secrets[0].deref(), "It is a secret to everybody");
2386    }
2387
2388    /// Test that we migrate the key requests table properly.
2389    ///
2390    /// Version 18 added a new column, with a unique index on the column,
2391    /// meaning that it can now only store one request per requested secret/key.
2392    /// Test that when we migrate from an older version that has multiple
2393    /// requests for the same secret, it only keeps one.
2394    #[async_test]
2395    async fn test_key_requests_migration() {
2396        use matrix_sdk_crypto::{GossipRequest, SecretInfo};
2397        use ruma::{TransactionId, events::secret::request::SecretName, owned_user_id};
2398
2399        use crate::utils::{EncryptableStore, SqliteAsyncConnExt};
2400
2401        // Create a database with version 16
2402        let tmpdir = tempdir().unwrap();
2403        let config = SqliteStoreConfig::new(tmpdir.path());
2404        let pool = config.build_pool_of_connections(super::DATABASE_NAME).unwrap();
2405        let conn = pool.get().await.unwrap();
2406        let version = super::initialize_store(&conn, 0).await.unwrap();
2407        let old_data_store = SqliteCryptoStore::create_raw(
2408            config.secret.clone(),
2409            pool,
2410            conn,
2411            config.pool_config(),
2412            config.runtime_config(),
2413        )
2414        .await
2415        .unwrap();
2416        super::run_migrations(&old_data_store, version, Some(16)).await.unwrap();
2417        old_data_store.write().await.unwrap().wal_checkpoint().await;
2418
2419        // Store a secret using the old format
2420        let recovery_request1 = GossipRequest {
2421            request_recipient: owned_user_id!("@alice:example.com"),
2422            request_id: TransactionId::new(),
2423            info: SecretInfo::SecretRequest(SecretName::RecoveryKey),
2424            sent_out: true,
2425        };
2426        let serialized_recovery_request1 =
2427            old_data_store.serialize_value(&recovery_request1).unwrap();
2428        let recovery_request2 = GossipRequest {
2429            request_recipient: owned_user_id!("@alice:example.com"),
2430            request_id: TransactionId::new(),
2431            info: SecretInfo::SecretRequest(SecretName::RecoveryKey),
2432            sent_out: true,
2433        };
2434        let serialized_recovery_request2 =
2435            old_data_store.serialize_value(&recovery_request2).unwrap();
2436        let msk_request = GossipRequest {
2437            request_recipient: owned_user_id!("@alice:example.com"),
2438            request_id: TransactionId::new(),
2439            info: SecretInfo::SecretRequest(SecretName::CrossSigningMasterKey),
2440            sent_out: true,
2441        };
2442        let serialized_msk_request = old_data_store.serialize_value(&msk_request).unwrap();
2443        let recovery_request1_clone = recovery_request1.clone();
2444        let recovery_request2_clone = recovery_request2.clone();
2445        let msk_request_clone = msk_request.clone();
2446        old_data_store
2447            .write()
2448            .await
2449            .unwrap()
2450            .prepare(
2451                "INSERT INTO key_requests (request_id, sent_out, data) VALUES (?1, ?2, ?3)",
2452                move |mut stmt| {
2453                    stmt.execute((
2454                        old_data_store.encode_key(
2455                            "key_requests",
2456                            recovery_request1_clone.request_id.as_bytes(),
2457                        ),
2458                        recovery_request1_clone.sent_out,
2459                        serialized_recovery_request1,
2460                    ))?;
2461                    stmt.execute((
2462                        old_data_store.encode_key(
2463                            "key_requests",
2464                            recovery_request2_clone.request_id.as_bytes(),
2465                        ),
2466                        recovery_request2_clone.sent_out,
2467                        serialized_recovery_request2,
2468                    ))?;
2469                    stmt.execute((
2470                        old_data_store
2471                            .encode_key("key_requests", msk_request_clone.request_id.as_bytes()),
2472                        msk_request_clone.sent_out,
2473                        serialized_msk_request,
2474                    ))
2475                },
2476            )
2477            .await
2478            .unwrap();
2479
2480        // After we open the store, the data will be migrated
2481        let store = SqliteCryptoStore::open_with_config(&config).await.unwrap();
2482
2483        // and we should be able to read one request for the recovery key and
2484        // one request for the MSK
2485        if let Some(GossipRequest {
2486            request_id,
2487            info: SecretInfo::SecretRequest(SecretName::RecoveryKey),
2488            ..
2489        }) = store
2490            .get_secret_request_by_info(&SecretInfo::SecretRequest(SecretName::RecoveryKey))
2491            .await
2492            .unwrap()
2493        {
2494            if request_id == recovery_request1.request_id {
2495                assert!(
2496                    store
2497                        .get_outgoing_secret_requests(&recovery_request2.request_id)
2498                        .await
2499                        .unwrap()
2500                        .is_none()
2501                );
2502            } else if request_id == recovery_request2.request_id {
2503                assert!(
2504                    store
2505                        .get_outgoing_secret_requests(&recovery_request1.request_id)
2506                        .await
2507                        .unwrap()
2508                        .is_none()
2509                );
2510            } else {
2511                panic!("unexpected record found");
2512            }
2513        } else {
2514            panic!("expected to get a secret request");
2515        }
2516        if let Some(GossipRequest {
2517            request_id,
2518            info: SecretInfo::SecretRequest(SecretName::CrossSigningMasterKey),
2519            ..
2520        }) = store
2521            .get_secret_request_by_info(&SecretInfo::SecretRequest(
2522                SecretName::CrossSigningMasterKey,
2523            ))
2524            .await
2525            .unwrap()
2526        {
2527            assert_eq!(request_id, msk_request.request_id);
2528        } else {
2529            panic!("expected to get a secret request");
2530        }
2531    }
2532
2533    async fn get_store(
2534        name: &str,
2535        passphrase: Option<&str>,
2536        clear_data: bool,
2537    ) -> SqliteCryptoStore {
2538        let tmpdir_path = TMP_DIR.path().join(name);
2539
2540        if clear_data {
2541            let _ = fs::remove_dir_all(&tmpdir_path).await;
2542        }
2543
2544        SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), passphrase)
2545            .await
2546            .expect("Can't create a secret protected store")
2547    }
2548
2549    cryptostore_integration_tests!();
2550    cryptostore_integration_tests_time!();
2551}
2552
2553#[cfg(test)]
2554mod encrypted_tests {
2555    use std::sync::LazyLock;
2556
2557    use matrix_sdk_crypto::{cryptostore_integration_tests, cryptostore_integration_tests_time};
2558    use tempfile::{TempDir, tempdir};
2559    use tokio::fs;
2560
2561    use super::SqliteCryptoStore;
2562
2563    static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
2564
2565    async fn get_store(
2566        name: &str,
2567        passphrase: Option<&str>,
2568        clear_data: bool,
2569    ) -> SqliteCryptoStore {
2570        let tmpdir_path = TMP_DIR.path().join(name);
2571        let pass = passphrase.unwrap_or("default_test_password");
2572
2573        if clear_data {
2574            let _ = fs::remove_dir_all(&tmpdir_path).await;
2575        }
2576
2577        SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), Some(pass))
2578            .await
2579            .expect("Can't create a secret protected store")
2580    }
2581
2582    cryptostore_integration_tests!();
2583    cryptostore_integration_tests_time!();
2584}
2585
2586#[cfg(test)]
2587mod close_reopen_tests {
2588    use std::sync::LazyLock;
2589
2590    use matrix_sdk_crypto::store::CryptoStore;
2591    use matrix_sdk_test::async_test;
2592    use tempfile::{TempDir, tempdir};
2593
2594    use super::SqliteCryptoStore;
2595
2596    static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
2597
2598    async fn new_store(name: &str) -> SqliteCryptoStore {
2599        let tmpdir_path = TMP_DIR.path().join(name);
2600        SqliteCryptoStore::open(tmpdir_path, None).await.unwrap()
2601    }
2602
2603    #[async_test]
2604    async fn test_close_completes_without_timeout() {
2605        let store = new_store("close_no_timeout").await;
2606
2607        // Close should complete quickly without hitting the 5s timeout.
2608        let start = std::time::Instant::now();
2609        store.close().await.unwrap();
2610        let elapsed = start.elapsed();
2611
2612        assert!(
2613            elapsed < std::time::Duration::from_secs(2),
2614            "close() took {elapsed:?}, expected < 2s (no timeout)"
2615        );
2616
2617        // Connections should be None after close.
2618        let guard = store.connections.lock().await;
2619        assert!(guard.is_none(), "connections should be None after close");
2620    }
2621
2622    #[async_test]
2623    async fn test_reopen_restores_connections() {
2624        let store = new_store("reopen_restores").await;
2625
2626        store.close().await.unwrap();
2627
2628        {
2629            let guard = store.connections.lock().await;
2630            assert!(guard.is_none());
2631        }
2632
2633        store.reopen().await.unwrap();
2634
2635        {
2636            let guard = store.connections.lock().await;
2637            assert!(guard.is_some(), "connections should be Some after reopen");
2638        }
2639    }
2640
2641    #[async_test]
2642    async fn test_close_is_idempotent() {
2643        let store = new_store("close_idempotent").await;
2644
2645        store.close().await.unwrap();
2646        // Second close should be a no-op.
2647        store.close().await.unwrap();
2648
2649        let guard = store.connections.lock().await;
2650        assert!(guard.is_none());
2651    }
2652
2653    #[async_test]
2654    async fn test_reopen_is_idempotent() {
2655        let store = new_store("reopen_idempotent").await;
2656
2657        // Reopen on an active store should be a no-op.
2658        store.reopen().await.unwrap();
2659
2660        let guard = store.connections.lock().await;
2661        assert!(guard.is_some());
2662    }
2663
2664    #[async_test]
2665    async fn test_read_fails_when_closed() {
2666        let store = new_store("read_fails_closed").await;
2667        store.close().await.unwrap();
2668
2669        let err = store.load_account().await;
2670        assert!(err.is_err(), "read should fail when closed");
2671
2672        let err_msg = err.unwrap_err().to_string();
2673        assert!(err_msg.contains("closed"), "error should mention 'closed', got: {err_msg}");
2674    }
2675
2676    #[async_test]
2677    async fn test_operations_work_after_reopen() {
2678        let store = new_store("ops_after_reopen").await;
2679
2680        store.close().await.unwrap();
2681        store.reopen().await.unwrap();
2682
2683        // A read operation should work immediately after reopen.
2684        let account = store.load_account().await;
2685        assert!(account.is_ok(), "load_account should succeed after reopen");
2686        // No account was saved, so this should be None.
2687        assert!(account.unwrap().is_none());
2688    }
2689
2690    #[async_test]
2691    async fn test_multiple_close_reopen_cycles() {
2692        let store = new_store("multi_cycles").await;
2693
2694        for _ in 0..5 {
2695            store.close().await.unwrap();
2696            store.reopen().await.unwrap();
2697
2698            // After each cycle, the store should be fully operational.
2699            let account = store.load_account().await;
2700            assert!(account.is_ok(), "store should work after close/reopen cycle");
2701        }
2702    }
2703
2704    #[async_test]
2705    async fn test_pool_is_fully_drained_after_close() {
2706        let store = new_store("pool_drained").await;
2707
2708        // Do a few reads to exercise the pool.
2709        let _ = store.load_account().await;
2710        let _ = store.load_account().await;
2711
2712        store.close().await.unwrap();
2713
2714        // After close, the connections field should be None.
2715        let guard = store.connections.lock().await;
2716        assert!(guard.is_none(), "all connections should be released after close");
2717    }
2718
2719    #[async_test]
2720    async fn test_close_waits_for_held_read_connection_to_drain() {
2721        let store = new_store("held_read_drain").await;
2722
2723        // Acquire a read connection and hold it, simulating an in-flight read.
2724        let held_conn = store.read().await.unwrap();
2725
2726        // Spawn close in a background task — it will close the pool and then
2727        // poll-wait for pool.status().size == 0 in the drain loop.
2728        let store_clone = store.clone();
2729        let close_handle = tokio::spawn(async move {
2730            store_clone.close().await.unwrap();
2731        });
2732
2733        // Give close() a moment to close the pool and enter the drain loop.
2734        tokio::time::sleep(std::time::Duration::from_millis(200)).await;
2735
2736        // The close task should still be running because we hold a connection.
2737        assert!(!close_handle.is_finished(), "close should be waiting for the held connection");
2738
2739        // Release the held connection — this lets pool.status().size drop to 0.
2740        drop(held_conn);
2741
2742        // Now close should complete promptly (well within the 5s timeout).
2743        let timeout = tokio::time::timeout(std::time::Duration::from_secs(3), close_handle).await;
2744        assert!(timeout.is_ok(), "close should complete after the held connection is released");
2745        timeout.unwrap().unwrap();
2746
2747        // Verify the store is fully closed.
2748        let guard = store.connections.lock().await;
2749        assert!(guard.is_none(), "connections should be None after close");
2750    }
2751}