1use std::{
42 collections::{BTreeMap, BTreeSet, HashMap, HashSet},
43 fmt::Debug,
44 ops::Deref,
45 pin::pin,
46 sync::{atomic::Ordering, Arc},
47 time::Duration,
48};
49
50use as_variant::as_variant;
51use futures_core::Stream;
52use futures_util::StreamExt;
53use matrix_sdk_common::locks::RwLock as StdRwLock;
54use ruma::{
55 encryption::KeyUsage, events::secret::request::SecretName, DeviceId, OwnedDeviceId,
56 OwnedRoomId, OwnedUserId, RoomId, UserId,
57};
58use serde::{de::DeserializeOwned, Deserialize, Serialize};
59use thiserror::Error;
60use tokio::sync::{Mutex, MutexGuard, Notify, OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
61use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
62use tracing::{info, warn};
63use vodozemac::{base64_encode, megolm::SessionOrdering, Curve25519PublicKey};
64use zeroize::{Zeroize, ZeroizeOnDrop};
65
66#[cfg(doc)]
67use crate::{backups::BackupMachine, identities::OwnUserIdentity};
68use crate::{
69 gossiping::GossippedSecret,
70 identities::{user::UserIdentity, Device, DeviceData, UserDevices, UserIdentityData},
71 olm::{
72 Account, ExportedRoomKey, InboundGroupSession, OlmMessageHash, OutboundGroupSession,
73 PrivateCrossSigningIdentity, Session, StaticAccountData,
74 },
75 types::{
76 events::room_key_withheld::RoomKeyWithheldEvent, BackupSecrets, CrossSigningSecrets,
77 EventEncryptionAlgorithm, MegolmBackupV1Curve25519AesSha2Secrets, SecretsBundle,
78 },
79 verification::VerificationMachine,
80 CrossSigningStatus, OwnUserIdentityData, RoomKeyImportResult,
81};
82
83pub mod caches;
84mod crypto_store_wrapper;
85mod error;
86mod memorystore;
87mod traits;
88
89#[cfg(any(test, feature = "testing"))]
90#[macro_use]
91#[allow(missing_docs)]
92pub mod integration_tests;
93
94use caches::{SequenceNumber, UsersForKeyQuery};
95pub(crate) use crypto_store_wrapper::CryptoStoreWrapper;
96pub use error::{CryptoStoreError, Result};
97use matrix_sdk_common::{
98 deserialized_responses::WithheldCode, store_locks::CrossProcessStoreLock, timeout::timeout,
99};
100pub use memorystore::MemoryStore;
101pub use traits::{CryptoStore, DynCryptoStore, IntoCryptoStore};
102
103use crate::types::{
104 events::room_key_withheld::RoomKeyWithheldContent, room_history::RoomKeyBundle,
105};
106pub use crate::{
107 dehydrated_devices::DehydrationError,
108 gossiping::{GossipRequest, SecretInfo},
109};
110
111#[derive(Debug, Clone)]
118pub struct Store {
119 inner: Arc<StoreInner>,
120}
121
122#[derive(Debug, Default)]
123pub(crate) struct KeyQueryManager {
124 users_for_key_query: Mutex<UsersForKeyQuery>,
126
127 users_for_key_query_notify: Notify,
129}
130
131impl KeyQueryManager {
132 pub async fn synced<'a>(&'a self, cache: &'a StoreCache) -> Result<SyncedKeyQueryManager<'a>> {
133 self.ensure_sync_tracked_users(cache).await?;
134 Ok(SyncedKeyQueryManager { cache, manager: self })
135 }
136
137 async fn ensure_sync_tracked_users(&self, cache: &StoreCache) -> Result<()> {
144 let loaded = cache.loaded_tracked_users.read().await;
146 if *loaded {
147 return Ok(());
148 }
149
150 drop(loaded);
152 let mut loaded = cache.loaded_tracked_users.write().await;
153
154 if *loaded {
158 return Ok(());
159 }
160
161 let tracked_users = cache.store.load_tracked_users().await?;
162
163 let mut query_users_lock = self.users_for_key_query.lock().await;
164 let mut tracked_users_cache = cache.tracked_users.write();
165 for user in tracked_users {
166 tracked_users_cache.insert(user.user_id.to_owned());
167
168 if user.dirty {
169 query_users_lock.insert_user(&user.user_id);
170 }
171 }
172
173 *loaded = true;
174
175 Ok(())
176 }
177
178 pub async fn wait_if_user_key_query_pending(
188 &self,
189 cache: StoreCacheGuard,
190 timeout_duration: Duration,
191 user: &UserId,
192 ) -> Result<UserKeyQueryResult> {
193 {
194 self.ensure_sync_tracked_users(&cache).await?;
197 drop(cache);
198 }
199
200 let mut users_for_key_query = self.users_for_key_query.lock().await;
201 let Some(waiter) = users_for_key_query.maybe_register_waiting_task(user) else {
202 return Ok(UserKeyQueryResult::WasNotPending);
203 };
204
205 let wait_for_completion = async {
206 while !waiter.completed.load(Ordering::Relaxed) {
207 let mut notified = pin!(self.users_for_key_query_notify.notified());
211 notified.as_mut().enable();
212 drop(users_for_key_query);
213
214 notified.await;
216
217 users_for_key_query = self.users_for_key_query.lock().await;
221 }
222 };
223
224 match timeout(Box::pin(wait_for_completion), timeout_duration).await {
225 Err(_) => {
226 warn!(
227 user_id = ?user,
228 "The user has a pending `/keys/query` request which did \
229 not finish yet, some devices might be missing."
230 );
231
232 Ok(UserKeyQueryResult::TimeoutExpired)
233 }
234 _ => Ok(UserKeyQueryResult::WasPending),
235 }
236 }
237}
238
239pub(crate) struct SyncedKeyQueryManager<'a> {
240 cache: &'a StoreCache,
241 manager: &'a KeyQueryManager,
242}
243
244impl SyncedKeyQueryManager<'_> {
245 pub async fn update_tracked_users(&self, users: impl Iterator<Item = &UserId>) -> Result<()> {
250 let mut store_updates = Vec::new();
251 let mut key_query_lock = self.manager.users_for_key_query.lock().await;
252
253 {
254 let mut tracked_users = self.cache.tracked_users.write();
255 for user_id in users {
256 if tracked_users.insert(user_id.to_owned()) {
257 key_query_lock.insert_user(user_id);
258 store_updates.push((user_id, true))
259 }
260 }
261 }
262
263 self.cache.store.save_tracked_users(&store_updates).await
264 }
265
266 pub async fn mark_tracked_users_as_changed(
273 &self,
274 users: impl Iterator<Item = &UserId>,
275 ) -> Result<()> {
276 let mut store_updates: Vec<(&UserId, bool)> = Vec::new();
277 let mut key_query_lock = self.manager.users_for_key_query.lock().await;
278
279 {
280 let tracked_users = &self.cache.tracked_users.read();
281 for user_id in users {
282 if tracked_users.contains(user_id) {
283 key_query_lock.insert_user(user_id);
284 store_updates.push((user_id, true));
285 }
286 }
287 }
288
289 self.cache.store.save_tracked_users(&store_updates).await
290 }
291
292 pub async fn mark_tracked_users_as_up_to_date(
298 &self,
299 users: impl Iterator<Item = &UserId>,
300 sequence_number: SequenceNumber,
301 ) -> Result<()> {
302 let mut store_updates: Vec<(&UserId, bool)> = Vec::new();
303 let mut key_query_lock = self.manager.users_for_key_query.lock().await;
304
305 {
306 let tracked_users = self.cache.tracked_users.read();
307 for user_id in users {
308 if tracked_users.contains(user_id) {
309 let clean = key_query_lock.maybe_remove_user(user_id, sequence_number);
310 store_updates.push((user_id, !clean));
311 }
312 }
313 }
314
315 self.cache.store.save_tracked_users(&store_updates).await?;
316 self.manager.users_for_key_query_notify.notify_waiters();
318
319 Ok(())
320 }
321
322 pub async fn users_for_key_query(&self) -> (HashSet<OwnedUserId>, SequenceNumber) {
334 self.manager.users_for_key_query.lock().await.users_for_key_query()
335 }
336
337 pub fn tracked_users(&self) -> HashSet<OwnedUserId> {
339 self.cache.tracked_users.read().iter().cloned().collect()
340 }
341
342 pub async fn mark_user_as_changed(&self, user: &UserId) -> Result<()> {
348 self.manager.users_for_key_query.lock().await.insert_user(user);
349 self.cache.tracked_users.write().insert(user.to_owned());
350
351 self.cache.store.save_tracked_users(&[(user, true)]).await
352 }
353}
354
355#[derive(Debug)]
356pub(crate) struct StoreCache {
357 store: Arc<CryptoStoreWrapper>,
358 tracked_users: StdRwLock<BTreeSet<OwnedUserId>>,
359 loaded_tracked_users: RwLock<bool>,
360 account: Mutex<Option<Account>>,
361}
362
363impl StoreCache {
364 pub(crate) fn store_wrapper(&self) -> &CryptoStoreWrapper {
365 self.store.as_ref()
366 }
367
368 async fn account(&self) -> Result<impl Deref<Target = Account> + '_> {
380 let mut guard = self.account.lock().await;
381 if guard.is_some() {
382 Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
383 } else {
384 match self.store.load_account().await? {
385 Some(account) => {
386 *guard = Some(account);
387 Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
388 }
389 None => Err(CryptoStoreError::AccountUnset),
390 }
391 }
392 }
393}
394
395pub(crate) struct StoreCacheGuard {
401 cache: OwnedRwLockReadGuard<StoreCache>,
402 }
404
405impl StoreCacheGuard {
406 pub async fn account(&self) -> Result<impl Deref<Target = Account> + '_> {
414 self.cache.account().await
415 }
416}
417
418impl Deref for StoreCacheGuard {
419 type Target = StoreCache;
420
421 fn deref(&self) -> &Self::Target {
422 &self.cache
423 }
424}
425
426#[allow(missing_debug_implementations)]
428pub struct StoreTransaction {
429 store: Store,
430 changes: PendingChanges,
431 cache: OwnedRwLockWriteGuard<StoreCache>,
433}
434
435impl StoreTransaction {
436 async fn new(store: Store) -> Self {
438 let cache = store.inner.cache.clone();
439
440 Self { store, changes: PendingChanges::default(), cache: cache.clone().write_owned().await }
441 }
442
443 pub(crate) fn cache(&self) -> &StoreCache {
444 &self.cache
445 }
446
447 pub fn store(&self) -> &Store {
449 &self.store
450 }
451
452 pub async fn account(&mut self) -> Result<&mut Account> {
459 if self.changes.account.is_none() {
460 let _ = self.cache.account().await?;
462 self.changes.account = self.cache.account.lock().await.take();
463 }
464 Ok(self.changes.account.as_mut().unwrap())
465 }
466
467 pub async fn commit(self) -> Result<()> {
470 if self.changes.is_empty() {
471 return Ok(());
472 }
473
474 let account = self.changes.account.as_ref().map(|acc| acc.deep_clone());
476
477 self.store.save_pending_changes(self.changes).await?;
478
479 if let Some(account) = account {
481 *self.cache.account.lock().await = Some(account);
482 }
483
484 Ok(())
485 }
486}
487
488#[derive(Debug)]
489struct StoreInner {
490 identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
491 store: Arc<CryptoStoreWrapper>,
492
493 cache: Arc<RwLock<StoreCache>>,
497
498 verification_machine: VerificationMachine,
499
500 static_account: StaticAccountData,
503}
504
505#[derive(Default, Debug)]
511#[allow(missing_docs)]
512pub struct PendingChanges {
513 pub account: Option<Account>,
514}
515
516impl PendingChanges {
517 pub fn is_empty(&self) -> bool {
519 self.account.is_none()
520 }
521}
522
523#[derive(Default, Debug)]
526#[allow(missing_docs)]
527pub struct Changes {
528 pub private_identity: Option<PrivateCrossSigningIdentity>,
529 pub backup_version: Option<String>,
530 pub backup_decryption_key: Option<BackupDecryptionKey>,
531 pub dehydrated_device_pickle_key: Option<DehydratedDeviceKey>,
532 pub sessions: Vec<Session>,
533 pub message_hashes: Vec<OlmMessageHash>,
534 pub inbound_group_sessions: Vec<InboundGroupSession>,
535 pub outbound_group_sessions: Vec<OutboundGroupSession>,
536 pub key_requests: Vec<GossipRequest>,
537 pub identities: IdentityChanges,
538 pub devices: DeviceChanges,
539 pub withheld_session_info: BTreeMap<OwnedRoomId, BTreeMap<String, RoomKeyWithheldEvent>>,
541 pub room_settings: HashMap<OwnedRoomId, RoomSettings>,
542 pub secrets: Vec<GossippedSecret>,
543 pub next_batch_token: Option<String>,
544}
545
546#[derive(Clone, Debug, Serialize, Deserialize)]
548pub struct TrackedUser {
549 pub user_id: OwnedUserId,
551 pub dirty: bool,
556}
557
558impl Changes {
559 pub fn is_empty(&self) -> bool {
561 self.private_identity.is_none()
562 && self.backup_version.is_none()
563 && self.backup_decryption_key.is_none()
564 && self.dehydrated_device_pickle_key.is_none()
565 && self.sessions.is_empty()
566 && self.message_hashes.is_empty()
567 && self.inbound_group_sessions.is_empty()
568 && self.outbound_group_sessions.is_empty()
569 && self.key_requests.is_empty()
570 && self.identities.is_empty()
571 && self.devices.is_empty()
572 && self.withheld_session_info.is_empty()
573 && self.room_settings.is_empty()
574 && self.secrets.is_empty()
575 && self.next_batch_token.is_none()
576 }
577}
578
579#[derive(Debug, Clone, Default)]
590#[allow(missing_docs)]
591pub struct IdentityChanges {
592 pub new: Vec<UserIdentityData>,
593 pub changed: Vec<UserIdentityData>,
594 pub unchanged: Vec<UserIdentityData>,
595}
596
597impl IdentityChanges {
598 fn is_empty(&self) -> bool {
599 self.new.is_empty() && self.changed.is_empty()
600 }
601
602 fn into_maps(
605 self,
606 ) -> (
607 BTreeMap<OwnedUserId, UserIdentityData>,
608 BTreeMap<OwnedUserId, UserIdentityData>,
609 BTreeMap<OwnedUserId, UserIdentityData>,
610 ) {
611 let new: BTreeMap<_, _> = self
612 .new
613 .into_iter()
614 .map(|identity| (identity.user_id().to_owned(), identity))
615 .collect();
616
617 let changed: BTreeMap<_, _> = self
618 .changed
619 .into_iter()
620 .map(|identity| (identity.user_id().to_owned(), identity))
621 .collect();
622
623 let unchanged: BTreeMap<_, _> = self
624 .unchanged
625 .into_iter()
626 .map(|identity| (identity.user_id().to_owned(), identity))
627 .collect();
628
629 (new, changed, unchanged)
630 }
631}
632
633#[derive(Debug, Clone, Default)]
634#[allow(missing_docs)]
635pub struct DeviceChanges {
636 pub new: Vec<DeviceData>,
637 pub changed: Vec<DeviceData>,
638 pub deleted: Vec<DeviceData>,
639}
640
641fn collect_device_updates(
647 verification_machine: VerificationMachine,
648 own_identity: Option<OwnUserIdentityData>,
649 identities: IdentityChanges,
650 devices: DeviceChanges,
651) -> DeviceUpdates {
652 let mut new: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new();
653 let mut changed: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new();
654
655 let (new_identities, changed_identities, unchanged_identities) = identities.into_maps();
656
657 let map_device = |device: DeviceData| {
658 let device_owner_identity = new_identities
659 .get(device.user_id())
660 .or_else(|| changed_identities.get(device.user_id()))
661 .or_else(|| unchanged_identities.get(device.user_id()))
662 .cloned();
663
664 Device {
665 inner: device,
666 verification_machine: verification_machine.to_owned(),
667 own_identity: own_identity.to_owned(),
668 device_owner_identity,
669 }
670 };
671
672 for device in devices.new {
673 let device = map_device(device);
674
675 new.entry(device.user_id().to_owned())
676 .or_default()
677 .insert(device.device_id().to_owned(), device);
678 }
679
680 for device in devices.changed {
681 let device = map_device(device);
682
683 changed
684 .entry(device.user_id().to_owned())
685 .or_default()
686 .insert(device.device_id().to_owned(), device.to_owned());
687 }
688
689 DeviceUpdates { new, changed }
690}
691
692#[derive(Clone, Debug, Default)]
695pub struct DeviceUpdates {
696 pub new: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, Device>>,
702 pub changed: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, Device>>,
704}
705
706#[derive(Clone, Debug, Default)]
709pub struct IdentityUpdates {
710 pub new: BTreeMap<OwnedUserId, UserIdentity>,
716 pub changed: BTreeMap<OwnedUserId, UserIdentity>,
718 pub unchanged: BTreeMap<OwnedUserId, UserIdentity>,
720}
721
722#[derive(Clone, Zeroize, ZeroizeOnDrop, Deserialize, Serialize)]
732#[serde(transparent)]
733pub struct BackupDecryptionKey {
734 pub(crate) inner: Box<[u8; BackupDecryptionKey::KEY_SIZE]>,
735}
736
737impl BackupDecryptionKey {
738 pub const KEY_SIZE: usize = 32;
740
741 pub fn new() -> Result<Self, rand::Error> {
743 let mut rng = rand::thread_rng();
744
745 let mut key = Box::new([0u8; Self::KEY_SIZE]);
746 rand::Fill::try_fill(key.as_mut_slice(), &mut rng)?;
747
748 Ok(Self { inner: key })
749 }
750
751 pub fn to_base64(&self) -> String {
753 base64_encode(self.inner.as_slice())
754 }
755}
756
757#[cfg(not(tarpaulin_include))]
758impl Debug for BackupDecryptionKey {
759 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
760 f.debug_tuple("BackupDecryptionKey").field(&"...").finish()
761 }
762}
763
764#[derive(Clone, Zeroize, ZeroizeOnDrop, Deserialize, Serialize)]
769#[serde(transparent)]
770pub struct DehydratedDeviceKey {
771 pub(crate) inner: Box<[u8; DehydratedDeviceKey::KEY_SIZE]>,
772}
773
774impl DehydratedDeviceKey {
775 pub const KEY_SIZE: usize = 32;
777
778 pub fn new() -> Result<Self, rand::Error> {
780 let mut rng = rand::thread_rng();
781
782 let mut key = Box::new([0u8; Self::KEY_SIZE]);
783 rand::Fill::try_fill(key.as_mut_slice(), &mut rng)?;
784
785 Ok(Self { inner: key })
786 }
787
788 pub fn from_slice(slice: &[u8]) -> Result<Self, DehydrationError> {
792 if slice.len() == 32 {
793 let mut key = Box::new([0u8; 32]);
794 key.copy_from_slice(slice);
795 Ok(DehydratedDeviceKey { inner: key })
796 } else {
797 Err(DehydrationError::PickleKeyLength(slice.len()))
798 }
799 }
800
801 pub fn from_bytes(raw_key: &[u8; 32]) -> Self {
803 let mut inner = Box::new([0u8; Self::KEY_SIZE]);
804 inner.copy_from_slice(raw_key);
805
806 Self { inner }
807 }
808
809 pub fn to_base64(&self) -> String {
811 base64_encode(self.inner.as_slice())
812 }
813}
814
815impl From<&[u8; 32]> for DehydratedDeviceKey {
816 fn from(value: &[u8; 32]) -> Self {
817 DehydratedDeviceKey { inner: Box::new(*value) }
818 }
819}
820
821impl From<DehydratedDeviceKey> for Vec<u8> {
822 fn from(key: DehydratedDeviceKey) -> Self {
823 key.inner.to_vec()
824 }
825}
826
827#[cfg(not(tarpaulin_include))]
828impl Debug for DehydratedDeviceKey {
829 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
830 f.debug_tuple("DehydratedDeviceKey").field(&"...").finish()
831 }
832}
833
834impl DeviceChanges {
835 pub fn extend(&mut self, other: DeviceChanges) {
837 self.new.extend(other.new);
838 self.changed.extend(other.changed);
839 self.deleted.extend(other.deleted);
840 }
841
842 fn is_empty(&self) -> bool {
843 self.new.is_empty() && self.changed.is_empty() && self.deleted.is_empty()
844 }
845}
846
847#[derive(Debug, Clone, Default)]
849pub struct RoomKeyCounts {
850 pub total: usize,
852 pub backed_up: usize,
854}
855
856#[derive(Default, Clone, Debug)]
858pub struct BackupKeys {
859 pub decryption_key: Option<BackupDecryptionKey>,
861 pub backup_version: Option<String>,
863}
864
865#[derive(Default, Zeroize, ZeroizeOnDrop)]
868pub struct CrossSigningKeyExport {
869 pub master_key: Option<String>,
871 pub self_signing_key: Option<String>,
873 pub user_signing_key: Option<String>,
875}
876
877#[cfg(not(tarpaulin_include))]
878impl Debug for CrossSigningKeyExport {
879 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
880 f.debug_struct("CrossSigningKeyExport")
881 .field("master_key", &self.master_key.is_some())
882 .field("self_signing_key", &self.self_signing_key.is_some())
883 .field("user_signing_key", &self.user_signing_key.is_some())
884 .finish_non_exhaustive()
885 }
886}
887
888#[derive(Debug, Error)]
891pub enum SecretImportError {
892 #[error(transparent)]
894 Key(#[from] vodozemac::KeyError),
895 #[error(
898 "The public key of the imported private key doesn't match to the \
899 public key that was uploaded to the server"
900 )]
901 MismatchedPublicKeys,
902 #[error(transparent)]
904 Store(#[from] CryptoStoreError),
905}
906
907#[derive(Debug, Error)]
912pub enum SecretsBundleExportError {
913 #[error(transparent)]
915 Store(#[from] CryptoStoreError),
916 #[error("The store is missing one or multiple cross-signing keys")]
918 MissingCrossSigningKey(KeyUsage),
919 #[error("The store doesn't contain any cross-signing keys")]
921 MissingCrossSigningKeys,
922 #[error("The store contains a backup key, but no backup version")]
925 MissingBackupVersion,
926}
927
928#[derive(Clone, Copy, Debug, PartialEq, Eq)]
931pub(crate) enum UserKeyQueryResult {
932 WasPending,
933 WasNotPending,
934
935 TimeoutExpired,
937}
938
939#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
941pub struct RoomSettings {
942 pub algorithm: EventEncryptionAlgorithm,
944
945 pub only_allow_trusted_devices: bool,
948
949 pub session_rotation_period: Option<Duration>,
952
953 pub session_rotation_period_messages: Option<usize>,
956}
957
958impl Default for RoomSettings {
959 fn default() -> Self {
960 Self {
961 algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
962 only_allow_trusted_devices: false,
963 session_rotation_period: None,
964 session_rotation_period_messages: None,
965 }
966 }
967}
968
969#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
971pub struct RoomKeyInfo {
972 pub algorithm: EventEncryptionAlgorithm,
977
978 pub room_id: OwnedRoomId,
980
981 pub sender_key: Curve25519PublicKey,
983
984 pub session_id: String,
986}
987
988impl From<&InboundGroupSession> for RoomKeyInfo {
989 fn from(group_session: &InboundGroupSession) -> Self {
990 RoomKeyInfo {
991 algorithm: group_session.algorithm().clone(),
992 room_id: group_session.room_id().to_owned(),
993 sender_key: group_session.sender_key(),
994 session_id: group_session.session_id().to_owned(),
995 }
996 }
997}
998
999#[derive(Clone, Debug, Deserialize, Serialize)]
1001pub struct RoomKeyWithheldInfo {
1002 pub room_id: OwnedRoomId,
1004
1005 pub session_id: String,
1007
1008 pub withheld_event: RoomKeyWithheldEvent,
1011}
1012
1013impl Store {
1014 pub(crate) fn new(
1016 account: StaticAccountData,
1017 identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
1018 store: Arc<CryptoStoreWrapper>,
1019 verification_machine: VerificationMachine,
1020 ) -> Self {
1021 Self {
1022 inner: Arc::new(StoreInner {
1023 static_account: account,
1024 identity,
1025 store: store.clone(),
1026 verification_machine,
1027 cache: Arc::new(RwLock::new(StoreCache {
1028 store,
1029 tracked_users: Default::default(),
1030 loaded_tracked_users: Default::default(),
1031 account: Default::default(),
1032 })),
1033 }),
1034 }
1035 }
1036
1037 pub(crate) fn user_id(&self) -> &UserId {
1039 &self.inner.static_account.user_id
1040 }
1041
1042 pub(crate) fn device_id(&self) -> &DeviceId {
1044 self.inner.verification_machine.own_device_id()
1045 }
1046
1047 pub(crate) fn static_account(&self) -> &StaticAccountData {
1049 &self.inner.static_account
1050 }
1051
1052 pub(crate) async fn cache(&self) -> Result<StoreCacheGuard> {
1053 Ok(StoreCacheGuard { cache: self.inner.cache.clone().read_owned().await })
1058 }
1059
1060 pub(crate) async fn transaction(&self) -> StoreTransaction {
1061 StoreTransaction::new(self.clone()).await
1062 }
1063
1064 pub(crate) async fn with_transaction<
1067 T,
1068 Fut: futures_core::Future<Output = Result<(StoreTransaction, T), crate::OlmError>>,
1069 F: FnOnce(StoreTransaction) -> Fut,
1070 >(
1071 &self,
1072 func: F,
1073 ) -> Result<T, crate::OlmError> {
1074 let tr = self.transaction().await;
1075 let (tr, res) = func(tr).await?;
1076 tr.commit().await?;
1077 Ok(res)
1078 }
1079
1080 #[cfg(test)]
1081 pub(crate) async fn reset_cross_signing_identity(&self) {
1083 self.inner.identity.lock().await.reset();
1084 }
1085
1086 pub(crate) fn private_identity(&self) -> Arc<Mutex<PrivateCrossSigningIdentity>> {
1088 self.inner.identity.clone()
1089 }
1090
1091 pub(crate) async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
1093 let changes = Changes { sessions: sessions.to_vec(), ..Default::default() };
1094
1095 self.save_changes(changes).await
1096 }
1097
1098 pub(crate) async fn get_sessions(
1099 &self,
1100 sender_key: &str,
1101 ) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
1102 self.inner.store.get_sessions(sender_key).await
1103 }
1104
1105 pub(crate) async fn save_changes(&self, changes: Changes) -> Result<()> {
1106 self.inner.store.save_changes(changes).await
1107 }
1108
1109 pub(crate) async fn compare_group_session(
1116 &self,
1117 session: &InboundGroupSession,
1118 ) -> Result<SessionOrdering> {
1119 let old_session = self
1120 .inner
1121 .store
1122 .get_inbound_group_session(session.room_id(), session.session_id())
1123 .await?;
1124
1125 Ok(if let Some(old_session) = old_session {
1126 session.compare(&old_session).await
1127 } else {
1128 SessionOrdering::Better
1129 })
1130 }
1131
1132 #[cfg(test)]
1133 pub(crate) async fn save_device_data(&self, devices: &[DeviceData]) -> Result<()> {
1135 let changes = Changes {
1136 devices: DeviceChanges { changed: devices.to_vec(), ..Default::default() },
1137 ..Default::default()
1138 };
1139
1140 self.save_changes(changes).await
1141 }
1142
1143 pub(crate) async fn save_inbound_group_sessions(
1145 &self,
1146 sessions: &[InboundGroupSession],
1147 ) -> Result<()> {
1148 let changes = Changes { inbound_group_sessions: sessions.to_vec(), ..Default::default() };
1149
1150 self.save_changes(changes).await
1151 }
1152
1153 pub(crate) async fn device_display_name(&self) -> Result<Option<String>, CryptoStoreError> {
1155 Ok(self
1156 .inner
1157 .store
1158 .get_device(self.user_id(), self.device_id())
1159 .await?
1160 .and_then(|d| d.display_name().map(|d| d.to_owned())))
1161 }
1162
1163 pub(crate) async fn get_device_data(
1168 &self,
1169 user_id: &UserId,
1170 device_id: &DeviceId,
1171 ) -> Result<Option<DeviceData>> {
1172 self.inner.store.get_device(user_id, device_id).await
1173 }
1174
1175 pub(crate) async fn get_device_data_for_user_filtered(
1183 &self,
1184 user_id: &UserId,
1185 ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1186 self.inner.store.get_user_devices(user_id).await.map(|mut d| {
1187 if user_id == self.user_id() {
1188 d.remove(self.device_id());
1189 }
1190 d
1191 })
1192 }
1193
1194 pub(crate) async fn get_device_data_for_user(
1203 &self,
1204 user_id: &UserId,
1205 ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1206 self.inner.store.get_user_devices(user_id).await
1207 }
1208
1209 pub(crate) async fn get_device_from_curve_key(
1215 &self,
1216 user_id: &UserId,
1217 curve_key: Curve25519PublicKey,
1218 ) -> Result<Option<Device>> {
1219 self.get_user_devices(user_id)
1220 .await
1221 .map(|d| d.devices().find(|d| d.curve25519_key() == Some(curve_key)))
1222 }
1223
1224 pub(crate) async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
1234 let devices = self.get_device_data_for_user(user_id).await?;
1235
1236 let own_identity = self
1237 .inner
1238 .store
1239 .get_user_identity(self.user_id())
1240 .await?
1241 .and_then(|i| i.own().cloned());
1242 let device_owner_identity = self.inner.store.get_user_identity(user_id).await?;
1243
1244 Ok(UserDevices {
1245 inner: devices,
1246 verification_machine: self.inner.verification_machine.clone(),
1247 own_identity,
1248 device_owner_identity,
1249 })
1250 }
1251
1252 pub(crate) async fn get_device(
1262 &self,
1263 user_id: &UserId,
1264 device_id: &DeviceId,
1265 ) -> Result<Option<Device>> {
1266 if let Some(device_data) = self.inner.store.get_device(user_id, device_id).await? {
1267 Ok(Some(self.wrap_device_data(device_data).await?))
1268 } else {
1269 Ok(None)
1270 }
1271 }
1272
1273 pub(crate) async fn wrap_device_data(&self, device_data: DeviceData) -> Result<Device> {
1278 let own_identity = self
1279 .inner
1280 .store
1281 .get_user_identity(self.user_id())
1282 .await?
1283 .and_then(|i| i.own().cloned());
1284
1285 let device_owner_identity =
1286 self.inner.store.get_user_identity(device_data.user_id()).await?;
1287
1288 Ok(Device {
1289 inner: device_data,
1290 verification_machine: self.inner.verification_machine.clone(),
1291 own_identity,
1292 device_owner_identity,
1293 })
1294 }
1295
1296 pub(crate) async fn get_identity(&self, user_id: &UserId) -> Result<Option<UserIdentity>> {
1298 let own_identity = self
1299 .inner
1300 .store
1301 .get_user_identity(self.user_id())
1302 .await?
1303 .and_then(as_variant!(UserIdentityData::Own));
1304
1305 Ok(self.inner.store.get_user_identity(user_id).await?.map(|i| {
1306 UserIdentity::new(
1307 self.clone(),
1308 i,
1309 self.inner.verification_machine.to_owned(),
1310 own_identity,
1311 )
1312 }))
1313 }
1314
1315 pub async fn export_secret(
1324 &self,
1325 secret_name: &SecretName,
1326 ) -> Result<Option<String>, CryptoStoreError> {
1327 Ok(match secret_name {
1328 SecretName::CrossSigningMasterKey
1329 | SecretName::CrossSigningUserSigningKey
1330 | SecretName::CrossSigningSelfSigningKey => {
1331 self.inner.identity.lock().await.export_secret(secret_name).await
1332 }
1333 SecretName::RecoveryKey => {
1334 if let Some(key) = self.load_backup_keys().await?.decryption_key {
1335 let exported = key.to_base64();
1336 Some(exported)
1337 } else {
1338 None
1339 }
1340 }
1341 name => {
1342 warn!(secret = ?name, "Unknown secret was requested");
1343 None
1344 }
1345 })
1346 }
1347
1348 pub async fn export_cross_signing_keys(
1356 &self,
1357 ) -> Result<Option<CrossSigningKeyExport>, CryptoStoreError> {
1358 let master_key = self.export_secret(&SecretName::CrossSigningMasterKey).await?;
1359 let self_signing_key = self.export_secret(&SecretName::CrossSigningSelfSigningKey).await?;
1360 let user_signing_key = self.export_secret(&SecretName::CrossSigningUserSigningKey).await?;
1361
1362 Ok(if master_key.is_none() && self_signing_key.is_none() && user_signing_key.is_none() {
1363 None
1364 } else {
1365 Some(CrossSigningKeyExport { master_key, self_signing_key, user_signing_key })
1366 })
1367 }
1368
1369 pub async fn import_cross_signing_keys(
1374 &self,
1375 export: CrossSigningKeyExport,
1376 ) -> Result<CrossSigningStatus, SecretImportError> {
1377 if let Some(public_identity) =
1378 self.get_identity(self.user_id()).await?.and_then(|i| i.own())
1379 {
1380 let identity = self.inner.identity.lock().await;
1381
1382 identity
1383 .import_secrets(
1384 public_identity.to_owned(),
1385 export.master_key.as_deref(),
1386 export.self_signing_key.as_deref(),
1387 export.user_signing_key.as_deref(),
1388 )
1389 .await?;
1390
1391 let status = identity.status().await;
1392
1393 let diff = identity.get_public_identity_diff(&public_identity.inner).await;
1394
1395 let mut changes =
1396 Changes { private_identity: Some(identity.clone()), ..Default::default() };
1397
1398 if diff.none_differ() {
1399 public_identity.mark_as_verified();
1400 changes.identities.changed.push(UserIdentityData::Own(public_identity.inner));
1401 }
1402
1403 info!(?status, "Successfully imported the private cross-signing keys");
1404
1405 self.save_changes(changes).await?;
1406 } else {
1407 warn!("No public identity found while importing cross-signing keys, a /keys/query needs to be done");
1408 }
1409
1410 Ok(self.inner.identity.lock().await.status().await)
1411 }
1412
1413 pub async fn export_secrets_bundle(&self) -> Result<SecretsBundle, SecretsBundleExportError> {
1425 let Some(cross_signing) = self.export_cross_signing_keys().await? else {
1426 return Err(SecretsBundleExportError::MissingCrossSigningKeys);
1427 };
1428
1429 let Some(master_key) = cross_signing.master_key.clone() else {
1430 return Err(SecretsBundleExportError::MissingCrossSigningKey(KeyUsage::Master));
1431 };
1432
1433 let Some(user_signing_key) = cross_signing.user_signing_key.clone() else {
1434 return Err(SecretsBundleExportError::MissingCrossSigningKey(KeyUsage::UserSigning));
1435 };
1436
1437 let Some(self_signing_key) = cross_signing.self_signing_key.clone() else {
1438 return Err(SecretsBundleExportError::MissingCrossSigningKey(KeyUsage::SelfSigning));
1439 };
1440
1441 let backup_keys = self.load_backup_keys().await?;
1442
1443 let backup = if let Some(key) = backup_keys.decryption_key {
1444 if let Some(backup_version) = backup_keys.backup_version {
1445 Some(BackupSecrets::MegolmBackupV1Curve25519AesSha2(
1446 MegolmBackupV1Curve25519AesSha2Secrets { key, backup_version },
1447 ))
1448 } else {
1449 return Err(SecretsBundleExportError::MissingBackupVersion);
1450 }
1451 } else {
1452 None
1453 };
1454
1455 Ok(SecretsBundle {
1456 cross_signing: CrossSigningSecrets { master_key, user_signing_key, self_signing_key },
1457 backup,
1458 })
1459 }
1460
1461 pub async fn import_secrets_bundle(
1474 &self,
1475 bundle: &SecretsBundle,
1476 ) -> Result<(), SecretImportError> {
1477 let mut changes = Changes::default();
1478
1479 if let Some(backup_bundle) = &bundle.backup {
1480 match backup_bundle {
1481 BackupSecrets::MegolmBackupV1Curve25519AesSha2(bundle) => {
1482 changes.backup_decryption_key = Some(bundle.key.clone());
1483 changes.backup_version = Some(bundle.backup_version.clone());
1484 }
1485 }
1486 }
1487
1488 let identity = self.inner.identity.lock().await;
1489
1490 identity
1491 .import_secrets_unchecked(
1492 Some(&bundle.cross_signing.master_key),
1493 Some(&bundle.cross_signing.self_signing_key),
1494 Some(&bundle.cross_signing.user_signing_key),
1495 )
1496 .await?;
1497
1498 let public_identity = identity.to_public_identity().await.expect(
1499 "We should be able to create a new public identity since we just imported \
1500 all the private cross-signing keys",
1501 );
1502
1503 changes.private_identity = Some(identity.clone());
1504 changes.identities.new.push(UserIdentityData::Own(public_identity));
1505
1506 Ok(self.save_changes(changes).await?)
1507 }
1508
1509 pub async fn import_secret(&self, secret: &GossippedSecret) -> Result<(), SecretImportError> {
1511 match &secret.secret_name {
1512 SecretName::CrossSigningMasterKey
1513 | SecretName::CrossSigningUserSigningKey
1514 | SecretName::CrossSigningSelfSigningKey => {
1515 if let Some(public_identity) =
1516 self.get_identity(self.user_id()).await?.and_then(|i| i.own())
1517 {
1518 let identity = self.inner.identity.lock().await;
1519
1520 identity
1521 .import_secret(
1522 public_identity,
1523 &secret.secret_name,
1524 &secret.event.content.secret,
1525 )
1526 .await?;
1527 info!(
1528 secret_name = ?secret.secret_name,
1529 "Successfully imported a private cross signing key"
1530 );
1531
1532 let changes =
1533 Changes { private_identity: Some(identity.clone()), ..Default::default() };
1534
1535 self.save_changes(changes).await?;
1536 }
1537 }
1538 SecretName::RecoveryKey => {
1539 }
1545 name => {
1546 warn!(secret = ?name, "Tried to import an unknown secret");
1547 }
1548 }
1549
1550 Ok(())
1551 }
1552
1553 pub async fn get_only_allow_trusted_devices(&self) -> Result<bool> {
1556 let value = self.get_value("only_allow_trusted_devices").await?.unwrap_or_default();
1557 Ok(value)
1558 }
1559
1560 pub async fn set_only_allow_trusted_devices(
1563 &self,
1564 block_untrusted_devices: bool,
1565 ) -> Result<()> {
1566 self.set_value("only_allow_trusted_devices", &block_untrusted_devices).await
1567 }
1568
1569 pub async fn get_value<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
1571 let Some(value) = self.get_custom_value(key).await? else {
1572 return Ok(None);
1573 };
1574 let deserialized = self.deserialize_value(&value)?;
1575 Ok(Some(deserialized))
1576 }
1577
1578 pub async fn set_value(&self, key: &str, value: &impl Serialize) -> Result<()> {
1580 let serialized = self.serialize_value(value)?;
1581 self.set_custom_value(key, serialized).await?;
1582 Ok(())
1583 }
1584
1585 fn serialize_value(&self, value: &impl Serialize) -> Result<Vec<u8>> {
1586 let serialized =
1587 rmp_serde::to_vec_named(value).map_err(|x| CryptoStoreError::Backend(x.into()))?;
1588 Ok(serialized)
1589 }
1590
1591 fn deserialize_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T> {
1592 let deserialized =
1593 rmp_serde::from_slice(value).map_err(|e| CryptoStoreError::Backend(e.into()))?;
1594 Ok(deserialized)
1595 }
1596
1597 pub fn room_keys_received_stream(
1609 &self,
1610 ) -> impl Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>> {
1611 self.inner.store.room_keys_received_stream()
1612 }
1613
1614 pub fn room_keys_withheld_received_stream(
1623 &self,
1624 ) -> impl Stream<Item = Vec<RoomKeyWithheldInfo>> {
1625 self.inner.store.room_keys_withheld_received_stream()
1626 }
1627
1628 pub fn user_identities_stream(&self) -> impl Stream<Item = IdentityUpdates> {
1659 let verification_machine = self.inner.verification_machine.to_owned();
1660
1661 let this = self.clone();
1662 self.inner.store.identities_stream().map(move |(own_identity, identities, _)| {
1663 let (new_identities, changed_identities, unchanged_identities) = identities.into_maps();
1664
1665 let map_identity = |(user_id, identity)| {
1666 (
1667 user_id,
1668 UserIdentity::new(
1669 this.clone(),
1670 identity,
1671 verification_machine.to_owned(),
1672 own_identity.to_owned(),
1673 ),
1674 )
1675 };
1676
1677 let new = new_identities.into_iter().map(map_identity).collect();
1678 let changed = changed_identities.into_iter().map(map_identity).collect();
1679 let unchanged = unchanged_identities.into_iter().map(map_identity).collect();
1680
1681 IdentityUpdates { new, changed, unchanged }
1682 })
1683 }
1684
1685 pub fn devices_stream(&self) -> impl Stream<Item = DeviceUpdates> {
1717 let verification_machine = self.inner.verification_machine.to_owned();
1718
1719 self.inner.store.identities_stream().map(move |(own_identity, identities, devices)| {
1720 collect_device_updates(
1721 verification_machine.to_owned(),
1722 own_identity,
1723 identities,
1724 devices,
1725 )
1726 })
1727 }
1728
1729 pub fn identities_stream_raw(&self) -> impl Stream<Item = (IdentityChanges, DeviceChanges)> {
1739 self.inner.store.identities_stream().map(|(_, identities, devices)| (identities, devices))
1740 }
1741
1742 pub fn create_store_lock(
1745 &self,
1746 lock_key: String,
1747 lock_value: String,
1748 ) -> CrossProcessStoreLock<LockableCryptoStore> {
1749 self.inner.store.create_store_lock(lock_key, lock_value)
1750 }
1751
1752 pub fn secrets_stream(&self) -> impl Stream<Item = GossippedSecret> {
1792 self.inner.store.secrets_stream()
1793 }
1794
1795 pub async fn import_room_keys(
1808 &self,
1809 exported_keys: Vec<ExportedRoomKey>,
1810 from_backup_version: Option<&str>,
1811 progress_listener: impl Fn(usize, usize),
1812 ) -> Result<RoomKeyImportResult> {
1813 let mut sessions = Vec::new();
1814
1815 async fn new_session_better(
1816 session: &InboundGroupSession,
1817 old_session: Option<InboundGroupSession>,
1818 ) -> bool {
1819 if let Some(old_session) = &old_session {
1820 session.compare(old_session).await == SessionOrdering::Better
1821 } else {
1822 true
1823 }
1824 }
1825
1826 let total_count = exported_keys.len();
1827 let mut keys = BTreeMap::new();
1828
1829 for (i, key) in exported_keys.into_iter().enumerate() {
1830 match InboundGroupSession::from_export(&key) {
1831 Ok(session) => {
1832 let old_session = self
1833 .inner
1834 .store
1835 .get_inbound_group_session(session.room_id(), session.session_id())
1836 .await?;
1837
1838 if new_session_better(&session, old_session).await {
1841 if from_backup_version.is_some() {
1842 session.mark_as_backed_up();
1843 }
1844
1845 keys.entry(session.room_id().to_owned())
1846 .or_insert_with(BTreeMap::new)
1847 .entry(session.sender_key().to_base64())
1848 .or_insert_with(BTreeSet::new)
1849 .insert(session.session_id().to_owned());
1850
1851 sessions.push(session);
1852 }
1853 }
1854 Err(e) => {
1855 warn!(
1856 sender_key= key.sender_key.to_base64(),
1857 room_id = ?key.room_id,
1858 session_id = key.session_id,
1859 error = ?e,
1860 "Couldn't import a room key from a file export."
1861 );
1862 }
1863 }
1864
1865 progress_listener(i, total_count);
1866 }
1867
1868 let imported_count = sessions.len();
1869
1870 self.inner.store.save_inbound_group_sessions(sessions, from_backup_version).await?;
1871
1872 info!(total_count, imported_count, room_keys = ?keys, "Successfully imported room keys");
1873
1874 Ok(RoomKeyImportResult::new(imported_count, total_count, keys))
1875 }
1876
1877 pub async fn import_exported_room_keys(
1904 &self,
1905 exported_keys: Vec<ExportedRoomKey>,
1906 progress_listener: impl Fn(usize, usize),
1907 ) -> Result<RoomKeyImportResult> {
1908 self.import_room_keys(exported_keys, None, progress_listener).await
1909 }
1910
1911 pub(crate) fn crypto_store(&self) -> Arc<CryptoStoreWrapper> {
1912 self.inner.store.clone()
1913 }
1914
1915 pub async fn export_room_keys(
1938 &self,
1939 predicate: impl FnMut(&InboundGroupSession) -> bool,
1940 ) -> Result<Vec<ExportedRoomKey>> {
1941 let mut exported = Vec::new();
1942
1943 let mut sessions = self.get_inbound_group_sessions().await?;
1944 sessions.retain(predicate);
1945
1946 for session in sessions {
1947 let export = session.export().await;
1948 exported.push(export);
1949 }
1950
1951 Ok(exported)
1952 }
1953
1954 pub async fn export_room_keys_stream(
1987 &self,
1988 predicate: impl FnMut(&InboundGroupSession) -> bool,
1989 ) -> Result<impl Stream<Item = ExportedRoomKey>> {
1990 let sessions = self.get_inbound_group_sessions().await?;
1992 Ok(futures_util::stream::iter(sessions.into_iter().filter(predicate))
1993 .then(|session| async move { session.export().await }))
1994 }
1995
1996 pub async fn build_room_key_bundle(
2001 &self,
2002 room_id: &RoomId,
2003 ) -> std::result::Result<RoomKeyBundle, CryptoStoreError> {
2004 let mut sessions = self.get_inbound_group_sessions().await?;
2007 sessions.retain(|session| session.room_id == room_id);
2008
2009 let mut bundle = RoomKeyBundle::default();
2010 for session in sessions {
2011 if session.shared_history() {
2012 bundle.room_keys.push(session.export().await.into());
2013 } else {
2014 bundle.withheld.push(RoomKeyWithheldContent::new(
2015 session.algorithm().to_owned(),
2016 WithheldCode::Unauthorised,
2017 session.room_id().to_owned(),
2018 session.session_id().to_owned(),
2019 session.sender_key().to_owned(),
2020 self.device_id().to_owned(),
2021 ));
2022 }
2023 }
2024
2025 Ok(bundle)
2026 }
2027}
2028
2029impl Deref for Store {
2030 type Target = DynCryptoStore;
2031
2032 fn deref(&self) -> &Self::Target {
2033 self.inner.store.deref().deref()
2034 }
2035}
2036
2037#[derive(Clone, Debug)]
2039pub struct LockableCryptoStore(Arc<dyn CryptoStore<Error = CryptoStoreError>>);
2040
2041#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
2042#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
2043impl matrix_sdk_common::store_locks::BackingStore for LockableCryptoStore {
2044 type LockError = CryptoStoreError;
2045
2046 async fn try_lock(
2047 &self,
2048 lease_duration_ms: u32,
2049 key: &str,
2050 holder: &str,
2051 ) -> std::result::Result<bool, Self::LockError> {
2052 self.0.try_take_leased_lock(lease_duration_ms, key, holder).await
2053 }
2054}
2055
2056#[cfg(test)]
2057mod tests {
2058 use std::pin::pin;
2059
2060 use futures_util::StreamExt;
2061 use insta::{_macro_support::Content, assert_json_snapshot, internals::ContentPath};
2062 use matrix_sdk_test::async_test;
2063 use ruma::{device_id, room_id, user_id, RoomId};
2064 use vodozemac::megolm::SessionKey;
2065
2066 use crate::{
2067 machine::test_helpers::get_machine_pair,
2068 olm::{InboundGroupSession, SenderData},
2069 store::DehydratedDeviceKey,
2070 types::EventEncryptionAlgorithm,
2071 OlmMachine,
2072 };
2073
2074 #[async_test]
2075 async fn test_import_room_keys_notifies_stream() {
2076 use futures_util::FutureExt;
2077
2078 let (alice, bob, _) =
2079 get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2080
2081 let room1_id = room_id!("!room1:localhost");
2082 alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2083 let exported_sessions = alice.store().export_room_keys(|_| true).await.unwrap();
2084
2085 let mut room_keys_received_stream = Box::pin(bob.store().room_keys_received_stream());
2086 bob.store().import_room_keys(exported_sessions, None, |_, _| {}).await.unwrap();
2087
2088 let room_keys = room_keys_received_stream
2089 .next()
2090 .now_or_never()
2091 .flatten()
2092 .expect("We should have received an update of room key infos")
2093 .unwrap();
2094 assert_eq!(room_keys.len(), 1);
2095 assert_eq!(room_keys[0].room_id, "!room1:localhost");
2096 }
2097
2098 #[async_test]
2099 async fn test_export_room_keys_provides_selected_keys() {
2100 let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2102 let room1_id = room_id!("!room1:localhost");
2103 let room2_id = room_id!("!room2:localhost");
2104 let room3_id = room_id!("!room3:localhost");
2105 alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2106 alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap();
2107 alice.create_outbound_group_session_with_defaults_test_helper(room3_id).await.unwrap();
2108
2109 let keys = alice
2111 .store()
2112 .export_room_keys(|s| s.room_id() == room2_id || s.room_id() == room3_id)
2113 .await
2114 .unwrap();
2115
2116 assert_eq!(keys.len(), 2);
2118 assert_eq!(keys[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2119 assert_eq!(keys[1].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2120 assert_eq!(keys[0].room_id, "!room2:localhost");
2121 assert_eq!(keys[1].room_id, "!room3:localhost");
2122 assert_eq!(keys[0].session_key.to_base64().len(), 220);
2123 assert_eq!(keys[1].session_key.to_base64().len(), 220);
2124 }
2125
2126 #[async_test]
2127 async fn test_export_room_keys_stream_can_provide_all_keys() {
2128 let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2130 let room1_id = room_id!("!room1:localhost");
2131 let room2_id = room_id!("!room2:localhost");
2132 alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2133 alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap();
2134
2135 let mut keys = pin!(alice.store().export_room_keys_stream(|_| true).await.unwrap());
2137
2138 let mut collected = vec![];
2140 while let Some(key) = keys.next().await {
2141 collected.push(key);
2142 }
2143
2144 assert_eq!(collected.len(), 2);
2146 assert_eq!(collected[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2147 assert_eq!(collected[1].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2148 assert_eq!(collected[0].room_id, "!room1:localhost");
2149 assert_eq!(collected[1].room_id, "!room2:localhost");
2150 assert_eq!(collected[0].session_key.to_base64().len(), 220);
2151 assert_eq!(collected[1].session_key.to_base64().len(), 220);
2152 }
2153
2154 #[async_test]
2155 async fn test_export_room_keys_stream_can_provide_a_subset_of_keys() {
2156 let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2158 let room1_id = room_id!("!room1:localhost");
2159 let room2_id = room_id!("!room2:localhost");
2160 alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2161 alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap();
2162
2163 let mut keys =
2165 pin!(alice.store().export_room_keys_stream(|s| s.room_id() == room1_id).await.unwrap());
2166
2167 let mut collected = vec![];
2169 while let Some(key) = keys.next().await {
2170 collected.push(key);
2171 }
2172
2173 assert_eq!(collected.len(), 1);
2175 assert_eq!(collected[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2176 assert_eq!(collected[0].room_id, "!room1:localhost");
2177 assert_eq!(collected[0].session_key.to_base64().len(), 220);
2178 }
2179
2180 #[async_test]
2181 async fn test_export_secrets_bundle() {
2182 let user_id = user_id!("@alice:example.com");
2183 let (first, second, _) = get_machine_pair(user_id, user_id, false).await;
2184
2185 let _ = first
2186 .bootstrap_cross_signing(false)
2187 .await
2188 .expect("We should be able to bootstrap cross-signing");
2189
2190 let bundle = first.store().export_secrets_bundle().await.expect(
2191 "We should be able to export the secrets bundle, now that we \
2192 have the cross-signing keys",
2193 );
2194
2195 assert!(bundle.backup.is_none(), "The bundle should not contain a backup key");
2196
2197 second
2198 .store()
2199 .import_secrets_bundle(&bundle)
2200 .await
2201 .expect("We should be able to import the secrets bundle");
2202
2203 let status = second.cross_signing_status().await;
2204 let identity = second.get_identity(user_id, None).await.unwrap().unwrap().own().unwrap();
2205
2206 assert!(identity.is_verified(), "The public identity should be marked as verified.");
2207
2208 assert!(status.is_complete(), "We should have imported all the cross-signing keys");
2209 }
2210
2211 #[async_test]
2212 async fn test_create_dehydrated_device_key() {
2213 let pickle_key = DehydratedDeviceKey::new()
2214 .expect("Should be able to create a random dehydrated device key");
2215
2216 let to_vec = pickle_key.inner.to_vec();
2217 let pickle_key_from_slice = DehydratedDeviceKey::from_slice(to_vec.as_slice())
2218 .expect("Should be able to create a dehydrated device key from slice");
2219
2220 assert_eq!(pickle_key_from_slice.to_base64(), pickle_key.to_base64());
2221 }
2222
2223 #[async_test]
2224 async fn test_create_dehydrated_errors() {
2225 let too_small = [0u8; 22];
2226 let pickle_key = DehydratedDeviceKey::from_slice(&too_small);
2227
2228 assert!(pickle_key.is_err());
2229
2230 let too_big = [0u8; 40];
2231 let pickle_key = DehydratedDeviceKey::from_slice(&too_big);
2232
2233 assert!(pickle_key.is_err());
2234 }
2235
2236 #[async_test]
2237 async fn test_build_room_key_bundle() {
2238 let alice = OlmMachine::new(user_id!("@a:s.co"), device_id!("ALICE")).await;
2241 let bob = OlmMachine::new(user_id!("@b:s.co"), device_id!("BOB")).await;
2242
2243 let room1_id = room_id!("!room1:localhost");
2244 let room2_id = room_id!("!room2:localhost");
2245
2246 let session_key1 = "AgAAAAC2XHVzsMBKs4QCRElJ92CJKyGtknCSC8HY7cQ7UYwndMKLQAejXLh5UA0l6s736mgctcUMNvELScUWrObdflrHo+vth/gWreXOaCnaSxmyjjKErQwyIYTkUfqbHy40RJfEesLwnN23on9XAkch/iy8R2+Jz7B8zfG01f2Ow2SxPQFnAndcO1ZSD2GmXgedy6n4B20MWI1jGP2wiexOWbFSya8DO/VxC9m5+/mF+WwYqdpKn9g4Y05Yw4uz7cdjTc3rXm7xK+8E7hI//5QD1nHPvuKYbjjM9u2JSL+Bzp61Cw";
2251 let session_key2 = "AgAAAAC1BXreFTUQQSBGekTEuYxhdytRKyv4JgDGcG+VOBYdPNGgs807SdibCGJky4lJ3I+7ZDGHoUzZPZP/4ogGu4kxni0PWdtWuN7+5zsuamgoFF/BkaGeUUGv6kgIkx8pyPpM5SASTUEP9bN2loDSpUPYwfiIqz74DgC4WQ4435sTBctYvKz8n+TDJwdLXpyT6zKljuqADAioud+s/iqx9LYn9HpbBfezZcvbg67GtE113pLrvde3IcPI5s6dNHK2onGO2B2eoaobcen18bbEDnlUGPeIivArLya7Da6us14jBQ";
2252 let session_key3 = "AgAAAAAM9KFsliaUUhGSXgwOzM5UemjkNH4n8NHgvC/y8hhw13zTF+ooGD4uIYEXYX630oNvQm/EvgZo+dkoc0re+vsqsx4sQeNODdSjcBsWOa0oDF+irQn9oYoLUDPI1IBtY1rX+FV99Zm/xnG7uFOX7aTVlko2GSdejy1w9mfobmfxu5aUc04A9zaKJP1pOthZvRAlhpymGYHgsDtWPrrjyc/yypMflE4kIUEEEtu1kT6mrAmcl615XYRAHYK9G2+fZsGvokwzbkl4nulGwcZMpQEoM0nD2o3GWgX81HW3nGfKBg";
2253 let session_key4 = "AgAAAAA4Kkesxq2h4v9PLD6Sm3Smxspz1PXTqytQPCMQMkkrHNmzV2bHlJ+6/Al9cu8vh1Oj69AK0WUAeJOJuaiskEeg/PI3P03+UYLeC379RzgqwSHdBgdQ41G2vD6zpgmE/8vYToe+qpCZACtPOswZxyqxHH+T/Iq0nv13JmlFGIeA6fEPfr5Y28B49viG74Fs9rxV9EH5PfjbuPM/p+Sz5obShuaBPKQBX1jT913nEXPoIJ06exNZGr0285nw/LgVvNlmWmbqNnbzO2cNZjQWA+xZYz5FSfyCxwqEBbEdUCuRCQ";
2254
2255 let sessions = [
2256 create_inbound_group_session_with_visibility(
2257 &alice,
2258 room1_id,
2259 &SessionKey::from_base64(session_key1).unwrap(),
2260 true,
2261 ),
2262 create_inbound_group_session_with_visibility(
2263 &alice,
2264 room1_id,
2265 &SessionKey::from_base64(session_key2).unwrap(),
2266 true,
2267 ),
2268 create_inbound_group_session_with_visibility(
2269 &alice,
2270 room1_id,
2271 &SessionKey::from_base64(session_key3).unwrap(),
2272 false,
2273 ),
2274 create_inbound_group_session_with_visibility(
2275 &alice,
2276 room2_id,
2277 &SessionKey::from_base64(session_key4).unwrap(),
2278 true,
2279 ),
2280 ];
2281 bob.store().save_inbound_group_sessions(&sessions).await.unwrap();
2282
2283 let mut bundle = bob.store().build_room_key_bundle(room1_id).await.unwrap();
2285
2286 bundle.room_keys.sort_by_key(|session| session.session_id.clone());
2290
2291 let alice_curve_key = alice.identity_keys().curve25519.to_base64();
2293 let map_alice_curve_key = move |value: Content, _path: ContentPath<'_>| {
2294 assert_eq!(value.as_str().unwrap(), alice_curve_key);
2295 "[alice curve key]"
2296 };
2297 let alice_ed25519_key = alice.identity_keys().ed25519.to_base64();
2298 let map_alice_ed25519_key = move |value: Content, _path: ContentPath<'_>| {
2299 assert_eq!(value.as_str().unwrap(), alice_ed25519_key);
2300 "[alice ed25519 key]"
2301 };
2302
2303 insta::with_settings!({ sort_maps => true }, {
2304 assert_json_snapshot!(bundle, {
2305 ".room_keys[].sender_key" => insta::dynamic_redaction(map_alice_curve_key.clone()),
2306 ".withheld[].sender_key" => insta::dynamic_redaction(map_alice_curve_key),
2307 ".room_keys[].sender_claimed_keys.ed25519" => insta::dynamic_redaction(map_alice_ed25519_key),
2308 });
2309 });
2310 }
2311
2312 fn create_inbound_group_session_with_visibility(
2317 olm_machine: &OlmMachine,
2318 room_id: &RoomId,
2319 session_key: &SessionKey,
2320 shared_history: bool,
2321 ) -> InboundGroupSession {
2322 let identity_keys = &olm_machine.store().static_account().identity_keys;
2323 InboundGroupSession::new(
2324 identity_keys.curve25519,
2325 identity_keys.ed25519,
2326 room_id,
2327 session_key,
2328 SenderData::unknown(),
2329 EventEncryptionAlgorithm::MegolmV1AesSha2,
2330 None,
2331 shared_history,
2332 )
2333 .unwrap()
2334 }
2335}