matrix_sdk_crypto/store/
memorystore.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{BTreeMap, HashMap, HashSet},
17    convert::Infallible,
18    sync::Arc,
19};
20
21use async_trait::async_trait;
22use matrix_sdk_common::{
23    cross_process_lock::{
24        CrossProcessLockGeneration,
25        memory_store_helper::{Lease, try_take_leased_lock},
26    },
27    locks::RwLock as StdRwLock,
28};
29use ruma::{
30    DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId,
31    UserId, events::secret::request::SecretName,
32};
33use tokio::sync::{Mutex, RwLock};
34use tracing::warn;
35use vodozemac::Curve25519PublicKey;
36
37use super::{
38    Account, CryptoStore, InboundGroupSession, Session,
39    caches::DeviceStore,
40    types::{
41        BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts, RoomSettings,
42        StoredRoomKeyBundleData, TrackedUser,
43    },
44};
45use crate::{
46    gossiping::{GossipRequest, GossippedSecret, SecretInfo},
47    identities::{DeviceData, UserIdentityData},
48    olm::{
49        OutboundGroupSession, PickledAccount, PickledInboundGroupSession, PickledSession,
50        PrivateCrossSigningIdentity, SenderDataType, StaticAccountData,
51    },
52    store::types::RoomKeyWithheldEntry,
53};
54
55fn encode_key_info(info: &SecretInfo) -> String {
56    match info {
57        SecretInfo::KeyRequest(info) => {
58            format!("{}{}{}", info.room_id(), info.algorithm(), info.session_id())
59        }
60        SecretInfo::SecretRequest(i) => i.as_ref().to_owned(),
61    }
62}
63
64type SessionId = String;
65
66/// The "version" of a backup - newtype wrapper around a String.
67#[derive(Clone, Debug, PartialEq)]
68struct BackupVersion(String);
69
70impl BackupVersion {
71    fn from(s: &str) -> Self {
72        Self(s.to_owned())
73    }
74
75    fn as_str(&self) -> &str {
76        &self.0
77    }
78}
79
80/// An in-memory only store that will forget all the E2EE key once it's dropped.
81#[derive(Default, Debug)]
82pub struct MemoryStore {
83    static_account: Arc<StdRwLock<Option<StaticAccountData>>>,
84
85    account: StdRwLock<Option<String>>,
86    // Map of sender_key to map of session_id to serialized pickle
87    sessions: StdRwLock<BTreeMap<String, BTreeMap<String, String>>>,
88    inbound_group_sessions: StdRwLock<BTreeMap<OwnedRoomId, HashMap<String, String>>>,
89
90    /// Map room id -> session id -> backup order number
91    /// The latest backup in which this session is stored. Equivalent to
92    /// `backed_up_to` in [`IndexedDbCryptoStore`]
93    inbound_group_sessions_backed_up_to:
94        StdRwLock<HashMap<OwnedRoomId, HashMap<SessionId, BackupVersion>>>,
95
96    outbound_group_sessions: StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>,
97    private_identity: StdRwLock<Option<PrivateCrossSigningIdentity>>,
98    tracked_users: StdRwLock<HashMap<OwnedUserId, TrackedUser>>,
99    olm_hashes: StdRwLock<HashMap<String, HashSet<String>>>,
100    devices: DeviceStore,
101    identities: StdRwLock<HashMap<OwnedUserId, String>>,
102    outgoing_key_requests: StdRwLock<HashMap<OwnedTransactionId, GossipRequest>>,
103    key_requests_by_info: StdRwLock<HashMap<String, OwnedTransactionId>>,
104    direct_withheld_info: StdRwLock<HashMap<OwnedRoomId, HashMap<String, RoomKeyWithheldEntry>>>,
105    custom_values: StdRwLock<HashMap<String, Vec<u8>>>,
106    leases: StdRwLock<HashMap<String, Lease>>,
107    secret_inbox: StdRwLock<HashMap<String, Vec<GossippedSecret>>>,
108    backup_keys: RwLock<BackupKeys>,
109    dehydrated_device_pickle_key: RwLock<Option<DehydratedDeviceKey>>,
110    next_batch_token: RwLock<Option<String>>,
111    room_settings: StdRwLock<HashMap<OwnedRoomId, RoomSettings>>,
112    room_key_bundles:
113        StdRwLock<HashMap<OwnedRoomId, HashMap<OwnedUserId, StoredRoomKeyBundleData>>>,
114
115    save_changes_lock: Arc<Mutex<()>>,
116}
117
118impl MemoryStore {
119    /// Create a new empty `MemoryStore`.
120    pub fn new() -> Self {
121        Self::default()
122    }
123
124    fn get_static_account(&self) -> Option<StaticAccountData> {
125        self.static_account.read().clone()
126    }
127
128    pub(crate) fn save_devices(&self, devices: Vec<DeviceData>) {
129        for device in devices {
130            let _ = self.devices.add(device);
131        }
132    }
133
134    fn delete_devices(&self, devices: Vec<DeviceData>) {
135        for device in devices {
136            let _ = self.devices.remove(device.user_id(), device.device_id());
137        }
138    }
139
140    fn save_sessions(&self, sessions: Vec<(String, PickledSession)>) {
141        let mut session_store = self.sessions.write();
142
143        for (session_id, pickle) in sessions {
144            let entry = session_store.entry(pickle.sender_key.to_base64()).or_default();
145
146            // insert or replace if exists
147            entry.insert(
148                session_id,
149                serde_json::to_string(&pickle).expect("Failed to serialize olm session"),
150            );
151        }
152    }
153
154    fn save_outbound_group_sessions(&self, sessions: Vec<OutboundGroupSession>) {
155        self.outbound_group_sessions
156            .write()
157            .extend(sessions.into_iter().map(|s| (s.room_id().to_owned(), s)));
158    }
159
160    fn save_private_identity(&self, private_identity: Option<PrivateCrossSigningIdentity>) {
161        *self.private_identity.write() = private_identity;
162    }
163
164    /// Return all the [`InboundGroupSession`]s we have, paired with the
165    /// `backed_up_to` value for each one (or "" where it is missing, which
166    /// should never happen).
167    async fn get_inbound_group_sessions_and_backed_up_to(
168        &self,
169    ) -> Result<Vec<(InboundGroupSession, Option<BackupVersion>)>> {
170        let lookup = |s: &InboundGroupSession| {
171            self.inbound_group_sessions_backed_up_to
172                .read()
173                .get(&s.room_id)?
174                .get(s.session_id())
175                .cloned()
176        };
177
178        Ok(self
179            .get_inbound_group_sessions()
180            .await?
181            .into_iter()
182            .map(|s| {
183                let v = lookup(&s);
184                (s, v)
185            })
186            .collect())
187    }
188}
189
190type Result<T> = std::result::Result<T, Infallible>;
191
192#[cfg_attr(target_family = "wasm", async_trait(?Send))]
193#[cfg_attr(not(target_family = "wasm"), async_trait)]
194impl CryptoStore for MemoryStore {
195    type Error = Infallible;
196
197    async fn load_account(&self) -> Result<Option<Account>> {
198        let pickled_account: Option<PickledAccount> = self.account.read().as_ref().map(|acc| {
199            serde_json::from_str(acc)
200                .expect("Deserialization failed: invalid pickled account JSON format")
201        });
202
203        if let Some(pickle) = pickled_account {
204            let account =
205                Account::from_pickle(pickle).expect("From pickle failed: invalid pickle format");
206
207            *self.static_account.write() = Some(account.static_data().clone());
208
209            Ok(Some(account))
210        } else {
211            Ok(None)
212        }
213    }
214
215    async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
216        Ok(self.private_identity.read().clone())
217    }
218
219    async fn next_batch_token(&self) -> Result<Option<String>> {
220        Ok(self.next_batch_token.read().await.clone())
221    }
222
223    async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
224        let _guard = self.save_changes_lock.lock().await;
225
226        let pickled_account = if let Some(account) = changes.account {
227            *self.static_account.write() = Some(account.static_data().clone());
228            Some(account.pickle())
229        } else {
230            None
231        };
232
233        *self.account.write() = pickled_account.map(|pickle| {
234            serde_json::to_string(&pickle)
235                .expect("Serialization failed: invalid pickled account JSON format")
236        });
237
238        Ok(())
239    }
240
241    async fn save_changes(&self, changes: Changes) -> Result<()> {
242        let _guard = self.save_changes_lock.lock().await;
243
244        let mut pickled_session: Vec<(String, PickledSession)> = Vec::new();
245        for session in changes.sessions {
246            let session_id = session.session_id().to_owned();
247            let pickle = session.pickle().await;
248            pickled_session.push((session_id.clone(), pickle));
249        }
250        self.save_sessions(pickled_session);
251
252        self.save_inbound_group_sessions(changes.inbound_group_sessions, None).await?;
253        self.save_outbound_group_sessions(changes.outbound_group_sessions);
254        self.save_private_identity(changes.private_identity);
255
256        self.save_devices(changes.devices.new);
257        self.save_devices(changes.devices.changed);
258        self.delete_devices(changes.devices.deleted);
259
260        {
261            let mut identities = self.identities.write();
262            for identity in changes.identities.new.into_iter().chain(changes.identities.changed) {
263                identities.insert(
264                    identity.user_id().to_owned(),
265                    serde_json::to_string(&identity)
266                        .expect("UserIdentityData should always serialize to json"),
267                );
268            }
269        }
270
271        {
272            let mut olm_hashes = self.olm_hashes.write();
273            for hash in changes.message_hashes {
274                olm_hashes.entry(hash.sender_key.to_owned()).or_default().insert(hash.hash.clone());
275            }
276        }
277
278        {
279            let mut outgoing_key_requests = self.outgoing_key_requests.write();
280            let mut key_requests_by_info = self.key_requests_by_info.write();
281
282            for key_request in changes.key_requests {
283                let id = key_request.request_id.clone();
284                let info_string = encode_key_info(&key_request.info);
285
286                outgoing_key_requests.insert(id.clone(), key_request);
287                key_requests_by_info.insert(info_string, id);
288            }
289        }
290
291        if let Some(key) = changes.backup_decryption_key {
292            self.backup_keys.write().await.decryption_key = Some(key);
293        }
294
295        if let Some(version) = changes.backup_version {
296            self.backup_keys.write().await.backup_version = Some(version);
297        }
298
299        if let Some(pickle_key) = changes.dehydrated_device_pickle_key {
300            let mut lock = self.dehydrated_device_pickle_key.write().await;
301            *lock = Some(pickle_key);
302        }
303
304        {
305            let mut secret_inbox = self.secret_inbox.write();
306            for secret in changes.secrets {
307                secret_inbox.entry(secret.secret_name.to_string()).or_default().push(secret);
308            }
309        }
310
311        {
312            let mut direct_withheld_info = self.direct_withheld_info.write();
313            for (room_id, data) in changes.withheld_session_info {
314                for (session_id, event) in data {
315                    direct_withheld_info
316                        .entry(room_id.to_owned())
317                        .or_default()
318                        .insert(session_id, event);
319                }
320            }
321        }
322
323        if let Some(next_batch_token) = changes.next_batch_token {
324            *self.next_batch_token.write().await = Some(next_batch_token);
325        }
326
327        if !changes.room_settings.is_empty() {
328            let mut settings = self.room_settings.write();
329            settings.extend(changes.room_settings);
330        }
331
332        if !changes.received_room_key_bundles.is_empty() {
333            let mut room_key_bundles = self.room_key_bundles.write();
334            for bundle in changes.received_room_key_bundles {
335                room_key_bundles
336                    .entry(bundle.bundle_data.room_id.clone())
337                    .or_default()
338                    .insert(bundle.sender_user.clone(), bundle);
339            }
340        }
341
342        Ok(())
343    }
344
345    async fn save_inbound_group_sessions(
346        &self,
347        sessions: Vec<InboundGroupSession>,
348        backed_up_to_version: Option<&str>,
349    ) -> Result<()> {
350        for session in sessions {
351            let room_id = session.room_id();
352            let session_id = session.session_id();
353
354            // Sanity-check that the data in the sessions corresponds to backed_up_version
355            let backed_up = session.backed_up();
356            if backed_up != backed_up_to_version.is_some() {
357                warn!(
358                    backed_up,
359                    backed_up_to_version,
360                    "Session backed-up flag does not correspond to backup version setting",
361                );
362            }
363
364            if let Some(backup_version) = backed_up_to_version {
365                self.inbound_group_sessions_backed_up_to
366                    .write()
367                    .entry(room_id.to_owned())
368                    .or_default()
369                    .insert(session_id.to_owned(), BackupVersion::from(backup_version));
370            }
371
372            let pickle = session.pickle().await;
373            self.inbound_group_sessions
374                .write()
375                .entry(session.room_id().to_owned())
376                .or_default()
377                .insert(
378                    session.session_id().to_owned(),
379                    serde_json::to_string(&pickle)
380                        .expect("Pickle pickle data should serialize to json"),
381                );
382        }
383        Ok(())
384    }
385
386    async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
387        let device_keys = self.get_own_device().await?.as_device_keys().clone();
388
389        if let Some(pickles) = self.sessions.read().get(sender_key) {
390            let mut sessions: Vec<Session> = Vec::new();
391            for serialized_pickle in pickles.values() {
392                let pickle: PickledSession = serde_json::from_str(serialized_pickle.as_str())
393                    .expect("Pickle pickle deserialization should work");
394                let session = Session::from_pickle(device_keys.clone(), pickle)
395                    .expect("Expect from pickle to always work");
396                sessions.push(session);
397            }
398            Ok(Some(sessions))
399        } else {
400            Ok(None)
401        }
402    }
403
404    async fn get_inbound_group_session(
405        &self,
406        room_id: &RoomId,
407        session_id: &str,
408    ) -> Result<Option<InboundGroupSession>> {
409        let pickle: Option<PickledInboundGroupSession> = self
410            .inbound_group_sessions
411            .read()
412            .get(room_id)
413            .and_then(|m| m.get(session_id))
414            .and_then(|ser| {
415                serde_json::from_str(ser).expect("Pickle pickle deserialization should work")
416            });
417
418        Ok(pickle.map(|p| {
419            InboundGroupSession::from_pickle(p).expect("Expect from pickle to always work")
420        }))
421    }
422
423    async fn get_withheld_info(
424        &self,
425        room_id: &RoomId,
426        session_id: &str,
427    ) -> Result<Option<RoomKeyWithheldEntry>> {
428        Ok(self
429            .direct_withheld_info
430            .read()
431            .get(room_id)
432            .and_then(|e| Some(e.get(session_id)?.to_owned())))
433    }
434
435    async fn get_withheld_sessions_by_room_id(
436        &self,
437        room_id: &RoomId,
438    ) -> crate::store::Result<Vec<RoomKeyWithheldEntry>, Self::Error> {
439        Ok(self
440            .direct_withheld_info
441            .read()
442            .get(room_id)
443            .map(|e| e.values().cloned().collect())
444            .unwrap_or_default())
445    }
446
447    async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
448        let inbounds = self
449            .inbound_group_sessions
450            .read()
451            .values()
452            .flat_map(HashMap::values)
453            .map(|ser| {
454                let pickle: PickledInboundGroupSession =
455                    serde_json::from_str(ser).expect("Pickle deserialization should work");
456                InboundGroupSession::from_pickle(pickle).expect("Expect from pickle to always work")
457            })
458            .collect();
459        Ok(inbounds)
460    }
461
462    async fn inbound_group_session_counts(
463        &self,
464        backup_version: Option<&str>,
465    ) -> Result<RoomKeyCounts> {
466        let backed_up = if let Some(backup_version) = backup_version {
467            self.get_inbound_group_sessions_and_backed_up_to()
468                .await?
469                .into_iter()
470                // Count the sessions backed up in the required backup
471                .filter(|(_, o)| o.as_ref().is_some_and(|o| o.as_str() == backup_version))
472                .count()
473        } else {
474            // We asked about a nonexistent backup version - this doesn't make much sense,
475            // but we can easily answer that nothing is backed up in this
476            // nonexistent backup.
477            0
478        };
479
480        let total = self.inbound_group_sessions.read().values().map(HashMap::len).sum();
481        Ok(RoomKeyCounts { total, backed_up })
482    }
483
484    async fn get_inbound_group_sessions_by_room_id(
485        &self,
486        room_id: &RoomId,
487    ) -> Result<Vec<InboundGroupSession>> {
488        let inbounds = match self.inbound_group_sessions.read().get(room_id) {
489            None => Vec::new(),
490            Some(v) => v
491                .values()
492                .map(|ser| {
493                    let pickle: PickledInboundGroupSession =
494                        serde_json::from_str(ser).expect("Pickle deserialization should work");
495                    InboundGroupSession::from_pickle(pickle)
496                        .expect("Expect from pickle to always work")
497                })
498                .collect(),
499        };
500        Ok(inbounds)
501    }
502
503    async fn get_inbound_group_sessions_for_device_batch(
504        &self,
505        sender_key: Curve25519PublicKey,
506        sender_data_type: SenderDataType,
507        after_session_id: Option<String>,
508        limit: usize,
509    ) -> Result<Vec<InboundGroupSession>> {
510        // First, find all InboundGroupSessions, filtering for those that match the
511        // device and sender_data type.
512        let mut sessions: Vec<_> = self
513            .get_inbound_group_sessions()
514            .await?
515            .into_iter()
516            .filter(|session: &InboundGroupSession| {
517                session.creator_info.curve25519_key == sender_key
518                    && session.sender_data.to_type() == sender_data_type
519            })
520            .collect();
521
522        // Then, sort the sessions in order of ascending session ID...
523        sessions.sort_by_key(|s| s.session_id().to_owned());
524
525        // Figure out where in the array to start returning results from
526        let start_index = {
527            match after_session_id {
528                None => 0,
529                Some(id) => {
530                    // We're looking for the first session with a session ID strictly after `id`; if
531                    // there are none, the end of the array.
532                    sessions
533                        .iter()
534                        .position(|session| session.session_id() > id.as_str())
535                        .unwrap_or(sessions.len())
536                }
537            }
538        };
539
540        // Return up to `limit` items from the array, starting from `start_index`
541        Ok(sessions.drain(start_index..).take(limit).collect())
542    }
543
544    async fn inbound_group_sessions_for_backup(
545        &self,
546        backup_version: &str,
547        limit: usize,
548    ) -> Result<Vec<InboundGroupSession>> {
549        Ok(self
550            .get_inbound_group_sessions_and_backed_up_to()
551            .await?
552            .into_iter()
553            .filter_map(|(session, backed_up_to)| {
554                if let Some(ref existing_version) = backed_up_to
555                    && existing_version.as_str() == backup_version
556                {
557                    // This session is already backed up in the required backup
558                    None
559                } else {
560                    // It's not backed up, or it's backed up in a different backup
561                    Some(session)
562                }
563            })
564            .take(limit)
565            .collect())
566    }
567
568    async fn mark_inbound_group_sessions_as_backed_up(
569        &self,
570        backup_version: &str,
571        room_and_session_ids: &[(&RoomId, &str)],
572    ) -> Result<()> {
573        for &(room_id, session_id) in room_and_session_ids {
574            let session = self.get_inbound_group_session(room_id, session_id).await?;
575
576            if let Some(session) = session {
577                session.mark_as_backed_up();
578
579                self.inbound_group_sessions_backed_up_to
580                    .write()
581                    .entry(room_id.to_owned())
582                    .or_default()
583                    .insert(session_id.to_owned(), BackupVersion::from(backup_version));
584
585                // Save it back
586                let updated_pickle = session.pickle().await;
587
588                self.inbound_group_sessions.write().entry(room_id.to_owned()).or_default().insert(
589                    session_id.to_owned(),
590                    serde_json::to_string(&updated_pickle)
591                        .expect("Pickle serialization should work"),
592                );
593            }
594        }
595
596        Ok(())
597    }
598
599    async fn reset_backup_state(&self) -> Result<()> {
600        // Nothing to do here, because we remember which backup versions we backed up to
601        // in `mark_inbound_group_sessions_as_backed_up`, so we don't need to
602        // reset anything here because the required version is passed in to
603        // `inbound_group_sessions_for_backup`, and we can compare against the
604        // version we stored.
605
606        Ok(())
607    }
608
609    async fn load_backup_keys(&self) -> Result<BackupKeys> {
610        Ok(self.backup_keys.read().await.to_owned())
611    }
612
613    async fn load_dehydrated_device_pickle_key(&self) -> Result<Option<DehydratedDeviceKey>> {
614        Ok(self.dehydrated_device_pickle_key.read().await.to_owned())
615    }
616
617    async fn delete_dehydrated_device_pickle_key(&self) -> Result<()> {
618        let mut lock = self.dehydrated_device_pickle_key.write().await;
619        *lock = None;
620        Ok(())
621    }
622
623    async fn get_outbound_group_session(
624        &self,
625        room_id: &RoomId,
626    ) -> Result<Option<OutboundGroupSession>> {
627        Ok(self.outbound_group_sessions.read().get(room_id).cloned())
628    }
629
630    async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
631        Ok(self.tracked_users.read().values().cloned().collect())
632    }
633
634    async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
635        self.tracked_users.write().extend(tracked_users.iter().map(|(user_id, dirty)| {
636            let user_id: OwnedUserId = user_id.to_owned().into();
637            (user_id.clone(), TrackedUser { user_id, dirty: *dirty })
638        }));
639        Ok(())
640    }
641
642    async fn get_device(
643        &self,
644        user_id: &UserId,
645        device_id: &DeviceId,
646    ) -> Result<Option<DeviceData>> {
647        Ok(self.devices.get(user_id, device_id))
648    }
649
650    async fn get_user_devices(
651        &self,
652        user_id: &UserId,
653    ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
654        Ok(self.devices.user_devices(user_id))
655    }
656
657    async fn get_own_device(&self) -> Result<DeviceData> {
658        let account =
659            self.get_static_account().expect("Expect account to exist when getting own device");
660
661        Ok(self
662            .devices
663            .get(&account.user_id, &account.device_id)
664            .expect("Invalid state: Should always have a own device"))
665    }
666
667    async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
668        let serialized = self.identities.read().get(user_id).cloned();
669        match serialized {
670            None => Ok(None),
671            Some(serialized) => {
672                let id: UserIdentityData = serde_json::from_str(serialized.as_str())
673                    .expect("Only valid serialized identity are saved");
674                Ok(Some(id))
675            }
676        }
677    }
678
679    async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
680        Ok(self
681            .olm_hashes
682            .write()
683            .entry(message_hash.sender_key.to_owned())
684            .or_default()
685            .contains(&message_hash.hash))
686    }
687
688    async fn get_outgoing_secret_requests(
689        &self,
690        request_id: &TransactionId,
691    ) -> Result<Option<GossipRequest>> {
692        Ok(self.outgoing_key_requests.read().get(request_id).cloned())
693    }
694
695    async fn get_secret_request_by_info(
696        &self,
697        key_info: &SecretInfo,
698    ) -> Result<Option<GossipRequest>> {
699        let key_info_string = encode_key_info(key_info);
700
701        Ok(self
702            .key_requests_by_info
703            .read()
704            .get(&key_info_string)
705            .and_then(|i| self.outgoing_key_requests.read().get(i).cloned()))
706    }
707
708    async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
709        Ok(self
710            .outgoing_key_requests
711            .read()
712            .values()
713            .filter(|req| !req.sent_out)
714            .cloned()
715            .collect())
716    }
717
718    async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
719        let req = self.outgoing_key_requests.write().remove(request_id);
720        if let Some(i) = req {
721            let key_info_string = encode_key_info(&i.info);
722            self.key_requests_by_info.write().remove(&key_info_string);
723        }
724
725        Ok(())
726    }
727
728    async fn get_secrets_from_inbox(
729        &self,
730        secret_name: &SecretName,
731    ) -> Result<Vec<GossippedSecret>> {
732        Ok(self.secret_inbox.write().entry(secret_name.to_string()).or_default().to_owned())
733    }
734
735    async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
736        self.secret_inbox.write().remove(secret_name.as_str());
737
738        Ok(())
739    }
740
741    async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
742        Ok(self.room_settings.read().get(room_id).cloned())
743    }
744
745    async fn get_received_room_key_bundle_data(
746        &self,
747        room_id: &RoomId,
748        user_id: &UserId,
749    ) -> Result<Option<StoredRoomKeyBundleData>> {
750        let guard = self.room_key_bundles.read();
751
752        let result = guard.get(room_id).and_then(|bundles| bundles.get(user_id).cloned());
753
754        Ok(result)
755    }
756
757    async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
758        Ok(self.custom_values.read().get(key).cloned())
759    }
760
761    async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
762        self.custom_values.write().insert(key.to_owned(), value);
763        Ok(())
764    }
765
766    async fn remove_custom_value(&self, key: &str) -> Result<()> {
767        self.custom_values.write().remove(key);
768        Ok(())
769    }
770
771    async fn try_take_leased_lock(
772        &self,
773        lease_duration_ms: u32,
774        key: &str,
775        holder: &str,
776    ) -> Result<Option<CrossProcessLockGeneration>> {
777        Ok(try_take_leased_lock(&mut self.leases.write(), lease_duration_ms, key, holder))
778    }
779
780    async fn get_size(&self) -> Result<Option<usize>> {
781        Ok(None)
782    }
783}
784
785#[cfg(test)]
786mod tests {
787    use std::collections::HashMap;
788
789    use matrix_sdk_test::async_test;
790    use ruma::{RoomId, room_id, user_id};
791    use vodozemac::{Curve25519PublicKey, Ed25519PublicKey};
792
793    use super::SessionId;
794    use crate::{
795        DeviceData,
796        identities::device::testing::get_device,
797        olm::{
798            Account, InboundGroupSession, OlmMessageHash, PrivateCrossSigningIdentity, SenderData,
799            tests::get_account_and_session_test_helper,
800        },
801        store::{
802            CryptoStore,
803            memorystore::MemoryStore,
804            types::{Changes, DeviceChanges, PendingChanges},
805        },
806    };
807
808    #[async_test]
809    async fn test_session_store() {
810        let (account, session) = get_account_and_session_test_helper();
811        let own_device = DeviceData::from_account(&account);
812        let store = MemoryStore::new();
813
814        assert!(store.load_account().await.unwrap().is_none());
815
816        store
817            .save_changes(Changes {
818                devices: DeviceChanges { new: vec![own_device], ..Default::default() },
819                ..Default::default()
820            })
821            .await
822            .unwrap();
823        store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
824
825        store
826            .save_changes(Changes { sessions: (vec![session.clone()]), ..Default::default() })
827            .await
828            .unwrap();
829
830        let sessions = store.get_sessions(&session.sender_key.to_base64()).await.unwrap().unwrap();
831
832        let loaded_session = &sessions[0];
833
834        assert_eq!(&session, loaded_session);
835    }
836
837    #[async_test]
838    async fn test_inbound_group_session_store() {
839        let (account, _) = get_account_and_session_test_helper();
840        let room_id = room_id!("!test:localhost");
841        let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw";
842
843        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
844        let inbound = InboundGroupSession::new(
845            Curve25519PublicKey::from_base64(curve_key).unwrap(),
846            Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE").unwrap(),
847            room_id,
848            &outbound.session_key().await,
849            SenderData::unknown(),
850            None,
851            outbound.settings().algorithm.to_owned(),
852            None,
853            false,
854        )
855        .unwrap();
856
857        let store = MemoryStore::new();
858        store.save_inbound_group_sessions(vec![inbound.clone()], None).await.unwrap();
859
860        let loaded_session =
861            store.get_inbound_group_session(room_id, outbound.session_id()).await.unwrap().unwrap();
862        assert_eq!(inbound, loaded_session);
863    }
864
865    #[async_test]
866    async fn test_backing_up_marks_sessions_as_backed_up() {
867        // Given there are 2 sessions
868        let room_id = room_id!("!test:localhost");
869        let (store, sessions) = store_with_sessions(2, room_id).await;
870
871        // When I mark them as backed up
872        mark_backed_up(&store, room_id, "bkp1", &sessions).await;
873
874        // Then their backed_up_to field is set
875        let but = backed_up_tos(&store).await;
876        assert_eq!(but[sessions[0].session_id()], "bkp1");
877        assert_eq!(but[sessions[1].session_id()], "bkp1");
878    }
879
880    #[async_test]
881    async fn test_backing_up_a_second_set_of_sessions_updates_their_backup_order() {
882        // Given there are 3 sessions
883        let room_id = room_id!("!test:localhost");
884        let (store, sessions) = store_with_sessions(3, room_id).await;
885
886        // When I mark 0 and 1 as backed up in bkp1
887        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
888
889        // And 1 and 2 as backed up in bkp2
890        mark_backed_up(&store, room_id, "bkp2", &sessions[1..]).await;
891
892        // Then 0 is backed up in bkp1 and the 1 and 2 are backed up in bkp2
893        let but = backed_up_tos(&store).await;
894        assert_eq!(but[sessions[0].session_id()], "bkp1");
895        assert_eq!(but[sessions[1].session_id()], "bkp2");
896        assert_eq!(but[sessions[2].session_id()], "bkp2");
897    }
898
899    #[async_test]
900    async fn test_backing_up_again_to_the_same_version_has_no_effect() {
901        // Given there are 3 sessions
902        let room_id = room_id!("!test:localhost");
903        let (store, sessions) = store_with_sessions(3, room_id).await;
904
905        // When I mark the first two as backed up in the first backup
906        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
907
908        // And the last 2 as backed up in the same backup version
909        mark_backed_up(&store, room_id, "bkp1", &sessions[1..]).await;
910
911        // Then they all get the same backed_up_to value
912        let but = backed_up_tos(&store).await;
913        assert_eq!(but[sessions[0].session_id()], "bkp1");
914        assert_eq!(but[sessions[1].session_id()], "bkp1");
915        assert_eq!(but[sessions[2].session_id()], "bkp1");
916    }
917
918    #[async_test]
919    async fn test_backing_up_to_an_old_backup_version_can_increase_backed_up_to() {
920        // Given we have backed up some sessions to 2 backup versions, an older and a
921        // newer
922        let room_id = room_id!("!test:localhost");
923        let (store, sessions) = store_with_sessions(4, room_id).await;
924        mark_backed_up(&store, room_id, "older_bkp", &sessions[..2]).await;
925        mark_backed_up(&store, room_id, "newer_bkp", &sessions[1..2]).await;
926
927        // When I ask to back up the un-backed-up ones to the older backup
928        mark_backed_up(&store, room_id, "older_bkp", &sessions[2..]).await;
929
930        // Then each session lists the backup it was most recently included in
931        let but = backed_up_tos(&store).await;
932        assert_eq!(but[sessions[0].session_id()], "older_bkp");
933        assert_eq!(but[sessions[1].session_id()], "newer_bkp");
934        assert_eq!(but[sessions[2].session_id()], "older_bkp");
935        assert_eq!(but[sessions[3].session_id()], "older_bkp");
936    }
937
938    #[async_test]
939    async fn test_backing_up_to_an_old_backup_version_overwrites_a_newer_one() {
940        // Given we have backed up to 2 backup versions, an older and a newer
941        let room_id = room_id!("!test:localhost");
942        let (store, sessions) = store_with_sessions(4, room_id).await;
943        mark_backed_up(&store, room_id, "older_bkp", &sessions).await;
944        // Sanity: they are backed up in order number 1
945        assert_eq!(backed_up_tos(&store).await[sessions[0].session_id()], "older_bkp");
946        mark_backed_up(&store, room_id, "newer_bkp", &sessions).await;
947        // Sanity: they are backed up in order number 2
948        assert_eq!(backed_up_tos(&store).await[sessions[0].session_id()], "newer_bkp");
949
950        // When I ask to back up some to the older version
951        mark_backed_up(&store, room_id, "older_bkp", &sessions[..2]).await;
952
953        // Then older backup overwrites: we don't consider the order here at all
954        let but = backed_up_tos(&store).await;
955        assert_eq!(but[sessions[0].session_id()], "older_bkp");
956        assert_eq!(but[sessions[1].session_id()], "older_bkp");
957        assert_eq!(but[sessions[2].session_id()], "newer_bkp");
958        assert_eq!(but[sessions[3].session_id()], "newer_bkp");
959    }
960
961    #[async_test]
962    async fn test_not_backed_up_sessions_are_eligible_for_backup() {
963        // Given there are 4 sessions, 2 of which are already backed up
964        let room_id = room_id!("!test:localhost");
965        let (store, sessions) = store_with_sessions(4, room_id).await;
966        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
967
968        // When I ask which to back up
969        let mut to_backup = store
970            .inbound_group_sessions_for_backup("bkp1", 10)
971            .await
972            .expect("Failed to ask for sessions to backup");
973        to_backup.sort_by_key(|s| s.session_id().to_owned());
974
975        // Then I am told the last 2 only
976        assert_eq!(to_backup, &[sessions[2].clone(), sessions[3].clone()]);
977    }
978
979    #[async_test]
980    async fn test_all_sessions_are_eligible_for_backup_if_version_is_unknown() {
981        // Given there are 4 sessions, 2 of which are already backed up in bkp1
982        let room_id = room_id!("!test:localhost");
983        let (store, sessions) = store_with_sessions(4, room_id).await;
984        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
985
986        // When I ask which to back up in an unknown version
987        let mut to_backup = store
988            .inbound_group_sessions_for_backup("unknown_bkp", 10)
989            .await
990            .expect("Failed to ask for sessions to backup");
991        to_backup.sort_by_key(|s| s.session_id().to_owned());
992
993        // Then I am told to back up all of them
994        assert_eq!(
995            to_backup,
996            &[sessions[0].clone(), sessions[1].clone(), sessions[2].clone(), sessions[3].clone()]
997        );
998    }
999
1000    #[async_test]
1001    async fn test_sessions_backed_up_to_a_later_version_are_eligible_for_backup() {
1002        // Given there are 4 sessions, some backed up to three different versions
1003        let room_id = room_id!("!test:localhost");
1004        let (store, sessions) = store_with_sessions(4, room_id).await;
1005        mark_backed_up(&store, room_id, "bkp0", &sessions[..1]).await;
1006        mark_backed_up(&store, room_id, "bkp1", &sessions[1..2]).await;
1007        mark_backed_up(&store, room_id, "bkp2", &sessions[2..3]).await;
1008
1009        // When I ask which to back up in the middle version
1010        let mut to_backup = store
1011            .inbound_group_sessions_for_backup("bkp1", 10)
1012            .await
1013            .expect("Failed to ask for sessions to backup");
1014        to_backup.sort_by_key(|s| s.session_id().to_owned());
1015
1016        // Then I am told to back up everything not in the version I asked about
1017        assert_eq!(
1018            to_backup,
1019            &[
1020                sessions[0].clone(), // Backed up in bkp0
1021                // sessions[1] is backed up in bkp1 already, which we asked about
1022                sessions[2].clone(), // Backed up in bkp2
1023                sessions[3].clone(), // Not backed up
1024            ]
1025        );
1026    }
1027
1028    #[async_test]
1029    async fn test_outbound_group_session_store() {
1030        // Given an outbound session
1031        let (account, _) = get_account_and_session_test_helper();
1032        let room_id = room_id!("!test:localhost");
1033        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
1034
1035        // When we save it to the store
1036        let store = MemoryStore::new();
1037        store.save_outbound_group_sessions(vec![outbound.clone()]);
1038
1039        // Then we can get it out again
1040        let loaded_session = store.get_outbound_group_session(room_id).await.unwrap().unwrap();
1041        assert_eq!(
1042            serde_json::to_string(&outbound.pickle().await).unwrap(),
1043            serde_json::to_string(&loaded_session.pickle().await).unwrap()
1044        );
1045    }
1046
1047    #[async_test]
1048    async fn test_tracked_users_are_stored_once_per_user_id() {
1049        // Given a store containing 2 tracked users, both dirty
1050        let user1 = user_id!("@user1:s");
1051        let user2 = user_id!("@user2:s");
1052        let user3 = user_id!("@user3:s");
1053        let store = MemoryStore::new();
1054        store.save_tracked_users(&[(user1, true), (user2, true)]).await.unwrap();
1055
1056        // When we mark one as clean and add another
1057        store.save_tracked_users(&[(user2, false), (user3, false)]).await.unwrap();
1058
1059        // Then we can get them out again and their dirty flags are correct
1060        let loaded_tracked_users =
1061            store.load_tracked_users().await.expect("failed to load tracked users");
1062
1063        let tracked_contains = |user_id, dirty| {
1064            loaded_tracked_users.iter().any(|u| u.user_id == user_id && u.dirty == dirty)
1065        };
1066
1067        assert!(tracked_contains(user1, true));
1068        assert!(tracked_contains(user2, false));
1069        assert!(tracked_contains(user3, false));
1070        assert_eq!(loaded_tracked_users.len(), 3);
1071    }
1072
1073    #[async_test]
1074    async fn test_private_identity_store() {
1075        // Given a private identity
1076        let private_identity = PrivateCrossSigningIdentity::empty(user_id!("@u:s"));
1077
1078        // When we save it to the store
1079        let store = MemoryStore::new();
1080        store.save_private_identity(Some(private_identity.clone()));
1081
1082        // Then we can get it out again
1083        let loaded_identity =
1084            store.load_identity().await.expect("failed to load private identity").unwrap();
1085
1086        assert_eq!(loaded_identity.user_id(), user_id!("@u:s"));
1087    }
1088
1089    #[async_test]
1090    async fn test_device_store() {
1091        let device = get_device();
1092        let store = MemoryStore::new();
1093
1094        store.save_devices(vec![device.clone()]);
1095
1096        let loaded_device =
1097            store.get_device(device.user_id(), device.device_id()).await.unwrap().unwrap();
1098
1099        assert_eq!(device, loaded_device);
1100
1101        let user_devices = store.get_user_devices(device.user_id()).await.unwrap();
1102
1103        assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
1104        assert_eq!(user_devices.values().next().unwrap(), &device);
1105
1106        let loaded_device = user_devices.get(device.device_id()).unwrap();
1107
1108        assert_eq!(&device, loaded_device);
1109
1110        store.delete_devices(vec![device.clone()]);
1111        assert!(store.get_device(device.user_id(), device.device_id()).await.unwrap().is_none());
1112    }
1113
1114    #[async_test]
1115    async fn test_message_hash() {
1116        let store = MemoryStore::new();
1117
1118        let hash =
1119            OlmMessageHash { sender_key: "test_sender".to_owned(), hash: "test_hash".to_owned() };
1120
1121        let mut changes = Changes::default();
1122        changes.message_hashes.push(hash.clone());
1123
1124        assert!(!store.is_message_known(&hash).await.unwrap());
1125        store.save_changes(changes).await.unwrap();
1126        assert!(store.is_message_known(&hash).await.unwrap());
1127    }
1128
1129    #[async_test]
1130    async fn test_key_counts_of_empty_store_are_zero() {
1131        // Given an empty store
1132        let store = MemoryStore::new();
1133
1134        // When we count keys
1135        let key_counts = store.inbound_group_session_counts(Some("")).await.unwrap();
1136
1137        // Then the answer is zero
1138        assert_eq!(key_counts.total, 0);
1139        assert_eq!(key_counts.backed_up, 0);
1140    }
1141
1142    #[async_test]
1143    async fn test_counting_sessions_reports_the_number_of_sessions() {
1144        // Given a store with sessions
1145        let room_id = room_id!("!test:localhost");
1146        let (store, _) = store_with_sessions(4, room_id).await;
1147
1148        // When we count keys
1149        let key_counts = store.inbound_group_session_counts(Some("bkp")).await.unwrap();
1150
1151        // Then the answer equals the number of sessions we created
1152        assert_eq!(key_counts.total, 4);
1153        // And none are backed up
1154        assert_eq!(key_counts.backed_up, 0);
1155    }
1156
1157    #[async_test]
1158    async fn test_counting_backed_up_sessions_reports_the_number_backed_up_in_this_backup() {
1159        // Given a store with sessions, some backed up
1160        let room_id = room_id!("!test:localhost");
1161        let (store, sessions) = store_with_sessions(5, room_id).await;
1162        mark_backed_up(&store, room_id, "bkp", &sessions[..2]).await;
1163
1164        // When we count keys
1165        let key_counts = store.inbound_group_session_counts(Some("bkp")).await.unwrap();
1166
1167        // Then the answer equals the number of sessions we created
1168        assert_eq!(key_counts.total, 5);
1169        // And the backed_up count matches how many were backed up
1170        assert_eq!(key_counts.backed_up, 2);
1171    }
1172
1173    #[async_test]
1174    async fn test_counting_backed_up_sessions_for_null_backup_reports_zero() {
1175        // Given a store with sessions, some backed up
1176        let room_id = room_id!("!test:localhost");
1177        let (store, sessions) = store_with_sessions(4, room_id).await;
1178        mark_backed_up(&store, room_id, "bkp", &sessions[..2]).await;
1179
1180        // When we count keys, providing None as the backup version
1181        let key_counts = store.inbound_group_session_counts(None).await.unwrap();
1182
1183        // Then we ignore everything and just say zero
1184        assert_eq!(key_counts.backed_up, 0);
1185    }
1186
1187    #[async_test]
1188    async fn test_counting_backed_up_sessions_only_reports_sessions_in_the_version_specified() {
1189        // Given a store with sessions, backed up in several versions
1190        let room_id = room_id!("!test:localhost");
1191        let (store, sessions) = store_with_sessions(4, room_id).await;
1192        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
1193        mark_backed_up(&store, room_id, "bkp2", &sessions[3..]).await;
1194
1195        // When we count keys for bkp2
1196        let key_counts = store.inbound_group_session_counts(Some("bkp2")).await.unwrap();
1197
1198        // Then the backed_up count reflects how many were backed up in bkp2 only
1199        assert_eq!(key_counts.backed_up, 1);
1200    }
1201
1202    /// Mark the supplied sessions as backed up in the supplied backup version
1203    async fn mark_backed_up(
1204        store: &MemoryStore,
1205        room_id: &RoomId,
1206        backup_version: &str,
1207        sessions: &[InboundGroupSession],
1208    ) {
1209        let rooms_and_ids: Vec<_> = sessions.iter().map(|s| (room_id, s.session_id())).collect();
1210
1211        store
1212            .mark_inbound_group_sessions_as_backed_up(backup_version, &rooms_and_ids)
1213            .await
1214            .expect("Failed to mark sessions as backed up");
1215    }
1216
1217    // Create a MemoryStore containing the supplied number of sessions.
1218    //
1219    // Sessions are returned in alphabetical order of session id.
1220    async fn store_with_sessions(
1221        num_sessions: usize,
1222        room_id: &RoomId,
1223    ) -> (MemoryStore, Vec<InboundGroupSession>) {
1224        let (account, _) = get_account_and_session_test_helper();
1225
1226        let mut sessions = Vec::with_capacity(num_sessions);
1227        for _ in 0..num_sessions {
1228            sessions.push(new_session(&account, room_id).await);
1229        }
1230        sessions.sort_by_key(|s| s.session_id().to_owned());
1231
1232        let store = MemoryStore::new();
1233        store.save_inbound_group_sessions(sessions.clone(), None).await.unwrap();
1234
1235        (store, sessions)
1236    }
1237
1238    // Create a new InboundGroupSession
1239    async fn new_session(account: &Account, room_id: &RoomId) -> InboundGroupSession {
1240        let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw";
1241        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
1242
1243        InboundGroupSession::new(
1244            Curve25519PublicKey::from_base64(curve_key).unwrap(),
1245            Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE").unwrap(),
1246            room_id,
1247            &outbound.session_key().await,
1248            SenderData::unknown(),
1249            None,
1250            outbound.settings().algorithm.to_owned(),
1251            None,
1252            false,
1253        )
1254        .unwrap()
1255    }
1256
1257    /// Find the session_id and backed_up_to value for each of the sessions in
1258    /// the store.
1259    async fn backed_up_tos(store: &MemoryStore) -> HashMap<SessionId, String> {
1260        store
1261            .get_inbound_group_sessions_and_backed_up_to()
1262            .await
1263            .expect("Unable to get inbound group sessions and backup order")
1264            .iter()
1265            .map(|(s, o)| {
1266                (
1267                    s.session_id().to_owned(),
1268                    o.as_ref().map(|v| v.as_str().to_owned()).unwrap_or("".to_owned()),
1269                )
1270            })
1271            .collect()
1272    }
1273}
1274
1275#[cfg(test)]
1276mod integration_tests {
1277    use std::{
1278        collections::HashMap,
1279        sync::{Arc, Mutex, OnceLock},
1280    };
1281
1282    use async_trait::async_trait;
1283    use matrix_sdk_common::cross_process_lock::CrossProcessLockGeneration;
1284    use ruma::{
1285        DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId, events::secret::request::SecretName,
1286    };
1287    use vodozemac::Curve25519PublicKey;
1288
1289    use super::MemoryStore;
1290    use crate::{
1291        Account, DeviceData, GossipRequest, GossippedSecret, SecretInfo, Session, UserIdentityData,
1292        cryptostore_integration_tests, cryptostore_integration_tests_time,
1293        olm::{
1294            InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity,
1295            SenderDataType, StaticAccountData,
1296        },
1297        store::{
1298            CryptoStore,
1299            types::{
1300                BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts,
1301                RoomKeyWithheldEntry, RoomSettings, StoredRoomKeyBundleData, TrackedUser,
1302            },
1303        },
1304    };
1305
1306    /// Holds on to a MemoryStore during a test, and moves it back into STORES
1307    /// when this is dropped
1308    #[derive(Clone, Debug)]
1309    struct PersistentMemoryStore(Arc<MemoryStore>);
1310
1311    impl PersistentMemoryStore {
1312        fn new() -> Self {
1313            Self(Arc::new(MemoryStore::new()))
1314        }
1315
1316        fn get_static_account(&self) -> Option<StaticAccountData> {
1317            self.0.get_static_account()
1318        }
1319    }
1320
1321    /// Return a clone of the store for the test with the supplied name. Note:
1322    /// dropping this store won't destroy its data, since
1323    /// [PersistentMemoryStore] is a reference-counted smart pointer
1324    /// to an underlying [MemoryStore].
1325    async fn get_store(
1326        name: &str,
1327        _passphrase: Option<&str>,
1328        clear_data: bool,
1329    ) -> PersistentMemoryStore {
1330        // Holds on to one [PersistentMemoryStore] per test, so even if the test drops
1331        // the store, we keep its data alive. This simulates the behaviour of
1332        // the other stores, which keep their data in a real DB, allowing us to
1333        // test MemoryStore using the same code.
1334        static STORES: OnceLock<Mutex<HashMap<String, PersistentMemoryStore>>> = OnceLock::new();
1335        let stores = STORES.get_or_init(|| Mutex::new(HashMap::new()));
1336
1337        let mut stores = stores.lock().unwrap();
1338
1339        if clear_data {
1340            // Create a new PersistentMemoryStore
1341            let new_store = PersistentMemoryStore::new();
1342            stores.insert(name.to_owned(), new_store.clone());
1343            new_store
1344        } else {
1345            stores.entry(name.to_owned()).or_insert_with(PersistentMemoryStore::new).clone()
1346        }
1347    }
1348
1349    /// Forwards all methods to the underlying [MemoryStore].
1350    #[cfg_attr(target_family = "wasm", async_trait(?Send))]
1351    #[cfg_attr(not(target_family = "wasm"), async_trait)]
1352    impl CryptoStore for PersistentMemoryStore {
1353        type Error = <MemoryStore as CryptoStore>::Error;
1354
1355        async fn load_account(&self) -> Result<Option<Account>, Self::Error> {
1356            self.0.load_account().await
1357        }
1358
1359        async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>, Self::Error> {
1360            self.0.load_identity().await
1361        }
1362
1363        async fn save_changes(&self, changes: Changes) -> Result<(), Self::Error> {
1364            self.0.save_changes(changes).await
1365        }
1366
1367        async fn save_pending_changes(&self, changes: PendingChanges) -> Result<(), Self::Error> {
1368            self.0.save_pending_changes(changes).await
1369        }
1370
1371        async fn save_inbound_group_sessions(
1372            &self,
1373            sessions: Vec<InboundGroupSession>,
1374            backed_up_to_version: Option<&str>,
1375        ) -> Result<(), Self::Error> {
1376            self.0.save_inbound_group_sessions(sessions, backed_up_to_version).await
1377        }
1378
1379        async fn get_sessions(
1380            &self,
1381            sender_key: &str,
1382        ) -> Result<Option<Vec<Session>>, Self::Error> {
1383            self.0.get_sessions(sender_key).await
1384        }
1385
1386        async fn get_inbound_group_session(
1387            &self,
1388            room_id: &RoomId,
1389            session_id: &str,
1390        ) -> Result<Option<InboundGroupSession>, Self::Error> {
1391            self.0.get_inbound_group_session(room_id, session_id).await
1392        }
1393
1394        async fn get_withheld_info(
1395            &self,
1396            room_id: &RoomId,
1397            session_id: &str,
1398        ) -> Result<Option<RoomKeyWithheldEntry>, Self::Error> {
1399            self.0.get_withheld_info(room_id, session_id).await
1400        }
1401
1402        async fn get_withheld_sessions_by_room_id(
1403            &self,
1404            room_id: &RoomId,
1405        ) -> Result<Vec<RoomKeyWithheldEntry>, Self::Error> {
1406            self.0.get_withheld_sessions_by_room_id(room_id).await
1407        }
1408
1409        async fn get_inbound_group_sessions(
1410            &self,
1411        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1412            self.0.get_inbound_group_sessions().await
1413        }
1414
1415        async fn inbound_group_session_counts(
1416            &self,
1417            backup_version: Option<&str>,
1418        ) -> Result<RoomKeyCounts, Self::Error> {
1419            self.0.inbound_group_session_counts(backup_version).await
1420        }
1421
1422        async fn get_inbound_group_sessions_by_room_id(
1423            &self,
1424            room_id: &RoomId,
1425        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1426            self.0.get_inbound_group_sessions_by_room_id(room_id).await
1427        }
1428
1429        async fn get_inbound_group_sessions_for_device_batch(
1430            &self,
1431            sender_key: Curve25519PublicKey,
1432            sender_data_type: SenderDataType,
1433            after_session_id: Option<String>,
1434            limit: usize,
1435        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1436            self.0
1437                .get_inbound_group_sessions_for_device_batch(
1438                    sender_key,
1439                    sender_data_type,
1440                    after_session_id,
1441                    limit,
1442                )
1443                .await
1444        }
1445
1446        async fn inbound_group_sessions_for_backup(
1447            &self,
1448            backup_version: &str,
1449            limit: usize,
1450        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1451            self.0.inbound_group_sessions_for_backup(backup_version, limit).await
1452        }
1453
1454        async fn mark_inbound_group_sessions_as_backed_up(
1455            &self,
1456            backup_version: &str,
1457            room_and_session_ids: &[(&RoomId, &str)],
1458        ) -> Result<(), Self::Error> {
1459            self.0
1460                .mark_inbound_group_sessions_as_backed_up(backup_version, room_and_session_ids)
1461                .await
1462        }
1463
1464        async fn reset_backup_state(&self) -> Result<(), Self::Error> {
1465            self.0.reset_backup_state().await
1466        }
1467
1468        async fn load_backup_keys(&self) -> Result<BackupKeys, Self::Error> {
1469            self.0.load_backup_keys().await
1470        }
1471
1472        async fn load_dehydrated_device_pickle_key(
1473            &self,
1474        ) -> Result<Option<DehydratedDeviceKey>, Self::Error> {
1475            self.0.load_dehydrated_device_pickle_key().await
1476        }
1477
1478        async fn delete_dehydrated_device_pickle_key(&self) -> Result<(), Self::Error> {
1479            self.0.delete_dehydrated_device_pickle_key().await
1480        }
1481
1482        async fn get_outbound_group_session(
1483            &self,
1484            room_id: &RoomId,
1485        ) -> Result<Option<OutboundGroupSession>, Self::Error> {
1486            self.0.get_outbound_group_session(room_id).await
1487        }
1488
1489        async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>, Self::Error> {
1490            self.0.load_tracked_users().await
1491        }
1492
1493        async fn save_tracked_users(&self, users: &[(&UserId, bool)]) -> Result<(), Self::Error> {
1494            self.0.save_tracked_users(users).await
1495        }
1496
1497        async fn get_device(
1498            &self,
1499            user_id: &UserId,
1500            device_id: &DeviceId,
1501        ) -> Result<Option<DeviceData>, Self::Error> {
1502            self.0.get_device(user_id, device_id).await
1503        }
1504
1505        async fn get_user_devices(
1506            &self,
1507            user_id: &UserId,
1508        ) -> Result<HashMap<OwnedDeviceId, DeviceData>, Self::Error> {
1509            self.0.get_user_devices(user_id).await
1510        }
1511
1512        async fn get_own_device(&self) -> Result<DeviceData, Self::Error> {
1513            self.0.get_own_device().await
1514        }
1515
1516        async fn get_user_identity(
1517            &self,
1518            user_id: &UserId,
1519        ) -> Result<Option<UserIdentityData>, Self::Error> {
1520            self.0.get_user_identity(user_id).await
1521        }
1522
1523        async fn is_message_known(
1524            &self,
1525            message_hash: &OlmMessageHash,
1526        ) -> Result<bool, Self::Error> {
1527            self.0.is_message_known(message_hash).await
1528        }
1529
1530        async fn get_outgoing_secret_requests(
1531            &self,
1532            request_id: &TransactionId,
1533        ) -> Result<Option<GossipRequest>, Self::Error> {
1534            self.0.get_outgoing_secret_requests(request_id).await
1535        }
1536
1537        async fn get_secret_request_by_info(
1538            &self,
1539            secret_info: &SecretInfo,
1540        ) -> Result<Option<GossipRequest>, Self::Error> {
1541            self.0.get_secret_request_by_info(secret_info).await
1542        }
1543
1544        async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>, Self::Error> {
1545            self.0.get_unsent_secret_requests().await
1546        }
1547
1548        async fn delete_outgoing_secret_requests(
1549            &self,
1550            request_id: &TransactionId,
1551        ) -> Result<(), Self::Error> {
1552            self.0.delete_outgoing_secret_requests(request_id).await
1553        }
1554
1555        async fn get_secrets_from_inbox(
1556            &self,
1557            secret_name: &SecretName,
1558        ) -> Result<Vec<GossippedSecret>, Self::Error> {
1559            self.0.get_secrets_from_inbox(secret_name).await
1560        }
1561
1562        async fn delete_secrets_from_inbox(
1563            &self,
1564            secret_name: &SecretName,
1565        ) -> Result<(), Self::Error> {
1566            self.0.delete_secrets_from_inbox(secret_name).await
1567        }
1568
1569        async fn get_room_settings(
1570            &self,
1571            room_id: &RoomId,
1572        ) -> Result<Option<RoomSettings>, Self::Error> {
1573            self.0.get_room_settings(room_id).await
1574        }
1575
1576        async fn get_received_room_key_bundle_data(
1577            &self,
1578            room_id: &RoomId,
1579            user_id: &UserId,
1580        ) -> crate::store::Result<Option<StoredRoomKeyBundleData>, Self::Error> {
1581            self.0.get_received_room_key_bundle_data(room_id, user_id).await
1582        }
1583
1584        async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>, Self::Error> {
1585            self.0.get_custom_value(key).await
1586        }
1587
1588        async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<(), Self::Error> {
1589            self.0.set_custom_value(key, value).await
1590        }
1591
1592        async fn remove_custom_value(&self, key: &str) -> Result<(), Self::Error> {
1593            self.0.remove_custom_value(key).await
1594        }
1595
1596        async fn try_take_leased_lock(
1597            &self,
1598            lease_duration_ms: u32,
1599            key: &str,
1600            holder: &str,
1601        ) -> Result<Option<CrossProcessLockGeneration>, Self::Error> {
1602            self.0.try_take_leased_lock(lease_duration_ms, key, holder).await
1603        }
1604
1605        async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
1606            self.0.next_batch_token().await
1607        }
1608
1609        async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1610            self.0.get_size().await
1611        }
1612    }
1613
1614    cryptostore_integration_tests!();
1615    cryptostore_integration_tests_time!();
1616}