1use std::{
16 collections::{BTreeMap, BTreeSet, HashMap},
17 sync::RwLock,
18};
19
20use async_trait::async_trait;
21use growable_bloom_filter::GrowableBloom;
22use ruma::{
23 canonical_json::{redact, RedactedBecause},
24 events::{
25 presence::PresenceEvent,
26 receipt::{Receipt, ReceiptThread, ReceiptType},
27 room::member::{MembershipState, StrippedRoomMemberEvent, SyncRoomMemberEvent},
28 AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
29 AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType,
30 },
31 serde::Raw,
32 time::Instant,
33 CanonicalJsonObject, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedMxcUri,
34 OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, RoomVersionId, TransactionId, UserId,
35};
36use tracing::{debug, instrument, warn};
37
38use super::{
39 send_queue::{ChildTransactionId, QueuedRequest, SentRequestKey},
40 traits::{ComposerDraft, ServerCapabilities},
41 DependentQueuedRequest, DependentQueuedRequestKind, QueuedRequestKind, Result, RoomInfo,
42 StateChanges, StateStore, StoreError,
43};
44use crate::{
45 deserialized_responses::{DisplayName, RawAnySyncOrStrippedState},
46 store::QueueWedgeError,
47 MinimalRoomMemberEvent, RoomMemberships, StateStoreDataKey, StateStoreDataValue,
48};
49
50#[derive(Debug, Default)]
51#[allow(clippy::type_complexity)]
52struct MemoryStoreInner {
53 recently_visited_rooms: HashMap<OwnedUserId, Vec<OwnedRoomId>>,
54 composer_drafts: HashMap<OwnedRoomId, ComposerDraft>,
55 user_avatar_url: HashMap<OwnedUserId, OwnedMxcUri>,
56 sync_token: Option<String>,
57 server_capabilities: Option<ServerCapabilities>,
58 filters: HashMap<String, String>,
59 utd_hook_manager_data: Option<GrowableBloom>,
60 account_data: HashMap<GlobalAccountDataEventType, Raw<AnyGlobalAccountDataEvent>>,
61 profiles: HashMap<OwnedRoomId, HashMap<OwnedUserId, MinimalRoomMemberEvent>>,
62 display_names: HashMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>,
63 members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
64 room_info: HashMap<OwnedRoomId, RoomInfo>,
65 room_state:
66 HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnySyncStateEvent>>>>,
67 room_account_data:
68 HashMap<OwnedRoomId, HashMap<RoomAccountDataEventType, Raw<AnyRoomAccountDataEvent>>>,
69 stripped_room_state:
70 HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnyStrippedStateEvent>>>>,
71 stripped_members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
72 presence: HashMap<OwnedUserId, Raw<PresenceEvent>>,
73 room_user_receipts: HashMap<
74 OwnedRoomId,
75 HashMap<(String, Option<String>), HashMap<OwnedUserId, (OwnedEventId, Receipt)>>,
76 >,
77
78 room_event_receipts: HashMap<
79 OwnedRoomId,
80 HashMap<(String, Option<String>), HashMap<OwnedEventId, HashMap<OwnedUserId, Receipt>>>,
81 >,
82 custom: HashMap<Vec<u8>, Vec<u8>>,
83 send_queue_events: BTreeMap<OwnedRoomId, Vec<QueuedRequest>>,
84 dependent_send_queue_events: BTreeMap<OwnedRoomId, Vec<DependentQueuedRequest>>,
85 seen_knock_requests: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, OwnedUserId>>,
86}
87
88#[derive(Debug, Default)]
92pub struct MemoryStore {
93 inner: RwLock<MemoryStoreInner>,
94}
95
96impl MemoryStore {
97 pub fn new() -> Self {
99 Self::default()
100 }
101
102 fn get_user_room_receipt_event_impl(
103 &self,
104 room_id: &RoomId,
105 receipt_type: ReceiptType,
106 thread: ReceiptThread,
107 user_id: &UserId,
108 ) -> Option<(OwnedEventId, Receipt)> {
109 self.inner
110 .read()
111 .unwrap()
112 .room_user_receipts
113 .get(room_id)?
114 .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
115 .get(user_id)
116 .cloned()
117 }
118
119 fn get_event_room_receipt_events_impl(
120 &self,
121 room_id: &RoomId,
122 receipt_type: ReceiptType,
123 thread: ReceiptThread,
124 event_id: &EventId,
125 ) -> Option<Vec<(OwnedUserId, Receipt)>> {
126 Some(
127 self.inner
128 .read()
129 .unwrap()
130 .room_event_receipts
131 .get(room_id)?
132 .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
133 .get(event_id)?
134 .iter()
135 .map(|(key, value)| (key.clone(), value.clone()))
136 .collect(),
137 )
138 }
139}
140
141#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
142#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
143impl StateStore for MemoryStore {
144 type Error = StoreError;
145
146 async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<Option<StateStoreDataValue>> {
147 let inner = self.inner.read().unwrap();
148 Ok(match key {
149 StateStoreDataKey::SyncToken => {
150 inner.sync_token.clone().map(StateStoreDataValue::SyncToken)
151 }
152 StateStoreDataKey::ServerCapabilities => {
153 inner.server_capabilities.clone().map(StateStoreDataValue::ServerCapabilities)
154 }
155 StateStoreDataKey::Filter(filter_name) => {
156 inner.filters.get(filter_name).cloned().map(StateStoreDataValue::Filter)
157 }
158 StateStoreDataKey::UserAvatarUrl(user_id) => {
159 inner.user_avatar_url.get(user_id).cloned().map(StateStoreDataValue::UserAvatarUrl)
160 }
161 StateStoreDataKey::RecentlyVisitedRooms(user_id) => inner
162 .recently_visited_rooms
163 .get(user_id)
164 .cloned()
165 .map(StateStoreDataValue::RecentlyVisitedRooms),
166 StateStoreDataKey::UtdHookManagerData => {
167 inner.utd_hook_manager_data.clone().map(StateStoreDataValue::UtdHookManagerData)
168 }
169 StateStoreDataKey::ComposerDraft(room_id) => {
170 inner.composer_drafts.get(room_id).cloned().map(StateStoreDataValue::ComposerDraft)
171 }
172 StateStoreDataKey::SeenKnockRequests(room_id) => inner
173 .seen_knock_requests
174 .get(room_id)
175 .cloned()
176 .map(StateStoreDataValue::SeenKnockRequests),
177 })
178 }
179
180 async fn set_kv_data(
181 &self,
182 key: StateStoreDataKey<'_>,
183 value: StateStoreDataValue,
184 ) -> Result<()> {
185 let mut inner = self.inner.write().unwrap();
186 match key {
187 StateStoreDataKey::SyncToken => {
188 inner.sync_token =
189 Some(value.into_sync_token().expect("Session data not a sync token"))
190 }
191 StateStoreDataKey::Filter(filter_name) => {
192 inner.filters.insert(
193 filter_name.to_owned(),
194 value.into_filter().expect("Session data not a filter"),
195 );
196 }
197 StateStoreDataKey::UserAvatarUrl(user_id) => {
198 inner.user_avatar_url.insert(
199 user_id.to_owned(),
200 value.into_user_avatar_url().expect("Session data not a user avatar url"),
201 );
202 }
203 StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
204 inner.recently_visited_rooms.insert(
205 user_id.to_owned(),
206 value
207 .into_recently_visited_rooms()
208 .expect("Session data not a list of recently visited rooms"),
209 );
210 }
211 StateStoreDataKey::UtdHookManagerData => {
212 inner.utd_hook_manager_data = Some(
213 value
214 .into_utd_hook_manager_data()
215 .expect("Session data not the hook manager data"),
216 );
217 }
218 StateStoreDataKey::ComposerDraft(room_id) => {
219 inner.composer_drafts.insert(
220 room_id.to_owned(),
221 value.into_composer_draft().expect("Session data not a composer draft"),
222 );
223 }
224 StateStoreDataKey::ServerCapabilities => {
225 inner.server_capabilities = Some(
226 value
227 .into_server_capabilities()
228 .expect("Session data not containing server capabilities"),
229 );
230 }
231 StateStoreDataKey::SeenKnockRequests(room_id) => {
232 inner.seen_knock_requests.insert(
233 room_id.to_owned(),
234 value
235 .into_seen_knock_requests()
236 .expect("Session data is not a set of seen join request ids"),
237 );
238 }
239 }
240
241 Ok(())
242 }
243
244 async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> {
245 let mut inner = self.inner.write().unwrap();
246 match key {
247 StateStoreDataKey::SyncToken => inner.sync_token = None,
248 StateStoreDataKey::ServerCapabilities => inner.server_capabilities = None,
249 StateStoreDataKey::Filter(filter_name) => {
250 inner.filters.remove(filter_name);
251 }
252 StateStoreDataKey::UserAvatarUrl(user_id) => {
253 inner.user_avatar_url.remove(user_id);
254 }
255 StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
256 inner.recently_visited_rooms.remove(user_id);
257 }
258 StateStoreDataKey::UtdHookManagerData => inner.utd_hook_manager_data = None,
259 StateStoreDataKey::ComposerDraft(room_id) => {
260 inner.composer_drafts.remove(room_id);
261 }
262 StateStoreDataKey::SeenKnockRequests(room_id) => {
263 inner.seen_knock_requests.remove(room_id);
264 }
265 }
266 Ok(())
267 }
268
269 #[instrument(skip(self, changes))]
270 async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
271 let now = Instant::now();
272
273 let mut inner = self.inner.write().unwrap();
274
275 if let Some(s) = &changes.sync_token {
276 inner.sync_token = Some(s.to_owned());
277 }
278
279 for (room, users) in &changes.profiles_to_delete {
280 let Some(room_profiles) = inner.profiles.get_mut(room) else {
281 continue;
282 };
283 for user in users {
284 room_profiles.remove(user);
285 }
286 }
287
288 for (room, users) in &changes.profiles {
289 for (user_id, profile) in users {
290 inner
291 .profiles
292 .entry(room.clone())
293 .or_default()
294 .insert(user_id.clone(), profile.clone());
295 }
296 }
297
298 for (room, map) in &changes.ambiguity_maps {
299 for (display_name, display_names) in map {
300 inner
301 .display_names
302 .entry(room.clone())
303 .or_default()
304 .insert(display_name.clone(), display_names.clone());
305 }
306 }
307
308 for (event_type, event) in &changes.account_data {
309 inner.account_data.insert(event_type.clone(), event.clone());
310 }
311
312 for (room, events) in &changes.room_account_data {
313 for (event_type, event) in events {
314 inner
315 .room_account_data
316 .entry(room.clone())
317 .or_default()
318 .insert(event_type.clone(), event.clone());
319 }
320 }
321
322 for (room, event_types) in &changes.state {
323 for (event_type, events) in event_types {
324 for (state_key, raw_event) in events {
325 inner
326 .room_state
327 .entry(room.clone())
328 .or_default()
329 .entry(event_type.clone())
330 .or_default()
331 .insert(state_key.to_owned(), raw_event.clone());
332 inner.stripped_room_state.remove(room);
333
334 if *event_type == StateEventType::RoomMember {
335 let event = match raw_event.deserialize_as::<SyncRoomMemberEvent>() {
336 Ok(ev) => ev,
337 Err(e) => {
338 let event_id: Option<String> =
339 raw_event.get_field("event_id").ok().flatten();
340 debug!(event_id, "Failed to deserialize member event: {e}");
341 continue;
342 }
343 };
344
345 inner.stripped_members.remove(room);
346
347 inner
348 .members
349 .entry(room.clone())
350 .or_default()
351 .insert(event.state_key().to_owned(), event.membership().clone());
352 }
353 }
354 }
355 }
356
357 for (room_id, info) in &changes.room_infos {
358 inner.room_info.insert(room_id.clone(), info.clone());
359 }
360
361 for (sender, event) in &changes.presence {
362 inner.presence.insert(sender.clone(), event.clone());
363 }
364
365 for (room, event_types) in &changes.stripped_state {
366 for (event_type, events) in event_types {
367 for (state_key, raw_event) in events {
368 inner
369 .stripped_room_state
370 .entry(room.clone())
371 .or_default()
372 .entry(event_type.clone())
373 .or_default()
374 .insert(state_key.to_owned(), raw_event.clone());
375
376 if *event_type == StateEventType::RoomMember {
377 let event = match raw_event.deserialize_as::<StrippedRoomMemberEvent>() {
378 Ok(ev) => ev,
379 Err(e) => {
380 let event_id: Option<String> =
381 raw_event.get_field("event_id").ok().flatten();
382 debug!(
383 event_id,
384 "Failed to deserialize stripped member event: {e}"
385 );
386 continue;
387 }
388 };
389
390 inner
391 .stripped_members
392 .entry(room.clone())
393 .or_default()
394 .insert(event.state_key, event.content.membership.clone());
395 }
396 }
397 }
398 }
399
400 for (room, content) in &changes.receipts {
401 for (event_id, receipts) in &content.0 {
402 for (receipt_type, receipts) in receipts {
403 for (user_id, receipt) in receipts {
404 let thread = receipt.thread.as_str().map(ToOwned::to_owned);
405 if let Some((old_event, _)) = inner
407 .room_user_receipts
408 .entry(room.clone())
409 .or_default()
410 .entry((receipt_type.to_string(), thread.clone()))
411 .or_default()
412 .insert(user_id.clone(), (event_id.clone(), receipt.clone()))
413 {
414 if let Some(receipt_map) = inner.room_event_receipts.get_mut(room) {
416 if let Some(event_map) =
417 receipt_map.get_mut(&(receipt_type.to_string(), thread.clone()))
418 {
419 if let Some(user_map) = event_map.get_mut(&old_event) {
420 user_map.remove(user_id);
421 }
422 }
423 }
424 }
425
426 inner
428 .room_event_receipts
429 .entry(room.clone())
430 .or_default()
431 .entry((receipt_type.to_string(), thread))
432 .or_default()
433 .entry(event_id.clone())
434 .or_default()
435 .insert(user_id.clone(), receipt.clone());
436 }
437 }
438 }
439 }
440
441 let make_room_version = |room_info: &HashMap<OwnedRoomId, RoomInfo>, room_id| {
442 room_info.get(room_id).and_then(|info| info.room_version().cloned()).unwrap_or_else(
443 || {
444 warn!(?room_id, "Unable to find the room version, assuming version 9");
445 RoomVersionId::V9
446 },
447 )
448 };
449
450 let inner = &mut *inner;
451 for (room_id, redactions) in &changes.redactions {
452 let mut room_version = None;
453
454 if let Some(room) = inner.room_state.get_mut(room_id) {
455 for ref_room_mu in room.values_mut() {
456 for raw_evt in ref_room_mu.values_mut() {
457 if let Ok(Some(event_id)) = raw_evt.get_field::<OwnedEventId>("event_id") {
458 if let Some(redaction) = redactions.get(&event_id) {
459 let redacted = redact(
460 raw_evt.deserialize_as::<CanonicalJsonObject>()?,
461 room_version.get_or_insert_with(|| {
462 make_room_version(&inner.room_info, room_id)
463 }),
464 Some(RedactedBecause::from_raw_event(redaction)?),
465 )
466 .map_err(StoreError::Redaction)?;
467 *raw_evt = Raw::new(&redacted)?.cast();
468 }
469 }
470 }
471 }
472 }
473 }
474
475 debug!("Saved changes in {:?}", now.elapsed());
476
477 Ok(())
478 }
479
480 async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
481 Ok(self.inner.read().unwrap().presence.get(user_id).cloned())
482 }
483
484 async fn get_presence_events(
485 &self,
486 user_ids: &[OwnedUserId],
487 ) -> Result<Vec<Raw<PresenceEvent>>> {
488 let presence = &self.inner.read().unwrap().presence;
489 Ok(user_ids.iter().filter_map(|user_id| presence.get(user_id).cloned()).collect())
490 }
491
492 async fn get_state_event(
493 &self,
494 room_id: &RoomId,
495 event_type: StateEventType,
496 state_key: &str,
497 ) -> Result<Option<RawAnySyncOrStrippedState>> {
498 Ok(self
499 .get_state_events_for_keys(room_id, event_type, &[state_key])
500 .await?
501 .into_iter()
502 .next())
503 }
504
505 async fn get_state_events(
506 &self,
507 room_id: &RoomId,
508 event_type: StateEventType,
509 ) -> Result<Vec<RawAnySyncOrStrippedState>> {
510 fn get_events<T>(
511 state_map: &HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<T>>>>,
512 room_id: &RoomId,
513 event_type: &StateEventType,
514 to_enum: fn(Raw<T>) -> RawAnySyncOrStrippedState,
515 ) -> Option<Vec<RawAnySyncOrStrippedState>> {
516 let state_events = state_map.get(room_id)?.get(event_type)?;
517 Some(state_events.values().cloned().map(to_enum).collect())
518 }
519
520 let inner = self.inner.read().unwrap();
521 Ok(get_events(
522 &inner.stripped_room_state,
523 room_id,
524 &event_type,
525 RawAnySyncOrStrippedState::Stripped,
526 )
527 .or_else(|| {
528 get_events(&inner.room_state, room_id, &event_type, RawAnySyncOrStrippedState::Sync)
529 })
530 .unwrap_or_default())
531 }
532
533 async fn get_state_events_for_keys(
534 &self,
535 room_id: &RoomId,
536 event_type: StateEventType,
537 state_keys: &[&str],
538 ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
539 let inner = self.inner.read().unwrap();
540
541 if let Some(stripped_state_events) =
542 inner.stripped_room_state.get(room_id).and_then(|events| events.get(&event_type))
543 {
544 Ok(state_keys
545 .iter()
546 .filter_map(|k| {
547 stripped_state_events
548 .get(*k)
549 .map(|e| RawAnySyncOrStrippedState::Stripped(e.clone()))
550 })
551 .collect())
552 } else if let Some(sync_state_events) =
553 inner.room_state.get(room_id).and_then(|events| events.get(&event_type))
554 {
555 Ok(state_keys
556 .iter()
557 .filter_map(|k| {
558 sync_state_events.get(*k).map(|e| RawAnySyncOrStrippedState::Sync(e.clone()))
559 })
560 .collect())
561 } else {
562 Ok(Vec::new())
563 }
564 }
565
566 async fn get_profile(
567 &self,
568 room_id: &RoomId,
569 user_id: &UserId,
570 ) -> Result<Option<MinimalRoomMemberEvent>> {
571 Ok(self
572 .inner
573 .read()
574 .unwrap()
575 .profiles
576 .get(room_id)
577 .and_then(|room_profiles| room_profiles.get(user_id))
578 .cloned())
579 }
580
581 async fn get_profiles<'a>(
582 &self,
583 room_id: &RoomId,
584 user_ids: &'a [OwnedUserId],
585 ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>> {
586 if user_ids.is_empty() {
587 return Ok(BTreeMap::new());
588 }
589
590 let profiles = &self.inner.read().unwrap().profiles;
591 let Some(room_profiles) = profiles.get(room_id) else {
592 return Ok(BTreeMap::new());
593 };
594
595 Ok(user_ids
596 .iter()
597 .filter_map(|user_id| room_profiles.get(user_id).map(|p| (&**user_id, p.clone())))
598 .collect())
599 }
600
601 #[instrument(skip(self, memberships))]
602 async fn get_user_ids(
603 &self,
604 room_id: &RoomId,
605 memberships: RoomMemberships,
606 ) -> Result<Vec<OwnedUserId>> {
607 fn get_user_ids_inner(
613 members: &HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
614 room_id: &RoomId,
615 memberships: RoomMemberships,
616 ) -> Vec<OwnedUserId> {
617 members
618 .get(room_id)
619 .map(|members| {
620 members
621 .iter()
622 .filter_map(|(user_id, membership)| {
623 memberships.matches(membership).then_some(user_id)
624 })
625 .cloned()
626 .collect()
627 })
628 .unwrap_or_default()
629 }
630 let inner = self.inner.read().unwrap();
631 let v = get_user_ids_inner(&inner.stripped_members, room_id, memberships);
632 if !v.is_empty() {
633 return Ok(v);
634 }
635 Ok(get_user_ids_inner(&inner.members, room_id, memberships))
636 }
637
638 async fn get_room_infos(&self) -> Result<Vec<RoomInfo>> {
639 Ok(self.inner.read().unwrap().room_info.values().cloned().collect())
640 }
641
642 async fn get_users_with_display_name(
643 &self,
644 room_id: &RoomId,
645 display_name: &DisplayName,
646 ) -> Result<BTreeSet<OwnedUserId>> {
647 Ok(self
648 .inner
649 .read()
650 .unwrap()
651 .display_names
652 .get(room_id)
653 .and_then(|room_names| room_names.get(display_name).cloned())
654 .unwrap_or_default())
655 }
656
657 async fn get_users_with_display_names<'a>(
658 &self,
659 room_id: &RoomId,
660 display_names: &'a [DisplayName],
661 ) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>> {
662 if display_names.is_empty() {
663 return Ok(HashMap::new());
664 }
665
666 let inner = self.inner.read().unwrap();
667 let Some(room_names) = inner.display_names.get(room_id) else {
668 return Ok(HashMap::new());
669 };
670
671 Ok(display_names.iter().filter_map(|n| room_names.get(n).map(|d| (n, d.clone()))).collect())
672 }
673
674 async fn get_account_data_event(
675 &self,
676 event_type: GlobalAccountDataEventType,
677 ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
678 Ok(self.inner.read().unwrap().account_data.get(&event_type).cloned())
679 }
680
681 async fn get_room_account_data_event(
682 &self,
683 room_id: &RoomId,
684 event_type: RoomAccountDataEventType,
685 ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
686 Ok(self
687 .inner
688 .read()
689 .unwrap()
690 .room_account_data
691 .get(room_id)
692 .and_then(|m| m.get(&event_type))
693 .cloned())
694 }
695
696 async fn get_user_room_receipt_event(
697 &self,
698 room_id: &RoomId,
699 receipt_type: ReceiptType,
700 thread: ReceiptThread,
701 user_id: &UserId,
702 ) -> Result<Option<(OwnedEventId, Receipt)>> {
703 Ok(self.get_user_room_receipt_event_impl(room_id, receipt_type, thread, user_id))
704 }
705
706 async fn get_event_room_receipt_events(
707 &self,
708 room_id: &RoomId,
709 receipt_type: ReceiptType,
710 thread: ReceiptThread,
711 event_id: &EventId,
712 ) -> Result<Vec<(OwnedUserId, Receipt)>> {
713 Ok(self
714 .get_event_room_receipt_events_impl(room_id, receipt_type, thread, event_id)
715 .unwrap_or_default())
716 }
717
718 async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
719 Ok(self.inner.read().unwrap().custom.get(key).cloned())
720 }
721
722 async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
723 Ok(self.inner.write().unwrap().custom.insert(key.to_vec(), value))
724 }
725
726 async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
727 Ok(self.inner.write().unwrap().custom.remove(key))
728 }
729
730 async fn remove_room(&self, room_id: &RoomId) -> Result<()> {
731 let mut inner = self.inner.write().unwrap();
732
733 inner.profiles.remove(room_id);
734 inner.display_names.remove(room_id);
735 inner.members.remove(room_id);
736 inner.room_info.remove(room_id);
737 inner.room_state.remove(room_id);
738 inner.room_account_data.remove(room_id);
739 inner.stripped_room_state.remove(room_id);
740 inner.stripped_members.remove(room_id);
741 inner.room_user_receipts.remove(room_id);
742 inner.room_event_receipts.remove(room_id);
743 inner.send_queue_events.remove(room_id);
744 inner.dependent_send_queue_events.remove(room_id);
745
746 Ok(())
747 }
748
749 async fn save_send_queue_request(
750 &self,
751 room_id: &RoomId,
752 transaction_id: OwnedTransactionId,
753 created_at: MilliSecondsSinceUnixEpoch,
754 kind: QueuedRequestKind,
755 priority: usize,
756 ) -> Result<(), Self::Error> {
757 self.inner
758 .write()
759 .unwrap()
760 .send_queue_events
761 .entry(room_id.to_owned())
762 .or_default()
763 .push(QueuedRequest { kind, transaction_id, error: None, priority, created_at });
764 Ok(())
765 }
766
767 async fn update_send_queue_request(
768 &self,
769 room_id: &RoomId,
770 transaction_id: &TransactionId,
771 kind: QueuedRequestKind,
772 ) -> Result<bool, Self::Error> {
773 if let Some(entry) = self
774 .inner
775 .write()
776 .unwrap()
777 .send_queue_events
778 .entry(room_id.to_owned())
779 .or_default()
780 .iter_mut()
781 .find(|item| item.transaction_id == transaction_id)
782 {
783 entry.kind = kind;
784 entry.error = None;
785 Ok(true)
786 } else {
787 Ok(false)
788 }
789 }
790
791 async fn remove_send_queue_request(
792 &self,
793 room_id: &RoomId,
794 transaction_id: &TransactionId,
795 ) -> Result<bool, Self::Error> {
796 let mut inner = self.inner.write().unwrap();
797 let q = &mut inner.send_queue_events;
798
799 let entry = q.get_mut(room_id);
800 if let Some(entry) = entry {
801 if let Some(pos) = entry.iter().position(|item| item.transaction_id == transaction_id) {
803 entry.remove(pos);
804 if entry.is_empty() {
806 q.remove(room_id);
807 }
808 return Ok(true);
809 }
810 }
811
812 Ok(false)
813 }
814
815 async fn load_send_queue_requests(
816 &self,
817 room_id: &RoomId,
818 ) -> Result<Vec<QueuedRequest>, Self::Error> {
819 let mut ret = self
820 .inner
821 .write()
822 .unwrap()
823 .send_queue_events
824 .entry(room_id.to_owned())
825 .or_default()
826 .clone();
827 ret.sort_by(|lhs, rhs| rhs.priority.cmp(&lhs.priority));
829 Ok(ret)
830 }
831
832 async fn update_send_queue_request_status(
833 &self,
834 room_id: &RoomId,
835 transaction_id: &TransactionId,
836 error: Option<QueueWedgeError>,
837 ) -> Result<(), Self::Error> {
838 if let Some(entry) = self
839 .inner
840 .write()
841 .unwrap()
842 .send_queue_events
843 .entry(room_id.to_owned())
844 .or_default()
845 .iter_mut()
846 .find(|item| item.transaction_id == transaction_id)
847 {
848 entry.error = error;
849 }
850 Ok(())
851 }
852
853 async fn load_rooms_with_unsent_requests(&self) -> Result<Vec<OwnedRoomId>, Self::Error> {
854 Ok(self.inner.read().unwrap().send_queue_events.keys().cloned().collect())
855 }
856
857 async fn save_dependent_queued_request(
858 &self,
859 room: &RoomId,
860 parent_transaction_id: &TransactionId,
861 own_transaction_id: ChildTransactionId,
862 created_at: MilliSecondsSinceUnixEpoch,
863 content: DependentQueuedRequestKind,
864 ) -> Result<(), Self::Error> {
865 self.inner
866 .write()
867 .unwrap()
868 .dependent_send_queue_events
869 .entry(room.to_owned())
870 .or_default()
871 .push(DependentQueuedRequest {
872 kind: content,
873 parent_transaction_id: parent_transaction_id.to_owned(),
874 own_transaction_id,
875 parent_key: None,
876 created_at,
877 });
878 Ok(())
879 }
880
881 async fn mark_dependent_queued_requests_as_ready(
882 &self,
883 room: &RoomId,
884 parent_txn_id: &TransactionId,
885 sent_parent_key: SentRequestKey,
886 ) -> Result<usize, Self::Error> {
887 let mut inner = self.inner.write().unwrap();
888 let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
889 let mut num_updated = 0;
890 for d in dependents.iter_mut().filter(|item| item.parent_transaction_id == parent_txn_id) {
891 d.parent_key = Some(sent_parent_key.clone());
892 num_updated += 1;
893 }
894 Ok(num_updated)
895 }
896
897 async fn update_dependent_queued_request(
898 &self,
899 room: &RoomId,
900 own_transaction_id: &ChildTransactionId,
901 new_content: DependentQueuedRequestKind,
902 ) -> Result<bool, Self::Error> {
903 let mut inner = self.inner.write().unwrap();
904 let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
905 for d in dependents.iter_mut() {
906 if d.own_transaction_id == *own_transaction_id {
907 d.kind = new_content;
908 return Ok(true);
909 }
910 }
911 Ok(false)
912 }
913
914 async fn remove_dependent_queued_request(
915 &self,
916 room: &RoomId,
917 txn_id: &ChildTransactionId,
918 ) -> Result<bool, Self::Error> {
919 let mut inner = self.inner.write().unwrap();
920 let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
921 if let Some(pos) = dependents.iter().position(|item| item.own_transaction_id == *txn_id) {
922 dependents.remove(pos);
923 Ok(true)
924 } else {
925 Ok(false)
926 }
927 }
928
929 async fn load_dependent_queued_requests(
934 &self,
935 room: &RoomId,
936 ) -> Result<Vec<DependentQueuedRequest>, Self::Error> {
937 Ok(self
938 .inner
939 .read()
940 .unwrap()
941 .dependent_send_queue_events
942 .get(room)
943 .cloned()
944 .unwrap_or_default())
945 }
946}
947
948#[cfg(test)]
949mod tests {
950 use super::{MemoryStore, Result, StateStore};
951
952 async fn get_store() -> Result<impl StateStore> {
953 Ok(MemoryStore::new())
954 }
955
956 statestore_integration_tests!();
957}