Skip to main content

matrix_sdk_crypto/store/
memorystore.rs

1// Copyright 2020, 2026 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{BTreeMap, HashMap, HashSet},
17    convert::Infallible,
18    sync::Arc,
19};
20
21use async_trait::async_trait;
22use matrix_sdk_common::{
23    cross_process_lock::{
24        CrossProcessLockGeneration,
25        memory_store_helper::{Lease, try_take_leased_lock},
26    },
27    locks::RwLock as StdRwLock,
28};
29use ruma::{
30    DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId,
31    UserId, events::secret::request::SecretName,
32};
33use tokio::sync::{Mutex, RwLock};
34use tracing::warn;
35use vodozemac::Curve25519PublicKey;
36use zeroize::Zeroizing;
37
38use super::{
39    Account, CryptoStore, InboundGroupSession, Session,
40    caches::DeviceStore,
41    types::{
42        BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts, RoomSettings,
43        StoredRoomKeyBundleData, TrackedUser,
44    },
45};
46use crate::{
47    gossiping::{GossipRequest, SecretInfo},
48    identities::{DeviceData, UserIdentityData},
49    olm::{
50        OutboundGroupSession, PickledAccount, PickledInboundGroupSession, PickledSession,
51        PrivateCrossSigningIdentity, SenderDataType, StaticAccountData,
52    },
53    store::types::{RoomKeyWithheldEntry, RoomPendingKeyBundleDetails},
54};
55
56fn encode_key_info(info: &SecretInfo) -> String {
57    match info {
58        SecretInfo::KeyRequest(info) => {
59            format!("{}{}{}", info.room_id(), info.algorithm(), info.session_id())
60        }
61        SecretInfo::SecretRequest(i) => i.as_ref().to_owned(),
62    }
63}
64
65type SessionId = String;
66
67/// The "version" of a backup - newtype wrapper around a String.
68#[derive(Clone, Debug, PartialEq)]
69struct BackupVersion(String);
70
71impl BackupVersion {
72    fn from(s: &str) -> Self {
73        Self(s.to_owned())
74    }
75
76    fn as_str(&self) -> &str {
77        &self.0
78    }
79}
80
81/// An in-memory only store that will forget all the E2EE key once it's dropped.
82#[derive(Default, Debug)]
83pub struct MemoryStore {
84    static_account: Arc<StdRwLock<Option<StaticAccountData>>>,
85
86    account: StdRwLock<Option<String>>,
87    // Map of sender_key to map of session_id to serialized pickle
88    sessions: StdRwLock<BTreeMap<String, BTreeMap<String, String>>>,
89    inbound_group_sessions: StdRwLock<BTreeMap<OwnedRoomId, HashMap<String, String>>>,
90
91    /// Map room id -> session id -> backup order number
92    /// The latest backup in which this session is stored. Equivalent to
93    /// `backed_up_to` in [`IndexedDbCryptoStore`]
94    inbound_group_sessions_backed_up_to:
95        StdRwLock<HashMap<OwnedRoomId, HashMap<SessionId, BackupVersion>>>,
96
97    outbound_group_sessions: StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>,
98    private_identity: StdRwLock<Option<PrivateCrossSigningIdentity>>,
99    tracked_users: StdRwLock<HashMap<OwnedUserId, TrackedUser>>,
100    olm_hashes: StdRwLock<HashMap<String, HashSet<String>>>,
101    devices: DeviceStore,
102    identities: StdRwLock<HashMap<OwnedUserId, String>>,
103    outgoing_key_requests: StdRwLock<HashMap<OwnedTransactionId, GossipRequest>>,
104    key_requests_by_info: StdRwLock<HashMap<String, OwnedTransactionId>>,
105    direct_withheld_info: StdRwLock<HashMap<OwnedRoomId, HashMap<String, RoomKeyWithheldEntry>>>,
106    custom_values: StdRwLock<HashMap<String, Vec<u8>>>,
107    leases: StdRwLock<HashMap<String, Lease>>,
108    secret_inbox: StdRwLock<HashMap<String, Vec<Zeroizing<String>>>>,
109    backup_keys: RwLock<BackupKeys>,
110    dehydrated_device_pickle_key: RwLock<Option<DehydratedDeviceKey>>,
111    next_batch_token: RwLock<Option<String>>,
112    room_settings: StdRwLock<HashMap<OwnedRoomId, RoomSettings>>,
113    room_key_bundles:
114        StdRwLock<HashMap<OwnedRoomId, HashMap<OwnedUserId, StoredRoomKeyBundleData>>>,
115    room_key_backups_fully_downloaded: StdRwLock<HashSet<OwnedRoomId>>,
116    rooms_pending_key_bundle: StdRwLock<HashMap<OwnedRoomId, RoomPendingKeyBundleDetails>>,
117
118    save_changes_lock: Arc<Mutex<()>>,
119}
120
121impl MemoryStore {
122    /// Create a new empty `MemoryStore`.
123    pub fn new() -> Self {
124        Self::default()
125    }
126
127    fn get_static_account(&self) -> Option<StaticAccountData> {
128        self.static_account.read().clone()
129    }
130
131    pub(crate) fn save_devices(&self, devices: Vec<DeviceData>) {
132        for device in devices {
133            let _ = self.devices.add(device);
134        }
135    }
136
137    fn delete_devices(&self, devices: Vec<DeviceData>) {
138        for device in devices {
139            let _ = self.devices.remove(device.user_id(), device.device_id());
140        }
141    }
142
143    fn save_sessions(&self, sessions: Vec<(String, PickledSession)>) {
144        let mut session_store = self.sessions.write();
145
146        for (session_id, pickle) in sessions {
147            let entry = session_store.entry(pickle.sender_key.to_base64()).or_default();
148
149            // insert or replace if exists
150            entry.insert(
151                session_id,
152                serde_json::to_string(&pickle).expect("Failed to serialize olm session"),
153            );
154        }
155    }
156
157    fn save_outbound_group_sessions(&self, sessions: Vec<OutboundGroupSession>) {
158        self.outbound_group_sessions
159            .write()
160            .extend(sessions.into_iter().map(|s| (s.room_id().to_owned(), s)));
161    }
162
163    fn save_private_identity(&self, private_identity: Option<PrivateCrossSigningIdentity>) {
164        *self.private_identity.write() = private_identity;
165    }
166
167    /// Return all the [`InboundGroupSession`]s we have, paired with the
168    /// `backed_up_to` value for each one (or "" where it is missing, which
169    /// should never happen).
170    async fn get_inbound_group_sessions_and_backed_up_to(
171        &self,
172    ) -> Result<Vec<(InboundGroupSession, Option<BackupVersion>)>> {
173        let lookup = |s: &InboundGroupSession| {
174            self.inbound_group_sessions_backed_up_to
175                .read()
176                .get(&s.room_id)?
177                .get(s.session_id())
178                .cloned()
179        };
180
181        Ok(self
182            .get_inbound_group_sessions()
183            .await?
184            .into_iter()
185            .map(|s| {
186                let v = lookup(&s);
187                (s, v)
188            })
189            .collect())
190    }
191}
192
193type Result<T> = std::result::Result<T, Infallible>;
194
195#[cfg_attr(target_family = "wasm", async_trait(?Send))]
196#[cfg_attr(not(target_family = "wasm"), async_trait)]
197impl CryptoStore for MemoryStore {
198    type Error = Infallible;
199
200    async fn load_account(&self) -> Result<Option<Account>> {
201        let pickled_account: Option<PickledAccount> = self.account.read().as_ref().map(|acc| {
202            serde_json::from_str(acc)
203                .expect("Deserialization failed: invalid pickled account JSON format")
204        });
205
206        if let Some(pickle) = pickled_account {
207            let account =
208                Account::from_pickle(pickle).expect("From pickle failed: invalid pickle format");
209
210            *self.static_account.write() = Some(account.static_data().clone());
211
212            Ok(Some(account))
213        } else {
214            Ok(None)
215        }
216    }
217
218    async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
219        Ok(self.private_identity.read().clone())
220    }
221
222    async fn next_batch_token(&self) -> Result<Option<String>> {
223        Ok(self.next_batch_token.read().await.clone())
224    }
225
226    async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
227        let _guard = self.save_changes_lock.lock().await;
228
229        let pickled_account = if let Some(account) = changes.account {
230            *self.static_account.write() = Some(account.static_data().clone());
231            Some(account.pickle())
232        } else {
233            None
234        };
235
236        *self.account.write() = pickled_account.map(|pickle| {
237            serde_json::to_string(&pickle)
238                .expect("Serialization failed: invalid pickled account JSON format")
239        });
240
241        Ok(())
242    }
243
244    async fn save_changes(&self, changes: Changes) -> Result<()> {
245        let _guard = self.save_changes_lock.lock().await;
246
247        let mut pickled_session: Vec<(String, PickledSession)> = Vec::new();
248        for session in changes.sessions {
249            let session_id = session.session_id().to_owned();
250            let pickle = session.pickle().await;
251            pickled_session.push((session_id.clone(), pickle));
252        }
253        self.save_sessions(pickled_session);
254
255        self.save_inbound_group_sessions(changes.inbound_group_sessions, None).await?;
256        self.save_outbound_group_sessions(changes.outbound_group_sessions);
257        self.save_private_identity(changes.private_identity);
258
259        self.save_devices(changes.devices.new);
260        self.save_devices(changes.devices.changed);
261        self.delete_devices(changes.devices.deleted);
262
263        {
264            let mut identities = self.identities.write();
265            for identity in changes.identities.new.into_iter().chain(changes.identities.changed) {
266                identities.insert(
267                    identity.user_id().to_owned(),
268                    serde_json::to_string(&identity)
269                        .expect("UserIdentityData should always serialize to json"),
270                );
271            }
272        }
273
274        {
275            let mut olm_hashes = self.olm_hashes.write();
276            for hash in changes.message_hashes {
277                olm_hashes.entry(hash.sender_key.to_owned()).or_default().insert(hash.hash.clone());
278            }
279        }
280
281        {
282            let mut outgoing_key_requests = self.outgoing_key_requests.write();
283            let mut key_requests_by_info = self.key_requests_by_info.write();
284
285            for key_request in changes.key_requests {
286                let id = key_request.request_id.clone();
287                let info_string = encode_key_info(&key_request.info);
288
289                outgoing_key_requests.insert(id.clone(), key_request);
290                key_requests_by_info.insert(info_string, id);
291            }
292        }
293
294        if let Some(key) = changes.backup_decryption_key {
295            self.backup_keys.write().await.decryption_key = Some(key);
296        }
297
298        if let Some(version) = changes.backup_version {
299            self.backup_keys.write().await.backup_version = Some(version);
300        }
301
302        if let Some(pickle_key) = changes.dehydrated_device_pickle_key {
303            let mut lock = self.dehydrated_device_pickle_key.write().await;
304            *lock = Some(pickle_key);
305        }
306
307        {
308            let mut secret_inbox = self.secret_inbox.write();
309            for secret in changes.secrets {
310                secret_inbox.entry(secret.secret_name.to_string()).or_default().push(secret.secret);
311            }
312        }
313
314        {
315            let mut direct_withheld_info = self.direct_withheld_info.write();
316            for (room_id, data) in changes.withheld_session_info {
317                for (session_id, event) in data {
318                    direct_withheld_info
319                        .entry(room_id.to_owned())
320                        .or_default()
321                        .insert(session_id, event);
322                }
323            }
324        }
325
326        if let Some(next_batch_token) = changes.next_batch_token {
327            *self.next_batch_token.write().await = Some(next_batch_token);
328        }
329
330        if !changes.room_settings.is_empty() {
331            let mut settings = self.room_settings.write();
332            settings.extend(changes.room_settings);
333        }
334
335        if !changes.received_room_key_bundles.is_empty() {
336            let mut room_key_bundles = self.room_key_bundles.write();
337            for bundle in changes.received_room_key_bundles {
338                room_key_bundles
339                    .entry(bundle.bundle_data.room_id.clone())
340                    .or_default()
341                    .insert(bundle.sender_user.clone(), bundle);
342            }
343        }
344
345        if !changes.room_key_backups_fully_downloaded.is_empty() {
346            let mut room_key_backups_fully_downloaded =
347                self.room_key_backups_fully_downloaded.write();
348            for room in changes.room_key_backups_fully_downloaded {
349                room_key_backups_fully_downloaded.insert(room);
350            }
351        }
352
353        if !changes.rooms_pending_key_bundle.is_empty() {
354            let mut lock = self.rooms_pending_key_bundle.write();
355            for (room, details) in changes.rooms_pending_key_bundle {
356                if let Some(details) = details {
357                    lock.insert(room, details);
358                } else {
359                    lock.remove(&room);
360                }
361            }
362        }
363
364        Ok(())
365    }
366
367    async fn save_inbound_group_sessions(
368        &self,
369        sessions: Vec<InboundGroupSession>,
370        backed_up_to_version: Option<&str>,
371    ) -> Result<()> {
372        for session in sessions {
373            let room_id = session.room_id();
374            let session_id = session.session_id();
375
376            // Sanity-check that the data in the sessions corresponds to backed_up_version
377            let backed_up = session.backed_up();
378            if backed_up != backed_up_to_version.is_some() {
379                warn!(
380                    backed_up,
381                    backed_up_to_version,
382                    "Session backed-up flag does not correspond to backup version setting",
383                );
384            }
385
386            if let Some(backup_version) = backed_up_to_version {
387                self.inbound_group_sessions_backed_up_to
388                    .write()
389                    .entry(room_id.to_owned())
390                    .or_default()
391                    .insert(session_id.to_owned(), BackupVersion::from(backup_version));
392            }
393
394            let pickle = session.pickle().await;
395            self.inbound_group_sessions
396                .write()
397                .entry(session.room_id().to_owned())
398                .or_default()
399                .insert(
400                    session.session_id().to_owned(),
401                    serde_json::to_string(&pickle)
402                        .expect("Pickle pickle data should serialize to json"),
403                );
404        }
405        Ok(())
406    }
407
408    async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
409        let device_keys = self.get_own_device().await?.as_device_keys().clone();
410
411        if let Some(pickles) = self.sessions.read().get(sender_key) {
412            let mut sessions: Vec<Session> = Vec::new();
413            for serialized_pickle in pickles.values() {
414                let pickle: PickledSession = serde_json::from_str(serialized_pickle.as_str())
415                    .expect("Pickle pickle deserialization should work");
416                let session = Session::from_pickle(device_keys.clone(), pickle)
417                    .expect("Expect from pickle to always work");
418                sessions.push(session);
419            }
420            Ok(Some(sessions))
421        } else {
422            Ok(None)
423        }
424    }
425
426    async fn get_inbound_group_session(
427        &self,
428        room_id: &RoomId,
429        session_id: &str,
430    ) -> Result<Option<InboundGroupSession>> {
431        let pickle: Option<PickledInboundGroupSession> = self
432            .inbound_group_sessions
433            .read()
434            .get(room_id)
435            .and_then(|m| m.get(session_id))
436            .and_then(|ser| {
437                serde_json::from_str(ser).expect("Pickle pickle deserialization should work")
438            });
439
440        Ok(pickle.map(|p| {
441            InboundGroupSession::from_pickle(p).expect("Expect from pickle to always work")
442        }))
443    }
444
445    async fn get_withheld_info(
446        &self,
447        room_id: &RoomId,
448        session_id: &str,
449    ) -> Result<Option<RoomKeyWithheldEntry>> {
450        Ok(self
451            .direct_withheld_info
452            .read()
453            .get(room_id)
454            .and_then(|e| Some(e.get(session_id)?.to_owned())))
455    }
456
457    async fn get_withheld_sessions_by_room_id(
458        &self,
459        room_id: &RoomId,
460    ) -> crate::store::Result<Vec<RoomKeyWithheldEntry>, Self::Error> {
461        Ok(self
462            .direct_withheld_info
463            .read()
464            .get(room_id)
465            .map(|e| e.values().cloned().collect())
466            .unwrap_or_default())
467    }
468
469    async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
470        let inbounds = self
471            .inbound_group_sessions
472            .read()
473            .values()
474            .flat_map(HashMap::values)
475            .map(|ser| {
476                let pickle: PickledInboundGroupSession =
477                    serde_json::from_str(ser).expect("Pickle deserialization should work");
478                InboundGroupSession::from_pickle(pickle).expect("Expect from pickle to always work")
479            })
480            .collect();
481        Ok(inbounds)
482    }
483
484    async fn inbound_group_session_counts(
485        &self,
486        backup_version: Option<&str>,
487    ) -> Result<RoomKeyCounts> {
488        let backed_up = if let Some(backup_version) = backup_version {
489            self.get_inbound_group_sessions_and_backed_up_to()
490                .await?
491                .into_iter()
492                // Count the sessions backed up in the required backup
493                .filter(|(_, o)| o.as_ref().is_some_and(|o| o.as_str() == backup_version))
494                .count()
495        } else {
496            // We asked about a nonexistent backup version - this doesn't make much sense,
497            // but we can easily answer that nothing is backed up in this
498            // nonexistent backup.
499            0
500        };
501
502        let total = self.inbound_group_sessions.read().values().map(HashMap::len).sum();
503        Ok(RoomKeyCounts { total, backed_up })
504    }
505
506    async fn get_inbound_group_sessions_by_room_id(
507        &self,
508        room_id: &RoomId,
509    ) -> Result<Vec<InboundGroupSession>> {
510        let inbounds = match self.inbound_group_sessions.read().get(room_id) {
511            None => Vec::new(),
512            Some(v) => v
513                .values()
514                .map(|ser| {
515                    let pickle: PickledInboundGroupSession =
516                        serde_json::from_str(ser).expect("Pickle deserialization should work");
517                    InboundGroupSession::from_pickle(pickle)
518                        .expect("Expect from pickle to always work")
519                })
520                .collect(),
521        };
522        Ok(inbounds)
523    }
524
525    async fn get_inbound_group_sessions_for_device_batch(
526        &self,
527        sender_key: Curve25519PublicKey,
528        sender_data_type: SenderDataType,
529        after_session_id: Option<String>,
530        limit: usize,
531    ) -> Result<Vec<InboundGroupSession>> {
532        // First, find all InboundGroupSessions, filtering for those that match the
533        // device and sender_data type.
534        let mut sessions: Vec<_> = self
535            .get_inbound_group_sessions()
536            .await?
537            .into_iter()
538            .filter(|session: &InboundGroupSession| {
539                session.creator_info.curve25519_key == sender_key
540                    && session.sender_data.to_type() == sender_data_type
541            })
542            .collect();
543
544        // Then, sort the sessions in order of ascending session ID...
545        sessions.sort_by_key(|s| s.session_id().to_owned());
546
547        // Figure out where in the array to start returning results from
548        let start_index = {
549            match after_session_id {
550                None => 0,
551                Some(id) => {
552                    // We're looking for the first session with a session ID strictly after `id`; if
553                    // there are none, the end of the array.
554                    sessions
555                        .iter()
556                        .position(|session| session.session_id() > id.as_str())
557                        .unwrap_or(sessions.len())
558                }
559            }
560        };
561
562        // Return up to `limit` items from the array, starting from `start_index`
563        Ok(sessions.drain(start_index..).take(limit).collect())
564    }
565
566    async fn inbound_group_sessions_for_backup(
567        &self,
568        backup_version: &str,
569        limit: usize,
570    ) -> Result<Vec<InboundGroupSession>> {
571        Ok(self
572            .get_inbound_group_sessions_and_backed_up_to()
573            .await?
574            .into_iter()
575            .filter_map(|(session, backed_up_to)| {
576                if let Some(ref existing_version) = backed_up_to
577                    && existing_version.as_str() == backup_version
578                {
579                    // This session is already backed up in the required backup
580                    None
581                } else {
582                    // It's not backed up, or it's backed up in a different backup
583                    Some(session)
584                }
585            })
586            .take(limit)
587            .collect())
588    }
589
590    async fn mark_inbound_group_sessions_as_backed_up(
591        &self,
592        backup_version: &str,
593        room_and_session_ids: &[(&RoomId, &str)],
594    ) -> Result<()> {
595        for &(room_id, session_id) in room_and_session_ids {
596            let session = self.get_inbound_group_session(room_id, session_id).await?;
597
598            if let Some(session) = session {
599                session.mark_as_backed_up();
600
601                self.inbound_group_sessions_backed_up_to
602                    .write()
603                    .entry(room_id.to_owned())
604                    .or_default()
605                    .insert(session_id.to_owned(), BackupVersion::from(backup_version));
606
607                // Save it back
608                let updated_pickle = session.pickle().await;
609
610                self.inbound_group_sessions.write().entry(room_id.to_owned()).or_default().insert(
611                    session_id.to_owned(),
612                    serde_json::to_string(&updated_pickle)
613                        .expect("Pickle serialization should work"),
614                );
615            }
616        }
617
618        Ok(())
619    }
620
621    async fn reset_backup_state(&self) -> Result<()> {
622        // Nothing to do here, because we remember which backup versions we backed up to
623        // in `mark_inbound_group_sessions_as_backed_up`, so we don't need to
624        // reset anything here because the required version is passed in to
625        // `inbound_group_sessions_for_backup`, and we can compare against the
626        // version we stored.
627
628        Ok(())
629    }
630
631    async fn load_backup_keys(&self) -> Result<BackupKeys> {
632        Ok(self.backup_keys.read().await.to_owned())
633    }
634
635    async fn load_dehydrated_device_pickle_key(&self) -> Result<Option<DehydratedDeviceKey>> {
636        Ok(self.dehydrated_device_pickle_key.read().await.to_owned())
637    }
638
639    async fn delete_dehydrated_device_pickle_key(&self) -> Result<()> {
640        let mut lock = self.dehydrated_device_pickle_key.write().await;
641        *lock = None;
642        Ok(())
643    }
644
645    async fn get_outbound_group_session(
646        &self,
647        room_id: &RoomId,
648    ) -> Result<Option<OutboundGroupSession>> {
649        Ok(self.outbound_group_sessions.read().get(room_id).cloned())
650    }
651
652    async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
653        Ok(self.tracked_users.read().values().cloned().collect())
654    }
655
656    async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
657        self.tracked_users.write().extend(tracked_users.iter().map(|(user_id, dirty)| {
658            let user_id: OwnedUserId = user_id.to_owned().into();
659            (user_id.clone(), TrackedUser { user_id, dirty: *dirty })
660        }));
661        Ok(())
662    }
663
664    async fn get_device(
665        &self,
666        user_id: &UserId,
667        device_id: &DeviceId,
668    ) -> Result<Option<DeviceData>> {
669        Ok(self.devices.get(user_id, device_id))
670    }
671
672    async fn get_user_devices(
673        &self,
674        user_id: &UserId,
675    ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
676        Ok(self.devices.user_devices(user_id))
677    }
678
679    async fn get_own_device(&self) -> Result<DeviceData> {
680        let account =
681            self.get_static_account().expect("Expect account to exist when getting own device");
682
683        Ok(self
684            .devices
685            .get(&account.user_id, &account.device_id)
686            .expect("Invalid state: Should always have a own device"))
687    }
688
689    async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
690        let serialized = self.identities.read().get(user_id).cloned();
691        match serialized {
692            None => Ok(None),
693            Some(serialized) => {
694                let id: UserIdentityData = serde_json::from_str(serialized.as_str())
695                    .expect("Only valid serialized identity are saved");
696                Ok(Some(id))
697            }
698        }
699    }
700
701    async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
702        Ok(self
703            .olm_hashes
704            .write()
705            .entry(message_hash.sender_key.to_owned())
706            .or_default()
707            .contains(&message_hash.hash))
708    }
709
710    async fn get_outgoing_secret_requests(
711        &self,
712        request_id: &TransactionId,
713    ) -> Result<Option<GossipRequest>> {
714        Ok(self.outgoing_key_requests.read().get(request_id).cloned())
715    }
716
717    async fn get_secret_request_by_info(
718        &self,
719        key_info: &SecretInfo,
720    ) -> Result<Option<GossipRequest>> {
721        let key_info_string = encode_key_info(key_info);
722
723        Ok(self
724            .key_requests_by_info
725            .read()
726            .get(&key_info_string)
727            .and_then(|i| self.outgoing_key_requests.read().get(i).cloned()))
728    }
729
730    async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
731        Ok(self
732            .outgoing_key_requests
733            .read()
734            .values()
735            .filter(|req| !req.sent_out)
736            .cloned()
737            .collect())
738    }
739
740    async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
741        let req = self.outgoing_key_requests.write().remove(request_id);
742        if let Some(i) = req {
743            let key_info_string = encode_key_info(&i.info);
744            self.key_requests_by_info.write().remove(&key_info_string);
745        }
746
747        Ok(())
748    }
749
750    async fn get_secrets_from_inbox(
751        &self,
752        secret_name: &SecretName,
753    ) -> Result<Vec<Zeroizing<String>>> {
754        Ok(self.secret_inbox.write().entry(secret_name.to_string()).or_default().to_owned())
755    }
756
757    async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
758        self.secret_inbox.write().remove(secret_name.as_str());
759
760        Ok(())
761    }
762
763    async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
764        Ok(self.room_settings.read().get(room_id).cloned())
765    }
766
767    async fn get_received_room_key_bundle_data(
768        &self,
769        room_id: &RoomId,
770        user_id: &UserId,
771    ) -> Result<Option<StoredRoomKeyBundleData>> {
772        let guard = self.room_key_bundles.read();
773
774        let result = guard.get(room_id).and_then(|bundles| bundles.get(user_id).cloned());
775
776        Ok(result)
777    }
778
779    async fn get_pending_key_bundle_details_for_room(
780        &self,
781        room_id: &RoomId,
782    ) -> Result<Option<RoomPendingKeyBundleDetails>> {
783        Ok(self.rooms_pending_key_bundle.read().get(room_id).cloned())
784    }
785
786    async fn get_all_rooms_pending_key_bundles(&self) -> Result<Vec<RoomPendingKeyBundleDetails>> {
787        Ok(self.rooms_pending_key_bundle.read().values().cloned().collect())
788    }
789
790    async fn has_downloaded_all_room_keys(&self, room_id: &RoomId) -> Result<bool> {
791        let guard = self.room_key_backups_fully_downloaded.read();
792        Ok(guard.contains(room_id))
793    }
794
795    async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
796        Ok(self.custom_values.read().get(key).cloned())
797    }
798
799    async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
800        self.custom_values.write().insert(key.to_owned(), value);
801        Ok(())
802    }
803
804    async fn remove_custom_value(&self, key: &str) -> Result<()> {
805        self.custom_values.write().remove(key);
806        Ok(())
807    }
808
809    async fn try_take_leased_lock(
810        &self,
811        lease_duration_ms: u32,
812        key: &str,
813        holder: &str,
814    ) -> Result<Option<CrossProcessLockGeneration>> {
815        Ok(try_take_leased_lock(&mut self.leases.write(), lease_duration_ms, key, holder))
816    }
817
818    async fn get_size(&self) -> Result<Option<usize>> {
819        Ok(None)
820    }
821}
822
823#[cfg(test)]
824mod tests {
825    use std::collections::HashMap;
826
827    use matrix_sdk_test::async_test;
828    use ruma::{RoomId, room_id, user_id};
829    use vodozemac::{Curve25519PublicKey, Ed25519PublicKey};
830
831    use super::SessionId;
832    use crate::{
833        DeviceData,
834        identities::device::testing::get_device,
835        olm::{
836            Account, InboundGroupSession, OlmMessageHash, PrivateCrossSigningIdentity, SenderData,
837            tests::get_account_and_session_test_helper,
838        },
839        store::{
840            CryptoStore,
841            memorystore::MemoryStore,
842            types::{Changes, DeviceChanges, PendingChanges},
843        },
844    };
845
846    #[async_test]
847    async fn test_session_store() {
848        let (account, session) = get_account_and_session_test_helper();
849        let own_device = DeviceData::from_account(&account);
850        let store = MemoryStore::new();
851
852        assert!(store.load_account().await.unwrap().is_none());
853
854        store
855            .save_changes(Changes {
856                devices: DeviceChanges { new: vec![own_device], ..Default::default() },
857                ..Default::default()
858            })
859            .await
860            .unwrap();
861        store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
862
863        store
864            .save_changes(Changes { sessions: (vec![session.clone()]), ..Default::default() })
865            .await
866            .unwrap();
867
868        let sessions = store.get_sessions(&session.sender_key.to_base64()).await.unwrap().unwrap();
869
870        let loaded_session = &sessions[0];
871
872        assert_eq!(&session, loaded_session);
873    }
874
875    #[async_test]
876    async fn test_inbound_group_session_store() {
877        let (account, _) = get_account_and_session_test_helper();
878        let room_id = room_id!("!test:localhost");
879        let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw";
880
881        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
882        let inbound = InboundGroupSession::new(
883            Curve25519PublicKey::from_base64(curve_key).unwrap(),
884            Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE").unwrap(),
885            room_id,
886            &outbound.session_key().await,
887            SenderData::unknown(),
888            None,
889            outbound.settings().algorithm.to_owned(),
890            None,
891            false,
892        )
893        .unwrap();
894
895        let store = MemoryStore::new();
896        store.save_inbound_group_sessions(vec![inbound.clone()], None).await.unwrap();
897
898        let loaded_session =
899            store.get_inbound_group_session(room_id, outbound.session_id()).await.unwrap().unwrap();
900        assert_eq!(inbound, loaded_session);
901    }
902
903    #[async_test]
904    async fn test_backing_up_marks_sessions_as_backed_up() {
905        // Given there are 2 sessions
906        let room_id = room_id!("!test:localhost");
907        let (store, sessions) = store_with_sessions(2, room_id).await;
908
909        // When I mark them as backed up
910        mark_backed_up(&store, room_id, "bkp1", &sessions).await;
911
912        // Then their backed_up_to field is set
913        let but = backed_up_tos(&store).await;
914        assert_eq!(but[sessions[0].session_id()], "bkp1");
915        assert_eq!(but[sessions[1].session_id()], "bkp1");
916    }
917
918    #[async_test]
919    async fn test_backing_up_a_second_set_of_sessions_updates_their_backup_order() {
920        // Given there are 3 sessions
921        let room_id = room_id!("!test:localhost");
922        let (store, sessions) = store_with_sessions(3, room_id).await;
923
924        // When I mark 0 and 1 as backed up in bkp1
925        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
926
927        // And 1 and 2 as backed up in bkp2
928        mark_backed_up(&store, room_id, "bkp2", &sessions[1..]).await;
929
930        // Then 0 is backed up in bkp1 and the 1 and 2 are backed up in bkp2
931        let but = backed_up_tos(&store).await;
932        assert_eq!(but[sessions[0].session_id()], "bkp1");
933        assert_eq!(but[sessions[1].session_id()], "bkp2");
934        assert_eq!(but[sessions[2].session_id()], "bkp2");
935    }
936
937    #[async_test]
938    async fn test_backing_up_again_to_the_same_version_has_no_effect() {
939        // Given there are 3 sessions
940        let room_id = room_id!("!test:localhost");
941        let (store, sessions) = store_with_sessions(3, room_id).await;
942
943        // When I mark the first two as backed up in the first backup
944        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
945
946        // And the last 2 as backed up in the same backup version
947        mark_backed_up(&store, room_id, "bkp1", &sessions[1..]).await;
948
949        // Then they all get the same backed_up_to value
950        let but = backed_up_tos(&store).await;
951        assert_eq!(but[sessions[0].session_id()], "bkp1");
952        assert_eq!(but[sessions[1].session_id()], "bkp1");
953        assert_eq!(but[sessions[2].session_id()], "bkp1");
954    }
955
956    #[async_test]
957    async fn test_backing_up_to_an_old_backup_version_can_increase_backed_up_to() {
958        // Given we have backed up some sessions to 2 backup versions, an older and a
959        // newer
960        let room_id = room_id!("!test:localhost");
961        let (store, sessions) = store_with_sessions(4, room_id).await;
962        mark_backed_up(&store, room_id, "older_bkp", &sessions[..2]).await;
963        mark_backed_up(&store, room_id, "newer_bkp", &sessions[1..2]).await;
964
965        // When I ask to back up the un-backed-up ones to the older backup
966        mark_backed_up(&store, room_id, "older_bkp", &sessions[2..]).await;
967
968        // Then each session lists the backup it was most recently included in
969        let but = backed_up_tos(&store).await;
970        assert_eq!(but[sessions[0].session_id()], "older_bkp");
971        assert_eq!(but[sessions[1].session_id()], "newer_bkp");
972        assert_eq!(but[sessions[2].session_id()], "older_bkp");
973        assert_eq!(but[sessions[3].session_id()], "older_bkp");
974    }
975
976    #[async_test]
977    async fn test_backing_up_to_an_old_backup_version_overwrites_a_newer_one() {
978        // Given we have backed up to 2 backup versions, an older and a newer
979        let room_id = room_id!("!test:localhost");
980        let (store, sessions) = store_with_sessions(4, room_id).await;
981        mark_backed_up(&store, room_id, "older_bkp", &sessions).await;
982        // Sanity: they are backed up in order number 1
983        assert_eq!(backed_up_tos(&store).await[sessions[0].session_id()], "older_bkp");
984        mark_backed_up(&store, room_id, "newer_bkp", &sessions).await;
985        // Sanity: they are backed up in order number 2
986        assert_eq!(backed_up_tos(&store).await[sessions[0].session_id()], "newer_bkp");
987
988        // When I ask to back up some to the older version
989        mark_backed_up(&store, room_id, "older_bkp", &sessions[..2]).await;
990
991        // Then older backup overwrites: we don't consider the order here at all
992        let but = backed_up_tos(&store).await;
993        assert_eq!(but[sessions[0].session_id()], "older_bkp");
994        assert_eq!(but[sessions[1].session_id()], "older_bkp");
995        assert_eq!(but[sessions[2].session_id()], "newer_bkp");
996        assert_eq!(but[sessions[3].session_id()], "newer_bkp");
997    }
998
999    #[async_test]
1000    async fn test_not_backed_up_sessions_are_eligible_for_backup() {
1001        // Given there are 4 sessions, 2 of which are already backed up
1002        let room_id = room_id!("!test:localhost");
1003        let (store, sessions) = store_with_sessions(4, room_id).await;
1004        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
1005
1006        // When I ask which to back up
1007        let mut to_backup = store
1008            .inbound_group_sessions_for_backup("bkp1", 10)
1009            .await
1010            .expect("Failed to ask for sessions to backup");
1011        to_backup.sort_by_key(|s| s.session_id().to_owned());
1012
1013        // Then I am told the last 2 only
1014        assert_eq!(to_backup, &[sessions[2].clone(), sessions[3].clone()]);
1015    }
1016
1017    #[async_test]
1018    async fn test_all_sessions_are_eligible_for_backup_if_version_is_unknown() {
1019        // Given there are 4 sessions, 2 of which are already backed up in bkp1
1020        let room_id = room_id!("!test:localhost");
1021        let (store, sessions) = store_with_sessions(4, room_id).await;
1022        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
1023
1024        // When I ask which to back up in an unknown version
1025        let mut to_backup = store
1026            .inbound_group_sessions_for_backup("unknown_bkp", 10)
1027            .await
1028            .expect("Failed to ask for sessions to backup");
1029        to_backup.sort_by_key(|s| s.session_id().to_owned());
1030
1031        // Then I am told to back up all of them
1032        assert_eq!(
1033            to_backup,
1034            &[sessions[0].clone(), sessions[1].clone(), sessions[2].clone(), sessions[3].clone()]
1035        );
1036    }
1037
1038    #[async_test]
1039    async fn test_sessions_backed_up_to_a_later_version_are_eligible_for_backup() {
1040        // Given there are 4 sessions, some backed up to three different versions
1041        let room_id = room_id!("!test:localhost");
1042        let (store, sessions) = store_with_sessions(4, room_id).await;
1043        mark_backed_up(&store, room_id, "bkp0", &sessions[..1]).await;
1044        mark_backed_up(&store, room_id, "bkp1", &sessions[1..2]).await;
1045        mark_backed_up(&store, room_id, "bkp2", &sessions[2..3]).await;
1046
1047        // When I ask which to back up in the middle version
1048        let mut to_backup = store
1049            .inbound_group_sessions_for_backup("bkp1", 10)
1050            .await
1051            .expect("Failed to ask for sessions to backup");
1052        to_backup.sort_by_key(|s| s.session_id().to_owned());
1053
1054        // Then I am told to back up everything not in the version I asked about
1055        assert_eq!(
1056            to_backup,
1057            &[
1058                sessions[0].clone(), // Backed up in bkp0
1059                // sessions[1] is backed up in bkp1 already, which we asked about
1060                sessions[2].clone(), // Backed up in bkp2
1061                sessions[3].clone(), // Not backed up
1062            ]
1063        );
1064    }
1065
1066    #[async_test]
1067    async fn test_outbound_group_session_store() {
1068        // Given an outbound session
1069        let (account, _) = get_account_and_session_test_helper();
1070        let room_id = room_id!("!test:localhost");
1071        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
1072
1073        // When we save it to the store
1074        let store = MemoryStore::new();
1075        store.save_outbound_group_sessions(vec![outbound.clone()]);
1076
1077        // Then we can get it out again
1078        let loaded_session = store.get_outbound_group_session(room_id).await.unwrap().unwrap();
1079        assert_eq!(
1080            serde_json::to_string(&outbound.pickle().await).unwrap(),
1081            serde_json::to_string(&loaded_session.pickle().await).unwrap()
1082        );
1083    }
1084
1085    #[async_test]
1086    async fn test_tracked_users_are_stored_once_per_user_id() {
1087        // Given a store containing 2 tracked users, both dirty
1088        let user1 = user_id!("@user1:s");
1089        let user2 = user_id!("@user2:s");
1090        let user3 = user_id!("@user3:s");
1091        let store = MemoryStore::new();
1092        store.save_tracked_users(&[(user1, true), (user2, true)]).await.unwrap();
1093
1094        // When we mark one as clean and add another
1095        store.save_tracked_users(&[(user2, false), (user3, false)]).await.unwrap();
1096
1097        // Then we can get them out again and their dirty flags are correct
1098        let loaded_tracked_users =
1099            store.load_tracked_users().await.expect("failed to load tracked users");
1100
1101        let tracked_contains = |user_id, dirty| {
1102            loaded_tracked_users.iter().any(|u| u.user_id == user_id && u.dirty == dirty)
1103        };
1104
1105        assert!(tracked_contains(user1, true));
1106        assert!(tracked_contains(user2, false));
1107        assert!(tracked_contains(user3, false));
1108        assert_eq!(loaded_tracked_users.len(), 3);
1109    }
1110
1111    #[async_test]
1112    async fn test_private_identity_store() {
1113        // Given a private identity
1114        let private_identity = PrivateCrossSigningIdentity::empty(user_id!("@u:s"));
1115
1116        // When we save it to the store
1117        let store = MemoryStore::new();
1118        store.save_private_identity(Some(private_identity.clone()));
1119
1120        // Then we can get it out again
1121        let loaded_identity =
1122            store.load_identity().await.expect("failed to load private identity").unwrap();
1123
1124        assert_eq!(loaded_identity.user_id(), user_id!("@u:s"));
1125    }
1126
1127    #[async_test]
1128    async fn test_device_store() {
1129        let device = get_device();
1130        let store = MemoryStore::new();
1131
1132        store.save_devices(vec![device.clone()]);
1133
1134        let loaded_device =
1135            store.get_device(device.user_id(), device.device_id()).await.unwrap().unwrap();
1136
1137        assert_eq!(device, loaded_device);
1138
1139        let user_devices = store.get_user_devices(device.user_id()).await.unwrap();
1140
1141        assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
1142        assert_eq!(user_devices.values().next().unwrap(), &device);
1143
1144        let loaded_device = user_devices.get(device.device_id()).unwrap();
1145
1146        assert_eq!(&device, loaded_device);
1147
1148        store.delete_devices(vec![device.clone()]);
1149        assert!(store.get_device(device.user_id(), device.device_id()).await.unwrap().is_none());
1150    }
1151
1152    #[async_test]
1153    async fn test_message_hash() {
1154        let store = MemoryStore::new();
1155
1156        let hash =
1157            OlmMessageHash { sender_key: "test_sender".to_owned(), hash: "test_hash".to_owned() };
1158
1159        let mut changes = Changes::default();
1160        changes.message_hashes.push(hash.clone());
1161
1162        assert!(!store.is_message_known(&hash).await.unwrap());
1163        store.save_changes(changes).await.unwrap();
1164        assert!(store.is_message_known(&hash).await.unwrap());
1165    }
1166
1167    #[async_test]
1168    async fn test_key_counts_of_empty_store_are_zero() {
1169        // Given an empty store
1170        let store = MemoryStore::new();
1171
1172        // When we count keys
1173        let key_counts = store.inbound_group_session_counts(Some("")).await.unwrap();
1174
1175        // Then the answer is zero
1176        assert_eq!(key_counts.total, 0);
1177        assert_eq!(key_counts.backed_up, 0);
1178    }
1179
1180    #[async_test]
1181    async fn test_counting_sessions_reports_the_number_of_sessions() {
1182        // Given a store with sessions
1183        let room_id = room_id!("!test:localhost");
1184        let (store, _) = store_with_sessions(4, room_id).await;
1185
1186        // When we count keys
1187        let key_counts = store.inbound_group_session_counts(Some("bkp")).await.unwrap();
1188
1189        // Then the answer equals the number of sessions we created
1190        assert_eq!(key_counts.total, 4);
1191        // And none are backed up
1192        assert_eq!(key_counts.backed_up, 0);
1193    }
1194
1195    #[async_test]
1196    async fn test_counting_backed_up_sessions_reports_the_number_backed_up_in_this_backup() {
1197        // Given a store with sessions, some backed up
1198        let room_id = room_id!("!test:localhost");
1199        let (store, sessions) = store_with_sessions(5, room_id).await;
1200        mark_backed_up(&store, room_id, "bkp", &sessions[..2]).await;
1201
1202        // When we count keys
1203        let key_counts = store.inbound_group_session_counts(Some("bkp")).await.unwrap();
1204
1205        // Then the answer equals the number of sessions we created
1206        assert_eq!(key_counts.total, 5);
1207        // And the backed_up count matches how many were backed up
1208        assert_eq!(key_counts.backed_up, 2);
1209    }
1210
1211    #[async_test]
1212    async fn test_counting_backed_up_sessions_for_null_backup_reports_zero() {
1213        // Given a store with sessions, some backed up
1214        let room_id = room_id!("!test:localhost");
1215        let (store, sessions) = store_with_sessions(4, room_id).await;
1216        mark_backed_up(&store, room_id, "bkp", &sessions[..2]).await;
1217
1218        // When we count keys, providing None as the backup version
1219        let key_counts = store.inbound_group_session_counts(None).await.unwrap();
1220
1221        // Then we ignore everything and just say zero
1222        assert_eq!(key_counts.backed_up, 0);
1223    }
1224
1225    #[async_test]
1226    async fn test_counting_backed_up_sessions_only_reports_sessions_in_the_version_specified() {
1227        // Given a store with sessions, backed up in several versions
1228        let room_id = room_id!("!test:localhost");
1229        let (store, sessions) = store_with_sessions(4, room_id).await;
1230        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
1231        mark_backed_up(&store, room_id, "bkp2", &sessions[3..]).await;
1232
1233        // When we count keys for bkp2
1234        let key_counts = store.inbound_group_session_counts(Some("bkp2")).await.unwrap();
1235
1236        // Then the backed_up count reflects how many were backed up in bkp2 only
1237        assert_eq!(key_counts.backed_up, 1);
1238    }
1239
1240    /// Mark the supplied sessions as backed up in the supplied backup version
1241    async fn mark_backed_up(
1242        store: &MemoryStore,
1243        room_id: &RoomId,
1244        backup_version: &str,
1245        sessions: &[InboundGroupSession],
1246    ) {
1247        let rooms_and_ids: Vec<_> = sessions.iter().map(|s| (room_id, s.session_id())).collect();
1248
1249        store
1250            .mark_inbound_group_sessions_as_backed_up(backup_version, &rooms_and_ids)
1251            .await
1252            .expect("Failed to mark sessions as backed up");
1253    }
1254
1255    // Create a MemoryStore containing the supplied number of sessions.
1256    //
1257    // Sessions are returned in alphabetical order of session id.
1258    async fn store_with_sessions(
1259        num_sessions: usize,
1260        room_id: &RoomId,
1261    ) -> (MemoryStore, Vec<InboundGroupSession>) {
1262        let (account, _) = get_account_and_session_test_helper();
1263
1264        let mut sessions = Vec::with_capacity(num_sessions);
1265        for _ in 0..num_sessions {
1266            sessions.push(new_session(&account, room_id).await);
1267        }
1268        sessions.sort_by_key(|s| s.session_id().to_owned());
1269
1270        let store = MemoryStore::new();
1271        store.save_inbound_group_sessions(sessions.clone(), None).await.unwrap();
1272
1273        (store, sessions)
1274    }
1275
1276    // Create a new InboundGroupSession
1277    async fn new_session(account: &Account, room_id: &RoomId) -> InboundGroupSession {
1278        let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw";
1279        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
1280
1281        InboundGroupSession::new(
1282            Curve25519PublicKey::from_base64(curve_key).unwrap(),
1283            Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE").unwrap(),
1284            room_id,
1285            &outbound.session_key().await,
1286            SenderData::unknown(),
1287            None,
1288            outbound.settings().algorithm.to_owned(),
1289            None,
1290            false,
1291        )
1292        .unwrap()
1293    }
1294
1295    /// Find the session_id and backed_up_to value for each of the sessions in
1296    /// the store.
1297    async fn backed_up_tos(store: &MemoryStore) -> HashMap<SessionId, String> {
1298        store
1299            .get_inbound_group_sessions_and_backed_up_to()
1300            .await
1301            .expect("Unable to get inbound group sessions and backup order")
1302            .iter()
1303            .map(|(s, o)| {
1304                (
1305                    s.session_id().to_owned(),
1306                    o.as_ref().map(|v| v.as_str().to_owned()).unwrap_or("".to_owned()),
1307                )
1308            })
1309            .collect()
1310    }
1311}
1312
1313#[cfg(test)]
1314mod integration_tests {
1315    use std::{
1316        collections::HashMap,
1317        sync::{Arc, Mutex, OnceLock},
1318    };
1319
1320    use async_trait::async_trait;
1321    use matrix_sdk_common::cross_process_lock::CrossProcessLockGeneration;
1322    use ruma::{
1323        DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId, events::secret::request::SecretName,
1324    };
1325    use vodozemac::Curve25519PublicKey;
1326    use zeroize::Zeroizing;
1327
1328    use super::MemoryStore;
1329    use crate::{
1330        Account, DeviceData, GossipRequest, SecretInfo, Session, UserIdentityData,
1331        cryptostore_integration_tests, cryptostore_integration_tests_time,
1332        olm::{
1333            InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity,
1334            SenderDataType, StaticAccountData,
1335        },
1336        store::{
1337            CryptoStore,
1338            types::{
1339                BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts,
1340                RoomKeyWithheldEntry, RoomPendingKeyBundleDetails, RoomSettings,
1341                StoredRoomKeyBundleData, TrackedUser,
1342            },
1343        },
1344    };
1345
1346    /// Holds on to a MemoryStore during a test, and moves it back into STORES
1347    /// when this is dropped
1348    #[derive(Clone, Debug)]
1349    struct PersistentMemoryStore(Arc<MemoryStore>);
1350
1351    impl PersistentMemoryStore {
1352        fn new() -> Self {
1353            Self(Arc::new(MemoryStore::new()))
1354        }
1355
1356        fn get_static_account(&self) -> Option<StaticAccountData> {
1357            self.0.get_static_account()
1358        }
1359    }
1360
1361    /// Return a clone of the store for the test with the supplied name. Note:
1362    /// dropping this store won't destroy its data, since
1363    /// [PersistentMemoryStore] is a reference-counted smart pointer
1364    /// to an underlying [MemoryStore].
1365    async fn get_store(
1366        name: &str,
1367        _passphrase: Option<&str>,
1368        clear_data: bool,
1369    ) -> PersistentMemoryStore {
1370        // Holds on to one [PersistentMemoryStore] per test, so even if the test drops
1371        // the store, we keep its data alive. This simulates the behaviour of
1372        // the other stores, which keep their data in a real DB, allowing us to
1373        // test MemoryStore using the same code.
1374        static STORES: OnceLock<Mutex<HashMap<String, PersistentMemoryStore>>> = OnceLock::new();
1375        let stores = STORES.get_or_init(|| Mutex::new(HashMap::new()));
1376
1377        let mut stores = stores.lock().unwrap();
1378
1379        if clear_data {
1380            // Create a new PersistentMemoryStore
1381            let new_store = PersistentMemoryStore::new();
1382            stores.insert(name.to_owned(), new_store.clone());
1383            new_store
1384        } else {
1385            stores.entry(name.to_owned()).or_insert_with(PersistentMemoryStore::new).clone()
1386        }
1387    }
1388
1389    /// Forwards all methods to the underlying [MemoryStore].
1390    #[cfg_attr(target_family = "wasm", async_trait(?Send))]
1391    #[cfg_attr(not(target_family = "wasm"), async_trait)]
1392    impl CryptoStore for PersistentMemoryStore {
1393        type Error = <MemoryStore as CryptoStore>::Error;
1394
1395        async fn load_account(&self) -> Result<Option<Account>, Self::Error> {
1396            self.0.load_account().await
1397        }
1398
1399        async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>, Self::Error> {
1400            self.0.load_identity().await
1401        }
1402
1403        async fn save_changes(&self, changes: Changes) -> Result<(), Self::Error> {
1404            self.0.save_changes(changes).await
1405        }
1406
1407        async fn save_pending_changes(&self, changes: PendingChanges) -> Result<(), Self::Error> {
1408            self.0.save_pending_changes(changes).await
1409        }
1410
1411        async fn save_inbound_group_sessions(
1412            &self,
1413            sessions: Vec<InboundGroupSession>,
1414            backed_up_to_version: Option<&str>,
1415        ) -> Result<(), Self::Error> {
1416            self.0.save_inbound_group_sessions(sessions, backed_up_to_version).await
1417        }
1418
1419        async fn get_sessions(
1420            &self,
1421            sender_key: &str,
1422        ) -> Result<Option<Vec<Session>>, Self::Error> {
1423            self.0.get_sessions(sender_key).await
1424        }
1425
1426        async fn get_inbound_group_session(
1427            &self,
1428            room_id: &RoomId,
1429            session_id: &str,
1430        ) -> Result<Option<InboundGroupSession>, Self::Error> {
1431            self.0.get_inbound_group_session(room_id, session_id).await
1432        }
1433
1434        async fn get_withheld_info(
1435            &self,
1436            room_id: &RoomId,
1437            session_id: &str,
1438        ) -> Result<Option<RoomKeyWithheldEntry>, Self::Error> {
1439            self.0.get_withheld_info(room_id, session_id).await
1440        }
1441
1442        async fn get_withheld_sessions_by_room_id(
1443            &self,
1444            room_id: &RoomId,
1445        ) -> Result<Vec<RoomKeyWithheldEntry>, Self::Error> {
1446            self.0.get_withheld_sessions_by_room_id(room_id).await
1447        }
1448
1449        async fn get_inbound_group_sessions(
1450            &self,
1451        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1452            self.0.get_inbound_group_sessions().await
1453        }
1454
1455        async fn inbound_group_session_counts(
1456            &self,
1457            backup_version: Option<&str>,
1458        ) -> Result<RoomKeyCounts, Self::Error> {
1459            self.0.inbound_group_session_counts(backup_version).await
1460        }
1461
1462        async fn get_inbound_group_sessions_by_room_id(
1463            &self,
1464            room_id: &RoomId,
1465        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1466            self.0.get_inbound_group_sessions_by_room_id(room_id).await
1467        }
1468
1469        async fn get_inbound_group_sessions_for_device_batch(
1470            &self,
1471            sender_key: Curve25519PublicKey,
1472            sender_data_type: SenderDataType,
1473            after_session_id: Option<String>,
1474            limit: usize,
1475        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1476            self.0
1477                .get_inbound_group_sessions_for_device_batch(
1478                    sender_key,
1479                    sender_data_type,
1480                    after_session_id,
1481                    limit,
1482                )
1483                .await
1484        }
1485
1486        async fn inbound_group_sessions_for_backup(
1487            &self,
1488            backup_version: &str,
1489            limit: usize,
1490        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1491            self.0.inbound_group_sessions_for_backup(backup_version, limit).await
1492        }
1493
1494        async fn mark_inbound_group_sessions_as_backed_up(
1495            &self,
1496            backup_version: &str,
1497            room_and_session_ids: &[(&RoomId, &str)],
1498        ) -> Result<(), Self::Error> {
1499            self.0
1500                .mark_inbound_group_sessions_as_backed_up(backup_version, room_and_session_ids)
1501                .await
1502        }
1503
1504        async fn reset_backup_state(&self) -> Result<(), Self::Error> {
1505            self.0.reset_backup_state().await
1506        }
1507
1508        async fn load_backup_keys(&self) -> Result<BackupKeys, Self::Error> {
1509            self.0.load_backup_keys().await
1510        }
1511
1512        async fn load_dehydrated_device_pickle_key(
1513            &self,
1514        ) -> Result<Option<DehydratedDeviceKey>, Self::Error> {
1515            self.0.load_dehydrated_device_pickle_key().await
1516        }
1517
1518        async fn delete_dehydrated_device_pickle_key(&self) -> Result<(), Self::Error> {
1519            self.0.delete_dehydrated_device_pickle_key().await
1520        }
1521
1522        async fn get_outbound_group_session(
1523            &self,
1524            room_id: &RoomId,
1525        ) -> Result<Option<OutboundGroupSession>, Self::Error> {
1526            self.0.get_outbound_group_session(room_id).await
1527        }
1528
1529        async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>, Self::Error> {
1530            self.0.load_tracked_users().await
1531        }
1532
1533        async fn save_tracked_users(&self, users: &[(&UserId, bool)]) -> Result<(), Self::Error> {
1534            self.0.save_tracked_users(users).await
1535        }
1536
1537        async fn get_device(
1538            &self,
1539            user_id: &UserId,
1540            device_id: &DeviceId,
1541        ) -> Result<Option<DeviceData>, Self::Error> {
1542            self.0.get_device(user_id, device_id).await
1543        }
1544
1545        async fn get_user_devices(
1546            &self,
1547            user_id: &UserId,
1548        ) -> Result<HashMap<OwnedDeviceId, DeviceData>, Self::Error> {
1549            self.0.get_user_devices(user_id).await
1550        }
1551
1552        async fn get_own_device(&self) -> Result<DeviceData, Self::Error> {
1553            self.0.get_own_device().await
1554        }
1555
1556        async fn get_user_identity(
1557            &self,
1558            user_id: &UserId,
1559        ) -> Result<Option<UserIdentityData>, Self::Error> {
1560            self.0.get_user_identity(user_id).await
1561        }
1562
1563        async fn is_message_known(
1564            &self,
1565            message_hash: &OlmMessageHash,
1566        ) -> Result<bool, Self::Error> {
1567            self.0.is_message_known(message_hash).await
1568        }
1569
1570        async fn get_outgoing_secret_requests(
1571            &self,
1572            request_id: &TransactionId,
1573        ) -> Result<Option<GossipRequest>, Self::Error> {
1574            self.0.get_outgoing_secret_requests(request_id).await
1575        }
1576
1577        async fn get_secret_request_by_info(
1578            &self,
1579            secret_info: &SecretInfo,
1580        ) -> Result<Option<GossipRequest>, Self::Error> {
1581            self.0.get_secret_request_by_info(secret_info).await
1582        }
1583
1584        async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>, Self::Error> {
1585            self.0.get_unsent_secret_requests().await
1586        }
1587
1588        async fn delete_outgoing_secret_requests(
1589            &self,
1590            request_id: &TransactionId,
1591        ) -> Result<(), Self::Error> {
1592            self.0.delete_outgoing_secret_requests(request_id).await
1593        }
1594
1595        async fn get_secrets_from_inbox(
1596            &self,
1597            secret_name: &SecretName,
1598        ) -> Result<Vec<Zeroizing<String>>, Self::Error> {
1599            self.0.get_secrets_from_inbox(secret_name).await
1600        }
1601
1602        async fn delete_secrets_from_inbox(
1603            &self,
1604            secret_name: &SecretName,
1605        ) -> Result<(), Self::Error> {
1606            self.0.delete_secrets_from_inbox(secret_name).await
1607        }
1608
1609        async fn get_room_settings(
1610            &self,
1611            room_id: &RoomId,
1612        ) -> Result<Option<RoomSettings>, Self::Error> {
1613            self.0.get_room_settings(room_id).await
1614        }
1615
1616        async fn get_received_room_key_bundle_data(
1617            &self,
1618            room_id: &RoomId,
1619            user_id: &UserId,
1620        ) -> crate::store::Result<Option<StoredRoomKeyBundleData>, Self::Error> {
1621            self.0.get_received_room_key_bundle_data(room_id, user_id).await
1622        }
1623
1624        async fn has_downloaded_all_room_keys(
1625            &self,
1626            room_id: &RoomId,
1627        ) -> Result<bool, Self::Error> {
1628            self.0.has_downloaded_all_room_keys(room_id).await
1629        }
1630
1631        async fn get_pending_key_bundle_details_for_room(
1632            &self,
1633            room_id: &RoomId,
1634        ) -> Result<Option<RoomPendingKeyBundleDetails>, Self::Error> {
1635            self.0.get_pending_key_bundle_details_for_room(room_id).await
1636        }
1637
1638        async fn get_all_rooms_pending_key_bundles(
1639            &self,
1640        ) -> Result<Vec<RoomPendingKeyBundleDetails>, Self::Error> {
1641            self.0.get_all_rooms_pending_key_bundles().await
1642        }
1643
1644        async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>, Self::Error> {
1645            self.0.get_custom_value(key).await
1646        }
1647
1648        async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<(), Self::Error> {
1649            self.0.set_custom_value(key, value).await
1650        }
1651
1652        async fn remove_custom_value(&self, key: &str) -> Result<(), Self::Error> {
1653            self.0.remove_custom_value(key).await
1654        }
1655
1656        async fn try_take_leased_lock(
1657            &self,
1658            lease_duration_ms: u32,
1659            key: &str,
1660            holder: &str,
1661        ) -> Result<Option<CrossProcessLockGeneration>, Self::Error> {
1662            self.0.try_take_leased_lock(lease_duration_ms, key, holder).await
1663        }
1664
1665        async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
1666            self.0.next_batch_token().await
1667        }
1668
1669        async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1670            self.0.get_size().await
1671        }
1672    }
1673
1674    cryptostore_integration_tests!();
1675    cryptostore_integration_tests_time!();
1676}