matrix_sdk_sqlite/
crypto_store.rs

1// Copyright 2022 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    borrow::Cow,
17    collections::HashMap,
18    fmt,
19    path::Path,
20    sync::{Arc, RwLock},
21};
22
23use async_trait::async_trait;
24use deadpool_sqlite::{Object as SqliteAsyncConn, Pool as SqlitePool, Runtime};
25use matrix_sdk_crypto::{
26    olm::{
27        InboundGroupSession, OutboundGroupSession, PickledInboundGroupSession,
28        PrivateCrossSigningIdentity, SenderDataType, Session, StaticAccountData,
29    },
30    store::{
31        BackupKeys, Changes, CryptoStore, DehydratedDeviceKey, PendingChanges, RoomKeyCounts,
32        RoomSettings,
33    },
34    types::events::room_key_withheld::RoomKeyWithheldEvent,
35    Account, DeviceData, GossipRequest, GossippedSecret, SecretInfo, TrackedUser, UserIdentityData,
36};
37use matrix_sdk_store_encryption::StoreCipher;
38use ruma::{
39    events::secret::request::SecretName, DeviceId, MilliSecondsSinceUnixEpoch, OwnedDeviceId,
40    RoomId, TransactionId, UserId,
41};
42use rusqlite::{named_params, params_from_iter, OptionalExtension};
43use serde::{de::DeserializeOwned, Serialize};
44use tokio::{fs, sync::Mutex};
45use tracing::{debug, instrument, warn};
46use vodozemac::Curve25519PublicKey;
47
48use crate::{
49    error::{Error, Result},
50    utils::{
51        repeat_vars, Key, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt,
52        SqliteKeyValueStoreConnExt,
53    },
54    OpenStoreError,
55};
56
57/// A sqlite based cryptostore.
58#[derive(Clone)]
59pub struct SqliteCryptoStore {
60    store_cipher: Option<Arc<StoreCipher>>,
61    pool: SqlitePool,
62
63    // DB values cached in memory
64    static_account: Arc<RwLock<Option<StaticAccountData>>>,
65    save_changes_lock: Arc<Mutex<()>>,
66}
67
68#[cfg(not(tarpaulin_include))]
69impl fmt::Debug for SqliteCryptoStore {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        f.debug_struct("SqliteCryptoStore").finish_non_exhaustive()
72    }
73}
74
75impl SqliteCryptoStore {
76    /// Open the sqlite-based crypto store at the given path using the given
77    /// passphrase to encrypt private data.
78    pub async fn open(
79        path: impl AsRef<Path>,
80        passphrase: Option<&str>,
81    ) -> Result<Self, OpenStoreError> {
82        let path = path.as_ref();
83        fs::create_dir_all(path).await.map_err(OpenStoreError::CreateDir)?;
84        let cfg = deadpool_sqlite::Config::new(path.join("matrix-sdk-crypto.sqlite3"));
85        let pool = cfg.create_pool(Runtime::Tokio1)?;
86
87        Self::open_with_pool(pool, passphrase).await
88    }
89
90    /// Create a sqlite-based crypto store using the given sqlite database pool.
91    /// The given passphrase will be used to encrypt private data.
92    pub async fn open_with_pool(
93        pool: SqlitePool,
94        passphrase: Option<&str>,
95    ) -> Result<Self, OpenStoreError> {
96        let conn = pool.get().await?;
97        conn.set_journal_size_limit().await?;
98
99        let version = conn.db_version().await?;
100        run_migrations(&conn, version).await?;
101        conn.optimize().await?;
102
103        let store_cipher = match passphrase {
104            Some(p) => Some(Arc::new(conn.get_or_create_store_cipher(p).await?)),
105            None => None,
106        };
107
108        Ok(SqliteCryptoStore {
109            store_cipher,
110            pool,
111            static_account: Arc::new(RwLock::new(None)),
112            save_changes_lock: Default::default(),
113        })
114    }
115
116    fn encode_value(&self, value: Vec<u8>) -> Result<Vec<u8>> {
117        if let Some(key) = &self.store_cipher {
118            let encrypted = key.encrypt_value_data(value)?;
119            Ok(rmp_serde::to_vec_named(&encrypted)?)
120        } else {
121            Ok(value)
122        }
123    }
124
125    fn decode_value<'a>(&self, value: &'a [u8]) -> Result<Cow<'a, [u8]>> {
126        if let Some(key) = &self.store_cipher {
127            let encrypted = rmp_serde::from_slice(value)?;
128            let decrypted = key.decrypt_value_data(encrypted)?;
129            Ok(Cow::Owned(decrypted))
130        } else {
131            Ok(Cow::Borrowed(value))
132        }
133    }
134
135    fn serialize_json(&self, value: &impl Serialize) -> Result<Vec<u8>> {
136        let serialized = serde_json::to_vec(value)?;
137        self.encode_value(serialized)
138    }
139
140    fn deserialize_json<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T> {
141        let decoded = self.decode_value(data)?;
142        Ok(serde_json::from_slice(&decoded)?)
143    }
144
145    fn serialize_value(&self, value: &impl Serialize) -> Result<Vec<u8>> {
146        let serialized = rmp_serde::to_vec_named(value)?;
147        self.encode_value(serialized)
148    }
149
150    fn deserialize_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T> {
151        let decoded = self.decode_value(value)?;
152        Ok(rmp_serde::from_slice(&decoded)?)
153    }
154
155    fn deserialize_and_unpickle_inbound_group_session(
156        &self,
157        value: Vec<u8>,
158        backed_up: bool,
159    ) -> Result<InboundGroupSession> {
160        let mut pickle: PickledInboundGroupSession = self.deserialize_value(&value)?;
161
162        // The `backed_up` SQL column is the source of truth, because we update it
163        // inside `mark_inbound_group_sessions_as_backed_up` and don't update
164        // the pickled value inside the `data` column (until now, when we are puling it
165        // out of the DB).
166        pickle.backed_up = backed_up;
167
168        Ok(InboundGroupSession::from_pickle(pickle)?)
169    }
170
171    fn deserialize_key_request(&self, value: &[u8], sent_out: bool) -> Result<GossipRequest> {
172        let mut request: GossipRequest = self.deserialize_value(value)?;
173        // sent_out SQL column is source of truth, sent_out field in serialized value
174        // needed for other stores though
175        request.sent_out = sent_out;
176        Ok(request)
177    }
178
179    fn encode_key(&self, table_name: &str, key: impl AsRef<[u8]>) -> Key {
180        let bytes = key.as_ref();
181        if let Some(store_cipher) = &self.store_cipher {
182            Key::Hashed(store_cipher.hash_key(table_name, bytes))
183        } else {
184            Key::Plain(bytes.to_owned())
185        }
186    }
187
188    fn get_static_account(&self) -> Option<StaticAccountData> {
189        self.static_account.read().unwrap().clone()
190    }
191
192    async fn acquire(&self) -> Result<SqliteAsyncConn> {
193        Ok(self.pool.get().await?)
194    }
195}
196
197const DATABASE_VERSION: u8 = 9;
198
199/// key for the dehydrated device pickle key in the key/value table.
200const DEHYDRATED_DEVICE_PICKLE_KEY: &str = "dehydrated_device_pickle_key";
201
202/// Run migrations for the given version of the database.
203async fn run_migrations(conn: &SqliteAsyncConn, version: u8) -> Result<()> {
204    if version == 0 {
205        debug!("Creating database");
206    } else if version < DATABASE_VERSION {
207        debug!(version, new_version = DATABASE_VERSION, "Upgrading database");
208    } else {
209        return Ok(());
210    }
211
212    if version < 1 {
213        // First turn on WAL mode, this can't be done in the transaction, it fails with
214        // the error message: "cannot change into wal mode from within a transaction".
215        conn.execute_batch("PRAGMA journal_mode = wal;").await?;
216        conn.with_transaction(|txn| {
217            txn.execute_batch(include_str!("../migrations/crypto_store/001_init.sql"))?;
218            txn.set_db_version(1)
219        })
220        .await?;
221    }
222
223    if version < 2 {
224        conn.with_transaction(|txn| {
225            txn.execute_batch(include_str!("../migrations/crypto_store/002_reset_olm_hash.sql"))?;
226            txn.set_db_version(2)
227        })
228        .await?;
229    }
230
231    if version < 3 {
232        conn.with_transaction(|txn| {
233            txn.execute_batch(include_str!("../migrations/crypto_store/003_room_settings.sql"))?;
234            txn.set_db_version(3)
235        })
236        .await?;
237    }
238
239    if version < 4 {
240        conn.with_transaction(|txn| {
241            txn.execute_batch(include_str!(
242                "../migrations/crypto_store/004_drop_outbound_group_sessions.sql"
243            ))?;
244            txn.set_db_version(4)
245        })
246        .await?;
247    }
248
249    if version < 5 {
250        conn.with_transaction(|txn| {
251            txn.execute_batch(include_str!("../migrations/crypto_store/005_withheld_code.sql"))?;
252            txn.set_db_version(5)
253        })
254        .await?;
255    }
256
257    if version < 6 {
258        conn.with_transaction(|txn| {
259            txn.execute_batch(include_str!(
260                "../migrations/crypto_store/006_drop_outbound_group_sessions.sql"
261            ))?;
262            txn.set_db_version(6)
263        })
264        .await?;
265    }
266
267    if version < 7 {
268        conn.with_transaction(|txn| {
269            txn.execute_batch(include_str!("../migrations/crypto_store/007_lock_leases.sql"))?;
270            txn.set_db_version(7)
271        })
272        .await?;
273    }
274
275    if version < 8 {
276        conn.with_transaction(|txn| {
277            txn.execute_batch(include_str!("../migrations/crypto_store/008_secret_inbox.sql"))?;
278            txn.set_db_version(8)
279        })
280        .await?;
281    }
282
283    if version < 9 {
284        conn.with_transaction(|txn| {
285            txn.execute_batch(include_str!(
286                "../migrations/crypto_store/009_inbound_group_session_sender_key_sender_data_type.sql"
287            ))?;
288            txn.set_db_version(9)
289        })
290        .await?;
291    }
292
293    Ok(())
294}
295
296trait SqliteConnectionExt {
297    fn set_session(
298        &self,
299        session_id: &[u8],
300        sender_key: &[u8],
301        data: &[u8],
302    ) -> rusqlite::Result<()>;
303
304    fn set_inbound_group_session(
305        &self,
306        room_id: &[u8],
307        session_id: &[u8],
308        data: &[u8],
309        backed_up: bool,
310        sender_key: Option<&[u8]>,
311        sender_data_type: Option<u8>,
312    ) -> rusqlite::Result<()>;
313
314    fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
315
316    fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
317    fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()>;
318
319    fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
320
321    fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()>;
322
323    fn set_key_request(
324        &self,
325        request_id: &[u8],
326        sent_out: bool,
327        data: &[u8],
328    ) -> rusqlite::Result<()>;
329
330    fn set_direct_withheld(
331        &self,
332        session_id: &[u8],
333        room_id: &[u8],
334        data: &[u8],
335    ) -> rusqlite::Result<()>;
336
337    fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
338
339    fn set_secret(&self, request_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
340}
341
342impl SqliteConnectionExt for rusqlite::Connection {
343    fn set_session(
344        &self,
345        session_id: &[u8],
346        sender_key: &[u8],
347        data: &[u8],
348    ) -> rusqlite::Result<()> {
349        self.execute(
350            "INSERT INTO session (session_id, sender_key, data)
351             VALUES (?1, ?2, ?3)
352             ON CONFLICT (session_id) DO UPDATE SET data = ?3",
353            (session_id, sender_key, data),
354        )?;
355        Ok(())
356    }
357
358    fn set_inbound_group_session(
359        &self,
360        room_id: &[u8],
361        session_id: &[u8],
362        data: &[u8],
363        backed_up: bool,
364        sender_key: Option<&[u8]>,
365        sender_data_type: Option<u8>,
366    ) -> rusqlite::Result<()> {
367        self.execute(
368            "INSERT INTO inbound_group_session (session_id, room_id, data, backed_up, sender_key, sender_data_type) \
369             VALUES (?1, ?2, ?3, ?4, ?5, ?6)
370             ON CONFLICT (session_id) DO UPDATE SET data = ?3, backed_up = ?4, sender_key = ?5, sender_data_type = ?6",
371            (session_id, room_id, data, backed_up, sender_key, sender_data_type),
372        )?;
373        Ok(())
374    }
375
376    fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
377        self.execute(
378            "INSERT INTO outbound_group_session (room_id, data) \
379             VALUES (?1, ?2)
380             ON CONFLICT (room_id) DO UPDATE SET data = ?2",
381            (room_id, data),
382        )?;
383        Ok(())
384    }
385
386    fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
387        self.execute(
388            "INSERT INTO device (user_id, device_id, data) \
389             VALUES (?1, ?2, ?3)
390             ON CONFLICT (user_id, device_id) DO UPDATE SET data = ?3",
391            (user_id, device_id, data),
392        )?;
393        Ok(())
394    }
395
396    fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()> {
397        self.execute(
398            "DELETE FROM device WHERE user_id = ? AND device_id = ?",
399            (user_id, device_id),
400        )?;
401        Ok(())
402    }
403
404    fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
405        self.execute(
406            "INSERT INTO identity (user_id, data) \
407             VALUES (?1, ?2)
408             ON CONFLICT (user_id) DO UPDATE SET data = ?2",
409            (user_id, data),
410        )?;
411        Ok(())
412    }
413
414    fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()> {
415        self.execute("INSERT INTO olm_hash (data) VALUES (?) ON CONFLICT DO NOTHING", (data,))?;
416        Ok(())
417    }
418
419    fn set_key_request(
420        &self,
421        request_id: &[u8],
422        sent_out: bool,
423        data: &[u8],
424    ) -> rusqlite::Result<()> {
425        self.execute(
426            "INSERT INTO key_requests (request_id, sent_out, data)
427            VALUES (?1, ?2, ?3)
428            ON CONFLICT (request_id) DO UPDATE SET sent_out = ?2, data = ?3",
429            (request_id, sent_out, data),
430        )?;
431        Ok(())
432    }
433
434    fn set_direct_withheld(
435        &self,
436        session_id: &[u8],
437        room_id: &[u8],
438        data: &[u8],
439    ) -> rusqlite::Result<()> {
440        self.execute(
441            "INSERT INTO direct_withheld_info (session_id, room_id, data)
442            VALUES (?1, ?2, ?3)
443            ON CONFLICT (session_id) DO UPDATE SET room_id = ?2, data = ?3",
444            (session_id, room_id, data),
445        )?;
446        Ok(())
447    }
448
449    fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
450        self.execute(
451            "INSERT INTO room_settings (room_id, data)
452            VALUES (?1, ?2)
453            ON CONFLICT (room_id) DO UPDATE SET data = ?2",
454            (room_id, data),
455        )?;
456        Ok(())
457    }
458
459    fn set_secret(&self, secret_name: &[u8], data: &[u8]) -> rusqlite::Result<()> {
460        self.execute(
461            "INSERT INTO secrets (secret_name, data)
462            VALUES (?1, ?2)",
463            (secret_name, data),
464        )?;
465
466        Ok(())
467    }
468}
469
470#[async_trait]
471trait SqliteObjectCryptoStoreExt: SqliteAsyncConnExt {
472    async fn get_sessions_for_sender_key(&self, sender_key: Key) -> Result<Vec<Vec<u8>>> {
473        Ok(self
474            .prepare("SELECT data FROM session WHERE sender_key = ?", |mut stmt| {
475                stmt.query((sender_key,))?.mapped(|row| row.get(0)).collect()
476            })
477            .await?)
478    }
479
480    async fn get_inbound_group_session(
481        &self,
482        session_id: Key,
483    ) -> Result<Option<(Vec<u8>, Vec<u8>, bool)>> {
484        Ok(self
485            .query_row(
486                "SELECT room_id, data, backed_up FROM inbound_group_session WHERE session_id = ?",
487                (session_id,),
488                |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
489            )
490            .await
491            .optional()?)
492    }
493
494    async fn get_inbound_group_sessions(&self) -> Result<Vec<(Vec<u8>, bool)>> {
495        Ok(self
496            .prepare("SELECT data, backed_up FROM inbound_group_session", |mut stmt| {
497                stmt.query(())?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
498            })
499            .await?)
500    }
501
502    async fn get_inbound_group_session_counts(
503        &self,
504        _backup_version: Option<&str>,
505    ) -> Result<RoomKeyCounts> {
506        let total = self
507            .query_row("SELECT count(*) FROM inbound_group_session", (), |row| row.get(0))
508            .await?;
509        let backed_up = self
510            .query_row(
511                "SELECT count(*) FROM inbound_group_session WHERE backed_up = TRUE",
512                (),
513                |row| row.get(0),
514            )
515            .await?;
516        Ok(RoomKeyCounts { total, backed_up })
517    }
518
519    async fn get_inbound_group_sessions_for_device_batch(
520        &self,
521        sender_key: Key,
522        sender_data_type: SenderDataType,
523        after_session_id: Option<Key>,
524        limit: usize,
525    ) -> Result<Vec<(Vec<u8>, bool)>> {
526        Ok(self
527            .prepare(
528                "
529                SELECT data, backed_up
530                FROM inbound_group_session
531                WHERE sender_key = :sender_key
532                    AND sender_data_type = :sender_data_type
533                    AND session_id > :after_session_id
534                ORDER BY session_id
535                LIMIT :limit
536                ",
537                move |mut stmt| {
538                    let sender_data_type = sender_data_type as u8;
539
540                    // If we are not provided with an `after_session_id`, use a key which will sort
541                    // before all real keys: the empty string.
542                    let after_session_id = after_session_id.unwrap_or(Key::Plain(Vec::new()));
543
544                    stmt.query(named_params! {
545                        ":sender_key": sender_key,
546                        ":sender_data_type": sender_data_type,
547                        ":after_session_id": after_session_id,
548                        ":limit": limit,
549                    })?
550                    .mapped(|row| Ok((row.get(0)?, row.get(1)?)))
551                    .collect()
552                },
553            )
554            .await?)
555    }
556
557    async fn get_inbound_group_sessions_for_backup(&self, limit: usize) -> Result<Vec<Vec<u8>>> {
558        Ok(self
559            .prepare(
560                "SELECT data FROM inbound_group_session WHERE backed_up = FALSE LIMIT ?",
561                move |mut stmt| stmt.query((limit,))?.mapped(|row| row.get(0)).collect(),
562            )
563            .await?)
564    }
565
566    async fn mark_inbound_group_sessions_as_backed_up(&self, session_ids: Vec<Key>) -> Result<()> {
567        if session_ids.is_empty() {
568            // We are not expecting to be called with an empty list of sessions
569            warn!("No sessions to mark as backed up!");
570            return Ok(());
571        }
572
573        let session_ids_len = session_ids.len();
574
575        self.chunk_large_query_over(session_ids, None, move |txn, session_ids| {
576            // Safety: placeholders is not generated using any user input except the number
577            // of session IDs, so it is safe from injection.
578            let sql_params = repeat_vars(session_ids_len);
579            let query = format!("UPDATE inbound_group_session SET backed_up = TRUE where session_id IN ({sql_params})");
580            txn.prepare(&query)?.execute(params_from_iter(session_ids.iter()))?;
581            Ok(Vec::<()>::new())
582        }).await?;
583
584        Ok(())
585    }
586
587    async fn reset_inbound_group_session_backup_state(&self) -> Result<()> {
588        self.execute("UPDATE inbound_group_session SET backed_up = FALSE", ()).await?;
589        Ok(())
590    }
591
592    async fn get_outbound_group_session(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
593        Ok(self
594            .query_row(
595                "SELECT data FROM outbound_group_session WHERE room_id = ?",
596                (room_id,),
597                |row| row.get(0),
598            )
599            .await
600            .optional()?)
601    }
602
603    async fn get_device(&self, user_id: Key, device_id: Key) -> Result<Option<Vec<u8>>> {
604        Ok(self
605            .query_row(
606                "SELECT data FROM device WHERE user_id = ? AND device_id = ?",
607                (user_id, device_id),
608                |row| row.get(0),
609            )
610            .await
611            .optional()?)
612    }
613
614    async fn get_user_devices(&self, user_id: Key) -> Result<Vec<Vec<u8>>> {
615        Ok(self
616            .prepare("SELECT data FROM device WHERE user_id = ?", |mut stmt| {
617                stmt.query((user_id,))?.mapped(|row| row.get(0)).collect()
618            })
619            .await?)
620    }
621
622    async fn get_user_identity(&self, user_id: Key) -> Result<Option<Vec<u8>>> {
623        Ok(self
624            .query_row("SELECT data FROM identity WHERE user_id = ?", (user_id,), |row| row.get(0))
625            .await
626            .optional()?)
627    }
628
629    async fn has_olm_hash(&self, data: Vec<u8>) -> Result<bool> {
630        Ok(self
631            .query_row("SELECT count(*) FROM olm_hash WHERE data = ?", (data,), |row| {
632                row.get::<_, i32>(0)
633            })
634            .await?
635            > 0)
636    }
637
638    async fn get_tracked_users(&self) -> Result<Vec<Vec<u8>>> {
639        Ok(self
640            .prepare("SELECT data FROM tracked_user", |mut stmt| {
641                stmt.query(())?.mapped(|row| row.get(0)).collect()
642            })
643            .await?)
644    }
645
646    async fn add_tracked_users(&self, users: Vec<(Key, Vec<u8>)>) -> Result<()> {
647        Ok(self
648            .prepare(
649                "INSERT INTO tracked_user (user_id, data) \
650                 VALUES (?1, ?2) \
651                 ON CONFLICT (user_id) DO UPDATE SET data = ?2",
652                |mut stmt| {
653                    for (user_id, data) in users {
654                        stmt.execute((user_id, data))?;
655                    }
656
657                    Ok(())
658                },
659            )
660            .await?)
661    }
662
663    async fn get_outgoing_secret_request(
664        &self,
665        request_id: Key,
666    ) -> Result<Option<(Vec<u8>, bool)>> {
667        Ok(self
668            .query_row(
669                "SELECT data, sent_out FROM key_requests WHERE request_id = ?",
670                (request_id,),
671                |row| Ok((row.get(0)?, row.get(1)?)),
672            )
673            .await
674            .optional()?)
675    }
676
677    async fn get_outgoing_secret_requests(&self) -> Result<Vec<(Vec<u8>, bool)>> {
678        Ok(self
679            .prepare("SELECT data, sent_out FROM key_requests", |mut stmt| {
680                stmt.query(())?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
681            })
682            .await?)
683    }
684
685    async fn get_unsent_secret_requests(&self) -> Result<Vec<Vec<u8>>> {
686        Ok(self
687            .prepare("SELECT data FROM key_requests WHERE sent_out = FALSE", |mut stmt| {
688                stmt.query(())?.mapped(|row| row.get(0)).collect()
689            })
690            .await?)
691    }
692
693    async fn delete_key_request(&self, request_id: Key) -> Result<()> {
694        self.execute("DELETE FROM key_requests WHERE request_id = ?", (request_id,)).await?;
695        Ok(())
696    }
697
698    async fn get_secrets_from_inbox(&self, secret_name: Key) -> Result<Vec<Vec<u8>>> {
699        Ok(self
700            .prepare("SELECT data FROM secrets WHERE secret_name = ?", |mut stmt| {
701                stmt.query((secret_name,))?.mapped(|row| row.get(0)).collect()
702            })
703            .await?)
704    }
705
706    async fn delete_secrets_from_inbox(&self, secret_name: Key) -> Result<()> {
707        self.execute("DELETE FROM secrets WHERE secret_name = ?", (secret_name,)).await?;
708        Ok(())
709    }
710
711    async fn get_direct_withheld_info(
712        &self,
713        session_id: Key,
714        room_id: Key,
715    ) -> Result<Option<Vec<u8>>> {
716        Ok(self
717            .query_row(
718                "SELECT data FROM direct_withheld_info WHERE session_id = ?1 AND room_id = ?2",
719                (session_id, room_id),
720                |row| row.get(0),
721            )
722            .await
723            .optional()?)
724    }
725
726    async fn get_room_settings(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
727        Ok(self
728            .query_row("SELECT data FROM room_settings WHERE room_id = ?", (room_id,), |row| {
729                row.get(0)
730            })
731            .await
732            .optional()?)
733    }
734}
735
736#[async_trait]
737impl SqliteObjectCryptoStoreExt for SqliteAsyncConn {}
738
739#[async_trait]
740impl CryptoStore for SqliteCryptoStore {
741    type Error = Error;
742
743    async fn load_account(&self) -> Result<Option<Account>> {
744        let conn = self.acquire().await?;
745        if let Some(pickle) = conn.get_kv("account").await? {
746            let pickle = self.deserialize_value(&pickle)?;
747
748            let account = Account::from_pickle(pickle).map_err(|_| Error::Unpickle)?;
749
750            *self.static_account.write().unwrap() = Some(account.static_data().clone());
751
752            Ok(Some(account))
753        } else {
754            Ok(None)
755        }
756    }
757
758    async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
759        let conn = self.acquire().await?;
760        if let Some(i) = conn.get_kv("identity").await? {
761            let pickle = self.deserialize_value(&i)?;
762            Ok(Some(PrivateCrossSigningIdentity::from_pickle(pickle).map_err(|_| Error::Unpickle)?))
763        } else {
764            Ok(None)
765        }
766    }
767
768    async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
769        // Serialize calls to `save_pending_changes`; there are multiple await points
770        // below, and we're pickling data as we go, so we don't want to
771        // invalidate data we've previously read and overwrite it in the store.
772        // TODO: #2000 should make this lock go away, or change its shape.
773        let _guard = self.save_changes_lock.lock().await;
774
775        let pickled_account = if let Some(account) = changes.account {
776            *self.static_account.write().unwrap() = Some(account.static_data().clone());
777            Some(account.pickle())
778        } else {
779            None
780        };
781
782        let this = self.clone();
783        self.acquire()
784            .await?
785            .with_transaction(move |txn| {
786                if let Some(pickled_account) = pickled_account {
787                    let serialized_account = this.serialize_value(&pickled_account)?;
788                    txn.set_kv("account", &serialized_account)?;
789                }
790
791                Ok::<_, Error>(())
792            })
793            .await?;
794
795        Ok(())
796    }
797
798    async fn save_changes(&self, changes: Changes) -> Result<()> {
799        // Serialize calls to `save_changes`; there are multiple await points below, and
800        // we're pickling data as we go, so we don't want to invalidate data
801        // we've previously read and overwrite it in the store.
802        // TODO: #2000 should make this lock go away, or change its shape.
803        let _guard = self.save_changes_lock.lock().await;
804
805        let pickled_private_identity =
806            if let Some(i) = changes.private_identity { Some(i.pickle().await) } else { None };
807
808        let mut session_changes = Vec::new();
809
810        for session in changes.sessions {
811            let session_id = self.encode_key("session", session.session_id());
812            let sender_key = self.encode_key("session", session.sender_key().to_base64());
813            let pickle = session.pickle().await;
814            session_changes.push((session_id, sender_key, pickle));
815        }
816
817        let mut inbound_session_changes = Vec::new();
818        for session in changes.inbound_group_sessions {
819            let room_id = self.encode_key("inbound_group_session", session.room_id().as_bytes());
820            let session_id = self.encode_key("inbound_group_session", session.session_id());
821            let pickle = session.pickle().await;
822            let sender_key =
823                self.encode_key("inbound_group_session", session.sender_key().to_base64());
824            inbound_session_changes.push((room_id, session_id, pickle, sender_key));
825        }
826
827        let mut outbound_session_changes = Vec::new();
828        for session in changes.outbound_group_sessions {
829            let room_id = self.encode_key("outbound_group_session", session.room_id().as_bytes());
830            let pickle = session.pickle().await;
831            outbound_session_changes.push((room_id, pickle));
832        }
833
834        let this = self.clone();
835        self.acquire()
836            .await?
837            .with_transaction(move |txn| {
838                if let Some(pickled_private_identity) = &pickled_private_identity {
839                    let serialized_private_identity =
840                        this.serialize_value(pickled_private_identity)?;
841                    txn.set_kv("identity", &serialized_private_identity)?;
842                }
843
844                if let Some(token) = &changes.next_batch_token {
845                    let serialized_token = this.serialize_value(token)?;
846                    txn.set_kv("next_batch_token", &serialized_token)?;
847                }
848
849                if let Some(decryption_key) = &changes.backup_decryption_key {
850                    let serialized_decryption_key = this.serialize_value(decryption_key)?;
851                    txn.set_kv("recovery_key_v1", &serialized_decryption_key)?;
852                }
853
854                if let Some(backup_version) = &changes.backup_version {
855                    let serialized_backup_version = this.serialize_value(backup_version)?;
856                    txn.set_kv("backup_version_v1", &serialized_backup_version)?;
857                }
858
859                if let Some(pickle_key) = &changes.dehydrated_device_pickle_key {
860                    let serialized_pickle_key = this.serialize_value(pickle_key)?;
861                    txn.set_kv(DEHYDRATED_DEVICE_PICKLE_KEY, &serialized_pickle_key)?;
862                }
863
864                for device in changes.devices.new.iter().chain(&changes.devices.changed) {
865                    let user_id = this.encode_key("device", device.user_id().as_bytes());
866                    let device_id = this.encode_key("device", device.device_id().as_bytes());
867                    let data = this.serialize_value(&device)?;
868                    txn.set_device(&user_id, &device_id, &data)?;
869                }
870
871                for device in &changes.devices.deleted {
872                    let user_id = this.encode_key("device", device.user_id().as_bytes());
873                    let device_id = this.encode_key("device", device.device_id().as_bytes());
874                    txn.delete_device(&user_id, &device_id)?;
875                }
876
877                for identity in changes.identities.changed.iter().chain(&changes.identities.new) {
878                    let user_id = this.encode_key("identity", identity.user_id().as_bytes());
879                    let data = this.serialize_value(&identity)?;
880                    txn.set_identity(&user_id, &data)?;
881                }
882
883                for (session_id, sender_key, pickle) in &session_changes {
884                    let serialized_session = this.serialize_value(&pickle)?;
885                    txn.set_session(session_id, sender_key, &serialized_session)?;
886                }
887
888                for (room_id, session_id, pickle, sender_key) in &inbound_session_changes {
889                    let serialized_session = this.serialize_value(&pickle)?;
890                    txn.set_inbound_group_session(
891                        room_id,
892                        session_id,
893                        &serialized_session,
894                        pickle.backed_up,
895                        Some(sender_key),
896                        Some(pickle.sender_data.to_type() as u8),
897                    )?;
898                }
899
900                for (room_id, pickle) in &outbound_session_changes {
901                    let serialized_session = this.serialize_json(&pickle)?;
902                    txn.set_outbound_group_session(room_id, &serialized_session)?;
903                }
904
905                for hash in &changes.message_hashes {
906                    let hash = rmp_serde::to_vec(hash)?;
907                    txn.add_olm_hash(&hash)?;
908                }
909
910                for request in changes.key_requests {
911                    let request_id = this.encode_key("key_requests", request.request_id.as_bytes());
912                    let serialized_request = this.serialize_value(&request)?;
913                    txn.set_key_request(&request_id, request.sent_out, &serialized_request)?;
914                }
915
916                for (room_id, data) in changes.withheld_session_info {
917                    for (session_id, event) in data {
918                        let session_id = this.encode_key("direct_withheld_info", session_id);
919                        let room_id = this.encode_key("direct_withheld_info", &room_id);
920                        let serialized_info = this.serialize_json(&event)?;
921                        txn.set_direct_withheld(&session_id, &room_id, &serialized_info)?;
922                    }
923                }
924
925                for (room_id, settings) in changes.room_settings {
926                    let room_id = this.encode_key("room_settings", room_id.as_bytes());
927                    let value = this.serialize_value(&settings)?;
928                    txn.set_room_settings(&room_id, &value)?;
929                }
930
931                for secret in changes.secrets {
932                    let secret_name = this.encode_key("secrets", secret.secret_name.to_string());
933                    let value = this.serialize_json(&secret)?;
934                    txn.set_secret(&secret_name, &value)?;
935                }
936
937                Ok::<_, Error>(())
938            })
939            .await?;
940
941        Ok(())
942    }
943
944    async fn save_inbound_group_sessions(
945        &self,
946        sessions: Vec<InboundGroupSession>,
947        backed_up_to_version: Option<&str>,
948    ) -> matrix_sdk_crypto::store::Result<(), Self::Error> {
949        // Sanity-check that the data in the sessions corresponds to backed_up_version
950        sessions.iter().for_each(|s| {
951            let backed_up = s.backed_up();
952            if backed_up != backed_up_to_version.is_some() {
953                warn!(
954                    backed_up,
955                    backed_up_to_version,
956                    "Session backed-up flag does not correspond to backup version setting",
957                );
958            }
959        });
960
961        // Currently, this store doesn't save the backup version separately, so this
962        // just delegates to save_changes.
963        self.save_changes(Changes { inbound_group_sessions: sessions, ..Changes::default() }).await
964    }
965
966    async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
967        let device_keys = self.get_own_device().await?.as_device_keys().clone();
968
969        let sessions: Vec<_> = self
970            .acquire()
971            .await?
972            .get_sessions_for_sender_key(self.encode_key("session", sender_key.as_bytes()))
973            .await?
974            .into_iter()
975            .map(|bytes| {
976                let pickle = self.deserialize_value(&bytes)?;
977                Session::from_pickle(device_keys.clone(), pickle).map_err(|_| Error::AccountUnset)
978            })
979            .collect::<Result<_>>()?;
980
981        if sessions.is_empty() {
982            Ok(None)
983        } else {
984            Ok(Some(sessions))
985        }
986    }
987
988    #[instrument(skip(self))]
989    async fn get_inbound_group_session(
990        &self,
991        room_id: &RoomId,
992        session_id: &str,
993    ) -> Result<Option<InboundGroupSession>> {
994        let session_id = self.encode_key("inbound_group_session", session_id);
995        let Some((room_id_from_db, value, backed_up)) =
996            self.acquire().await?.get_inbound_group_session(session_id).await?
997        else {
998            return Ok(None);
999        };
1000
1001        let room_id = self.encode_key("inbound_group_session", room_id.as_bytes());
1002        if *room_id != room_id_from_db {
1003            warn!("expected room_id for session_id doesn't match what's in the DB");
1004            return Ok(None);
1005        }
1006
1007        Ok(Some(self.deserialize_and_unpickle_inbound_group_session(value, backed_up)?))
1008    }
1009
1010    async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
1011        self.acquire()
1012            .await?
1013            .get_inbound_group_sessions()
1014            .await?
1015            .into_iter()
1016            .map(|(value, backed_up)| {
1017                self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1018            })
1019            .collect()
1020    }
1021
1022    async fn get_inbound_group_sessions_for_device_batch(
1023        &self,
1024        sender_key: Curve25519PublicKey,
1025        sender_data_type: SenderDataType,
1026        after_session_id: Option<String>,
1027        limit: usize,
1028    ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1029        let after_session_id =
1030            after_session_id.map(|session_id| self.encode_key("inbound_group_session", session_id));
1031        let sender_key = self.encode_key("inbound_group_session", sender_key.to_base64());
1032
1033        self.acquire()
1034            .await?
1035            .get_inbound_group_sessions_for_device_batch(
1036                sender_key,
1037                sender_data_type,
1038                after_session_id,
1039                limit,
1040            )
1041            .await?
1042            .into_iter()
1043            .map(|(value, backed_up)| {
1044                self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1045            })
1046            .collect()
1047    }
1048
1049    async fn inbound_group_session_counts(
1050        &self,
1051        backup_version: Option<&str>,
1052    ) -> Result<RoomKeyCounts> {
1053        Ok(self.acquire().await?.get_inbound_group_session_counts(backup_version).await?)
1054    }
1055
1056    async fn inbound_group_sessions_for_backup(
1057        &self,
1058        _backup_version: &str,
1059        limit: usize,
1060    ) -> Result<Vec<InboundGroupSession>> {
1061        self.acquire()
1062            .await?
1063            .get_inbound_group_sessions_for_backup(limit)
1064            .await?
1065            .into_iter()
1066            .map(|value| self.deserialize_and_unpickle_inbound_group_session(value, false))
1067            .collect()
1068    }
1069
1070    async fn mark_inbound_group_sessions_as_backed_up(
1071        &self,
1072        _backup_version: &str,
1073        session_ids: &[(&RoomId, &str)],
1074    ) -> Result<()> {
1075        Ok(self
1076            .acquire()
1077            .await?
1078            .mark_inbound_group_sessions_as_backed_up(
1079                session_ids
1080                    .iter()
1081                    .map(|(_, s)| self.encode_key("inbound_group_session", s))
1082                    .collect(),
1083            )
1084            .await?)
1085    }
1086
1087    async fn reset_backup_state(&self) -> Result<()> {
1088        Ok(self.acquire().await?.reset_inbound_group_session_backup_state().await?)
1089    }
1090
1091    async fn load_backup_keys(&self) -> Result<BackupKeys> {
1092        let conn = self.acquire().await?;
1093
1094        let backup_version = conn
1095            .get_kv("backup_version_v1")
1096            .await?
1097            .map(|value| self.deserialize_value(&value))
1098            .transpose()?;
1099
1100        let decryption_key = conn
1101            .get_kv("recovery_key_v1")
1102            .await?
1103            .map(|value| self.deserialize_value(&value))
1104            .transpose()?;
1105
1106        Ok(BackupKeys { backup_version, decryption_key })
1107    }
1108
1109    async fn load_dehydrated_device_pickle_key(&self) -> Result<Option<DehydratedDeviceKey>> {
1110        let conn = self.acquire().await?;
1111
1112        conn.get_kv(DEHYDRATED_DEVICE_PICKLE_KEY)
1113            .await?
1114            .map(|value| self.deserialize_value(&value))
1115            .transpose()
1116    }
1117
1118    async fn delete_dehydrated_device_pickle_key(&self) -> Result<(), Self::Error> {
1119        let conn = self.acquire().await?;
1120        conn.clear_kv(DEHYDRATED_DEVICE_PICKLE_KEY).await?;
1121
1122        Ok(())
1123    }
1124    async fn get_outbound_group_session(
1125        &self,
1126        room_id: &RoomId,
1127    ) -> Result<Option<OutboundGroupSession>> {
1128        let room_id = self.encode_key("outbound_group_session", room_id.as_bytes());
1129        let Some(value) = self.acquire().await?.get_outbound_group_session(room_id).await? else {
1130            return Ok(None);
1131        };
1132
1133        let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1134
1135        let pickle = self.deserialize_json(&value)?;
1136        let session = OutboundGroupSession::from_pickle(
1137            account_info.device_id,
1138            account_info.identity_keys,
1139            pickle,
1140        )
1141        .map_err(|_| Error::Unpickle)?;
1142
1143        return Ok(Some(session));
1144    }
1145
1146    async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
1147        self.acquire()
1148            .await?
1149            .get_tracked_users()
1150            .await?
1151            .iter()
1152            .map(|value| self.deserialize_value(value))
1153            .collect()
1154    }
1155
1156    async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
1157        let users: Vec<(Key, Vec<u8>)> = tracked_users
1158            .iter()
1159            .map(|(u, d)| {
1160                let user_id = self.encode_key("tracked_users", u.as_bytes());
1161                let data =
1162                    self.serialize_value(&TrackedUser { user_id: (*u).into(), dirty: *d })?;
1163                Ok((user_id, data))
1164            })
1165            .collect::<Result<_>>()?;
1166
1167        Ok(self.acquire().await?.add_tracked_users(users).await?)
1168    }
1169
1170    async fn get_device(
1171        &self,
1172        user_id: &UserId,
1173        device_id: &DeviceId,
1174    ) -> Result<Option<DeviceData>> {
1175        let user_id = self.encode_key("device", user_id.as_bytes());
1176        let device_id = self.encode_key("device", device_id.as_bytes());
1177        Ok(self
1178            .acquire()
1179            .await?
1180            .get_device(user_id, device_id)
1181            .await?
1182            .map(|value| self.deserialize_value(&value))
1183            .transpose()?)
1184    }
1185
1186    async fn get_user_devices(
1187        &self,
1188        user_id: &UserId,
1189    ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1190        let user_id = self.encode_key("device", user_id.as_bytes());
1191        self.acquire()
1192            .await?
1193            .get_user_devices(user_id)
1194            .await?
1195            .into_iter()
1196            .map(|value| {
1197                let device: DeviceData = self.deserialize_value(&value)?;
1198                Ok((device.device_id().to_owned(), device))
1199            })
1200            .collect()
1201    }
1202
1203    async fn get_own_device(&self) -> Result<DeviceData> {
1204        let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1205
1206        Ok(self
1207            .get_device(&account_info.user_id, &account_info.device_id)
1208            .await?
1209            .expect("We should be able to find our own device."))
1210    }
1211
1212    async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
1213        let user_id = self.encode_key("identity", user_id.as_bytes());
1214        Ok(self
1215            .acquire()
1216            .await?
1217            .get_user_identity(user_id)
1218            .await?
1219            .map(|value| self.deserialize_value(&value))
1220            .transpose()?)
1221    }
1222
1223    async fn is_message_known(
1224        &self,
1225        message_hash: &matrix_sdk_crypto::olm::OlmMessageHash,
1226    ) -> Result<bool> {
1227        let value = rmp_serde::to_vec(message_hash)?;
1228        Ok(self.acquire().await?.has_olm_hash(value).await?)
1229    }
1230
1231    async fn get_outgoing_secret_requests(
1232        &self,
1233        request_id: &TransactionId,
1234    ) -> Result<Option<GossipRequest>> {
1235        let request_id = self.encode_key("key_requests", request_id.as_bytes());
1236        Ok(self
1237            .acquire()
1238            .await?
1239            .get_outgoing_secret_request(request_id)
1240            .await?
1241            .map(|(value, sent_out)| self.deserialize_key_request(&value, sent_out))
1242            .transpose()?)
1243    }
1244
1245    async fn get_secret_request_by_info(
1246        &self,
1247        key_info: &SecretInfo,
1248    ) -> Result<Option<GossipRequest>> {
1249        let requests = self.acquire().await?.get_outgoing_secret_requests().await?;
1250        for (request, sent_out) in requests {
1251            let request = self.deserialize_key_request(&request, sent_out)?;
1252            if request.info == *key_info {
1253                return Ok(Some(request));
1254            }
1255        }
1256        Ok(None)
1257    }
1258
1259    async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
1260        self.acquire()
1261            .await?
1262            .get_unsent_secret_requests()
1263            .await?
1264            .iter()
1265            .map(|value| {
1266                let request = self.deserialize_key_request(value, false)?;
1267                Ok(request)
1268            })
1269            .collect()
1270    }
1271
1272    async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
1273        let request_id = self.encode_key("key_requests", request_id.as_bytes());
1274        Ok(self.acquire().await?.delete_key_request(request_id).await?)
1275    }
1276
1277    async fn get_secrets_from_inbox(
1278        &self,
1279        secret_name: &SecretName,
1280    ) -> Result<Vec<GossippedSecret>> {
1281        let secret_name = self.encode_key("secrets", secret_name.to_string());
1282
1283        self.acquire()
1284            .await?
1285            .get_secrets_from_inbox(secret_name)
1286            .await?
1287            .into_iter()
1288            .map(|value| self.deserialize_json(value.as_ref()))
1289            .collect()
1290    }
1291
1292    async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
1293        let secret_name = self.encode_key("secrets", secret_name.to_string());
1294        self.acquire().await?.delete_secrets_from_inbox(secret_name).await
1295    }
1296
1297    async fn get_withheld_info(
1298        &self,
1299        room_id: &RoomId,
1300        session_id: &str,
1301    ) -> Result<Option<RoomKeyWithheldEvent>> {
1302        let room_id = self.encode_key("direct_withheld_info", room_id);
1303        let session_id = self.encode_key("direct_withheld_info", session_id);
1304
1305        self.acquire()
1306            .await?
1307            .get_direct_withheld_info(session_id, room_id)
1308            .await?
1309            .map(|value| {
1310                let info = self.deserialize_json::<RoomKeyWithheldEvent>(&value)?;
1311                Ok(info)
1312            })
1313            .transpose()
1314    }
1315
1316    async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
1317        let room_id = self.encode_key("room_settings", room_id.as_bytes());
1318        let Some(value) = self.acquire().await?.get_room_settings(room_id).await? else {
1319            return Ok(None);
1320        };
1321
1322        let settings = self.deserialize_value(&value)?;
1323
1324        return Ok(Some(settings));
1325    }
1326
1327    async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
1328        let Some(serialized) = self.acquire().await?.get_kv(key).await? else {
1329            return Ok(None);
1330        };
1331        let value = if let Some(cipher) = &self.store_cipher {
1332            let encrypted = rmp_serde::from_slice(&serialized)?;
1333            cipher.decrypt_value_data(encrypted)?
1334        } else {
1335            serialized
1336        };
1337
1338        Ok(Some(value))
1339    }
1340
1341    async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
1342        let serialized = if let Some(cipher) = &self.store_cipher {
1343            let encrypted = cipher.encrypt_value_data(value)?;
1344            rmp_serde::to_vec_named(&encrypted)?
1345        } else {
1346            value
1347        };
1348
1349        self.acquire().await?.set_kv(key, serialized).await?;
1350        Ok(())
1351    }
1352
1353    async fn remove_custom_value(&self, key: &str) -> Result<()> {
1354        let key = key.to_owned();
1355        self.acquire()
1356            .await?
1357            .interact(move |conn| conn.execute("DELETE FROM kv WHERE key = ?1", (&key,)))
1358            .await
1359            .unwrap()?;
1360        Ok(())
1361    }
1362
1363    async fn try_take_leased_lock(
1364        &self,
1365        lease_duration_ms: u32,
1366        key: &str,
1367        holder: &str,
1368    ) -> Result<bool> {
1369        let key = key.to_owned();
1370        let holder = holder.to_owned();
1371
1372        let now_ts: u64 = MilliSecondsSinceUnixEpoch::now().get().into();
1373        let expiration_ts = now_ts + lease_duration_ms as u64;
1374
1375        let num_touched = self
1376            .acquire()
1377            .await?
1378            .with_transaction(move |txn| {
1379                txn.execute(
1380                    "INSERT INTO lease_locks (key, holder, expiration_ts)
1381                    VALUES (?1, ?2, ?3)
1382                    ON CONFLICT (key)
1383                    DO
1384                        UPDATE SET holder = ?2, expiration_ts = ?3
1385                        WHERE holder = ?2
1386                        OR expiration_ts < ?4
1387                ",
1388                    (key, holder, expiration_ts, now_ts),
1389                )
1390            })
1391            .await?;
1392
1393        Ok(num_touched == 1)
1394    }
1395
1396    async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
1397        let conn = self.acquire().await?;
1398        if let Some(token) = conn.get_kv("next_batch_token").await? {
1399            let maybe_token: Option<String> = self.deserialize_value(&token)?;
1400            Ok(maybe_token)
1401        } else {
1402            Ok(None)
1403        }
1404    }
1405}
1406
1407#[cfg(test)]
1408mod tests {
1409    use std::path::Path;
1410
1411    use matrix_sdk_common::deserialized_responses::WithheldCode;
1412    use matrix_sdk_crypto::{
1413        cryptostore_integration_tests, cryptostore_integration_tests_time, olm::SenderDataType,
1414        store::CryptoStore,
1415    };
1416    use matrix_sdk_test::async_test;
1417    use once_cell::sync::Lazy;
1418    use ruma::{device_id, room_id, user_id};
1419    use similar_asserts::assert_eq;
1420    use tempfile::{tempdir, TempDir};
1421    use tokio::fs;
1422
1423    use super::SqliteCryptoStore;
1424
1425    static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
1426
1427    struct TestDb {
1428        // Needs to be kept alive because the Drop implementation for TempDir deletes the
1429        // directory.
1430        _dir: TempDir,
1431        database: SqliteCryptoStore,
1432    }
1433
1434    async fn get_test_db(data_path: &str, passphrase: Option<&str>) -> TestDb {
1435        let db_name = "matrix-sdk-crypto.sqlite3";
1436
1437        let manifest_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../..");
1438        let database_path = manifest_path.join(data_path).join(db_name);
1439
1440        let tmpdir = tempdir().unwrap();
1441        let destination = tmpdir.path().join(db_name);
1442
1443        // Copy the test database to the tempdir so our test runs are idempotent.
1444        std::fs::copy(&database_path, destination).unwrap();
1445
1446        let database = SqliteCryptoStore::open(tmpdir.path(), passphrase)
1447            .await
1448            .expect("Can't open the test store");
1449
1450        TestDb { _dir: tmpdir, database }
1451    }
1452
1453    /// Test that we didn't regress in our storage layer by loading data from a
1454    /// pre-filled database, or in other words use a test vector for this.
1455    #[async_test]
1456    async fn test_open_test_vector_store() {
1457        let TestDb { _dir: _, database } = get_test_db("testing/data/storage", None).await;
1458
1459        let account = database
1460            .load_account()
1461            .await
1462            .unwrap()
1463            .expect("The test database is prefilled with data, we should find an account");
1464
1465        let user_id = account.user_id();
1466        let device_id = account.device_id();
1467
1468        assert_eq!(
1469            user_id.as_str(),
1470            "@pjtest:synapse-oidc.element.dev",
1471            "The user ID should match to the one we expect."
1472        );
1473
1474        assert_eq!(
1475            device_id.as_str(),
1476            "v4TqgcuIH6",
1477            "The device ID should match to the one we expect."
1478        );
1479
1480        let device = database
1481            .get_device(user_id, device_id)
1482            .await
1483            .unwrap()
1484            .expect("Our own device should be found in the store.");
1485
1486        assert_eq!(device.device_id(), device_id);
1487        assert_eq!(device.user_id(), user_id);
1488
1489        assert_eq!(
1490            device.ed25519_key().expect("The device should have a Ed25519 key.").to_base64(),
1491            "+cxl1Gl3du5i7UJwfWnoRDdnafFF+xYdAiTYYhYLr8s"
1492        );
1493
1494        assert_eq!(
1495            device.curve25519_key().expect("The device should have a Curve25519 key.").to_base64(),
1496            "4SL9eEUlpyWSUvjljC5oMjknHQQJY7WZKo5S1KL/5VU"
1497        );
1498
1499        let identity = database
1500            .get_user_identity(user_id)
1501            .await
1502            .unwrap()
1503            .expect("The store should contain an identity.");
1504
1505        assert_eq!(identity.user_id(), user_id);
1506
1507        let identity = identity
1508            .own()
1509            .expect("The identity should be of the correct type, it should be our own identity.");
1510
1511        let master_key = identity
1512            .master_key()
1513            .get_first_key()
1514            .expect("Our own identity should have a master key");
1515
1516        assert_eq!(master_key.to_base64(), "iCUEtB1RwANeqRa5epDrblLk4mer/36sylwQ5hYY3oE");
1517    }
1518
1519    /// Test that we didn't regress in our storage layer by loading data from a
1520    /// pre-filled database, or in other words use a test vector for this.
1521    #[async_test]
1522    async fn test_open_test_vector_encrypted_store() {
1523        let TestDb { _dir: _, database } = get_test_db(
1524            "testing/data/storage/alice",
1525            Some(concat!(
1526                "/rCia2fYAJ+twCZ1Xm2mxFCYcmJdyzkdJjwtgXsziWpYS/UeNxnixuSieuwZXm+x1VsJHmWpl",
1527                "H+QIQBZpEGZtC9/S/l8xK+WOCesmET0o6yJ/KP73ofDtjBlnNpPwuHLKFpyTbyicpCgQ4UT+5E",
1528                "UBuJ08TY9Ujdf1D13k5kr5tSZUefDKKCuG1fCRqlU8ByRas1PMQsZxT2W8t7QgBrQiiGmhpo/O",
1529                "Ti4hfx97GOxncKcxTzppiYQNoHs/f15+XXQD7/oiCcqRIuUlXNsU6hRpFGmbYx2Pi1eyQViQCt",
1530                "B5dAEiSD0N8U81wXYnpynuTPtnL+hfnOJIn7Sy7mkERQeKg"
1531            )),
1532        )
1533        .await;
1534
1535        let account = database
1536            .load_account()
1537            .await
1538            .unwrap()
1539            .expect("The test database is prefilled with data, we should find an account");
1540
1541        let user_id = account.user_id();
1542        let device_id = account.device_id();
1543
1544        assert_eq!(
1545            user_id.as_str(),
1546            "@alice:localhost",
1547            "The user ID should match to the one we expect."
1548        );
1549
1550        assert_eq!(
1551            device_id.as_str(),
1552            "JVVORTHFXY",
1553            "The device ID should match to the one we expect."
1554        );
1555
1556        let tracked_users =
1557            database.load_tracked_users().await.expect("Should be tracking some users");
1558
1559        assert_eq!(tracked_users.len(), 6);
1560
1561        let known_users = vec![
1562            user_id!("@alice:localhost"),
1563            user_id!("@dehydration3:localhost"),
1564            user_id!("@eve:localhost"),
1565            user_id!("@bob:localhost"),
1566            user_id!("@malo:localhost"),
1567            user_id!("@carl:localhost"),
1568        ];
1569
1570        // load the identities
1571        for user_id in known_users {
1572            database.get_user_identity(user_id).await.expect("Should load this identity").unwrap();
1573        }
1574
1575        let carl_identity =
1576            database.get_user_identity(user_id!("@carl:localhost")).await.unwrap().unwrap();
1577
1578        assert_eq!(
1579            carl_identity.master_key().get_first_key().unwrap().to_base64(),
1580            "CdhKYYDeBDQveOioXEGWhTPCyzc63Irpar3CNyfun2Q"
1581        );
1582        assert!(!carl_identity.was_previously_verified());
1583
1584        let bob_identity =
1585            database.get_user_identity(user_id!("@bob:localhost")).await.unwrap().unwrap();
1586
1587        assert_eq!(
1588            bob_identity.master_key().get_first_key().unwrap().to_base64(),
1589            "COh2GYOJWSjem5QPRCaGp9iWV83IELG1IzLKW2S3pFY"
1590        );
1591        // Bob is verified so this flag should be set
1592        assert!(bob_identity.was_previously_verified());
1593
1594        let known_devices = vec![
1595            (device_id!("OPXQHCZSKW"), user_id!("@alice:localhost")),
1596            // a dehydrated one
1597            (
1598                device_id!("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho"),
1599                user_id!("@alice:localhost"),
1600            ),
1601            (device_id!("HEEFRFQENV"), user_id!("@alice:localhost")),
1602            (device_id!("JVVORTHFXY"), user_id!("@alice:localhost")),
1603            (device_id!("NQUWWSKKHS"), user_id!("@alice:localhost")),
1604            (device_id!("ORBLPFYCPG"), user_id!("@alice:localhost")),
1605            (device_id!("YXOWENSEGM"), user_id!("@dehydration3:localhost")),
1606            (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
1607            (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
1608            (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
1609            (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
1610            (device_id!("QKUKWJTTQC"), user_id!("@malo:localhost")),
1611            (device_id!("LOUXJECTFG"), user_id!("@malo:localhost")),
1612            (device_id!("MKKMAEVLPB"), user_id!("@carl:localhost")),
1613        ];
1614
1615        for (device_id, user_id) in known_devices {
1616            database.get_device(user_id, device_id).await.expect("Should load the device").unwrap();
1617        }
1618
1619        let known_sender_key_to_session_count = vec![
1620            ("FfYcYfDF4nWy+LHdK6CEpIMlFAQDORc30WUkghL06kM", 1),
1621            ("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho", 1),
1622            ("hAGsoA4a9M6wwEUX5Q1jux1i+tUngLi01n5AmhDoHTY", 1),
1623            ("aKqtSJymLzuoglWFwPGk1r/Vm2LE2hFESzXxn4RNjRM", 0),
1624            ("zHK1psCrgeMn0kaz8hcdvA3INyar9jg1yfrSp0p1pHo", 1),
1625            ("1QmBA316Wj5jIFRwNOti6N6Xh/vW0bsYCcR4uPfy8VQ", 1),
1626            ("g5ef2vZF3VXgSPyODIeXpyHIRkuthvLhGvd6uwYggWU", 1),
1627            ("o7hfupPd1VsNkRIvdlH6ujrEJFSKjFCGbxhAd31XxjI", 1),
1628            ("Z3RxKQLxY7xpP+ZdOGR2SiNE37SrvmRhW7GPu1UGdm8", 1),
1629            ("GDomaav8NiY3J+dNEeApJm+O0FooJ3IpVaIyJzCN4w4", 1),
1630            ("7m7fqkHyEr47V5s/KjaxtJMOr3pSHrrns2q2lWpAQi8", 0),
1631            ("9psAkPUIF8vNbWbnviX3PlwRcaeO53EHJdNtKpTY1X0", 0),
1632            ("mqanh+ztw5oRtpqYQgLGW864i6NY2zpoKMIlrcyC+Aw", 0),
1633            ("fJU/TJdbsv7tVbbpHw1Ke73ziElnM32cNhP2WIg4T10", 0),
1634            ("sUIeFeFcCZoa5IC6nJ6Vrbvztcyx09m8BBg57XKRClg", 1),
1635        ];
1636
1637        for (id, count) in known_sender_key_to_session_count {
1638            let olm_sessions =
1639                database.get_sessions(id).await.expect("Should have some olm sessions");
1640
1641            println!("### Session id: {:?}", id);
1642            assert_eq!(olm_sessions.map_or(0, |v| v.len()), count);
1643        }
1644
1645        let inbound_group_sessions = database.get_inbound_group_sessions().await.unwrap();
1646        assert_eq!(inbound_group_sessions.len(), 15);
1647        let known_inbound_group_sessions = vec![
1648            (
1649                "5hNAxrLai3VI0LKBwfh3wLfksfBFWds0W1a5X5/vSXA",
1650                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1651            ),
1652            (
1653                "M6d2eU3y54gaYTbvGSlqa/xc1Az35l56Cp9sxzHWO4g",
1654                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1655            ),
1656            (
1657                "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
1658                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1659            ),
1660            (
1661                "Y74+l9jTo7N5UF+GQwdpgJGe4sn1+QtWITq7BxulHIE",
1662                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1663            ),
1664            (
1665                "HpJxQR57WbQGdY6w2Q+C16znVvbXGa+JvQdRoMpWbXg",
1666                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1667            ),
1668            (
1669                "Xetvi+ydFkZt8dpONGFbEusQb/Chc2V0XlLByZhsbgE",
1670                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1671            ),
1672            (
1673                "wv/WN/39akyerIXczTaIpjAuLnwgXKRtbXFSEHiJqxo",
1674                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1675            ),
1676            (
1677                "nA4gQwL//Cm8OdlyjABl/jChbPT/cP5V4Sd8iuE6H0s",
1678                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1679            ),
1680            (
1681                "bAAgqFeRDTjfEqL6Qf/c9mk55zoNDCSlboAIRd6b0hw",
1682                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1683            ),
1684            (
1685                "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
1686                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1687            ),
1688            (
1689                "h+om7oSw/ZV94fcKaoe8FGXJwQXWOfKQfzbGgNWQILI",
1690                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1691            ),
1692            (
1693                "ul3VXonpgk4lO2L3fEWubP/nxsTmLHqu5v8ZM9vHEcw",
1694                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1695            ),
1696            (
1697                "JXY15UxC3az2mwg8uX4qwgxfvCM4aygiIWMcdNiVQoc",
1698                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1699            ),
1700            (
1701                "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
1702                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1703            ),
1704            (
1705                "SFkHcbxjUOYF7mUAYI/oEMDZFaXszQbCN6Jza7iemj0",
1706                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1707            ),
1708        ];
1709
1710        // ensure we can load them all
1711        for (session_id, room_id) in &known_inbound_group_sessions {
1712            database
1713                .get_inbound_group_session(room_id, session_id)
1714                .await
1715                .expect("Should be able to load inbound group session")
1716                .unwrap();
1717        }
1718
1719        let bob_sender_verified = database
1720            .get_inbound_group_session(
1721                room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1722                "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
1723            )
1724            .await
1725            .unwrap()
1726            .unwrap();
1727
1728        assert_eq!(bob_sender_verified.sender_data.to_type(), SenderDataType::SenderVerified);
1729        assert!(bob_sender_verified.backed_up());
1730        assert!(!bob_sender_verified.has_been_imported());
1731
1732        let alice_unknown_device = database
1733            .get_inbound_group_session(
1734                room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1735                "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
1736            )
1737            .await
1738            .unwrap()
1739            .unwrap();
1740
1741        assert_eq!(alice_unknown_device.sender_data.to_type(), SenderDataType::UnknownDevice);
1742        assert!(alice_unknown_device.backed_up());
1743        assert!(alice_unknown_device.has_been_imported());
1744
1745        let carl_tofu_session = database
1746            .get_inbound_group_session(
1747                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1748                "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
1749            )
1750            .await
1751            .unwrap()
1752            .unwrap();
1753
1754        assert_eq!(carl_tofu_session.sender_data.to_type(), SenderDataType::SenderUnverified);
1755        assert!(carl_tofu_session.backed_up());
1756        assert!(!carl_tofu_session.has_been_imported());
1757
1758        // Load outbound sessions
1759        database
1760            .get_outbound_group_session(room_id!("!OgRiTRMaUzLdpCeDBM:localhost"))
1761            .await
1762            .unwrap()
1763            .unwrap();
1764        database
1765            .get_outbound_group_session(room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"))
1766            .await
1767            .unwrap()
1768            .unwrap();
1769        database
1770            .get_outbound_group_session(room_id!("!SRstFdydzrGwJYtVfm:localhost"))
1771            .await
1772            .unwrap()
1773            .unwrap();
1774
1775        let withheld_info = database
1776            .get_withheld_info(
1777                room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1778                "SASgZ+EklvAF4QxJclMlDRlmL0fAMjAJJIKFMdb4Ht0",
1779            )
1780            .await
1781            .expect("This session should be withheld")
1782            .unwrap();
1783
1784        assert_eq!(withheld_info.content.withheld_code(), WithheldCode::Unverified);
1785
1786        let backup_keys = database.load_backup_keys().await.expect("backup key should be cached");
1787        assert_eq!(backup_keys.backup_version.unwrap(), "6");
1788        assert!(backup_keys.decryption_key.is_some());
1789    }
1790    async fn get_store(
1791        name: &str,
1792        passphrase: Option<&str>,
1793        clear_data: bool,
1794    ) -> SqliteCryptoStore {
1795        let tmpdir_path = TMP_DIR.path().join(name);
1796
1797        if clear_data {
1798            let _ = fs::remove_dir_all(&tmpdir_path).await;
1799        }
1800
1801        SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), passphrase)
1802            .await
1803            .expect("Can't create a passphrase protected store")
1804    }
1805
1806    cryptostore_integration_tests!();
1807    cryptostore_integration_tests_time!();
1808}
1809
1810#[cfg(test)]
1811mod encrypted_tests {
1812    use matrix_sdk_crypto::{cryptostore_integration_tests, cryptostore_integration_tests_time};
1813    use once_cell::sync::Lazy;
1814    use tempfile::{tempdir, TempDir};
1815    use tokio::fs;
1816
1817    use super::SqliteCryptoStore;
1818
1819    static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
1820
1821    async fn get_store(
1822        name: &str,
1823        passphrase: Option<&str>,
1824        clear_data: bool,
1825    ) -> SqliteCryptoStore {
1826        let tmpdir_path = TMP_DIR.path().join(name);
1827        let pass = passphrase.unwrap_or("default_test_password");
1828
1829        if clear_data {
1830            let _ = fs::remove_dir_all(&tmpdir_path).await;
1831        }
1832
1833        SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), Some(pass))
1834            .await
1835            .expect("Can't create a passphrase protected store")
1836    }
1837
1838    cryptostore_integration_tests!();
1839    cryptostore_integration_tests_time!();
1840}