matrix_sdk_base/store/
memory_store.rs

1// Copyright 2021 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{BTreeMap, BTreeSet, HashMap},
17    sync::RwLock,
18};
19
20use async_trait::async_trait;
21use growable_bloom_filter::GrowableBloom;
22use ruma::{
23    canonical_json::{redact, RedactedBecause},
24    events::{
25        presence::PresenceEvent,
26        receipt::{Receipt, ReceiptThread, ReceiptType},
27        room::member::{MembershipState, StrippedRoomMemberEvent, SyncRoomMemberEvent},
28        AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
29        AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType,
30    },
31    serde::Raw,
32    time::Instant,
33    CanonicalJsonObject, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedMxcUri,
34    OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, RoomVersionId, TransactionId, UserId,
35};
36use tracing::{debug, instrument, warn};
37
38use super::{
39    send_queue::{ChildTransactionId, QueuedRequest, SentRequestKey},
40    traits::{ComposerDraft, ServerCapabilities},
41    DependentQueuedRequest, DependentQueuedRequestKind, QueuedRequestKind, Result, RoomInfo,
42    StateChanges, StateStore, StoreError,
43};
44use crate::{
45    deserialized_responses::{DisplayName, RawAnySyncOrStrippedState},
46    store::QueueWedgeError,
47    MinimalRoomMemberEvent, RoomMemberships, StateStoreDataKey, StateStoreDataValue,
48};
49
50#[derive(Debug, Default)]
51#[allow(clippy::type_complexity)]
52struct MemoryStoreInner {
53    recently_visited_rooms: HashMap<OwnedUserId, Vec<OwnedRoomId>>,
54    composer_drafts: HashMap<OwnedRoomId, ComposerDraft>,
55    user_avatar_url: HashMap<OwnedUserId, OwnedMxcUri>,
56    sync_token: Option<String>,
57    server_capabilities: Option<ServerCapabilities>,
58    filters: HashMap<String, String>,
59    utd_hook_manager_data: Option<GrowableBloom>,
60    account_data: HashMap<GlobalAccountDataEventType, Raw<AnyGlobalAccountDataEvent>>,
61    profiles: HashMap<OwnedRoomId, HashMap<OwnedUserId, MinimalRoomMemberEvent>>,
62    display_names: HashMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>,
63    members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
64    room_info: HashMap<OwnedRoomId, RoomInfo>,
65    room_state:
66        HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnySyncStateEvent>>>>,
67    room_account_data:
68        HashMap<OwnedRoomId, HashMap<RoomAccountDataEventType, Raw<AnyRoomAccountDataEvent>>>,
69    stripped_room_state:
70        HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnyStrippedStateEvent>>>>,
71    stripped_members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
72    presence: HashMap<OwnedUserId, Raw<PresenceEvent>>,
73    room_user_receipts: HashMap<
74        OwnedRoomId,
75        HashMap<(String, Option<String>), HashMap<OwnedUserId, (OwnedEventId, Receipt)>>,
76    >,
77
78    room_event_receipts: HashMap<
79        OwnedRoomId,
80        HashMap<(String, Option<String>), HashMap<OwnedEventId, HashMap<OwnedUserId, Receipt>>>,
81    >,
82    custom: HashMap<Vec<u8>, Vec<u8>>,
83    send_queue_events: BTreeMap<OwnedRoomId, Vec<QueuedRequest>>,
84    dependent_send_queue_events: BTreeMap<OwnedRoomId, Vec<DependentQueuedRequest>>,
85    seen_knock_requests: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, OwnedUserId>>,
86}
87
88/// In-memory, non-persistent implementation of the `StateStore`.
89///
90/// Default if no other is configured at startup.
91#[derive(Debug, Default)]
92pub struct MemoryStore {
93    inner: RwLock<MemoryStoreInner>,
94}
95
96impl MemoryStore {
97    /// Create a new empty MemoryStore
98    pub fn new() -> Self {
99        Self::default()
100    }
101
102    fn get_user_room_receipt_event_impl(
103        &self,
104        room_id: &RoomId,
105        receipt_type: ReceiptType,
106        thread: ReceiptThread,
107        user_id: &UserId,
108    ) -> Option<(OwnedEventId, Receipt)> {
109        self.inner
110            .read()
111            .unwrap()
112            .room_user_receipts
113            .get(room_id)?
114            .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
115            .get(user_id)
116            .cloned()
117    }
118
119    fn get_event_room_receipt_events_impl(
120        &self,
121        room_id: &RoomId,
122        receipt_type: ReceiptType,
123        thread: ReceiptThread,
124        event_id: &EventId,
125    ) -> Option<Vec<(OwnedUserId, Receipt)>> {
126        Some(
127            self.inner
128                .read()
129                .unwrap()
130                .room_event_receipts
131                .get(room_id)?
132                .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
133                .get(event_id)?
134                .iter()
135                .map(|(key, value)| (key.clone(), value.clone()))
136                .collect(),
137        )
138    }
139}
140
141#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
142#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
143impl StateStore for MemoryStore {
144    type Error = StoreError;
145
146    async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<Option<StateStoreDataValue>> {
147        let inner = self.inner.read().unwrap();
148        Ok(match key {
149            StateStoreDataKey::SyncToken => {
150                inner.sync_token.clone().map(StateStoreDataValue::SyncToken)
151            }
152            StateStoreDataKey::ServerCapabilities => {
153                inner.server_capabilities.clone().map(StateStoreDataValue::ServerCapabilities)
154            }
155            StateStoreDataKey::Filter(filter_name) => {
156                inner.filters.get(filter_name).cloned().map(StateStoreDataValue::Filter)
157            }
158            StateStoreDataKey::UserAvatarUrl(user_id) => {
159                inner.user_avatar_url.get(user_id).cloned().map(StateStoreDataValue::UserAvatarUrl)
160            }
161            StateStoreDataKey::RecentlyVisitedRooms(user_id) => inner
162                .recently_visited_rooms
163                .get(user_id)
164                .cloned()
165                .map(StateStoreDataValue::RecentlyVisitedRooms),
166            StateStoreDataKey::UtdHookManagerData => {
167                inner.utd_hook_manager_data.clone().map(StateStoreDataValue::UtdHookManagerData)
168            }
169            StateStoreDataKey::ComposerDraft(room_id) => {
170                inner.composer_drafts.get(room_id).cloned().map(StateStoreDataValue::ComposerDraft)
171            }
172            StateStoreDataKey::SeenKnockRequests(room_id) => inner
173                .seen_knock_requests
174                .get(room_id)
175                .cloned()
176                .map(StateStoreDataValue::SeenKnockRequests),
177        })
178    }
179
180    async fn set_kv_data(
181        &self,
182        key: StateStoreDataKey<'_>,
183        value: StateStoreDataValue,
184    ) -> Result<()> {
185        let mut inner = self.inner.write().unwrap();
186        match key {
187            StateStoreDataKey::SyncToken => {
188                inner.sync_token =
189                    Some(value.into_sync_token().expect("Session data not a sync token"))
190            }
191            StateStoreDataKey::Filter(filter_name) => {
192                inner.filters.insert(
193                    filter_name.to_owned(),
194                    value.into_filter().expect("Session data not a filter"),
195                );
196            }
197            StateStoreDataKey::UserAvatarUrl(user_id) => {
198                inner.user_avatar_url.insert(
199                    user_id.to_owned(),
200                    value.into_user_avatar_url().expect("Session data not a user avatar url"),
201                );
202            }
203            StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
204                inner.recently_visited_rooms.insert(
205                    user_id.to_owned(),
206                    value
207                        .into_recently_visited_rooms()
208                        .expect("Session data not a list of recently visited rooms"),
209                );
210            }
211            StateStoreDataKey::UtdHookManagerData => {
212                inner.utd_hook_manager_data = Some(
213                    value
214                        .into_utd_hook_manager_data()
215                        .expect("Session data not the hook manager data"),
216                );
217            }
218            StateStoreDataKey::ComposerDraft(room_id) => {
219                inner.composer_drafts.insert(
220                    room_id.to_owned(),
221                    value.into_composer_draft().expect("Session data not a composer draft"),
222                );
223            }
224            StateStoreDataKey::ServerCapabilities => {
225                inner.server_capabilities = Some(
226                    value
227                        .into_server_capabilities()
228                        .expect("Session data not containing server capabilities"),
229                );
230            }
231            StateStoreDataKey::SeenKnockRequests(room_id) => {
232                inner.seen_knock_requests.insert(
233                    room_id.to_owned(),
234                    value
235                        .into_seen_knock_requests()
236                        .expect("Session data is not a set of seen join request ids"),
237                );
238            }
239        }
240
241        Ok(())
242    }
243
244    async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> {
245        let mut inner = self.inner.write().unwrap();
246        match key {
247            StateStoreDataKey::SyncToken => inner.sync_token = None,
248            StateStoreDataKey::ServerCapabilities => inner.server_capabilities = None,
249            StateStoreDataKey::Filter(filter_name) => {
250                inner.filters.remove(filter_name);
251            }
252            StateStoreDataKey::UserAvatarUrl(user_id) => {
253                inner.user_avatar_url.remove(user_id);
254            }
255            StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
256                inner.recently_visited_rooms.remove(user_id);
257            }
258            StateStoreDataKey::UtdHookManagerData => inner.utd_hook_manager_data = None,
259            StateStoreDataKey::ComposerDraft(room_id) => {
260                inner.composer_drafts.remove(room_id);
261            }
262            StateStoreDataKey::SeenKnockRequests(room_id) => {
263                inner.seen_knock_requests.remove(room_id);
264            }
265        }
266        Ok(())
267    }
268
269    #[instrument(skip(self, changes))]
270    async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
271        let now = Instant::now();
272
273        let mut inner = self.inner.write().unwrap();
274
275        if let Some(s) = &changes.sync_token {
276            inner.sync_token = Some(s.to_owned());
277        }
278
279        for (room, users) in &changes.profiles_to_delete {
280            let Some(room_profiles) = inner.profiles.get_mut(room) else {
281                continue;
282            };
283            for user in users {
284                room_profiles.remove(user);
285            }
286        }
287
288        for (room, users) in &changes.profiles {
289            for (user_id, profile) in users {
290                inner
291                    .profiles
292                    .entry(room.clone())
293                    .or_default()
294                    .insert(user_id.clone(), profile.clone());
295            }
296        }
297
298        for (room, map) in &changes.ambiguity_maps {
299            for (display_name, display_names) in map {
300                inner
301                    .display_names
302                    .entry(room.clone())
303                    .or_default()
304                    .insert(display_name.clone(), display_names.clone());
305            }
306        }
307
308        for (event_type, event) in &changes.account_data {
309            inner.account_data.insert(event_type.clone(), event.clone());
310        }
311
312        for (room, events) in &changes.room_account_data {
313            for (event_type, event) in events {
314                inner
315                    .room_account_data
316                    .entry(room.clone())
317                    .or_default()
318                    .insert(event_type.clone(), event.clone());
319            }
320        }
321
322        for (room, event_types) in &changes.state {
323            for (event_type, events) in event_types {
324                for (state_key, raw_event) in events {
325                    inner
326                        .room_state
327                        .entry(room.clone())
328                        .or_default()
329                        .entry(event_type.clone())
330                        .or_default()
331                        .insert(state_key.to_owned(), raw_event.clone());
332                    inner.stripped_room_state.remove(room);
333
334                    if *event_type == StateEventType::RoomMember {
335                        let event = match raw_event.deserialize_as::<SyncRoomMemberEvent>() {
336                            Ok(ev) => ev,
337                            Err(e) => {
338                                let event_id: Option<String> =
339                                    raw_event.get_field("event_id").ok().flatten();
340                                debug!(event_id, "Failed to deserialize member event: {e}");
341                                continue;
342                            }
343                        };
344
345                        inner.stripped_members.remove(room);
346
347                        inner
348                            .members
349                            .entry(room.clone())
350                            .or_default()
351                            .insert(event.state_key().to_owned(), event.membership().clone());
352                    }
353                }
354            }
355        }
356
357        for (room_id, info) in &changes.room_infos {
358            inner.room_info.insert(room_id.clone(), info.clone());
359        }
360
361        for (sender, event) in &changes.presence {
362            inner.presence.insert(sender.clone(), event.clone());
363        }
364
365        for (room, event_types) in &changes.stripped_state {
366            for (event_type, events) in event_types {
367                for (state_key, raw_event) in events {
368                    inner
369                        .stripped_room_state
370                        .entry(room.clone())
371                        .or_default()
372                        .entry(event_type.clone())
373                        .or_default()
374                        .insert(state_key.to_owned(), raw_event.clone());
375
376                    if *event_type == StateEventType::RoomMember {
377                        let event = match raw_event.deserialize_as::<StrippedRoomMemberEvent>() {
378                            Ok(ev) => ev,
379                            Err(e) => {
380                                let event_id: Option<String> =
381                                    raw_event.get_field("event_id").ok().flatten();
382                                debug!(
383                                    event_id,
384                                    "Failed to deserialize stripped member event: {e}"
385                                );
386                                continue;
387                            }
388                        };
389
390                        inner
391                            .stripped_members
392                            .entry(room.clone())
393                            .or_default()
394                            .insert(event.state_key, event.content.membership.clone());
395                    }
396                }
397            }
398        }
399
400        for (room, content) in &changes.receipts {
401            for (event_id, receipts) in &content.0 {
402                for (receipt_type, receipts) in receipts {
403                    for (user_id, receipt) in receipts {
404                        let thread = receipt.thread.as_str().map(ToOwned::to_owned);
405                        // Add the receipt to the room user receipts
406                        if let Some((old_event, _)) = inner
407                            .room_user_receipts
408                            .entry(room.clone())
409                            .or_default()
410                            .entry((receipt_type.to_string(), thread.clone()))
411                            .or_default()
412                            .insert(user_id.clone(), (event_id.clone(), receipt.clone()))
413                        {
414                            // Remove the old receipt from the room event receipts
415                            if let Some(receipt_map) = inner.room_event_receipts.get_mut(room) {
416                                if let Some(event_map) =
417                                    receipt_map.get_mut(&(receipt_type.to_string(), thread.clone()))
418                                {
419                                    if let Some(user_map) = event_map.get_mut(&old_event) {
420                                        user_map.remove(user_id);
421                                    }
422                                }
423                            }
424                        }
425
426                        // Add the receipt to the room event receipts
427                        inner
428                            .room_event_receipts
429                            .entry(room.clone())
430                            .or_default()
431                            .entry((receipt_type.to_string(), thread))
432                            .or_default()
433                            .entry(event_id.clone())
434                            .or_default()
435                            .insert(user_id.clone(), receipt.clone());
436                    }
437                }
438            }
439        }
440
441        let make_room_version = |room_info: &HashMap<OwnedRoomId, RoomInfo>, room_id| {
442            room_info.get(room_id).and_then(|info| info.room_version().cloned()).unwrap_or_else(
443                || {
444                    warn!(?room_id, "Unable to find the room version, assuming version 9");
445                    RoomVersionId::V9
446                },
447            )
448        };
449
450        let inner = &mut *inner;
451        for (room_id, redactions) in &changes.redactions {
452            let mut room_version = None;
453
454            if let Some(room) = inner.room_state.get_mut(room_id) {
455                for ref_room_mu in room.values_mut() {
456                    for raw_evt in ref_room_mu.values_mut() {
457                        if let Ok(Some(event_id)) = raw_evt.get_field::<OwnedEventId>("event_id") {
458                            if let Some(redaction) = redactions.get(&event_id) {
459                                let redacted = redact(
460                                    raw_evt.deserialize_as::<CanonicalJsonObject>()?,
461                                    room_version.get_or_insert_with(|| {
462                                        make_room_version(&inner.room_info, room_id)
463                                    }),
464                                    Some(RedactedBecause::from_raw_event(redaction)?),
465                                )
466                                .map_err(StoreError::Redaction)?;
467                                *raw_evt = Raw::new(&redacted)?.cast();
468                            }
469                        }
470                    }
471                }
472            }
473        }
474
475        debug!("Saved changes in {:?}", now.elapsed());
476
477        Ok(())
478    }
479
480    async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
481        Ok(self.inner.read().unwrap().presence.get(user_id).cloned())
482    }
483
484    async fn get_presence_events(
485        &self,
486        user_ids: &[OwnedUserId],
487    ) -> Result<Vec<Raw<PresenceEvent>>> {
488        let presence = &self.inner.read().unwrap().presence;
489        Ok(user_ids.iter().filter_map(|user_id| presence.get(user_id).cloned()).collect())
490    }
491
492    async fn get_state_event(
493        &self,
494        room_id: &RoomId,
495        event_type: StateEventType,
496        state_key: &str,
497    ) -> Result<Option<RawAnySyncOrStrippedState>> {
498        Ok(self
499            .get_state_events_for_keys(room_id, event_type, &[state_key])
500            .await?
501            .into_iter()
502            .next())
503    }
504
505    async fn get_state_events(
506        &self,
507        room_id: &RoomId,
508        event_type: StateEventType,
509    ) -> Result<Vec<RawAnySyncOrStrippedState>> {
510        fn get_events<T>(
511            state_map: &HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<T>>>>,
512            room_id: &RoomId,
513            event_type: &StateEventType,
514            to_enum: fn(Raw<T>) -> RawAnySyncOrStrippedState,
515        ) -> Option<Vec<RawAnySyncOrStrippedState>> {
516            let state_events = state_map.get(room_id)?.get(event_type)?;
517            Some(state_events.values().cloned().map(to_enum).collect())
518        }
519
520        let inner = self.inner.read().unwrap();
521        Ok(get_events(
522            &inner.stripped_room_state,
523            room_id,
524            &event_type,
525            RawAnySyncOrStrippedState::Stripped,
526        )
527        .or_else(|| {
528            get_events(&inner.room_state, room_id, &event_type, RawAnySyncOrStrippedState::Sync)
529        })
530        .unwrap_or_default())
531    }
532
533    async fn get_state_events_for_keys(
534        &self,
535        room_id: &RoomId,
536        event_type: StateEventType,
537        state_keys: &[&str],
538    ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
539        let inner = self.inner.read().unwrap();
540
541        if let Some(stripped_state_events) =
542            inner.stripped_room_state.get(room_id).and_then(|events| events.get(&event_type))
543        {
544            Ok(state_keys
545                .iter()
546                .filter_map(|k| {
547                    stripped_state_events
548                        .get(*k)
549                        .map(|e| RawAnySyncOrStrippedState::Stripped(e.clone()))
550                })
551                .collect())
552        } else if let Some(sync_state_events) =
553            inner.room_state.get(room_id).and_then(|events| events.get(&event_type))
554        {
555            Ok(state_keys
556                .iter()
557                .filter_map(|k| {
558                    sync_state_events.get(*k).map(|e| RawAnySyncOrStrippedState::Sync(e.clone()))
559                })
560                .collect())
561        } else {
562            Ok(Vec::new())
563        }
564    }
565
566    async fn get_profile(
567        &self,
568        room_id: &RoomId,
569        user_id: &UserId,
570    ) -> Result<Option<MinimalRoomMemberEvent>> {
571        Ok(self
572            .inner
573            .read()
574            .unwrap()
575            .profiles
576            .get(room_id)
577            .and_then(|room_profiles| room_profiles.get(user_id))
578            .cloned())
579    }
580
581    async fn get_profiles<'a>(
582        &self,
583        room_id: &RoomId,
584        user_ids: &'a [OwnedUserId],
585    ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>> {
586        if user_ids.is_empty() {
587            return Ok(BTreeMap::new());
588        }
589
590        let profiles = &self.inner.read().unwrap().profiles;
591        let Some(room_profiles) = profiles.get(room_id) else {
592            return Ok(BTreeMap::new());
593        };
594
595        Ok(user_ids
596            .iter()
597            .filter_map(|user_id| room_profiles.get(user_id).map(|p| (&**user_id, p.clone())))
598            .collect())
599    }
600
601    #[instrument(skip(self, memberships))]
602    async fn get_user_ids(
603        &self,
604        room_id: &RoomId,
605        memberships: RoomMemberships,
606    ) -> Result<Vec<OwnedUserId>> {
607        /// Get the user IDs for the given room with the given memberships and
608        /// stripped state.
609        ///
610        /// If `memberships` is empty, returns all user IDs in the room with the
611        /// given stripped state.
612        fn get_user_ids_inner(
613            members: &HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
614            room_id: &RoomId,
615            memberships: RoomMemberships,
616        ) -> Vec<OwnedUserId> {
617            members
618                .get(room_id)
619                .map(|members| {
620                    members
621                        .iter()
622                        .filter_map(|(user_id, membership)| {
623                            memberships.matches(membership).then_some(user_id)
624                        })
625                        .cloned()
626                        .collect()
627                })
628                .unwrap_or_default()
629        }
630        let inner = self.inner.read().unwrap();
631        let v = get_user_ids_inner(&inner.stripped_members, room_id, memberships);
632        if !v.is_empty() {
633            return Ok(v);
634        }
635        Ok(get_user_ids_inner(&inner.members, room_id, memberships))
636    }
637
638    async fn get_room_infos(&self) -> Result<Vec<RoomInfo>> {
639        Ok(self.inner.read().unwrap().room_info.values().cloned().collect())
640    }
641
642    async fn get_users_with_display_name(
643        &self,
644        room_id: &RoomId,
645        display_name: &DisplayName,
646    ) -> Result<BTreeSet<OwnedUserId>> {
647        Ok(self
648            .inner
649            .read()
650            .unwrap()
651            .display_names
652            .get(room_id)
653            .and_then(|room_names| room_names.get(display_name).cloned())
654            .unwrap_or_default())
655    }
656
657    async fn get_users_with_display_names<'a>(
658        &self,
659        room_id: &RoomId,
660        display_names: &'a [DisplayName],
661    ) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>> {
662        if display_names.is_empty() {
663            return Ok(HashMap::new());
664        }
665
666        let inner = self.inner.read().unwrap();
667        let Some(room_names) = inner.display_names.get(room_id) else {
668            return Ok(HashMap::new());
669        };
670
671        Ok(display_names.iter().filter_map(|n| room_names.get(n).map(|d| (n, d.clone()))).collect())
672    }
673
674    async fn get_account_data_event(
675        &self,
676        event_type: GlobalAccountDataEventType,
677    ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
678        Ok(self.inner.read().unwrap().account_data.get(&event_type).cloned())
679    }
680
681    async fn get_room_account_data_event(
682        &self,
683        room_id: &RoomId,
684        event_type: RoomAccountDataEventType,
685    ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
686        Ok(self
687            .inner
688            .read()
689            .unwrap()
690            .room_account_data
691            .get(room_id)
692            .and_then(|m| m.get(&event_type))
693            .cloned())
694    }
695
696    async fn get_user_room_receipt_event(
697        &self,
698        room_id: &RoomId,
699        receipt_type: ReceiptType,
700        thread: ReceiptThread,
701        user_id: &UserId,
702    ) -> Result<Option<(OwnedEventId, Receipt)>> {
703        Ok(self.get_user_room_receipt_event_impl(room_id, receipt_type, thread, user_id))
704    }
705
706    async fn get_event_room_receipt_events(
707        &self,
708        room_id: &RoomId,
709        receipt_type: ReceiptType,
710        thread: ReceiptThread,
711        event_id: &EventId,
712    ) -> Result<Vec<(OwnedUserId, Receipt)>> {
713        Ok(self
714            .get_event_room_receipt_events_impl(room_id, receipt_type, thread, event_id)
715            .unwrap_or_default())
716    }
717
718    async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
719        Ok(self.inner.read().unwrap().custom.get(key).cloned())
720    }
721
722    async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
723        Ok(self.inner.write().unwrap().custom.insert(key.to_vec(), value))
724    }
725
726    async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
727        Ok(self.inner.write().unwrap().custom.remove(key))
728    }
729
730    async fn remove_room(&self, room_id: &RoomId) -> Result<()> {
731        let mut inner = self.inner.write().unwrap();
732
733        inner.profiles.remove(room_id);
734        inner.display_names.remove(room_id);
735        inner.members.remove(room_id);
736        inner.room_info.remove(room_id);
737        inner.room_state.remove(room_id);
738        inner.room_account_data.remove(room_id);
739        inner.stripped_room_state.remove(room_id);
740        inner.stripped_members.remove(room_id);
741        inner.room_user_receipts.remove(room_id);
742        inner.room_event_receipts.remove(room_id);
743        inner.send_queue_events.remove(room_id);
744        inner.dependent_send_queue_events.remove(room_id);
745
746        Ok(())
747    }
748
749    async fn save_send_queue_request(
750        &self,
751        room_id: &RoomId,
752        transaction_id: OwnedTransactionId,
753        created_at: MilliSecondsSinceUnixEpoch,
754        kind: QueuedRequestKind,
755        priority: usize,
756    ) -> Result<(), Self::Error> {
757        self.inner
758            .write()
759            .unwrap()
760            .send_queue_events
761            .entry(room_id.to_owned())
762            .or_default()
763            .push(QueuedRequest { kind, transaction_id, error: None, priority, created_at });
764        Ok(())
765    }
766
767    async fn update_send_queue_request(
768        &self,
769        room_id: &RoomId,
770        transaction_id: &TransactionId,
771        kind: QueuedRequestKind,
772    ) -> Result<bool, Self::Error> {
773        if let Some(entry) = self
774            .inner
775            .write()
776            .unwrap()
777            .send_queue_events
778            .entry(room_id.to_owned())
779            .or_default()
780            .iter_mut()
781            .find(|item| item.transaction_id == transaction_id)
782        {
783            entry.kind = kind;
784            entry.error = None;
785            Ok(true)
786        } else {
787            Ok(false)
788        }
789    }
790
791    async fn remove_send_queue_request(
792        &self,
793        room_id: &RoomId,
794        transaction_id: &TransactionId,
795    ) -> Result<bool, Self::Error> {
796        let mut inner = self.inner.write().unwrap();
797        let q = &mut inner.send_queue_events;
798
799        let entry = q.get_mut(room_id);
800        if let Some(entry) = entry {
801            // Find the event by id in its room queue, and remove it if present.
802            if let Some(pos) = entry.iter().position(|item| item.transaction_id == transaction_id) {
803                entry.remove(pos);
804                // And if this was the last event before removal, remove the entire room entry.
805                if entry.is_empty() {
806                    q.remove(room_id);
807                }
808                return Ok(true);
809            }
810        }
811
812        Ok(false)
813    }
814
815    async fn load_send_queue_requests(
816        &self,
817        room_id: &RoomId,
818    ) -> Result<Vec<QueuedRequest>, Self::Error> {
819        let mut ret = self
820            .inner
821            .write()
822            .unwrap()
823            .send_queue_events
824            .entry(room_id.to_owned())
825            .or_default()
826            .clone();
827        // Inverted order of priority, use stable sort to keep insertion order.
828        ret.sort_by(|lhs, rhs| rhs.priority.cmp(&lhs.priority));
829        Ok(ret)
830    }
831
832    async fn update_send_queue_request_status(
833        &self,
834        room_id: &RoomId,
835        transaction_id: &TransactionId,
836        error: Option<QueueWedgeError>,
837    ) -> Result<(), Self::Error> {
838        if let Some(entry) = self
839            .inner
840            .write()
841            .unwrap()
842            .send_queue_events
843            .entry(room_id.to_owned())
844            .or_default()
845            .iter_mut()
846            .find(|item| item.transaction_id == transaction_id)
847        {
848            entry.error = error;
849        }
850        Ok(())
851    }
852
853    async fn load_rooms_with_unsent_requests(&self) -> Result<Vec<OwnedRoomId>, Self::Error> {
854        Ok(self.inner.read().unwrap().send_queue_events.keys().cloned().collect())
855    }
856
857    async fn save_dependent_queued_request(
858        &self,
859        room: &RoomId,
860        parent_transaction_id: &TransactionId,
861        own_transaction_id: ChildTransactionId,
862        created_at: MilliSecondsSinceUnixEpoch,
863        content: DependentQueuedRequestKind,
864    ) -> Result<(), Self::Error> {
865        self.inner
866            .write()
867            .unwrap()
868            .dependent_send_queue_events
869            .entry(room.to_owned())
870            .or_default()
871            .push(DependentQueuedRequest {
872                kind: content,
873                parent_transaction_id: parent_transaction_id.to_owned(),
874                own_transaction_id,
875                parent_key: None,
876                created_at,
877            });
878        Ok(())
879    }
880
881    async fn mark_dependent_queued_requests_as_ready(
882        &self,
883        room: &RoomId,
884        parent_txn_id: &TransactionId,
885        sent_parent_key: SentRequestKey,
886    ) -> Result<usize, Self::Error> {
887        let mut inner = self.inner.write().unwrap();
888        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
889        let mut num_updated = 0;
890        for d in dependents.iter_mut().filter(|item| item.parent_transaction_id == parent_txn_id) {
891            d.parent_key = Some(sent_parent_key.clone());
892            num_updated += 1;
893        }
894        Ok(num_updated)
895    }
896
897    async fn update_dependent_queued_request(
898        &self,
899        room: &RoomId,
900        own_transaction_id: &ChildTransactionId,
901        new_content: DependentQueuedRequestKind,
902    ) -> Result<bool, Self::Error> {
903        let mut inner = self.inner.write().unwrap();
904        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
905        for d in dependents.iter_mut() {
906            if d.own_transaction_id == *own_transaction_id {
907                d.kind = new_content;
908                return Ok(true);
909            }
910        }
911        Ok(false)
912    }
913
914    async fn remove_dependent_queued_request(
915        &self,
916        room: &RoomId,
917        txn_id: &ChildTransactionId,
918    ) -> Result<bool, Self::Error> {
919        let mut inner = self.inner.write().unwrap();
920        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
921        if let Some(pos) = dependents.iter().position(|item| item.own_transaction_id == *txn_id) {
922            dependents.remove(pos);
923            Ok(true)
924        } else {
925            Ok(false)
926        }
927    }
928
929    /// List all the dependent send queue events.
930    ///
931    /// This returns absolutely all the dependent send queue events, whether
932    /// they have an event id or not.
933    async fn load_dependent_queued_requests(
934        &self,
935        room: &RoomId,
936    ) -> Result<Vec<DependentQueuedRequest>, Self::Error> {
937        Ok(self
938            .inner
939            .read()
940            .unwrap()
941            .dependent_send_queue_events
942            .get(room)
943            .cloned()
944            .unwrap_or_default())
945    }
946}
947
948#[cfg(test)]
949mod tests {
950    use super::{MemoryStore, Result, StateStore};
951
952    async fn get_store() -> Result<impl StateStore> {
953        Ok(MemoryStore::new())
954    }
955
956    statestore_integration_tests!();
957}