matrix_sdk_sqlite/
crypto_store.rs

1// Copyright 2022 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::HashMap,
17    fmt,
18    path::Path,
19    sync::{Arc, RwLock},
20};
21
22use async_trait::async_trait;
23use deadpool_sqlite::{Object as SqliteAsyncConn, Pool as SqlitePool, Runtime};
24use matrix_sdk_crypto::{
25    olm::{
26        InboundGroupSession, OutboundGroupSession, PickledInboundGroupSession,
27        PrivateCrossSigningIdentity, SenderDataType, Session, StaticAccountData,
28    },
29    store::{
30        types::{
31            BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts, RoomSettings,
32            StoredRoomKeyBundleData,
33        },
34        CryptoStore,
35    },
36    types::events::room_key_withheld::RoomKeyWithheldEvent,
37    Account, DeviceData, GossipRequest, GossippedSecret, SecretInfo, TrackedUser, UserIdentityData,
38};
39use matrix_sdk_store_encryption::StoreCipher;
40use ruma::{
41    events::secret::request::SecretName, DeviceId, MilliSecondsSinceUnixEpoch, OwnedDeviceId,
42    RoomId, TransactionId, UserId,
43};
44use rusqlite::{named_params, params_from_iter, OptionalExtension};
45use tokio::{fs, sync::Mutex};
46use tracing::{debug, instrument, warn};
47use vodozemac::Curve25519PublicKey;
48
49use crate::{
50    error::{Error, Result},
51    utils::{
52        repeat_vars, EncryptableStore, Key, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt,
53        SqliteKeyValueStoreConnExt,
54    },
55    OpenStoreError, SqliteStoreConfig,
56};
57
58/// The database name.
59const DATABASE_NAME: &str = "matrix-sdk-crypto.sqlite3";
60
61/// An SQLite-based crypto store.
62#[derive(Clone)]
63pub struct SqliteCryptoStore {
64    store_cipher: Option<Arc<StoreCipher>>,
65    pool: SqlitePool,
66
67    // DB values cached in memory
68    static_account: Arc<RwLock<Option<StaticAccountData>>>,
69    save_changes_lock: Arc<Mutex<()>>,
70}
71
72#[cfg(not(tarpaulin_include))]
73impl fmt::Debug for SqliteCryptoStore {
74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75        f.debug_struct("SqliteCryptoStore").finish_non_exhaustive()
76    }
77}
78
79impl EncryptableStore for SqliteCryptoStore {
80    fn get_cypher(&self) -> Option<&StoreCipher> {
81        self.store_cipher.as_deref()
82    }
83}
84
85impl SqliteCryptoStore {
86    /// Open the SQLite-based crypto store at the given path using the given
87    /// passphrase to encrypt private data.
88    pub async fn open(
89        path: impl AsRef<Path>,
90        passphrase: Option<&str>,
91    ) -> Result<Self, OpenStoreError> {
92        Self::open_with_config(SqliteStoreConfig::new(path).passphrase(passphrase)).await
93    }
94
95    /// Open the SQLite-based crypto store with the config open config.
96    pub async fn open_with_config(config: SqliteStoreConfig) -> Result<Self, OpenStoreError> {
97        let SqliteStoreConfig { path, passphrase, pool_config, runtime_config } = config;
98
99        fs::create_dir_all(&path).await.map_err(OpenStoreError::CreateDir)?;
100
101        let mut config = deadpool_sqlite::Config::new(path.join(DATABASE_NAME));
102        config.pool = Some(pool_config);
103
104        let pool = config.create_pool(Runtime::Tokio1)?;
105
106        let this = Self::open_with_pool(pool, passphrase.as_deref()).await?;
107        this.pool.get().await?.apply_runtime_config(runtime_config).await?;
108
109        Ok(this)
110    }
111
112    /// Create an SQLite-based crypto store using the given SQLite database
113    /// pool. The given passphrase will be used to encrypt private data.
114    async fn open_with_pool(
115        pool: SqlitePool,
116        passphrase: Option<&str>,
117    ) -> Result<Self, OpenStoreError> {
118        let conn = pool.get().await?;
119
120        let version = conn.db_version().await?;
121        debug!("Opened sqlite store with version {}", version);
122        run_migrations(&conn, version).await?;
123
124        let store_cipher = match passphrase {
125            Some(p) => Some(Arc::new(conn.get_or_create_store_cipher(p).await?)),
126            None => None,
127        };
128
129        Ok(SqliteCryptoStore {
130            store_cipher,
131            pool,
132            static_account: Arc::new(RwLock::new(None)),
133            save_changes_lock: Default::default(),
134        })
135    }
136
137    fn deserialize_and_unpickle_inbound_group_session(
138        &self,
139        value: Vec<u8>,
140        backed_up: bool,
141    ) -> Result<InboundGroupSession> {
142        let mut pickle: PickledInboundGroupSession = self.deserialize_value(&value)?;
143
144        // The `backed_up` SQL column is the source of truth, because we update it
145        // inside `mark_inbound_group_sessions_as_backed_up` and don't update
146        // the pickled value inside the `data` column (until now, when we are puling it
147        // out of the DB).
148        pickle.backed_up = backed_up;
149
150        Ok(InboundGroupSession::from_pickle(pickle)?)
151    }
152
153    fn deserialize_key_request(&self, value: &[u8], sent_out: bool) -> Result<GossipRequest> {
154        let mut request: GossipRequest = self.deserialize_value(value)?;
155        // sent_out SQL column is source of truth, sent_out field in serialized value
156        // needed for other stores though
157        request.sent_out = sent_out;
158        Ok(request)
159    }
160
161    fn get_static_account(&self) -> Option<StaticAccountData> {
162        self.static_account.read().unwrap().clone()
163    }
164
165    async fn acquire(&self) -> Result<SqliteAsyncConn> {
166        Ok(self.pool.get().await?)
167    }
168}
169
170const DATABASE_VERSION: u8 = 11;
171
172/// key for the dehydrated device pickle key in the key/value table.
173const DEHYDRATED_DEVICE_PICKLE_KEY: &str = "dehydrated_device_pickle_key";
174
175/// Run migrations for the given version of the database.
176async fn run_migrations(conn: &SqliteAsyncConn, version: u8) -> Result<()> {
177    if version == 0 {
178        debug!("Creating database");
179    } else if version < DATABASE_VERSION {
180        debug!(version, new_version = DATABASE_VERSION, "Upgrading database");
181    } else {
182        return Ok(());
183    }
184
185    if version < 1 {
186        // First turn on WAL mode, this can't be done in the transaction, it fails with
187        // the error message: "cannot change into wal mode from within a transaction".
188        conn.execute_batch("PRAGMA journal_mode = wal;").await?;
189        conn.with_transaction(|txn| {
190            txn.execute_batch(include_str!("../migrations/crypto_store/001_init.sql"))?;
191            txn.set_db_version(1)
192        })
193        .await?;
194    }
195
196    if version < 2 {
197        conn.with_transaction(|txn| {
198            txn.execute_batch(include_str!("../migrations/crypto_store/002_reset_olm_hash.sql"))?;
199            txn.set_db_version(2)
200        })
201        .await?;
202    }
203
204    if version < 3 {
205        conn.with_transaction(|txn| {
206            txn.execute_batch(include_str!("../migrations/crypto_store/003_room_settings.sql"))?;
207            txn.set_db_version(3)
208        })
209        .await?;
210    }
211
212    if version < 4 {
213        conn.with_transaction(|txn| {
214            txn.execute_batch(include_str!(
215                "../migrations/crypto_store/004_drop_outbound_group_sessions.sql"
216            ))?;
217            txn.set_db_version(4)
218        })
219        .await?;
220    }
221
222    if version < 5 {
223        conn.with_transaction(|txn| {
224            txn.execute_batch(include_str!("../migrations/crypto_store/005_withheld_code.sql"))?;
225            txn.set_db_version(5)
226        })
227        .await?;
228    }
229
230    if version < 6 {
231        conn.with_transaction(|txn| {
232            txn.execute_batch(include_str!(
233                "../migrations/crypto_store/006_drop_outbound_group_sessions.sql"
234            ))?;
235            txn.set_db_version(6)
236        })
237        .await?;
238    }
239
240    if version < 7 {
241        conn.with_transaction(|txn| {
242            txn.execute_batch(include_str!("../migrations/crypto_store/007_lock_leases.sql"))?;
243            txn.set_db_version(7)
244        })
245        .await?;
246    }
247
248    if version < 8 {
249        conn.with_transaction(|txn| {
250            txn.execute_batch(include_str!("../migrations/crypto_store/008_secret_inbox.sql"))?;
251            txn.set_db_version(8)
252        })
253        .await?;
254    }
255
256    if version < 9 {
257        conn.with_transaction(|txn| {
258            txn.execute_batch(include_str!(
259                "../migrations/crypto_store/009_inbound_group_session_sender_key_sender_data_type.sql"
260            ))?;
261            txn.set_db_version(9)
262        })
263        .await?;
264    }
265
266    if version < 10 {
267        conn.with_transaction(|txn| {
268            txn.execute_batch(include_str!(
269                "../migrations/crypto_store/010_received_room_key_bundles.sql"
270            ))?;
271            txn.set_db_version(10)
272        })
273        .await?;
274    }
275
276    if version < 11 {
277        conn.with_transaction(|txn| {
278            txn.execute_batch(include_str!(
279                "../migrations/crypto_store/011_received_room_key_bundles_with_curve_key.sql"
280            ))?;
281            txn.set_db_version(11)
282        })
283        .await?;
284    }
285
286    Ok(())
287}
288
289trait SqliteConnectionExt {
290    fn set_session(
291        &self,
292        session_id: &[u8],
293        sender_key: &[u8],
294        data: &[u8],
295    ) -> rusqlite::Result<()>;
296
297    fn set_inbound_group_session(
298        &self,
299        room_id: &[u8],
300        session_id: &[u8],
301        data: &[u8],
302        backed_up: bool,
303        sender_key: Option<&[u8]>,
304        sender_data_type: Option<u8>,
305    ) -> rusqlite::Result<()>;
306
307    fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
308
309    fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
310    fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()>;
311
312    fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
313
314    fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()>;
315
316    fn set_key_request(
317        &self,
318        request_id: &[u8],
319        sent_out: bool,
320        data: &[u8],
321    ) -> rusqlite::Result<()>;
322
323    fn set_direct_withheld(
324        &self,
325        session_id: &[u8],
326        room_id: &[u8],
327        data: &[u8],
328    ) -> rusqlite::Result<()>;
329
330    fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
331
332    fn set_secret(&self, request_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
333
334    fn set_received_room_key_bundle(
335        &self,
336        room_id: &[u8],
337        user_id: &[u8],
338        data: &[u8],
339    ) -> rusqlite::Result<()>;
340}
341
342impl SqliteConnectionExt for rusqlite::Connection {
343    fn set_session(
344        &self,
345        session_id: &[u8],
346        sender_key: &[u8],
347        data: &[u8],
348    ) -> rusqlite::Result<()> {
349        self.execute(
350            "INSERT INTO session (session_id, sender_key, data)
351             VALUES (?1, ?2, ?3)
352             ON CONFLICT (session_id) DO UPDATE SET data = ?3",
353            (session_id, sender_key, data),
354        )?;
355        Ok(())
356    }
357
358    fn set_inbound_group_session(
359        &self,
360        room_id: &[u8],
361        session_id: &[u8],
362        data: &[u8],
363        backed_up: bool,
364        sender_key: Option<&[u8]>,
365        sender_data_type: Option<u8>,
366    ) -> rusqlite::Result<()> {
367        self.execute(
368            "INSERT INTO inbound_group_session (session_id, room_id, data, backed_up, sender_key, sender_data_type) \
369             VALUES (?1, ?2, ?3, ?4, ?5, ?6)
370             ON CONFLICT (session_id) DO UPDATE SET data = ?3, backed_up = ?4, sender_key = ?5, sender_data_type = ?6",
371            (session_id, room_id, data, backed_up, sender_key, sender_data_type),
372        )?;
373        Ok(())
374    }
375
376    fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
377        self.execute(
378            "INSERT INTO outbound_group_session (room_id, data) \
379             VALUES (?1, ?2)
380             ON CONFLICT (room_id) DO UPDATE SET data = ?2",
381            (room_id, data),
382        )?;
383        Ok(())
384    }
385
386    fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
387        self.execute(
388            "INSERT INTO device (user_id, device_id, data) \
389             VALUES (?1, ?2, ?3)
390             ON CONFLICT (user_id, device_id) DO UPDATE SET data = ?3",
391            (user_id, device_id, data),
392        )?;
393        Ok(())
394    }
395
396    fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()> {
397        self.execute(
398            "DELETE FROM device WHERE user_id = ? AND device_id = ?",
399            (user_id, device_id),
400        )?;
401        Ok(())
402    }
403
404    fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
405        self.execute(
406            "INSERT INTO identity (user_id, data) \
407             VALUES (?1, ?2)
408             ON CONFLICT (user_id) DO UPDATE SET data = ?2",
409            (user_id, data),
410        )?;
411        Ok(())
412    }
413
414    fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()> {
415        self.execute("INSERT INTO olm_hash (data) VALUES (?) ON CONFLICT DO NOTHING", (data,))?;
416        Ok(())
417    }
418
419    fn set_key_request(
420        &self,
421        request_id: &[u8],
422        sent_out: bool,
423        data: &[u8],
424    ) -> rusqlite::Result<()> {
425        self.execute(
426            "INSERT INTO key_requests (request_id, sent_out, data)
427            VALUES (?1, ?2, ?3)
428            ON CONFLICT (request_id) DO UPDATE SET sent_out = ?2, data = ?3",
429            (request_id, sent_out, data),
430        )?;
431        Ok(())
432    }
433
434    fn set_direct_withheld(
435        &self,
436        session_id: &[u8],
437        room_id: &[u8],
438        data: &[u8],
439    ) -> rusqlite::Result<()> {
440        self.execute(
441            "INSERT INTO direct_withheld_info (session_id, room_id, data)
442            VALUES (?1, ?2, ?3)
443            ON CONFLICT (session_id) DO UPDATE SET room_id = ?2, data = ?3",
444            (session_id, room_id, data),
445        )?;
446        Ok(())
447    }
448
449    fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
450        self.execute(
451            "INSERT INTO room_settings (room_id, data)
452            VALUES (?1, ?2)
453            ON CONFLICT (room_id) DO UPDATE SET data = ?2",
454            (room_id, data),
455        )?;
456        Ok(())
457    }
458
459    fn set_secret(&self, secret_name: &[u8], data: &[u8]) -> rusqlite::Result<()> {
460        self.execute(
461            "INSERT INTO secrets (secret_name, data)
462            VALUES (?1, ?2)",
463            (secret_name, data),
464        )?;
465
466        Ok(())
467    }
468
469    fn set_received_room_key_bundle(
470        &self,
471        room_id: &[u8],
472        sender_user_id: &[u8],
473        data: &[u8],
474    ) -> rusqlite::Result<()> {
475        self.execute(
476            "INSERT INTO received_room_key_bundle(room_id, sender_user_id, bundle_data)
477            VALUES (?1, ?2, ?3)
478            ON CONFLICT (room_id, sender_user_id) DO UPDATE SET bundle_data = ?3",
479            (room_id, sender_user_id, data),
480        )?;
481        Ok(())
482    }
483}
484
485#[async_trait]
486trait SqliteObjectCryptoStoreExt: SqliteAsyncConnExt {
487    async fn get_sessions_for_sender_key(&self, sender_key: Key) -> Result<Vec<Vec<u8>>> {
488        Ok(self
489            .prepare("SELECT data FROM session WHERE sender_key = ?", |mut stmt| {
490                stmt.query((sender_key,))?.mapped(|row| row.get(0)).collect()
491            })
492            .await?)
493    }
494
495    async fn get_inbound_group_session(
496        &self,
497        session_id: Key,
498    ) -> Result<Option<(Vec<u8>, Vec<u8>, bool)>> {
499        Ok(self
500            .query_row(
501                "SELECT room_id, data, backed_up FROM inbound_group_session WHERE session_id = ?",
502                (session_id,),
503                |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
504            )
505            .await
506            .optional()?)
507    }
508
509    async fn get_inbound_group_sessions(&self) -> Result<Vec<(Vec<u8>, bool)>> {
510        Ok(self
511            .prepare("SELECT data, backed_up FROM inbound_group_session", |mut stmt| {
512                stmt.query(())?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
513            })
514            .await?)
515    }
516
517    async fn get_inbound_group_session_counts(
518        &self,
519        _backup_version: Option<&str>,
520    ) -> Result<RoomKeyCounts> {
521        let total = self
522            .query_row("SELECT count(*) FROM inbound_group_session", (), |row| row.get(0))
523            .await?;
524        let backed_up = self
525            .query_row(
526                "SELECT count(*) FROM inbound_group_session WHERE backed_up = TRUE",
527                (),
528                |row| row.get(0),
529            )
530            .await?;
531        Ok(RoomKeyCounts { total, backed_up })
532    }
533
534    async fn get_inbound_group_sessions_for_device_batch(
535        &self,
536        sender_key: Key,
537        sender_data_type: SenderDataType,
538        after_session_id: Option<Key>,
539        limit: usize,
540    ) -> Result<Vec<(Vec<u8>, bool)>> {
541        Ok(self
542            .prepare(
543                "
544                SELECT data, backed_up
545                FROM inbound_group_session
546                WHERE sender_key = :sender_key
547                    AND sender_data_type = :sender_data_type
548                    AND session_id > :after_session_id
549                ORDER BY session_id
550                LIMIT :limit
551                ",
552                move |mut stmt| {
553                    let sender_data_type = sender_data_type as u8;
554
555                    // If we are not provided with an `after_session_id`, use a key which will sort
556                    // before all real keys: the empty string.
557                    let after_session_id = after_session_id.unwrap_or(Key::Plain(Vec::new()));
558
559                    stmt.query(named_params! {
560                        ":sender_key": sender_key,
561                        ":sender_data_type": sender_data_type,
562                        ":after_session_id": after_session_id,
563                        ":limit": limit,
564                    })?
565                    .mapped(|row| Ok((row.get(0)?, row.get(1)?)))
566                    .collect()
567                },
568            )
569            .await?)
570    }
571
572    async fn get_inbound_group_sessions_for_backup(&self, limit: usize) -> Result<Vec<Vec<u8>>> {
573        Ok(self
574            .prepare(
575                "SELECT data FROM inbound_group_session WHERE backed_up = FALSE LIMIT ?",
576                move |mut stmt| stmt.query((limit,))?.mapped(|row| row.get(0)).collect(),
577            )
578            .await?)
579    }
580
581    async fn mark_inbound_group_sessions_as_backed_up(&self, session_ids: Vec<Key>) -> Result<()> {
582        if session_ids.is_empty() {
583            // We are not expecting to be called with an empty list of sessions
584            warn!("No sessions to mark as backed up!");
585            return Ok(());
586        }
587
588        let session_ids_len = session_ids.len();
589
590        self.chunk_large_query_over(session_ids, None, move |txn, session_ids| {
591            // Safety: placeholders is not generated using any user input except the number
592            // of session IDs, so it is safe from injection.
593            let sql_params = repeat_vars(session_ids_len);
594            let query = format!("UPDATE inbound_group_session SET backed_up = TRUE where session_id IN ({sql_params})");
595            txn.prepare(&query)?.execute(params_from_iter(session_ids.iter()))?;
596            Ok(Vec::<()>::new())
597        }).await?;
598
599        Ok(())
600    }
601
602    async fn reset_inbound_group_session_backup_state(&self) -> Result<()> {
603        self.execute("UPDATE inbound_group_session SET backed_up = FALSE", ()).await?;
604        Ok(())
605    }
606
607    async fn get_outbound_group_session(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
608        Ok(self
609            .query_row(
610                "SELECT data FROM outbound_group_session WHERE room_id = ?",
611                (room_id,),
612                |row| row.get(0),
613            )
614            .await
615            .optional()?)
616    }
617
618    async fn get_device(&self, user_id: Key, device_id: Key) -> Result<Option<Vec<u8>>> {
619        Ok(self
620            .query_row(
621                "SELECT data FROM device WHERE user_id = ? AND device_id = ?",
622                (user_id, device_id),
623                |row| row.get(0),
624            )
625            .await
626            .optional()?)
627    }
628
629    async fn get_user_devices(&self, user_id: Key) -> Result<Vec<Vec<u8>>> {
630        Ok(self
631            .prepare("SELECT data FROM device WHERE user_id = ?", |mut stmt| {
632                stmt.query((user_id,))?.mapped(|row| row.get(0)).collect()
633            })
634            .await?)
635    }
636
637    async fn get_user_identity(&self, user_id: Key) -> Result<Option<Vec<u8>>> {
638        Ok(self
639            .query_row("SELECT data FROM identity WHERE user_id = ?", (user_id,), |row| row.get(0))
640            .await
641            .optional()?)
642    }
643
644    async fn has_olm_hash(&self, data: Vec<u8>) -> Result<bool> {
645        Ok(self
646            .query_row("SELECT count(*) FROM olm_hash WHERE data = ?", (data,), |row| {
647                row.get::<_, i32>(0)
648            })
649            .await?
650            > 0)
651    }
652
653    async fn get_tracked_users(&self) -> Result<Vec<Vec<u8>>> {
654        Ok(self
655            .prepare("SELECT data FROM tracked_user", |mut stmt| {
656                stmt.query(())?.mapped(|row| row.get(0)).collect()
657            })
658            .await?)
659    }
660
661    async fn add_tracked_users(&self, users: Vec<(Key, Vec<u8>)>) -> Result<()> {
662        Ok(self
663            .prepare(
664                "INSERT INTO tracked_user (user_id, data) \
665                 VALUES (?1, ?2) \
666                 ON CONFLICT (user_id) DO UPDATE SET data = ?2",
667                |mut stmt| {
668                    for (user_id, data) in users {
669                        stmt.execute((user_id, data))?;
670                    }
671
672                    Ok(())
673                },
674            )
675            .await?)
676    }
677
678    async fn get_outgoing_secret_request(
679        &self,
680        request_id: Key,
681    ) -> Result<Option<(Vec<u8>, bool)>> {
682        Ok(self
683            .query_row(
684                "SELECT data, sent_out FROM key_requests WHERE request_id = ?",
685                (request_id,),
686                |row| Ok((row.get(0)?, row.get(1)?)),
687            )
688            .await
689            .optional()?)
690    }
691
692    async fn get_outgoing_secret_requests(&self) -> Result<Vec<(Vec<u8>, bool)>> {
693        Ok(self
694            .prepare("SELECT data, sent_out FROM key_requests", |mut stmt| {
695                stmt.query(())?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
696            })
697            .await?)
698    }
699
700    async fn get_unsent_secret_requests(&self) -> Result<Vec<Vec<u8>>> {
701        Ok(self
702            .prepare("SELECT data FROM key_requests WHERE sent_out = FALSE", |mut stmt| {
703                stmt.query(())?.mapped(|row| row.get(0)).collect()
704            })
705            .await?)
706    }
707
708    async fn delete_key_request(&self, request_id: Key) -> Result<()> {
709        self.execute("DELETE FROM key_requests WHERE request_id = ?", (request_id,)).await?;
710        Ok(())
711    }
712
713    async fn get_secrets_from_inbox(&self, secret_name: Key) -> Result<Vec<Vec<u8>>> {
714        Ok(self
715            .prepare("SELECT data FROM secrets WHERE secret_name = ?", |mut stmt| {
716                stmt.query((secret_name,))?.mapped(|row| row.get(0)).collect()
717            })
718            .await?)
719    }
720
721    async fn delete_secrets_from_inbox(&self, secret_name: Key) -> Result<()> {
722        self.execute("DELETE FROM secrets WHERE secret_name = ?", (secret_name,)).await?;
723        Ok(())
724    }
725
726    async fn get_direct_withheld_info(
727        &self,
728        session_id: Key,
729        room_id: Key,
730    ) -> Result<Option<Vec<u8>>> {
731        Ok(self
732            .query_row(
733                "SELECT data FROM direct_withheld_info WHERE session_id = ?1 AND room_id = ?2",
734                (session_id, room_id),
735                |row| row.get(0),
736            )
737            .await
738            .optional()?)
739    }
740
741    async fn get_room_settings(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
742        Ok(self
743            .query_row("SELECT data FROM room_settings WHERE room_id = ?", (room_id,), |row| {
744                row.get(0)
745            })
746            .await
747            .optional()?)
748    }
749
750    async fn get_received_room_key_bundle(
751        &self,
752        room_id: Key,
753        sender_user: Key,
754    ) -> Result<Option<Vec<u8>>> {
755        Ok(self
756            .query_row(
757                "SELECT bundle_data FROM received_room_key_bundle WHERE room_id = ? AND sender_user_id = ?",
758                (room_id, sender_user),
759                |row| { row.get(0) },
760            )
761            .await
762            .optional()?)
763    }
764}
765
766#[async_trait]
767impl SqliteObjectCryptoStoreExt for SqliteAsyncConn {}
768
769#[async_trait]
770impl CryptoStore for SqliteCryptoStore {
771    type Error = Error;
772
773    async fn load_account(&self) -> Result<Option<Account>> {
774        let conn = self.acquire().await?;
775        if let Some(pickle) = conn.get_kv("account").await? {
776            let pickle = self.deserialize_value(&pickle)?;
777
778            let account = Account::from_pickle(pickle).map_err(|_| Error::Unpickle)?;
779
780            *self.static_account.write().unwrap() = Some(account.static_data().clone());
781
782            Ok(Some(account))
783        } else {
784            Ok(None)
785        }
786    }
787
788    async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
789        let conn = self.acquire().await?;
790        if let Some(i) = conn.get_kv("identity").await? {
791            let pickle = self.deserialize_value(&i)?;
792            Ok(Some(PrivateCrossSigningIdentity::from_pickle(pickle).map_err(|_| Error::Unpickle)?))
793        } else {
794            Ok(None)
795        }
796    }
797
798    async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
799        // Serialize calls to `save_pending_changes`; there are multiple await points
800        // below, and we're pickling data as we go, so we don't want to
801        // invalidate data we've previously read and overwrite it in the store.
802        // TODO: #2000 should make this lock go away, or change its shape.
803        let _guard = self.save_changes_lock.lock().await;
804
805        let pickled_account = if let Some(account) = changes.account {
806            *self.static_account.write().unwrap() = Some(account.static_data().clone());
807            Some(account.pickle())
808        } else {
809            None
810        };
811
812        let this = self.clone();
813        self.acquire()
814            .await?
815            .with_transaction(move |txn| {
816                if let Some(pickled_account) = pickled_account {
817                    let serialized_account = this.serialize_value(&pickled_account)?;
818                    txn.set_kv("account", &serialized_account)?;
819                }
820
821                Ok::<_, Error>(())
822            })
823            .await?;
824
825        Ok(())
826    }
827
828    async fn save_changes(&self, changes: Changes) -> Result<()> {
829        // Serialize calls to `save_changes`; there are multiple await points below, and
830        // we're pickling data as we go, so we don't want to invalidate data
831        // we've previously read and overwrite it in the store.
832        // TODO: #2000 should make this lock go away, or change its shape.
833        let _guard = self.save_changes_lock.lock().await;
834
835        let pickled_private_identity =
836            if let Some(i) = changes.private_identity { Some(i.pickle().await) } else { None };
837
838        let mut session_changes = Vec::new();
839
840        for session in changes.sessions {
841            let session_id = self.encode_key("session", session.session_id());
842            let sender_key = self.encode_key("session", session.sender_key().to_base64());
843            let pickle = session.pickle().await;
844            session_changes.push((session_id, sender_key, pickle));
845        }
846
847        let mut inbound_session_changes = Vec::new();
848        for session in changes.inbound_group_sessions {
849            let room_id = self.encode_key("inbound_group_session", session.room_id().as_bytes());
850            let session_id = self.encode_key("inbound_group_session", session.session_id());
851            let pickle = session.pickle().await;
852            let sender_key =
853                self.encode_key("inbound_group_session", session.sender_key().to_base64());
854            inbound_session_changes.push((room_id, session_id, pickle, sender_key));
855        }
856
857        let mut outbound_session_changes = Vec::new();
858        for session in changes.outbound_group_sessions {
859            let room_id = self.encode_key("outbound_group_session", session.room_id().as_bytes());
860            let pickle = session.pickle().await;
861            outbound_session_changes.push((room_id, pickle));
862        }
863
864        let this = self.clone();
865        self.acquire()
866            .await?
867            .with_transaction(move |txn| {
868                if let Some(pickled_private_identity) = &pickled_private_identity {
869                    let serialized_private_identity =
870                        this.serialize_value(pickled_private_identity)?;
871                    txn.set_kv("identity", &serialized_private_identity)?;
872                }
873
874                if let Some(token) = &changes.next_batch_token {
875                    let serialized_token = this.serialize_value(token)?;
876                    txn.set_kv("next_batch_token", &serialized_token)?;
877                }
878
879                if let Some(decryption_key) = &changes.backup_decryption_key {
880                    let serialized_decryption_key = this.serialize_value(decryption_key)?;
881                    txn.set_kv("recovery_key_v1", &serialized_decryption_key)?;
882                }
883
884                if let Some(backup_version) = &changes.backup_version {
885                    let serialized_backup_version = this.serialize_value(backup_version)?;
886                    txn.set_kv("backup_version_v1", &serialized_backup_version)?;
887                }
888
889                if let Some(pickle_key) = &changes.dehydrated_device_pickle_key {
890                    let serialized_pickle_key = this.serialize_value(pickle_key)?;
891                    txn.set_kv(DEHYDRATED_DEVICE_PICKLE_KEY, &serialized_pickle_key)?;
892                }
893
894                for device in changes.devices.new.iter().chain(&changes.devices.changed) {
895                    let user_id = this.encode_key("device", device.user_id().as_bytes());
896                    let device_id = this.encode_key("device", device.device_id().as_bytes());
897                    let data = this.serialize_value(&device)?;
898                    txn.set_device(&user_id, &device_id, &data)?;
899                }
900
901                for device in &changes.devices.deleted {
902                    let user_id = this.encode_key("device", device.user_id().as_bytes());
903                    let device_id = this.encode_key("device", device.device_id().as_bytes());
904                    txn.delete_device(&user_id, &device_id)?;
905                }
906
907                for identity in changes.identities.changed.iter().chain(&changes.identities.new) {
908                    let user_id = this.encode_key("identity", identity.user_id().as_bytes());
909                    let data = this.serialize_value(&identity)?;
910                    txn.set_identity(&user_id, &data)?;
911                }
912
913                for (session_id, sender_key, pickle) in &session_changes {
914                    let serialized_session = this.serialize_value(&pickle)?;
915                    txn.set_session(session_id, sender_key, &serialized_session)?;
916                }
917
918                for (room_id, session_id, pickle, sender_key) in &inbound_session_changes {
919                    let serialized_session = this.serialize_value(&pickle)?;
920                    txn.set_inbound_group_session(
921                        room_id,
922                        session_id,
923                        &serialized_session,
924                        pickle.backed_up,
925                        Some(sender_key),
926                        Some(pickle.sender_data.to_type() as u8),
927                    )?;
928                }
929
930                for (room_id, pickle) in &outbound_session_changes {
931                    let serialized_session = this.serialize_json(&pickle)?;
932                    txn.set_outbound_group_session(room_id, &serialized_session)?;
933                }
934
935                for hash in &changes.message_hashes {
936                    let hash = rmp_serde::to_vec(hash)?;
937                    txn.add_olm_hash(&hash)?;
938                }
939
940                for request in changes.key_requests {
941                    let request_id = this.encode_key("key_requests", request.request_id.as_bytes());
942                    let serialized_request = this.serialize_value(&request)?;
943                    txn.set_key_request(&request_id, request.sent_out, &serialized_request)?;
944                }
945
946                for (room_id, data) in changes.withheld_session_info {
947                    for (session_id, event) in data {
948                        let session_id = this.encode_key("direct_withheld_info", session_id);
949                        let room_id = this.encode_key("direct_withheld_info", &room_id);
950                        let serialized_info = this.serialize_json(&event)?;
951                        txn.set_direct_withheld(&session_id, &room_id, &serialized_info)?;
952                    }
953                }
954
955                for (room_id, settings) in changes.room_settings {
956                    let room_id = this.encode_key("room_settings", room_id.as_bytes());
957                    let value = this.serialize_value(&settings)?;
958                    txn.set_room_settings(&room_id, &value)?;
959                }
960
961                for secret in changes.secrets {
962                    let secret_name = this.encode_key("secrets", secret.secret_name.to_string());
963                    let value = this.serialize_json(&secret)?;
964                    txn.set_secret(&secret_name, &value)?;
965                }
966
967                for bundle in changes.received_room_key_bundles {
968                    let room_id =
969                        this.encode_key("received_room_key_bundle", &bundle.bundle_data.room_id);
970                    let user_id = this.encode_key("received_room_key_bundle", &bundle.sender_user);
971                    let value = this.serialize_value(&bundle)?;
972                    txn.set_received_room_key_bundle(&room_id, &user_id, &value)?;
973                }
974
975                Ok::<_, Error>(())
976            })
977            .await?;
978
979        Ok(())
980    }
981
982    async fn save_inbound_group_sessions(
983        &self,
984        sessions: Vec<InboundGroupSession>,
985        backed_up_to_version: Option<&str>,
986    ) -> matrix_sdk_crypto::store::Result<(), Self::Error> {
987        // Sanity-check that the data in the sessions corresponds to backed_up_version
988        sessions.iter().for_each(|s| {
989            let backed_up = s.backed_up();
990            if backed_up != backed_up_to_version.is_some() {
991                warn!(
992                    backed_up,
993                    backed_up_to_version,
994                    "Session backed-up flag does not correspond to backup version setting",
995                );
996            }
997        });
998
999        // Currently, this store doesn't save the backup version separately, so this
1000        // just delegates to save_changes.
1001        self.save_changes(Changes { inbound_group_sessions: sessions, ..Changes::default() }).await
1002    }
1003
1004    async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
1005        let device_keys = self.get_own_device().await?.as_device_keys().clone();
1006
1007        let sessions: Vec<_> = self
1008            .acquire()
1009            .await?
1010            .get_sessions_for_sender_key(self.encode_key("session", sender_key.as_bytes()))
1011            .await?
1012            .into_iter()
1013            .map(|bytes| {
1014                let pickle = self.deserialize_value(&bytes)?;
1015                Session::from_pickle(device_keys.clone(), pickle).map_err(|_| Error::AccountUnset)
1016            })
1017            .collect::<Result<_>>()?;
1018
1019        if sessions.is_empty() {
1020            Ok(None)
1021        } else {
1022            Ok(Some(sessions))
1023        }
1024    }
1025
1026    #[instrument(skip(self))]
1027    async fn get_inbound_group_session(
1028        &self,
1029        room_id: &RoomId,
1030        session_id: &str,
1031    ) -> Result<Option<InboundGroupSession>> {
1032        let session_id = self.encode_key("inbound_group_session", session_id);
1033        let Some((room_id_from_db, value, backed_up)) =
1034            self.acquire().await?.get_inbound_group_session(session_id).await?
1035        else {
1036            return Ok(None);
1037        };
1038
1039        let room_id = self.encode_key("inbound_group_session", room_id.as_bytes());
1040        if *room_id != room_id_from_db {
1041            warn!("expected room_id for session_id doesn't match what's in the DB");
1042            return Ok(None);
1043        }
1044
1045        Ok(Some(self.deserialize_and_unpickle_inbound_group_session(value, backed_up)?))
1046    }
1047
1048    async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
1049        self.acquire()
1050            .await?
1051            .get_inbound_group_sessions()
1052            .await?
1053            .into_iter()
1054            .map(|(value, backed_up)| {
1055                self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1056            })
1057            .collect()
1058    }
1059
1060    async fn get_inbound_group_sessions_for_device_batch(
1061        &self,
1062        sender_key: Curve25519PublicKey,
1063        sender_data_type: SenderDataType,
1064        after_session_id: Option<String>,
1065        limit: usize,
1066    ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1067        let after_session_id =
1068            after_session_id.map(|session_id| self.encode_key("inbound_group_session", session_id));
1069        let sender_key = self.encode_key("inbound_group_session", sender_key.to_base64());
1070
1071        self.acquire()
1072            .await?
1073            .get_inbound_group_sessions_for_device_batch(
1074                sender_key,
1075                sender_data_type,
1076                after_session_id,
1077                limit,
1078            )
1079            .await?
1080            .into_iter()
1081            .map(|(value, backed_up)| {
1082                self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1083            })
1084            .collect()
1085    }
1086
1087    async fn inbound_group_session_counts(
1088        &self,
1089        backup_version: Option<&str>,
1090    ) -> Result<RoomKeyCounts> {
1091        Ok(self.acquire().await?.get_inbound_group_session_counts(backup_version).await?)
1092    }
1093
1094    async fn inbound_group_sessions_for_backup(
1095        &self,
1096        _backup_version: &str,
1097        limit: usize,
1098    ) -> Result<Vec<InboundGroupSession>> {
1099        self.acquire()
1100            .await?
1101            .get_inbound_group_sessions_for_backup(limit)
1102            .await?
1103            .into_iter()
1104            .map(|value| self.deserialize_and_unpickle_inbound_group_session(value, false))
1105            .collect()
1106    }
1107
1108    async fn mark_inbound_group_sessions_as_backed_up(
1109        &self,
1110        _backup_version: &str,
1111        session_ids: &[(&RoomId, &str)],
1112    ) -> Result<()> {
1113        Ok(self
1114            .acquire()
1115            .await?
1116            .mark_inbound_group_sessions_as_backed_up(
1117                session_ids
1118                    .iter()
1119                    .map(|(_, s)| self.encode_key("inbound_group_session", s))
1120                    .collect(),
1121            )
1122            .await?)
1123    }
1124
1125    async fn reset_backup_state(&self) -> Result<()> {
1126        Ok(self.acquire().await?.reset_inbound_group_session_backup_state().await?)
1127    }
1128
1129    async fn load_backup_keys(&self) -> Result<BackupKeys> {
1130        let conn = self.acquire().await?;
1131
1132        let backup_version = conn
1133            .get_kv("backup_version_v1")
1134            .await?
1135            .map(|value| self.deserialize_value(&value))
1136            .transpose()?;
1137
1138        let decryption_key = conn
1139            .get_kv("recovery_key_v1")
1140            .await?
1141            .map(|value| self.deserialize_value(&value))
1142            .transpose()?;
1143
1144        Ok(BackupKeys { backup_version, decryption_key })
1145    }
1146
1147    async fn load_dehydrated_device_pickle_key(&self) -> Result<Option<DehydratedDeviceKey>> {
1148        let conn = self.acquire().await?;
1149
1150        conn.get_kv(DEHYDRATED_DEVICE_PICKLE_KEY)
1151            .await?
1152            .map(|value| self.deserialize_value(&value))
1153            .transpose()
1154    }
1155
1156    async fn delete_dehydrated_device_pickle_key(&self) -> Result<(), Self::Error> {
1157        let conn = self.acquire().await?;
1158        conn.clear_kv(DEHYDRATED_DEVICE_PICKLE_KEY).await?;
1159
1160        Ok(())
1161    }
1162    async fn get_outbound_group_session(
1163        &self,
1164        room_id: &RoomId,
1165    ) -> Result<Option<OutboundGroupSession>> {
1166        let room_id = self.encode_key("outbound_group_session", room_id.as_bytes());
1167        let Some(value) = self.acquire().await?.get_outbound_group_session(room_id).await? else {
1168            return Ok(None);
1169        };
1170
1171        let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1172
1173        let pickle = self.deserialize_json(&value)?;
1174        let session = OutboundGroupSession::from_pickle(
1175            account_info.device_id,
1176            account_info.identity_keys,
1177            pickle,
1178        )
1179        .map_err(|_| Error::Unpickle)?;
1180
1181        return Ok(Some(session));
1182    }
1183
1184    async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
1185        self.acquire()
1186            .await?
1187            .get_tracked_users()
1188            .await?
1189            .iter()
1190            .map(|value| self.deserialize_value(value))
1191            .collect()
1192    }
1193
1194    async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
1195        let users: Vec<(Key, Vec<u8>)> = tracked_users
1196            .iter()
1197            .map(|(u, d)| {
1198                let user_id = self.encode_key("tracked_users", u.as_bytes());
1199                let data =
1200                    self.serialize_value(&TrackedUser { user_id: (*u).into(), dirty: *d })?;
1201                Ok((user_id, data))
1202            })
1203            .collect::<Result<_>>()?;
1204
1205        Ok(self.acquire().await?.add_tracked_users(users).await?)
1206    }
1207
1208    async fn get_device(
1209        &self,
1210        user_id: &UserId,
1211        device_id: &DeviceId,
1212    ) -> Result<Option<DeviceData>> {
1213        let user_id = self.encode_key("device", user_id.as_bytes());
1214        let device_id = self.encode_key("device", device_id.as_bytes());
1215        Ok(self
1216            .acquire()
1217            .await?
1218            .get_device(user_id, device_id)
1219            .await?
1220            .map(|value| self.deserialize_value(&value))
1221            .transpose()?)
1222    }
1223
1224    async fn get_user_devices(
1225        &self,
1226        user_id: &UserId,
1227    ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1228        let user_id = self.encode_key("device", user_id.as_bytes());
1229        self.acquire()
1230            .await?
1231            .get_user_devices(user_id)
1232            .await?
1233            .into_iter()
1234            .map(|value| {
1235                let device: DeviceData = self.deserialize_value(&value)?;
1236                Ok((device.device_id().to_owned(), device))
1237            })
1238            .collect()
1239    }
1240
1241    async fn get_own_device(&self) -> Result<DeviceData> {
1242        let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1243
1244        Ok(self
1245            .get_device(&account_info.user_id, &account_info.device_id)
1246            .await?
1247            .expect("We should be able to find our own device."))
1248    }
1249
1250    async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
1251        let user_id = self.encode_key("identity", user_id.as_bytes());
1252        Ok(self
1253            .acquire()
1254            .await?
1255            .get_user_identity(user_id)
1256            .await?
1257            .map(|value| self.deserialize_value(&value))
1258            .transpose()?)
1259    }
1260
1261    async fn is_message_known(
1262        &self,
1263        message_hash: &matrix_sdk_crypto::olm::OlmMessageHash,
1264    ) -> Result<bool> {
1265        let value = rmp_serde::to_vec(message_hash)?;
1266        Ok(self.acquire().await?.has_olm_hash(value).await?)
1267    }
1268
1269    async fn get_outgoing_secret_requests(
1270        &self,
1271        request_id: &TransactionId,
1272    ) -> Result<Option<GossipRequest>> {
1273        let request_id = self.encode_key("key_requests", request_id.as_bytes());
1274        Ok(self
1275            .acquire()
1276            .await?
1277            .get_outgoing_secret_request(request_id)
1278            .await?
1279            .map(|(value, sent_out)| self.deserialize_key_request(&value, sent_out))
1280            .transpose()?)
1281    }
1282
1283    async fn get_secret_request_by_info(
1284        &self,
1285        key_info: &SecretInfo,
1286    ) -> Result<Option<GossipRequest>> {
1287        let requests = self.acquire().await?.get_outgoing_secret_requests().await?;
1288        for (request, sent_out) in requests {
1289            let request = self.deserialize_key_request(&request, sent_out)?;
1290            if request.info == *key_info {
1291                return Ok(Some(request));
1292            }
1293        }
1294        Ok(None)
1295    }
1296
1297    async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
1298        self.acquire()
1299            .await?
1300            .get_unsent_secret_requests()
1301            .await?
1302            .iter()
1303            .map(|value| {
1304                let request = self.deserialize_key_request(value, false)?;
1305                Ok(request)
1306            })
1307            .collect()
1308    }
1309
1310    async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
1311        let request_id = self.encode_key("key_requests", request_id.as_bytes());
1312        Ok(self.acquire().await?.delete_key_request(request_id).await?)
1313    }
1314
1315    async fn get_secrets_from_inbox(
1316        &self,
1317        secret_name: &SecretName,
1318    ) -> Result<Vec<GossippedSecret>> {
1319        let secret_name = self.encode_key("secrets", secret_name.to_string());
1320
1321        self.acquire()
1322            .await?
1323            .get_secrets_from_inbox(secret_name)
1324            .await?
1325            .into_iter()
1326            .map(|value| self.deserialize_json(value.as_ref()))
1327            .collect()
1328    }
1329
1330    async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
1331        let secret_name = self.encode_key("secrets", secret_name.to_string());
1332        self.acquire().await?.delete_secrets_from_inbox(secret_name).await
1333    }
1334
1335    async fn get_withheld_info(
1336        &self,
1337        room_id: &RoomId,
1338        session_id: &str,
1339    ) -> Result<Option<RoomKeyWithheldEvent>> {
1340        let room_id = self.encode_key("direct_withheld_info", room_id);
1341        let session_id = self.encode_key("direct_withheld_info", session_id);
1342
1343        self.acquire()
1344            .await?
1345            .get_direct_withheld_info(session_id, room_id)
1346            .await?
1347            .map(|value| {
1348                let info = self.deserialize_json::<RoomKeyWithheldEvent>(&value)?;
1349                Ok(info)
1350            })
1351            .transpose()
1352    }
1353
1354    async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
1355        let room_id = self.encode_key("room_settings", room_id.as_bytes());
1356        let Some(value) = self.acquire().await?.get_room_settings(room_id).await? else {
1357            return Ok(None);
1358        };
1359
1360        let settings = self.deserialize_value(&value)?;
1361
1362        return Ok(Some(settings));
1363    }
1364
1365    async fn get_received_room_key_bundle_data(
1366        &self,
1367        room_id: &RoomId,
1368        user_id: &UserId,
1369    ) -> Result<Option<StoredRoomKeyBundleData>> {
1370        let room_id = self.encode_key("received_room_key_bundle", room_id);
1371        let user_id = self.encode_key("received_room_key_bundle", user_id);
1372        self.acquire()
1373            .await?
1374            .get_received_room_key_bundle(room_id, user_id)
1375            .await?
1376            .map(|value| self.deserialize_value(&value))
1377            .transpose()
1378    }
1379
1380    async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
1381        let Some(serialized) = self.acquire().await?.get_kv(key).await? else {
1382            return Ok(None);
1383        };
1384        let value = if let Some(cipher) = &self.store_cipher {
1385            let encrypted = rmp_serde::from_slice(&serialized)?;
1386            cipher.decrypt_value_data(encrypted)?
1387        } else {
1388            serialized
1389        };
1390
1391        Ok(Some(value))
1392    }
1393
1394    async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
1395        let serialized = if let Some(cipher) = &self.store_cipher {
1396            let encrypted = cipher.encrypt_value_data(value)?;
1397            rmp_serde::to_vec_named(&encrypted)?
1398        } else {
1399            value
1400        };
1401
1402        self.acquire().await?.set_kv(key, serialized).await?;
1403        Ok(())
1404    }
1405
1406    async fn remove_custom_value(&self, key: &str) -> Result<()> {
1407        let key = key.to_owned();
1408        self.acquire()
1409            .await?
1410            .interact(move |conn| conn.execute("DELETE FROM kv WHERE key = ?1", (&key,)))
1411            .await
1412            .unwrap()?;
1413        Ok(())
1414    }
1415
1416    async fn try_take_leased_lock(
1417        &self,
1418        lease_duration_ms: u32,
1419        key: &str,
1420        holder: &str,
1421    ) -> Result<bool> {
1422        let key = key.to_owned();
1423        let holder = holder.to_owned();
1424
1425        let now_ts: u64 = MilliSecondsSinceUnixEpoch::now().get().into();
1426        let expiration_ts = now_ts + lease_duration_ms as u64;
1427
1428        let num_touched = self
1429            .acquire()
1430            .await?
1431            .with_transaction(move |txn| {
1432                txn.execute(
1433                    "INSERT INTO lease_locks (key, holder, expiration_ts)
1434                    VALUES (?1, ?2, ?3)
1435                    ON CONFLICT (key)
1436                    DO
1437                        UPDATE SET holder = ?2, expiration_ts = ?3
1438                        WHERE holder = ?2
1439                        OR expiration_ts < ?4
1440                ",
1441                    (key, holder, expiration_ts, now_ts),
1442                )
1443            })
1444            .await?;
1445
1446        Ok(num_touched == 1)
1447    }
1448
1449    async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
1450        let conn = self.acquire().await?;
1451        if let Some(token) = conn.get_kv("next_batch_token").await? {
1452            let maybe_token: Option<String> = self.deserialize_value(&token)?;
1453            Ok(maybe_token)
1454        } else {
1455            Ok(None)
1456        }
1457    }
1458}
1459
1460#[cfg(test)]
1461mod tests {
1462    use std::path::Path;
1463
1464    use matrix_sdk_common::deserialized_responses::WithheldCode;
1465    use matrix_sdk_crypto::{
1466        cryptostore_integration_tests, cryptostore_integration_tests_time, olm::SenderDataType,
1467        store::CryptoStore,
1468    };
1469    use matrix_sdk_test::async_test;
1470    use once_cell::sync::Lazy;
1471    use ruma::{device_id, room_id, user_id};
1472    use similar_asserts::assert_eq;
1473    use tempfile::{tempdir, TempDir};
1474    use tokio::fs;
1475
1476    use super::SqliteCryptoStore;
1477    use crate::SqliteStoreConfig;
1478
1479    static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
1480
1481    struct TestDb {
1482        // Needs to be kept alive because the Drop implementation for TempDir deletes the
1483        // directory.
1484        _dir: TempDir,
1485        database: SqliteCryptoStore,
1486    }
1487
1488    fn copy_db(data_path: &str) -> TempDir {
1489        let db_name = super::DATABASE_NAME;
1490
1491        let manifest_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../..");
1492        let database_path = manifest_path.join(data_path).join(db_name);
1493
1494        let tmpdir = tempdir().unwrap();
1495        let destination = tmpdir.path().join(db_name);
1496
1497        // Copy the test database to the tempdir so our test runs are idempotent.
1498        std::fs::copy(&database_path, destination).unwrap();
1499
1500        tmpdir
1501    }
1502
1503    async fn get_test_db(data_path: &str, passphrase: Option<&str>) -> TestDb {
1504        let tmpdir = copy_db(data_path);
1505
1506        let database = SqliteCryptoStore::open(tmpdir.path(), passphrase)
1507            .await
1508            .expect("Can't open the test store");
1509
1510        TestDb { _dir: tmpdir, database }
1511    }
1512
1513    #[async_test]
1514    async fn test_pool_size() {
1515        let store_open_config =
1516            SqliteStoreConfig::new(TMP_DIR.path().join("test_pool_size")).pool_max_size(42);
1517
1518        let store = SqliteCryptoStore::open_with_config(store_open_config).await.unwrap();
1519
1520        assert_eq!(store.pool.status().max_size, 42);
1521    }
1522
1523    /// Test that we didn't regress in our storage layer by loading data from a
1524    /// pre-filled database, or in other words use a test vector for this.
1525    #[async_test]
1526    async fn test_open_test_vector_store() {
1527        let TestDb { _dir: _, database } = get_test_db("testing/data/storage", None).await;
1528
1529        let account = database
1530            .load_account()
1531            .await
1532            .unwrap()
1533            .expect("The test database is prefilled with data, we should find an account");
1534
1535        let user_id = account.user_id();
1536        let device_id = account.device_id();
1537
1538        assert_eq!(
1539            user_id.as_str(),
1540            "@pjtest:synapse-oidc.element.dev",
1541            "The user ID should match to the one we expect."
1542        );
1543
1544        assert_eq!(
1545            device_id.as_str(),
1546            "v4TqgcuIH6",
1547            "The device ID should match to the one we expect."
1548        );
1549
1550        let device = database
1551            .get_device(user_id, device_id)
1552            .await
1553            .unwrap()
1554            .expect("Our own device should be found in the store.");
1555
1556        assert_eq!(device.device_id(), device_id);
1557        assert_eq!(device.user_id(), user_id);
1558
1559        assert_eq!(
1560            device.ed25519_key().expect("The device should have a Ed25519 key.").to_base64(),
1561            "+cxl1Gl3du5i7UJwfWnoRDdnafFF+xYdAiTYYhYLr8s"
1562        );
1563
1564        assert_eq!(
1565            device.curve25519_key().expect("The device should have a Curve25519 key.").to_base64(),
1566            "4SL9eEUlpyWSUvjljC5oMjknHQQJY7WZKo5S1KL/5VU"
1567        );
1568
1569        let identity = database
1570            .get_user_identity(user_id)
1571            .await
1572            .unwrap()
1573            .expect("The store should contain an identity.");
1574
1575        assert_eq!(identity.user_id(), user_id);
1576
1577        let identity = identity
1578            .own()
1579            .expect("The identity should be of the correct type, it should be our own identity.");
1580
1581        let master_key = identity
1582            .master_key()
1583            .get_first_key()
1584            .expect("Our own identity should have a master key");
1585
1586        assert_eq!(master_key.to_base64(), "iCUEtB1RwANeqRa5epDrblLk4mer/36sylwQ5hYY3oE");
1587    }
1588
1589    /// Test that we didn't regress in our storage layer by loading data from a
1590    /// pre-filled database, or in other words use a test vector for this.
1591    #[async_test]
1592    async fn test_open_test_vector_encrypted_store() {
1593        let TestDb { _dir: _, database } = get_test_db(
1594            "testing/data/storage/alice",
1595            Some(concat!(
1596                "/rCia2fYAJ+twCZ1Xm2mxFCYcmJdyzkdJjwtgXsziWpYS/UeNxnixuSieuwZXm+x1VsJHmWpl",
1597                "H+QIQBZpEGZtC9/S/l8xK+WOCesmET0o6yJ/KP73ofDtjBlnNpPwuHLKFpyTbyicpCgQ4UT+5E",
1598                "UBuJ08TY9Ujdf1D13k5kr5tSZUefDKKCuG1fCRqlU8ByRas1PMQsZxT2W8t7QgBrQiiGmhpo/O",
1599                "Ti4hfx97GOxncKcxTzppiYQNoHs/f15+XXQD7/oiCcqRIuUlXNsU6hRpFGmbYx2Pi1eyQViQCt",
1600                "B5dAEiSD0N8U81wXYnpynuTPtnL+hfnOJIn7Sy7mkERQeKg"
1601            )),
1602        )
1603        .await;
1604
1605        let account = database
1606            .load_account()
1607            .await
1608            .unwrap()
1609            .expect("The test database is prefilled with data, we should find an account");
1610
1611        let user_id = account.user_id();
1612        let device_id = account.device_id();
1613
1614        assert_eq!(
1615            user_id.as_str(),
1616            "@alice:localhost",
1617            "The user ID should match to the one we expect."
1618        );
1619
1620        assert_eq!(
1621            device_id.as_str(),
1622            "JVVORTHFXY",
1623            "The device ID should match to the one we expect."
1624        );
1625
1626        let tracked_users =
1627            database.load_tracked_users().await.expect("Should be tracking some users");
1628
1629        assert_eq!(tracked_users.len(), 6);
1630
1631        let known_users = vec![
1632            user_id!("@alice:localhost"),
1633            user_id!("@dehydration3:localhost"),
1634            user_id!("@eve:localhost"),
1635            user_id!("@bob:localhost"),
1636            user_id!("@malo:localhost"),
1637            user_id!("@carl:localhost"),
1638        ];
1639
1640        // load the identities
1641        for user_id in known_users {
1642            database.get_user_identity(user_id).await.expect("Should load this identity").unwrap();
1643        }
1644
1645        let carl_identity =
1646            database.get_user_identity(user_id!("@carl:localhost")).await.unwrap().unwrap();
1647
1648        assert_eq!(
1649            carl_identity.master_key().get_first_key().unwrap().to_base64(),
1650            "CdhKYYDeBDQveOioXEGWhTPCyzc63Irpar3CNyfun2Q"
1651        );
1652        assert!(!carl_identity.was_previously_verified());
1653
1654        let bob_identity =
1655            database.get_user_identity(user_id!("@bob:localhost")).await.unwrap().unwrap();
1656
1657        assert_eq!(
1658            bob_identity.master_key().get_first_key().unwrap().to_base64(),
1659            "COh2GYOJWSjem5QPRCaGp9iWV83IELG1IzLKW2S3pFY"
1660        );
1661        // Bob is verified so this flag should be set
1662        assert!(bob_identity.was_previously_verified());
1663
1664        let known_devices = vec![
1665            (device_id!("OPXQHCZSKW"), user_id!("@alice:localhost")),
1666            // a dehydrated one
1667            (
1668                device_id!("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho"),
1669                user_id!("@alice:localhost"),
1670            ),
1671            (device_id!("HEEFRFQENV"), user_id!("@alice:localhost")),
1672            (device_id!("JVVORTHFXY"), user_id!("@alice:localhost")),
1673            (device_id!("NQUWWSKKHS"), user_id!("@alice:localhost")),
1674            (device_id!("ORBLPFYCPG"), user_id!("@alice:localhost")),
1675            (device_id!("YXOWENSEGM"), user_id!("@dehydration3:localhost")),
1676            (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
1677            (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
1678            (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
1679            (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
1680            (device_id!("QKUKWJTTQC"), user_id!("@malo:localhost")),
1681            (device_id!("LOUXJECTFG"), user_id!("@malo:localhost")),
1682            (device_id!("MKKMAEVLPB"), user_id!("@carl:localhost")),
1683        ];
1684
1685        for (device_id, user_id) in known_devices {
1686            database.get_device(user_id, device_id).await.expect("Should load the device").unwrap();
1687        }
1688
1689        let known_sender_key_to_session_count = vec![
1690            ("FfYcYfDF4nWy+LHdK6CEpIMlFAQDORc30WUkghL06kM", 1),
1691            ("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho", 1),
1692            ("hAGsoA4a9M6wwEUX5Q1jux1i+tUngLi01n5AmhDoHTY", 1),
1693            ("aKqtSJymLzuoglWFwPGk1r/Vm2LE2hFESzXxn4RNjRM", 0),
1694            ("zHK1psCrgeMn0kaz8hcdvA3INyar9jg1yfrSp0p1pHo", 1),
1695            ("1QmBA316Wj5jIFRwNOti6N6Xh/vW0bsYCcR4uPfy8VQ", 1),
1696            ("g5ef2vZF3VXgSPyODIeXpyHIRkuthvLhGvd6uwYggWU", 1),
1697            ("o7hfupPd1VsNkRIvdlH6ujrEJFSKjFCGbxhAd31XxjI", 1),
1698            ("Z3RxKQLxY7xpP+ZdOGR2SiNE37SrvmRhW7GPu1UGdm8", 1),
1699            ("GDomaav8NiY3J+dNEeApJm+O0FooJ3IpVaIyJzCN4w4", 1),
1700            ("7m7fqkHyEr47V5s/KjaxtJMOr3pSHrrns2q2lWpAQi8", 0),
1701            ("9psAkPUIF8vNbWbnviX3PlwRcaeO53EHJdNtKpTY1X0", 0),
1702            ("mqanh+ztw5oRtpqYQgLGW864i6NY2zpoKMIlrcyC+Aw", 0),
1703            ("fJU/TJdbsv7tVbbpHw1Ke73ziElnM32cNhP2WIg4T10", 0),
1704            ("sUIeFeFcCZoa5IC6nJ6Vrbvztcyx09m8BBg57XKRClg", 1),
1705        ];
1706
1707        for (id, count) in known_sender_key_to_session_count {
1708            let olm_sessions =
1709                database.get_sessions(id).await.expect("Should have some olm sessions");
1710
1711            println!("### Session id: {id:?}");
1712            assert_eq!(olm_sessions.map_or(0, |v| v.len()), count);
1713        }
1714
1715        let inbound_group_sessions = database.get_inbound_group_sessions().await.unwrap();
1716        assert_eq!(inbound_group_sessions.len(), 15);
1717        let known_inbound_group_sessions = vec![
1718            (
1719                "5hNAxrLai3VI0LKBwfh3wLfksfBFWds0W1a5X5/vSXA",
1720                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1721            ),
1722            (
1723                "M6d2eU3y54gaYTbvGSlqa/xc1Az35l56Cp9sxzHWO4g",
1724                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1725            ),
1726            (
1727                "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
1728                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1729            ),
1730            (
1731                "Y74+l9jTo7N5UF+GQwdpgJGe4sn1+QtWITq7BxulHIE",
1732                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1733            ),
1734            (
1735                "HpJxQR57WbQGdY6w2Q+C16znVvbXGa+JvQdRoMpWbXg",
1736                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1737            ),
1738            (
1739                "Xetvi+ydFkZt8dpONGFbEusQb/Chc2V0XlLByZhsbgE",
1740                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1741            ),
1742            (
1743                "wv/WN/39akyerIXczTaIpjAuLnwgXKRtbXFSEHiJqxo",
1744                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1745            ),
1746            (
1747                "nA4gQwL//Cm8OdlyjABl/jChbPT/cP5V4Sd8iuE6H0s",
1748                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1749            ),
1750            (
1751                "bAAgqFeRDTjfEqL6Qf/c9mk55zoNDCSlboAIRd6b0hw",
1752                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1753            ),
1754            (
1755                "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
1756                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1757            ),
1758            (
1759                "h+om7oSw/ZV94fcKaoe8FGXJwQXWOfKQfzbGgNWQILI",
1760                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1761            ),
1762            (
1763                "ul3VXonpgk4lO2L3fEWubP/nxsTmLHqu5v8ZM9vHEcw",
1764                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1765            ),
1766            (
1767                "JXY15UxC3az2mwg8uX4qwgxfvCM4aygiIWMcdNiVQoc",
1768                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1769            ),
1770            (
1771                "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
1772                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1773            ),
1774            (
1775                "SFkHcbxjUOYF7mUAYI/oEMDZFaXszQbCN6Jza7iemj0",
1776                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1777            ),
1778        ];
1779
1780        // ensure we can load them all
1781        for (session_id, room_id) in &known_inbound_group_sessions {
1782            database
1783                .get_inbound_group_session(room_id, session_id)
1784                .await
1785                .expect("Should be able to load inbound group session")
1786                .unwrap();
1787        }
1788
1789        let bob_sender_verified = database
1790            .get_inbound_group_session(
1791                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1792                "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
1793            )
1794            .await
1795            .unwrap()
1796            .unwrap();
1797
1798        assert_eq!(bob_sender_verified.sender_data.to_type(), SenderDataType::SenderVerified);
1799        assert!(bob_sender_verified.backed_up());
1800        assert!(!bob_sender_verified.has_been_imported());
1801
1802        let alice_unknown_device = database
1803            .get_inbound_group_session(
1804                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1805                "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
1806            )
1807            .await
1808            .unwrap()
1809            .unwrap();
1810
1811        assert_eq!(alice_unknown_device.sender_data.to_type(), SenderDataType::UnknownDevice);
1812        assert!(alice_unknown_device.backed_up());
1813        assert!(alice_unknown_device.has_been_imported());
1814
1815        let carl_tofu_session = database
1816            .get_inbound_group_session(
1817                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1818                "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
1819            )
1820            .await
1821            .unwrap()
1822            .unwrap();
1823
1824        assert_eq!(carl_tofu_session.sender_data.to_type(), SenderDataType::SenderUnverified);
1825        assert!(carl_tofu_session.backed_up());
1826        assert!(!carl_tofu_session.has_been_imported());
1827
1828        // Load outbound sessions
1829        database
1830            .get_outbound_group_session(room_id!("!OgRiTRMaUzLdpCeDBM:localhost"))
1831            .await
1832            .unwrap()
1833            .unwrap();
1834        database
1835            .get_outbound_group_session(room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"))
1836            .await
1837            .unwrap()
1838            .unwrap();
1839        database
1840            .get_outbound_group_session(room_id!("!SRstFdydzrGwJYtVfm:localhost"))
1841            .await
1842            .unwrap()
1843            .unwrap();
1844
1845        let withheld_info = database
1846            .get_withheld_info(
1847                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1848                "SASgZ+EklvAF4QxJclMlDRlmL0fAMjAJJIKFMdb4Ht0",
1849            )
1850            .await
1851            .expect("This session should be withheld")
1852            .unwrap();
1853
1854        assert_eq!(withheld_info.content.withheld_code(), WithheldCode::Unverified);
1855
1856        let backup_keys = database.load_backup_keys().await.expect("backup key should be cached");
1857        assert_eq!(backup_keys.backup_version.unwrap(), "6");
1858        assert!(backup_keys.decryption_key.is_some());
1859    }
1860
1861    async fn get_store(
1862        name: &str,
1863        passphrase: Option<&str>,
1864        clear_data: bool,
1865    ) -> SqliteCryptoStore {
1866        let tmpdir_path = TMP_DIR.path().join(name);
1867
1868        if clear_data {
1869            let _ = fs::remove_dir_all(&tmpdir_path).await;
1870        }
1871
1872        SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), passphrase)
1873            .await
1874            .expect("Can't create a passphrase protected store")
1875    }
1876
1877    cryptostore_integration_tests!();
1878    cryptostore_integration_tests_time!();
1879}
1880
1881#[cfg(test)]
1882mod encrypted_tests {
1883    use matrix_sdk_crypto::{cryptostore_integration_tests, cryptostore_integration_tests_time};
1884    use once_cell::sync::Lazy;
1885    use tempfile::{tempdir, TempDir};
1886    use tokio::fs;
1887
1888    use super::SqliteCryptoStore;
1889
1890    static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
1891
1892    async fn get_store(
1893        name: &str,
1894        passphrase: Option<&str>,
1895        clear_data: bool,
1896    ) -> SqliteCryptoStore {
1897        let tmpdir_path = TMP_DIR.path().join(name);
1898        let pass = passphrase.unwrap_or("default_test_password");
1899
1900        if clear_data {
1901            let _ = fs::remove_dir_all(&tmpdir_path).await;
1902        }
1903
1904        SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), Some(pass))
1905            .await
1906            .expect("Can't create a passphrase protected store")
1907    }
1908
1909    cryptostore_integration_tests!();
1910    cryptostore_integration_tests_time!();
1911}