Skip to main content

matrix_sdk_crypto/store/
memorystore.rs

1// Copyright 2020, 2026 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{BTreeMap, HashMap, HashSet},
17    convert::Infallible,
18    sync::Arc,
19};
20
21use async_trait::async_trait;
22use matrix_sdk_common::{
23    cross_process_lock::{
24        CrossProcessLockGeneration,
25        memory_store_helper::{Lease, try_take_leased_lock},
26    },
27    locks::RwLock as StdRwLock,
28};
29use ruma::{
30    DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId,
31    UserId, events::secret::request::SecretName,
32};
33use tokio::sync::{Mutex, RwLock};
34use tracing::warn;
35use vodozemac::Curve25519PublicKey;
36use zeroize::Zeroizing;
37
38use super::{
39    Account, CryptoStore, InboundGroupSession, Session,
40    caches::DeviceStore,
41    types::{
42        BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts, RoomSettings,
43        StoredRoomKeyBundleData, TrackedUser,
44    },
45};
46use crate::{
47    gossiping::{GossipRequest, SecretInfo},
48    identities::{DeviceData, UserIdentityData},
49    olm::{
50        OutboundGroupSession, PickledAccount, PickledInboundGroupSession, PickledSession,
51        PrivateCrossSigningIdentity, SenderDataType, StaticAccountData,
52    },
53    store::types::{RoomKeyWithheldEntry, RoomPendingKeyBundleDetails},
54};
55
56fn encode_key_info(info: &SecretInfo) -> String {
57    match info {
58        SecretInfo::KeyRequest(info) => {
59            format!("{}{}{}", info.room_id(), info.algorithm(), info.session_id())
60        }
61        SecretInfo::SecretRequest(i) => i.as_ref().to_owned(),
62    }
63}
64
65type SessionId = String;
66
67/// The "version" of a backup - newtype wrapper around a String.
68#[derive(Clone, Debug, PartialEq)]
69struct BackupVersion(String);
70
71impl BackupVersion {
72    fn from(s: &str) -> Self {
73        Self(s.to_owned())
74    }
75
76    fn as_str(&self) -> &str {
77        &self.0
78    }
79}
80
81/// An in-memory only store that will forget all the E2EE key once it's dropped.
82#[derive(Default, Debug)]
83pub struct MemoryStore {
84    static_account: Arc<StdRwLock<Option<StaticAccountData>>>,
85
86    account: StdRwLock<Option<String>>,
87    // Map of sender_key to map of session_id to serialized pickle
88    sessions: StdRwLock<BTreeMap<String, BTreeMap<String, String>>>,
89    inbound_group_sessions: StdRwLock<BTreeMap<OwnedRoomId, HashMap<String, String>>>,
90
91    /// Map room id -> session id -> backup order number
92    /// The latest backup in which this session is stored. Equivalent to
93    /// `backed_up_to` in [`IndexedDbCryptoStore`]
94    inbound_group_sessions_backed_up_to:
95        StdRwLock<HashMap<OwnedRoomId, HashMap<SessionId, BackupVersion>>>,
96
97    outbound_group_sessions: StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>,
98    private_identity: StdRwLock<Option<PrivateCrossSigningIdentity>>,
99    tracked_users: StdRwLock<HashMap<OwnedUserId, TrackedUser>>,
100    olm_hashes: StdRwLock<HashMap<String, HashSet<String>>>,
101    devices: DeviceStore,
102    identities: StdRwLock<HashMap<OwnedUserId, String>>,
103    outgoing_key_requests: StdRwLock<HashMap<OwnedTransactionId, GossipRequest>>,
104    key_requests_by_info: StdRwLock<HashMap<String, OwnedTransactionId>>,
105    direct_withheld_info: StdRwLock<HashMap<OwnedRoomId, HashMap<String, RoomKeyWithheldEntry>>>,
106    custom_values: StdRwLock<HashMap<String, Vec<u8>>>,
107    leases: StdRwLock<HashMap<String, Lease>>,
108    secret_inbox: StdRwLock<HashMap<String, Vec<Zeroizing<String>>>>,
109    backup_keys: RwLock<BackupKeys>,
110    dehydrated_device_pickle_key: RwLock<Option<DehydratedDeviceKey>>,
111    next_batch_token: RwLock<Option<String>>,
112    room_settings: StdRwLock<HashMap<OwnedRoomId, RoomSettings>>,
113    room_key_bundles:
114        StdRwLock<HashMap<OwnedRoomId, HashMap<OwnedUserId, StoredRoomKeyBundleData>>>,
115    room_key_backups_fully_downloaded: StdRwLock<HashSet<OwnedRoomId>>,
116    rooms_pending_key_bundle: StdRwLock<HashMap<OwnedRoomId, RoomPendingKeyBundleDetails>>,
117
118    save_changes_lock: Arc<Mutex<()>>,
119}
120
121impl MemoryStore {
122    /// Create a new empty `MemoryStore`.
123    pub fn new() -> Self {
124        Self::default()
125    }
126
127    fn get_static_account(&self) -> Option<StaticAccountData> {
128        self.static_account.read().clone()
129    }
130
131    pub(crate) fn save_devices(&self, devices: Vec<DeviceData>) {
132        for device in devices {
133            let _ = self.devices.add(device);
134        }
135    }
136
137    fn delete_devices(&self, devices: Vec<DeviceData>) {
138        for device in devices {
139            let _ = self.devices.remove(device.user_id(), device.device_id());
140        }
141    }
142
143    fn save_sessions(&self, sessions: Vec<(String, PickledSession)>) {
144        let mut session_store = self.sessions.write();
145
146        for (session_id, pickle) in sessions {
147            let entry = session_store.entry(pickle.sender_key.to_base64()).or_default();
148
149            // insert or replace if exists
150            entry.insert(
151                session_id,
152                serde_json::to_string(&pickle).expect("Failed to serialize olm session"),
153            );
154        }
155    }
156
157    fn save_outbound_group_sessions(&self, sessions: Vec<OutboundGroupSession>) {
158        self.outbound_group_sessions
159            .write()
160            .extend(sessions.into_iter().map(|s| (s.room_id().to_owned(), s)));
161    }
162
163    fn save_private_identity(&self, private_identity: Option<PrivateCrossSigningIdentity>) {
164        *self.private_identity.write() = private_identity;
165    }
166
167    /// Return all the [`InboundGroupSession`]s we have, paired with the
168    /// `backed_up_to` value for each one (or "" where it is missing, which
169    /// should never happen).
170    async fn get_inbound_group_sessions_and_backed_up_to(
171        &self,
172    ) -> Result<Vec<(InboundGroupSession, Option<BackupVersion>)>> {
173        let lookup = |s: &InboundGroupSession| {
174            self.inbound_group_sessions_backed_up_to
175                .read()
176                .get(&s.room_id)?
177                .get(s.session_id())
178                .cloned()
179        };
180
181        Ok(self
182            .get_inbound_group_sessions()
183            .await?
184            .into_iter()
185            .map(|s| {
186                let v = lookup(&s);
187                (s, v)
188            })
189            .collect())
190    }
191}
192
193type Result<T> = std::result::Result<T, Infallible>;
194
195#[cfg_attr(target_family = "wasm", async_trait(?Send))]
196#[cfg_attr(not(target_family = "wasm"), async_trait)]
197impl CryptoStore for MemoryStore {
198    type Error = Infallible;
199
200    async fn close(&self) -> Result<()> {
201        Ok(())
202    }
203
204    async fn reopen(&self) -> Result<()> {
205        Ok(())
206    }
207
208    async fn load_account(&self) -> Result<Option<Account>> {
209        let pickled_account: Option<PickledAccount> = self.account.read().as_ref().map(|acc| {
210            serde_json::from_str(acc)
211                .expect("Deserialization failed: invalid pickled account JSON format")
212        });
213
214        if let Some(pickle) = pickled_account {
215            let account =
216                Account::from_pickle(pickle).expect("From pickle failed: invalid pickle format");
217
218            *self.static_account.write() = Some(account.static_data().clone());
219
220            Ok(Some(account))
221        } else {
222            Ok(None)
223        }
224    }
225
226    async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
227        Ok(self.private_identity.read().clone())
228    }
229
230    async fn next_batch_token(&self) -> Result<Option<String>> {
231        Ok(self.next_batch_token.read().await.clone())
232    }
233
234    async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
235        let _guard = self.save_changes_lock.lock().await;
236
237        let pickled_account = if let Some(account) = changes.account {
238            *self.static_account.write() = Some(account.static_data().clone());
239            Some(account.pickle())
240        } else {
241            None
242        };
243
244        *self.account.write() = pickled_account.map(|pickle| {
245            serde_json::to_string(&pickle)
246                .expect("Serialization failed: invalid pickled account JSON format")
247        });
248
249        Ok(())
250    }
251
252    async fn save_changes(&self, changes: Changes) -> Result<()> {
253        let _guard = self.save_changes_lock.lock().await;
254
255        let mut pickled_session: Vec<(String, PickledSession)> = Vec::new();
256        for session in changes.sessions {
257            let session_id = session.session_id().to_owned();
258            let pickle = session.pickle().await;
259            pickled_session.push((session_id.clone(), pickle));
260        }
261        self.save_sessions(pickled_session);
262
263        self.save_inbound_group_sessions(changes.inbound_group_sessions, None).await?;
264        self.save_outbound_group_sessions(changes.outbound_group_sessions);
265        self.save_private_identity(changes.private_identity);
266
267        self.save_devices(changes.devices.new);
268        self.save_devices(changes.devices.changed);
269        self.delete_devices(changes.devices.deleted);
270
271        {
272            let mut identities = self.identities.write();
273            for identity in changes.identities.new.into_iter().chain(changes.identities.changed) {
274                identities.insert(
275                    identity.user_id().to_owned(),
276                    serde_json::to_string(&identity)
277                        .expect("UserIdentityData should always serialize to json"),
278                );
279            }
280        }
281
282        {
283            let mut olm_hashes = self.olm_hashes.write();
284            for hash in changes.message_hashes {
285                olm_hashes.entry(hash.sender_key.to_owned()).or_default().insert(hash.hash.clone());
286            }
287        }
288
289        {
290            let mut outgoing_key_requests = self.outgoing_key_requests.write();
291            let mut key_requests_by_info = self.key_requests_by_info.write();
292
293            for key_request in changes.key_requests {
294                let id = key_request.request_id.clone();
295                let info_string = encode_key_info(&key_request.info);
296
297                outgoing_key_requests.insert(id.clone(), key_request);
298                key_requests_by_info.insert(info_string, id);
299            }
300        }
301
302        if let Some(key) = changes.backup_decryption_key {
303            self.backup_keys.write().await.decryption_key = Some(key);
304        }
305
306        if let Some(version) = changes.backup_version {
307            self.backup_keys.write().await.backup_version = Some(version);
308        }
309
310        if let Some(pickle_key) = changes.dehydrated_device_pickle_key {
311            let mut lock = self.dehydrated_device_pickle_key.write().await;
312            *lock = Some(pickle_key);
313        }
314
315        {
316            let mut secret_inbox = self.secret_inbox.write();
317            for secret in changes.secrets {
318                secret_inbox.entry(secret.secret_name.to_string()).or_default().push(secret.secret);
319            }
320        }
321
322        {
323            let mut direct_withheld_info = self.direct_withheld_info.write();
324            for (room_id, data) in changes.withheld_session_info {
325                for (session_id, event) in data {
326                    direct_withheld_info
327                        .entry(room_id.to_owned())
328                        .or_default()
329                        .insert(session_id, event);
330                }
331            }
332        }
333
334        if let Some(next_batch_token) = changes.next_batch_token {
335            *self.next_batch_token.write().await = Some(next_batch_token);
336        }
337
338        if !changes.room_settings.is_empty() {
339            let mut settings = self.room_settings.write();
340            settings.extend(changes.room_settings);
341        }
342
343        if !changes.received_room_key_bundles.is_empty() {
344            let mut room_key_bundles = self.room_key_bundles.write();
345            for bundle in changes.received_room_key_bundles {
346                room_key_bundles
347                    .entry(bundle.bundle_data.room_id.clone())
348                    .or_default()
349                    .insert(bundle.sender_user.clone(), bundle);
350            }
351        }
352
353        if !changes.room_key_backups_fully_downloaded.is_empty() {
354            let mut room_key_backups_fully_downloaded =
355                self.room_key_backups_fully_downloaded.write();
356            for room in changes.room_key_backups_fully_downloaded {
357                room_key_backups_fully_downloaded.insert(room);
358            }
359        }
360
361        if !changes.rooms_pending_key_bundle.is_empty() {
362            let mut lock = self.rooms_pending_key_bundle.write();
363            for (room, details) in changes.rooms_pending_key_bundle {
364                if let Some(details) = details {
365                    lock.insert(room, details);
366                } else {
367                    lock.remove(&room);
368                }
369            }
370        }
371
372        Ok(())
373    }
374
375    async fn save_inbound_group_sessions(
376        &self,
377        sessions: Vec<InboundGroupSession>,
378        backed_up_to_version: Option<&str>,
379    ) -> Result<()> {
380        for session in sessions {
381            let room_id = session.room_id();
382            let session_id = session.session_id();
383
384            // Sanity-check that the data in the sessions corresponds to backed_up_version
385            let backed_up = session.backed_up();
386            if backed_up != backed_up_to_version.is_some() {
387                warn!(
388                    backed_up,
389                    backed_up_to_version,
390                    "Session backed-up flag does not correspond to backup version setting",
391                );
392            }
393
394            if let Some(backup_version) = backed_up_to_version {
395                self.inbound_group_sessions_backed_up_to
396                    .write()
397                    .entry(room_id.to_owned())
398                    .or_default()
399                    .insert(session_id.to_owned(), BackupVersion::from(backup_version));
400            }
401
402            let pickle = session.pickle().await;
403            self.inbound_group_sessions
404                .write()
405                .entry(session.room_id().to_owned())
406                .or_default()
407                .insert(
408                    session.session_id().to_owned(),
409                    serde_json::to_string(&pickle)
410                        .expect("Pickle pickle data should serialize to json"),
411                );
412        }
413        Ok(())
414    }
415
416    async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
417        let device_keys = self.get_own_device().await?.as_device_keys().clone();
418
419        if let Some(pickles) = self.sessions.read().get(sender_key) {
420            let mut sessions: Vec<Session> = Vec::new();
421            for serialized_pickle in pickles.values() {
422                let pickle: PickledSession = serde_json::from_str(serialized_pickle.as_str())
423                    .expect("Pickle pickle deserialization should work");
424                let session = Session::from_pickle(device_keys.clone(), pickle)
425                    .expect("Expect from pickle to always work");
426                sessions.push(session);
427            }
428            Ok(Some(sessions))
429        } else {
430            Ok(None)
431        }
432    }
433
434    async fn get_inbound_group_session(
435        &self,
436        room_id: &RoomId,
437        session_id: &str,
438    ) -> Result<Option<InboundGroupSession>> {
439        let pickle: Option<PickledInboundGroupSession> = self
440            .inbound_group_sessions
441            .read()
442            .get(room_id)
443            .and_then(|m| m.get(session_id))
444            .and_then(|ser| {
445                serde_json::from_str(ser).expect("Pickle pickle deserialization should work")
446            });
447
448        Ok(pickle.map(|p| {
449            InboundGroupSession::from_pickle(p).expect("Expect from pickle to always work")
450        }))
451    }
452
453    async fn get_withheld_info(
454        &self,
455        room_id: &RoomId,
456        session_id: &str,
457    ) -> Result<Option<RoomKeyWithheldEntry>> {
458        Ok(self
459            .direct_withheld_info
460            .read()
461            .get(room_id)
462            .and_then(|e| Some(e.get(session_id)?.to_owned())))
463    }
464
465    async fn get_withheld_sessions_by_room_id(
466        &self,
467        room_id: &RoomId,
468    ) -> crate::store::Result<Vec<RoomKeyWithheldEntry>, Self::Error> {
469        Ok(self
470            .direct_withheld_info
471            .read()
472            .get(room_id)
473            .map(|e| e.values().cloned().collect())
474            .unwrap_or_default())
475    }
476
477    async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
478        let inbounds = self
479            .inbound_group_sessions
480            .read()
481            .values()
482            .flat_map(HashMap::values)
483            .map(|ser| {
484                let pickle: PickledInboundGroupSession =
485                    serde_json::from_str(ser).expect("Pickle deserialization should work");
486                InboundGroupSession::from_pickle(pickle).expect("Expect from pickle to always work")
487            })
488            .collect();
489        Ok(inbounds)
490    }
491
492    async fn inbound_group_session_counts(
493        &self,
494        backup_version: Option<&str>,
495    ) -> Result<RoomKeyCounts> {
496        let backed_up = if let Some(backup_version) = backup_version {
497            self.get_inbound_group_sessions_and_backed_up_to()
498                .await?
499                .into_iter()
500                // Count the sessions backed up in the required backup
501                .filter(|(_, o)| o.as_ref().is_some_and(|o| o.as_str() == backup_version))
502                .count()
503        } else {
504            // We asked about a nonexistent backup version - this doesn't make much sense,
505            // but we can easily answer that nothing is backed up in this
506            // nonexistent backup.
507            0
508        };
509
510        let total = self.inbound_group_sessions.read().values().map(HashMap::len).sum();
511        Ok(RoomKeyCounts { total, backed_up })
512    }
513
514    async fn get_inbound_group_sessions_by_room_id(
515        &self,
516        room_id: &RoomId,
517    ) -> Result<Vec<InboundGroupSession>> {
518        let inbounds = match self.inbound_group_sessions.read().get(room_id) {
519            None => Vec::new(),
520            Some(v) => v
521                .values()
522                .map(|ser| {
523                    let pickle: PickledInboundGroupSession =
524                        serde_json::from_str(ser).expect("Pickle deserialization should work");
525                    InboundGroupSession::from_pickle(pickle)
526                        .expect("Expect from pickle to always work")
527                })
528                .collect(),
529        };
530        Ok(inbounds)
531    }
532
533    async fn get_inbound_group_sessions_for_device_batch(
534        &self,
535        sender_key: Curve25519PublicKey,
536        sender_data_type: SenderDataType,
537        after_session_id: Option<String>,
538        limit: usize,
539    ) -> Result<Vec<InboundGroupSession>> {
540        // First, find all InboundGroupSessions, filtering for those that match the
541        // device and sender_data type.
542        let mut sessions: Vec<_> = self
543            .get_inbound_group_sessions()
544            .await?
545            .into_iter()
546            .filter(|session: &InboundGroupSession| {
547                session.creator_info.curve25519_key == sender_key
548                    && session.sender_data.to_type() == sender_data_type
549            })
550            .collect();
551
552        // Then, sort the sessions in order of ascending session ID...
553        sessions.sort_by_key(|s| s.session_id().to_owned());
554
555        // Figure out where in the array to start returning results from
556        let start_index = {
557            match after_session_id {
558                None => 0,
559                Some(id) => {
560                    // We're looking for the first session with a session ID strictly after `id`; if
561                    // there are none, the end of the array.
562                    sessions
563                        .iter()
564                        .position(|session| session.session_id() > id.as_str())
565                        .unwrap_or(sessions.len())
566                }
567            }
568        };
569
570        // Return up to `limit` items from the array, starting from `start_index`
571        Ok(sessions.drain(start_index..).take(limit).collect())
572    }
573
574    async fn inbound_group_sessions_for_backup(
575        &self,
576        backup_version: &str,
577        limit: usize,
578    ) -> Result<Vec<InboundGroupSession>> {
579        Ok(self
580            .get_inbound_group_sessions_and_backed_up_to()
581            .await?
582            .into_iter()
583            .filter_map(|(session, backed_up_to)| {
584                if let Some(ref existing_version) = backed_up_to
585                    && existing_version.as_str() == backup_version
586                {
587                    // This session is already backed up in the required backup
588                    None
589                } else {
590                    // It's not backed up, or it's backed up in a different backup
591                    Some(session)
592                }
593            })
594            .take(limit)
595            .collect())
596    }
597
598    async fn mark_inbound_group_sessions_as_backed_up(
599        &self,
600        backup_version: &str,
601        room_and_session_ids: &[(&RoomId, &str)],
602    ) -> Result<()> {
603        for &(room_id, session_id) in room_and_session_ids {
604            let session = self.get_inbound_group_session(room_id, session_id).await?;
605
606            if let Some(session) = session {
607                session.mark_as_backed_up();
608
609                self.inbound_group_sessions_backed_up_to
610                    .write()
611                    .entry(room_id.to_owned())
612                    .or_default()
613                    .insert(session_id.to_owned(), BackupVersion::from(backup_version));
614
615                // Save it back
616                let updated_pickle = session.pickle().await;
617
618                self.inbound_group_sessions.write().entry(room_id.to_owned()).or_default().insert(
619                    session_id.to_owned(),
620                    serde_json::to_string(&updated_pickle)
621                        .expect("Pickle serialization should work"),
622                );
623            }
624        }
625
626        Ok(())
627    }
628
629    async fn reset_backup_state(&self) -> Result<()> {
630        // Nothing to do here, because we remember which backup versions we backed up to
631        // in `mark_inbound_group_sessions_as_backed_up`, so we don't need to
632        // reset anything here because the required version is passed in to
633        // `inbound_group_sessions_for_backup`, and we can compare against the
634        // version we stored.
635
636        Ok(())
637    }
638
639    async fn load_backup_keys(&self) -> Result<BackupKeys> {
640        Ok(self.backup_keys.read().await.to_owned())
641    }
642
643    async fn load_dehydrated_device_pickle_key(&self) -> Result<Option<DehydratedDeviceKey>> {
644        Ok(self.dehydrated_device_pickle_key.read().await.to_owned())
645    }
646
647    async fn delete_dehydrated_device_pickle_key(&self) -> Result<()> {
648        let mut lock = self.dehydrated_device_pickle_key.write().await;
649        *lock = None;
650        Ok(())
651    }
652
653    async fn get_outbound_group_session(
654        &self,
655        room_id: &RoomId,
656    ) -> Result<Option<OutboundGroupSession>> {
657        Ok(self.outbound_group_sessions.read().get(room_id).cloned())
658    }
659
660    async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
661        Ok(self.tracked_users.read().values().cloned().collect())
662    }
663
664    async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
665        self.tracked_users.write().extend(tracked_users.iter().map(|(user_id, dirty)| {
666            let user_id: OwnedUserId = user_id.to_owned().into();
667            (user_id.clone(), TrackedUser { user_id, dirty: *dirty })
668        }));
669        Ok(())
670    }
671
672    async fn get_device(
673        &self,
674        user_id: &UserId,
675        device_id: &DeviceId,
676    ) -> Result<Option<DeviceData>> {
677        Ok(self.devices.get(user_id, device_id))
678    }
679
680    async fn get_user_devices(
681        &self,
682        user_id: &UserId,
683    ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
684        Ok(self.devices.user_devices(user_id))
685    }
686
687    async fn get_own_device(&self) -> Result<DeviceData> {
688        let account =
689            self.get_static_account().expect("Expect account to exist when getting own device");
690
691        Ok(self
692            .devices
693            .get(&account.user_id, &account.device_id)
694            .expect("Invalid state: Should always have a own device"))
695    }
696
697    async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
698        let serialized = self.identities.read().get(user_id).cloned();
699        match serialized {
700            None => Ok(None),
701            Some(serialized) => {
702                let id: UserIdentityData = serde_json::from_str(serialized.as_str())
703                    .expect("Only valid serialized identity are saved");
704                Ok(Some(id))
705            }
706        }
707    }
708
709    async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
710        Ok(self
711            .olm_hashes
712            .write()
713            .entry(message_hash.sender_key.to_owned())
714            .or_default()
715            .contains(&message_hash.hash))
716    }
717
718    async fn get_outgoing_secret_requests(
719        &self,
720        request_id: &TransactionId,
721    ) -> Result<Option<GossipRequest>> {
722        Ok(self.outgoing_key_requests.read().get(request_id).cloned())
723    }
724
725    async fn get_secret_request_by_info(
726        &self,
727        key_info: &SecretInfo,
728    ) -> Result<Option<GossipRequest>> {
729        let key_info_string = encode_key_info(key_info);
730
731        Ok(self
732            .key_requests_by_info
733            .read()
734            .get(&key_info_string)
735            .and_then(|i| self.outgoing_key_requests.read().get(i).cloned()))
736    }
737
738    async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
739        Ok(self
740            .outgoing_key_requests
741            .read()
742            .values()
743            .filter(|req| !req.sent_out)
744            .cloned()
745            .collect())
746    }
747
748    async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
749        let req = self.outgoing_key_requests.write().remove(request_id);
750        if let Some(i) = req {
751            let key_info_string = encode_key_info(&i.info);
752            self.key_requests_by_info.write().remove(&key_info_string);
753        }
754
755        Ok(())
756    }
757
758    async fn get_secrets_from_inbox(
759        &self,
760        secret_name: &SecretName,
761    ) -> Result<Vec<Zeroizing<String>>> {
762        Ok(self.secret_inbox.write().entry(secret_name.to_string()).or_default().to_owned())
763    }
764
765    async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
766        self.secret_inbox.write().remove(secret_name.as_str());
767
768        Ok(())
769    }
770
771    async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
772        Ok(self.room_settings.read().get(room_id).cloned())
773    }
774
775    async fn get_received_room_key_bundle_data(
776        &self,
777        room_id: &RoomId,
778        user_id: &UserId,
779    ) -> Result<Option<StoredRoomKeyBundleData>> {
780        let guard = self.room_key_bundles.read();
781
782        let result = guard.get(room_id).and_then(|bundles| bundles.get(user_id).cloned());
783
784        Ok(result)
785    }
786
787    async fn get_pending_key_bundle_details_for_room(
788        &self,
789        room_id: &RoomId,
790    ) -> Result<Option<RoomPendingKeyBundleDetails>> {
791        Ok(self.rooms_pending_key_bundle.read().get(room_id).cloned())
792    }
793
794    async fn get_all_rooms_pending_key_bundles(&self) -> Result<Vec<RoomPendingKeyBundleDetails>> {
795        Ok(self.rooms_pending_key_bundle.read().values().cloned().collect())
796    }
797
798    async fn has_downloaded_all_room_keys(&self, room_id: &RoomId) -> Result<bool> {
799        let guard = self.room_key_backups_fully_downloaded.read();
800        Ok(guard.contains(room_id))
801    }
802
803    async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
804        Ok(self.custom_values.read().get(key).cloned())
805    }
806
807    async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
808        self.custom_values.write().insert(key.to_owned(), value);
809        Ok(())
810    }
811
812    async fn remove_custom_value(&self, key: &str) -> Result<()> {
813        self.custom_values.write().remove(key);
814        Ok(())
815    }
816
817    async fn try_take_leased_lock(
818        &self,
819        lease_duration_ms: u32,
820        key: &str,
821        holder: &str,
822    ) -> Result<Option<CrossProcessLockGeneration>> {
823        Ok(try_take_leased_lock(&mut self.leases.write(), lease_duration_ms, key, holder))
824    }
825
826    async fn get_size(&self) -> Result<Option<usize>> {
827        Ok(None)
828    }
829}
830
831#[cfg(test)]
832mod tests {
833    use std::collections::HashMap;
834
835    use matrix_sdk_test::async_test;
836    use ruma::{RoomId, room_id, user_id};
837    use vodozemac::{Curve25519PublicKey, Ed25519PublicKey};
838
839    use super::SessionId;
840    use crate::{
841        DeviceData,
842        identities::device::testing::get_device,
843        olm::{
844            Account, InboundGroupSession, OlmMessageHash, PrivateCrossSigningIdentity, SenderData,
845            tests::get_account_and_session_test_helper,
846        },
847        store::{
848            CryptoStore,
849            memorystore::MemoryStore,
850            types::{Changes, DeviceChanges, PendingChanges},
851        },
852    };
853
854    #[async_test]
855    async fn test_session_store() {
856        let (account, session) = get_account_and_session_test_helper();
857        let own_device = DeviceData::from_account(&account);
858        let store = MemoryStore::new();
859
860        assert!(store.load_account().await.unwrap().is_none());
861
862        store
863            .save_changes(Changes {
864                devices: DeviceChanges { new: vec![own_device], ..Default::default() },
865                ..Default::default()
866            })
867            .await
868            .unwrap();
869        store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
870
871        store
872            .save_changes(Changes { sessions: (vec![session.clone()]), ..Default::default() })
873            .await
874            .unwrap();
875
876        let sessions = store.get_sessions(&session.sender_key.to_base64()).await.unwrap().unwrap();
877
878        let loaded_session = &sessions[0];
879
880        assert_eq!(&session, loaded_session);
881    }
882
883    #[async_test]
884    async fn test_inbound_group_session_store() {
885        let (account, _) = get_account_and_session_test_helper();
886        let room_id = room_id!("!test:localhost");
887        let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw";
888
889        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
890        let inbound = InboundGroupSession::new(
891            Curve25519PublicKey::from_base64(curve_key).unwrap(),
892            Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE").unwrap(),
893            room_id,
894            &outbound.session_key().await,
895            SenderData::unknown(),
896            None,
897            outbound.settings().algorithm.to_owned(),
898            None,
899            false,
900        )
901        .unwrap();
902
903        let store = MemoryStore::new();
904        store.save_inbound_group_sessions(vec![inbound.clone()], None).await.unwrap();
905
906        let loaded_session =
907            store.get_inbound_group_session(room_id, outbound.session_id()).await.unwrap().unwrap();
908        assert_eq!(inbound, loaded_session);
909    }
910
911    #[async_test]
912    async fn test_backing_up_marks_sessions_as_backed_up() {
913        // Given there are 2 sessions
914        let room_id = room_id!("!test:localhost");
915        let (store, sessions) = store_with_sessions(2, room_id).await;
916
917        // When I mark them as backed up
918        mark_backed_up(&store, room_id, "bkp1", &sessions).await;
919
920        // Then their backed_up_to field is set
921        let but = backed_up_tos(&store).await;
922        assert_eq!(but[sessions[0].session_id()], "bkp1");
923        assert_eq!(but[sessions[1].session_id()], "bkp1");
924    }
925
926    #[async_test]
927    async fn test_backing_up_a_second_set_of_sessions_updates_their_backup_order() {
928        // Given there are 3 sessions
929        let room_id = room_id!("!test:localhost");
930        let (store, sessions) = store_with_sessions(3, room_id).await;
931
932        // When I mark 0 and 1 as backed up in bkp1
933        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
934
935        // And 1 and 2 as backed up in bkp2
936        mark_backed_up(&store, room_id, "bkp2", &sessions[1..]).await;
937
938        // Then 0 is backed up in bkp1 and the 1 and 2 are backed up in bkp2
939        let but = backed_up_tos(&store).await;
940        assert_eq!(but[sessions[0].session_id()], "bkp1");
941        assert_eq!(but[sessions[1].session_id()], "bkp2");
942        assert_eq!(but[sessions[2].session_id()], "bkp2");
943    }
944
945    #[async_test]
946    async fn test_backing_up_again_to_the_same_version_has_no_effect() {
947        // Given there are 3 sessions
948        let room_id = room_id!("!test:localhost");
949        let (store, sessions) = store_with_sessions(3, room_id).await;
950
951        // When I mark the first two as backed up in the first backup
952        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
953
954        // And the last 2 as backed up in the same backup version
955        mark_backed_up(&store, room_id, "bkp1", &sessions[1..]).await;
956
957        // Then they all get the same backed_up_to value
958        let but = backed_up_tos(&store).await;
959        assert_eq!(but[sessions[0].session_id()], "bkp1");
960        assert_eq!(but[sessions[1].session_id()], "bkp1");
961        assert_eq!(but[sessions[2].session_id()], "bkp1");
962    }
963
964    #[async_test]
965    async fn test_backing_up_to_an_old_backup_version_can_increase_backed_up_to() {
966        // Given we have backed up some sessions to 2 backup versions, an older and a
967        // newer
968        let room_id = room_id!("!test:localhost");
969        let (store, sessions) = store_with_sessions(4, room_id).await;
970        mark_backed_up(&store, room_id, "older_bkp", &sessions[..2]).await;
971        mark_backed_up(&store, room_id, "newer_bkp", &sessions[1..2]).await;
972
973        // When I ask to back up the un-backed-up ones to the older backup
974        mark_backed_up(&store, room_id, "older_bkp", &sessions[2..]).await;
975
976        // Then each session lists the backup it was most recently included in
977        let but = backed_up_tos(&store).await;
978        assert_eq!(but[sessions[0].session_id()], "older_bkp");
979        assert_eq!(but[sessions[1].session_id()], "newer_bkp");
980        assert_eq!(but[sessions[2].session_id()], "older_bkp");
981        assert_eq!(but[sessions[3].session_id()], "older_bkp");
982    }
983
984    #[async_test]
985    async fn test_backing_up_to_an_old_backup_version_overwrites_a_newer_one() {
986        // Given we have backed up to 2 backup versions, an older and a newer
987        let room_id = room_id!("!test:localhost");
988        let (store, sessions) = store_with_sessions(4, room_id).await;
989        mark_backed_up(&store, room_id, "older_bkp", &sessions).await;
990        // Sanity: they are backed up in order number 1
991        assert_eq!(backed_up_tos(&store).await[sessions[0].session_id()], "older_bkp");
992        mark_backed_up(&store, room_id, "newer_bkp", &sessions).await;
993        // Sanity: they are backed up in order number 2
994        assert_eq!(backed_up_tos(&store).await[sessions[0].session_id()], "newer_bkp");
995
996        // When I ask to back up some to the older version
997        mark_backed_up(&store, room_id, "older_bkp", &sessions[..2]).await;
998
999        // Then older backup overwrites: we don't consider the order here at all
1000        let but = backed_up_tos(&store).await;
1001        assert_eq!(but[sessions[0].session_id()], "older_bkp");
1002        assert_eq!(but[sessions[1].session_id()], "older_bkp");
1003        assert_eq!(but[sessions[2].session_id()], "newer_bkp");
1004        assert_eq!(but[sessions[3].session_id()], "newer_bkp");
1005    }
1006
1007    #[async_test]
1008    async fn test_not_backed_up_sessions_are_eligible_for_backup() {
1009        // Given there are 4 sessions, 2 of which are already backed up
1010        let room_id = room_id!("!test:localhost");
1011        let (store, sessions) = store_with_sessions(4, room_id).await;
1012        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
1013
1014        // When I ask which to back up
1015        let mut to_backup = store
1016            .inbound_group_sessions_for_backup("bkp1", 10)
1017            .await
1018            .expect("Failed to ask for sessions to backup");
1019        to_backup.sort_by_key(|s| s.session_id().to_owned());
1020
1021        // Then I am told the last 2 only
1022        assert_eq!(to_backup, &[sessions[2].clone(), sessions[3].clone()]);
1023    }
1024
1025    #[async_test]
1026    async fn test_all_sessions_are_eligible_for_backup_if_version_is_unknown() {
1027        // Given there are 4 sessions, 2 of which are already backed up in bkp1
1028        let room_id = room_id!("!test:localhost");
1029        let (store, sessions) = store_with_sessions(4, room_id).await;
1030        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
1031
1032        // When I ask which to back up in an unknown version
1033        let mut to_backup = store
1034            .inbound_group_sessions_for_backup("unknown_bkp", 10)
1035            .await
1036            .expect("Failed to ask for sessions to backup");
1037        to_backup.sort_by_key(|s| s.session_id().to_owned());
1038
1039        // Then I am told to back up all of them
1040        assert_eq!(
1041            to_backup,
1042            &[sessions[0].clone(), sessions[1].clone(), sessions[2].clone(), sessions[3].clone()]
1043        );
1044    }
1045
1046    #[async_test]
1047    async fn test_sessions_backed_up_to_a_later_version_are_eligible_for_backup() {
1048        // Given there are 4 sessions, some backed up to three different versions
1049        let room_id = room_id!("!test:localhost");
1050        let (store, sessions) = store_with_sessions(4, room_id).await;
1051        mark_backed_up(&store, room_id, "bkp0", &sessions[..1]).await;
1052        mark_backed_up(&store, room_id, "bkp1", &sessions[1..2]).await;
1053        mark_backed_up(&store, room_id, "bkp2", &sessions[2..3]).await;
1054
1055        // When I ask which to back up in the middle version
1056        let mut to_backup = store
1057            .inbound_group_sessions_for_backup("bkp1", 10)
1058            .await
1059            .expect("Failed to ask for sessions to backup");
1060        to_backup.sort_by_key(|s| s.session_id().to_owned());
1061
1062        // Then I am told to back up everything not in the version I asked about
1063        assert_eq!(
1064            to_backup,
1065            &[
1066                sessions[0].clone(), // Backed up in bkp0
1067                // sessions[1] is backed up in bkp1 already, which we asked about
1068                sessions[2].clone(), // Backed up in bkp2
1069                sessions[3].clone(), // Not backed up
1070            ]
1071        );
1072    }
1073
1074    #[async_test]
1075    async fn test_outbound_group_session_store() {
1076        // Given an outbound session
1077        let (account, _) = get_account_and_session_test_helper();
1078        let room_id = room_id!("!test:localhost");
1079        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
1080
1081        // When we save it to the store
1082        let store = MemoryStore::new();
1083        store.save_outbound_group_sessions(vec![outbound.clone()]);
1084
1085        // Then we can get it out again
1086        let loaded_session = store.get_outbound_group_session(room_id).await.unwrap().unwrap();
1087        assert_eq!(
1088            serde_json::to_string(&outbound.pickle().await).unwrap(),
1089            serde_json::to_string(&loaded_session.pickle().await).unwrap()
1090        );
1091    }
1092
1093    #[async_test]
1094    async fn test_tracked_users_are_stored_once_per_user_id() {
1095        // Given a store containing 2 tracked users, both dirty
1096        let user1 = user_id!("@user1:s");
1097        let user2 = user_id!("@user2:s");
1098        let user3 = user_id!("@user3:s");
1099        let store = MemoryStore::new();
1100        store.save_tracked_users(&[(user1, true), (user2, true)]).await.unwrap();
1101
1102        // When we mark one as clean and add another
1103        store.save_tracked_users(&[(user2, false), (user3, false)]).await.unwrap();
1104
1105        // Then we can get them out again and their dirty flags are correct
1106        let loaded_tracked_users =
1107            store.load_tracked_users().await.expect("failed to load tracked users");
1108
1109        let tracked_contains = |user_id, dirty| {
1110            loaded_tracked_users.iter().any(|u| u.user_id == user_id && u.dirty == dirty)
1111        };
1112
1113        assert!(tracked_contains(user1, true));
1114        assert!(tracked_contains(user2, false));
1115        assert!(tracked_contains(user3, false));
1116        assert_eq!(loaded_tracked_users.len(), 3);
1117    }
1118
1119    #[async_test]
1120    async fn test_private_identity_store() {
1121        // Given a private identity
1122        let private_identity = PrivateCrossSigningIdentity::empty(user_id!("@u:s"));
1123
1124        // When we save it to the store
1125        let store = MemoryStore::new();
1126        store.save_private_identity(Some(private_identity.clone()));
1127
1128        // Then we can get it out again
1129        let loaded_identity =
1130            store.load_identity().await.expect("failed to load private identity").unwrap();
1131
1132        assert_eq!(loaded_identity.user_id(), user_id!("@u:s"));
1133    }
1134
1135    #[async_test]
1136    async fn test_device_store() {
1137        let device = get_device();
1138        let store = MemoryStore::new();
1139
1140        store.save_devices(vec![device.clone()]);
1141
1142        let loaded_device =
1143            store.get_device(device.user_id(), device.device_id()).await.unwrap().unwrap();
1144
1145        assert_eq!(device, loaded_device);
1146
1147        let user_devices = store.get_user_devices(device.user_id()).await.unwrap();
1148
1149        assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
1150        assert_eq!(user_devices.values().next().unwrap(), &device);
1151
1152        let loaded_device = user_devices.get(device.device_id()).unwrap();
1153
1154        assert_eq!(&device, loaded_device);
1155
1156        store.delete_devices(vec![device.clone()]);
1157        assert!(store.get_device(device.user_id(), device.device_id()).await.unwrap().is_none());
1158    }
1159
1160    #[async_test]
1161    async fn test_message_hash() {
1162        let store = MemoryStore::new();
1163
1164        let hash =
1165            OlmMessageHash { sender_key: "test_sender".to_owned(), hash: "test_hash".to_owned() };
1166
1167        let mut changes = Changes::default();
1168        changes.message_hashes.push(hash.clone());
1169
1170        assert!(!store.is_message_known(&hash).await.unwrap());
1171        store.save_changes(changes).await.unwrap();
1172        assert!(store.is_message_known(&hash).await.unwrap());
1173    }
1174
1175    #[async_test]
1176    async fn test_key_counts_of_empty_store_are_zero() {
1177        // Given an empty store
1178        let store = MemoryStore::new();
1179
1180        // When we count keys
1181        let key_counts = store.inbound_group_session_counts(Some("")).await.unwrap();
1182
1183        // Then the answer is zero
1184        assert_eq!(key_counts.total, 0);
1185        assert_eq!(key_counts.backed_up, 0);
1186    }
1187
1188    #[async_test]
1189    async fn test_counting_sessions_reports_the_number_of_sessions() {
1190        // Given a store with sessions
1191        let room_id = room_id!("!test:localhost");
1192        let (store, _) = store_with_sessions(4, room_id).await;
1193
1194        // When we count keys
1195        let key_counts = store.inbound_group_session_counts(Some("bkp")).await.unwrap();
1196
1197        // Then the answer equals the number of sessions we created
1198        assert_eq!(key_counts.total, 4);
1199        // And none are backed up
1200        assert_eq!(key_counts.backed_up, 0);
1201    }
1202
1203    #[async_test]
1204    async fn test_counting_backed_up_sessions_reports_the_number_backed_up_in_this_backup() {
1205        // Given a store with sessions, some backed up
1206        let room_id = room_id!("!test:localhost");
1207        let (store, sessions) = store_with_sessions(5, room_id).await;
1208        mark_backed_up(&store, room_id, "bkp", &sessions[..2]).await;
1209
1210        // When we count keys
1211        let key_counts = store.inbound_group_session_counts(Some("bkp")).await.unwrap();
1212
1213        // Then the answer equals the number of sessions we created
1214        assert_eq!(key_counts.total, 5);
1215        // And the backed_up count matches how many were backed up
1216        assert_eq!(key_counts.backed_up, 2);
1217    }
1218
1219    #[async_test]
1220    async fn test_counting_backed_up_sessions_for_null_backup_reports_zero() {
1221        // Given a store with sessions, some backed up
1222        let room_id = room_id!("!test:localhost");
1223        let (store, sessions) = store_with_sessions(4, room_id).await;
1224        mark_backed_up(&store, room_id, "bkp", &sessions[..2]).await;
1225
1226        // When we count keys, providing None as the backup version
1227        let key_counts = store.inbound_group_session_counts(None).await.unwrap();
1228
1229        // Then we ignore everything and just say zero
1230        assert_eq!(key_counts.backed_up, 0);
1231    }
1232
1233    #[async_test]
1234    async fn test_counting_backed_up_sessions_only_reports_sessions_in_the_version_specified() {
1235        // Given a store with sessions, backed up in several versions
1236        let room_id = room_id!("!test:localhost");
1237        let (store, sessions) = store_with_sessions(4, room_id).await;
1238        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
1239        mark_backed_up(&store, room_id, "bkp2", &sessions[3..]).await;
1240
1241        // When we count keys for bkp2
1242        let key_counts = store.inbound_group_session_counts(Some("bkp2")).await.unwrap();
1243
1244        // Then the backed_up count reflects how many were backed up in bkp2 only
1245        assert_eq!(key_counts.backed_up, 1);
1246    }
1247
1248    /// Mark the supplied sessions as backed up in the supplied backup version
1249    async fn mark_backed_up(
1250        store: &MemoryStore,
1251        room_id: &RoomId,
1252        backup_version: &str,
1253        sessions: &[InboundGroupSession],
1254    ) {
1255        let rooms_and_ids: Vec<_> = sessions.iter().map(|s| (room_id, s.session_id())).collect();
1256
1257        store
1258            .mark_inbound_group_sessions_as_backed_up(backup_version, &rooms_and_ids)
1259            .await
1260            .expect("Failed to mark sessions as backed up");
1261    }
1262
1263    // Create a MemoryStore containing the supplied number of sessions.
1264    //
1265    // Sessions are returned in alphabetical order of session id.
1266    async fn store_with_sessions(
1267        num_sessions: usize,
1268        room_id: &RoomId,
1269    ) -> (MemoryStore, Vec<InboundGroupSession>) {
1270        let (account, _) = get_account_and_session_test_helper();
1271
1272        let mut sessions = Vec::with_capacity(num_sessions);
1273        for _ in 0..num_sessions {
1274            sessions.push(new_session(&account, room_id).await);
1275        }
1276        sessions.sort_by_key(|s| s.session_id().to_owned());
1277
1278        let store = MemoryStore::new();
1279        store.save_inbound_group_sessions(sessions.clone(), None).await.unwrap();
1280
1281        (store, sessions)
1282    }
1283
1284    // Create a new InboundGroupSession
1285    async fn new_session(account: &Account, room_id: &RoomId) -> InboundGroupSession {
1286        let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw";
1287        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
1288
1289        InboundGroupSession::new(
1290            Curve25519PublicKey::from_base64(curve_key).unwrap(),
1291            Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE").unwrap(),
1292            room_id,
1293            &outbound.session_key().await,
1294            SenderData::unknown(),
1295            None,
1296            outbound.settings().algorithm.to_owned(),
1297            None,
1298            false,
1299        )
1300        .unwrap()
1301    }
1302
1303    /// Find the session_id and backed_up_to value for each of the sessions in
1304    /// the store.
1305    async fn backed_up_tos(store: &MemoryStore) -> HashMap<SessionId, String> {
1306        store
1307            .get_inbound_group_sessions_and_backed_up_to()
1308            .await
1309            .expect("Unable to get inbound group sessions and backup order")
1310            .iter()
1311            .map(|(s, o)| {
1312                (
1313                    s.session_id().to_owned(),
1314                    o.as_ref().map(|v| v.as_str().to_owned()).unwrap_or("".to_owned()),
1315                )
1316            })
1317            .collect()
1318    }
1319}
1320
1321#[cfg(test)]
1322mod integration_tests {
1323    use std::{
1324        collections::HashMap,
1325        sync::{Arc, Mutex, OnceLock},
1326    };
1327
1328    use async_trait::async_trait;
1329    use matrix_sdk_common::cross_process_lock::CrossProcessLockGeneration;
1330    use ruma::{
1331        DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId, events::secret::request::SecretName,
1332    };
1333    use vodozemac::Curve25519PublicKey;
1334    use zeroize::Zeroizing;
1335
1336    use super::MemoryStore;
1337    use crate::{
1338        Account, DeviceData, GossipRequest, SecretInfo, Session, UserIdentityData,
1339        cryptostore_integration_tests, cryptostore_integration_tests_time,
1340        olm::{
1341            InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity,
1342            SenderDataType, StaticAccountData,
1343        },
1344        store::{
1345            CryptoStore,
1346            types::{
1347                BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts,
1348                RoomKeyWithheldEntry, RoomPendingKeyBundleDetails, RoomSettings,
1349                StoredRoomKeyBundleData, TrackedUser,
1350            },
1351        },
1352    };
1353
1354    /// Holds on to a MemoryStore during a test, and moves it back into STORES
1355    /// when this is dropped
1356    #[derive(Clone, Debug)]
1357    struct PersistentMemoryStore(Arc<MemoryStore>);
1358
1359    impl PersistentMemoryStore {
1360        fn new() -> Self {
1361            Self(Arc::new(MemoryStore::new()))
1362        }
1363
1364        fn get_static_account(&self) -> Option<StaticAccountData> {
1365            self.0.get_static_account()
1366        }
1367    }
1368
1369    /// Return a clone of the store for the test with the supplied name. Note:
1370    /// dropping this store won't destroy its data, since
1371    /// [PersistentMemoryStore] is a reference-counted smart pointer
1372    /// to an underlying [MemoryStore].
1373    async fn get_store(
1374        name: &str,
1375        _passphrase: Option<&str>,
1376        clear_data: bool,
1377    ) -> PersistentMemoryStore {
1378        // Holds on to one [PersistentMemoryStore] per test, so even if the test drops
1379        // the store, we keep its data alive. This simulates the behaviour of
1380        // the other stores, which keep their data in a real DB, allowing us to
1381        // test MemoryStore using the same code.
1382        static STORES: OnceLock<Mutex<HashMap<String, PersistentMemoryStore>>> = OnceLock::new();
1383        let stores = STORES.get_or_init(|| Mutex::new(HashMap::new()));
1384
1385        let mut stores = stores.lock().unwrap();
1386
1387        if clear_data {
1388            // Create a new PersistentMemoryStore
1389            let new_store = PersistentMemoryStore::new();
1390            stores.insert(name.to_owned(), new_store.clone());
1391            new_store
1392        } else {
1393            stores.entry(name.to_owned()).or_insert_with(PersistentMemoryStore::new).clone()
1394        }
1395    }
1396
1397    /// Forwards all methods to the underlying [MemoryStore].
1398    #[cfg_attr(target_family = "wasm", async_trait(?Send))]
1399    #[cfg_attr(not(target_family = "wasm"), async_trait)]
1400    impl CryptoStore for PersistentMemoryStore {
1401        type Error = <MemoryStore as CryptoStore>::Error;
1402
1403        async fn close(&self) -> Result<(), Self::Error> {
1404            self.0.close().await
1405        }
1406
1407        async fn reopen(&self) -> Result<(), Self::Error> {
1408            self.0.reopen().await
1409        }
1410
1411        async fn load_account(&self) -> Result<Option<Account>, Self::Error> {
1412            self.0.load_account().await
1413        }
1414
1415        async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>, Self::Error> {
1416            self.0.load_identity().await
1417        }
1418
1419        async fn save_changes(&self, changes: Changes) -> Result<(), Self::Error> {
1420            self.0.save_changes(changes).await
1421        }
1422
1423        async fn save_pending_changes(&self, changes: PendingChanges) -> Result<(), Self::Error> {
1424            self.0.save_pending_changes(changes).await
1425        }
1426
1427        async fn save_inbound_group_sessions(
1428            &self,
1429            sessions: Vec<InboundGroupSession>,
1430            backed_up_to_version: Option<&str>,
1431        ) -> Result<(), Self::Error> {
1432            self.0.save_inbound_group_sessions(sessions, backed_up_to_version).await
1433        }
1434
1435        async fn get_sessions(
1436            &self,
1437            sender_key: &str,
1438        ) -> Result<Option<Vec<Session>>, Self::Error> {
1439            self.0.get_sessions(sender_key).await
1440        }
1441
1442        async fn get_inbound_group_session(
1443            &self,
1444            room_id: &RoomId,
1445            session_id: &str,
1446        ) -> Result<Option<InboundGroupSession>, Self::Error> {
1447            self.0.get_inbound_group_session(room_id, session_id).await
1448        }
1449
1450        async fn get_withheld_info(
1451            &self,
1452            room_id: &RoomId,
1453            session_id: &str,
1454        ) -> Result<Option<RoomKeyWithheldEntry>, Self::Error> {
1455            self.0.get_withheld_info(room_id, session_id).await
1456        }
1457
1458        async fn get_withheld_sessions_by_room_id(
1459            &self,
1460            room_id: &RoomId,
1461        ) -> Result<Vec<RoomKeyWithheldEntry>, Self::Error> {
1462            self.0.get_withheld_sessions_by_room_id(room_id).await
1463        }
1464
1465        async fn get_inbound_group_sessions(
1466            &self,
1467        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1468            self.0.get_inbound_group_sessions().await
1469        }
1470
1471        async fn inbound_group_session_counts(
1472            &self,
1473            backup_version: Option<&str>,
1474        ) -> Result<RoomKeyCounts, Self::Error> {
1475            self.0.inbound_group_session_counts(backup_version).await
1476        }
1477
1478        async fn get_inbound_group_sessions_by_room_id(
1479            &self,
1480            room_id: &RoomId,
1481        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1482            self.0.get_inbound_group_sessions_by_room_id(room_id).await
1483        }
1484
1485        async fn get_inbound_group_sessions_for_device_batch(
1486            &self,
1487            sender_key: Curve25519PublicKey,
1488            sender_data_type: SenderDataType,
1489            after_session_id: Option<String>,
1490            limit: usize,
1491        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1492            self.0
1493                .get_inbound_group_sessions_for_device_batch(
1494                    sender_key,
1495                    sender_data_type,
1496                    after_session_id,
1497                    limit,
1498                )
1499                .await
1500        }
1501
1502        async fn inbound_group_sessions_for_backup(
1503            &self,
1504            backup_version: &str,
1505            limit: usize,
1506        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1507            self.0.inbound_group_sessions_for_backup(backup_version, limit).await
1508        }
1509
1510        async fn mark_inbound_group_sessions_as_backed_up(
1511            &self,
1512            backup_version: &str,
1513            room_and_session_ids: &[(&RoomId, &str)],
1514        ) -> Result<(), Self::Error> {
1515            self.0
1516                .mark_inbound_group_sessions_as_backed_up(backup_version, room_and_session_ids)
1517                .await
1518        }
1519
1520        async fn reset_backup_state(&self) -> Result<(), Self::Error> {
1521            self.0.reset_backup_state().await
1522        }
1523
1524        async fn load_backup_keys(&self) -> Result<BackupKeys, Self::Error> {
1525            self.0.load_backup_keys().await
1526        }
1527
1528        async fn load_dehydrated_device_pickle_key(
1529            &self,
1530        ) -> Result<Option<DehydratedDeviceKey>, Self::Error> {
1531            self.0.load_dehydrated_device_pickle_key().await
1532        }
1533
1534        async fn delete_dehydrated_device_pickle_key(&self) -> Result<(), Self::Error> {
1535            self.0.delete_dehydrated_device_pickle_key().await
1536        }
1537
1538        async fn get_outbound_group_session(
1539            &self,
1540            room_id: &RoomId,
1541        ) -> Result<Option<OutboundGroupSession>, Self::Error> {
1542            self.0.get_outbound_group_session(room_id).await
1543        }
1544
1545        async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>, Self::Error> {
1546            self.0.load_tracked_users().await
1547        }
1548
1549        async fn save_tracked_users(&self, users: &[(&UserId, bool)]) -> Result<(), Self::Error> {
1550            self.0.save_tracked_users(users).await
1551        }
1552
1553        async fn get_device(
1554            &self,
1555            user_id: &UserId,
1556            device_id: &DeviceId,
1557        ) -> Result<Option<DeviceData>, Self::Error> {
1558            self.0.get_device(user_id, device_id).await
1559        }
1560
1561        async fn get_user_devices(
1562            &self,
1563            user_id: &UserId,
1564        ) -> Result<HashMap<OwnedDeviceId, DeviceData>, Self::Error> {
1565            self.0.get_user_devices(user_id).await
1566        }
1567
1568        async fn get_own_device(&self) -> Result<DeviceData, Self::Error> {
1569            self.0.get_own_device().await
1570        }
1571
1572        async fn get_user_identity(
1573            &self,
1574            user_id: &UserId,
1575        ) -> Result<Option<UserIdentityData>, Self::Error> {
1576            self.0.get_user_identity(user_id).await
1577        }
1578
1579        async fn is_message_known(
1580            &self,
1581            message_hash: &OlmMessageHash,
1582        ) -> Result<bool, Self::Error> {
1583            self.0.is_message_known(message_hash).await
1584        }
1585
1586        async fn get_outgoing_secret_requests(
1587            &self,
1588            request_id: &TransactionId,
1589        ) -> Result<Option<GossipRequest>, Self::Error> {
1590            self.0.get_outgoing_secret_requests(request_id).await
1591        }
1592
1593        async fn get_secret_request_by_info(
1594            &self,
1595            secret_info: &SecretInfo,
1596        ) -> Result<Option<GossipRequest>, Self::Error> {
1597            self.0.get_secret_request_by_info(secret_info).await
1598        }
1599
1600        async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>, Self::Error> {
1601            self.0.get_unsent_secret_requests().await
1602        }
1603
1604        async fn delete_outgoing_secret_requests(
1605            &self,
1606            request_id: &TransactionId,
1607        ) -> Result<(), Self::Error> {
1608            self.0.delete_outgoing_secret_requests(request_id).await
1609        }
1610
1611        async fn get_secrets_from_inbox(
1612            &self,
1613            secret_name: &SecretName,
1614        ) -> Result<Vec<Zeroizing<String>>, Self::Error> {
1615            self.0.get_secrets_from_inbox(secret_name).await
1616        }
1617
1618        async fn delete_secrets_from_inbox(
1619            &self,
1620            secret_name: &SecretName,
1621        ) -> Result<(), Self::Error> {
1622            self.0.delete_secrets_from_inbox(secret_name).await
1623        }
1624
1625        async fn get_room_settings(
1626            &self,
1627            room_id: &RoomId,
1628        ) -> Result<Option<RoomSettings>, Self::Error> {
1629            self.0.get_room_settings(room_id).await
1630        }
1631
1632        async fn get_received_room_key_bundle_data(
1633            &self,
1634            room_id: &RoomId,
1635            user_id: &UserId,
1636        ) -> crate::store::Result<Option<StoredRoomKeyBundleData>, Self::Error> {
1637            self.0.get_received_room_key_bundle_data(room_id, user_id).await
1638        }
1639
1640        async fn has_downloaded_all_room_keys(
1641            &self,
1642            room_id: &RoomId,
1643        ) -> Result<bool, Self::Error> {
1644            self.0.has_downloaded_all_room_keys(room_id).await
1645        }
1646
1647        async fn get_pending_key_bundle_details_for_room(
1648            &self,
1649            room_id: &RoomId,
1650        ) -> Result<Option<RoomPendingKeyBundleDetails>, Self::Error> {
1651            self.0.get_pending_key_bundle_details_for_room(room_id).await
1652        }
1653
1654        async fn get_all_rooms_pending_key_bundles(
1655            &self,
1656        ) -> Result<Vec<RoomPendingKeyBundleDetails>, Self::Error> {
1657            self.0.get_all_rooms_pending_key_bundles().await
1658        }
1659
1660        async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>, Self::Error> {
1661            self.0.get_custom_value(key).await
1662        }
1663
1664        async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<(), Self::Error> {
1665            self.0.set_custom_value(key, value).await
1666        }
1667
1668        async fn remove_custom_value(&self, key: &str) -> Result<(), Self::Error> {
1669            self.0.remove_custom_value(key).await
1670        }
1671
1672        async fn try_take_leased_lock(
1673            &self,
1674            lease_duration_ms: u32,
1675            key: &str,
1676            holder: &str,
1677        ) -> Result<Option<CrossProcessLockGeneration>, Self::Error> {
1678            self.0.try_take_leased_lock(lease_duration_ms, key, holder).await
1679        }
1680
1681        async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
1682            self.0.next_batch_token().await
1683        }
1684
1685        async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1686            self.0.get_size().await
1687        }
1688    }
1689
1690    cryptostore_integration_tests!();
1691    cryptostore_integration_tests_time!();
1692}