1use std::{
16 borrow::Borrow,
17 collections::{BTreeMap, BTreeSet, HashMap},
18 fmt,
19 ops::Deref,
20 sync::Arc,
21};
22
23use as_variant::as_variant;
24use async_trait::async_trait;
25use growable_bloom_filter::GrowableBloom;
26use matrix_sdk_common::{AsyncTraitDeps, ttl::TtlValue};
27use ruma::{
28 EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedMxcUri, OwnedRoomId,
29 OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId,
30 api::{
31 MatrixVersion, SupportedVersions,
32 client::{
33 discovery::{
34 discover_homeserver::{self, HomeserverInfo, IdentityServerInfo, TileServerInfo},
35 get_capabilities::v3::Capabilities,
36 },
37 rtc::RtcTransport,
38 },
39 },
40 events::{
41 AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, EmptyStateKey, GlobalAccountDataEvent,
42 GlobalAccountDataEventContent, GlobalAccountDataEventType, RedactContent,
43 RedactedStateEventContent, RoomAccountDataEvent, RoomAccountDataEventContent,
44 RoomAccountDataEventType, StateEventType, StaticEventContent, StaticStateEventContent,
45 presence::PresenceEvent,
46 receipt::{Receipt, ReceiptThread, ReceiptType},
47 },
48 serde::Raw,
49};
50use serde::{Deserialize, Serialize};
51use thiserror::Error;
52use tokio::sync::{Mutex, MutexGuard};
53
54use super::{
55 ChildTransactionId, DependentQueuedRequest, DependentQueuedRequestKind, QueueWedgeError,
56 QueuedRequest, QueuedRequestKind, RoomLoadSettings, StateChanges, StoreError,
57 send_queue::SentRequestKey,
58};
59use crate::{
60 MinimalRoomMemberEvent, RoomInfo, RoomMemberships,
61 deserialized_responses::{
62 DisplayName, RawAnySyncOrStrippedState, RawMemberEvent, RawSyncOrStrippedState,
63 },
64 store::StoredThreadSubscription,
65};
66
67#[cfg_attr(target_family = "wasm", async_trait(?Send))]
70#[cfg_attr(not(target_family = "wasm"), async_trait)]
71pub trait StateStore: AsyncTraitDeps {
72 type Error: fmt::Debug + Into<StoreError> + From<serde_json::Error>;
74
75 async fn get_kv_data(
81 &self,
82 key: StateStoreDataKey<'_>,
83 ) -> Result<Option<StateStoreDataValue>, Self::Error>;
84
85 async fn set_kv_data(
95 &self,
96 key: StateStoreDataKey<'_>,
97 value: StateStoreDataValue,
98 ) -> Result<(), Self::Error>;
99
100 async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<(), Self::Error>;
106
107 async fn save_changes(&self, changes: &StateChanges) -> Result<(), Self::Error>;
109
110 async fn get_presence_event(
117 &self,
118 user_id: &UserId,
119 ) -> Result<Option<Raw<PresenceEvent>>, Self::Error>;
120
121 async fn get_presence_events(
127 &self,
128 user_ids: &[OwnedUserId],
129 ) -> Result<Vec<Raw<PresenceEvent>>, Self::Error>;
130
131 async fn get_state_event(
139 &self,
140 room_id: &RoomId,
141 event_type: StateEventType,
142 state_key: &str,
143 ) -> Result<Option<RawAnySyncOrStrippedState>, Self::Error>;
144
145 async fn get_state_events(
153 &self,
154 room_id: &RoomId,
155 event_type: StateEventType,
156 ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error>;
157
158 async fn get_state_events_for_keys(
169 &self,
170 room_id: &RoomId,
171 event_type: StateEventType,
172 state_keys: &[&str],
173 ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error>;
174
175 async fn get_profile(
183 &self,
184 room_id: &RoomId,
185 user_id: &UserId,
186 ) -> Result<Option<MinimalRoomMemberEvent>, Self::Error>;
187
188 async fn get_profiles<'a>(
196 &self,
197 room_id: &RoomId,
198 user_ids: &'a [OwnedUserId],
199 ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>, Self::Error>;
200
201 async fn get_user_ids(
204 &self,
205 room_id: &RoomId,
206 memberships: RoomMemberships,
207 ) -> Result<Vec<OwnedUserId>, Self::Error>;
208
209 async fn get_room_infos(
211 &self,
212 room_load_settings: &RoomLoadSettings,
213 ) -> Result<Vec<RoomInfo>, Self::Error>;
214
215 async fn get_users_with_display_name(
224 &self,
225 room_id: &RoomId,
226 display_name: &DisplayName,
227 ) -> Result<BTreeSet<OwnedUserId>, Self::Error>;
228
229 async fn get_users_with_display_names<'a>(
237 &self,
238 room_id: &RoomId,
239 display_names: &'a [DisplayName],
240 ) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>, Self::Error>;
241
242 async fn get_account_data_event(
248 &self,
249 event_type: GlobalAccountDataEventType,
250 ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>, Self::Error>;
251
252 async fn get_room_account_data_event(
262 &self,
263 room_id: &RoomId,
264 event_type: RoomAccountDataEventType,
265 ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>, Self::Error>;
266
267 async fn get_user_room_receipt_event(
280 &self,
281 room_id: &RoomId,
282 receipt_type: ReceiptType,
283 thread: ReceiptThread,
284 user_id: &UserId,
285 ) -> Result<Option<(OwnedEventId, Receipt)>, Self::Error>;
286
287 async fn get_event_room_receipt_events(
301 &self,
302 room_id: &RoomId,
303 receipt_type: ReceiptType,
304 thread: ReceiptThread,
305 event_id: &EventId,
306 ) -> Result<Vec<(OwnedUserId, Receipt)>, Self::Error>;
307
308 async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error>;
314
315 async fn set_custom_value(
324 &self,
325 key: &[u8],
326 value: Vec<u8>,
327 ) -> Result<Option<Vec<u8>>, Self::Error>;
328
329 async fn set_custom_value_no_read(
343 &self,
344 key: &[u8],
345 value: Vec<u8>,
346 ) -> Result<(), Self::Error> {
347 self.set_custom_value(key, value).await.map(|_| ())
348 }
349
350 async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error>;
356
357 async fn remove_room(&self, room_id: &RoomId) -> Result<(), Self::Error>;
363
364 async fn save_send_queue_request(
374 &self,
375 room_id: &RoomId,
376 transaction_id: OwnedTransactionId,
377 created_at: MilliSecondsSinceUnixEpoch,
378 request: QueuedRequestKind,
379 priority: usize,
380 ) -> Result<(), Self::Error>;
381
382 async fn update_send_queue_request(
394 &self,
395 room_id: &RoomId,
396 transaction_id: &TransactionId,
397 content: QueuedRequestKind,
398 ) -> Result<bool, Self::Error>;
399
400 async fn remove_send_queue_request(
406 &self,
407 room_id: &RoomId,
408 transaction_id: &TransactionId,
409 ) -> Result<bool, Self::Error>;
410
411 async fn load_send_queue_requests(
417 &self,
418 room_id: &RoomId,
419 ) -> Result<Vec<QueuedRequest>, Self::Error>;
420
421 async fn update_send_queue_request_status(
424 &self,
425 room_id: &RoomId,
426 transaction_id: &TransactionId,
427 error: Option<QueueWedgeError>,
428 ) -> Result<(), Self::Error>;
429
430 async fn load_rooms_with_unsent_requests(&self) -> Result<Vec<OwnedRoomId>, Self::Error>;
432
433 async fn save_dependent_queued_request(
436 &self,
437 room_id: &RoomId,
438 parent_txn_id: &TransactionId,
439 own_txn_id: ChildTransactionId,
440 created_at: MilliSecondsSinceUnixEpoch,
441 content: DependentQueuedRequestKind,
442 ) -> Result<(), Self::Error>;
443
444 async fn mark_dependent_queued_requests_as_ready(
453 &self,
454 room_id: &RoomId,
455 parent_txn_id: &TransactionId,
456 sent_parent_key: SentRequestKey,
457 ) -> Result<usize, Self::Error>;
458
459 async fn update_dependent_queued_request(
463 &self,
464 room_id: &RoomId,
465 own_transaction_id: &ChildTransactionId,
466 new_content: DependentQueuedRequestKind,
467 ) -> Result<bool, Self::Error>;
468
469 async fn remove_dependent_queued_request(
474 &self,
475 room: &RoomId,
476 own_txn_id: &ChildTransactionId,
477 ) -> Result<bool, Self::Error>;
478
479 async fn load_dependent_queued_requests(
485 &self,
486 room: &RoomId,
487 ) -> Result<Vec<DependentQueuedRequest>, Self::Error>;
488
489 async fn upsert_thread_subscriptions(
499 &self,
500 updates: Vec<(&RoomId, &EventId, StoredThreadSubscription)>,
501 ) -> Result<(), Self::Error>;
502
503 async fn remove_thread_subscription(
507 &self,
508 room: &RoomId,
509 thread_id: &EventId,
510 ) -> Result<(), Self::Error>;
511
512 async fn load_thread_subscription(
516 &self,
517 room: &RoomId,
518 thread_id: &EventId,
519 ) -> Result<Option<StoredThreadSubscription>, Self::Error>;
520
521 async fn close(&self) -> Result<(), Self::Error>;
527
528 async fn reopen(&self) -> Result<(), Self::Error>;
531
532 #[doc(hidden)]
538 async fn optimize(&self) -> Result<(), Self::Error>;
539
540 async fn get_size(&self) -> Result<Option<usize>, Self::Error>;
542}
543
544#[cfg_attr(target_family = "wasm", async_trait(?Send))]
545#[cfg_attr(not(target_family = "wasm"), async_trait)]
546impl<T: StateStore> StateStore for &T {
547 type Error = T::Error;
548
549 async fn get_kv_data(
550 &self,
551 key: StateStoreDataKey<'_>,
552 ) -> Result<Option<StateStoreDataValue>, Self::Error> {
553 (*self).get_kv_data(key).await
554 }
555
556 async fn set_kv_data(
557 &self,
558 key: StateStoreDataKey<'_>,
559 value: StateStoreDataValue,
560 ) -> Result<(), Self::Error> {
561 (*self).set_kv_data(key, value).await
562 }
563
564 async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<(), Self::Error> {
565 (*self).remove_kv_data(key).await
566 }
567
568 async fn save_changes(&self, changes: &StateChanges) -> Result<(), Self::Error> {
569 (*self).save_changes(changes).await
570 }
571
572 async fn get_presence_event(
573 &self,
574 user_id: &UserId,
575 ) -> Result<Option<Raw<PresenceEvent>>, Self::Error> {
576 (*self).get_presence_event(user_id).await
577 }
578
579 async fn get_presence_events(
580 &self,
581 user_ids: &[OwnedUserId],
582 ) -> Result<Vec<Raw<PresenceEvent>>, Self::Error> {
583 (*self).get_presence_events(user_ids).await
584 }
585
586 async fn get_state_event(
587 &self,
588 room_id: &RoomId,
589 event_type: StateEventType,
590 state_key: &str,
591 ) -> Result<Option<RawAnySyncOrStrippedState>, Self::Error> {
592 (*self).get_state_event(room_id, event_type, state_key).await
593 }
594
595 async fn get_state_events(
596 &self,
597 room_id: &RoomId,
598 event_type: StateEventType,
599 ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
600 (*self).get_state_events(room_id, event_type).await
601 }
602
603 async fn get_state_events_for_keys(
604 &self,
605 room_id: &RoomId,
606 event_type: StateEventType,
607 state_keys: &[&str],
608 ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
609 (*self).get_state_events_for_keys(room_id, event_type, state_keys).await
610 }
611
612 async fn get_profile(
613 &self,
614 room_id: &RoomId,
615 user_id: &UserId,
616 ) -> Result<Option<MinimalRoomMemberEvent>, Self::Error> {
617 (*self).get_profile(room_id, user_id).await
618 }
619
620 async fn get_profiles<'a>(
621 &self,
622 room_id: &RoomId,
623 user_ids: &'a [OwnedUserId],
624 ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>, Self::Error> {
625 (*self).get_profiles(room_id, user_ids).await
626 }
627
628 async fn get_user_ids(
629 &self,
630 room_id: &RoomId,
631 memberships: RoomMemberships,
632 ) -> Result<Vec<OwnedUserId>, Self::Error> {
633 (*self).get_user_ids(room_id, memberships).await
634 }
635
636 async fn get_room_infos(
637 &self,
638 room_load_settings: &RoomLoadSettings,
639 ) -> Result<Vec<RoomInfo>, Self::Error> {
640 (*self).get_room_infos(room_load_settings).await
641 }
642
643 async fn get_users_with_display_name(
644 &self,
645 room_id: &RoomId,
646 display_name: &DisplayName,
647 ) -> Result<BTreeSet<OwnedUserId>, Self::Error> {
648 (*self).get_users_with_display_name(room_id, display_name).await
649 }
650
651 async fn get_users_with_display_names<'a>(
652 &self,
653 room_id: &RoomId,
654 display_names: &'a [DisplayName],
655 ) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>, Self::Error> {
656 (*self).get_users_with_display_names(room_id, display_names).await
657 }
658
659 async fn get_account_data_event(
660 &self,
661 event_type: GlobalAccountDataEventType,
662 ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>, Self::Error> {
663 (*self).get_account_data_event(event_type).await
664 }
665
666 async fn get_room_account_data_event(
667 &self,
668 room_id: &RoomId,
669 event_type: RoomAccountDataEventType,
670 ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>, Self::Error> {
671 (*self).get_room_account_data_event(room_id, event_type).await
672 }
673
674 async fn get_user_room_receipt_event(
675 &self,
676 room_id: &RoomId,
677 receipt_type: ReceiptType,
678 thread: ReceiptThread,
679 user_id: &UserId,
680 ) -> Result<Option<(OwnedEventId, Receipt)>, Self::Error> {
681 (*self).get_user_room_receipt_event(room_id, receipt_type, thread, user_id).await
682 }
683
684 async fn get_event_room_receipt_events(
685 &self,
686 room_id: &RoomId,
687 receipt_type: ReceiptType,
688 thread: ReceiptThread,
689 event_id: &EventId,
690 ) -> Result<Vec<(OwnedUserId, Receipt)>, Self::Error> {
691 (*self).get_event_room_receipt_events(room_id, receipt_type, thread, event_id).await
692 }
693
694 async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
695 (*self).get_custom_value(key).await
696 }
697
698 async fn set_custom_value(
699 &self,
700 key: &[u8],
701 value: Vec<u8>,
702 ) -> Result<Option<Vec<u8>>, Self::Error> {
703 (*self).set_custom_value(key, value).await
704 }
705
706 async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
707 (*self).remove_custom_value(key).await
708 }
709
710 async fn remove_room(&self, room_id: &RoomId) -> Result<(), Self::Error> {
711 (*self).remove_room(room_id).await
712 }
713
714 async fn save_send_queue_request(
715 &self,
716 room_id: &RoomId,
717 transaction_id: OwnedTransactionId,
718 created_at: MilliSecondsSinceUnixEpoch,
719 request: QueuedRequestKind,
720 priority: usize,
721 ) -> Result<(), Self::Error> {
722 (*self)
723 .save_send_queue_request(room_id, transaction_id, created_at, request, priority)
724 .await
725 }
726
727 async fn update_send_queue_request(
728 &self,
729 room_id: &RoomId,
730 transaction_id: &TransactionId,
731 content: QueuedRequestKind,
732 ) -> Result<bool, Self::Error> {
733 (*self).update_send_queue_request(room_id, transaction_id, content).await
734 }
735
736 async fn remove_send_queue_request(
737 &self,
738 room_id: &RoomId,
739 transaction_id: &TransactionId,
740 ) -> Result<bool, Self::Error> {
741 (*self).remove_send_queue_request(room_id, transaction_id).await
742 }
743
744 async fn load_send_queue_requests(
745 &self,
746 room_id: &RoomId,
747 ) -> Result<Vec<QueuedRequest>, Self::Error> {
748 (*self).load_send_queue_requests(room_id).await
749 }
750
751 async fn update_send_queue_request_status(
752 &self,
753 room_id: &RoomId,
754 transaction_id: &TransactionId,
755 error: Option<QueueWedgeError>,
756 ) -> Result<(), Self::Error> {
757 (*self).update_send_queue_request_status(room_id, transaction_id, error).await
758 }
759
760 async fn load_rooms_with_unsent_requests(&self) -> Result<Vec<OwnedRoomId>, Self::Error> {
761 (*self).load_rooms_with_unsent_requests().await
762 }
763
764 async fn save_dependent_queued_request(
765 &self,
766 room_id: &RoomId,
767 parent_txn_id: &TransactionId,
768 own_txn_id: ChildTransactionId,
769 created_at: MilliSecondsSinceUnixEpoch,
770 content: DependentQueuedRequestKind,
771 ) -> Result<(), Self::Error> {
772 (*self)
773 .save_dependent_queued_request(room_id, parent_txn_id, own_txn_id, created_at, content)
774 .await
775 }
776
777 async fn mark_dependent_queued_requests_as_ready(
778 &self,
779 room_id: &RoomId,
780 parent_txn_id: &TransactionId,
781 sent_parent_key: SentRequestKey,
782 ) -> Result<usize, Self::Error> {
783 (*self)
784 .mark_dependent_queued_requests_as_ready(room_id, parent_txn_id, sent_parent_key)
785 .await
786 }
787
788 async fn update_dependent_queued_request(
789 &self,
790 room_id: &RoomId,
791 own_transaction_id: &ChildTransactionId,
792 new_content: DependentQueuedRequestKind,
793 ) -> Result<bool, Self::Error> {
794 (*self).update_dependent_queued_request(room_id, own_transaction_id, new_content).await
795 }
796
797 async fn remove_dependent_queued_request(
798 &self,
799 room: &RoomId,
800 own_txn_id: &ChildTransactionId,
801 ) -> Result<bool, Self::Error> {
802 (*self).remove_dependent_queued_request(room, own_txn_id).await
803 }
804
805 async fn load_dependent_queued_requests(
806 &self,
807 room: &RoomId,
808 ) -> Result<Vec<DependentQueuedRequest>, Self::Error> {
809 (*self).load_dependent_queued_requests(room).await
810 }
811
812 async fn upsert_thread_subscriptions(
813 &self,
814 updates: Vec<(&RoomId, &EventId, StoredThreadSubscription)>,
815 ) -> Result<(), Self::Error> {
816 (*self).upsert_thread_subscriptions(updates).await
817 }
818
819 async fn remove_thread_subscription(
820 &self,
821 room: &RoomId,
822 thread_id: &EventId,
823 ) -> Result<(), Self::Error> {
824 (*self).remove_thread_subscription(room, thread_id).await
825 }
826
827 async fn load_thread_subscription(
828 &self,
829 room: &RoomId,
830 thread_id: &EventId,
831 ) -> Result<Option<StoredThreadSubscription>, Self::Error> {
832 (*self).load_thread_subscription(room, thread_id).await
833 }
834
835 async fn close(&self) -> Result<(), Self::Error> {
836 (*self).close().await
837 }
838
839 async fn reopen(&self) -> Result<(), Self::Error> {
840 (*self).reopen().await
841 }
842
843 async fn optimize(&self) -> Result<(), Self::Error> {
844 (*self).optimize().await
845 }
846
847 async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
848 (*self).get_size().await
849 }
850}
851
852#[cfg_attr(target_family = "wasm", async_trait(?Send))]
853#[cfg_attr(not(target_family = "wasm"), async_trait)]
854impl<T: StateStore + ?Sized> StateStore for Arc<T> {
855 type Error = T::Error;
856
857 async fn get_kv_data(
858 &self,
859 key: StateStoreDataKey<'_>,
860 ) -> Result<Option<StateStoreDataValue>, Self::Error> {
861 self.deref().get_kv_data(key).await
862 }
863
864 async fn set_kv_data(
865 &self,
866 key: StateStoreDataKey<'_>,
867 value: StateStoreDataValue,
868 ) -> Result<(), Self::Error> {
869 self.deref().set_kv_data(key, value).await
870 }
871
872 async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<(), Self::Error> {
873 self.deref().remove_kv_data(key).await
874 }
875
876 async fn save_changes(&self, changes: &StateChanges) -> Result<(), Self::Error> {
877 self.deref().save_changes(changes).await
878 }
879
880 async fn get_presence_event(
881 &self,
882 user_id: &UserId,
883 ) -> Result<Option<Raw<PresenceEvent>>, Self::Error> {
884 self.deref().get_presence_event(user_id).await
885 }
886
887 async fn get_presence_events(
888 &self,
889 user_ids: &[OwnedUserId],
890 ) -> Result<Vec<Raw<PresenceEvent>>, Self::Error> {
891 self.deref().get_presence_events(user_ids).await
892 }
893
894 async fn get_state_event(
895 &self,
896 room_id: &RoomId,
897 event_type: StateEventType,
898 state_key: &str,
899 ) -> Result<Option<RawAnySyncOrStrippedState>, Self::Error> {
900 self.deref().get_state_event(room_id, event_type, state_key).await
901 }
902
903 async fn get_state_events(
904 &self,
905 room_id: &RoomId,
906 event_type: StateEventType,
907 ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
908 self.deref().get_state_events(room_id, event_type).await
909 }
910
911 async fn get_state_events_for_keys(
912 &self,
913 room_id: &RoomId,
914 event_type: StateEventType,
915 state_keys: &[&str],
916 ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
917 self.deref().get_state_events_for_keys(room_id, event_type, state_keys).await
918 }
919
920 async fn get_profile(
921 &self,
922 room_id: &RoomId,
923 user_id: &UserId,
924 ) -> Result<Option<MinimalRoomMemberEvent>, Self::Error> {
925 self.deref().get_profile(room_id, user_id).await
926 }
927
928 async fn get_profiles<'a>(
929 &self,
930 room_id: &RoomId,
931 user_ids: &'a [OwnedUserId],
932 ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>, Self::Error> {
933 self.deref().get_profiles(room_id, user_ids).await
934 }
935
936 async fn get_user_ids(
937 &self,
938 room_id: &RoomId,
939 memberships: RoomMemberships,
940 ) -> Result<Vec<OwnedUserId>, Self::Error> {
941 self.deref().get_user_ids(room_id, memberships).await
942 }
943
944 async fn get_room_infos(
945 &self,
946 room_load_settings: &RoomLoadSettings,
947 ) -> Result<Vec<RoomInfo>, Self::Error> {
948 self.deref().get_room_infos(room_load_settings).await
949 }
950
951 async fn get_users_with_display_name(
952 &self,
953 room_id: &RoomId,
954 display_name: &DisplayName,
955 ) -> Result<BTreeSet<OwnedUserId>, Self::Error> {
956 self.deref().get_users_with_display_name(room_id, display_name).await
957 }
958
959 async fn get_users_with_display_names<'a>(
960 &self,
961 room_id: &RoomId,
962 display_names: &'a [DisplayName],
963 ) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>, Self::Error> {
964 self.deref().get_users_with_display_names(room_id, display_names).await
965 }
966
967 async fn get_account_data_event(
968 &self,
969 event_type: GlobalAccountDataEventType,
970 ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>, Self::Error> {
971 self.deref().get_account_data_event(event_type).await
972 }
973
974 async fn get_room_account_data_event(
975 &self,
976 room_id: &RoomId,
977 event_type: RoomAccountDataEventType,
978 ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>, Self::Error> {
979 self.deref().get_room_account_data_event(room_id, event_type).await
980 }
981
982 async fn get_user_room_receipt_event(
983 &self,
984 room_id: &RoomId,
985 receipt_type: ReceiptType,
986 thread: ReceiptThread,
987 user_id: &UserId,
988 ) -> Result<Option<(OwnedEventId, Receipt)>, Self::Error> {
989 self.deref().get_user_room_receipt_event(room_id, receipt_type, thread, user_id).await
990 }
991
992 async fn get_event_room_receipt_events(
993 &self,
994 room_id: &RoomId,
995 receipt_type: ReceiptType,
996 thread: ReceiptThread,
997 event_id: &EventId,
998 ) -> Result<Vec<(OwnedUserId, Receipt)>, Self::Error> {
999 self.deref().get_event_room_receipt_events(room_id, receipt_type, thread, event_id).await
1000 }
1001
1002 async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
1003 self.deref().get_custom_value(key).await
1004 }
1005
1006 async fn set_custom_value(
1007 &self,
1008 key: &[u8],
1009 value: Vec<u8>,
1010 ) -> Result<Option<Vec<u8>>, Self::Error> {
1011 self.deref().set_custom_value(key, value).await
1012 }
1013
1014 async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
1015 self.deref().remove_custom_value(key).await
1016 }
1017
1018 async fn remove_room(&self, room_id: &RoomId) -> Result<(), Self::Error> {
1019 self.deref().remove_room(room_id).await
1020 }
1021
1022 async fn save_send_queue_request(
1023 &self,
1024 room_id: &RoomId,
1025 transaction_id: OwnedTransactionId,
1026 created_at: MilliSecondsSinceUnixEpoch,
1027 request: QueuedRequestKind,
1028 priority: usize,
1029 ) -> Result<(), Self::Error> {
1030 self.deref()
1031 .save_send_queue_request(room_id, transaction_id, created_at, request, priority)
1032 .await
1033 }
1034
1035 async fn update_send_queue_request(
1036 &self,
1037 room_id: &RoomId,
1038 transaction_id: &TransactionId,
1039 content: QueuedRequestKind,
1040 ) -> Result<bool, Self::Error> {
1041 self.deref().update_send_queue_request(room_id, transaction_id, content).await
1042 }
1043
1044 async fn remove_send_queue_request(
1045 &self,
1046 room_id: &RoomId,
1047 transaction_id: &TransactionId,
1048 ) -> Result<bool, Self::Error> {
1049 self.deref().remove_send_queue_request(room_id, transaction_id).await
1050 }
1051
1052 async fn load_send_queue_requests(
1053 &self,
1054 room_id: &RoomId,
1055 ) -> Result<Vec<QueuedRequest>, Self::Error> {
1056 self.deref().load_send_queue_requests(room_id).await
1057 }
1058
1059 async fn update_send_queue_request_status(
1060 &self,
1061 room_id: &RoomId,
1062 transaction_id: &TransactionId,
1063 error: Option<QueueWedgeError>,
1064 ) -> Result<(), Self::Error> {
1065 self.deref().update_send_queue_request_status(room_id, transaction_id, error).await
1066 }
1067
1068 async fn load_rooms_with_unsent_requests(&self) -> Result<Vec<OwnedRoomId>, Self::Error> {
1069 self.deref().load_rooms_with_unsent_requests().await
1070 }
1071
1072 async fn save_dependent_queued_request(
1073 &self,
1074 room_id: &RoomId,
1075 parent_txn_id: &TransactionId,
1076 own_txn_id: ChildTransactionId,
1077 created_at: MilliSecondsSinceUnixEpoch,
1078 content: DependentQueuedRequestKind,
1079 ) -> Result<(), Self::Error> {
1080 self.deref()
1081 .save_dependent_queued_request(room_id, parent_txn_id, own_txn_id, created_at, content)
1082 .await
1083 }
1084
1085 async fn mark_dependent_queued_requests_as_ready(
1086 &self,
1087 room_id: &RoomId,
1088 parent_txn_id: &TransactionId,
1089 sent_parent_key: SentRequestKey,
1090 ) -> Result<usize, Self::Error> {
1091 self.deref()
1092 .mark_dependent_queued_requests_as_ready(room_id, parent_txn_id, sent_parent_key)
1093 .await
1094 }
1095
1096 async fn update_dependent_queued_request(
1097 &self,
1098 room_id: &RoomId,
1099 own_transaction_id: &ChildTransactionId,
1100 new_content: DependentQueuedRequestKind,
1101 ) -> Result<bool, Self::Error> {
1102 self.deref().update_dependent_queued_request(room_id, own_transaction_id, new_content).await
1103 }
1104
1105 async fn remove_dependent_queued_request(
1106 &self,
1107 room: &RoomId,
1108 own_txn_id: &ChildTransactionId,
1109 ) -> Result<bool, Self::Error> {
1110 self.deref().remove_dependent_queued_request(room, own_txn_id).await
1111 }
1112
1113 async fn load_dependent_queued_requests(
1114 &self,
1115 room: &RoomId,
1116 ) -> Result<Vec<DependentQueuedRequest>, Self::Error> {
1117 self.deref().load_dependent_queued_requests(room).await
1118 }
1119
1120 async fn upsert_thread_subscriptions(
1121 &self,
1122 updates: Vec<(&RoomId, &EventId, StoredThreadSubscription)>,
1123 ) -> Result<(), Self::Error> {
1124 self.deref().upsert_thread_subscriptions(updates).await
1125 }
1126
1127 async fn remove_thread_subscription(
1128 &self,
1129 room: &RoomId,
1130 thread_id: &EventId,
1131 ) -> Result<(), Self::Error> {
1132 self.deref().remove_thread_subscription(room, thread_id).await
1133 }
1134
1135 async fn load_thread_subscription(
1136 &self,
1137 room: &RoomId,
1138 thread_id: &EventId,
1139 ) -> Result<Option<StoredThreadSubscription>, Self::Error> {
1140 self.deref().load_thread_subscription(room, thread_id).await
1141 }
1142
1143 async fn close(&self) -> Result<(), Self::Error> {
1144 self.deref().close().await
1145 }
1146
1147 async fn reopen(&self) -> Result<(), Self::Error> {
1148 self.deref().reopen().await
1149 }
1150
1151 async fn optimize(&self) -> Result<(), Self::Error> {
1152 self.deref().optimize().await
1153 }
1154
1155 async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1156 self.deref().get_size().await
1157 }
1158}
1159
1160#[repr(transparent)]
1161struct EraseStateStoreError<T>(T);
1162
1163#[cfg(not(tarpaulin_include))]
1164impl<T: fmt::Debug> fmt::Debug for EraseStateStoreError<T> {
1165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1166 self.0.fmt(f)
1167 }
1168}
1169
1170#[cfg_attr(target_family = "wasm", async_trait(?Send))]
1171#[cfg_attr(not(target_family = "wasm"), async_trait)]
1172impl<T: StateStore> StateStore for EraseStateStoreError<T> {
1173 type Error = StoreError;
1174
1175 async fn get_kv_data(
1176 &self,
1177 key: StateStoreDataKey<'_>,
1178 ) -> Result<Option<StateStoreDataValue>, Self::Error> {
1179 self.0.get_kv_data(key).await.map_err(Into::into)
1180 }
1181
1182 async fn set_kv_data(
1183 &self,
1184 key: StateStoreDataKey<'_>,
1185 value: StateStoreDataValue,
1186 ) -> Result<(), Self::Error> {
1187 self.0.set_kv_data(key, value).await.map_err(Into::into)
1188 }
1189
1190 async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<(), Self::Error> {
1191 self.0.remove_kv_data(key).await.map_err(Into::into)
1192 }
1193
1194 async fn save_changes(&self, changes: &StateChanges) -> Result<(), Self::Error> {
1195 self.0.save_changes(changes).await.map_err(Into::into)
1196 }
1197
1198 async fn get_presence_event(
1199 &self,
1200 user_id: &UserId,
1201 ) -> Result<Option<Raw<PresenceEvent>>, Self::Error> {
1202 self.0.get_presence_event(user_id).await.map_err(Into::into)
1203 }
1204
1205 async fn get_presence_events(
1206 &self,
1207 user_ids: &[OwnedUserId],
1208 ) -> Result<Vec<Raw<PresenceEvent>>, Self::Error> {
1209 self.0.get_presence_events(user_ids).await.map_err(Into::into)
1210 }
1211
1212 async fn get_state_event(
1213 &self,
1214 room_id: &RoomId,
1215 event_type: StateEventType,
1216 state_key: &str,
1217 ) -> Result<Option<RawAnySyncOrStrippedState>, Self::Error> {
1218 self.0.get_state_event(room_id, event_type, state_key).await.map_err(Into::into)
1219 }
1220
1221 async fn get_state_events(
1222 &self,
1223 room_id: &RoomId,
1224 event_type: StateEventType,
1225 ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
1226 self.0.get_state_events(room_id, event_type).await.map_err(Into::into)
1227 }
1228
1229 async fn get_state_events_for_keys(
1230 &self,
1231 room_id: &RoomId,
1232 event_type: StateEventType,
1233 state_keys: &[&str],
1234 ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
1235 self.0.get_state_events_for_keys(room_id, event_type, state_keys).await.map_err(Into::into)
1236 }
1237
1238 async fn get_profile(
1239 &self,
1240 room_id: &RoomId,
1241 user_id: &UserId,
1242 ) -> Result<Option<MinimalRoomMemberEvent>, Self::Error> {
1243 self.0.get_profile(room_id, user_id).await.map_err(Into::into)
1244 }
1245
1246 async fn get_profiles<'a>(
1247 &self,
1248 room_id: &RoomId,
1249 user_ids: &'a [OwnedUserId],
1250 ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>, Self::Error> {
1251 self.0.get_profiles(room_id, user_ids).await.map_err(Into::into)
1252 }
1253
1254 async fn get_user_ids(
1255 &self,
1256 room_id: &RoomId,
1257 memberships: RoomMemberships,
1258 ) -> Result<Vec<OwnedUserId>, Self::Error> {
1259 self.0.get_user_ids(room_id, memberships).await.map_err(Into::into)
1260 }
1261
1262 async fn get_room_infos(
1263 &self,
1264 room_load_settings: &RoomLoadSettings,
1265 ) -> Result<Vec<RoomInfo>, Self::Error> {
1266 self.0.get_room_infos(room_load_settings).await.map_err(Into::into)
1267 }
1268
1269 async fn get_users_with_display_name(
1270 &self,
1271 room_id: &RoomId,
1272 display_name: &DisplayName,
1273 ) -> Result<BTreeSet<OwnedUserId>, Self::Error> {
1274 self.0.get_users_with_display_name(room_id, display_name).await.map_err(Into::into)
1275 }
1276
1277 async fn get_users_with_display_names<'a>(
1278 &self,
1279 room_id: &RoomId,
1280 display_names: &'a [DisplayName],
1281 ) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>, Self::Error> {
1282 self.0.get_users_with_display_names(room_id, display_names).await.map_err(Into::into)
1283 }
1284
1285 async fn get_account_data_event(
1286 &self,
1287 event_type: GlobalAccountDataEventType,
1288 ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>, Self::Error> {
1289 self.0.get_account_data_event(event_type).await.map_err(Into::into)
1290 }
1291
1292 async fn get_room_account_data_event(
1293 &self,
1294 room_id: &RoomId,
1295 event_type: RoomAccountDataEventType,
1296 ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>, Self::Error> {
1297 self.0.get_room_account_data_event(room_id, event_type).await.map_err(Into::into)
1298 }
1299
1300 async fn get_user_room_receipt_event(
1301 &self,
1302 room_id: &RoomId,
1303 receipt_type: ReceiptType,
1304 thread: ReceiptThread,
1305 user_id: &UserId,
1306 ) -> Result<Option<(OwnedEventId, Receipt)>, Self::Error> {
1307 self.0
1308 .get_user_room_receipt_event(room_id, receipt_type, thread, user_id)
1309 .await
1310 .map_err(Into::into)
1311 }
1312
1313 async fn get_event_room_receipt_events(
1314 &self,
1315 room_id: &RoomId,
1316 receipt_type: ReceiptType,
1317 thread: ReceiptThread,
1318 event_id: &EventId,
1319 ) -> Result<Vec<(OwnedUserId, Receipt)>, Self::Error> {
1320 self.0
1321 .get_event_room_receipt_events(room_id, receipt_type, thread, event_id)
1322 .await
1323 .map_err(Into::into)
1324 }
1325
1326 async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
1327 self.0.get_custom_value(key).await.map_err(Into::into)
1328 }
1329
1330 async fn set_custom_value(
1331 &self,
1332 key: &[u8],
1333 value: Vec<u8>,
1334 ) -> Result<Option<Vec<u8>>, Self::Error> {
1335 self.0.set_custom_value(key, value).await.map_err(Into::into)
1336 }
1337
1338 async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
1339 self.0.remove_custom_value(key).await.map_err(Into::into)
1340 }
1341
1342 async fn remove_room(&self, room_id: &RoomId) -> Result<(), Self::Error> {
1343 self.0.remove_room(room_id).await.map_err(Into::into)
1344 }
1345
1346 async fn save_send_queue_request(
1347 &self,
1348 room_id: &RoomId,
1349 transaction_id: OwnedTransactionId,
1350 created_at: MilliSecondsSinceUnixEpoch,
1351 content: QueuedRequestKind,
1352 priority: usize,
1353 ) -> Result<(), Self::Error> {
1354 self.0
1355 .save_send_queue_request(room_id, transaction_id, created_at, content, priority)
1356 .await
1357 .map_err(Into::into)
1358 }
1359
1360 async fn update_send_queue_request(
1361 &self,
1362 room_id: &RoomId,
1363 transaction_id: &TransactionId,
1364 content: QueuedRequestKind,
1365 ) -> Result<bool, Self::Error> {
1366 self.0.update_send_queue_request(room_id, transaction_id, content).await.map_err(Into::into)
1367 }
1368
1369 async fn remove_send_queue_request(
1370 &self,
1371 room_id: &RoomId,
1372 transaction_id: &TransactionId,
1373 ) -> Result<bool, Self::Error> {
1374 self.0.remove_send_queue_request(room_id, transaction_id).await.map_err(Into::into)
1375 }
1376
1377 async fn load_send_queue_requests(
1378 &self,
1379 room_id: &RoomId,
1380 ) -> Result<Vec<QueuedRequest>, Self::Error> {
1381 self.0.load_send_queue_requests(room_id).await.map_err(Into::into)
1382 }
1383
1384 async fn update_send_queue_request_status(
1385 &self,
1386 room_id: &RoomId,
1387 transaction_id: &TransactionId,
1388 error: Option<QueueWedgeError>,
1389 ) -> Result<(), Self::Error> {
1390 self.0
1391 .update_send_queue_request_status(room_id, transaction_id, error)
1392 .await
1393 .map_err(Into::into)
1394 }
1395
1396 async fn load_rooms_with_unsent_requests(&self) -> Result<Vec<OwnedRoomId>, Self::Error> {
1397 self.0.load_rooms_with_unsent_requests().await.map_err(Into::into)
1398 }
1399
1400 async fn save_dependent_queued_request(
1401 &self,
1402 room_id: &RoomId,
1403 parent_txn_id: &TransactionId,
1404 own_txn_id: ChildTransactionId,
1405 created_at: MilliSecondsSinceUnixEpoch,
1406 content: DependentQueuedRequestKind,
1407 ) -> Result<(), Self::Error> {
1408 self.0
1409 .save_dependent_queued_request(room_id, parent_txn_id, own_txn_id, created_at, content)
1410 .await
1411 .map_err(Into::into)
1412 }
1413
1414 async fn mark_dependent_queued_requests_as_ready(
1415 &self,
1416 room_id: &RoomId,
1417 parent_txn_id: &TransactionId,
1418 sent_parent_key: SentRequestKey,
1419 ) -> Result<usize, Self::Error> {
1420 self.0
1421 .mark_dependent_queued_requests_as_ready(room_id, parent_txn_id, sent_parent_key)
1422 .await
1423 .map_err(Into::into)
1424 }
1425
1426 async fn remove_dependent_queued_request(
1427 &self,
1428 room_id: &RoomId,
1429 own_txn_id: &ChildTransactionId,
1430 ) -> Result<bool, Self::Error> {
1431 self.0.remove_dependent_queued_request(room_id, own_txn_id).await.map_err(Into::into)
1432 }
1433
1434 async fn load_dependent_queued_requests(
1435 &self,
1436 room_id: &RoomId,
1437 ) -> Result<Vec<DependentQueuedRequest>, Self::Error> {
1438 self.0.load_dependent_queued_requests(room_id).await.map_err(Into::into)
1439 }
1440
1441 async fn update_dependent_queued_request(
1442 &self,
1443 room_id: &RoomId,
1444 own_transaction_id: &ChildTransactionId,
1445 new_content: DependentQueuedRequestKind,
1446 ) -> Result<bool, Self::Error> {
1447 self.0
1448 .update_dependent_queued_request(room_id, own_transaction_id, new_content)
1449 .await
1450 .map_err(Into::into)
1451 }
1452
1453 async fn upsert_thread_subscriptions(
1454 &self,
1455 updates: Vec<(&RoomId, &EventId, StoredThreadSubscription)>,
1456 ) -> Result<(), Self::Error> {
1457 self.0.upsert_thread_subscriptions(updates).await.map_err(Into::into)
1458 }
1459
1460 async fn load_thread_subscription(
1461 &self,
1462 room: &RoomId,
1463 thread_id: &EventId,
1464 ) -> Result<Option<StoredThreadSubscription>, Self::Error> {
1465 self.0.load_thread_subscription(room, thread_id).await.map_err(Into::into)
1466 }
1467
1468 async fn remove_thread_subscription(
1469 &self,
1470 room: &RoomId,
1471 thread_id: &EventId,
1472 ) -> Result<(), Self::Error> {
1473 self.0.remove_thread_subscription(room, thread_id).await.map_err(Into::into)
1474 }
1475
1476 async fn close(&self) -> Result<(), Self::Error> {
1477 self.0.close().await.map_err(Into::into)
1478 }
1479
1480 async fn reopen(&self) -> Result<(), Self::Error> {
1481 self.0.reopen().await.map_err(Into::into)
1482 }
1483
1484 async fn optimize(&self) -> Result<(), Self::Error> {
1485 self.0.optimize().await.map_err(Into::into)
1486 }
1487
1488 async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1489 self.0.get_size().await.map_err(Into::into)
1490 }
1491}
1492
1493#[derive(Debug, Clone)]
1496pub struct SaveLockedStateStore<T = Arc<DynStateStore>> {
1497 store: T,
1498 lock: Arc<Mutex<()>>,
1499}
1500
1501#[derive(Debug, Error)]
1505#[error("a mutex guard was provided, but it does not reference the correct mutex")]
1506pub struct IncorrectMutexGuardError;
1507
1508impl From<IncorrectMutexGuardError> for StoreError {
1509 fn from(value: IncorrectMutexGuardError) -> Self {
1510 Self::backend(value)
1511 }
1512}
1513
1514impl<T> SaveLockedStateStore<T> {
1515 pub fn new(store: T) -> Self {
1517 Self { store, lock: Arc::new(Mutex::new(())) }
1518 }
1519
1520 pub fn lock(&self) -> &Mutex<()> {
1523 self.lock.as_ref()
1524 }
1525}
1526
1527impl<T: StateStore> SaveLockedStateStore<T> {
1528 pub async fn save_changes_with_guard(
1533 &self,
1534 guard: &MutexGuard<'_, ()>,
1535 changes: &StateChanges,
1536 ) -> Result<(), StoreError> {
1537 if !std::ptr::eq(MutexGuard::mutex(guard), self.lock()) {
1538 Err(IncorrectMutexGuardError.into())
1539 } else {
1540 self.store.save_changes(changes).await.map_err(Into::into)
1541 }
1542 }
1543
1544 pub async fn remove_room_with_guard(
1549 &self,
1550 guard: &MutexGuard<'_, ()>,
1551 room_id: &RoomId,
1552 ) -> Result<(), StoreError> {
1553 if !std::ptr::eq(MutexGuard::mutex(guard), self.lock()) {
1554 Err(IncorrectMutexGuardError.into())
1555 } else {
1556 self.store.remove_room(room_id).await.map_err(Into::into)
1557 }
1558 }
1559}
1560
1561#[cfg_attr(target_family = "wasm", async_trait(?Send))]
1562#[cfg_attr(not(target_family = "wasm"), async_trait)]
1563impl<T: StateStore> StateStore for SaveLockedStateStore<T> {
1564 type Error = T::Error;
1565
1566 async fn get_kv_data(
1567 &self,
1568 key: StateStoreDataKey<'_>,
1569 ) -> Result<Option<StateStoreDataValue>, Self::Error> {
1570 self.store.get_kv_data(key).await
1571 }
1572
1573 async fn set_kv_data(
1574 &self,
1575 key: StateStoreDataKey<'_>,
1576 value: StateStoreDataValue,
1577 ) -> Result<(), Self::Error> {
1578 self.store.set_kv_data(key, value).await
1579 }
1580
1581 async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<(), Self::Error> {
1582 self.store.remove_kv_data(key).await
1583 }
1584
1585 async fn save_changes(&self, changes: &StateChanges) -> Result<(), Self::Error> {
1586 let _guard = self.lock.lock().await;
1587 self.store.save_changes(changes).await
1588 }
1589
1590 async fn get_presence_event(
1591 &self,
1592 user_id: &UserId,
1593 ) -> Result<Option<Raw<PresenceEvent>>, Self::Error> {
1594 self.store.get_presence_event(user_id).await
1595 }
1596
1597 async fn get_presence_events(
1598 &self,
1599 user_ids: &[OwnedUserId],
1600 ) -> Result<Vec<Raw<PresenceEvent>>, Self::Error> {
1601 self.store.get_presence_events(user_ids).await
1602 }
1603
1604 async fn get_state_event(
1605 &self,
1606 room_id: &RoomId,
1607 event_type: StateEventType,
1608 state_key: &str,
1609 ) -> Result<Option<RawAnySyncOrStrippedState>, Self::Error> {
1610 self.store.get_state_event(room_id, event_type, state_key).await
1611 }
1612
1613 async fn get_state_events(
1614 &self,
1615 room_id: &RoomId,
1616 event_type: StateEventType,
1617 ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
1618 self.store.get_state_events(room_id, event_type).await
1619 }
1620
1621 async fn get_state_events_for_keys(
1622 &self,
1623 room_id: &RoomId,
1624 event_type: StateEventType,
1625 state_keys: &[&str],
1626 ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
1627 self.store.get_state_events_for_keys(room_id, event_type, state_keys).await
1628 }
1629
1630 async fn get_profile(
1631 &self,
1632 room_id: &RoomId,
1633 user_id: &UserId,
1634 ) -> Result<Option<MinimalRoomMemberEvent>, Self::Error> {
1635 self.store.get_profile(room_id, user_id).await
1636 }
1637
1638 async fn get_profiles<'a>(
1639 &self,
1640 room_id: &RoomId,
1641 user_ids: &'a [OwnedUserId],
1642 ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>, Self::Error> {
1643 self.store.get_profiles(room_id, user_ids).await
1644 }
1645
1646 async fn get_user_ids(
1647 &self,
1648 room_id: &RoomId,
1649 memberships: RoomMemberships,
1650 ) -> Result<Vec<OwnedUserId>, Self::Error> {
1651 self.store.get_user_ids(room_id, memberships).await
1652 }
1653
1654 async fn get_room_infos(
1655 &self,
1656 room_load_settings: &RoomLoadSettings,
1657 ) -> Result<Vec<RoomInfo>, Self::Error> {
1658 self.store.get_room_infos(room_load_settings).await
1659 }
1660
1661 async fn get_users_with_display_name(
1662 &self,
1663 room_id: &RoomId,
1664 display_name: &DisplayName,
1665 ) -> Result<BTreeSet<OwnedUserId>, Self::Error> {
1666 self.store.get_users_with_display_name(room_id, display_name).await
1667 }
1668
1669 async fn get_users_with_display_names<'a>(
1670 &self,
1671 room_id: &RoomId,
1672 display_names: &'a [DisplayName],
1673 ) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>, Self::Error> {
1674 self.store.get_users_with_display_names(room_id, display_names).await
1675 }
1676
1677 async fn get_account_data_event(
1678 &self,
1679 event_type: GlobalAccountDataEventType,
1680 ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>, Self::Error> {
1681 self.store.get_account_data_event(event_type).await
1682 }
1683
1684 async fn get_room_account_data_event(
1685 &self,
1686 room_id: &RoomId,
1687 event_type: RoomAccountDataEventType,
1688 ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>, Self::Error> {
1689 self.store.get_room_account_data_event(room_id, event_type).await
1690 }
1691
1692 async fn get_user_room_receipt_event(
1693 &self,
1694 room_id: &RoomId,
1695 receipt_type: ReceiptType,
1696 thread: ReceiptThread,
1697 user_id: &UserId,
1698 ) -> Result<Option<(OwnedEventId, Receipt)>, Self::Error> {
1699 self.store.get_user_room_receipt_event(room_id, receipt_type, thread, user_id).await
1700 }
1701
1702 async fn get_event_room_receipt_events(
1703 &self,
1704 room_id: &RoomId,
1705 receipt_type: ReceiptType,
1706 thread: ReceiptThread,
1707 event_id: &EventId,
1708 ) -> Result<Vec<(OwnedUserId, Receipt)>, Self::Error> {
1709 self.store.get_event_room_receipt_events(room_id, receipt_type, thread, event_id).await
1710 }
1711
1712 async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
1713 self.store.get_custom_value(key).await
1714 }
1715
1716 async fn set_custom_value(
1717 &self,
1718 key: &[u8],
1719 value: Vec<u8>,
1720 ) -> Result<Option<Vec<u8>>, Self::Error> {
1721 self.store.set_custom_value(key, value).await
1722 }
1723
1724 async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
1725 self.store.remove_custom_value(key).await
1726 }
1727
1728 async fn remove_room(&self, room_id: &RoomId) -> Result<(), Self::Error> {
1729 let _guard = self.lock.lock().await;
1730 self.store.remove_room(room_id).await
1731 }
1732
1733 async fn save_send_queue_request(
1734 &self,
1735 room_id: &RoomId,
1736 transaction_id: OwnedTransactionId,
1737 created_at: MilliSecondsSinceUnixEpoch,
1738 request: QueuedRequestKind,
1739 priority: usize,
1740 ) -> Result<(), Self::Error> {
1741 self.store
1742 .save_send_queue_request(room_id, transaction_id, created_at, request, priority)
1743 .await
1744 }
1745
1746 async fn update_send_queue_request(
1747 &self,
1748 room_id: &RoomId,
1749 transaction_id: &TransactionId,
1750 content: QueuedRequestKind,
1751 ) -> Result<bool, Self::Error> {
1752 self.store.update_send_queue_request(room_id, transaction_id, content).await
1753 }
1754
1755 async fn remove_send_queue_request(
1756 &self,
1757 room_id: &RoomId,
1758 transaction_id: &TransactionId,
1759 ) -> Result<bool, Self::Error> {
1760 self.store.remove_send_queue_request(room_id, transaction_id).await
1761 }
1762
1763 async fn load_send_queue_requests(
1764 &self,
1765 room_id: &RoomId,
1766 ) -> Result<Vec<QueuedRequest>, Self::Error> {
1767 self.store.load_send_queue_requests(room_id).await
1768 }
1769
1770 async fn update_send_queue_request_status(
1771 &self,
1772 room_id: &RoomId,
1773 transaction_id: &TransactionId,
1774 error: Option<QueueWedgeError>,
1775 ) -> Result<(), Self::Error> {
1776 self.store.update_send_queue_request_status(room_id, transaction_id, error).await
1777 }
1778
1779 async fn load_rooms_with_unsent_requests(&self) -> Result<Vec<OwnedRoomId>, Self::Error> {
1780 self.store.load_rooms_with_unsent_requests().await
1781 }
1782
1783 async fn save_dependent_queued_request(
1784 &self,
1785 room_id: &RoomId,
1786 parent_txn_id: &TransactionId,
1787 own_txn_id: ChildTransactionId,
1788 created_at: MilliSecondsSinceUnixEpoch,
1789 content: DependentQueuedRequestKind,
1790 ) -> Result<(), Self::Error> {
1791 self.store
1792 .save_dependent_queued_request(room_id, parent_txn_id, own_txn_id, created_at, content)
1793 .await
1794 }
1795
1796 async fn mark_dependent_queued_requests_as_ready(
1797 &self,
1798 room_id: &RoomId,
1799 parent_txn_id: &TransactionId,
1800 sent_parent_key: SentRequestKey,
1801 ) -> Result<usize, Self::Error> {
1802 self.store
1803 .mark_dependent_queued_requests_as_ready(room_id, parent_txn_id, sent_parent_key)
1804 .await
1805 }
1806
1807 async fn update_dependent_queued_request(
1808 &self,
1809 room_id: &RoomId,
1810 own_transaction_id: &ChildTransactionId,
1811 new_content: DependentQueuedRequestKind,
1812 ) -> Result<bool, Self::Error> {
1813 self.store.update_dependent_queued_request(room_id, own_transaction_id, new_content).await
1814 }
1815
1816 async fn remove_dependent_queued_request(
1817 &self,
1818 room: &RoomId,
1819 own_txn_id: &ChildTransactionId,
1820 ) -> Result<bool, Self::Error> {
1821 self.store.remove_dependent_queued_request(room, own_txn_id).await
1822 }
1823
1824 async fn load_dependent_queued_requests(
1825 &self,
1826 room: &RoomId,
1827 ) -> Result<Vec<DependentQueuedRequest>, Self::Error> {
1828 self.store.load_dependent_queued_requests(room).await
1829 }
1830
1831 async fn upsert_thread_subscriptions(
1832 &self,
1833 updates: Vec<(&RoomId, &EventId, StoredThreadSubscription)>,
1834 ) -> Result<(), Self::Error> {
1835 self.store.upsert_thread_subscriptions(updates).await
1836 }
1837
1838 async fn load_thread_subscription(
1839 &self,
1840 room: &RoomId,
1841 thread_id: &EventId,
1842 ) -> Result<Option<StoredThreadSubscription>, Self::Error> {
1843 self.store.load_thread_subscription(room, thread_id).await
1844 }
1845
1846 async fn remove_thread_subscription(
1847 &self,
1848 room: &RoomId,
1849 thread_id: &EventId,
1850 ) -> Result<(), Self::Error> {
1851 self.store.remove_thread_subscription(room, thread_id).await
1852 }
1853
1854 async fn close(&self) -> Result<(), Self::Error> {
1855 self.store.close().await
1856 }
1857
1858 async fn reopen(&self) -> Result<(), Self::Error> {
1859 self.store.reopen().await
1860 }
1861
1862 async fn optimize(&self) -> Result<(), Self::Error> {
1863 self.store.optimize().await
1864 }
1865
1866 async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1867 self.store.get_size().await
1868 }
1869}
1870
1871#[cfg_attr(target_family = "wasm", async_trait(?Send))]
1873#[cfg_attr(not(target_family = "wasm"), async_trait)]
1874pub trait StateStoreExt: StateStore {
1875 async fn get_state_event_static<C>(
1881 &self,
1882 room_id: &RoomId,
1883 ) -> Result<Option<RawSyncOrStrippedState<C>>, Self::Error>
1884 where
1885 C: StaticEventContent<IsPrefix = ruma::events::False>
1886 + StaticStateEventContent<StateKey = EmptyStateKey>
1887 + RedactContent,
1888 C::Redacted: RedactedStateEventContent,
1889 {
1890 Ok(self.get_state_event(room_id, C::TYPE.into(), "").await?.map(|raw| raw.cast()))
1891 }
1892
1893 async fn get_state_event_static_for_key<C, K>(
1899 &self,
1900 room_id: &RoomId,
1901 state_key: &K,
1902 ) -> Result<Option<RawSyncOrStrippedState<C>>, Self::Error>
1903 where
1904 C: StaticEventContent<IsPrefix = ruma::events::False>
1905 + StaticStateEventContent
1906 + RedactContent,
1907 C::StateKey: Borrow<K>,
1908 C::Redacted: RedactedStateEventContent,
1909 K: AsRef<str> + ?Sized + Sync,
1910 {
1911 Ok(self
1912 .get_state_event(room_id, C::TYPE.into(), state_key.as_ref())
1913 .await?
1914 .map(|raw| raw.cast()))
1915 }
1916
1917 async fn get_state_events_static<C>(
1923 &self,
1924 room_id: &RoomId,
1925 ) -> Result<Vec<RawSyncOrStrippedState<C>>, Self::Error>
1926 where
1927 C: StaticEventContent<IsPrefix = ruma::events::False>
1928 + StaticStateEventContent
1929 + RedactContent,
1930 C::Redacted: RedactedStateEventContent,
1931 {
1932 Ok(self
1934 .get_state_events(room_id, C::TYPE.into())
1935 .await?
1936 .into_iter()
1937 .map(|raw| raw.cast())
1938 .collect())
1939 }
1940
1941 async fn get_state_events_for_keys_static<'a, C, K, I>(
1950 &self,
1951 room_id: &RoomId,
1952 state_keys: I,
1953 ) -> Result<Vec<RawSyncOrStrippedState<C>>, Self::Error>
1954 where
1955 C: StaticEventContent<IsPrefix = ruma::events::False>
1956 + StaticStateEventContent
1957 + RedactContent,
1958 C::StateKey: Borrow<K>,
1959 C::Redacted: RedactedStateEventContent,
1960 K: AsRef<str> + Sized + Sync + 'a,
1961 I: IntoIterator<Item = &'a K> + Send,
1962 I::IntoIter: Send,
1963 {
1964 Ok(self
1965 .get_state_events_for_keys(
1966 room_id,
1967 C::TYPE.into(),
1968 &state_keys.into_iter().map(|k| k.as_ref()).collect::<Vec<_>>(),
1969 )
1970 .await?
1971 .into_iter()
1972 .map(|raw| raw.cast())
1973 .collect())
1974 }
1975
1976 async fn get_account_data_event_static<C>(
1978 &self,
1979 ) -> Result<Option<Raw<GlobalAccountDataEvent<C>>>, Self::Error>
1980 where
1981 C: StaticEventContent<IsPrefix = ruma::events::False> + GlobalAccountDataEventContent,
1982 {
1983 Ok(self.get_account_data_event(C::TYPE.into()).await?.map(Raw::cast_unchecked))
1984 }
1985
1986 async fn get_room_account_data_event_static<C>(
1994 &self,
1995 room_id: &RoomId,
1996 ) -> Result<Option<Raw<RoomAccountDataEvent<C>>>, Self::Error>
1997 where
1998 C: StaticEventContent<IsPrefix = ruma::events::False> + RoomAccountDataEventContent,
1999 {
2000 Ok(self
2001 .get_room_account_data_event(room_id, C::TYPE.into())
2002 .await?
2003 .map(Raw::cast_unchecked))
2004 }
2005
2006 async fn get_member_event(
2014 &self,
2015 room_id: &RoomId,
2016 state_key: &UserId,
2017 ) -> Result<Option<RawMemberEvent>, Self::Error> {
2018 self.get_state_event_static_for_key(room_id, state_key).await
2019 }
2020}
2021
2022#[cfg_attr(target_family = "wasm", async_trait(?Send))]
2023#[cfg_attr(not(target_family = "wasm"), async_trait)]
2024impl<T: StateStore + ?Sized> StateStoreExt for T {}
2025
2026pub type DynStateStore = dyn StateStore<Error = StoreError>;
2028
2029pub trait IntoStateStore {
2035 #[doc(hidden)]
2036 fn into_state_store(self) -> Arc<DynStateStore>;
2037}
2038
2039impl<T> IntoStateStore for T
2040where
2041 T: StateStore + Sized + 'static,
2042{
2043 fn into_state_store(self) -> Arc<DynStateStore> {
2044 Arc::new(EraseStateStoreError(self))
2045 }
2046}
2047
2048#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
2050pub struct SupportedVersionsResponse {
2051 pub versions: Vec<String>,
2053
2054 pub unstable_features: BTreeMap<String, bool>,
2056}
2057
2058impl SupportedVersionsResponse {
2059 pub fn supported_versions(&self) -> SupportedVersions {
2065 let mut supported_versions =
2066 SupportedVersions::from_parts(&self.versions, &self.unstable_features);
2067
2068 if supported_versions.versions.is_empty() {
2071 supported_versions.versions.insert(MatrixVersion::V1_0);
2072 }
2073
2074 supported_versions
2075 }
2076}
2077
2078#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
2079pub struct WellKnownResponse {
2081 pub homeserver: HomeserverInfo,
2083
2084 pub identity_server: Option<IdentityServerInfo>,
2086
2087 pub tile_server: Option<TileServerInfo>,
2089
2090 pub rtc_foci: Vec<RtcTransport>,
2092}
2093
2094impl From<discover_homeserver::Response> for WellKnownResponse {
2095 fn from(response: discover_homeserver::Response) -> Self {
2096 Self {
2097 homeserver: response.homeserver,
2098 identity_server: response.identity_server,
2099 tile_server: response.tile_server,
2100 rtc_foci: response.rtc_foci,
2101 }
2102 }
2103}
2104
2105#[derive(Debug, Clone)]
2107pub enum StateStoreDataValue {
2108 SyncToken(String),
2110
2111 SupportedVersions(TtlValue<SupportedVersionsResponse>),
2113
2114 WellKnown(TtlValue<Option<WellKnownResponse>>),
2116
2117 Filter(String),
2119
2120 UserAvatarUrl(OwnedMxcUri),
2122
2123 RecentlyVisitedRooms(Vec<OwnedRoomId>),
2125
2126 UtdHookManagerData(GrowableBloom),
2129
2130 OneTimeKeyAlreadyUploaded,
2133
2134 ComposerDraft(ComposerDraft),
2139
2140 SeenKnockRequests(BTreeMap<OwnedEventId, OwnedUserId>),
2142
2143 ThreadSubscriptionsCatchupTokens(Vec<ThreadSubscriptionCatchupToken>),
2148
2149 HomeserverCapabilities(TtlValue<Capabilities>),
2151}
2152
2153#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
2163pub struct ThreadSubscriptionCatchupToken {
2164 pub from: String,
2170
2171 pub to: Option<String>,
2177}
2178
2179#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
2181pub struct ComposerDraft {
2182 pub plain_text: String,
2184 pub html_text: Option<String>,
2187 pub draft_type: ComposerDraftType,
2189 #[serde(default)]
2191 pub attachments: Vec<DraftAttachment>,
2192}
2193
2194#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
2196pub struct DraftAttachment {
2197 pub filename: String,
2199 pub content: DraftAttachmentContent,
2201}
2202
2203#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
2205#[serde(tag = "type")]
2206pub enum DraftAttachmentContent {
2207 Image {
2209 data: Vec<u8>,
2211 mimetype: Option<String>,
2213 size: Option<u64>,
2215 width: Option<u64>,
2217 height: Option<u64>,
2219 blurhash: Option<String>,
2221 thumbnail: Option<DraftThumbnail>,
2223 },
2224 Video {
2226 data: Vec<u8>,
2228 mimetype: Option<String>,
2230 size: Option<u64>,
2232 width: Option<u64>,
2234 height: Option<u64>,
2236 duration: Option<std::time::Duration>,
2238 blurhash: Option<String>,
2240 thumbnail: Option<DraftThumbnail>,
2242 },
2243 Audio {
2245 data: Vec<u8>,
2247 mimetype: Option<String>,
2249 size: Option<u64>,
2251 duration: Option<std::time::Duration>,
2253 },
2254 File {
2256 data: Vec<u8>,
2258 mimetype: Option<String>,
2260 size: Option<u64>,
2262 },
2263}
2264
2265#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
2267pub struct DraftThumbnail {
2268 pub filename: String,
2270 pub data: Vec<u8>,
2272 pub mimetype: Option<String>,
2274 pub width: Option<u64>,
2276 pub height: Option<u64>,
2278 pub size: Option<u64>,
2280}
2281
2282#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
2284pub enum ComposerDraftType {
2285 NewMessage,
2287 Reply {
2289 event_id: OwnedEventId,
2291 },
2292 Edit {
2294 event_id: OwnedEventId,
2296 },
2297}
2298
2299impl StateStoreDataValue {
2300 pub fn into_sync_token(self) -> Option<String> {
2302 as_variant!(self, Self::SyncToken)
2303 }
2304
2305 pub fn into_filter(self) -> Option<String> {
2307 as_variant!(self, Self::Filter)
2308 }
2309
2310 pub fn into_user_avatar_url(self) -> Option<OwnedMxcUri> {
2312 as_variant!(self, Self::UserAvatarUrl)
2313 }
2314
2315 pub fn into_recently_visited_rooms(self) -> Option<Vec<OwnedRoomId>> {
2317 as_variant!(self, Self::RecentlyVisitedRooms)
2318 }
2319
2320 pub fn into_utd_hook_manager_data(self) -> Option<GrowableBloom> {
2322 as_variant!(self, Self::UtdHookManagerData)
2323 }
2324
2325 pub fn into_composer_draft(self) -> Option<ComposerDraft> {
2327 as_variant!(self, Self::ComposerDraft)
2328 }
2329
2330 pub fn into_supported_versions(self) -> Option<TtlValue<SupportedVersionsResponse>> {
2332 as_variant!(self, Self::SupportedVersions)
2333 }
2334
2335 pub fn into_well_known(self) -> Option<TtlValue<Option<WellKnownResponse>>> {
2337 as_variant!(self, Self::WellKnown)
2338 }
2339
2340 pub fn into_seen_knock_requests(self) -> Option<BTreeMap<OwnedEventId, OwnedUserId>> {
2342 as_variant!(self, Self::SeenKnockRequests)
2343 }
2344
2345 pub fn into_thread_subscriptions_catchup_tokens(
2348 self,
2349 ) -> Option<Vec<ThreadSubscriptionCatchupToken>> {
2350 as_variant!(self, Self::ThreadSubscriptionsCatchupTokens)
2351 }
2352
2353 pub fn into_homeserver_capabilities(self) -> Option<TtlValue<Capabilities>> {
2356 as_variant!(self, Self::HomeserverCapabilities)
2357 }
2358}
2359
2360#[derive(Debug, Clone, Copy)]
2362pub enum StateStoreDataKey<'a> {
2363 SyncToken,
2365
2366 SupportedVersions,
2368
2369 WellKnown,
2371
2372 Filter(&'a str),
2374
2375 UserAvatarUrl(&'a UserId),
2377
2378 RecentlyVisitedRooms(&'a UserId),
2380
2381 UtdHookManagerData,
2384
2385 OneTimeKeyAlreadyUploaded,
2388
2389 ComposerDraft(&'a RoomId, Option<&'a EventId>),
2394
2395 SeenKnockRequests(&'a RoomId),
2397
2398 ThreadSubscriptionsCatchupTokens,
2400
2401 HomeserverCapabilities,
2403}
2404
2405impl StateStoreDataKey<'_> {
2406 pub const SYNC_TOKEN: &'static str = "sync_token";
2408
2409 pub const SUPPORTED_VERSIONS: &'static str = "server_capabilities"; pub const WELL_KNOWN: &'static str = "well_known";
2416
2417 pub const FILTER: &'static str = "filter";
2419
2420 pub const USER_AVATAR_URL: &'static str = "user_avatar_url";
2423
2424 pub const RECENTLY_VISITED_ROOMS: &'static str = "recently_visited_rooms";
2427
2428 pub const UTD_HOOK_MANAGER_DATA: &'static str = "utd_hook_manager_data";
2431
2432 pub const ONE_TIME_KEY_ALREADY_UPLOADED: &'static str = "one_time_key_already_uploaded";
2435
2436 pub const COMPOSER_DRAFT: &'static str = "composer_draft";
2439
2440 pub const SEEN_KNOCK_REQUESTS: &'static str = "seen_knock_requests";
2443
2444 pub const THREAD_SUBSCRIPTIONS_CATCHUP_TOKENS: &'static str =
2447 "thread_subscriptions_catchup_tokens";
2448
2449 pub const HOMESERVER_CAPABILITIES: &'static str = "homeserver_capabilities";
2451}
2452
2453pub fn compare_thread_subscription_bump_stamps(
2462 previous: Option<u64>,
2463 new: &mut Option<u64>,
2464) -> bool {
2465 match (previous, &new) {
2466 (Some(prev_bump), None) => {
2469 *new = Some(prev_bump);
2470 }
2471
2472 (Some(prev_bump), Some(new_bump)) if *new_bump <= prev_bump => {
2474 return false;
2475 }
2476
2477 _ => {}
2479 }
2480
2481 true
2482}
2483
2484#[cfg(test)]
2485mod tests {
2486 mod save_locked_state_store {
2487 use std::time::Duration;
2488
2489 use assert_matches::assert_matches;
2490 use futures_util::future::{self, Either};
2491 #[cfg(all(target_family = "wasm", target_os = "unknown"))]
2492 use gloo_timers::future::sleep;
2493 use matrix_sdk_common::executor::spawn;
2494 use matrix_sdk_test::async_test;
2495 use ruma::room_id;
2496 use tokio::sync::Mutex;
2497 #[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
2498 use tokio::time::sleep;
2499
2500 use crate::{
2501 StateChanges, StateStore,
2502 store::{IntoStateStore, MemoryStore, Result, SaveLockedStateStore},
2503 };
2504
2505 async fn get_store() -> Result<impl StateStore> {
2506 Ok(SaveLockedStateStore::new(MemoryStore::new()))
2507 }
2508
2509 statestore_integration_tests!();
2510
2511 #[async_test]
2512 async fn test_save_changes_only_accepts_guard_for_underlying_mutex() {
2513 let state_store = SaveLockedStateStore::new(MemoryStore::new());
2514 let state_changes = StateChanges::default();
2515 state_store
2516 .save_changes_with_guard(&state_store.lock().lock().await, &state_changes)
2517 .await
2518 .expect("state store accepts guard for underlying mutex");
2519
2520 let mutex = Mutex::new(());
2521 state_store
2522 .save_changes_with_guard(&mutex.lock().await, &state_changes)
2523 .await
2524 .expect_err("state store does not accept guard for unknown mutex");
2525 }
2526
2527 #[async_test]
2528 async fn test_remove_room_only_accepts_guard_for_underlying_mutex() {
2529 let state_store = SaveLockedStateStore::new(MemoryStore::new());
2530 let room_id = room_id!("!room");
2531 state_store
2532 .remove_room_with_guard(&state_store.lock().lock().await, room_id)
2533 .await
2534 .expect("state store accepts guard for underlying mutex");
2535
2536 let mutex = Mutex::new(());
2537 state_store
2538 .remove_room_with_guard(&mutex.lock().await, room_id)
2539 .await
2540 .expect_err("state store does not accept guard for unknown mutex");
2541 }
2542
2543 #[derive(Debug)]
2544 struct Elapsed;
2545
2546 async fn timeout<F: Future + Unpin>(
2547 duration: Duration,
2548 f: F,
2549 ) -> Result<F::Output, Elapsed> {
2550 #[cfg(all(target_family = "wasm", target_os = "unknown"))]
2551 {
2552 match future::select(sleep(duration), f).await {
2553 Either::Left(_) => return Err(Elapsed),
2554 Either::Right((output, _)) => Ok(output),
2555 }
2556 }
2557 #[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
2558 {
2559 tokio::time::timeout(duration, f).await.map_err(|_| Elapsed)
2560 }
2561 }
2562
2563 #[async_test]
2564 async fn test_state_store_waits_to_acquire_lock_before_saving_changes() {
2565 let state_store = SaveLockedStateStore::new(MemoryStore::new().into_state_store());
2566
2567 let lock_task = spawn({
2569 let state_store = state_store.clone();
2570 async move {
2571 let lock = state_store.lock();
2572 let _guard = lock.lock().await;
2573 sleep(Duration::from_secs(5)).await;
2574 }
2575 });
2576
2577 let save_task =
2579 spawn(async move { state_store.save_changes(&StateChanges::default()).await });
2580
2581 assert_matches!(future::select(lock_task, save_task).await, Either::Left((_, save_task)) => {
2584 timeout(Duration::from_millis(100), save_task)
2585 .await
2586 .expect("task completes before timeout")
2587 .expect("task completes successfully")
2588 .expect("task saves changes");
2589 });
2590 }
2591
2592 #[async_test]
2593 async fn test_state_store_waits_to_acquire_lock_before_removing_room() {
2594 let state_store = SaveLockedStateStore::new(MemoryStore::new().into_state_store());
2595
2596 let lock_task = spawn({
2598 let state_store = state_store.clone();
2599 async move {
2600 let lock = state_store.lock();
2601 let _guard = lock.lock().await;
2602 sleep(Duration::from_secs(5)).await;
2603 }
2604 });
2605
2606 let remove_task =
2609 spawn(async move { state_store.remove_room(room_id!("!room")).await });
2610
2611 assert_matches!(future::select(lock_task, remove_task).await, Either::Left((_, remove_task)) => {
2614 timeout(Duration::from_millis(100), remove_task)
2615 .await
2616 .expect("task completes before timeout")
2617 .expect("task completes successfully")
2618 .expect("task saves changes");
2619 });
2620 }
2621 }
2622}