1use std::{
42 collections::{BTreeMap, BTreeSet, HashMap, HashSet},
43 fmt::Debug,
44 ops::Deref,
45 pin::pin,
46 sync::{atomic::Ordering, Arc},
47 time::Duration,
48};
49
50use as_variant::as_variant;
51use futures_core::Stream;
52use futures_util::StreamExt;
53use matrix_sdk_common::locks::RwLock as StdRwLock;
54use ruma::{
55 encryption::KeyUsage, events::secret::request::SecretName, DeviceId, OwnedDeviceId,
56 OwnedRoomId, OwnedUserId, UserId,
57};
58use serde::{de::DeserializeOwned, Deserialize, Serialize};
59use thiserror::Error;
60use tokio::sync::{Mutex, MutexGuard, Notify, OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
61use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
62use tracing::{info, warn};
63use vodozemac::{base64_encode, megolm::SessionOrdering, Curve25519PublicKey};
64use zeroize::{Zeroize, ZeroizeOnDrop};
65
66#[cfg(doc)]
67use crate::{backups::BackupMachine, identities::OwnUserIdentity};
68use crate::{
69 gossiping::GossippedSecret,
70 identities::{user::UserIdentity, Device, DeviceData, UserDevices, UserIdentityData},
71 olm::{
72 Account, ExportedRoomKey, InboundGroupSession, OlmMessageHash, OutboundGroupSession,
73 PrivateCrossSigningIdentity, Session, StaticAccountData,
74 },
75 types::{
76 events::room_key_withheld::RoomKeyWithheldEvent, BackupSecrets, CrossSigningSecrets,
77 EventEncryptionAlgorithm, MegolmBackupV1Curve25519AesSha2Secrets, SecretsBundle,
78 },
79 verification::VerificationMachine,
80 CrossSigningStatus, OwnUserIdentityData, RoomKeyImportResult,
81};
82
83pub mod caches;
84mod crypto_store_wrapper;
85mod error;
86mod memorystore;
87mod traits;
88
89#[cfg(any(test, feature = "testing"))]
90#[macro_use]
91#[allow(missing_docs)]
92pub mod integration_tests;
93
94use caches::{SequenceNumber, UsersForKeyQuery};
95pub(crate) use crypto_store_wrapper::CryptoStoreWrapper;
96pub use error::{CryptoStoreError, Result};
97use matrix_sdk_common::{store_locks::CrossProcessStoreLock, timeout::timeout};
98pub use memorystore::MemoryStore;
99pub use traits::{CryptoStore, DynCryptoStore, IntoCryptoStore};
100
101pub use crate::{
102 dehydrated_devices::DehydrationError,
103 gossiping::{GossipRequest, SecretInfo},
104};
105
106#[derive(Debug, Clone)]
113pub struct Store {
114 inner: Arc<StoreInner>,
115}
116
117#[derive(Debug, Default)]
118pub(crate) struct KeyQueryManager {
119 users_for_key_query: Mutex<UsersForKeyQuery>,
121
122 users_for_key_query_notify: Notify,
124}
125
126impl KeyQueryManager {
127 pub async fn synced<'a>(&'a self, cache: &'a StoreCache) -> Result<SyncedKeyQueryManager<'a>> {
128 self.ensure_sync_tracked_users(cache).await?;
129 Ok(SyncedKeyQueryManager { cache, manager: self })
130 }
131
132 async fn ensure_sync_tracked_users(&self, cache: &StoreCache) -> Result<()> {
139 let loaded = cache.loaded_tracked_users.read().await;
141 if *loaded {
142 return Ok(());
143 }
144
145 drop(loaded);
147 let mut loaded = cache.loaded_tracked_users.write().await;
148
149 if *loaded {
153 return Ok(());
154 }
155
156 let tracked_users = cache.store.load_tracked_users().await?;
157
158 let mut query_users_lock = self.users_for_key_query.lock().await;
159 let mut tracked_users_cache = cache.tracked_users.write();
160 for user in tracked_users {
161 tracked_users_cache.insert(user.user_id.to_owned());
162
163 if user.dirty {
164 query_users_lock.insert_user(&user.user_id);
165 }
166 }
167
168 *loaded = true;
169
170 Ok(())
171 }
172
173 pub async fn wait_if_user_key_query_pending(
183 &self,
184 cache: StoreCacheGuard,
185 timeout_duration: Duration,
186 user: &UserId,
187 ) -> Result<UserKeyQueryResult> {
188 {
189 self.ensure_sync_tracked_users(&cache).await?;
192 drop(cache);
193 }
194
195 let mut users_for_key_query = self.users_for_key_query.lock().await;
196 let Some(waiter) = users_for_key_query.maybe_register_waiting_task(user) else {
197 return Ok(UserKeyQueryResult::WasNotPending);
198 };
199
200 let wait_for_completion = async {
201 while !waiter.completed.load(Ordering::Relaxed) {
202 let mut notified = pin!(self.users_for_key_query_notify.notified());
206 notified.as_mut().enable();
207 drop(users_for_key_query);
208
209 notified.await;
211
212 users_for_key_query = self.users_for_key_query.lock().await;
216 }
217 };
218
219 match timeout(Box::pin(wait_for_completion), timeout_duration).await {
220 Err(_) => {
221 warn!(
222 user_id = ?user,
223 "The user has a pending `/keys/query` request which did \
224 not finish yet, some devices might be missing."
225 );
226
227 Ok(UserKeyQueryResult::TimeoutExpired)
228 }
229 _ => Ok(UserKeyQueryResult::WasPending),
230 }
231 }
232}
233
234pub(crate) struct SyncedKeyQueryManager<'a> {
235 cache: &'a StoreCache,
236 manager: &'a KeyQueryManager,
237}
238
239impl SyncedKeyQueryManager<'_> {
240 pub async fn update_tracked_users(&self, users: impl Iterator<Item = &UserId>) -> Result<()> {
245 let mut store_updates = Vec::new();
246 let mut key_query_lock = self.manager.users_for_key_query.lock().await;
247
248 {
249 let mut tracked_users = self.cache.tracked_users.write();
250 for user_id in users {
251 if tracked_users.insert(user_id.to_owned()) {
252 key_query_lock.insert_user(user_id);
253 store_updates.push((user_id, true))
254 }
255 }
256 }
257
258 self.cache.store.save_tracked_users(&store_updates).await
259 }
260
261 pub async fn mark_tracked_users_as_changed(
268 &self,
269 users: impl Iterator<Item = &UserId>,
270 ) -> Result<()> {
271 let mut store_updates: Vec<(&UserId, bool)> = Vec::new();
272 let mut key_query_lock = self.manager.users_for_key_query.lock().await;
273
274 {
275 let tracked_users = &self.cache.tracked_users.read();
276 for user_id in users {
277 if tracked_users.contains(user_id) {
278 key_query_lock.insert_user(user_id);
279 store_updates.push((user_id, true));
280 }
281 }
282 }
283
284 self.cache.store.save_tracked_users(&store_updates).await
285 }
286
287 pub async fn mark_tracked_users_as_up_to_date(
293 &self,
294 users: impl Iterator<Item = &UserId>,
295 sequence_number: SequenceNumber,
296 ) -> Result<()> {
297 let mut store_updates: Vec<(&UserId, bool)> = Vec::new();
298 let mut key_query_lock = self.manager.users_for_key_query.lock().await;
299
300 {
301 let tracked_users = self.cache.tracked_users.read();
302 for user_id in users {
303 if tracked_users.contains(user_id) {
304 let clean = key_query_lock.maybe_remove_user(user_id, sequence_number);
305 store_updates.push((user_id, !clean));
306 }
307 }
308 }
309
310 self.cache.store.save_tracked_users(&store_updates).await?;
311 self.manager.users_for_key_query_notify.notify_waiters();
313
314 Ok(())
315 }
316
317 pub async fn users_for_key_query(&self) -> (HashSet<OwnedUserId>, SequenceNumber) {
329 self.manager.users_for_key_query.lock().await.users_for_key_query()
330 }
331
332 pub fn tracked_users(&self) -> HashSet<OwnedUserId> {
334 self.cache.tracked_users.read().iter().cloned().collect()
335 }
336
337 pub async fn mark_user_as_changed(&self, user: &UserId) -> Result<()> {
343 self.manager.users_for_key_query.lock().await.insert_user(user);
344 self.cache.tracked_users.write().insert(user.to_owned());
345
346 self.cache.store.save_tracked_users(&[(user, true)]).await
347 }
348}
349
350#[derive(Debug)]
351pub(crate) struct StoreCache {
352 store: Arc<CryptoStoreWrapper>,
353 tracked_users: StdRwLock<BTreeSet<OwnedUserId>>,
354 loaded_tracked_users: RwLock<bool>,
355 account: Mutex<Option<Account>>,
356}
357
358impl StoreCache {
359 pub(crate) fn store_wrapper(&self) -> &CryptoStoreWrapper {
360 self.store.as_ref()
361 }
362
363 async fn account(&self) -> Result<impl Deref<Target = Account> + '_> {
375 let mut guard = self.account.lock().await;
376 if guard.is_some() {
377 Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
378 } else {
379 match self.store.load_account().await? {
380 Some(account) => {
381 *guard = Some(account);
382 Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
383 }
384 None => Err(CryptoStoreError::AccountUnset),
385 }
386 }
387 }
388}
389
390pub(crate) struct StoreCacheGuard {
396 cache: OwnedRwLockReadGuard<StoreCache>,
397 }
399
400impl StoreCacheGuard {
401 pub async fn account(&self) -> Result<impl Deref<Target = Account> + '_> {
409 self.cache.account().await
410 }
411}
412
413impl Deref for StoreCacheGuard {
414 type Target = StoreCache;
415
416 fn deref(&self) -> &Self::Target {
417 &self.cache
418 }
419}
420
421#[allow(missing_debug_implementations)]
423pub struct StoreTransaction {
424 store: Store,
425 changes: PendingChanges,
426 cache: OwnedRwLockWriteGuard<StoreCache>,
428}
429
430impl StoreTransaction {
431 async fn new(store: Store) -> Self {
433 let cache = store.inner.cache.clone();
434
435 Self { store, changes: PendingChanges::default(), cache: cache.clone().write_owned().await }
436 }
437
438 pub(crate) fn cache(&self) -> &StoreCache {
439 &self.cache
440 }
441
442 pub fn store(&self) -> &Store {
444 &self.store
445 }
446
447 pub async fn account(&mut self) -> Result<&mut Account> {
454 if self.changes.account.is_none() {
455 let _ = self.cache.account().await?;
457 self.changes.account = self.cache.account.lock().await.take();
458 }
459 Ok(self.changes.account.as_mut().unwrap())
460 }
461
462 pub async fn commit(self) -> Result<()> {
465 if self.changes.is_empty() {
466 return Ok(());
467 }
468
469 let account = self.changes.account.as_ref().map(|acc| acc.deep_clone());
471
472 self.store.save_pending_changes(self.changes).await?;
473
474 if let Some(account) = account {
476 *self.cache.account.lock().await = Some(account);
477 }
478
479 Ok(())
480 }
481}
482
483#[derive(Debug)]
484struct StoreInner {
485 identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
486 store: Arc<CryptoStoreWrapper>,
487
488 cache: Arc<RwLock<StoreCache>>,
492
493 verification_machine: VerificationMachine,
494
495 static_account: StaticAccountData,
498}
499
500#[derive(Default, Debug)]
506#[allow(missing_docs)]
507pub struct PendingChanges {
508 pub account: Option<Account>,
509}
510
511impl PendingChanges {
512 pub fn is_empty(&self) -> bool {
514 self.account.is_none()
515 }
516}
517
518#[derive(Default, Debug)]
521#[allow(missing_docs)]
522pub struct Changes {
523 pub private_identity: Option<PrivateCrossSigningIdentity>,
524 pub backup_version: Option<String>,
525 pub backup_decryption_key: Option<BackupDecryptionKey>,
526 pub dehydrated_device_pickle_key: Option<DehydratedDeviceKey>,
527 pub sessions: Vec<Session>,
528 pub message_hashes: Vec<OlmMessageHash>,
529 pub inbound_group_sessions: Vec<InboundGroupSession>,
530 pub outbound_group_sessions: Vec<OutboundGroupSession>,
531 pub key_requests: Vec<GossipRequest>,
532 pub identities: IdentityChanges,
533 pub devices: DeviceChanges,
534 pub withheld_session_info: BTreeMap<OwnedRoomId, BTreeMap<String, RoomKeyWithheldEvent>>,
536 pub room_settings: HashMap<OwnedRoomId, RoomSettings>,
537 pub secrets: Vec<GossippedSecret>,
538 pub next_batch_token: Option<String>,
539}
540
541#[derive(Clone, Debug, Serialize, Deserialize)]
543pub struct TrackedUser {
544 pub user_id: OwnedUserId,
546 pub dirty: bool,
551}
552
553impl Changes {
554 pub fn is_empty(&self) -> bool {
556 self.private_identity.is_none()
557 && self.backup_version.is_none()
558 && self.backup_decryption_key.is_none()
559 && self.dehydrated_device_pickle_key.is_none()
560 && self.sessions.is_empty()
561 && self.message_hashes.is_empty()
562 && self.inbound_group_sessions.is_empty()
563 && self.outbound_group_sessions.is_empty()
564 && self.key_requests.is_empty()
565 && self.identities.is_empty()
566 && self.devices.is_empty()
567 && self.withheld_session_info.is_empty()
568 && self.room_settings.is_empty()
569 && self.secrets.is_empty()
570 && self.next_batch_token.is_none()
571 }
572}
573
574#[derive(Debug, Clone, Default)]
585#[allow(missing_docs)]
586pub struct IdentityChanges {
587 pub new: Vec<UserIdentityData>,
588 pub changed: Vec<UserIdentityData>,
589 pub unchanged: Vec<UserIdentityData>,
590}
591
592impl IdentityChanges {
593 fn is_empty(&self) -> bool {
594 self.new.is_empty() && self.changed.is_empty()
595 }
596
597 fn into_maps(
600 self,
601 ) -> (
602 BTreeMap<OwnedUserId, UserIdentityData>,
603 BTreeMap<OwnedUserId, UserIdentityData>,
604 BTreeMap<OwnedUserId, UserIdentityData>,
605 ) {
606 let new: BTreeMap<_, _> = self
607 .new
608 .into_iter()
609 .map(|identity| (identity.user_id().to_owned(), identity))
610 .collect();
611
612 let changed: BTreeMap<_, _> = self
613 .changed
614 .into_iter()
615 .map(|identity| (identity.user_id().to_owned(), identity))
616 .collect();
617
618 let unchanged: BTreeMap<_, _> = self
619 .unchanged
620 .into_iter()
621 .map(|identity| (identity.user_id().to_owned(), identity))
622 .collect();
623
624 (new, changed, unchanged)
625 }
626}
627
628#[derive(Debug, Clone, Default)]
629#[allow(missing_docs)]
630pub struct DeviceChanges {
631 pub new: Vec<DeviceData>,
632 pub changed: Vec<DeviceData>,
633 pub deleted: Vec<DeviceData>,
634}
635
636fn collect_device_updates(
642 verification_machine: VerificationMachine,
643 own_identity: Option<OwnUserIdentityData>,
644 identities: IdentityChanges,
645 devices: DeviceChanges,
646) -> DeviceUpdates {
647 let mut new: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new();
648 let mut changed: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new();
649
650 let (new_identities, changed_identities, unchanged_identities) = identities.into_maps();
651
652 let map_device = |device: DeviceData| {
653 let device_owner_identity = new_identities
654 .get(device.user_id())
655 .or_else(|| changed_identities.get(device.user_id()))
656 .or_else(|| unchanged_identities.get(device.user_id()))
657 .cloned();
658
659 Device {
660 inner: device,
661 verification_machine: verification_machine.to_owned(),
662 own_identity: own_identity.to_owned(),
663 device_owner_identity,
664 }
665 };
666
667 for device in devices.new {
668 let device = map_device(device);
669
670 new.entry(device.user_id().to_owned())
671 .or_default()
672 .insert(device.device_id().to_owned(), device);
673 }
674
675 for device in devices.changed {
676 let device = map_device(device);
677
678 changed
679 .entry(device.user_id().to_owned())
680 .or_default()
681 .insert(device.device_id().to_owned(), device.to_owned());
682 }
683
684 DeviceUpdates { new, changed }
685}
686
687#[derive(Clone, Debug, Default)]
690pub struct DeviceUpdates {
691 pub new: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, Device>>,
697 pub changed: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, Device>>,
699}
700
701#[derive(Clone, Debug, Default)]
704pub struct IdentityUpdates {
705 pub new: BTreeMap<OwnedUserId, UserIdentity>,
711 pub changed: BTreeMap<OwnedUserId, UserIdentity>,
713 pub unchanged: BTreeMap<OwnedUserId, UserIdentity>,
715}
716
717#[derive(Clone, Zeroize, ZeroizeOnDrop, Deserialize, Serialize)]
727#[serde(transparent)]
728pub struct BackupDecryptionKey {
729 pub(crate) inner: Box<[u8; BackupDecryptionKey::KEY_SIZE]>,
730}
731
732impl BackupDecryptionKey {
733 pub const KEY_SIZE: usize = 32;
735
736 pub fn new() -> Result<Self, rand::Error> {
738 let mut rng = rand::thread_rng();
739
740 let mut key = Box::new([0u8; Self::KEY_SIZE]);
741 rand::Fill::try_fill(key.as_mut_slice(), &mut rng)?;
742
743 Ok(Self { inner: key })
744 }
745
746 pub fn to_base64(&self) -> String {
748 base64_encode(self.inner.as_slice())
749 }
750}
751
752#[cfg(not(tarpaulin_include))]
753impl Debug for BackupDecryptionKey {
754 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
755 f.debug_tuple("BackupDecryptionKey").field(&"...").finish()
756 }
757}
758
759#[derive(Clone, Zeroize, ZeroizeOnDrop, Deserialize, Serialize)]
764#[serde(transparent)]
765pub struct DehydratedDeviceKey {
766 pub(crate) inner: Box<[u8; DehydratedDeviceKey::KEY_SIZE]>,
767}
768
769impl DehydratedDeviceKey {
770 pub const KEY_SIZE: usize = 32;
772
773 pub fn new() -> Result<Self, rand::Error> {
775 let mut rng = rand::thread_rng();
776
777 let mut key = Box::new([0u8; Self::KEY_SIZE]);
778 rand::Fill::try_fill(key.as_mut_slice(), &mut rng)?;
779
780 Ok(Self { inner: key })
781 }
782
783 pub fn from_slice(slice: &[u8]) -> Result<Self, DehydrationError> {
787 if slice.len() == 32 {
788 let mut key = Box::new([0u8; 32]);
789 key.copy_from_slice(slice);
790 Ok(DehydratedDeviceKey { inner: key })
791 } else {
792 Err(DehydrationError::PickleKeyLength(slice.len()))
793 }
794 }
795
796 pub fn from_bytes(raw_key: &[u8; 32]) -> Self {
798 let mut inner = Box::new([0u8; Self::KEY_SIZE]);
799 inner.copy_from_slice(raw_key);
800
801 Self { inner }
802 }
803
804 pub fn to_base64(&self) -> String {
806 base64_encode(self.inner.as_slice())
807 }
808}
809
810impl From<&[u8; 32]> for DehydratedDeviceKey {
811 fn from(value: &[u8; 32]) -> Self {
812 DehydratedDeviceKey { inner: Box::new(*value) }
813 }
814}
815
816impl From<DehydratedDeviceKey> for Vec<u8> {
817 fn from(key: DehydratedDeviceKey) -> Self {
818 key.inner.to_vec()
819 }
820}
821
822#[cfg(not(tarpaulin_include))]
823impl Debug for DehydratedDeviceKey {
824 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
825 f.debug_tuple("DehydratedDeviceKey").field(&"...").finish()
826 }
827}
828
829impl DeviceChanges {
830 pub fn extend(&mut self, other: DeviceChanges) {
832 self.new.extend(other.new);
833 self.changed.extend(other.changed);
834 self.deleted.extend(other.deleted);
835 }
836
837 fn is_empty(&self) -> bool {
838 self.new.is_empty() && self.changed.is_empty() && self.deleted.is_empty()
839 }
840}
841
842#[derive(Debug, Clone, Default)]
844pub struct RoomKeyCounts {
845 pub total: usize,
847 pub backed_up: usize,
849}
850
851#[derive(Default, Clone, Debug)]
853pub struct BackupKeys {
854 pub decryption_key: Option<BackupDecryptionKey>,
856 pub backup_version: Option<String>,
858}
859
860#[derive(Default, Zeroize, ZeroizeOnDrop)]
863pub struct CrossSigningKeyExport {
864 pub master_key: Option<String>,
866 pub self_signing_key: Option<String>,
868 pub user_signing_key: Option<String>,
870}
871
872#[cfg(not(tarpaulin_include))]
873impl Debug for CrossSigningKeyExport {
874 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
875 f.debug_struct("CrossSigningKeyExport")
876 .field("master_key", &self.master_key.is_some())
877 .field("self_signing_key", &self.self_signing_key.is_some())
878 .field("user_signing_key", &self.user_signing_key.is_some())
879 .finish_non_exhaustive()
880 }
881}
882
883#[derive(Debug, Error)]
886pub enum SecretImportError {
887 #[error(transparent)]
889 Key(#[from] vodozemac::KeyError),
890 #[error(
893 "The public key of the imported private key doesn't match to the \
894 public key that was uploaded to the server"
895 )]
896 MismatchedPublicKeys,
897 #[error(transparent)]
899 Store(#[from] CryptoStoreError),
900}
901
902#[derive(Debug, Error)]
907pub enum SecretsBundleExportError {
908 #[error(transparent)]
910 Store(#[from] CryptoStoreError),
911 #[error("The store is missing one or multiple cross-signing keys")]
913 MissingCrossSigningKey(KeyUsage),
914 #[error("The store doesn't contain any cross-signing keys")]
916 MissingCrossSigningKeys,
917 #[error("The store contains a backup key, but no backup version")]
920 MissingBackupVersion,
921}
922
923#[derive(Clone, Copy, Debug, PartialEq, Eq)]
926pub(crate) enum UserKeyQueryResult {
927 WasPending,
928 WasNotPending,
929
930 TimeoutExpired,
932}
933
934#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
936pub struct RoomSettings {
937 pub algorithm: EventEncryptionAlgorithm,
939
940 pub only_allow_trusted_devices: bool,
943
944 pub session_rotation_period: Option<Duration>,
947
948 pub session_rotation_period_messages: Option<usize>,
951}
952
953impl Default for RoomSettings {
954 fn default() -> Self {
955 Self {
956 algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
957 only_allow_trusted_devices: false,
958 session_rotation_period: None,
959 session_rotation_period_messages: None,
960 }
961 }
962}
963
964#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
966pub struct RoomKeyInfo {
967 pub algorithm: EventEncryptionAlgorithm,
972
973 pub room_id: OwnedRoomId,
975
976 pub sender_key: Curve25519PublicKey,
978
979 pub session_id: String,
981}
982
983impl From<&InboundGroupSession> for RoomKeyInfo {
984 fn from(group_session: &InboundGroupSession) -> Self {
985 RoomKeyInfo {
986 algorithm: group_session.algorithm().clone(),
987 room_id: group_session.room_id().to_owned(),
988 sender_key: group_session.sender_key(),
989 session_id: group_session.session_id().to_owned(),
990 }
991 }
992}
993
994#[derive(Clone, Debug, Deserialize, Serialize)]
996pub struct RoomKeyWithheldInfo {
997 pub room_id: OwnedRoomId,
999
1000 pub session_id: String,
1002
1003 pub withheld_event: RoomKeyWithheldEvent,
1006}
1007
1008impl Store {
1009 pub(crate) fn new(
1011 account: StaticAccountData,
1012 identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
1013 store: Arc<CryptoStoreWrapper>,
1014 verification_machine: VerificationMachine,
1015 ) -> Self {
1016 Self {
1017 inner: Arc::new(StoreInner {
1018 static_account: account,
1019 identity,
1020 store: store.clone(),
1021 verification_machine,
1022 cache: Arc::new(RwLock::new(StoreCache {
1023 store,
1024 tracked_users: Default::default(),
1025 loaded_tracked_users: Default::default(),
1026 account: Default::default(),
1027 })),
1028 }),
1029 }
1030 }
1031
1032 pub(crate) fn user_id(&self) -> &UserId {
1034 &self.inner.static_account.user_id
1035 }
1036
1037 pub(crate) fn device_id(&self) -> &DeviceId {
1039 self.inner.verification_machine.own_device_id()
1040 }
1041
1042 pub(crate) fn static_account(&self) -> &StaticAccountData {
1044 &self.inner.static_account
1045 }
1046
1047 pub(crate) async fn cache(&self) -> Result<StoreCacheGuard> {
1048 Ok(StoreCacheGuard { cache: self.inner.cache.clone().read_owned().await })
1053 }
1054
1055 pub(crate) async fn transaction(&self) -> StoreTransaction {
1056 StoreTransaction::new(self.clone()).await
1057 }
1058
1059 pub(crate) async fn with_transaction<
1062 T,
1063 Fut: futures_core::Future<Output = Result<(StoreTransaction, T), crate::OlmError>>,
1064 F: FnOnce(StoreTransaction) -> Fut,
1065 >(
1066 &self,
1067 func: F,
1068 ) -> Result<T, crate::OlmError> {
1069 let tr = self.transaction().await;
1070 let (tr, res) = func(tr).await?;
1071 tr.commit().await?;
1072 Ok(res)
1073 }
1074
1075 #[cfg(test)]
1076 pub(crate) async fn reset_cross_signing_identity(&self) {
1078 self.inner.identity.lock().await.reset();
1079 }
1080
1081 pub(crate) fn private_identity(&self) -> Arc<Mutex<PrivateCrossSigningIdentity>> {
1083 self.inner.identity.clone()
1084 }
1085
1086 pub(crate) async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
1088 let changes = Changes { sessions: sessions.to_vec(), ..Default::default() };
1089
1090 self.save_changes(changes).await
1091 }
1092
1093 pub(crate) async fn get_sessions(
1094 &self,
1095 sender_key: &str,
1096 ) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
1097 self.inner.store.get_sessions(sender_key).await
1098 }
1099
1100 pub(crate) async fn save_changes(&self, changes: Changes) -> Result<()> {
1101 self.inner.store.save_changes(changes).await
1102 }
1103
1104 pub(crate) async fn compare_group_session(
1111 &self,
1112 session: &InboundGroupSession,
1113 ) -> Result<SessionOrdering> {
1114 let old_session = self
1115 .inner
1116 .store
1117 .get_inbound_group_session(session.room_id(), session.session_id())
1118 .await?;
1119
1120 Ok(if let Some(old_session) = old_session {
1121 session.compare(&old_session).await
1122 } else {
1123 SessionOrdering::Better
1124 })
1125 }
1126
1127 #[cfg(test)]
1128 pub(crate) async fn save_device_data(&self, devices: &[DeviceData]) -> Result<()> {
1130 let changes = Changes {
1131 devices: DeviceChanges { changed: devices.to_vec(), ..Default::default() },
1132 ..Default::default()
1133 };
1134
1135 self.save_changes(changes).await
1136 }
1137
1138 pub(crate) async fn save_inbound_group_sessions(
1140 &self,
1141 sessions: &[InboundGroupSession],
1142 ) -> Result<()> {
1143 let changes = Changes { inbound_group_sessions: sessions.to_vec(), ..Default::default() };
1144
1145 self.save_changes(changes).await
1146 }
1147
1148 pub(crate) async fn device_display_name(&self) -> Result<Option<String>, CryptoStoreError> {
1150 Ok(self
1151 .inner
1152 .store
1153 .get_device(self.user_id(), self.device_id())
1154 .await?
1155 .and_then(|d| d.display_name().map(|d| d.to_owned())))
1156 }
1157
1158 pub(crate) async fn get_device_data(
1163 &self,
1164 user_id: &UserId,
1165 device_id: &DeviceId,
1166 ) -> Result<Option<DeviceData>> {
1167 self.inner.store.get_device(user_id, device_id).await
1168 }
1169
1170 pub(crate) async fn get_device_data_for_user_filtered(
1178 &self,
1179 user_id: &UserId,
1180 ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1181 self.inner.store.get_user_devices(user_id).await.map(|mut d| {
1182 if user_id == self.user_id() {
1183 d.remove(self.device_id());
1184 }
1185 d
1186 })
1187 }
1188
1189 pub(crate) async fn get_device_data_for_user(
1198 &self,
1199 user_id: &UserId,
1200 ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1201 self.inner.store.get_user_devices(user_id).await
1202 }
1203
1204 pub(crate) async fn get_device_from_curve_key(
1210 &self,
1211 user_id: &UserId,
1212 curve_key: Curve25519PublicKey,
1213 ) -> Result<Option<Device>> {
1214 self.get_user_devices(user_id)
1215 .await
1216 .map(|d| d.devices().find(|d| d.curve25519_key() == Some(curve_key)))
1217 }
1218
1219 pub(crate) async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
1229 let devices = self.get_device_data_for_user(user_id).await?;
1230
1231 let own_identity = self
1232 .inner
1233 .store
1234 .get_user_identity(self.user_id())
1235 .await?
1236 .and_then(|i| i.own().cloned());
1237 let device_owner_identity = self.inner.store.get_user_identity(user_id).await?;
1238
1239 Ok(UserDevices {
1240 inner: devices,
1241 verification_machine: self.inner.verification_machine.clone(),
1242 own_identity,
1243 device_owner_identity,
1244 })
1245 }
1246
1247 pub(crate) async fn get_device(
1257 &self,
1258 user_id: &UserId,
1259 device_id: &DeviceId,
1260 ) -> Result<Option<Device>> {
1261 if let Some(device_data) = self.inner.store.get_device(user_id, device_id).await? {
1262 Ok(Some(self.wrap_device_data(device_data).await?))
1263 } else {
1264 Ok(None)
1265 }
1266 }
1267
1268 pub(crate) async fn wrap_device_data(&self, device_data: DeviceData) -> Result<Device> {
1273 let own_identity = self
1274 .inner
1275 .store
1276 .get_user_identity(self.user_id())
1277 .await?
1278 .and_then(|i| i.own().cloned());
1279
1280 let device_owner_identity =
1281 self.inner.store.get_user_identity(device_data.user_id()).await?;
1282
1283 Ok(Device {
1284 inner: device_data,
1285 verification_machine: self.inner.verification_machine.clone(),
1286 own_identity,
1287 device_owner_identity,
1288 })
1289 }
1290
1291 pub(crate) async fn get_identity(&self, user_id: &UserId) -> Result<Option<UserIdentity>> {
1293 let own_identity = self
1294 .inner
1295 .store
1296 .get_user_identity(self.user_id())
1297 .await?
1298 .and_then(as_variant!(UserIdentityData::Own));
1299
1300 Ok(self.inner.store.get_user_identity(user_id).await?.map(|i| {
1301 UserIdentity::new(
1302 self.clone(),
1303 i,
1304 self.inner.verification_machine.to_owned(),
1305 own_identity,
1306 )
1307 }))
1308 }
1309
1310 pub async fn export_secret(
1319 &self,
1320 secret_name: &SecretName,
1321 ) -> Result<Option<String>, CryptoStoreError> {
1322 Ok(match secret_name {
1323 SecretName::CrossSigningMasterKey
1324 | SecretName::CrossSigningUserSigningKey
1325 | SecretName::CrossSigningSelfSigningKey => {
1326 self.inner.identity.lock().await.export_secret(secret_name).await
1327 }
1328 SecretName::RecoveryKey => {
1329 if let Some(key) = self.load_backup_keys().await?.decryption_key {
1330 let exported = key.to_base64();
1331 Some(exported)
1332 } else {
1333 None
1334 }
1335 }
1336 name => {
1337 warn!(secret = ?name, "Unknown secret was requested");
1338 None
1339 }
1340 })
1341 }
1342
1343 pub async fn export_cross_signing_keys(
1351 &self,
1352 ) -> Result<Option<CrossSigningKeyExport>, CryptoStoreError> {
1353 let master_key = self.export_secret(&SecretName::CrossSigningMasterKey).await?;
1354 let self_signing_key = self.export_secret(&SecretName::CrossSigningSelfSigningKey).await?;
1355 let user_signing_key = self.export_secret(&SecretName::CrossSigningUserSigningKey).await?;
1356
1357 Ok(if master_key.is_none() && self_signing_key.is_none() && user_signing_key.is_none() {
1358 None
1359 } else {
1360 Some(CrossSigningKeyExport { master_key, self_signing_key, user_signing_key })
1361 })
1362 }
1363
1364 pub async fn import_cross_signing_keys(
1369 &self,
1370 export: CrossSigningKeyExport,
1371 ) -> Result<CrossSigningStatus, SecretImportError> {
1372 if let Some(public_identity) =
1373 self.get_identity(self.user_id()).await?.and_then(|i| i.own())
1374 {
1375 let identity = self.inner.identity.lock().await;
1376
1377 identity
1378 .import_secrets(
1379 public_identity.to_owned(),
1380 export.master_key.as_deref(),
1381 export.self_signing_key.as_deref(),
1382 export.user_signing_key.as_deref(),
1383 )
1384 .await?;
1385
1386 let status = identity.status().await;
1387
1388 let diff = identity.get_public_identity_diff(&public_identity.inner).await;
1389
1390 let mut changes =
1391 Changes { private_identity: Some(identity.clone()), ..Default::default() };
1392
1393 if diff.none_differ() {
1394 public_identity.mark_as_verified();
1395 changes.identities.changed.push(UserIdentityData::Own(public_identity.inner));
1396 }
1397
1398 info!(?status, "Successfully imported the private cross-signing keys");
1399
1400 self.save_changes(changes).await?;
1401 } else {
1402 warn!("No public identity found while importing cross-signing keys, a /keys/query needs to be done");
1403 }
1404
1405 Ok(self.inner.identity.lock().await.status().await)
1406 }
1407
1408 pub async fn export_secrets_bundle(&self) -> Result<SecretsBundle, SecretsBundleExportError> {
1420 let Some(cross_signing) = self.export_cross_signing_keys().await? else {
1421 return Err(SecretsBundleExportError::MissingCrossSigningKeys);
1422 };
1423
1424 let Some(master_key) = cross_signing.master_key.clone() else {
1425 return Err(SecretsBundleExportError::MissingCrossSigningKey(KeyUsage::Master));
1426 };
1427
1428 let Some(user_signing_key) = cross_signing.user_signing_key.clone() else {
1429 return Err(SecretsBundleExportError::MissingCrossSigningKey(KeyUsage::UserSigning));
1430 };
1431
1432 let Some(self_signing_key) = cross_signing.self_signing_key.clone() else {
1433 return Err(SecretsBundleExportError::MissingCrossSigningKey(KeyUsage::SelfSigning));
1434 };
1435
1436 let backup_keys = self.load_backup_keys().await?;
1437
1438 let backup = if let Some(key) = backup_keys.decryption_key {
1439 if let Some(backup_version) = backup_keys.backup_version {
1440 Some(BackupSecrets::MegolmBackupV1Curve25519AesSha2(
1441 MegolmBackupV1Curve25519AesSha2Secrets { key, backup_version },
1442 ))
1443 } else {
1444 return Err(SecretsBundleExportError::MissingBackupVersion);
1445 }
1446 } else {
1447 None
1448 };
1449
1450 Ok(SecretsBundle {
1451 cross_signing: CrossSigningSecrets { master_key, user_signing_key, self_signing_key },
1452 backup,
1453 })
1454 }
1455
1456 pub async fn import_secrets_bundle(
1469 &self,
1470 bundle: &SecretsBundle,
1471 ) -> Result<(), SecretImportError> {
1472 let mut changes = Changes::default();
1473
1474 if let Some(backup_bundle) = &bundle.backup {
1475 match backup_bundle {
1476 BackupSecrets::MegolmBackupV1Curve25519AesSha2(bundle) => {
1477 changes.backup_decryption_key = Some(bundle.key.clone());
1478 changes.backup_version = Some(bundle.backup_version.clone());
1479 }
1480 }
1481 }
1482
1483 let identity = self.inner.identity.lock().await;
1484
1485 identity
1486 .import_secrets_unchecked(
1487 Some(&bundle.cross_signing.master_key),
1488 Some(&bundle.cross_signing.self_signing_key),
1489 Some(&bundle.cross_signing.user_signing_key),
1490 )
1491 .await?;
1492
1493 let public_identity = identity.to_public_identity().await.expect(
1494 "We should be able to create a new public identity since we just imported \
1495 all the private cross-signing keys",
1496 );
1497
1498 changes.private_identity = Some(identity.clone());
1499 changes.identities.new.push(UserIdentityData::Own(public_identity));
1500
1501 Ok(self.save_changes(changes).await?)
1502 }
1503
1504 pub async fn import_secret(&self, secret: &GossippedSecret) -> Result<(), SecretImportError> {
1506 match &secret.secret_name {
1507 SecretName::CrossSigningMasterKey
1508 | SecretName::CrossSigningUserSigningKey
1509 | SecretName::CrossSigningSelfSigningKey => {
1510 if let Some(public_identity) =
1511 self.get_identity(self.user_id()).await?.and_then(|i| i.own())
1512 {
1513 let identity = self.inner.identity.lock().await;
1514
1515 identity
1516 .import_secret(
1517 public_identity,
1518 &secret.secret_name,
1519 &secret.event.content.secret,
1520 )
1521 .await?;
1522 info!(
1523 secret_name = ?secret.secret_name,
1524 "Successfully imported a private cross signing key"
1525 );
1526
1527 let changes =
1528 Changes { private_identity: Some(identity.clone()), ..Default::default() };
1529
1530 self.save_changes(changes).await?;
1531 }
1532 }
1533 SecretName::RecoveryKey => {
1534 }
1540 name => {
1541 warn!(secret = ?name, "Tried to import an unknown secret");
1542 }
1543 }
1544
1545 Ok(())
1546 }
1547
1548 pub async fn get_only_allow_trusted_devices(&self) -> Result<bool> {
1551 let value = self.get_value("only_allow_trusted_devices").await?.unwrap_or_default();
1552 Ok(value)
1553 }
1554
1555 pub async fn set_only_allow_trusted_devices(
1558 &self,
1559 block_untrusted_devices: bool,
1560 ) -> Result<()> {
1561 self.set_value("only_allow_trusted_devices", &block_untrusted_devices).await
1562 }
1563
1564 pub async fn get_value<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
1566 let Some(value) = self.get_custom_value(key).await? else {
1567 return Ok(None);
1568 };
1569 let deserialized = self.deserialize_value(&value)?;
1570 Ok(Some(deserialized))
1571 }
1572
1573 pub async fn set_value(&self, key: &str, value: &impl Serialize) -> Result<()> {
1575 let serialized = self.serialize_value(value)?;
1576 self.set_custom_value(key, serialized).await?;
1577 Ok(())
1578 }
1579
1580 fn serialize_value(&self, value: &impl Serialize) -> Result<Vec<u8>> {
1581 let serialized =
1582 rmp_serde::to_vec_named(value).map_err(|x| CryptoStoreError::Backend(x.into()))?;
1583 Ok(serialized)
1584 }
1585
1586 fn deserialize_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T> {
1587 let deserialized =
1588 rmp_serde::from_slice(value).map_err(|e| CryptoStoreError::Backend(e.into()))?;
1589 Ok(deserialized)
1590 }
1591
1592 pub fn room_keys_received_stream(
1604 &self,
1605 ) -> impl Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>> {
1606 self.inner.store.room_keys_received_stream()
1607 }
1608
1609 pub fn room_keys_withheld_received_stream(
1618 &self,
1619 ) -> impl Stream<Item = Vec<RoomKeyWithheldInfo>> {
1620 self.inner.store.room_keys_withheld_received_stream()
1621 }
1622
1623 pub fn user_identities_stream(&self) -> impl Stream<Item = IdentityUpdates> {
1654 let verification_machine = self.inner.verification_machine.to_owned();
1655
1656 let this = self.clone();
1657 self.inner.store.identities_stream().map(move |(own_identity, identities, _)| {
1658 let (new_identities, changed_identities, unchanged_identities) = identities.into_maps();
1659
1660 let map_identity = |(user_id, identity)| {
1661 (
1662 user_id,
1663 UserIdentity::new(
1664 this.clone(),
1665 identity,
1666 verification_machine.to_owned(),
1667 own_identity.to_owned(),
1668 ),
1669 )
1670 };
1671
1672 let new = new_identities.into_iter().map(map_identity).collect();
1673 let changed = changed_identities.into_iter().map(map_identity).collect();
1674 let unchanged = unchanged_identities.into_iter().map(map_identity).collect();
1675
1676 IdentityUpdates { new, changed, unchanged }
1677 })
1678 }
1679
1680 pub fn devices_stream(&self) -> impl Stream<Item = DeviceUpdates> {
1712 let verification_machine = self.inner.verification_machine.to_owned();
1713
1714 self.inner.store.identities_stream().map(move |(own_identity, identities, devices)| {
1715 collect_device_updates(
1716 verification_machine.to_owned(),
1717 own_identity,
1718 identities,
1719 devices,
1720 )
1721 })
1722 }
1723
1724 pub fn identities_stream_raw(&self) -> impl Stream<Item = (IdentityChanges, DeviceChanges)> {
1734 self.inner.store.identities_stream().map(|(_, identities, devices)| (identities, devices))
1735 }
1736
1737 pub fn create_store_lock(
1740 &self,
1741 lock_key: String,
1742 lock_value: String,
1743 ) -> CrossProcessStoreLock<LockableCryptoStore> {
1744 self.inner.store.create_store_lock(lock_key, lock_value)
1745 }
1746
1747 pub fn secrets_stream(&self) -> impl Stream<Item = GossippedSecret> {
1787 self.inner.store.secrets_stream()
1788 }
1789
1790 pub async fn import_room_keys(
1803 &self,
1804 exported_keys: Vec<ExportedRoomKey>,
1805 from_backup_version: Option<&str>,
1806 progress_listener: impl Fn(usize, usize),
1807 ) -> Result<RoomKeyImportResult> {
1808 let mut sessions = Vec::new();
1809
1810 async fn new_session_better(
1811 session: &InboundGroupSession,
1812 old_session: Option<InboundGroupSession>,
1813 ) -> bool {
1814 if let Some(old_session) = &old_session {
1815 session.compare(old_session).await == SessionOrdering::Better
1816 } else {
1817 true
1818 }
1819 }
1820
1821 let total_count = exported_keys.len();
1822 let mut keys = BTreeMap::new();
1823
1824 for (i, key) in exported_keys.into_iter().enumerate() {
1825 match InboundGroupSession::from_export(&key) {
1826 Ok(session) => {
1827 let old_session = self
1828 .inner
1829 .store
1830 .get_inbound_group_session(session.room_id(), session.session_id())
1831 .await?;
1832
1833 if new_session_better(&session, old_session).await {
1836 if from_backup_version.is_some() {
1837 session.mark_as_backed_up();
1838 }
1839
1840 keys.entry(session.room_id().to_owned())
1841 .or_insert_with(BTreeMap::new)
1842 .entry(session.sender_key().to_base64())
1843 .or_insert_with(BTreeSet::new)
1844 .insert(session.session_id().to_owned());
1845
1846 sessions.push(session);
1847 }
1848 }
1849 Err(e) => {
1850 warn!(
1851 sender_key= key.sender_key.to_base64(),
1852 room_id = ?key.room_id,
1853 session_id = key.session_id,
1854 error = ?e,
1855 "Couldn't import a room key from a file export."
1856 );
1857 }
1858 }
1859
1860 progress_listener(i, total_count);
1861 }
1862
1863 let imported_count = sessions.len();
1864
1865 self.inner.store.save_inbound_group_sessions(sessions, from_backup_version).await?;
1866
1867 info!(total_count, imported_count, room_keys = ?keys, "Successfully imported room keys");
1868
1869 Ok(RoomKeyImportResult::new(imported_count, total_count, keys))
1870 }
1871
1872 pub async fn import_exported_room_keys(
1899 &self,
1900 exported_keys: Vec<ExportedRoomKey>,
1901 progress_listener: impl Fn(usize, usize),
1902 ) -> Result<RoomKeyImportResult> {
1903 self.import_room_keys(exported_keys, None, progress_listener).await
1904 }
1905
1906 pub(crate) fn crypto_store(&self) -> Arc<CryptoStoreWrapper> {
1907 self.inner.store.clone()
1908 }
1909
1910 pub async fn export_room_keys(
1933 &self,
1934 predicate: impl FnMut(&InboundGroupSession) -> bool,
1935 ) -> Result<Vec<ExportedRoomKey>> {
1936 let mut exported = Vec::new();
1937
1938 let mut sessions = self.get_inbound_group_sessions().await?;
1939 sessions.retain(predicate);
1940
1941 for session in sessions {
1942 let export = session.export().await;
1943 exported.push(export);
1944 }
1945
1946 Ok(exported)
1947 }
1948
1949 pub async fn export_room_keys_stream(
1982 &self,
1983 predicate: impl FnMut(&InboundGroupSession) -> bool,
1984 ) -> Result<impl Stream<Item = ExportedRoomKey>> {
1985 let sessions = self.get_inbound_group_sessions().await?;
1987 Ok(futures_util::stream::iter(sessions.into_iter().filter(predicate))
1988 .then(|session| async move { session.export().await }))
1989 }
1990}
1991
1992impl Deref for Store {
1993 type Target = DynCryptoStore;
1994
1995 fn deref(&self) -> &Self::Target {
1996 self.inner.store.deref().deref()
1997 }
1998}
1999
2000#[derive(Clone, Debug)]
2002pub struct LockableCryptoStore(Arc<dyn CryptoStore<Error = CryptoStoreError>>);
2003
2004#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
2005#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
2006impl matrix_sdk_common::store_locks::BackingStore for LockableCryptoStore {
2007 type LockError = CryptoStoreError;
2008
2009 async fn try_lock(
2010 &self,
2011 lease_duration_ms: u32,
2012 key: &str,
2013 holder: &str,
2014 ) -> std::result::Result<bool, Self::LockError> {
2015 self.0.try_take_leased_lock(lease_duration_ms, key, holder).await
2016 }
2017}
2018
2019#[cfg(test)]
2020mod tests {
2021 use std::pin::pin;
2022
2023 use futures_util::StreamExt;
2024 use matrix_sdk_test::async_test;
2025 use ruma::{room_id, user_id};
2026
2027 use crate::{
2028 machine::test_helpers::get_machine_pair, store::DehydratedDeviceKey,
2029 types::EventEncryptionAlgorithm,
2030 };
2031
2032 #[async_test]
2033 async fn test_import_room_keys_notifies_stream() {
2034 use futures_util::FutureExt;
2035
2036 let (alice, bob, _) =
2037 get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2038
2039 let room1_id = room_id!("!room1:localhost");
2040 alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2041 let exported_sessions = alice.store().export_room_keys(|_| true).await.unwrap();
2042
2043 let mut room_keys_received_stream = Box::pin(bob.store().room_keys_received_stream());
2044 bob.store().import_room_keys(exported_sessions, None, |_, _| {}).await.unwrap();
2045
2046 let room_keys = room_keys_received_stream
2047 .next()
2048 .now_or_never()
2049 .flatten()
2050 .expect("We should have received an update of room key infos")
2051 .unwrap();
2052 assert_eq!(room_keys.len(), 1);
2053 assert_eq!(room_keys[0].room_id, "!room1:localhost");
2054 }
2055
2056 #[async_test]
2057 async fn test_export_room_keys_provides_selected_keys() {
2058 let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2060 let room1_id = room_id!("!room1:localhost");
2061 let room2_id = room_id!("!room2:localhost");
2062 let room3_id = room_id!("!room3:localhost");
2063 alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2064 alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap();
2065 alice.create_outbound_group_session_with_defaults_test_helper(room3_id).await.unwrap();
2066
2067 let keys = alice
2069 .store()
2070 .export_room_keys(|s| s.room_id() == room2_id || s.room_id() == room3_id)
2071 .await
2072 .unwrap();
2073
2074 assert_eq!(keys.len(), 2);
2076 assert_eq!(keys[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2077 assert_eq!(keys[1].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2078 assert_eq!(keys[0].room_id, "!room2:localhost");
2079 assert_eq!(keys[1].room_id, "!room3:localhost");
2080 assert_eq!(keys[0].session_key.to_base64().len(), 220);
2081 assert_eq!(keys[1].session_key.to_base64().len(), 220);
2082 }
2083
2084 #[async_test]
2085 async fn test_export_room_keys_stream_can_provide_all_keys() {
2086 let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2088 let room1_id = room_id!("!room1:localhost");
2089 let room2_id = room_id!("!room2:localhost");
2090 alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2091 alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap();
2092
2093 let mut keys = pin!(alice.store().export_room_keys_stream(|_| true).await.unwrap());
2095
2096 let mut collected = vec![];
2098 while let Some(key) = keys.next().await {
2099 collected.push(key);
2100 }
2101
2102 assert_eq!(collected.len(), 2);
2104 assert_eq!(collected[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2105 assert_eq!(collected[1].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2106 assert_eq!(collected[0].room_id, "!room1:localhost");
2107 assert_eq!(collected[1].room_id, "!room2:localhost");
2108 assert_eq!(collected[0].session_key.to_base64().len(), 220);
2109 assert_eq!(collected[1].session_key.to_base64().len(), 220);
2110 }
2111
2112 #[async_test]
2113 async fn test_export_room_keys_stream_can_provide_a_subset_of_keys() {
2114 let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2116 let room1_id = room_id!("!room1:localhost");
2117 let room2_id = room_id!("!room2:localhost");
2118 alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2119 alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap();
2120
2121 let mut keys =
2123 pin!(alice.store().export_room_keys_stream(|s| s.room_id() == room1_id).await.unwrap());
2124
2125 let mut collected = vec![];
2127 while let Some(key) = keys.next().await {
2128 collected.push(key);
2129 }
2130
2131 assert_eq!(collected.len(), 1);
2133 assert_eq!(collected[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2134 assert_eq!(collected[0].room_id, "!room1:localhost");
2135 assert_eq!(collected[0].session_key.to_base64().len(), 220);
2136 }
2137
2138 #[async_test]
2139 async fn test_export_secrets_bundle() {
2140 let user_id = user_id!("@alice:example.com");
2141 let (first, second, _) = get_machine_pair(user_id, user_id, false).await;
2142
2143 let _ = first
2144 .bootstrap_cross_signing(false)
2145 .await
2146 .expect("We should be able to bootstrap cross-signing");
2147
2148 let bundle = first.store().export_secrets_bundle().await.expect(
2149 "We should be able to export the secrets bundle, now that we \
2150 have the cross-signing keys",
2151 );
2152
2153 assert!(bundle.backup.is_none(), "The bundle should not contain a backup key");
2154
2155 second
2156 .store()
2157 .import_secrets_bundle(&bundle)
2158 .await
2159 .expect("We should be able to import the secrets bundle");
2160
2161 let status = second.cross_signing_status().await;
2162 let identity = second.get_identity(user_id, None).await.unwrap().unwrap().own().unwrap();
2163
2164 assert!(identity.is_verified(), "The public identity should be marked as verified.");
2165
2166 assert!(status.is_complete(), "We should have imported all the cross-signing keys");
2167 }
2168
2169 #[async_test]
2170 async fn test_create_dehydrated_device_key() {
2171 let pickle_key = DehydratedDeviceKey::new()
2172 .expect("Should be able to create a random dehydrated device key");
2173
2174 let to_vec = pickle_key.inner.to_vec();
2175 let pickle_key_from_slice = DehydratedDeviceKey::from_slice(to_vec.as_slice())
2176 .expect("Should be able to create a dehydrated device key from slice");
2177
2178 assert_eq!(pickle_key_from_slice.to_base64(), pickle_key.to_base64());
2179 }
2180
2181 #[async_test]
2182 async fn test_create_dehydrated_errors() {
2183 let too_small = [0u8; 22];
2184 let pickle_key = DehydratedDeviceKey::from_slice(&too_small);
2185
2186 assert!(pickle_key.is_err());
2187
2188 let too_big = [0u8; 40];
2189 let pickle_key = DehydratedDeviceKey::from_slice(&too_big);
2190
2191 assert!(pickle_key.is_err());
2192 }
2193}