Skip to main content

matrix_sdk_crypto/store/
memorystore.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{BTreeMap, HashMap, HashSet},
17    convert::Infallible,
18    sync::Arc,
19};
20
21use async_trait::async_trait;
22use matrix_sdk_common::{
23    cross_process_lock::{
24        CrossProcessLockGeneration,
25        memory_store_helper::{Lease, try_take_leased_lock},
26    },
27    locks::RwLock as StdRwLock,
28};
29use ruma::{
30    DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId,
31    UserId, events::secret::request::SecretName,
32};
33use tokio::sync::{Mutex, RwLock};
34use tracing::warn;
35use vodozemac::Curve25519PublicKey;
36
37use super::{
38    Account, CryptoStore, InboundGroupSession, Session,
39    caches::DeviceStore,
40    types::{
41        BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts, RoomSettings,
42        StoredRoomKeyBundleData, TrackedUser,
43    },
44};
45use crate::{
46    gossiping::{GossipRequest, GossippedSecret, SecretInfo},
47    identities::{DeviceData, UserIdentityData},
48    olm::{
49        OutboundGroupSession, PickledAccount, PickledInboundGroupSession, PickledSession,
50        PrivateCrossSigningIdentity, SenderDataType, StaticAccountData,
51    },
52    store::types::{RoomKeyWithheldEntry, RoomPendingKeyBundleDetails},
53};
54
55fn encode_key_info(info: &SecretInfo) -> String {
56    match info {
57        SecretInfo::KeyRequest(info) => {
58            format!("{}{}{}", info.room_id(), info.algorithm(), info.session_id())
59        }
60        SecretInfo::SecretRequest(i) => i.as_ref().to_owned(),
61    }
62}
63
64type SessionId = String;
65
66/// The "version" of a backup - newtype wrapper around a String.
67#[derive(Clone, Debug, PartialEq)]
68struct BackupVersion(String);
69
70impl BackupVersion {
71    fn from(s: &str) -> Self {
72        Self(s.to_owned())
73    }
74
75    fn as_str(&self) -> &str {
76        &self.0
77    }
78}
79
80/// An in-memory only store that will forget all the E2EE key once it's dropped.
81#[derive(Default, Debug)]
82pub struct MemoryStore {
83    static_account: Arc<StdRwLock<Option<StaticAccountData>>>,
84
85    account: StdRwLock<Option<String>>,
86    // Map of sender_key to map of session_id to serialized pickle
87    sessions: StdRwLock<BTreeMap<String, BTreeMap<String, String>>>,
88    inbound_group_sessions: StdRwLock<BTreeMap<OwnedRoomId, HashMap<String, String>>>,
89
90    /// Map room id -> session id -> backup order number
91    /// The latest backup in which this session is stored. Equivalent to
92    /// `backed_up_to` in [`IndexedDbCryptoStore`]
93    inbound_group_sessions_backed_up_to:
94        StdRwLock<HashMap<OwnedRoomId, HashMap<SessionId, BackupVersion>>>,
95
96    outbound_group_sessions: StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>,
97    private_identity: StdRwLock<Option<PrivateCrossSigningIdentity>>,
98    tracked_users: StdRwLock<HashMap<OwnedUserId, TrackedUser>>,
99    olm_hashes: StdRwLock<HashMap<String, HashSet<String>>>,
100    devices: DeviceStore,
101    identities: StdRwLock<HashMap<OwnedUserId, String>>,
102    outgoing_key_requests: StdRwLock<HashMap<OwnedTransactionId, GossipRequest>>,
103    key_requests_by_info: StdRwLock<HashMap<String, OwnedTransactionId>>,
104    direct_withheld_info: StdRwLock<HashMap<OwnedRoomId, HashMap<String, RoomKeyWithheldEntry>>>,
105    custom_values: StdRwLock<HashMap<String, Vec<u8>>>,
106    leases: StdRwLock<HashMap<String, Lease>>,
107    secret_inbox: StdRwLock<HashMap<String, Vec<GossippedSecret>>>,
108    backup_keys: RwLock<BackupKeys>,
109    dehydrated_device_pickle_key: RwLock<Option<DehydratedDeviceKey>>,
110    next_batch_token: RwLock<Option<String>>,
111    room_settings: StdRwLock<HashMap<OwnedRoomId, RoomSettings>>,
112    room_key_bundles:
113        StdRwLock<HashMap<OwnedRoomId, HashMap<OwnedUserId, StoredRoomKeyBundleData>>>,
114    room_key_backups_fully_downloaded: StdRwLock<HashSet<OwnedRoomId>>,
115    rooms_pending_key_bundle: StdRwLock<HashMap<OwnedRoomId, RoomPendingKeyBundleDetails>>,
116
117    save_changes_lock: Arc<Mutex<()>>,
118}
119
120impl MemoryStore {
121    /// Create a new empty `MemoryStore`.
122    pub fn new() -> Self {
123        Self::default()
124    }
125
126    fn get_static_account(&self) -> Option<StaticAccountData> {
127        self.static_account.read().clone()
128    }
129
130    pub(crate) fn save_devices(&self, devices: Vec<DeviceData>) {
131        for device in devices {
132            let _ = self.devices.add(device);
133        }
134    }
135
136    fn delete_devices(&self, devices: Vec<DeviceData>) {
137        for device in devices {
138            let _ = self.devices.remove(device.user_id(), device.device_id());
139        }
140    }
141
142    fn save_sessions(&self, sessions: Vec<(String, PickledSession)>) {
143        let mut session_store = self.sessions.write();
144
145        for (session_id, pickle) in sessions {
146            let entry = session_store.entry(pickle.sender_key.to_base64()).or_default();
147
148            // insert or replace if exists
149            entry.insert(
150                session_id,
151                serde_json::to_string(&pickle).expect("Failed to serialize olm session"),
152            );
153        }
154    }
155
156    fn save_outbound_group_sessions(&self, sessions: Vec<OutboundGroupSession>) {
157        self.outbound_group_sessions
158            .write()
159            .extend(sessions.into_iter().map(|s| (s.room_id().to_owned(), s)));
160    }
161
162    fn save_private_identity(&self, private_identity: Option<PrivateCrossSigningIdentity>) {
163        *self.private_identity.write() = private_identity;
164    }
165
166    /// Return all the [`InboundGroupSession`]s we have, paired with the
167    /// `backed_up_to` value for each one (or "" where it is missing, which
168    /// should never happen).
169    async fn get_inbound_group_sessions_and_backed_up_to(
170        &self,
171    ) -> Result<Vec<(InboundGroupSession, Option<BackupVersion>)>> {
172        let lookup = |s: &InboundGroupSession| {
173            self.inbound_group_sessions_backed_up_to
174                .read()
175                .get(&s.room_id)?
176                .get(s.session_id())
177                .cloned()
178        };
179
180        Ok(self
181            .get_inbound_group_sessions()
182            .await?
183            .into_iter()
184            .map(|s| {
185                let v = lookup(&s);
186                (s, v)
187            })
188            .collect())
189    }
190}
191
192type Result<T> = std::result::Result<T, Infallible>;
193
194#[cfg_attr(target_family = "wasm", async_trait(?Send))]
195#[cfg_attr(not(target_family = "wasm"), async_trait)]
196impl CryptoStore for MemoryStore {
197    type Error = Infallible;
198
199    async fn load_account(&self) -> Result<Option<Account>> {
200        let pickled_account: Option<PickledAccount> = self.account.read().as_ref().map(|acc| {
201            serde_json::from_str(acc)
202                .expect("Deserialization failed: invalid pickled account JSON format")
203        });
204
205        if let Some(pickle) = pickled_account {
206            let account =
207                Account::from_pickle(pickle).expect("From pickle failed: invalid pickle format");
208
209            *self.static_account.write() = Some(account.static_data().clone());
210
211            Ok(Some(account))
212        } else {
213            Ok(None)
214        }
215    }
216
217    async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
218        Ok(self.private_identity.read().clone())
219    }
220
221    async fn next_batch_token(&self) -> Result<Option<String>> {
222        Ok(self.next_batch_token.read().await.clone())
223    }
224
225    async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
226        let _guard = self.save_changes_lock.lock().await;
227
228        let pickled_account = if let Some(account) = changes.account {
229            *self.static_account.write() = Some(account.static_data().clone());
230            Some(account.pickle())
231        } else {
232            None
233        };
234
235        *self.account.write() = pickled_account.map(|pickle| {
236            serde_json::to_string(&pickle)
237                .expect("Serialization failed: invalid pickled account JSON format")
238        });
239
240        Ok(())
241    }
242
243    async fn save_changes(&self, changes: Changes) -> Result<()> {
244        let _guard = self.save_changes_lock.lock().await;
245
246        let mut pickled_session: Vec<(String, PickledSession)> = Vec::new();
247        for session in changes.sessions {
248            let session_id = session.session_id().to_owned();
249            let pickle = session.pickle().await;
250            pickled_session.push((session_id.clone(), pickle));
251        }
252        self.save_sessions(pickled_session);
253
254        self.save_inbound_group_sessions(changes.inbound_group_sessions, None).await?;
255        self.save_outbound_group_sessions(changes.outbound_group_sessions);
256        self.save_private_identity(changes.private_identity);
257
258        self.save_devices(changes.devices.new);
259        self.save_devices(changes.devices.changed);
260        self.delete_devices(changes.devices.deleted);
261
262        {
263            let mut identities = self.identities.write();
264            for identity in changes.identities.new.into_iter().chain(changes.identities.changed) {
265                identities.insert(
266                    identity.user_id().to_owned(),
267                    serde_json::to_string(&identity)
268                        .expect("UserIdentityData should always serialize to json"),
269                );
270            }
271        }
272
273        {
274            let mut olm_hashes = self.olm_hashes.write();
275            for hash in changes.message_hashes {
276                olm_hashes.entry(hash.sender_key.to_owned()).or_default().insert(hash.hash.clone());
277            }
278        }
279
280        {
281            let mut outgoing_key_requests = self.outgoing_key_requests.write();
282            let mut key_requests_by_info = self.key_requests_by_info.write();
283
284            for key_request in changes.key_requests {
285                let id = key_request.request_id.clone();
286                let info_string = encode_key_info(&key_request.info);
287
288                outgoing_key_requests.insert(id.clone(), key_request);
289                key_requests_by_info.insert(info_string, id);
290            }
291        }
292
293        if let Some(key) = changes.backup_decryption_key {
294            self.backup_keys.write().await.decryption_key = Some(key);
295        }
296
297        if let Some(version) = changes.backup_version {
298            self.backup_keys.write().await.backup_version = Some(version);
299        }
300
301        if let Some(pickle_key) = changes.dehydrated_device_pickle_key {
302            let mut lock = self.dehydrated_device_pickle_key.write().await;
303            *lock = Some(pickle_key);
304        }
305
306        {
307            let mut secret_inbox = self.secret_inbox.write();
308            for secret in changes.secrets {
309                secret_inbox.entry(secret.secret_name.to_string()).or_default().push(secret);
310            }
311        }
312
313        {
314            let mut direct_withheld_info = self.direct_withheld_info.write();
315            for (room_id, data) in changes.withheld_session_info {
316                for (session_id, event) in data {
317                    direct_withheld_info
318                        .entry(room_id.to_owned())
319                        .or_default()
320                        .insert(session_id, event);
321                }
322            }
323        }
324
325        if let Some(next_batch_token) = changes.next_batch_token {
326            *self.next_batch_token.write().await = Some(next_batch_token);
327        }
328
329        if !changes.room_settings.is_empty() {
330            let mut settings = self.room_settings.write();
331            settings.extend(changes.room_settings);
332        }
333
334        if !changes.received_room_key_bundles.is_empty() {
335            let mut room_key_bundles = self.room_key_bundles.write();
336            for bundle in changes.received_room_key_bundles {
337                room_key_bundles
338                    .entry(bundle.bundle_data.room_id.clone())
339                    .or_default()
340                    .insert(bundle.sender_user.clone(), bundle);
341            }
342        }
343
344        if !changes.room_key_backups_fully_downloaded.is_empty() {
345            let mut room_key_backups_fully_downloaded =
346                self.room_key_backups_fully_downloaded.write();
347            for room in changes.room_key_backups_fully_downloaded {
348                room_key_backups_fully_downloaded.insert(room);
349            }
350        }
351
352        if !changes.rooms_pending_key_bundle.is_empty() {
353            let mut lock = self.rooms_pending_key_bundle.write();
354            for (room, details) in changes.rooms_pending_key_bundle {
355                if let Some(details) = details {
356                    lock.insert(room, details);
357                } else {
358                    lock.remove(&room);
359                }
360            }
361        }
362
363        Ok(())
364    }
365
366    async fn save_inbound_group_sessions(
367        &self,
368        sessions: Vec<InboundGroupSession>,
369        backed_up_to_version: Option<&str>,
370    ) -> Result<()> {
371        for session in sessions {
372            let room_id = session.room_id();
373            let session_id = session.session_id();
374
375            // Sanity-check that the data in the sessions corresponds to backed_up_version
376            let backed_up = session.backed_up();
377            if backed_up != backed_up_to_version.is_some() {
378                warn!(
379                    backed_up,
380                    backed_up_to_version,
381                    "Session backed-up flag does not correspond to backup version setting",
382                );
383            }
384
385            if let Some(backup_version) = backed_up_to_version {
386                self.inbound_group_sessions_backed_up_to
387                    .write()
388                    .entry(room_id.to_owned())
389                    .or_default()
390                    .insert(session_id.to_owned(), BackupVersion::from(backup_version));
391            }
392
393            let pickle = session.pickle().await;
394            self.inbound_group_sessions
395                .write()
396                .entry(session.room_id().to_owned())
397                .or_default()
398                .insert(
399                    session.session_id().to_owned(),
400                    serde_json::to_string(&pickle)
401                        .expect("Pickle pickle data should serialize to json"),
402                );
403        }
404        Ok(())
405    }
406
407    async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
408        let device_keys = self.get_own_device().await?.as_device_keys().clone();
409
410        if let Some(pickles) = self.sessions.read().get(sender_key) {
411            let mut sessions: Vec<Session> = Vec::new();
412            for serialized_pickle in pickles.values() {
413                let pickle: PickledSession = serde_json::from_str(serialized_pickle.as_str())
414                    .expect("Pickle pickle deserialization should work");
415                let session = Session::from_pickle(device_keys.clone(), pickle)
416                    .expect("Expect from pickle to always work");
417                sessions.push(session);
418            }
419            Ok(Some(sessions))
420        } else {
421            Ok(None)
422        }
423    }
424
425    async fn get_inbound_group_session(
426        &self,
427        room_id: &RoomId,
428        session_id: &str,
429    ) -> Result<Option<InboundGroupSession>> {
430        let pickle: Option<PickledInboundGroupSession> = self
431            .inbound_group_sessions
432            .read()
433            .get(room_id)
434            .and_then(|m| m.get(session_id))
435            .and_then(|ser| {
436                serde_json::from_str(ser).expect("Pickle pickle deserialization should work")
437            });
438
439        Ok(pickle.map(|p| {
440            InboundGroupSession::from_pickle(p).expect("Expect from pickle to always work")
441        }))
442    }
443
444    async fn get_withheld_info(
445        &self,
446        room_id: &RoomId,
447        session_id: &str,
448    ) -> Result<Option<RoomKeyWithheldEntry>> {
449        Ok(self
450            .direct_withheld_info
451            .read()
452            .get(room_id)
453            .and_then(|e| Some(e.get(session_id)?.to_owned())))
454    }
455
456    async fn get_withheld_sessions_by_room_id(
457        &self,
458        room_id: &RoomId,
459    ) -> crate::store::Result<Vec<RoomKeyWithheldEntry>, Self::Error> {
460        Ok(self
461            .direct_withheld_info
462            .read()
463            .get(room_id)
464            .map(|e| e.values().cloned().collect())
465            .unwrap_or_default())
466    }
467
468    async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
469        let inbounds = self
470            .inbound_group_sessions
471            .read()
472            .values()
473            .flat_map(HashMap::values)
474            .map(|ser| {
475                let pickle: PickledInboundGroupSession =
476                    serde_json::from_str(ser).expect("Pickle deserialization should work");
477                InboundGroupSession::from_pickle(pickle).expect("Expect from pickle to always work")
478            })
479            .collect();
480        Ok(inbounds)
481    }
482
483    async fn inbound_group_session_counts(
484        &self,
485        backup_version: Option<&str>,
486    ) -> Result<RoomKeyCounts> {
487        let backed_up = if let Some(backup_version) = backup_version {
488            self.get_inbound_group_sessions_and_backed_up_to()
489                .await?
490                .into_iter()
491                // Count the sessions backed up in the required backup
492                .filter(|(_, o)| o.as_ref().is_some_and(|o| o.as_str() == backup_version))
493                .count()
494        } else {
495            // We asked about a nonexistent backup version - this doesn't make much sense,
496            // but we can easily answer that nothing is backed up in this
497            // nonexistent backup.
498            0
499        };
500
501        let total = self.inbound_group_sessions.read().values().map(HashMap::len).sum();
502        Ok(RoomKeyCounts { total, backed_up })
503    }
504
505    async fn get_inbound_group_sessions_by_room_id(
506        &self,
507        room_id: &RoomId,
508    ) -> Result<Vec<InboundGroupSession>> {
509        let inbounds = match self.inbound_group_sessions.read().get(room_id) {
510            None => Vec::new(),
511            Some(v) => v
512                .values()
513                .map(|ser| {
514                    let pickle: PickledInboundGroupSession =
515                        serde_json::from_str(ser).expect("Pickle deserialization should work");
516                    InboundGroupSession::from_pickle(pickle)
517                        .expect("Expect from pickle to always work")
518                })
519                .collect(),
520        };
521        Ok(inbounds)
522    }
523
524    async fn get_inbound_group_sessions_for_device_batch(
525        &self,
526        sender_key: Curve25519PublicKey,
527        sender_data_type: SenderDataType,
528        after_session_id: Option<String>,
529        limit: usize,
530    ) -> Result<Vec<InboundGroupSession>> {
531        // First, find all InboundGroupSessions, filtering for those that match the
532        // device and sender_data type.
533        let mut sessions: Vec<_> = self
534            .get_inbound_group_sessions()
535            .await?
536            .into_iter()
537            .filter(|session: &InboundGroupSession| {
538                session.creator_info.curve25519_key == sender_key
539                    && session.sender_data.to_type() == sender_data_type
540            })
541            .collect();
542
543        // Then, sort the sessions in order of ascending session ID...
544        sessions.sort_by_key(|s| s.session_id().to_owned());
545
546        // Figure out where in the array to start returning results from
547        let start_index = {
548            match after_session_id {
549                None => 0,
550                Some(id) => {
551                    // We're looking for the first session with a session ID strictly after `id`; if
552                    // there are none, the end of the array.
553                    sessions
554                        .iter()
555                        .position(|session| session.session_id() > id.as_str())
556                        .unwrap_or(sessions.len())
557                }
558            }
559        };
560
561        // Return up to `limit` items from the array, starting from `start_index`
562        Ok(sessions.drain(start_index..).take(limit).collect())
563    }
564
565    async fn inbound_group_sessions_for_backup(
566        &self,
567        backup_version: &str,
568        limit: usize,
569    ) -> Result<Vec<InboundGroupSession>> {
570        Ok(self
571            .get_inbound_group_sessions_and_backed_up_to()
572            .await?
573            .into_iter()
574            .filter_map(|(session, backed_up_to)| {
575                if let Some(ref existing_version) = backed_up_to
576                    && existing_version.as_str() == backup_version
577                {
578                    // This session is already backed up in the required backup
579                    None
580                } else {
581                    // It's not backed up, or it's backed up in a different backup
582                    Some(session)
583                }
584            })
585            .take(limit)
586            .collect())
587    }
588
589    async fn mark_inbound_group_sessions_as_backed_up(
590        &self,
591        backup_version: &str,
592        room_and_session_ids: &[(&RoomId, &str)],
593    ) -> Result<()> {
594        for &(room_id, session_id) in room_and_session_ids {
595            let session = self.get_inbound_group_session(room_id, session_id).await?;
596
597            if let Some(session) = session {
598                session.mark_as_backed_up();
599
600                self.inbound_group_sessions_backed_up_to
601                    .write()
602                    .entry(room_id.to_owned())
603                    .or_default()
604                    .insert(session_id.to_owned(), BackupVersion::from(backup_version));
605
606                // Save it back
607                let updated_pickle = session.pickle().await;
608
609                self.inbound_group_sessions.write().entry(room_id.to_owned()).or_default().insert(
610                    session_id.to_owned(),
611                    serde_json::to_string(&updated_pickle)
612                        .expect("Pickle serialization should work"),
613                );
614            }
615        }
616
617        Ok(())
618    }
619
620    async fn reset_backup_state(&self) -> Result<()> {
621        // Nothing to do here, because we remember which backup versions we backed up to
622        // in `mark_inbound_group_sessions_as_backed_up`, so we don't need to
623        // reset anything here because the required version is passed in to
624        // `inbound_group_sessions_for_backup`, and we can compare against the
625        // version we stored.
626
627        Ok(())
628    }
629
630    async fn load_backup_keys(&self) -> Result<BackupKeys> {
631        Ok(self.backup_keys.read().await.to_owned())
632    }
633
634    async fn load_dehydrated_device_pickle_key(&self) -> Result<Option<DehydratedDeviceKey>> {
635        Ok(self.dehydrated_device_pickle_key.read().await.to_owned())
636    }
637
638    async fn delete_dehydrated_device_pickle_key(&self) -> Result<()> {
639        let mut lock = self.dehydrated_device_pickle_key.write().await;
640        *lock = None;
641        Ok(())
642    }
643
644    async fn get_outbound_group_session(
645        &self,
646        room_id: &RoomId,
647    ) -> Result<Option<OutboundGroupSession>> {
648        Ok(self.outbound_group_sessions.read().get(room_id).cloned())
649    }
650
651    async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
652        Ok(self.tracked_users.read().values().cloned().collect())
653    }
654
655    async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
656        self.tracked_users.write().extend(tracked_users.iter().map(|(user_id, dirty)| {
657            let user_id: OwnedUserId = user_id.to_owned().into();
658            (user_id.clone(), TrackedUser { user_id, dirty: *dirty })
659        }));
660        Ok(())
661    }
662
663    async fn get_device(
664        &self,
665        user_id: &UserId,
666        device_id: &DeviceId,
667    ) -> Result<Option<DeviceData>> {
668        Ok(self.devices.get(user_id, device_id))
669    }
670
671    async fn get_user_devices(
672        &self,
673        user_id: &UserId,
674    ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
675        Ok(self.devices.user_devices(user_id))
676    }
677
678    async fn get_own_device(&self) -> Result<DeviceData> {
679        let account =
680            self.get_static_account().expect("Expect account to exist when getting own device");
681
682        Ok(self
683            .devices
684            .get(&account.user_id, &account.device_id)
685            .expect("Invalid state: Should always have a own device"))
686    }
687
688    async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
689        let serialized = self.identities.read().get(user_id).cloned();
690        match serialized {
691            None => Ok(None),
692            Some(serialized) => {
693                let id: UserIdentityData = serde_json::from_str(serialized.as_str())
694                    .expect("Only valid serialized identity are saved");
695                Ok(Some(id))
696            }
697        }
698    }
699
700    async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
701        Ok(self
702            .olm_hashes
703            .write()
704            .entry(message_hash.sender_key.to_owned())
705            .or_default()
706            .contains(&message_hash.hash))
707    }
708
709    async fn get_outgoing_secret_requests(
710        &self,
711        request_id: &TransactionId,
712    ) -> Result<Option<GossipRequest>> {
713        Ok(self.outgoing_key_requests.read().get(request_id).cloned())
714    }
715
716    async fn get_secret_request_by_info(
717        &self,
718        key_info: &SecretInfo,
719    ) -> Result<Option<GossipRequest>> {
720        let key_info_string = encode_key_info(key_info);
721
722        Ok(self
723            .key_requests_by_info
724            .read()
725            .get(&key_info_string)
726            .and_then(|i| self.outgoing_key_requests.read().get(i).cloned()))
727    }
728
729    async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
730        Ok(self
731            .outgoing_key_requests
732            .read()
733            .values()
734            .filter(|req| !req.sent_out)
735            .cloned()
736            .collect())
737    }
738
739    async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
740        let req = self.outgoing_key_requests.write().remove(request_id);
741        if let Some(i) = req {
742            let key_info_string = encode_key_info(&i.info);
743            self.key_requests_by_info.write().remove(&key_info_string);
744        }
745
746        Ok(())
747    }
748
749    async fn get_secrets_from_inbox(
750        &self,
751        secret_name: &SecretName,
752    ) -> Result<Vec<GossippedSecret>> {
753        Ok(self.secret_inbox.write().entry(secret_name.to_string()).or_default().to_owned())
754    }
755
756    async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
757        self.secret_inbox.write().remove(secret_name.as_str());
758
759        Ok(())
760    }
761
762    async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
763        Ok(self.room_settings.read().get(room_id).cloned())
764    }
765
766    async fn get_received_room_key_bundle_data(
767        &self,
768        room_id: &RoomId,
769        user_id: &UserId,
770    ) -> Result<Option<StoredRoomKeyBundleData>> {
771        let guard = self.room_key_bundles.read();
772
773        let result = guard.get(room_id).and_then(|bundles| bundles.get(user_id).cloned());
774
775        Ok(result)
776    }
777
778    async fn get_pending_key_bundle_details_for_room(
779        &self,
780        room_id: &RoomId,
781    ) -> Result<Option<RoomPendingKeyBundleDetails>> {
782        Ok(self.rooms_pending_key_bundle.read().get(room_id).cloned())
783    }
784
785    async fn get_all_rooms_pending_key_bundles(&self) -> Result<Vec<RoomPendingKeyBundleDetails>> {
786        Ok(self.rooms_pending_key_bundle.read().values().cloned().collect())
787    }
788
789    async fn has_downloaded_all_room_keys(&self, room_id: &RoomId) -> Result<bool> {
790        let guard = self.room_key_backups_fully_downloaded.read();
791        Ok(guard.contains(room_id))
792    }
793
794    async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
795        Ok(self.custom_values.read().get(key).cloned())
796    }
797
798    async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
799        self.custom_values.write().insert(key.to_owned(), value);
800        Ok(())
801    }
802
803    async fn remove_custom_value(&self, key: &str) -> Result<()> {
804        self.custom_values.write().remove(key);
805        Ok(())
806    }
807
808    async fn try_take_leased_lock(
809        &self,
810        lease_duration_ms: u32,
811        key: &str,
812        holder: &str,
813    ) -> Result<Option<CrossProcessLockGeneration>> {
814        Ok(try_take_leased_lock(&mut self.leases.write(), lease_duration_ms, key, holder))
815    }
816
817    async fn get_size(&self) -> Result<Option<usize>> {
818        Ok(None)
819    }
820}
821
822#[cfg(test)]
823mod tests {
824    use std::collections::HashMap;
825
826    use matrix_sdk_test::async_test;
827    use ruma::{RoomId, room_id, user_id};
828    use vodozemac::{Curve25519PublicKey, Ed25519PublicKey};
829
830    use super::SessionId;
831    use crate::{
832        DeviceData,
833        identities::device::testing::get_device,
834        olm::{
835            Account, InboundGroupSession, OlmMessageHash, PrivateCrossSigningIdentity, SenderData,
836            tests::get_account_and_session_test_helper,
837        },
838        store::{
839            CryptoStore,
840            memorystore::MemoryStore,
841            types::{Changes, DeviceChanges, PendingChanges},
842        },
843    };
844
845    #[async_test]
846    async fn test_session_store() {
847        let (account, session) = get_account_and_session_test_helper();
848        let own_device = DeviceData::from_account(&account);
849        let store = MemoryStore::new();
850
851        assert!(store.load_account().await.unwrap().is_none());
852
853        store
854            .save_changes(Changes {
855                devices: DeviceChanges { new: vec![own_device], ..Default::default() },
856                ..Default::default()
857            })
858            .await
859            .unwrap();
860        store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
861
862        store
863            .save_changes(Changes { sessions: (vec![session.clone()]), ..Default::default() })
864            .await
865            .unwrap();
866
867        let sessions = store.get_sessions(&session.sender_key.to_base64()).await.unwrap().unwrap();
868
869        let loaded_session = &sessions[0];
870
871        assert_eq!(&session, loaded_session);
872    }
873
874    #[async_test]
875    async fn test_inbound_group_session_store() {
876        let (account, _) = get_account_and_session_test_helper();
877        let room_id = room_id!("!test:localhost");
878        let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw";
879
880        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
881        let inbound = InboundGroupSession::new(
882            Curve25519PublicKey::from_base64(curve_key).unwrap(),
883            Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE").unwrap(),
884            room_id,
885            &outbound.session_key().await,
886            SenderData::unknown(),
887            None,
888            outbound.settings().algorithm.to_owned(),
889            None,
890            false,
891        )
892        .unwrap();
893
894        let store = MemoryStore::new();
895        store.save_inbound_group_sessions(vec![inbound.clone()], None).await.unwrap();
896
897        let loaded_session =
898            store.get_inbound_group_session(room_id, outbound.session_id()).await.unwrap().unwrap();
899        assert_eq!(inbound, loaded_session);
900    }
901
902    #[async_test]
903    async fn test_backing_up_marks_sessions_as_backed_up() {
904        // Given there are 2 sessions
905        let room_id = room_id!("!test:localhost");
906        let (store, sessions) = store_with_sessions(2, room_id).await;
907
908        // When I mark them as backed up
909        mark_backed_up(&store, room_id, "bkp1", &sessions).await;
910
911        // Then their backed_up_to field is set
912        let but = backed_up_tos(&store).await;
913        assert_eq!(but[sessions[0].session_id()], "bkp1");
914        assert_eq!(but[sessions[1].session_id()], "bkp1");
915    }
916
917    #[async_test]
918    async fn test_backing_up_a_second_set_of_sessions_updates_their_backup_order() {
919        // Given there are 3 sessions
920        let room_id = room_id!("!test:localhost");
921        let (store, sessions) = store_with_sessions(3, room_id).await;
922
923        // When I mark 0 and 1 as backed up in bkp1
924        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
925
926        // And 1 and 2 as backed up in bkp2
927        mark_backed_up(&store, room_id, "bkp2", &sessions[1..]).await;
928
929        // Then 0 is backed up in bkp1 and the 1 and 2 are backed up in bkp2
930        let but = backed_up_tos(&store).await;
931        assert_eq!(but[sessions[0].session_id()], "bkp1");
932        assert_eq!(but[sessions[1].session_id()], "bkp2");
933        assert_eq!(but[sessions[2].session_id()], "bkp2");
934    }
935
936    #[async_test]
937    async fn test_backing_up_again_to_the_same_version_has_no_effect() {
938        // Given there are 3 sessions
939        let room_id = room_id!("!test:localhost");
940        let (store, sessions) = store_with_sessions(3, room_id).await;
941
942        // When I mark the first two as backed up in the first backup
943        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
944
945        // And the last 2 as backed up in the same backup version
946        mark_backed_up(&store, room_id, "bkp1", &sessions[1..]).await;
947
948        // Then they all get the same backed_up_to value
949        let but = backed_up_tos(&store).await;
950        assert_eq!(but[sessions[0].session_id()], "bkp1");
951        assert_eq!(but[sessions[1].session_id()], "bkp1");
952        assert_eq!(but[sessions[2].session_id()], "bkp1");
953    }
954
955    #[async_test]
956    async fn test_backing_up_to_an_old_backup_version_can_increase_backed_up_to() {
957        // Given we have backed up some sessions to 2 backup versions, an older and a
958        // newer
959        let room_id = room_id!("!test:localhost");
960        let (store, sessions) = store_with_sessions(4, room_id).await;
961        mark_backed_up(&store, room_id, "older_bkp", &sessions[..2]).await;
962        mark_backed_up(&store, room_id, "newer_bkp", &sessions[1..2]).await;
963
964        // When I ask to back up the un-backed-up ones to the older backup
965        mark_backed_up(&store, room_id, "older_bkp", &sessions[2..]).await;
966
967        // Then each session lists the backup it was most recently included in
968        let but = backed_up_tos(&store).await;
969        assert_eq!(but[sessions[0].session_id()], "older_bkp");
970        assert_eq!(but[sessions[1].session_id()], "newer_bkp");
971        assert_eq!(but[sessions[2].session_id()], "older_bkp");
972        assert_eq!(but[sessions[3].session_id()], "older_bkp");
973    }
974
975    #[async_test]
976    async fn test_backing_up_to_an_old_backup_version_overwrites_a_newer_one() {
977        // Given we have backed up to 2 backup versions, an older and a newer
978        let room_id = room_id!("!test:localhost");
979        let (store, sessions) = store_with_sessions(4, room_id).await;
980        mark_backed_up(&store, room_id, "older_bkp", &sessions).await;
981        // Sanity: they are backed up in order number 1
982        assert_eq!(backed_up_tos(&store).await[sessions[0].session_id()], "older_bkp");
983        mark_backed_up(&store, room_id, "newer_bkp", &sessions).await;
984        // Sanity: they are backed up in order number 2
985        assert_eq!(backed_up_tos(&store).await[sessions[0].session_id()], "newer_bkp");
986
987        // When I ask to back up some to the older version
988        mark_backed_up(&store, room_id, "older_bkp", &sessions[..2]).await;
989
990        // Then older backup overwrites: we don't consider the order here at all
991        let but = backed_up_tos(&store).await;
992        assert_eq!(but[sessions[0].session_id()], "older_bkp");
993        assert_eq!(but[sessions[1].session_id()], "older_bkp");
994        assert_eq!(but[sessions[2].session_id()], "newer_bkp");
995        assert_eq!(but[sessions[3].session_id()], "newer_bkp");
996    }
997
998    #[async_test]
999    async fn test_not_backed_up_sessions_are_eligible_for_backup() {
1000        // Given there are 4 sessions, 2 of which are already backed up
1001        let room_id = room_id!("!test:localhost");
1002        let (store, sessions) = store_with_sessions(4, room_id).await;
1003        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
1004
1005        // When I ask which to back up
1006        let mut to_backup = store
1007            .inbound_group_sessions_for_backup("bkp1", 10)
1008            .await
1009            .expect("Failed to ask for sessions to backup");
1010        to_backup.sort_by_key(|s| s.session_id().to_owned());
1011
1012        // Then I am told the last 2 only
1013        assert_eq!(to_backup, &[sessions[2].clone(), sessions[3].clone()]);
1014    }
1015
1016    #[async_test]
1017    async fn test_all_sessions_are_eligible_for_backup_if_version_is_unknown() {
1018        // Given there are 4 sessions, 2 of which are already backed up in bkp1
1019        let room_id = room_id!("!test:localhost");
1020        let (store, sessions) = store_with_sessions(4, room_id).await;
1021        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
1022
1023        // When I ask which to back up in an unknown version
1024        let mut to_backup = store
1025            .inbound_group_sessions_for_backup("unknown_bkp", 10)
1026            .await
1027            .expect("Failed to ask for sessions to backup");
1028        to_backup.sort_by_key(|s| s.session_id().to_owned());
1029
1030        // Then I am told to back up all of them
1031        assert_eq!(
1032            to_backup,
1033            &[sessions[0].clone(), sessions[1].clone(), sessions[2].clone(), sessions[3].clone()]
1034        );
1035    }
1036
1037    #[async_test]
1038    async fn test_sessions_backed_up_to_a_later_version_are_eligible_for_backup() {
1039        // Given there are 4 sessions, some backed up to three different versions
1040        let room_id = room_id!("!test:localhost");
1041        let (store, sessions) = store_with_sessions(4, room_id).await;
1042        mark_backed_up(&store, room_id, "bkp0", &sessions[..1]).await;
1043        mark_backed_up(&store, room_id, "bkp1", &sessions[1..2]).await;
1044        mark_backed_up(&store, room_id, "bkp2", &sessions[2..3]).await;
1045
1046        // When I ask which to back up in the middle version
1047        let mut to_backup = store
1048            .inbound_group_sessions_for_backup("bkp1", 10)
1049            .await
1050            .expect("Failed to ask for sessions to backup");
1051        to_backup.sort_by_key(|s| s.session_id().to_owned());
1052
1053        // Then I am told to back up everything not in the version I asked about
1054        assert_eq!(
1055            to_backup,
1056            &[
1057                sessions[0].clone(), // Backed up in bkp0
1058                // sessions[1] is backed up in bkp1 already, which we asked about
1059                sessions[2].clone(), // Backed up in bkp2
1060                sessions[3].clone(), // Not backed up
1061            ]
1062        );
1063    }
1064
1065    #[async_test]
1066    async fn test_outbound_group_session_store() {
1067        // Given an outbound session
1068        let (account, _) = get_account_and_session_test_helper();
1069        let room_id = room_id!("!test:localhost");
1070        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
1071
1072        // When we save it to the store
1073        let store = MemoryStore::new();
1074        store.save_outbound_group_sessions(vec![outbound.clone()]);
1075
1076        // Then we can get it out again
1077        let loaded_session = store.get_outbound_group_session(room_id).await.unwrap().unwrap();
1078        assert_eq!(
1079            serde_json::to_string(&outbound.pickle().await).unwrap(),
1080            serde_json::to_string(&loaded_session.pickle().await).unwrap()
1081        );
1082    }
1083
1084    #[async_test]
1085    async fn test_tracked_users_are_stored_once_per_user_id() {
1086        // Given a store containing 2 tracked users, both dirty
1087        let user1 = user_id!("@user1:s");
1088        let user2 = user_id!("@user2:s");
1089        let user3 = user_id!("@user3:s");
1090        let store = MemoryStore::new();
1091        store.save_tracked_users(&[(user1, true), (user2, true)]).await.unwrap();
1092
1093        // When we mark one as clean and add another
1094        store.save_tracked_users(&[(user2, false), (user3, false)]).await.unwrap();
1095
1096        // Then we can get them out again and their dirty flags are correct
1097        let loaded_tracked_users =
1098            store.load_tracked_users().await.expect("failed to load tracked users");
1099
1100        let tracked_contains = |user_id, dirty| {
1101            loaded_tracked_users.iter().any(|u| u.user_id == user_id && u.dirty == dirty)
1102        };
1103
1104        assert!(tracked_contains(user1, true));
1105        assert!(tracked_contains(user2, false));
1106        assert!(tracked_contains(user3, false));
1107        assert_eq!(loaded_tracked_users.len(), 3);
1108    }
1109
1110    #[async_test]
1111    async fn test_private_identity_store() {
1112        // Given a private identity
1113        let private_identity = PrivateCrossSigningIdentity::empty(user_id!("@u:s"));
1114
1115        // When we save it to the store
1116        let store = MemoryStore::new();
1117        store.save_private_identity(Some(private_identity.clone()));
1118
1119        // Then we can get it out again
1120        let loaded_identity =
1121            store.load_identity().await.expect("failed to load private identity").unwrap();
1122
1123        assert_eq!(loaded_identity.user_id(), user_id!("@u:s"));
1124    }
1125
1126    #[async_test]
1127    async fn test_device_store() {
1128        let device = get_device();
1129        let store = MemoryStore::new();
1130
1131        store.save_devices(vec![device.clone()]);
1132
1133        let loaded_device =
1134            store.get_device(device.user_id(), device.device_id()).await.unwrap().unwrap();
1135
1136        assert_eq!(device, loaded_device);
1137
1138        let user_devices = store.get_user_devices(device.user_id()).await.unwrap();
1139
1140        assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
1141        assert_eq!(user_devices.values().next().unwrap(), &device);
1142
1143        let loaded_device = user_devices.get(device.device_id()).unwrap();
1144
1145        assert_eq!(&device, loaded_device);
1146
1147        store.delete_devices(vec![device.clone()]);
1148        assert!(store.get_device(device.user_id(), device.device_id()).await.unwrap().is_none());
1149    }
1150
1151    #[async_test]
1152    async fn test_message_hash() {
1153        let store = MemoryStore::new();
1154
1155        let hash =
1156            OlmMessageHash { sender_key: "test_sender".to_owned(), hash: "test_hash".to_owned() };
1157
1158        let mut changes = Changes::default();
1159        changes.message_hashes.push(hash.clone());
1160
1161        assert!(!store.is_message_known(&hash).await.unwrap());
1162        store.save_changes(changes).await.unwrap();
1163        assert!(store.is_message_known(&hash).await.unwrap());
1164    }
1165
1166    #[async_test]
1167    async fn test_key_counts_of_empty_store_are_zero() {
1168        // Given an empty store
1169        let store = MemoryStore::new();
1170
1171        // When we count keys
1172        let key_counts = store.inbound_group_session_counts(Some("")).await.unwrap();
1173
1174        // Then the answer is zero
1175        assert_eq!(key_counts.total, 0);
1176        assert_eq!(key_counts.backed_up, 0);
1177    }
1178
1179    #[async_test]
1180    async fn test_counting_sessions_reports_the_number_of_sessions() {
1181        // Given a store with sessions
1182        let room_id = room_id!("!test:localhost");
1183        let (store, _) = store_with_sessions(4, room_id).await;
1184
1185        // When we count keys
1186        let key_counts = store.inbound_group_session_counts(Some("bkp")).await.unwrap();
1187
1188        // Then the answer equals the number of sessions we created
1189        assert_eq!(key_counts.total, 4);
1190        // And none are backed up
1191        assert_eq!(key_counts.backed_up, 0);
1192    }
1193
1194    #[async_test]
1195    async fn test_counting_backed_up_sessions_reports_the_number_backed_up_in_this_backup() {
1196        // Given a store with sessions, some backed up
1197        let room_id = room_id!("!test:localhost");
1198        let (store, sessions) = store_with_sessions(5, room_id).await;
1199        mark_backed_up(&store, room_id, "bkp", &sessions[..2]).await;
1200
1201        // When we count keys
1202        let key_counts = store.inbound_group_session_counts(Some("bkp")).await.unwrap();
1203
1204        // Then the answer equals the number of sessions we created
1205        assert_eq!(key_counts.total, 5);
1206        // And the backed_up count matches how many were backed up
1207        assert_eq!(key_counts.backed_up, 2);
1208    }
1209
1210    #[async_test]
1211    async fn test_counting_backed_up_sessions_for_null_backup_reports_zero() {
1212        // Given a store with sessions, some backed up
1213        let room_id = room_id!("!test:localhost");
1214        let (store, sessions) = store_with_sessions(4, room_id).await;
1215        mark_backed_up(&store, room_id, "bkp", &sessions[..2]).await;
1216
1217        // When we count keys, providing None as the backup version
1218        let key_counts = store.inbound_group_session_counts(None).await.unwrap();
1219
1220        // Then we ignore everything and just say zero
1221        assert_eq!(key_counts.backed_up, 0);
1222    }
1223
1224    #[async_test]
1225    async fn test_counting_backed_up_sessions_only_reports_sessions_in_the_version_specified() {
1226        // Given a store with sessions, backed up in several versions
1227        let room_id = room_id!("!test:localhost");
1228        let (store, sessions) = store_with_sessions(4, room_id).await;
1229        mark_backed_up(&store, room_id, "bkp1", &sessions[..2]).await;
1230        mark_backed_up(&store, room_id, "bkp2", &sessions[3..]).await;
1231
1232        // When we count keys for bkp2
1233        let key_counts = store.inbound_group_session_counts(Some("bkp2")).await.unwrap();
1234
1235        // Then the backed_up count reflects how many were backed up in bkp2 only
1236        assert_eq!(key_counts.backed_up, 1);
1237    }
1238
1239    /// Mark the supplied sessions as backed up in the supplied backup version
1240    async fn mark_backed_up(
1241        store: &MemoryStore,
1242        room_id: &RoomId,
1243        backup_version: &str,
1244        sessions: &[InboundGroupSession],
1245    ) {
1246        let rooms_and_ids: Vec<_> = sessions.iter().map(|s| (room_id, s.session_id())).collect();
1247
1248        store
1249            .mark_inbound_group_sessions_as_backed_up(backup_version, &rooms_and_ids)
1250            .await
1251            .expect("Failed to mark sessions as backed up");
1252    }
1253
1254    // Create a MemoryStore containing the supplied number of sessions.
1255    //
1256    // Sessions are returned in alphabetical order of session id.
1257    async fn store_with_sessions(
1258        num_sessions: usize,
1259        room_id: &RoomId,
1260    ) -> (MemoryStore, Vec<InboundGroupSession>) {
1261        let (account, _) = get_account_and_session_test_helper();
1262
1263        let mut sessions = Vec::with_capacity(num_sessions);
1264        for _ in 0..num_sessions {
1265            sessions.push(new_session(&account, room_id).await);
1266        }
1267        sessions.sort_by_key(|s| s.session_id().to_owned());
1268
1269        let store = MemoryStore::new();
1270        store.save_inbound_group_sessions(sessions.clone(), None).await.unwrap();
1271
1272        (store, sessions)
1273    }
1274
1275    // Create a new InboundGroupSession
1276    async fn new_session(account: &Account, room_id: &RoomId) -> InboundGroupSession {
1277        let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw";
1278        let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
1279
1280        InboundGroupSession::new(
1281            Curve25519PublicKey::from_base64(curve_key).unwrap(),
1282            Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE").unwrap(),
1283            room_id,
1284            &outbound.session_key().await,
1285            SenderData::unknown(),
1286            None,
1287            outbound.settings().algorithm.to_owned(),
1288            None,
1289            false,
1290        )
1291        .unwrap()
1292    }
1293
1294    /// Find the session_id and backed_up_to value for each of the sessions in
1295    /// the store.
1296    async fn backed_up_tos(store: &MemoryStore) -> HashMap<SessionId, String> {
1297        store
1298            .get_inbound_group_sessions_and_backed_up_to()
1299            .await
1300            .expect("Unable to get inbound group sessions and backup order")
1301            .iter()
1302            .map(|(s, o)| {
1303                (
1304                    s.session_id().to_owned(),
1305                    o.as_ref().map(|v| v.as_str().to_owned()).unwrap_or("".to_owned()),
1306                )
1307            })
1308            .collect()
1309    }
1310}
1311
1312#[cfg(test)]
1313mod integration_tests {
1314    use std::{
1315        collections::HashMap,
1316        sync::{Arc, Mutex, OnceLock},
1317    };
1318
1319    use async_trait::async_trait;
1320    use matrix_sdk_common::cross_process_lock::CrossProcessLockGeneration;
1321    use ruma::{
1322        DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId, events::secret::request::SecretName,
1323    };
1324    use vodozemac::Curve25519PublicKey;
1325
1326    use super::MemoryStore;
1327    use crate::{
1328        Account, DeviceData, GossipRequest, GossippedSecret, SecretInfo, Session, UserIdentityData,
1329        cryptostore_integration_tests, cryptostore_integration_tests_time,
1330        olm::{
1331            InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity,
1332            SenderDataType, StaticAccountData,
1333        },
1334        store::{
1335            CryptoStore,
1336            types::{
1337                BackupKeys, Changes, DehydratedDeviceKey, PendingChanges, RoomKeyCounts,
1338                RoomKeyWithheldEntry, RoomPendingKeyBundleDetails, RoomSettings,
1339                StoredRoomKeyBundleData, TrackedUser,
1340            },
1341        },
1342    };
1343
1344    /// Holds on to a MemoryStore during a test, and moves it back into STORES
1345    /// when this is dropped
1346    #[derive(Clone, Debug)]
1347    struct PersistentMemoryStore(Arc<MemoryStore>);
1348
1349    impl PersistentMemoryStore {
1350        fn new() -> Self {
1351            Self(Arc::new(MemoryStore::new()))
1352        }
1353
1354        fn get_static_account(&self) -> Option<StaticAccountData> {
1355            self.0.get_static_account()
1356        }
1357    }
1358
1359    /// Return a clone of the store for the test with the supplied name. Note:
1360    /// dropping this store won't destroy its data, since
1361    /// [PersistentMemoryStore] is a reference-counted smart pointer
1362    /// to an underlying [MemoryStore].
1363    async fn get_store(
1364        name: &str,
1365        _passphrase: Option<&str>,
1366        clear_data: bool,
1367    ) -> PersistentMemoryStore {
1368        // Holds on to one [PersistentMemoryStore] per test, so even if the test drops
1369        // the store, we keep its data alive. This simulates the behaviour of
1370        // the other stores, which keep their data in a real DB, allowing us to
1371        // test MemoryStore using the same code.
1372        static STORES: OnceLock<Mutex<HashMap<String, PersistentMemoryStore>>> = OnceLock::new();
1373        let stores = STORES.get_or_init(|| Mutex::new(HashMap::new()));
1374
1375        let mut stores = stores.lock().unwrap();
1376
1377        if clear_data {
1378            // Create a new PersistentMemoryStore
1379            let new_store = PersistentMemoryStore::new();
1380            stores.insert(name.to_owned(), new_store.clone());
1381            new_store
1382        } else {
1383            stores.entry(name.to_owned()).or_insert_with(PersistentMemoryStore::new).clone()
1384        }
1385    }
1386
1387    /// Forwards all methods to the underlying [MemoryStore].
1388    #[cfg_attr(target_family = "wasm", async_trait(?Send))]
1389    #[cfg_attr(not(target_family = "wasm"), async_trait)]
1390    impl CryptoStore for PersistentMemoryStore {
1391        type Error = <MemoryStore as CryptoStore>::Error;
1392
1393        async fn load_account(&self) -> Result<Option<Account>, Self::Error> {
1394            self.0.load_account().await
1395        }
1396
1397        async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>, Self::Error> {
1398            self.0.load_identity().await
1399        }
1400
1401        async fn save_changes(&self, changes: Changes) -> Result<(), Self::Error> {
1402            self.0.save_changes(changes).await
1403        }
1404
1405        async fn save_pending_changes(&self, changes: PendingChanges) -> Result<(), Self::Error> {
1406            self.0.save_pending_changes(changes).await
1407        }
1408
1409        async fn save_inbound_group_sessions(
1410            &self,
1411            sessions: Vec<InboundGroupSession>,
1412            backed_up_to_version: Option<&str>,
1413        ) -> Result<(), Self::Error> {
1414            self.0.save_inbound_group_sessions(sessions, backed_up_to_version).await
1415        }
1416
1417        async fn get_sessions(
1418            &self,
1419            sender_key: &str,
1420        ) -> Result<Option<Vec<Session>>, Self::Error> {
1421            self.0.get_sessions(sender_key).await
1422        }
1423
1424        async fn get_inbound_group_session(
1425            &self,
1426            room_id: &RoomId,
1427            session_id: &str,
1428        ) -> Result<Option<InboundGroupSession>, Self::Error> {
1429            self.0.get_inbound_group_session(room_id, session_id).await
1430        }
1431
1432        async fn get_withheld_info(
1433            &self,
1434            room_id: &RoomId,
1435            session_id: &str,
1436        ) -> Result<Option<RoomKeyWithheldEntry>, Self::Error> {
1437            self.0.get_withheld_info(room_id, session_id).await
1438        }
1439
1440        async fn get_withheld_sessions_by_room_id(
1441            &self,
1442            room_id: &RoomId,
1443        ) -> Result<Vec<RoomKeyWithheldEntry>, Self::Error> {
1444            self.0.get_withheld_sessions_by_room_id(room_id).await
1445        }
1446
1447        async fn get_inbound_group_sessions(
1448            &self,
1449        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1450            self.0.get_inbound_group_sessions().await
1451        }
1452
1453        async fn inbound_group_session_counts(
1454            &self,
1455            backup_version: Option<&str>,
1456        ) -> Result<RoomKeyCounts, Self::Error> {
1457            self.0.inbound_group_session_counts(backup_version).await
1458        }
1459
1460        async fn get_inbound_group_sessions_by_room_id(
1461            &self,
1462            room_id: &RoomId,
1463        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1464            self.0.get_inbound_group_sessions_by_room_id(room_id).await
1465        }
1466
1467        async fn get_inbound_group_sessions_for_device_batch(
1468            &self,
1469            sender_key: Curve25519PublicKey,
1470            sender_data_type: SenderDataType,
1471            after_session_id: Option<String>,
1472            limit: usize,
1473        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1474            self.0
1475                .get_inbound_group_sessions_for_device_batch(
1476                    sender_key,
1477                    sender_data_type,
1478                    after_session_id,
1479                    limit,
1480                )
1481                .await
1482        }
1483
1484        async fn inbound_group_sessions_for_backup(
1485            &self,
1486            backup_version: &str,
1487            limit: usize,
1488        ) -> Result<Vec<InboundGroupSession>, Self::Error> {
1489            self.0.inbound_group_sessions_for_backup(backup_version, limit).await
1490        }
1491
1492        async fn mark_inbound_group_sessions_as_backed_up(
1493            &self,
1494            backup_version: &str,
1495            room_and_session_ids: &[(&RoomId, &str)],
1496        ) -> Result<(), Self::Error> {
1497            self.0
1498                .mark_inbound_group_sessions_as_backed_up(backup_version, room_and_session_ids)
1499                .await
1500        }
1501
1502        async fn reset_backup_state(&self) -> Result<(), Self::Error> {
1503            self.0.reset_backup_state().await
1504        }
1505
1506        async fn load_backup_keys(&self) -> Result<BackupKeys, Self::Error> {
1507            self.0.load_backup_keys().await
1508        }
1509
1510        async fn load_dehydrated_device_pickle_key(
1511            &self,
1512        ) -> Result<Option<DehydratedDeviceKey>, Self::Error> {
1513            self.0.load_dehydrated_device_pickle_key().await
1514        }
1515
1516        async fn delete_dehydrated_device_pickle_key(&self) -> Result<(), Self::Error> {
1517            self.0.delete_dehydrated_device_pickle_key().await
1518        }
1519
1520        async fn get_outbound_group_session(
1521            &self,
1522            room_id: &RoomId,
1523        ) -> Result<Option<OutboundGroupSession>, Self::Error> {
1524            self.0.get_outbound_group_session(room_id).await
1525        }
1526
1527        async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>, Self::Error> {
1528            self.0.load_tracked_users().await
1529        }
1530
1531        async fn save_tracked_users(&self, users: &[(&UserId, bool)]) -> Result<(), Self::Error> {
1532            self.0.save_tracked_users(users).await
1533        }
1534
1535        async fn get_device(
1536            &self,
1537            user_id: &UserId,
1538            device_id: &DeviceId,
1539        ) -> Result<Option<DeviceData>, Self::Error> {
1540            self.0.get_device(user_id, device_id).await
1541        }
1542
1543        async fn get_user_devices(
1544            &self,
1545            user_id: &UserId,
1546        ) -> Result<HashMap<OwnedDeviceId, DeviceData>, Self::Error> {
1547            self.0.get_user_devices(user_id).await
1548        }
1549
1550        async fn get_own_device(&self) -> Result<DeviceData, Self::Error> {
1551            self.0.get_own_device().await
1552        }
1553
1554        async fn get_user_identity(
1555            &self,
1556            user_id: &UserId,
1557        ) -> Result<Option<UserIdentityData>, Self::Error> {
1558            self.0.get_user_identity(user_id).await
1559        }
1560
1561        async fn is_message_known(
1562            &self,
1563            message_hash: &OlmMessageHash,
1564        ) -> Result<bool, Self::Error> {
1565            self.0.is_message_known(message_hash).await
1566        }
1567
1568        async fn get_outgoing_secret_requests(
1569            &self,
1570            request_id: &TransactionId,
1571        ) -> Result<Option<GossipRequest>, Self::Error> {
1572            self.0.get_outgoing_secret_requests(request_id).await
1573        }
1574
1575        async fn get_secret_request_by_info(
1576            &self,
1577            secret_info: &SecretInfo,
1578        ) -> Result<Option<GossipRequest>, Self::Error> {
1579            self.0.get_secret_request_by_info(secret_info).await
1580        }
1581
1582        async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>, Self::Error> {
1583            self.0.get_unsent_secret_requests().await
1584        }
1585
1586        async fn delete_outgoing_secret_requests(
1587            &self,
1588            request_id: &TransactionId,
1589        ) -> Result<(), Self::Error> {
1590            self.0.delete_outgoing_secret_requests(request_id).await
1591        }
1592
1593        async fn get_secrets_from_inbox(
1594            &self,
1595            secret_name: &SecretName,
1596        ) -> Result<Vec<GossippedSecret>, Self::Error> {
1597            self.0.get_secrets_from_inbox(secret_name).await
1598        }
1599
1600        async fn delete_secrets_from_inbox(
1601            &self,
1602            secret_name: &SecretName,
1603        ) -> Result<(), Self::Error> {
1604            self.0.delete_secrets_from_inbox(secret_name).await
1605        }
1606
1607        async fn get_room_settings(
1608            &self,
1609            room_id: &RoomId,
1610        ) -> Result<Option<RoomSettings>, Self::Error> {
1611            self.0.get_room_settings(room_id).await
1612        }
1613
1614        async fn get_received_room_key_bundle_data(
1615            &self,
1616            room_id: &RoomId,
1617            user_id: &UserId,
1618        ) -> crate::store::Result<Option<StoredRoomKeyBundleData>, Self::Error> {
1619            self.0.get_received_room_key_bundle_data(room_id, user_id).await
1620        }
1621
1622        async fn has_downloaded_all_room_keys(
1623            &self,
1624            room_id: &RoomId,
1625        ) -> Result<bool, Self::Error> {
1626            self.0.has_downloaded_all_room_keys(room_id).await
1627        }
1628
1629        async fn get_pending_key_bundle_details_for_room(
1630            &self,
1631            room_id: &RoomId,
1632        ) -> Result<Option<RoomPendingKeyBundleDetails>, Self::Error> {
1633            self.0.get_pending_key_bundle_details_for_room(room_id).await
1634        }
1635
1636        async fn get_all_rooms_pending_key_bundles(
1637            &self,
1638        ) -> Result<Vec<RoomPendingKeyBundleDetails>, Self::Error> {
1639            self.0.get_all_rooms_pending_key_bundles().await
1640        }
1641
1642        async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>, Self::Error> {
1643            self.0.get_custom_value(key).await
1644        }
1645
1646        async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<(), Self::Error> {
1647            self.0.set_custom_value(key, value).await
1648        }
1649
1650        async fn remove_custom_value(&self, key: &str) -> Result<(), Self::Error> {
1651            self.0.remove_custom_value(key).await
1652        }
1653
1654        async fn try_take_leased_lock(
1655            &self,
1656            lease_duration_ms: u32,
1657            key: &str,
1658            holder: &str,
1659        ) -> Result<Option<CrossProcessLockGeneration>, Self::Error> {
1660            self.0.try_take_leased_lock(lease_duration_ms, key, holder).await
1661        }
1662
1663        async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
1664            self.0.next_batch_token().await
1665        }
1666
1667        async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1668            self.0.get_size().await
1669        }
1670    }
1671
1672    cryptostore_integration_tests!();
1673    cryptostore_integration_tests_time!();
1674}