1use std::{
16 collections::HashMap,
17 fmt,
18 ops::Deref,
19 path::{Path, PathBuf},
20 sync::{Arc, RwLock},
21};
22
23use async_trait::async_trait;
24use deadpool::managed::PoolConfig;
25use matrix_sdk_base::cross_process_lock::CrossProcessLockGeneration;
26use matrix_sdk_crypto::{
27 Account, DeviceData, GossipRequest, GossippedSecret, SecretInfo, TrackedUser, UserIdentityData,
28 olm::{
29 InboundGroupSession, OutboundGroupSession, PickledInboundGroupSession,
30 PrivateCrossSigningIdentity, SenderDataType, Session, StaticAccountData,
31 },
32 store::{
33 CryptoStore,
34 types::{
35 BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts,
36 RoomKeyWithheldEntry, RoomPendingKeyBundleDetails, RoomSettings,
37 StoredRoomKeyBundleData,
38 },
39 },
40};
41use matrix_sdk_store_encryption::StoreCipher;
42use ruma::{
43 DeviceId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, RoomId, TransactionId, UserId,
44 events::secret::request::SecretName,
45};
46use rusqlite::{OptionalExtension, named_params, params_from_iter};
47use tokio::{
48 fs,
49 sync::{Mutex, OwnedMutexGuard},
50};
51use tracing::{debug, instrument, warn};
52use vodozemac::Curve25519PublicKey;
53use zeroize::Zeroizing;
54
55use crate::{
56 OpenStoreError, RuntimeConfig, Secret, SqliteStoreConfig,
57 connection::{self, Connection as SqliteAsyncConn, Pool as SqlitePool, SqliteConnections},
58 error::{Error, Result},
59 utils::{
60 EncryptableStore, Key, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt,
61 SqliteKeyValueStoreConnExt, repeat_vars,
62 },
63};
64
65const DATABASE_NAME: &str = "matrix-sdk-crypto.sqlite3";
67
68#[derive(Clone)]
70pub struct SqliteCryptoStore {
71 store_cipher: Option<Arc<StoreCipher>>,
72
73 connections: Arc<Mutex<Option<SqliteConnections>>>,
76
77 db_path: PathBuf,
79
80 pool_config: PoolConfig,
82
83 runtime_config: RuntimeConfig,
85
86 static_account: Arc<RwLock<Option<StaticAccountData>>>,
88 save_changes_lock: Arc<Mutex<()>>,
89}
90
91#[cfg(not(tarpaulin_include))]
92impl fmt::Debug for SqliteCryptoStore {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 f.debug_struct("SqliteCryptoStore").finish_non_exhaustive()
95 }
96}
97
98impl EncryptableStore for SqliteCryptoStore {
99 fn get_cypher(&self) -> Option<&StoreCipher> {
100 self.store_cipher.as_deref()
101 }
102}
103
104impl SqliteCryptoStore {
105 pub(crate) async fn create_raw(
117 secret: Option<Secret>,
118 pool: SqlitePool,
119 conn: SqliteAsyncConn,
120 pool_config: PoolConfig,
121 runtime_config: RuntimeConfig,
122 ) -> Result<Self, OpenStoreError> {
123 let store_cipher = match secret {
124 Some(s) => Some(Arc::new(conn.get_or_create_store_cipher(s).await?)),
125 None => None,
126 };
127
128 let db_path = pool.manager().database_path.clone();
129
130 Ok(Self {
131 store_cipher,
132 connections: Arc::new(Mutex::new(Some(SqliteConnections {
133 pool,
134 write_connection: Arc::new(Mutex::new(conn)),
135 }))),
136 db_path,
137 pool_config,
138 runtime_config,
139 static_account: Arc::new(RwLock::new(None)),
140 save_changes_lock: Default::default(),
141 })
142 }
143
144 pub async fn open(
147 path: impl AsRef<Path>,
148 passphrase: Option<&str>,
149 ) -> Result<Self, OpenStoreError> {
150 Self::open_with_config(&SqliteStoreConfig::new(path).passphrase(passphrase)).await
151 }
152
153 pub async fn open_with_key(
156 path: impl AsRef<Path>,
157 key: Option<&[u8; 32]>,
158 ) -> Result<Self, OpenStoreError> {
159 Self::open_with_config(&SqliteStoreConfig::new(path).key(key)).await
160 }
161
162 pub async fn open_with_config(config: &SqliteStoreConfig) -> Result<Self, OpenStoreError> {
164 fs::create_dir_all(&config.path).await.map_err(OpenStoreError::CreateDir)?;
165
166 let pool = config.build_pool_of_connections(DATABASE_NAME)?;
167 let pool_config = config.pool_config();
168 let runtime_config = config.runtime_config();
169
170 let this =
171 Self::open_with_pool(pool, config.secret.clone(), pool_config, runtime_config).await?;
172 this.read().await?.apply_runtime_config(runtime_config).await?;
173
174 Ok(this)
175 }
176
177 async fn open_with_pool(
180 pool: SqlitePool,
181 secret: Option<Secret>,
182 pool_config: PoolConfig,
183 runtime_config: RuntimeConfig,
184 ) -> Result<Self, OpenStoreError> {
185 let conn = pool.get().await?;
186
187 let version = conn.db_version().await?;
188 debug!("Opened sqlite store with version {}", version);
189
190 let version = initialize_store(&conn, version).await?;
191
192 let store = Self::create_raw(secret, pool, conn, pool_config, runtime_config).await?;
193
194 run_migrations(&store, version, None).await?;
195
196 store.write().await?.wal_checkpoint().await;
197
198 Ok(store)
199 }
200
201 fn deserialize_and_unpickle_inbound_group_session(
202 &self,
203 value: Vec<u8>,
204 backed_up: bool,
205 ) -> Result<InboundGroupSession> {
206 let mut pickle: PickledInboundGroupSession = self.deserialize_value(&value)?;
207
208 pickle.backed_up = backed_up;
213
214 Ok(InboundGroupSession::from_pickle(pickle)?)
215 }
216
217 fn deserialize_key_request(&self, value: &[u8], sent_out: bool) -> Result<GossipRequest> {
218 let mut request: GossipRequest = self.deserialize_value(value)?;
219 request.sent_out = sent_out;
222 Ok(request)
223 }
224
225 fn get_static_account(&self) -> Option<StaticAccountData> {
226 self.static_account.read().unwrap().clone()
227 }
228
229 #[instrument(skip_all)]
231 async fn read(&self) -> Result<SqliteAsyncConn> {
232 let pool = {
233 let guard = self.connections.lock().await;
234 let conns = guard.as_ref().ok_or(Error::StoreClosed)?;
235 conns.pool.clone()
236 };
237 Ok(pool.get().await?)
238 }
239
240 #[instrument(skip_all)]
242 pub(crate) async fn write(&self) -> Result<OwnedMutexGuard<SqliteAsyncConn>> {
243 let write_connection = {
244 let guard = self.connections.lock().await;
245 let conns = guard.as_ref().ok_or(Error::StoreClosed)?;
246 conns.write_connection.clone()
247 };
248 Ok(write_connection.lock_owned().await)
249 }
250}
251
252const DATABASE_VERSION: u8 = 15;
253
254const DEHYDRATED_DEVICE_PICKLE_KEY: &str = "dehydrated_device_pickle_key";
256
257pub(crate) async fn initialize_store(conn: &SqliteAsyncConn, version: u8) -> Result<u8> {
268 if version == 0 {
269 debug!("Creating database");
270 } else if version < DATABASE_VERSION {
271 debug!(version, new_version = DATABASE_VERSION, "Upgrading database");
272 } else {
273 return Ok(version);
274 }
275
276 if version < 1 {
277 debug!("Creating database");
278 conn.execute_batch("PRAGMA journal_mode = wal;").await?;
281 conn.with_transaction(|txn| {
282 txn.execute_batch(include_str!("../migrations/crypto_store/001_init.sql"))?;
283 txn.set_db_version(1)
284 })
285 .await?;
286 return Ok(1);
287 }
288
289 Ok(version)
290}
291
292pub(crate) async fn run_migrations(
304 store: &SqliteCryptoStore,
305 version: u8,
306 max_version: Option<u8>,
307) -> Result<()> {
308 let conn = store.write().await?;
309
310 if version < 2 {
311 debug!("Upgrading database to version 2");
312 conn.with_transaction(|txn| {
313 txn.execute_batch(include_str!("../migrations/crypto_store/002_reset_olm_hash.sql"))?;
314 txn.set_db_version(2)
315 })
316 .await?;
317 }
318
319 if version < 3 {
320 debug!("Upgrading database to version 3");
321 conn.with_transaction(|txn| {
322 txn.execute_batch(include_str!("../migrations/crypto_store/003_room_settings.sql"))?;
323 txn.set_db_version(3)
324 })
325 .await?;
326 }
327
328 if version < 4 {
329 debug!("Upgrading database to version 4");
330 conn.with_transaction(|txn| {
331 txn.execute_batch(include_str!(
332 "../migrations/crypto_store/004_drop_outbound_group_sessions.sql"
333 ))?;
334 txn.set_db_version(4)
335 })
336 .await?;
337 }
338
339 if version < 5 {
340 debug!("Upgrading database to version 5");
341 conn.with_transaction(|txn| {
342 txn.execute_batch(include_str!("../migrations/crypto_store/005_withheld_code.sql"))?;
343 txn.set_db_version(5)
344 })
345 .await?;
346 }
347
348 if version < 6 {
349 debug!("Upgrading database to version 6");
350 conn.with_transaction(|txn| {
351 txn.execute_batch(include_str!(
352 "../migrations/crypto_store/006_drop_outbound_group_sessions.sql"
353 ))?;
354 txn.set_db_version(6)
355 })
356 .await?;
357 }
358
359 if version < 7 {
360 debug!("Upgrading database to version 7");
361 conn.with_transaction(|txn| {
362 txn.execute_batch(include_str!("../migrations/crypto_store/007_lock_leases.sql"))?;
363 txn.set_db_version(7)
364 })
365 .await?;
366 }
367
368 if version < 8 {
369 debug!("Upgrading database to version 8");
370 conn.with_transaction(|txn| {
371 txn.execute_batch(include_str!("../migrations/crypto_store/008_secret_inbox.sql"))?;
372 txn.set_db_version(8)
373 })
374 .await?;
375 }
376
377 if version < 9 {
378 debug!("Upgrading database to version 9");
379 conn.with_transaction(|txn| {
380 txn.execute_batch(include_str!(
381 "../migrations/crypto_store/009_inbound_group_session_sender_key_sender_data_type.sql"
382 ))?;
383 txn.set_db_version(9)
384 })
385 .await?;
386 }
387
388 if version < 10 {
389 debug!("Upgrading database to version 10");
390 conn.with_transaction(|txn| {
391 txn.execute_batch(include_str!(
392 "../migrations/crypto_store/010_received_room_key_bundles.sql"
393 ))?;
394 txn.set_db_version(10)
395 })
396 .await?;
397 }
398
399 if version < 11 {
400 debug!("Upgrading database to version 11");
401 conn.with_transaction(|txn| {
402 txn.execute_batch(include_str!(
403 "../migrations/crypto_store/011_received_room_key_bundles_with_curve_key.sql"
404 ))?;
405 txn.set_db_version(11)
406 })
407 .await?;
408 }
409
410 if version < 12 {
411 debug!("Upgrading database to version 12");
412 conn.with_transaction(|txn| {
413 txn.execute_batch(include_str!(
414 "../migrations/crypto_store/012_withheld_code_by_room.sql"
415 ))?;
416 txn.set_db_version(12)
417 })
418 .await?;
419 }
420
421 if version < 13 {
422 debug!("Upgrading database to version 13");
423 conn.with_transaction(|txn| {
424 txn.execute_batch(include_str!(
425 "../migrations/crypto_store/013_lease_locks_with_generation.sql"
426 ))?;
427 txn.set_db_version(13)
428 })
429 .await?;
430 }
431
432 if version < 14 {
433 debug!("Upgrading database to version 14");
434 conn.with_transaction(|txn| {
435 txn.execute_batch(include_str!(
436 "../migrations/crypto_store/014_room_key_backups_fully_downloaded.sql"
437 ))?;
438 txn.set_db_version(14)
439 })
440 .await?;
441 }
442
443 if version < 15 {
444 debug!("Upgrading database to version 15");
445 conn.with_transaction(|txn| {
446 txn.execute_batch(include_str!(
447 "../migrations/crypto_store/015_rooms_pending_key_bundle.sql"
448 ))?;
449 txn.set_db_version(15)
450 })
451 .await?;
452 }
453
454 if version < 16 {
455 debug!("Upgrading database to version 16");
456 conn.with_transaction(|txn| {
457 txn.execute_batch(include_str!(
458 "../migrations/crypto_store/016_remove_old_generation_counter.sql"
459 ))?;
460 txn.set_db_version(16)
461 })
462 .await?;
463 }
464
465 if max_version.is_some_and(|max_version| max_version < 17) {
466 return Ok(());
467 }
468
469 if version < 17 {
470 let store = store.clone();
471 conn.with_transaction(move |txn| {
472 txn.execute_batch(include_str!(
473 "../migrations/crypto_store/017_add_new_secrets_inbox.sql"
474 ))?;
475 let mut select_query = txn.prepare("SELECT data FROM secrets")?;
476 let mut secrets = select_query.query([])?;
477 let mut insert_query = txn.prepare(
478 "INSERT OR IGNORE INTO secrets_inbox (secret_name, secret)
479 VALUES (?1, ?2)",
480 )?;
481 while let Some(row) = secrets.next()? {
482 let Ok(secret) =
483 store.deserialize_json::<GossippedSecret>(row.get::<_, Vec<u8>>(0)?.as_ref())
484 else {
485 continue;
486 };
487 let Ok(encoded_secret) = store.serialize_json(&secret.event.content.secret) else {
488 continue;
489 };
490 insert_query.execute((
491 store.encode_key("secrets_inbox", secret.secret_name.to_string()),
492 &encoded_secret,
493 ))?;
494 }
495 txn.execute_batch(include_str!(
496 "../migrations/crypto_store/017_drop_old_secrets_inbox.sql"
497 ))?;
498 txn.set_db_version(17)
499 })
500 .await?;
501 }
502
503 Ok(())
504}
505
506trait SqliteConnectionExt {
507 fn set_session(
508 &self,
509 session_id: &[u8],
510 sender_key: &[u8],
511 data: &[u8],
512 ) -> rusqlite::Result<()>;
513
514 fn set_inbound_group_session(
515 &self,
516 room_id: &[u8],
517 session_id: &[u8],
518 data: &[u8],
519 backed_up: bool,
520 sender_key: Option<&[u8]>,
521 sender_data_type: Option<u8>,
522 ) -> rusqlite::Result<()>;
523
524 fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
525
526 fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
527 fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()>;
528
529 fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
530
531 fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()>;
532
533 fn set_key_request(
534 &self,
535 request_id: &[u8],
536 sent_out: bool,
537 data: &[u8],
538 ) -> rusqlite::Result<()>;
539
540 fn set_direct_withheld(
541 &self,
542 session_id: &[u8],
543 room_id: &[u8],
544 data: &[u8],
545 ) -> rusqlite::Result<()>;
546
547 fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
548
549 fn set_secret(&self, request_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
550
551 fn set_received_room_key_bundle(
552 &self,
553 room_id: &[u8],
554 user_id: &[u8],
555 data: &[u8],
556 ) -> rusqlite::Result<()>;
557
558 fn set_has_downloaded_all_room_keys(&self, room_id: &[u8]) -> rusqlite::Result<()>;
559
560 fn set_room_pending_key_bundle(
561 &self,
562 room_id: &[u8],
563 details: Option<&[u8]>,
564 ) -> rusqlite::Result<()>;
565}
566
567impl SqliteConnectionExt for rusqlite::Connection {
568 fn set_session(
569 &self,
570 session_id: &[u8],
571 sender_key: &[u8],
572 data: &[u8],
573 ) -> rusqlite::Result<()> {
574 self.execute(
575 "INSERT INTO session (session_id, sender_key, data)
576 VALUES (?1, ?2, ?3)
577 ON CONFLICT (session_id) DO UPDATE SET data = ?3",
578 (session_id, sender_key, data),
579 )?;
580 Ok(())
581 }
582
583 fn set_inbound_group_session(
584 &self,
585 room_id: &[u8],
586 session_id: &[u8],
587 data: &[u8],
588 backed_up: bool,
589 sender_key: Option<&[u8]>,
590 sender_data_type: Option<u8>,
591 ) -> rusqlite::Result<()> {
592 self.execute(
593 "INSERT INTO inbound_group_session (session_id, room_id, data, backed_up, sender_key, sender_data_type) \
594 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
595 ON CONFLICT (session_id) DO UPDATE SET data = ?3, backed_up = ?4, sender_key = ?5, sender_data_type = ?6",
596 (session_id, room_id, data, backed_up, sender_key, sender_data_type),
597 )?;
598 Ok(())
599 }
600
601 fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
602 self.execute(
603 "INSERT INTO outbound_group_session (room_id, data) \
604 VALUES (?1, ?2)
605 ON CONFLICT (room_id) DO UPDATE SET data = ?2",
606 (room_id, data),
607 )?;
608 Ok(())
609 }
610
611 fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
612 self.execute(
613 "INSERT INTO device (user_id, device_id, data) \
614 VALUES (?1, ?2, ?3)
615 ON CONFLICT (user_id, device_id) DO UPDATE SET data = ?3",
616 (user_id, device_id, data),
617 )?;
618 Ok(())
619 }
620
621 fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()> {
622 self.execute(
623 "DELETE FROM device WHERE user_id = ? AND device_id = ?",
624 (user_id, device_id),
625 )?;
626 Ok(())
627 }
628
629 fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
630 self.execute(
631 "INSERT INTO identity (user_id, data) \
632 VALUES (?1, ?2)
633 ON CONFLICT (user_id) DO UPDATE SET data = ?2",
634 (user_id, data),
635 )?;
636 Ok(())
637 }
638
639 fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()> {
640 self.execute("INSERT INTO olm_hash (data) VALUES (?) ON CONFLICT DO NOTHING", (data,))?;
641 Ok(())
642 }
643
644 fn set_key_request(
645 &self,
646 request_id: &[u8],
647 sent_out: bool,
648 data: &[u8],
649 ) -> rusqlite::Result<()> {
650 self.execute(
651 "INSERT INTO key_requests (request_id, sent_out, data)
652 VALUES (?1, ?2, ?3)
653 ON CONFLICT (request_id) DO UPDATE SET sent_out = ?2, data = ?3",
654 (request_id, sent_out, data),
655 )?;
656 Ok(())
657 }
658
659 fn set_direct_withheld(
660 &self,
661 session_id: &[u8],
662 room_id: &[u8],
663 data: &[u8],
664 ) -> rusqlite::Result<()> {
665 self.execute(
666 "INSERT INTO direct_withheld_info (session_id, room_id, data)
667 VALUES (?1, ?2, ?3)
668 ON CONFLICT (session_id) DO UPDATE SET room_id = ?2, data = ?3",
669 (session_id, room_id, data),
670 )?;
671 Ok(())
672 }
673
674 fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
675 self.execute(
676 "INSERT INTO room_settings (room_id, data)
677 VALUES (?1, ?2)
678 ON CONFLICT (room_id) DO UPDATE SET data = ?2",
679 (room_id, data),
680 )?;
681 Ok(())
682 }
683
684 fn set_secret(&self, secret_name: &[u8], secret: &[u8]) -> rusqlite::Result<()> {
685 self.execute(
688 "INSERT OR IGNORE INTO secrets_inbox (secret_name, secret)
689 VALUES (?1, ?2)",
690 (secret_name, secret),
691 )?;
692
693 Ok(())
694 }
695
696 fn set_received_room_key_bundle(
697 &self,
698 room_id: &[u8],
699 sender_user_id: &[u8],
700 data: &[u8],
701 ) -> rusqlite::Result<()> {
702 self.execute(
703 "INSERT INTO received_room_key_bundle(room_id, sender_user_id, bundle_data)
704 VALUES (?1, ?2, ?3)
705 ON CONFLICT (room_id, sender_user_id) DO UPDATE SET bundle_data = ?3",
706 (room_id, sender_user_id, data),
707 )?;
708 Ok(())
709 }
710
711 fn set_room_pending_key_bundle(
712 &self,
713 room_id: &[u8],
714 data: Option<&[u8]>,
715 ) -> rusqlite::Result<()> {
716 if let Some(data) = data {
717 self.execute(
718 "INSERT INTO rooms_pending_key_bundle (room_id, data)
719 VALUES (?1, ?2)
720 ON CONFLICT (room_id) DO UPDATE SET data = ?2",
721 (room_id, data),
722 )?;
723 } else {
724 self.execute("DELETE FROM rooms_pending_key_bundle WHERE room_id = ?1", (room_id,))?;
725 }
726 Ok(())
727 }
728
729 fn set_has_downloaded_all_room_keys(&self, room_id: &[u8]) -> rusqlite::Result<()> {
730 self.execute(
731 "INSERT INTO room_key_backups_fully_downloaded(room_id)
732 VALUES (?1)
733 ON CONFLICT(room_id) DO NOTHING",
734 (room_id,),
735 )?;
736 Ok(())
737 }
738}
739
740#[async_trait]
741trait SqliteObjectCryptoStoreExt: SqliteAsyncConnExt {
742 async fn get_sessions_for_sender_key(&self, sender_key: Key) -> Result<Vec<Vec<u8>>> {
743 Ok(self
744 .prepare("SELECT data FROM session WHERE sender_key = ?", |mut stmt| {
745 stmt.query((sender_key,))?.mapped(|row| row.get(0)).collect()
746 })
747 .await?)
748 }
749
750 async fn get_inbound_group_session(
751 &self,
752 session_id: Key,
753 ) -> Result<Option<(Vec<u8>, Vec<u8>, bool)>> {
754 Ok(self
755 .query_row(
756 "SELECT room_id, data, backed_up FROM inbound_group_session WHERE session_id = ?",
757 (session_id,),
758 |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
759 )
760 .await
761 .optional()?)
762 }
763
764 async fn get_inbound_group_sessions(&self) -> Result<Vec<(Vec<u8>, bool)>> {
765 Ok(self
766 .prepare("SELECT data, backed_up FROM inbound_group_session", |mut stmt| {
767 stmt.query(())?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
768 })
769 .await?)
770 }
771
772 async fn get_inbound_group_session_counts(
773 &self,
774 _backup_version: Option<&str>,
775 ) -> Result<RoomKeyCounts> {
776 let total = self
777 .query_row("SELECT count(*) FROM inbound_group_session", (), |row| row.get(0))
778 .await?;
779 let backed_up = self
780 .query_row(
781 "SELECT count(*) FROM inbound_group_session WHERE backed_up = TRUE",
782 (),
783 |row| row.get(0),
784 )
785 .await?;
786 Ok(RoomKeyCounts { total, backed_up })
787 }
788
789 async fn get_inbound_group_sessions_by_room_id(
790 &self,
791 room_id: Key,
792 ) -> Result<Vec<(Vec<u8>, bool)>> {
793 Ok(self
794 .prepare(
795 "SELECT data, backed_up FROM inbound_group_session WHERE room_id = :room_id",
796 move |mut stmt| {
797 stmt.query(named_params! {
798 ":room_id": room_id,
799 })?
800 .mapped(|row| Ok((row.get(0)?, row.get(1)?)))
801 .collect()
802 },
803 )
804 .await?)
805 }
806
807 async fn get_inbound_group_sessions_for_device_batch(
808 &self,
809 sender_key: Key,
810 sender_data_type: SenderDataType,
811 after_session_id: Option<Key>,
812 limit: usize,
813 ) -> Result<Vec<(Vec<u8>, bool)>> {
814 Ok(self
815 .prepare(
816 "
817 SELECT data, backed_up
818 FROM inbound_group_session
819 WHERE sender_key = :sender_key
820 AND sender_data_type = :sender_data_type
821 AND session_id > :after_session_id
822 ORDER BY session_id
823 LIMIT :limit
824 ",
825 move |mut stmt| {
826 let sender_data_type = sender_data_type as u8;
827
828 let after_session_id = after_session_id.unwrap_or(Key::Plain(Vec::new()));
831
832 stmt.query(named_params! {
833 ":sender_key": sender_key,
834 ":sender_data_type": sender_data_type,
835 ":after_session_id": after_session_id,
836 ":limit": limit,
837 })?
838 .mapped(|row| Ok((row.get(0)?, row.get(1)?)))
839 .collect()
840 },
841 )
842 .await?)
843 }
844
845 async fn get_inbound_group_sessions_for_backup(&self, limit: usize) -> Result<Vec<Vec<u8>>> {
846 Ok(self
847 .prepare(
848 "SELECT data FROM inbound_group_session WHERE backed_up = FALSE LIMIT ?",
849 move |mut stmt| stmt.query((limit,))?.mapped(|row| row.get(0)).collect(),
850 )
851 .await?)
852 }
853
854 async fn mark_inbound_group_sessions_as_backed_up(&self, session_ids: Vec<Key>) -> Result<()> {
855 if session_ids.is_empty() {
856 warn!("No sessions to mark as backed up!");
858 return Ok(());
859 }
860
861 let session_ids_len = session_ids.len();
862
863 self.chunk_large_query_over(session_ids, None, move |txn, session_ids| {
864 let sql_params = repeat_vars(session_ids_len);
867 let query = format!("UPDATE inbound_group_session SET backed_up = TRUE where session_id IN ({sql_params})");
868 txn.prepare(&query)?.execute(params_from_iter(session_ids.iter()))?;
869 Ok(Vec::<()>::new())
870 }).await?;
871
872 Ok(())
873 }
874
875 async fn reset_inbound_group_session_backup_state(&self) -> Result<()> {
876 self.execute("UPDATE inbound_group_session SET backed_up = FALSE", ()).await?;
877 Ok(())
878 }
879
880 async fn get_outbound_group_session(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
881 Ok(self
882 .query_row(
883 "SELECT data FROM outbound_group_session WHERE room_id = ?",
884 (room_id,),
885 |row| row.get(0),
886 )
887 .await
888 .optional()?)
889 }
890
891 async fn get_device(&self, user_id: Key, device_id: Key) -> Result<Option<Vec<u8>>> {
892 Ok(self
893 .query_row(
894 "SELECT data FROM device WHERE user_id = ? AND device_id = ?",
895 (user_id, device_id),
896 |row| row.get(0),
897 )
898 .await
899 .optional()?)
900 }
901
902 async fn get_user_devices(&self, user_id: Key) -> Result<Vec<Vec<u8>>> {
903 Ok(self
904 .prepare("SELECT data FROM device WHERE user_id = ?", |mut stmt| {
905 stmt.query((user_id,))?.mapped(|row| row.get(0)).collect()
906 })
907 .await?)
908 }
909
910 async fn get_user_identity(&self, user_id: Key) -> Result<Option<Vec<u8>>> {
911 Ok(self
912 .query_row("SELECT data FROM identity WHERE user_id = ?", (user_id,), |row| row.get(0))
913 .await
914 .optional()?)
915 }
916
917 async fn has_olm_hash(&self, data: Vec<u8>) -> Result<bool> {
918 Ok(self
919 .query_row("SELECT count(*) FROM olm_hash WHERE data = ?", (data,), |row| {
920 row.get::<_, i32>(0)
921 })
922 .await?
923 > 0)
924 }
925
926 async fn get_tracked_users(&self) -> Result<Vec<Vec<u8>>> {
927 Ok(self
928 .prepare("SELECT data FROM tracked_user", |mut stmt| {
929 stmt.query(())?.mapped(|row| row.get(0)).collect()
930 })
931 .await?)
932 }
933
934 async fn add_tracked_users(&self, users: Vec<(Key, Vec<u8>)>) -> Result<()> {
935 Ok(self
936 .prepare(
937 "INSERT INTO tracked_user (user_id, data) \
938 VALUES (?1, ?2) \
939 ON CONFLICT (user_id) DO UPDATE SET data = ?2",
940 |mut stmt| {
941 for (user_id, data) in users {
942 stmt.execute((user_id, data))?;
943 }
944
945 Ok(())
946 },
947 )
948 .await?)
949 }
950
951 async fn get_outgoing_secret_request(
952 &self,
953 request_id: Key,
954 ) -> Result<Option<(Vec<u8>, bool)>> {
955 Ok(self
956 .query_row(
957 "SELECT data, sent_out FROM key_requests WHERE request_id = ?",
958 (request_id,),
959 |row| Ok((row.get(0)?, row.get(1)?)),
960 )
961 .await
962 .optional()?)
963 }
964
965 async fn get_outgoing_secret_requests(&self) -> Result<Vec<(Vec<u8>, bool)>> {
966 Ok(self
967 .prepare("SELECT data, sent_out FROM key_requests", |mut stmt| {
968 stmt.query(())?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
969 })
970 .await?)
971 }
972
973 async fn get_unsent_secret_requests(&self) -> Result<Vec<Vec<u8>>> {
974 Ok(self
975 .prepare("SELECT data FROM key_requests WHERE sent_out = FALSE", |mut stmt| {
976 stmt.query(())?.mapped(|row| row.get(0)).collect()
977 })
978 .await?)
979 }
980
981 async fn delete_key_request(&self, request_id: Key) -> Result<()> {
982 self.execute("DELETE FROM key_requests WHERE request_id = ?", (request_id,)).await?;
983 Ok(())
984 }
985
986 async fn get_secrets_from_inbox(&self, secret_name: Key) -> Result<Vec<Vec<u8>>> {
987 Ok(self
988 .prepare("SELECT secret FROM secrets_inbox WHERE secret_name = ?", |mut stmt| {
989 stmt.query((secret_name,))?.mapped(|row| row.get(0)).collect()
990 })
991 .await?)
992 }
993
994 async fn delete_secrets_from_inbox(&self, secret_name: Key) -> Result<()> {
995 self.execute("DELETE FROM secrets_inbox WHERE secret_name = ?", (secret_name,)).await?;
996 Ok(())
997 }
998
999 async fn get_direct_withheld_info(
1000 &self,
1001 session_id: Key,
1002 room_id: Key,
1003 ) -> Result<Option<Vec<u8>>> {
1004 Ok(self
1005 .query_row(
1006 "SELECT data FROM direct_withheld_info WHERE session_id = ?1 AND room_id = ?2",
1007 (session_id, room_id),
1008 |row| row.get(0),
1009 )
1010 .await
1011 .optional()?)
1012 }
1013
1014 async fn get_withheld_sessions_by_room_id(&self, room_id: Key) -> Result<Vec<Vec<u8>>> {
1015 Ok(self
1016 .prepare("SELECT data FROM direct_withheld_info WHERE room_id = ?1", |mut stmt| {
1017 stmt.query((room_id,))?.mapped(|row| row.get(0)).collect()
1018 })
1019 .await?)
1020 }
1021
1022 async fn get_room_settings(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
1023 Ok(self
1024 .query_row("SELECT data FROM room_settings WHERE room_id = ?", (room_id,), |row| {
1025 row.get(0)
1026 })
1027 .await
1028 .optional()?)
1029 }
1030
1031 async fn get_received_room_key_bundle(
1032 &self,
1033 room_id: Key,
1034 sender_user: Key,
1035 ) -> Result<Option<Vec<u8>>> {
1036 Ok(self
1037 .query_row(
1038 "SELECT bundle_data FROM received_room_key_bundle WHERE room_id = ? AND sender_user_id = ?",
1039 (room_id, sender_user),
1040 |row| { row.get(0) },
1041 )
1042 .await
1043 .optional()?)
1044 }
1045
1046 async fn get_room_pending_key_bundle(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
1047 Ok(self
1048 .query_row(
1049 "SELECT data FROM rooms_pending_key_bundle WHERE room_id = ?",
1050 (room_id,),
1051 |row| row.get(0),
1052 )
1053 .await
1054 .optional()?)
1055 }
1056
1057 async fn get_all_rooms_pending_key_bundle(&self) -> Result<Vec<Vec<u8>>> {
1058 Ok(self
1059 .query_many("SELECT data FROM rooms_pending_key_bundle", (), |row| row.get(0))
1060 .await?)
1061 }
1062
1063 async fn has_downloaded_all_room_keys(&self, room_id: Key) -> Result<bool> {
1064 Ok(self
1065 .query_row(
1066 "SELECT EXISTS (SELECT 1 FROM room_key_backups_fully_downloaded WHERE room_id = ?)",
1067 (room_id,),
1068 |row| row.get(0),
1069 )
1070 .await?)
1071 }
1072}
1073
1074#[async_trait]
1075impl SqliteObjectCryptoStoreExt for SqliteAsyncConn {}
1076
1077#[async_trait]
1078impl CryptoStore for SqliteCryptoStore {
1079 type Error = Error;
1080
1081 async fn load_account(&self) -> Result<Option<Account>> {
1082 let conn = self.read().await?;
1083 if let Some(pickle) = conn.get_kv("account").await? {
1084 let pickle = self.deserialize_value(&pickle)?;
1085
1086 let account = Account::from_pickle(pickle).map_err(|_| Error::Unpickle)?;
1087
1088 *self.static_account.write().unwrap() = Some(account.static_data().clone());
1089
1090 Ok(Some(account))
1091 } else {
1092 Ok(None)
1093 }
1094 }
1095
1096 async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
1097 let conn = self.read().await?;
1098 if let Some(i) = conn.get_kv("identity").await? {
1099 let pickle = self.deserialize_value(&i)?;
1100 Ok(Some(PrivateCrossSigningIdentity::from_pickle(pickle).map_err(|_| Error::Unpickle)?))
1101 } else {
1102 Ok(None)
1103 }
1104 }
1105
1106 async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
1107 let _guard = self.save_changes_lock.lock().await;
1112
1113 let pickled_account = if let Some(account) = changes.account {
1114 *self.static_account.write().unwrap() = Some(account.static_data().clone());
1115 Some(account.pickle())
1116 } else {
1117 None
1118 };
1119
1120 let this = self.clone();
1121 self.write()
1122 .await?
1123 .with_transaction(move |txn| {
1124 if let Some(pickled_account) = pickled_account {
1125 let serialized_account = this.serialize_value(&pickled_account)?;
1126 txn.set_kv("account", &serialized_account)?;
1127 }
1128
1129 Ok::<_, Error>(())
1130 })
1131 .await?;
1132
1133 Ok(())
1134 }
1135
1136 async fn save_changes(&self, changes: Changes) -> Result<()> {
1137 let _guard = self.save_changes_lock.lock().await;
1142
1143 let pickled_private_identity =
1144 if let Some(i) = changes.private_identity { Some(i.pickle().await) } else { None };
1145
1146 let mut session_changes = Vec::new();
1147
1148 for session in changes.sessions {
1149 let session_id = self.encode_key("session", session.session_id());
1150 let sender_key = self.encode_key("session", session.sender_key().to_base64());
1151 let pickle = session.pickle().await;
1152 session_changes.push((session_id, sender_key, pickle));
1153 }
1154
1155 let mut inbound_session_changes = Vec::new();
1156 for session in changes.inbound_group_sessions {
1157 let room_id = self.encode_key("inbound_group_session", session.room_id().as_bytes());
1158 let session_id = self.encode_key("inbound_group_session", session.session_id());
1159 let pickle = session.pickle().await;
1160 let sender_key =
1161 self.encode_key("inbound_group_session", session.sender_key().to_base64());
1162 inbound_session_changes.push((room_id, session_id, pickle, sender_key));
1163 }
1164
1165 let mut outbound_session_changes = Vec::new();
1166 for session in changes.outbound_group_sessions {
1167 let room_id = self.encode_key("outbound_group_session", session.room_id().as_bytes());
1168 let pickle = session.pickle().await;
1169 outbound_session_changes.push((room_id, pickle));
1170 }
1171
1172 let this = self.clone();
1173 self.write()
1174 .await?
1175 .with_transaction(move |txn| {
1176 if let Some(pickled_private_identity) = &pickled_private_identity {
1177 let serialized_private_identity =
1178 this.serialize_value(pickled_private_identity)?;
1179 txn.set_kv("identity", &serialized_private_identity)?;
1180 }
1181
1182 if let Some(token) = &changes.next_batch_token {
1183 let serialized_token = this.serialize_value(token)?;
1184 txn.set_kv("next_batch_token", &serialized_token)?;
1185 }
1186
1187 if let Some(decryption_key) = &changes.backup_decryption_key {
1188 let serialized_decryption_key = this.serialize_value(decryption_key)?;
1189 txn.set_kv("recovery_key_v1", &serialized_decryption_key)?;
1190 }
1191
1192 if let Some(backup_version) = &changes.backup_version {
1193 let serialized_backup_version = this.serialize_value(backup_version)?;
1194 txn.set_kv("backup_version_v1", &serialized_backup_version)?;
1195 }
1196
1197 if let Some(pickle_key) = &changes.dehydrated_device_pickle_key {
1198 let serialized_pickle_key = this.serialize_value(pickle_key)?;
1199 txn.set_kv(DEHYDRATED_DEVICE_PICKLE_KEY, &serialized_pickle_key)?;
1200 }
1201
1202 for device in changes.devices.new.iter().chain(&changes.devices.changed) {
1203 let user_id = this.encode_key("device", device.user_id().as_bytes());
1204 let device_id = this.encode_key("device", device.device_id().as_bytes());
1205 let data = this.serialize_value(&device)?;
1206 txn.set_device(&user_id, &device_id, &data)?;
1207 }
1208
1209 for device in &changes.devices.deleted {
1210 let user_id = this.encode_key("device", device.user_id().as_bytes());
1211 let device_id = this.encode_key("device", device.device_id().as_bytes());
1212 txn.delete_device(&user_id, &device_id)?;
1213 }
1214
1215 for identity in changes.identities.changed.iter().chain(&changes.identities.new) {
1216 let user_id = this.encode_key("identity", identity.user_id().as_bytes());
1217 let data = this.serialize_value(&identity)?;
1218 txn.set_identity(&user_id, &data)?;
1219 }
1220
1221 for (session_id, sender_key, pickle) in &session_changes {
1222 let serialized_session = this.serialize_value(&pickle)?;
1223 txn.set_session(session_id, sender_key, &serialized_session)?;
1224 }
1225
1226 for (room_id, session_id, pickle, sender_key) in &inbound_session_changes {
1227 let serialized_session = this.serialize_value(&pickle)?;
1228 txn.set_inbound_group_session(
1229 room_id,
1230 session_id,
1231 &serialized_session,
1232 pickle.backed_up,
1233 Some(sender_key),
1234 Some(pickle.sender_data.to_type() as u8),
1235 )?;
1236 }
1237
1238 for (room_id, pickle) in &outbound_session_changes {
1239 let serialized_session = this.serialize_json(&pickle)?;
1240 txn.set_outbound_group_session(room_id, &serialized_session)?;
1241 }
1242
1243 for hash in &changes.message_hashes {
1244 let hash = rmp_serde::to_vec(hash)?;
1245 txn.add_olm_hash(&hash)?;
1246 }
1247
1248 for request in changes.key_requests {
1249 let request_id = this.encode_key("key_requests", request.request_id.as_bytes());
1250 let serialized_request = this.serialize_value(&request)?;
1251 txn.set_key_request(&request_id, request.sent_out, &serialized_request)?;
1252 }
1253
1254 for (room_id, data) in changes.withheld_session_info {
1255 for (session_id, event) in data {
1256 let session_id = this.encode_key("direct_withheld_info", session_id);
1257 let room_id = this.encode_key("direct_withheld_info", &room_id);
1258 let serialized_info = this.serialize_json(&event)?;
1259 txn.set_direct_withheld(&session_id, &room_id, &serialized_info)?;
1260 }
1261 }
1262
1263 for (room_id, settings) in changes.room_settings {
1264 let room_id = this.encode_key("room_settings", room_id.as_bytes());
1265 let value = this.serialize_value(&settings)?;
1266 txn.set_room_settings(&room_id, &value)?;
1267 }
1268
1269 for secret in changes.secrets {
1270 let secret_name =
1271 this.encode_key("secrets_inbox", secret.secret_name.to_string());
1272 let value = this.serialize_json(secret.secret.deref())?;
1273 txn.set_secret(&secret_name, &value)?;
1274 }
1275
1276 for bundle in changes.received_room_key_bundles {
1277 let room_id =
1278 this.encode_key("received_room_key_bundle", &bundle.bundle_data.room_id);
1279 let user_id = this.encode_key("received_room_key_bundle", &bundle.sender_user);
1280 let value = this.serialize_value(&bundle)?;
1281 txn.set_received_room_key_bundle(&room_id, &user_id, &value)?;
1282 }
1283
1284 for room in changes.room_key_backups_fully_downloaded {
1285 let room_id = this.encode_key("room_key_backups_fully_downloaded", &room);
1286 txn.set_has_downloaded_all_room_keys(&room_id)?;
1287 }
1288
1289 for (room, details) in changes.rooms_pending_key_bundle {
1290 let room_id = this.encode_key("rooms_pending_key_bundle", &room);
1291 let value = details.as_ref().map(|d| this.serialize_value(d)).transpose()?;
1292 txn.set_room_pending_key_bundle(&room_id, value.as_deref())?;
1293 }
1294
1295 Ok::<_, Error>(())
1296 })
1297 .await?;
1298
1299 Ok(())
1300 }
1301
1302 async fn save_inbound_group_sessions(
1303 &self,
1304 sessions: Vec<InboundGroupSession>,
1305 backed_up_to_version: Option<&str>,
1306 ) -> matrix_sdk_crypto::store::Result<(), Self::Error> {
1307 sessions.iter().for_each(|s| {
1309 let backed_up = s.backed_up();
1310 if backed_up != backed_up_to_version.is_some() {
1311 warn!(
1312 backed_up,
1313 backed_up_to_version,
1314 "Session backed-up flag does not correspond to backup version setting",
1315 );
1316 }
1317 });
1318
1319 self.save_changes(Changes { inbound_group_sessions: sessions, ..Changes::default() }).await
1322 }
1323
1324 async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
1325 let device_keys = self.get_own_device().await?.as_device_keys().clone();
1326
1327 let sessions: Vec<_> = self
1328 .read()
1329 .await?
1330 .get_sessions_for_sender_key(self.encode_key("session", sender_key.as_bytes()))
1331 .await?
1332 .into_iter()
1333 .map(|bytes| {
1334 let pickle = self.deserialize_value(&bytes)?;
1335 Session::from_pickle(device_keys.clone(), pickle).map_err(|_| Error::AccountUnset)
1336 })
1337 .collect::<Result<_>>()?;
1338
1339 if sessions.is_empty() { Ok(None) } else { Ok(Some(sessions)) }
1340 }
1341
1342 #[instrument(skip(self))]
1343 async fn get_inbound_group_session(
1344 &self,
1345 room_id: &RoomId,
1346 session_id: &str,
1347 ) -> Result<Option<InboundGroupSession>> {
1348 let session_id = self.encode_key("inbound_group_session", session_id);
1349 let Some((room_id_from_db, value, backed_up)) =
1350 self.read().await?.get_inbound_group_session(session_id).await?
1351 else {
1352 return Ok(None);
1353 };
1354
1355 let room_id = self.encode_key("inbound_group_session", room_id.as_bytes());
1356 if *room_id != room_id_from_db {
1357 warn!("expected room_id for session_id doesn't match what's in the DB");
1358 return Ok(None);
1359 }
1360
1361 Ok(Some(self.deserialize_and_unpickle_inbound_group_session(value, backed_up)?))
1362 }
1363
1364 async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
1365 self.read()
1366 .await?
1367 .get_inbound_group_sessions()
1368 .await?
1369 .into_iter()
1370 .map(|(value, backed_up)| {
1371 self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1372 })
1373 .collect()
1374 }
1375
1376 async fn get_inbound_group_sessions_by_room_id(
1377 &self,
1378 room_id: &RoomId,
1379 ) -> Result<Vec<InboundGroupSession>> {
1380 let room_id = self.encode_key("inbound_group_session", room_id.as_bytes());
1381 self.read()
1382 .await?
1383 .get_inbound_group_sessions_by_room_id(room_id)
1384 .await?
1385 .into_iter()
1386 .map(|(value, backed_up)| {
1387 self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1388 })
1389 .collect()
1390 }
1391
1392 async fn get_inbound_group_sessions_for_device_batch(
1393 &self,
1394 sender_key: Curve25519PublicKey,
1395 sender_data_type: SenderDataType,
1396 after_session_id: Option<String>,
1397 limit: usize,
1398 ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1399 let after_session_id =
1400 after_session_id.map(|session_id| self.encode_key("inbound_group_session", session_id));
1401 let sender_key = self.encode_key("inbound_group_session", sender_key.to_base64());
1402
1403 self.read()
1404 .await?
1405 .get_inbound_group_sessions_for_device_batch(
1406 sender_key,
1407 sender_data_type,
1408 after_session_id,
1409 limit,
1410 )
1411 .await?
1412 .into_iter()
1413 .map(|(value, backed_up)| {
1414 self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1415 })
1416 .collect()
1417 }
1418
1419 async fn inbound_group_session_counts(
1420 &self,
1421 backup_version: Option<&str>,
1422 ) -> Result<RoomKeyCounts> {
1423 Ok(self.read().await?.get_inbound_group_session_counts(backup_version).await?)
1424 }
1425
1426 async fn inbound_group_sessions_for_backup(
1427 &self,
1428 _backup_version: &str,
1429 limit: usize,
1430 ) -> Result<Vec<InboundGroupSession>> {
1431 self.read()
1432 .await?
1433 .get_inbound_group_sessions_for_backup(limit)
1434 .await?
1435 .into_iter()
1436 .map(|value| self.deserialize_and_unpickle_inbound_group_session(value, false))
1437 .collect()
1438 }
1439
1440 async fn mark_inbound_group_sessions_as_backed_up(
1441 &self,
1442 _backup_version: &str,
1443 session_ids: &[(&RoomId, &str)],
1444 ) -> Result<()> {
1445 Ok(self
1446 .write()
1447 .await?
1448 .mark_inbound_group_sessions_as_backed_up(
1449 session_ids
1450 .iter()
1451 .map(|(_, s)| self.encode_key("inbound_group_session", s))
1452 .collect(),
1453 )
1454 .await?)
1455 }
1456
1457 async fn reset_backup_state(&self) -> Result<()> {
1458 Ok(self.write().await?.reset_inbound_group_session_backup_state().await?)
1459 }
1460
1461 async fn load_backup_keys(&self) -> Result<BackupKeys> {
1462 let conn = self.read().await?;
1463
1464 let backup_version = conn
1465 .get_kv("backup_version_v1")
1466 .await?
1467 .map(|value| self.deserialize_value(&value))
1468 .transpose()?;
1469
1470 let decryption_key = conn
1471 .get_kv("recovery_key_v1")
1472 .await?
1473 .map(|value| self.deserialize_value(&value))
1474 .transpose()?;
1475
1476 Ok(BackupKeys { backup_version, decryption_key })
1477 }
1478
1479 async fn load_dehydrated_device_pickle_key(&self) -> Result<Option<DehydratedDeviceKey>> {
1480 let conn = self.read().await?;
1481
1482 conn.get_kv(DEHYDRATED_DEVICE_PICKLE_KEY)
1483 .await?
1484 .map(|value| self.deserialize_value(&value))
1485 .transpose()
1486 }
1487
1488 async fn delete_dehydrated_device_pickle_key(&self) -> Result<(), Self::Error> {
1489 Ok(self.write().await?.clear_kv(DEHYDRATED_DEVICE_PICKLE_KEY).await?)
1490 }
1491 async fn get_outbound_group_session(
1492 &self,
1493 room_id: &RoomId,
1494 ) -> Result<Option<OutboundGroupSession>> {
1495 let room_id = self.encode_key("outbound_group_session", room_id.as_bytes());
1496 let Some(value) = self.read().await?.get_outbound_group_session(room_id).await? else {
1497 return Ok(None);
1498 };
1499
1500 let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1501
1502 let pickle = self.deserialize_json(&value)?;
1503 let session = OutboundGroupSession::from_pickle(
1504 account_info.device_id,
1505 account_info.identity_keys,
1506 pickle,
1507 )
1508 .map_err(|_| Error::Unpickle)?;
1509
1510 return Ok(Some(session));
1511 }
1512
1513 async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
1514 self.read()
1515 .await?
1516 .get_tracked_users()
1517 .await?
1518 .iter()
1519 .map(|value| self.deserialize_value(value))
1520 .collect()
1521 }
1522
1523 async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
1524 let users: Vec<(Key, Vec<u8>)> = tracked_users
1525 .iter()
1526 .map(|(u, d)| {
1527 let user_id = self.encode_key("tracked_users", u.as_bytes());
1528 let data =
1529 self.serialize_value(&TrackedUser { user_id: (*u).into(), dirty: *d })?;
1530 Ok((user_id, data))
1531 })
1532 .collect::<Result<_>>()?;
1533
1534 Ok(self.write().await?.add_tracked_users(users).await?)
1535 }
1536
1537 async fn get_device(
1538 &self,
1539 user_id: &UserId,
1540 device_id: &DeviceId,
1541 ) -> Result<Option<DeviceData>> {
1542 let user_id = self.encode_key("device", user_id.as_bytes());
1543 let device_id = self.encode_key("device", device_id.as_bytes());
1544 Ok(self
1545 .read()
1546 .await?
1547 .get_device(user_id, device_id)
1548 .await?
1549 .map(|value| self.deserialize_value(&value))
1550 .transpose()?)
1551 }
1552
1553 async fn get_user_devices(
1554 &self,
1555 user_id: &UserId,
1556 ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1557 let user_id = self.encode_key("device", user_id.as_bytes());
1558 self.read()
1559 .await?
1560 .get_user_devices(user_id)
1561 .await?
1562 .into_iter()
1563 .map(|value| {
1564 let device: DeviceData = self.deserialize_value(&value)?;
1565 Ok((device.device_id().to_owned(), device))
1566 })
1567 .collect()
1568 }
1569
1570 async fn get_own_device(&self) -> Result<DeviceData> {
1571 let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1572
1573 Ok(self
1574 .get_device(&account_info.user_id, &account_info.device_id)
1575 .await?
1576 .expect("We should be able to find our own device."))
1577 }
1578
1579 async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
1580 let user_id = self.encode_key("identity", user_id.as_bytes());
1581 Ok(self
1582 .read()
1583 .await?
1584 .get_user_identity(user_id)
1585 .await?
1586 .map(|value| self.deserialize_value(&value))
1587 .transpose()?)
1588 }
1589
1590 async fn is_message_known(
1591 &self,
1592 message_hash: &matrix_sdk_crypto::olm::OlmMessageHash,
1593 ) -> Result<bool> {
1594 let value = rmp_serde::to_vec(message_hash)?;
1595 Ok(self.read().await?.has_olm_hash(value).await?)
1596 }
1597
1598 async fn get_outgoing_secret_requests(
1599 &self,
1600 request_id: &TransactionId,
1601 ) -> Result<Option<GossipRequest>> {
1602 let request_id = self.encode_key("key_requests", request_id.as_bytes());
1603 Ok(self
1604 .read()
1605 .await?
1606 .get_outgoing_secret_request(request_id)
1607 .await?
1608 .map(|(value, sent_out)| self.deserialize_key_request(&value, sent_out))
1609 .transpose()?)
1610 }
1611
1612 async fn get_secret_request_by_info(
1613 &self,
1614 key_info: &SecretInfo,
1615 ) -> Result<Option<GossipRequest>> {
1616 let requests = self.read().await?.get_outgoing_secret_requests().await?;
1617 for (request, sent_out) in requests {
1618 let request = self.deserialize_key_request(&request, sent_out)?;
1619 if request.info == *key_info {
1620 return Ok(Some(request));
1621 }
1622 }
1623 Ok(None)
1624 }
1625
1626 async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
1627 self.read()
1628 .await?
1629 .get_unsent_secret_requests()
1630 .await?
1631 .iter()
1632 .map(|value| {
1633 let request = self.deserialize_key_request(value, false)?;
1634 Ok(request)
1635 })
1636 .collect()
1637 }
1638
1639 async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
1640 let request_id = self.encode_key("key_requests", request_id.as_bytes());
1641 Ok(self.write().await?.delete_key_request(request_id).await?)
1642 }
1643
1644 async fn get_secrets_from_inbox(
1645 &self,
1646 secret_name: &SecretName,
1647 ) -> Result<Vec<Zeroizing<String>>> {
1648 let secret_name = self.encode_key("secrets_inbox", secret_name.to_string());
1649
1650 self.read()
1651 .await?
1652 .get_secrets_from_inbox(secret_name)
1653 .await?
1654 .into_iter()
1655 .map(|value| self.deserialize_json(value.as_ref()).map(|value: String| value.into()))
1656 .collect()
1657 }
1658
1659 async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
1660 let secret_name = self.encode_key("secrets_inbox", secret_name.to_string());
1661 self.write().await?.delete_secrets_from_inbox(secret_name).await
1662 }
1663
1664 async fn get_withheld_info(
1665 &self,
1666 room_id: &RoomId,
1667 session_id: &str,
1668 ) -> Result<Option<RoomKeyWithheldEntry>> {
1669 let room_id = self.encode_key("direct_withheld_info", room_id);
1670 let session_id = self.encode_key("direct_withheld_info", session_id);
1671
1672 self.read()
1673 .await?
1674 .get_direct_withheld_info(session_id, room_id)
1675 .await?
1676 .map(|value| {
1677 let info = self.deserialize_json::<RoomKeyWithheldEntry>(&value)?;
1678 Ok(info)
1679 })
1680 .transpose()
1681 }
1682
1683 async fn get_withheld_sessions_by_room_id(
1684 &self,
1685 room_id: &RoomId,
1686 ) -> matrix_sdk_crypto::store::Result<Vec<RoomKeyWithheldEntry>, Self::Error> {
1687 let room_id = self.encode_key("direct_withheld_info", room_id);
1688
1689 self.read()
1690 .await?
1691 .get_withheld_sessions_by_room_id(room_id)
1692 .await?
1693 .into_iter()
1694 .map(|value| self.deserialize_json(&value))
1695 .collect()
1696 }
1697
1698 async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
1699 let room_id = self.encode_key("room_settings", room_id.as_bytes());
1700 let Some(value) = self.read().await?.get_room_settings(room_id).await? else {
1701 return Ok(None);
1702 };
1703
1704 let settings = self.deserialize_value(&value)?;
1705
1706 return Ok(Some(settings));
1707 }
1708
1709 async fn get_received_room_key_bundle_data(
1710 &self,
1711 room_id: &RoomId,
1712 user_id: &UserId,
1713 ) -> Result<Option<StoredRoomKeyBundleData>> {
1714 let room_id = self.encode_key("received_room_key_bundle", room_id);
1715 let user_id = self.encode_key("received_room_key_bundle", user_id);
1716 self.read()
1717 .await?
1718 .get_received_room_key_bundle(room_id, user_id)
1719 .await?
1720 .map(|value| self.deserialize_value(&value))
1721 .transpose()
1722 }
1723
1724 async fn has_downloaded_all_room_keys(&self, room_id: &RoomId) -> Result<bool> {
1725 let room_id = self.encode_key("room_key_backups_fully_downloaded", room_id);
1726 self.read().await?.has_downloaded_all_room_keys(room_id).await
1727 }
1728
1729 async fn get_pending_key_bundle_details_for_room(
1730 &self,
1731 room_id: &RoomId,
1732 ) -> Result<Option<RoomPendingKeyBundleDetails>> {
1733 let room_id = self.encode_key("rooms_pending_key_bundle", room_id.as_bytes());
1734 let Some(value) = self.read().await?.get_room_pending_key_bundle(room_id).await? else {
1735 return Ok(None);
1736 };
1737
1738 let details = self.deserialize_value(&value)?;
1739 Ok(Some(details))
1740 }
1741
1742 async fn get_all_rooms_pending_key_bundles(&self) -> Result<Vec<RoomPendingKeyBundleDetails>> {
1743 let details = self.read().await?.get_all_rooms_pending_key_bundle().await?;
1744 let room_ids = details
1745 .into_iter()
1746 .map(|value| self.deserialize_value(&value))
1747 .collect::<Result<_, _>>()?;
1748 Ok(room_ids)
1749 }
1750
1751 async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
1752 let Some(serialized) = self.read().await?.get_kv(key).await? else {
1753 return Ok(None);
1754 };
1755 let value = if let Some(cipher) = &self.store_cipher {
1756 let encrypted = rmp_serde::from_slice(&serialized)?;
1757 cipher.decrypt_value_data(encrypted)?
1758 } else {
1759 serialized
1760 };
1761
1762 Ok(Some(value))
1763 }
1764
1765 async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
1766 let serialized = if let Some(cipher) = &self.store_cipher {
1767 let encrypted = cipher.encrypt_value_data(value)?;
1768 rmp_serde::to_vec_named(&encrypted)?
1769 } else {
1770 value
1771 };
1772
1773 self.write().await?.set_kv(key, serialized).await?;
1774 Ok(())
1775 }
1776
1777 async fn remove_custom_value(&self, key: &str) -> Result<()> {
1778 let key = key.to_owned();
1779 self.write()
1780 .await?
1781 .interact(move |conn| conn.execute("DELETE FROM kv WHERE key = ?1", (&key,)))
1782 .await
1783 .unwrap()?;
1784 Ok(())
1785 }
1786
1787 #[instrument(skip(self))]
1788 async fn try_take_leased_lock(
1789 &self,
1790 lease_duration_ms: u32,
1791 key: &str,
1792 holder: &str,
1793 ) -> Result<Option<CrossProcessLockGeneration>> {
1794 let key = key.to_owned();
1795 let holder = holder.to_owned();
1796
1797 let now: u64 = MilliSecondsSinceUnixEpoch::now().get().into();
1798 let expiration = now + lease_duration_ms as u64;
1799
1800 let generation = self
1802 .write()
1803 .await?
1804 .with_transaction(move |txn| {
1805 txn.query_row(
1806 "INSERT INTO lease_locks (key, holder, expiration)
1807 VALUES (?1, ?2, ?3)
1808 ON CONFLICT (key)
1809 DO
1810 UPDATE SET
1811 holder = excluded.holder,
1812 expiration = excluded.expiration,
1813 generation =
1814 CASE holder
1815 WHEN excluded.holder THEN generation
1816 ELSE generation + 1
1817 END
1818 WHERE
1819 holder = excluded.holder
1820 OR expiration < ?4
1821 RETURNING generation
1822 ",
1823 (key, holder, expiration, now),
1824 |row| row.get(0),
1825 )
1826 .optional()
1827 })
1828 .await?;
1829
1830 Ok(generation)
1831 }
1832
1833 async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
1834 let conn = self.read().await?;
1835 if let Some(token) = conn.get_kv("next_batch_token").await? {
1836 let maybe_token: Option<String> = self.deserialize_value(&token)?;
1837 Ok(maybe_token)
1838 } else {
1839 Ok(None)
1840 }
1841 }
1842
1843 async fn close(&self) -> Result<()> {
1844 connection::close_connections(&self.connections, "Crypto store").await;
1845 Ok(())
1846 }
1847
1848 async fn reopen(&self) -> Result<()> {
1849 connection::reopen_connections(
1850 &self.connections,
1851 self.db_path.clone(),
1852 self.pool_config,
1853 self.runtime_config,
1854 )
1855 .await?;
1856 Ok(())
1857 }
1858
1859 async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1860 Ok(Some(self.read().await?.get_db_size().await?))
1861 }
1862}
1863
1864#[cfg(test)]
1865mod tests {
1866 use std::{path::Path, sync::LazyLock};
1867
1868 use matrix_sdk_common::deserialized_responses::WithheldCode;
1869 use matrix_sdk_crypto::{
1870 cryptostore_integration_tests, cryptostore_integration_tests_time, olm::SenderDataType,
1871 store::CryptoStore,
1872 };
1873 use matrix_sdk_test::async_test;
1874 use ruma::{device_id, room_id, user_id};
1875 use similar_asserts::assert_eq;
1876 use tempfile::{TempDir, tempdir};
1877 use tokio::fs;
1878
1879 use super::SqliteCryptoStore;
1880 use crate::SqliteStoreConfig;
1881
1882 static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
1883
1884 struct TestDb {
1885 _dir: TempDir,
1888 database: SqliteCryptoStore,
1889 }
1890
1891 fn copy_db(data_path: &str) -> TempDir {
1892 let db_name = super::DATABASE_NAME;
1893
1894 let manifest_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../..");
1895 let database_path = manifest_path.join(data_path).join(db_name);
1896
1897 let tmpdir = tempdir().unwrap();
1898 let destination = tmpdir.path().join(db_name);
1899
1900 std::fs::copy(&database_path, destination).unwrap();
1902
1903 tmpdir
1904 }
1905
1906 async fn get_test_db(data_path: &str, passphrase: Option<&str>) -> TestDb {
1907 let tmpdir = copy_db(data_path);
1908
1909 let database = SqliteCryptoStore::open(tmpdir.path(), passphrase)
1910 .await
1911 .expect("Can't open the test store");
1912
1913 TestDb { _dir: tmpdir, database }
1914 }
1915
1916 #[async_test]
1917 async fn test_pool_size() {
1918 let store_open_config =
1919 SqliteStoreConfig::new(TMP_DIR.path().join("test_pool_size")).pool_max_size(42);
1920
1921 let store = SqliteCryptoStore::open_with_config(&store_open_config).await.unwrap();
1922
1923 let guard = store.connections.lock().await;
1924 let conns = guard.as_ref().unwrap();
1925 assert_eq!(conns.pool.status().max_size, 42);
1926 }
1927
1928 #[async_test]
1931 async fn test_open_test_vector_store() {
1932 let TestDb { _dir: _, database } = get_test_db("testing/data/storage", None).await;
1933
1934 let account = database
1935 .load_account()
1936 .await
1937 .unwrap()
1938 .expect("The test database is prefilled with data, we should find an account");
1939
1940 let user_id = account.user_id();
1941 let device_id = account.device_id();
1942
1943 assert_eq!(
1944 user_id.as_str(),
1945 "@pjtest:synapse-oidc.element.dev",
1946 "The user ID should match to the one we expect."
1947 );
1948
1949 assert_eq!(
1950 device_id.as_str(),
1951 "v4TqgcuIH6",
1952 "The device ID should match to the one we expect."
1953 );
1954
1955 let device = database
1956 .get_device(user_id, device_id)
1957 .await
1958 .unwrap()
1959 .expect("Our own device should be found in the store.");
1960
1961 assert_eq!(device.device_id(), device_id);
1962 assert_eq!(device.user_id(), user_id);
1963
1964 assert_eq!(
1965 device.ed25519_key().expect("The device should have a Ed25519 key.").to_base64(),
1966 "+cxl1Gl3du5i7UJwfWnoRDdnafFF+xYdAiTYYhYLr8s"
1967 );
1968
1969 assert_eq!(
1970 device.curve25519_key().expect("The device should have a Curve25519 key.").to_base64(),
1971 "4SL9eEUlpyWSUvjljC5oMjknHQQJY7WZKo5S1KL/5VU"
1972 );
1973
1974 let identity = database
1975 .get_user_identity(user_id)
1976 .await
1977 .unwrap()
1978 .expect("The store should contain an identity.");
1979
1980 assert_eq!(identity.user_id(), user_id);
1981
1982 let identity = identity
1983 .own()
1984 .expect("The identity should be of the correct type, it should be our own identity.");
1985
1986 let master_key = identity
1987 .master_key()
1988 .get_first_key()
1989 .expect("Our own identity should have a master key");
1990
1991 assert_eq!(master_key.to_base64(), "iCUEtB1RwANeqRa5epDrblLk4mer/36sylwQ5hYY3oE");
1992 }
1993
1994 #[async_test]
1997 async fn test_open_test_vector_encrypted_store() {
1998 let TestDb { _dir: _, database } = get_test_db(
1999 "testing/data/storage/alice",
2000 Some(concat!(
2001 "/rCia2fYAJ+twCZ1Xm2mxFCYcmJdyzkdJjwtgXsziWpYS/UeNxnixuSieuwZXm+x1VsJHmWpl",
2002 "H+QIQBZpEGZtC9/S/l8xK+WOCesmET0o6yJ/KP73ofDtjBlnNpPwuHLKFpyTbyicpCgQ4UT+5E",
2003 "UBuJ08TY9Ujdf1D13k5kr5tSZUefDKKCuG1fCRqlU8ByRas1PMQsZxT2W8t7QgBrQiiGmhpo/O",
2004 "Ti4hfx97GOxncKcxTzppiYQNoHs/f15+XXQD7/oiCcqRIuUlXNsU6hRpFGmbYx2Pi1eyQViQCt",
2005 "B5dAEiSD0N8U81wXYnpynuTPtnL+hfnOJIn7Sy7mkERQeKg"
2006 )),
2007 )
2008 .await;
2009
2010 let account = database
2011 .load_account()
2012 .await
2013 .unwrap()
2014 .expect("The test database is prefilled with data, we should find an account");
2015
2016 let user_id = account.user_id();
2017 let device_id = account.device_id();
2018
2019 assert_eq!(
2020 user_id.as_str(),
2021 "@alice:localhost",
2022 "The user ID should match to the one we expect."
2023 );
2024
2025 assert_eq!(
2026 device_id.as_str(),
2027 "JVVORTHFXY",
2028 "The device ID should match to the one we expect."
2029 );
2030
2031 let tracked_users =
2032 database.load_tracked_users().await.expect("Should be tracking some users");
2033
2034 assert_eq!(tracked_users.len(), 6);
2035
2036 let known_users = vec![
2037 user_id!("@alice:localhost"),
2038 user_id!("@dehydration3:localhost"),
2039 user_id!("@eve:localhost"),
2040 user_id!("@bob:localhost"),
2041 user_id!("@malo:localhost"),
2042 user_id!("@carl:localhost"),
2043 ];
2044
2045 for user_id in known_users {
2047 database.get_user_identity(user_id).await.expect("Should load this identity").unwrap();
2048 }
2049
2050 let carl_identity =
2051 database.get_user_identity(user_id!("@carl:localhost")).await.unwrap().unwrap();
2052
2053 assert_eq!(
2054 carl_identity.master_key().get_first_key().unwrap().to_base64(),
2055 "CdhKYYDeBDQveOioXEGWhTPCyzc63Irpar3CNyfun2Q"
2056 );
2057 assert!(!carl_identity.was_previously_verified());
2058
2059 let bob_identity =
2060 database.get_user_identity(user_id!("@bob:localhost")).await.unwrap().unwrap();
2061
2062 assert_eq!(
2063 bob_identity.master_key().get_first_key().unwrap().to_base64(),
2064 "COh2GYOJWSjem5QPRCaGp9iWV83IELG1IzLKW2S3pFY"
2065 );
2066 assert!(bob_identity.was_previously_verified());
2068
2069 let known_devices = vec![
2070 (device_id!("OPXQHCZSKW"), user_id!("@alice:localhost")),
2071 (
2073 device_id!("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho"),
2074 user_id!("@alice:localhost"),
2075 ),
2076 (device_id!("HEEFRFQENV"), user_id!("@alice:localhost")),
2077 (device_id!("JVVORTHFXY"), user_id!("@alice:localhost")),
2078 (device_id!("NQUWWSKKHS"), user_id!("@alice:localhost")),
2079 (device_id!("ORBLPFYCPG"), user_id!("@alice:localhost")),
2080 (device_id!("YXOWENSEGM"), user_id!("@dehydration3:localhost")),
2081 (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
2082 (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
2083 (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
2084 (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
2085 (device_id!("QKUKWJTTQC"), user_id!("@malo:localhost")),
2086 (device_id!("LOUXJECTFG"), user_id!("@malo:localhost")),
2087 (device_id!("MKKMAEVLPB"), user_id!("@carl:localhost")),
2088 ];
2089
2090 for (device_id, user_id) in known_devices {
2091 database.get_device(user_id, device_id).await.expect("Should load the device").unwrap();
2092 }
2093
2094 let known_sender_key_to_session_count = vec![
2095 ("FfYcYfDF4nWy+LHdK6CEpIMlFAQDORc30WUkghL06kM", 1),
2096 ("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho", 1),
2097 ("hAGsoA4a9M6wwEUX5Q1jux1i+tUngLi01n5AmhDoHTY", 1),
2098 ("aKqtSJymLzuoglWFwPGk1r/Vm2LE2hFESzXxn4RNjRM", 0),
2099 ("zHK1psCrgeMn0kaz8hcdvA3INyar9jg1yfrSp0p1pHo", 1),
2100 ("1QmBA316Wj5jIFRwNOti6N6Xh/vW0bsYCcR4uPfy8VQ", 1),
2101 ("g5ef2vZF3VXgSPyODIeXpyHIRkuthvLhGvd6uwYggWU", 1),
2102 ("o7hfupPd1VsNkRIvdlH6ujrEJFSKjFCGbxhAd31XxjI", 1),
2103 ("Z3RxKQLxY7xpP+ZdOGR2SiNE37SrvmRhW7GPu1UGdm8", 1),
2104 ("GDomaav8NiY3J+dNEeApJm+O0FooJ3IpVaIyJzCN4w4", 1),
2105 ("7m7fqkHyEr47V5s/KjaxtJMOr3pSHrrns2q2lWpAQi8", 0),
2106 ("9psAkPUIF8vNbWbnviX3PlwRcaeO53EHJdNtKpTY1X0", 0),
2107 ("mqanh+ztw5oRtpqYQgLGW864i6NY2zpoKMIlrcyC+Aw", 0),
2108 ("fJU/TJdbsv7tVbbpHw1Ke73ziElnM32cNhP2WIg4T10", 0),
2109 ("sUIeFeFcCZoa5IC6nJ6Vrbvztcyx09m8BBg57XKRClg", 1),
2110 ];
2111
2112 for (id, count) in known_sender_key_to_session_count {
2113 let olm_sessions =
2114 database.get_sessions(id).await.expect("Should have some olm sessions");
2115
2116 println!("### Session id: {id:?}");
2117 assert_eq!(olm_sessions.map_or(0, |v| v.len()), count);
2118 }
2119
2120 let inbound_group_sessions = database.get_inbound_group_sessions().await.unwrap();
2121 assert_eq!(inbound_group_sessions.len(), 15);
2122 let known_inbound_group_sessions = vec![
2123 (
2124 "5hNAxrLai3VI0LKBwfh3wLfksfBFWds0W1a5X5/vSXA",
2125 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2126 ),
2127 (
2128 "M6d2eU3y54gaYTbvGSlqa/xc1Az35l56Cp9sxzHWO4g",
2129 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2130 ),
2131 (
2132 "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
2133 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2134 ),
2135 (
2136 "Y74+l9jTo7N5UF+GQwdpgJGe4sn1+QtWITq7BxulHIE",
2137 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2138 ),
2139 (
2140 "HpJxQR57WbQGdY6w2Q+C16znVvbXGa+JvQdRoMpWbXg",
2141 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2142 ),
2143 (
2144 "Xetvi+ydFkZt8dpONGFbEusQb/Chc2V0XlLByZhsbgE",
2145 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2146 ),
2147 (
2148 "wv/WN/39akyerIXczTaIpjAuLnwgXKRtbXFSEHiJqxo",
2149 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2150 ),
2151 (
2152 "nA4gQwL//Cm8OdlyjABl/jChbPT/cP5V4Sd8iuE6H0s",
2153 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2154 ),
2155 (
2156 "bAAgqFeRDTjfEqL6Qf/c9mk55zoNDCSlboAIRd6b0hw",
2157 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2158 ),
2159 (
2160 "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
2161 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2162 ),
2163 (
2164 "h+om7oSw/ZV94fcKaoe8FGXJwQXWOfKQfzbGgNWQILI",
2165 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2166 ),
2167 (
2168 "ul3VXonpgk4lO2L3fEWubP/nxsTmLHqu5v8ZM9vHEcw",
2169 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2170 ),
2171 (
2172 "JXY15UxC3az2mwg8uX4qwgxfvCM4aygiIWMcdNiVQoc",
2173 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2174 ),
2175 (
2176 "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
2177 room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
2178 ),
2179 (
2180 "SFkHcbxjUOYF7mUAYI/oEMDZFaXszQbCN6Jza7iemj0",
2181 room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
2182 ),
2183 ];
2184
2185 for (session_id, room_id) in &known_inbound_group_sessions {
2187 database
2188 .get_inbound_group_session(room_id, session_id)
2189 .await
2190 .expect("Should be able to load inbound group session")
2191 .unwrap();
2192 }
2193
2194 let bob_sender_verified = database
2195 .get_inbound_group_session(
2196 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
2197 "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
2198 )
2199 .await
2200 .unwrap()
2201 .unwrap();
2202
2203 assert_eq!(bob_sender_verified.sender_data.to_type(), SenderDataType::SenderVerified);
2204 assert!(bob_sender_verified.backed_up());
2205 assert!(!bob_sender_verified.has_been_imported());
2206
2207 let alice_unknown_device = database
2208 .get_inbound_group_session(
2209 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
2210 "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
2211 )
2212 .await
2213 .unwrap()
2214 .unwrap();
2215
2216 assert_eq!(alice_unknown_device.sender_data.to_type(), SenderDataType::UnknownDevice);
2217 assert!(alice_unknown_device.backed_up());
2218 assert!(alice_unknown_device.has_been_imported());
2219
2220 let carl_tofu_session = database
2221 .get_inbound_group_session(
2222 room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
2223 "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
2224 )
2225 .await
2226 .unwrap()
2227 .unwrap();
2228
2229 assert_eq!(carl_tofu_session.sender_data.to_type(), SenderDataType::SenderUnverified);
2230 assert!(carl_tofu_session.backed_up());
2231 assert!(!carl_tofu_session.has_been_imported());
2232
2233 database
2235 .get_outbound_group_session(room_id!("!OgRiTRMaUzLdpCeDBM:localhost"))
2236 .await
2237 .unwrap()
2238 .unwrap();
2239 database
2240 .get_outbound_group_session(room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"))
2241 .await
2242 .unwrap()
2243 .unwrap();
2244 database
2245 .get_outbound_group_session(room_id!("!SRstFdydzrGwJYtVfm:localhost"))
2246 .await
2247 .unwrap()
2248 .unwrap();
2249
2250 let withheld_info = database
2251 .get_withheld_info(
2252 room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
2253 "SASgZ+EklvAF4QxJclMlDRlmL0fAMjAJJIKFMdb4Ht0",
2254 )
2255 .await
2256 .expect("This session should be withheld")
2257 .unwrap();
2258
2259 assert_eq!(withheld_info.content.withheld_code(), WithheldCode::Unverified);
2260
2261 let backup_keys = database.load_backup_keys().await.expect("backup key should be cached");
2262 assert_eq!(backup_keys.backup_version.unwrap(), "6");
2263 assert!(backup_keys.decryption_key.is_some());
2264 }
2265
2266 #[async_test]
2272 async fn test_secrets_inbox_migration() {
2273 use std::ops::Deref;
2274
2275 use matrix_sdk_crypto::{
2276 GossipRequest, GossippedSecret, SecretInfo,
2277 types::events::{
2278 olm_v1::{DecryptedSecretSendEvent, OlmV1Keys},
2279 secret_send::SecretSendContent,
2280 },
2281 vodozemac::Ed25519SecretKey,
2282 };
2283 use ruma::{TransactionId, events::secret::request::SecretName, owned_user_id};
2284
2285 use crate::utils::{EncryptableStore, SqliteAsyncConnExt};
2286
2287 let tmpdir = tempdir().unwrap();
2289 let config = SqliteStoreConfig::new(tmpdir.path());
2290 let pool = config.build_pool_of_connections(super::DATABASE_NAME).unwrap();
2291 let conn = pool.get().await.unwrap();
2292 let version = super::initialize_store(&conn, 0).await.unwrap();
2293 let old_data_store = SqliteCryptoStore::create_raw(
2294 config.secret.clone(),
2295 pool,
2296 conn,
2297 config.pool_config(),
2298 config.runtime_config(),
2299 )
2300 .await
2301 .unwrap();
2302 super::run_migrations(&old_data_store, version, Some(16)).await.unwrap();
2303 old_data_store.write().await.unwrap().wal_checkpoint().await;
2304
2305 let secret = GossippedSecret {
2307 secret_name: SecretName::CrossSigningMasterKey,
2308 gossip_request: GossipRequest {
2309 request_recipient: owned_user_id!("@alice:example.com"),
2310 request_id: TransactionId::new(),
2311 info: SecretInfo::SecretRequest(SecretName::CrossSigningMasterKey),
2312 sent_out: true,
2313 },
2314 event: DecryptedSecretSendEvent {
2315 sender: owned_user_id!("@alice:example.com"),
2316 recipient: owned_user_id!("@alice:example.com"),
2317 keys: OlmV1Keys { ed25519: Ed25519SecretKey::new().public_key() },
2318 recipient_keys: OlmV1Keys { ed25519: Ed25519SecretKey::new().public_key() },
2319 sender_device_keys: None,
2320 content: SecretSendContent::new(
2321 "abc".into(),
2322 "It is a secret to everybody".to_owned(),
2323 ),
2324 },
2325 };
2326 let value = old_data_store.serialize_json(&secret).unwrap();
2327 old_data_store
2328 .write()
2329 .await
2330 .unwrap()
2331 .prepare("INSERT INTO secrets (secret_name, data) VALUES (?1, ?2)", |mut stmt| {
2332 stmt.execute((SecretName::CrossSigningMasterKey.to_string(), value))
2333 })
2334 .await
2335 .unwrap();
2336
2337 let store = SqliteCryptoStore::open_with_config(&config).await.unwrap();
2339
2340 let secrets =
2342 store.get_secrets_from_inbox(&SecretName::CrossSigningMasterKey).await.unwrap();
2343 assert_eq!(secrets.len(), 1);
2344 assert_eq!(secrets[0].deref(), "It is a secret to everybody");
2345 }
2346
2347 async fn get_store(
2348 name: &str,
2349 passphrase: Option<&str>,
2350 clear_data: bool,
2351 ) -> SqliteCryptoStore {
2352 let tmpdir_path = TMP_DIR.path().join(name);
2353
2354 if clear_data {
2355 let _ = fs::remove_dir_all(&tmpdir_path).await;
2356 }
2357
2358 SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), passphrase)
2359 .await
2360 .expect("Can't create a secret protected store")
2361 }
2362
2363 cryptostore_integration_tests!();
2364 cryptostore_integration_tests_time!();
2365}
2366
2367#[cfg(test)]
2368mod encrypted_tests {
2369 use std::sync::LazyLock;
2370
2371 use matrix_sdk_crypto::{cryptostore_integration_tests, cryptostore_integration_tests_time};
2372 use tempfile::{TempDir, tempdir};
2373 use tokio::fs;
2374
2375 use super::SqliteCryptoStore;
2376
2377 static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
2378
2379 async fn get_store(
2380 name: &str,
2381 passphrase: Option<&str>,
2382 clear_data: bool,
2383 ) -> SqliteCryptoStore {
2384 let tmpdir_path = TMP_DIR.path().join(name);
2385 let pass = passphrase.unwrap_or("default_test_password");
2386
2387 if clear_data {
2388 let _ = fs::remove_dir_all(&tmpdir_path).await;
2389 }
2390
2391 SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), Some(pass))
2392 .await
2393 .expect("Can't create a secret protected store")
2394 }
2395
2396 cryptostore_integration_tests!();
2397 cryptostore_integration_tests_time!();
2398}
2399
2400#[cfg(test)]
2401mod close_reopen_tests {
2402 use std::sync::LazyLock;
2403
2404 use matrix_sdk_crypto::store::CryptoStore;
2405 use matrix_sdk_test::async_test;
2406 use tempfile::{TempDir, tempdir};
2407
2408 use super::SqliteCryptoStore;
2409
2410 static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
2411
2412 async fn new_store(name: &str) -> SqliteCryptoStore {
2413 let tmpdir_path = TMP_DIR.path().join(name);
2414 SqliteCryptoStore::open(tmpdir_path, None).await.unwrap()
2415 }
2416
2417 #[async_test]
2418 async fn test_close_completes_without_timeout() {
2419 let store = new_store("close_no_timeout").await;
2420
2421 let start = std::time::Instant::now();
2423 store.close().await.unwrap();
2424 let elapsed = start.elapsed();
2425
2426 assert!(
2427 elapsed < std::time::Duration::from_secs(2),
2428 "close() took {elapsed:?}, expected < 2s (no timeout)"
2429 );
2430
2431 let guard = store.connections.lock().await;
2433 assert!(guard.is_none(), "connections should be None after close");
2434 }
2435
2436 #[async_test]
2437 async fn test_reopen_restores_connections() {
2438 let store = new_store("reopen_restores").await;
2439
2440 store.close().await.unwrap();
2441
2442 {
2443 let guard = store.connections.lock().await;
2444 assert!(guard.is_none());
2445 }
2446
2447 store.reopen().await.unwrap();
2448
2449 {
2450 let guard = store.connections.lock().await;
2451 assert!(guard.is_some(), "connections should be Some after reopen");
2452 }
2453 }
2454
2455 #[async_test]
2456 async fn test_close_is_idempotent() {
2457 let store = new_store("close_idempotent").await;
2458
2459 store.close().await.unwrap();
2460 store.close().await.unwrap();
2462
2463 let guard = store.connections.lock().await;
2464 assert!(guard.is_none());
2465 }
2466
2467 #[async_test]
2468 async fn test_reopen_is_idempotent() {
2469 let store = new_store("reopen_idempotent").await;
2470
2471 store.reopen().await.unwrap();
2473
2474 let guard = store.connections.lock().await;
2475 assert!(guard.is_some());
2476 }
2477
2478 #[async_test]
2479 async fn test_read_fails_when_closed() {
2480 let store = new_store("read_fails_closed").await;
2481 store.close().await.unwrap();
2482
2483 let err = store.load_account().await;
2484 assert!(err.is_err(), "read should fail when closed");
2485
2486 let err_msg = err.unwrap_err().to_string();
2487 assert!(err_msg.contains("closed"), "error should mention 'closed', got: {err_msg}");
2488 }
2489
2490 #[async_test]
2491 async fn test_operations_work_after_reopen() {
2492 let store = new_store("ops_after_reopen").await;
2493
2494 store.close().await.unwrap();
2495 store.reopen().await.unwrap();
2496
2497 let account = store.load_account().await;
2499 assert!(account.is_ok(), "load_account should succeed after reopen");
2500 assert!(account.unwrap().is_none());
2502 }
2503
2504 #[async_test]
2505 async fn test_multiple_close_reopen_cycles() {
2506 let store = new_store("multi_cycles").await;
2507
2508 for _ in 0..5 {
2509 store.close().await.unwrap();
2510 store.reopen().await.unwrap();
2511
2512 let account = store.load_account().await;
2514 assert!(account.is_ok(), "store should work after close/reopen cycle");
2515 }
2516 }
2517
2518 #[async_test]
2519 async fn test_pool_is_fully_drained_after_close() {
2520 let store = new_store("pool_drained").await;
2521
2522 let _ = store.load_account().await;
2524 let _ = store.load_account().await;
2525
2526 store.close().await.unwrap();
2527
2528 let guard = store.connections.lock().await;
2530 assert!(guard.is_none(), "all connections should be released after close");
2531 }
2532
2533 #[async_test]
2534 async fn test_close_waits_for_held_read_connection_to_drain() {
2535 let store = new_store("held_read_drain").await;
2536
2537 let held_conn = store.read().await.unwrap();
2539
2540 let store_clone = store.clone();
2543 let close_handle = tokio::spawn(async move {
2544 store_clone.close().await.unwrap();
2545 });
2546
2547 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
2549
2550 assert!(!close_handle.is_finished(), "close should be waiting for the held connection");
2552
2553 drop(held_conn);
2555
2556 let timeout = tokio::time::timeout(std::time::Duration::from_secs(3), close_handle).await;
2558 assert!(timeout.is_ok(), "close should complete after the held connection is released");
2559 timeout.unwrap().unwrap();
2560
2561 let guard = store.connections.lock().await;
2563 assert!(guard.is_none(), "connections should be None after close");
2564 }
2565}