Skip to main content

matrix_sdk_base/store/
memory_store.rs

1// Copyright 2021 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    cmp::Reverse,
17    collections::{BTreeMap, BTreeSet, HashMap},
18    sync::RwLock,
19};
20
21use async_trait::async_trait;
22use growable_bloom_filter::GrowableBloom;
23use matrix_sdk_common::{ROOM_VERSION_FALLBACK, ROOM_VERSION_RULES_FALLBACK};
24use ruma::{
25    CanonicalJsonObject, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedMxcUri,
26    OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId,
27    canonical_json::{RedactedBecause, redact},
28    events::{
29        AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
30        AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType,
31        presence::PresenceEvent,
32        receipt::{Receipt, ReceiptThread, ReceiptType},
33        room::member::{MembershipState, StrippedRoomMemberEvent, SyncRoomMemberEvent},
34    },
35    serde::Raw,
36    time::Instant,
37};
38use tracing::{debug, instrument, warn};
39
40use super::{
41    DependentQueuedRequest, DependentQueuedRequestKind, QueuedRequestKind, Result, RoomInfo,
42    RoomLoadSettings, StateChanges, StateStore, StoreError, SupportedVersionsResponse,
43    TtlStoreValue, WellKnownResponse,
44    send_queue::{ChildTransactionId, QueuedRequest, SentRequestKey},
45    traits::ComposerDraft,
46};
47use crate::{
48    MinimalRoomMemberEvent, RoomMemberships, StateStoreDataKey, StateStoreDataValue,
49    deserialized_responses::{DisplayName, RawAnySyncOrStrippedState},
50    store::{
51        QueueWedgeError, StoredThreadSubscription,
52        traits::{ThreadSubscriptionCatchupToken, compare_thread_subscription_bump_stamps},
53    },
54};
55
56#[derive(Debug, Default)]
57#[allow(clippy::type_complexity)]
58struct MemoryStoreInner {
59    recently_visited_rooms: HashMap<OwnedUserId, Vec<OwnedRoomId>>,
60    composer_drafts: HashMap<(OwnedRoomId, Option<OwnedEventId>), ComposerDraft>,
61    user_avatar_url: HashMap<OwnedUserId, OwnedMxcUri>,
62    sync_token: Option<String>,
63    supported_versions: Option<TtlStoreValue<SupportedVersionsResponse>>,
64    well_known: Option<TtlStoreValue<Option<WellKnownResponse>>>,
65    filters: HashMap<String, String>,
66    utd_hook_manager_data: Option<GrowableBloom>,
67    one_time_key_uploaded_error: bool,
68    account_data: HashMap<GlobalAccountDataEventType, Raw<AnyGlobalAccountDataEvent>>,
69    profiles: HashMap<OwnedRoomId, HashMap<OwnedUserId, MinimalRoomMemberEvent>>,
70    display_names: HashMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>,
71    members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
72    room_info: HashMap<OwnedRoomId, RoomInfo>,
73    room_state:
74        HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnySyncStateEvent>>>>,
75    room_account_data:
76        HashMap<OwnedRoomId, HashMap<RoomAccountDataEventType, Raw<AnyRoomAccountDataEvent>>>,
77    stripped_room_state:
78        HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnyStrippedStateEvent>>>>,
79    stripped_members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
80    presence: HashMap<OwnedUserId, Raw<PresenceEvent>>,
81    room_user_receipts: HashMap<
82        OwnedRoomId,
83        HashMap<(String, Option<String>), HashMap<OwnedUserId, (OwnedEventId, Receipt)>>,
84    >,
85    room_event_receipts: HashMap<
86        OwnedRoomId,
87        HashMap<(String, Option<String>), HashMap<OwnedEventId, HashMap<OwnedUserId, Receipt>>>,
88    >,
89    custom: HashMap<Vec<u8>, Vec<u8>>,
90    send_queue_events: BTreeMap<OwnedRoomId, Vec<QueuedRequest>>,
91    dependent_send_queue_events: BTreeMap<OwnedRoomId, Vec<DependentQueuedRequest>>,
92    seen_knock_requests: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, OwnedUserId>>,
93    thread_subscriptions: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, StoredThreadSubscription>>,
94    thread_subscriptions_catchup_tokens: Option<Vec<ThreadSubscriptionCatchupToken>>,
95}
96
97/// In-memory, non-persistent implementation of the `StateStore`.
98///
99/// Default if no other is configured at startup.
100#[derive(Debug, Default)]
101pub struct MemoryStore {
102    inner: RwLock<MemoryStoreInner>,
103}
104
105impl MemoryStore {
106    /// Create a new empty MemoryStore
107    pub fn new() -> Self {
108        Self::default()
109    }
110
111    fn get_user_room_receipt_event_impl(
112        &self,
113        room_id: &RoomId,
114        receipt_type: ReceiptType,
115        thread: ReceiptThread,
116        user_id: &UserId,
117    ) -> Option<(OwnedEventId, Receipt)> {
118        self.inner
119            .read()
120            .unwrap()
121            .room_user_receipts
122            .get(room_id)?
123            .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
124            .get(user_id)
125            .cloned()
126    }
127
128    fn get_event_room_receipt_events_impl(
129        &self,
130        room_id: &RoomId,
131        receipt_type: ReceiptType,
132        thread: ReceiptThread,
133        event_id: &EventId,
134    ) -> Option<Vec<(OwnedUserId, Receipt)>> {
135        Some(
136            self.inner
137                .read()
138                .unwrap()
139                .room_event_receipts
140                .get(room_id)?
141                .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
142                .get(event_id)?
143                .iter()
144                .map(|(key, value)| (key.clone(), value.clone()))
145                .collect(),
146        )
147    }
148}
149
150#[cfg_attr(target_family = "wasm", async_trait(?Send))]
151#[cfg_attr(not(target_family = "wasm"), async_trait)]
152impl StateStore for MemoryStore {
153    type Error = StoreError;
154
155    async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<Option<StateStoreDataValue>> {
156        let inner = self.inner.read().unwrap();
157
158        Ok(match key {
159            StateStoreDataKey::SyncToken => {
160                inner.sync_token.clone().map(StateStoreDataValue::SyncToken)
161            }
162            StateStoreDataKey::SupportedVersions => {
163                inner.supported_versions.clone().map(StateStoreDataValue::SupportedVersions)
164            }
165            StateStoreDataKey::WellKnown => {
166                inner.well_known.clone().map(StateStoreDataValue::WellKnown)
167            }
168            StateStoreDataKey::Filter(filter_name) => {
169                inner.filters.get(filter_name).cloned().map(StateStoreDataValue::Filter)
170            }
171            StateStoreDataKey::UserAvatarUrl(user_id) => {
172                inner.user_avatar_url.get(user_id).cloned().map(StateStoreDataValue::UserAvatarUrl)
173            }
174            StateStoreDataKey::RecentlyVisitedRooms(user_id) => inner
175                .recently_visited_rooms
176                .get(user_id)
177                .cloned()
178                .map(StateStoreDataValue::RecentlyVisitedRooms),
179            StateStoreDataKey::UtdHookManagerData => {
180                inner.utd_hook_manager_data.clone().map(StateStoreDataValue::UtdHookManagerData)
181            }
182            StateStoreDataKey::OneTimeKeyAlreadyUploaded => inner
183                .one_time_key_uploaded_error
184                .then_some(StateStoreDataValue::OneTimeKeyAlreadyUploaded),
185            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
186                let key = (room_id.to_owned(), thread_root.map(ToOwned::to_owned));
187                inner.composer_drafts.get(&key).cloned().map(StateStoreDataValue::ComposerDraft)
188            }
189            StateStoreDataKey::SeenKnockRequests(room_id) => inner
190                .seen_knock_requests
191                .get(room_id)
192                .cloned()
193                .map(StateStoreDataValue::SeenKnockRequests),
194            StateStoreDataKey::ThreadSubscriptionsCatchupTokens => inner
195                .thread_subscriptions_catchup_tokens
196                .clone()
197                .map(StateStoreDataValue::ThreadSubscriptionsCatchupTokens),
198        })
199    }
200
201    async fn set_kv_data(
202        &self,
203        key: StateStoreDataKey<'_>,
204        value: StateStoreDataValue,
205    ) -> Result<()> {
206        let mut inner = self.inner.write().unwrap();
207        match key {
208            StateStoreDataKey::SyncToken => {
209                inner.sync_token =
210                    Some(value.into_sync_token().expect("Session data not a sync token"))
211            }
212            StateStoreDataKey::Filter(filter_name) => {
213                inner.filters.insert(
214                    filter_name.to_owned(),
215                    value.into_filter().expect("Session data not a filter"),
216                );
217            }
218            StateStoreDataKey::UserAvatarUrl(user_id) => {
219                inner.user_avatar_url.insert(
220                    user_id.to_owned(),
221                    value.into_user_avatar_url().expect("Session data not a user avatar url"),
222                );
223            }
224            StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
225                inner.recently_visited_rooms.insert(
226                    user_id.to_owned(),
227                    value
228                        .into_recently_visited_rooms()
229                        .expect("Session data not a list of recently visited rooms"),
230                );
231            }
232            StateStoreDataKey::UtdHookManagerData => {
233                inner.utd_hook_manager_data = Some(
234                    value
235                        .into_utd_hook_manager_data()
236                        .expect("Session data not the hook manager data"),
237                );
238            }
239            StateStoreDataKey::OneTimeKeyAlreadyUploaded => {
240                inner.one_time_key_uploaded_error = true;
241            }
242            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
243                inner.composer_drafts.insert(
244                    (room_id.to_owned(), thread_root.map(ToOwned::to_owned)),
245                    value.into_composer_draft().expect("Session data not a composer draft"),
246                );
247            }
248            StateStoreDataKey::SupportedVersions => {
249                inner.supported_versions = Some(
250                    value
251                        .into_supported_versions()
252                        .expect("Session data not containing supported versions"),
253                );
254            }
255            StateStoreDataKey::WellKnown => {
256                inner.well_known =
257                    Some(value.into_well_known().expect("Session data not containing well-known"));
258            }
259            StateStoreDataKey::SeenKnockRequests(room_id) => {
260                inner.seen_knock_requests.insert(
261                    room_id.to_owned(),
262                    value
263                        .into_seen_knock_requests()
264                        .expect("Session data is not a set of seen join request ids"),
265                );
266            }
267            StateStoreDataKey::ThreadSubscriptionsCatchupTokens => {
268                inner.thread_subscriptions_catchup_tokens =
269                    Some(value.into_thread_subscriptions_catchup_tokens().expect(
270                        "Session data is not a list of thread subscription catchup tokens",
271                    ));
272            }
273        }
274
275        Ok(())
276    }
277
278    async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> {
279        let mut inner = self.inner.write().unwrap();
280        match key {
281            StateStoreDataKey::SyncToken => inner.sync_token = None,
282            StateStoreDataKey::SupportedVersions => inner.supported_versions = None,
283            StateStoreDataKey::WellKnown => inner.well_known = None,
284            StateStoreDataKey::Filter(filter_name) => {
285                inner.filters.remove(filter_name);
286            }
287            StateStoreDataKey::UserAvatarUrl(user_id) => {
288                inner.user_avatar_url.remove(user_id);
289            }
290            StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
291                inner.recently_visited_rooms.remove(user_id);
292            }
293            StateStoreDataKey::UtdHookManagerData => inner.utd_hook_manager_data = None,
294            StateStoreDataKey::OneTimeKeyAlreadyUploaded => {
295                inner.one_time_key_uploaded_error = false
296            }
297            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
298                let key = (room_id.to_owned(), thread_root.map(ToOwned::to_owned));
299                inner.composer_drafts.remove(&key);
300            }
301            StateStoreDataKey::SeenKnockRequests(room_id) => {
302                inner.seen_knock_requests.remove(room_id);
303            }
304            StateStoreDataKey::ThreadSubscriptionsCatchupTokens => {
305                inner.thread_subscriptions_catchup_tokens = None;
306            }
307        }
308        Ok(())
309    }
310
311    #[instrument(skip(self, changes))]
312    async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
313        let now = Instant::now();
314
315        let mut inner = self.inner.write().unwrap();
316
317        if let Some(s) = &changes.sync_token {
318            inner.sync_token = Some(s.to_owned());
319        }
320
321        for (room, users) in &changes.profiles_to_delete {
322            let Some(room_profiles) = inner.profiles.get_mut(room) else {
323                continue;
324            };
325            for user in users {
326                room_profiles.remove(user);
327            }
328        }
329
330        for (room, users) in &changes.profiles {
331            for (user_id, profile) in users {
332                inner
333                    .profiles
334                    .entry(room.clone())
335                    .or_default()
336                    .insert(user_id.clone(), profile.clone());
337            }
338        }
339
340        for (room, map) in &changes.ambiguity_maps {
341            for (display_name, display_names) in map {
342                inner
343                    .display_names
344                    .entry(room.clone())
345                    .or_default()
346                    .insert(display_name.clone(), display_names.clone());
347            }
348        }
349
350        for (event_type, event) in &changes.account_data {
351            inner.account_data.insert(event_type.clone(), event.clone());
352        }
353
354        for (room, events) in &changes.room_account_data {
355            for (event_type, event) in events {
356                inner
357                    .room_account_data
358                    .entry(room.clone())
359                    .or_default()
360                    .insert(event_type.clone(), event.clone());
361            }
362        }
363
364        for (room, event_types) in &changes.state {
365            for (event_type, events) in event_types {
366                for (state_key, raw_event) in events {
367                    inner
368                        .room_state
369                        .entry(room.clone())
370                        .or_default()
371                        .entry(event_type.clone())
372                        .or_default()
373                        .insert(state_key.to_owned(), raw_event.clone());
374                    inner.stripped_room_state.remove(room);
375
376                    if *event_type == StateEventType::RoomMember {
377                        let event =
378                            match raw_event.deserialize_as_unchecked::<SyncRoomMemberEvent>() {
379                                Ok(ev) => ev,
380                                Err(e) => {
381                                    let event_id: Option<String> =
382                                        raw_event.get_field("event_id").ok().flatten();
383                                    debug!(event_id, "Failed to deserialize member event: {e}");
384                                    continue;
385                                }
386                            };
387
388                        inner.stripped_members.remove(room);
389
390                        inner
391                            .members
392                            .entry(room.clone())
393                            .or_default()
394                            .insert(event.state_key().to_owned(), event.membership().clone());
395                    }
396                }
397            }
398        }
399
400        for (room_id, info) in &changes.room_infos {
401            inner.room_info.insert(room_id.clone(), info.clone());
402        }
403
404        for (sender, event) in &changes.presence {
405            inner.presence.insert(sender.clone(), event.clone());
406        }
407
408        for (room, event_types) in &changes.stripped_state {
409            for (event_type, events) in event_types {
410                for (state_key, raw_event) in events {
411                    inner
412                        .stripped_room_state
413                        .entry(room.clone())
414                        .or_default()
415                        .entry(event_type.clone())
416                        .or_default()
417                        .insert(state_key.to_owned(), raw_event.clone());
418
419                    if *event_type == StateEventType::RoomMember {
420                        let event =
421                            match raw_event.deserialize_as_unchecked::<StrippedRoomMemberEvent>() {
422                                Ok(ev) => ev,
423                                Err(e) => {
424                                    let event_id: Option<String> =
425                                        raw_event.get_field("event_id").ok().flatten();
426                                    debug!(
427                                        event_id,
428                                        "Failed to deserialize stripped member event: {e}"
429                                    );
430                                    continue;
431                                }
432                            };
433
434                        inner
435                            .stripped_members
436                            .entry(room.clone())
437                            .or_default()
438                            .insert(event.state_key, event.content.membership.clone());
439                    }
440                }
441            }
442        }
443
444        for (room, content) in &changes.receipts {
445            for (event_id, receipts) in &content.0 {
446                for (receipt_type, receipts) in receipts {
447                    for (user_id, receipt) in receipts {
448                        let thread = receipt.thread.as_str().map(ToOwned::to_owned);
449                        // Add the receipt to the room user receipts
450                        if let Some((old_event, _)) = inner
451                            .room_user_receipts
452                            .entry(room.clone())
453                            .or_default()
454                            .entry((receipt_type.to_string(), thread.clone()))
455                            .or_default()
456                            .insert(user_id.clone(), (event_id.clone(), receipt.clone()))
457                        {
458                            // Remove the old receipt from the room event receipts
459                            if let Some(receipt_map) = inner.room_event_receipts.get_mut(room)
460                                && let Some(event_map) =
461                                    receipt_map.get_mut(&(receipt_type.to_string(), thread.clone()))
462                                && let Some(user_map) = event_map.get_mut(&old_event)
463                            {
464                                user_map.remove(user_id);
465                            }
466                        }
467
468                        // Add the receipt to the room event receipts
469                        inner
470                            .room_event_receipts
471                            .entry(room.clone())
472                            .or_default()
473                            .entry((receipt_type.to_string(), thread))
474                            .or_default()
475                            .entry(event_id.clone())
476                            .or_default()
477                            .insert(user_id.clone(), receipt.clone());
478                    }
479                }
480            }
481        }
482
483        let make_redaction_rules = |room_info: &HashMap<OwnedRoomId, RoomInfo>, room_id| {
484            room_info.get(room_id).map(|info| info.room_version_rules_or_default()).unwrap_or_else(|| {
485                warn!(
486                    ?room_id,
487                    "Unable to get the room version rules, defaulting to rules for room version {ROOM_VERSION_FALLBACK}"
488                );
489                ROOM_VERSION_RULES_FALLBACK
490            }).redaction
491        };
492
493        let inner = &mut *inner;
494        for (room_id, redactions) in &changes.redactions {
495            let mut redaction_rules = None;
496
497            if let Some(room) = inner.room_state.get_mut(room_id) {
498                for ref_room_mu in room.values_mut() {
499                    for raw_evt in ref_room_mu.values_mut() {
500                        if let Ok(Some(event_id)) = raw_evt.get_field::<OwnedEventId>("event_id")
501                            && let Some(redaction) = redactions.get(&event_id)
502                        {
503                            let redacted = redact(
504                                raw_evt.deserialize_as::<CanonicalJsonObject>()?,
505                                redaction_rules.get_or_insert_with(|| {
506                                    make_redaction_rules(&inner.room_info, room_id)
507                                }),
508                                Some(RedactedBecause::from_raw_event(redaction)?),
509                            )
510                            .map_err(StoreError::Redaction)?;
511                            *raw_evt = Raw::new(&redacted)?.cast_unchecked();
512                        }
513                    }
514                }
515            }
516        }
517
518        debug!("Saved changes in {:?}", now.elapsed());
519
520        Ok(())
521    }
522
523    async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
524        Ok(self.inner.read().unwrap().presence.get(user_id).cloned())
525    }
526
527    async fn get_presence_events(
528        &self,
529        user_ids: &[OwnedUserId],
530    ) -> Result<Vec<Raw<PresenceEvent>>> {
531        let presence = &self.inner.read().unwrap().presence;
532        Ok(user_ids.iter().filter_map(|user_id| presence.get(user_id).cloned()).collect())
533    }
534
535    async fn get_state_event(
536        &self,
537        room_id: &RoomId,
538        event_type: StateEventType,
539        state_key: &str,
540    ) -> Result<Option<RawAnySyncOrStrippedState>> {
541        Ok(self
542            .get_state_events_for_keys(room_id, event_type, &[state_key])
543            .await?
544            .into_iter()
545            .next())
546    }
547
548    async fn get_state_events(
549        &self,
550        room_id: &RoomId,
551        event_type: StateEventType,
552    ) -> Result<Vec<RawAnySyncOrStrippedState>> {
553        fn get_events<T>(
554            state_map: &HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<T>>>>,
555            room_id: &RoomId,
556            event_type: &StateEventType,
557            to_enum: fn(Raw<T>) -> RawAnySyncOrStrippedState,
558        ) -> Option<Vec<RawAnySyncOrStrippedState>> {
559            let state_events = state_map.get(room_id)?.get(event_type)?;
560            Some(state_events.values().cloned().map(to_enum).collect())
561        }
562
563        let inner = self.inner.read().unwrap();
564        Ok(get_events(
565            &inner.stripped_room_state,
566            room_id,
567            &event_type,
568            RawAnySyncOrStrippedState::Stripped,
569        )
570        .or_else(|| {
571            get_events(&inner.room_state, room_id, &event_type, RawAnySyncOrStrippedState::Sync)
572        })
573        .unwrap_or_default())
574    }
575
576    async fn get_state_events_for_keys(
577        &self,
578        room_id: &RoomId,
579        event_type: StateEventType,
580        state_keys: &[&str],
581    ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
582        let inner = self.inner.read().unwrap();
583
584        if let Some(stripped_state_events) =
585            inner.stripped_room_state.get(room_id).and_then(|events| events.get(&event_type))
586        {
587            Ok(state_keys
588                .iter()
589                .filter_map(|k| {
590                    stripped_state_events
591                        .get(*k)
592                        .map(|e| RawAnySyncOrStrippedState::Stripped(e.clone()))
593                })
594                .collect())
595        } else if let Some(sync_state_events) =
596            inner.room_state.get(room_id).and_then(|events| events.get(&event_type))
597        {
598            Ok(state_keys
599                .iter()
600                .filter_map(|k| {
601                    sync_state_events.get(*k).map(|e| RawAnySyncOrStrippedState::Sync(e.clone()))
602                })
603                .collect())
604        } else {
605            Ok(Vec::new())
606        }
607    }
608
609    async fn get_profile(
610        &self,
611        room_id: &RoomId,
612        user_id: &UserId,
613    ) -> Result<Option<MinimalRoomMemberEvent>> {
614        Ok(self
615            .inner
616            .read()
617            .unwrap()
618            .profiles
619            .get(room_id)
620            .and_then(|room_profiles| room_profiles.get(user_id))
621            .cloned())
622    }
623
624    async fn get_profiles<'a>(
625        &self,
626        room_id: &RoomId,
627        user_ids: &'a [OwnedUserId],
628    ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>> {
629        if user_ids.is_empty() {
630            return Ok(BTreeMap::new());
631        }
632
633        let profiles = &self.inner.read().unwrap().profiles;
634        let Some(room_profiles) = profiles.get(room_id) else {
635            return Ok(BTreeMap::new());
636        };
637
638        Ok(user_ids
639            .iter()
640            .filter_map(|user_id| room_profiles.get(user_id).map(|p| (&**user_id, p.clone())))
641            .collect())
642    }
643
644    #[instrument(skip(self, memberships))]
645    async fn get_user_ids(
646        &self,
647        room_id: &RoomId,
648        memberships: RoomMemberships,
649    ) -> Result<Vec<OwnedUserId>> {
650        /// Get the user IDs for the given room with the given memberships and
651        /// stripped state.
652        ///
653        /// If `memberships` is empty, returns all user IDs in the room with the
654        /// given stripped state.
655        fn get_user_ids_inner(
656            members: &HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
657            room_id: &RoomId,
658            memberships: RoomMemberships,
659        ) -> Vec<OwnedUserId> {
660            members
661                .get(room_id)
662                .map(|members| {
663                    members
664                        .iter()
665                        .filter_map(|(user_id, membership)| {
666                            memberships.matches(membership).then_some(user_id)
667                        })
668                        .cloned()
669                        .collect()
670                })
671                .unwrap_or_default()
672        }
673        let inner = self.inner.read().unwrap();
674        let v = get_user_ids_inner(&inner.stripped_members, room_id, memberships);
675        if !v.is_empty() {
676            return Ok(v);
677        }
678        Ok(get_user_ids_inner(&inner.members, room_id, memberships))
679    }
680
681    async fn get_room_infos(&self, room_load_settings: &RoomLoadSettings) -> Result<Vec<RoomInfo>> {
682        let memory_store_inner = self.inner.read().unwrap();
683        let room_infos = &memory_store_inner.room_info;
684
685        Ok(match room_load_settings {
686            RoomLoadSettings::All => room_infos.values().cloned().collect(),
687
688            RoomLoadSettings::One(room_id) => match room_infos.get(room_id) {
689                Some(room_info) => vec![room_info.clone()],
690                None => vec![],
691            },
692        })
693    }
694
695    async fn get_users_with_display_name(
696        &self,
697        room_id: &RoomId,
698        display_name: &DisplayName,
699    ) -> Result<BTreeSet<OwnedUserId>> {
700        Ok(self
701            .inner
702            .read()
703            .unwrap()
704            .display_names
705            .get(room_id)
706            .and_then(|room_names| room_names.get(display_name).cloned())
707            .unwrap_or_default())
708    }
709
710    async fn get_users_with_display_names<'a>(
711        &self,
712        room_id: &RoomId,
713        display_names: &'a [DisplayName],
714    ) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>> {
715        if display_names.is_empty() {
716            return Ok(HashMap::new());
717        }
718
719        let inner = self.inner.read().unwrap();
720        let Some(room_names) = inner.display_names.get(room_id) else {
721            return Ok(HashMap::new());
722        };
723
724        Ok(display_names.iter().filter_map(|n| room_names.get(n).map(|d| (n, d.clone()))).collect())
725    }
726
727    async fn get_account_data_event(
728        &self,
729        event_type: GlobalAccountDataEventType,
730    ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
731        Ok(self.inner.read().unwrap().account_data.get(&event_type).cloned())
732    }
733
734    async fn get_room_account_data_event(
735        &self,
736        room_id: &RoomId,
737        event_type: RoomAccountDataEventType,
738    ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
739        Ok(self
740            .inner
741            .read()
742            .unwrap()
743            .room_account_data
744            .get(room_id)
745            .and_then(|m| m.get(&event_type))
746            .cloned())
747    }
748
749    async fn get_user_room_receipt_event(
750        &self,
751        room_id: &RoomId,
752        receipt_type: ReceiptType,
753        thread: ReceiptThread,
754        user_id: &UserId,
755    ) -> Result<Option<(OwnedEventId, Receipt)>> {
756        Ok(self.get_user_room_receipt_event_impl(room_id, receipt_type, thread, user_id))
757    }
758
759    async fn get_event_room_receipt_events(
760        &self,
761        room_id: &RoomId,
762        receipt_type: ReceiptType,
763        thread: ReceiptThread,
764        event_id: &EventId,
765    ) -> Result<Vec<(OwnedUserId, Receipt)>> {
766        Ok(self
767            .get_event_room_receipt_events_impl(room_id, receipt_type, thread, event_id)
768            .unwrap_or_default())
769    }
770
771    async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
772        Ok(self.inner.read().unwrap().custom.get(key).cloned())
773    }
774
775    async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
776        Ok(self.inner.write().unwrap().custom.insert(key.to_vec(), value))
777    }
778
779    async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
780        Ok(self.inner.write().unwrap().custom.remove(key))
781    }
782
783    async fn remove_room(&self, room_id: &RoomId) -> Result<()> {
784        let mut inner = self.inner.write().unwrap();
785
786        inner.profiles.remove(room_id);
787        inner.display_names.remove(room_id);
788        inner.members.remove(room_id);
789        inner.room_info.remove(room_id);
790        inner.room_state.remove(room_id);
791        inner.room_account_data.remove(room_id);
792        inner.stripped_room_state.remove(room_id);
793        inner.stripped_members.remove(room_id);
794        inner.room_user_receipts.remove(room_id);
795        inner.room_event_receipts.remove(room_id);
796        inner.send_queue_events.remove(room_id);
797        inner.dependent_send_queue_events.remove(room_id);
798        inner.thread_subscriptions.remove(room_id);
799
800        Ok(())
801    }
802
803    async fn save_send_queue_request(
804        &self,
805        room_id: &RoomId,
806        transaction_id: OwnedTransactionId,
807        created_at: MilliSecondsSinceUnixEpoch,
808        kind: QueuedRequestKind,
809        priority: usize,
810    ) -> Result<(), Self::Error> {
811        self.inner
812            .write()
813            .unwrap()
814            .send_queue_events
815            .entry(room_id.to_owned())
816            .or_default()
817            .push(QueuedRequest { kind, transaction_id, error: None, priority, created_at });
818        Ok(())
819    }
820
821    async fn update_send_queue_request(
822        &self,
823        room_id: &RoomId,
824        transaction_id: &TransactionId,
825        kind: QueuedRequestKind,
826    ) -> Result<bool, Self::Error> {
827        if let Some(entry) = self
828            .inner
829            .write()
830            .unwrap()
831            .send_queue_events
832            .entry(room_id.to_owned())
833            .or_default()
834            .iter_mut()
835            .find(|item| item.transaction_id == transaction_id)
836        {
837            entry.kind = kind;
838            entry.error = None;
839            Ok(true)
840        } else {
841            Ok(false)
842        }
843    }
844
845    async fn remove_send_queue_request(
846        &self,
847        room_id: &RoomId,
848        transaction_id: &TransactionId,
849    ) -> Result<bool, Self::Error> {
850        let mut inner = self.inner.write().unwrap();
851        let q = &mut inner.send_queue_events;
852
853        let entry = q.get_mut(room_id);
854        if let Some(entry) = entry {
855            // Find the event by id in its room queue, and remove it if present.
856            if let Some(pos) = entry.iter().position(|item| item.transaction_id == transaction_id) {
857                entry.remove(pos);
858                // And if this was the last event before removal, remove the entire room entry.
859                if entry.is_empty() {
860                    q.remove(room_id);
861                }
862                return Ok(true);
863            }
864        }
865
866        Ok(false)
867    }
868
869    async fn load_send_queue_requests(
870        &self,
871        room_id: &RoomId,
872    ) -> Result<Vec<QueuedRequest>, Self::Error> {
873        let mut ret = self
874            .inner
875            .write()
876            .unwrap()
877            .send_queue_events
878            .entry(room_id.to_owned())
879            .or_default()
880            .clone();
881        // Inverted order of priority, use stable sort to keep insertion order.
882        ret.sort_by_key(|item| Reverse(item.priority));
883        Ok(ret)
884    }
885
886    async fn update_send_queue_request_status(
887        &self,
888        room_id: &RoomId,
889        transaction_id: &TransactionId,
890        error: Option<QueueWedgeError>,
891    ) -> Result<(), Self::Error> {
892        if let Some(entry) = self
893            .inner
894            .write()
895            .unwrap()
896            .send_queue_events
897            .entry(room_id.to_owned())
898            .or_default()
899            .iter_mut()
900            .find(|item| item.transaction_id == transaction_id)
901        {
902            entry.error = error;
903        }
904        Ok(())
905    }
906
907    async fn load_rooms_with_unsent_requests(&self) -> Result<Vec<OwnedRoomId>, Self::Error> {
908        Ok(self.inner.read().unwrap().send_queue_events.keys().cloned().collect())
909    }
910
911    async fn save_dependent_queued_request(
912        &self,
913        room: &RoomId,
914        parent_transaction_id: &TransactionId,
915        own_transaction_id: ChildTransactionId,
916        created_at: MilliSecondsSinceUnixEpoch,
917        content: DependentQueuedRequestKind,
918    ) -> Result<(), Self::Error> {
919        self.inner
920            .write()
921            .unwrap()
922            .dependent_send_queue_events
923            .entry(room.to_owned())
924            .or_default()
925            .push(DependentQueuedRequest {
926                kind: content,
927                parent_transaction_id: parent_transaction_id.to_owned(),
928                own_transaction_id,
929                parent_key: None,
930                created_at,
931            });
932        Ok(())
933    }
934
935    async fn mark_dependent_queued_requests_as_ready(
936        &self,
937        room: &RoomId,
938        parent_txn_id: &TransactionId,
939        sent_parent_key: SentRequestKey,
940    ) -> Result<usize, Self::Error> {
941        let mut inner = self.inner.write().unwrap();
942        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
943        let mut num_updated = 0;
944        for d in dependents.iter_mut().filter(|item| item.parent_transaction_id == parent_txn_id) {
945            d.parent_key = Some(sent_parent_key.clone());
946            num_updated += 1;
947        }
948        Ok(num_updated)
949    }
950
951    async fn update_dependent_queued_request(
952        &self,
953        room: &RoomId,
954        own_transaction_id: &ChildTransactionId,
955        new_content: DependentQueuedRequestKind,
956    ) -> Result<bool, Self::Error> {
957        let mut inner = self.inner.write().unwrap();
958        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
959        for d in dependents.iter_mut() {
960            if d.own_transaction_id == *own_transaction_id {
961                d.kind = new_content;
962                return Ok(true);
963            }
964        }
965        Ok(false)
966    }
967
968    async fn remove_dependent_queued_request(
969        &self,
970        room: &RoomId,
971        txn_id: &ChildTransactionId,
972    ) -> Result<bool, Self::Error> {
973        let mut inner = self.inner.write().unwrap();
974        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
975        if let Some(pos) = dependents.iter().position(|item| item.own_transaction_id == *txn_id) {
976            dependents.remove(pos);
977            Ok(true)
978        } else {
979            Ok(false)
980        }
981    }
982
983    async fn load_dependent_queued_requests(
984        &self,
985        room: &RoomId,
986    ) -> Result<Vec<DependentQueuedRequest>, Self::Error> {
987        Ok(self
988            .inner
989            .read()
990            .unwrap()
991            .dependent_send_queue_events
992            .get(room)
993            .cloned()
994            .unwrap_or_default())
995    }
996
997    async fn upsert_thread_subscriptions(
998        &self,
999        updates: Vec<(&RoomId, &EventId, StoredThreadSubscription)>,
1000    ) -> Result<(), Self::Error> {
1001        let mut inner = self.inner.write().unwrap();
1002
1003        for (room_id, thread_id, mut new) in updates {
1004            let room_subs = inner.thread_subscriptions.entry(room_id.to_owned()).or_default();
1005
1006            if let Some(previous) = room_subs.get(thread_id) {
1007                if *previous == new {
1008                    continue;
1009                }
1010                if !compare_thread_subscription_bump_stamps(
1011                    previous.bump_stamp,
1012                    &mut new.bump_stamp,
1013                ) {
1014                    continue;
1015                }
1016            }
1017
1018            room_subs.insert(thread_id.to_owned(), new);
1019        }
1020
1021        Ok(())
1022    }
1023
1024    async fn load_thread_subscription(
1025        &self,
1026        room: &RoomId,
1027        thread_id: &EventId,
1028    ) -> Result<Option<StoredThreadSubscription>, Self::Error> {
1029        let inner = self.inner.read().unwrap();
1030        Ok(inner
1031            .thread_subscriptions
1032            .get(room)
1033            .and_then(|subscriptions| subscriptions.get(thread_id))
1034            .copied())
1035    }
1036
1037    async fn remove_thread_subscription(
1038        &self,
1039        room: &RoomId,
1040        thread_id: &EventId,
1041    ) -> Result<(), Self::Error> {
1042        let mut inner = self.inner.write().unwrap();
1043
1044        let Some(room_subs) = inner.thread_subscriptions.get_mut(room) else {
1045            return Ok(());
1046        };
1047
1048        room_subs.remove(thread_id);
1049
1050        if room_subs.is_empty() {
1051            // If there are no more subscriptions for this room, remove the room entry.
1052            inner.thread_subscriptions.remove(room);
1053        }
1054
1055        Ok(())
1056    }
1057
1058    async fn optimize(&self) -> Result<(), Self::Error> {
1059        Ok(())
1060    }
1061
1062    async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1063        Ok(None)
1064    }
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069    use super::{MemoryStore, Result, StateStore};
1070
1071    async fn get_store() -> Result<impl StateStore> {
1072        Ok(MemoryStore::new())
1073    }
1074
1075    statestore_integration_tests!();
1076}