matrix_sdk_sqlite/
crypto_store.rs

1// Copyright 2022 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::HashMap,
17    fmt,
18    path::Path,
19    sync::{Arc, RwLock},
20};
21
22use async_trait::async_trait;
23use deadpool_sqlite::{Object as SqliteAsyncConn, Pool as SqlitePool, Runtime};
24use matrix_sdk_crypto::{
25    olm::{
26        InboundGroupSession, OutboundGroupSession, PickledInboundGroupSession,
27        PrivateCrossSigningIdentity, SenderDataType, Session, StaticAccountData,
28    },
29    store::{
30        types::{
31            BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts, RoomSettings,
32            StoredRoomKeyBundleData,
33        },
34        CryptoStore,
35    },
36    types::events::room_key_withheld::RoomKeyWithheldEvent,
37    Account, DeviceData, GossipRequest, GossippedSecret, SecretInfo, TrackedUser, UserIdentityData,
38};
39use matrix_sdk_store_encryption::StoreCipher;
40use ruma::{
41    events::secret::request::SecretName, DeviceId, MilliSecondsSinceUnixEpoch, OwnedDeviceId,
42    RoomId, TransactionId, UserId,
43};
44use rusqlite::{named_params, params_from_iter, OptionalExtension};
45use tokio::{fs, sync::Mutex};
46use tracing::{debug, instrument, warn};
47use vodozemac::Curve25519PublicKey;
48
49use crate::{
50    error::{Error, Result},
51    utils::{
52        repeat_vars, EncryptableStore, Key, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt,
53        SqliteKeyValueStoreConnExt,
54    },
55    OpenStoreError, SqliteStoreConfig,
56};
57
58/// The database name.
59const DATABASE_NAME: &str = "matrix-sdk-crypto.sqlite3";
60
61/// An SQLite-based crypto store.
62#[derive(Clone)]
63pub struct SqliteCryptoStore {
64    store_cipher: Option<Arc<StoreCipher>>,
65    pool: SqlitePool,
66
67    // DB values cached in memory
68    static_account: Arc<RwLock<Option<StaticAccountData>>>,
69    save_changes_lock: Arc<Mutex<()>>,
70}
71
72#[cfg(not(tarpaulin_include))]
73impl fmt::Debug for SqliteCryptoStore {
74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75        f.debug_struct("SqliteCryptoStore").finish_non_exhaustive()
76    }
77}
78
79impl EncryptableStore for SqliteCryptoStore {
80    fn get_cypher(&self) -> Option<&StoreCipher> {
81        self.store_cipher.as_deref()
82    }
83}
84
85impl SqliteCryptoStore {
86    /// Open the SQLite-based crypto store at the given path using the given
87    /// passphrase to encrypt private data.
88    pub async fn open(
89        path: impl AsRef<Path>,
90        passphrase: Option<&str>,
91    ) -> Result<Self, OpenStoreError> {
92        Self::open_with_config(SqliteStoreConfig::new(path).passphrase(passphrase)).await
93    }
94
95    /// Open the SQLite-based crypto store with the config open config.
96    pub async fn open_with_config(config: SqliteStoreConfig) -> Result<Self, OpenStoreError> {
97        let SqliteStoreConfig { path, passphrase, pool_config, runtime_config } = config;
98
99        fs::create_dir_all(&path).await.map_err(OpenStoreError::CreateDir)?;
100
101        let mut config = deadpool_sqlite::Config::new(path.join(DATABASE_NAME));
102        config.pool = Some(pool_config);
103
104        let pool = config.create_pool(Runtime::Tokio1)?;
105
106        let this = Self::open_with_pool(pool, passphrase.as_deref()).await?;
107        this.pool.get().await?.apply_runtime_config(runtime_config).await?;
108
109        Ok(this)
110    }
111
112    /// Create an SQLite-based crypto store using the given SQLite database
113    /// pool. The given passphrase will be used to encrypt private data.
114    async fn open_with_pool(
115        pool: SqlitePool,
116        passphrase: Option<&str>,
117    ) -> Result<Self, OpenStoreError> {
118        let conn = pool.get().await?;
119
120        let version = conn.db_version().await?;
121        debug!("Opened sqlite store with version {}", version);
122        run_migrations(&conn, version).await?;
123
124        let store_cipher = match passphrase {
125            Some(p) => Some(Arc::new(conn.get_or_create_store_cipher(p).await?)),
126            None => None,
127        };
128
129        Ok(SqliteCryptoStore {
130            store_cipher,
131            pool,
132            static_account: Arc::new(RwLock::new(None)),
133            save_changes_lock: Default::default(),
134        })
135    }
136
137    fn deserialize_and_unpickle_inbound_group_session(
138        &self,
139        value: Vec<u8>,
140        backed_up: bool,
141    ) -> Result<InboundGroupSession> {
142        let mut pickle: PickledInboundGroupSession = self.deserialize_value(&value)?;
143
144        // The `backed_up` SQL column is the source of truth, because we update it
145        // inside `mark_inbound_group_sessions_as_backed_up` and don't update
146        // the pickled value inside the `data` column (until now, when we are puling it
147        // out of the DB).
148        pickle.backed_up = backed_up;
149
150        Ok(InboundGroupSession::from_pickle(pickle)?)
151    }
152
153    fn deserialize_key_request(&self, value: &[u8], sent_out: bool) -> Result<GossipRequest> {
154        let mut request: GossipRequest = self.deserialize_value(value)?;
155        // sent_out SQL column is source of truth, sent_out field in serialized value
156        // needed for other stores though
157        request.sent_out = sent_out;
158        Ok(request)
159    }
160
161    fn get_static_account(&self) -> Option<StaticAccountData> {
162        self.static_account.read().unwrap().clone()
163    }
164
165    async fn acquire(&self) -> Result<SqliteAsyncConn> {
166        Ok(self.pool.get().await?)
167    }
168}
169
170const DATABASE_VERSION: u8 = 10;
171
172/// key for the dehydrated device pickle key in the key/value table.
173const DEHYDRATED_DEVICE_PICKLE_KEY: &str = "dehydrated_device_pickle_key";
174
175/// Run migrations for the given version of the database.
176async fn run_migrations(conn: &SqliteAsyncConn, version: u8) -> Result<()> {
177    if version == 0 {
178        debug!("Creating database");
179    } else if version < DATABASE_VERSION {
180        debug!(version, new_version = DATABASE_VERSION, "Upgrading database");
181    } else {
182        return Ok(());
183    }
184
185    if version < 1 {
186        // First turn on WAL mode, this can't be done in the transaction, it fails with
187        // the error message: "cannot change into wal mode from within a transaction".
188        conn.execute_batch("PRAGMA journal_mode = wal;").await?;
189        conn.with_transaction(|txn| {
190            txn.execute_batch(include_str!("../migrations/crypto_store/001_init.sql"))?;
191            txn.set_db_version(1)
192        })
193        .await?;
194    }
195
196    if version < 2 {
197        conn.with_transaction(|txn| {
198            txn.execute_batch(include_str!("../migrations/crypto_store/002_reset_olm_hash.sql"))?;
199            txn.set_db_version(2)
200        })
201        .await?;
202    }
203
204    if version < 3 {
205        conn.with_transaction(|txn| {
206            txn.execute_batch(include_str!("../migrations/crypto_store/003_room_settings.sql"))?;
207            txn.set_db_version(3)
208        })
209        .await?;
210    }
211
212    if version < 4 {
213        conn.with_transaction(|txn| {
214            txn.execute_batch(include_str!(
215                "../migrations/crypto_store/004_drop_outbound_group_sessions.sql"
216            ))?;
217            txn.set_db_version(4)
218        })
219        .await?;
220    }
221
222    if version < 5 {
223        conn.with_transaction(|txn| {
224            txn.execute_batch(include_str!("../migrations/crypto_store/005_withheld_code.sql"))?;
225            txn.set_db_version(5)
226        })
227        .await?;
228    }
229
230    if version < 6 {
231        conn.with_transaction(|txn| {
232            txn.execute_batch(include_str!(
233                "../migrations/crypto_store/006_drop_outbound_group_sessions.sql"
234            ))?;
235            txn.set_db_version(6)
236        })
237        .await?;
238    }
239
240    if version < 7 {
241        conn.with_transaction(|txn| {
242            txn.execute_batch(include_str!("../migrations/crypto_store/007_lock_leases.sql"))?;
243            txn.set_db_version(7)
244        })
245        .await?;
246    }
247
248    if version < 8 {
249        conn.with_transaction(|txn| {
250            txn.execute_batch(include_str!("../migrations/crypto_store/008_secret_inbox.sql"))?;
251            txn.set_db_version(8)
252        })
253        .await?;
254    }
255
256    if version < 9 {
257        conn.with_transaction(|txn| {
258            txn.execute_batch(include_str!(
259                "../migrations/crypto_store/009_inbound_group_session_sender_key_sender_data_type.sql"
260            ))?;
261            txn.set_db_version(9)
262        })
263        .await?;
264    }
265
266    if version < 10 {
267        conn.with_transaction(|txn| {
268            txn.execute_batch(include_str!(
269                "../migrations/crypto_store/010_received_room_key_bundles.sql"
270            ))?;
271            txn.set_db_version(10)
272        })
273        .await?;
274    }
275
276    Ok(())
277}
278
279trait SqliteConnectionExt {
280    fn set_session(
281        &self,
282        session_id: &[u8],
283        sender_key: &[u8],
284        data: &[u8],
285    ) -> rusqlite::Result<()>;
286
287    fn set_inbound_group_session(
288        &self,
289        room_id: &[u8],
290        session_id: &[u8],
291        data: &[u8],
292        backed_up: bool,
293        sender_key: Option<&[u8]>,
294        sender_data_type: Option<u8>,
295    ) -> rusqlite::Result<()>;
296
297    fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
298
299    fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
300    fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()>;
301
302    fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
303
304    fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()>;
305
306    fn set_key_request(
307        &self,
308        request_id: &[u8],
309        sent_out: bool,
310        data: &[u8],
311    ) -> rusqlite::Result<()>;
312
313    fn set_direct_withheld(
314        &self,
315        session_id: &[u8],
316        room_id: &[u8],
317        data: &[u8],
318    ) -> rusqlite::Result<()>;
319
320    fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
321
322    fn set_secret(&self, request_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
323
324    fn set_received_room_key_bundle(
325        &self,
326        room_id: &[u8],
327        user_id: &[u8],
328        data: &[u8],
329    ) -> rusqlite::Result<()>;
330}
331
332impl SqliteConnectionExt for rusqlite::Connection {
333    fn set_session(
334        &self,
335        session_id: &[u8],
336        sender_key: &[u8],
337        data: &[u8],
338    ) -> rusqlite::Result<()> {
339        self.execute(
340            "INSERT INTO session (session_id, sender_key, data)
341             VALUES (?1, ?2, ?3)
342             ON CONFLICT (session_id) DO UPDATE SET data = ?3",
343            (session_id, sender_key, data),
344        )?;
345        Ok(())
346    }
347
348    fn set_inbound_group_session(
349        &self,
350        room_id: &[u8],
351        session_id: &[u8],
352        data: &[u8],
353        backed_up: bool,
354        sender_key: Option<&[u8]>,
355        sender_data_type: Option<u8>,
356    ) -> rusqlite::Result<()> {
357        self.execute(
358            "INSERT INTO inbound_group_session (session_id, room_id, data, backed_up, sender_key, sender_data_type) \
359             VALUES (?1, ?2, ?3, ?4, ?5, ?6)
360             ON CONFLICT (session_id) DO UPDATE SET data = ?3, backed_up = ?4, sender_key = ?5, sender_data_type = ?6",
361            (session_id, room_id, data, backed_up, sender_key, sender_data_type),
362        )?;
363        Ok(())
364    }
365
366    fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
367        self.execute(
368            "INSERT INTO outbound_group_session (room_id, data) \
369             VALUES (?1, ?2)
370             ON CONFLICT (room_id) DO UPDATE SET data = ?2",
371            (room_id, data),
372        )?;
373        Ok(())
374    }
375
376    fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
377        self.execute(
378            "INSERT INTO device (user_id, device_id, data) \
379             VALUES (?1, ?2, ?3)
380             ON CONFLICT (user_id, device_id) DO UPDATE SET data = ?3",
381            (user_id, device_id, data),
382        )?;
383        Ok(())
384    }
385
386    fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()> {
387        self.execute(
388            "DELETE FROM device WHERE user_id = ? AND device_id = ?",
389            (user_id, device_id),
390        )?;
391        Ok(())
392    }
393
394    fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
395        self.execute(
396            "INSERT INTO identity (user_id, data) \
397             VALUES (?1, ?2)
398             ON CONFLICT (user_id) DO UPDATE SET data = ?2",
399            (user_id, data),
400        )?;
401        Ok(())
402    }
403
404    fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()> {
405        self.execute("INSERT INTO olm_hash (data) VALUES (?) ON CONFLICT DO NOTHING", (data,))?;
406        Ok(())
407    }
408
409    fn set_key_request(
410        &self,
411        request_id: &[u8],
412        sent_out: bool,
413        data: &[u8],
414    ) -> rusqlite::Result<()> {
415        self.execute(
416            "INSERT INTO key_requests (request_id, sent_out, data)
417            VALUES (?1, ?2, ?3)
418            ON CONFLICT (request_id) DO UPDATE SET sent_out = ?2, data = ?3",
419            (request_id, sent_out, data),
420        )?;
421        Ok(())
422    }
423
424    fn set_direct_withheld(
425        &self,
426        session_id: &[u8],
427        room_id: &[u8],
428        data: &[u8],
429    ) -> rusqlite::Result<()> {
430        self.execute(
431            "INSERT INTO direct_withheld_info (session_id, room_id, data)
432            VALUES (?1, ?2, ?3)
433            ON CONFLICT (session_id) DO UPDATE SET room_id = ?2, data = ?3",
434            (session_id, room_id, data),
435        )?;
436        Ok(())
437    }
438
439    fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
440        self.execute(
441            "INSERT INTO room_settings (room_id, data)
442            VALUES (?1, ?2)
443            ON CONFLICT (room_id) DO UPDATE SET data = ?2",
444            (room_id, data),
445        )?;
446        Ok(())
447    }
448
449    fn set_secret(&self, secret_name: &[u8], data: &[u8]) -> rusqlite::Result<()> {
450        self.execute(
451            "INSERT INTO secrets (secret_name, data)
452            VALUES (?1, ?2)",
453            (secret_name, data),
454        )?;
455
456        Ok(())
457    }
458
459    fn set_received_room_key_bundle(
460        &self,
461        room_id: &[u8],
462        sender_user_id: &[u8],
463        data: &[u8],
464    ) -> rusqlite::Result<()> {
465        self.execute(
466            "INSERT INTO received_room_key_bundle(room_id, sender_user_id, bundle_data)
467            VALUES (?1, ?2, ?3)
468            ON CONFLICT (room_id, sender_user_id) DO UPDATE SET bundle_data = ?3",
469            (room_id, sender_user_id, data),
470        )?;
471        Ok(())
472    }
473}
474
475#[async_trait]
476trait SqliteObjectCryptoStoreExt: SqliteAsyncConnExt {
477    async fn get_sessions_for_sender_key(&self, sender_key: Key) -> Result<Vec<Vec<u8>>> {
478        Ok(self
479            .prepare("SELECT data FROM session WHERE sender_key = ?", |mut stmt| {
480                stmt.query((sender_key,))?.mapped(|row| row.get(0)).collect()
481            })
482            .await?)
483    }
484
485    async fn get_inbound_group_session(
486        &self,
487        session_id: Key,
488    ) -> Result<Option<(Vec<u8>, Vec<u8>, bool)>> {
489        Ok(self
490            .query_row(
491                "SELECT room_id, data, backed_up FROM inbound_group_session WHERE session_id = ?",
492                (session_id,),
493                |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
494            )
495            .await
496            .optional()?)
497    }
498
499    async fn get_inbound_group_sessions(&self) -> Result<Vec<(Vec<u8>, bool)>> {
500        Ok(self
501            .prepare("SELECT data, backed_up FROM inbound_group_session", |mut stmt| {
502                stmt.query(())?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
503            })
504            .await?)
505    }
506
507    async fn get_inbound_group_session_counts(
508        &self,
509        _backup_version: Option<&str>,
510    ) -> Result<RoomKeyCounts> {
511        let total = self
512            .query_row("SELECT count(*) FROM inbound_group_session", (), |row| row.get(0))
513            .await?;
514        let backed_up = self
515            .query_row(
516                "SELECT count(*) FROM inbound_group_session WHERE backed_up = TRUE",
517                (),
518                |row| row.get(0),
519            )
520            .await?;
521        Ok(RoomKeyCounts { total, backed_up })
522    }
523
524    async fn get_inbound_group_sessions_for_device_batch(
525        &self,
526        sender_key: Key,
527        sender_data_type: SenderDataType,
528        after_session_id: Option<Key>,
529        limit: usize,
530    ) -> Result<Vec<(Vec<u8>, bool)>> {
531        Ok(self
532            .prepare(
533                "
534                SELECT data, backed_up
535                FROM inbound_group_session
536                WHERE sender_key = :sender_key
537                    AND sender_data_type = :sender_data_type
538                    AND session_id > :after_session_id
539                ORDER BY session_id
540                LIMIT :limit
541                ",
542                move |mut stmt| {
543                    let sender_data_type = sender_data_type as u8;
544
545                    // If we are not provided with an `after_session_id`, use a key which will sort
546                    // before all real keys: the empty string.
547                    let after_session_id = after_session_id.unwrap_or(Key::Plain(Vec::new()));
548
549                    stmt.query(named_params! {
550                        ":sender_key": sender_key,
551                        ":sender_data_type": sender_data_type,
552                        ":after_session_id": after_session_id,
553                        ":limit": limit,
554                    })?
555                    .mapped(|row| Ok((row.get(0)?, row.get(1)?)))
556                    .collect()
557                },
558            )
559            .await?)
560    }
561
562    async fn get_inbound_group_sessions_for_backup(&self, limit: usize) -> Result<Vec<Vec<u8>>> {
563        Ok(self
564            .prepare(
565                "SELECT data FROM inbound_group_session WHERE backed_up = FALSE LIMIT ?",
566                move |mut stmt| stmt.query((limit,))?.mapped(|row| row.get(0)).collect(),
567            )
568            .await?)
569    }
570
571    async fn mark_inbound_group_sessions_as_backed_up(&self, session_ids: Vec<Key>) -> Result<()> {
572        if session_ids.is_empty() {
573            // We are not expecting to be called with an empty list of sessions
574            warn!("No sessions to mark as backed up!");
575            return Ok(());
576        }
577
578        let session_ids_len = session_ids.len();
579
580        self.chunk_large_query_over(session_ids, None, move |txn, session_ids| {
581            // Safety: placeholders is not generated using any user input except the number
582            // of session IDs, so it is safe from injection.
583            let sql_params = repeat_vars(session_ids_len);
584            let query = format!("UPDATE inbound_group_session SET backed_up = TRUE where session_id IN ({sql_params})");
585            txn.prepare(&query)?.execute(params_from_iter(session_ids.iter()))?;
586            Ok(Vec::<()>::new())
587        }).await?;
588
589        Ok(())
590    }
591
592    async fn reset_inbound_group_session_backup_state(&self) -> Result<()> {
593        self.execute("UPDATE inbound_group_session SET backed_up = FALSE", ()).await?;
594        Ok(())
595    }
596
597    async fn get_outbound_group_session(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
598        Ok(self
599            .query_row(
600                "SELECT data FROM outbound_group_session WHERE room_id = ?",
601                (room_id,),
602                |row| row.get(0),
603            )
604            .await
605            .optional()?)
606    }
607
608    async fn get_device(&self, user_id: Key, device_id: Key) -> Result<Option<Vec<u8>>> {
609        Ok(self
610            .query_row(
611                "SELECT data FROM device WHERE user_id = ? AND device_id = ?",
612                (user_id, device_id),
613                |row| row.get(0),
614            )
615            .await
616            .optional()?)
617    }
618
619    async fn get_user_devices(&self, user_id: Key) -> Result<Vec<Vec<u8>>> {
620        Ok(self
621            .prepare("SELECT data FROM device WHERE user_id = ?", |mut stmt| {
622                stmt.query((user_id,))?.mapped(|row| row.get(0)).collect()
623            })
624            .await?)
625    }
626
627    async fn get_user_identity(&self, user_id: Key) -> Result<Option<Vec<u8>>> {
628        Ok(self
629            .query_row("SELECT data FROM identity WHERE user_id = ?", (user_id,), |row| row.get(0))
630            .await
631            .optional()?)
632    }
633
634    async fn has_olm_hash(&self, data: Vec<u8>) -> Result<bool> {
635        Ok(self
636            .query_row("SELECT count(*) FROM olm_hash WHERE data = ?", (data,), |row| {
637                row.get::<_, i32>(0)
638            })
639            .await?
640            > 0)
641    }
642
643    async fn get_tracked_users(&self) -> Result<Vec<Vec<u8>>> {
644        Ok(self
645            .prepare("SELECT data FROM tracked_user", |mut stmt| {
646                stmt.query(())?.mapped(|row| row.get(0)).collect()
647            })
648            .await?)
649    }
650
651    async fn add_tracked_users(&self, users: Vec<(Key, Vec<u8>)>) -> Result<()> {
652        Ok(self
653            .prepare(
654                "INSERT INTO tracked_user (user_id, data) \
655                 VALUES (?1, ?2) \
656                 ON CONFLICT (user_id) DO UPDATE SET data = ?2",
657                |mut stmt| {
658                    for (user_id, data) in users {
659                        stmt.execute((user_id, data))?;
660                    }
661
662                    Ok(())
663                },
664            )
665            .await?)
666    }
667
668    async fn get_outgoing_secret_request(
669        &self,
670        request_id: Key,
671    ) -> Result<Option<(Vec<u8>, bool)>> {
672        Ok(self
673            .query_row(
674                "SELECT data, sent_out FROM key_requests WHERE request_id = ?",
675                (request_id,),
676                |row| Ok((row.get(0)?, row.get(1)?)),
677            )
678            .await
679            .optional()?)
680    }
681
682    async fn get_outgoing_secret_requests(&self) -> Result<Vec<(Vec<u8>, bool)>> {
683        Ok(self
684            .prepare("SELECT data, sent_out FROM key_requests", |mut stmt| {
685                stmt.query(())?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
686            })
687            .await?)
688    }
689
690    async fn get_unsent_secret_requests(&self) -> Result<Vec<Vec<u8>>> {
691        Ok(self
692            .prepare("SELECT data FROM key_requests WHERE sent_out = FALSE", |mut stmt| {
693                stmt.query(())?.mapped(|row| row.get(0)).collect()
694            })
695            .await?)
696    }
697
698    async fn delete_key_request(&self, request_id: Key) -> Result<()> {
699        self.execute("DELETE FROM key_requests WHERE request_id = ?", (request_id,)).await?;
700        Ok(())
701    }
702
703    async fn get_secrets_from_inbox(&self, secret_name: Key) -> Result<Vec<Vec<u8>>> {
704        Ok(self
705            .prepare("SELECT data FROM secrets WHERE secret_name = ?", |mut stmt| {
706                stmt.query((secret_name,))?.mapped(|row| row.get(0)).collect()
707            })
708            .await?)
709    }
710
711    async fn delete_secrets_from_inbox(&self, secret_name: Key) -> Result<()> {
712        self.execute("DELETE FROM secrets WHERE secret_name = ?", (secret_name,)).await?;
713        Ok(())
714    }
715
716    async fn get_direct_withheld_info(
717        &self,
718        session_id: Key,
719        room_id: Key,
720    ) -> Result<Option<Vec<u8>>> {
721        Ok(self
722            .query_row(
723                "SELECT data FROM direct_withheld_info WHERE session_id = ?1 AND room_id = ?2",
724                (session_id, room_id),
725                |row| row.get(0),
726            )
727            .await
728            .optional()?)
729    }
730
731    async fn get_room_settings(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
732        Ok(self
733            .query_row("SELECT data FROM room_settings WHERE room_id = ?", (room_id,), |row| {
734                row.get(0)
735            })
736            .await
737            .optional()?)
738    }
739
740    async fn get_received_room_key_bundle(
741        &self,
742        room_id: Key,
743        sender_user: Key,
744    ) -> Result<Option<Vec<u8>>> {
745        Ok(self
746            .query_row(
747                "SELECT bundle_data FROM received_room_key_bundle WHERE room_id = ? AND sender_user_id = ?",
748                (room_id, sender_user),
749                |row| { row.get(0) },
750            )
751            .await
752            .optional()?)
753    }
754}
755
756#[async_trait]
757impl SqliteObjectCryptoStoreExt for SqliteAsyncConn {}
758
759#[async_trait]
760impl CryptoStore for SqliteCryptoStore {
761    type Error = Error;
762
763    async fn load_account(&self) -> Result<Option<Account>> {
764        let conn = self.acquire().await?;
765        if let Some(pickle) = conn.get_kv("account").await? {
766            let pickle = self.deserialize_value(&pickle)?;
767
768            let account = Account::from_pickle(pickle).map_err(|_| Error::Unpickle)?;
769
770            *self.static_account.write().unwrap() = Some(account.static_data().clone());
771
772            Ok(Some(account))
773        } else {
774            Ok(None)
775        }
776    }
777
778    async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
779        let conn = self.acquire().await?;
780        if let Some(i) = conn.get_kv("identity").await? {
781            let pickle = self.deserialize_value(&i)?;
782            Ok(Some(PrivateCrossSigningIdentity::from_pickle(pickle).map_err(|_| Error::Unpickle)?))
783        } else {
784            Ok(None)
785        }
786    }
787
788    async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
789        // Serialize calls to `save_pending_changes`; there are multiple await points
790        // below, and we're pickling data as we go, so we don't want to
791        // invalidate data we've previously read and overwrite it in the store.
792        // TODO: #2000 should make this lock go away, or change its shape.
793        let _guard = self.save_changes_lock.lock().await;
794
795        let pickled_account = if let Some(account) = changes.account {
796            *self.static_account.write().unwrap() = Some(account.static_data().clone());
797            Some(account.pickle())
798        } else {
799            None
800        };
801
802        let this = self.clone();
803        self.acquire()
804            .await?
805            .with_transaction(move |txn| {
806                if let Some(pickled_account) = pickled_account {
807                    let serialized_account = this.serialize_value(&pickled_account)?;
808                    txn.set_kv("account", &serialized_account)?;
809                }
810
811                Ok::<_, Error>(())
812            })
813            .await?;
814
815        Ok(())
816    }
817
818    async fn save_changes(&self, changes: Changes) -> Result<()> {
819        // Serialize calls to `save_changes`; there are multiple await points below, and
820        // we're pickling data as we go, so we don't want to invalidate data
821        // we've previously read and overwrite it in the store.
822        // TODO: #2000 should make this lock go away, or change its shape.
823        let _guard = self.save_changes_lock.lock().await;
824
825        let pickled_private_identity =
826            if let Some(i) = changes.private_identity { Some(i.pickle().await) } else { None };
827
828        let mut session_changes = Vec::new();
829
830        for session in changes.sessions {
831            let session_id = self.encode_key("session", session.session_id());
832            let sender_key = self.encode_key("session", session.sender_key().to_base64());
833            let pickle = session.pickle().await;
834            session_changes.push((session_id, sender_key, pickle));
835        }
836
837        let mut inbound_session_changes = Vec::new();
838        for session in changes.inbound_group_sessions {
839            let room_id = self.encode_key("inbound_group_session", session.room_id().as_bytes());
840            let session_id = self.encode_key("inbound_group_session", session.session_id());
841            let pickle = session.pickle().await;
842            let sender_key =
843                self.encode_key("inbound_group_session", session.sender_key().to_base64());
844            inbound_session_changes.push((room_id, session_id, pickle, sender_key));
845        }
846
847        let mut outbound_session_changes = Vec::new();
848        for session in changes.outbound_group_sessions {
849            let room_id = self.encode_key("outbound_group_session", session.room_id().as_bytes());
850            let pickle = session.pickle().await;
851            outbound_session_changes.push((room_id, pickle));
852        }
853
854        let this = self.clone();
855        self.acquire()
856            .await?
857            .with_transaction(move |txn| {
858                if let Some(pickled_private_identity) = &pickled_private_identity {
859                    let serialized_private_identity =
860                        this.serialize_value(pickled_private_identity)?;
861                    txn.set_kv("identity", &serialized_private_identity)?;
862                }
863
864                if let Some(token) = &changes.next_batch_token {
865                    let serialized_token = this.serialize_value(token)?;
866                    txn.set_kv("next_batch_token", &serialized_token)?;
867                }
868
869                if let Some(decryption_key) = &changes.backup_decryption_key {
870                    let serialized_decryption_key = this.serialize_value(decryption_key)?;
871                    txn.set_kv("recovery_key_v1", &serialized_decryption_key)?;
872                }
873
874                if let Some(backup_version) = &changes.backup_version {
875                    let serialized_backup_version = this.serialize_value(backup_version)?;
876                    txn.set_kv("backup_version_v1", &serialized_backup_version)?;
877                }
878
879                if let Some(pickle_key) = &changes.dehydrated_device_pickle_key {
880                    let serialized_pickle_key = this.serialize_value(pickle_key)?;
881                    txn.set_kv(DEHYDRATED_DEVICE_PICKLE_KEY, &serialized_pickle_key)?;
882                }
883
884                for device in changes.devices.new.iter().chain(&changes.devices.changed) {
885                    let user_id = this.encode_key("device", device.user_id().as_bytes());
886                    let device_id = this.encode_key("device", device.device_id().as_bytes());
887                    let data = this.serialize_value(&device)?;
888                    txn.set_device(&user_id, &device_id, &data)?;
889                }
890
891                for device in &changes.devices.deleted {
892                    let user_id = this.encode_key("device", device.user_id().as_bytes());
893                    let device_id = this.encode_key("device", device.device_id().as_bytes());
894                    txn.delete_device(&user_id, &device_id)?;
895                }
896
897                for identity in changes.identities.changed.iter().chain(&changes.identities.new) {
898                    let user_id = this.encode_key("identity", identity.user_id().as_bytes());
899                    let data = this.serialize_value(&identity)?;
900                    txn.set_identity(&user_id, &data)?;
901                }
902
903                for (session_id, sender_key, pickle) in &session_changes {
904                    let serialized_session = this.serialize_value(&pickle)?;
905                    txn.set_session(session_id, sender_key, &serialized_session)?;
906                }
907
908                for (room_id, session_id, pickle, sender_key) in &inbound_session_changes {
909                    let serialized_session = this.serialize_value(&pickle)?;
910                    txn.set_inbound_group_session(
911                        room_id,
912                        session_id,
913                        &serialized_session,
914                        pickle.backed_up,
915                        Some(sender_key),
916                        Some(pickle.sender_data.to_type() as u8),
917                    )?;
918                }
919
920                for (room_id, pickle) in &outbound_session_changes {
921                    let serialized_session = this.serialize_json(&pickle)?;
922                    txn.set_outbound_group_session(room_id, &serialized_session)?;
923                }
924
925                for hash in &changes.message_hashes {
926                    let hash = rmp_serde::to_vec(hash)?;
927                    txn.add_olm_hash(&hash)?;
928                }
929
930                for request in changes.key_requests {
931                    let request_id = this.encode_key("key_requests", request.request_id.as_bytes());
932                    let serialized_request = this.serialize_value(&request)?;
933                    txn.set_key_request(&request_id, request.sent_out, &serialized_request)?;
934                }
935
936                for (room_id, data) in changes.withheld_session_info {
937                    for (session_id, event) in data {
938                        let session_id = this.encode_key("direct_withheld_info", session_id);
939                        let room_id = this.encode_key("direct_withheld_info", &room_id);
940                        let serialized_info = this.serialize_json(&event)?;
941                        txn.set_direct_withheld(&session_id, &room_id, &serialized_info)?;
942                    }
943                }
944
945                for (room_id, settings) in changes.room_settings {
946                    let room_id = this.encode_key("room_settings", room_id.as_bytes());
947                    let value = this.serialize_value(&settings)?;
948                    txn.set_room_settings(&room_id, &value)?;
949                }
950
951                for secret in changes.secrets {
952                    let secret_name = this.encode_key("secrets", secret.secret_name.to_string());
953                    let value = this.serialize_json(&secret)?;
954                    txn.set_secret(&secret_name, &value)?;
955                }
956
957                for bundle in changes.received_room_key_bundles {
958                    let room_id =
959                        this.encode_key("received_room_key_bundle", &bundle.bundle_data.room_id);
960                    let user_id = this.encode_key("received_room_key_bundle", &bundle.sender_user);
961                    let value = this.serialize_value(&bundle)?;
962                    txn.set_received_room_key_bundle(&room_id, &user_id, &value)?;
963                }
964
965                Ok::<_, Error>(())
966            })
967            .await?;
968
969        Ok(())
970    }
971
972    async fn save_inbound_group_sessions(
973        &self,
974        sessions: Vec<InboundGroupSession>,
975        backed_up_to_version: Option<&str>,
976    ) -> matrix_sdk_crypto::store::Result<(), Self::Error> {
977        // Sanity-check that the data in the sessions corresponds to backed_up_version
978        sessions.iter().for_each(|s| {
979            let backed_up = s.backed_up();
980            if backed_up != backed_up_to_version.is_some() {
981                warn!(
982                    backed_up,
983                    backed_up_to_version,
984                    "Session backed-up flag does not correspond to backup version setting",
985                );
986            }
987        });
988
989        // Currently, this store doesn't save the backup version separately, so this
990        // just delegates to save_changes.
991        self.save_changes(Changes { inbound_group_sessions: sessions, ..Changes::default() }).await
992    }
993
994    async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
995        let device_keys = self.get_own_device().await?.as_device_keys().clone();
996
997        let sessions: Vec<_> = self
998            .acquire()
999            .await?
1000            .get_sessions_for_sender_key(self.encode_key("session", sender_key.as_bytes()))
1001            .await?
1002            .into_iter()
1003            .map(|bytes| {
1004                let pickle = self.deserialize_value(&bytes)?;
1005                Session::from_pickle(device_keys.clone(), pickle).map_err(|_| Error::AccountUnset)
1006            })
1007            .collect::<Result<_>>()?;
1008
1009        if sessions.is_empty() {
1010            Ok(None)
1011        } else {
1012            Ok(Some(sessions))
1013        }
1014    }
1015
1016    #[instrument(skip(self))]
1017    async fn get_inbound_group_session(
1018        &self,
1019        room_id: &RoomId,
1020        session_id: &str,
1021    ) -> Result<Option<InboundGroupSession>> {
1022        let session_id = self.encode_key("inbound_group_session", session_id);
1023        let Some((room_id_from_db, value, backed_up)) =
1024            self.acquire().await?.get_inbound_group_session(session_id).await?
1025        else {
1026            return Ok(None);
1027        };
1028
1029        let room_id = self.encode_key("inbound_group_session", room_id.as_bytes());
1030        if *room_id != room_id_from_db {
1031            warn!("expected room_id for session_id doesn't match what's in the DB");
1032            return Ok(None);
1033        }
1034
1035        Ok(Some(self.deserialize_and_unpickle_inbound_group_session(value, backed_up)?))
1036    }
1037
1038    async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
1039        self.acquire()
1040            .await?
1041            .get_inbound_group_sessions()
1042            .await?
1043            .into_iter()
1044            .map(|(value, backed_up)| {
1045                self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1046            })
1047            .collect()
1048    }
1049
1050    async fn get_inbound_group_sessions_for_device_batch(
1051        &self,
1052        sender_key: Curve25519PublicKey,
1053        sender_data_type: SenderDataType,
1054        after_session_id: Option<String>,
1055        limit: usize,
1056    ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1057        let after_session_id =
1058            after_session_id.map(|session_id| self.encode_key("inbound_group_session", session_id));
1059        let sender_key = self.encode_key("inbound_group_session", sender_key.to_base64());
1060
1061        self.acquire()
1062            .await?
1063            .get_inbound_group_sessions_for_device_batch(
1064                sender_key,
1065                sender_data_type,
1066                after_session_id,
1067                limit,
1068            )
1069            .await?
1070            .into_iter()
1071            .map(|(value, backed_up)| {
1072                self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1073            })
1074            .collect()
1075    }
1076
1077    async fn inbound_group_session_counts(
1078        &self,
1079        backup_version: Option<&str>,
1080    ) -> Result<RoomKeyCounts> {
1081        Ok(self.acquire().await?.get_inbound_group_session_counts(backup_version).await?)
1082    }
1083
1084    async fn inbound_group_sessions_for_backup(
1085        &self,
1086        _backup_version: &str,
1087        limit: usize,
1088    ) -> Result<Vec<InboundGroupSession>> {
1089        self.acquire()
1090            .await?
1091            .get_inbound_group_sessions_for_backup(limit)
1092            .await?
1093            .into_iter()
1094            .map(|value| self.deserialize_and_unpickle_inbound_group_session(value, false))
1095            .collect()
1096    }
1097
1098    async fn mark_inbound_group_sessions_as_backed_up(
1099        &self,
1100        _backup_version: &str,
1101        session_ids: &[(&RoomId, &str)],
1102    ) -> Result<()> {
1103        Ok(self
1104            .acquire()
1105            .await?
1106            .mark_inbound_group_sessions_as_backed_up(
1107                session_ids
1108                    .iter()
1109                    .map(|(_, s)| self.encode_key("inbound_group_session", s))
1110                    .collect(),
1111            )
1112            .await?)
1113    }
1114
1115    async fn reset_backup_state(&self) -> Result<()> {
1116        Ok(self.acquire().await?.reset_inbound_group_session_backup_state().await?)
1117    }
1118
1119    async fn load_backup_keys(&self) -> Result<BackupKeys> {
1120        let conn = self.acquire().await?;
1121
1122        let backup_version = conn
1123            .get_kv("backup_version_v1")
1124            .await?
1125            .map(|value| self.deserialize_value(&value))
1126            .transpose()?;
1127
1128        let decryption_key = conn
1129            .get_kv("recovery_key_v1")
1130            .await?
1131            .map(|value| self.deserialize_value(&value))
1132            .transpose()?;
1133
1134        Ok(BackupKeys { backup_version, decryption_key })
1135    }
1136
1137    async fn load_dehydrated_device_pickle_key(&self) -> Result<Option<DehydratedDeviceKey>> {
1138        let conn = self.acquire().await?;
1139
1140        conn.get_kv(DEHYDRATED_DEVICE_PICKLE_KEY)
1141            .await?
1142            .map(|value| self.deserialize_value(&value))
1143            .transpose()
1144    }
1145
1146    async fn delete_dehydrated_device_pickle_key(&self) -> Result<(), Self::Error> {
1147        let conn = self.acquire().await?;
1148        conn.clear_kv(DEHYDRATED_DEVICE_PICKLE_KEY).await?;
1149
1150        Ok(())
1151    }
1152    async fn get_outbound_group_session(
1153        &self,
1154        room_id: &RoomId,
1155    ) -> Result<Option<OutboundGroupSession>> {
1156        let room_id = self.encode_key("outbound_group_session", room_id.as_bytes());
1157        let Some(value) = self.acquire().await?.get_outbound_group_session(room_id).await? else {
1158            return Ok(None);
1159        };
1160
1161        let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1162
1163        let pickle = self.deserialize_json(&value)?;
1164        let session = OutboundGroupSession::from_pickle(
1165            account_info.device_id,
1166            account_info.identity_keys,
1167            pickle,
1168        )
1169        .map_err(|_| Error::Unpickle)?;
1170
1171        return Ok(Some(session));
1172    }
1173
1174    async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
1175        self.acquire()
1176            .await?
1177            .get_tracked_users()
1178            .await?
1179            .iter()
1180            .map(|value| self.deserialize_value(value))
1181            .collect()
1182    }
1183
1184    async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
1185        let users: Vec<(Key, Vec<u8>)> = tracked_users
1186            .iter()
1187            .map(|(u, d)| {
1188                let user_id = self.encode_key("tracked_users", u.as_bytes());
1189                let data =
1190                    self.serialize_value(&TrackedUser { user_id: (*u).into(), dirty: *d })?;
1191                Ok((user_id, data))
1192            })
1193            .collect::<Result<_>>()?;
1194
1195        Ok(self.acquire().await?.add_tracked_users(users).await?)
1196    }
1197
1198    async fn get_device(
1199        &self,
1200        user_id: &UserId,
1201        device_id: &DeviceId,
1202    ) -> Result<Option<DeviceData>> {
1203        let user_id = self.encode_key("device", user_id.as_bytes());
1204        let device_id = self.encode_key("device", device_id.as_bytes());
1205        Ok(self
1206            .acquire()
1207            .await?
1208            .get_device(user_id, device_id)
1209            .await?
1210            .map(|value| self.deserialize_value(&value))
1211            .transpose()?)
1212    }
1213
1214    async fn get_user_devices(
1215        &self,
1216        user_id: &UserId,
1217    ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1218        let user_id = self.encode_key("device", user_id.as_bytes());
1219        self.acquire()
1220            .await?
1221            .get_user_devices(user_id)
1222            .await?
1223            .into_iter()
1224            .map(|value| {
1225                let device: DeviceData = self.deserialize_value(&value)?;
1226                Ok((device.device_id().to_owned(), device))
1227            })
1228            .collect()
1229    }
1230
1231    async fn get_own_device(&self) -> Result<DeviceData> {
1232        let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1233
1234        Ok(self
1235            .get_device(&account_info.user_id, &account_info.device_id)
1236            .await?
1237            .expect("We should be able to find our own device."))
1238    }
1239
1240    async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
1241        let user_id = self.encode_key("identity", user_id.as_bytes());
1242        Ok(self
1243            .acquire()
1244            .await?
1245            .get_user_identity(user_id)
1246            .await?
1247            .map(|value| self.deserialize_value(&value))
1248            .transpose()?)
1249    }
1250
1251    async fn is_message_known(
1252        &self,
1253        message_hash: &matrix_sdk_crypto::olm::OlmMessageHash,
1254    ) -> Result<bool> {
1255        let value = rmp_serde::to_vec(message_hash)?;
1256        Ok(self.acquire().await?.has_olm_hash(value).await?)
1257    }
1258
1259    async fn get_outgoing_secret_requests(
1260        &self,
1261        request_id: &TransactionId,
1262    ) -> Result<Option<GossipRequest>> {
1263        let request_id = self.encode_key("key_requests", request_id.as_bytes());
1264        Ok(self
1265            .acquire()
1266            .await?
1267            .get_outgoing_secret_request(request_id)
1268            .await?
1269            .map(|(value, sent_out)| self.deserialize_key_request(&value, sent_out))
1270            .transpose()?)
1271    }
1272
1273    async fn get_secret_request_by_info(
1274        &self,
1275        key_info: &SecretInfo,
1276    ) -> Result<Option<GossipRequest>> {
1277        let requests = self.acquire().await?.get_outgoing_secret_requests().await?;
1278        for (request, sent_out) in requests {
1279            let request = self.deserialize_key_request(&request, sent_out)?;
1280            if request.info == *key_info {
1281                return Ok(Some(request));
1282            }
1283        }
1284        Ok(None)
1285    }
1286
1287    async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
1288        self.acquire()
1289            .await?
1290            .get_unsent_secret_requests()
1291            .await?
1292            .iter()
1293            .map(|value| {
1294                let request = self.deserialize_key_request(value, false)?;
1295                Ok(request)
1296            })
1297            .collect()
1298    }
1299
1300    async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
1301        let request_id = self.encode_key("key_requests", request_id.as_bytes());
1302        Ok(self.acquire().await?.delete_key_request(request_id).await?)
1303    }
1304
1305    async fn get_secrets_from_inbox(
1306        &self,
1307        secret_name: &SecretName,
1308    ) -> Result<Vec<GossippedSecret>> {
1309        let secret_name = self.encode_key("secrets", secret_name.to_string());
1310
1311        self.acquire()
1312            .await?
1313            .get_secrets_from_inbox(secret_name)
1314            .await?
1315            .into_iter()
1316            .map(|value| self.deserialize_json(value.as_ref()))
1317            .collect()
1318    }
1319
1320    async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
1321        let secret_name = self.encode_key("secrets", secret_name.to_string());
1322        self.acquire().await?.delete_secrets_from_inbox(secret_name).await
1323    }
1324
1325    async fn get_withheld_info(
1326        &self,
1327        room_id: &RoomId,
1328        session_id: &str,
1329    ) -> Result<Option<RoomKeyWithheldEvent>> {
1330        let room_id = self.encode_key("direct_withheld_info", room_id);
1331        let session_id = self.encode_key("direct_withheld_info", session_id);
1332
1333        self.acquire()
1334            .await?
1335            .get_direct_withheld_info(session_id, room_id)
1336            .await?
1337            .map(|value| {
1338                let info = self.deserialize_json::<RoomKeyWithheldEvent>(&value)?;
1339                Ok(info)
1340            })
1341            .transpose()
1342    }
1343
1344    async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
1345        let room_id = self.encode_key("room_settings", room_id.as_bytes());
1346        let Some(value) = self.acquire().await?.get_room_settings(room_id).await? else {
1347            return Ok(None);
1348        };
1349
1350        let settings = self.deserialize_value(&value)?;
1351
1352        return Ok(Some(settings));
1353    }
1354
1355    async fn get_received_room_key_bundle_data(
1356        &self,
1357        room_id: &RoomId,
1358        user_id: &UserId,
1359    ) -> Result<Option<StoredRoomKeyBundleData>> {
1360        let room_id = self.encode_key("received_room_key_bundle", room_id);
1361        let user_id = self.encode_key("received_room_key_bundle", user_id);
1362        self.acquire()
1363            .await?
1364            .get_received_room_key_bundle(room_id, user_id)
1365            .await?
1366            .map(|value| self.deserialize_value(&value))
1367            .transpose()
1368    }
1369
1370    async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
1371        let Some(serialized) = self.acquire().await?.get_kv(key).await? else {
1372            return Ok(None);
1373        };
1374        let value = if let Some(cipher) = &self.store_cipher {
1375            let encrypted = rmp_serde::from_slice(&serialized)?;
1376            cipher.decrypt_value_data(encrypted)?
1377        } else {
1378            serialized
1379        };
1380
1381        Ok(Some(value))
1382    }
1383
1384    async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
1385        let serialized = if let Some(cipher) = &self.store_cipher {
1386            let encrypted = cipher.encrypt_value_data(value)?;
1387            rmp_serde::to_vec_named(&encrypted)?
1388        } else {
1389            value
1390        };
1391
1392        self.acquire().await?.set_kv(key, serialized).await?;
1393        Ok(())
1394    }
1395
1396    async fn remove_custom_value(&self, key: &str) -> Result<()> {
1397        let key = key.to_owned();
1398        self.acquire()
1399            .await?
1400            .interact(move |conn| conn.execute("DELETE FROM kv WHERE key = ?1", (&key,)))
1401            .await
1402            .unwrap()?;
1403        Ok(())
1404    }
1405
1406    async fn try_take_leased_lock(
1407        &self,
1408        lease_duration_ms: u32,
1409        key: &str,
1410        holder: &str,
1411    ) -> Result<bool> {
1412        let key = key.to_owned();
1413        let holder = holder.to_owned();
1414
1415        let now_ts: u64 = MilliSecondsSinceUnixEpoch::now().get().into();
1416        let expiration_ts = now_ts + lease_duration_ms as u64;
1417
1418        let num_touched = self
1419            .acquire()
1420            .await?
1421            .with_transaction(move |txn| {
1422                txn.execute(
1423                    "INSERT INTO lease_locks (key, holder, expiration_ts)
1424                    VALUES (?1, ?2, ?3)
1425                    ON CONFLICT (key)
1426                    DO
1427                        UPDATE SET holder = ?2, expiration_ts = ?3
1428                        WHERE holder = ?2
1429                        OR expiration_ts < ?4
1430                ",
1431                    (key, holder, expiration_ts, now_ts),
1432                )
1433            })
1434            .await?;
1435
1436        Ok(num_touched == 1)
1437    }
1438
1439    async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
1440        let conn = self.acquire().await?;
1441        if let Some(token) = conn.get_kv("next_batch_token").await? {
1442            let maybe_token: Option<String> = self.deserialize_value(&token)?;
1443            Ok(maybe_token)
1444        } else {
1445            Ok(None)
1446        }
1447    }
1448}
1449
1450#[cfg(test)]
1451mod tests {
1452    use std::path::Path;
1453
1454    use matrix_sdk_common::deserialized_responses::WithheldCode;
1455    use matrix_sdk_crypto::{
1456        cryptostore_integration_tests, cryptostore_integration_tests_time, olm::SenderDataType,
1457        store::CryptoStore,
1458    };
1459    use matrix_sdk_test::async_test;
1460    use once_cell::sync::Lazy;
1461    use ruma::{device_id, room_id, user_id};
1462    use similar_asserts::assert_eq;
1463    use tempfile::{tempdir, TempDir};
1464    use tokio::fs;
1465
1466    use super::SqliteCryptoStore;
1467    use crate::SqliteStoreConfig;
1468
1469    static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
1470
1471    struct TestDb {
1472        // Needs to be kept alive because the Drop implementation for TempDir deletes the
1473        // directory.
1474        _dir: TempDir,
1475        database: SqliteCryptoStore,
1476    }
1477
1478    fn copy_db(data_path: &str) -> TempDir {
1479        let db_name = super::DATABASE_NAME;
1480
1481        let manifest_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../..");
1482        let database_path = manifest_path.join(data_path).join(db_name);
1483
1484        let tmpdir = tempdir().unwrap();
1485        let destination = tmpdir.path().join(db_name);
1486
1487        // Copy the test database to the tempdir so our test runs are idempotent.
1488        std::fs::copy(&database_path, destination).unwrap();
1489
1490        tmpdir
1491    }
1492
1493    async fn get_test_db(data_path: &str, passphrase: Option<&str>) -> TestDb {
1494        let tmpdir = copy_db(data_path);
1495
1496        let database = SqliteCryptoStore::open(tmpdir.path(), passphrase)
1497            .await
1498            .expect("Can't open the test store");
1499
1500        TestDb { _dir: tmpdir, database }
1501    }
1502
1503    #[async_test]
1504    async fn test_pool_size() {
1505        let store_open_config =
1506            SqliteStoreConfig::new(TMP_DIR.path().join("test_pool_size")).pool_max_size(42);
1507
1508        let store = SqliteCryptoStore::open_with_config(store_open_config).await.unwrap();
1509
1510        assert_eq!(store.pool.status().max_size, 42);
1511    }
1512
1513    /// Test that we didn't regress in our storage layer by loading data from a
1514    /// pre-filled database, or in other words use a test vector for this.
1515    #[async_test]
1516    async fn test_open_test_vector_store() {
1517        let TestDb { _dir: _, database } = get_test_db("testing/data/storage", None).await;
1518
1519        let account = database
1520            .load_account()
1521            .await
1522            .unwrap()
1523            .expect("The test database is prefilled with data, we should find an account");
1524
1525        let user_id = account.user_id();
1526        let device_id = account.device_id();
1527
1528        assert_eq!(
1529            user_id.as_str(),
1530            "@pjtest:synapse-oidc.element.dev",
1531            "The user ID should match to the one we expect."
1532        );
1533
1534        assert_eq!(
1535            device_id.as_str(),
1536            "v4TqgcuIH6",
1537            "The device ID should match to the one we expect."
1538        );
1539
1540        let device = database
1541            .get_device(user_id, device_id)
1542            .await
1543            .unwrap()
1544            .expect("Our own device should be found in the store.");
1545
1546        assert_eq!(device.device_id(), device_id);
1547        assert_eq!(device.user_id(), user_id);
1548
1549        assert_eq!(
1550            device.ed25519_key().expect("The device should have a Ed25519 key.").to_base64(),
1551            "+cxl1Gl3du5i7UJwfWnoRDdnafFF+xYdAiTYYhYLr8s"
1552        );
1553
1554        assert_eq!(
1555            device.curve25519_key().expect("The device should have a Curve25519 key.").to_base64(),
1556            "4SL9eEUlpyWSUvjljC5oMjknHQQJY7WZKo5S1KL/5VU"
1557        );
1558
1559        let identity = database
1560            .get_user_identity(user_id)
1561            .await
1562            .unwrap()
1563            .expect("The store should contain an identity.");
1564
1565        assert_eq!(identity.user_id(), user_id);
1566
1567        let identity = identity
1568            .own()
1569            .expect("The identity should be of the correct type, it should be our own identity.");
1570
1571        let master_key = identity
1572            .master_key()
1573            .get_first_key()
1574            .expect("Our own identity should have a master key");
1575
1576        assert_eq!(master_key.to_base64(), "iCUEtB1RwANeqRa5epDrblLk4mer/36sylwQ5hYY3oE");
1577    }
1578
1579    /// Test that we didn't regress in our storage layer by loading data from a
1580    /// pre-filled database, or in other words use a test vector for this.
1581    #[async_test]
1582    async fn test_open_test_vector_encrypted_store() {
1583        let TestDb { _dir: _, database } = get_test_db(
1584            "testing/data/storage/alice",
1585            Some(concat!(
1586                "/rCia2fYAJ+twCZ1Xm2mxFCYcmJdyzkdJjwtgXsziWpYS/UeNxnixuSieuwZXm+x1VsJHmWpl",
1587                "H+QIQBZpEGZtC9/S/l8xK+WOCesmET0o6yJ/KP73ofDtjBlnNpPwuHLKFpyTbyicpCgQ4UT+5E",
1588                "UBuJ08TY9Ujdf1D13k5kr5tSZUefDKKCuG1fCRqlU8ByRas1PMQsZxT2W8t7QgBrQiiGmhpo/O",
1589                "Ti4hfx97GOxncKcxTzppiYQNoHs/f15+XXQD7/oiCcqRIuUlXNsU6hRpFGmbYx2Pi1eyQViQCt",
1590                "B5dAEiSD0N8U81wXYnpynuTPtnL+hfnOJIn7Sy7mkERQeKg"
1591            )),
1592        )
1593        .await;
1594
1595        let account = database
1596            .load_account()
1597            .await
1598            .unwrap()
1599            .expect("The test database is prefilled with data, we should find an account");
1600
1601        let user_id = account.user_id();
1602        let device_id = account.device_id();
1603
1604        assert_eq!(
1605            user_id.as_str(),
1606            "@alice:localhost",
1607            "The user ID should match to the one we expect."
1608        );
1609
1610        assert_eq!(
1611            device_id.as_str(),
1612            "JVVORTHFXY",
1613            "The device ID should match to the one we expect."
1614        );
1615
1616        let tracked_users =
1617            database.load_tracked_users().await.expect("Should be tracking some users");
1618
1619        assert_eq!(tracked_users.len(), 6);
1620
1621        let known_users = vec![
1622            user_id!("@alice:localhost"),
1623            user_id!("@dehydration3:localhost"),
1624            user_id!("@eve:localhost"),
1625            user_id!("@bob:localhost"),
1626            user_id!("@malo:localhost"),
1627            user_id!("@carl:localhost"),
1628        ];
1629
1630        // load the identities
1631        for user_id in known_users {
1632            database.get_user_identity(user_id).await.expect("Should load this identity").unwrap();
1633        }
1634
1635        let carl_identity =
1636            database.get_user_identity(user_id!("@carl:localhost")).await.unwrap().unwrap();
1637
1638        assert_eq!(
1639            carl_identity.master_key().get_first_key().unwrap().to_base64(),
1640            "CdhKYYDeBDQveOioXEGWhTPCyzc63Irpar3CNyfun2Q"
1641        );
1642        assert!(!carl_identity.was_previously_verified());
1643
1644        let bob_identity =
1645            database.get_user_identity(user_id!("@bob:localhost")).await.unwrap().unwrap();
1646
1647        assert_eq!(
1648            bob_identity.master_key().get_first_key().unwrap().to_base64(),
1649            "COh2GYOJWSjem5QPRCaGp9iWV83IELG1IzLKW2S3pFY"
1650        );
1651        // Bob is verified so this flag should be set
1652        assert!(bob_identity.was_previously_verified());
1653
1654        let known_devices = vec![
1655            (device_id!("OPXQHCZSKW"), user_id!("@alice:localhost")),
1656            // a dehydrated one
1657            (
1658                device_id!("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho"),
1659                user_id!("@alice:localhost"),
1660            ),
1661            (device_id!("HEEFRFQENV"), user_id!("@alice:localhost")),
1662            (device_id!("JVVORTHFXY"), user_id!("@alice:localhost")),
1663            (device_id!("NQUWWSKKHS"), user_id!("@alice:localhost")),
1664            (device_id!("ORBLPFYCPG"), user_id!("@alice:localhost")),
1665            (device_id!("YXOWENSEGM"), user_id!("@dehydration3:localhost")),
1666            (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
1667            (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
1668            (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
1669            (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
1670            (device_id!("QKUKWJTTQC"), user_id!("@malo:localhost")),
1671            (device_id!("LOUXJECTFG"), user_id!("@malo:localhost")),
1672            (device_id!("MKKMAEVLPB"), user_id!("@carl:localhost")),
1673        ];
1674
1675        for (device_id, user_id) in known_devices {
1676            database.get_device(user_id, device_id).await.expect("Should load the device").unwrap();
1677        }
1678
1679        let known_sender_key_to_session_count = vec![
1680            ("FfYcYfDF4nWy+LHdK6CEpIMlFAQDORc30WUkghL06kM", 1),
1681            ("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho", 1),
1682            ("hAGsoA4a9M6wwEUX5Q1jux1i+tUngLi01n5AmhDoHTY", 1),
1683            ("aKqtSJymLzuoglWFwPGk1r/Vm2LE2hFESzXxn4RNjRM", 0),
1684            ("zHK1psCrgeMn0kaz8hcdvA3INyar9jg1yfrSp0p1pHo", 1),
1685            ("1QmBA316Wj5jIFRwNOti6N6Xh/vW0bsYCcR4uPfy8VQ", 1),
1686            ("g5ef2vZF3VXgSPyODIeXpyHIRkuthvLhGvd6uwYggWU", 1),
1687            ("o7hfupPd1VsNkRIvdlH6ujrEJFSKjFCGbxhAd31XxjI", 1),
1688            ("Z3RxKQLxY7xpP+ZdOGR2SiNE37SrvmRhW7GPu1UGdm8", 1),
1689            ("GDomaav8NiY3J+dNEeApJm+O0FooJ3IpVaIyJzCN4w4", 1),
1690            ("7m7fqkHyEr47V5s/KjaxtJMOr3pSHrrns2q2lWpAQi8", 0),
1691            ("9psAkPUIF8vNbWbnviX3PlwRcaeO53EHJdNtKpTY1X0", 0),
1692            ("mqanh+ztw5oRtpqYQgLGW864i6NY2zpoKMIlrcyC+Aw", 0),
1693            ("fJU/TJdbsv7tVbbpHw1Ke73ziElnM32cNhP2WIg4T10", 0),
1694            ("sUIeFeFcCZoa5IC6nJ6Vrbvztcyx09m8BBg57XKRClg", 1),
1695        ];
1696
1697        for (id, count) in known_sender_key_to_session_count {
1698            let olm_sessions =
1699                database.get_sessions(id).await.expect("Should have some olm sessions");
1700
1701            println!("### Session id: {id:?}");
1702            assert_eq!(olm_sessions.map_or(0, |v| v.len()), count);
1703        }
1704
1705        let inbound_group_sessions = database.get_inbound_group_sessions().await.unwrap();
1706        assert_eq!(inbound_group_sessions.len(), 15);
1707        let known_inbound_group_sessions = vec![
1708            (
1709                "5hNAxrLai3VI0LKBwfh3wLfksfBFWds0W1a5X5/vSXA",
1710                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1711            ),
1712            (
1713                "M6d2eU3y54gaYTbvGSlqa/xc1Az35l56Cp9sxzHWO4g",
1714                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1715            ),
1716            (
1717                "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
1718                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1719            ),
1720            (
1721                "Y74+l9jTo7N5UF+GQwdpgJGe4sn1+QtWITq7BxulHIE",
1722                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1723            ),
1724            (
1725                "HpJxQR57WbQGdY6w2Q+C16znVvbXGa+JvQdRoMpWbXg",
1726                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1727            ),
1728            (
1729                "Xetvi+ydFkZt8dpONGFbEusQb/Chc2V0XlLByZhsbgE",
1730                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1731            ),
1732            (
1733                "wv/WN/39akyerIXczTaIpjAuLnwgXKRtbXFSEHiJqxo",
1734                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1735            ),
1736            (
1737                "nA4gQwL//Cm8OdlyjABl/jChbPT/cP5V4Sd8iuE6H0s",
1738                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1739            ),
1740            (
1741                "bAAgqFeRDTjfEqL6Qf/c9mk55zoNDCSlboAIRd6b0hw",
1742                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1743            ),
1744            (
1745                "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
1746                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1747            ),
1748            (
1749                "h+om7oSw/ZV94fcKaoe8FGXJwQXWOfKQfzbGgNWQILI",
1750                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1751            ),
1752            (
1753                "ul3VXonpgk4lO2L3fEWubP/nxsTmLHqu5v8ZM9vHEcw",
1754                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1755            ),
1756            (
1757                "JXY15UxC3az2mwg8uX4qwgxfvCM4aygiIWMcdNiVQoc",
1758                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1759            ),
1760            (
1761                "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
1762                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1763            ),
1764            (
1765                "SFkHcbxjUOYF7mUAYI/oEMDZFaXszQbCN6Jza7iemj0",
1766                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1767            ),
1768        ];
1769
1770        // ensure we can load them all
1771        for (session_id, room_id) in &known_inbound_group_sessions {
1772            database
1773                .get_inbound_group_session(room_id, session_id)
1774                .await
1775                .expect("Should be able to load inbound group session")
1776                .unwrap();
1777        }
1778
1779        let bob_sender_verified = database
1780            .get_inbound_group_session(
1781                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1782                "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
1783            )
1784            .await
1785            .unwrap()
1786            .unwrap();
1787
1788        assert_eq!(bob_sender_verified.sender_data.to_type(), SenderDataType::SenderVerified);
1789        assert!(bob_sender_verified.backed_up());
1790        assert!(!bob_sender_verified.has_been_imported());
1791
1792        let alice_unknown_device = database
1793            .get_inbound_group_session(
1794                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1795                "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
1796            )
1797            .await
1798            .unwrap()
1799            .unwrap();
1800
1801        assert_eq!(alice_unknown_device.sender_data.to_type(), SenderDataType::UnknownDevice);
1802        assert!(alice_unknown_device.backed_up());
1803        assert!(alice_unknown_device.has_been_imported());
1804
1805        let carl_tofu_session = database
1806            .get_inbound_group_session(
1807                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1808                "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
1809            )
1810            .await
1811            .unwrap()
1812            .unwrap();
1813
1814        assert_eq!(carl_tofu_session.sender_data.to_type(), SenderDataType::SenderUnverified);
1815        assert!(carl_tofu_session.backed_up());
1816        assert!(!carl_tofu_session.has_been_imported());
1817
1818        // Load outbound sessions
1819        database
1820            .get_outbound_group_session(room_id!("!OgRiTRMaUzLdpCeDBM:localhost"))
1821            .await
1822            .unwrap()
1823            .unwrap();
1824        database
1825            .get_outbound_group_session(room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"))
1826            .await
1827            .unwrap()
1828            .unwrap();
1829        database
1830            .get_outbound_group_session(room_id!("!SRstFdydzrGwJYtVfm:localhost"))
1831            .await
1832            .unwrap()
1833            .unwrap();
1834
1835        let withheld_info = database
1836            .get_withheld_info(
1837                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1838                "SASgZ+EklvAF4QxJclMlDRlmL0fAMjAJJIKFMdb4Ht0",
1839            )
1840            .await
1841            .expect("This session should be withheld")
1842            .unwrap();
1843
1844        assert_eq!(withheld_info.content.withheld_code(), WithheldCode::Unverified);
1845
1846        let backup_keys = database.load_backup_keys().await.expect("backup key should be cached");
1847        assert_eq!(backup_keys.backup_version.unwrap(), "6");
1848        assert!(backup_keys.decryption_key.is_some());
1849    }
1850
1851    async fn get_store(
1852        name: &str,
1853        passphrase: Option<&str>,
1854        clear_data: bool,
1855    ) -> SqliteCryptoStore {
1856        let tmpdir_path = TMP_DIR.path().join(name);
1857
1858        if clear_data {
1859            let _ = fs::remove_dir_all(&tmpdir_path).await;
1860        }
1861
1862        SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), passphrase)
1863            .await
1864            .expect("Can't create a passphrase protected store")
1865    }
1866
1867    cryptostore_integration_tests!();
1868    cryptostore_integration_tests_time!();
1869}
1870
1871#[cfg(test)]
1872mod encrypted_tests {
1873    use matrix_sdk_crypto::{cryptostore_integration_tests, cryptostore_integration_tests_time};
1874    use once_cell::sync::Lazy;
1875    use tempfile::{tempdir, TempDir};
1876    use tokio::fs;
1877
1878    use super::SqliteCryptoStore;
1879
1880    static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
1881
1882    async fn get_store(
1883        name: &str,
1884        passphrase: Option<&str>,
1885        clear_data: bool,
1886    ) -> SqliteCryptoStore {
1887        let tmpdir_path = TMP_DIR.path().join(name);
1888        let pass = passphrase.unwrap_or("default_test_password");
1889
1890        if clear_data {
1891            let _ = fs::remove_dir_all(&tmpdir_path).await;
1892        }
1893
1894        SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), Some(pass))
1895            .await
1896            .expect("Can't create a passphrase protected store")
1897    }
1898
1899    cryptostore_integration_tests!();
1900    cryptostore_integration_tests_time!();
1901}