1use std::{
16 borrow::Cow,
17 collections::HashMap,
18 fmt,
19 path::Path,
20 sync::{Arc, RwLock},
21};
22
23use async_trait::async_trait;
24use deadpool_sqlite::{Object as SqliteAsyncConn, Pool as SqlitePool, Runtime};
25use matrix_sdk_crypto::{
26 olm::{
27 InboundGroupSession, OutboundGroupSession, PickledInboundGroupSession,
28 PrivateCrossSigningIdentity, SenderDataType, Session, StaticAccountData,
29 },
30 store::{
31 BackupKeys, Changes, CryptoStore, DehydratedDeviceKey, PendingChanges, RoomKeyCounts,
32 RoomSettings,
33 },
34 types::events::room_key_withheld::RoomKeyWithheldEvent,
35 Account, DeviceData, GossipRequest, GossippedSecret, SecretInfo, TrackedUser, UserIdentityData,
36};
37use matrix_sdk_store_encryption::StoreCipher;
38use ruma::{
39 events::secret::request::SecretName, DeviceId, MilliSecondsSinceUnixEpoch, OwnedDeviceId,
40 RoomId, TransactionId, UserId,
41};
42use rusqlite::{named_params, params_from_iter, OptionalExtension};
43use serde::{de::DeserializeOwned, Serialize};
44use tokio::{fs, sync::Mutex};
45use tracing::{debug, instrument, warn};
46use vodozemac::Curve25519PublicKey;
47
48use crate::{
49 error::{Error, Result},
50 utils::{
51 repeat_vars, Key, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt,
52 SqliteKeyValueStoreConnExt,
53 },
54 OpenStoreError, SqliteStoreConfig,
55};
56
57const DATABASE_NAME: &str = "matrix-sdk-crypto.sqlite3";
59
60#[derive(Clone)]
62pub struct SqliteCryptoStore {
63 store_cipher: Option<Arc<StoreCipher>>,
64 pool: SqlitePool,
65
66 static_account: Arc<RwLock<Option<StaticAccountData>>>,
68 save_changes_lock: Arc<Mutex<()>>,
69}
70
71#[cfg(not(tarpaulin_include))]
72impl fmt::Debug for SqliteCryptoStore {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 f.debug_struct("SqliteCryptoStore").finish_non_exhaustive()
75 }
76}
77
78impl SqliteCryptoStore {
79 pub async fn open(
82 path: impl AsRef<Path>,
83 passphrase: Option<&str>,
84 ) -> Result<Self, OpenStoreError> {
85 Self::open_with_config(SqliteStoreConfig::new(path).passphrase(passphrase)).await
86 }
87
88 pub async fn open_with_config(config: SqliteStoreConfig) -> Result<Self, OpenStoreError> {
90 let SqliteStoreConfig { path, passphrase, pool_config, runtime_config } = config;
91
92 fs::create_dir_all(&path).await.map_err(OpenStoreError::CreateDir)?;
93
94 let mut config = deadpool_sqlite::Config::new(path.join(DATABASE_NAME));
95 config.pool = Some(pool_config);
96
97 let pool = config.create_pool(Runtime::Tokio1)?;
98
99 let this = Self::open_with_pool(pool, passphrase.as_deref()).await?;
100 this.pool.get().await?.apply_runtime_config(runtime_config).await?;
101
102 Ok(this)
103 }
104
105 async fn open_with_pool(
108 pool: SqlitePool,
109 passphrase: Option<&str>,
110 ) -> Result<Self, OpenStoreError> {
111 let conn = pool.get().await?;
112
113 let version = conn.db_version().await?;
114 run_migrations(&conn, version).await?;
115
116 let store_cipher = match passphrase {
117 Some(p) => Some(Arc::new(conn.get_or_create_store_cipher(p).await?)),
118 None => None,
119 };
120
121 Ok(SqliteCryptoStore {
122 store_cipher,
123 pool,
124 static_account: Arc::new(RwLock::new(None)),
125 save_changes_lock: Default::default(),
126 })
127 }
128
129 fn encode_value(&self, value: Vec<u8>) -> Result<Vec<u8>> {
130 if let Some(key) = &self.store_cipher {
131 let encrypted = key.encrypt_value_data(value)?;
132 Ok(rmp_serde::to_vec_named(&encrypted)?)
133 } else {
134 Ok(value)
135 }
136 }
137
138 fn decode_value<'a>(&self, value: &'a [u8]) -> Result<Cow<'a, [u8]>> {
139 if let Some(key) = &self.store_cipher {
140 let encrypted = rmp_serde::from_slice(value)?;
141 let decrypted = key.decrypt_value_data(encrypted)?;
142 Ok(Cow::Owned(decrypted))
143 } else {
144 Ok(Cow::Borrowed(value))
145 }
146 }
147
148 fn serialize_json(&self, value: &impl Serialize) -> Result<Vec<u8>> {
149 let serialized = serde_json::to_vec(value)?;
150 self.encode_value(serialized)
151 }
152
153 fn deserialize_json<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T> {
154 let decoded = self.decode_value(data)?;
155 Ok(serde_json::from_slice(&decoded)?)
156 }
157
158 fn serialize_value(&self, value: &impl Serialize) -> Result<Vec<u8>> {
159 let serialized = rmp_serde::to_vec_named(value)?;
160 self.encode_value(serialized)
161 }
162
163 fn deserialize_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T> {
164 let decoded = self.decode_value(value)?;
165 Ok(rmp_serde::from_slice(&decoded)?)
166 }
167
168 fn deserialize_and_unpickle_inbound_group_session(
169 &self,
170 value: Vec<u8>,
171 backed_up: bool,
172 ) -> Result<InboundGroupSession> {
173 let mut pickle: PickledInboundGroupSession = self.deserialize_value(&value)?;
174
175 pickle.backed_up = backed_up;
180
181 Ok(InboundGroupSession::from_pickle(pickle)?)
182 }
183
184 fn deserialize_key_request(&self, value: &[u8], sent_out: bool) -> Result<GossipRequest> {
185 let mut request: GossipRequest = self.deserialize_value(value)?;
186 request.sent_out = sent_out;
189 Ok(request)
190 }
191
192 fn encode_key(&self, table_name: &str, key: impl AsRef<[u8]>) -> Key {
193 let bytes = key.as_ref();
194 if let Some(store_cipher) = &self.store_cipher {
195 Key::Hashed(store_cipher.hash_key(table_name, bytes))
196 } else {
197 Key::Plain(bytes.to_owned())
198 }
199 }
200
201 fn get_static_account(&self) -> Option<StaticAccountData> {
202 self.static_account.read().unwrap().clone()
203 }
204
205 async fn acquire(&self) -> Result<SqliteAsyncConn> {
206 Ok(self.pool.get().await?)
207 }
208}
209
210const DATABASE_VERSION: u8 = 9;
211
212const DEHYDRATED_DEVICE_PICKLE_KEY: &str = "dehydrated_device_pickle_key";
214
215async fn run_migrations(conn: &SqliteAsyncConn, version: u8) -> Result<()> {
217 if version == 0 {
218 debug!("Creating database");
219 } else if version < DATABASE_VERSION {
220 debug!(version, new_version = DATABASE_VERSION, "Upgrading database");
221 } else {
222 return Ok(());
223 }
224
225 if version < 1 {
226 conn.execute_batch("PRAGMA journal_mode = wal;").await?;
229 conn.with_transaction(|txn| {
230 txn.execute_batch(include_str!("../migrations/crypto_store/001_init.sql"))?;
231 txn.set_db_version(1)
232 })
233 .await?;
234 }
235
236 if version < 2 {
237 conn.with_transaction(|txn| {
238 txn.execute_batch(include_str!("../migrations/crypto_store/002_reset_olm_hash.sql"))?;
239 txn.set_db_version(2)
240 })
241 .await?;
242 }
243
244 if version < 3 {
245 conn.with_transaction(|txn| {
246 txn.execute_batch(include_str!("../migrations/crypto_store/003_room_settings.sql"))?;
247 txn.set_db_version(3)
248 })
249 .await?;
250 }
251
252 if version < 4 {
253 conn.with_transaction(|txn| {
254 txn.execute_batch(include_str!(
255 "../migrations/crypto_store/004_drop_outbound_group_sessions.sql"
256 ))?;
257 txn.set_db_version(4)
258 })
259 .await?;
260 }
261
262 if version < 5 {
263 conn.with_transaction(|txn| {
264 txn.execute_batch(include_str!("../migrations/crypto_store/005_withheld_code.sql"))?;
265 txn.set_db_version(5)
266 })
267 .await?;
268 }
269
270 if version < 6 {
271 conn.with_transaction(|txn| {
272 txn.execute_batch(include_str!(
273 "../migrations/crypto_store/006_drop_outbound_group_sessions.sql"
274 ))?;
275 txn.set_db_version(6)
276 })
277 .await?;
278 }
279
280 if version < 7 {
281 conn.with_transaction(|txn| {
282 txn.execute_batch(include_str!("../migrations/crypto_store/007_lock_leases.sql"))?;
283 txn.set_db_version(7)
284 })
285 .await?;
286 }
287
288 if version < 8 {
289 conn.with_transaction(|txn| {
290 txn.execute_batch(include_str!("../migrations/crypto_store/008_secret_inbox.sql"))?;
291 txn.set_db_version(8)
292 })
293 .await?;
294 }
295
296 if version < 9 {
297 conn.with_transaction(|txn| {
298 txn.execute_batch(include_str!(
299 "../migrations/crypto_store/009_inbound_group_session_sender_key_sender_data_type.sql"
300 ))?;
301 txn.set_db_version(9)
302 })
303 .await?;
304 }
305
306 Ok(())
307}
308
309trait SqliteConnectionExt {
310 fn set_session(
311 &self,
312 session_id: &[u8],
313 sender_key: &[u8],
314 data: &[u8],
315 ) -> rusqlite::Result<()>;
316
317 fn set_inbound_group_session(
318 &self,
319 room_id: &[u8],
320 session_id: &[u8],
321 data: &[u8],
322 backed_up: bool,
323 sender_key: Option<&[u8]>,
324 sender_data_type: Option<u8>,
325 ) -> rusqlite::Result<()>;
326
327 fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
328
329 fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
330 fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()>;
331
332 fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
333
334 fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()>;
335
336 fn set_key_request(
337 &self,
338 request_id: &[u8],
339 sent_out: bool,
340 data: &[u8],
341 ) -> rusqlite::Result<()>;
342
343 fn set_direct_withheld(
344 &self,
345 session_id: &[u8],
346 room_id: &[u8],
347 data: &[u8],
348 ) -> rusqlite::Result<()>;
349
350 fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
351
352 fn set_secret(&self, request_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
353}
354
355impl SqliteConnectionExt for rusqlite::Connection {
356 fn set_session(
357 &self,
358 session_id: &[u8],
359 sender_key: &[u8],
360 data: &[u8],
361 ) -> rusqlite::Result<()> {
362 self.execute(
363 "INSERT INTO session (session_id, sender_key, data)
364 VALUES (?1, ?2, ?3)
365 ON CONFLICT (session_id) DO UPDATE SET data = ?3",
366 (session_id, sender_key, data),
367 )?;
368 Ok(())
369 }
370
371 fn set_inbound_group_session(
372 &self,
373 room_id: &[u8],
374 session_id: &[u8],
375 data: &[u8],
376 backed_up: bool,
377 sender_key: Option<&[u8]>,
378 sender_data_type: Option<u8>,
379 ) -> rusqlite::Result<()> {
380 self.execute(
381 "INSERT INTO inbound_group_session (session_id, room_id, data, backed_up, sender_key, sender_data_type) \
382 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
383 ON CONFLICT (session_id) DO UPDATE SET data = ?3, backed_up = ?4, sender_key = ?5, sender_data_type = ?6",
384 (session_id, room_id, data, backed_up, sender_key, sender_data_type),
385 )?;
386 Ok(())
387 }
388
389 fn set_outbound_group_session(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
390 self.execute(
391 "INSERT INTO outbound_group_session (room_id, data) \
392 VALUES (?1, ?2)
393 ON CONFLICT (room_id) DO UPDATE SET data = ?2",
394 (room_id, data),
395 )?;
396 Ok(())
397 }
398
399 fn set_device(&self, user_id: &[u8], device_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
400 self.execute(
401 "INSERT INTO device (user_id, device_id, data) \
402 VALUES (?1, ?2, ?3)
403 ON CONFLICT (user_id, device_id) DO UPDATE SET data = ?3",
404 (user_id, device_id, data),
405 )?;
406 Ok(())
407 }
408
409 fn delete_device(&self, user_id: &[u8], device_id: &[u8]) -> rusqlite::Result<()> {
410 self.execute(
411 "DELETE FROM device WHERE user_id = ? AND device_id = ?",
412 (user_id, device_id),
413 )?;
414 Ok(())
415 }
416
417 fn set_identity(&self, user_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
418 self.execute(
419 "INSERT INTO identity (user_id, data) \
420 VALUES (?1, ?2)
421 ON CONFLICT (user_id) DO UPDATE SET data = ?2",
422 (user_id, data),
423 )?;
424 Ok(())
425 }
426
427 fn add_olm_hash(&self, data: &[u8]) -> rusqlite::Result<()> {
428 self.execute("INSERT INTO olm_hash (data) VALUES (?) ON CONFLICT DO NOTHING", (data,))?;
429 Ok(())
430 }
431
432 fn set_key_request(
433 &self,
434 request_id: &[u8],
435 sent_out: bool,
436 data: &[u8],
437 ) -> rusqlite::Result<()> {
438 self.execute(
439 "INSERT INTO key_requests (request_id, sent_out, data)
440 VALUES (?1, ?2, ?3)
441 ON CONFLICT (request_id) DO UPDATE SET sent_out = ?2, data = ?3",
442 (request_id, sent_out, data),
443 )?;
444 Ok(())
445 }
446
447 fn set_direct_withheld(
448 &self,
449 session_id: &[u8],
450 room_id: &[u8],
451 data: &[u8],
452 ) -> rusqlite::Result<()> {
453 self.execute(
454 "INSERT INTO direct_withheld_info (session_id, room_id, data)
455 VALUES (?1, ?2, ?3)
456 ON CONFLICT (session_id) DO UPDATE SET room_id = ?2, data = ?3",
457 (session_id, room_id, data),
458 )?;
459 Ok(())
460 }
461
462 fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
463 self.execute(
464 "INSERT INTO room_settings (room_id, data)
465 VALUES (?1, ?2)
466 ON CONFLICT (room_id) DO UPDATE SET data = ?2",
467 (room_id, data),
468 )?;
469 Ok(())
470 }
471
472 fn set_secret(&self, secret_name: &[u8], data: &[u8]) -> rusqlite::Result<()> {
473 self.execute(
474 "INSERT INTO secrets (secret_name, data)
475 VALUES (?1, ?2)",
476 (secret_name, data),
477 )?;
478
479 Ok(())
480 }
481}
482
483#[async_trait]
484trait SqliteObjectCryptoStoreExt: SqliteAsyncConnExt {
485 async fn get_sessions_for_sender_key(&self, sender_key: Key) -> Result<Vec<Vec<u8>>> {
486 Ok(self
487 .prepare("SELECT data FROM session WHERE sender_key = ?", |mut stmt| {
488 stmt.query((sender_key,))?.mapped(|row| row.get(0)).collect()
489 })
490 .await?)
491 }
492
493 async fn get_inbound_group_session(
494 &self,
495 session_id: Key,
496 ) -> Result<Option<(Vec<u8>, Vec<u8>, bool)>> {
497 Ok(self
498 .query_row(
499 "SELECT room_id, data, backed_up FROM inbound_group_session WHERE session_id = ?",
500 (session_id,),
501 |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
502 )
503 .await
504 .optional()?)
505 }
506
507 async fn get_inbound_group_sessions(&self) -> Result<Vec<(Vec<u8>, bool)>> {
508 Ok(self
509 .prepare("SELECT data, backed_up FROM inbound_group_session", |mut stmt| {
510 stmt.query(())?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
511 })
512 .await?)
513 }
514
515 async fn get_inbound_group_session_counts(
516 &self,
517 _backup_version: Option<&str>,
518 ) -> Result<RoomKeyCounts> {
519 let total = self
520 .query_row("SELECT count(*) FROM inbound_group_session", (), |row| row.get(0))
521 .await?;
522 let backed_up = self
523 .query_row(
524 "SELECT count(*) FROM inbound_group_session WHERE backed_up = TRUE",
525 (),
526 |row| row.get(0),
527 )
528 .await?;
529 Ok(RoomKeyCounts { total, backed_up })
530 }
531
532 async fn get_inbound_group_sessions_for_device_batch(
533 &self,
534 sender_key: Key,
535 sender_data_type: SenderDataType,
536 after_session_id: Option<Key>,
537 limit: usize,
538 ) -> Result<Vec<(Vec<u8>, bool)>> {
539 Ok(self
540 .prepare(
541 "
542 SELECT data, backed_up
543 FROM inbound_group_session
544 WHERE sender_key = :sender_key
545 AND sender_data_type = :sender_data_type
546 AND session_id > :after_session_id
547 ORDER BY session_id
548 LIMIT :limit
549 ",
550 move |mut stmt| {
551 let sender_data_type = sender_data_type as u8;
552
553 let after_session_id = after_session_id.unwrap_or(Key::Plain(Vec::new()));
556
557 stmt.query(named_params! {
558 ":sender_key": sender_key,
559 ":sender_data_type": sender_data_type,
560 ":after_session_id": after_session_id,
561 ":limit": limit,
562 })?
563 .mapped(|row| Ok((row.get(0)?, row.get(1)?)))
564 .collect()
565 },
566 )
567 .await?)
568 }
569
570 async fn get_inbound_group_sessions_for_backup(&self, limit: usize) -> Result<Vec<Vec<u8>>> {
571 Ok(self
572 .prepare(
573 "SELECT data FROM inbound_group_session WHERE backed_up = FALSE LIMIT ?",
574 move |mut stmt| stmt.query((limit,))?.mapped(|row| row.get(0)).collect(),
575 )
576 .await?)
577 }
578
579 async fn mark_inbound_group_sessions_as_backed_up(&self, session_ids: Vec<Key>) -> Result<()> {
580 if session_ids.is_empty() {
581 warn!("No sessions to mark as backed up!");
583 return Ok(());
584 }
585
586 let session_ids_len = session_ids.len();
587
588 self.chunk_large_query_over(session_ids, None, move |txn, session_ids| {
589 let sql_params = repeat_vars(session_ids_len);
592 let query = format!("UPDATE inbound_group_session SET backed_up = TRUE where session_id IN ({sql_params})");
593 txn.prepare(&query)?.execute(params_from_iter(session_ids.iter()))?;
594 Ok(Vec::<()>::new())
595 }).await?;
596
597 Ok(())
598 }
599
600 async fn reset_inbound_group_session_backup_state(&self) -> Result<()> {
601 self.execute("UPDATE inbound_group_session SET backed_up = FALSE", ()).await?;
602 Ok(())
603 }
604
605 async fn get_outbound_group_session(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
606 Ok(self
607 .query_row(
608 "SELECT data FROM outbound_group_session WHERE room_id = ?",
609 (room_id,),
610 |row| row.get(0),
611 )
612 .await
613 .optional()?)
614 }
615
616 async fn get_device(&self, user_id: Key, device_id: Key) -> Result<Option<Vec<u8>>> {
617 Ok(self
618 .query_row(
619 "SELECT data FROM device WHERE user_id = ? AND device_id = ?",
620 (user_id, device_id),
621 |row| row.get(0),
622 )
623 .await
624 .optional()?)
625 }
626
627 async fn get_user_devices(&self, user_id: Key) -> Result<Vec<Vec<u8>>> {
628 Ok(self
629 .prepare("SELECT data FROM device WHERE user_id = ?", |mut stmt| {
630 stmt.query((user_id,))?.mapped(|row| row.get(0)).collect()
631 })
632 .await?)
633 }
634
635 async fn get_user_identity(&self, user_id: Key) -> Result<Option<Vec<u8>>> {
636 Ok(self
637 .query_row("SELECT data FROM identity WHERE user_id = ?", (user_id,), |row| row.get(0))
638 .await
639 .optional()?)
640 }
641
642 async fn has_olm_hash(&self, data: Vec<u8>) -> Result<bool> {
643 Ok(self
644 .query_row("SELECT count(*) FROM olm_hash WHERE data = ?", (data,), |row| {
645 row.get::<_, i32>(0)
646 })
647 .await?
648 > 0)
649 }
650
651 async fn get_tracked_users(&self) -> Result<Vec<Vec<u8>>> {
652 Ok(self
653 .prepare("SELECT data FROM tracked_user", |mut stmt| {
654 stmt.query(())?.mapped(|row| row.get(0)).collect()
655 })
656 .await?)
657 }
658
659 async fn add_tracked_users(&self, users: Vec<(Key, Vec<u8>)>) -> Result<()> {
660 Ok(self
661 .prepare(
662 "INSERT INTO tracked_user (user_id, data) \
663 VALUES (?1, ?2) \
664 ON CONFLICT (user_id) DO UPDATE SET data = ?2",
665 |mut stmt| {
666 for (user_id, data) in users {
667 stmt.execute((user_id, data))?;
668 }
669
670 Ok(())
671 },
672 )
673 .await?)
674 }
675
676 async fn get_outgoing_secret_request(
677 &self,
678 request_id: Key,
679 ) -> Result<Option<(Vec<u8>, bool)>> {
680 Ok(self
681 .query_row(
682 "SELECT data, sent_out FROM key_requests WHERE request_id = ?",
683 (request_id,),
684 |row| Ok((row.get(0)?, row.get(1)?)),
685 )
686 .await
687 .optional()?)
688 }
689
690 async fn get_outgoing_secret_requests(&self) -> Result<Vec<(Vec<u8>, bool)>> {
691 Ok(self
692 .prepare("SELECT data, sent_out FROM key_requests", |mut stmt| {
693 stmt.query(())?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
694 })
695 .await?)
696 }
697
698 async fn get_unsent_secret_requests(&self) -> Result<Vec<Vec<u8>>> {
699 Ok(self
700 .prepare("SELECT data FROM key_requests WHERE sent_out = FALSE", |mut stmt| {
701 stmt.query(())?.mapped(|row| row.get(0)).collect()
702 })
703 .await?)
704 }
705
706 async fn delete_key_request(&self, request_id: Key) -> Result<()> {
707 self.execute("DELETE FROM key_requests WHERE request_id = ?", (request_id,)).await?;
708 Ok(())
709 }
710
711 async fn get_secrets_from_inbox(&self, secret_name: Key) -> Result<Vec<Vec<u8>>> {
712 Ok(self
713 .prepare("SELECT data FROM secrets WHERE secret_name = ?", |mut stmt| {
714 stmt.query((secret_name,))?.mapped(|row| row.get(0)).collect()
715 })
716 .await?)
717 }
718
719 async fn delete_secrets_from_inbox(&self, secret_name: Key) -> Result<()> {
720 self.execute("DELETE FROM secrets WHERE secret_name = ?", (secret_name,)).await?;
721 Ok(())
722 }
723
724 async fn get_direct_withheld_info(
725 &self,
726 session_id: Key,
727 room_id: Key,
728 ) -> Result<Option<Vec<u8>>> {
729 Ok(self
730 .query_row(
731 "SELECT data FROM direct_withheld_info WHERE session_id = ?1 AND room_id = ?2",
732 (session_id, room_id),
733 |row| row.get(0),
734 )
735 .await
736 .optional()?)
737 }
738
739 async fn get_room_settings(&self, room_id: Key) -> Result<Option<Vec<u8>>> {
740 Ok(self
741 .query_row("SELECT data FROM room_settings WHERE room_id = ?", (room_id,), |row| {
742 row.get(0)
743 })
744 .await
745 .optional()?)
746 }
747}
748
749#[async_trait]
750impl SqliteObjectCryptoStoreExt for SqliteAsyncConn {}
751
752#[async_trait]
753impl CryptoStore for SqliteCryptoStore {
754 type Error = Error;
755
756 async fn load_account(&self) -> Result<Option<Account>> {
757 let conn = self.acquire().await?;
758 if let Some(pickle) = conn.get_kv("account").await? {
759 let pickle = self.deserialize_value(&pickle)?;
760
761 let account = Account::from_pickle(pickle).map_err(|_| Error::Unpickle)?;
762
763 *self.static_account.write().unwrap() = Some(account.static_data().clone());
764
765 Ok(Some(account))
766 } else {
767 Ok(None)
768 }
769 }
770
771 async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
772 let conn = self.acquire().await?;
773 if let Some(i) = conn.get_kv("identity").await? {
774 let pickle = self.deserialize_value(&i)?;
775 Ok(Some(PrivateCrossSigningIdentity::from_pickle(pickle).map_err(|_| Error::Unpickle)?))
776 } else {
777 Ok(None)
778 }
779 }
780
781 async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
782 let _guard = self.save_changes_lock.lock().await;
787
788 let pickled_account = if let Some(account) = changes.account {
789 *self.static_account.write().unwrap() = Some(account.static_data().clone());
790 Some(account.pickle())
791 } else {
792 None
793 };
794
795 let this = self.clone();
796 self.acquire()
797 .await?
798 .with_transaction(move |txn| {
799 if let Some(pickled_account) = pickled_account {
800 let serialized_account = this.serialize_value(&pickled_account)?;
801 txn.set_kv("account", &serialized_account)?;
802 }
803
804 Ok::<_, Error>(())
805 })
806 .await?;
807
808 Ok(())
809 }
810
811 async fn save_changes(&self, changes: Changes) -> Result<()> {
812 let _guard = self.save_changes_lock.lock().await;
817
818 let pickled_private_identity =
819 if let Some(i) = changes.private_identity { Some(i.pickle().await) } else { None };
820
821 let mut session_changes = Vec::new();
822
823 for session in changes.sessions {
824 let session_id = self.encode_key("session", session.session_id());
825 let sender_key = self.encode_key("session", session.sender_key().to_base64());
826 let pickle = session.pickle().await;
827 session_changes.push((session_id, sender_key, pickle));
828 }
829
830 let mut inbound_session_changes = Vec::new();
831 for session in changes.inbound_group_sessions {
832 let room_id = self.encode_key("inbound_group_session", session.room_id().as_bytes());
833 let session_id = self.encode_key("inbound_group_session", session.session_id());
834 let pickle = session.pickle().await;
835 let sender_key =
836 self.encode_key("inbound_group_session", session.sender_key().to_base64());
837 inbound_session_changes.push((room_id, session_id, pickle, sender_key));
838 }
839
840 let mut outbound_session_changes = Vec::new();
841 for session in changes.outbound_group_sessions {
842 let room_id = self.encode_key("outbound_group_session", session.room_id().as_bytes());
843 let pickle = session.pickle().await;
844 outbound_session_changes.push((room_id, pickle));
845 }
846
847 let this = self.clone();
848 self.acquire()
849 .await?
850 .with_transaction(move |txn| {
851 if let Some(pickled_private_identity) = &pickled_private_identity {
852 let serialized_private_identity =
853 this.serialize_value(pickled_private_identity)?;
854 txn.set_kv("identity", &serialized_private_identity)?;
855 }
856
857 if let Some(token) = &changes.next_batch_token {
858 let serialized_token = this.serialize_value(token)?;
859 txn.set_kv("next_batch_token", &serialized_token)?;
860 }
861
862 if let Some(decryption_key) = &changes.backup_decryption_key {
863 let serialized_decryption_key = this.serialize_value(decryption_key)?;
864 txn.set_kv("recovery_key_v1", &serialized_decryption_key)?;
865 }
866
867 if let Some(backup_version) = &changes.backup_version {
868 let serialized_backup_version = this.serialize_value(backup_version)?;
869 txn.set_kv("backup_version_v1", &serialized_backup_version)?;
870 }
871
872 if let Some(pickle_key) = &changes.dehydrated_device_pickle_key {
873 let serialized_pickle_key = this.serialize_value(pickle_key)?;
874 txn.set_kv(DEHYDRATED_DEVICE_PICKLE_KEY, &serialized_pickle_key)?;
875 }
876
877 for device in changes.devices.new.iter().chain(&changes.devices.changed) {
878 let user_id = this.encode_key("device", device.user_id().as_bytes());
879 let device_id = this.encode_key("device", device.device_id().as_bytes());
880 let data = this.serialize_value(&device)?;
881 txn.set_device(&user_id, &device_id, &data)?;
882 }
883
884 for device in &changes.devices.deleted {
885 let user_id = this.encode_key("device", device.user_id().as_bytes());
886 let device_id = this.encode_key("device", device.device_id().as_bytes());
887 txn.delete_device(&user_id, &device_id)?;
888 }
889
890 for identity in changes.identities.changed.iter().chain(&changes.identities.new) {
891 let user_id = this.encode_key("identity", identity.user_id().as_bytes());
892 let data = this.serialize_value(&identity)?;
893 txn.set_identity(&user_id, &data)?;
894 }
895
896 for (session_id, sender_key, pickle) in &session_changes {
897 let serialized_session = this.serialize_value(&pickle)?;
898 txn.set_session(session_id, sender_key, &serialized_session)?;
899 }
900
901 for (room_id, session_id, pickle, sender_key) in &inbound_session_changes {
902 let serialized_session = this.serialize_value(&pickle)?;
903 txn.set_inbound_group_session(
904 room_id,
905 session_id,
906 &serialized_session,
907 pickle.backed_up,
908 Some(sender_key),
909 Some(pickle.sender_data.to_type() as u8),
910 )?;
911 }
912
913 for (room_id, pickle) in &outbound_session_changes {
914 let serialized_session = this.serialize_json(&pickle)?;
915 txn.set_outbound_group_session(room_id, &serialized_session)?;
916 }
917
918 for hash in &changes.message_hashes {
919 let hash = rmp_serde::to_vec(hash)?;
920 txn.add_olm_hash(&hash)?;
921 }
922
923 for request in changes.key_requests {
924 let request_id = this.encode_key("key_requests", request.request_id.as_bytes());
925 let serialized_request = this.serialize_value(&request)?;
926 txn.set_key_request(&request_id, request.sent_out, &serialized_request)?;
927 }
928
929 for (room_id, data) in changes.withheld_session_info {
930 for (session_id, event) in data {
931 let session_id = this.encode_key("direct_withheld_info", session_id);
932 let room_id = this.encode_key("direct_withheld_info", &room_id);
933 let serialized_info = this.serialize_json(&event)?;
934 txn.set_direct_withheld(&session_id, &room_id, &serialized_info)?;
935 }
936 }
937
938 for (room_id, settings) in changes.room_settings {
939 let room_id = this.encode_key("room_settings", room_id.as_bytes());
940 let value = this.serialize_value(&settings)?;
941 txn.set_room_settings(&room_id, &value)?;
942 }
943
944 for secret in changes.secrets {
945 let secret_name = this.encode_key("secrets", secret.secret_name.to_string());
946 let value = this.serialize_json(&secret)?;
947 txn.set_secret(&secret_name, &value)?;
948 }
949
950 Ok::<_, Error>(())
951 })
952 .await?;
953
954 Ok(())
955 }
956
957 async fn save_inbound_group_sessions(
958 &self,
959 sessions: Vec<InboundGroupSession>,
960 backed_up_to_version: Option<&str>,
961 ) -> matrix_sdk_crypto::store::Result<(), Self::Error> {
962 sessions.iter().for_each(|s| {
964 let backed_up = s.backed_up();
965 if backed_up != backed_up_to_version.is_some() {
966 warn!(
967 backed_up,
968 backed_up_to_version,
969 "Session backed-up flag does not correspond to backup version setting",
970 );
971 }
972 });
973
974 self.save_changes(Changes { inbound_group_sessions: sessions, ..Changes::default() }).await
977 }
978
979 async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
980 let device_keys = self.get_own_device().await?.as_device_keys().clone();
981
982 let sessions: Vec<_> = self
983 .acquire()
984 .await?
985 .get_sessions_for_sender_key(self.encode_key("session", sender_key.as_bytes()))
986 .await?
987 .into_iter()
988 .map(|bytes| {
989 let pickle = self.deserialize_value(&bytes)?;
990 Session::from_pickle(device_keys.clone(), pickle).map_err(|_| Error::AccountUnset)
991 })
992 .collect::<Result<_>>()?;
993
994 if sessions.is_empty() {
995 Ok(None)
996 } else {
997 Ok(Some(sessions))
998 }
999 }
1000
1001 #[instrument(skip(self))]
1002 async fn get_inbound_group_session(
1003 &self,
1004 room_id: &RoomId,
1005 session_id: &str,
1006 ) -> Result<Option<InboundGroupSession>> {
1007 let session_id = self.encode_key("inbound_group_session", session_id);
1008 let Some((room_id_from_db, value, backed_up)) =
1009 self.acquire().await?.get_inbound_group_session(session_id).await?
1010 else {
1011 return Ok(None);
1012 };
1013
1014 let room_id = self.encode_key("inbound_group_session", room_id.as_bytes());
1015 if *room_id != room_id_from_db {
1016 warn!("expected room_id for session_id doesn't match what's in the DB");
1017 return Ok(None);
1018 }
1019
1020 Ok(Some(self.deserialize_and_unpickle_inbound_group_session(value, backed_up)?))
1021 }
1022
1023 async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
1024 self.acquire()
1025 .await?
1026 .get_inbound_group_sessions()
1027 .await?
1028 .into_iter()
1029 .map(|(value, backed_up)| {
1030 self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1031 })
1032 .collect()
1033 }
1034
1035 async fn get_inbound_group_sessions_for_device_batch(
1036 &self,
1037 sender_key: Curve25519PublicKey,
1038 sender_data_type: SenderDataType,
1039 after_session_id: Option<String>,
1040 limit: usize,
1041 ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1042 let after_session_id =
1043 after_session_id.map(|session_id| self.encode_key("inbound_group_session", session_id));
1044 let sender_key = self.encode_key("inbound_group_session", sender_key.to_base64());
1045
1046 self.acquire()
1047 .await?
1048 .get_inbound_group_sessions_for_device_batch(
1049 sender_key,
1050 sender_data_type,
1051 after_session_id,
1052 limit,
1053 )
1054 .await?
1055 .into_iter()
1056 .map(|(value, backed_up)| {
1057 self.deserialize_and_unpickle_inbound_group_session(value, backed_up)
1058 })
1059 .collect()
1060 }
1061
1062 async fn inbound_group_session_counts(
1063 &self,
1064 backup_version: Option<&str>,
1065 ) -> Result<RoomKeyCounts> {
1066 Ok(self.acquire().await?.get_inbound_group_session_counts(backup_version).await?)
1067 }
1068
1069 async fn inbound_group_sessions_for_backup(
1070 &self,
1071 _backup_version: &str,
1072 limit: usize,
1073 ) -> Result<Vec<InboundGroupSession>> {
1074 self.acquire()
1075 .await?
1076 .get_inbound_group_sessions_for_backup(limit)
1077 .await?
1078 .into_iter()
1079 .map(|value| self.deserialize_and_unpickle_inbound_group_session(value, false))
1080 .collect()
1081 }
1082
1083 async fn mark_inbound_group_sessions_as_backed_up(
1084 &self,
1085 _backup_version: &str,
1086 session_ids: &[(&RoomId, &str)],
1087 ) -> Result<()> {
1088 Ok(self
1089 .acquire()
1090 .await?
1091 .mark_inbound_group_sessions_as_backed_up(
1092 session_ids
1093 .iter()
1094 .map(|(_, s)| self.encode_key("inbound_group_session", s))
1095 .collect(),
1096 )
1097 .await?)
1098 }
1099
1100 async fn reset_backup_state(&self) -> Result<()> {
1101 Ok(self.acquire().await?.reset_inbound_group_session_backup_state().await?)
1102 }
1103
1104 async fn load_backup_keys(&self) -> Result<BackupKeys> {
1105 let conn = self.acquire().await?;
1106
1107 let backup_version = conn
1108 .get_kv("backup_version_v1")
1109 .await?
1110 .map(|value| self.deserialize_value(&value))
1111 .transpose()?;
1112
1113 let decryption_key = conn
1114 .get_kv("recovery_key_v1")
1115 .await?
1116 .map(|value| self.deserialize_value(&value))
1117 .transpose()?;
1118
1119 Ok(BackupKeys { backup_version, decryption_key })
1120 }
1121
1122 async fn load_dehydrated_device_pickle_key(&self) -> Result<Option<DehydratedDeviceKey>> {
1123 let conn = self.acquire().await?;
1124
1125 conn.get_kv(DEHYDRATED_DEVICE_PICKLE_KEY)
1126 .await?
1127 .map(|value| self.deserialize_value(&value))
1128 .transpose()
1129 }
1130
1131 async fn delete_dehydrated_device_pickle_key(&self) -> Result<(), Self::Error> {
1132 let conn = self.acquire().await?;
1133 conn.clear_kv(DEHYDRATED_DEVICE_PICKLE_KEY).await?;
1134
1135 Ok(())
1136 }
1137 async fn get_outbound_group_session(
1138 &self,
1139 room_id: &RoomId,
1140 ) -> Result<Option<OutboundGroupSession>> {
1141 let room_id = self.encode_key("outbound_group_session", room_id.as_bytes());
1142 let Some(value) = self.acquire().await?.get_outbound_group_session(room_id).await? else {
1143 return Ok(None);
1144 };
1145
1146 let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1147
1148 let pickle = self.deserialize_json(&value)?;
1149 let session = OutboundGroupSession::from_pickle(
1150 account_info.device_id,
1151 account_info.identity_keys,
1152 pickle,
1153 )
1154 .map_err(|_| Error::Unpickle)?;
1155
1156 return Ok(Some(session));
1157 }
1158
1159 async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
1160 self.acquire()
1161 .await?
1162 .get_tracked_users()
1163 .await?
1164 .iter()
1165 .map(|value| self.deserialize_value(value))
1166 .collect()
1167 }
1168
1169 async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
1170 let users: Vec<(Key, Vec<u8>)> = tracked_users
1171 .iter()
1172 .map(|(u, d)| {
1173 let user_id = self.encode_key("tracked_users", u.as_bytes());
1174 let data =
1175 self.serialize_value(&TrackedUser { user_id: (*u).into(), dirty: *d })?;
1176 Ok((user_id, data))
1177 })
1178 .collect::<Result<_>>()?;
1179
1180 Ok(self.acquire().await?.add_tracked_users(users).await?)
1181 }
1182
1183 async fn get_device(
1184 &self,
1185 user_id: &UserId,
1186 device_id: &DeviceId,
1187 ) -> Result<Option<DeviceData>> {
1188 let user_id = self.encode_key("device", user_id.as_bytes());
1189 let device_id = self.encode_key("device", device_id.as_bytes());
1190 Ok(self
1191 .acquire()
1192 .await?
1193 .get_device(user_id, device_id)
1194 .await?
1195 .map(|value| self.deserialize_value(&value))
1196 .transpose()?)
1197 }
1198
1199 async fn get_user_devices(
1200 &self,
1201 user_id: &UserId,
1202 ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1203 let user_id = self.encode_key("device", user_id.as_bytes());
1204 self.acquire()
1205 .await?
1206 .get_user_devices(user_id)
1207 .await?
1208 .into_iter()
1209 .map(|value| {
1210 let device: DeviceData = self.deserialize_value(&value)?;
1211 Ok((device.device_id().to_owned(), device))
1212 })
1213 .collect()
1214 }
1215
1216 async fn get_own_device(&self) -> Result<DeviceData> {
1217 let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
1218
1219 Ok(self
1220 .get_device(&account_info.user_id, &account_info.device_id)
1221 .await?
1222 .expect("We should be able to find our own device."))
1223 }
1224
1225 async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
1226 let user_id = self.encode_key("identity", user_id.as_bytes());
1227 Ok(self
1228 .acquire()
1229 .await?
1230 .get_user_identity(user_id)
1231 .await?
1232 .map(|value| self.deserialize_value(&value))
1233 .transpose()?)
1234 }
1235
1236 async fn is_message_known(
1237 &self,
1238 message_hash: &matrix_sdk_crypto::olm::OlmMessageHash,
1239 ) -> Result<bool> {
1240 let value = rmp_serde::to_vec(message_hash)?;
1241 Ok(self.acquire().await?.has_olm_hash(value).await?)
1242 }
1243
1244 async fn get_outgoing_secret_requests(
1245 &self,
1246 request_id: &TransactionId,
1247 ) -> Result<Option<GossipRequest>> {
1248 let request_id = self.encode_key("key_requests", request_id.as_bytes());
1249 Ok(self
1250 .acquire()
1251 .await?
1252 .get_outgoing_secret_request(request_id)
1253 .await?
1254 .map(|(value, sent_out)| self.deserialize_key_request(&value, sent_out))
1255 .transpose()?)
1256 }
1257
1258 async fn get_secret_request_by_info(
1259 &self,
1260 key_info: &SecretInfo,
1261 ) -> Result<Option<GossipRequest>> {
1262 let requests = self.acquire().await?.get_outgoing_secret_requests().await?;
1263 for (request, sent_out) in requests {
1264 let request = self.deserialize_key_request(&request, sent_out)?;
1265 if request.info == *key_info {
1266 return Ok(Some(request));
1267 }
1268 }
1269 Ok(None)
1270 }
1271
1272 async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
1273 self.acquire()
1274 .await?
1275 .get_unsent_secret_requests()
1276 .await?
1277 .iter()
1278 .map(|value| {
1279 let request = self.deserialize_key_request(value, false)?;
1280 Ok(request)
1281 })
1282 .collect()
1283 }
1284
1285 async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
1286 let request_id = self.encode_key("key_requests", request_id.as_bytes());
1287 Ok(self.acquire().await?.delete_key_request(request_id).await?)
1288 }
1289
1290 async fn get_secrets_from_inbox(
1291 &self,
1292 secret_name: &SecretName,
1293 ) -> Result<Vec<GossippedSecret>> {
1294 let secret_name = self.encode_key("secrets", secret_name.to_string());
1295
1296 self.acquire()
1297 .await?
1298 .get_secrets_from_inbox(secret_name)
1299 .await?
1300 .into_iter()
1301 .map(|value| self.deserialize_json(value.as_ref()))
1302 .collect()
1303 }
1304
1305 async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
1306 let secret_name = self.encode_key("secrets", secret_name.to_string());
1307 self.acquire().await?.delete_secrets_from_inbox(secret_name).await
1308 }
1309
1310 async fn get_withheld_info(
1311 &self,
1312 room_id: &RoomId,
1313 session_id: &str,
1314 ) -> Result<Option<RoomKeyWithheldEvent>> {
1315 let room_id = self.encode_key("direct_withheld_info", room_id);
1316 let session_id = self.encode_key("direct_withheld_info", session_id);
1317
1318 self.acquire()
1319 .await?
1320 .get_direct_withheld_info(session_id, room_id)
1321 .await?
1322 .map(|value| {
1323 let info = self.deserialize_json::<RoomKeyWithheldEvent>(&value)?;
1324 Ok(info)
1325 })
1326 .transpose()
1327 }
1328
1329 async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
1330 let room_id = self.encode_key("room_settings", room_id.as_bytes());
1331 let Some(value) = self.acquire().await?.get_room_settings(room_id).await? else {
1332 return Ok(None);
1333 };
1334
1335 let settings = self.deserialize_value(&value)?;
1336
1337 return Ok(Some(settings));
1338 }
1339
1340 async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
1341 let Some(serialized) = self.acquire().await?.get_kv(key).await? else {
1342 return Ok(None);
1343 };
1344 let value = if let Some(cipher) = &self.store_cipher {
1345 let encrypted = rmp_serde::from_slice(&serialized)?;
1346 cipher.decrypt_value_data(encrypted)?
1347 } else {
1348 serialized
1349 };
1350
1351 Ok(Some(value))
1352 }
1353
1354 async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
1355 let serialized = if let Some(cipher) = &self.store_cipher {
1356 let encrypted = cipher.encrypt_value_data(value)?;
1357 rmp_serde::to_vec_named(&encrypted)?
1358 } else {
1359 value
1360 };
1361
1362 self.acquire().await?.set_kv(key, serialized).await?;
1363 Ok(())
1364 }
1365
1366 async fn remove_custom_value(&self, key: &str) -> Result<()> {
1367 let key = key.to_owned();
1368 self.acquire()
1369 .await?
1370 .interact(move |conn| conn.execute("DELETE FROM kv WHERE key = ?1", (&key,)))
1371 .await
1372 .unwrap()?;
1373 Ok(())
1374 }
1375
1376 async fn try_take_leased_lock(
1377 &self,
1378 lease_duration_ms: u32,
1379 key: &str,
1380 holder: &str,
1381 ) -> Result<bool> {
1382 let key = key.to_owned();
1383 let holder = holder.to_owned();
1384
1385 let now_ts: u64 = MilliSecondsSinceUnixEpoch::now().get().into();
1386 let expiration_ts = now_ts + lease_duration_ms as u64;
1387
1388 let num_touched = self
1389 .acquire()
1390 .await?
1391 .with_transaction(move |txn| {
1392 txn.execute(
1393 "INSERT INTO lease_locks (key, holder, expiration_ts)
1394 VALUES (?1, ?2, ?3)
1395 ON CONFLICT (key)
1396 DO
1397 UPDATE SET holder = ?2, expiration_ts = ?3
1398 WHERE holder = ?2
1399 OR expiration_ts < ?4
1400 ",
1401 (key, holder, expiration_ts, now_ts),
1402 )
1403 })
1404 .await?;
1405
1406 Ok(num_touched == 1)
1407 }
1408
1409 async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
1410 let conn = self.acquire().await?;
1411 if let Some(token) = conn.get_kv("next_batch_token").await? {
1412 let maybe_token: Option<String> = self.deserialize_value(&token)?;
1413 Ok(maybe_token)
1414 } else {
1415 Ok(None)
1416 }
1417 }
1418}
1419
1420#[cfg(test)]
1421mod tests {
1422 use std::path::Path;
1423
1424 use matrix_sdk_common::deserialized_responses::WithheldCode;
1425 use matrix_sdk_crypto::{
1426 cryptostore_integration_tests, cryptostore_integration_tests_time, olm::SenderDataType,
1427 store::CryptoStore,
1428 };
1429 use matrix_sdk_test::async_test;
1430 use once_cell::sync::Lazy;
1431 use ruma::{device_id, room_id, user_id};
1432 use similar_asserts::assert_eq;
1433 use tempfile::{tempdir, TempDir};
1434 use tokio::fs;
1435
1436 use super::SqliteCryptoStore;
1437 use crate::SqliteStoreConfig;
1438
1439 static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
1440
1441 struct TestDb {
1442 _dir: TempDir,
1445 database: SqliteCryptoStore,
1446 }
1447
1448 fn copy_db(data_path: &str) -> TempDir {
1449 let db_name = super::DATABASE_NAME;
1450
1451 let manifest_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../..");
1452 let database_path = manifest_path.join(data_path).join(db_name);
1453
1454 let tmpdir = tempdir().unwrap();
1455 let destination = tmpdir.path().join(db_name);
1456
1457 std::fs::copy(&database_path, destination).unwrap();
1459
1460 tmpdir
1461 }
1462
1463 async fn get_test_db(data_path: &str, passphrase: Option<&str>) -> TestDb {
1464 let tmpdir = copy_db(data_path);
1465
1466 let database = SqliteCryptoStore::open(tmpdir.path(), passphrase)
1467 .await
1468 .expect("Can't open the test store");
1469
1470 TestDb { _dir: tmpdir, database }
1471 }
1472
1473 #[async_test]
1474 async fn test_pool_size() {
1475 let store_open_config =
1476 SqliteStoreConfig::new(TMP_DIR.path().join("test_pool_size")).pool_max_size(42);
1477
1478 let store = SqliteCryptoStore::open_with_config(store_open_config).await.unwrap();
1479
1480 assert_eq!(store.pool.status().max_size, 42);
1481 }
1482
1483 #[async_test]
1486 async fn test_open_test_vector_store() {
1487 let TestDb { _dir: _, database } = get_test_db("testing/data/storage", None).await;
1488
1489 let account = database
1490 .load_account()
1491 .await
1492 .unwrap()
1493 .expect("The test database is prefilled with data, we should find an account");
1494
1495 let user_id = account.user_id();
1496 let device_id = account.device_id();
1497
1498 assert_eq!(
1499 user_id.as_str(),
1500 "@pjtest:synapse-oidc.element.dev",
1501 "The user ID should match to the one we expect."
1502 );
1503
1504 assert_eq!(
1505 device_id.as_str(),
1506 "v4TqgcuIH6",
1507 "The device ID should match to the one we expect."
1508 );
1509
1510 let device = database
1511 .get_device(user_id, device_id)
1512 .await
1513 .unwrap()
1514 .expect("Our own device should be found in the store.");
1515
1516 assert_eq!(device.device_id(), device_id);
1517 assert_eq!(device.user_id(), user_id);
1518
1519 assert_eq!(
1520 device.ed25519_key().expect("The device should have a Ed25519 key.").to_base64(),
1521 "+cxl1Gl3du5i7UJwfWnoRDdnafFF+xYdAiTYYhYLr8s"
1522 );
1523
1524 assert_eq!(
1525 device.curve25519_key().expect("The device should have a Curve25519 key.").to_base64(),
1526 "4SL9eEUlpyWSUvjljC5oMjknHQQJY7WZKo5S1KL/5VU"
1527 );
1528
1529 let identity = database
1530 .get_user_identity(user_id)
1531 .await
1532 .unwrap()
1533 .expect("The store should contain an identity.");
1534
1535 assert_eq!(identity.user_id(), user_id);
1536
1537 let identity = identity
1538 .own()
1539 .expect("The identity should be of the correct type, it should be our own identity.");
1540
1541 let master_key = identity
1542 .master_key()
1543 .get_first_key()
1544 .expect("Our own identity should have a master key");
1545
1546 assert_eq!(master_key.to_base64(), "iCUEtB1RwANeqRa5epDrblLk4mer/36sylwQ5hYY3oE");
1547 }
1548
1549 #[async_test]
1552 async fn test_open_test_vector_encrypted_store() {
1553 let TestDb { _dir: _, database } = get_test_db(
1554 "testing/data/storage/alice",
1555 Some(concat!(
1556 "/rCia2fYAJ+twCZ1Xm2mxFCYcmJdyzkdJjwtgXsziWpYS/UeNxnixuSieuwZXm+x1VsJHmWpl",
1557 "H+QIQBZpEGZtC9/S/l8xK+WOCesmET0o6yJ/KP73ofDtjBlnNpPwuHLKFpyTbyicpCgQ4UT+5E",
1558 "UBuJ08TY9Ujdf1D13k5kr5tSZUefDKKCuG1fCRqlU8ByRas1PMQsZxT2W8t7QgBrQiiGmhpo/O",
1559 "Ti4hfx97GOxncKcxTzppiYQNoHs/f15+XXQD7/oiCcqRIuUlXNsU6hRpFGmbYx2Pi1eyQViQCt",
1560 "B5dAEiSD0N8U81wXYnpynuTPtnL+hfnOJIn7Sy7mkERQeKg"
1561 )),
1562 )
1563 .await;
1564
1565 let account = database
1566 .load_account()
1567 .await
1568 .unwrap()
1569 .expect("The test database is prefilled with data, we should find an account");
1570
1571 let user_id = account.user_id();
1572 let device_id = account.device_id();
1573
1574 assert_eq!(
1575 user_id.as_str(),
1576 "@alice:localhost",
1577 "The user ID should match to the one we expect."
1578 );
1579
1580 assert_eq!(
1581 device_id.as_str(),
1582 "JVVORTHFXY",
1583 "The device ID should match to the one we expect."
1584 );
1585
1586 let tracked_users =
1587 database.load_tracked_users().await.expect("Should be tracking some users");
1588
1589 assert_eq!(tracked_users.len(), 6);
1590
1591 let known_users = vec![
1592 user_id!("@alice:localhost"),
1593 user_id!("@dehydration3:localhost"),
1594 user_id!("@eve:localhost"),
1595 user_id!("@bob:localhost"),
1596 user_id!("@malo:localhost"),
1597 user_id!("@carl:localhost"),
1598 ];
1599
1600 for user_id in known_users {
1602 database.get_user_identity(user_id).await.expect("Should load this identity").unwrap();
1603 }
1604
1605 let carl_identity =
1606 database.get_user_identity(user_id!("@carl:localhost")).await.unwrap().unwrap();
1607
1608 assert_eq!(
1609 carl_identity.master_key().get_first_key().unwrap().to_base64(),
1610 "CdhKYYDeBDQveOioXEGWhTPCyzc63Irpar3CNyfun2Q"
1611 );
1612 assert!(!carl_identity.was_previously_verified());
1613
1614 let bob_identity =
1615 database.get_user_identity(user_id!("@bob:localhost")).await.unwrap().unwrap();
1616
1617 assert_eq!(
1618 bob_identity.master_key().get_first_key().unwrap().to_base64(),
1619 "COh2GYOJWSjem5QPRCaGp9iWV83IELG1IzLKW2S3pFY"
1620 );
1621 assert!(bob_identity.was_previously_verified());
1623
1624 let known_devices = vec![
1625 (device_id!("OPXQHCZSKW"), user_id!("@alice:localhost")),
1626 (
1628 device_id!("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho"),
1629 user_id!("@alice:localhost"),
1630 ),
1631 (device_id!("HEEFRFQENV"), user_id!("@alice:localhost")),
1632 (device_id!("JVVORTHFXY"), user_id!("@alice:localhost")),
1633 (device_id!("NQUWWSKKHS"), user_id!("@alice:localhost")),
1634 (device_id!("ORBLPFYCPG"), user_id!("@alice:localhost")),
1635 (device_id!("YXOWENSEGM"), user_id!("@dehydration3:localhost")),
1636 (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
1637 (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
1638 (device_id!("VXLFMYCHXC"), user_id!("@bob:localhost")),
1639 (device_id!("FDGDQAEWOW"), user_id!("@bob:localhost")),
1640 (device_id!("QKUKWJTTQC"), user_id!("@malo:localhost")),
1641 (device_id!("LOUXJECTFG"), user_id!("@malo:localhost")),
1642 (device_id!("MKKMAEVLPB"), user_id!("@carl:localhost")),
1643 ];
1644
1645 for (device_id, user_id) in known_devices {
1646 database.get_device(user_id, device_id).await.expect("Should load the device").unwrap();
1647 }
1648
1649 let known_sender_key_to_session_count = vec![
1650 ("FfYcYfDF4nWy+LHdK6CEpIMlFAQDORc30WUkghL06kM", 1),
1651 ("EvW+9IrGR10KVgVeZP25/KaPfx4R86FofVMcaz7VOho", 1),
1652 ("hAGsoA4a9M6wwEUX5Q1jux1i+tUngLi01n5AmhDoHTY", 1),
1653 ("aKqtSJymLzuoglWFwPGk1r/Vm2LE2hFESzXxn4RNjRM", 0),
1654 ("zHK1psCrgeMn0kaz8hcdvA3INyar9jg1yfrSp0p1pHo", 1),
1655 ("1QmBA316Wj5jIFRwNOti6N6Xh/vW0bsYCcR4uPfy8VQ", 1),
1656 ("g5ef2vZF3VXgSPyODIeXpyHIRkuthvLhGvd6uwYggWU", 1),
1657 ("o7hfupPd1VsNkRIvdlH6ujrEJFSKjFCGbxhAd31XxjI", 1),
1658 ("Z3RxKQLxY7xpP+ZdOGR2SiNE37SrvmRhW7GPu1UGdm8", 1),
1659 ("GDomaav8NiY3J+dNEeApJm+O0FooJ3IpVaIyJzCN4w4", 1),
1660 ("7m7fqkHyEr47V5s/KjaxtJMOr3pSHrrns2q2lWpAQi8", 0),
1661 ("9psAkPUIF8vNbWbnviX3PlwRcaeO53EHJdNtKpTY1X0", 0),
1662 ("mqanh+ztw5oRtpqYQgLGW864i6NY2zpoKMIlrcyC+Aw", 0),
1663 ("fJU/TJdbsv7tVbbpHw1Ke73ziElnM32cNhP2WIg4T10", 0),
1664 ("sUIeFeFcCZoa5IC6nJ6Vrbvztcyx09m8BBg57XKRClg", 1),
1665 ];
1666
1667 for (id, count) in known_sender_key_to_session_count {
1668 let olm_sessions =
1669 database.get_sessions(id).await.expect("Should have some olm sessions");
1670
1671 println!("### Session id: {:?}", id);
1672 assert_eq!(olm_sessions.map_or(0, |v| v.len()), count);
1673 }
1674
1675 let inbound_group_sessions = database.get_inbound_group_sessions().await.unwrap();
1676 assert_eq!(inbound_group_sessions.len(), 15);
1677 let known_inbound_group_sessions = vec![
1678 (
1679 "5hNAxrLai3VI0LKBwfh3wLfksfBFWds0W1a5X5/vSXA",
1680 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1681 ),
1682 (
1683 "M6d2eU3y54gaYTbvGSlqa/xc1Az35l56Cp9sxzHWO4g",
1684 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1685 ),
1686 (
1687 "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
1688 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1689 ),
1690 (
1691 "Y74+l9jTo7N5UF+GQwdpgJGe4sn1+QtWITq7BxulHIE",
1692 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1693 ),
1694 (
1695 "HpJxQR57WbQGdY6w2Q+C16znVvbXGa+JvQdRoMpWbXg",
1696 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1697 ),
1698 (
1699 "Xetvi+ydFkZt8dpONGFbEusQb/Chc2V0XlLByZhsbgE",
1700 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1701 ),
1702 (
1703 "wv/WN/39akyerIXczTaIpjAuLnwgXKRtbXFSEHiJqxo",
1704 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1705 ),
1706 (
1707 "nA4gQwL//Cm8OdlyjABl/jChbPT/cP5V4Sd8iuE6H0s",
1708 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1709 ),
1710 (
1711 "bAAgqFeRDTjfEqL6Qf/c9mk55zoNDCSlboAIRd6b0hw",
1712 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1713 ),
1714 (
1715 "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
1716 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1717 ),
1718 (
1719 "h+om7oSw/ZV94fcKaoe8FGXJwQXWOfKQfzbGgNWQILI",
1720 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1721 ),
1722 (
1723 "ul3VXonpgk4lO2L3fEWubP/nxsTmLHqu5v8ZM9vHEcw",
1724 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1725 ),
1726 (
1727 "JXY15UxC3az2mwg8uX4qwgxfvCM4aygiIWMcdNiVQoc",
1728 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1729 ),
1730 (
1731 "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
1732 room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1733 ),
1734 (
1735 "SFkHcbxjUOYF7mUAYI/oEMDZFaXszQbCN6Jza7iemj0",
1736 room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1737 ),
1738 ];
1739
1740 for (session_id, room_id) in &known_inbound_group_sessions {
1742 database
1743 .get_inbound_group_session(room_id, session_id)
1744 .await
1745 .expect("Should be able to load inbound group session")
1746 .unwrap();
1747 }
1748
1749 let bob_sender_verified = database
1750 .get_inbound_group_session(
1751 room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"),
1752 "exPbsMMdGfAG2qmDdFtpAn+koVprfzS0Zip/RA9QRCE",
1753 )
1754 .await
1755 .unwrap()
1756 .unwrap();
1757
1758 assert_eq!(bob_sender_verified.sender_data.to_type(), SenderDataType::SenderVerified);
1759 assert!(bob_sender_verified.backed_up());
1760 assert!(!bob_sender_verified.has_been_imported());
1761
1762 let alice_unknown_device = database
1763 .get_inbound_group_session(
1764 room_id!("!SRstFdydzrGwJYtVfm:localhost"),
1765 "IrydwXkRk2N2AqUMIVmLL3oJgMq14R9KId0P/uSD100",
1766 )
1767 .await
1768 .unwrap()
1769 .unwrap();
1770
1771 assert_eq!(alice_unknown_device.sender_data.to_type(), SenderDataType::UnknownDevice);
1772 assert!(alice_unknown_device.backed_up());
1773 assert!(alice_unknown_device.has_been_imported());
1774
1775 let carl_tofu_session = database
1776 .get_inbound_group_session(
1777 room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1778 "OGB9lObr9kWUvha9tB5sMfOF/Mztk24JwQz/nwg3iFQ",
1779 )
1780 .await
1781 .unwrap()
1782 .unwrap();
1783
1784 assert_eq!(carl_tofu_session.sender_data.to_type(), SenderDataType::SenderUnverified);
1785 assert!(carl_tofu_session.backed_up());
1786 assert!(!carl_tofu_session.has_been_imported());
1787
1788 database
1790 .get_outbound_group_session(room_id!("!OgRiTRMaUzLdpCeDBM:localhost"))
1791 .await
1792 .unwrap()
1793 .unwrap();
1794 database
1795 .get_outbound_group_session(room_id!("!ZIwZcFqZVAYLAqVjfV:localhost"))
1796 .await
1797 .unwrap()
1798 .unwrap();
1799 database
1800 .get_outbound_group_session(room_id!("!SRstFdydzrGwJYtVfm:localhost"))
1801 .await
1802 .unwrap()
1803 .unwrap();
1804
1805 let withheld_info = database
1806 .get_withheld_info(
1807 room_id!("!OgRiTRMaUzLdpCeDBM:localhost"),
1808 "SASgZ+EklvAF4QxJclMlDRlmL0fAMjAJJIKFMdb4Ht0",
1809 )
1810 .await
1811 .expect("This session should be withheld")
1812 .unwrap();
1813
1814 assert_eq!(withheld_info.content.withheld_code(), WithheldCode::Unverified);
1815
1816 let backup_keys = database.load_backup_keys().await.expect("backup key should be cached");
1817 assert_eq!(backup_keys.backup_version.unwrap(), "6");
1818 assert!(backup_keys.decryption_key.is_some());
1819 }
1820
1821 async fn get_store(
1822 name: &str,
1823 passphrase: Option<&str>,
1824 clear_data: bool,
1825 ) -> SqliteCryptoStore {
1826 let tmpdir_path = TMP_DIR.path().join(name);
1827
1828 if clear_data {
1829 let _ = fs::remove_dir_all(&tmpdir_path).await;
1830 }
1831
1832 SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), passphrase)
1833 .await
1834 .expect("Can't create a passphrase protected store")
1835 }
1836
1837 cryptostore_integration_tests!();
1838 cryptostore_integration_tests_time!();
1839}
1840
1841#[cfg(test)]
1842mod encrypted_tests {
1843 use matrix_sdk_crypto::{cryptostore_integration_tests, cryptostore_integration_tests_time};
1844 use once_cell::sync::Lazy;
1845 use tempfile::{tempdir, TempDir};
1846 use tokio::fs;
1847
1848 use super::SqliteCryptoStore;
1849
1850 static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
1851
1852 async fn get_store(
1853 name: &str,
1854 passphrase: Option<&str>,
1855 clear_data: bool,
1856 ) -> SqliteCryptoStore {
1857 let tmpdir_path = TMP_DIR.path().join(name);
1858 let pass = passphrase.unwrap_or("default_test_password");
1859
1860 if clear_data {
1861 let _ = fs::remove_dir_all(&tmpdir_path).await;
1862 }
1863
1864 SqliteCryptoStore::open(tmpdir_path.to_str().unwrap(), Some(pass))
1865 .await
1866 .expect("Can't create a passphrase protected store")
1867 }
1868
1869 cryptostore_integration_tests!();
1870 cryptostore_integration_tests_time!();
1871}