matrix_sdk_base/store/
memory_store.rs

1// Copyright 2021 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{BTreeMap, BTreeSet, HashMap},
17    sync::RwLock,
18};
19
20use async_trait::async_trait;
21use growable_bloom_filter::GrowableBloom;
22use matrix_sdk_common::{ROOM_VERSION_FALLBACK, ROOM_VERSION_RULES_FALLBACK};
23use ruma::{
24    CanonicalJsonObject, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedMxcUri,
25    OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId,
26    canonical_json::{RedactedBecause, redact},
27    events::{
28        AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
29        AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType,
30        presence::PresenceEvent,
31        receipt::{Receipt, ReceiptThread, ReceiptType},
32        room::member::{MembershipState, StrippedRoomMemberEvent, SyncRoomMemberEvent},
33    },
34    serde::Raw,
35    time::Instant,
36};
37use tracing::{debug, instrument, warn};
38
39use super::{
40    DependentQueuedRequest, DependentQueuedRequestKind, QueuedRequestKind, Result, RoomInfo,
41    RoomLoadSettings, StateChanges, StateStore, StoreError,
42    send_queue::{ChildTransactionId, QueuedRequest, SentRequestKey},
43    traits::{ComposerDraft, ServerInfo},
44};
45use crate::{
46    MinimalRoomMemberEvent, RoomMemberships, StateStoreDataKey, StateStoreDataValue,
47    deserialized_responses::{DisplayName, RawAnySyncOrStrippedState},
48    store::{
49        QueueWedgeError, StoredThreadSubscription,
50        traits::{ThreadSubscriptionCatchupToken, compare_thread_subscription_bump_stamps},
51    },
52};
53
54#[derive(Debug, Default)]
55#[allow(clippy::type_complexity)]
56struct MemoryStoreInner {
57    recently_visited_rooms: HashMap<OwnedUserId, Vec<OwnedRoomId>>,
58    composer_drafts: HashMap<(OwnedRoomId, Option<OwnedEventId>), ComposerDraft>,
59    user_avatar_url: HashMap<OwnedUserId, OwnedMxcUri>,
60    sync_token: Option<String>,
61    server_info: Option<ServerInfo>,
62    filters: HashMap<String, String>,
63    utd_hook_manager_data: Option<GrowableBloom>,
64    one_time_key_uploaded_error: bool,
65    account_data: HashMap<GlobalAccountDataEventType, Raw<AnyGlobalAccountDataEvent>>,
66    profiles: HashMap<OwnedRoomId, HashMap<OwnedUserId, MinimalRoomMemberEvent>>,
67    display_names: HashMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>,
68    members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
69    room_info: HashMap<OwnedRoomId, RoomInfo>,
70    room_state:
71        HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnySyncStateEvent>>>>,
72    room_account_data:
73        HashMap<OwnedRoomId, HashMap<RoomAccountDataEventType, Raw<AnyRoomAccountDataEvent>>>,
74    stripped_room_state:
75        HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnyStrippedStateEvent>>>>,
76    stripped_members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
77    presence: HashMap<OwnedUserId, Raw<PresenceEvent>>,
78    room_user_receipts: HashMap<
79        OwnedRoomId,
80        HashMap<(String, Option<String>), HashMap<OwnedUserId, (OwnedEventId, Receipt)>>,
81    >,
82    room_event_receipts: HashMap<
83        OwnedRoomId,
84        HashMap<(String, Option<String>), HashMap<OwnedEventId, HashMap<OwnedUserId, Receipt>>>,
85    >,
86    custom: HashMap<Vec<u8>, Vec<u8>>,
87    send_queue_events: BTreeMap<OwnedRoomId, Vec<QueuedRequest>>,
88    dependent_send_queue_events: BTreeMap<OwnedRoomId, Vec<DependentQueuedRequest>>,
89    seen_knock_requests: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, OwnedUserId>>,
90    thread_subscriptions: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, StoredThreadSubscription>>,
91    thread_subscriptions_catchup_tokens: Option<Vec<ThreadSubscriptionCatchupToken>>,
92}
93
94/// In-memory, non-persistent implementation of the `StateStore`.
95///
96/// Default if no other is configured at startup.
97#[derive(Debug, Default)]
98pub struct MemoryStore {
99    inner: RwLock<MemoryStoreInner>,
100}
101
102impl MemoryStore {
103    /// Create a new empty MemoryStore
104    pub fn new() -> Self {
105        Self::default()
106    }
107
108    fn get_user_room_receipt_event_impl(
109        &self,
110        room_id: &RoomId,
111        receipt_type: ReceiptType,
112        thread: ReceiptThread,
113        user_id: &UserId,
114    ) -> Option<(OwnedEventId, Receipt)> {
115        self.inner
116            .read()
117            .unwrap()
118            .room_user_receipts
119            .get(room_id)?
120            .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
121            .get(user_id)
122            .cloned()
123    }
124
125    fn get_event_room_receipt_events_impl(
126        &self,
127        room_id: &RoomId,
128        receipt_type: ReceiptType,
129        thread: ReceiptThread,
130        event_id: &EventId,
131    ) -> Option<Vec<(OwnedUserId, Receipt)>> {
132        Some(
133            self.inner
134                .read()
135                .unwrap()
136                .room_event_receipts
137                .get(room_id)?
138                .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
139                .get(event_id)?
140                .iter()
141                .map(|(key, value)| (key.clone(), value.clone()))
142                .collect(),
143        )
144    }
145}
146
147#[cfg_attr(target_family = "wasm", async_trait(?Send))]
148#[cfg_attr(not(target_family = "wasm"), async_trait)]
149impl StateStore for MemoryStore {
150    type Error = StoreError;
151
152    async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<Option<StateStoreDataValue>> {
153        let inner = self.inner.read().unwrap();
154
155        Ok(match key {
156            StateStoreDataKey::SyncToken => {
157                inner.sync_token.clone().map(StateStoreDataValue::SyncToken)
158            }
159            StateStoreDataKey::ServerInfo => {
160                inner.server_info.clone().map(StateStoreDataValue::ServerInfo)
161            }
162            StateStoreDataKey::Filter(filter_name) => {
163                inner.filters.get(filter_name).cloned().map(StateStoreDataValue::Filter)
164            }
165            StateStoreDataKey::UserAvatarUrl(user_id) => {
166                inner.user_avatar_url.get(user_id).cloned().map(StateStoreDataValue::UserAvatarUrl)
167            }
168            StateStoreDataKey::RecentlyVisitedRooms(user_id) => inner
169                .recently_visited_rooms
170                .get(user_id)
171                .cloned()
172                .map(StateStoreDataValue::RecentlyVisitedRooms),
173            StateStoreDataKey::UtdHookManagerData => {
174                inner.utd_hook_manager_data.clone().map(StateStoreDataValue::UtdHookManagerData)
175            }
176            StateStoreDataKey::OneTimeKeyAlreadyUploaded => inner
177                .one_time_key_uploaded_error
178                .then_some(StateStoreDataValue::OneTimeKeyAlreadyUploaded),
179            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
180                let key = (room_id.to_owned(), thread_root.map(ToOwned::to_owned));
181                inner.composer_drafts.get(&key).cloned().map(StateStoreDataValue::ComposerDraft)
182            }
183            StateStoreDataKey::SeenKnockRequests(room_id) => inner
184                .seen_knock_requests
185                .get(room_id)
186                .cloned()
187                .map(StateStoreDataValue::SeenKnockRequests),
188            StateStoreDataKey::ThreadSubscriptionsCatchupTokens => inner
189                .thread_subscriptions_catchup_tokens
190                .clone()
191                .map(StateStoreDataValue::ThreadSubscriptionsCatchupTokens),
192        })
193    }
194
195    async fn set_kv_data(
196        &self,
197        key: StateStoreDataKey<'_>,
198        value: StateStoreDataValue,
199    ) -> Result<()> {
200        let mut inner = self.inner.write().unwrap();
201        match key {
202            StateStoreDataKey::SyncToken => {
203                inner.sync_token =
204                    Some(value.into_sync_token().expect("Session data not a sync token"))
205            }
206            StateStoreDataKey::Filter(filter_name) => {
207                inner.filters.insert(
208                    filter_name.to_owned(),
209                    value.into_filter().expect("Session data not a filter"),
210                );
211            }
212            StateStoreDataKey::UserAvatarUrl(user_id) => {
213                inner.user_avatar_url.insert(
214                    user_id.to_owned(),
215                    value.into_user_avatar_url().expect("Session data not a user avatar url"),
216                );
217            }
218            StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
219                inner.recently_visited_rooms.insert(
220                    user_id.to_owned(),
221                    value
222                        .into_recently_visited_rooms()
223                        .expect("Session data not a list of recently visited rooms"),
224                );
225            }
226            StateStoreDataKey::UtdHookManagerData => {
227                inner.utd_hook_manager_data = Some(
228                    value
229                        .into_utd_hook_manager_data()
230                        .expect("Session data not the hook manager data"),
231                );
232            }
233            StateStoreDataKey::OneTimeKeyAlreadyUploaded => {
234                inner.one_time_key_uploaded_error = true;
235            }
236            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
237                inner.composer_drafts.insert(
238                    (room_id.to_owned(), thread_root.map(ToOwned::to_owned)),
239                    value.into_composer_draft().expect("Session data not a composer draft"),
240                );
241            }
242            StateStoreDataKey::ServerInfo => {
243                inner.server_info = Some(
244                    value.into_server_info().expect("Session data not containing server info"),
245                );
246            }
247            StateStoreDataKey::SeenKnockRequests(room_id) => {
248                inner.seen_knock_requests.insert(
249                    room_id.to_owned(),
250                    value
251                        .into_seen_knock_requests()
252                        .expect("Session data is not a set of seen join request ids"),
253                );
254            }
255            StateStoreDataKey::ThreadSubscriptionsCatchupTokens => {
256                inner.thread_subscriptions_catchup_tokens =
257                    Some(value.into_thread_subscriptions_catchup_tokens().expect(
258                        "Session data is not a list of thread subscription catchup tokens",
259                    ));
260            }
261        }
262
263        Ok(())
264    }
265
266    async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> {
267        let mut inner = self.inner.write().unwrap();
268        match key {
269            StateStoreDataKey::SyncToken => inner.sync_token = None,
270            StateStoreDataKey::ServerInfo => inner.server_info = None,
271            StateStoreDataKey::Filter(filter_name) => {
272                inner.filters.remove(filter_name);
273            }
274            StateStoreDataKey::UserAvatarUrl(user_id) => {
275                inner.user_avatar_url.remove(user_id);
276            }
277            StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
278                inner.recently_visited_rooms.remove(user_id);
279            }
280            StateStoreDataKey::UtdHookManagerData => inner.utd_hook_manager_data = None,
281            StateStoreDataKey::OneTimeKeyAlreadyUploaded => {
282                inner.one_time_key_uploaded_error = false
283            }
284            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
285                let key = (room_id.to_owned(), thread_root.map(ToOwned::to_owned));
286                inner.composer_drafts.remove(&key);
287            }
288            StateStoreDataKey::SeenKnockRequests(room_id) => {
289                inner.seen_knock_requests.remove(room_id);
290            }
291            StateStoreDataKey::ThreadSubscriptionsCatchupTokens => {
292                inner.thread_subscriptions_catchup_tokens = None;
293            }
294        }
295        Ok(())
296    }
297
298    #[instrument(skip(self, changes))]
299    async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
300        let now = Instant::now();
301
302        let mut inner = self.inner.write().unwrap();
303
304        if let Some(s) = &changes.sync_token {
305            inner.sync_token = Some(s.to_owned());
306        }
307
308        for (room, users) in &changes.profiles_to_delete {
309            let Some(room_profiles) = inner.profiles.get_mut(room) else {
310                continue;
311            };
312            for user in users {
313                room_profiles.remove(user);
314            }
315        }
316
317        for (room, users) in &changes.profiles {
318            for (user_id, profile) in users {
319                inner
320                    .profiles
321                    .entry(room.clone())
322                    .or_default()
323                    .insert(user_id.clone(), profile.clone());
324            }
325        }
326
327        for (room, map) in &changes.ambiguity_maps {
328            for (display_name, display_names) in map {
329                inner
330                    .display_names
331                    .entry(room.clone())
332                    .or_default()
333                    .insert(display_name.clone(), display_names.clone());
334            }
335        }
336
337        for (event_type, event) in &changes.account_data {
338            inner.account_data.insert(event_type.clone(), event.clone());
339        }
340
341        for (room, events) in &changes.room_account_data {
342            for (event_type, event) in events {
343                inner
344                    .room_account_data
345                    .entry(room.clone())
346                    .or_default()
347                    .insert(event_type.clone(), event.clone());
348            }
349        }
350
351        for (room, event_types) in &changes.state {
352            for (event_type, events) in event_types {
353                for (state_key, raw_event) in events {
354                    inner
355                        .room_state
356                        .entry(room.clone())
357                        .or_default()
358                        .entry(event_type.clone())
359                        .or_default()
360                        .insert(state_key.to_owned(), raw_event.clone());
361                    inner.stripped_room_state.remove(room);
362
363                    if *event_type == StateEventType::RoomMember {
364                        let event =
365                            match raw_event.deserialize_as_unchecked::<SyncRoomMemberEvent>() {
366                                Ok(ev) => ev,
367                                Err(e) => {
368                                    let event_id: Option<String> =
369                                        raw_event.get_field("event_id").ok().flatten();
370                                    debug!(event_id, "Failed to deserialize member event: {e}");
371                                    continue;
372                                }
373                            };
374
375                        inner.stripped_members.remove(room);
376
377                        inner
378                            .members
379                            .entry(room.clone())
380                            .or_default()
381                            .insert(event.state_key().to_owned(), event.membership().clone());
382                    }
383                }
384            }
385        }
386
387        for (room_id, info) in &changes.room_infos {
388            inner.room_info.insert(room_id.clone(), info.clone());
389        }
390
391        for (sender, event) in &changes.presence {
392            inner.presence.insert(sender.clone(), event.clone());
393        }
394
395        for (room, event_types) in &changes.stripped_state {
396            for (event_type, events) in event_types {
397                for (state_key, raw_event) in events {
398                    inner
399                        .stripped_room_state
400                        .entry(room.clone())
401                        .or_default()
402                        .entry(event_type.clone())
403                        .or_default()
404                        .insert(state_key.to_owned(), raw_event.clone());
405
406                    if *event_type == StateEventType::RoomMember {
407                        let event =
408                            match raw_event.deserialize_as_unchecked::<StrippedRoomMemberEvent>() {
409                                Ok(ev) => ev,
410                                Err(e) => {
411                                    let event_id: Option<String> =
412                                        raw_event.get_field("event_id").ok().flatten();
413                                    debug!(
414                                        event_id,
415                                        "Failed to deserialize stripped member event: {e}"
416                                    );
417                                    continue;
418                                }
419                            };
420
421                        inner
422                            .stripped_members
423                            .entry(room.clone())
424                            .or_default()
425                            .insert(event.state_key, event.content.membership.clone());
426                    }
427                }
428            }
429        }
430
431        for (room, content) in &changes.receipts {
432            for (event_id, receipts) in &content.0 {
433                for (receipt_type, receipts) in receipts {
434                    for (user_id, receipt) in receipts {
435                        let thread = receipt.thread.as_str().map(ToOwned::to_owned);
436                        // Add the receipt to the room user receipts
437                        if let Some((old_event, _)) = inner
438                            .room_user_receipts
439                            .entry(room.clone())
440                            .or_default()
441                            .entry((receipt_type.to_string(), thread.clone()))
442                            .or_default()
443                            .insert(user_id.clone(), (event_id.clone(), receipt.clone()))
444                        {
445                            // Remove the old receipt from the room event receipts
446                            if let Some(receipt_map) = inner.room_event_receipts.get_mut(room)
447                                && let Some(event_map) =
448                                    receipt_map.get_mut(&(receipt_type.to_string(), thread.clone()))
449                                && let Some(user_map) = event_map.get_mut(&old_event)
450                            {
451                                user_map.remove(user_id);
452                            }
453                        }
454
455                        // Add the receipt to the room event receipts
456                        inner
457                            .room_event_receipts
458                            .entry(room.clone())
459                            .or_default()
460                            .entry((receipt_type.to_string(), thread))
461                            .or_default()
462                            .entry(event_id.clone())
463                            .or_default()
464                            .insert(user_id.clone(), receipt.clone());
465                    }
466                }
467            }
468        }
469
470        let make_redaction_rules = |room_info: &HashMap<OwnedRoomId, RoomInfo>, room_id| {
471            room_info.get(room_id).map(|info| info.room_version_rules_or_default()).unwrap_or_else(|| {
472                warn!(
473                    ?room_id,
474                    "Unable to get the room version rules, defaulting to rules for room version {ROOM_VERSION_FALLBACK}"
475                );
476                ROOM_VERSION_RULES_FALLBACK
477            }).redaction
478        };
479
480        let inner = &mut *inner;
481        for (room_id, redactions) in &changes.redactions {
482            let mut redaction_rules = None;
483
484            if let Some(room) = inner.room_state.get_mut(room_id) {
485                for ref_room_mu in room.values_mut() {
486                    for raw_evt in ref_room_mu.values_mut() {
487                        if let Ok(Some(event_id)) = raw_evt.get_field::<OwnedEventId>("event_id")
488                            && let Some(redaction) = redactions.get(&event_id)
489                        {
490                            let redacted = redact(
491                                raw_evt.deserialize_as::<CanonicalJsonObject>()?,
492                                redaction_rules.get_or_insert_with(|| {
493                                    make_redaction_rules(&inner.room_info, room_id)
494                                }),
495                                Some(RedactedBecause::from_raw_event(redaction)?),
496                            )
497                            .map_err(StoreError::Redaction)?;
498                            *raw_evt = Raw::new(&redacted)?.cast_unchecked();
499                        }
500                    }
501                }
502            }
503        }
504
505        debug!("Saved changes in {:?}", now.elapsed());
506
507        Ok(())
508    }
509
510    async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
511        Ok(self.inner.read().unwrap().presence.get(user_id).cloned())
512    }
513
514    async fn get_presence_events(
515        &self,
516        user_ids: &[OwnedUserId],
517    ) -> Result<Vec<Raw<PresenceEvent>>> {
518        let presence = &self.inner.read().unwrap().presence;
519        Ok(user_ids.iter().filter_map(|user_id| presence.get(user_id).cloned()).collect())
520    }
521
522    async fn get_state_event(
523        &self,
524        room_id: &RoomId,
525        event_type: StateEventType,
526        state_key: &str,
527    ) -> Result<Option<RawAnySyncOrStrippedState>> {
528        Ok(self
529            .get_state_events_for_keys(room_id, event_type, &[state_key])
530            .await?
531            .into_iter()
532            .next())
533    }
534
535    async fn get_state_events(
536        &self,
537        room_id: &RoomId,
538        event_type: StateEventType,
539    ) -> Result<Vec<RawAnySyncOrStrippedState>> {
540        fn get_events<T>(
541            state_map: &HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<T>>>>,
542            room_id: &RoomId,
543            event_type: &StateEventType,
544            to_enum: fn(Raw<T>) -> RawAnySyncOrStrippedState,
545        ) -> Option<Vec<RawAnySyncOrStrippedState>> {
546            let state_events = state_map.get(room_id)?.get(event_type)?;
547            Some(state_events.values().cloned().map(to_enum).collect())
548        }
549
550        let inner = self.inner.read().unwrap();
551        Ok(get_events(
552            &inner.stripped_room_state,
553            room_id,
554            &event_type,
555            RawAnySyncOrStrippedState::Stripped,
556        )
557        .or_else(|| {
558            get_events(&inner.room_state, room_id, &event_type, RawAnySyncOrStrippedState::Sync)
559        })
560        .unwrap_or_default())
561    }
562
563    async fn get_state_events_for_keys(
564        &self,
565        room_id: &RoomId,
566        event_type: StateEventType,
567        state_keys: &[&str],
568    ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
569        let inner = self.inner.read().unwrap();
570
571        if let Some(stripped_state_events) =
572            inner.stripped_room_state.get(room_id).and_then(|events| events.get(&event_type))
573        {
574            Ok(state_keys
575                .iter()
576                .filter_map(|k| {
577                    stripped_state_events
578                        .get(*k)
579                        .map(|e| RawAnySyncOrStrippedState::Stripped(e.clone()))
580                })
581                .collect())
582        } else if let Some(sync_state_events) =
583            inner.room_state.get(room_id).and_then(|events| events.get(&event_type))
584        {
585            Ok(state_keys
586                .iter()
587                .filter_map(|k| {
588                    sync_state_events.get(*k).map(|e| RawAnySyncOrStrippedState::Sync(e.clone()))
589                })
590                .collect())
591        } else {
592            Ok(Vec::new())
593        }
594    }
595
596    async fn get_profile(
597        &self,
598        room_id: &RoomId,
599        user_id: &UserId,
600    ) -> Result<Option<MinimalRoomMemberEvent>> {
601        Ok(self
602            .inner
603            .read()
604            .unwrap()
605            .profiles
606            .get(room_id)
607            .and_then(|room_profiles| room_profiles.get(user_id))
608            .cloned())
609    }
610
611    async fn get_profiles<'a>(
612        &self,
613        room_id: &RoomId,
614        user_ids: &'a [OwnedUserId],
615    ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>> {
616        if user_ids.is_empty() {
617            return Ok(BTreeMap::new());
618        }
619
620        let profiles = &self.inner.read().unwrap().profiles;
621        let Some(room_profiles) = profiles.get(room_id) else {
622            return Ok(BTreeMap::new());
623        };
624
625        Ok(user_ids
626            .iter()
627            .filter_map(|user_id| room_profiles.get(user_id).map(|p| (&**user_id, p.clone())))
628            .collect())
629    }
630
631    #[instrument(skip(self, memberships))]
632    async fn get_user_ids(
633        &self,
634        room_id: &RoomId,
635        memberships: RoomMemberships,
636    ) -> Result<Vec<OwnedUserId>> {
637        /// Get the user IDs for the given room with the given memberships and
638        /// stripped state.
639        ///
640        /// If `memberships` is empty, returns all user IDs in the room with the
641        /// given stripped state.
642        fn get_user_ids_inner(
643            members: &HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
644            room_id: &RoomId,
645            memberships: RoomMemberships,
646        ) -> Vec<OwnedUserId> {
647            members
648                .get(room_id)
649                .map(|members| {
650                    members
651                        .iter()
652                        .filter_map(|(user_id, membership)| {
653                            memberships.matches(membership).then_some(user_id)
654                        })
655                        .cloned()
656                        .collect()
657                })
658                .unwrap_or_default()
659        }
660        let inner = self.inner.read().unwrap();
661        let v = get_user_ids_inner(&inner.stripped_members, room_id, memberships);
662        if !v.is_empty() {
663            return Ok(v);
664        }
665        Ok(get_user_ids_inner(&inner.members, room_id, memberships))
666    }
667
668    async fn get_room_infos(&self, room_load_settings: &RoomLoadSettings) -> Result<Vec<RoomInfo>> {
669        let memory_store_inner = self.inner.read().unwrap();
670        let room_infos = &memory_store_inner.room_info;
671
672        Ok(match room_load_settings {
673            RoomLoadSettings::All => room_infos.values().cloned().collect(),
674
675            RoomLoadSettings::One(room_id) => match room_infos.get(room_id) {
676                Some(room_info) => vec![room_info.clone()],
677                None => vec![],
678            },
679        })
680    }
681
682    async fn get_users_with_display_name(
683        &self,
684        room_id: &RoomId,
685        display_name: &DisplayName,
686    ) -> Result<BTreeSet<OwnedUserId>> {
687        Ok(self
688            .inner
689            .read()
690            .unwrap()
691            .display_names
692            .get(room_id)
693            .and_then(|room_names| room_names.get(display_name).cloned())
694            .unwrap_or_default())
695    }
696
697    async fn get_users_with_display_names<'a>(
698        &self,
699        room_id: &RoomId,
700        display_names: &'a [DisplayName],
701    ) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>> {
702        if display_names.is_empty() {
703            return Ok(HashMap::new());
704        }
705
706        let inner = self.inner.read().unwrap();
707        let Some(room_names) = inner.display_names.get(room_id) else {
708            return Ok(HashMap::new());
709        };
710
711        Ok(display_names.iter().filter_map(|n| room_names.get(n).map(|d| (n, d.clone()))).collect())
712    }
713
714    async fn get_account_data_event(
715        &self,
716        event_type: GlobalAccountDataEventType,
717    ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
718        Ok(self.inner.read().unwrap().account_data.get(&event_type).cloned())
719    }
720
721    async fn get_room_account_data_event(
722        &self,
723        room_id: &RoomId,
724        event_type: RoomAccountDataEventType,
725    ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
726        Ok(self
727            .inner
728            .read()
729            .unwrap()
730            .room_account_data
731            .get(room_id)
732            .and_then(|m| m.get(&event_type))
733            .cloned())
734    }
735
736    async fn get_user_room_receipt_event(
737        &self,
738        room_id: &RoomId,
739        receipt_type: ReceiptType,
740        thread: ReceiptThread,
741        user_id: &UserId,
742    ) -> Result<Option<(OwnedEventId, Receipt)>> {
743        Ok(self.get_user_room_receipt_event_impl(room_id, receipt_type, thread, user_id))
744    }
745
746    async fn get_event_room_receipt_events(
747        &self,
748        room_id: &RoomId,
749        receipt_type: ReceiptType,
750        thread: ReceiptThread,
751        event_id: &EventId,
752    ) -> Result<Vec<(OwnedUserId, Receipt)>> {
753        Ok(self
754            .get_event_room_receipt_events_impl(room_id, receipt_type, thread, event_id)
755            .unwrap_or_default())
756    }
757
758    async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
759        Ok(self.inner.read().unwrap().custom.get(key).cloned())
760    }
761
762    async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
763        Ok(self.inner.write().unwrap().custom.insert(key.to_vec(), value))
764    }
765
766    async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
767        Ok(self.inner.write().unwrap().custom.remove(key))
768    }
769
770    async fn remove_room(&self, room_id: &RoomId) -> Result<()> {
771        let mut inner = self.inner.write().unwrap();
772
773        inner.profiles.remove(room_id);
774        inner.display_names.remove(room_id);
775        inner.members.remove(room_id);
776        inner.room_info.remove(room_id);
777        inner.room_state.remove(room_id);
778        inner.room_account_data.remove(room_id);
779        inner.stripped_room_state.remove(room_id);
780        inner.stripped_members.remove(room_id);
781        inner.room_user_receipts.remove(room_id);
782        inner.room_event_receipts.remove(room_id);
783        inner.send_queue_events.remove(room_id);
784        inner.dependent_send_queue_events.remove(room_id);
785        inner.thread_subscriptions.remove(room_id);
786
787        Ok(())
788    }
789
790    async fn save_send_queue_request(
791        &self,
792        room_id: &RoomId,
793        transaction_id: OwnedTransactionId,
794        created_at: MilliSecondsSinceUnixEpoch,
795        kind: QueuedRequestKind,
796        priority: usize,
797    ) -> Result<(), Self::Error> {
798        self.inner
799            .write()
800            .unwrap()
801            .send_queue_events
802            .entry(room_id.to_owned())
803            .or_default()
804            .push(QueuedRequest { kind, transaction_id, error: None, priority, created_at });
805        Ok(())
806    }
807
808    async fn update_send_queue_request(
809        &self,
810        room_id: &RoomId,
811        transaction_id: &TransactionId,
812        kind: QueuedRequestKind,
813    ) -> Result<bool, Self::Error> {
814        if let Some(entry) = self
815            .inner
816            .write()
817            .unwrap()
818            .send_queue_events
819            .entry(room_id.to_owned())
820            .or_default()
821            .iter_mut()
822            .find(|item| item.transaction_id == transaction_id)
823        {
824            entry.kind = kind;
825            entry.error = None;
826            Ok(true)
827        } else {
828            Ok(false)
829        }
830    }
831
832    async fn remove_send_queue_request(
833        &self,
834        room_id: &RoomId,
835        transaction_id: &TransactionId,
836    ) -> Result<bool, Self::Error> {
837        let mut inner = self.inner.write().unwrap();
838        let q = &mut inner.send_queue_events;
839
840        let entry = q.get_mut(room_id);
841        if let Some(entry) = entry {
842            // Find the event by id in its room queue, and remove it if present.
843            if let Some(pos) = entry.iter().position(|item| item.transaction_id == transaction_id) {
844                entry.remove(pos);
845                // And if this was the last event before removal, remove the entire room entry.
846                if entry.is_empty() {
847                    q.remove(room_id);
848                }
849                return Ok(true);
850            }
851        }
852
853        Ok(false)
854    }
855
856    async fn load_send_queue_requests(
857        &self,
858        room_id: &RoomId,
859    ) -> Result<Vec<QueuedRequest>, Self::Error> {
860        let mut ret = self
861            .inner
862            .write()
863            .unwrap()
864            .send_queue_events
865            .entry(room_id.to_owned())
866            .or_default()
867            .clone();
868        // Inverted order of priority, use stable sort to keep insertion order.
869        ret.sort_by(|lhs, rhs| rhs.priority.cmp(&lhs.priority));
870        Ok(ret)
871    }
872
873    async fn update_send_queue_request_status(
874        &self,
875        room_id: &RoomId,
876        transaction_id: &TransactionId,
877        error: Option<QueueWedgeError>,
878    ) -> Result<(), Self::Error> {
879        if let Some(entry) = self
880            .inner
881            .write()
882            .unwrap()
883            .send_queue_events
884            .entry(room_id.to_owned())
885            .or_default()
886            .iter_mut()
887            .find(|item| item.transaction_id == transaction_id)
888        {
889            entry.error = error;
890        }
891        Ok(())
892    }
893
894    async fn load_rooms_with_unsent_requests(&self) -> Result<Vec<OwnedRoomId>, Self::Error> {
895        Ok(self.inner.read().unwrap().send_queue_events.keys().cloned().collect())
896    }
897
898    async fn save_dependent_queued_request(
899        &self,
900        room: &RoomId,
901        parent_transaction_id: &TransactionId,
902        own_transaction_id: ChildTransactionId,
903        created_at: MilliSecondsSinceUnixEpoch,
904        content: DependentQueuedRequestKind,
905    ) -> Result<(), Self::Error> {
906        self.inner
907            .write()
908            .unwrap()
909            .dependent_send_queue_events
910            .entry(room.to_owned())
911            .or_default()
912            .push(DependentQueuedRequest {
913                kind: content,
914                parent_transaction_id: parent_transaction_id.to_owned(),
915                own_transaction_id,
916                parent_key: None,
917                created_at,
918            });
919        Ok(())
920    }
921
922    async fn mark_dependent_queued_requests_as_ready(
923        &self,
924        room: &RoomId,
925        parent_txn_id: &TransactionId,
926        sent_parent_key: SentRequestKey,
927    ) -> Result<usize, Self::Error> {
928        let mut inner = self.inner.write().unwrap();
929        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
930        let mut num_updated = 0;
931        for d in dependents.iter_mut().filter(|item| item.parent_transaction_id == parent_txn_id) {
932            d.parent_key = Some(sent_parent_key.clone());
933            num_updated += 1;
934        }
935        Ok(num_updated)
936    }
937
938    async fn update_dependent_queued_request(
939        &self,
940        room: &RoomId,
941        own_transaction_id: &ChildTransactionId,
942        new_content: DependentQueuedRequestKind,
943    ) -> Result<bool, Self::Error> {
944        let mut inner = self.inner.write().unwrap();
945        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
946        for d in dependents.iter_mut() {
947            if d.own_transaction_id == *own_transaction_id {
948                d.kind = new_content;
949                return Ok(true);
950            }
951        }
952        Ok(false)
953    }
954
955    async fn remove_dependent_queued_request(
956        &self,
957        room: &RoomId,
958        txn_id: &ChildTransactionId,
959    ) -> Result<bool, Self::Error> {
960        let mut inner = self.inner.write().unwrap();
961        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
962        if let Some(pos) = dependents.iter().position(|item| item.own_transaction_id == *txn_id) {
963            dependents.remove(pos);
964            Ok(true)
965        } else {
966            Ok(false)
967        }
968    }
969
970    async fn load_dependent_queued_requests(
971        &self,
972        room: &RoomId,
973    ) -> Result<Vec<DependentQueuedRequest>, Self::Error> {
974        Ok(self
975            .inner
976            .read()
977            .unwrap()
978            .dependent_send_queue_events
979            .get(room)
980            .cloned()
981            .unwrap_or_default())
982    }
983
984    async fn upsert_thread_subscription(
985        &self,
986        room: &RoomId,
987        thread_id: &EventId,
988        mut new: StoredThreadSubscription,
989    ) -> Result<(), Self::Error> {
990        let mut inner = self.inner.write().unwrap();
991        let room_subs = inner.thread_subscriptions.entry(room.to_owned()).or_default();
992
993        if let Some(previous) = room_subs.get(thread_id) {
994            // Nothing to do.
995            if *previous == new {
996                return Ok(());
997            }
998            if !compare_thread_subscription_bump_stamps(previous.bump_stamp, &mut new.bump_stamp) {
999                return Ok(());
1000            }
1001        }
1002
1003        room_subs.insert(thread_id.to_owned(), new);
1004
1005        Ok(())
1006    }
1007
1008    async fn load_thread_subscription(
1009        &self,
1010        room: &RoomId,
1011        thread_id: &EventId,
1012    ) -> Result<Option<StoredThreadSubscription>, Self::Error> {
1013        let inner = self.inner.read().unwrap();
1014        Ok(inner
1015            .thread_subscriptions
1016            .get(room)
1017            .and_then(|subscriptions| subscriptions.get(thread_id))
1018            .copied())
1019    }
1020
1021    async fn remove_thread_subscription(
1022        &self,
1023        room: &RoomId,
1024        thread_id: &EventId,
1025    ) -> Result<(), Self::Error> {
1026        let mut inner = self.inner.write().unwrap();
1027
1028        let Some(room_subs) = inner.thread_subscriptions.get_mut(room) else {
1029            return Ok(());
1030        };
1031
1032        room_subs.remove(thread_id);
1033
1034        if room_subs.is_empty() {
1035            // If there are no more subscriptions for this room, remove the room entry.
1036            inner.thread_subscriptions.remove(room);
1037        }
1038
1039        Ok(())
1040    }
1041}
1042
1043#[cfg(test)]
1044mod tests {
1045    use super::{MemoryStore, Result, StateStore};
1046
1047    async fn get_store() -> Result<impl StateStore> {
1048        Ok(MemoryStore::new())
1049    }
1050
1051    statestore_integration_tests!();
1052}