1use std::{
116 collections::{BTreeMap, BTreeSet},
117 pin::Pin,
118 sync::Weak,
119};
120
121use as_variant::as_variant;
122use futures_core::Stream;
123use futures_util::{StreamExt, pin_mut};
124#[cfg(doc)]
125use matrix_sdk_base::{BaseClient, crypto::OlmMachine};
126use matrix_sdk_base::{
127 crypto::{
128 store::types::{RoomKeyInfo, RoomKeyWithheldInfo},
129 types::events::room::encrypted::EncryptedEvent,
130 },
131 deserialized_responses::{DecryptedRoomEvent, TimelineEvent, TimelineEventKind},
132 event_cache::store::EventCacheStoreLockState,
133 locks::Mutex,
134 timer,
135};
136#[cfg(doc)]
137use matrix_sdk_common::deserialized_responses::EncryptionInfo;
138use matrix_sdk_common::executor::{AbortOnDrop, JoinHandleExt, spawn};
139use ruma::{
140 OwnedEventId, OwnedRoomId, RoomId,
141 events::{AnySyncTimelineEvent, room::encrypted::OriginalSyncRoomEncryptedEvent},
142 push::Action,
143 serde::Raw,
144};
145use tokio::sync::{
146 broadcast::{self, Sender},
147 mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
148};
149use tokio_stream::wrappers::{
150 BroadcastStream, UnboundedReceiverStream, errors::BroadcastStreamRecvError,
151};
152use tracing::{info, instrument, trace, warn};
153
154#[cfg(doc)]
155use super::RoomEventCache;
156use super::{EventCache, EventCacheError, EventCacheInner, EventsOrigin, RoomEventCacheUpdate};
157use crate::{
158 Client, Result, Room,
159 encryption::backups::BackupState,
160 event_cache::{RoomEventCacheGenericUpdate, RoomEventCacheLinkedChunkUpdate},
161 room::PushContext,
162};
163
164type SessionId<'a> = &'a str;
165type OwnedSessionId = String;
166
167type EventIdAndUtd = (OwnedEventId, Raw<AnySyncTimelineEvent>);
168type EventIdAndEvent = (OwnedEventId, DecryptedRoomEvent);
169type ResolvedUtd = (OwnedEventId, DecryptedRoomEvent, Option<Vec<Action>>);
170
171#[derive(Debug, Clone)]
174pub struct DecryptionRetryRequest {
175 pub room_id: OwnedRoomId,
177 pub utd_session_ids: BTreeSet<OwnedSessionId>,
179 pub refresh_info_session_ids: BTreeSet<OwnedSessionId>,
182}
183
184#[derive(Debug, Clone)]
186pub enum RedecryptorReport {
187 ResolvedUtds {
189 room_id: OwnedRoomId,
191 events: BTreeSet<OwnedEventId>,
193 },
194 Lagging,
197 BackupAvailable,
202}
203
204pub(super) struct RedecryptorChannels {
205 utd_reporter: Sender<RedecryptorReport>,
206 pub(super) decryption_request_sender: UnboundedSender<DecryptionRetryRequest>,
207 pub(super) decryption_request_receiver:
208 Mutex<Option<UnboundedReceiver<DecryptionRetryRequest>>>,
209}
210
211impl RedecryptorChannels {
212 pub(super) fn new() -> Self {
213 let (utd_reporter, _) = broadcast::channel(100);
214 let (decryption_request_sender, decryption_request_receiver) = unbounded_channel();
215
216 Self {
217 utd_reporter,
218 decryption_request_sender,
219 decryption_request_receiver: Mutex::new(Some(decryption_request_receiver)),
220 }
221 }
222}
223
224fn filter_timeline_event_to_utd(
229 event: TimelineEvent,
230) -> Option<(OwnedEventId, Raw<AnySyncTimelineEvent>)> {
231 let event_id = event.event_id();
232
233 let event = as_variant!(event.kind, TimelineEventKind::UnableToDecrypt { event, .. } => event);
236 event_id.zip(event)
239}
240
241fn filter_timeline_event_to_decrypted(
247 event: TimelineEvent,
248) -> Option<(OwnedEventId, DecryptedRoomEvent)> {
249 let event_id = event.event_id();
250
251 let event = as_variant!(event.kind, TimelineEventKind::Decrypted(event) => event);
252 event_id.zip(event)
255}
256
257impl EventCache {
258 async fn get_utds(
266 &self,
267 room_id: &RoomId,
268 session_id: SessionId<'_>,
269 ) -> Result<Vec<EventIdAndUtd>, EventCacheError> {
270 let events = match self.inner.store.lock().await? {
271 EventCacheStoreLockState::Clean(guard) | EventCacheStoreLockState::Dirty(guard) => {
276 guard.get_room_events(room_id, Some("m.room.encrypted"), Some(session_id)).await?
277 }
278 };
279
280 Ok(events.into_iter().filter_map(filter_timeline_event_to_utd).collect())
281 }
282
283 async fn get_utds_from_memory(&self) -> BTreeMap<OwnedRoomId, Vec<EventIdAndUtd>> {
286 let mut utds = BTreeMap::new();
287
288 for (room_id, room_cache) in self.inner.by_room.read().await.iter() {
289 let room_utds: Vec<_> = room_cache
290 .events()
291 .await
292 .into_iter()
293 .flatten()
294 .filter_map(filter_timeline_event_to_utd)
295 .collect();
296
297 utds.insert(room_id.to_owned(), room_utds);
298 }
299
300 utds
301 }
302
303 async fn get_decrypted_events(
304 &self,
305 room_id: &RoomId,
306 session_id: SessionId<'_>,
307 ) -> Result<Vec<EventIdAndEvent>, EventCacheError> {
308 let events = match self.inner.store.lock().await? {
309 EventCacheStoreLockState::Clean(guard) | EventCacheStoreLockState::Dirty(guard) => {
314 guard.get_room_events(room_id, None, Some(session_id)).await?
315 }
316 };
317
318 Ok(events.into_iter().filter_map(filter_timeline_event_to_decrypted).collect())
319 }
320
321 async fn get_decrypted_events_from_memory(
322 &self,
323 ) -> BTreeMap<OwnedRoomId, Vec<EventIdAndEvent>> {
324 let mut decrypted_events = BTreeMap::new();
325
326 for (room_id, room_cache) in self.inner.by_room.read().await.iter() {
327 let room_utds: Vec<_> = room_cache
328 .events()
329 .await
330 .into_iter()
331 .flatten()
332 .filter_map(filter_timeline_event_to_decrypted)
333 .collect();
334
335 decrypted_events.insert(room_id.to_owned(), room_utds);
336 }
337
338 decrypted_events
339 }
340
341 #[instrument(skip_all, fields(room_id))]
353 async fn on_resolved_utds(
354 &self,
355 room_id: &RoomId,
356 events: Vec<ResolvedUtd>,
357 ) -> Result<(), EventCacheError> {
358 if events.is_empty() {
359 trace!("No events were redecrypted or updated, nothing to replace");
360 return Ok(());
361 }
362
363 timer!("Resolving UTDs");
364
365 let (room_cache, _drop_handles) = self.for_room(room_id).await?;
368 let mut state = room_cache.inner.state.write().await?;
369
370 let event_ids: BTreeSet<_> =
371 events.iter().cloned().map(|(event_id, _, _)| event_id).collect();
372 let mut new_events = Vec::with_capacity(events.len());
373
374 for (event_id, decrypted, actions) in events {
375 if let Some((location, mut target_event)) = state.find_event(&event_id).await? {
379 target_event.kind = TimelineEventKind::Decrypted(decrypted);
380
381 if let Some(actions) = actions {
382 target_event.set_push_actions(actions);
383 }
384
385 state.replace_event_at(location, target_event.clone()).await?;
388 new_events.push(target_event);
389 }
390 }
391
392 state.post_process_new_events(new_events, false).await?;
393
394 let diffs = state.room_linked_chunk().updates_as_vector_diffs();
397
398 let _ = room_cache.inner.update_sender.send(RoomEventCacheUpdate::UpdateTimelineEvents {
399 diffs,
400 origin: EventsOrigin::Cache,
401 });
402
403 let _ = room_cache
404 .inner
405 .generic_update_sender
406 .send(RoomEventCacheGenericUpdate { room_id: room_id.to_owned() });
407
408 let report =
413 RedecryptorReport::ResolvedUtds { room_id: room_id.to_owned(), events: event_ids };
414 let _ = self.inner.redecryption_channels.utd_reporter.send(report);
415
416 Ok(())
417 }
418
419 async fn decrypt_event(
421 &self,
422 room_id: &RoomId,
423 room: Option<&Room>,
424 push_context: Option<&PushContext>,
425 event: &Raw<EncryptedEvent>,
426 ) -> Option<(DecryptedRoomEvent, Option<Vec<Action>>)> {
427 if let Some(room) = room {
428 match room
429 .decrypt_event(
430 event.cast_ref_unchecked::<OriginalSyncRoomEncryptedEvent>(),
431 push_context,
432 )
433 .await
434 {
435 Ok(maybe_decrypted) => {
436 let actions = maybe_decrypted.push_actions().map(|a| a.to_vec());
437
438 if let TimelineEventKind::Decrypted(decrypted) = maybe_decrypted.kind {
439 Some((decrypted, actions))
440 } else {
441 warn!(
442 "Failed to redecrypt an event despite receiving a room key or request to redecrypt"
443 );
444 None
445 }
446 }
447 Err(e) => {
448 warn!(
449 "Failed to redecrypt an event despite receiving a room key or request to redecrypt {e:?}"
450 );
451 None
452 }
453 }
454 } else {
455 let client = self.inner.client().ok()?;
456 let machine = client.olm_machine().await;
457 let machine = machine.as_ref()?;
458
459 match machine.decrypt_room_event(event, room_id, client.decryption_settings()).await {
460 Ok(decrypted) => Some((decrypted, None)),
461 Err(e) => {
462 warn!(
463 "Failed to redecrypt an event despite receiving a room key or a request to redecrypt {e:?}"
464 );
465 None
466 }
467 }
468 }
469 }
470
471 #[instrument(skip_all, fields(room_id, session_id))]
474 async fn retry_decryption(
475 &self,
476 room_id: &RoomId,
477 session_id: SessionId<'_>,
478 ) -> Result<(), EventCacheError> {
479 let events = self.get_utds(room_id, session_id).await?;
481 self.retry_decryption_for_events(room_id, events).await
482 }
483
484 #[instrument(skip_all, fields(updates.linked_chunk_id))]
486 async fn retry_decryption_for_event_cache_updates(
487 &self,
488 updates: RoomEventCacheLinkedChunkUpdate,
489 ) -> Result<(), EventCacheError> {
490 let room_id = updates.linked_chunk_id.room_id();
491 let events: Vec<_> = updates
492 .updates
493 .into_iter()
494 .flat_map(|updates| updates.into_items())
495 .filter_map(filter_timeline_event_to_utd)
496 .collect();
497
498 self.retry_decryption_for_events(room_id, events).await
499 }
500
501 async fn retry_decryption_for_in_memory_events(&self) {
502 let utds = self.get_utds_from_memory().await;
503
504 for (room_id, utds) in utds.into_iter() {
505 if let Err(e) = self.retry_decryption_for_events(&room_id, utds).await {
506 warn!(%room_id, "Failed to redecrypt in-memory events {e:?}");
507 }
508 }
509 }
510
511 #[instrument(skip_all, fields(room_id, session_id))]
513 async fn retry_decryption_for_events(
514 &self,
515 room_id: &RoomId,
516 events: Vec<EventIdAndUtd>,
517 ) -> Result<(), EventCacheError> {
518 trace!("Retrying to decrypt");
519
520 if events.is_empty() {
521 trace!("No relevant events found.");
522 return Ok(());
523 }
524
525 let room = self.inner.client().ok().and_then(|client| client.get_room(room_id));
526 let push_context =
527 if let Some(room) = &room { room.push_context().await.ok().flatten() } else { None };
528
529 let mut decrypted_events = Vec::with_capacity(events.len());
531
532 for (event_id, event) in events {
533 if let Some((decrypted, actions)) = self
536 .decrypt_event(
537 room_id,
538 room.as_ref(),
539 push_context.as_ref(),
540 event.cast_ref_unchecked(),
541 )
542 .await
543 {
544 decrypted_events.push((event_id, decrypted, actions));
545 }
546 }
547
548 let event_ids: BTreeSet<_> =
549 decrypted_events.iter().map(|(event_id, _, _)| event_id).collect();
550
551 if !event_ids.is_empty() {
552 trace!(?event_ids, "Successfully redecrypted events");
553 }
554
555 self.on_resolved_utds(room_id, decrypted_events).await?;
558
559 Ok(())
560 }
561
562 async fn update_encryption_info_for_events(
564 &self,
565 room: &Room,
566 events: Vec<EventIdAndEvent>,
567 ) -> Result<(), EventCacheError> {
568 let mut updated_events = Vec::with_capacity(events.len());
570
571 for (event_id, mut event) in events {
572 if let Some(session_id) = event.encryption_info.session_id() {
573 let new_encryption_info =
574 room.get_encryption_info(session_id, &event.encryption_info.sender).await;
575
576 if let Some(new_encryption_info) = new_encryption_info
578 && event.encryption_info != new_encryption_info
579 {
580 event.encryption_info = new_encryption_info;
581 updated_events.push((event_id, event, None));
582 }
583 }
584 }
585
586 let event_ids: BTreeSet<_> =
587 updated_events.iter().map(|(event_id, _, _)| event_id).collect();
588
589 if !event_ids.is_empty() {
590 trace!(?event_ids, "Replacing the encryption info of some events");
591 }
592
593 self.on_resolved_utds(room.room_id(), updated_events).await
594 }
595
596 #[instrument(skip_all, fields(room_id, session_id))]
597 async fn update_encryption_info(
598 &self,
599 room_id: &RoomId,
600 session_id: SessionId<'_>,
601 ) -> Result<(), EventCacheError> {
602 trace!("Updating encryption info");
603
604 let Ok(client) = self.inner.client() else {
605 return Ok(());
606 };
607
608 let Some(room) = client.get_room(room_id) else {
609 return Ok(());
610 };
611
612 let events = self.get_decrypted_events(room_id, session_id).await?;
614
615 if events.is_empty() {
616 trace!("No relevant events found.");
617 return Ok(());
618 }
619
620 self.update_encryption_info_for_events(&room, events).await
622 }
623
624 async fn retry_update_encryption_info_for_in_memory_events(&self) {
625 let decrypted_events = self.get_decrypted_events_from_memory().await;
626
627 for (room_id, events) in decrypted_events.into_iter() {
628 let Some(room) = self.inner.client().ok().and_then(|c| c.get_room(&room_id)) else {
629 continue;
630 };
631
632 if let Err(e) = self.update_encryption_info_for_events(&room, events).await {
633 warn!(
634 %room_id,
635 "Failed to replace the encryption info for in-memory events {e:?}"
636 );
637 }
638 }
639 }
640
641 async fn retry_in_memory_events(&self) {
652 self.retry_decryption_for_in_memory_events().await;
653 self.retry_update_encryption_info_for_in_memory_events().await;
654 }
655
656 pub fn request_decryption(&self, request: DecryptionRetryRequest) {
697 let _ =
698 self.inner.redecryption_channels.decryption_request_sender.send(request).inspect_err(
699 |_| warn!("Requesting a decryption while the redecryption task has been shut down"),
700 );
701 }
702
703 pub fn subscribe_to_decryption_reports(
754 &self,
755 ) -> impl Stream<Item = Result<RedecryptorReport, BroadcastStreamRecvError>> {
756 BroadcastStream::new(self.inner.redecryption_channels.utd_reporter.subscribe())
757 }
758}
759
760#[inline(always)]
761fn upgrade_event_cache(cache: &Weak<EventCacheInner>) -> Option<EventCache> {
762 cache.upgrade().map(|inner| EventCache { inner })
763}
764
765async fn send_report_and_retry_memory_events(
766 cache: &Weak<EventCacheInner>,
767 report: RedecryptorReport,
768) -> Result<(), ()> {
769 let Some(cache) = upgrade_event_cache(cache) else {
770 return Err(());
771 };
772
773 cache.retry_in_memory_events().await;
774 let _ = cache.inner.redecryption_channels.utd_reporter.send(report);
775
776 Ok(())
777}
778
779pub(crate) struct Redecryptor {
786 _task: AbortOnDrop<()>,
787}
788
789impl Redecryptor {
790 pub(super) fn new(
795 client: &Client,
796 cache: Weak<EventCacheInner>,
797 receiver: UnboundedReceiver<DecryptionRetryRequest>,
798 linked_chunk_update_sender: &Sender<RoomEventCacheLinkedChunkUpdate>,
799 ) -> Self {
800 let linked_chunk_stream = BroadcastStream::new(linked_chunk_update_sender.subscribe());
801 let backup_state_stream = client.encryption().backups().state_stream();
802
803 let task = spawn(async {
804 let request_redecryption_stream = UnboundedReceiverStream::new(receiver);
805
806 Self::listen_for_room_keys_task(
807 cache,
808 request_redecryption_stream,
809 linked_chunk_stream,
810 backup_state_stream,
811 )
812 .await;
813 })
814 .abort_on_drop();
815
816 Self { _task: task }
817 }
818
819 async fn subscribe_to_room_key_stream(
824 cache: &Weak<EventCacheInner>,
825 ) -> Option<(
826 impl Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>>,
827 impl Stream<Item = Vec<RoomKeyWithheldInfo>>,
828 )> {
829 let event_cache = cache.upgrade()?;
830 let client = event_cache.client().ok()?;
831 let machine = client.olm_machine().await;
832
833 machine.as_ref().map(|m| {
834 (m.store().room_keys_received_stream(), m.store().room_keys_withheld_received_stream())
835 })
836 }
837
838 async fn redecryption_loop(
839 cache: &Weak<EventCacheInner>,
840 decryption_request_stream: &mut Pin<&mut impl Stream<Item = DecryptionRetryRequest>>,
841 events_stream: &mut Pin<
842 &mut impl Stream<Item = Result<RoomEventCacheLinkedChunkUpdate, BroadcastStreamRecvError>>,
843 >,
844 backup_state_stream: &mut Pin<
845 &mut impl Stream<Item = Result<BackupState, BroadcastStreamRecvError>>,
846 >,
847 ) -> bool {
848 let Some((room_key_stream, withheld_stream)) =
849 Self::subscribe_to_room_key_stream(cache).await
850 else {
851 return false;
852 };
853
854 pin_mut!(room_key_stream);
855 pin_mut!(withheld_stream);
856
857 loop {
858 tokio::select! {
859 Some(request) = decryption_request_stream.next() => {
862 let Some(cache) = upgrade_event_cache(cache) else {
863 break false;
864 };
865
866 trace!(?request, "Received a redecryption request");
867
868 for session_id in request.utd_session_ids {
869 let _ = cache
870 .retry_decryption(&request.room_id, &session_id)
871 .await
872 .inspect_err(|e| warn!("Error redecrypting after an explicit request was received {e:?}"));
873 }
874
875 for session_id in request.refresh_info_session_ids {
876 let _ = cache.update_encryption_info(&request.room_id, &session_id).await.inspect_err(|e|
877 warn!(
878 room_id = %request.room_id,
879 session_id = session_id,
880 "Unable to update the encryption info {e:?}",
881 ));
882 }
883 }
884 room_keys = room_key_stream.next() => {
887 match room_keys {
888 Some(Ok(room_keys)) => {
889 let Some(cache) = upgrade_event_cache(cache) else {
893 break false;
894 };
895
896 trace!(?room_keys, "Received new room keys");
897
898 for key in &room_keys {
899 let _ = cache
900 .retry_decryption(&key.room_id, &key.session_id)
901 .await
902 .inspect_err(|e| warn!("Error redecrypting {e:?}"));
903 }
904
905 for key in room_keys {
906 let _ = cache.update_encryption_info(&key.room_id, &key.session_id).await.inspect_err(|e|
907 warn!(
908 room_id = %key.room_id,
909 session_id = key.session_id,
910 "Unable to update the encryption info {e:?}",
911 ));
912 }
913 },
914 Some(Err(_)) => {
915 warn!("The room key stream lagged, reporting the lag to our listeners");
922
923 if send_report_and_retry_memory_events(cache, RedecryptorReport::Lagging).await.is_err() {
924 break false;
925 }
926 },
927 None => {
930 break true
931 }
932 }
933 }
934 withheld_info = withheld_stream.next() => {
935 match withheld_info {
936 Some(infos) => {
937 let Some(cache) = upgrade_event_cache(cache) else {
938 break false;
939 };
940
941 trace!(?infos, "Received new withheld infos");
942
943 for RoomKeyWithheldInfo { room_id, session_id, .. } in &infos {
944 let _ = cache.update_encryption_info(room_id, session_id).await.inspect_err(|e|
945 warn!(
946 room_id = %room_id,
947 session_id = session_id,
948 "Unable to update the encryption info {e:?}",
949 ));
950 }
951 }
952 None => break true
955 }
956 }
957 Some(event_updates) = events_stream.next() => {
961 match event_updates {
962 Ok(updates) => {
963 let Some(cache) = upgrade_event_cache(cache) else {
964 break false;
965 };
966
967 let linked_chunk_id = updates.linked_chunk_id.to_owned();
968
969 let _ = cache.retry_decryption_for_event_cache_updates(updates).await.inspect_err(|e|
970 warn!(
971 %linked_chunk_id,
972 "Unable to handle UTDs from event cache updates {e:?}",
973 )
974 );
975 }
976 Err(_) => {
977 if send_report_and_retry_memory_events(cache, RedecryptorReport::Lagging).await.is_err() {
978 break false;
979 }
980 }
981 }
982 }
983 Some(backup_state_update) = backup_state_stream.next() => {
984 match backup_state_update {
985 Ok(state) => {
986 match state {
987 BackupState::Unknown |
988 BackupState::Creating |
989 BackupState::Enabling |
990 BackupState::Resuming |
991 BackupState::Downloading |
992 BackupState::Disabling =>{
993 }
996 BackupState::Enabled => {
997 if send_report_and_retry_memory_events(cache, RedecryptorReport::BackupAvailable).await.is_err() {
1002 break false;
1003 }
1004 }
1005 }
1006 }
1007 Err(_) => {
1008 if send_report_and_retry_memory_events(cache, RedecryptorReport::Lagging).await.is_err() {
1009 break false;
1010 }
1011 }
1012 }
1013 }
1014 else => break false,
1015 }
1016 }
1017 }
1018
1019 async fn listen_for_room_keys_task(
1020 cache: Weak<EventCacheInner>,
1021 decryption_request_stream: UnboundedReceiverStream<DecryptionRetryRequest>,
1022 events_stream: BroadcastStream<RoomEventCacheLinkedChunkUpdate>,
1023 backup_state_stream: impl Stream<Item = Result<BackupState, BroadcastStreamRecvError>>,
1024 ) {
1025 pin_mut!(decryption_request_stream);
1029 pin_mut!(events_stream);
1030 pin_mut!(backup_state_stream);
1031
1032 while Self::redecryption_loop(
1033 &cache,
1034 &mut decryption_request_stream,
1035 &mut events_stream,
1036 &mut backup_state_stream,
1037 )
1038 .await
1039 {
1040 info!("Regenerating the re-decryption streams");
1041
1042 if send_report_and_retry_memory_events(&cache, RedecryptorReport::Lagging)
1045 .await
1046 .is_err()
1047 {
1048 break;
1049 }
1050 }
1051
1052 info!("Shutting down the event cache redecryptor");
1053 }
1054}
1055
1056#[cfg(not(target_family = "wasm"))]
1057#[cfg(test)]
1058mod tests {
1059 use std::{
1060 collections::BTreeSet,
1061 sync::{
1062 Arc,
1063 atomic::{AtomicBool, Ordering},
1064 },
1065 time::Duration,
1066 };
1067
1068 use assert_matches2::assert_matches;
1069 use async_trait::async_trait;
1070 use eyeball_im::VectorDiff;
1071 use matrix_sdk_base::{
1072 cross_process_lock::CrossProcessLockGeneration,
1073 crypto::types::events::{ToDeviceEvent, room::encrypted::ToDeviceEncryptedEventContent},
1074 deserialized_responses::{TimelineEventKind, VerificationState},
1075 event_cache::{
1076 Event, Gap,
1077 store::{EventCacheStore, EventCacheStoreError, MemoryStore},
1078 },
1079 linked_chunk::{
1080 ChunkIdentifier, ChunkIdentifierGenerator, ChunkMetadata, LinkedChunkId, Position,
1081 RawChunk, Update,
1082 },
1083 locks::Mutex,
1084 sleep::sleep,
1085 store::StoreConfig,
1086 };
1087 use matrix_sdk_test::{
1088 JoinedRoomBuilder, StateTestEvent, async_test, event_factory::EventFactory,
1089 };
1090 use ruma::{
1091 EventId, OwnedEventId, RoomId, device_id, event_id,
1092 events::{AnySyncTimelineEvent, relation::RelationType},
1093 room_id,
1094 serde::Raw,
1095 user_id,
1096 };
1097 use serde_json::json;
1098 use tokio::sync::oneshot::{self, Sender};
1099 use tracing::{Instrument, info};
1100
1101 use crate::{
1102 Client, assert_let_timeout,
1103 encryption::EncryptionSettings,
1104 event_cache::{DecryptionRetryRequest, RoomEventCacheGenericUpdate, RoomEventCacheUpdate},
1105 test_utils::mocks::MatrixMockServer,
1106 };
1107
1108 #[derive(Debug, Clone)]
1113 struct DelayingStore {
1114 memory_store: MemoryStore,
1115 delaying: Arc<AtomicBool>,
1116 foo: Arc<Mutex<Option<Sender<()>>>>,
1117 }
1118
1119 impl DelayingStore {
1120 fn new() -> Self {
1121 Self {
1122 memory_store: MemoryStore::new(),
1123 delaying: AtomicBool::new(true).into(),
1124 foo: Arc::new(Mutex::new(None)),
1125 }
1126 }
1127
1128 async fn stop_delaying(&self) {
1129 let (sender, receiver) = oneshot::channel();
1130
1131 {
1132 *self.foo.lock() = Some(sender);
1133 }
1134
1135 self.delaying.store(false, Ordering::SeqCst);
1136
1137 receiver.await.expect("We should be able to receive a response")
1138 }
1139 }
1140
1141 #[cfg_attr(target_family = "wasm", async_trait(?Send))]
1142 #[cfg_attr(not(target_family = "wasm"), async_trait)]
1143 impl EventCacheStore for DelayingStore {
1144 type Error = EventCacheStoreError;
1145
1146 async fn try_take_leased_lock(
1147 &self,
1148 lease_duration_ms: u32,
1149 key: &str,
1150 holder: &str,
1151 ) -> Result<Option<CrossProcessLockGeneration>, Self::Error> {
1152 self.memory_store.try_take_leased_lock(lease_duration_ms, key, holder).await
1153 }
1154
1155 async fn handle_linked_chunk_updates(
1156 &self,
1157 linked_chunk_id: LinkedChunkId<'_>,
1158 updates: Vec<Update<Event, Gap>>,
1159 ) -> Result<(), Self::Error> {
1160 while self.delaying.load(Ordering::SeqCst) {
1166 sleep(Duration::from_millis(10)).await;
1167 }
1168
1169 let sender = self.foo.lock().take();
1170 let ret = self.memory_store.handle_linked_chunk_updates(linked_chunk_id, updates).await;
1171
1172 if let Some(sender) = sender {
1173 sender.send(()).expect("We should be able to notify the other side that we're done with the storage operation");
1174 }
1175
1176 ret
1177 }
1178
1179 async fn load_all_chunks(
1180 &self,
1181 linked_chunk_id: LinkedChunkId<'_>,
1182 ) -> Result<Vec<RawChunk<Event, Gap>>, Self::Error> {
1183 self.memory_store.load_all_chunks(linked_chunk_id).await
1184 }
1185
1186 async fn load_all_chunks_metadata(
1187 &self,
1188 linked_chunk_id: LinkedChunkId<'_>,
1189 ) -> Result<Vec<ChunkMetadata>, Self::Error> {
1190 self.memory_store.load_all_chunks_metadata(linked_chunk_id).await
1191 }
1192
1193 async fn load_last_chunk(
1194 &self,
1195 linked_chunk_id: LinkedChunkId<'_>,
1196 ) -> Result<(Option<RawChunk<Event, Gap>>, ChunkIdentifierGenerator), Self::Error> {
1197 self.memory_store.load_last_chunk(linked_chunk_id).await
1198 }
1199
1200 async fn load_previous_chunk(
1201 &self,
1202 linked_chunk_id: LinkedChunkId<'_>,
1203 before_chunk_identifier: ChunkIdentifier,
1204 ) -> Result<Option<RawChunk<Event, Gap>>, Self::Error> {
1205 self.memory_store.load_previous_chunk(linked_chunk_id, before_chunk_identifier).await
1206 }
1207
1208 async fn clear_all_linked_chunks(&self) -> Result<(), Self::Error> {
1209 self.memory_store.clear_all_linked_chunks().await
1210 }
1211
1212 async fn filter_duplicated_events(
1213 &self,
1214 linked_chunk_id: LinkedChunkId<'_>,
1215 events: Vec<OwnedEventId>,
1216 ) -> Result<Vec<(OwnedEventId, Position)>, Self::Error> {
1217 self.memory_store.filter_duplicated_events(linked_chunk_id, events).await
1218 }
1219
1220 async fn find_event(
1221 &self,
1222 room_id: &RoomId,
1223 event_id: &EventId,
1224 ) -> Result<Option<Event>, Self::Error> {
1225 self.memory_store.find_event(room_id, event_id).await
1226 }
1227
1228 async fn find_event_relations(
1229 &self,
1230 room_id: &RoomId,
1231 event_id: &EventId,
1232 filters: Option<&[RelationType]>,
1233 ) -> Result<Vec<(Event, Option<Position>)>, Self::Error> {
1234 self.memory_store.find_event_relations(room_id, event_id, filters).await
1235 }
1236
1237 async fn get_room_events(
1238 &self,
1239 room_id: &RoomId,
1240 event_type: Option<&str>,
1241 session_id: Option<&str>,
1242 ) -> Result<Vec<Event>, Self::Error> {
1243 self.memory_store.get_room_events(room_id, event_type, session_id).await
1244 }
1245
1246 async fn save_event(&self, room_id: &RoomId, event: Event) -> Result<(), Self::Error> {
1247 self.memory_store.save_event(room_id, event).await
1248 }
1249
1250 async fn optimize(&self) -> Result<(), Self::Error> {
1251 self.memory_store.optimize().await
1252 }
1253
1254 async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1255 self.memory_store.get_size().await
1256 }
1257 }
1258
1259 async fn set_up_clients(
1260 room_id: &RoomId,
1261 alice_enables_cross_signing: bool,
1262 use_delayed_store: bool,
1263 ) -> (Client, Client, MatrixMockServer, Option<DelayingStore>) {
1264 let alice_span = tracing::info_span!("alice");
1265 let bob_span = tracing::info_span!("bob");
1266
1267 let alice_user_id = user_id!("@alice:localhost");
1268 let alice_device_id = device_id!("ALICEDEVICE");
1269 let bob_user_id = user_id!("@bob:localhost");
1270 let bob_device_id = device_id!("BOBDEVICE");
1271
1272 let matrix_mock_server = MatrixMockServer::new().await;
1273 matrix_mock_server.mock_crypto_endpoints_preset().await;
1274
1275 let encryption_settings = EncryptionSettings {
1276 auto_enable_cross_signing: alice_enables_cross_signing,
1277 ..Default::default()
1278 };
1279
1280 let alice = matrix_mock_server
1283 .client_builder_for_crypto_end_to_end(alice_user_id, alice_device_id)
1284 .on_builder(|builder| {
1285 builder
1286 .with_enable_share_history_on_invite(true)
1287 .with_encryption_settings(encryption_settings)
1288 })
1289 .build()
1290 .instrument(alice_span.clone())
1291 .await;
1292
1293 let encryption_settings =
1294 EncryptionSettings { auto_enable_cross_signing: true, ..Default::default() };
1295
1296 let (store_config, store) = if use_delayed_store {
1297 let store = DelayingStore::new();
1298
1299 (
1300 StoreConfig::new("delayed_store_event_cache_test".into())
1301 .event_cache_store(store.clone()),
1302 Some(store),
1303 )
1304 } else {
1305 (StoreConfig::new("normal_store_event_cache_test".into()), None)
1306 };
1307
1308 let bob = matrix_mock_server
1309 .client_builder_for_crypto_end_to_end(bob_user_id, bob_device_id)
1310 .on_builder(|builder| {
1311 builder
1312 .with_enable_share_history_on_invite(true)
1313 .with_encryption_settings(encryption_settings)
1314 .store_config(store_config)
1315 })
1316 .build()
1317 .instrument(bob_span.clone())
1318 .await;
1319
1320 bob.event_cache().subscribe().expect("Bob should be able to enable the event cache");
1321
1322 matrix_mock_server.exchange_e2ee_identities(&alice, &bob).await;
1324
1325 let room_builder = JoinedRoomBuilder::new(room_id)
1327 .add_state_event(StateTestEvent::Create)
1328 .add_state_event(StateTestEvent::Encryption);
1329
1330 matrix_mock_server
1331 .mock_sync()
1332 .ok_and_run(&alice, |builder| {
1333 builder.add_joined_room(room_builder.clone());
1334 })
1335 .instrument(alice_span)
1336 .await;
1337
1338 matrix_mock_server
1339 .mock_sync()
1340 .ok_and_run(&bob, |builder| {
1341 builder.add_joined_room(room_builder);
1342 })
1343 .instrument(bob_span)
1344 .await;
1345
1346 (alice, bob, matrix_mock_server, store)
1347 }
1348
1349 async fn prepare_room(
1350 matrix_mock_server: &MatrixMockServer,
1351 event_factory: &EventFactory,
1352 alice: &Client,
1353 bob: &Client,
1354 room_id: &RoomId,
1355 ) -> (Raw<AnySyncTimelineEvent>, Raw<ToDeviceEvent<ToDeviceEncryptedEventContent>>) {
1356 let alice_user_id = alice.user_id().unwrap();
1357 let bob_user_id = bob.user_id().unwrap();
1358
1359 let alice_member_event = event_factory.member(alice_user_id).into_raw();
1360 let bob_member_event = event_factory.member(bob_user_id).into_raw();
1361
1362 let room = alice
1363 .get_room(room_id)
1364 .expect("Alice should have access to the room now that we synced");
1365
1366 let event_type = "m.room.message";
1371 let content = json!({"body": "It's a secret to everybody", "msgtype": "m.text"});
1372
1373 let event_id = event_id!("$some_id");
1374 let (event_receiver, mock) =
1375 matrix_mock_server.mock_room_send().ok_with_capture(event_id, alice_user_id);
1376 let (_guard, room_key) = matrix_mock_server.mock_capture_put_to_device(alice_user_id).await;
1377
1378 {
1379 let _guard = mock.mock_once().mount_as_scoped().await;
1380
1381 matrix_mock_server
1382 .mock_get_members()
1383 .ok(vec![alice_member_event.clone(), bob_member_event.clone()])
1384 .mock_once()
1385 .mount()
1386 .await;
1387
1388 room.send_raw(event_type, content)
1389 .await
1390 .expect("We should be able to send an initial message");
1391 };
1392
1393 let event = event_receiver.await.expect("Alice should have sent the event by now");
1395 let room_key = room_key.await;
1396
1397 (event, room_key)
1398 }
1399
1400 #[async_test]
1401 async fn test_redecryptor() {
1402 let room_id = room_id!("!test:localhost");
1403
1404 let event_factory = EventFactory::new().room(room_id);
1405 let (alice, bob, matrix_mock_server, _) = set_up_clients(room_id, true, false).await;
1406
1407 let (event, room_key) =
1408 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1409
1410 let event_cache = bob.event_cache();
1413 let (room_cache, _) = event_cache
1414 .for_room(room_id)
1415 .await
1416 .expect("We should be able to get to the event cache for a specific room");
1417
1418 let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
1419 let mut generic_stream = event_cache.subscribe_to_room_generic_updates();
1420
1421 bob.inner
1424 .base_client
1425 .regenerate_olm(None)
1426 .await
1427 .expect("We should be able to regenerate the Olm machine");
1428
1429 matrix_mock_server
1431 .mock_sync()
1432 .ok_and_run(&bob, |builder| {
1433 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1434 })
1435 .await;
1436
1437 assert_let_timeout!(
1440 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1441 );
1442
1443 assert_eq!(diffs.len(), 1);
1446 assert_matches!(&diffs[0], VectorDiff::Append { values });
1447 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1448
1449 assert_let_timeout!(
1450 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1451 );
1452 assert_eq!(expected_room_id, room_id);
1453 assert!(generic_stream.is_empty());
1454
1455 matrix_mock_server
1457 .mock_sync()
1458 .ok_and_run(&bob, |builder| {
1459 builder.add_to_device_event(
1460 room_key
1461 .deserialize_as()
1462 .expect("We should be able to deserialize the room key"),
1463 );
1464 })
1465 .await;
1466
1467 assert_let_timeout!(
1469 Duration::from_secs(1),
1470 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1471 );
1472
1473 assert_eq!(diffs.len(), 1);
1475 assert_matches!(&diffs[0], VectorDiff::Set { index, value });
1476 assert_eq!(*index, 0);
1477 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1478
1479 assert_let_timeout!(
1480 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1481 );
1482 assert_eq!(expected_room_id, room_id);
1483 assert!(generic_stream.is_empty());
1484 }
1485
1486 #[async_test]
1487 async fn test_redecryptor_updating_encryption_info() {
1488 let bob_span = tracing::info_span!("bob");
1489
1490 let room_id = room_id!("!test:localhost");
1491
1492 let event_factory = EventFactory::new().room(room_id);
1493 let (alice, bob, matrix_mock_server, _) = set_up_clients(room_id, false, false).await;
1494
1495 let (event, room_key) =
1496 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1497
1498 let event_cache = bob.event_cache();
1501 let (room_cache, _) = event_cache
1502 .for_room(room_id)
1503 .instrument(bob_span.clone())
1504 .await
1505 .expect("We should be able to get to the event cache for a specific room");
1506
1507 let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
1508 let mut generic_stream = event_cache.subscribe_to_room_generic_updates();
1509
1510 matrix_mock_server
1512 .mock_sync()
1513 .ok_and_run(&bob, |builder| {
1514 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1515 })
1516 .instrument(bob_span.clone())
1517 .await;
1518
1519 assert_let_timeout!(
1522 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1523 );
1524
1525 assert_eq!(diffs.len(), 1);
1528 assert_matches!(&diffs[0], VectorDiff::Append { values });
1529 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1530
1531 assert_let_timeout!(
1532 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1533 );
1534 assert_eq!(expected_room_id, room_id);
1535 assert!(generic_stream.is_empty());
1536
1537 matrix_mock_server
1539 .mock_sync()
1540 .ok_and_run(&bob, |builder| {
1541 builder.add_to_device_event(
1542 room_key
1543 .deserialize_as()
1544 .expect("We should be able to deserialize the room key"),
1545 );
1546 })
1547 .instrument(bob_span.clone())
1548 .await;
1549
1550 assert_let_timeout!(
1552 Duration::from_secs(1),
1553 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1554 );
1555
1556 assert_eq!(diffs.len(), 1);
1558 assert_matches!(&diffs[0], VectorDiff::Set { index: 0, value });
1559 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1560
1561 let encryption_info = value.encryption_info().unwrap();
1562 assert_matches!(&encryption_info.verification_state, VerificationState::Unverified(_));
1563
1564 assert_let_timeout!(
1565 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1566 );
1567 assert_eq!(expected_room_id, room_id);
1568 assert!(generic_stream.is_empty());
1569
1570 let session_id = encryption_info.session_id().unwrap().to_owned();
1571 let alice_user_id = alice.user_id().unwrap();
1572
1573 alice
1575 .encryption()
1576 .bootstrap_cross_signing(None)
1577 .await
1578 .expect("Alice should be able to create the cross-signing keys");
1579
1580 bob.update_tracked_users_for_testing([alice_user_id]).instrument(bob_span.clone()).await;
1581 matrix_mock_server
1582 .mock_sync()
1583 .ok_and_run(&bob, |builder| {
1584 builder.add_change_device(alice_user_id);
1585 })
1586 .instrument(bob_span.clone())
1587 .await;
1588
1589 bob.event_cache().request_decryption(DecryptionRetryRequest {
1590 room_id: room_id.into(),
1591 utd_session_ids: BTreeSet::new(),
1592 refresh_info_session_ids: BTreeSet::from([session_id]),
1593 });
1594
1595 assert_let_timeout!(
1598 Duration::from_secs(1),
1599 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1600 );
1601
1602 assert_eq!(diffs.len(), 1);
1603 assert_matches!(&diffs[0], VectorDiff::Set { index: 0, value });
1604 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1605 let encryption_info = value.encryption_info().unwrap();
1606
1607 assert_matches!(
1608 &encryption_info.verification_state,
1609 VerificationState::Unverified(_),
1610 "The event should now know about the identity but still be unverified"
1611 );
1612
1613 assert_let_timeout!(
1614 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1615 );
1616 assert_eq!(expected_room_id, room_id);
1617 assert!(generic_stream.is_empty());
1618 }
1619
1620 #[async_test]
1621 async fn test_event_is_redecrypted_even_if_key_arrives_while_event_processing() {
1622 let room_id = room_id!("!test:localhost");
1623
1624 let event_factory = EventFactory::new().room(room_id);
1625 let (alice, bob, matrix_mock_server, delayed_store) =
1626 set_up_clients(room_id, true, true).await;
1627
1628 let delayed_store = delayed_store.unwrap();
1629
1630 let (event, room_key) =
1631 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1632
1633 let event_cache = bob.event_cache();
1634
1635 let (room_cache, _) = event_cache
1637 .for_room(room_id)
1638 .await
1639 .expect("We should be able to get to the event cache for a specific room");
1640
1641 let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
1642 let mut generic_stream = event_cache.subscribe_to_room_generic_updates();
1643
1644 matrix_mock_server
1646 .mock_sync()
1647 .ok_and_run(&bob, |builder| {
1648 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1649 })
1650 .await;
1651
1652 matrix_mock_server
1654 .mock_sync()
1655 .ok_and_run(&bob, |builder| {
1656 builder.add_to_device_event(
1657 room_key
1658 .deserialize_as()
1659 .expect("We should be able to deserialize the room key"),
1660 );
1661 })
1662 .await;
1663
1664 info!("Stopping the delay");
1665 delayed_store.stop_delaying().await;
1666
1667 assert_let_timeout!(
1674 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1675 );
1676
1677 assert_eq!(diffs.len(), 1);
1680 assert_matches!(&diffs[0], VectorDiff::Append { values });
1681 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1682
1683 assert_let_timeout!(
1684 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1685 );
1686 assert_eq!(expected_room_id, room_id);
1687
1688 assert_let_timeout!(
1690 Duration::from_secs(1),
1691 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1692 );
1693
1694 assert_eq!(diffs.len(), 1);
1696 assert_matches!(&diffs[0], VectorDiff::Set { index, value });
1697 assert_eq!(*index, 0);
1698 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1699
1700 assert_let_timeout!(
1701 Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
1702 );
1703 assert_eq!(expected_room_id, room_id);
1704 assert!(generic_stream.is_empty());
1705 }
1706}