matrix_sdk_crypto/store/
memorystore.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{BTreeMap, HashMap, HashSet},
17    convert::Infallible,
18    sync::Arc,
19};
20
21use async_trait::async_trait;
22use matrix_sdk_common::{
23    cross_process_lock::{
24        CrossProcessLockGeneration,
25        memory_store_helper::{Lease, try_take_leased_lock},
26    },
27    locks::RwLock as StdRwLock,
28};
29use ruma::{
30    DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId,
31    UserId, events::secret::request::SecretName,
32};
33use tokio::sync::{Mutex, RwLock};
34use tracing::warn;
35use vodozemac::Curve25519PublicKey;
36
37use super::{
38    Account, CryptoStore, InboundGroupSession, Session,
39    caches::DeviceStore,
40    types::{
41        BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts, RoomSettings,
42        StoredRoomKeyBundleData, TrackedUser,
43    },
44};
45use crate::{
46    gossiping::{GossipRequest, GossippedSecret, SecretInfo},
47    identities::{DeviceData, UserIdentityData},
48    olm::{
49        OutboundGroupSession, PickledAccount, PickledInboundGroupSession, PickledSession,
50        PrivateCrossSigningIdentity, SenderDataType, StaticAccountData,
51    },
52    store::types::RoomKeyWithheldEntry,
53};
54
55fn encode_key_info(info: &SecretInfo) -> String {
56    match info {
57        SecretInfo::KeyRequest(info) => {
58            format!("{}{}{}", info.room_id(), info.algorithm(), info.session_id())
59        }
60        SecretInfo::SecretRequest(i) => i.as_ref().to_owned(),
61    }
62}
63
64type SessionId = String;
65
66/// The "version" of a backup - newtype wrapper around a String.
67#[derive(Clone, Debug, PartialEq)]
68struct BackupVersion(String);
69
70impl BackupVersion {
71    fn from(s: &str) -> Self {
72        Self(s.to_owned())
73    }
74
75    fn as_str(&self) -> &str {
76        &self.0
77    }
78}
79
80/// An in-memory only store that will forget all the E2EE key once it's dropped.
81#[derive(Default, Debug)]
82pub struct MemoryStore {
83    static_account: Arc<StdRwLock<Option<StaticAccountData>>>,
84
85    account: StdRwLock<Option<String>>,
86    // Map of sender_key to map of session_id to serialized pickle
87    sessions: StdRwLock<BTreeMap<String, BTreeMap<String, String>>>,
88    inbound_group_sessions: StdRwLock<BTreeMap<OwnedRoomId, HashMap<String, String>>>,
89
90    /// Map room id -> session id -> backup order number
91    /// The latest backup in which this session is stored. Equivalent to
92    /// `backed_up_to` in [`IndexedDbCryptoStore`]
93    inbound_group_sessions_backed_up_to:
94        StdRwLock<HashMap<OwnedRoomId, HashMap<SessionId, BackupVersion>>>,
95
96    outbound_group_sessions: StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>,
97    private_identity: StdRwLock<Option<PrivateCrossSigningIdentity>>,
98    tracked_users: StdRwLock<HashMap<OwnedUserId, TrackedUser>>,
99    olm_hashes: StdRwLock<HashMap<String, HashSet<String>>>,
100    devices: DeviceStore,
101    identities: StdRwLock<HashMap<OwnedUserId, String>>,
102    outgoing_key_requests: StdRwLock<HashMap<OwnedTransactionId, GossipRequest>>,
103    key_requests_by_info: StdRwLock<HashMap<String, OwnedTransactionId>>,
104    direct_withheld_info: StdRwLock<HashMap<OwnedRoomId, HashMap<String, RoomKeyWithheldEntry>>>,
105    custom_values: StdRwLock<HashMap<String, Vec<u8>>>,
106    leases: StdRwLock<HashMap<String, Lease>>,
107    secret_inbox: StdRwLock<HashMap<String, Vec<GossippedSecret>>>,
108    backup_keys: RwLock<BackupKeys>,
109    dehydrated_device_pickle_key: RwLock<Option<DehydratedDeviceKey>>,
110    next_batch_token: RwLock<Option<String>>,
111    room_settings: StdRwLock<HashMap<OwnedRoomId, RoomSettings>>,
112    room_key_bundles:
113        StdRwLock<HashMap<OwnedRoomId, HashMap<OwnedUserId, StoredRoomKeyBundleData>>>,
114    room_key_backups_fully_downloaded: StdRwLock<HashSet<OwnedRoomId>>,
115
116    save_changes_lock: Arc<Mutex<()>>,
117}
118
119impl MemoryStore {
120    /// Create a new empty `MemoryStore`.
121    pub fn new() -> Self {
122        Self::default()
123    }
124
125    fn get_static_account(&self) -> Option<StaticAccountData> {
126        self.static_account.read().clone()
127    }
128
129    pub(crate) fn save_devices(&self, devices: Vec<DeviceData>) {
130        for device in devices {
131            let _ = self.devices.add(device);
132        }
133    }
134
135    fn delete_devices(&self, devices: Vec<DeviceData>) {
136        for device in devices {
137            let _ = self.devices.remove(device.user_id(), device.device_id());
138        }
139    }
140
141    fn save_sessions(&self, sessions: Vec<(String, PickledSession)>) {
142        let mut session_store = self.sessions.write();
143
144        for (session_id, pickle) in sessions {
145            let entry = session_store.entry(pickle.sender_key.to_base64()).or_default();
146
147            // insert or replace if exists
148            entry.insert(
149                session_id,
150                serde_json::to_string(&pickle).expect("Failed to serialize olm session"),
151            );
152        }
153    }
154
155    fn save_outbound_group_sessions(&self, sessions: Vec<OutboundGroupSession>) {
156        self.outbound_group_sessions
157            .write()
158            .extend(sessions.into_iter().map(|s| (s.room_id().to_owned(), s)));
159    }
160
161    fn save_private_identity(&self, private_identity: Option<PrivateCrossSigningIdentity>) {
162        *self.private_identity.write() = private_identity;
163    }
164
165    /// Return all the [`InboundGroupSession`]s we have, paired with the
166    /// `backed_up_to` value for each one (or "" where it is missing, which
167    /// should never happen).
168    async fn get_inbound_group_sessions_and_backed_up_to(
169        &self,
170    ) -> Result<Vec<(InboundGroupSession, Option<BackupVersion>)>> {
171        let lookup = |s: &InboundGroupSession| {
172            self.inbound_group_sessions_backed_up_to
173                .read()
174                .get(&s.room_id)?
175                .get(s.session_id())
176                .cloned()
177        };
178
179        Ok(self
180            .get_inbound_group_sessions()
181            .await?
182            .into_iter()
183            .map(|s| {
184                let v = lookup(&s);
185                (s, v)
186            })
187            .collect())
188    }
189}
190
191type Result<T> = std::result::Result<T, Infallible>;
192
193#[cfg_attr(target_family = "wasm", async_trait(?Send))]
194#[cfg_attr(not(target_family = "wasm"), async_trait)]
195impl CryptoStore for MemoryStore {
196    type Error = Infallible;
197
198    async fn load_account(&self) -> Result<Option<Account>> {
199        let pickled_account: Option<PickledAccount> = self.account.read().as_ref().map(|acc| {
200            serde_json::from_str(acc)
201                .expect("Deserialization failed: invalid pickled account JSON format")
202        });
203
204        if let Some(pickle) = pickled_account {
205            let account =
206                Account::from_pickle(pickle).expect("From pickle failed: invalid pickle format");
207
208            *self.static_account.write() = Some(account.static_data().clone());
209
210            Ok(Some(account))
211        } else {
212            Ok(None)
213        }
214    }
215
216    async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
217        Ok(self.private_identity.read().clone())
218    }
219
220    async fn next_batch_token(&self) -> Result<Option<String>> {
221        Ok(self.next_batch_token.read().await.clone())
222    }
223
224    async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
225        let _guard = self.save_changes_lock.lock().await;
226
227        let pickled_account = if let Some(account) = changes.account {
228            *self.static_account.write() = Some(account.static_data().clone());
229            Some(account.pickle())
230        } else {
231            None
232        };
233
234        *self.account.write() = pickled_account.map(|pickle| {
235            serde_json::to_string(&pickle)
236                .expect("Serialization failed: invalid pickled account JSON format")
237        });
238
239        Ok(())
240    }
241
242    async fn save_changes(&self, changes: Changes) -> Result<()> {
243        let _guard = self.save_changes_lock.lock().await;
244
245        let mut pickled_session: Vec<(String, PickledSession)> = Vec::new();
246        for session in changes.sessions {
247            let session_id = session.session_id().to_owned();
248            let pickle = session.pickle().await;
249            pickled_session.push((session_id.clone(), pickle));
250        }
251        self.save_sessions(pickled_session);
252
253        self.save_inbound_group_sessions(changes.inbound_group_sessions, None).await?;
254        self.save_outbound_group_sessions(changes.outbound_group_sessions);
255        self.save_private_identity(changes.private_identity);
256
257        self.save_devices(changes.devices.new);
258        self.save_devices(changes.devices.changed);
259        self.delete_devices(changes.devices.deleted);
260
261        {
262            let mut identities = self.identities.write();
263            for identity in changes.identities.new.into_iter().chain(changes.identities.changed) {
264                identities.insert(
265                    identity.user_id().to_owned(),
266                    serde_json::to_string(&identity)
267                        .expect("UserIdentityData should always serialize to json"),
268                );
269            }
270        }
271
272        {
273            let mut olm_hashes = self.olm_hashes.write();
274            for hash in changes.message_hashes {
275                olm_hashes.entry(hash.sender_key.to_owned()).or_default().insert(hash.hash.clone());
276            }
277        }
278
279        {
280            let mut outgoing_key_requests = self.outgoing_key_requests.write();
281            let mut key_requests_by_info = self.key_requests_by_info.write();
282
283            for key_request in changes.key_requests {
284                let id = key_request.request_id.clone();
285                let info_string = encode_key_info(&key_request.info);
286
287                outgoing_key_requests.insert(id.clone(), key_request);
288                key_requests_by_info.insert(info_string, id);
289            }
290        }
291
292        if let Some(key) = changes.backup_decryption_key {
293            self.backup_keys.write().await.decryption_key = Some(key);
294        }
295
296        if let Some(version) = changes.backup_version {
297            self.backup_keys.write().await.backup_version = Some(version);
298        }
299
300        if let Some(pickle_key) = changes.dehydrated_device_pickle_key {
301            let mut lock = self.dehydrated_device_pickle_key.write().await;
302            *lock = Some(pickle_key);
303        }
304
305        {
306            let mut secret_inbox = self.secret_inbox.write();
307            for secret in changes.secrets {
308                secret_inbox.entry(secret.secret_name.to_string()).or_default().push(secret);
309            }
310        }
311
312        {
313            let mut direct_withheld_info = self.direct_withheld_info.write();
314            for (room_id, data) in changes.withheld_session_info {
315                for (session_id, event) in data {
316                    direct_withheld_info
317                        .entry(room_id.to_owned())
318                        .or_default()
319                        .insert(session_id, event);
320                }
321            }
322        }
323
324        if let Some(next_batch_token) = changes.next_batch_token {
325            *self.next_batch_token.write().await = Some(next_batch_token);
326        }
327
328        if !changes.room_settings.is_empty() {
329            let mut settings = self.room_settings.write();
330            settings.extend(changes.room_settings);
331        }
332
333        if !changes.received_room_key_bundles.is_empty() {
334            let mut room_key_bundles = self.room_key_bundles.write();
335            for bundle in changes.received_room_key_bundles {
336                room_key_bundles
337                    .entry(bundle.bundle_data.room_id.clone())
338                    .or_default()
339                    .insert(bundle.sender_user.clone(), bundle);
340            }
341        }
342
343        if !changes.room_key_backups_fully_downloaded.is_empty() {
344            let mut room_key_backups_fully_downloaded =
345                self.room_key_backups_fully_downloaded.write();
346            for room in changes.room_key_backups_fully_downloaded {
347                room_key_backups_fully_downloaded.insert(room);
348            }
349        }
350
351        Ok(())
352    }
353
354    async fn save_inbound_group_sessions(
355        &self,
356        sessions: Vec<InboundGroupSession>,
357        backed_up_to_version: Option<&str>,
358    ) -> Result<()> {
359        for session in sessions {
360            let room_id = session.room_id();
361            let session_id = session.session_id();
362
363            // Sanity-check that the data in the sessions corresponds to backed_up_version
364            let backed_up = session.backed_up();
365            if backed_up != backed_up_to_version.is_some() {
366                warn!(
367                    backed_up,
368                    backed_up_to_version,
369                    "Session backed-up flag does not correspond to backup version setting",
370                );
371            }
372
373            if let Some(backup_version) = backed_up_to_version {
374                self.inbound_group_sessions_backed_up_to
375                    .write()
376                    .entry(room_id.to_owned())
377                    .or_default()
378                    .insert(session_id.to_owned(), BackupVersion::from(backup_version));
379            }
380
381            let pickle = session.pickle().await;
382            self.inbound_group_sessions
383                .write()
384                .entry(session.room_id().to_owned())
385                .or_default()
386                .insert(
387                    session.session_id().to_owned(),
388                    serde_json::to_string(&pickle)
389                        .expect("Pickle pickle data should serialize to json"),
390                );
391        }
392        Ok(())
393    }
394
395    async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
396        let device_keys = self.get_own_device().await?.as_device_keys().clone();
397
398        if let Some(pickles) = self.sessions.read().get(sender_key) {
399            let mut sessions: Vec<Session> = Vec::new();
400            for serialized_pickle in pickles.values() {
401                let pickle: PickledSession = serde_json::from_str(serialized_pickle.as_str())
402                    .expect("Pickle pickle deserialization should work");
403                let session = Session::from_pickle(device_keys.clone(), pickle)
404                    .expect("Expect from pickle to always work");
405                sessions.push(session);
406            }
407            Ok(Some(sessions))
408        } else {
409            Ok(None)
410        }
411    }
412
413    async fn get_inbound_group_session(
414        &self,
415        room_id: &RoomId,
416        session_id: &str,
417    ) -> Result<Option<InboundGroupSession>> {
418        let pickle: Option<PickledInboundGroupSession> = self
419            .inbound_group_sessions
420            .read()
421            .get(room_id)
422            .and_then(|m| m.get(session_id))
423            .and_then(|ser| {
424                serde_json::from_str(ser).expect("Pickle pickle deserialization should work")
425            });
426
427        Ok(pickle.map(|p| {
428            InboundGroupSession::from_pickle(p).expect("Expect from pickle to always work")
429        }))
430    }
431
432    async fn get_withheld_info(
433        &self,
434        room_id: &RoomId,
435        session_id: &str,
436    ) -> Result<Option<RoomKeyWithheldEntry>> {
437        Ok(self
438            .direct_withheld_info
439            .read()
440            .get(room_id)
441            .and_then(|e| Some(e.get(session_id)?.to_owned())))
442    }
443
444    async fn get_withheld_sessions_by_room_id(
445        &self,
446        room_id: &RoomId,
447    ) -> crate::store::Result<Vec<RoomKeyWithheldEntry>, Self::Error> {
448        Ok(self
449            .direct_withheld_info
450            .read()
451            .get(room_id)
452            .map(|e| e.values().cloned().collect())
453            .unwrap_or_default())
454    }
455
456    async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
457        let inbounds = self
458            .inbound_group_sessions
459            .read()
460            .values()
461            .flat_map(HashMap::values)
462            .map(|ser| {
463                let pickle: PickledInboundGroupSession =
464                    serde_json::from_str(ser).expect("Pickle deserialization should work");
465                InboundGroupSession::from_pickle(pickle).expect("Expect from pickle to always work")
466            })
467            .collect();
468        Ok(inbounds)
469    }
470
471    async fn inbound_group_session_counts(
472        &self,
473        backup_version: Option<&str>,
474    ) -> Result<RoomKeyCounts> {
475        let backed_up = if let Some(backup_version) = backup_version {
476            self.get_inbound_group_sessions_and_backed_up_to()
477                .await?
478                .into_iter()
479                // Count the sessions backed up in the required backup
480                .filter(|(_, o)| o.as_ref().is_some_and(|o| o.as_str() == backup_version))
481                .count()
482        } else {
483            // We asked about a nonexistent backup version - this doesn't make much sense,
484            // but we can easily answer that nothing is backed up in this
485            // nonexistent backup.
486            0
487        };
488
489        let total = self.inbound_group_sessions.read().values().map(HashMap::len).sum();
490        Ok(RoomKeyCounts { total, backed_up })
491    }
492
493    async fn get_inbound_group_sessions_by_room_id(
494        &self,
495        room_id: &RoomId,
496    ) -> Result<Vec<InboundGroupSession>> {
497        let inbounds = match self.inbound_group_sessions.read().get(room_id) {
498            None => Vec::new(),
499            Some(v) => v
500                .values()
501                .map(|ser| {
502                    let pickle: PickledInboundGroupSession =
503                        serde_json::from_str(ser).expect("Pickle deserialization should work");
504                    InboundGroupSession::from_pickle(pickle)
505                        .expect("Expect from pickle to always work")
506                })
507                .collect(),
508        };
509        Ok(inbounds)
510    }
511
512    async fn get_inbound_group_sessions_for_device_batch(
513        &self,
514        sender_key: Curve25519PublicKey,
515        sender_data_type: SenderDataType,
516        after_session_id: Option<String>,
517        limit: usize,
518    ) -> Result<Vec<InboundGroupSession>> {
519        // First, find all InboundGroupSessions, filtering for those that match the
520        // device and sender_data type.
521        let mut sessions: Vec<_> = self
522            .get_inbound_group_sessions()
523            .await?
524            .into_iter()
525            .filter(|session: &InboundGroupSession| {
526                session.creator_info.curve25519_key == sender_key
527                    && session.sender_data.to_type() == sender_data_type
528            })
529            .collect();
530
531        // Then, sort the sessions in order of ascending session ID...
532        sessions.sort_by_key(|s| s.session_id().to_owned());
533
534        // Figure out where in the array to start returning results from
535        let start_index = {
536            match after_session_id {
537                None => 0,
538                Some(id) => {
539                    // We're looking for the first session with a session ID strictly after `id`; if
540                    // there are none, the end of the array.
541                    sessions
542                        .iter()
543                        .position(|session| session.session_id() > id.as_str())
544                        .unwrap_or(sessions.len())
545                }
546            }
547        };
548
549        // Return up to `limit` items from the array, starting from `start_index`
550        Ok(sessions.drain(start_index..).take(limit).collect())
551    }
552
553    async fn inbound_group_sessions_for_backup(
554        &self,
555        backup_version: &str,
556        limit: usize,
557    ) -> Result<Vec<InboundGroupSession>> {
558        Ok(self
559            .get_inbound_group_sessions_and_backed_up_to()
560            .await?
561            .into_iter()
562            .filter_map(|(session, backed_up_to)| {
563                if let Some(ref existing_version) = backed_up_to
564                    && existing_version.as_str() == backup_version
565                {
566                    // This session is already backed up in the required backup
567                    None
568                } else {
569                    // It's not backed up, or it's backed up in a different backup
570                    Some(session)
571                }
572            })
573            .take(limit)
574            .collect())
575    }
576
577    async fn mark_inbound_group_sessions_as_backed_up(
578        &self,
579        backup_version: &str,
580        room_and_session_ids: &[(&RoomId, &str)],
581    ) -> Result<()> {
582        for &(room_id, session_id) in room_and_session_ids {
583            let session = self.get_inbound_group_session(room_id, session_id).await?;
584
585            if let Some(session) = session {
586                session.mark_as_backed_up();
587
588                self.inbound_group_sessions_backed_up_to
589                    .write()
590                    .entry(room_id.to_owned())
591                    .or_default()
592                    .insert(session_id.to_owned(), BackupVersion::from(backup_version));
593
594                // Save it back
595                let updated_pickle = session.pickle().await;
596
597                self.inbound_group_sessions.write().entry(room_id.to_owned()).or_default().insert(
598                    session_id.to_owned(),
599                    serde_json::to_string(&updated_pickle)
600                        .expect("Pickle serialization should work"),
601                );
602            }
603        }
604
605        Ok(())
606    }
607
608    async fn reset_backup_state(&self) -> Result<()> {
609        // Nothing to do here, because we remember which backup versions we backed up to
610        // in `mark_inbound_group_sessions_as_backed_up`, so we don't need to
611        // reset anything here because the required version is passed in to
612        // `inbound_group_sessions_for_backup`, and we can compare against the
613        // version we stored.
614
615        Ok(())
616    }
617
618    async fn load_backup_keys(&self) -> Result<BackupKeys> {
619        Ok(self.backup_keys.read().await.to_owned())
620    }
621
622    async fn load_dehydrated_device_pickle_key(&self) -> Result<Option<DehydratedDeviceKey>> {
623        Ok(self.dehydrated_device_pickle_key.read().await.to_owned())
624    }
625
626    async fn delete_dehydrated_device_pickle_key(&self) -> Result<()> {
627        let mut lock = self.dehydrated_device_pickle_key.write().await;
628        *lock = None;
629        Ok(())
630    }
631
632    async fn get_outbound_group_session(
633        &self,
634        room_id: &RoomId,
635    ) -> Result<Option<OutboundGroupSession>> {
636        Ok(self.outbound_group_sessions.read().get(room_id).cloned())
637    }
638
639    async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
640        Ok(self.tracked_users.read().values().cloned().collect())
641    }
642
643    async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
644        self.tracked_users.write().extend(tracked_users.iter().map(|(user_id, dirty)| {
645            let user_id: OwnedUserId = user_id.to_owned().into();
646            (user_id.clone(), TrackedUser { user_id, dirty: *dirty })
647        }));
648        Ok(())
649    }
650
651    async fn get_device(
652        &self,
653        user_id: &UserId,
654        device_id: &DeviceId,
655    ) -> Result<Option<DeviceData>> {
656        Ok(self.devices.get(user_id, device_id))
657    }
658
659    async fn get_user_devices(
660        &self,
661        user_id: &UserId,
662    ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
663        Ok(self.devices.user_devices(user_id))
664    }
665
666    async fn get_own_device(&self) -> Result<DeviceData> {
667        let account =
668            self.get_static_account().expect("Expect account to exist when getting own device");
669
670        Ok(self
671            .devices
672            .get(&account.user_id, &account.device_id)
673            .expect("Invalid state: Should always have a own device"))
674    }
675
676    async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
677        let serialized = self.identities.read().get(user_id).cloned();
678        match serialized {
679            None => Ok(None),
680            Some(serialized) => {
681                let id: UserIdentityData = serde_json::from_str(serialized.as_str())
682                    .expect("Only valid serialized identity are saved");
683                Ok(Some(id))
684            }
685        }
686    }
687
688    async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
689        Ok(self
690            .olm_hashes
691            .write()
692            .entry(message_hash.sender_key.to_owned())
693            .or_default()
694            .contains(&message_hash.hash))
695    }
696
697    async fn get_outgoing_secret_requests(
698        &self,
699        request_id: &TransactionId,
700    ) -> Result<Option<GossipRequest>> {
701        Ok(self.outgoing_key_requests.read().get(request_id).cloned())
702    }
703
704    async fn get_secret_request_by_info(
705        &self,
706        key_info: &SecretInfo,
707    ) -> Result<Option<GossipRequest>> {
708        let key_info_string = encode_key_info(key_info);
709
710        Ok(self
711            .key_requests_by_info
712            .read()
713            .get(&key_info_string)
714            .and_then(|i| self.outgoing_key_requests.read().get(i).cloned()))
715    }
716
717    async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
718        Ok(self
719            .outgoing_key_requests
720            .read()
721            .values()
722            .filter(|req| !req.sent_out)
723            .cloned()
724            .collect())
725    }
726
727    async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
728        let req = self.outgoing_key_requests.write().remove(request_id);
729        if let Some(i) = req {
730            let key_info_string = encode_key_info(&i.info);
731            self.key_requests_by_info.write().remove(&key_info_string);
732        }
733
734        Ok(())
735    }
736
737    async fn get_secrets_from_inbox(
738        &self,
739        secret_name: &SecretName,
740    ) -> Result<Vec<GossippedSecret>> {
741        Ok(self.secret_inbox.write().entry(secret_name.to_string()).or_default().to_owned())
742    }
743
744    async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
745        self.secret_inbox.write().remove(secret_name.as_str());
746
747        Ok(())
748    }
749
750    async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
751        Ok(self.room_settings.read().get(room_id).cloned())
752    }
753
754    async fn get_received_room_key_bundle_data(
755        &self,
756        room_id: &RoomId,
757        user_id: &UserId,
758    ) -> Result<Option<StoredRoomKeyBundleData>> {
759        let guard = self.room_key_bundles.read();
760
761        let result = guard.get(room_id).and_then(|bundles| bundles.get(user_id).cloned());
762
763        Ok(result)
764    }
765
766    async fn has_downloaded_all_room_keys(&self, room_id: &RoomId) -> Result<bool> {
767        let guard = self.room_key_backups_fully_downloaded.read();
768        Ok(guard.contains(room_id))
769    }
770
771    async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
772        Ok(self.custom_values.read().get(key).cloned())
773    }
774
775    async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
776        self.custom_values.write().insert(key.to_owned(), value);
777        Ok(())
778    }
779
780    async fn remove_custom_value(&self, key: &str) -> Result<()> {
781        self.custom_values.write().remove(key);
782        Ok(())
783    }
784
785    async fn try_take_leased_lock(
786        &self,
787        lease_duration_ms: u32,
788        key: &str,
789        holder: &str,
790    ) -> Result<Option<CrossProcessLockGeneration>> {
791        Ok(try_take_leased_lock(&mut self.leases.write(), lease_duration_ms, key, holder))
792    }
793
794    async fn get_size(&self) -> Result<Option<usize>> {
795        Ok(None)
796    }
797}
798
799#[cfg(test)]
800mod tests {
801    use std::collections::HashMap;
802
803    use matrix_sdk_test::async_test;
804    use ruma::{RoomId, room_id, user_id};
805    use vodozemac::{Curve25519PublicKey, Ed25519PublicKey};
806
807    use super::SessionId;
808    use crate::{
809        DeviceData,
810        identities::device::testing::get_device,
811        olm::{
812            Account, InboundGroupSession, OlmMessageHash, PrivateCrossSigningIdentity, SenderData,
813            tests::get_account_and_session_test_helper,
814        },
815        store::{
816            CryptoStore,
817            memorystore::MemoryStore,
818            types::{Changes, DeviceChanges, PendingChanges},
819        },
820    };
821
822    #[async_test]
823    async fn test_session_store() {
824        let (account, session) = get_account_and_session_test_helper();
825        let own_device = DeviceData::from_account(&account);
826        let store = MemoryStore::new();
827
828        assert!(store.load_account().await.unwrap().is_none());
829
830        store
831            .save_changes(Changes {
832                devices: DeviceChanges { new: vec![own_device], ..Default::default() },
833                ..Default::default()
834            })
835            .await
836            .unwrap();
837        store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
838
839        store
840            .save_changes(Changes { sessions: (vec![session.clone()]), ..Default::default() })
841            .await
842            .unwrap();
843
844        let sessions = store.get_sessions(&session.sender_key.to_base64()).await.unwrap().unwrap();
845
846        let loaded_session = &sessions[0];
847
848        assert_eq!(&session, loaded_session);
849    }
850
851    #[async_test]
852    async fn test_inbound_group_session_store() {
853        let (account, _) = get_account_and_session_test_helper();
854        let room_id = room_id!("!test:localhost");
855        let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw";
856
857        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
858        let inbound = InboundGroupSession::new(
859            Curve25519PublicKey::from_base64(curve_key).unwrap(),
860            Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE").unwrap(),
861            room_id,
862            &outbound.session_key().await,
863            SenderData::unknown(),
864            None,
865            outbound.settings().algorithm.to_owned(),
866            None,
867            false,
868        )
869        .unwrap();
870
871        let store = MemoryStore::new();
872        store.save_inbound_group_sessions(vec![inbound.clone()], None).await.unwrap();
873
874        let loaded_session =
875            store.get_inbound_group_session(room_id, outbound.session_id()).await.unwrap().unwrap();
876        assert_eq!(inbound, loaded_session);
877    }
878
879    #[async_test]
880    async fn test_backing_up_marks_sessions_as_backed_up() {
881        // Given there are 2 sessions
882        let room_id = room_id!("!test:localhost");
883        let (store, sessions) = store_with_sessions(2, room_id).await;
884
885        // When I mark them as backed up
886        mark_backed_up(&store, room_id, "bkp1", &sessions).await;
887
888        // Then their backed_up_to field is set
889        let but = backed_up_tos(&store).await;
890        assert_eq!(but[sessions[0].session_id()], "bkp1");
891        assert_eq!(but[sessions[1].session_id()], "bkp1");
892    }
893
894    #[async_test]
895    async fn test_backing_up_a_second_set_of_sessions_updates_their_backup_order() {
896        // Given there are 3 sessions
897        let room_id = room_id!("!test:localhost");
898        let (store, sessions) = store_with_sessions(3, room_id).await;
899
900        // When I mark 0 and 1 as backed up in bkp1
901        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
902
903        // And 1 and 2 as backed up in bkp2
904        mark_backed_up(&store, room_id, "bkp2", &sessions[1..]).await;
905
906        // Then 0 is backed up in bkp1 and the 1 and 2 are backed up in bkp2
907        let but = backed_up_tos(&store).await;
908        assert_eq!(but[sessions[0].session_id()], "bkp1");
909        assert_eq!(but[sessions[1].session_id()], "bkp2");
910        assert_eq!(but[sessions[2].session_id()], "bkp2");
911    }
912
913    #[async_test]
914    async fn test_backing_up_again_to_the_same_version_has_no_effect() {
915        // Given there are 3 sessions
916        let room_id = room_id!("!test:localhost");
917        let (store, sessions) = store_with_sessions(3, room_id).await;
918
919        // When I mark the first two as backed up in the first backup
920        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
921
922        // And the last 2 as backed up in the same backup version
923        mark_backed_up(&store, room_id, "bkp1", &sessions[1..]).await;
924
925        // Then they all get the same backed_up_to value
926        let but = backed_up_tos(&store).await;
927        assert_eq!(but[sessions[0].session_id()], "bkp1");
928        assert_eq!(but[sessions[1].session_id()], "bkp1");
929        assert_eq!(but[sessions[2].session_id()], "bkp1");
930    }
931
932    #[async_test]
933    async fn test_backing_up_to_an_old_backup_version_can_increase_backed_up_to() {
934        // Given we have backed up some sessions to 2 backup versions, an older and a
935        // newer
936        let room_id = room_id!("!test:localhost");
937        let (store, sessions) = store_with_sessions(4, room_id).await;
938        mark_backed_up(&store, room_id, "older_bkp", &sessions[..2]).await;
939        mark_backed_up(&store, room_id, "newer_bkp", &sessions[1..2]).await;
940
941        // When I ask to back up the un-backed-up ones to the older backup
942        mark_backed_up(&store, room_id, "older_bkp", &sessions[2..]).await;
943
944        // Then each session lists the backup it was most recently included in
945        let but = backed_up_tos(&store).await;
946        assert_eq!(but[sessions[0].session_id()], "older_bkp");
947        assert_eq!(but[sessions[1].session_id()], "newer_bkp");
948        assert_eq!(but[sessions[2].session_id()], "older_bkp");
949        assert_eq!(but[sessions[3].session_id()], "older_bkp");
950    }
951
952    #[async_test]
953    async fn test_backing_up_to_an_old_backup_version_overwrites_a_newer_one() {
954        // Given we have backed up to 2 backup versions, an older and a newer
955        let room_id = room_id!("!test:localhost");
956        let (store, sessions) = store_with_sessions(4, room_id).await;
957        mark_backed_up(&store, room_id, "older_bkp", &sessions).await;
958        // Sanity: they are backed up in order number 1
959        assert_eq!(backed_up_tos(&store).await[sessions[0].session_id()], "older_bkp");
960        mark_backed_up(&store, room_id, "newer_bkp", &sessions).await;
961        // Sanity: they are backed up in order number 2
962        assert_eq!(backed_up_tos(&store).await[sessions[0].session_id()], "newer_bkp");
963
964        // When I ask to back up some to the older version
965        mark_backed_up(&store, room_id, "older_bkp", &sessions[..2]).await;
966
967        // Then older backup overwrites: we don't consider the order here at all
968        let but = backed_up_tos(&store).await;
969        assert_eq!(but[sessions[0].session_id()], "older_bkp");
970        assert_eq!(but[sessions[1].session_id()], "older_bkp");
971        assert_eq!(but[sessions[2].session_id()], "newer_bkp");
972        assert_eq!(but[sessions[3].session_id()], "newer_bkp");
973    }
974
975    #[async_test]
976    async fn test_not_backed_up_sessions_are_eligible_for_backup() {
977        // Given there are 4 sessions, 2 of which are already backed up
978        let room_id = room_id!("!test:localhost");
979        let (store, sessions) = store_with_sessions(4, room_id).await;
980        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
981
982        // When I ask which to back up
983        let mut to_backup = store
984            .inbound_group_sessions_for_backup("bkp1", 10)
985            .await
986            .expect("Failed to ask for sessions to backup");
987        to_backup.sort_by_key(|s| s.session_id().to_owned());
988
989        // Then I am told the last 2 only
990        assert_eq!(to_backup, &[sessions[2].clone(), sessions[3].clone()]);
991    }
992
993    #[async_test]
994    async fn test_all_sessions_are_eligible_for_backup_if_version_is_unknown() {
995        // Given there are 4 sessions, 2 of which are already backed up in bkp1
996        let room_id = room_id!("!test:localhost");
997        let (store, sessions) = store_with_sessions(4, room_id).await;
998        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
999
1000        // When I ask which to back up in an unknown version
1001        let mut to_backup = store
1002            .inbound_group_sessions_for_backup("unknown_bkp", 10)
1003            .await
1004            .expect("Failed to ask for sessions to backup");
1005        to_backup.sort_by_key(|s| s.session_id().to_owned());
1006
1007        // Then I am told to back up all of them
1008        assert_eq!(
1009            to_backup,
1010            &[sessions[0].clone(), sessions[1].clone(), sessions[2].clone(), sessions[3].clone()]
1011        );
1012    }
1013
1014    #[async_test]
1015    async fn test_sessions_backed_up_to_a_later_version_are_eligible_for_backup() {
1016        // Given there are 4 sessions, some backed up to three different versions
1017        let room_id = room_id!("!test:localhost");
1018        let (store, sessions) = store_with_sessions(4, room_id).await;
1019        mark_backed_up(&store, room_id, "bkp0", &sessions[..1]).await;
1020        mark_backed_up(&store, room_id, "bkp1", &sessions[1..2]).await;
1021        mark_backed_up(&store, room_id, "bkp2", &sessions[2..3]).await;
1022
1023        // When I ask which to back up in the middle version
1024        let mut to_backup = store
1025            .inbound_group_sessions_for_backup("bkp1", 10)
1026            .await
1027            .expect("Failed to ask for sessions to backup");
1028        to_backup.sort_by_key(|s| s.session_id().to_owned());
1029
1030        // Then I am told to back up everything not in the version I asked about
1031        assert_eq!(
1032            to_backup,
1033            &[
1034                sessions[0].clone(), // Backed up in bkp0
1035                // sessions[1] is backed up in bkp1 already, which we asked about
1036                sessions[2].clone(), // Backed up in bkp2
1037                sessions[3].clone(), // Not backed up
1038            ]
1039        );
1040    }
1041
1042    #[async_test]
1043    async fn test_outbound_group_session_store() {
1044        // Given an outbound session
1045        let (account, _) = get_account_and_session_test_helper();
1046        let room_id = room_id!("!test:localhost");
1047        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
1048
1049        // When we save it to the store
1050        let store = MemoryStore::new();
1051        store.save_outbound_group_sessions(vec![outbound.clone()]);
1052
1053        // Then we can get it out again
1054        let loaded_session = store.get_outbound_group_session(room_id).await.unwrap().unwrap();
1055        assert_eq!(
1056            serde_json::to_string(&outbound.pickle().await).unwrap(),
1057            serde_json::to_string(&loaded_session.pickle().await).unwrap()
1058        );
1059    }
1060
1061    #[async_test]
1062    async fn test_tracked_users_are_stored_once_per_user_id() {
1063        // Given a store containing 2 tracked users, both dirty
1064        let user1 = user_id!("@user1:s");
1065        let user2 = user_id!("@user2:s");
1066        let user3 = user_id!("@user3:s");
1067        let store = MemoryStore::new();
1068        store.save_tracked_users(&[(user1, true), (user2, true)]).await.unwrap();
1069
1070        // When we mark one as clean and add another
1071        store.save_tracked_users(&[(user2, false), (user3, false)]).await.unwrap();
1072
1073        // Then we can get them out again and their dirty flags are correct
1074        let loaded_tracked_users =
1075            store.load_tracked_users().await.expect("failed to load tracked users");
1076
1077        let tracked_contains = |user_id, dirty| {
1078            loaded_tracked_users.iter().any(|u| u.user_id == user_id && u.dirty == dirty)
1079        };
1080
1081        assert!(tracked_contains(user1, true));
1082        assert!(tracked_contains(user2, false));
1083        assert!(tracked_contains(user3, false));
1084        assert_eq!(loaded_tracked_users.len(), 3);
1085    }
1086
1087    #[async_test]
1088    async fn test_private_identity_store() {
1089        // Given a private identity
1090        let private_identity = PrivateCrossSigningIdentity::empty(user_id!("@u:s"));
1091
1092        // When we save it to the store
1093        let store = MemoryStore::new();
1094        store.save_private_identity(Some(private_identity.clone()));
1095
1096        // Then we can get it out again
1097        let loaded_identity =
1098            store.load_identity().await.expect("failed to load private identity").unwrap();
1099
1100        assert_eq!(loaded_identity.user_id(), user_id!("@u:s"));
1101    }
1102
1103    #[async_test]
1104    async fn test_device_store() {
1105        let device = get_device();
1106        let store = MemoryStore::new();
1107
1108        store.save_devices(vec![device.clone()]);
1109
1110        let loaded_device =
1111            store.get_device(device.user_id(), device.device_id()).await.unwrap().unwrap();
1112
1113        assert_eq!(device, loaded_device);
1114
1115        let user_devices = store.get_user_devices(device.user_id()).await.unwrap();
1116
1117        assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
1118        assert_eq!(user_devices.values().next().unwrap(), &device);
1119
1120        let loaded_device = user_devices.get(device.device_id()).unwrap();
1121
1122        assert_eq!(&device, loaded_device);
1123
1124        store.delete_devices(vec![device.clone()]);
1125        assert!(store.get_device(device.user_id(), device.device_id()).await.unwrap().is_none());
1126    }
1127
1128    #[async_test]
1129    async fn test_message_hash() {
1130        let store = MemoryStore::new();
1131
1132        let hash =
1133            OlmMessageHash { sender_key: "test_sender".to_owned(), hash: "test_hash".to_owned() };
1134
1135        let mut changes = Changes::default();
1136        changes.message_hashes.push(hash.clone());
1137
1138        assert!(!store.is_message_known(&hash).await.unwrap());
1139        store.save_changes(changes).await.unwrap();
1140        assert!(store.is_message_known(&hash).await.unwrap());
1141    }
1142
1143    #[async_test]
1144    async fn test_key_counts_of_empty_store_are_zero() {
1145        // Given an empty store
1146        let store = MemoryStore::new();
1147
1148        // When we count keys
1149        let key_counts = store.inbound_group_session_counts(Some("")).await.unwrap();
1150
1151        // Then the answer is zero
1152        assert_eq!(key_counts.total, 0);
1153        assert_eq!(key_counts.backed_up, 0);
1154    }
1155
1156    #[async_test]
1157    async fn test_counting_sessions_reports_the_number_of_sessions() {
1158        // Given a store with sessions
1159        let room_id = room_id!("!test:localhost");
1160        let (store, _) = store_with_sessions(4, room_id).await;
1161
1162        // When we count keys
1163        let key_counts = store.inbound_group_session_counts(Some("bkp")).await.unwrap();
1164
1165        // Then the answer equals the number of sessions we created
1166        assert_eq!(key_counts.total, 4);
1167        // And none are backed up
1168        assert_eq!(key_counts.backed_up, 0);
1169    }
1170
1171    #[async_test]
1172    async fn test_counting_backed_up_sessions_reports_the_number_backed_up_in_this_backup() {
1173        // Given a store with sessions, some backed up
1174        let room_id = room_id!("!test:localhost");
1175        let (store, sessions) = store_with_sessions(5, room_id).await;
1176        mark_backed_up(&store, room_id, "bkp", &sessions[..2]).await;
1177
1178        // When we count keys
1179        let key_counts = store.inbound_group_session_counts(Some("bkp")).await.unwrap();
1180
1181        // Then the answer equals the number of sessions we created
1182        assert_eq!(key_counts.total, 5);
1183        // And the backed_up count matches how many were backed up
1184        assert_eq!(key_counts.backed_up, 2);
1185    }
1186
1187    #[async_test]
1188    async fn test_counting_backed_up_sessions_for_null_backup_reports_zero() {
1189        // Given a store with sessions, some backed up
1190        let room_id = room_id!("!test:localhost");
1191        let (store, sessions) = store_with_sessions(4, room_id).await;
1192        mark_backed_up(&store, room_id, "bkp", &sessions[..2]).await;
1193
1194        // When we count keys, providing None as the backup version
1195        let key_counts = store.inbound_group_session_counts(None).await.unwrap();
1196
1197        // Then we ignore everything and just say zero
1198        assert_eq!(key_counts.backed_up, 0);
1199    }
1200
1201    #[async_test]
1202    async fn test_counting_backed_up_sessions_only_reports_sessions_in_the_version_specified() {
1203        // Given a store with sessions, backed up in several versions
1204        let room_id = room_id!("!test:localhost");
1205        let (store, sessions) = store_with_sessions(4, room_id).await;
1206        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
1207        mark_backed_up(&store, room_id, "bkp2", &sessions[3..]).await;
1208
1209        // When we count keys for bkp2
1210        let key_counts = store.inbound_group_session_counts(Some("bkp2")).await.unwrap();
1211
1212        // Then the backed_up count reflects how many were backed up in bkp2 only
1213        assert_eq!(key_counts.backed_up, 1);
1214    }
1215
1216    /// Mark the supplied sessions as backed up in the supplied backup version
1217    async fn mark_backed_up(
1218        store: &MemoryStore,
1219        room_id: &RoomId,
1220        backup_version: &str,
1221        sessions: &[InboundGroupSession],
1222    ) {
1223        let rooms_and_ids: Vec<_> = sessions.iter().map(|s| (room_id, s.session_id())).collect();
1224
1225        store
1226            .mark_inbound_group_sessions_as_backed_up(backup_version, &rooms_and_ids)
1227            .await
1228            .expect("Failed to mark sessions as backed up");
1229    }
1230
1231    // Create a MemoryStore containing the supplied number of sessions.
1232    //
1233    // Sessions are returned in alphabetical order of session id.
1234    async fn store_with_sessions(
1235        num_sessions: usize,
1236        room_id: &RoomId,
1237    ) -> (MemoryStore, Vec<InboundGroupSession>) {
1238        let (account, _) = get_account_and_session_test_helper();
1239
1240        let mut sessions = Vec::with_capacity(num_sessions);
1241        for _ in 0..num_sessions {
1242            sessions.push(new_session(&account, room_id).await);
1243        }
1244        sessions.sort_by_key(|s| s.session_id().to_owned());
1245
1246        let store = MemoryStore::new();
1247        store.save_inbound_group_sessions(sessions.clone(), None).await.unwrap();
1248
1249        (store, sessions)
1250    }
1251
1252    // Create a new InboundGroupSession
1253    async fn new_session(account: &Account, room_id: &RoomId) -> InboundGroupSession {
1254        let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw";
1255        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
1256
1257        InboundGroupSession::new(
1258            Curve25519PublicKey::from_base64(curve_key).unwrap(),
1259            Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE").unwrap(),
1260            room_id,
1261            &outbound.session_key().await,
1262            SenderData::unknown(),
1263            None,
1264            outbound.settings().algorithm.to_owned(),
1265            None,
1266            false,
1267        )
1268        .unwrap()
1269    }
1270
1271    /// Find the session_id and backed_up_to value for each of the sessions in
1272    /// the store.
1273    async fn backed_up_tos(store: &MemoryStore) -> HashMap<SessionId, String> {
1274        store
1275            .get_inbound_group_sessions_and_backed_up_to()
1276            .await
1277            .expect("Unable to get inbound group sessions and backup order")
1278            .iter()
1279            .map(|(s, o)| {
1280                (
1281                    s.session_id().to_owned(),
1282                    o.as_ref().map(|v| v.as_str().to_owned()).unwrap_or("".to_owned()),
1283                )
1284            })
1285            .collect()
1286    }
1287}
1288
1289#[cfg(test)]
1290mod integration_tests {
1291    use std::{
1292        collections::HashMap,
1293        sync::{Arc, Mutex, OnceLock},
1294    };
1295
1296    use async_trait::async_trait;
1297    use matrix_sdk_common::cross_process_lock::CrossProcessLockGeneration;
1298    use ruma::{
1299        DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId, events::secret::request::SecretName,
1300    };
1301    use vodozemac::Curve25519PublicKey;
1302
1303    use super::MemoryStore;
1304    use crate::{
1305        Account, DeviceData, GossipRequest, GossippedSecret, SecretInfo, Session, UserIdentityData,
1306        cryptostore_integration_tests, cryptostore_integration_tests_time,
1307        olm::{
1308            InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity,
1309            SenderDataType, StaticAccountData,
1310        },
1311        store::{
1312            CryptoStore,
1313            types::{
1314                BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts,
1315                RoomKeyWithheldEntry, RoomSettings, StoredRoomKeyBundleData, TrackedUser,
1316            },
1317        },
1318    };
1319
1320    /// Holds on to a MemoryStore during a test, and moves it back into STORES
1321    /// when this is dropped
1322    #[derive(Clone, Debug)]
1323    struct PersistentMemoryStore(Arc<MemoryStore>);
1324
1325    impl PersistentMemoryStore {
1326        fn new() -> Self {
1327            Self(Arc::new(MemoryStore::new()))
1328        }
1329
1330        fn get_static_account(&self) -> Option<StaticAccountData> {
1331            self.0.get_static_account()
1332        }
1333    }
1334
1335    /// Return a clone of the store for the test with the supplied name. Note:
1336    /// dropping this store won't destroy its data, since
1337    /// [PersistentMemoryStore] is a reference-counted smart pointer
1338    /// to an underlying [MemoryStore].
1339    async fn get_store(
1340        name: &str,
1341        _passphrase: Option<&str>,
1342        clear_data: bool,
1343    ) -> PersistentMemoryStore {
1344        // Holds on to one [PersistentMemoryStore] per test, so even if the test drops
1345        // the store, we keep its data alive. This simulates the behaviour of
1346        // the other stores, which keep their data in a real DB, allowing us to
1347        // test MemoryStore using the same code.
1348        static STORES: OnceLock<Mutex<HashMap<String, PersistentMemoryStore>>> = OnceLock::new();
1349        let stores = STORES.get_or_init(|| Mutex::new(HashMap::new()));
1350
1351        let mut stores = stores.lock().unwrap();
1352
1353        if clear_data {
1354            // Create a new PersistentMemoryStore
1355            let new_store = PersistentMemoryStore::new();
1356            stores.insert(name.to_owned(), new_store.clone());
1357            new_store
1358        } else {
1359            stores.entry(name.to_owned()).or_insert_with(PersistentMemoryStore::new).clone()
1360        }
1361    }
1362
1363    /// Forwards all methods to the underlying [MemoryStore].
1364    #[cfg_attr(target_family = "wasm", async_trait(?Send))]
1365    #[cfg_attr(not(target_family = "wasm"), async_trait)]
1366    impl CryptoStore for PersistentMemoryStore {
1367        type Error = <MemoryStore as CryptoStore>::Error;
1368
1369        async fn load_account(&self) -> Result<Option<Account>, Self::Error> {
1370            self.0.load_account().await
1371        }
1372
1373        async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>, Self::Error> {
1374            self.0.load_identity().await
1375        }
1376
1377        async fn save_changes(&self, changes: Changes) -> Result<(), Self::Error> {
1378            self.0.save_changes(changes).await
1379        }
1380
1381        async fn save_pending_changes(&self, changes: PendingChanges) -> Result<(), Self::Error> {
1382            self.0.save_pending_changes(changes).await
1383        }
1384
1385        async fn save_inbound_group_sessions(
1386            &self,
1387            sessions: Vec<InboundGroupSession>,
1388            backed_up_to_version: Option<&str>,
1389        ) -> Result<(), Self::Error> {
1390            self.0.save_inbound_group_sessions(sessions, backed_up_to_version).await
1391        }
1392
1393        async fn get_sessions(
1394            &self,
1395            sender_key: &str,
1396        ) -> Result<Option<Vec<Session>>, Self::Error> {
1397            self.0.get_sessions(sender_key).await
1398        }
1399
1400        async fn get_inbound_group_session(
1401            &self,
1402            room_id: &RoomId,
1403            session_id: &str,
1404        ) -> Result<Option<InboundGroupSession>, Self::Error> {
1405            self.0.get_inbound_group_session(room_id, session_id).await
1406        }
1407
1408        async fn get_withheld_info(
1409            &self,
1410            room_id: &RoomId,
1411            session_id: &str,
1412        ) -> Result<Option<RoomKeyWithheldEntry>, Self::Error> {
1413            self.0.get_withheld_info(room_id, session_id).await
1414        }
1415
1416        async fn get_withheld_sessions_by_room_id(
1417            &self,
1418            room_id: &RoomId,
1419        ) -> Result<Vec<RoomKeyWithheldEntry>, Self::Error> {
1420            self.0.get_withheld_sessions_by_room_id(room_id).await
1421        }
1422
1423        async fn get_inbound_group_sessions(
1424            &self,
1425        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1426            self.0.get_inbound_group_sessions().await
1427        }
1428
1429        async fn inbound_group_session_counts(
1430            &self,
1431            backup_version: Option<&str>,
1432        ) -> Result<RoomKeyCounts, Self::Error> {
1433            self.0.inbound_group_session_counts(backup_version).await
1434        }
1435
1436        async fn get_inbound_group_sessions_by_room_id(
1437            &self,
1438            room_id: &RoomId,
1439        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1440            self.0.get_inbound_group_sessions_by_room_id(room_id).await
1441        }
1442
1443        async fn get_inbound_group_sessions_for_device_batch(
1444            &self,
1445            sender_key: Curve25519PublicKey,
1446            sender_data_type: SenderDataType,
1447            after_session_id: Option<String>,
1448            limit: usize,
1449        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1450            self.0
1451                .get_inbound_group_sessions_for_device_batch(
1452                    sender_key,
1453                    sender_data_type,
1454                    after_session_id,
1455                    limit,
1456                )
1457                .await
1458        }
1459
1460        async fn inbound_group_sessions_for_backup(
1461            &self,
1462            backup_version: &str,
1463            limit: usize,
1464        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1465            self.0.inbound_group_sessions_for_backup(backup_version, limit).await
1466        }
1467
1468        async fn mark_inbound_group_sessions_as_backed_up(
1469            &self,
1470            backup_version: &str,
1471            room_and_session_ids: &[(&RoomId, &str)],
1472        ) -> Result<(), Self::Error> {
1473            self.0
1474                .mark_inbound_group_sessions_as_backed_up(backup_version, room_and_session_ids)
1475                .await
1476        }
1477
1478        async fn reset_backup_state(&self) -> Result<(), Self::Error> {
1479            self.0.reset_backup_state().await
1480        }
1481
1482        async fn load_backup_keys(&self) -> Result<BackupKeys, Self::Error> {
1483            self.0.load_backup_keys().await
1484        }
1485
1486        async fn load_dehydrated_device_pickle_key(
1487            &self,
1488        ) -> Result<Option<DehydratedDeviceKey>, Self::Error> {
1489            self.0.load_dehydrated_device_pickle_key().await
1490        }
1491
1492        async fn delete_dehydrated_device_pickle_key(&self) -> Result<(), Self::Error> {
1493            self.0.delete_dehydrated_device_pickle_key().await
1494        }
1495
1496        async fn get_outbound_group_session(
1497            &self,
1498            room_id: &RoomId,
1499        ) -> Result<Option<OutboundGroupSession>, Self::Error> {
1500            self.0.get_outbound_group_session(room_id).await
1501        }
1502
1503        async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>, Self::Error> {
1504            self.0.load_tracked_users().await
1505        }
1506
1507        async fn save_tracked_users(&self, users: &[(&UserId, bool)]) -> Result<(), Self::Error> {
1508            self.0.save_tracked_users(users).await
1509        }
1510
1511        async fn get_device(
1512            &self,
1513            user_id: &UserId,
1514            device_id: &DeviceId,
1515        ) -> Result<Option<DeviceData>, Self::Error> {
1516            self.0.get_device(user_id, device_id).await
1517        }
1518
1519        async fn get_user_devices(
1520            &self,
1521            user_id: &UserId,
1522        ) -> Result<HashMap<OwnedDeviceId, DeviceData>, Self::Error> {
1523            self.0.get_user_devices(user_id).await
1524        }
1525
1526        async fn get_own_device(&self) -> Result<DeviceData, Self::Error> {
1527            self.0.get_own_device().await
1528        }
1529
1530        async fn get_user_identity(
1531            &self,
1532            user_id: &UserId,
1533        ) -> Result<Option<UserIdentityData>, Self::Error> {
1534            self.0.get_user_identity(user_id).await
1535        }
1536
1537        async fn is_message_known(
1538            &self,
1539            message_hash: &OlmMessageHash,
1540        ) -> Result<bool, Self::Error> {
1541            self.0.is_message_known(message_hash).await
1542        }
1543
1544        async fn get_outgoing_secret_requests(
1545            &self,
1546            request_id: &TransactionId,
1547        ) -> Result<Option<GossipRequest>, Self::Error> {
1548            self.0.get_outgoing_secret_requests(request_id).await
1549        }
1550
1551        async fn get_secret_request_by_info(
1552            &self,
1553            secret_info: &SecretInfo,
1554        ) -> Result<Option<GossipRequest>, Self::Error> {
1555            self.0.get_secret_request_by_info(secret_info).await
1556        }
1557
1558        async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>, Self::Error> {
1559            self.0.get_unsent_secret_requests().await
1560        }
1561
1562        async fn delete_outgoing_secret_requests(
1563            &self,
1564            request_id: &TransactionId,
1565        ) -> Result<(), Self::Error> {
1566            self.0.delete_outgoing_secret_requests(request_id).await
1567        }
1568
1569        async fn get_secrets_from_inbox(
1570            &self,
1571            secret_name: &SecretName,
1572        ) -> Result<Vec<GossippedSecret>, Self::Error> {
1573            self.0.get_secrets_from_inbox(secret_name).await
1574        }
1575
1576        async fn delete_secrets_from_inbox(
1577            &self,
1578            secret_name: &SecretName,
1579        ) -> Result<(), Self::Error> {
1580            self.0.delete_secrets_from_inbox(secret_name).await
1581        }
1582
1583        async fn get_room_settings(
1584            &self,
1585            room_id: &RoomId,
1586        ) -> Result<Option<RoomSettings>, Self::Error> {
1587            self.0.get_room_settings(room_id).await
1588        }
1589
1590        async fn get_received_room_key_bundle_data(
1591            &self,
1592            room_id: &RoomId,
1593            user_id: &UserId,
1594        ) -> crate::store::Result<Option<StoredRoomKeyBundleData>, Self::Error> {
1595            self.0.get_received_room_key_bundle_data(room_id, user_id).await
1596        }
1597
1598        async fn has_downloaded_all_room_keys(
1599            &self,
1600            room_id: &RoomId,
1601        ) -> Result<bool, Self::Error> {
1602            self.0.has_downloaded_all_room_keys(room_id).await
1603        }
1604
1605        async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>, Self::Error> {
1606            self.0.get_custom_value(key).await
1607        }
1608
1609        async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<(), Self::Error> {
1610            self.0.set_custom_value(key, value).await
1611        }
1612
1613        async fn remove_custom_value(&self, key: &str) -> Result<(), Self::Error> {
1614            self.0.remove_custom_value(key).await
1615        }
1616
1617        async fn try_take_leased_lock(
1618            &self,
1619            lease_duration_ms: u32,
1620            key: &str,
1621            holder: &str,
1622        ) -> Result<Option<CrossProcessLockGeneration>, Self::Error> {
1623            self.0.try_take_leased_lock(lease_duration_ms, key, holder).await
1624        }
1625
1626        async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
1627            self.0.next_batch_token().await
1628        }
1629
1630        async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1631            self.0.get_size().await
1632        }
1633    }
1634
1635    cryptostore_integration_tests!();
1636    cryptostore_integration_tests_time!();
1637}