1use std::{
16 collections::HashMap,
17 fmt,
18 path::Path,
19 sync::{Arc, RwLock},
20};
21
22use async_trait::async_trait;
23use deadpool_sqlite::{Object as SqliteAsyncConn, Pool as SqlitePool, Runtime};
24use matrix_sdk_crypto::{
25 olm::{
26 InboundGroupSession, OutboundGroupSession, PickledInboundGroupSession,
27 PrivateCrossSigningIdentity, SenderDataType, Session, StaticAccountData,
28 },
29 store::{
30 types::{
31 BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts, RoomSettings,
32 StoredRoomKeyBundleData,
33 },
34 CryptoStore,
35 },
36 types::events::room_key_withheld::RoomKeyWithheldEvent,
37 Account, DeviceData, GossipRequest, GossippedSecret, SecretInfo, TrackedUser, UserIdentityData,
38};
39use matrix_sdk_store_encryption::StoreCipher;
40use ruma::{
41 events::secret::request::SecretName, DeviceId, MilliSecondsSinceUnixEpoch, OwnedDeviceId,
42 RoomId, TransactionId, UserId,
43};
44use rusqlite::{named_params, params_from_iter, OptionalExtension};
45use tokio::{fs, sync::Mutex};
46use tracing::{debug, instrument, warn};
47use vodozemac::Curve25519PublicKey;
48
49use crate::{
50 error::{Error, Result},
51 utils::{
52 repeat_vars, EncryptableStore, Key, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt,
53 SqliteKeyValueStoreConnExt,
54 },
55 OpenStoreError, SqliteStoreConfig,
56};
57
58const DATABASE_NAME: &str = "matrix-sdk-crypto.sqlite3";
60
61#[derive(Clone)]
63pub struct SqliteCryptoStore {
64 store_cipher: Option<Arc<StoreCipher>>,
65 pool: SqlitePool,
66
67 static_account: Arc<RwLock<Option<StaticAccountData>>>,
69 save_changes_lock: Arc<Mutex<()>>,
70}
71
72#[cfg(not(tarpaulin_include))]
73impl fmt::Debug for SqliteCryptoStore {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 f.debug_struct("SqliteCryptoStore").finish_non_exhaustive()
76 }
77}
78
79impl EncryptableStore for SqliteCryptoStore {
80 fn get_cypher(&self) -> Option<&StoreCipher> {
81 self.store_cipher.as_deref()
82 }
83}
84
85impl SqliteCryptoStore {
86 pub async fn open(
89 path: impl AsRef<Path>,
90 passphrase: Option<&str>,
91 ) -> Result<Self, OpenStoreError> {
92 Self::open_with_config(SqliteStoreConfig::new(path).passphrase(passphrase)).await
93 }
94
95 pub async fn open_with_config(config: SqliteStoreConfig) -> Result<Self, OpenStoreError> {
97 let SqliteStoreConfig { path, passphrase, pool_config, runtime_config } = config;
98
99 fs::create_dir_all(&path).await.map_err(OpenStoreError::CreateDir)?;
100
101 let mut config = deadpool_sqlite::Config::new(path.join(DATABASE_NAME));
102 config.pool = Some(pool_config);
103
104 let pool = config.create_pool(Runtime::Tokio1)?;
105
106 let this = Self::open_with_pool(pool, passphrase.as_deref()).await?;
107 this.pool.get().await?.apply_runtime_config(runtime_config).await?;
108
109 Ok(this)
110 }
111
112 async fn open_with_pool(
115 pool: SqlitePool,
116 passphrase: Option<&str>,
117 ) -> Result<Self, OpenStoreError> {
118 let conn = pool.get().await?;
119
120 let version = conn.db_version().await?;
121 debug!("Opened sqlite store with version {}", version);
122 run_migrations(&conn, version).await?;
123
124 let store_cipher = match passphrase {
125 Some(p) => Some(Arc::new(conn.get_or_create_store_cipher(p).await?)),
126 None => None,
127 };
128
129 Ok(SqliteCryptoStore {
130 store_cipher,
131 pool,
132 static_account: Arc::new(RwLock::new(None)),
133 save_changes_lock: Default::default(),
134 })
135 }
136
137 fn deserialize_and_unpickle_inbound_group_session(
138 &self,
139 value: Vec<u8>,
140 backed_up: bool,
141 ) -> Result<InboundGroupSession> {
142 let mut pickle: PickledInboundGroupSession = self.deserialize_value(&value)?;
143
144 pickle.backed_up = backed_up;
149
150 Ok(InboundGroupSession::from_pickle(pickle)?)
151 }
152
153 fn deserialize_key_request(&self, value: &[u8], sent_out: bool) -> Result<GossipRequest> {
154 let mut request: GossipRequest = self.deserialize_value(value)?;
155 request.sent_out = sent_out;
158 Ok(request)
159 }
160
161 fn get_static_account(&self) -> Option<StaticAccountData> {
162 self.static_account.read().unwrap().clone()
163 }
164
165 async fn acquire(&self) -> Result<SqliteAsyncConn> {
166 Ok(self.pool.get().await?)
167 }
168}
169
170const DATABASE_VERSION: u8 = 10;
171
172const DEHYDRATED_DEVICE_PICKLE_KEY: &str = "dehydrated_device_pickle_key";
174
175async fn run_migrations(conn: &SqliteAsyncConn, version: u8) -> Result<()> {
177 if version == 0 {
178 debug!("Creating database");
179 } else if version < DATABASE_VERSION {
180 debug!(version, new_version = DATABASE_VERSION, "Upgrading database");
181 } else {
182 return Ok(());
183 }
184
185 if version < 1 {
186 conn.execute_batch("PRAGMA journal_mode = wal;").await?;
189 conn.with_transaction(|txn| {
190 txn.execute_batch(include_str!("../migrations/crypto_store/001_init.sql"))?;
191 txn.set_db_version(1)
192 })
193 .await?;
194 }
195
196 if version < 2 {
197 conn.with_transaction(|txn| {
198 txn.execute_batch(include_str!("../migrations/crypto_store/002_reset_olm_hash.sql"))?;
199 txn.set_db_version(2)
200 })
201 .await?;
202 }
203
204 if version < 3 {
205 conn.with_transaction(|txn| {
206 txn.execute_batch(include_str!("../migrations/crypto_store/003_room_settings.sql"))?;
207 txn.set_db_version(3)
208 })
209 .await?;
210 }
211
212 if version < 4 {
213 conn.with_transaction(|txn| {
214 txn.execute_batch(include_str!(
215 "../migrations/crypto_store/004_drop_outbound_group_sessions.sql"
216 ))?;
217 txn.set_db_version(4)
218 })
219 .await?;
220 }
221
222 if version < 5 {
223 conn.with_transaction(|txn| {
224 txn.execute_batch(include_str!("../migrations/crypto_store/005_withheld_code.sql"))?;
225 txn.set_db_version(5)
226 })
227 .await?;
228 }
229
230 if version < 6 {
231 conn.with_transaction(|txn| {
232 txn.execute_batch(include_str!(
233 "../migrations/crypto_store/006_drop_outbound_group_sessions.sql"
234 ))?;
235 txn.set_db_version(6)
236 })
237 .await?;
238 }
239
240 if version < 7 {
241 conn.with_transaction(|txn| {
242 txn.execute_batch(include_str!("../migrations/crypto_store/007_lock_leases.sql"))?;
243 txn.set_db_version(7)
244 })
245 .await?;
246 }
247
248 if version < 8 {
249 conn.with_transaction(|txn| {
250 txn.execute_batch(include_str!("../migrations/crypto_store/008_secret_inbox.sql"))?;
251 txn.set_db_version(8)
252 })
253 .await?;
254 }
255
256 if version < 9 {
257 conn.with_transaction(|txn| {
258 txn.execute_batch(include_str!(
259 "../migrations/crypto_store/009_inbound_group_session_sender_key_sender_data_type.sql"
260 ))?;
261 txn.set_db_version(9)
262 })
263 .await?;
264 }
265
266 if version < 10 {
267 conn.with_transaction(|txn| {
268 txn.execute_batch(include_str!(
269 "../migrations/crypto_store/010_received_room_key_bundles.sql"
270 ))?;
271 txn.set_db_version(10)
272 })
273 .await?;
274 }
275
276 Ok(())
277}
278
279trait SqliteConnectionExt {
280 fn set_session(
281 &self,
282 session_id: &[u8],
283 sender_key: &[u8],
284 data: &[u8],
285 ) -> rusqlite::Result<()>;
286
287 fn set_inbound_group_session(
288 &self,
289 room_id: &[u8],
290 session_id: &[u8],
291 data: &[u8],
292 backed_up: bool,
293 sender_key: Option<&[u8]>,
294 sender_data_type: Option<u8>,
295 ) -> rusqlite::Result<()>;
296
297 fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
298
299 fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
300 fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()>;
301
302 fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
303
304 fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()>;
305
306 fn set_key_request(
307 &self,
308 request_id: &[u8],
309 sent_out: bool,
310 data: &[u8],
311 ) -> rusqlite::Result<()>;
312
313 fn set_direct_withheld(
314 &self,
315 session_id: &[u8],
316 room_id: &[u8],
317 data: &[u8],
318 ) -> rusqlite::Result<()>;
319
320 fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
321
322 fn set_secret(&self, request_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
323
324 fn set_received_room_key_bundle(
325 &self,
326 room_id: &[u8],
327 user_id: &[u8],
328 data: &[u8],
329 ) -> rusqlite::Result<()>;
330}
331
332impl SqliteConnectionExt for rusqlite::Connection {
333 fn set_session(
334 &self,
335 session_id: &[u8],
336 sender_key: &[u8],
337 data: &[u8],
338 ) -> rusqlite::Result<()> {
339 self.execute(
340 "INSERT INTO session (session_id, sender_key, data)
341 VALUES (?1, ?2, ?3)
342 ON CONFLICT (session_id) DO UPDATE SET data = ?3",
343 (session_id, sender_key, data),
344 )?;
345 Ok(())
346 }
347
348 fn set_inbound_group_session(
349 &self,
350 room_id: &[u8],
351 session_id: &[u8],
352 data: &[u8],
353 backed_up: bool,
354 sender_key: Option<&[u8]>,
355 sender_data_type: Option<u8>,
356 ) -> rusqlite::Result<()> {
357 self.execute(
358 "INSERT INTO inbound_group_session (session_id, room_id, data, backed_up, sender_key, sender_data_type) \
359 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
360 ON CONFLICT (session_id) DO UPDATE SET data = ?3, backed_up = ?4, sender_key = ?5, sender_data_type = ?6",
361 (session_id, room_id, data, backed_up, sender_key, sender_data_type),
362 )?;
363 Ok(())
364 }
365
366 fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
367 self.execute(
368 "INSERT INTO outbound_group_session (room_id, data) \
369 VALUES (?1, ?2)
370 ON CONFLICT (room_id) DO UPDATE SET data = ?2",
371 (room_id, data),
372 )?;
373 Ok(())
374 }
375
376 fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
377 self.execute(
378 "INSERT INTO device (user_id, device_id, data) \
379 VALUES (?1, ?2, ?3)
380 ON CONFLICT (user_id, device_id) DO UPDATE SET data = ?3",
381 (user_id, device_id, data),
382 )?;
383 Ok(())
384 }
385
386 fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()> {
387 self.execute(
388 "DELETE FROM device WHERE user_id = ? AND device_id = ?",
389 (user_id, device_id),
390 )?;
391 Ok(())
392 }
393
394 fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
395 self.execute(
396 "INSERT INTO identity (user_id, data) \
397 VALUES (?1, ?2)
398 ON CONFLICT (user_id) DO UPDATE SET data = ?2",
399 (user_id, data),
400 )?;
401 Ok(())
402 }
403
404 fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()> {
405 self.execute("INSERT INTO olm_hash (data) VALUES (?) ON CONFLICT DO NOTHING", (data,))?;
406 Ok(())
407 }
408
409 fn set_key_request(
410 &self,
411 request_id: &[u8],
412 sent_out: bool,
413 data: &[u8],
414 ) -> rusqlite::Result<()> {
415 self.execute(
416 "INSERT INTO key_requests (request_id, sent_out, data)
417 VALUES (?1, ?2, ?3)
418 ON CONFLICT (request_id) DO UPDATE SET sent_out = ?2, data = ?3",
419 (request_id, sent_out, data),
420 )?;
421 Ok(())
422 }
423
424 fn set_direct_withheld(
425 &self,
426 session_id: &[u8],
427 room_id: &[u8],
428 data: &[u8],
429 ) -> rusqlite::Result<()> {
430 self.execute(
431 "INSERT INTO direct_withheld_info (session_id, room_id, data)
432 VALUES (?1, ?2, ?3)
433 ON CONFLICT (session_id) DO UPDATE SET room_id = ?2, data = ?3",
434 (session_id, room_id, data),
435 )?;
436 Ok(())
437 }
438
439 fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
440 self.execute(
441 "INSERT INTO room_settings (room_id, data)
442 VALUES (?1, ?2)
443 ON CONFLICT (room_id) DO UPDATE SET data = ?2",
444 (room_id, data),
445 )?;
446 Ok(())
447 }
448
449 fn set_secret(&self, secret_name: &[u8], data: &[u8]) -> rusqlite::Result<()> {
450 self.execute(
451 "INSERT INTO secrets (secret_name, data)
452 VALUES (?1, ?2)",
453 (secret_name, data),
454 )?;
455
456 Ok(())
457 }
458
459 fn set_received_room_key_bundle(
460 &self,
461 room_id: &[u8],
462 sender_user_id: &[u8],
463 data: &[u8],
464 ) -> rusqlite::Result<()> {
465 self.execute(
466 "INSERT INTO received_room_key_bundle(room_id, sender_user_id, bundle_data)
467 VALUES (?1, ?2, ?3)
468 ON CONFLICT (room_id, sender_user_id) DO UPDATE SET bundle_data = ?3",
469 (room_id, sender_user_id, data),
470 )?;
471 Ok(())
472 }
473}
474
475#[async_trait]
476trait SqliteObjectCryptoStoreExt: SqliteAsyncConnExt {
477 async fn get_sessions_for_sender_key(&self, sender_key: Key) -> Result<Vec<Vec<u8>>> {
478 Ok(self
479 .prepare("SELECT data FROM session WHERE sender_key = ?", |mut stmt| {
480 stmt.query((sender_key,))?.mapped(|row| row.get(0)).collect()
481 })
482 .await?)
483 }
484
485 async fn get_inbound_group_session(
486 &self,
487 session_id: Key,
488 ) -> Result<Option<(Vec<u8>, Vec<u8>, bool)>> {
489 Ok(self
490 .query_row(
491 "SELECT room_id, data, backed_up FROM inbound_group_session WHERE session_id = ?",
492 (session_id,),
493 |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
494 )
495 .await
496 .optional()?)
497 }
498
499 async fn get_inbound_group_sessions(&self) -> Result<Vec<(Vec<u8>, bool)>> {
500 Ok(self
501 .prepare("SELECT data, backed_up FROM inbound_group_session", |mut stmt| {
502 stmt.query(())?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
503 })
504 .await?)
505 }
506
507 async fn get_inbound_group_session_counts(
508 &self,
509 _backup_version: Option<&str>,
510 ) -> Result<RoomKeyCounts> {
511 let total = self
512 .query_row("SELECT count(*) FROM inbound_group_session", (), |row| row.get(0))
513 .await?;
514 let backed_up = self
515 .query_row(
516 "SELECT count(*) FROM inbound_group_session WHERE backed_up = TRUE",
517 (),
518 |row| row.get(0),
519 )
520 .await?;
521 Ok(RoomKeyCounts { total, backed_up })
522 }
523
524 async fn get_inbound_group_sessions_for_device_batch(
525 &self,
526 sender_key: Key,
527 sender_data_type: SenderDataType,
528 after_session_id: Option<Key>,
529 limit: usize,
530 ) -> Result<Vec<(Vec<u8>, bool)>> {
531 Ok(self
532 .prepare(
533 "
534 SELECT data, backed_up
535 FROM inbound_group_session
536 WHERE sender_key = :sender_key
537 AND sender_data_type = :sender_data_type
538 AND session_id > :after_session_id
539 ORDER BY session_id
540 LIMIT :limit
541 ",
542 move |mut stmt| {
543 let sender_data_type = sender_data_type as u8;
544
545 let after_session_id = after_session_id.unwrap_or(Key::Plain(Vec::new()));
548
549 stmt.query(named_params! {
550 ":sender_key": sender_key,
551 ":sender_data_type": sender_data_type,
552 ":after_session_id": after_session_id,
553 ":limit": limit,
554 })?
555 .mapped(|row| Ok((row.get(0)?, row.get(1)?)))
556 .collect()
557 },
558 )
559 .await?)
560 }
561
562 async fn get_inbound_group_sessions_for_backup(&self, limit: usize) -> Result<Vec<Vec<u8>>> {
563 Ok(self
564 .prepare(
565 "SELECT data FROM inbound_group_session WHERE backed_up = FALSE LIMIT ?",
566 move |mut stmt| stmt.query((limit,))?.mapped(|row| row.get(0)).collect(),
567 )
568 .await?)
569 }
570
571 async fn mark_inbound_group_sessions_as_backed_up(&self, session_ids: Vec<Key>) -> Result<()> {
572 if session_ids.is_empty() {
573 warn!("No sessions to mark as backed up!");
575 return Ok(());
576 }
577
578 let session_ids_len = session_ids.len();
579
580 self.chunk_large_query_over(session_ids, None, move |txn, session_ids| {
581 let sql_params = repeat_vars(session_ids_len);
584 let query = format!("UPDATE inbound_group_session SET backed_up = TRUE where session_id IN ({sql_params})");
585 txn.prepare(&query)?.execute(params_from_iter(session_ids.iter()))?;
586 Ok(Vec::<()>::new())
587 }).await?;
588
589 Ok(())
590 }
591
592 async fn reset_inbound_group_session_backup_state(&self) -> Result<()> {
593 self.execute("UPDATE inbound_group_session SET backed_up = FALSE", ()).await?;
594 Ok(())
595 }
596
597 async fn get_outbound_group_session(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
598 Ok(self
599 .query_row(
600 "SELECT data FROM outbound_group_session WHERE room_id = ?",
601 (room_id,),
602 |row| row.get(0),
603 )
604 .await
605 .optional()?)
606 }
607
608 async fn get_device(&self, user_id: Key, device_id: Key) -> Result<Option<Vec<u8>>> {
609 Ok(self
610 .query_row(
611 "SELECT data FROM device WHERE user_id = ? AND device_id = ?",
612 (user_id, device_id),
613 |row| row.get(0),
614 )
615 .await
616 .optional()?)
617 }
618
619 async fn get_user_devices(&self, user_id: Key) -> Result<Vec<Vec<u8>>> {
620 Ok(self
621 .prepare("SELECT data FROM device WHERE user_id = ?", |mut stmt| {
622 stmt.query((user_id,))?.mapped(|row| row.get(0)).collect()
623 })
624 .await?)
625 }
626
627 async fn get_user_identity(&self, user_id: Key) -> Result<Option<Vec<u8>>> {
628 Ok(self
629 .query_row("SELECT data FROM identity WHERE user_id = ?", (user_id,), |row| row.get(0))
630 .await
631 .optional()?)
632 }
633
634 async fn has_olm_hash(&self, data: Vec<u8>) -> Result<bool> {
635 Ok(self
636 .query_row("SELECT count(*) FROM olm_hash WHERE data = ?", (data,), |row| {
637 row.get::<_, i32>(0)
638 })
639 .await?
640 > 0)
641 }
642
643 async fn get_tracked_users(&self) -> Result<Vec<Vec<u8>>> {
644 Ok(self
645 .prepare("SELECT data FROM tracked_user", |mut stmt| {
646 stmt.query(())?.mapped(|row| row.get(0)).collect()
647 })
648 .await?)
649 }
650
651 async fn add_tracked_users(&self, users: Vec<(Key, Vec<u8>)>) -> Result<()> {
652 Ok(self
653 .prepare(
654 "INSERT INTO tracked_user (user_id, data) \
655 VALUES (?1, ?2) \
656 ON CONFLICT (user_id) DO UPDATE SET data = ?2",
657 |mut stmt| {
658 for (user_id, data) in users {
659 stmt.execute((user_id, data))?;
660 }
661
662 Ok(())
663 },
664 )
665 .await?)
666 }
667
668 async fn get_outgoing_secret_request(
669 &self,
670 request_id: Key,
671 ) -> Result<Option<(Vec<u8>, bool)>> {
672 Ok(self
673 .query_row(
674 "SELECT data, sent_out FROM key_requests WHERE request_id = ?",
675 (request_id,),
676 |row| Ok((row.get(0)?, row.get(1)?)),
677 )
678 .await
679 .optional()?)
680 }
681
682 async fn get_outgoing_secret_requests(&self) -> Result<Vec<(Vec<u8>, bool)>> {
683 Ok(self
684 .prepare("SELECT data, sent_out FROM key_requests", |mut stmt| {
685 stmt.query(())?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
686 })
687 .await?)
688 }
689
690 async fn get_unsent_secret_requests(&self) -> Result<Vec<Vec<u8>>> {
691 Ok(self
692 .prepare("SELECT data FROM key_requests WHERE sent_out = FALSE", |mut stmt| {
693 stmt.query(())?.mapped(|row| row.get(0)).collect()
694 })
695 .await?)
696 }
697
698 async fn delete_key_request(&self, request_id: Key) -> Result<()> {
699 self.execute("DELETE FROM key_requests WHERE request_id = ?", (request_id,)).await?;
700 Ok(())
701 }
702
703 async fn get_secrets_from_inbox(&self, secret_name: Key) -> Result<Vec<Vec<u8>>> {
704 Ok(self
705 .prepare("SELECT data FROM secrets WHERE secret_name = ?", |mut stmt| {
706 stmt.query((secret_name,))?.mapped(|row| row.get(0)).collect()
707 })
708 .await?)
709 }
710
711 async fn delete_secrets_from_inbox(&self, secret_name: Key) -> Result<()> {
712 self.execute("DELETE FROM secrets WHERE secret_name = ?", (secret_name,)).await?;
713 Ok(())
714 }
715
716 async fn get_direct_withheld_info(
717 &self,
718 session_id: Key,
719 room_id: Key,
720 ) -> Result<Option<Vec<u8>>> {
721 Ok(self
722 .query_row(
723 "SELECT data FROM direct_withheld_info WHERE session_id = ?1 AND room_id = ?2",
724 (session_id, room_id),
725 |row| row.get(0),
726 )
727 .await
728 .optional()?)
729 }
730
731 async fn get_room_settings(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
732 Ok(self
733 .query_row("SELECT data FROM room_settings WHERE room_id = ?", (room_id,), |row| {
734 row.get(0)
735 })
736 .await
737 .optional()?)
738 }
739
740 async fn get_received_room_key_bundle(
741 &self,
742 room_id: Key,
743 sender_user: Key,
744 ) -> Result<Option<Vec<u8>>> {
745 Ok(self
746 .query_row(
747 "SELECT bundle_data FROM received_room_key_bundle WHERE room_id = ? AND sender_user_id = ?",
748 (room_id, sender_user),
749 |row| { row.get(0) },
750 )
751 .await
752 .optional()?)
753 }
754}
755
756#[async_trait]
757impl SqliteObjectCryptoStoreExt for SqliteAsyncConn {}
758
759#[async_trait]
760impl CryptoStore for SqliteCryptoStore {
761 type Error = Error;
762
763 async fn load_account(&self) -> Result<Option<Account>> {
764 let conn = self.acquire().await?;
765 if let Some(pickle) = conn.get_kv("account").await? {
766 let pickle = self.deserialize_value(&pickle)?;
767
768 let account = Account::from_pickle(pickle).map_err(|_| Error::Unpickle)?;
769
770 *self.static_account.write().unwrap() = Some(account.static_data().clone());
771
772 Ok(Some(account))
773 } else {
774 Ok(None)
775 }
776 }
777
778 async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
779 let conn = self.acquire().await?;
780 if let Some(i) = conn.get_kv("identity").await? {
781 let pickle = self.deserialize_value(&i)?;
782 Ok(Some(PrivateCrossSigningIdentity::from_pickle(pickle).map_err(|_| Error::Unpickle)?))
783 } else {
784 Ok(None)
785 }
786 }
787
788 async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
789 let _guard = self.save_changes_lock.lock().await;
794
795 let pickled_account = if let Some(account) = changes.account {
796 *self.static_account.write().unwrap() = Some(account.static_data().clone());
797 Some(account.pickle())
798 } else {
799 None
800 };
801
802 let this = self.clone();
803 self.acquire()
804 .await?
805 .with_transaction(move |txn| {
806 if let Some(pickled_account) = pickled_account {
807 let serialized_account = this.serialize_value(&pickled_account)?;
808 txn.set_kv("account", &serialized_account)?;
809 }
810
811 Ok::<_, Error>(())
812 })
813 .await?;
814
815 Ok(())
816 }
817
818 async fn save_changes(&self, changes: Changes) -> Result<()> {
819 let _guard = self.save_changes_lock.lock().await;
824
825 let pickled_private_identity =
826 if let Some(i) = changes.private_identity { Some(i.pickle().await) } else { None };
827
828 let mut session_changes = Vec::new();
829
830 for session in changes.sessions {
831 let session_id = self.encode_key("session", session.session_id());
832 let sender_key = self.encode_key("session", session.sender_key().to_base64());
833 let pickle = session.pickle().await;
834 session_changes.push((session_id, sender_key, pickle));
835 }
836
837 let mut inbound_session_changes = Vec::new();
838 for session in changes.inbound_group_sessions {
839 let room_id = self.encode_key("inbound_group_session", session.room_id().as_bytes());
840 let session_id = self.encode_key("inbound_group_session", session.session_id());
841 let pickle = session.pickle().await;
842 let sender_key =
843 self.encode_key("inbound_group_session", session.sender_key().to_base64());
844 inbound_session_changes.push((room_id, session_id, pickle, sender_key));
845 }
846
847 let mut outbound_session_changes = Vec::new();
848 for session in changes.outbound_group_sessions {
849 let room_id = self.encode_key("outbound_group_session", session.room_id().as_bytes());
850 let pickle = session.pickle().await;
851 outbound_session_changes.push((room_id, pickle));
852 }
853
854 let this = self.clone();
855 self.acquire()
856 .await?
857 .with_transaction(move |txn| {
858 if let Some(pickled_private_identity) = &pickled_private_identity {
859 let serialized_private_identity =
860 this.serialize_value(pickled_private_identity)?;
861 txn.set_kv("identity", &serialized_private_identity)?;
862 }
863
864 if let Some(token) = &changes.next_batch_token {
865 let serialized_token = this.serialize_value(token)?;
866 txn.set_kv("next_batch_token", &serialized_token)?;
867 }
868
869 if let Some(decryption_key) = &changes.backup_decryption_key {
870 let serialized_decryption_key = this.serialize_value(decryption_key)?;
871 txn.set_kv("recovery_key_v1", &serialized_decryption_key)?;
872 }
873
874 if let Some(backup_version) = &changes.backup_version {
875 let serialized_backup_version = this.serialize_value(backup_version)?;
876 txn.set_kv("backup_version_v1", &serialized_backup_version)?;
877 }
878
879 if let Some(pickle_key) = &changes.dehydrated_device_pickle_key {
880 let serialized_pickle_key = this.serialize_value(pickle_key)?;
881 txn.set_kv(DEHYDRATED_DEVICE_PICKLE_KEY, &serialized_pickle_key)?;
882 }
883
884 for device in changes.devices.new.iter().chain(&changes.devices.changed) {
885 let user_id = this.encode_key("device", device.user_id().as_bytes());
886 let device_id = this.encode_key("device", device.device_id().as_bytes());
887 let data = this.serialize_value(&device)?;
888 txn.set_device(&user_id, &device_id, &data)?;
889 }
890
891 for device in &changes.devices.deleted {
892 let user_id = this.encode_key("device", device.user_id().as_bytes());
893 let device_id = this.encode_key("device", device.device_id().as_bytes());
894 txn.delete_device(&user_id, &device_id)?;
895 }
896
897 for identity in changes.identities.changed.iter().chain(&changes.identities.new) {
898 let user_id = this.encode_key("identity", identity.user_id().as_bytes());
899 let data = this.serialize_value(&identity)?;
900 txn.set_identity(&user_id, &data)?;
901 }
902
903 for (session_id, sender_key, pickle) in &session_changes {
904 let serialized_session = this.serialize_value(&pickle)?;
905 txn.set_session(session_id, sender_key, &serialized_session)?;
906 }
907
908 for (room_id, session_id, pickle, sender_key) in &inbound_session_changes {
909 let serialized_session = this.serialize_value(&pickle)?;
910 txn.set_inbound_group_session(
911 room_id,
912 session_id,
913 &serialized_session,
914 pickle.backed_up,
915 Some(sender_key),
916 Some(pickle.sender_data.to_type() as u8),
917 )?;
918 }
919
920 for (room_id, pickle) in &outbound_session_changes {
921 let serialized_session = this.serialize_json(&pickle)?;
922 txn.set_outbound_group_session(room_id, &serialized_session)?;
923 }
924
925 for hash in &changes.message_hashes {
926 let hash = rmp_serde::to_vec(hash)?;
927 txn.add_olm_hash(&hash)?;
928 }
929
930 for request in changes.key_requests {
931 let request_id = this.encode_key("key_requests", request.request_id.as_bytes());
932 let serialized_request = this.serialize_value(&request)?;
933 txn.set_key_request(&request_id, request.sent_out, &serialized_request)?;
934 }
935
936 for (room_id, data) in changes.withheld_session_info {
937 for (session_id, event) in data {
938 let session_id = this.encode_key("direct_withheld_info", session_id);
939 let room_id = this.encode_key("direct_withheld_info", &room_id);
940 let serialized_info = this.serialize_json(&event)?;
941 txn.set_direct_withheld(&session_id, &room_id, &serialized_info)?;
942 }
943 }
944
945 for (room_id, settings) in changes.room_settings {
946 let room_id = this.encode_key("room_settings", room_id.as_bytes());
947 let value = this.serialize_value(&settings)?;
948 txn.set_room_settings(&room_id, &value)?;
949 }
950
951 for secret in changes.secrets {
952 let secret_name = this.encode_key("secrets", secret.secret_name.to_string());
953 let value = this.serialize_json(&secret)?;
954 txn.set_secret(&secret_name, &value)?;
955 }
956
957 for bundle in changes.received_room_key_bundles {
958 let room_id =
959 this.encode_key("received_room_key_bundle", &bundle.bundle_data.room_id);
960 let user_id = this.encode_key("received_room_key_bundle", &bundle.sender_user);
961 let value = this.serialize_value(&bundle)?;
962 txn.set_received_room_key_bundle(&room_id, &user_id, &value)?;
963 }
964
965 Ok::<_, Error>(())
966 })
967 .await?;
968
969 Ok(())
970 }
971
972 async fn save_inbound_group_sessions(
973 &self,
974 sessions: Vec<InboundGroupSession>,
975 backed_up_to_version: Option<&str>,
976 ) -> matrix_sdk_crypto::store::Result<(), Self::Error> {
977 sessions.iter().for_each(|s| {
979 let backed_up = s.backed_up();
980 if backed_up != backed_up_to_version.is_some() {
981 warn!(
982 backed_up,
983 backed_up_to_version,
984 "Session backed-up flag does not correspond to backup version setting",
985 );
986 }
987 });
988
989 self.save_changes(Changes { inbound_group_sessions: sessions, ..Changes::default() }).await
992 }
993
994 async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
995 let device_keys = self.get_own_device().await?.as_device_keys().clone();
996
997 let sessions: Vec<_> = self
998 .acquire()
999 .await?
1000 .get_sessions_for_sender_key(self.encode_key("session", sender_key.as_bytes()))
1001 .await?
1002 .into_iter()
1003 .map(|bytes| {
1004 let pickle = self.deserialize_value(&bytes)?;
1005 Session::from_pickle(device_keys.clone(), pickle).map_err(|_| Error::AccountUnset)
1006 })
1007 .collect::<Result<_>>()?;
1008
1009 if sessions.is_empty() {
1010 Ok(None)
1011 } else {
1012 Ok(Some(sessions))
1013 }
1014 }
1015
1016 #[instrument(skip(self))]
1017 async fn get_inbound_group_session(
1018 &self,
1019 room_id: &RoomId,
1020 session_id: &str,
1021 ) -> Result<Option<InboundGroupSession>> {
1022 let session_id = self.encode_key("inbound_group_session", session_id);
1023 let Some((room_id_from_db, value, backed_up)) =
1024 self.acquire().await?.get_inbound_group_session(session_id).await?
1025 else {
1026 return Ok(None);
1027 };
1028
1029 let room_id = self.encode_key("inbound_group_session", room_id.as_bytes());
1030 if *room_id != room_id_from_db {
1031 warn!("expected room_id for session_id doesn't match what's in the DB");
1032 return Ok(None);
1033 }
1034
1035 Ok(Some(self.deserialize_and_unpickle_inbound_group_session(value, backed_up)?))
1036 }
1037
1038 async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
1039 self.acquire()
1040 .await?
1041 .get_inbound_group_sessions()
1042 .await?
1043 .into_iter()
1044 .map(|(value, backed_up)| {
1045 self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1046 })
1047 .collect()
1048 }
1049
1050 async fn get_inbound_group_sessions_for_device_batch(
1051 &self,
1052 sender_key: Curve25519PublicKey,
1053 sender_data_type: SenderDataType,
1054 after_session_id: Option<String>,
1055 limit: usize,
1056 ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1057 let after_session_id =
1058 after_session_id.map(|session_id| self.encode_key("inbound_group_session", session_id));
1059 let sender_key = self.encode_key("inbound_group_session", sender_key.to_base64());
1060
1061 self.acquire()
1062 .await?
1063 .get_inbound_group_sessions_for_device_batch(
1064 sender_key,
1065 sender_data_type,
1066 after_session_id,
1067 limit,
1068 )
1069 .await?
1070 .into_iter()
1071 .map(|(value, backed_up)| {
1072 self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1073 })
1074 .collect()
1075 }
1076
1077 async fn inbound_group_session_counts(
1078 &self,
1079 backup_version: Option<&str>,
1080 ) -> Result<RoomKeyCounts> {
1081 Ok(self.acquire().await?.get_inbound_group_session_counts(backup_version).await?)
1082 }
1083
1084 async fn inbound_group_sessions_for_backup(
1085 &self,
1086 _backup_version: &str,
1087 limit: usize,
1088 ) -> Result<Vec<InboundGroupSession>> {
1089 self.acquire()
1090 .await?
1091 .get_inbound_group_sessions_for_backup(limit)
1092 .await?
1093 .into_iter()
1094 .map(|value| self.deserialize_and_unpickle_inbound_group_session(value, false))
1095 .collect()
1096 }
1097
1098 async fn mark_inbound_group_sessions_as_backed_up(
1099 &self,
1100 _backup_version: &str,
1101 session_ids: &[(&RoomId, &str)],
1102 ) -> Result<()> {
1103 Ok(self
1104 .acquire()
1105 .await?
1106 .mark_inbound_group_sessions_as_backed_up(
1107 session_ids
1108 .iter()
1109 .map(|(_, s)| self.encode_key("inbound_group_session", s))
1110 .collect(),
1111 )
1112 .await?)
1113 }
1114
1115 async fn reset_backup_state(&self) -> Result<()> {
1116 Ok(self.acquire().await?.reset_inbound_group_session_backup_state().await?)
1117 }
1118
1119 async fn load_backup_keys(&self) -> Result<BackupKeys> {
1120 let conn = self.acquire().await?;
1121
1122 let backup_version = conn
1123 .get_kv("backup_version_v1")
1124 .await?
1125 .map(|value| self.deserialize_value(&value))
1126 .transpose()?;
1127
1128 let decryption_key = conn
1129 .get_kv("recovery_key_v1")
1130 .await?
1131 .map(|value| self.deserialize_value(&value))
1132 .transpose()?;
1133
1134 Ok(BackupKeys { backup_version, decryption_key })
1135 }
1136
1137 async fn load_dehydrated_device_pickle_key(&self) -> Result<Option<DehydratedDeviceKey>> {
1138 let conn = self.acquire().await?;
1139
1140 conn.get_kv(DEHYDRATED_DEVICE_PICKLE_KEY)
1141 .await?
1142 .map(|value| self.deserialize_value(&value))
1143 .transpose()
1144 }
1145
1146 async fn delete_dehydrated_device_pickle_key(&self) -> Result<(), Self::Error> {
1147 let conn = self.acquire().await?;
1148 conn.clear_kv(DEHYDRATED_DEVICE_PICKLE_KEY).await?;
1149
1150 Ok(())
1151 }
1152 async fn get_outbound_group_session(
1153 &self,
1154 room_id: &RoomId,
1155 ) -> Result<Option<OutboundGroupSession>> {
1156 let room_id = self.encode_key("outbound_group_session", room_id.as_bytes());
1157 let Some(value) = self.acquire().await?.get_outbound_group_session(room_id).await? else {
1158 return Ok(None);
1159 };
1160
1161 let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1162
1163 let pickle = self.deserialize_json(&value)?;
1164 let session = OutboundGroupSession::from_pickle(
1165 account_info.device_id,
1166 account_info.identity_keys,
1167 pickle,
1168 )
1169 .map_err(|_| Error::Unpickle)?;
1170
1171 return Ok(Some(session));
1172 }
1173
1174 async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
1175 self.acquire()
1176 .await?
1177 .get_tracked_users()
1178 .await?
1179 .iter()
1180 .map(|value| self.deserialize_value(value))
1181 .collect()
1182 }
1183
1184 async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
1185 let users: Vec<(Key, Vec<u8>)> = tracked_users
1186 .iter()
1187 .map(|(u, d)| {
1188 let user_id = self.encode_key("tracked_users", u.as_bytes());
1189 let data =
1190 self.serialize_value(&TrackedUser { user_id: (*u).into(), dirty: *d })?;
1191 Ok((user_id, data))
1192 })
1193 .collect::<Result<_>>()?;
1194
1195 Ok(self.acquire().await?.add_tracked_users(users).await?)
1196 }
1197
1198 async fn get_device(
1199 &self,
1200 user_id: &UserId,
1201 device_id: &DeviceId,
1202 ) -> Result<Option<DeviceData>> {
1203 let user_id = self.encode_key("device", user_id.as_bytes());
1204 let device_id = self.encode_key("device", device_id.as_bytes());
1205 Ok(self
1206 .acquire()
1207 .await?
1208 .get_device(user_id, device_id)
1209 .await?
1210 .map(|value| self.deserialize_value(&value))
1211 .transpose()?)
1212 }
1213
1214 async fn get_user_devices(
1215 &self,
1216 user_id: &UserId,
1217 ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1218 let user_id = self.encode_key("device", user_id.as_bytes());
1219 self.acquire()
1220 .await?
1221 .get_user_devices(user_id)
1222 .await?
1223 .into_iter()
1224 .map(|value| {
1225 let device: DeviceData = self.deserialize_value(&value)?;
1226 Ok((device.device_id().to_owned(), device))
1227 })
1228 .collect()
1229 }
1230
1231 async fn get_own_device(&self) -> Result<DeviceData> {
1232 let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1233
1234 Ok(self
1235 .get_device(&account_info.user_id, &account_info.device_id)
1236 .await?
1237 .expect("We should be able to find our own device."))
1238 }
1239
1240 async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
1241 let user_id = self.encode_key("identity", user_id.as_bytes());
1242 Ok(self
1243 .acquire()
1244 .await?
1245 .get_user_identity(user_id)
1246 .await?
1247 .map(|value| self.deserialize_value(&value))
1248 .transpose()?)
1249 }
1250
1251 async fn is_message_known(
1252 &self,
1253 message_hash: &matrix_sdk_crypto::olm::OlmMessageHash,
1254 ) -> Result<bool> {
1255 let value = rmp_serde::to_vec(message_hash)?;
1256 Ok(self.acquire().await?.has_olm_hash(value).await?)
1257 }
1258
1259 async fn get_outgoing_secret_requests(
1260 &self,
1261 request_id: &TransactionId,
1262 ) -> Result<Option<GossipRequest>> {
1263 let request_id = self.encode_key("key_requests", request_id.as_bytes());
1264 Ok(self
1265 .acquire()
1266 .await?
1267 .get_outgoing_secret_request(request_id)
1268 .await?
1269 .map(|(value, sent_out)| self.deserialize_key_request(&value, sent_out))
1270 .transpose()?)
1271 }
1272
1273 async fn get_secret_request_by_info(
1274 &self,
1275 key_info: &SecretInfo,
1276 ) -> Result<Option<GossipRequest>> {
1277 let requests = self.acquire().await?.get_outgoing_secret_requests().await?;
1278 for (request, sent_out) in requests {
1279 let request = self.deserialize_key_request(&request, sent_out)?;
1280 if request.info == *key_info {
1281 return Ok(Some(request));
1282 }
1283 }
1284 Ok(None)
1285 }
1286
1287 async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
1288 self.acquire()
1289 .await?
1290 .get_unsent_secret_requests()
1291 .await?
1292 .iter()
1293 .map(|value| {
1294 let request = self.deserialize_key_request(value, false)?;
1295 Ok(request)
1296 })
1297 .collect()
1298 }
1299
1300 async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
1301 let request_id = self.encode_key("key_requests", request_id.as_bytes());
1302 Ok(self.acquire().await?.delete_key_request(request_id).await?)
1303 }
1304
1305 async fn get_secrets_from_inbox(
1306 &self,
1307 secret_name: &SecretName,
1308 ) -> Result<Vec<GossippedSecret>> {
1309 let secret_name = self.encode_key("secrets", secret_name.to_string());
1310
1311 self.acquire()
1312 .await?
1313 .get_secrets_from_inbox(secret_name)
1314 .await?
1315 .into_iter()
1316 .map(|value| self.deserialize_json(value.as_ref()))
1317 .collect()
1318 }
1319
1320 async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
1321 let secret_name = self.encode_key("secrets", secret_name.to_string());
1322 self.acquire().await?.delete_secrets_from_inbox(secret_name).await
1323 }
1324
1325 async fn get_withheld_info(
1326 &self,
1327 room_id: &RoomId,
1328 session_id: &str,
1329 ) -> Result<Option<RoomKeyWithheldEvent>> {
1330 let room_id = self.encode_key("direct_withheld_info", room_id);
1331 let session_id = self.encode_key("direct_withheld_info", session_id);
1332
1333 self.acquire()
1334 .await?
1335 .get_direct_withheld_info(session_id, room_id)
1336 .await?
1337 .map(|value| {
1338 let info = self.deserialize_json::<RoomKeyWithheldEvent>(&value)?;
1339 Ok(info)
1340 })
1341 .transpose()
1342 }
1343
1344 async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
1345 let room_id = self.encode_key("room_settings", room_id.as_bytes());
1346 let Some(value) = self.acquire().await?.get_room_settings(room_id).await? else {
1347 return Ok(None);
1348 };
1349
1350 let settings = self.deserialize_value(&value)?;
1351
1352 return Ok(Some(settings));
1353 }
1354
1355 async fn get_received_room_key_bundle_data(
1356 &self,
1357 room_id: &RoomId,
1358 user_id: &UserId,
1359 ) -> Result<Option<StoredRoomKeyBundleData>> {
1360 let room_id = self.encode_key("received_room_key_bundle", room_id);
1361 let user_id = self.encode_key("received_room_key_bundle", user_id);
1362 self.acquire()
1363 .await?
1364 .get_received_room_key_bundle(room_id, user_id)
1365 .await?
1366 .map(|value| self.deserialize_value(&value))
1367 .transpose()
1368 }
1369
1370 async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
1371 let Some(serialized) = self.acquire().await?.get_kv(key).await? else {
1372 return Ok(None);
1373 };
1374 let value = if let Some(cipher) = &self.store_cipher {
1375 let encrypted = rmp_serde::from_slice(&serialized)?;
1376 cipher.decrypt_value_data(encrypted)?
1377 } else {
1378 serialized
1379 };
1380
1381 Ok(Some(value))
1382 }
1383
1384 async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
1385 let serialized = if let Some(cipher) = &self.store_cipher {
1386 let encrypted = cipher.encrypt_value_data(value)?;
1387 rmp_serde::to_vec_named(&encrypted)?
1388 } else {
1389 value
1390 };
1391
1392 self.acquire().await?.set_kv(key, serialized).await?;
1393 Ok(())
1394 }
1395
1396 async fn remove_custom_value(&self, key: &str) -> Result<()> {
1397 let key = key.to_owned();
1398 self.acquire()
1399 .await?
1400 .interact(move |conn| conn.execute("DELETE FROM kv WHERE key = ?1", (&key,)))
1401 .await
1402 .unwrap()?;
1403 Ok(())
1404 }
1405
1406 async fn try_take_leased_lock(
1407 &self,
1408 lease_duration_ms: u32,
1409 key: &str,
1410 holder: &str,
1411 ) -> Result<bool> {
1412 let key = key.to_owned();
1413 let holder = holder.to_owned();
1414
1415 let now_ts: u64 = MilliSecondsSinceUnixEpoch::now().get().into();
1416 let expiration_ts = now_ts + lease_duration_ms as u64;
1417
1418 let num_touched = self
1419 .acquire()
1420 .await?
1421 .with_transaction(move |txn| {
1422 txn.execute(
1423 "INSERT INTO lease_locks (key, holder, expiration_ts)
1424 VALUES (?1, ?2, ?3)
1425 ON CONFLICT (key)
1426 DO
1427 UPDATE SET holder = ?2, expiration_ts = ?3
1428 WHERE holder = ?2
1429 OR expiration_ts < ?4
1430 ",
1431 (key, holder, expiration_ts, now_ts),
1432 )
1433 })
1434 .await?;
1435
1436 Ok(num_touched == 1)
1437 }
1438
1439 async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
1440 let conn = self.acquire().await?;
1441 if let Some(token) = conn.get_kv("next_batch_token").await? {
1442 let maybe_token: Option<String> = self.deserialize_value(&token)?;
1443 Ok(maybe_token)
1444 } else {
1445 Ok(None)
1446 }
1447 }
1448}
1449
1450#[cfg(test)]
1451mod tests {
1452 use std::path::Path;
1453
1454 use matrix_sdk_common::deserialized_responses::WithheldCode;
1455 use matrix_sdk_crypto::{
1456 cryptostore_integration_tests, cryptostore_integration_tests_time, olm::SenderDataType,
1457 store::CryptoStore,
1458 };
1459 use matrix_sdk_test::async_test;
1460 use once_cell::sync::Lazy;
1461 use ruma::{device_id, room_id, user_id};
1462 use similar_asserts::assert_eq;
1463 use tempfile::{tempdir, TempDir};
1464 use tokio::fs;
1465
1466 use super::SqliteCryptoStore;
1467 use crate::SqliteStoreConfig;
1468
1469 static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
1470
1471 struct TestDb {
1472 _dir: TempDir,
1475 database: SqliteCryptoStore,
1476 }
1477
1478 fn copy_db(data_path: &str) -> TempDir {
1479 let db_name = super::DATABASE_NAME;
1480
1481 let manifest_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../..");
1482 let database_path = manifest_path.join(data_path).join(db_name);
1483
1484 let tmpdir = tempdir().unwrap();
1485 let destination = tmpdir.path().join(db_name);
1486
1487 std::fs::copy(&database_path, destination).unwrap();
1489
1490 tmpdir
1491 }
1492
1493 async fn get_test_db(data_path: &str, passphrase: Option<&str>) -> TestDb {
1494 let tmpdir = copy_db(data_path);
1495
1496 let database = SqliteCryptoStore::open(tmpdir.path(), passphrase)
1497 .await
1498 .expect("Can't open the test store");
1499
1500 TestDb { _dir: tmpdir, database }
1501 }
1502
1503 #[async_test]
1504 async fn test_pool_size() {
1505 let store_open_config =
1506 SqliteStoreConfig::new(TMP_DIR.path().join("test_pool_size")).pool_max_size(42);
1507
1508 let store = SqliteCryptoStore::open_with_config(store_open_config).await.unwrap();
1509
1510 assert_eq!(store.pool.status().max_size, 42);
1511 }
1512
1513 #[async_test]
1516 async fn test_open_test_vector_store() {
1517 let TestDb { _dir: _, database } = get_test_db("testing/data/storage", None).await;
1518
1519 let account = database
1520 .load_account()
1521 .await
1522 .unwrap()
1523 .expect("The test database is prefilled with data, we should find an account");
1524
1525 let user_id = account.user_id();
1526 let device_id = account.device_id();
1527
1528 assert_eq!(
1529 user_id.as_str(),
1530 "@pjtest:synapse-oidc.element.dev",
1531 "The user ID should match to the one we expect."
1532 );
1533
1534 assert_eq!(
1535 device_id.as_str(),
1536 "v4TqgcuIH6",
1537 "The device ID should match to the one we expect."
1538 );
1539
1540 let device = database
1541 .get_device(user_id, device_id)
1542 .await
1543 .unwrap()
1544 .expect("Our own device should be found in the store.");
1545
1546 assert_eq!(device.device_id(), device_id);
1547 assert_eq!(device.user_id(), user_id);
1548
1549 assert_eq!(
1550 device.ed25519_key().expect("The device should have a Ed25519 key.").to_base64(),
1551 "+cxl1Gl3du5i7UJwfWnoRDdnafFF+xYdAiTYYhYLr8s"
1552 );
1553
1554 assert_eq!(
1555 device.curve25519_key().expect("The device should have a Curve25519 key.").to_base64(),
1556 "4SL9eEUlpyWSUvjljC5oMjknHQQJY7WZKo5S1KL/5VU"
1557 );
1558
1559 let identity = database
1560 .get_user_identity(user_id)
1561 .await
1562 .unwrap()
1563 .expect("The store should contain an identity.");
1564
1565 assert_eq!(identity.user_id(), user_id);
1566
1567 let identity = identity
1568 .own()
1569 .expect("The identity should be of the correct type, it should be our own identity.");
1570
1571 let master_key = identity
1572 .master_key()
1573 .get_first_key()
1574 .expect("Our own identity should have a master key");
1575
1576 assert_eq!(master_key.to_base64(), "iCUEtB1RwANeqRa5epDrblLk4mer/36sylwQ5hYY3oE");
1577 }
1578
1579 #[async_test]
1582 async fn test_open_test_vector_encrypted_store() {
1583 let TestDb { _dir: _, database } = get_test_db(
1584 "testing/data/storage/alice",
1585 Some(concat!(
1586 "/rCia2fYAJ+twCZ1Xm2mxFCYcmJdyzkdJjwtgXsziWpYS/UeNxnixuSieuwZXm+x1VsJHmWpl",
1587 "H+QIQBZpEGZtC9/S/l8xK+WOCesmET0o6yJ/KP73ofDtjBlnNpPwuHLKFpyTbyicpCgQ4UT+5E",
1588 "UBuJ08TY9Ujdf1D13k5kr5tSZUefDKKCuG1fCRqlU8ByRas1PMQsZxT2W8t7QgBrQiiGmhpo/O",
1589 "Ti4hfx97GOxncKcxTzppiYQNoHs/f15+XXQD7/oiCcqRIuUlXNsU6hRpFGmbYx2Pi1eyQViQCt",
1590 "B5dAEiSD0N8U81wXYnpynuTPtnL+hfnOJIn7Sy7mkERQeKg"
1591 )),
1592 )
1593 .await;
1594
1595 let account = database
1596 .load_account()
1597 .await
1598 .unwrap()
1599 .expect("The test database is prefilled with data, we should find an account");
1600
1601 let user_id = account.user_id();
1602 let device_id = account.device_id();
1603
1604 assert_eq!(
1605 user_id.as_str(),
1606 "@alice:localhost",
1607 "The user ID should match to the one we expect."
1608 );
1609
1610 assert_eq!(
1611 device_id.as_str(),
1612 "JVVORTHFXY",
1613 "The device ID should match to the one we expect."
1614 );
1615
1616 let tracked_users =
1617 database.load_tracked_users().await.expect("Should be tracking some users");
1618
1619 assert_eq!(tracked_users.len(), 6);
1620
1621 let known_users = vec![
1622 user_id!("@alice:localhost"),
1623 user_id!("@dehydration3:localhost"),
1624 user_id!("@eve:localhost"),
1625 user_id!("@bob:localhost"),
1626 user_id!("@malo:localhost"),
1627 user_id!("@carl:localhost"),
1628 ];
1629
1630 for user_id in known_users {
1632 database.get_user_identity(user_id).await.expect("Should load this identity").unwrap();
1633 }
1634
1635 let carl_identity =
1636 database.get_user_identity(user_id!("@carl:localhost")).await.unwrap().unwrap();
1637
1638 assert_eq!(
1639 carl_identity.master_key().get_first_key().unwrap().to_base64(),
1640 "CdhKYYDeBDQveOioXEGWhTPCyzc63Irpar3CNyfun2Q"
1641 );
1642 assert!(!carl_identity.was_previously_verified());
1643
1644 let bob_identity =
1645 database.get_user_identity(user_id!("@bob:localhost")).await.unwrap().unwrap();
1646
1647 assert_eq!(
1648 bob_identity.master_key().get_first_key().unwrap().to_base64(),
1649 "COh2GYOJWSjem5QPRCaGp9iWV83IELG1IzLKW2S3pFY"
1650 );
1651 assert!(bob_identity.was_previously_verified());
1653
1654 let known_devices = vec![
1655 (device_id!("OPXQHCZSKW"), user_id!("@alice:localhost")),
1656 (
1658 device_id!("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho"),
1659 user_id!("@alice:localhost"),
1660 ),
1661 (device_id!("HEEFRFQENV"), user_id!("@alice:localhost")),
1662 (device_id!("JVVORTHFXY"), user_id!("@alice:localhost")),
1663 (device_id!("NQUWWSKKHS"), user_id!("@alice:localhost")),
1664 (device_id!("ORBLPFYCPG"), user_id!("@alice:localhost")),
1665 (device_id!("YXOWENSEGM"), user_id!("@dehydration3:localhost")),
1666 (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
1667 (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
1668 (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
1669 (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
1670 (device_id!("QKUKWJTTQC"), user_id!("@malo:localhost")),
1671 (device_id!("LOUXJECTFG"), user_id!("@malo:localhost")),
1672 (device_id!("MKKMAEVLPB"), user_id!("@carl:localhost")),
1673 ];
1674
1675 for (device_id, user_id) in known_devices {
1676 database.get_device(user_id, device_id).await.expect("Should load the device").unwrap();
1677 }
1678
1679 let known_sender_key_to_session_count = vec![
1680 ("FfYcYfDF4nWy+LHdK6CEpIMlFAQDORc30WUkghL06kM", 1),
1681 ("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho", 1),
1682 ("hAGsoA4a9M6wwEUX5Q1jux1i+tUngLi01n5AmhDoHTY", 1),
1683 ("aKqtSJymLzuoglWFwPGk1r/Vm2LE2hFESzXxn4RNjRM", 0),
1684 ("zHK1psCrgeMn0kaz8hcdvA3INyar9jg1yfrSp0p1pHo", 1),
1685 ("1QmBA316Wj5jIFRwNOti6N6Xh/vW0bsYCcR4uPfy8VQ", 1),
1686 ("g5ef2vZF3VXgSPyODIeXpyHIRkuthvLhGvd6uwYggWU", 1),
1687 ("o7hfupPd1VsNkRIvdlH6ujrEJFSKjFCGbxhAd31XxjI", 1),
1688 ("Z3RxKQLxY7xpP+ZdOGR2SiNE37SrvmRhW7GPu1UGdm8", 1),
1689 ("GDomaav8NiY3J+dNEeApJm+O0FooJ3IpVaIyJzCN4w4", 1),
1690 ("7m7fqkHyEr47V5s/KjaxtJMOr3pSHrrns2q2lWpAQi8", 0),
1691 ("9psAkPUIF8vNbWbnviX3PlwRcaeO53EHJdNtKpTY1X0", 0),
1692 ("mqanh+ztw5oRtpqYQgLGW864i6NY2zpoKMIlrcyC+Aw", 0),
1693 ("fJU/TJdbsv7tVbbpHw1Ke73ziElnM32cNhP2WIg4T10", 0),
1694 ("sUIeFeFcCZoa5IC6nJ6Vrbvztcyx09m8BBg57XKRClg", 1),
1695 ];
1696
1697 for (id, count) in known_sender_key_to_session_count {
1698 let olm_sessions =
1699 database.get_sessions(id).await.expect("Should have some olm sessions");
1700
1701 println!("### Session id: {id:?}");
1702 assert_eq!(olm_sessions.map_or(0, |v| v.len()), count);
1703 }
1704
1705 let inbound_group_sessions = database.get_inbound_group_sessions().await.unwrap();
1706 assert_eq!(inbound_group_sessions.len(), 15);
1707 let known_inbound_group_sessions = vec![
1708 (
1709 "5hNAxrLai3VI0LKBwfh3wLfksfBFWds0W1a5X5/vSXA",
1710 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1711 ),
1712 (
1713 "M6d2eU3y54gaYTbvGSlqa/xc1Az35l56Cp9sxzHWO4g",
1714 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1715 ),
1716 (
1717 "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
1718 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1719 ),
1720 (
1721 "Y74+l9jTo7N5UF+GQwdpgJGe4sn1+QtWITq7BxulHIE",
1722 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1723 ),
1724 (
1725 "HpJxQR57WbQGdY6w2Q+C16znVvbXGa+JvQdRoMpWbXg",
1726 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1727 ),
1728 (
1729 "Xetvi+ydFkZt8dpONGFbEusQb/Chc2V0XlLByZhsbgE",
1730 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1731 ),
1732 (
1733 "wv/WN/39akyerIXczTaIpjAuLnwgXKRtbXFSEHiJqxo",
1734 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1735 ),
1736 (
1737 "nA4gQwL//Cm8OdlyjABl/jChbPT/cP5V4Sd8iuE6H0s",
1738 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1739 ),
1740 (
1741 "bAAgqFeRDTjfEqL6Qf/c9mk55zoNDCSlboAIRd6b0hw",
1742 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1743 ),
1744 (
1745 "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
1746 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1747 ),
1748 (
1749 "h+om7oSw/ZV94fcKaoe8FGXJwQXWOfKQfzbGgNWQILI",
1750 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1751 ),
1752 (
1753 "ul3VXonpgk4lO2L3fEWubP/nxsTmLHqu5v8ZM9vHEcw",
1754 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1755 ),
1756 (
1757 "JXY15UxC3az2mwg8uX4qwgxfvCM4aygiIWMcdNiVQoc",
1758 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1759 ),
1760 (
1761 "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
1762 room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1763 ),
1764 (
1765 "SFkHcbxjUOYF7mUAYI/oEMDZFaXszQbCN6Jza7iemj0",
1766 room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1767 ),
1768 ];
1769
1770 for (session_id, room_id) in &known_inbound_group_sessions {
1772 database
1773 .get_inbound_group_session(room_id, session_id)
1774 .await
1775 .expect("Should be able to load inbound group session")
1776 .unwrap();
1777 }
1778
1779 let bob_sender_verified = database
1780 .get_inbound_group_session(
1781 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1782 "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
1783 )
1784 .await
1785 .unwrap()
1786 .unwrap();
1787
1788 assert_eq!(bob_sender_verified.sender_data.to_type(), SenderDataType::SenderVerified);
1789 assert!(bob_sender_verified.backed_up());
1790 assert!(!bob_sender_verified.has_been_imported());
1791
1792 let alice_unknown_device = database
1793 .get_inbound_group_session(
1794 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1795 "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
1796 )
1797 .await
1798 .unwrap()
1799 .unwrap();
1800
1801 assert_eq!(alice_unknown_device.sender_data.to_type(), SenderDataType::UnknownDevice);
1802 assert!(alice_unknown_device.backed_up());
1803 assert!(alice_unknown_device.has_been_imported());
1804
1805 let carl_tofu_session = database
1806 .get_inbound_group_session(
1807 room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1808 "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
1809 )
1810 .await
1811 .unwrap()
1812 .unwrap();
1813
1814 assert_eq!(carl_tofu_session.sender_data.to_type(), SenderDataType::SenderUnverified);
1815 assert!(carl_tofu_session.backed_up());
1816 assert!(!carl_tofu_session.has_been_imported());
1817
1818 database
1820 .get_outbound_group_session(room_id!("!OgRiTRMaUzLdpCeDBM:localhost"))
1821 .await
1822 .unwrap()
1823 .unwrap();
1824 database
1825 .get_outbound_group_session(room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"))
1826 .await
1827 .unwrap()
1828 .unwrap();
1829 database
1830 .get_outbound_group_session(room_id!("!SRstFdydzrGwJYtVfm:localhost"))
1831 .await
1832 .unwrap()
1833 .unwrap();
1834
1835 let withheld_info = database
1836 .get_withheld_info(
1837 room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1838 "SASgZ+EklvAF4QxJclMlDRlmL0fAMjAJJIKFMdb4Ht0",
1839 )
1840 .await
1841 .expect("This session should be withheld")
1842 .unwrap();
1843
1844 assert_eq!(withheld_info.content.withheld_code(), WithheldCode::Unverified);
1845
1846 let backup_keys = database.load_backup_keys().await.expect("backup key should be cached");
1847 assert_eq!(backup_keys.backup_version.unwrap(), "6");
1848 assert!(backup_keys.decryption_key.is_some());
1849 }
1850
1851 async fn get_store(
1852 name: &str,
1853 passphrase: Option<&str>,
1854 clear_data: bool,
1855 ) -> SqliteCryptoStore {
1856 let tmpdir_path = TMP_DIR.path().join(name);
1857
1858 if clear_data {
1859 let _ = fs::remove_dir_all(&tmpdir_path).await;
1860 }
1861
1862 SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), passphrase)
1863 .await
1864 .expect("Can't create a passphrase protected store")
1865 }
1866
1867 cryptostore_integration_tests!();
1868 cryptostore_integration_tests_time!();
1869}
1870
1871#[cfg(test)]
1872mod encrypted_tests {
1873 use matrix_sdk_crypto::{cryptostore_integration_tests, cryptostore_integration_tests_time};
1874 use once_cell::sync::Lazy;
1875 use tempfile::{tempdir, TempDir};
1876 use tokio::fs;
1877
1878 use super::SqliteCryptoStore;
1879
1880 static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
1881
1882 async fn get_store(
1883 name: &str,
1884 passphrase: Option<&str>,
1885 clear_data: bool,
1886 ) -> SqliteCryptoStore {
1887 let tmpdir_path = TMP_DIR.path().join(name);
1888 let pass = passphrase.unwrap_or("default_test_password");
1889
1890 if clear_data {
1891 let _ = fs::remove_dir_all(&tmpdir_path).await;
1892 }
1893
1894 SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), Some(pass))
1895 .await
1896 .expect("Can't create a passphrase protected store")
1897 }
1898
1899 cryptostore_integration_tests!();
1900 cryptostore_integration_tests_time!();
1901}