1use std::{
117 collections::{BTreeMap, BTreeSet},
118 pin::Pin,
119 sync::Weak,
120};
121
122use as_variant::as_variant;
123use futures_core::Stream;
124use futures_util::{StreamExt, pin_mut};
125#[cfg(doc)]
126use matrix_sdk_base::{BaseClient, crypto::OlmMachine};
127use matrix_sdk_base::{
128 crypto::{
129 store::types::{RoomKeyInfo, RoomKeyWithheldInfo},
130 types::events::room::encrypted::EncryptedEvent,
131 },
132 deserialized_responses::{DecryptedRoomEvent, TimelineEvent, TimelineEventKind},
133 event_cache::store::EventCacheStoreLockState,
134 locks::Mutex,
135 task_monitor::BackgroundTaskHandle,
136 timer,
137};
138#[cfg(doc)]
139use matrix_sdk_common::deserialized_responses::EncryptionInfo;
140use ruma::{
141 OwnedEventId, OwnedRoomId, RoomId,
142 events::{AnySyncTimelineEvent, room::encrypted::OriginalSyncRoomEncryptedEvent},
143 push::Action,
144 serde::Raw,
145};
146use tokio::sync::{
147 broadcast::{self, Sender},
148 mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
149};
150use tokio_stream::wrappers::{
151 BroadcastStream, UnboundedReceiverStream, errors::BroadcastStreamRecvError,
152};
153use tracing::{info, instrument, trace, warn};
154
155#[cfg(doc)]
156use super::RoomEventCache;
157use super::{EventCache, EventCacheError, EventCacheInner, EventsOrigin, RoomEventCacheUpdate};
158use crate::{
159 Client, Result, Room,
160 encryption::backups::BackupState,
161 event_cache::{
162 RoomEventCacheGenericUpdate, RoomEventCacheLinkedChunkUpdate, TimelineVectorDiffs,
163 room::PostProcessingOrigin,
164 },
165 room::PushContext,
166};
167
168type SessionId<'a> = &'a str;
169type OwnedSessionId = String;
170
171type EventIdAndUtd = (OwnedEventId, Raw<AnySyncTimelineEvent>);
172type EventIdAndEvent = (OwnedEventId, DecryptedRoomEvent);
173pub(in crate::event_cache) type ResolvedUtd =
174 (OwnedEventId, DecryptedRoomEvent, Option<Vec<Action>>);
175
176#[derive(Debug, Clone)]
179pub struct DecryptionRetryRequest {
180 pub room_id: OwnedRoomId,
182 pub utd_session_ids: BTreeSet<OwnedSessionId>,
184 pub refresh_info_session_ids: BTreeSet<OwnedSessionId>,
187}
188
189#[derive(Debug, Clone)]
191pub enum RedecryptorReport {
192 ResolvedUtds {
194 room_id: OwnedRoomId,
196 events: BTreeSet<OwnedEventId>,
198 },
199 Lagging,
202 BackupAvailable,
207}
208
209pub(super) struct RedecryptorChannels {
210 utd_reporter: Sender<RedecryptorReport>,
211 pub(super) decryption_request_sender: UnboundedSender<DecryptionRetryRequest>,
212 pub(super) decryption_request_receiver:
213 Mutex<Option<UnboundedReceiver<DecryptionRetryRequest>>>,
214}
215
216impl RedecryptorChannels {
217 pub(super) fn new() -> Self {
218 let (utd_reporter, _) = broadcast::channel(100);
219 let (decryption_request_sender, decryption_request_receiver) = unbounded_channel();
220
221 Self {
222 utd_reporter,
223 decryption_request_sender,
224 decryption_request_receiver: Mutex::new(Some(decryption_request_receiver)),
225 }
226 }
227}
228
229fn filter_timeline_event_to_utd(
234 event: TimelineEvent,
235) -> Option<(OwnedEventId, Raw<AnySyncTimelineEvent>)> {
236 let event_id = event.event_id();
237
238 let event = as_variant!(event.kind, TimelineEventKind::UnableToDecrypt { event, .. } => event);
241 event_id.zip(event)
244}
245
246fn filter_timeline_event_to_decrypted(
252 event: TimelineEvent,
253) -> Option<(OwnedEventId, DecryptedRoomEvent)> {
254 let event_id = event.event_id();
255
256 let event = as_variant!(event.kind, TimelineEventKind::Decrypted(event) => event);
257 event_id.zip(event)
260}
261
262impl EventCache {
263 async fn get_utds(
271 &self,
272 room_id: &RoomId,
273 session_id: SessionId<'_>,
274 ) -> Result<Vec<EventIdAndUtd>, EventCacheError> {
275 let events = match self.inner.store.lock().await? {
276 EventCacheStoreLockState::Clean(guard) | EventCacheStoreLockState::Dirty(guard) => {
281 guard.get_room_events(room_id, Some("m.room.encrypted"), Some(session_id)).await?
282 }
283 };
284
285 Ok(events.into_iter().filter_map(filter_timeline_event_to_utd).collect())
286 }
287
288 async fn get_utds_from_memory(&self) -> BTreeMap<OwnedRoomId, Vec<EventIdAndUtd>> {
291 let mut utds = BTreeMap::new();
292
293 for (room_id, room_cache) in self.inner.by_room.read().await.iter() {
294 let room_utds: Vec<_> = room_cache
295 .events()
296 .await
297 .into_iter()
298 .flatten()
299 .filter_map(filter_timeline_event_to_utd)
300 .collect();
301
302 utds.insert(room_id.to_owned(), room_utds);
303 }
304
305 utds
306 }
307
308 async fn get_decrypted_events(
309 &self,
310 room_id: &RoomId,
311 session_id: SessionId<'_>,
312 ) -> Result<Vec<EventIdAndEvent>, EventCacheError> {
313 let events = match self.inner.store.lock().await? {
314 EventCacheStoreLockState::Clean(guard) | EventCacheStoreLockState::Dirty(guard) => {
319 guard.get_room_events(room_id, None, Some(session_id)).await?
320 }
321 };
322
323 Ok(events.into_iter().filter_map(filter_timeline_event_to_decrypted).collect())
324 }
325
326 async fn get_decrypted_events_from_memory(
327 &self,
328 ) -> BTreeMap<OwnedRoomId, Vec<EventIdAndEvent>> {
329 let mut decrypted_events = BTreeMap::new();
330
331 for (room_id, room_cache) in self.inner.by_room.read().await.iter() {
332 let room_utds: Vec<_> = room_cache
333 .events()
334 .await
335 .into_iter()
336 .flatten()
337 .filter_map(filter_timeline_event_to_decrypted)
338 .collect();
339
340 decrypted_events.insert(room_id.to_owned(), room_utds);
341 }
342
343 decrypted_events
344 }
345
346 #[instrument(skip_all, fields(room_id))]
358 async fn on_resolved_utds(
359 &self,
360 room_id: &RoomId,
361 events: Vec<ResolvedUtd>,
362 ) -> Result<(), EventCacheError> {
363 if events.is_empty() {
364 trace!("No events were redecrypted or updated, nothing to replace");
365 return Ok(());
366 }
367
368 timer!("Resolving UTDs");
369
370 let (room_cache, _drop_handles) = self.for_room(room_id).await?;
373 let mut state = room_cache.inner.state.write().await?;
374
375 let event_ids: BTreeSet<_> =
376 events.iter().cloned().map(|(event_id, _, _)| event_id).collect();
377 let mut new_events = Vec::with_capacity(events.len());
378
379 if let Some(pinned_cache) = state.pinned_event_cache() {
381 pinned_cache.replace_utds(&events).await?;
382 }
383
384 for (event_id, decrypted, actions) in events {
386 if let Some((location, mut target_event)) = state.find_event(&event_id).await? {
390 target_event.kind = TimelineEventKind::Decrypted(decrypted);
391
392 if let Some(actions) = actions {
393 target_event.set_push_actions(actions);
394 }
395
396 state.replace_event_at(location, target_event.clone()).await?;
399 new_events.push(target_event);
400 }
401 }
402
403 state.post_process_new_events(new_events, PostProcessingOrigin::Redecryption).await?;
404
405 let diffs = state.room_linked_chunk().updates_as_vector_diffs();
408
409 let _ = room_cache.inner.update_sender.send(RoomEventCacheUpdate::UpdateTimelineEvents(
410 TimelineVectorDiffs { diffs, origin: EventsOrigin::Cache },
411 ));
412
413 let _ = room_cache
414 .inner
415 .generic_update_sender
416 .send(RoomEventCacheGenericUpdate { room_id: room_id.to_owned() });
417
418 let report =
423 RedecryptorReport::ResolvedUtds { room_id: room_id.to_owned(), events: event_ids };
424 let _ = self.inner.redecryption_channels.utd_reporter.send(report);
425
426 Ok(())
427 }
428
429 async fn decrypt_event(
431 &self,
432 room_id: &RoomId,
433 room: Option<&Room>,
434 push_context: Option<&PushContext>,
435 event: &Raw<EncryptedEvent>,
436 ) -> Option<(DecryptedRoomEvent, Option<Vec<Action>>)> {
437 if let Some(room) = room {
438 match room
439 .decrypt_event(
440 event.cast_ref_unchecked::<OriginalSyncRoomEncryptedEvent>(),
441 push_context,
442 )
443 .await
444 {
445 Ok(maybe_decrypted) => {
446 let actions = maybe_decrypted.push_actions().map(|a| a.to_vec());
447
448 if let TimelineEventKind::Decrypted(decrypted) = maybe_decrypted.kind {
449 Some((decrypted, actions))
450 } else {
451 warn!(
452 "Failed to redecrypt an event despite receiving a room key or request to redecrypt"
453 );
454 None
455 }
456 }
457 Err(e) => {
458 warn!(
459 "Failed to redecrypt an event despite receiving a room key or request to redecrypt {e:?}"
460 );
461 None
462 }
463 }
464 } else {
465 let client = self.inner.client().ok()?;
466 let machine = client.olm_machine().await;
467 let machine = machine.as_ref()?;
468
469 match machine.decrypt_room_event(event, room_id, client.decryption_settings()).await {
470 Ok(decrypted) => Some((decrypted, None)),
471 Err(e) => {
472 warn!(
473 "Failed to redecrypt an event despite receiving a room key or a request to redecrypt {e:?}"
474 );
475 None
476 }
477 }
478 }
479 }
480
481 #[instrument(skip_all, fields(room_id, session_id))]
484 async fn retry_decryption(
485 &self,
486 room_id: &RoomId,
487 session_id: SessionId<'_>,
488 ) -> Result<(), EventCacheError> {
489 let events = self.get_utds(room_id, session_id).await?;
491 self.retry_decryption_for_events(room_id, events).await
492 }
493
494 #[instrument(skip_all, fields(updates.linked_chunk_id))]
496 async fn retry_decryption_for_event_cache_updates(
497 &self,
498 updates: RoomEventCacheLinkedChunkUpdate,
499 ) -> Result<(), EventCacheError> {
500 let room_id = updates.linked_chunk_id.room_id();
501 let events: Vec<_> = updates
502 .updates
503 .into_iter()
504 .flat_map(|updates| updates.into_items())
505 .filter_map(filter_timeline_event_to_utd)
506 .collect();
507
508 self.retry_decryption_for_events(room_id, events).await
509 }
510
511 async fn retry_decryption_for_in_memory_events(&self) {
512 let utds = self.get_utds_from_memory().await;
513
514 for (room_id, utds) in utds.into_iter() {
515 if let Err(e) = self.retry_decryption_for_events(&room_id, utds).await {
516 warn!(%room_id, "Failed to redecrypt in-memory events {e:?}");
517 }
518 }
519 }
520
521 #[instrument(skip_all, fields(room_id, session_id))]
523 async fn retry_decryption_for_events(
524 &self,
525 room_id: &RoomId,
526 events: Vec<EventIdAndUtd>,
527 ) -> Result<(), EventCacheError> {
528 trace!("Retrying to decrypt");
529
530 if events.is_empty() {
531 trace!("No relevant events found.");
532 return Ok(());
533 }
534
535 let room = self.inner.client().ok().and_then(|client| client.get_room(room_id));
536 let push_context =
537 if let Some(room) = &room { room.push_context().await.ok().flatten() } else { None };
538
539 let mut decrypted_events = Vec::with_capacity(events.len());
541
542 for (event_id, event) in events {
543 if let Some((decrypted, actions)) = self
546 .decrypt_event(
547 room_id,
548 room.as_ref(),
549 push_context.as_ref(),
550 event.cast_ref_unchecked(),
551 )
552 .await
553 {
554 decrypted_events.push((event_id, decrypted, actions));
555 }
556 }
557
558 let event_ids: BTreeSet<_> =
559 decrypted_events.iter().map(|(event_id, _, _)| event_id).collect();
560
561 if !event_ids.is_empty() {
562 trace!(?event_ids, "Successfully redecrypted events");
563 }
564
565 self.on_resolved_utds(room_id, decrypted_events).await?;
568
569 Ok(())
570 }
571
572 async fn update_encryption_info_for_events(
574 &self,
575 room: &Room,
576 events: Vec<EventIdAndEvent>,
577 ) -> Result<(), EventCacheError> {
578 let mut updated_events = Vec::with_capacity(events.len());
580
581 for (event_id, mut event) in events {
582 if let Some(session_id) = event.encryption_info.session_id() {
583 let new_encryption_info =
584 room.get_encryption_info(session_id, &event.encryption_info.sender).await;
585
586 if let Some(new_encryption_info) = new_encryption_info
588 && event.encryption_info != new_encryption_info
589 {
590 event.encryption_info = new_encryption_info;
591 updated_events.push((event_id, event, None));
592 }
593 }
594 }
595
596 let event_ids: BTreeSet<_> =
597 updated_events.iter().map(|(event_id, _, _)| event_id).collect();
598
599 if !event_ids.is_empty() {
600 trace!(?event_ids, "Replacing the encryption info of some events");
601 }
602
603 self.on_resolved_utds(room.room_id(), updated_events).await
604 }
605
606 #[instrument(skip_all, fields(room_id, session_id))]
607 async fn update_encryption_info(
608 &self,
609 room_id: &RoomId,
610 session_id: SessionId<'_>,
611 ) -> Result<(), EventCacheError> {
612 trace!("Updating encryption info");
613
614 let Ok(client) = self.inner.client() else {
615 return Ok(());
616 };
617
618 let Some(room) = client.get_room(room_id) else {
619 return Ok(());
620 };
621
622 let events = self.get_decrypted_events(room_id, session_id).await?;
624
625 if events.is_empty() {
626 trace!("No relevant events found.");
627 return Ok(());
628 }
629
630 self.update_encryption_info_for_events(&room, events).await
632 }
633
634 async fn retry_update_encryption_info_for_in_memory_events(&self) {
635 let decrypted_events = self.get_decrypted_events_from_memory().await;
636
637 for (room_id, events) in decrypted_events.into_iter() {
638 let Some(room) = self.inner.client().ok().and_then(|c| c.get_room(&room_id)) else {
639 continue;
640 };
641
642 if let Err(e) = self.update_encryption_info_for_events(&room, events).await {
643 warn!(
644 %room_id,
645 "Failed to replace the encryption info for in-memory events {e:?}"
646 );
647 }
648 }
649 }
650
651 async fn retry_in_memory_events(&self) {
662 self.retry_decryption_for_in_memory_events().await;
663 self.retry_update_encryption_info_for_in_memory_events().await;
664 }
665
666 pub fn request_decryption(&self, request: DecryptionRetryRequest) {
707 let _ =
708 self.inner.redecryption_channels.decryption_request_sender.send(request).inspect_err(
709 |_| warn!("Requesting a decryption while the redecryption task has been shut down"),
710 );
711 }
712
713 pub fn subscribe_to_decryption_reports(
764 &self,
765 ) -> impl Stream<Item = Result<RedecryptorReport, BroadcastStreamRecvError>> {
766 BroadcastStream::new(self.inner.redecryption_channels.utd_reporter.subscribe())
767 }
768}
769
770#[inline(always)]
771fn upgrade_event_cache(cache: &Weak<EventCacheInner>) -> Option<EventCache> {
772 cache.upgrade().map(|inner| EventCache { inner })
773}
774
775async fn send_report_and_retry_memory_events(
776 cache: &Weak<EventCacheInner>,
777 report: RedecryptorReport,
778) -> Result<(), ()> {
779 let Some(cache) = upgrade_event_cache(cache) else {
780 return Err(());
781 };
782
783 cache.retry_in_memory_events().await;
784 let _ = cache.inner.redecryption_channels.utd_reporter.send(report);
785
786 Ok(())
787}
788
789pub(crate) struct Redecryptor {
796 _task: BackgroundTaskHandle,
797}
798
799impl Redecryptor {
800 pub(super) fn new(
805 client: &Client,
806 cache: Weak<EventCacheInner>,
807 receiver: UnboundedReceiver<DecryptionRetryRequest>,
808 linked_chunk_update_sender: &Sender<RoomEventCacheLinkedChunkUpdate>,
809 ) -> Self {
810 let linked_chunk_stream = BroadcastStream::new(linked_chunk_update_sender.subscribe());
811 let backup_state_stream = client.encryption().backups().state_stream();
812
813 let task = client
814 .task_monitor()
815 .spawn_background_task("event_cache::redecryptor", async {
816 let request_redecryption_stream = UnboundedReceiverStream::new(receiver);
817
818 Self::listen_for_room_keys_task(
819 cache,
820 request_redecryption_stream,
821 linked_chunk_stream,
822 backup_state_stream,
823 )
824 .await;
825 })
826 .abort_on_drop();
827
828 Self { _task: task }
829 }
830
831 async fn subscribe_to_room_key_stream(
836 cache: &Weak<EventCacheInner>,
837 ) -> Option<(
838 impl Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>>,
839 impl Stream<Item = Vec<RoomKeyWithheldInfo>>,
840 )> {
841 let event_cache = cache.upgrade()?;
842 let client = event_cache.client().ok()?;
843 let machine = client.olm_machine().await;
844
845 machine.as_ref().map(|m| {
846 (m.store().room_keys_received_stream(), m.store().room_keys_withheld_received_stream())
847 })
848 }
849
850 async fn redecryption_loop(
851 cache: &Weak<EventCacheInner>,
852 decryption_request_stream: &mut Pin<&mut impl Stream<Item = DecryptionRetryRequest>>,
853 events_stream: &mut Pin<
854 &mut impl Stream<Item = Result<RoomEventCacheLinkedChunkUpdate, BroadcastStreamRecvError>>,
855 >,
856 backup_state_stream: &mut Pin<
857 &mut impl Stream<Item = Result<BackupState, BroadcastStreamRecvError>>,
858 >,
859 ) -> bool {
860 let Some((room_key_stream, withheld_stream)) =
861 Self::subscribe_to_room_key_stream(cache).await
862 else {
863 return false;
864 };
865
866 pin_mut!(room_key_stream);
867 pin_mut!(withheld_stream);
868
869 loop {
870 tokio::select! {
871 Some(request) = decryption_request_stream.next() => {
874 let Some(cache) = upgrade_event_cache(cache) else {
875 break false;
876 };
877
878 trace!(?request, "Received a redecryption request");
879
880 for session_id in request.utd_session_ids {
881 let _ = cache
882 .retry_decryption(&request.room_id, &session_id)
883 .await
884 .inspect_err(|e| warn!("Error redecrypting after an explicit request was received {e:?}"));
885 }
886
887 for session_id in request.refresh_info_session_ids {
888 let _ = cache.update_encryption_info(&request.room_id, &session_id).await.inspect_err(|e|
889 warn!(
890 room_id = %request.room_id,
891 session_id = session_id,
892 "Unable to update the encryption info {e:?}",
893 ));
894 }
895 }
896 room_keys = room_key_stream.next() => {
899 match room_keys {
900 Some(Ok(room_keys)) => {
901 let Some(cache) = upgrade_event_cache(cache) else {
905 break false;
906 };
907
908 trace!(?room_keys, "Received new room keys");
909
910 for key in &room_keys {
911 let _ = cache
912 .retry_decryption(&key.room_id, &key.session_id)
913 .await
914 .inspect_err(|e| warn!("Error redecrypting {e:?}"));
915 }
916
917 for key in room_keys {
918 let _ = cache.update_encryption_info(&key.room_id, &key.session_id).await.inspect_err(|e|
919 warn!(
920 room_id = %key.room_id,
921 session_id = key.session_id,
922 "Unable to update the encryption info {e:?}",
923 ));
924 }
925 },
926 Some(Err(_)) => {
927 warn!("The room key stream lagged, reporting the lag to our listeners");
934
935 if send_report_and_retry_memory_events(cache, RedecryptorReport::Lagging).await.is_err() {
936 break false;
937 }
938 },
939 None => {
942 break true;
943 }
944 }
945 }
946 withheld_info = withheld_stream.next() => {
947 match withheld_info {
948 Some(infos) => {
949 let Some(cache) = upgrade_event_cache(cache) else {
950 break false;
951 };
952
953 trace!(?infos, "Received new withheld infos");
954
955 for RoomKeyWithheldInfo { room_id, session_id, .. } in &infos {
956 let _ = cache.update_encryption_info(room_id, session_id).await.inspect_err(|e|
957 warn!(
958 room_id = %room_id,
959 session_id = session_id,
960 "Unable to update the encryption info {e:?}",
961 ));
962 }
963 }
964 None => break true,
967 }
968 }
969 Some(event_updates) = events_stream.next() => {
973 match event_updates {
974 Ok(updates) => {
975 let Some(cache) = upgrade_event_cache(cache) else {
976 break false;
977 };
978
979 let linked_chunk_id = updates.linked_chunk_id.to_owned();
980
981 let _ = cache.retry_decryption_for_event_cache_updates(updates).await.inspect_err(|e|
982 warn!(
983 %linked_chunk_id,
984 "Unable to handle UTDs from event cache updates {e:?}",
985 )
986 );
987 }
988 Err(_) => {
989 if send_report_and_retry_memory_events(cache, RedecryptorReport::Lagging).await.is_err() {
990 break false;
991 }
992 }
993 }
994 }
995 Some(backup_state_update) = backup_state_stream.next() => {
996 match backup_state_update {
997 Ok(state) => {
998 match state {
999 BackupState::Unknown |
1000 BackupState::Creating |
1001 BackupState::Enabling |
1002 BackupState::Resuming |
1003 BackupState::Downloading |
1004 BackupState::Disabling =>{
1005 }
1008 BackupState::Enabled => {
1009 if send_report_and_retry_memory_events(cache, RedecryptorReport::BackupAvailable).await.is_err() {
1014 break false;
1015 }
1016 }
1017 }
1018 }
1019 Err(_) => {
1020 if send_report_and_retry_memory_events(cache, RedecryptorReport::Lagging).await.is_err() {
1021 break false;
1022 }
1023 }
1024 }
1025 }
1026 else => break false,
1027 }
1028 }
1029 }
1030
1031 async fn listen_for_room_keys_task(
1032 cache: Weak<EventCacheInner>,
1033 decryption_request_stream: UnboundedReceiverStream<DecryptionRetryRequest>,
1034 events_stream: BroadcastStream<RoomEventCacheLinkedChunkUpdate>,
1035 backup_state_stream: impl Stream<Item = Result<BackupState, BroadcastStreamRecvError>>,
1036 ) {
1037 pin_mut!(decryption_request_stream);
1041 pin_mut!(events_stream);
1042 pin_mut!(backup_state_stream);
1043
1044 while Self::redecryption_loop(
1045 &cache,
1046 &mut decryption_request_stream,
1047 &mut events_stream,
1048 &mut backup_state_stream,
1049 )
1050 .await
1051 {
1052 info!("Regenerating the re-decryption streams");
1053
1054 if send_report_and_retry_memory_events(&cache, RedecryptorReport::Lagging)
1057 .await
1058 .is_err()
1059 {
1060 break;
1061 }
1062 }
1063
1064 info!("Shutting down the event cache redecryptor");
1065 }
1066}
1067
1068#[cfg(not(target_family = "wasm"))]
1069#[cfg(test)]
1070mod tests {
1071 use std::{
1072 collections::BTreeSet,
1073 sync::{
1074 Arc,
1075 atomic::{AtomicBool, Ordering},
1076 },
1077 time::Duration,
1078 };
1079
1080 use assert_matches2::assert_matches;
1081 use async_trait::async_trait;
1082 use eyeball_im::VectorDiff;
1083 use matrix_sdk_base::{
1084 cross_process_lock::CrossProcessLockGeneration,
1085 crypto::types::events::{ToDeviceEvent, room::encrypted::ToDeviceEncryptedEventContent},
1086 deserialized_responses::{TimelineEventKind, VerificationState},
1087 event_cache::{
1088 Event, Gap,
1089 store::{EventCacheStore, EventCacheStoreError, MemoryStore},
1090 },
1091 linked_chunk::{
1092 ChunkIdentifier, ChunkIdentifierGenerator, ChunkMetadata, LinkedChunkId, Position,
1093 RawChunk, Update,
1094 },
1095 locks::Mutex,
1096 sleep::sleep,
1097 store::StoreConfig,
1098 };
1099 use matrix_sdk_test::{
1100 JoinedRoomBuilder, StateTestEvent, async_test, event_factory::EventFactory,
1101 };
1102 use ruma::{
1103 EventId, OwnedEventId, RoomId, device_id, event_id,
1104 events::{AnySyncTimelineEvent, relation::RelationType},
1105 room_id,
1106 serde::Raw,
1107 user_id,
1108 };
1109 use serde_json::json;
1110 use tokio::sync::oneshot::{self, Sender};
1111 use tracing::{Instrument, info};
1112
1113 use crate::{
1114 Client, assert_let_timeout,
1115 encryption::EncryptionSettings,
1116 event_cache::{
1117 DecryptionRetryRequest, RoomEventCacheGenericUpdate, RoomEventCacheUpdate,
1118 TimelineVectorDiffs,
1119 },
1120 test_utils::mocks::MatrixMockServer,
1121 };
1122
1123 #[derive(Debug, Clone)]
1128 struct DelayingStore {
1129 memory_store: MemoryStore,
1130 delaying: Arc<AtomicBool>,
1131 foo: Arc<Mutex<Option<Sender<()>>>>,
1132 }
1133
1134 impl DelayingStore {
1135 fn new() -> Self {
1136 Self {
1137 memory_store: MemoryStore::new(),
1138 delaying: AtomicBool::new(true).into(),
1139 foo: Arc::new(Mutex::new(None)),
1140 }
1141 }
1142
1143 async fn stop_delaying(&self) {
1144 let (sender, receiver) = oneshot::channel();
1145
1146 {
1147 *self.foo.lock() = Some(sender);
1148 }
1149
1150 self.delaying.store(false, Ordering::SeqCst);
1151
1152 receiver.await.expect("We should be able to receive a response")
1153 }
1154 }
1155
1156 #[cfg_attr(target_family = "wasm", async_trait(?Send))]
1157 #[cfg_attr(not(target_family = "wasm"), async_trait)]
1158 impl EventCacheStore for DelayingStore {
1159 type Error = EventCacheStoreError;
1160
1161 async fn try_take_leased_lock(
1162 &self,
1163 lease_duration_ms: u32,
1164 key: &str,
1165 holder: &str,
1166 ) -> Result<Option<CrossProcessLockGeneration>, Self::Error> {
1167 self.memory_store.try_take_leased_lock(lease_duration_ms, key, holder).await
1168 }
1169
1170 async fn handle_linked_chunk_updates(
1171 &self,
1172 linked_chunk_id: LinkedChunkId<'_>,
1173 updates: Vec<Update<Event, Gap>>,
1174 ) -> Result<(), Self::Error> {
1175 while self.delaying.load(Ordering::SeqCst) {
1181 sleep(Duration::from_millis(10)).await;
1182 }
1183
1184 let sender = self.foo.lock().take();
1185 let ret = self.memory_store.handle_linked_chunk_updates(linked_chunk_id, updates).await;
1186
1187 if let Some(sender) = sender {
1188 sender.send(()).expect("We should be able to notify the other side that we're done with the storage operation");
1189 }
1190
1191 ret
1192 }
1193
1194 async fn load_all_chunks(
1195 &self,
1196 linked_chunk_id: LinkedChunkId<'_>,
1197 ) -> Result<Vec<RawChunk<Event, Gap>>, Self::Error> {
1198 self.memory_store.load_all_chunks(linked_chunk_id).await
1199 }
1200
1201 async fn load_all_chunks_metadata(
1202 &self,
1203 linked_chunk_id: LinkedChunkId<'_>,
1204 ) -> Result<Vec<ChunkMetadata>, Self::Error> {
1205 self.memory_store.load_all_chunks_metadata(linked_chunk_id).await
1206 }
1207
1208 async fn load_last_chunk(
1209 &self,
1210 linked_chunk_id: LinkedChunkId<'_>,
1211 ) -> Result<(Option<RawChunk<Event, Gap>>, ChunkIdentifierGenerator), Self::Error> {
1212 self.memory_store.load_last_chunk(linked_chunk_id).await
1213 }
1214
1215 async fn load_previous_chunk(
1216 &self,
1217 linked_chunk_id: LinkedChunkId<'_>,
1218 before_chunk_identifier: ChunkIdentifier,
1219 ) -> Result<Option<RawChunk<Event, Gap>>, Self::Error> {
1220 self.memory_store.load_previous_chunk(linked_chunk_id, before_chunk_identifier).await
1221 }
1222
1223 async fn clear_all_linked_chunks(&self) -> Result<(), Self::Error> {
1224 self.memory_store.clear_all_linked_chunks().await
1225 }
1226
1227 async fn filter_duplicated_events(
1228 &self,
1229 linked_chunk_id: LinkedChunkId<'_>,
1230 events: Vec<OwnedEventId>,
1231 ) -> Result<Vec<(OwnedEventId, Position)>, Self::Error> {
1232 self.memory_store.filter_duplicated_events(linked_chunk_id, events).await
1233 }
1234
1235 async fn find_event(
1236 &self,
1237 room_id: &RoomId,
1238 event_id: &EventId,
1239 ) -> Result<Option<Event>, Self::Error> {
1240 self.memory_store.find_event(room_id, event_id).await
1241 }
1242
1243 async fn find_event_relations(
1244 &self,
1245 room_id: &RoomId,
1246 event_id: &EventId,
1247 filters: Option<&[RelationType]>,
1248 ) -> Result<Vec<(Event, Option<Position>)>, Self::Error> {
1249 self.memory_store.find_event_relations(room_id, event_id, filters).await
1250 }
1251
1252 async fn get_room_events(
1253 &self,
1254 room_id: &RoomId,
1255 event_type: Option<&str>,
1256 session_id: Option<&str>,
1257 ) -> Result<Vec<Event>, Self::Error> {
1258 self.memory_store.get_room_events(room_id, event_type, session_id).await
1259 }
1260
1261 async fn save_event(&self, room_id: &RoomId, event: Event) -> Result<(), Self::Error> {
1262 self.memory_store.save_event(room_id, event).await
1263 }
1264
1265 async fn optimize(&self) -> Result<(), Self::Error> {
1266 self.memory_store.optimize().await
1267 }
1268
1269 async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1270 self.memory_store.get_size().await
1271 }
1272 }
1273
1274 async fn set_up_clients(
1275 room_id: &RoomId,
1276 alice_enables_cross_signing: bool,
1277 use_delayed_store: bool,
1278 ) -> (Client, Client, MatrixMockServer, Option<DelayingStore>) {
1279 let alice_span = tracing::info_span!("alice");
1280 let bob_span = tracing::info_span!("bob");
1281
1282 let alice_user_id = user_id!("@alice:localhost");
1283 let alice_device_id = device_id!("ALICEDEVICE");
1284 let bob_user_id = user_id!("@bob:localhost");
1285 let bob_device_id = device_id!("BOBDEVICE");
1286
1287 let matrix_mock_server = MatrixMockServer::new().await;
1288 matrix_mock_server.mock_crypto_endpoints_preset().await;
1289
1290 let encryption_settings = EncryptionSettings {
1291 auto_enable_cross_signing: alice_enables_cross_signing,
1292 ..Default::default()
1293 };
1294
1295 let alice = matrix_mock_server
1298 .client_builder_for_crypto_end_to_end(alice_user_id, alice_device_id)
1299 .on_builder(|builder| {
1300 builder
1301 .with_enable_share_history_on_invite(true)
1302 .with_encryption_settings(encryption_settings)
1303 })
1304 .build()
1305 .instrument(alice_span.clone())
1306 .await;
1307
1308 let encryption_settings =
1309 EncryptionSettings { auto_enable_cross_signing: true, ..Default::default() };
1310
1311 let (store_config, store) = if use_delayed_store {
1312 let store = DelayingStore::new();
1313
1314 (
1315 StoreConfig::new("delayed_store_event_cache_test".into())
1316 .event_cache_store(store.clone()),
1317 Some(store),
1318 )
1319 } else {
1320 (StoreConfig::new("normal_store_event_cache_test".into()), None)
1321 };
1322
1323 let bob = matrix_mock_server
1324 .client_builder_for_crypto_end_to_end(bob_user_id, bob_device_id)
1325 .on_builder(|builder| {
1326 builder
1327 .with_enable_share_history_on_invite(true)
1328 .with_encryption_settings(encryption_settings)
1329 .store_config(store_config)
1330 })
1331 .build()
1332 .instrument(bob_span.clone())
1333 .await;
1334
1335 bob.event_cache().subscribe().expect("Bob should be able to enable the event cache");
1336
1337 matrix_mock_server.exchange_e2ee_identities(&alice, &bob).await;
1339
1340 let room_builder = JoinedRoomBuilder::new(room_id)
1342 .add_state_event(StateTestEvent::Create)
1343 .add_state_event(StateTestEvent::Encryption);
1344
1345 matrix_mock_server
1346 .mock_sync()
1347 .ok_and_run(&alice, |builder| {
1348 builder.add_joined_room(room_builder.clone());
1349 })
1350 .instrument(alice_span)
1351 .await;
1352
1353 matrix_mock_server
1354 .mock_sync()
1355 .ok_and_run(&bob, |builder| {
1356 builder.add_joined_room(room_builder);
1357 })
1358 .instrument(bob_span)
1359 .await;
1360
1361 (alice, bob, matrix_mock_server, store)
1362 }
1363
1364 async fn prepare_room(
1365 matrix_mock_server: &MatrixMockServer,
1366 event_factory: &EventFactory,
1367 alice: &Client,
1368 bob: &Client,
1369 room_id: &RoomId,
1370 ) -> (Raw<AnySyncTimelineEvent>, Raw<ToDeviceEvent<ToDeviceEncryptedEventContent>>) {
1371 let alice_user_id = alice.user_id().unwrap();
1372 let bob_user_id = bob.user_id().unwrap();
1373
1374 let alice_member_event = event_factory.member(alice_user_id).into_raw();
1375 let bob_member_event = event_factory.member(bob_user_id).into_raw();
1376
1377 let room = alice
1378 .get_room(room_id)
1379 .expect("Alice should have access to the room now that we synced");
1380
1381 let event_type = "m.room.message";
1386 let content = json!({"body": "It's a secret to everybody", "msgtype": "m.text"});
1387
1388 let event_id = event_id!("$some_id");
1389 let (event_receiver, mock) =
1390 matrix_mock_server.mock_room_send().ok_with_capture(event_id, alice_user_id);
1391 let (_guard, room_key) = matrix_mock_server.mock_capture_put_to_device(alice_user_id).await;
1392
1393 {
1394 let _guard = mock.mock_once().mount_as_scoped().await;
1395
1396 matrix_mock_server
1397 .mock_get_members()
1398 .ok(vec![alice_member_event.clone(), bob_member_event.clone()])
1399 .mock_once()
1400 .mount()
1401 .await;
1402
1403 room.send_raw(event_type, content)
1404 .await
1405 .expect("We should be able to send an initial message");
1406 };
1407
1408 let event = event_receiver.await.expect("Alice should have sent the event by now");
1410 let room_key = room_key.await;
1411
1412 (event, room_key)
1413 }
1414
1415 #[async_test]
1416 async fn test_redecryptor() {
1417 let room_id = room_id!("!test:localhost");
1418
1419 let event_factory = EventFactory::new().room(room_id);
1420 let (alice, bob, matrix_mock_server, _) = set_up_clients(room_id, true, false).await;
1421
1422 let (event, room_key) =
1423 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1424
1425 let event_cache = bob.event_cache();
1428 let (room_cache, _) = event_cache
1429 .for_room(room_id)
1430 .await
1431 .expect("We should be able to get to the event cache for a specific room");
1432
1433 let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
1434 let mut generic_stream = event_cache.subscribe_to_room_generic_updates();
1435
1436 bob.inner
1439 .base_client
1440 .regenerate_olm(None)
1441 .await
1442 .expect("We should be able to regenerate the Olm machine");
1443
1444 matrix_mock_server
1446 .mock_sync()
1447 .ok_and_run(&bob, |builder| {
1448 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1449 })
1450 .await;
1451
1452 assert_let_timeout!(
1455 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1456 subscriber.recv()
1457 );
1458
1459 assert_eq!(diffs.len(), 1);
1462 assert_matches!(&diffs[0], VectorDiff::Append { values });
1463 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1464
1465 assert_let_timeout!(
1466 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1467 );
1468 assert_eq!(expected_room_id, room_id);
1469 assert!(generic_stream.is_empty());
1470
1471 matrix_mock_server
1473 .mock_sync()
1474 .ok_and_run(&bob, |builder| {
1475 builder.add_to_device_event(
1476 room_key
1477 .deserialize_as()
1478 .expect("We should be able to deserialize the room key"),
1479 );
1480 })
1481 .await;
1482
1483 assert_let_timeout!(
1485 Duration::from_secs(1),
1486 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1487 subscriber.recv()
1488 );
1489
1490 assert_eq!(diffs.len(), 1);
1492 assert_matches!(&diffs[0], VectorDiff::Set { index, value });
1493 assert_eq!(*index, 0);
1494 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1495
1496 assert_let_timeout!(
1497 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1498 );
1499 assert_eq!(expected_room_id, room_id);
1500 assert!(generic_stream.is_empty());
1501 }
1502
1503 #[async_test]
1504 async fn test_redecryptor_updating_encryption_info() {
1505 let bob_span = tracing::info_span!("bob");
1506
1507 let room_id = room_id!("!test:localhost");
1508
1509 let event_factory = EventFactory::new().room(room_id);
1510 let (alice, bob, matrix_mock_server, _) = set_up_clients(room_id, false, false).await;
1511
1512 let (event, room_key) =
1513 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1514
1515 let event_cache = bob.event_cache();
1518 let (room_cache, _) = event_cache
1519 .for_room(room_id)
1520 .instrument(bob_span.clone())
1521 .await
1522 .expect("We should be able to get to the event cache for a specific room");
1523
1524 let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
1525 let mut generic_stream = event_cache.subscribe_to_room_generic_updates();
1526
1527 matrix_mock_server
1529 .mock_sync()
1530 .ok_and_run(&bob, |builder| {
1531 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1532 })
1533 .instrument(bob_span.clone())
1534 .await;
1535
1536 assert_let_timeout!(
1539 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1540 subscriber.recv()
1541 );
1542
1543 assert_eq!(diffs.len(), 1);
1546 assert_matches!(&diffs[0], VectorDiff::Append { values });
1547 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1548
1549 assert_let_timeout!(
1550 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1551 );
1552 assert_eq!(expected_room_id, room_id);
1553 assert!(generic_stream.is_empty());
1554
1555 matrix_mock_server
1557 .mock_sync()
1558 .ok_and_run(&bob, |builder| {
1559 builder.add_to_device_event(
1560 room_key
1561 .deserialize_as()
1562 .expect("We should be able to deserialize the room key"),
1563 );
1564 })
1565 .instrument(bob_span.clone())
1566 .await;
1567
1568 assert_let_timeout!(
1570 Duration::from_secs(1),
1571 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1572 subscriber.recv()
1573 );
1574
1575 assert_eq!(diffs.len(), 1);
1577 assert_matches!(&diffs[0], VectorDiff::Set { index: 0, value });
1578 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1579
1580 let encryption_info = value.encryption_info().unwrap();
1581 assert_matches!(&encryption_info.verification_state, VerificationState::Unverified(_));
1582
1583 assert_let_timeout!(
1584 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1585 );
1586 assert_eq!(expected_room_id, room_id);
1587 assert!(generic_stream.is_empty());
1588
1589 let session_id = encryption_info.session_id().unwrap().to_owned();
1590 let alice_user_id = alice.user_id().unwrap();
1591
1592 alice
1594 .encryption()
1595 .bootstrap_cross_signing(None)
1596 .await
1597 .expect("Alice should be able to create the cross-signing keys");
1598
1599 bob.update_tracked_users_for_testing([alice_user_id]).instrument(bob_span.clone()).await;
1600 matrix_mock_server
1601 .mock_sync()
1602 .ok_and_run(&bob, |builder| {
1603 builder.add_change_device(alice_user_id);
1604 })
1605 .instrument(bob_span.clone())
1606 .await;
1607
1608 bob.event_cache().request_decryption(DecryptionRetryRequest {
1609 room_id: room_id.into(),
1610 utd_session_ids: BTreeSet::new(),
1611 refresh_info_session_ids: BTreeSet::from([session_id]),
1612 });
1613
1614 assert_let_timeout!(
1617 Duration::from_secs(1),
1618 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1619 subscriber.recv()
1620 );
1621
1622 assert_eq!(diffs.len(), 1);
1623 assert_matches!(&diffs[0], VectorDiff::Set { index: 0, value });
1624 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1625 let encryption_info = value.encryption_info().unwrap();
1626
1627 assert_matches!(
1628 &encryption_info.verification_state,
1629 VerificationState::Unverified(_),
1630 "The event should now know about the identity but still be unverified"
1631 );
1632
1633 assert_let_timeout!(
1634 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1635 );
1636 assert_eq!(expected_room_id, room_id);
1637 assert!(generic_stream.is_empty());
1638 }
1639
1640 #[async_test]
1641 async fn test_event_is_redecrypted_even_if_key_arrives_while_event_processing() {
1642 let room_id = room_id!("!test:localhost");
1643
1644 let event_factory = EventFactory::new().room(room_id);
1645 let (alice, bob, matrix_mock_server, delayed_store) =
1646 set_up_clients(room_id, true, true).await;
1647
1648 let delayed_store = delayed_store.unwrap();
1649
1650 let (event, room_key) =
1651 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1652
1653 let event_cache = bob.event_cache();
1654
1655 let (room_cache, _) = event_cache
1657 .for_room(room_id)
1658 .await
1659 .expect("We should be able to get to the event cache for a specific room");
1660
1661 let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
1662 let mut generic_stream = event_cache.subscribe_to_room_generic_updates();
1663
1664 matrix_mock_server
1666 .mock_sync()
1667 .ok_and_run(&bob, |builder| {
1668 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1669 })
1670 .await;
1671
1672 matrix_mock_server
1674 .mock_sync()
1675 .ok_and_run(&bob, |builder| {
1676 builder.add_to_device_event(
1677 room_key
1678 .deserialize_as()
1679 .expect("We should be able to deserialize the room key"),
1680 );
1681 })
1682 .await;
1683
1684 info!("Stopping the delay");
1685 delayed_store.stop_delaying().await;
1686
1687 assert_let_timeout!(
1694 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1695 subscriber.recv()
1696 );
1697
1698 assert_eq!(diffs.len(), 1);
1701 assert_matches!(&diffs[0], VectorDiff::Append { values });
1702 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1703
1704 assert_let_timeout!(
1705 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1706 );
1707 assert_eq!(expected_room_id, room_id);
1708
1709 assert_let_timeout!(
1711 Duration::from_secs(1),
1712 Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
1713 subscriber.recv()
1714 );
1715
1716 assert_eq!(diffs.len(), 1);
1718 assert_matches!(&diffs[0], VectorDiff::Set { index, value });
1719 assert_eq!(*index, 0);
1720 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1721
1722 assert_let_timeout!(
1723 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1724 );
1725 assert_eq!(expected_room_id, room_id);
1726 assert!(generic_stream.is_empty());
1727 }
1728}