Skip to main content

matrix_sdk_base/store/
memory_store.rs

1// Copyright 2021 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    cmp::Reverse,
17    collections::{BTreeMap, BTreeSet, HashMap},
18    sync::RwLock,
19};
20
21use async_trait::async_trait;
22use growable_bloom_filter::GrowableBloom;
23use matrix_sdk_common::{ROOM_VERSION_FALLBACK, ROOM_VERSION_RULES_FALLBACK};
24use ruma::{
25    CanonicalJsonObject, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedMxcUri,
26    OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId,
27    api::client::discovery::get_capabilities::v3::Capabilities,
28    canonical_json::{RedactedBecause, redact},
29    events::{
30        AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
31        AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType,
32        presence::PresenceEvent,
33        receipt::{Receipt, ReceiptThread, ReceiptType},
34        room::member::{MembershipState, StrippedRoomMemberEvent, SyncRoomMemberEvent},
35    },
36    serde::Raw,
37    time::Instant,
38};
39use tracing::{debug, instrument, warn};
40
41use super::{
42    DependentQueuedRequest, DependentQueuedRequestKind, QueuedRequestKind, Result, RoomInfo,
43    RoomLoadSettings, StateChanges, StateStore, StoreError, SupportedVersionsResponse,
44    TtlStoreValue, WellKnownResponse,
45    send_queue::{ChildTransactionId, QueuedRequest, SentRequestKey},
46    traits::ComposerDraft,
47};
48use crate::{
49    MinimalRoomMemberEvent, RoomMemberships, StateStoreDataKey, StateStoreDataValue,
50    deserialized_responses::{DisplayName, RawAnySyncOrStrippedState},
51    store::{
52        QueueWedgeError, StoredThreadSubscription,
53        traits::{ThreadSubscriptionCatchupToken, compare_thread_subscription_bump_stamps},
54    },
55};
56
57#[derive(Debug, Default)]
58#[allow(clippy::type_complexity)]
59struct MemoryStoreInner {
60    recently_visited_rooms: HashMap<OwnedUserId, Vec<OwnedRoomId>>,
61    composer_drafts: HashMap<(OwnedRoomId, Option<OwnedEventId>), ComposerDraft>,
62    user_avatar_url: HashMap<OwnedUserId, OwnedMxcUri>,
63    sync_token: Option<String>,
64    supported_versions: Option<TtlStoreValue<SupportedVersionsResponse>>,
65    well_known: Option<TtlStoreValue<Option<WellKnownResponse>>>,
66    filters: HashMap<String, String>,
67    utd_hook_manager_data: Option<GrowableBloom>,
68    one_time_key_uploaded_error: bool,
69    account_data: HashMap<GlobalAccountDataEventType, Raw<AnyGlobalAccountDataEvent>>,
70    profiles: HashMap<OwnedRoomId, HashMap<OwnedUserId, MinimalRoomMemberEvent>>,
71    display_names: HashMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>,
72    members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
73    room_info: HashMap<OwnedRoomId, RoomInfo>,
74    room_state:
75        HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnySyncStateEvent>>>>,
76    room_account_data:
77        HashMap<OwnedRoomId, HashMap<RoomAccountDataEventType, Raw<AnyRoomAccountDataEvent>>>,
78    stripped_room_state:
79        HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnyStrippedStateEvent>>>>,
80    stripped_members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
81    presence: HashMap<OwnedUserId, Raw<PresenceEvent>>,
82    room_user_receipts: HashMap<
83        OwnedRoomId,
84        HashMap<(String, Option<String>), HashMap<OwnedUserId, (OwnedEventId, Receipt)>>,
85    >,
86    room_event_receipts: HashMap<
87        OwnedRoomId,
88        HashMap<(String, Option<String>), HashMap<OwnedEventId, HashMap<OwnedUserId, Receipt>>>,
89    >,
90    custom: HashMap<Vec<u8>, Vec<u8>>,
91    send_queue_events: BTreeMap<OwnedRoomId, Vec<QueuedRequest>>,
92    dependent_send_queue_events: BTreeMap<OwnedRoomId, Vec<DependentQueuedRequest>>,
93    seen_knock_requests: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, OwnedUserId>>,
94    thread_subscriptions: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, StoredThreadSubscription>>,
95    thread_subscriptions_catchup_tokens: Option<Vec<ThreadSubscriptionCatchupToken>>,
96    homeserver_capabilities: Option<Capabilities>,
97}
98
99/// In-memory, non-persistent implementation of the `StateStore`.
100///
101/// Default if no other is configured at startup.
102#[derive(Debug, Default)]
103pub struct MemoryStore {
104    inner: RwLock<MemoryStoreInner>,
105}
106
107impl MemoryStore {
108    /// Create a new empty MemoryStore
109    pub fn new() -> Self {
110        Self::default()
111    }
112
113    fn get_user_room_receipt_event_impl(
114        &self,
115        room_id: &RoomId,
116        receipt_type: ReceiptType,
117        thread: ReceiptThread,
118        user_id: &UserId,
119    ) -> Option<(OwnedEventId, Receipt)> {
120        self.inner
121            .read()
122            .unwrap()
123            .room_user_receipts
124            .get(room_id)?
125            .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
126            .get(user_id)
127            .cloned()
128    }
129
130    fn get_event_room_receipt_events_impl(
131        &self,
132        room_id: &RoomId,
133        receipt_type: ReceiptType,
134        thread: ReceiptThread,
135        event_id: &EventId,
136    ) -> Option<Vec<(OwnedUserId, Receipt)>> {
137        Some(
138            self.inner
139                .read()
140                .unwrap()
141                .room_event_receipts
142                .get(room_id)?
143                .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
144                .get(event_id)?
145                .iter()
146                .map(|(key, value)| (key.clone(), value.clone()))
147                .collect(),
148        )
149    }
150}
151
152#[cfg_attr(target_family = "wasm", async_trait(?Send))]
153#[cfg_attr(not(target_family = "wasm"), async_trait)]
154impl StateStore for MemoryStore {
155    type Error = StoreError;
156
157    async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<Option<StateStoreDataValue>> {
158        let inner = self.inner.read().unwrap();
159
160        Ok(match key {
161            StateStoreDataKey::SyncToken => {
162                inner.sync_token.clone().map(StateStoreDataValue::SyncToken)
163            }
164            StateStoreDataKey::SupportedVersions => {
165                inner.supported_versions.clone().map(StateStoreDataValue::SupportedVersions)
166            }
167            StateStoreDataKey::WellKnown => {
168                inner.well_known.clone().map(StateStoreDataValue::WellKnown)
169            }
170            StateStoreDataKey::Filter(filter_name) => {
171                inner.filters.get(filter_name).cloned().map(StateStoreDataValue::Filter)
172            }
173            StateStoreDataKey::UserAvatarUrl(user_id) => {
174                inner.user_avatar_url.get(user_id).cloned().map(StateStoreDataValue::UserAvatarUrl)
175            }
176            StateStoreDataKey::RecentlyVisitedRooms(user_id) => inner
177                .recently_visited_rooms
178                .get(user_id)
179                .cloned()
180                .map(StateStoreDataValue::RecentlyVisitedRooms),
181            StateStoreDataKey::UtdHookManagerData => {
182                inner.utd_hook_manager_data.clone().map(StateStoreDataValue::UtdHookManagerData)
183            }
184            StateStoreDataKey::OneTimeKeyAlreadyUploaded => inner
185                .one_time_key_uploaded_error
186                .then_some(StateStoreDataValue::OneTimeKeyAlreadyUploaded),
187            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
188                let key = (room_id.to_owned(), thread_root.map(ToOwned::to_owned));
189                inner.composer_drafts.get(&key).cloned().map(StateStoreDataValue::ComposerDraft)
190            }
191            StateStoreDataKey::SeenKnockRequests(room_id) => inner
192                .seen_knock_requests
193                .get(room_id)
194                .cloned()
195                .map(StateStoreDataValue::SeenKnockRequests),
196            StateStoreDataKey::ThreadSubscriptionsCatchupTokens => inner
197                .thread_subscriptions_catchup_tokens
198                .clone()
199                .map(StateStoreDataValue::ThreadSubscriptionsCatchupTokens),
200            StateStoreDataKey::HomeserverCapabilities => inner
201                .homeserver_capabilities
202                .clone()
203                .map(StateStoreDataValue::HomeserverCapabilities),
204        })
205    }
206
207    async fn set_kv_data(
208        &self,
209        key: StateStoreDataKey<'_>,
210        value: StateStoreDataValue,
211    ) -> Result<()> {
212        let mut inner = self.inner.write().unwrap();
213        match key {
214            StateStoreDataKey::SyncToken => {
215                inner.sync_token =
216                    Some(value.into_sync_token().expect("Session data not a sync token"))
217            }
218            StateStoreDataKey::Filter(filter_name) => {
219                inner.filters.insert(
220                    filter_name.to_owned(),
221                    value.into_filter().expect("Session data not a filter"),
222                );
223            }
224            StateStoreDataKey::UserAvatarUrl(user_id) => {
225                inner.user_avatar_url.insert(
226                    user_id.to_owned(),
227                    value.into_user_avatar_url().expect("Session data not a user avatar url"),
228                );
229            }
230            StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
231                inner.recently_visited_rooms.insert(
232                    user_id.to_owned(),
233                    value
234                        .into_recently_visited_rooms()
235                        .expect("Session data not a list of recently visited rooms"),
236                );
237            }
238            StateStoreDataKey::UtdHookManagerData => {
239                inner.utd_hook_manager_data = Some(
240                    value
241                        .into_utd_hook_manager_data()
242                        .expect("Session data not the hook manager data"),
243                );
244            }
245            StateStoreDataKey::OneTimeKeyAlreadyUploaded => {
246                inner.one_time_key_uploaded_error = true;
247            }
248            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
249                inner.composer_drafts.insert(
250                    (room_id.to_owned(), thread_root.map(ToOwned::to_owned)),
251                    value.into_composer_draft().expect("Session data not a composer draft"),
252                );
253            }
254            StateStoreDataKey::SupportedVersions => {
255                inner.supported_versions = Some(
256                    value
257                        .into_supported_versions()
258                        .expect("Session data not containing supported versions"),
259                );
260            }
261            StateStoreDataKey::WellKnown => {
262                inner.well_known =
263                    Some(value.into_well_known().expect("Session data not containing well-known"));
264            }
265            StateStoreDataKey::SeenKnockRequests(room_id) => {
266                inner.seen_knock_requests.insert(
267                    room_id.to_owned(),
268                    value
269                        .into_seen_knock_requests()
270                        .expect("Session data is not a set of seen join request ids"),
271                );
272            }
273            StateStoreDataKey::ThreadSubscriptionsCatchupTokens => {
274                inner.thread_subscriptions_catchup_tokens =
275                    Some(value.into_thread_subscriptions_catchup_tokens().expect(
276                        "Session data is not a list of thread subscription catchup tokens",
277                    ));
278            }
279            StateStoreDataKey::HomeserverCapabilities => {
280                inner.homeserver_capabilities = Some(
281                    value
282                        .into_homeserver_capabilities()
283                        .expect("Session data is not a homeserver capabilities"),
284                );
285            }
286        }
287
288        Ok(())
289    }
290
291    async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> {
292        let mut inner = self.inner.write().unwrap();
293        match key {
294            StateStoreDataKey::SyncToken => inner.sync_token = None,
295            StateStoreDataKey::SupportedVersions => inner.supported_versions = None,
296            StateStoreDataKey::WellKnown => inner.well_known = None,
297            StateStoreDataKey::Filter(filter_name) => {
298                inner.filters.remove(filter_name);
299            }
300            StateStoreDataKey::UserAvatarUrl(user_id) => {
301                inner.user_avatar_url.remove(user_id);
302            }
303            StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
304                inner.recently_visited_rooms.remove(user_id);
305            }
306            StateStoreDataKey::UtdHookManagerData => inner.utd_hook_manager_data = None,
307            StateStoreDataKey::OneTimeKeyAlreadyUploaded => {
308                inner.one_time_key_uploaded_error = false
309            }
310            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
311                let key = (room_id.to_owned(), thread_root.map(ToOwned::to_owned));
312                inner.composer_drafts.remove(&key);
313            }
314            StateStoreDataKey::SeenKnockRequests(room_id) => {
315                inner.seen_knock_requests.remove(room_id);
316            }
317            StateStoreDataKey::ThreadSubscriptionsCatchupTokens => {
318                inner.thread_subscriptions_catchup_tokens = None;
319            }
320            StateStoreDataKey::HomeserverCapabilities => inner.homeserver_capabilities = None,
321        }
322        Ok(())
323    }
324
325    #[instrument(skip(self, changes))]
326    async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
327        let now = Instant::now();
328
329        let mut inner = self.inner.write().unwrap();
330
331        if let Some(s) = &changes.sync_token {
332            inner.sync_token = Some(s.to_owned());
333        }
334
335        for (room, users) in &changes.profiles_to_delete {
336            let Some(room_profiles) = inner.profiles.get_mut(room) else {
337                continue;
338            };
339            for user in users {
340                room_profiles.remove(user);
341            }
342        }
343
344        for (room, users) in &changes.profiles {
345            for (user_id, profile) in users {
346                inner
347                    .profiles
348                    .entry(room.clone())
349                    .or_default()
350                    .insert(user_id.clone(), profile.clone());
351            }
352        }
353
354        for (room, map) in &changes.ambiguity_maps {
355            for (display_name, display_names) in map {
356                inner
357                    .display_names
358                    .entry(room.clone())
359                    .or_default()
360                    .insert(display_name.clone(), display_names.clone());
361            }
362        }
363
364        for (event_type, event) in &changes.account_data {
365            inner.account_data.insert(event_type.clone(), event.clone());
366        }
367
368        for (room, events) in &changes.room_account_data {
369            for (event_type, event) in events {
370                inner
371                    .room_account_data
372                    .entry(room.clone())
373                    .or_default()
374                    .insert(event_type.clone(), event.clone());
375            }
376        }
377
378        for (room, event_types) in &changes.state {
379            for (event_type, events) in event_types {
380                for (state_key, raw_event) in events {
381                    inner
382                        .room_state
383                        .entry(room.clone())
384                        .or_default()
385                        .entry(event_type.clone())
386                        .or_default()
387                        .insert(state_key.to_owned(), raw_event.clone());
388                    inner.stripped_room_state.remove(room);
389
390                    if *event_type == StateEventType::RoomMember {
391                        let event =
392                            match raw_event.deserialize_as_unchecked::<SyncRoomMemberEvent>() {
393                                Ok(ev) => ev,
394                                Err(e) => {
395                                    let event_id: Option<String> =
396                                        raw_event.get_field("event_id").ok().flatten();
397                                    debug!(event_id, "Failed to deserialize member event: {e}");
398                                    continue;
399                                }
400                            };
401
402                        inner.stripped_members.remove(room);
403
404                        inner
405                            .members
406                            .entry(room.clone())
407                            .or_default()
408                            .insert(event.state_key().to_owned(), event.membership().clone());
409                    }
410                }
411            }
412        }
413
414        for (room_id, info) in &changes.room_infos {
415            inner.room_info.insert(room_id.clone(), info.clone());
416        }
417
418        for (sender, event) in &changes.presence {
419            inner.presence.insert(sender.clone(), event.clone());
420        }
421
422        for (room, event_types) in &changes.stripped_state {
423            for (event_type, events) in event_types {
424                for (state_key, raw_event) in events {
425                    inner
426                        .stripped_room_state
427                        .entry(room.clone())
428                        .or_default()
429                        .entry(event_type.clone())
430                        .or_default()
431                        .insert(state_key.to_owned(), raw_event.clone());
432
433                    if *event_type == StateEventType::RoomMember {
434                        let event =
435                            match raw_event.deserialize_as_unchecked::<StrippedRoomMemberEvent>() {
436                                Ok(ev) => ev,
437                                Err(e) => {
438                                    let event_id: Option<String> =
439                                        raw_event.get_field("event_id").ok().flatten();
440                                    debug!(
441                                        event_id,
442                                        "Failed to deserialize stripped member event: {e}"
443                                    );
444                                    continue;
445                                }
446                            };
447
448                        inner
449                            .stripped_members
450                            .entry(room.clone())
451                            .or_default()
452                            .insert(event.state_key, event.content.membership.clone());
453                    }
454                }
455            }
456        }
457
458        for (room, content) in &changes.receipts {
459            for (event_id, receipts) in &content.0 {
460                for (receipt_type, receipts) in receipts {
461                    for (user_id, receipt) in receipts {
462                        let thread = receipt.thread.as_str().map(ToOwned::to_owned);
463                        // Add the receipt to the room user receipts
464                        if let Some((old_event, _)) = inner
465                            .room_user_receipts
466                            .entry(room.clone())
467                            .or_default()
468                            .entry((receipt_type.to_string(), thread.clone()))
469                            .or_default()
470                            .insert(user_id.clone(), (event_id.clone(), receipt.clone()))
471                        {
472                            // Remove the old receipt from the room event receipts
473                            if let Some(receipt_map) = inner.room_event_receipts.get_mut(room)
474                                && let Some(event_map) =
475                                    receipt_map.get_mut(&(receipt_type.to_string(), thread.clone()))
476                                && let Some(user_map) = event_map.get_mut(&old_event)
477                            {
478                                user_map.remove(user_id);
479                            }
480                        }
481
482                        // Add the receipt to the room event receipts
483                        inner
484                            .room_event_receipts
485                            .entry(room.clone())
486                            .or_default()
487                            .entry((receipt_type.to_string(), thread))
488                            .or_default()
489                            .entry(event_id.clone())
490                            .or_default()
491                            .insert(user_id.clone(), receipt.clone());
492                    }
493                }
494            }
495        }
496
497        let make_redaction_rules = |room_info: &HashMap<OwnedRoomId, RoomInfo>, room_id| {
498            room_info.get(room_id).map(|info| info.room_version_rules_or_default()).unwrap_or_else(|| {
499                warn!(
500                    ?room_id,
501                    "Unable to get the room version rules, defaulting to rules for room version {ROOM_VERSION_FALLBACK}"
502                );
503                ROOM_VERSION_RULES_FALLBACK
504            }).redaction
505        };
506
507        let inner = &mut *inner;
508        for (room_id, redactions) in &changes.redactions {
509            let mut redaction_rules = None;
510
511            if let Some(room) = inner.room_state.get_mut(room_id) {
512                for ref_room_mu in room.values_mut() {
513                    for raw_evt in ref_room_mu.values_mut() {
514                        if let Ok(Some(event_id)) = raw_evt.get_field::<OwnedEventId>("event_id")
515                            && let Some(redaction) = redactions.get(&event_id)
516                        {
517                            let redacted = redact(
518                                raw_evt.deserialize_as::<CanonicalJsonObject>()?,
519                                redaction_rules.get_or_insert_with(|| {
520                                    make_redaction_rules(&inner.room_info, room_id)
521                                }),
522                                Some(RedactedBecause::from_raw_event(redaction)?),
523                            )
524                            .map_err(StoreError::Redaction)?;
525                            *raw_evt = Raw::new(&redacted)?.cast_unchecked();
526                        }
527                    }
528                }
529            }
530        }
531
532        debug!("Saved changes in {:?}", now.elapsed());
533
534        Ok(())
535    }
536
537    async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
538        Ok(self.inner.read().unwrap().presence.get(user_id).cloned())
539    }
540
541    async fn get_presence_events(
542        &self,
543        user_ids: &[OwnedUserId],
544    ) -> Result<Vec<Raw<PresenceEvent>>> {
545        let presence = &self.inner.read().unwrap().presence;
546        Ok(user_ids.iter().filter_map(|user_id| presence.get(user_id).cloned()).collect())
547    }
548
549    async fn get_state_event(
550        &self,
551        room_id: &RoomId,
552        event_type: StateEventType,
553        state_key: &str,
554    ) -> Result<Option<RawAnySyncOrStrippedState>> {
555        Ok(self
556            .get_state_events_for_keys(room_id, event_type, &[state_key])
557            .await?
558            .into_iter()
559            .next())
560    }
561
562    async fn get_state_events(
563        &self,
564        room_id: &RoomId,
565        event_type: StateEventType,
566    ) -> Result<Vec<RawAnySyncOrStrippedState>> {
567        fn get_events<T>(
568            state_map: &HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<T>>>>,
569            room_id: &RoomId,
570            event_type: &StateEventType,
571            to_enum: fn(Raw<T>) -> RawAnySyncOrStrippedState,
572        ) -> Option<Vec<RawAnySyncOrStrippedState>> {
573            let state_events = state_map.get(room_id)?.get(event_type)?;
574            Some(state_events.values().cloned().map(to_enum).collect())
575        }
576
577        let inner = self.inner.read().unwrap();
578        Ok(get_events(
579            &inner.stripped_room_state,
580            room_id,
581            &event_type,
582            RawAnySyncOrStrippedState::Stripped,
583        )
584        .or_else(|| {
585            get_events(&inner.room_state, room_id, &event_type, RawAnySyncOrStrippedState::Sync)
586        })
587        .unwrap_or_default())
588    }
589
590    async fn get_state_events_for_keys(
591        &self,
592        room_id: &RoomId,
593        event_type: StateEventType,
594        state_keys: &[&str],
595    ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
596        let inner = self.inner.read().unwrap();
597
598        if let Some(stripped_state_events) =
599            inner.stripped_room_state.get(room_id).and_then(|events| events.get(&event_type))
600        {
601            Ok(state_keys
602                .iter()
603                .filter_map(|k| {
604                    stripped_state_events
605                        .get(*k)
606                        .map(|e| RawAnySyncOrStrippedState::Stripped(e.clone()))
607                })
608                .collect())
609        } else if let Some(sync_state_events) =
610            inner.room_state.get(room_id).and_then(|events| events.get(&event_type))
611        {
612            Ok(state_keys
613                .iter()
614                .filter_map(|k| {
615                    sync_state_events.get(*k).map(|e| RawAnySyncOrStrippedState::Sync(e.clone()))
616                })
617                .collect())
618        } else {
619            Ok(Vec::new())
620        }
621    }
622
623    async fn get_profile(
624        &self,
625        room_id: &RoomId,
626        user_id: &UserId,
627    ) -> Result<Option<MinimalRoomMemberEvent>> {
628        Ok(self
629            .inner
630            .read()
631            .unwrap()
632            .profiles
633            .get(room_id)
634            .and_then(|room_profiles| room_profiles.get(user_id))
635            .cloned())
636    }
637
638    async fn get_profiles<'a>(
639        &self,
640        room_id: &RoomId,
641        user_ids: &'a [OwnedUserId],
642    ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>> {
643        if user_ids.is_empty() {
644            return Ok(BTreeMap::new());
645        }
646
647        let profiles = &self.inner.read().unwrap().profiles;
648        let Some(room_profiles) = profiles.get(room_id) else {
649            return Ok(BTreeMap::new());
650        };
651
652        Ok(user_ids
653            .iter()
654            .filter_map(|user_id| room_profiles.get(user_id).map(|p| (&**user_id, p.clone())))
655            .collect())
656    }
657
658    #[instrument(skip(self, memberships))]
659    async fn get_user_ids(
660        &self,
661        room_id: &RoomId,
662        memberships: RoomMemberships,
663    ) -> Result<Vec<OwnedUserId>> {
664        /// Get the user IDs for the given room with the given memberships and
665        /// stripped state.
666        ///
667        /// If `memberships` is empty, returns all user IDs in the room with the
668        /// given stripped state.
669        fn get_user_ids_inner(
670            members: &HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
671            room_id: &RoomId,
672            memberships: RoomMemberships,
673        ) -> Vec<OwnedUserId> {
674            members
675                .get(room_id)
676                .map(|members| {
677                    members
678                        .iter()
679                        .filter_map(|(user_id, membership)| {
680                            memberships.matches(membership).then_some(user_id)
681                        })
682                        .cloned()
683                        .collect()
684                })
685                .unwrap_or_default()
686        }
687        let inner = self.inner.read().unwrap();
688        let v = get_user_ids_inner(&inner.stripped_members, room_id, memberships);
689        if !v.is_empty() {
690            return Ok(v);
691        }
692        Ok(get_user_ids_inner(&inner.members, room_id, memberships))
693    }
694
695    async fn get_room_infos(&self, room_load_settings: &RoomLoadSettings) -> Result<Vec<RoomInfo>> {
696        let memory_store_inner = self.inner.read().unwrap();
697        let room_infos = &memory_store_inner.room_info;
698
699        Ok(match room_load_settings {
700            RoomLoadSettings::All => room_infos.values().cloned().collect(),
701
702            RoomLoadSettings::One(room_id) => match room_infos.get(room_id) {
703                Some(room_info) => vec![room_info.clone()],
704                None => vec![],
705            },
706        })
707    }
708
709    async fn get_users_with_display_name(
710        &self,
711        room_id: &RoomId,
712        display_name: &DisplayName,
713    ) -> Result<BTreeSet<OwnedUserId>> {
714        Ok(self
715            .inner
716            .read()
717            .unwrap()
718            .display_names
719            .get(room_id)
720            .and_then(|room_names| room_names.get(display_name).cloned())
721            .unwrap_or_default())
722    }
723
724    async fn get_users_with_display_names<'a>(
725        &self,
726        room_id: &RoomId,
727        display_names: &'a [DisplayName],
728    ) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>> {
729        if display_names.is_empty() {
730            return Ok(HashMap::new());
731        }
732
733        let inner = self.inner.read().unwrap();
734        let Some(room_names) = inner.display_names.get(room_id) else {
735            return Ok(HashMap::new());
736        };
737
738        Ok(display_names.iter().filter_map(|n| room_names.get(n).map(|d| (n, d.clone()))).collect())
739    }
740
741    async fn get_account_data_event(
742        &self,
743        event_type: GlobalAccountDataEventType,
744    ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
745        Ok(self.inner.read().unwrap().account_data.get(&event_type).cloned())
746    }
747
748    async fn get_room_account_data_event(
749        &self,
750        room_id: &RoomId,
751        event_type: RoomAccountDataEventType,
752    ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
753        Ok(self
754            .inner
755            .read()
756            .unwrap()
757            .room_account_data
758            .get(room_id)
759            .and_then(|m| m.get(&event_type))
760            .cloned())
761    }
762
763    async fn get_user_room_receipt_event(
764        &self,
765        room_id: &RoomId,
766        receipt_type: ReceiptType,
767        thread: ReceiptThread,
768        user_id: &UserId,
769    ) -> Result<Option<(OwnedEventId, Receipt)>> {
770        Ok(self.get_user_room_receipt_event_impl(room_id, receipt_type, thread, user_id))
771    }
772
773    async fn get_event_room_receipt_events(
774        &self,
775        room_id: &RoomId,
776        receipt_type: ReceiptType,
777        thread: ReceiptThread,
778        event_id: &EventId,
779    ) -> Result<Vec<(OwnedUserId, Receipt)>> {
780        Ok(self
781            .get_event_room_receipt_events_impl(room_id, receipt_type, thread, event_id)
782            .unwrap_or_default())
783    }
784
785    async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
786        Ok(self.inner.read().unwrap().custom.get(key).cloned())
787    }
788
789    async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
790        Ok(self.inner.write().unwrap().custom.insert(key.to_vec(), value))
791    }
792
793    async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
794        Ok(self.inner.write().unwrap().custom.remove(key))
795    }
796
797    async fn remove_room(&self, room_id: &RoomId) -> Result<()> {
798        let mut inner = self.inner.write().unwrap();
799
800        inner.profiles.remove(room_id);
801        inner.display_names.remove(room_id);
802        inner.members.remove(room_id);
803        inner.room_info.remove(room_id);
804        inner.room_state.remove(room_id);
805        inner.room_account_data.remove(room_id);
806        inner.stripped_room_state.remove(room_id);
807        inner.stripped_members.remove(room_id);
808        inner.room_user_receipts.remove(room_id);
809        inner.room_event_receipts.remove(room_id);
810        inner.send_queue_events.remove(room_id);
811        inner.dependent_send_queue_events.remove(room_id);
812        inner.thread_subscriptions.remove(room_id);
813
814        Ok(())
815    }
816
817    async fn save_send_queue_request(
818        &self,
819        room_id: &RoomId,
820        transaction_id: OwnedTransactionId,
821        created_at: MilliSecondsSinceUnixEpoch,
822        kind: QueuedRequestKind,
823        priority: usize,
824    ) -> Result<(), Self::Error> {
825        self.inner
826            .write()
827            .unwrap()
828            .send_queue_events
829            .entry(room_id.to_owned())
830            .or_default()
831            .push(QueuedRequest { kind, transaction_id, error: None, priority, created_at });
832        Ok(())
833    }
834
835    async fn update_send_queue_request(
836        &self,
837        room_id: &RoomId,
838        transaction_id: &TransactionId,
839        kind: QueuedRequestKind,
840    ) -> Result<bool, Self::Error> {
841        if let Some(entry) = self
842            .inner
843            .write()
844            .unwrap()
845            .send_queue_events
846            .entry(room_id.to_owned())
847            .or_default()
848            .iter_mut()
849            .find(|item| item.transaction_id == transaction_id)
850        {
851            entry.kind = kind;
852            entry.error = None;
853            Ok(true)
854        } else {
855            Ok(false)
856        }
857    }
858
859    async fn remove_send_queue_request(
860        &self,
861        room_id: &RoomId,
862        transaction_id: &TransactionId,
863    ) -> Result<bool, Self::Error> {
864        let mut inner = self.inner.write().unwrap();
865        let q = &mut inner.send_queue_events;
866
867        let entry = q.get_mut(room_id);
868        if let Some(entry) = entry {
869            // Find the event by id in its room queue, and remove it if present.
870            if let Some(pos) = entry.iter().position(|item| item.transaction_id == transaction_id) {
871                entry.remove(pos);
872                // And if this was the last event before removal, remove the entire room entry.
873                if entry.is_empty() {
874                    q.remove(room_id);
875                }
876                return Ok(true);
877            }
878        }
879
880        Ok(false)
881    }
882
883    async fn load_send_queue_requests(
884        &self,
885        room_id: &RoomId,
886    ) -> Result<Vec<QueuedRequest>, Self::Error> {
887        let mut ret = self
888            .inner
889            .write()
890            .unwrap()
891            .send_queue_events
892            .entry(room_id.to_owned())
893            .or_default()
894            .clone();
895        // Inverted order of priority, use stable sort to keep insertion order.
896        ret.sort_by_key(|item| Reverse(item.priority));
897        Ok(ret)
898    }
899
900    async fn update_send_queue_request_status(
901        &self,
902        room_id: &RoomId,
903        transaction_id: &TransactionId,
904        error: Option<QueueWedgeError>,
905    ) -> Result<(), Self::Error> {
906        if let Some(entry) = self
907            .inner
908            .write()
909            .unwrap()
910            .send_queue_events
911            .entry(room_id.to_owned())
912            .or_default()
913            .iter_mut()
914            .find(|item| item.transaction_id == transaction_id)
915        {
916            entry.error = error;
917        }
918        Ok(())
919    }
920
921    async fn load_rooms_with_unsent_requests(&self) -> Result<Vec<OwnedRoomId>, Self::Error> {
922        Ok(self.inner.read().unwrap().send_queue_events.keys().cloned().collect())
923    }
924
925    async fn save_dependent_queued_request(
926        &self,
927        room: &RoomId,
928        parent_transaction_id: &TransactionId,
929        own_transaction_id: ChildTransactionId,
930        created_at: MilliSecondsSinceUnixEpoch,
931        content: DependentQueuedRequestKind,
932    ) -> Result<(), Self::Error> {
933        self.inner
934            .write()
935            .unwrap()
936            .dependent_send_queue_events
937            .entry(room.to_owned())
938            .or_default()
939            .push(DependentQueuedRequest {
940                kind: content,
941                parent_transaction_id: parent_transaction_id.to_owned(),
942                own_transaction_id,
943                parent_key: None,
944                created_at,
945            });
946        Ok(())
947    }
948
949    async fn mark_dependent_queued_requests_as_ready(
950        &self,
951        room: &RoomId,
952        parent_txn_id: &TransactionId,
953        sent_parent_key: SentRequestKey,
954    ) -> Result<usize, Self::Error> {
955        let mut inner = self.inner.write().unwrap();
956        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
957        let mut num_updated = 0;
958        for d in dependents.iter_mut().filter(|item| item.parent_transaction_id == parent_txn_id) {
959            d.parent_key = Some(sent_parent_key.clone());
960            num_updated += 1;
961        }
962        Ok(num_updated)
963    }
964
965    async fn update_dependent_queued_request(
966        &self,
967        room: &RoomId,
968        own_transaction_id: &ChildTransactionId,
969        new_content: DependentQueuedRequestKind,
970    ) -> Result<bool, Self::Error> {
971        let mut inner = self.inner.write().unwrap();
972        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
973        for d in dependents.iter_mut() {
974            if d.own_transaction_id == *own_transaction_id {
975                d.kind = new_content;
976                return Ok(true);
977            }
978        }
979        Ok(false)
980    }
981
982    async fn remove_dependent_queued_request(
983        &self,
984        room: &RoomId,
985        txn_id: &ChildTransactionId,
986    ) -> Result<bool, Self::Error> {
987        let mut inner = self.inner.write().unwrap();
988        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
989        if let Some(pos) = dependents.iter().position(|item| item.own_transaction_id == *txn_id) {
990            dependents.remove(pos);
991            Ok(true)
992        } else {
993            Ok(false)
994        }
995    }
996
997    async fn load_dependent_queued_requests(
998        &self,
999        room: &RoomId,
1000    ) -> Result<Vec<DependentQueuedRequest>, Self::Error> {
1001        Ok(self
1002            .inner
1003            .read()
1004            .unwrap()
1005            .dependent_send_queue_events
1006            .get(room)
1007            .cloned()
1008            .unwrap_or_default())
1009    }
1010
1011    async fn upsert_thread_subscriptions(
1012        &self,
1013        updates: Vec<(&RoomId, &EventId, StoredThreadSubscription)>,
1014    ) -> Result<(), Self::Error> {
1015        let mut inner = self.inner.write().unwrap();
1016
1017        for (room_id, thread_id, mut new) in updates {
1018            let room_subs = inner.thread_subscriptions.entry(room_id.to_owned()).or_default();
1019
1020            if let Some(previous) = room_subs.get(thread_id) {
1021                if *previous == new {
1022                    continue;
1023                }
1024                if !compare_thread_subscription_bump_stamps(
1025                    previous.bump_stamp,
1026                    &mut new.bump_stamp,
1027                ) {
1028                    continue;
1029                }
1030            }
1031
1032            room_subs.insert(thread_id.to_owned(), new);
1033        }
1034
1035        Ok(())
1036    }
1037
1038    async fn load_thread_subscription(
1039        &self,
1040        room: &RoomId,
1041        thread_id: &EventId,
1042    ) -> Result<Option<StoredThreadSubscription>, Self::Error> {
1043        let inner = self.inner.read().unwrap();
1044        Ok(inner
1045            .thread_subscriptions
1046            .get(room)
1047            .and_then(|subscriptions| subscriptions.get(thread_id))
1048            .copied())
1049    }
1050
1051    async fn remove_thread_subscription(
1052        &self,
1053        room: &RoomId,
1054        thread_id: &EventId,
1055    ) -> Result<(), Self::Error> {
1056        let mut inner = self.inner.write().unwrap();
1057
1058        let Some(room_subs) = inner.thread_subscriptions.get_mut(room) else {
1059            return Ok(());
1060        };
1061
1062        room_subs.remove(thread_id);
1063
1064        if room_subs.is_empty() {
1065            // If there are no more subscriptions for this room, remove the room entry.
1066            inner.thread_subscriptions.remove(room);
1067        }
1068
1069        Ok(())
1070    }
1071
1072    async fn optimize(&self) -> Result<(), Self::Error> {
1073        Ok(())
1074    }
1075
1076    async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1077        Ok(None)
1078    }
1079}
1080
1081#[cfg(test)]
1082mod tests {
1083    use super::{MemoryStore, Result, StateStore};
1084
1085    async fn get_store() -> Result<impl StateStore> {
1086        Ok(MemoryStore::new())
1087    }
1088
1089    statestore_integration_tests!();
1090}