1use std::{
16 collections::{BTreeMap, BTreeSet, HashMap},
17 sync::RwLock,
18};
19
20use async_trait::async_trait;
21use growable_bloom_filter::GrowableBloom;
22use matrix_sdk_common::{ROOM_VERSION_FALLBACK, ROOM_VERSION_RULES_FALLBACK};
23use ruma::{
24 CanonicalJsonObject, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedMxcUri,
25 OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId,
26 canonical_json::{RedactedBecause, redact},
27 events::{
28 AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
29 AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType,
30 presence::PresenceEvent,
31 receipt::{Receipt, ReceiptThread, ReceiptType},
32 room::member::{MembershipState, StrippedRoomMemberEvent, SyncRoomMemberEvent},
33 },
34 serde::Raw,
35 time::Instant,
36};
37use tracing::{debug, instrument, warn};
38
39use super::{
40 DependentQueuedRequest, DependentQueuedRequestKind, QueuedRequestKind, Result, RoomInfo,
41 RoomLoadSettings, StateChanges, StateStore, StoreError,
42 send_queue::{ChildTransactionId, QueuedRequest, SentRequestKey},
43 traits::{ComposerDraft, ServerInfo},
44};
45use crate::{
46 MinimalRoomMemberEvent, RoomMemberships, StateStoreDataKey, StateStoreDataValue,
47 deserialized_responses::{DisplayName, RawAnySyncOrStrippedState},
48 store::{
49 QueueWedgeError, StoredThreadSubscription,
50 traits::{ThreadSubscriptionCatchupToken, compare_thread_subscription_bump_stamps},
51 },
52};
53
54#[derive(Debug, Default)]
55#[allow(clippy::type_complexity)]
56struct MemoryStoreInner {
57 recently_visited_rooms: HashMap<OwnedUserId, Vec<OwnedRoomId>>,
58 composer_drafts: HashMap<(OwnedRoomId, Option<OwnedEventId>), ComposerDraft>,
59 user_avatar_url: HashMap<OwnedUserId, OwnedMxcUri>,
60 sync_token: Option<String>,
61 server_info: Option<ServerInfo>,
62 filters: HashMap<String, String>,
63 utd_hook_manager_data: Option<GrowableBloom>,
64 one_time_key_uploaded_error: bool,
65 account_data: HashMap<GlobalAccountDataEventType, Raw<AnyGlobalAccountDataEvent>>,
66 profiles: HashMap<OwnedRoomId, HashMap<OwnedUserId, MinimalRoomMemberEvent>>,
67 display_names: HashMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>,
68 members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
69 room_info: HashMap<OwnedRoomId, RoomInfo>,
70 room_state:
71 HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnySyncStateEvent>>>>,
72 room_account_data:
73 HashMap<OwnedRoomId, HashMap<RoomAccountDataEventType, Raw<AnyRoomAccountDataEvent>>>,
74 stripped_room_state:
75 HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnyStrippedStateEvent>>>>,
76 stripped_members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
77 presence: HashMap<OwnedUserId, Raw<PresenceEvent>>,
78 room_user_receipts: HashMap<
79 OwnedRoomId,
80 HashMap<(String, Option<String>), HashMap<OwnedUserId, (OwnedEventId, Receipt)>>,
81 >,
82 room_event_receipts: HashMap<
83 OwnedRoomId,
84 HashMap<(String, Option<String>), HashMap<OwnedEventId, HashMap<OwnedUserId, Receipt>>>,
85 >,
86 custom: HashMap<Vec<u8>, Vec<u8>>,
87 send_queue_events: BTreeMap<OwnedRoomId, Vec<QueuedRequest>>,
88 dependent_send_queue_events: BTreeMap<OwnedRoomId, Vec<DependentQueuedRequest>>,
89 seen_knock_requests: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, OwnedUserId>>,
90 thread_subscriptions: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, StoredThreadSubscription>>,
91 thread_subscriptions_catchup_tokens: Option<Vec<ThreadSubscriptionCatchupToken>>,
92}
93
94#[derive(Debug, Default)]
98pub struct MemoryStore {
99 inner: RwLock<MemoryStoreInner>,
100}
101
102impl MemoryStore {
103 pub fn new() -> Self {
105 Self::default()
106 }
107
108 fn get_user_room_receipt_event_impl(
109 &self,
110 room_id: &RoomId,
111 receipt_type: ReceiptType,
112 thread: ReceiptThread,
113 user_id: &UserId,
114 ) -> Option<(OwnedEventId, Receipt)> {
115 self.inner
116 .read()
117 .unwrap()
118 .room_user_receipts
119 .get(room_id)?
120 .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
121 .get(user_id)
122 .cloned()
123 }
124
125 fn get_event_room_receipt_events_impl(
126 &self,
127 room_id: &RoomId,
128 receipt_type: ReceiptType,
129 thread: ReceiptThread,
130 event_id: &EventId,
131 ) -> Option<Vec<(OwnedUserId, Receipt)>> {
132 Some(
133 self.inner
134 .read()
135 .unwrap()
136 .room_event_receipts
137 .get(room_id)?
138 .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
139 .get(event_id)?
140 .iter()
141 .map(|(key, value)| (key.clone(), value.clone()))
142 .collect(),
143 )
144 }
145}
146
147#[cfg_attr(target_family = "wasm", async_trait(?Send))]
148#[cfg_attr(not(target_family = "wasm"), async_trait)]
149impl StateStore for MemoryStore {
150 type Error = StoreError;
151
152 async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<Option<StateStoreDataValue>> {
153 let inner = self.inner.read().unwrap();
154
155 Ok(match key {
156 StateStoreDataKey::SyncToken => {
157 inner.sync_token.clone().map(StateStoreDataValue::SyncToken)
158 }
159 StateStoreDataKey::ServerInfo => {
160 inner.server_info.clone().map(StateStoreDataValue::ServerInfo)
161 }
162 StateStoreDataKey::Filter(filter_name) => {
163 inner.filters.get(filter_name).cloned().map(StateStoreDataValue::Filter)
164 }
165 StateStoreDataKey::UserAvatarUrl(user_id) => {
166 inner.user_avatar_url.get(user_id).cloned().map(StateStoreDataValue::UserAvatarUrl)
167 }
168 StateStoreDataKey::RecentlyVisitedRooms(user_id) => inner
169 .recently_visited_rooms
170 .get(user_id)
171 .cloned()
172 .map(StateStoreDataValue::RecentlyVisitedRooms),
173 StateStoreDataKey::UtdHookManagerData => {
174 inner.utd_hook_manager_data.clone().map(StateStoreDataValue::UtdHookManagerData)
175 }
176 StateStoreDataKey::OneTimeKeyAlreadyUploaded => inner
177 .one_time_key_uploaded_error
178 .then_some(StateStoreDataValue::OneTimeKeyAlreadyUploaded),
179 StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
180 let key = (room_id.to_owned(), thread_root.map(ToOwned::to_owned));
181 inner.composer_drafts.get(&key).cloned().map(StateStoreDataValue::ComposerDraft)
182 }
183 StateStoreDataKey::SeenKnockRequests(room_id) => inner
184 .seen_knock_requests
185 .get(room_id)
186 .cloned()
187 .map(StateStoreDataValue::SeenKnockRequests),
188 StateStoreDataKey::ThreadSubscriptionsCatchupTokens => inner
189 .thread_subscriptions_catchup_tokens
190 .clone()
191 .map(StateStoreDataValue::ThreadSubscriptionsCatchupTokens),
192 })
193 }
194
195 async fn set_kv_data(
196 &self,
197 key: StateStoreDataKey<'_>,
198 value: StateStoreDataValue,
199 ) -> Result<()> {
200 let mut inner = self.inner.write().unwrap();
201 match key {
202 StateStoreDataKey::SyncToken => {
203 inner.sync_token =
204 Some(value.into_sync_token().expect("Session data not a sync token"))
205 }
206 StateStoreDataKey::Filter(filter_name) => {
207 inner.filters.insert(
208 filter_name.to_owned(),
209 value.into_filter().expect("Session data not a filter"),
210 );
211 }
212 StateStoreDataKey::UserAvatarUrl(user_id) => {
213 inner.user_avatar_url.insert(
214 user_id.to_owned(),
215 value.into_user_avatar_url().expect("Session data not a user avatar url"),
216 );
217 }
218 StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
219 inner.recently_visited_rooms.insert(
220 user_id.to_owned(),
221 value
222 .into_recently_visited_rooms()
223 .expect("Session data not a list of recently visited rooms"),
224 );
225 }
226 StateStoreDataKey::UtdHookManagerData => {
227 inner.utd_hook_manager_data = Some(
228 value
229 .into_utd_hook_manager_data()
230 .expect("Session data not the hook manager data"),
231 );
232 }
233 StateStoreDataKey::OneTimeKeyAlreadyUploaded => {
234 inner.one_time_key_uploaded_error = true;
235 }
236 StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
237 inner.composer_drafts.insert(
238 (room_id.to_owned(), thread_root.map(ToOwned::to_owned)),
239 value.into_composer_draft().expect("Session data not a composer draft"),
240 );
241 }
242 StateStoreDataKey::ServerInfo => {
243 inner.server_info = Some(
244 value.into_server_info().expect("Session data not containing server info"),
245 );
246 }
247 StateStoreDataKey::SeenKnockRequests(room_id) => {
248 inner.seen_knock_requests.insert(
249 room_id.to_owned(),
250 value
251 .into_seen_knock_requests()
252 .expect("Session data is not a set of seen join request ids"),
253 );
254 }
255 StateStoreDataKey::ThreadSubscriptionsCatchupTokens => {
256 inner.thread_subscriptions_catchup_tokens =
257 Some(value.into_thread_subscriptions_catchup_tokens().expect(
258 "Session data is not a list of thread subscription catchup tokens",
259 ));
260 }
261 }
262
263 Ok(())
264 }
265
266 async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> {
267 let mut inner = self.inner.write().unwrap();
268 match key {
269 StateStoreDataKey::SyncToken => inner.sync_token = None,
270 StateStoreDataKey::ServerInfo => inner.server_info = None,
271 StateStoreDataKey::Filter(filter_name) => {
272 inner.filters.remove(filter_name);
273 }
274 StateStoreDataKey::UserAvatarUrl(user_id) => {
275 inner.user_avatar_url.remove(user_id);
276 }
277 StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
278 inner.recently_visited_rooms.remove(user_id);
279 }
280 StateStoreDataKey::UtdHookManagerData => inner.utd_hook_manager_data = None,
281 StateStoreDataKey::OneTimeKeyAlreadyUploaded => {
282 inner.one_time_key_uploaded_error = false
283 }
284 StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
285 let key = (room_id.to_owned(), thread_root.map(ToOwned::to_owned));
286 inner.composer_drafts.remove(&key);
287 }
288 StateStoreDataKey::SeenKnockRequests(room_id) => {
289 inner.seen_knock_requests.remove(room_id);
290 }
291 StateStoreDataKey::ThreadSubscriptionsCatchupTokens => {
292 inner.thread_subscriptions_catchup_tokens = None;
293 }
294 }
295 Ok(())
296 }
297
298 #[instrument(skip(self, changes))]
299 async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
300 let now = Instant::now();
301
302 let mut inner = self.inner.write().unwrap();
303
304 if let Some(s) = &changes.sync_token {
305 inner.sync_token = Some(s.to_owned());
306 }
307
308 for (room, users) in &changes.profiles_to_delete {
309 let Some(room_profiles) = inner.profiles.get_mut(room) else {
310 continue;
311 };
312 for user in users {
313 room_profiles.remove(user);
314 }
315 }
316
317 for (room, users) in &changes.profiles {
318 for (user_id, profile) in users {
319 inner
320 .profiles
321 .entry(room.clone())
322 .or_default()
323 .insert(user_id.clone(), profile.clone());
324 }
325 }
326
327 for (room, map) in &changes.ambiguity_maps {
328 for (display_name, display_names) in map {
329 inner
330 .display_names
331 .entry(room.clone())
332 .or_default()
333 .insert(display_name.clone(), display_names.clone());
334 }
335 }
336
337 for (event_type, event) in &changes.account_data {
338 inner.account_data.insert(event_type.clone(), event.clone());
339 }
340
341 for (room, events) in &changes.room_account_data {
342 for (event_type, event) in events {
343 inner
344 .room_account_data
345 .entry(room.clone())
346 .or_default()
347 .insert(event_type.clone(), event.clone());
348 }
349 }
350
351 for (room, event_types) in &changes.state {
352 for (event_type, events) in event_types {
353 for (state_key, raw_event) in events {
354 inner
355 .room_state
356 .entry(room.clone())
357 .or_default()
358 .entry(event_type.clone())
359 .or_default()
360 .insert(state_key.to_owned(), raw_event.clone());
361 inner.stripped_room_state.remove(room);
362
363 if *event_type == StateEventType::RoomMember {
364 let event =
365 match raw_event.deserialize_as_unchecked::<SyncRoomMemberEvent>() {
366 Ok(ev) => ev,
367 Err(e) => {
368 let event_id: Option<String> =
369 raw_event.get_field("event_id").ok().flatten();
370 debug!(event_id, "Failed to deserialize member event: {e}");
371 continue;
372 }
373 };
374
375 inner.stripped_members.remove(room);
376
377 inner
378 .members
379 .entry(room.clone())
380 .or_default()
381 .insert(event.state_key().to_owned(), event.membership().clone());
382 }
383 }
384 }
385 }
386
387 for (room_id, info) in &changes.room_infos {
388 inner.room_info.insert(room_id.clone(), info.clone());
389 }
390
391 for (sender, event) in &changes.presence {
392 inner.presence.insert(sender.clone(), event.clone());
393 }
394
395 for (room, event_types) in &changes.stripped_state {
396 for (event_type, events) in event_types {
397 for (state_key, raw_event) in events {
398 inner
399 .stripped_room_state
400 .entry(room.clone())
401 .or_default()
402 .entry(event_type.clone())
403 .or_default()
404 .insert(state_key.to_owned(), raw_event.clone());
405
406 if *event_type == StateEventType::RoomMember {
407 let event =
408 match raw_event.deserialize_as_unchecked::<StrippedRoomMemberEvent>() {
409 Ok(ev) => ev,
410 Err(e) => {
411 let event_id: Option<String> =
412 raw_event.get_field("event_id").ok().flatten();
413 debug!(
414 event_id,
415 "Failed to deserialize stripped member event: {e}"
416 );
417 continue;
418 }
419 };
420
421 inner
422 .stripped_members
423 .entry(room.clone())
424 .or_default()
425 .insert(event.state_key, event.content.membership.clone());
426 }
427 }
428 }
429 }
430
431 for (room, content) in &changes.receipts {
432 for (event_id, receipts) in &content.0 {
433 for (receipt_type, receipts) in receipts {
434 for (user_id, receipt) in receipts {
435 let thread = receipt.thread.as_str().map(ToOwned::to_owned);
436 if let Some((old_event, _)) = inner
438 .room_user_receipts
439 .entry(room.clone())
440 .or_default()
441 .entry((receipt_type.to_string(), thread.clone()))
442 .or_default()
443 .insert(user_id.clone(), (event_id.clone(), receipt.clone()))
444 {
445 if let Some(receipt_map) = inner.room_event_receipts.get_mut(room)
447 && let Some(event_map) =
448 receipt_map.get_mut(&(receipt_type.to_string(), thread.clone()))
449 && let Some(user_map) = event_map.get_mut(&old_event)
450 {
451 user_map.remove(user_id);
452 }
453 }
454
455 inner
457 .room_event_receipts
458 .entry(room.clone())
459 .or_default()
460 .entry((receipt_type.to_string(), thread))
461 .or_default()
462 .entry(event_id.clone())
463 .or_default()
464 .insert(user_id.clone(), receipt.clone());
465 }
466 }
467 }
468 }
469
470 let make_redaction_rules = |room_info: &HashMap<OwnedRoomId, RoomInfo>, room_id| {
471 room_info.get(room_id).map(|info| info.room_version_rules_or_default()).unwrap_or_else(|| {
472 warn!(
473 ?room_id,
474 "Unable to get the room version rules, defaulting to rules for room version {ROOM_VERSION_FALLBACK}"
475 );
476 ROOM_VERSION_RULES_FALLBACK
477 }).redaction
478 };
479
480 let inner = &mut *inner;
481 for (room_id, redactions) in &changes.redactions {
482 let mut redaction_rules = None;
483
484 if let Some(room) = inner.room_state.get_mut(room_id) {
485 for ref_room_mu in room.values_mut() {
486 for raw_evt in ref_room_mu.values_mut() {
487 if let Ok(Some(event_id)) = raw_evt.get_field::<OwnedEventId>("event_id")
488 && let Some(redaction) = redactions.get(&event_id)
489 {
490 let redacted = redact(
491 raw_evt.deserialize_as::<CanonicalJsonObject>()?,
492 redaction_rules.get_or_insert_with(|| {
493 make_redaction_rules(&inner.room_info, room_id)
494 }),
495 Some(RedactedBecause::from_raw_event(redaction)?),
496 )
497 .map_err(StoreError::Redaction)?;
498 *raw_evt = Raw::new(&redacted)?.cast_unchecked();
499 }
500 }
501 }
502 }
503 }
504
505 debug!("Saved changes in {:?}", now.elapsed());
506
507 Ok(())
508 }
509
510 async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
511 Ok(self.inner.read().unwrap().presence.get(user_id).cloned())
512 }
513
514 async fn get_presence_events(
515 &self,
516 user_ids: &[OwnedUserId],
517 ) -> Result<Vec<Raw<PresenceEvent>>> {
518 let presence = &self.inner.read().unwrap().presence;
519 Ok(user_ids.iter().filter_map(|user_id| presence.get(user_id).cloned()).collect())
520 }
521
522 async fn get_state_event(
523 &self,
524 room_id: &RoomId,
525 event_type: StateEventType,
526 state_key: &str,
527 ) -> Result<Option<RawAnySyncOrStrippedState>> {
528 Ok(self
529 .get_state_events_for_keys(room_id, event_type, &[state_key])
530 .await?
531 .into_iter()
532 .next())
533 }
534
535 async fn get_state_events(
536 &self,
537 room_id: &RoomId,
538 event_type: StateEventType,
539 ) -> Result<Vec<RawAnySyncOrStrippedState>> {
540 fn get_events<T>(
541 state_map: &HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<T>>>>,
542 room_id: &RoomId,
543 event_type: &StateEventType,
544 to_enum: fn(Raw<T>) -> RawAnySyncOrStrippedState,
545 ) -> Option<Vec<RawAnySyncOrStrippedState>> {
546 let state_events = state_map.get(room_id)?.get(event_type)?;
547 Some(state_events.values().cloned().map(to_enum).collect())
548 }
549
550 let inner = self.inner.read().unwrap();
551 Ok(get_events(
552 &inner.stripped_room_state,
553 room_id,
554 &event_type,
555 RawAnySyncOrStrippedState::Stripped,
556 )
557 .or_else(|| {
558 get_events(&inner.room_state, room_id, &event_type, RawAnySyncOrStrippedState::Sync)
559 })
560 .unwrap_or_default())
561 }
562
563 async fn get_state_events_for_keys(
564 &self,
565 room_id: &RoomId,
566 event_type: StateEventType,
567 state_keys: &[&str],
568 ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
569 let inner = self.inner.read().unwrap();
570
571 if let Some(stripped_state_events) =
572 inner.stripped_room_state.get(room_id).and_then(|events| events.get(&event_type))
573 {
574 Ok(state_keys
575 .iter()
576 .filter_map(|k| {
577 stripped_state_events
578 .get(*k)
579 .map(|e| RawAnySyncOrStrippedState::Stripped(e.clone()))
580 })
581 .collect())
582 } else if let Some(sync_state_events) =
583 inner.room_state.get(room_id).and_then(|events| events.get(&event_type))
584 {
585 Ok(state_keys
586 .iter()
587 .filter_map(|k| {
588 sync_state_events.get(*k).map(|e| RawAnySyncOrStrippedState::Sync(e.clone()))
589 })
590 .collect())
591 } else {
592 Ok(Vec::new())
593 }
594 }
595
596 async fn get_profile(
597 &self,
598 room_id: &RoomId,
599 user_id: &UserId,
600 ) -> Result<Option<MinimalRoomMemberEvent>> {
601 Ok(self
602 .inner
603 .read()
604 .unwrap()
605 .profiles
606 .get(room_id)
607 .and_then(|room_profiles| room_profiles.get(user_id))
608 .cloned())
609 }
610
611 async fn get_profiles<'a>(
612 &self,
613 room_id: &RoomId,
614 user_ids: &'a [OwnedUserId],
615 ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>> {
616 if user_ids.is_empty() {
617 return Ok(BTreeMap::new());
618 }
619
620 let profiles = &self.inner.read().unwrap().profiles;
621 let Some(room_profiles) = profiles.get(room_id) else {
622 return Ok(BTreeMap::new());
623 };
624
625 Ok(user_ids
626 .iter()
627 .filter_map(|user_id| room_profiles.get(user_id).map(|p| (&**user_id, p.clone())))
628 .collect())
629 }
630
631 #[instrument(skip(self, memberships))]
632 async fn get_user_ids(
633 &self,
634 room_id: &RoomId,
635 memberships: RoomMemberships,
636 ) -> Result<Vec<OwnedUserId>> {
637 fn get_user_ids_inner(
643 members: &HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
644 room_id: &RoomId,
645 memberships: RoomMemberships,
646 ) -> Vec<OwnedUserId> {
647 members
648 .get(room_id)
649 .map(|members| {
650 members
651 .iter()
652 .filter_map(|(user_id, membership)| {
653 memberships.matches(membership).then_some(user_id)
654 })
655 .cloned()
656 .collect()
657 })
658 .unwrap_or_default()
659 }
660 let inner = self.inner.read().unwrap();
661 let v = get_user_ids_inner(&inner.stripped_members, room_id, memberships);
662 if !v.is_empty() {
663 return Ok(v);
664 }
665 Ok(get_user_ids_inner(&inner.members, room_id, memberships))
666 }
667
668 async fn get_room_infos(&self, room_load_settings: &RoomLoadSettings) -> Result<Vec<RoomInfo>> {
669 let memory_store_inner = self.inner.read().unwrap();
670 let room_infos = &memory_store_inner.room_info;
671
672 Ok(match room_load_settings {
673 RoomLoadSettings::All => room_infos.values().cloned().collect(),
674
675 RoomLoadSettings::One(room_id) => match room_infos.get(room_id) {
676 Some(room_info) => vec![room_info.clone()],
677 None => vec![],
678 },
679 })
680 }
681
682 async fn get_users_with_display_name(
683 &self,
684 room_id: &RoomId,
685 display_name: &DisplayName,
686 ) -> Result<BTreeSet<OwnedUserId>> {
687 Ok(self
688 .inner
689 .read()
690 .unwrap()
691 .display_names
692 .get(room_id)
693 .and_then(|room_names| room_names.get(display_name).cloned())
694 .unwrap_or_default())
695 }
696
697 async fn get_users_with_display_names<'a>(
698 &self,
699 room_id: &RoomId,
700 display_names: &'a [DisplayName],
701 ) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>> {
702 if display_names.is_empty() {
703 return Ok(HashMap::new());
704 }
705
706 let inner = self.inner.read().unwrap();
707 let Some(room_names) = inner.display_names.get(room_id) else {
708 return Ok(HashMap::new());
709 };
710
711 Ok(display_names.iter().filter_map(|n| room_names.get(n).map(|d| (n, d.clone()))).collect())
712 }
713
714 async fn get_account_data_event(
715 &self,
716 event_type: GlobalAccountDataEventType,
717 ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
718 Ok(self.inner.read().unwrap().account_data.get(&event_type).cloned())
719 }
720
721 async fn get_room_account_data_event(
722 &self,
723 room_id: &RoomId,
724 event_type: RoomAccountDataEventType,
725 ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
726 Ok(self
727 .inner
728 .read()
729 .unwrap()
730 .room_account_data
731 .get(room_id)
732 .and_then(|m| m.get(&event_type))
733 .cloned())
734 }
735
736 async fn get_user_room_receipt_event(
737 &self,
738 room_id: &RoomId,
739 receipt_type: ReceiptType,
740 thread: ReceiptThread,
741 user_id: &UserId,
742 ) -> Result<Option<(OwnedEventId, Receipt)>> {
743 Ok(self.get_user_room_receipt_event_impl(room_id, receipt_type, thread, user_id))
744 }
745
746 async fn get_event_room_receipt_events(
747 &self,
748 room_id: &RoomId,
749 receipt_type: ReceiptType,
750 thread: ReceiptThread,
751 event_id: &EventId,
752 ) -> Result<Vec<(OwnedUserId, Receipt)>> {
753 Ok(self
754 .get_event_room_receipt_events_impl(room_id, receipt_type, thread, event_id)
755 .unwrap_or_default())
756 }
757
758 async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
759 Ok(self.inner.read().unwrap().custom.get(key).cloned())
760 }
761
762 async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
763 Ok(self.inner.write().unwrap().custom.insert(key.to_vec(), value))
764 }
765
766 async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
767 Ok(self.inner.write().unwrap().custom.remove(key))
768 }
769
770 async fn remove_room(&self, room_id: &RoomId) -> Result<()> {
771 let mut inner = self.inner.write().unwrap();
772
773 inner.profiles.remove(room_id);
774 inner.display_names.remove(room_id);
775 inner.members.remove(room_id);
776 inner.room_info.remove(room_id);
777 inner.room_state.remove(room_id);
778 inner.room_account_data.remove(room_id);
779 inner.stripped_room_state.remove(room_id);
780 inner.stripped_members.remove(room_id);
781 inner.room_user_receipts.remove(room_id);
782 inner.room_event_receipts.remove(room_id);
783 inner.send_queue_events.remove(room_id);
784 inner.dependent_send_queue_events.remove(room_id);
785 inner.thread_subscriptions.remove(room_id);
786
787 Ok(())
788 }
789
790 async fn save_send_queue_request(
791 &self,
792 room_id: &RoomId,
793 transaction_id: OwnedTransactionId,
794 created_at: MilliSecondsSinceUnixEpoch,
795 kind: QueuedRequestKind,
796 priority: usize,
797 ) -> Result<(), Self::Error> {
798 self.inner
799 .write()
800 .unwrap()
801 .send_queue_events
802 .entry(room_id.to_owned())
803 .or_default()
804 .push(QueuedRequest { kind, transaction_id, error: None, priority, created_at });
805 Ok(())
806 }
807
808 async fn update_send_queue_request(
809 &self,
810 room_id: &RoomId,
811 transaction_id: &TransactionId,
812 kind: QueuedRequestKind,
813 ) -> Result<bool, Self::Error> {
814 if let Some(entry) = self
815 .inner
816 .write()
817 .unwrap()
818 .send_queue_events
819 .entry(room_id.to_owned())
820 .or_default()
821 .iter_mut()
822 .find(|item| item.transaction_id == transaction_id)
823 {
824 entry.kind = kind;
825 entry.error = None;
826 Ok(true)
827 } else {
828 Ok(false)
829 }
830 }
831
832 async fn remove_send_queue_request(
833 &self,
834 room_id: &RoomId,
835 transaction_id: &TransactionId,
836 ) -> Result<bool, Self::Error> {
837 let mut inner = self.inner.write().unwrap();
838 let q = &mut inner.send_queue_events;
839
840 let entry = q.get_mut(room_id);
841 if let Some(entry) = entry {
842 if let Some(pos) = entry.iter().position(|item| item.transaction_id == transaction_id) {
844 entry.remove(pos);
845 if entry.is_empty() {
847 q.remove(room_id);
848 }
849 return Ok(true);
850 }
851 }
852
853 Ok(false)
854 }
855
856 async fn load_send_queue_requests(
857 &self,
858 room_id: &RoomId,
859 ) -> Result<Vec<QueuedRequest>, Self::Error> {
860 let mut ret = self
861 .inner
862 .write()
863 .unwrap()
864 .send_queue_events
865 .entry(room_id.to_owned())
866 .or_default()
867 .clone();
868 ret.sort_by(|lhs, rhs| rhs.priority.cmp(&lhs.priority));
870 Ok(ret)
871 }
872
873 async fn update_send_queue_request_status(
874 &self,
875 room_id: &RoomId,
876 transaction_id: &TransactionId,
877 error: Option<QueueWedgeError>,
878 ) -> Result<(), Self::Error> {
879 if let Some(entry) = self
880 .inner
881 .write()
882 .unwrap()
883 .send_queue_events
884 .entry(room_id.to_owned())
885 .or_default()
886 .iter_mut()
887 .find(|item| item.transaction_id == transaction_id)
888 {
889 entry.error = error;
890 }
891 Ok(())
892 }
893
894 async fn load_rooms_with_unsent_requests(&self) -> Result<Vec<OwnedRoomId>, Self::Error> {
895 Ok(self.inner.read().unwrap().send_queue_events.keys().cloned().collect())
896 }
897
898 async fn save_dependent_queued_request(
899 &self,
900 room: &RoomId,
901 parent_transaction_id: &TransactionId,
902 own_transaction_id: ChildTransactionId,
903 created_at: MilliSecondsSinceUnixEpoch,
904 content: DependentQueuedRequestKind,
905 ) -> Result<(), Self::Error> {
906 self.inner
907 .write()
908 .unwrap()
909 .dependent_send_queue_events
910 .entry(room.to_owned())
911 .or_default()
912 .push(DependentQueuedRequest {
913 kind: content,
914 parent_transaction_id: parent_transaction_id.to_owned(),
915 own_transaction_id,
916 parent_key: None,
917 created_at,
918 });
919 Ok(())
920 }
921
922 async fn mark_dependent_queued_requests_as_ready(
923 &self,
924 room: &RoomId,
925 parent_txn_id: &TransactionId,
926 sent_parent_key: SentRequestKey,
927 ) -> Result<usize, Self::Error> {
928 let mut inner = self.inner.write().unwrap();
929 let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
930 let mut num_updated = 0;
931 for d in dependents.iter_mut().filter(|item| item.parent_transaction_id == parent_txn_id) {
932 d.parent_key = Some(sent_parent_key.clone());
933 num_updated += 1;
934 }
935 Ok(num_updated)
936 }
937
938 async fn update_dependent_queued_request(
939 &self,
940 room: &RoomId,
941 own_transaction_id: &ChildTransactionId,
942 new_content: DependentQueuedRequestKind,
943 ) -> Result<bool, Self::Error> {
944 let mut inner = self.inner.write().unwrap();
945 let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
946 for d in dependents.iter_mut() {
947 if d.own_transaction_id == *own_transaction_id {
948 d.kind = new_content;
949 return Ok(true);
950 }
951 }
952 Ok(false)
953 }
954
955 async fn remove_dependent_queued_request(
956 &self,
957 room: &RoomId,
958 txn_id: &ChildTransactionId,
959 ) -> Result<bool, Self::Error> {
960 let mut inner = self.inner.write().unwrap();
961 let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
962 if let Some(pos) = dependents.iter().position(|item| item.own_transaction_id == *txn_id) {
963 dependents.remove(pos);
964 Ok(true)
965 } else {
966 Ok(false)
967 }
968 }
969
970 async fn load_dependent_queued_requests(
971 &self,
972 room: &RoomId,
973 ) -> Result<Vec<DependentQueuedRequest>, Self::Error> {
974 Ok(self
975 .inner
976 .read()
977 .unwrap()
978 .dependent_send_queue_events
979 .get(room)
980 .cloned()
981 .unwrap_or_default())
982 }
983
984 async fn upsert_thread_subscription(
985 &self,
986 room: &RoomId,
987 thread_id: &EventId,
988 mut new: StoredThreadSubscription,
989 ) -> Result<(), Self::Error> {
990 let mut inner = self.inner.write().unwrap();
991 let room_subs = inner.thread_subscriptions.entry(room.to_owned()).or_default();
992
993 if let Some(previous) = room_subs.get(thread_id) {
994 if *previous == new {
996 return Ok(());
997 }
998 if !compare_thread_subscription_bump_stamps(previous.bump_stamp, &mut new.bump_stamp) {
999 return Ok(());
1000 }
1001 }
1002
1003 room_subs.insert(thread_id.to_owned(), new);
1004
1005 Ok(())
1006 }
1007
1008 async fn load_thread_subscription(
1009 &self,
1010 room: &RoomId,
1011 thread_id: &EventId,
1012 ) -> Result<Option<StoredThreadSubscription>, Self::Error> {
1013 let inner = self.inner.read().unwrap();
1014 Ok(inner
1015 .thread_subscriptions
1016 .get(room)
1017 .and_then(|subscriptions| subscriptions.get(thread_id))
1018 .copied())
1019 }
1020
1021 async fn remove_thread_subscription(
1022 &self,
1023 room: &RoomId,
1024 thread_id: &EventId,
1025 ) -> Result<(), Self::Error> {
1026 let mut inner = self.inner.write().unwrap();
1027
1028 let Some(room_subs) = inner.thread_subscriptions.get_mut(room) else {
1029 return Ok(());
1030 };
1031
1032 room_subs.remove(thread_id);
1033
1034 if room_subs.is_empty() {
1035 inner.thread_subscriptions.remove(room);
1037 }
1038
1039 Ok(())
1040 }
1041}
1042
1043#[cfg(test)]
1044mod tests {
1045 use super::{MemoryStore, Result, StateStore};
1046
1047 async fn get_store() -> Result<impl StateStore> {
1048 Ok(MemoryStore::new())
1049 }
1050
1051 statestore_integration_tests!();
1052}