1use std::{
117 collections::{BTreeMap, BTreeSet},
118 pin::Pin,
119 sync::Weak,
120};
121
122use as_variant::as_variant;
123use futures_core::Stream;
124use futures_util::{StreamExt, future::join_all, pin_mut};
125#[cfg(doc)]
126use matrix_sdk_base::{BaseClient, crypto::OlmMachine};
127use matrix_sdk_base::{
128 crypto::{
129 store::types::{RoomKeyInfo, RoomKeyWithheldInfo},
130 types::events::room::encrypted::EncryptedEvent,
131 },
132 deserialized_responses::{DecryptedRoomEvent, TimelineEvent, TimelineEventKind},
133 event_cache::store::EventCacheStoreLockState,
134 locks::Mutex,
135 task_monitor::BackgroundTaskHandle,
136 timer,
137};
138#[cfg(doc)]
139use matrix_sdk_common::deserialized_responses::EncryptionInfo;
140use ruma::{
141 OwnedEventId, OwnedRoomId, RoomId,
142 events::{AnySyncTimelineEvent, room::encrypted::OriginalSyncRoomEncryptedEvent},
143 push::Action,
144 serde::Raw,
145};
146use tokio::sync::{
147 broadcast::{self, Sender},
148 mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
149};
150use tokio_stream::wrappers::{
151 BroadcastStream, UnboundedReceiverStream, errors::BroadcastStreamRecvError,
152};
153use tracing::{info, instrument, trace, warn};
154
155#[cfg(doc)]
156use super::RoomEventCache;
157use super::{
158 EventCache, EventCacheError, EventCacheInner, EventsOrigin, RoomEventCacheGenericUpdate,
159 RoomEventCacheUpdate, TimelineVectorDiffs,
160 caches::room::{PostProcessingOrigin, RoomEventCacheLinkedChunkUpdate},
161};
162use crate::{Client, Result, Room, encryption::backups::BackupState, room::PushContext};
163
164type SessionId<'a> = &'a str;
165type OwnedSessionId = String;
166
167type EventIdAndUtd = (OwnedEventId, Raw<AnySyncTimelineEvent>);
168type EventIdAndEvent = (OwnedEventId, DecryptedRoomEvent);
169pub(in crate::event_cache) type ResolvedUtd =
170 (OwnedEventId, DecryptedRoomEvent, Option<Vec<Action>>);
171
172#[derive(Debug, Clone)]
175pub struct DecryptionRetryRequest {
176 pub room_id: OwnedRoomId,
178 pub utd_session_ids: BTreeSet<OwnedSessionId>,
180 pub refresh_info_session_ids: BTreeSet<OwnedSessionId>,
183}
184
185#[derive(Debug, Clone)]
187pub enum RedecryptorReport {
188 ResolvedUtds {
190 room_id: OwnedRoomId,
192 events: BTreeSet<OwnedEventId>,
194 },
195 Lagging,
198 BackupAvailable,
203}
204
205pub(super) struct RedecryptorChannels {
206 utd_reporter: Sender<RedecryptorReport>,
207 pub(super) decryption_request_sender: UnboundedSender<DecryptionRetryRequest>,
208 pub(super) decryption_request_receiver:
209 Mutex<Option<UnboundedReceiver<DecryptionRetryRequest>>>,
210}
211
212impl RedecryptorChannels {
213 pub(super) fn new() -> Self {
214 let (utd_reporter, _) = broadcast::channel(100);
215 let (decryption_request_sender, decryption_request_receiver) = unbounded_channel();
216
217 Self {
218 utd_reporter,
219 decryption_request_sender,
220 decryption_request_receiver: Mutex::new(Some(decryption_request_receiver)),
221 }
222 }
223}
224
225fn filter_timeline_event_to_utd(
230 event: TimelineEvent,
231) -> Option<(OwnedEventId, Raw<AnySyncTimelineEvent>)> {
232 let event_id = event.event_id();
233
234 let event = as_variant!(event.kind, TimelineEventKind::UnableToDecrypt { event, .. } => event);
237 event_id.zip(event)
240}
241
242fn filter_timeline_event_to_decrypted(
248 event: TimelineEvent,
249) -> Option<(OwnedEventId, DecryptedRoomEvent)> {
250 let event_id = event.event_id();
251
252 let event = as_variant!(event.kind, TimelineEventKind::Decrypted(event) => event);
253 event_id.zip(event)
256}
257
258impl EventCache {
259 async fn get_utds(
267 &self,
268 room_id: &RoomId,
269 session_id: SessionId<'_>,
270 ) -> Result<Vec<EventIdAndUtd>, EventCacheError> {
271 let events = match self.inner.store.lock().await? {
272 EventCacheStoreLockState::Clean(guard) | EventCacheStoreLockState::Dirty(guard) => {
277 guard.get_room_events(room_id, Some("m.room.encrypted"), Some(session_id)).await?
278 }
279 };
280
281 Ok(events.into_iter().filter_map(filter_timeline_event_to_utd).collect())
282 }
283
284 async fn get_utds_from_memory(&self) -> BTreeMap<OwnedRoomId, Vec<EventIdAndUtd>> {
287 let mut utds = BTreeMap::new();
288
289 for (room_id, caches) in self.inner.by_room.read().await.iter() {
290 let room_utds: Vec<_> = caches
291 .all_events()
292 .await
293 .into_iter()
294 .flatten()
295 .filter_map(filter_timeline_event_to_utd)
296 .collect();
297
298 utds.insert(room_id.to_owned(), room_utds);
299 }
300
301 utds
302 }
303
304 async fn get_decrypted_events(
305 &self,
306 room_id: &RoomId,
307 session_id: SessionId<'_>,
308 ) -> Result<Vec<EventIdAndEvent>, EventCacheError> {
309 let events = match self.inner.store.lock().await? {
310 EventCacheStoreLockState::Clean(guard) | EventCacheStoreLockState::Dirty(guard) => {
315 guard.get_room_events(room_id, None, Some(session_id)).await?
316 }
317 };
318
319 Ok(events.into_iter().filter_map(filter_timeline_event_to_decrypted).collect())
320 }
321
322 async fn get_decrypted_events_from_memory(
323 &self,
324 ) -> BTreeMap<OwnedRoomId, Vec<EventIdAndEvent>> {
325 let mut decrypted_events = BTreeMap::new();
326
327 for (room_id, caches) in self.inner.by_room.read().await.iter() {
328 let room_utds: Vec<_> = caches
329 .all_events()
330 .await
331 .into_iter()
332 .flatten()
333 .filter_map(filter_timeline_event_to_decrypted)
334 .collect();
335
336 decrypted_events.insert(room_id.to_owned(), room_utds);
337 }
338
339 decrypted_events
340 }
341
342 #[instrument(skip_all, fields(room_id))]
354 async fn on_resolved_utds(
355 &self,
356 room_id: &RoomId,
357 events: Vec<ResolvedUtd>,
358 ) -> Result<(), EventCacheError> {
359 if events.is_empty() {
360 trace!("No events were redecrypted or updated, nothing to replace");
361 return Ok(());
362 }
363
364 timer!("Resolving UTDs");
365
366 let (room_cache, _drop_handles) = self.for_room(room_id).await?;
369 let mut state = room_cache.state().write().await?;
370
371 let event_ids: BTreeSet<_> =
372 events.iter().cloned().map(|(event_id, _, _)| event_id).collect();
373 let mut new_events = Vec::with_capacity(events.len());
374
375 if let Some(pinned_cache) = state.pinned_event_cache() {
377 pinned_cache.replace_utds(&events).await?;
378 }
379
380 join_all(state.event_focused_caches().map(|cache| cache.replace_utds(&events))).await;
386
387 for (event_id, decrypted, actions) in events {
389 if let Some((location, mut target_event)) = state.find_event(&event_id).await? {
393 target_event.kind = TimelineEventKind::Decrypted(decrypted);
394
395 if let Some(actions) = actions {
396 target_event.set_push_actions(actions);
397 }
398
399 state.replace_event_at(location, target_event.clone()).await?;
402 new_events.push(target_event);
403 }
404 }
405
406 let receipt_event = None;
413
414 state
415 .post_process_new_events(new_events, PostProcessingOrigin::Redecryption, receipt_event)
416 .await?;
417
418 room_cache.update_sender().send(
421 RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs {
422 diffs: state.room_linked_chunk_mut().updates_as_vector_diffs(),
423 origin: EventsOrigin::Cache,
424 }),
425 Some(RoomEventCacheGenericUpdate { room_id: room_id.to_owned() }),
426 );
427
428 let report =
433 RedecryptorReport::ResolvedUtds { room_id: room_id.to_owned(), events: event_ids };
434 let _ = self.inner.redecryption_channels.utd_reporter.send(report);
435
436 Ok(())
437 }
438
439 async fn decrypt_event(
441 &self,
442 room_id: &RoomId,
443 room: Option<&Room>,
444 push_context: Option<&PushContext>,
445 event: &Raw<EncryptedEvent>,
446 ) -> Option<(DecryptedRoomEvent, Option<Vec<Action>>)> {
447 if let Some(room) = room {
448 match room
449 .decrypt_event(
450 event.cast_ref_unchecked::<OriginalSyncRoomEncryptedEvent>(),
451 push_context,
452 )
453 .await
454 {
455 Ok(maybe_decrypted) => {
456 let actions = maybe_decrypted.push_actions().map(|a| a.to_vec());
457
458 if let TimelineEventKind::Decrypted(decrypted) = maybe_decrypted.kind {
459 Some((decrypted, actions))
460 } else {
461 warn!(
462 "Failed to redecrypt an event despite receiving a room key or request to redecrypt"
463 );
464 None
465 }
466 }
467 Err(e) => {
468 warn!(
469 "Failed to redecrypt an event despite receiving a room key or request to redecrypt {e:?}"
470 );
471 None
472 }
473 }
474 } else {
475 let client = self.inner.client().ok()?;
476 let machine = client.olm_machine().await;
477 let machine = machine.as_ref()?;
478
479 match machine.decrypt_room_event(event, room_id, client.decryption_settings()).await {
480 Ok(decrypted) => Some((decrypted, None)),
481 Err(e) => {
482 warn!(
483 "Failed to redecrypt an event despite receiving a room key or a request to redecrypt {e:?}"
484 );
485 None
486 }
487 }
488 }
489 }
490
491 #[instrument(skip_all, fields(room_id, session_id))]
494 async fn retry_decryption(
495 &self,
496 room_id: &RoomId,
497 session_id: SessionId<'_>,
498 ) -> Result<(), EventCacheError> {
499 let events = self.get_utds(room_id, session_id).await?;
501 self.retry_decryption_for_events(room_id, events).await
502 }
503
504 #[instrument(skip_all, fields(updates.linked_chunk_id))]
506 async fn retry_decryption_for_event_cache_updates(
507 &self,
508 updates: RoomEventCacheLinkedChunkUpdate,
509 ) -> Result<(), EventCacheError> {
510 let room_id = updates.linked_chunk_id.room_id();
511 let events: Vec<_> = updates
512 .updates
513 .into_iter()
514 .flat_map(|updates| updates.into_items())
515 .filter_map(filter_timeline_event_to_utd)
516 .collect();
517
518 self.retry_decryption_for_events(room_id, events).await
519 }
520
521 async fn retry_decryption_for_in_memory_events(&self) {
522 let utds = self.get_utds_from_memory().await;
523
524 for (room_id, utds) in utds.into_iter() {
525 if let Err(e) = self.retry_decryption_for_events(&room_id, utds).await {
526 warn!(%room_id, "Failed to redecrypt in-memory events {e:?}");
527 }
528 }
529 }
530
531 #[instrument(skip_all, fields(room_id, session_id))]
533 async fn retry_decryption_for_events(
534 &self,
535 room_id: &RoomId,
536 events: Vec<EventIdAndUtd>,
537 ) -> Result<(), EventCacheError> {
538 trace!("Retrying to decrypt");
539
540 if events.is_empty() {
541 trace!("No relevant events found.");
542 return Ok(());
543 }
544
545 let room = self.inner.client().ok().and_then(|client| client.get_room(room_id));
546 let push_context =
547 if let Some(room) = &room { room.push_context().await.ok().flatten() } else { None };
548
549 let mut decrypted_events = Vec::with_capacity(events.len());
551
552 for (event_id, event) in events {
553 if let Some((decrypted, actions)) = self
556 .decrypt_event(
557 room_id,
558 room.as_ref(),
559 push_context.as_ref(),
560 event.cast_ref_unchecked(),
561 )
562 .await
563 {
564 decrypted_events.push((event_id, decrypted, actions));
565 }
566 }
567
568 let event_ids: BTreeSet<_> =
569 decrypted_events.iter().map(|(event_id, _, _)| event_id).collect();
570
571 if !event_ids.is_empty() {
572 trace!(?event_ids, "Successfully redecrypted events");
573 }
574
575 self.on_resolved_utds(room_id, decrypted_events).await?;
578
579 Ok(())
580 }
581
582 async fn update_encryption_info_for_events(
584 &self,
585 room: &Room,
586 events: Vec<EventIdAndEvent>,
587 ) -> Result<(), EventCacheError> {
588 let mut updated_events = Vec::with_capacity(events.len());
590
591 for (event_id, mut event) in events {
592 if let Some(session_id) = event.encryption_info.session_id() {
593 let new_encryption_info =
594 room.get_encryption_info(session_id, &event.encryption_info.sender).await;
595
596 if let Some(new_encryption_info) = new_encryption_info
598 && event.encryption_info != new_encryption_info
599 {
600 event.encryption_info = new_encryption_info;
601 updated_events.push((event_id, event, None));
602 }
603 }
604 }
605
606 let event_ids: BTreeSet<_> =
607 updated_events.iter().map(|(event_id, _, _)| event_id).collect();
608
609 if !event_ids.is_empty() {
610 trace!(?event_ids, "Replacing the encryption info of some events");
611 }
612
613 self.on_resolved_utds(room.room_id(), updated_events).await
614 }
615
616 #[instrument(skip_all, fields(room_id, session_id))]
617 async fn update_encryption_info(
618 &self,
619 room_id: &RoomId,
620 session_id: SessionId<'_>,
621 ) -> Result<(), EventCacheError> {
622 trace!("Updating encryption info");
623
624 let Ok(client) = self.inner.client() else {
625 return Ok(());
626 };
627
628 let Some(room) = client.get_room(room_id) else {
629 return Ok(());
630 };
631
632 let events = self.get_decrypted_events(room_id, session_id).await?;
634
635 if events.is_empty() {
636 trace!("No relevant events found.");
637 return Ok(());
638 }
639
640 self.update_encryption_info_for_events(&room, events).await
642 }
643
644 async fn retry_update_encryption_info_for_in_memory_events(&self) {
645 let decrypted_events = self.get_decrypted_events_from_memory().await;
646
647 for (room_id, events) in decrypted_events.into_iter() {
648 let Some(room) = self.inner.client().ok().and_then(|c| c.get_room(&room_id)) else {
649 continue;
650 };
651
652 if let Err(e) = self.update_encryption_info_for_events(&room, events).await {
653 warn!(
654 %room_id,
655 "Failed to replace the encryption info for in-memory events {e:?}"
656 );
657 }
658 }
659 }
660
661 async fn retry_in_memory_events(&self) {
672 self.retry_decryption_for_in_memory_events().await;
673 self.retry_update_encryption_info_for_in_memory_events().await;
674 }
675
676 pub fn request_decryption(&self, request: DecryptionRetryRequest) {
717 let _ =
718 self.inner.redecryption_channels.decryption_request_sender.send(request).inspect_err(
719 |_| warn!("Requesting a decryption while the redecryption task has been shut down"),
720 );
721 }
722
723 pub fn subscribe_to_decryption_reports(
774 &self,
775 ) -> impl Stream<Item = Result<RedecryptorReport, BroadcastStreamRecvError>> {
776 BroadcastStream::new(self.inner.redecryption_channels.utd_reporter.subscribe())
777 }
778}
779
780#[inline(always)]
781fn upgrade_event_cache(cache: &Weak<EventCacheInner>) -> Option<EventCache> {
782 cache.upgrade().map(|inner| EventCache { inner })
783}
784
785async fn send_report_and_retry_memory_events(
786 cache: &Weak<EventCacheInner>,
787 report: RedecryptorReport,
788) -> Result<(), ()> {
789 let Some(cache) = upgrade_event_cache(cache) else {
790 return Err(());
791 };
792
793 cache.retry_in_memory_events().await;
794 let _ = cache.inner.redecryption_channels.utd_reporter.send(report);
795
796 Ok(())
797}
798
799pub(crate) struct Redecryptor {
806 _task: BackgroundTaskHandle,
807}
808
809impl Redecryptor {
810 pub(super) fn new(
815 client: &Client,
816 cache: Weak<EventCacheInner>,
817 receiver: UnboundedReceiver<DecryptionRetryRequest>,
818 linked_chunk_update_sender: &Sender<RoomEventCacheLinkedChunkUpdate>,
819 ) -> Self {
820 let linked_chunk_stream = BroadcastStream::new(linked_chunk_update_sender.subscribe());
821 let backup_state_stream = client.encryption().backups().state_stream();
822
823 let task = client
824 .task_monitor()
825 .spawn_background_task("event_cache::redecryptor", async {
826 let request_redecryption_stream = UnboundedReceiverStream::new(receiver);
827
828 Self::listen_for_room_keys_task(
829 cache,
830 request_redecryption_stream,
831 linked_chunk_stream,
832 backup_state_stream,
833 )
834 .await;
835 })
836 .abort_on_drop();
837
838 Self { _task: task }
839 }
840
841 async fn subscribe_to_room_key_stream(
846 cache: &Weak<EventCacheInner>,
847 ) -> Option<(
848 impl Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>>,
849 impl Stream<Item = Vec<RoomKeyWithheldInfo>>,
850 )> {
851 let event_cache = cache.upgrade()?;
852 let client = event_cache.client().ok()?;
853 let machine = client.olm_machine().await;
854
855 machine.as_ref().map(|m| {
856 (m.store().room_keys_received_stream(), m.store().room_keys_withheld_received_stream())
857 })
858 }
859
860 async fn redecryption_loop(
861 cache: &Weak<EventCacheInner>,
862 decryption_request_stream: &mut Pin<&mut impl Stream<Item = DecryptionRetryRequest>>,
863 events_stream: &mut Pin<
864 &mut impl Stream<Item = Result<RoomEventCacheLinkedChunkUpdate, BroadcastStreamRecvError>>,
865 >,
866 backup_state_stream: &mut Pin<
867 &mut impl Stream<Item = Result<BackupState, BroadcastStreamRecvError>>,
868 >,
869 ) -> bool {
870 let Some((room_key_stream, withheld_stream)) =
871 Self::subscribe_to_room_key_stream(cache).await
872 else {
873 return false;
874 };
875
876 pin_mut!(room_key_stream);
877 pin_mut!(withheld_stream);
878
879 loop {
880 tokio::select! {
881 Some(request) = decryption_request_stream.next() => {
884 let Some(cache) = upgrade_event_cache(cache) else {
885 break false;
886 };
887
888 trace!(?request, "Received a redecryption request");
889
890 for session_id in request.utd_session_ids {
891 let _ = cache
892 .retry_decryption(&request.room_id, &session_id)
893 .await
894 .inspect_err(|e| warn!("Error redecrypting after an explicit request was received {e:?}"));
895 }
896
897 for session_id in request.refresh_info_session_ids {
898 let _ = cache.update_encryption_info(&request.room_id, &session_id).await.inspect_err(|e|
899 warn!(
900 room_id = %request.room_id,
901 session_id = session_id,
902 "Unable to update the encryption info {e:?}",
903 ));
904 }
905 }
906 room_keys = room_key_stream.next() => {
909 match room_keys {
910 Some(Ok(room_keys)) => {
911 let Some(cache) = upgrade_event_cache(cache) else {
915 break false;
916 };
917
918 trace!(?room_keys, "Received new room keys");
919
920 for key in &room_keys {
921 let _ = cache
922 .retry_decryption(&key.room_id, &key.session_id)
923 .await
924 .inspect_err(|e| warn!("Error redecrypting {e:?}"));
925 }
926
927 for key in room_keys {
928 let _ = cache.update_encryption_info(&key.room_id, &key.session_id).await.inspect_err(|e|
929 warn!(
930 room_id = %key.room_id,
931 session_id = key.session_id,
932 "Unable to update the encryption info {e:?}",
933 ));
934 }
935 },
936 Some(Err(_)) => {
937 warn!("The room key stream lagged, reporting the lag to our listeners");
944
945 if send_report_and_retry_memory_events(cache, RedecryptorReport::Lagging).await.is_err() {
946 break false;
947 }
948 },
949 None => {
952 break true;
953 }
954 }
955 }
956 withheld_info = withheld_stream.next() => {
957 match withheld_info {
958 Some(infos) => {
959 let Some(cache) = upgrade_event_cache(cache) else {
960 break false;
961 };
962
963 trace!(?infos, "Received new withheld infos");
964
965 for RoomKeyWithheldInfo { room_id, session_id, .. } in &infos {
966 let _ = cache.update_encryption_info(room_id, session_id).await.inspect_err(|e|
967 warn!(
968 room_id = %room_id,
969 session_id = session_id,
970 "Unable to update the encryption info {e:?}",
971 ));
972 }
973 }
974 None => break true,
977 }
978 }
979 Some(event_updates) = events_stream.next() => {
983 match event_updates {
984 Ok(updates) => {
985 let Some(cache) = upgrade_event_cache(cache) else {
986 break false;
987 };
988
989 let linked_chunk_id = updates.linked_chunk_id.to_owned();
990
991 let _ = cache.retry_decryption_for_event_cache_updates(updates).await.inspect_err(|e|
992 warn!(
993 %linked_chunk_id,
994 "Unable to handle UTDs from event cache updates {e:?}",
995 )
996 );
997 }
998 Err(_) => {
999 if send_report_and_retry_memory_events(cache, RedecryptorReport::Lagging).await.is_err() {
1000 break false;
1001 }
1002 }
1003 }
1004 }
1005 Some(backup_state_update) = backup_state_stream.next() => {
1006 match backup_state_update {
1007 Ok(state) => {
1008 match state {
1009 BackupState::Unknown |
1010 BackupState::Creating |
1011 BackupState::Enabling |
1012 BackupState::Resuming |
1013 BackupState::Downloading |
1014 BackupState::Disabling =>{
1015 }
1018 BackupState::Enabled => {
1019 if send_report_and_retry_memory_events(cache, RedecryptorReport::BackupAvailable).await.is_err() {
1024 break false;
1025 }
1026 }
1027 }
1028 }
1029 Err(_) => {
1030 if send_report_and_retry_memory_events(cache, RedecryptorReport::Lagging).await.is_err() {
1031 break false;
1032 }
1033 }
1034 }
1035 }
1036 else => break false,
1037 }
1038 }
1039 }
1040
1041 async fn listen_for_room_keys_task(
1042 cache: Weak<EventCacheInner>,
1043 decryption_request_stream: UnboundedReceiverStream<DecryptionRetryRequest>,
1044 events_stream: BroadcastStream<RoomEventCacheLinkedChunkUpdate>,
1045 backup_state_stream: impl Stream<Item = Result<BackupState, BroadcastStreamRecvError>>,
1046 ) {
1047 pin_mut!(decryption_request_stream);
1051 pin_mut!(events_stream);
1052 pin_mut!(backup_state_stream);
1053
1054 while Self::redecryption_loop(
1055 &cache,
1056 &mut decryption_request_stream,
1057 &mut events_stream,
1058 &mut backup_state_stream,
1059 )
1060 .await
1061 {
1062 info!("Regenerating the re-decryption streams");
1063
1064 if send_report_and_retry_memory_events(&cache, RedecryptorReport::Lagging)
1067 .await
1068 .is_err()
1069 {
1070 break;
1071 }
1072 }
1073
1074 info!("Shutting down the event cache redecryptor");
1075 }
1076}
1077
1078#[cfg(not(target_family = "wasm"))]
1079#[cfg(test)]
1080mod tests {
1081 use std::{
1082 collections::BTreeSet,
1083 sync::{
1084 Arc,
1085 atomic::{AtomicBool, Ordering},
1086 },
1087 time::Duration,
1088 };
1089
1090 use assert_matches2::assert_matches;
1091 use async_trait::async_trait;
1092 use eyeball_im::VectorDiff;
1093 use matrix_sdk_base::{
1094 cross_process_lock::CrossProcessLockGeneration,
1095 crypto::types::events::{ToDeviceEvent, room::encrypted::ToDeviceEncryptedEventContent},
1096 deserialized_responses::{TimelineEventKind, VerificationState},
1097 event_cache::{
1098 Event, Gap,
1099 store::{EventCacheStore, EventCacheStoreError, MemoryStore},
1100 },
1101 linked_chunk::{
1102 ChunkIdentifier, ChunkIdentifierGenerator, ChunkMetadata, LinkedChunkId, Position,
1103 RawChunk, Update,
1104 },
1105 locks::Mutex,
1106 sleep::sleep,
1107 store::StoreConfig,
1108 };
1109 use matrix_sdk_common::cross_process_lock::CrossProcessLockConfig;
1110 use matrix_sdk_test::{JoinedRoomBuilder, async_test, event_factory::EventFactory};
1111 use ruma::{
1112 EventId, OwnedEventId, RoomId, RoomVersionId, device_id, event_id,
1113 events::{AnySyncTimelineEvent, relation::RelationType},
1114 room_id,
1115 serde::Raw,
1116 user_id,
1117 };
1118 use serde_json::json;
1119 use tokio::sync::oneshot::{self, Sender};
1120 use tracing::{Instrument, info};
1121
1122 use crate::{
1123 Client, assert_let_timeout,
1124 encryption::EncryptionSettings,
1125 event_cache::{
1126 DecryptionRetryRequest, RoomEventCacheGenericUpdate, RoomEventCacheUpdate,
1127 TimelineVectorDiffs,
1128 },
1129 test_utils::mocks::MatrixMockServer,
1130 };
1131
1132 #[derive(Debug, Clone)]
1137 struct DelayingStore {
1138 memory_store: MemoryStore,
1139 delaying: Arc<AtomicBool>,
1140 foo: Arc<Mutex<Option<Sender<()>>>>,
1141 }
1142
1143 impl DelayingStore {
1144 fn new() -> Self {
1145 Self {
1146 memory_store: MemoryStore::new(),
1147 delaying: AtomicBool::new(true).into(),
1148 foo: Arc::new(Mutex::new(None)),
1149 }
1150 }
1151
1152 async fn stop_delaying(&self) {
1153 let (sender, receiver) = oneshot::channel();
1154
1155 {
1156 *self.foo.lock() = Some(sender);
1157 }
1158
1159 self.delaying.store(false, Ordering::SeqCst);
1160
1161 receiver.await.expect("We should be able to receive a response")
1162 }
1163 }
1164
1165 #[cfg_attr(target_family = "wasm", async_trait(?Send))]
1166 #[cfg_attr(not(target_family = "wasm"), async_trait)]
1167 impl EventCacheStore for DelayingStore {
1168 type Error = EventCacheStoreError;
1169
1170 async fn try_take_leased_lock(
1171 &self,
1172 lease_duration_ms: u32,
1173 key: &str,
1174 holder: &str,
1175 ) -> Result<Option<CrossProcessLockGeneration>, Self::Error> {
1176 self.memory_store.try_take_leased_lock(lease_duration_ms, key, holder).await
1177 }
1178
1179 async fn handle_linked_chunk_updates(
1180 &self,
1181 linked_chunk_id: LinkedChunkId<'_>,
1182 updates: Vec<Update<Event, Gap>>,
1183 ) -> Result<(), Self::Error> {
1184 while self.delaying.load(Ordering::SeqCst) {
1190 sleep(Duration::from_millis(10)).await;
1191 }
1192
1193 let sender = self.foo.lock().take();
1194 let ret = self.memory_store.handle_linked_chunk_updates(linked_chunk_id, updates).await;
1195
1196 if let Some(sender) = sender {
1197 sender.send(()).expect("We should be able to notify the other side that we're done with the storage operation");
1198 }
1199
1200 ret
1201 }
1202
1203 async fn load_all_chunks(
1204 &self,
1205 linked_chunk_id: LinkedChunkId<'_>,
1206 ) -> Result<Vec<RawChunk<Event, Gap>>, Self::Error> {
1207 self.memory_store.load_all_chunks(linked_chunk_id).await
1208 }
1209
1210 async fn load_all_chunks_metadata(
1211 &self,
1212 linked_chunk_id: LinkedChunkId<'_>,
1213 ) -> Result<Vec<ChunkMetadata>, Self::Error> {
1214 self.memory_store.load_all_chunks_metadata(linked_chunk_id).await
1215 }
1216
1217 async fn load_last_chunk(
1218 &self,
1219 linked_chunk_id: LinkedChunkId<'_>,
1220 ) -> Result<(Option<RawChunk<Event, Gap>>, ChunkIdentifierGenerator), Self::Error> {
1221 self.memory_store.load_last_chunk(linked_chunk_id).await
1222 }
1223
1224 async fn load_previous_chunk(
1225 &self,
1226 linked_chunk_id: LinkedChunkId<'_>,
1227 before_chunk_identifier: ChunkIdentifier,
1228 ) -> Result<Option<RawChunk<Event, Gap>>, Self::Error> {
1229 self.memory_store.load_previous_chunk(linked_chunk_id, before_chunk_identifier).await
1230 }
1231
1232 async fn clear_all_linked_chunks(&self) -> Result<(), Self::Error> {
1233 self.memory_store.clear_all_linked_chunks().await
1234 }
1235
1236 async fn filter_duplicated_events(
1237 &self,
1238 linked_chunk_id: LinkedChunkId<'_>,
1239 events: Vec<OwnedEventId>,
1240 ) -> Result<Vec<(OwnedEventId, Position)>, Self::Error> {
1241 self.memory_store.filter_duplicated_events(linked_chunk_id, events).await
1242 }
1243
1244 async fn find_event(
1245 &self,
1246 room_id: &RoomId,
1247 event_id: &EventId,
1248 ) -> Result<Option<Event>, Self::Error> {
1249 self.memory_store.find_event(room_id, event_id).await
1250 }
1251
1252 async fn find_event_relations(
1253 &self,
1254 room_id: &RoomId,
1255 event_id: &EventId,
1256 filters: Option<&[RelationType]>,
1257 ) -> Result<Vec<(Event, Option<Position>)>, Self::Error> {
1258 self.memory_store.find_event_relations(room_id, event_id, filters).await
1259 }
1260
1261 async fn get_room_events(
1262 &self,
1263 room_id: &RoomId,
1264 event_type: Option<&str>,
1265 session_id: Option<&str>,
1266 ) -> Result<Vec<Event>, Self::Error> {
1267 self.memory_store.get_room_events(room_id, event_type, session_id).await
1268 }
1269
1270 async fn save_event(&self, room_id: &RoomId, event: Event) -> Result<(), Self::Error> {
1271 self.memory_store.save_event(room_id, event).await
1272 }
1273
1274 async fn optimize(&self) -> Result<(), Self::Error> {
1275 self.memory_store.optimize().await
1276 }
1277
1278 async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1279 self.memory_store.get_size().await
1280 }
1281 }
1282
1283 async fn set_up_clients(
1284 room_id: &RoomId,
1285 alice_enables_cross_signing: bool,
1286 use_delayed_store: bool,
1287 ) -> (Client, Client, MatrixMockServer, Option<DelayingStore>) {
1288 let alice_span = tracing::info_span!("alice");
1289 let bob_span = tracing::info_span!("bob");
1290
1291 let alice_user_id = user_id!("@alice:localhost");
1292 let alice_device_id = device_id!("ALICEDEVICE");
1293 let bob_user_id = user_id!("@bob:localhost");
1294 let bob_device_id = device_id!("BOBDEVICE");
1295
1296 let matrix_mock_server = MatrixMockServer::new().await;
1297 matrix_mock_server.mock_crypto_endpoints_preset().await;
1298
1299 let encryption_settings = EncryptionSettings {
1300 auto_enable_cross_signing: alice_enables_cross_signing,
1301 ..Default::default()
1302 };
1303
1304 let alice = matrix_mock_server
1307 .client_builder_for_crypto_end_to_end(alice_user_id, alice_device_id)
1308 .on_builder(|builder| {
1309 builder
1310 .with_enable_share_history_on_invite(true)
1311 .with_encryption_settings(encryption_settings)
1312 })
1313 .build()
1314 .instrument(alice_span.clone())
1315 .await;
1316
1317 let encryption_settings =
1318 EncryptionSettings { auto_enable_cross_signing: true, ..Default::default() };
1319
1320 let (store_config, store) = if use_delayed_store {
1321 let store = DelayingStore::new();
1322
1323 (
1324 StoreConfig::new(CrossProcessLockConfig::multi_process(
1325 "delayed_store_event_cache_test",
1326 ))
1327 .event_cache_store(store.clone()),
1328 Some(store),
1329 )
1330 } else {
1331 (
1332 StoreConfig::new(CrossProcessLockConfig::multi_process(
1333 "normal_store_event_cache_test",
1334 )),
1335 None,
1336 )
1337 };
1338
1339 let bob = matrix_mock_server
1340 .client_builder_for_crypto_end_to_end(bob_user_id, bob_device_id)
1341 .on_builder(|builder| {
1342 builder
1343 .with_enable_share_history_on_invite(true)
1344 .with_encryption_settings(encryption_settings)
1345 .store_config(store_config)
1346 })
1347 .build()
1348 .instrument(bob_span.clone())
1349 .await;
1350
1351 bob.event_cache().subscribe().expect("Bob should be able to enable the event cache");
1352
1353 matrix_mock_server.exchange_e2ee_identities(&alice, &bob).await;
1355
1356 let event_factory = EventFactory::new().room(room_id).sender(alice_user_id);
1357
1358 let room_builder = JoinedRoomBuilder::new(room_id)
1360 .add_state_event(event_factory.create(alice_user_id, RoomVersionId::V1))
1361 .add_state_event(event_factory.room_encryption());
1362
1363 matrix_mock_server
1364 .mock_sync()
1365 .ok_and_run(&alice, |builder| {
1366 builder.add_joined_room(room_builder.clone());
1367 })
1368 .instrument(alice_span)
1369 .await;
1370
1371 matrix_mock_server
1372 .mock_sync()
1373 .ok_and_run(&bob, |builder| {
1374 builder.add_joined_room(room_builder);
1375 })
1376 .instrument(bob_span)
1377 .await;
1378
1379 (alice, bob, matrix_mock_server, store)
1380 }
1381
1382 async fn prepare_room(
1383 matrix_mock_server: &MatrixMockServer,
1384 event_factory: &EventFactory,
1385 alice: &Client,
1386 bob: &Client,
1387 room_id: &RoomId,
1388 ) -> (Raw<AnySyncTimelineEvent>, Raw<ToDeviceEvent<ToDeviceEncryptedEventContent>>) {
1389 let alice_user_id = alice.user_id().unwrap();
1390 let bob_user_id = bob.user_id().unwrap();
1391
1392 let alice_member_event = event_factory.member(alice_user_id).into_raw();
1393 let bob_member_event = event_factory.member(bob_user_id).into_raw();
1394
1395 let room = alice
1396 .get_room(room_id)
1397 .expect("Alice should have access to the room now that we synced");
1398
1399 let event_type = "m.room.message";
1404 let content = json!({"body": "It's a secret to everybody", "msgtype": "m.text"});
1405
1406 let event_id = event_id!("$some_id");
1407 let (event_receiver, mock) =
1408 matrix_mock_server.mock_room_send().ok_with_capture(event_id, alice_user_id);
1409 let (_guard, room_key) = matrix_mock_server.mock_capture_put_to_device(alice_user_id).await;
1410
1411 {
1412 let _guard = mock.mock_once().mount_as_scoped().await;
1413
1414 matrix_mock_server
1415 .mock_get_members()
1416 .ok(vec![alice_member_event.clone(), bob_member_event.clone()])
1417 .mock_once()
1418 .mount()
1419 .await;
1420
1421 room.send_raw(event_type, content)
1422 .await
1423 .expect("We should be able to send an initial message");
1424 };
1425
1426 let event = event_receiver.await.expect("Alice should have sent the event by now");
1428 let room_key = room_key.await;
1429
1430 (event, room_key)
1431 }
1432
1433 #[async_test]
1434 async fn test_redecryptor() {
1435 let room_id = room_id!("!test:localhost");
1436
1437 let event_factory = EventFactory::new().room(room_id);
1438 let (alice, bob, matrix_mock_server, _) = set_up_clients(room_id, true, false).await;
1439
1440 let (event, room_key) =
1441 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1442
1443 let event_cache = bob.event_cache();
1446 let (room_cache, _) = event_cache
1447 .for_room(room_id)
1448 .await
1449 .expect("We should be able to get to the event cache for a specific room");
1450
1451 let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
1452 let mut generic_stream = event_cache.subscribe_to_room_generic_updates();
1453
1454 bob.inner
1457 .base_client
1458 .regenerate_olm(None)
1459 .await
1460 .expect("We should be able to regenerate the Olm machine");
1461
1462 matrix_mock_server
1464 .mock_sync()
1465 .ok_and_run(&bob, |builder| {
1466 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1467 })
1468 .await;
1469
1470 assert_let_timeout!(
1473 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1474 subscriber.recv()
1475 );
1476
1477 assert_eq!(diffs.len(), 1);
1480 assert_matches!(&diffs[0], VectorDiff::Append { values });
1481 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1482
1483 assert_let_timeout!(
1484 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1485 );
1486 assert_eq!(expected_room_id, room_id);
1487 assert!(generic_stream.is_empty());
1488
1489 matrix_mock_server
1491 .mock_sync()
1492 .ok_and_run(&bob, |builder| {
1493 builder.add_to_device_event(
1494 room_key
1495 .deserialize_as()
1496 .expect("We should be able to deserialize the room key"),
1497 );
1498 })
1499 .await;
1500
1501 assert_let_timeout!(
1503 Duration::from_secs(1),
1504 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1505 subscriber.recv()
1506 );
1507
1508 assert_eq!(diffs.len(), 1);
1510 assert_matches!(&diffs[0], VectorDiff::Set { index, value });
1511 assert_eq!(*index, 0);
1512 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1513
1514 assert_let_timeout!(
1515 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1516 );
1517 assert_eq!(expected_room_id, room_id);
1518 assert!(generic_stream.is_empty());
1519 }
1520
1521 #[async_test]
1522 async fn test_redecryptor_updating_encryption_info() {
1523 let bob_span = tracing::info_span!("bob");
1524
1525 let room_id = room_id!("!test:localhost");
1526
1527 let event_factory = EventFactory::new().room(room_id);
1528 let (alice, bob, matrix_mock_server, _) = set_up_clients(room_id, false, false).await;
1529
1530 let (event, room_key) =
1531 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1532
1533 let event_cache = bob.event_cache();
1536 let (room_cache, _) = event_cache
1537 .for_room(room_id)
1538 .instrument(bob_span.clone())
1539 .await
1540 .expect("We should be able to get to the event cache for a specific room");
1541
1542 let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
1543 let mut generic_stream = event_cache.subscribe_to_room_generic_updates();
1544
1545 matrix_mock_server
1547 .mock_sync()
1548 .ok_and_run(&bob, |builder| {
1549 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1550 })
1551 .instrument(bob_span.clone())
1552 .await;
1553
1554 assert_let_timeout!(
1557 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1558 subscriber.recv()
1559 );
1560
1561 assert_eq!(diffs.len(), 1);
1564 assert_matches!(&diffs[0], VectorDiff::Append { values });
1565 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1566
1567 assert_let_timeout!(
1568 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1569 );
1570 assert_eq!(expected_room_id, room_id);
1571 assert!(generic_stream.is_empty());
1572
1573 matrix_mock_server
1575 .mock_sync()
1576 .ok_and_run(&bob, |builder| {
1577 builder.add_to_device_event(
1578 room_key
1579 .deserialize_as()
1580 .expect("We should be able to deserialize the room key"),
1581 );
1582 })
1583 .instrument(bob_span.clone())
1584 .await;
1585
1586 assert_let_timeout!(
1588 Duration::from_secs(1),
1589 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1590 subscriber.recv()
1591 );
1592
1593 assert_eq!(diffs.len(), 1);
1595 assert_matches!(&diffs[0], VectorDiff::Set { index: 0, value });
1596 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1597
1598 let encryption_info = value.encryption_info().unwrap();
1599 assert_matches!(&encryption_info.verification_state, VerificationState::Unverified(_));
1600
1601 assert_let_timeout!(
1602 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1603 );
1604 assert_eq!(expected_room_id, room_id);
1605 assert!(generic_stream.is_empty());
1606
1607 let session_id = encryption_info.session_id().unwrap().to_owned();
1608 let alice_user_id = alice.user_id().unwrap();
1609
1610 alice
1612 .encryption()
1613 .bootstrap_cross_signing(None)
1614 .await
1615 .expect("Alice should be able to create the cross-signing keys");
1616
1617 bob.update_tracked_users_for_testing([alice_user_id]).instrument(bob_span.clone()).await;
1618 matrix_mock_server
1619 .mock_sync()
1620 .ok_and_run(&bob, |builder| {
1621 builder.add_change_device(alice_user_id);
1622 })
1623 .instrument(bob_span.clone())
1624 .await;
1625
1626 bob.event_cache().request_decryption(DecryptionRetryRequest {
1627 room_id: room_id.into(),
1628 utd_session_ids: BTreeSet::new(),
1629 refresh_info_session_ids: BTreeSet::from([session_id]),
1630 });
1631
1632 assert_let_timeout!(
1635 Duration::from_secs(1),
1636 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1637 subscriber.recv()
1638 );
1639
1640 assert_eq!(diffs.len(), 1);
1641 assert_matches!(&diffs[0], VectorDiff::Set { index: 0, value });
1642 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1643 let encryption_info = value.encryption_info().unwrap();
1644
1645 assert_matches!(
1646 &encryption_info.verification_state,
1647 VerificationState::Unverified(_),
1648 "The event should now know about the identity but still be unverified"
1649 );
1650
1651 assert_let_timeout!(
1652 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1653 );
1654 assert_eq!(expected_room_id, room_id);
1655 assert!(generic_stream.is_empty());
1656 }
1657
1658 #[async_test]
1659 async fn test_event_is_redecrypted_even_if_key_arrives_while_event_processing() {
1660 let room_id = room_id!("!test:localhost");
1661
1662 let event_factory = EventFactory::new().room(room_id);
1663 let (alice, bob, matrix_mock_server, delayed_store) =
1664 set_up_clients(room_id, true, true).await;
1665
1666 let delayed_store = delayed_store.unwrap();
1667
1668 let (event, room_key) =
1669 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1670
1671 let event_cache = bob.event_cache();
1672
1673 let (room_cache, _) = event_cache
1675 .for_room(room_id)
1676 .await
1677 .expect("We should be able to get to the event cache for a specific room");
1678
1679 let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
1680 let mut generic_stream = event_cache.subscribe_to_room_generic_updates();
1681
1682 matrix_mock_server
1684 .mock_sync()
1685 .ok_and_run(&bob, |builder| {
1686 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1687 })
1688 .await;
1689
1690 matrix_mock_server
1692 .mock_sync()
1693 .ok_and_run(&bob, |builder| {
1694 builder.add_to_device_event(
1695 room_key
1696 .deserialize_as()
1697 .expect("We should be able to deserialize the room key"),
1698 );
1699 })
1700 .await;
1701
1702 info!("Stopping the delay");
1703 delayed_store.stop_delaying().await;
1704
1705 assert_let_timeout!(
1712 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1713 subscriber.recv()
1714 );
1715
1716 assert_eq!(diffs.len(), 1);
1719 assert_matches!(&diffs[0], VectorDiff::Append { values });
1720 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1721
1722 assert_let_timeout!(
1723 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1724 );
1725 assert_eq!(expected_room_id, room_id);
1726
1727 assert_let_timeout!(
1729 Duration::from_secs(1),
1730 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1731 subscriber.recv()
1732 );
1733
1734 assert_eq!(diffs.len(), 1);
1736 assert_matches!(&diffs[0], VectorDiff::Set { index, value });
1737 assert_eq!(*index, 0);
1738 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1739
1740 assert_let_timeout!(
1741 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1742 );
1743 assert_eq!(expected_room_id, room_id);
1744 assert!(generic_stream.is_empty());
1745 }
1746}