matrix_sdk_crypto/store/
memorystore.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{BTreeMap, HashMap, HashSet},
17    convert::Infallible,
18    sync::Arc,
19};
20
21use async_trait::async_trait;
22use matrix_sdk_common::{
23    cross_process_lock::memory_store_helper::try_take_leased_lock, locks::RwLock as StdRwLock,
24};
25use ruma::{
26    events::secret::request::SecretName, time::Instant, DeviceId, OwnedDeviceId, OwnedRoomId,
27    OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId,
28};
29use tokio::sync::{Mutex, RwLock};
30use tracing::warn;
31use vodozemac::Curve25519PublicKey;
32
33use super::{
34    caches::DeviceStore,
35    types::{
36        BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts, RoomSettings,
37        StoredRoomKeyBundleData, TrackedUser,
38    },
39    Account, CryptoStore, InboundGroupSession, Session,
40};
41use crate::{
42    gossiping::{GossipRequest, GossippedSecret, SecretInfo},
43    identities::{DeviceData, UserIdentityData},
44    olm::{
45        OutboundGroupSession, PickledAccount, PickledInboundGroupSession, PickledSession,
46        PrivateCrossSigningIdentity, SenderDataType, StaticAccountData,
47    },
48    store::types::RoomKeyWithheldEntry,
49};
50
51fn encode_key_info(info: &SecretInfo) -> String {
52    match info {
53        SecretInfo::KeyRequest(info) => {
54            format!("{}{}{}", info.room_id(), info.algorithm(), info.session_id())
55        }
56        SecretInfo::SecretRequest(i) => i.as_ref().to_owned(),
57    }
58}
59
60type SessionId = String;
61
62/// The "version" of a backup - newtype wrapper around a String.
63#[derive(Clone, Debug, PartialEq)]
64struct BackupVersion(String);
65
66impl BackupVersion {
67    fn from(s: &str) -> Self {
68        Self(s.to_owned())
69    }
70
71    fn as_str(&self) -> &str {
72        &self.0
73    }
74}
75
76/// An in-memory only store that will forget all the E2EE key once it's dropped.
77#[derive(Default, Debug)]
78pub struct MemoryStore {
79    static_account: Arc<StdRwLock<Option<StaticAccountData>>>,
80
81    account: StdRwLock<Option<String>>,
82    // Map of sender_key to map of session_id to serialized pickle
83    sessions: StdRwLock<BTreeMap<String, BTreeMap<String, String>>>,
84    inbound_group_sessions: StdRwLock<BTreeMap<OwnedRoomId, HashMap<String, String>>>,
85
86    /// Map room id -> session id -> backup order number
87    /// The latest backup in which this session is stored. Equivalent to
88    /// `backed_up_to` in [`IndexedDbCryptoStore`]
89    inbound_group_sessions_backed_up_to:
90        StdRwLock<HashMap<OwnedRoomId, HashMap<SessionId, BackupVersion>>>,
91
92    outbound_group_sessions: StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>,
93    private_identity: StdRwLock<Option<PrivateCrossSigningIdentity>>,
94    tracked_users: StdRwLock<HashMap<OwnedUserId, TrackedUser>>,
95    olm_hashes: StdRwLock<HashMap<String, HashSet<String>>>,
96    devices: DeviceStore,
97    identities: StdRwLock<HashMap<OwnedUserId, String>>,
98    outgoing_key_requests: StdRwLock<HashMap<OwnedTransactionId, GossipRequest>>,
99    key_requests_by_info: StdRwLock<HashMap<String, OwnedTransactionId>>,
100    direct_withheld_info: StdRwLock<HashMap<OwnedRoomId, HashMap<String, RoomKeyWithheldEntry>>>,
101    custom_values: StdRwLock<HashMap<String, Vec<u8>>>,
102    leases: StdRwLock<HashMap<String, (String, Instant)>>,
103    secret_inbox: StdRwLock<HashMap<String, Vec<GossippedSecret>>>,
104    backup_keys: RwLock<BackupKeys>,
105    dehydrated_device_pickle_key: RwLock<Option<DehydratedDeviceKey>>,
106    next_batch_token: RwLock<Option<String>>,
107    room_settings: StdRwLock<HashMap<OwnedRoomId, RoomSettings>>,
108    room_key_bundles:
109        StdRwLock<HashMap<OwnedRoomId, HashMap<OwnedUserId, StoredRoomKeyBundleData>>>,
110
111    save_changes_lock: Arc<Mutex<()>>,
112}
113
114impl MemoryStore {
115    /// Create a new empty `MemoryStore`.
116    pub fn new() -> Self {
117        Self::default()
118    }
119
120    fn get_static_account(&self) -> Option<StaticAccountData> {
121        self.static_account.read().clone()
122    }
123
124    pub(crate) fn save_devices(&self, devices: Vec<DeviceData>) {
125        for device in devices {
126            let _ = self.devices.add(device);
127        }
128    }
129
130    fn delete_devices(&self, devices: Vec<DeviceData>) {
131        for device in devices {
132            let _ = self.devices.remove(device.user_id(), device.device_id());
133        }
134    }
135
136    fn save_sessions(&self, sessions: Vec<(String, PickledSession)>) {
137        let mut session_store = self.sessions.write();
138
139        for (session_id, pickle) in sessions {
140            let entry = session_store.entry(pickle.sender_key.to_base64()).or_default();
141
142            // insert or replace if exists
143            entry.insert(
144                session_id,
145                serde_json::to_string(&pickle).expect("Failed to serialize olm session"),
146            );
147        }
148    }
149
150    fn save_outbound_group_sessions(&self, sessions: Vec<OutboundGroupSession>) {
151        self.outbound_group_sessions
152            .write()
153            .extend(sessions.into_iter().map(|s| (s.room_id().to_owned(), s)));
154    }
155
156    fn save_private_identity(&self, private_identity: Option<PrivateCrossSigningIdentity>) {
157        *self.private_identity.write() = private_identity;
158    }
159
160    /// Return all the [`InboundGroupSession`]s we have, paired with the
161    /// `backed_up_to` value for each one (or "" where it is missing, which
162    /// should never happen).
163    async fn get_inbound_group_sessions_and_backed_up_to(
164        &self,
165    ) -> Result<Vec<(InboundGroupSession, Option<BackupVersion>)>> {
166        let lookup = |s: &InboundGroupSession| {
167            self.inbound_group_sessions_backed_up_to
168                .read()
169                .get(&s.room_id)?
170                .get(s.session_id())
171                .cloned()
172        };
173
174        Ok(self
175            .get_inbound_group_sessions()
176            .await?
177            .into_iter()
178            .map(|s| {
179                let v = lookup(&s);
180                (s, v)
181            })
182            .collect())
183    }
184}
185
186type Result<T> = std::result::Result<T, Infallible>;
187
188#[cfg_attr(target_family = "wasm", async_trait(?Send))]
189#[cfg_attr(not(target_family = "wasm"), async_trait)]
190impl CryptoStore for MemoryStore {
191    type Error = Infallible;
192
193    async fn load_account(&self) -> Result<Option<Account>> {
194        let pickled_account: Option<PickledAccount> = self.account.read().as_ref().map(|acc| {
195            serde_json::from_str(acc)
196                .expect("Deserialization failed: invalid pickled account JSON format")
197        });
198
199        if let Some(pickle) = pickled_account {
200            let account =
201                Account::from_pickle(pickle).expect("From pickle failed: invalid pickle format");
202
203            *self.static_account.write() = Some(account.static_data().clone());
204
205            Ok(Some(account))
206        } else {
207            Ok(None)
208        }
209    }
210
211    async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
212        Ok(self.private_identity.read().clone())
213    }
214
215    async fn next_batch_token(&self) -> Result<Option<String>> {
216        Ok(self.next_batch_token.read().await.clone())
217    }
218
219    async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
220        let _guard = self.save_changes_lock.lock().await;
221
222        let pickled_account = if let Some(account) = changes.account {
223            *self.static_account.write() = Some(account.static_data().clone());
224            Some(account.pickle())
225        } else {
226            None
227        };
228
229        *self.account.write() = pickled_account.map(|pickle| {
230            serde_json::to_string(&pickle)
231                .expect("Serialization failed: invalid pickled account JSON format")
232        });
233
234        Ok(())
235    }
236
237    async fn save_changes(&self, changes: Changes) -> Result<()> {
238        let _guard = self.save_changes_lock.lock().await;
239
240        let mut pickled_session: Vec<(String, PickledSession)> = Vec::new();
241        for session in changes.sessions {
242            let session_id = session.session_id().to_owned();
243            let pickle = session.pickle().await;
244            pickled_session.push((session_id.clone(), pickle));
245        }
246        self.save_sessions(pickled_session);
247
248        self.save_inbound_group_sessions(changes.inbound_group_sessions, None).await?;
249        self.save_outbound_group_sessions(changes.outbound_group_sessions);
250        self.save_private_identity(changes.private_identity);
251
252        self.save_devices(changes.devices.new);
253        self.save_devices(changes.devices.changed);
254        self.delete_devices(changes.devices.deleted);
255
256        {
257            let mut identities = self.identities.write();
258            for identity in changes.identities.new.into_iter().chain(changes.identities.changed) {
259                identities.insert(
260                    identity.user_id().to_owned(),
261                    serde_json::to_string(&identity)
262                        .expect("UserIdentityData should always serialize to json"),
263                );
264            }
265        }
266
267        {
268            let mut olm_hashes = self.olm_hashes.write();
269            for hash in changes.message_hashes {
270                olm_hashes.entry(hash.sender_key.to_owned()).or_default().insert(hash.hash.clone());
271            }
272        }
273
274        {
275            let mut outgoing_key_requests = self.outgoing_key_requests.write();
276            let mut key_requests_by_info = self.key_requests_by_info.write();
277
278            for key_request in changes.key_requests {
279                let id = key_request.request_id.clone();
280                let info_string = encode_key_info(&key_request.info);
281
282                outgoing_key_requests.insert(id.clone(), key_request);
283                key_requests_by_info.insert(info_string, id);
284            }
285        }
286
287        if let Some(key) = changes.backup_decryption_key {
288            self.backup_keys.write().await.decryption_key = Some(key);
289        }
290
291        if let Some(version) = changes.backup_version {
292            self.backup_keys.write().await.backup_version = Some(version);
293        }
294
295        if let Some(pickle_key) = changes.dehydrated_device_pickle_key {
296            let mut lock = self.dehydrated_device_pickle_key.write().await;
297            *lock = Some(pickle_key);
298        }
299
300        {
301            let mut secret_inbox = self.secret_inbox.write();
302            for secret in changes.secrets {
303                secret_inbox.entry(secret.secret_name.to_string()).or_default().push(secret);
304            }
305        }
306
307        {
308            let mut direct_withheld_info = self.direct_withheld_info.write();
309            for (room_id, data) in changes.withheld_session_info {
310                for (session_id, event) in data {
311                    direct_withheld_info
312                        .entry(room_id.to_owned())
313                        .or_default()
314                        .insert(session_id, event);
315                }
316            }
317        }
318
319        if let Some(next_batch_token) = changes.next_batch_token {
320            *self.next_batch_token.write().await = Some(next_batch_token);
321        }
322
323        if !changes.room_settings.is_empty() {
324            let mut settings = self.room_settings.write();
325            settings.extend(changes.room_settings);
326        }
327
328        if !changes.received_room_key_bundles.is_empty() {
329            let mut room_key_bundles = self.room_key_bundles.write();
330            for bundle in changes.received_room_key_bundles {
331                room_key_bundles
332                    .entry(bundle.bundle_data.room_id.clone())
333                    .or_default()
334                    .insert(bundle.sender_user.clone(), bundle);
335            }
336        }
337
338        Ok(())
339    }
340
341    async fn save_inbound_group_sessions(
342        &self,
343        sessions: Vec<InboundGroupSession>,
344        backed_up_to_version: Option<&str>,
345    ) -> Result<()> {
346        for session in sessions {
347            let room_id = session.room_id();
348            let session_id = session.session_id();
349
350            // Sanity-check that the data in the sessions corresponds to backed_up_version
351            let backed_up = session.backed_up();
352            if backed_up != backed_up_to_version.is_some() {
353                warn!(
354                    backed_up,
355                    backed_up_to_version,
356                    "Session backed-up flag does not correspond to backup version setting",
357                );
358            }
359
360            if let Some(backup_version) = backed_up_to_version {
361                self.inbound_group_sessions_backed_up_to
362                    .write()
363                    .entry(room_id.to_owned())
364                    .or_default()
365                    .insert(session_id.to_owned(), BackupVersion::from(backup_version));
366            }
367
368            let pickle = session.pickle().await;
369            self.inbound_group_sessions
370                .write()
371                .entry(session.room_id().to_owned())
372                .or_default()
373                .insert(
374                    session.session_id().to_owned(),
375                    serde_json::to_string(&pickle)
376                        .expect("Pickle pickle data should serialize to json"),
377                );
378        }
379        Ok(())
380    }
381
382    async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
383        let device_keys = self.get_own_device().await?.as_device_keys().clone();
384
385        if let Some(pickles) = self.sessions.read().get(sender_key) {
386            let mut sessions: Vec<Session> = Vec::new();
387            for serialized_pickle in pickles.values() {
388                let pickle: PickledSession = serde_json::from_str(serialized_pickle.as_str())
389                    .expect("Pickle pickle deserialization should work");
390                let session = Session::from_pickle(device_keys.clone(), pickle)
391                    .expect("Expect from pickle to always work");
392                sessions.push(session);
393            }
394            Ok(Some(sessions))
395        } else {
396            Ok(None)
397        }
398    }
399
400    async fn get_inbound_group_session(
401        &self,
402        room_id: &RoomId,
403        session_id: &str,
404    ) -> Result<Option<InboundGroupSession>> {
405        let pickle: Option<PickledInboundGroupSession> = self
406            .inbound_group_sessions
407            .read()
408            .get(room_id)
409            .and_then(|m| m.get(session_id))
410            .and_then(|ser| {
411                serde_json::from_str(ser).expect("Pickle pickle deserialization should work")
412            });
413
414        Ok(pickle.map(|p| {
415            InboundGroupSession::from_pickle(p).expect("Expect from pickle to always work")
416        }))
417    }
418
419    async fn get_withheld_info(
420        &self,
421        room_id: &RoomId,
422        session_id: &str,
423    ) -> Result<Option<RoomKeyWithheldEntry>> {
424        Ok(self
425            .direct_withheld_info
426            .read()
427            .get(room_id)
428            .and_then(|e| Some(e.get(session_id)?.to_owned())))
429    }
430
431    async fn get_withheld_sessions_by_room_id(
432        &self,
433        room_id: &RoomId,
434    ) -> crate::store::Result<Vec<RoomKeyWithheldEntry>, Self::Error> {
435        Ok(self
436            .direct_withheld_info
437            .read()
438            .get(room_id)
439            .map(|e| e.values().cloned().collect())
440            .unwrap_or_default())
441    }
442
443    async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
444        let inbounds = self
445            .inbound_group_sessions
446            .read()
447            .values()
448            .flat_map(HashMap::values)
449            .map(|ser| {
450                let pickle: PickledInboundGroupSession =
451                    serde_json::from_str(ser).expect("Pickle deserialization should work");
452                InboundGroupSession::from_pickle(pickle).expect("Expect from pickle to always work")
453            })
454            .collect();
455        Ok(inbounds)
456    }
457
458    async fn inbound_group_session_counts(
459        &self,
460        backup_version: Option<&str>,
461    ) -> Result<RoomKeyCounts> {
462        let backed_up = if let Some(backup_version) = backup_version {
463            self.get_inbound_group_sessions_and_backed_up_to()
464                .await?
465                .into_iter()
466                // Count the sessions backed up in the required backup
467                .filter(|(_, o)| o.as_ref().is_some_and(|o| o.as_str() == backup_version))
468                .count()
469        } else {
470            // We asked about a nonexistent backup version - this doesn't make much sense,
471            // but we can easily answer that nothing is backed up in this
472            // nonexistent backup.
473            0
474        };
475
476        let total = self.inbound_group_sessions.read().values().map(HashMap::len).sum();
477        Ok(RoomKeyCounts { total, backed_up })
478    }
479
480    async fn get_inbound_group_sessions_by_room_id(
481        &self,
482        room_id: &RoomId,
483    ) -> Result<Vec<InboundGroupSession>> {
484        let inbounds = match self.inbound_group_sessions.read().get(room_id) {
485            None => Vec::new(),
486            Some(v) => v
487                .values()
488                .map(|ser| {
489                    let pickle: PickledInboundGroupSession =
490                        serde_json::from_str(ser).expect("Pickle deserialization should work");
491                    InboundGroupSession::from_pickle(pickle)
492                        .expect("Expect from pickle to always work")
493                })
494                .collect(),
495        };
496        Ok(inbounds)
497    }
498
499    async fn get_inbound_group_sessions_for_device_batch(
500        &self,
501        sender_key: Curve25519PublicKey,
502        sender_data_type: SenderDataType,
503        after_session_id: Option<String>,
504        limit: usize,
505    ) -> Result<Vec<InboundGroupSession>> {
506        // First, find all InboundGroupSessions, filtering for those that match the
507        // device and sender_data type.
508        let mut sessions: Vec<_> = self
509            .get_inbound_group_sessions()
510            .await?
511            .into_iter()
512            .filter(|session: &InboundGroupSession| {
513                session.creator_info.curve25519_key == sender_key
514                    && session.sender_data.to_type() == sender_data_type
515            })
516            .collect();
517
518        // Then, sort the sessions in order of ascending session ID...
519        sessions.sort_by_key(|s| s.session_id().to_owned());
520
521        // Figure out where in the array to start returning results from
522        let start_index = {
523            match after_session_id {
524                None => 0,
525                Some(id) => {
526                    // We're looking for the first session with a session ID strictly after `id`; if
527                    // there are none, the end of the array.
528                    sessions
529                        .iter()
530                        .position(|session| session.session_id() > id.as_str())
531                        .unwrap_or(sessions.len())
532                }
533            }
534        };
535
536        // Return up to `limit` items from the array, starting from `start_index`
537        Ok(sessions.drain(start_index..).take(limit).collect())
538    }
539
540    async fn inbound_group_sessions_for_backup(
541        &self,
542        backup_version: &str,
543        limit: usize,
544    ) -> Result<Vec<InboundGroupSession>> {
545        Ok(self
546            .get_inbound_group_sessions_and_backed_up_to()
547            .await?
548            .into_iter()
549            .filter_map(|(session, backed_up_to)| {
550                if let Some(ref existing_version) = backed_up_to {
551                    if existing_version.as_str() == backup_version {
552                        // This session is already backed up in the required backup
553                        return None;
554                    }
555                }
556                // It's not backed up, or it's backed up in a different backup
557                Some(session)
558            })
559            .take(limit)
560            .collect())
561    }
562
563    async fn mark_inbound_group_sessions_as_backed_up(
564        &self,
565        backup_version: &str,
566        room_and_session_ids: &[(&RoomId, &str)],
567    ) -> Result<()> {
568        for &(room_id, session_id) in room_and_session_ids {
569            let session = self.get_inbound_group_session(room_id, session_id).await?;
570
571            if let Some(session) = session {
572                session.mark_as_backed_up();
573
574                self.inbound_group_sessions_backed_up_to
575                    .write()
576                    .entry(room_id.to_owned())
577                    .or_default()
578                    .insert(session_id.to_owned(), BackupVersion::from(backup_version));
579
580                // Save it back
581                let updated_pickle = session.pickle().await;
582
583                self.inbound_group_sessions.write().entry(room_id.to_owned()).or_default().insert(
584                    session_id.to_owned(),
585                    serde_json::to_string(&updated_pickle)
586                        .expect("Pickle serialization should work"),
587                );
588            }
589        }
590
591        Ok(())
592    }
593
594    async fn reset_backup_state(&self) -> Result<()> {
595        // Nothing to do here, because we remember which backup versions we backed up to
596        // in `mark_inbound_group_sessions_as_backed_up`, so we don't need to
597        // reset anything here because the required version is passed in to
598        // `inbound_group_sessions_for_backup`, and we can compare against the
599        // version we stored.
600
601        Ok(())
602    }
603
604    async fn load_backup_keys(&self) -> Result<BackupKeys> {
605        Ok(self.backup_keys.read().await.to_owned())
606    }
607
608    async fn load_dehydrated_device_pickle_key(&self) -> Result<Option<DehydratedDeviceKey>> {
609        Ok(self.dehydrated_device_pickle_key.read().await.to_owned())
610    }
611
612    async fn delete_dehydrated_device_pickle_key(&self) -> Result<()> {
613        let mut lock = self.dehydrated_device_pickle_key.write().await;
614        *lock = None;
615        Ok(())
616    }
617
618    async fn get_outbound_group_session(
619        &self,
620        room_id: &RoomId,
621    ) -> Result<Option<OutboundGroupSession>> {
622        Ok(self.outbound_group_sessions.read().get(room_id).cloned())
623    }
624
625    async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
626        Ok(self.tracked_users.read().values().cloned().collect())
627    }
628
629    async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
630        self.tracked_users.write().extend(tracked_users.iter().map(|(user_id, dirty)| {
631            let user_id: OwnedUserId = user_id.to_owned().into();
632            (user_id.clone(), TrackedUser { user_id, dirty: *dirty })
633        }));
634        Ok(())
635    }
636
637    async fn get_device(
638        &self,
639        user_id: &UserId,
640        device_id: &DeviceId,
641    ) -> Result<Option<DeviceData>> {
642        Ok(self.devices.get(user_id, device_id))
643    }
644
645    async fn get_user_devices(
646        &self,
647        user_id: &UserId,
648    ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
649        Ok(self.devices.user_devices(user_id))
650    }
651
652    async fn get_own_device(&self) -> Result<DeviceData> {
653        let account =
654            self.get_static_account().expect("Expect account to exist when getting own device");
655
656        Ok(self
657            .devices
658            .get(&account.user_id, &account.device_id)
659            .expect("Invalid state: Should always have a own device"))
660    }
661
662    async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
663        let serialized = self.identities.read().get(user_id).cloned();
664        match serialized {
665            None => Ok(None),
666            Some(serialized) => {
667                let id: UserIdentityData = serde_json::from_str(serialized.as_str())
668                    .expect("Only valid serialized identity are saved");
669                Ok(Some(id))
670            }
671        }
672    }
673
674    async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
675        Ok(self
676            .olm_hashes
677            .write()
678            .entry(message_hash.sender_key.to_owned())
679            .or_default()
680            .contains(&message_hash.hash))
681    }
682
683    async fn get_outgoing_secret_requests(
684        &self,
685        request_id: &TransactionId,
686    ) -> Result<Option<GossipRequest>> {
687        Ok(self.outgoing_key_requests.read().get(request_id).cloned())
688    }
689
690    async fn get_secret_request_by_info(
691        &self,
692        key_info: &SecretInfo,
693    ) -> Result<Option<GossipRequest>> {
694        let key_info_string = encode_key_info(key_info);
695
696        Ok(self
697            .key_requests_by_info
698            .read()
699            .get(&key_info_string)
700            .and_then(|i| self.outgoing_key_requests.read().get(i).cloned()))
701    }
702
703    async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
704        Ok(self
705            .outgoing_key_requests
706            .read()
707            .values()
708            .filter(|req| !req.sent_out)
709            .cloned()
710            .collect())
711    }
712
713    async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
714        let req = self.outgoing_key_requests.write().remove(request_id);
715        if let Some(i) = req {
716            let key_info_string = encode_key_info(&i.info);
717            self.key_requests_by_info.write().remove(&key_info_string);
718        }
719
720        Ok(())
721    }
722
723    async fn get_secrets_from_inbox(
724        &self,
725        secret_name: &SecretName,
726    ) -> Result<Vec<GossippedSecret>> {
727        Ok(self.secret_inbox.write().entry(secret_name.to_string()).or_default().to_owned())
728    }
729
730    async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
731        self.secret_inbox.write().remove(secret_name.as_str());
732
733        Ok(())
734    }
735
736    async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
737        Ok(self.room_settings.read().get(room_id).cloned())
738    }
739
740    async fn get_received_room_key_bundle_data(
741        &self,
742        room_id: &RoomId,
743        user_id: &UserId,
744    ) -> Result<Option<StoredRoomKeyBundleData>> {
745        let guard = self.room_key_bundles.read();
746
747        let result = guard.get(room_id).and_then(|bundles| bundles.get(user_id).cloned());
748
749        Ok(result)
750    }
751
752    async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
753        Ok(self.custom_values.read().get(key).cloned())
754    }
755
756    async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
757        self.custom_values.write().insert(key.to_owned(), value);
758        Ok(())
759    }
760
761    async fn remove_custom_value(&self, key: &str) -> Result<()> {
762        self.custom_values.write().remove(key);
763        Ok(())
764    }
765
766    async fn try_take_leased_lock(
767        &self,
768        lease_duration_ms: u32,
769        key: &str,
770        holder: &str,
771    ) -> Result<bool> {
772        Ok(try_take_leased_lock(&mut self.leases.write(), lease_duration_ms, key, holder))
773    }
774}
775
776#[cfg(test)]
777mod tests {
778    use std::collections::HashMap;
779
780    use matrix_sdk_test::async_test;
781    use ruma::{room_id, user_id, RoomId};
782    use vodozemac::{Curve25519PublicKey, Ed25519PublicKey};
783
784    use super::SessionId;
785    use crate::{
786        identities::device::testing::get_device,
787        olm::{
788            tests::get_account_and_session_test_helper, Account, InboundGroupSession,
789            OlmMessageHash, PrivateCrossSigningIdentity, SenderData,
790        },
791        store::{
792            memorystore::MemoryStore,
793            types::{Changes, DeviceChanges, PendingChanges},
794            CryptoStore,
795        },
796        DeviceData,
797    };
798
799    #[async_test]
800    async fn test_session_store() {
801        let (account, session) = get_account_and_session_test_helper();
802        let own_device = DeviceData::from_account(&account);
803        let store = MemoryStore::new();
804
805        assert!(store.load_account().await.unwrap().is_none());
806
807        store
808            .save_changes(Changes {
809                devices: DeviceChanges { new: vec![own_device], ..Default::default() },
810                ..Default::default()
811            })
812            .await
813            .unwrap();
814        store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
815
816        store
817            .save_changes(Changes { sessions: (vec![session.clone()]), ..Default::default() })
818            .await
819            .unwrap();
820
821        let sessions = store.get_sessions(&session.sender_key.to_base64()).await.unwrap().unwrap();
822
823        let loaded_session = &sessions[0];
824
825        assert_eq!(&session, loaded_session);
826    }
827
828    #[async_test]
829    async fn test_inbound_group_session_store() {
830        let (account, _) = get_account_and_session_test_helper();
831        let room_id = room_id!("!test:localhost");
832        let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw";
833
834        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
835        let inbound = InboundGroupSession::new(
836            Curve25519PublicKey::from_base64(curve_key).unwrap(),
837            Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE").unwrap(),
838            room_id,
839            &outbound.session_key().await,
840            SenderData::unknown(),
841            outbound.settings().algorithm.to_owned(),
842            None,
843            false,
844        )
845        .unwrap();
846
847        let store = MemoryStore::new();
848        store.save_inbound_group_sessions(vec![inbound.clone()], None).await.unwrap();
849
850        let loaded_session =
851            store.get_inbound_group_session(room_id, outbound.session_id()).await.unwrap().unwrap();
852        assert_eq!(inbound, loaded_session);
853    }
854
855    #[async_test]
856    async fn test_backing_up_marks_sessions_as_backed_up() {
857        // Given there are 2 sessions
858        let room_id = room_id!("!test:localhost");
859        let (store, sessions) = store_with_sessions(2, room_id).await;
860
861        // When I mark them as backed up
862        mark_backed_up(&store, room_id, "bkp1", &sessions).await;
863
864        // Then their backed_up_to field is set
865        let but = backed_up_tos(&store).await;
866        assert_eq!(but[sessions[0].session_id()], "bkp1");
867        assert_eq!(but[sessions[1].session_id()], "bkp1");
868    }
869
870    #[async_test]
871    async fn test_backing_up_a_second_set_of_sessions_updates_their_backup_order() {
872        // Given there are 3 sessions
873        let room_id = room_id!("!test:localhost");
874        let (store, sessions) = store_with_sessions(3, room_id).await;
875
876        // When I mark 0 and 1 as backed up in bkp1
877        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
878
879        // And 1 and 2 as backed up in bkp2
880        mark_backed_up(&store, room_id, "bkp2", &sessions[1..]).await;
881
882        // Then 0 is backed up in bkp1 and the 1 and 2 are backed up in bkp2
883        let but = backed_up_tos(&store).await;
884        assert_eq!(but[sessions[0].session_id()], "bkp1");
885        assert_eq!(but[sessions[1].session_id()], "bkp2");
886        assert_eq!(but[sessions[2].session_id()], "bkp2");
887    }
888
889    #[async_test]
890    async fn test_backing_up_again_to_the_same_version_has_no_effect() {
891        // Given there are 3 sessions
892        let room_id = room_id!("!test:localhost");
893        let (store, sessions) = store_with_sessions(3, room_id).await;
894
895        // When I mark the first two as backed up in the first backup
896        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
897
898        // And the last 2 as backed up in the same backup version
899        mark_backed_up(&store, room_id, "bkp1", &sessions[1..]).await;
900
901        // Then they all get the same backed_up_to value
902        let but = backed_up_tos(&store).await;
903        assert_eq!(but[sessions[0].session_id()], "bkp1");
904        assert_eq!(but[sessions[1].session_id()], "bkp1");
905        assert_eq!(but[sessions[2].session_id()], "bkp1");
906    }
907
908    #[async_test]
909    async fn test_backing_up_to_an_old_backup_version_can_increase_backed_up_to() {
910        // Given we have backed up some sessions to 2 backup versions, an older and a
911        // newer
912        let room_id = room_id!("!test:localhost");
913        let (store, sessions) = store_with_sessions(4, room_id).await;
914        mark_backed_up(&store, room_id, "older_bkp", &sessions[..2]).await;
915        mark_backed_up(&store, room_id, "newer_bkp", &sessions[1..2]).await;
916
917        // When I ask to back up the un-backed-up ones to the older backup
918        mark_backed_up(&store, room_id, "older_bkp", &sessions[2..]).await;
919
920        // Then each session lists the backup it was most recently included in
921        let but = backed_up_tos(&store).await;
922        assert_eq!(but[sessions[0].session_id()], "older_bkp");
923        assert_eq!(but[sessions[1].session_id()], "newer_bkp");
924        assert_eq!(but[sessions[2].session_id()], "older_bkp");
925        assert_eq!(but[sessions[3].session_id()], "older_bkp");
926    }
927
928    #[async_test]
929    async fn test_backing_up_to_an_old_backup_version_overwrites_a_newer_one() {
930        // Given we have backed up to 2 backup versions, an older and a newer
931        let room_id = room_id!("!test:localhost");
932        let (store, sessions) = store_with_sessions(4, room_id).await;
933        mark_backed_up(&store, room_id, "older_bkp", &sessions).await;
934        // Sanity: they are backed up in order number 1
935        assert_eq!(backed_up_tos(&store).await[sessions[0].session_id()], "older_bkp");
936        mark_backed_up(&store, room_id, "newer_bkp", &sessions).await;
937        // Sanity: they are backed up in order number 2
938        assert_eq!(backed_up_tos(&store).await[sessions[0].session_id()], "newer_bkp");
939
940        // When I ask to back up some to the older version
941        mark_backed_up(&store, room_id, "older_bkp", &sessions[..2]).await;
942
943        // Then older backup overwrites: we don't consider the order here at all
944        let but = backed_up_tos(&store).await;
945        assert_eq!(but[sessions[0].session_id()], "older_bkp");
946        assert_eq!(but[sessions[1].session_id()], "older_bkp");
947        assert_eq!(but[sessions[2].session_id()], "newer_bkp");
948        assert_eq!(but[sessions[3].session_id()], "newer_bkp");
949    }
950
951    #[async_test]
952    async fn test_not_backed_up_sessions_are_eligible_for_backup() {
953        // Given there are 4 sessions, 2 of which are already backed up
954        let room_id = room_id!("!test:localhost");
955        let (store, sessions) = store_with_sessions(4, room_id).await;
956        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
957
958        // When I ask which to back up
959        let mut to_backup = store
960            .inbound_group_sessions_for_backup("bkp1", 10)
961            .await
962            .expect("Failed to ask for sessions to backup");
963        to_backup.sort_by_key(|s| s.session_id().to_owned());
964
965        // Then I am told the last 2 only
966        assert_eq!(to_backup, &[sessions[2].clone(), sessions[3].clone()]);
967    }
968
969    #[async_test]
970    async fn test_all_sessions_are_eligible_for_backup_if_version_is_unknown() {
971        // Given there are 4 sessions, 2 of which are already backed up in bkp1
972        let room_id = room_id!("!test:localhost");
973        let (store, sessions) = store_with_sessions(4, room_id).await;
974        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
975
976        // When I ask which to back up in an unknown version
977        let mut to_backup = store
978            .inbound_group_sessions_for_backup("unknown_bkp", 10)
979            .await
980            .expect("Failed to ask for sessions to backup");
981        to_backup.sort_by_key(|s| s.session_id().to_owned());
982
983        // Then I am told to back up all of them
984        assert_eq!(
985            to_backup,
986            &[sessions[0].clone(), sessions[1].clone(), sessions[2].clone(), sessions[3].clone()]
987        );
988    }
989
990    #[async_test]
991    async fn test_sessions_backed_up_to_a_later_version_are_eligible_for_backup() {
992        // Given there are 4 sessions, some backed up to three different versions
993        let room_id = room_id!("!test:localhost");
994        let (store, sessions) = store_with_sessions(4, room_id).await;
995        mark_backed_up(&store, room_id, "bkp0", &sessions[..1]).await;
996        mark_backed_up(&store, room_id, "bkp1", &sessions[1..2]).await;
997        mark_backed_up(&store, room_id, "bkp2", &sessions[2..3]).await;
998
999        // When I ask which to back up in the middle version
1000        let mut to_backup = store
1001            .inbound_group_sessions_for_backup("bkp1", 10)
1002            .await
1003            .expect("Failed to ask for sessions to backup");
1004        to_backup.sort_by_key(|s| s.session_id().to_owned());
1005
1006        // Then I am told to back up everything not in the version I asked about
1007        assert_eq!(
1008            to_backup,
1009            &[
1010                sessions[0].clone(), // Backed up in bkp0
1011                // sessions[1] is backed up in bkp1 already, which we asked about
1012                sessions[2].clone(), // Backed up in bkp2
1013                sessions[3].clone(), // Not backed up
1014            ]
1015        );
1016    }
1017
1018    #[async_test]
1019    async fn test_outbound_group_session_store() {
1020        // Given an outbound session
1021        let (account, _) = get_account_and_session_test_helper();
1022        let room_id = room_id!("!test:localhost");
1023        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
1024
1025        // When we save it to the store
1026        let store = MemoryStore::new();
1027        store.save_outbound_group_sessions(vec![outbound.clone()]);
1028
1029        // Then we can get it out again
1030        let loaded_session = store.get_outbound_group_session(room_id).await.unwrap().unwrap();
1031        assert_eq!(
1032            serde_json::to_string(&outbound.pickle().await).unwrap(),
1033            serde_json::to_string(&loaded_session.pickle().await).unwrap()
1034        );
1035    }
1036
1037    #[async_test]
1038    async fn test_tracked_users_are_stored_once_per_user_id() {
1039        // Given a store containing 2 tracked users, both dirty
1040        let user1 = user_id!("@user1:s");
1041        let user2 = user_id!("@user2:s");
1042        let user3 = user_id!("@user3:s");
1043        let store = MemoryStore::new();
1044        store.save_tracked_users(&[(user1, true), (user2, true)]).await.unwrap();
1045
1046        // When we mark one as clean and add another
1047        store.save_tracked_users(&[(user2, false), (user3, false)]).await.unwrap();
1048
1049        // Then we can get them out again and their dirty flags are correct
1050        let loaded_tracked_users =
1051            store.load_tracked_users().await.expect("failed to load tracked users");
1052
1053        let tracked_contains = |user_id, dirty| {
1054            loaded_tracked_users.iter().any(|u| u.user_id == user_id && u.dirty == dirty)
1055        };
1056
1057        assert!(tracked_contains(user1, true));
1058        assert!(tracked_contains(user2, false));
1059        assert!(tracked_contains(user3, false));
1060        assert_eq!(loaded_tracked_users.len(), 3);
1061    }
1062
1063    #[async_test]
1064    async fn test_private_identity_store() {
1065        // Given a private identity
1066        let private_identity = PrivateCrossSigningIdentity::empty(user_id!("@u:s"));
1067
1068        // When we save it to the store
1069        let store = MemoryStore::new();
1070        store.save_private_identity(Some(private_identity.clone()));
1071
1072        // Then we can get it out again
1073        let loaded_identity =
1074            store.load_identity().await.expect("failed to load private identity").unwrap();
1075
1076        assert_eq!(loaded_identity.user_id(), user_id!("@u:s"));
1077    }
1078
1079    #[async_test]
1080    async fn test_device_store() {
1081        let device = get_device();
1082        let store = MemoryStore::new();
1083
1084        store.save_devices(vec![device.clone()]);
1085
1086        let loaded_device =
1087            store.get_device(device.user_id(), device.device_id()).await.unwrap().unwrap();
1088
1089        assert_eq!(device, loaded_device);
1090
1091        let user_devices = store.get_user_devices(device.user_id()).await.unwrap();
1092
1093        assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
1094        assert_eq!(user_devices.values().next().unwrap(), &device);
1095
1096        let loaded_device = user_devices.get(device.device_id()).unwrap();
1097
1098        assert_eq!(&device, loaded_device);
1099
1100        store.delete_devices(vec![device.clone()]);
1101        assert!(store.get_device(device.user_id(), device.device_id()).await.unwrap().is_none());
1102    }
1103
1104    #[async_test]
1105    async fn test_message_hash() {
1106        let store = MemoryStore::new();
1107
1108        let hash =
1109            OlmMessageHash { sender_key: "test_sender".to_owned(), hash: "test_hash".to_owned() };
1110
1111        let mut changes = Changes::default();
1112        changes.message_hashes.push(hash.clone());
1113
1114        assert!(!store.is_message_known(&hash).await.unwrap());
1115        store.save_changes(changes).await.unwrap();
1116        assert!(store.is_message_known(&hash).await.unwrap());
1117    }
1118
1119    #[async_test]
1120    async fn test_key_counts_of_empty_store_are_zero() {
1121        // Given an empty store
1122        let store = MemoryStore::new();
1123
1124        // When we count keys
1125        let key_counts = store.inbound_group_session_counts(Some("")).await.unwrap();
1126
1127        // Then the answer is zero
1128        assert_eq!(key_counts.total, 0);
1129        assert_eq!(key_counts.backed_up, 0);
1130    }
1131
1132    #[async_test]
1133    async fn test_counting_sessions_reports_the_number_of_sessions() {
1134        // Given a store with sessions
1135        let room_id = room_id!("!test:localhost");
1136        let (store, _) = store_with_sessions(4, room_id).await;
1137
1138        // When we count keys
1139        let key_counts = store.inbound_group_session_counts(Some("bkp")).await.unwrap();
1140
1141        // Then the answer equals the number of sessions we created
1142        assert_eq!(key_counts.total, 4);
1143        // And none are backed up
1144        assert_eq!(key_counts.backed_up, 0);
1145    }
1146
1147    #[async_test]
1148    async fn test_counting_backed_up_sessions_reports_the_number_backed_up_in_this_backup() {
1149        // Given a store with sessions, some backed up
1150        let room_id = room_id!("!test:localhost");
1151        let (store, sessions) = store_with_sessions(5, room_id).await;
1152        mark_backed_up(&store, room_id, "bkp", &sessions[..2]).await;
1153
1154        // When we count keys
1155        let key_counts = store.inbound_group_session_counts(Some("bkp")).await.unwrap();
1156
1157        // Then the answer equals the number of sessions we created
1158        assert_eq!(key_counts.total, 5);
1159        // And the backed_up count matches how many were backed up
1160        assert_eq!(key_counts.backed_up, 2);
1161    }
1162
1163    #[async_test]
1164    async fn test_counting_backed_up_sessions_for_null_backup_reports_zero() {
1165        // Given a store with sessions, some backed up
1166        let room_id = room_id!("!test:localhost");
1167        let (store, sessions) = store_with_sessions(4, room_id).await;
1168        mark_backed_up(&store, room_id, "bkp", &sessions[..2]).await;
1169
1170        // When we count keys, providing None as the backup version
1171        let key_counts = store.inbound_group_session_counts(None).await.unwrap();
1172
1173        // Then we ignore everything and just say zero
1174        assert_eq!(key_counts.backed_up, 0);
1175    }
1176
1177    #[async_test]
1178    async fn test_counting_backed_up_sessions_only_reports_sessions_in_the_version_specified() {
1179        // Given a store with sessions, backed up in several versions
1180        let room_id = room_id!("!test:localhost");
1181        let (store, sessions) = store_with_sessions(4, room_id).await;
1182        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
1183        mark_backed_up(&store, room_id, "bkp2", &sessions[3..]).await;
1184
1185        // When we count keys for bkp2
1186        let key_counts = store.inbound_group_session_counts(Some("bkp2")).await.unwrap();
1187
1188        // Then the backed_up count reflects how many were backed up in bkp2 only
1189        assert_eq!(key_counts.backed_up, 1);
1190    }
1191
1192    /// Mark the supplied sessions as backed up in the supplied backup version
1193    async fn mark_backed_up(
1194        store: &MemoryStore,
1195        room_id: &RoomId,
1196        backup_version: &str,
1197        sessions: &[InboundGroupSession],
1198    ) {
1199        let rooms_and_ids: Vec<_> = sessions.iter().map(|s| (room_id, s.session_id())).collect();
1200
1201        store
1202            .mark_inbound_group_sessions_as_backed_up(backup_version, &rooms_and_ids)
1203            .await
1204            .expect("Failed to mark sessions as backed up");
1205    }
1206
1207    // Create a MemoryStore containing the supplied number of sessions.
1208    //
1209    // Sessions are returned in alphabetical order of session id.
1210    async fn store_with_sessions(
1211        num_sessions: usize,
1212        room_id: &RoomId,
1213    ) -> (MemoryStore, Vec<InboundGroupSession>) {
1214        let (account, _) = get_account_and_session_test_helper();
1215
1216        let mut sessions = Vec::with_capacity(num_sessions);
1217        for _ in 0..num_sessions {
1218            sessions.push(new_session(&account, room_id).await);
1219        }
1220        sessions.sort_by_key(|s| s.session_id().to_owned());
1221
1222        let store = MemoryStore::new();
1223        store.save_inbound_group_sessions(sessions.clone(), None).await.unwrap();
1224
1225        (store, sessions)
1226    }
1227
1228    // Create a new InboundGroupSession
1229    async fn new_session(account: &Account, room_id: &RoomId) -> InboundGroupSession {
1230        let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw";
1231        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
1232
1233        InboundGroupSession::new(
1234            Curve25519PublicKey::from_base64(curve_key).unwrap(),
1235            Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE").unwrap(),
1236            room_id,
1237            &outbound.session_key().await,
1238            SenderData::unknown(),
1239            outbound.settings().algorithm.to_owned(),
1240            None,
1241            false,
1242        )
1243        .unwrap()
1244    }
1245
1246    /// Find the session_id and backed_up_to value for each of the sessions in
1247    /// the store.
1248    async fn backed_up_tos(store: &MemoryStore) -> HashMap<SessionId, String> {
1249        store
1250            .get_inbound_group_sessions_and_backed_up_to()
1251            .await
1252            .expect("Unable to get inbound group sessions and backup order")
1253            .iter()
1254            .map(|(s, o)| {
1255                (
1256                    s.session_id().to_owned(),
1257                    o.as_ref().map(|v| v.as_str().to_owned()).unwrap_or("".to_owned()),
1258                )
1259            })
1260            .collect()
1261    }
1262}
1263
1264#[cfg(test)]
1265mod integration_tests {
1266    use std::{
1267        collections::HashMap,
1268        sync::{Arc, Mutex, OnceLock},
1269    };
1270
1271    use async_trait::async_trait;
1272    use ruma::{
1273        events::secret::request::SecretName, DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId,
1274    };
1275    use vodozemac::Curve25519PublicKey;
1276
1277    use super::MemoryStore;
1278    use crate::{
1279        cryptostore_integration_tests, cryptostore_integration_tests_time,
1280        olm::{
1281            InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity,
1282            SenderDataType, StaticAccountData,
1283        },
1284        store::{
1285            types::{
1286                BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts,
1287                RoomKeyWithheldEntry, RoomSettings, StoredRoomKeyBundleData, TrackedUser,
1288            },
1289            CryptoStore,
1290        },
1291        Account, DeviceData, GossipRequest, GossippedSecret, SecretInfo, Session, UserIdentityData,
1292    };
1293
1294    /// Holds on to a MemoryStore during a test, and moves it back into STORES
1295    /// when this is dropped
1296    #[derive(Clone, Debug)]
1297    struct PersistentMemoryStore(Arc<MemoryStore>);
1298
1299    impl PersistentMemoryStore {
1300        fn new() -> Self {
1301            Self(Arc::new(MemoryStore::new()))
1302        }
1303
1304        fn get_static_account(&self) -> Option<StaticAccountData> {
1305            self.0.get_static_account()
1306        }
1307    }
1308
1309    /// Return a clone of the store for the test with the supplied name. Note:
1310    /// dropping this store won't destroy its data, since
1311    /// [PersistentMemoryStore] is a reference-counted smart pointer
1312    /// to an underlying [MemoryStore].
1313    async fn get_store(
1314        name: &str,
1315        _passphrase: Option<&str>,
1316        clear_data: bool,
1317    ) -> PersistentMemoryStore {
1318        // Holds on to one [PersistentMemoryStore] per test, so even if the test drops
1319        // the store, we keep its data alive. This simulates the behaviour of
1320        // the other stores, which keep their data in a real DB, allowing us to
1321        // test MemoryStore using the same code.
1322        static STORES: OnceLock<Mutex<HashMap<String, PersistentMemoryStore>>> = OnceLock::new();
1323        let stores = STORES.get_or_init(|| Mutex::new(HashMap::new()));
1324
1325        let mut stores = stores.lock().unwrap();
1326
1327        if clear_data {
1328            // Create a new PersistentMemoryStore
1329            let new_store = PersistentMemoryStore::new();
1330            stores.insert(name.to_owned(), new_store.clone());
1331            new_store
1332        } else {
1333            stores.entry(name.to_owned()).or_insert_with(PersistentMemoryStore::new).clone()
1334        }
1335    }
1336
1337    /// Forwards all methods to the underlying [MemoryStore].
1338    #[cfg_attr(target_family = "wasm", async_trait(?Send))]
1339    #[cfg_attr(not(target_family = "wasm"), async_trait)]
1340    impl CryptoStore for PersistentMemoryStore {
1341        type Error = <MemoryStore as CryptoStore>::Error;
1342
1343        async fn load_account(&self) -> Result<Option<Account>, Self::Error> {
1344            self.0.load_account().await
1345        }
1346
1347        async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>, Self::Error> {
1348            self.0.load_identity().await
1349        }
1350
1351        async fn save_changes(&self, changes: Changes) -> Result<(), Self::Error> {
1352            self.0.save_changes(changes).await
1353        }
1354
1355        async fn save_pending_changes(&self, changes: PendingChanges) -> Result<(), Self::Error> {
1356            self.0.save_pending_changes(changes).await
1357        }
1358
1359        async fn save_inbound_group_sessions(
1360            &self,
1361            sessions: Vec<InboundGroupSession>,
1362            backed_up_to_version: Option<&str>,
1363        ) -> Result<(), Self::Error> {
1364            self.0.save_inbound_group_sessions(sessions, backed_up_to_version).await
1365        }
1366
1367        async fn get_sessions(
1368            &self,
1369            sender_key: &str,
1370        ) -> Result<Option<Vec<Session>>, Self::Error> {
1371            self.0.get_sessions(sender_key).await
1372        }
1373
1374        async fn get_inbound_group_session(
1375            &self,
1376            room_id: &RoomId,
1377            session_id: &str,
1378        ) -> Result<Option<InboundGroupSession>, Self::Error> {
1379            self.0.get_inbound_group_session(room_id, session_id).await
1380        }
1381
1382        async fn get_withheld_info(
1383            &self,
1384            room_id: &RoomId,
1385            session_id: &str,
1386        ) -> Result<Option<RoomKeyWithheldEntry>, Self::Error> {
1387            self.0.get_withheld_info(room_id, session_id).await
1388        }
1389
1390        async fn get_withheld_sessions_by_room_id(
1391            &self,
1392            room_id: &RoomId,
1393        ) -> Result<Vec<RoomKeyWithheldEntry>, Self::Error> {
1394            self.0.get_withheld_sessions_by_room_id(room_id).await
1395        }
1396
1397        async fn get_inbound_group_sessions(
1398            &self,
1399        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1400            self.0.get_inbound_group_sessions().await
1401        }
1402
1403        async fn inbound_group_session_counts(
1404            &self,
1405            backup_version: Option<&str>,
1406        ) -> Result<RoomKeyCounts, Self::Error> {
1407            self.0.inbound_group_session_counts(backup_version).await
1408        }
1409
1410        async fn get_inbound_group_sessions_by_room_id(
1411            &self,
1412            room_id: &RoomId,
1413        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1414            self.0.get_inbound_group_sessions_by_room_id(room_id).await
1415        }
1416
1417        async fn get_inbound_group_sessions_for_device_batch(
1418            &self,
1419            sender_key: Curve25519PublicKey,
1420            sender_data_type: SenderDataType,
1421            after_session_id: Option<String>,
1422            limit: usize,
1423        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1424            self.0
1425                .get_inbound_group_sessions_for_device_batch(
1426                    sender_key,
1427                    sender_data_type,
1428                    after_session_id,
1429                    limit,
1430                )
1431                .await
1432        }
1433
1434        async fn inbound_group_sessions_for_backup(
1435            &self,
1436            backup_version: &str,
1437            limit: usize,
1438        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1439            self.0.inbound_group_sessions_for_backup(backup_version, limit).await
1440        }
1441
1442        async fn mark_inbound_group_sessions_as_backed_up(
1443            &self,
1444            backup_version: &str,
1445            room_and_session_ids: &[(&RoomId, &str)],
1446        ) -> Result<(), Self::Error> {
1447            self.0
1448                .mark_inbound_group_sessions_as_backed_up(backup_version, room_and_session_ids)
1449                .await
1450        }
1451
1452        async fn reset_backup_state(&self) -> Result<(), Self::Error> {
1453            self.0.reset_backup_state().await
1454        }
1455
1456        async fn load_backup_keys(&self) -> Result<BackupKeys, Self::Error> {
1457            self.0.load_backup_keys().await
1458        }
1459
1460        async fn load_dehydrated_device_pickle_key(
1461            &self,
1462        ) -> Result<Option<DehydratedDeviceKey>, Self::Error> {
1463            self.0.load_dehydrated_device_pickle_key().await
1464        }
1465
1466        async fn delete_dehydrated_device_pickle_key(&self) -> Result<(), Self::Error> {
1467            self.0.delete_dehydrated_device_pickle_key().await
1468        }
1469
1470        async fn get_outbound_group_session(
1471            &self,
1472            room_id: &RoomId,
1473        ) -> Result<Option<OutboundGroupSession>, Self::Error> {
1474            self.0.get_outbound_group_session(room_id).await
1475        }
1476
1477        async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>, Self::Error> {
1478            self.0.load_tracked_users().await
1479        }
1480
1481        async fn save_tracked_users(&self, users: &[(&UserId, bool)]) -> Result<(), Self::Error> {
1482            self.0.save_tracked_users(users).await
1483        }
1484
1485        async fn get_device(
1486            &self,
1487            user_id: &UserId,
1488            device_id: &DeviceId,
1489        ) -> Result<Option<DeviceData>, Self::Error> {
1490            self.0.get_device(user_id, device_id).await
1491        }
1492
1493        async fn get_user_devices(
1494            &self,
1495            user_id: &UserId,
1496        ) -> Result<HashMap<OwnedDeviceId, DeviceData>, Self::Error> {
1497            self.0.get_user_devices(user_id).await
1498        }
1499
1500        async fn get_own_device(&self) -> Result<DeviceData, Self::Error> {
1501            self.0.get_own_device().await
1502        }
1503
1504        async fn get_user_identity(
1505            &self,
1506            user_id: &UserId,
1507        ) -> Result<Option<UserIdentityData>, Self::Error> {
1508            self.0.get_user_identity(user_id).await
1509        }
1510
1511        async fn is_message_known(
1512            &self,
1513            message_hash: &OlmMessageHash,
1514        ) -> Result<bool, Self::Error> {
1515            self.0.is_message_known(message_hash).await
1516        }
1517
1518        async fn get_outgoing_secret_requests(
1519            &self,
1520            request_id: &TransactionId,
1521        ) -> Result<Option<GossipRequest>, Self::Error> {
1522            self.0.get_outgoing_secret_requests(request_id).await
1523        }
1524
1525        async fn get_secret_request_by_info(
1526            &self,
1527            secret_info: &SecretInfo,
1528        ) -> Result<Option<GossipRequest>, Self::Error> {
1529            self.0.get_secret_request_by_info(secret_info).await
1530        }
1531
1532        async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>, Self::Error> {
1533            self.0.get_unsent_secret_requests().await
1534        }
1535
1536        async fn delete_outgoing_secret_requests(
1537            &self,
1538            request_id: &TransactionId,
1539        ) -> Result<(), Self::Error> {
1540            self.0.delete_outgoing_secret_requests(request_id).await
1541        }
1542
1543        async fn get_secrets_from_inbox(
1544            &self,
1545            secret_name: &SecretName,
1546        ) -> Result<Vec<GossippedSecret>, Self::Error> {
1547            self.0.get_secrets_from_inbox(secret_name).await
1548        }
1549
1550        async fn delete_secrets_from_inbox(
1551            &self,
1552            secret_name: &SecretName,
1553        ) -> Result<(), Self::Error> {
1554            self.0.delete_secrets_from_inbox(secret_name).await
1555        }
1556
1557        async fn get_room_settings(
1558            &self,
1559            room_id: &RoomId,
1560        ) -> Result<Option<RoomSettings>, Self::Error> {
1561            self.0.get_room_settings(room_id).await
1562        }
1563
1564        async fn get_received_room_key_bundle_data(
1565            &self,
1566            room_id: &RoomId,
1567            user_id: &UserId,
1568        ) -> crate::store::Result<Option<StoredRoomKeyBundleData>, Self::Error> {
1569            self.0.get_received_room_key_bundle_data(room_id, user_id).await
1570        }
1571
1572        async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>, Self::Error> {
1573            self.0.get_custom_value(key).await
1574        }
1575
1576        async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<(), Self::Error> {
1577            self.0.set_custom_value(key, value).await
1578        }
1579
1580        async fn remove_custom_value(&self, key: &str) -> Result<(), Self::Error> {
1581            self.0.remove_custom_value(key).await
1582        }
1583
1584        async fn try_take_leased_lock(
1585            &self,
1586            lease_duration_ms: u32,
1587            key: &str,
1588            holder: &str,
1589        ) -> Result<bool, Self::Error> {
1590            self.0.try_take_leased_lock(lease_duration_ms, key, holder).await
1591        }
1592
1593        async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
1594            self.0.next_batch_token().await
1595        }
1596    }
1597
1598    cryptostore_integration_tests!();
1599    cryptostore_integration_tests_time!();
1600}