matrix_sdk_base/store/
memory_store.rs

1// Copyright 2021 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{BTreeMap, BTreeSet, HashMap},
17    sync::RwLock,
18};
19
20use async_trait::async_trait;
21use growable_bloom_filter::GrowableBloom;
22use matrix_sdk_common::{ROOM_VERSION_FALLBACK, ROOM_VERSION_RULES_FALLBACK};
23use ruma::{
24    CanonicalJsonObject, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedMxcUri,
25    OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId,
26    canonical_json::{RedactedBecause, redact},
27    events::{
28        AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
29        AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType,
30        presence::PresenceEvent,
31        receipt::{Receipt, ReceiptThread, ReceiptType},
32        room::member::{MembershipState, StrippedRoomMemberEvent, SyncRoomMemberEvent},
33    },
34    serde::Raw,
35    time::Instant,
36};
37use tracing::{debug, instrument, warn};
38
39use super::{
40    DependentQueuedRequest, DependentQueuedRequestKind, QueuedRequestKind, Result, RoomInfo,
41    RoomLoadSettings, StateChanges, StateStore, StoreError,
42    send_queue::{ChildTransactionId, QueuedRequest, SentRequestKey},
43    traits::{ComposerDraft, ServerInfo},
44};
45use crate::{
46    MinimalRoomMemberEvent, RoomMemberships, StateStoreDataKey, StateStoreDataValue,
47    deserialized_responses::{DisplayName, RawAnySyncOrStrippedState},
48    store::{QueueWedgeError, ThreadSubscription},
49};
50
51#[derive(Debug, Default)]
52#[allow(clippy::type_complexity)]
53struct MemoryStoreInner {
54    recently_visited_rooms: HashMap<OwnedUserId, Vec<OwnedRoomId>>,
55    composer_drafts: HashMap<(OwnedRoomId, Option<OwnedEventId>), ComposerDraft>,
56    user_avatar_url: HashMap<OwnedUserId, OwnedMxcUri>,
57    sync_token: Option<String>,
58    server_info: Option<ServerInfo>,
59    filters: HashMap<String, String>,
60    utd_hook_manager_data: Option<GrowableBloom>,
61    one_time_key_uploaded_error: bool,
62    account_data: HashMap<GlobalAccountDataEventType, Raw<AnyGlobalAccountDataEvent>>,
63    profiles: HashMap<OwnedRoomId, HashMap<OwnedUserId, MinimalRoomMemberEvent>>,
64    display_names: HashMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>,
65    members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
66    room_info: HashMap<OwnedRoomId, RoomInfo>,
67    room_state:
68        HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnySyncStateEvent>>>>,
69    room_account_data:
70        HashMap<OwnedRoomId, HashMap<RoomAccountDataEventType, Raw<AnyRoomAccountDataEvent>>>,
71    stripped_room_state:
72        HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnyStrippedStateEvent>>>>,
73    stripped_members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
74    presence: HashMap<OwnedUserId, Raw<PresenceEvent>>,
75    room_user_receipts: HashMap<
76        OwnedRoomId,
77        HashMap<(String, Option<String>), HashMap<OwnedUserId, (OwnedEventId, Receipt)>>,
78    >,
79    room_event_receipts: HashMap<
80        OwnedRoomId,
81        HashMap<(String, Option<String>), HashMap<OwnedEventId, HashMap<OwnedUserId, Receipt>>>,
82    >,
83    custom: HashMap<Vec<u8>, Vec<u8>>,
84    send_queue_events: BTreeMap<OwnedRoomId, Vec<QueuedRequest>>,
85    dependent_send_queue_events: BTreeMap<OwnedRoomId, Vec<DependentQueuedRequest>>,
86    seen_knock_requests: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, OwnedUserId>>,
87    thread_subscriptions: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadSubscription>>,
88}
89
90/// In-memory, non-persistent implementation of the `StateStore`.
91///
92/// Default if no other is configured at startup.
93#[derive(Debug, Default)]
94pub struct MemoryStore {
95    inner: RwLock<MemoryStoreInner>,
96}
97
98impl MemoryStore {
99    /// Create a new empty MemoryStore
100    pub fn new() -> Self {
101        Self::default()
102    }
103
104    fn get_user_room_receipt_event_impl(
105        &self,
106        room_id: &RoomId,
107        receipt_type: ReceiptType,
108        thread: ReceiptThread,
109        user_id: &UserId,
110    ) -> Option<(OwnedEventId, Receipt)> {
111        self.inner
112            .read()
113            .unwrap()
114            .room_user_receipts
115            .get(room_id)?
116            .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
117            .get(user_id)
118            .cloned()
119    }
120
121    fn get_event_room_receipt_events_impl(
122        &self,
123        room_id: &RoomId,
124        receipt_type: ReceiptType,
125        thread: ReceiptThread,
126        event_id: &EventId,
127    ) -> Option<Vec<(OwnedUserId, Receipt)>> {
128        Some(
129            self.inner
130                .read()
131                .unwrap()
132                .room_event_receipts
133                .get(room_id)?
134                .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
135                .get(event_id)?
136                .iter()
137                .map(|(key, value)| (key.clone(), value.clone()))
138                .collect(),
139        )
140    }
141}
142
143#[cfg_attr(target_family = "wasm", async_trait(?Send))]
144#[cfg_attr(not(target_family = "wasm"), async_trait)]
145impl StateStore for MemoryStore {
146    type Error = StoreError;
147
148    async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<Option<StateStoreDataValue>> {
149        let inner = self.inner.read().unwrap();
150
151        Ok(match key {
152            StateStoreDataKey::SyncToken => {
153                inner.sync_token.clone().map(StateStoreDataValue::SyncToken)
154            }
155            StateStoreDataKey::ServerInfo => {
156                inner.server_info.clone().map(StateStoreDataValue::ServerInfo)
157            }
158            StateStoreDataKey::Filter(filter_name) => {
159                inner.filters.get(filter_name).cloned().map(StateStoreDataValue::Filter)
160            }
161            StateStoreDataKey::UserAvatarUrl(user_id) => {
162                inner.user_avatar_url.get(user_id).cloned().map(StateStoreDataValue::UserAvatarUrl)
163            }
164            StateStoreDataKey::RecentlyVisitedRooms(user_id) => inner
165                .recently_visited_rooms
166                .get(user_id)
167                .cloned()
168                .map(StateStoreDataValue::RecentlyVisitedRooms),
169            StateStoreDataKey::UtdHookManagerData => {
170                inner.utd_hook_manager_data.clone().map(StateStoreDataValue::UtdHookManagerData)
171            }
172            StateStoreDataKey::OneTimeKeyAlreadyUploaded => inner
173                .one_time_key_uploaded_error
174                .then_some(StateStoreDataValue::OneTimeKeyAlreadyUploaded),
175            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
176                let key = (room_id.to_owned(), thread_root.map(ToOwned::to_owned));
177                inner.composer_drafts.get(&key).cloned().map(StateStoreDataValue::ComposerDraft)
178            }
179            StateStoreDataKey::SeenKnockRequests(room_id) => inner
180                .seen_knock_requests
181                .get(room_id)
182                .cloned()
183                .map(StateStoreDataValue::SeenKnockRequests),
184        })
185    }
186
187    async fn set_kv_data(
188        &self,
189        key: StateStoreDataKey<'_>,
190        value: StateStoreDataValue,
191    ) -> Result<()> {
192        let mut inner = self.inner.write().unwrap();
193        match key {
194            StateStoreDataKey::SyncToken => {
195                inner.sync_token =
196                    Some(value.into_sync_token().expect("Session data not a sync token"))
197            }
198            StateStoreDataKey::Filter(filter_name) => {
199                inner.filters.insert(
200                    filter_name.to_owned(),
201                    value.into_filter().expect("Session data not a filter"),
202                );
203            }
204            StateStoreDataKey::UserAvatarUrl(user_id) => {
205                inner.user_avatar_url.insert(
206                    user_id.to_owned(),
207                    value.into_user_avatar_url().expect("Session data not a user avatar url"),
208                );
209            }
210            StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
211                inner.recently_visited_rooms.insert(
212                    user_id.to_owned(),
213                    value
214                        .into_recently_visited_rooms()
215                        .expect("Session data not a list of recently visited rooms"),
216                );
217            }
218            StateStoreDataKey::UtdHookManagerData => {
219                inner.utd_hook_manager_data = Some(
220                    value
221                        .into_utd_hook_manager_data()
222                        .expect("Session data not the hook manager data"),
223                );
224            }
225            StateStoreDataKey::OneTimeKeyAlreadyUploaded => {
226                inner.one_time_key_uploaded_error = true;
227            }
228            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
229                inner.composer_drafts.insert(
230                    (room_id.to_owned(), thread_root.map(ToOwned::to_owned)),
231                    value.into_composer_draft().expect("Session data not a composer draft"),
232                );
233            }
234            StateStoreDataKey::ServerInfo => {
235                inner.server_info = Some(
236                    value.into_server_info().expect("Session data not containing server info"),
237                );
238            }
239            StateStoreDataKey::SeenKnockRequests(room_id) => {
240                inner.seen_knock_requests.insert(
241                    room_id.to_owned(),
242                    value
243                        .into_seen_knock_requests()
244                        .expect("Session data is not a set of seen join request ids"),
245                );
246            }
247        }
248
249        Ok(())
250    }
251
252    async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> {
253        let mut inner = self.inner.write().unwrap();
254        match key {
255            StateStoreDataKey::SyncToken => inner.sync_token = None,
256            StateStoreDataKey::ServerInfo => inner.server_info = None,
257            StateStoreDataKey::Filter(filter_name) => {
258                inner.filters.remove(filter_name);
259            }
260            StateStoreDataKey::UserAvatarUrl(user_id) => {
261                inner.user_avatar_url.remove(user_id);
262            }
263            StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
264                inner.recently_visited_rooms.remove(user_id);
265            }
266            StateStoreDataKey::UtdHookManagerData => inner.utd_hook_manager_data = None,
267            StateStoreDataKey::OneTimeKeyAlreadyUploaded => {
268                inner.one_time_key_uploaded_error = false
269            }
270            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
271                let key = (room_id.to_owned(), thread_root.map(ToOwned::to_owned));
272                inner.composer_drafts.remove(&key);
273            }
274            StateStoreDataKey::SeenKnockRequests(room_id) => {
275                inner.seen_knock_requests.remove(room_id);
276            }
277        }
278        Ok(())
279    }
280
281    #[instrument(skip(self, changes))]
282    async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
283        let now = Instant::now();
284
285        let mut inner = self.inner.write().unwrap();
286
287        if let Some(s) = &changes.sync_token {
288            inner.sync_token = Some(s.to_owned());
289        }
290
291        for (room, users) in &changes.profiles_to_delete {
292            let Some(room_profiles) = inner.profiles.get_mut(room) else {
293                continue;
294            };
295            for user in users {
296                room_profiles.remove(user);
297            }
298        }
299
300        for (room, users) in &changes.profiles {
301            for (user_id, profile) in users {
302                inner
303                    .profiles
304                    .entry(room.clone())
305                    .or_default()
306                    .insert(user_id.clone(), profile.clone());
307            }
308        }
309
310        for (room, map) in &changes.ambiguity_maps {
311            for (display_name, display_names) in map {
312                inner
313                    .display_names
314                    .entry(room.clone())
315                    .or_default()
316                    .insert(display_name.clone(), display_names.clone());
317            }
318        }
319
320        for (event_type, event) in &changes.account_data {
321            inner.account_data.insert(event_type.clone(), event.clone());
322        }
323
324        for (room, events) in &changes.room_account_data {
325            for (event_type, event) in events {
326                inner
327                    .room_account_data
328                    .entry(room.clone())
329                    .or_default()
330                    .insert(event_type.clone(), event.clone());
331            }
332        }
333
334        for (room, event_types) in &changes.state {
335            for (event_type, events) in event_types {
336                for (state_key, raw_event) in events {
337                    inner
338                        .room_state
339                        .entry(room.clone())
340                        .or_default()
341                        .entry(event_type.clone())
342                        .or_default()
343                        .insert(state_key.to_owned(), raw_event.clone());
344                    inner.stripped_room_state.remove(room);
345
346                    if *event_type == StateEventType::RoomMember {
347                        let event =
348                            match raw_event.deserialize_as_unchecked::<SyncRoomMemberEvent>() {
349                                Ok(ev) => ev,
350                                Err(e) => {
351                                    let event_id: Option<String> =
352                                        raw_event.get_field("event_id").ok().flatten();
353                                    debug!(event_id, "Failed to deserialize member event: {e}");
354                                    continue;
355                                }
356                            };
357
358                        inner.stripped_members.remove(room);
359
360                        inner
361                            .members
362                            .entry(room.clone())
363                            .or_default()
364                            .insert(event.state_key().to_owned(), event.membership().clone());
365                    }
366                }
367            }
368        }
369
370        for (room_id, info) in &changes.room_infos {
371            inner.room_info.insert(room_id.clone(), info.clone());
372        }
373
374        for (sender, event) in &changes.presence {
375            inner.presence.insert(sender.clone(), event.clone());
376        }
377
378        for (room, event_types) in &changes.stripped_state {
379            for (event_type, events) in event_types {
380                for (state_key, raw_event) in events {
381                    inner
382                        .stripped_room_state
383                        .entry(room.clone())
384                        .or_default()
385                        .entry(event_type.clone())
386                        .or_default()
387                        .insert(state_key.to_owned(), raw_event.clone());
388
389                    if *event_type == StateEventType::RoomMember {
390                        let event =
391                            match raw_event.deserialize_as_unchecked::<StrippedRoomMemberEvent>() {
392                                Ok(ev) => ev,
393                                Err(e) => {
394                                    let event_id: Option<String> =
395                                        raw_event.get_field("event_id").ok().flatten();
396                                    debug!(
397                                        event_id,
398                                        "Failed to deserialize stripped member event: {e}"
399                                    );
400                                    continue;
401                                }
402                            };
403
404                        inner
405                            .stripped_members
406                            .entry(room.clone())
407                            .or_default()
408                            .insert(event.state_key, event.content.membership.clone());
409                    }
410                }
411            }
412        }
413
414        for (room, content) in &changes.receipts {
415            for (event_id, receipts) in &content.0 {
416                for (receipt_type, receipts) in receipts {
417                    for (user_id, receipt) in receipts {
418                        let thread = receipt.thread.as_str().map(ToOwned::to_owned);
419                        // Add the receipt to the room user receipts
420                        if let Some((old_event, _)) = inner
421                            .room_user_receipts
422                            .entry(room.clone())
423                            .or_default()
424                            .entry((receipt_type.to_string(), thread.clone()))
425                            .or_default()
426                            .insert(user_id.clone(), (event_id.clone(), receipt.clone()))
427                        {
428                            // Remove the old receipt from the room event receipts
429                            if let Some(receipt_map) = inner.room_event_receipts.get_mut(room)
430                                && let Some(event_map) =
431                                    receipt_map.get_mut(&(receipt_type.to_string(), thread.clone()))
432                                && let Some(user_map) = event_map.get_mut(&old_event)
433                            {
434                                user_map.remove(user_id);
435                            }
436                        }
437
438                        // Add the receipt to the room event receipts
439                        inner
440                            .room_event_receipts
441                            .entry(room.clone())
442                            .or_default()
443                            .entry((receipt_type.to_string(), thread))
444                            .or_default()
445                            .entry(event_id.clone())
446                            .or_default()
447                            .insert(user_id.clone(), receipt.clone());
448                    }
449                }
450            }
451        }
452
453        let make_redaction_rules = |room_info: &HashMap<OwnedRoomId, RoomInfo>, room_id| {
454            room_info.get(room_id).map(|info| info.room_version_rules_or_default()).unwrap_or_else(|| {
455                warn!(
456                    ?room_id,
457                    "Unable to get the room version rules, defaulting to rules for room version {ROOM_VERSION_FALLBACK}"
458                );
459                ROOM_VERSION_RULES_FALLBACK
460            }).redaction
461        };
462
463        let inner = &mut *inner;
464        for (room_id, redactions) in &changes.redactions {
465            let mut redaction_rules = None;
466
467            if let Some(room) = inner.room_state.get_mut(room_id) {
468                for ref_room_mu in room.values_mut() {
469                    for raw_evt in ref_room_mu.values_mut() {
470                        if let Ok(Some(event_id)) = raw_evt.get_field::<OwnedEventId>("event_id")
471                            && let Some(redaction) = redactions.get(&event_id)
472                        {
473                            let redacted = redact(
474                                raw_evt.deserialize_as::<CanonicalJsonObject>()?,
475                                redaction_rules.get_or_insert_with(|| {
476                                    make_redaction_rules(&inner.room_info, room_id)
477                                }),
478                                Some(RedactedBecause::from_raw_event(redaction)?),
479                            )
480                            .map_err(StoreError::Redaction)?;
481                            *raw_evt = Raw::new(&redacted)?.cast_unchecked();
482                        }
483                    }
484                }
485            }
486        }
487
488        debug!("Saved changes in {:?}", now.elapsed());
489
490        Ok(())
491    }
492
493    async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
494        Ok(self.inner.read().unwrap().presence.get(user_id).cloned())
495    }
496
497    async fn get_presence_events(
498        &self,
499        user_ids: &[OwnedUserId],
500    ) -> Result<Vec<Raw<PresenceEvent>>> {
501        let presence = &self.inner.read().unwrap().presence;
502        Ok(user_ids.iter().filter_map(|user_id| presence.get(user_id).cloned()).collect())
503    }
504
505    async fn get_state_event(
506        &self,
507        room_id: &RoomId,
508        event_type: StateEventType,
509        state_key: &str,
510    ) -> Result<Option<RawAnySyncOrStrippedState>> {
511        Ok(self
512            .get_state_events_for_keys(room_id, event_type, &[state_key])
513            .await?
514            .into_iter()
515            .next())
516    }
517
518    async fn get_state_events(
519        &self,
520        room_id: &RoomId,
521        event_type: StateEventType,
522    ) -> Result<Vec<RawAnySyncOrStrippedState>> {
523        fn get_events<T>(
524            state_map: &HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<T>>>>,
525            room_id: &RoomId,
526            event_type: &StateEventType,
527            to_enum: fn(Raw<T>) -> RawAnySyncOrStrippedState,
528        ) -> Option<Vec<RawAnySyncOrStrippedState>> {
529            let state_events = state_map.get(room_id)?.get(event_type)?;
530            Some(state_events.values().cloned().map(to_enum).collect())
531        }
532
533        let inner = self.inner.read().unwrap();
534        Ok(get_events(
535            &inner.stripped_room_state,
536            room_id,
537            &event_type,
538            RawAnySyncOrStrippedState::Stripped,
539        )
540        .or_else(|| {
541            get_events(&inner.room_state, room_id, &event_type, RawAnySyncOrStrippedState::Sync)
542        })
543        .unwrap_or_default())
544    }
545
546    async fn get_state_events_for_keys(
547        &self,
548        room_id: &RoomId,
549        event_type: StateEventType,
550        state_keys: &[&str],
551    ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
552        let inner = self.inner.read().unwrap();
553
554        if let Some(stripped_state_events) =
555            inner.stripped_room_state.get(room_id).and_then(|events| events.get(&event_type))
556        {
557            Ok(state_keys
558                .iter()
559                .filter_map(|k| {
560                    stripped_state_events
561                        .get(*k)
562                        .map(|e| RawAnySyncOrStrippedState::Stripped(e.clone()))
563                })
564                .collect())
565        } else if let Some(sync_state_events) =
566            inner.room_state.get(room_id).and_then(|events| events.get(&event_type))
567        {
568            Ok(state_keys
569                .iter()
570                .filter_map(|k| {
571                    sync_state_events.get(*k).map(|e| RawAnySyncOrStrippedState::Sync(e.clone()))
572                })
573                .collect())
574        } else {
575            Ok(Vec::new())
576        }
577    }
578
579    async fn get_profile(
580        &self,
581        room_id: &RoomId,
582        user_id: &UserId,
583    ) -> Result<Option<MinimalRoomMemberEvent>> {
584        Ok(self
585            .inner
586            .read()
587            .unwrap()
588            .profiles
589            .get(room_id)
590            .and_then(|room_profiles| room_profiles.get(user_id))
591            .cloned())
592    }
593
594    async fn get_profiles<'a>(
595        &self,
596        room_id: &RoomId,
597        user_ids: &'a [OwnedUserId],
598    ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>> {
599        if user_ids.is_empty() {
600            return Ok(BTreeMap::new());
601        }
602
603        let profiles = &self.inner.read().unwrap().profiles;
604        let Some(room_profiles) = profiles.get(room_id) else {
605            return Ok(BTreeMap::new());
606        };
607
608        Ok(user_ids
609            .iter()
610            .filter_map(|user_id| room_profiles.get(user_id).map(|p| (&**user_id, p.clone())))
611            .collect())
612    }
613
614    #[instrument(skip(self, memberships))]
615    async fn get_user_ids(
616        &self,
617        room_id: &RoomId,
618        memberships: RoomMemberships,
619    ) -> Result<Vec<OwnedUserId>> {
620        /// Get the user IDs for the given room with the given memberships and
621        /// stripped state.
622        ///
623        /// If `memberships` is empty, returns all user IDs in the room with the
624        /// given stripped state.
625        fn get_user_ids_inner(
626            members: &HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
627            room_id: &RoomId,
628            memberships: RoomMemberships,
629        ) -> Vec<OwnedUserId> {
630            members
631                .get(room_id)
632                .map(|members| {
633                    members
634                        .iter()
635                        .filter_map(|(user_id, membership)| {
636                            memberships.matches(membership).then_some(user_id)
637                        })
638                        .cloned()
639                        .collect()
640                })
641                .unwrap_or_default()
642        }
643        let inner = self.inner.read().unwrap();
644        let v = get_user_ids_inner(&inner.stripped_members, room_id, memberships);
645        if !v.is_empty() {
646            return Ok(v);
647        }
648        Ok(get_user_ids_inner(&inner.members, room_id, memberships))
649    }
650
651    async fn get_room_infos(&self, room_load_settings: &RoomLoadSettings) -> Result<Vec<RoomInfo>> {
652        let memory_store_inner = self.inner.read().unwrap();
653        let room_infos = &memory_store_inner.room_info;
654
655        Ok(match room_load_settings {
656            RoomLoadSettings::All => room_infos.values().cloned().collect(),
657
658            RoomLoadSettings::One(room_id) => match room_infos.get(room_id) {
659                Some(room_info) => vec![room_info.clone()],
660                None => vec![],
661            },
662        })
663    }
664
665    async fn get_users_with_display_name(
666        &self,
667        room_id: &RoomId,
668        display_name: &DisplayName,
669    ) -> Result<BTreeSet<OwnedUserId>> {
670        Ok(self
671            .inner
672            .read()
673            .unwrap()
674            .display_names
675            .get(room_id)
676            .and_then(|room_names| room_names.get(display_name).cloned())
677            .unwrap_or_default())
678    }
679
680    async fn get_users_with_display_names<'a>(
681        &self,
682        room_id: &RoomId,
683        display_names: &'a [DisplayName],
684    ) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>> {
685        if display_names.is_empty() {
686            return Ok(HashMap::new());
687        }
688
689        let inner = self.inner.read().unwrap();
690        let Some(room_names) = inner.display_names.get(room_id) else {
691            return Ok(HashMap::new());
692        };
693
694        Ok(display_names.iter().filter_map(|n| room_names.get(n).map(|d| (n, d.clone()))).collect())
695    }
696
697    async fn get_account_data_event(
698        &self,
699        event_type: GlobalAccountDataEventType,
700    ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
701        Ok(self.inner.read().unwrap().account_data.get(&event_type).cloned())
702    }
703
704    async fn get_room_account_data_event(
705        &self,
706        room_id: &RoomId,
707        event_type: RoomAccountDataEventType,
708    ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
709        Ok(self
710            .inner
711            .read()
712            .unwrap()
713            .room_account_data
714            .get(room_id)
715            .and_then(|m| m.get(&event_type))
716            .cloned())
717    }
718
719    async fn get_user_room_receipt_event(
720        &self,
721        room_id: &RoomId,
722        receipt_type: ReceiptType,
723        thread: ReceiptThread,
724        user_id: &UserId,
725    ) -> Result<Option<(OwnedEventId, Receipt)>> {
726        Ok(self.get_user_room_receipt_event_impl(room_id, receipt_type, thread, user_id))
727    }
728
729    async fn get_event_room_receipt_events(
730        &self,
731        room_id: &RoomId,
732        receipt_type: ReceiptType,
733        thread: ReceiptThread,
734        event_id: &EventId,
735    ) -> Result<Vec<(OwnedUserId, Receipt)>> {
736        Ok(self
737            .get_event_room_receipt_events_impl(room_id, receipt_type, thread, event_id)
738            .unwrap_or_default())
739    }
740
741    async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
742        Ok(self.inner.read().unwrap().custom.get(key).cloned())
743    }
744
745    async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
746        Ok(self.inner.write().unwrap().custom.insert(key.to_vec(), value))
747    }
748
749    async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
750        Ok(self.inner.write().unwrap().custom.remove(key))
751    }
752
753    async fn remove_room(&self, room_id: &RoomId) -> Result<()> {
754        let mut inner = self.inner.write().unwrap();
755
756        inner.profiles.remove(room_id);
757        inner.display_names.remove(room_id);
758        inner.members.remove(room_id);
759        inner.room_info.remove(room_id);
760        inner.room_state.remove(room_id);
761        inner.room_account_data.remove(room_id);
762        inner.stripped_room_state.remove(room_id);
763        inner.stripped_members.remove(room_id);
764        inner.room_user_receipts.remove(room_id);
765        inner.room_event_receipts.remove(room_id);
766        inner.send_queue_events.remove(room_id);
767        inner.dependent_send_queue_events.remove(room_id);
768        inner.thread_subscriptions.remove(room_id);
769
770        Ok(())
771    }
772
773    async fn save_send_queue_request(
774        &self,
775        room_id: &RoomId,
776        transaction_id: OwnedTransactionId,
777        created_at: MilliSecondsSinceUnixEpoch,
778        kind: QueuedRequestKind,
779        priority: usize,
780    ) -> Result<(), Self::Error> {
781        self.inner
782            .write()
783            .unwrap()
784            .send_queue_events
785            .entry(room_id.to_owned())
786            .or_default()
787            .push(QueuedRequest { kind, transaction_id, error: None, priority, created_at });
788        Ok(())
789    }
790
791    async fn update_send_queue_request(
792        &self,
793        room_id: &RoomId,
794        transaction_id: &TransactionId,
795        kind: QueuedRequestKind,
796    ) -> Result<bool, Self::Error> {
797        if let Some(entry) = self
798            .inner
799            .write()
800            .unwrap()
801            .send_queue_events
802            .entry(room_id.to_owned())
803            .or_default()
804            .iter_mut()
805            .find(|item| item.transaction_id == transaction_id)
806        {
807            entry.kind = kind;
808            entry.error = None;
809            Ok(true)
810        } else {
811            Ok(false)
812        }
813    }
814
815    async fn remove_send_queue_request(
816        &self,
817        room_id: &RoomId,
818        transaction_id: &TransactionId,
819    ) -> Result<bool, Self::Error> {
820        let mut inner = self.inner.write().unwrap();
821        let q = &mut inner.send_queue_events;
822
823        let entry = q.get_mut(room_id);
824        if let Some(entry) = entry {
825            // Find the event by id in its room queue, and remove it if present.
826            if let Some(pos) = entry.iter().position(|item| item.transaction_id == transaction_id) {
827                entry.remove(pos);
828                // And if this was the last event before removal, remove the entire room entry.
829                if entry.is_empty() {
830                    q.remove(room_id);
831                }
832                return Ok(true);
833            }
834        }
835
836        Ok(false)
837    }
838
839    async fn load_send_queue_requests(
840        &self,
841        room_id: &RoomId,
842    ) -> Result<Vec<QueuedRequest>, Self::Error> {
843        let mut ret = self
844            .inner
845            .write()
846            .unwrap()
847            .send_queue_events
848            .entry(room_id.to_owned())
849            .or_default()
850            .clone();
851        // Inverted order of priority, use stable sort to keep insertion order.
852        ret.sort_by(|lhs, rhs| rhs.priority.cmp(&lhs.priority));
853        Ok(ret)
854    }
855
856    async fn update_send_queue_request_status(
857        &self,
858        room_id: &RoomId,
859        transaction_id: &TransactionId,
860        error: Option<QueueWedgeError>,
861    ) -> Result<(), Self::Error> {
862        if let Some(entry) = self
863            .inner
864            .write()
865            .unwrap()
866            .send_queue_events
867            .entry(room_id.to_owned())
868            .or_default()
869            .iter_mut()
870            .find(|item| item.transaction_id == transaction_id)
871        {
872            entry.error = error;
873        }
874        Ok(())
875    }
876
877    async fn load_rooms_with_unsent_requests(&self) -> Result<Vec<OwnedRoomId>, Self::Error> {
878        Ok(self.inner.read().unwrap().send_queue_events.keys().cloned().collect())
879    }
880
881    async fn save_dependent_queued_request(
882        &self,
883        room: &RoomId,
884        parent_transaction_id: &TransactionId,
885        own_transaction_id: ChildTransactionId,
886        created_at: MilliSecondsSinceUnixEpoch,
887        content: DependentQueuedRequestKind,
888    ) -> Result<(), Self::Error> {
889        self.inner
890            .write()
891            .unwrap()
892            .dependent_send_queue_events
893            .entry(room.to_owned())
894            .or_default()
895            .push(DependentQueuedRequest {
896                kind: content,
897                parent_transaction_id: parent_transaction_id.to_owned(),
898                own_transaction_id,
899                parent_key: None,
900                created_at,
901            });
902        Ok(())
903    }
904
905    async fn mark_dependent_queued_requests_as_ready(
906        &self,
907        room: &RoomId,
908        parent_txn_id: &TransactionId,
909        sent_parent_key: SentRequestKey,
910    ) -> Result<usize, Self::Error> {
911        let mut inner = self.inner.write().unwrap();
912        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
913        let mut num_updated = 0;
914        for d in dependents.iter_mut().filter(|item| item.parent_transaction_id == parent_txn_id) {
915            d.parent_key = Some(sent_parent_key.clone());
916            num_updated += 1;
917        }
918        Ok(num_updated)
919    }
920
921    async fn update_dependent_queued_request(
922        &self,
923        room: &RoomId,
924        own_transaction_id: &ChildTransactionId,
925        new_content: DependentQueuedRequestKind,
926    ) -> Result<bool, Self::Error> {
927        let mut inner = self.inner.write().unwrap();
928        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
929        for d in dependents.iter_mut() {
930            if d.own_transaction_id == *own_transaction_id {
931                d.kind = new_content;
932                return Ok(true);
933            }
934        }
935        Ok(false)
936    }
937
938    async fn remove_dependent_queued_request(
939        &self,
940        room: &RoomId,
941        txn_id: &ChildTransactionId,
942    ) -> Result<bool, Self::Error> {
943        let mut inner = self.inner.write().unwrap();
944        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
945        if let Some(pos) = dependents.iter().position(|item| item.own_transaction_id == *txn_id) {
946            dependents.remove(pos);
947            Ok(true)
948        } else {
949            Ok(false)
950        }
951    }
952
953    async fn load_dependent_queued_requests(
954        &self,
955        room: &RoomId,
956    ) -> Result<Vec<DependentQueuedRequest>, Self::Error> {
957        Ok(self
958            .inner
959            .read()
960            .unwrap()
961            .dependent_send_queue_events
962            .get(room)
963            .cloned()
964            .unwrap_or_default())
965    }
966
967    async fn upsert_thread_subscription(
968        &self,
969        room: &RoomId,
970        thread_id: &EventId,
971        subscription: ThreadSubscription,
972    ) -> Result<(), Self::Error> {
973        self.inner
974            .write()
975            .unwrap()
976            .thread_subscriptions
977            .entry(room.to_owned())
978            .or_default()
979            .insert(thread_id.to_owned(), subscription);
980        Ok(())
981    }
982
983    async fn load_thread_subscription(
984        &self,
985        room: &RoomId,
986        thread_id: &EventId,
987    ) -> Result<Option<ThreadSubscription>, Self::Error> {
988        let inner = self.inner.read().unwrap();
989        Ok(inner
990            .thread_subscriptions
991            .get(room)
992            .and_then(|subscriptions| subscriptions.get(thread_id))
993            .copied())
994    }
995
996    async fn remove_thread_subscription(
997        &self,
998        room: &RoomId,
999        thread_id: &EventId,
1000    ) -> Result<(), Self::Error> {
1001        let mut inner = self.inner.write().unwrap();
1002
1003        let Some(room_subs) = inner.thread_subscriptions.get_mut(room) else {
1004            return Ok(());
1005        };
1006
1007        room_subs.remove(thread_id);
1008
1009        if room_subs.is_empty() {
1010            // If there are no more subscriptions for this room, remove the room entry.
1011            inner.thread_subscriptions.remove(room);
1012        }
1013
1014        Ok(())
1015    }
1016}
1017
1018#[cfg(test)]
1019mod tests {
1020    use super::{MemoryStore, Result, StateStore};
1021
1022    async fn get_store() -> Result<impl StateStore> {
1023        Ok(MemoryStore::new())
1024    }
1025
1026    statestore_integration_tests!();
1027}