1use std::{
117 collections::{BTreeMap, BTreeSet},
118 pin::Pin,
119 sync::Weak,
120};
121
122use as_variant::as_variant;
123use futures_core::Stream;
124use futures_util::{StreamExt, future::join_all, pin_mut};
125#[cfg(doc)]
126use matrix_sdk_base::{BaseClient, crypto::OlmMachine};
127use matrix_sdk_base::{
128 crypto::{
129 store::types::{RoomKeyInfo, RoomKeyWithheldInfo},
130 types::events::room::encrypted::EncryptedEvent,
131 },
132 deserialized_responses::{DecryptedRoomEvent, TimelineEvent, TimelineEventKind},
133 event_cache::store::EventCacheStoreLockState,
134 locks::Mutex,
135 task_monitor::BackgroundTaskHandle,
136 timer,
137};
138#[cfg(doc)]
139use matrix_sdk_common::deserialized_responses::EncryptionInfo;
140use ruma::{
141 OwnedEventId, OwnedRoomId, RoomId,
142 events::{AnySyncTimelineEvent, room::encrypted::OriginalSyncRoomEncryptedEvent},
143 push::Action,
144 serde::Raw,
145};
146use tokio::sync::{
147 broadcast::{self, Sender},
148 mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
149};
150use tokio_stream::wrappers::{
151 BroadcastStream, UnboundedReceiverStream, errors::BroadcastStreamRecvError,
152};
153use tracing::{info, instrument, trace, warn};
154
155#[cfg(doc)]
156use super::RoomEventCache;
157use super::{
158 EventCache, EventCacheError, EventCacheInner, EventsOrigin, RoomEventCacheGenericUpdate,
159 RoomEventCacheUpdate, TimelineVectorDiffs,
160 caches::room::{PostProcessingOrigin, RoomEventCacheLinkedChunkUpdate},
161};
162use crate::{Client, Result, Room, encryption::backups::BackupState, room::PushContext};
163
164type SessionId<'a> = &'a str;
165type OwnedSessionId = String;
166
167type EventIdAndUtd = (OwnedEventId, Raw<AnySyncTimelineEvent>);
168type EventIdAndEvent = (OwnedEventId, DecryptedRoomEvent);
169pub(in crate::event_cache) type ResolvedUtd =
170 (OwnedEventId, DecryptedRoomEvent, Option<Vec<Action>>);
171
172#[derive(Debug, Clone)]
175pub struct DecryptionRetryRequest {
176 pub room_id: OwnedRoomId,
178 pub utd_session_ids: BTreeSet<OwnedSessionId>,
180 pub refresh_info_session_ids: BTreeSet<OwnedSessionId>,
183}
184
185#[derive(Debug, Clone)]
187pub enum RedecryptorReport {
188 ResolvedUtds {
190 room_id: OwnedRoomId,
192 events: BTreeSet<OwnedEventId>,
194 },
195 Lagging,
198 BackupAvailable,
203}
204
205pub(super) struct RedecryptorChannels {
206 utd_reporter: Sender<RedecryptorReport>,
207 pub(super) decryption_request_sender: UnboundedSender<DecryptionRetryRequest>,
208 pub(super) decryption_request_receiver:
209 Mutex<Option<UnboundedReceiver<DecryptionRetryRequest>>>,
210}
211
212impl RedecryptorChannels {
213 pub(super) fn new() -> Self {
214 let (utd_reporter, _) = broadcast::channel(100);
215 let (decryption_request_sender, decryption_request_receiver) = unbounded_channel();
216
217 Self {
218 utd_reporter,
219 decryption_request_sender,
220 decryption_request_receiver: Mutex::new(Some(decryption_request_receiver)),
221 }
222 }
223}
224
225fn filter_timeline_event_to_utd(
230 event: TimelineEvent,
231) -> Option<(OwnedEventId, Raw<AnySyncTimelineEvent>)> {
232 let event_id = event.event_id();
233
234 let event = as_variant!(event.kind, TimelineEventKind::UnableToDecrypt { event, .. } => event);
237 event_id.zip(event)
240}
241
242fn filter_timeline_event_to_decrypted(
248 event: TimelineEvent,
249) -> Option<(OwnedEventId, DecryptedRoomEvent)> {
250 let event_id = event.event_id();
251
252 let event = as_variant!(event.kind, TimelineEventKind::Decrypted(event) => event);
253 event_id.zip(event)
256}
257
258impl EventCache {
259 async fn get_utds(
267 &self,
268 room_id: &RoomId,
269 session_id: SessionId<'_>,
270 ) -> Result<Vec<EventIdAndUtd>, EventCacheError> {
271 let events = match self.inner.store.lock().await? {
272 EventCacheStoreLockState::Clean(guard) | EventCacheStoreLockState::Dirty(guard) => {
277 guard.get_room_events(room_id, Some("m.room.encrypted"), Some(session_id)).await?
278 }
279 };
280
281 Ok(events.into_iter().filter_map(filter_timeline_event_to_utd).collect())
282 }
283
284 async fn get_utds_from_memory(&self) -> BTreeMap<OwnedRoomId, Vec<EventIdAndUtd>> {
287 let mut utds = BTreeMap::new();
288
289 for (room_id, room_cache) in self.inner.by_room.read().await.iter() {
290 let room_utds: Vec<_> = room_cache
291 .events()
292 .await
293 .into_iter()
294 .flatten()
295 .filter_map(filter_timeline_event_to_utd)
296 .collect();
297
298 utds.insert(room_id.to_owned(), room_utds);
299 }
300
301 utds
302 }
303
304 async fn get_decrypted_events(
305 &self,
306 room_id: &RoomId,
307 session_id: SessionId<'_>,
308 ) -> Result<Vec<EventIdAndEvent>, EventCacheError> {
309 let events = match self.inner.store.lock().await? {
310 EventCacheStoreLockState::Clean(guard) | EventCacheStoreLockState::Dirty(guard) => {
315 guard.get_room_events(room_id, None, Some(session_id)).await?
316 }
317 };
318
319 Ok(events.into_iter().filter_map(filter_timeline_event_to_decrypted).collect())
320 }
321
322 async fn get_decrypted_events_from_memory(
323 &self,
324 ) -> BTreeMap<OwnedRoomId, Vec<EventIdAndEvent>> {
325 let mut decrypted_events = BTreeMap::new();
326
327 for (room_id, room_cache) in self.inner.by_room.read().await.iter() {
328 let room_utds: Vec<_> = room_cache
329 .events()
330 .await
331 .into_iter()
332 .flatten()
333 .filter_map(filter_timeline_event_to_decrypted)
334 .collect();
335
336 decrypted_events.insert(room_id.to_owned(), room_utds);
337 }
338
339 decrypted_events
340 }
341
342 #[instrument(skip_all, fields(room_id))]
354 async fn on_resolved_utds(
355 &self,
356 room_id: &RoomId,
357 events: Vec<ResolvedUtd>,
358 ) -> Result<(), EventCacheError> {
359 if events.is_empty() {
360 trace!("No events were redecrypted or updated, nothing to replace");
361 return Ok(());
362 }
363
364 timer!("Resolving UTDs");
365
366 let (room_cache, _drop_handles) = self.for_room(room_id).await?;
369 let mut state = room_cache.state().write().await?;
370
371 let event_ids: BTreeSet<_> =
372 events.iter().cloned().map(|(event_id, _, _)| event_id).collect();
373 let mut new_events = Vec::with_capacity(events.len());
374
375 if let Some(pinned_cache) = state.pinned_event_cache() {
377 pinned_cache.replace_utds(&events).await?;
378 }
379
380 join_all(state.event_focused_caches().map(|cache| cache.replace_utds(&events))).await;
386
387 for (event_id, decrypted, actions) in events {
389 if let Some((location, mut target_event)) = state.find_event(&event_id).await? {
393 target_event.kind = TimelineEventKind::Decrypted(decrypted);
394
395 if let Some(actions) = actions {
396 target_event.set_push_actions(actions);
397 }
398
399 state.replace_event_at(location, target_event.clone()).await?;
402 new_events.push(target_event);
403 }
404 }
405
406 state.post_process_new_events(new_events, PostProcessingOrigin::Redecryption).await?;
407
408 room_cache.update_sender().send(
411 RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs {
412 diffs: state.room_linked_chunk_mut().updates_as_vector_diffs(),
413 origin: EventsOrigin::Cache,
414 }),
415 Some(RoomEventCacheGenericUpdate { room_id: room_id.to_owned() }),
416 );
417
418 let report =
423 RedecryptorReport::ResolvedUtds { room_id: room_id.to_owned(), events: event_ids };
424 let _ = self.inner.redecryption_channels.utd_reporter.send(report);
425
426 Ok(())
427 }
428
429 async fn decrypt_event(
431 &self,
432 room_id: &RoomId,
433 room: Option<&Room>,
434 push_context: Option<&PushContext>,
435 event: &Raw<EncryptedEvent>,
436 ) -> Option<(DecryptedRoomEvent, Option<Vec<Action>>)> {
437 if let Some(room) = room {
438 match room
439 .decrypt_event(
440 event.cast_ref_unchecked::<OriginalSyncRoomEncryptedEvent>(),
441 push_context,
442 )
443 .await
444 {
445 Ok(maybe_decrypted) => {
446 let actions = maybe_decrypted.push_actions().map(|a| a.to_vec());
447
448 if let TimelineEventKind::Decrypted(decrypted) = maybe_decrypted.kind {
449 Some((decrypted, actions))
450 } else {
451 warn!(
452 "Failed to redecrypt an event despite receiving a room key or request to redecrypt"
453 );
454 None
455 }
456 }
457 Err(e) => {
458 warn!(
459 "Failed to redecrypt an event despite receiving a room key or request to redecrypt {e:?}"
460 );
461 None
462 }
463 }
464 } else {
465 let client = self.inner.client().ok()?;
466 let machine = client.olm_machine().await;
467 let machine = machine.as_ref()?;
468
469 match machine.decrypt_room_event(event, room_id, client.decryption_settings()).await {
470 Ok(decrypted) => Some((decrypted, None)),
471 Err(e) => {
472 warn!(
473 "Failed to redecrypt an event despite receiving a room key or a request to redecrypt {e:?}"
474 );
475 None
476 }
477 }
478 }
479 }
480
481 #[instrument(skip_all, fields(room_id, session_id))]
484 async fn retry_decryption(
485 &self,
486 room_id: &RoomId,
487 session_id: SessionId<'_>,
488 ) -> Result<(), EventCacheError> {
489 let events = self.get_utds(room_id, session_id).await?;
491 self.retry_decryption_for_events(room_id, events).await
492 }
493
494 #[instrument(skip_all, fields(updates.linked_chunk_id))]
496 async fn retry_decryption_for_event_cache_updates(
497 &self,
498 updates: RoomEventCacheLinkedChunkUpdate,
499 ) -> Result<(), EventCacheError> {
500 let room_id = updates.linked_chunk_id.room_id();
501 let events: Vec<_> = updates
502 .updates
503 .into_iter()
504 .flat_map(|updates| updates.into_items())
505 .filter_map(filter_timeline_event_to_utd)
506 .collect();
507
508 self.retry_decryption_for_events(room_id, events).await
509 }
510
511 async fn retry_decryption_for_in_memory_events(&self) {
512 let utds = self.get_utds_from_memory().await;
513
514 for (room_id, utds) in utds.into_iter() {
515 if let Err(e) = self.retry_decryption_for_events(&room_id, utds).await {
516 warn!(%room_id, "Failed to redecrypt in-memory events {e:?}");
517 }
518 }
519 }
520
521 #[instrument(skip_all, fields(room_id, session_id))]
523 async fn retry_decryption_for_events(
524 &self,
525 room_id: &RoomId,
526 events: Vec<EventIdAndUtd>,
527 ) -> Result<(), EventCacheError> {
528 trace!("Retrying to decrypt");
529
530 if events.is_empty() {
531 trace!("No relevant events found.");
532 return Ok(());
533 }
534
535 let room = self.inner.client().ok().and_then(|client| client.get_room(room_id));
536 let push_context =
537 if let Some(room) = &room { room.push_context().await.ok().flatten() } else { None };
538
539 let mut decrypted_events = Vec::with_capacity(events.len());
541
542 for (event_id, event) in events {
543 if let Some((decrypted, actions)) = self
546 .decrypt_event(
547 room_id,
548 room.as_ref(),
549 push_context.as_ref(),
550 event.cast_ref_unchecked(),
551 )
552 .await
553 {
554 decrypted_events.push((event_id, decrypted, actions));
555 }
556 }
557
558 let event_ids: BTreeSet<_> =
559 decrypted_events.iter().map(|(event_id, _, _)| event_id).collect();
560
561 if !event_ids.is_empty() {
562 trace!(?event_ids, "Successfully redecrypted events");
563 }
564
565 self.on_resolved_utds(room_id, decrypted_events).await?;
568
569 Ok(())
570 }
571
572 async fn update_encryption_info_for_events(
574 &self,
575 room: &Room,
576 events: Vec<EventIdAndEvent>,
577 ) -> Result<(), EventCacheError> {
578 let mut updated_events = Vec::with_capacity(events.len());
580
581 for (event_id, mut event) in events {
582 if let Some(session_id) = event.encryption_info.session_id() {
583 let new_encryption_info =
584 room.get_encryption_info(session_id, &event.encryption_info.sender).await;
585
586 if let Some(new_encryption_info) = new_encryption_info
588 && event.encryption_info != new_encryption_info
589 {
590 event.encryption_info = new_encryption_info;
591 updated_events.push((event_id, event, None));
592 }
593 }
594 }
595
596 let event_ids: BTreeSet<_> =
597 updated_events.iter().map(|(event_id, _, _)| event_id).collect();
598
599 if !event_ids.is_empty() {
600 trace!(?event_ids, "Replacing the encryption info of some events");
601 }
602
603 self.on_resolved_utds(room.room_id(), updated_events).await
604 }
605
606 #[instrument(skip_all, fields(room_id, session_id))]
607 async fn update_encryption_info(
608 &self,
609 room_id: &RoomId,
610 session_id: SessionId<'_>,
611 ) -> Result<(), EventCacheError> {
612 trace!("Updating encryption info");
613
614 let Ok(client) = self.inner.client() else {
615 return Ok(());
616 };
617
618 let Some(room) = client.get_room(room_id) else {
619 return Ok(());
620 };
621
622 let events = self.get_decrypted_events(room_id, session_id).await?;
624
625 if events.is_empty() {
626 trace!("No relevant events found.");
627 return Ok(());
628 }
629
630 self.update_encryption_info_for_events(&room, events).await
632 }
633
634 async fn retry_update_encryption_info_for_in_memory_events(&self) {
635 let decrypted_events = self.get_decrypted_events_from_memory().await;
636
637 for (room_id, events) in decrypted_events.into_iter() {
638 let Some(room) = self.inner.client().ok().and_then(|c| c.get_room(&room_id)) else {
639 continue;
640 };
641
642 if let Err(e) = self.update_encryption_info_for_events(&room, events).await {
643 warn!(
644 %room_id,
645 "Failed to replace the encryption info for in-memory events {e:?}"
646 );
647 }
648 }
649 }
650
651 async fn retry_in_memory_events(&self) {
662 self.retry_decryption_for_in_memory_events().await;
663 self.retry_update_encryption_info_for_in_memory_events().await;
664 }
665
666 pub fn request_decryption(&self, request: DecryptionRetryRequest) {
707 let _ =
708 self.inner.redecryption_channels.decryption_request_sender.send(request).inspect_err(
709 |_| warn!("Requesting a decryption while the redecryption task has been shut down"),
710 );
711 }
712
713 pub fn subscribe_to_decryption_reports(
764 &self,
765 ) -> impl Stream<Item = Result<RedecryptorReport, BroadcastStreamRecvError>> {
766 BroadcastStream::new(self.inner.redecryption_channels.utd_reporter.subscribe())
767 }
768}
769
770#[inline(always)]
771fn upgrade_event_cache(cache: &Weak<EventCacheInner>) -> Option<EventCache> {
772 cache.upgrade().map(|inner| EventCache { inner })
773}
774
775async fn send_report_and_retry_memory_events(
776 cache: &Weak<EventCacheInner>,
777 report: RedecryptorReport,
778) -> Result<(), ()> {
779 let Some(cache) = upgrade_event_cache(cache) else {
780 return Err(());
781 };
782
783 cache.retry_in_memory_events().await;
784 let _ = cache.inner.redecryption_channels.utd_reporter.send(report);
785
786 Ok(())
787}
788
789pub(crate) struct Redecryptor {
796 _task: BackgroundTaskHandle,
797}
798
799impl Redecryptor {
800 pub(super) fn new(
805 client: &Client,
806 cache: Weak<EventCacheInner>,
807 receiver: UnboundedReceiver<DecryptionRetryRequest>,
808 linked_chunk_update_sender: &Sender<RoomEventCacheLinkedChunkUpdate>,
809 ) -> Self {
810 let linked_chunk_stream = BroadcastStream::new(linked_chunk_update_sender.subscribe());
811 let backup_state_stream = client.encryption().backups().state_stream();
812
813 let task = client
814 .task_monitor()
815 .spawn_background_task("event_cache::redecryptor", async {
816 let request_redecryption_stream = UnboundedReceiverStream::new(receiver);
817
818 Self::listen_for_room_keys_task(
819 cache,
820 request_redecryption_stream,
821 linked_chunk_stream,
822 backup_state_stream,
823 )
824 .await;
825 })
826 .abort_on_drop();
827
828 Self { _task: task }
829 }
830
831 async fn subscribe_to_room_key_stream(
836 cache: &Weak<EventCacheInner>,
837 ) -> Option<(
838 impl Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>>,
839 impl Stream<Item = Vec<RoomKeyWithheldInfo>>,
840 )> {
841 let event_cache = cache.upgrade()?;
842 let client = event_cache.client().ok()?;
843 let machine = client.olm_machine().await;
844
845 machine.as_ref().map(|m| {
846 (m.store().room_keys_received_stream(), m.store().room_keys_withheld_received_stream())
847 })
848 }
849
850 async fn redecryption_loop(
851 cache: &Weak<EventCacheInner>,
852 decryption_request_stream: &mut Pin<&mut impl Stream<Item = DecryptionRetryRequest>>,
853 events_stream: &mut Pin<
854 &mut impl Stream<Item = Result<RoomEventCacheLinkedChunkUpdate, BroadcastStreamRecvError>>,
855 >,
856 backup_state_stream: &mut Pin<
857 &mut impl Stream<Item = Result<BackupState, BroadcastStreamRecvError>>,
858 >,
859 ) -> bool {
860 let Some((room_key_stream, withheld_stream)) =
861 Self::subscribe_to_room_key_stream(cache).await
862 else {
863 return false;
864 };
865
866 pin_mut!(room_key_stream);
867 pin_mut!(withheld_stream);
868
869 loop {
870 tokio::select! {
871 Some(request) = decryption_request_stream.next() => {
874 let Some(cache) = upgrade_event_cache(cache) else {
875 break false;
876 };
877
878 trace!(?request, "Received a redecryption request");
879
880 for session_id in request.utd_session_ids {
881 let _ = cache
882 .retry_decryption(&request.room_id, &session_id)
883 .await
884 .inspect_err(|e| warn!("Error redecrypting after an explicit request was received {e:?}"));
885 }
886
887 for session_id in request.refresh_info_session_ids {
888 let _ = cache.update_encryption_info(&request.room_id, &session_id).await.inspect_err(|e|
889 warn!(
890 room_id = %request.room_id,
891 session_id = session_id,
892 "Unable to update the encryption info {e:?}",
893 ));
894 }
895 }
896 room_keys = room_key_stream.next() => {
899 match room_keys {
900 Some(Ok(room_keys)) => {
901 let Some(cache) = upgrade_event_cache(cache) else {
905 break false;
906 };
907
908 trace!(?room_keys, "Received new room keys");
909
910 for key in &room_keys {
911 let _ = cache
912 .retry_decryption(&key.room_id, &key.session_id)
913 .await
914 .inspect_err(|e| warn!("Error redecrypting {e:?}"));
915 }
916
917 for key in room_keys {
918 let _ = cache.update_encryption_info(&key.room_id, &key.session_id).await.inspect_err(|e|
919 warn!(
920 room_id = %key.room_id,
921 session_id = key.session_id,
922 "Unable to update the encryption info {e:?}",
923 ));
924 }
925 },
926 Some(Err(_)) => {
927 warn!("The room key stream lagged, reporting the lag to our listeners");
934
935 if send_report_and_retry_memory_events(cache, RedecryptorReport::Lagging).await.is_err() {
936 break false;
937 }
938 },
939 None => {
942 break true;
943 }
944 }
945 }
946 withheld_info = withheld_stream.next() => {
947 match withheld_info {
948 Some(infos) => {
949 let Some(cache) = upgrade_event_cache(cache) else {
950 break false;
951 };
952
953 trace!(?infos, "Received new withheld infos");
954
955 for RoomKeyWithheldInfo { room_id, session_id, .. } in &infos {
956 let _ = cache.update_encryption_info(room_id, session_id).await.inspect_err(|e|
957 warn!(
958 room_id = %room_id,
959 session_id = session_id,
960 "Unable to update the encryption info {e:?}",
961 ));
962 }
963 }
964 None => break true,
967 }
968 }
969 Some(event_updates) = events_stream.next() => {
973 match event_updates {
974 Ok(updates) => {
975 let Some(cache) = upgrade_event_cache(cache) else {
976 break false;
977 };
978
979 let linked_chunk_id = updates.linked_chunk_id.to_owned();
980
981 let _ = cache.retry_decryption_for_event_cache_updates(updates).await.inspect_err(|e|
982 warn!(
983 %linked_chunk_id,
984 "Unable to handle UTDs from event cache updates {e:?}",
985 )
986 );
987 }
988 Err(_) => {
989 if send_report_and_retry_memory_events(cache, RedecryptorReport::Lagging).await.is_err() {
990 break false;
991 }
992 }
993 }
994 }
995 Some(backup_state_update) = backup_state_stream.next() => {
996 match backup_state_update {
997 Ok(state) => {
998 match state {
999 BackupState::Unknown |
1000 BackupState::Creating |
1001 BackupState::Enabling |
1002 BackupState::Resuming |
1003 BackupState::Downloading |
1004 BackupState::Disabling =>{
1005 }
1008 BackupState::Enabled => {
1009 if send_report_and_retry_memory_events(cache, RedecryptorReport::BackupAvailable).await.is_err() {
1014 break false;
1015 }
1016 }
1017 }
1018 }
1019 Err(_) => {
1020 if send_report_and_retry_memory_events(cache, RedecryptorReport::Lagging).await.is_err() {
1021 break false;
1022 }
1023 }
1024 }
1025 }
1026 else => break false,
1027 }
1028 }
1029 }
1030
1031 async fn listen_for_room_keys_task(
1032 cache: Weak<EventCacheInner>,
1033 decryption_request_stream: UnboundedReceiverStream<DecryptionRetryRequest>,
1034 events_stream: BroadcastStream<RoomEventCacheLinkedChunkUpdate>,
1035 backup_state_stream: impl Stream<Item = Result<BackupState, BroadcastStreamRecvError>>,
1036 ) {
1037 pin_mut!(decryption_request_stream);
1041 pin_mut!(events_stream);
1042 pin_mut!(backup_state_stream);
1043
1044 while Self::redecryption_loop(
1045 &cache,
1046 &mut decryption_request_stream,
1047 &mut events_stream,
1048 &mut backup_state_stream,
1049 )
1050 .await
1051 {
1052 info!("Regenerating the re-decryption streams");
1053
1054 if send_report_and_retry_memory_events(&cache, RedecryptorReport::Lagging)
1057 .await
1058 .is_err()
1059 {
1060 break;
1061 }
1062 }
1063
1064 info!("Shutting down the event cache redecryptor");
1065 }
1066}
1067
1068#[cfg(not(target_family = "wasm"))]
1069#[cfg(test)]
1070mod tests {
1071 use std::{
1072 collections::BTreeSet,
1073 sync::{
1074 Arc,
1075 atomic::{AtomicBool, Ordering},
1076 },
1077 time::Duration,
1078 };
1079
1080 use assert_matches2::assert_matches;
1081 use async_trait::async_trait;
1082 use eyeball_im::VectorDiff;
1083 use matrix_sdk_base::{
1084 cross_process_lock::CrossProcessLockGeneration,
1085 crypto::types::events::{ToDeviceEvent, room::encrypted::ToDeviceEncryptedEventContent},
1086 deserialized_responses::{TimelineEventKind, VerificationState},
1087 event_cache::{
1088 Event, Gap,
1089 store::{EventCacheStore, EventCacheStoreError, MemoryStore},
1090 },
1091 linked_chunk::{
1092 ChunkIdentifier, ChunkIdentifierGenerator, ChunkMetadata, LinkedChunkId, Position,
1093 RawChunk, Update,
1094 },
1095 locks::Mutex,
1096 sleep::sleep,
1097 store::StoreConfig,
1098 };
1099 use matrix_sdk_common::cross_process_lock::CrossProcessLockConfig;
1100 use matrix_sdk_test::{JoinedRoomBuilder, async_test, event_factory::EventFactory};
1101 use ruma::{
1102 EventId, OwnedEventId, RoomId, RoomVersionId, device_id, event_id,
1103 events::{AnySyncTimelineEvent, relation::RelationType},
1104 room_id,
1105 serde::Raw,
1106 user_id,
1107 };
1108 use serde_json::json;
1109 use tokio::sync::oneshot::{self, Sender};
1110 use tracing::{Instrument, info};
1111
1112 use crate::{
1113 Client, assert_let_timeout,
1114 encryption::EncryptionSettings,
1115 event_cache::{
1116 DecryptionRetryRequest, RoomEventCacheGenericUpdate, RoomEventCacheUpdate,
1117 TimelineVectorDiffs,
1118 },
1119 test_utils::mocks::MatrixMockServer,
1120 };
1121
1122 #[derive(Debug, Clone)]
1127 struct DelayingStore {
1128 memory_store: MemoryStore,
1129 delaying: Arc<AtomicBool>,
1130 foo: Arc<Mutex<Option<Sender<()>>>>,
1131 }
1132
1133 impl DelayingStore {
1134 fn new() -> Self {
1135 Self {
1136 memory_store: MemoryStore::new(),
1137 delaying: AtomicBool::new(true).into(),
1138 foo: Arc::new(Mutex::new(None)),
1139 }
1140 }
1141
1142 async fn stop_delaying(&self) {
1143 let (sender, receiver) = oneshot::channel();
1144
1145 {
1146 *self.foo.lock() = Some(sender);
1147 }
1148
1149 self.delaying.store(false, Ordering::SeqCst);
1150
1151 receiver.await.expect("We should be able to receive a response")
1152 }
1153 }
1154
1155 #[cfg_attr(target_family = "wasm", async_trait(?Send))]
1156 #[cfg_attr(not(target_family = "wasm"), async_trait)]
1157 impl EventCacheStore for DelayingStore {
1158 type Error = EventCacheStoreError;
1159
1160 async fn try_take_leased_lock(
1161 &self,
1162 lease_duration_ms: u32,
1163 key: &str,
1164 holder: &str,
1165 ) -> Result<Option<CrossProcessLockGeneration>, Self::Error> {
1166 self.memory_store.try_take_leased_lock(lease_duration_ms, key, holder).await
1167 }
1168
1169 async fn handle_linked_chunk_updates(
1170 &self,
1171 linked_chunk_id: LinkedChunkId<'_>,
1172 updates: Vec<Update<Event, Gap>>,
1173 ) -> Result<(), Self::Error> {
1174 while self.delaying.load(Ordering::SeqCst) {
1180 sleep(Duration::from_millis(10)).await;
1181 }
1182
1183 let sender = self.foo.lock().take();
1184 let ret = self.memory_store.handle_linked_chunk_updates(linked_chunk_id, updates).await;
1185
1186 if let Some(sender) = sender {
1187 sender.send(()).expect("We should be able to notify the other side that we're done with the storage operation");
1188 }
1189
1190 ret
1191 }
1192
1193 async fn load_all_chunks(
1194 &self,
1195 linked_chunk_id: LinkedChunkId<'_>,
1196 ) -> Result<Vec<RawChunk<Event, Gap>>, Self::Error> {
1197 self.memory_store.load_all_chunks(linked_chunk_id).await
1198 }
1199
1200 async fn load_all_chunks_metadata(
1201 &self,
1202 linked_chunk_id: LinkedChunkId<'_>,
1203 ) -> Result<Vec<ChunkMetadata>, Self::Error> {
1204 self.memory_store.load_all_chunks_metadata(linked_chunk_id).await
1205 }
1206
1207 async fn load_last_chunk(
1208 &self,
1209 linked_chunk_id: LinkedChunkId<'_>,
1210 ) -> Result<(Option<RawChunk<Event, Gap>>, ChunkIdentifierGenerator), Self::Error> {
1211 self.memory_store.load_last_chunk(linked_chunk_id).await
1212 }
1213
1214 async fn load_previous_chunk(
1215 &self,
1216 linked_chunk_id: LinkedChunkId<'_>,
1217 before_chunk_identifier: ChunkIdentifier,
1218 ) -> Result<Option<RawChunk<Event, Gap>>, Self::Error> {
1219 self.memory_store.load_previous_chunk(linked_chunk_id, before_chunk_identifier).await
1220 }
1221
1222 async fn clear_all_linked_chunks(&self) -> Result<(), Self::Error> {
1223 self.memory_store.clear_all_linked_chunks().await
1224 }
1225
1226 async fn filter_duplicated_events(
1227 &self,
1228 linked_chunk_id: LinkedChunkId<'_>,
1229 events: Vec<OwnedEventId>,
1230 ) -> Result<Vec<(OwnedEventId, Position)>, Self::Error> {
1231 self.memory_store.filter_duplicated_events(linked_chunk_id, events).await
1232 }
1233
1234 async fn find_event(
1235 &self,
1236 room_id: &RoomId,
1237 event_id: &EventId,
1238 ) -> Result<Option<Event>, Self::Error> {
1239 self.memory_store.find_event(room_id, event_id).await
1240 }
1241
1242 async fn find_event_relations(
1243 &self,
1244 room_id: &RoomId,
1245 event_id: &EventId,
1246 filters: Option<&[RelationType]>,
1247 ) -> Result<Vec<(Event, Option<Position>)>, Self::Error> {
1248 self.memory_store.find_event_relations(room_id, event_id, filters).await
1249 }
1250
1251 async fn get_room_events(
1252 &self,
1253 room_id: &RoomId,
1254 event_type: Option<&str>,
1255 session_id: Option<&str>,
1256 ) -> Result<Vec<Event>, Self::Error> {
1257 self.memory_store.get_room_events(room_id, event_type, session_id).await
1258 }
1259
1260 async fn save_event(&self, room_id: &RoomId, event: Event) -> Result<(), Self::Error> {
1261 self.memory_store.save_event(room_id, event).await
1262 }
1263
1264 async fn optimize(&self) -> Result<(), Self::Error> {
1265 self.memory_store.optimize().await
1266 }
1267
1268 async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1269 self.memory_store.get_size().await
1270 }
1271 }
1272
1273 async fn set_up_clients(
1274 room_id: &RoomId,
1275 alice_enables_cross_signing: bool,
1276 use_delayed_store: bool,
1277 ) -> (Client, Client, MatrixMockServer, Option<DelayingStore>) {
1278 let alice_span = tracing::info_span!("alice");
1279 let bob_span = tracing::info_span!("bob");
1280
1281 let alice_user_id = user_id!("@alice:localhost");
1282 let alice_device_id = device_id!("ALICEDEVICE");
1283 let bob_user_id = user_id!("@bob:localhost");
1284 let bob_device_id = device_id!("BOBDEVICE");
1285
1286 let matrix_mock_server = MatrixMockServer::new().await;
1287 matrix_mock_server.mock_crypto_endpoints_preset().await;
1288
1289 let encryption_settings = EncryptionSettings {
1290 auto_enable_cross_signing: alice_enables_cross_signing,
1291 ..Default::default()
1292 };
1293
1294 let alice = matrix_mock_server
1297 .client_builder_for_crypto_end_to_end(alice_user_id, alice_device_id)
1298 .on_builder(|builder| {
1299 builder
1300 .with_enable_share_history_on_invite(true)
1301 .with_encryption_settings(encryption_settings)
1302 })
1303 .build()
1304 .instrument(alice_span.clone())
1305 .await;
1306
1307 let encryption_settings =
1308 EncryptionSettings { auto_enable_cross_signing: true, ..Default::default() };
1309
1310 let (store_config, store) = if use_delayed_store {
1311 let store = DelayingStore::new();
1312
1313 (
1314 StoreConfig::new(CrossProcessLockConfig::multi_process(
1315 "delayed_store_event_cache_test",
1316 ))
1317 .event_cache_store(store.clone()),
1318 Some(store),
1319 )
1320 } else {
1321 (
1322 StoreConfig::new(CrossProcessLockConfig::multi_process(
1323 "normal_store_event_cache_test",
1324 )),
1325 None,
1326 )
1327 };
1328
1329 let bob = matrix_mock_server
1330 .client_builder_for_crypto_end_to_end(bob_user_id, bob_device_id)
1331 .on_builder(|builder| {
1332 builder
1333 .with_enable_share_history_on_invite(true)
1334 .with_encryption_settings(encryption_settings)
1335 .store_config(store_config)
1336 })
1337 .build()
1338 .instrument(bob_span.clone())
1339 .await;
1340
1341 bob.event_cache().subscribe().expect("Bob should be able to enable the event cache");
1342
1343 matrix_mock_server.exchange_e2ee_identities(&alice, &bob).await;
1345
1346 let event_factory = EventFactory::new().room(room_id).sender(alice_user_id);
1347
1348 let room_builder = JoinedRoomBuilder::new(room_id)
1350 .add_state_event(event_factory.create(alice_user_id, RoomVersionId::V1))
1351 .add_state_event(event_factory.room_encryption());
1352
1353 matrix_mock_server
1354 .mock_sync()
1355 .ok_and_run(&alice, |builder| {
1356 builder.add_joined_room(room_builder.clone());
1357 })
1358 .instrument(alice_span)
1359 .await;
1360
1361 matrix_mock_server
1362 .mock_sync()
1363 .ok_and_run(&bob, |builder| {
1364 builder.add_joined_room(room_builder);
1365 })
1366 .instrument(bob_span)
1367 .await;
1368
1369 (alice, bob, matrix_mock_server, store)
1370 }
1371
1372 async fn prepare_room(
1373 matrix_mock_server: &MatrixMockServer,
1374 event_factory: &EventFactory,
1375 alice: &Client,
1376 bob: &Client,
1377 room_id: &RoomId,
1378 ) -> (Raw<AnySyncTimelineEvent>, Raw<ToDeviceEvent<ToDeviceEncryptedEventContent>>) {
1379 let alice_user_id = alice.user_id().unwrap();
1380 let bob_user_id = bob.user_id().unwrap();
1381
1382 let alice_member_event = event_factory.member(alice_user_id).into_raw();
1383 let bob_member_event = event_factory.member(bob_user_id).into_raw();
1384
1385 let room = alice
1386 .get_room(room_id)
1387 .expect("Alice should have access to the room now that we synced");
1388
1389 let event_type = "m.room.message";
1394 let content = json!({"body": "It's a secret to everybody", "msgtype": "m.text"});
1395
1396 let event_id = event_id!("$some_id");
1397 let (event_receiver, mock) =
1398 matrix_mock_server.mock_room_send().ok_with_capture(event_id, alice_user_id);
1399 let (_guard, room_key) = matrix_mock_server.mock_capture_put_to_device(alice_user_id).await;
1400
1401 {
1402 let _guard = mock.mock_once().mount_as_scoped().await;
1403
1404 matrix_mock_server
1405 .mock_get_members()
1406 .ok(vec![alice_member_event.clone(), bob_member_event.clone()])
1407 .mock_once()
1408 .mount()
1409 .await;
1410
1411 room.send_raw(event_type, content)
1412 .await
1413 .expect("We should be able to send an initial message");
1414 };
1415
1416 let event = event_receiver.await.expect("Alice should have sent the event by now");
1418 let room_key = room_key.await;
1419
1420 (event, room_key)
1421 }
1422
1423 #[async_test]
1424 async fn test_redecryptor() {
1425 let room_id = room_id!("!test:localhost");
1426
1427 let event_factory = EventFactory::new().room(room_id);
1428 let (alice, bob, matrix_mock_server, _) = set_up_clients(room_id, true, false).await;
1429
1430 let (event, room_key) =
1431 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1432
1433 let event_cache = bob.event_cache();
1436 let (room_cache, _) = event_cache
1437 .for_room(room_id)
1438 .await
1439 .expect("We should be able to get to the event cache for a specific room");
1440
1441 let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
1442 let mut generic_stream = event_cache.subscribe_to_room_generic_updates();
1443
1444 bob.inner
1447 .base_client
1448 .regenerate_olm(None)
1449 .await
1450 .expect("We should be able to regenerate the Olm machine");
1451
1452 matrix_mock_server
1454 .mock_sync()
1455 .ok_and_run(&bob, |builder| {
1456 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1457 })
1458 .await;
1459
1460 assert_let_timeout!(
1463 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1464 subscriber.recv()
1465 );
1466
1467 assert_eq!(diffs.len(), 1);
1470 assert_matches!(&diffs[0], VectorDiff::Append { values });
1471 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1472
1473 assert_let_timeout!(
1474 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1475 );
1476 assert_eq!(expected_room_id, room_id);
1477 assert!(generic_stream.is_empty());
1478
1479 matrix_mock_server
1481 .mock_sync()
1482 .ok_and_run(&bob, |builder| {
1483 builder.add_to_device_event(
1484 room_key
1485 .deserialize_as()
1486 .expect("We should be able to deserialize the room key"),
1487 );
1488 })
1489 .await;
1490
1491 assert_let_timeout!(
1493 Duration::from_secs(1),
1494 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1495 subscriber.recv()
1496 );
1497
1498 assert_eq!(diffs.len(), 1);
1500 assert_matches!(&diffs[0], VectorDiff::Set { index, value });
1501 assert_eq!(*index, 0);
1502 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1503
1504 assert_let_timeout!(
1505 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1506 );
1507 assert_eq!(expected_room_id, room_id);
1508 assert!(generic_stream.is_empty());
1509 }
1510
1511 #[async_test]
1512 async fn test_redecryptor_updating_encryption_info() {
1513 let bob_span = tracing::info_span!("bob");
1514
1515 let room_id = room_id!("!test:localhost");
1516
1517 let event_factory = EventFactory::new().room(room_id);
1518 let (alice, bob, matrix_mock_server, _) = set_up_clients(room_id, false, false).await;
1519
1520 let (event, room_key) =
1521 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1522
1523 let event_cache = bob.event_cache();
1526 let (room_cache, _) = event_cache
1527 .for_room(room_id)
1528 .instrument(bob_span.clone())
1529 .await
1530 .expect("We should be able to get to the event cache for a specific room");
1531
1532 let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
1533 let mut generic_stream = event_cache.subscribe_to_room_generic_updates();
1534
1535 matrix_mock_server
1537 .mock_sync()
1538 .ok_and_run(&bob, |builder| {
1539 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1540 })
1541 .instrument(bob_span.clone())
1542 .await;
1543
1544 assert_let_timeout!(
1547 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1548 subscriber.recv()
1549 );
1550
1551 assert_eq!(diffs.len(), 1);
1554 assert_matches!(&diffs[0], VectorDiff::Append { values });
1555 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1556
1557 assert_let_timeout!(
1558 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1559 );
1560 assert_eq!(expected_room_id, room_id);
1561 assert!(generic_stream.is_empty());
1562
1563 matrix_mock_server
1565 .mock_sync()
1566 .ok_and_run(&bob, |builder| {
1567 builder.add_to_device_event(
1568 room_key
1569 .deserialize_as()
1570 .expect("We should be able to deserialize the room key"),
1571 );
1572 })
1573 .instrument(bob_span.clone())
1574 .await;
1575
1576 assert_let_timeout!(
1578 Duration::from_secs(1),
1579 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1580 subscriber.recv()
1581 );
1582
1583 assert_eq!(diffs.len(), 1);
1585 assert_matches!(&diffs[0], VectorDiff::Set { index: 0, value });
1586 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1587
1588 let encryption_info = value.encryption_info().unwrap();
1589 assert_matches!(&encryption_info.verification_state, VerificationState::Unverified(_));
1590
1591 assert_let_timeout!(
1592 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1593 );
1594 assert_eq!(expected_room_id, room_id);
1595 assert!(generic_stream.is_empty());
1596
1597 let session_id = encryption_info.session_id().unwrap().to_owned();
1598 let alice_user_id = alice.user_id().unwrap();
1599
1600 alice
1602 .encryption()
1603 .bootstrap_cross_signing(None)
1604 .await
1605 .expect("Alice should be able to create the cross-signing keys");
1606
1607 bob.update_tracked_users_for_testing([alice_user_id]).instrument(bob_span.clone()).await;
1608 matrix_mock_server
1609 .mock_sync()
1610 .ok_and_run(&bob, |builder| {
1611 builder.add_change_device(alice_user_id);
1612 })
1613 .instrument(bob_span.clone())
1614 .await;
1615
1616 bob.event_cache().request_decryption(DecryptionRetryRequest {
1617 room_id: room_id.into(),
1618 utd_session_ids: BTreeSet::new(),
1619 refresh_info_session_ids: BTreeSet::from([session_id]),
1620 });
1621
1622 assert_let_timeout!(
1625 Duration::from_secs(1),
1626 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1627 subscriber.recv()
1628 );
1629
1630 assert_eq!(diffs.len(), 1);
1631 assert_matches!(&diffs[0], VectorDiff::Set { index: 0, value });
1632 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1633 let encryption_info = value.encryption_info().unwrap();
1634
1635 assert_matches!(
1636 &encryption_info.verification_state,
1637 VerificationState::Unverified(_),
1638 "The event should now know about the identity but still be unverified"
1639 );
1640
1641 assert_let_timeout!(
1642 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1643 );
1644 assert_eq!(expected_room_id, room_id);
1645 assert!(generic_stream.is_empty());
1646 }
1647
1648 #[async_test]
1649 async fn test_event_is_redecrypted_even_if_key_arrives_while_event_processing() {
1650 let room_id = room_id!("!test:localhost");
1651
1652 let event_factory = EventFactory::new().room(room_id);
1653 let (alice, bob, matrix_mock_server, delayed_store) =
1654 set_up_clients(room_id, true, true).await;
1655
1656 let delayed_store = delayed_store.unwrap();
1657
1658 let (event, room_key) =
1659 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1660
1661 let event_cache = bob.event_cache();
1662
1663 let (room_cache, _) = event_cache
1665 .for_room(room_id)
1666 .await
1667 .expect("We should be able to get to the event cache for a specific room");
1668
1669 let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
1670 let mut generic_stream = event_cache.subscribe_to_room_generic_updates();
1671
1672 matrix_mock_server
1674 .mock_sync()
1675 .ok_and_run(&bob, |builder| {
1676 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1677 })
1678 .await;
1679
1680 matrix_mock_server
1682 .mock_sync()
1683 .ok_and_run(&bob, |builder| {
1684 builder.add_to_device_event(
1685 room_key
1686 .deserialize_as()
1687 .expect("We should be able to deserialize the room key"),
1688 );
1689 })
1690 .await;
1691
1692 info!("Stopping the delay");
1693 delayed_store.stop_delaying().await;
1694
1695 assert_let_timeout!(
1702 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1703 subscriber.recv()
1704 );
1705
1706 assert_eq!(diffs.len(), 1);
1709 assert_matches!(&diffs[0], VectorDiff::Append { values });
1710 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1711
1712 assert_let_timeout!(
1713 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1714 );
1715 assert_eq!(expected_room_id, room_id);
1716
1717 assert_let_timeout!(
1719 Duration::from_secs(1),
1720 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1721 subscriber.recv()
1722 );
1723
1724 assert_eq!(diffs.len(), 1);
1726 assert_matches!(&diffs[0], VectorDiff::Set { index, value });
1727 assert_eq!(*index, 0);
1728 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1729
1730 assert_let_timeout!(
1731 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1732 );
1733 assert_eq!(expected_room_id, room_id);
1734 assert!(generic_stream.is_empty());
1735 }
1736}