1use std::{collections::BTreeSet, pin::Pin, sync::Weak};
114
115use as_variant::as_variant;
116use futures_core::Stream;
117use futures_util::{StreamExt, pin_mut};
118#[cfg(doc)]
119use matrix_sdk_base::{BaseClient, crypto::OlmMachine};
120use matrix_sdk_base::{
121 crypto::{
122 store::types::{RoomKeyInfo, RoomKeyWithheldInfo},
123 types::events::room::encrypted::EncryptedEvent,
124 },
125 deserialized_responses::{DecryptedRoomEvent, TimelineEvent, TimelineEventKind},
126 locks::Mutex,
127 timer,
128};
129#[cfg(doc)]
130use matrix_sdk_common::deserialized_responses::EncryptionInfo;
131use matrix_sdk_common::executor::{AbortOnDrop, JoinHandleExt, spawn};
132use ruma::{
133 OwnedEventId, OwnedRoomId, RoomId,
134 events::{AnySyncTimelineEvent, room::encrypted::OriginalSyncRoomEncryptedEvent},
135 push::Action,
136 serde::Raw,
137};
138use tokio::sync::{
139 broadcast::{self, Sender},
140 mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
141};
142use tokio_stream::wrappers::{
143 BroadcastStream, UnboundedReceiverStream, errors::BroadcastStreamRecvError,
144};
145use tracing::{info, instrument, trace, warn};
146
147#[cfg(doc)]
148use super::RoomEventCache;
149use super::{EventCache, EventCacheError, EventCacheInner, EventsOrigin, RoomEventCacheUpdate};
150use crate::{Room, event_cache::RoomEventCacheLinkedChunkUpdate, room::PushContext};
151
152type SessionId<'a> = &'a str;
153type OwnedSessionId = String;
154
155type EventIdAndUtd = (OwnedEventId, Raw<AnySyncTimelineEvent>);
156type EventIdAndEvent = (OwnedEventId, DecryptedRoomEvent);
157type ResolvedUtd = (OwnedEventId, DecryptedRoomEvent, Option<Vec<Action>>);
158
159#[derive(Debug, Clone)]
162pub struct DecryptionRetryRequest {
163 pub room_id: OwnedRoomId,
165 pub utd_session_ids: BTreeSet<OwnedSessionId>,
167 pub refresh_info_session_ids: BTreeSet<OwnedSessionId>,
170}
171
172#[derive(Debug, Clone)]
174pub enum RedecryptorReport {
175 ResolvedUtds {
177 room_id: OwnedRoomId,
179 events: BTreeSet<OwnedEventId>,
181 },
182 Lagging,
185}
186
187pub(super) struct RedecryptorChannels {
188 utd_reporter: Sender<RedecryptorReport>,
189 pub(super) decryption_request_sender: UnboundedSender<DecryptionRetryRequest>,
190 pub(super) decryption_request_receiver:
191 Mutex<Option<UnboundedReceiver<DecryptionRetryRequest>>>,
192}
193
194impl RedecryptorChannels {
195 pub(super) fn new() -> Self {
196 let (utd_reporter, _) = broadcast::channel(100);
197 let (decryption_request_sender, decryption_request_receiver) = unbounded_channel();
198
199 Self {
200 utd_reporter,
201 decryption_request_sender,
202 decryption_request_receiver: Mutex::new(Some(decryption_request_receiver)),
203 }
204 }
205}
206
207fn filter_timeline_event_to_utd(
212 event: TimelineEvent,
213) -> Option<(OwnedEventId, Raw<AnySyncTimelineEvent>)> {
214 let event_id = event.event_id();
215
216 let event = as_variant!(event.kind, TimelineEventKind::UnableToDecrypt { event, .. } => event);
219 event_id.zip(event)
222}
223
224impl EventCache {
225 async fn get_utds(
233 &self,
234 room_id: &RoomId,
235 session_id: SessionId<'_>,
236 ) -> Result<Vec<EventIdAndUtd>, EventCacheError> {
237 let events = {
238 let store = self.inner.store.lock().await?;
239 store.get_room_events(room_id, Some("m.room.encrypted"), Some(session_id)).await?
240 };
241
242 Ok(events.into_iter().filter_map(filter_timeline_event_to_utd).collect())
243 }
244
245 async fn get_decrypted_events(
246 &self,
247 room_id: &RoomId,
248 session_id: SessionId<'_>,
249 ) -> Result<Vec<EventIdAndEvent>, EventCacheError> {
250 let filter = |event: TimelineEvent| {
251 let event_id = event.event_id();
252
253 let event = as_variant!(event.kind, TimelineEventKind::Decrypted(event) => event);
254 event_id.zip(event)
257 };
258
259 let events = {
260 let store = self.inner.store.lock().await?;
261 store.get_room_events(room_id, None, Some(session_id)).await?
262 };
263
264 Ok(events.into_iter().filter_map(filter).collect())
265 }
266
267 #[instrument(skip_all, fields(room_id))]
279 async fn on_resolved_utds(
280 &self,
281 room_id: &RoomId,
282 events: Vec<ResolvedUtd>,
283 ) -> Result<(), EventCacheError> {
284 if events.is_empty() {
285 trace!("No events were redecrypted or updated, nothing to replace");
286 return Ok(());
287 }
288
289 timer!("Resolving UTDs");
290
291 let (room_cache, _drop_handles) = self.for_room(room_id).await?;
294 let mut state = room_cache.inner.state.write().await;
295
296 let event_ids: BTreeSet<_> =
297 events.iter().cloned().map(|(event_id, _, _)| event_id).collect();
298 let mut new_events = Vec::with_capacity(events.len());
299
300 for (event_id, decrypted, actions) in events {
301 if let Some((location, mut target_event)) = state.find_event(&event_id).await? {
305 target_event.kind = TimelineEventKind::Decrypted(decrypted);
306
307 if let Some(actions) = actions {
308 target_event.set_push_actions(actions);
309 }
310
311 state.replace_event_at(location, target_event.clone()).await?;
314 new_events.push(target_event);
315 }
316 }
317
318 state.post_process_new_events(new_events, false).await?;
319
320 let diffs = state.room_linked_chunk_mut().updates_as_vector_diffs();
323
324 let _ = room_cache.inner.sender.send(RoomEventCacheUpdate::UpdateTimelineEvents {
325 diffs,
326 origin: EventsOrigin::Cache,
327 });
328
329 let report =
334 RedecryptorReport::ResolvedUtds { room_id: room_id.to_owned(), events: event_ids };
335 let _ = self.inner.redecryption_channels.utd_reporter.send(report);
336
337 Ok(())
338 }
339
340 async fn decrypt_event(
342 &self,
343 room_id: &RoomId,
344 room: Option<&Room>,
345 push_context: Option<&PushContext>,
346 event: &Raw<EncryptedEvent>,
347 ) -> Option<(DecryptedRoomEvent, Option<Vec<Action>>)> {
348 if let Some(room) = room {
349 match room
350 .decrypt_event(
351 event.cast_ref_unchecked::<OriginalSyncRoomEncryptedEvent>(),
352 push_context,
353 )
354 .await
355 {
356 Ok(maybe_decrypted) => {
357 let actions = maybe_decrypted.push_actions().map(|a| a.to_vec());
358
359 if let TimelineEventKind::Decrypted(decrypted) = maybe_decrypted.kind {
360 Some((decrypted, actions))
361 } else {
362 warn!(
363 "Failed to redecrypt an event despite receiving a room key or request to redecrypt"
364 );
365 None
366 }
367 }
368 Err(e) => {
369 warn!(
370 "Failed to redecrypt an event despite receiving a room key or request to redecrypt {e:?}"
371 );
372 None
373 }
374 }
375 } else {
376 let client = self.inner.client().ok()?;
377 let machine = client.olm_machine().await;
378 let machine = machine.as_ref()?;
379
380 match machine.decrypt_room_event(event, room_id, client.decryption_settings()).await {
381 Ok(decrypted) => Some((decrypted, None)),
382 Err(e) => {
383 warn!(
384 "Failed to redecrypt an event despite receiving a room key or a request to redecrypt {e:?}"
385 );
386 None
387 }
388 }
389 }
390 }
391
392 #[instrument(skip_all, fields(room_id, session_id))]
395 async fn retry_decryption(
396 &self,
397 room_id: &RoomId,
398 session_id: SessionId<'_>,
399 ) -> Result<(), EventCacheError> {
400 let events = self.get_utds(room_id, session_id).await?;
402 self.retry_decryption_for_events(room_id, events).await
403 }
404
405 #[instrument(skip_all, fields(updates.linked_chunk_id))]
407 async fn retry_decryption_for_event_cache_updates(
408 &self,
409 updates: RoomEventCacheLinkedChunkUpdate,
410 ) -> Result<(), EventCacheError> {
411 let room_id = updates.linked_chunk_id.room_id();
412 let events: Vec<_> = updates
413 .updates
414 .into_iter()
415 .flat_map(|updates| updates.into_items())
416 .filter_map(filter_timeline_event_to_utd)
417 .collect();
418
419 self.retry_decryption_for_events(room_id, events).await
420 }
421
422 #[instrument(skip_all, fields(room_id, session_id))]
424 async fn retry_decryption_for_events(
425 &self,
426 room_id: &RoomId,
427 events: Vec<EventIdAndUtd>,
428 ) -> Result<(), EventCacheError> {
429 trace!("Retrying to decrypt");
430
431 if events.is_empty() {
432 trace!("No relevant events found.");
433 return Ok(());
434 }
435
436 let room = self.inner.client().ok().and_then(|client| client.get_room(room_id));
437 let push_context =
438 if let Some(room) = &room { room.push_context().await.ok().flatten() } else { None };
439
440 let mut decrypted_events = Vec::with_capacity(events.len());
442
443 for (event_id, event) in events {
444 if let Some((decrypted, actions)) = self
447 .decrypt_event(
448 room_id,
449 room.as_ref(),
450 push_context.as_ref(),
451 event.cast_ref_unchecked(),
452 )
453 .await
454 {
455 decrypted_events.push((event_id, decrypted, actions));
456 }
457 }
458
459 let event_ids: BTreeSet<_> =
460 decrypted_events.iter().map(|(event_id, _, _)| event_id).collect();
461
462 if !event_ids.is_empty() {
463 trace!(?event_ids, "Successfully redecrypted events");
464 }
465
466 self.on_resolved_utds(room_id, decrypted_events).await?;
469
470 Ok(())
471 }
472
473 #[instrument(skip_all, fields(room_id, session_id))]
474 async fn update_encryption_info(
475 &self,
476 room_id: &RoomId,
477 session_id: SessionId<'_>,
478 ) -> Result<(), EventCacheError> {
479 trace!("Updating encryption info");
480
481 let Ok(client) = self.inner.client() else {
482 return Ok(());
483 };
484
485 let Some(room) = client.get_room(room_id) else {
486 return Ok(());
487 };
488
489 let events = self.get_decrypted_events(room_id, session_id).await?;
491
492 if events.is_empty() {
493 trace!("No relevant events found.");
494 return Ok(());
495 }
496
497 let mut updated_events = Vec::with_capacity(events.len());
499
500 for (event_id, mut event) in events {
501 let new_encryption_info =
502 room.get_encryption_info(session_id, &event.encryption_info.sender).await;
503
504 if let Some(new_encryption_info) = new_encryption_info
506 && event.encryption_info != new_encryption_info
507 {
508 event.encryption_info = new_encryption_info;
509 updated_events.push((event_id, event, None));
510 }
511 }
512
513 let event_ids: BTreeSet<_> =
514 updated_events.iter().map(|(event_id, _, _)| event_id).collect();
515
516 if !event_ids.is_empty() {
517 trace!(?event_ids, "Replacing the encryption info of some events");
518 }
519
520 self.on_resolved_utds(room_id, updated_events).await?;
521
522 Ok(())
523 }
524
525 pub fn request_decryption(&self, request: DecryptionRetryRequest) {
566 let _ =
567 self.inner.redecryption_channels.decryption_request_sender.send(request).inspect_err(
568 |_| warn!("Requesting a decryption while the redecryption task has been shut down"),
569 );
570 }
571
572 pub fn subscribe_to_decryption_reports(
615 &self,
616 ) -> impl Stream<Item = Result<RedecryptorReport, BroadcastStreamRecvError>> {
617 BroadcastStream::new(self.inner.redecryption_channels.utd_reporter.subscribe())
618 }
619}
620
621pub(crate) struct Redecryptor {
628 _task: AbortOnDrop<()>,
629}
630
631impl Redecryptor {
632 pub(super) fn new(
637 cache: Weak<EventCacheInner>,
638 receiver: UnboundedReceiver<DecryptionRetryRequest>,
639 linked_chunk_update_sender: &Sender<RoomEventCacheLinkedChunkUpdate>,
640 ) -> Self {
641 let linked_chunk_stream = BroadcastStream::new(linked_chunk_update_sender.subscribe());
642
643 let task = spawn(async {
644 let request_redecryption_stream = UnboundedReceiverStream::new(receiver);
645
646 Self::listen_for_room_keys_task(
647 cache,
648 request_redecryption_stream,
649 linked_chunk_stream,
650 )
651 .await;
652 })
653 .abort_on_drop();
654
655 Self { _task: task }
656 }
657
658 async fn subscribe_to_room_key_stream(
663 cache: &Weak<EventCacheInner>,
664 ) -> Option<(
665 impl Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>>,
666 impl Stream<Item = Vec<RoomKeyWithheldInfo>>,
667 )> {
668 let event_cache = cache.upgrade()?;
669 let client = event_cache.client().ok()?;
670 let machine = client.olm_machine().await;
671
672 machine.as_ref().map(|m| {
673 (m.store().room_keys_received_stream(), m.store().room_keys_withheld_received_stream())
674 })
675 }
676
677 #[inline(always)]
678 fn upgrade_event_cache(cache: &Weak<EventCacheInner>) -> Option<EventCache> {
679 cache.upgrade().map(|inner| EventCache { inner })
680 }
681
682 async fn redecryption_loop(
683 cache: &Weak<EventCacheInner>,
684 decryption_request_stream: &mut Pin<&mut impl Stream<Item = DecryptionRetryRequest>>,
685 events_stream: &mut Pin<
686 &mut impl Stream<Item = Result<RoomEventCacheLinkedChunkUpdate, BroadcastStreamRecvError>>,
687 >,
688 ) -> bool {
689 let Some((room_key_stream, withheld_stream)) =
690 Self::subscribe_to_room_key_stream(cache).await
691 else {
692 return false;
693 };
694
695 pin_mut!(room_key_stream);
696 pin_mut!(withheld_stream);
697
698 loop {
699 tokio::select! {
700 Some(request) = decryption_request_stream.next() => {
703 let Some(cache) = Self::upgrade_event_cache(cache) else {
704 break false;
705 };
706
707 trace!(?request, "Received a redecryption request");
708
709 for session_id in request.utd_session_ids {
710 let _ = cache
711 .retry_decryption(&request.room_id, &session_id)
712 .await
713 .inspect_err(|e| warn!("Error redecrypting after an explicit request was received {e:?}"));
714 }
715
716 for session_id in request.refresh_info_session_ids {
717 let _ = cache.update_encryption_info(&request.room_id, &session_id).await.inspect_err(|e|
718 warn!(
719 room_id = %request.room_id,
720 session_id = session_id,
721 "Unable to update the encryption info {e:?}",
722 ));
723 }
724 }
725 room_keys = room_key_stream.next() => {
728 match room_keys {
729 Some(Ok(room_keys)) => {
730 let Some(cache) = Self::upgrade_event_cache(cache) else {
734 break false;
735 };
736
737 trace!(?room_keys, "Received new room keys");
738
739 for key in &room_keys {
740 let _ = cache
741 .retry_decryption(&key.room_id, &key.session_id)
742 .await
743 .inspect_err(|e| warn!("Error redecrypting {e:?}"));
744 }
745
746 for key in room_keys {
747 let _ = cache.update_encryption_info(&key.room_id, &key.session_id).await.inspect_err(|e|
748 warn!(
749 room_id = %key.room_id,
750 session_id = key.session_id,
751 "Unable to update the encryption info {e:?}",
752 ));
753 }
754 },
755 Some(Err(_)) => {
756 let Some(cache) = Self::upgrade_event_cache(cache) else {
763 break false;
764 };
765
766 let message = RedecryptorReport::Lagging;
767 let _ = cache.inner.redecryption_channels.utd_reporter.send(message);
768 },
769 None => {
772 break true
773 }
774 }
775 }
776 withheld_info = withheld_stream.next() => {
777 match withheld_info {
778 Some(infos) => {
779 let Some(cache) = Self::upgrade_event_cache(cache) else {
780 break false;
781 };
782
783 trace!(?infos, "Received new withheld infos");
784
785 for RoomKeyWithheldInfo { room_id, session_id, .. } in &infos {
786 let _ = cache.update_encryption_info(room_id, session_id).await.inspect_err(|e|
787 warn!(
788 room_id = %room_id,
789 session_id = session_id,
790 "Unable to update the encryption info {e:?}",
791 ));
792 }
793 }
794 None => break true
797 }
798 }
799 Some(event_updates) = events_stream.next() => {
803 match event_updates {
804 Ok(updates) => {
805 let Some(cache) = Self::upgrade_event_cache(cache) else {
806 break false;
807 };
808
809 let linked_chunk_id = updates.linked_chunk_id.to_owned();
810
811 let _ = cache.retry_decryption_for_event_cache_updates(updates).await.inspect_err(|e|
812 warn!(
813 %linked_chunk_id,
814 "Unable to handle UTDs from event cache updates {e:?}",
815 )
816 );
817 }
818 Err(_) => {
819 let Some(cache) = Self::upgrade_event_cache(cache) else {
820 break false;
821 };
822
823 let message = RedecryptorReport::Lagging;
824 let _ = cache.inner.redecryption_channels.utd_reporter.send(message);
825 }
826 }
827 }
828 else => break false,
829 }
830 }
831 }
832
833 async fn listen_for_room_keys_task(
834 cache: Weak<EventCacheInner>,
835 decryption_request_stream: UnboundedReceiverStream<DecryptionRetryRequest>,
836 events_stream: BroadcastStream<RoomEventCacheLinkedChunkUpdate>,
837 ) {
838 pin_mut!(decryption_request_stream);
842 pin_mut!(events_stream);
843
844 while Self::redecryption_loop(&cache, &mut decryption_request_stream, &mut events_stream)
845 .await
846 {
847 info!("Regenerating the re-decryption streams");
848
849 let Some(cache) = Self::upgrade_event_cache(&cache) else {
850 break;
851 };
852
853 let message = RedecryptorReport::Lagging;
856 let _ = cache.inner.redecryption_channels.utd_reporter.send(message);
857 }
858
859 info!("Shutting down the event cache redecryptor");
860 }
861}
862
863#[cfg(not(target_family = "wasm"))]
864#[cfg(test)]
865mod tests {
866 use std::{
867 collections::BTreeSet,
868 sync::{
869 Arc,
870 atomic::{AtomicBool, Ordering},
871 },
872 time::Duration,
873 };
874
875 use assert_matches2::assert_matches;
876 use async_trait::async_trait;
877 use eyeball_im::VectorDiff;
878 use matrix_sdk_base::{
879 cross_process_lock::CrossProcessLockGeneration,
880 crypto::types::events::{ToDeviceEvent, room::encrypted::ToDeviceEncryptedEventContent},
881 deserialized_responses::{TimelineEventKind, VerificationState},
882 event_cache::{
883 Event, Gap,
884 store::{EventCacheStore, EventCacheStoreError, MemoryStore},
885 },
886 linked_chunk::{
887 ChunkIdentifier, ChunkIdentifierGenerator, ChunkMetadata, LinkedChunkId, Position,
888 RawChunk, Update,
889 },
890 locks::Mutex,
891 sleep::sleep,
892 store::StoreConfig,
893 };
894 use matrix_sdk_test::{
895 JoinedRoomBuilder, StateTestEvent, async_test, event_factory::EventFactory,
896 };
897 use ruma::{
898 EventId, OwnedEventId, RoomId, device_id, event_id,
899 events::{AnySyncTimelineEvent, relation::RelationType},
900 room_id,
901 serde::Raw,
902 user_id,
903 };
904 use serde_json::json;
905 use tokio::sync::oneshot::{self, Sender};
906 use tracing::{Instrument, info};
907
908 use crate::{
909 Client, assert_let_timeout,
910 encryption::EncryptionSettings,
911 event_cache::{DecryptionRetryRequest, RoomEventCacheUpdate},
912 test_utils::mocks::MatrixMockServer,
913 };
914
915 #[derive(Debug, Clone)]
920 struct DelayingStore {
921 memory_store: MemoryStore,
922 delaying: Arc<AtomicBool>,
923 foo: Arc<Mutex<Option<Sender<()>>>>,
924 }
925
926 impl DelayingStore {
927 fn new() -> Self {
928 Self {
929 memory_store: MemoryStore::new(),
930 delaying: AtomicBool::new(true).into(),
931 foo: Arc::new(Mutex::new(None)),
932 }
933 }
934
935 async fn stop_delaying(&self) {
936 let (sender, receiver) = oneshot::channel();
937
938 {
939 *self.foo.lock() = Some(sender);
940 }
941
942 self.delaying.store(false, Ordering::SeqCst);
943
944 receiver.await.expect("We should be able to receive a response")
945 }
946 }
947
948 #[cfg_attr(target_family = "wasm", async_trait(?Send))]
949 #[cfg_attr(not(target_family = "wasm"), async_trait)]
950 impl EventCacheStore for DelayingStore {
951 type Error = EventCacheStoreError;
952
953 async fn try_take_leased_lock(
954 &self,
955 lease_duration_ms: u32,
956 key: &str,
957 holder: &str,
958 ) -> Result<Option<CrossProcessLockGeneration>, Self::Error> {
959 self.memory_store.try_take_leased_lock(lease_duration_ms, key, holder).await
960 }
961
962 async fn handle_linked_chunk_updates(
963 &self,
964 linked_chunk_id: LinkedChunkId<'_>,
965 updates: Vec<Update<Event, Gap>>,
966 ) -> Result<(), Self::Error> {
967 while self.delaying.load(Ordering::SeqCst) {
973 sleep(Duration::from_millis(10)).await;
974 }
975
976 let sender = self.foo.lock().take();
977 let ret = self.memory_store.handle_linked_chunk_updates(linked_chunk_id, updates).await;
978
979 if let Some(sender) = sender {
980 sender.send(()).expect("We should be able to notify the other side that we're done with the storage operation");
981 }
982
983 ret
984 }
985
986 async fn load_all_chunks(
987 &self,
988 linked_chunk_id: LinkedChunkId<'_>,
989 ) -> Result<Vec<RawChunk<Event, Gap>>, Self::Error> {
990 self.memory_store.load_all_chunks(linked_chunk_id).await
991 }
992
993 async fn load_all_chunks_metadata(
994 &self,
995 linked_chunk_id: LinkedChunkId<'_>,
996 ) -> Result<Vec<ChunkMetadata>, Self::Error> {
997 self.memory_store.load_all_chunks_metadata(linked_chunk_id).await
998 }
999
1000 async fn load_last_chunk(
1001 &self,
1002 linked_chunk_id: LinkedChunkId<'_>,
1003 ) -> Result<(Option<RawChunk<Event, Gap>>, ChunkIdentifierGenerator), Self::Error> {
1004 self.memory_store.load_last_chunk(linked_chunk_id).await
1005 }
1006
1007 async fn load_previous_chunk(
1008 &self,
1009 linked_chunk_id: LinkedChunkId<'_>,
1010 before_chunk_identifier: ChunkIdentifier,
1011 ) -> Result<Option<RawChunk<Event, Gap>>, Self::Error> {
1012 self.memory_store.load_previous_chunk(linked_chunk_id, before_chunk_identifier).await
1013 }
1014
1015 async fn clear_all_linked_chunks(&self) -> Result<(), Self::Error> {
1016 self.memory_store.clear_all_linked_chunks().await
1017 }
1018
1019 async fn filter_duplicated_events(
1020 &self,
1021 linked_chunk_id: LinkedChunkId<'_>,
1022 events: Vec<OwnedEventId>,
1023 ) -> Result<Vec<(OwnedEventId, Position)>, Self::Error> {
1024 self.memory_store.filter_duplicated_events(linked_chunk_id, events).await
1025 }
1026
1027 async fn find_event(
1028 &self,
1029 room_id: &RoomId,
1030 event_id: &EventId,
1031 ) -> Result<Option<Event>, Self::Error> {
1032 self.memory_store.find_event(room_id, event_id).await
1033 }
1034
1035 async fn find_event_relations(
1036 &self,
1037 room_id: &RoomId,
1038 event_id: &EventId,
1039 filters: Option<&[RelationType]>,
1040 ) -> Result<Vec<(Event, Option<Position>)>, Self::Error> {
1041 self.memory_store.find_event_relations(room_id, event_id, filters).await
1042 }
1043
1044 async fn get_room_events(
1045 &self,
1046 room_id: &RoomId,
1047 event_type: Option<&str>,
1048 session_id: Option<&str>,
1049 ) -> Result<Vec<Event>, Self::Error> {
1050 self.memory_store.get_room_events(room_id, event_type, session_id).await
1051 }
1052
1053 async fn save_event(&self, room_id: &RoomId, event: Event) -> Result<(), Self::Error> {
1054 self.memory_store.save_event(room_id, event).await
1055 }
1056 }
1057
1058 async fn set_up_clients(
1059 room_id: &RoomId,
1060 alice_enables_cross_signing: bool,
1061 use_delayed_store: bool,
1062 ) -> (Client, Client, MatrixMockServer, Option<DelayingStore>) {
1063 let alice_span = tracing::info_span!("alice");
1064 let bob_span = tracing::info_span!("bob");
1065
1066 let alice_user_id = user_id!("@alice:localhost");
1067 let alice_device_id = device_id!("ALICEDEVICE");
1068 let bob_user_id = user_id!("@bob:localhost");
1069 let bob_device_id = device_id!("BOBDEVICE");
1070
1071 let matrix_mock_server = MatrixMockServer::new().await;
1072 matrix_mock_server.mock_crypto_endpoints_preset().await;
1073
1074 let encryption_settings = EncryptionSettings {
1075 auto_enable_cross_signing: alice_enables_cross_signing,
1076 ..Default::default()
1077 };
1078
1079 let alice = matrix_mock_server
1082 .client_builder_for_crypto_end_to_end(alice_user_id, alice_device_id)
1083 .on_builder(|builder| {
1084 builder
1085 .with_enable_share_history_on_invite(true)
1086 .with_encryption_settings(encryption_settings)
1087 })
1088 .build()
1089 .instrument(alice_span.clone())
1090 .await;
1091
1092 let encryption_settings =
1093 EncryptionSettings { auto_enable_cross_signing: true, ..Default::default() };
1094
1095 let (store_config, store) = if use_delayed_store {
1096 let store = DelayingStore::new();
1097
1098 (
1099 StoreConfig::new("delayed_store_event_cache_test".into())
1100 .event_cache_store(store.clone()),
1101 Some(store),
1102 )
1103 } else {
1104 (StoreConfig::new("normal_store_event_cache_test".into()), None)
1105 };
1106
1107 let bob = matrix_mock_server
1108 .client_builder_for_crypto_end_to_end(bob_user_id, bob_device_id)
1109 .on_builder(|builder| {
1110 builder
1111 .with_enable_share_history_on_invite(true)
1112 .with_encryption_settings(encryption_settings)
1113 .store_config(store_config)
1114 })
1115 .build()
1116 .instrument(bob_span.clone())
1117 .await;
1118
1119 bob.event_cache().subscribe().expect("Bob should be able to enable the event cache");
1120
1121 matrix_mock_server.exchange_e2ee_identities(&alice, &bob).await;
1123
1124 let room_builder = JoinedRoomBuilder::new(room_id)
1126 .add_state_event(StateTestEvent::Create)
1127 .add_state_event(StateTestEvent::Encryption);
1128
1129 matrix_mock_server
1130 .mock_sync()
1131 .ok_and_run(&alice, |builder| {
1132 builder.add_joined_room(room_builder.clone());
1133 })
1134 .instrument(alice_span)
1135 .await;
1136
1137 matrix_mock_server
1138 .mock_sync()
1139 .ok_and_run(&bob, |builder| {
1140 builder.add_joined_room(room_builder);
1141 })
1142 .instrument(bob_span)
1143 .await;
1144
1145 (alice, bob, matrix_mock_server, store)
1146 }
1147
1148 async fn prepare_room(
1149 matrix_mock_server: &MatrixMockServer,
1150 event_factory: &EventFactory,
1151 alice: &Client,
1152 bob: &Client,
1153 room_id: &RoomId,
1154 ) -> (Raw<AnySyncTimelineEvent>, Raw<ToDeviceEvent<ToDeviceEncryptedEventContent>>) {
1155 let alice_user_id = alice.user_id().unwrap();
1156 let bob_user_id = bob.user_id().unwrap();
1157
1158 let alice_member_event = event_factory.member(alice_user_id).into_raw();
1159 let bob_member_event = event_factory.member(bob_user_id).into_raw();
1160
1161 let room = alice
1162 .get_room(room_id)
1163 .expect("Alice should have access to the room now that we synced");
1164
1165 let event_type = "m.room.message";
1170 let content = json!({"body": "It's a secret to everybody", "msgtype": "m.text"});
1171
1172 let event_id = event_id!("$some_id");
1173 let (event_receiver, mock) =
1174 matrix_mock_server.mock_room_send().ok_with_capture(event_id, alice_user_id);
1175 let (_guard, room_key) = matrix_mock_server.mock_capture_put_to_device(alice_user_id).await;
1176
1177 {
1178 let _guard = mock.mock_once().mount_as_scoped().await;
1179
1180 matrix_mock_server
1181 .mock_get_members()
1182 .ok(vec![alice_member_event.clone(), bob_member_event.clone()])
1183 .mock_once()
1184 .mount()
1185 .await;
1186
1187 room.send_raw(event_type, content)
1188 .await
1189 .expect("We should be able to send an initial message");
1190 };
1191
1192 let event = event_receiver.await.expect("Alice should have sent the event by now");
1194 let room_key = room_key.await;
1195
1196 (event, room_key)
1197 }
1198
1199 #[async_test]
1200 async fn test_redecryptor() {
1201 let room_id = room_id!("!test:localhost");
1202
1203 let event_factory = EventFactory::new().room(room_id);
1204 let (alice, bob, matrix_mock_server, _) = set_up_clients(room_id, true, false).await;
1205
1206 let (event, room_key) =
1207 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1208
1209 let (room_cache, _) = bob
1212 .event_cache()
1213 .for_room(room_id)
1214 .await
1215 .expect("We should be able to get to the event cache for a specific room");
1216
1217 let (_, mut subscriber) = room_cache.subscribe().await;
1218
1219 bob.inner
1222 .base_client
1223 .regenerate_olm(None)
1224 .await
1225 .expect("We should be able to regenerate the Olm machine");
1226
1227 matrix_mock_server
1229 .mock_sync()
1230 .ok_and_run(&bob, |builder| {
1231 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1232 })
1233 .await;
1234
1235 assert_let_timeout!(
1238 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1239 );
1240
1241 assert_eq!(diffs.len(), 1);
1244 assert_matches!(&diffs[0], VectorDiff::Append { values });
1245 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1246
1247 matrix_mock_server
1249 .mock_sync()
1250 .ok_and_run(&bob, |builder| {
1251 builder.add_to_device_event(
1252 room_key
1253 .deserialize_as()
1254 .expect("We should be able to deserialize the room key"),
1255 );
1256 })
1257 .await;
1258
1259 assert_let_timeout!(
1261 Duration::from_secs(1),
1262 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1263 );
1264
1265 assert_eq!(diffs.len(), 1);
1267 assert_matches!(&diffs[0], VectorDiff::Set { index, value });
1268 assert_eq!(*index, 0);
1269 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1270 }
1271
1272 #[async_test]
1273 async fn test_redecryptor_updating_encryption_info() {
1274 let bob_span = tracing::info_span!("bob");
1275
1276 let room_id = room_id!("!test:localhost");
1277
1278 let event_factory = EventFactory::new().room(room_id);
1279 let (alice, bob, matrix_mock_server, _) = set_up_clients(room_id, false, false).await;
1280
1281 let (event, room_key) =
1282 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1283
1284 let (room_cache, _) = bob
1287 .event_cache()
1288 .for_room(room_id)
1289 .instrument(bob_span.clone())
1290 .await
1291 .expect("We should be able to get to the event cache for a specific room");
1292
1293 let (_, mut subscriber) = room_cache.subscribe().await;
1294
1295 matrix_mock_server
1297 .mock_sync()
1298 .ok_and_run(&bob, |builder| {
1299 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1300 })
1301 .instrument(bob_span.clone())
1302 .await;
1303
1304 assert_let_timeout!(
1307 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1308 );
1309
1310 assert_eq!(diffs.len(), 1);
1313 assert_matches!(&diffs[0], VectorDiff::Append { values });
1314 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1315
1316 matrix_mock_server
1318 .mock_sync()
1319 .ok_and_run(&bob, |builder| {
1320 builder.add_to_device_event(
1321 room_key
1322 .deserialize_as()
1323 .expect("We should be able to deserialize the room key"),
1324 );
1325 })
1326 .instrument(bob_span.clone())
1327 .await;
1328
1329 assert_let_timeout!(
1331 Duration::from_secs(1),
1332 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1333 );
1334
1335 assert_eq!(diffs.len(), 1);
1337 assert_matches!(&diffs[0], VectorDiff::Set { index: 0, value });
1338 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1339
1340 let encryption_info = value.encryption_info().unwrap();
1341 assert_matches!(&encryption_info.verification_state, VerificationState::Unverified(_));
1342 let session_id = encryption_info.session_id().unwrap().to_owned();
1343
1344 let alice_user_id = alice.user_id().unwrap();
1345
1346 alice
1348 .encryption()
1349 .bootstrap_cross_signing(None)
1350 .await
1351 .expect("Alice should be able to create the cross-signing keys");
1352
1353 bob.update_tracked_users_for_testing([alice_user_id]).instrument(bob_span.clone()).await;
1354 matrix_mock_server
1355 .mock_sync()
1356 .ok_and_run(&bob, |builder| {
1357 builder.add_change_device(alice_user_id);
1358 })
1359 .instrument(bob_span.clone())
1360 .await;
1361
1362 bob.event_cache().request_decryption(DecryptionRetryRequest {
1363 room_id: room_id.into(),
1364 utd_session_ids: BTreeSet::new(),
1365 refresh_info_session_ids: BTreeSet::from([session_id]),
1366 });
1367
1368 assert_let_timeout!(
1371 Duration::from_secs(1),
1372 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1373 );
1374
1375 assert_eq!(diffs.len(), 1);
1376 assert_matches!(&diffs[0], VectorDiff::Set { index: 0, value });
1377 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1378 let encryption_info = value.encryption_info().unwrap();
1379
1380 assert_matches!(
1381 &encryption_info.verification_state,
1382 VerificationState::Unverified(_),
1383 "The event should now know about the identity but still be unverified"
1384 );
1385 }
1386
1387 #[async_test]
1388 async fn test_event_is_redecrypted_even_if_key_arrives_while_event_processing() {
1389 let room_id = room_id!("!test:localhost");
1390
1391 let event_factory = EventFactory::new().room(room_id);
1392 let (alice, bob, matrix_mock_server, delayed_store) =
1393 set_up_clients(room_id, true, true).await;
1394
1395 let delayed_store = delayed_store.unwrap();
1396
1397 let (event, room_key) =
1398 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1399
1400 let (room_cache, _) = bob
1402 .event_cache()
1403 .for_room(room_id)
1404 .await
1405 .expect("We should be able to get to the event cache for a specific room");
1406
1407 let (_, mut subscriber) = room_cache.subscribe().await;
1408
1409 matrix_mock_server
1411 .mock_sync()
1412 .ok_and_run(&bob, |builder| {
1413 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1414 })
1415 .await;
1416
1417 matrix_mock_server
1419 .mock_sync()
1420 .ok_and_run(&bob, |builder| {
1421 builder.add_to_device_event(
1422 room_key
1423 .deserialize_as()
1424 .expect("We should be able to deserialize the room key"),
1425 );
1426 })
1427 .await;
1428
1429 info!("Stopping the delay");
1430 delayed_store.stop_delaying().await;
1431
1432 assert_let_timeout!(
1439 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1440 );
1441
1442 assert_eq!(diffs.len(), 1);
1445 assert_matches!(&diffs[0], VectorDiff::Append { values });
1446 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1447
1448 assert_let_timeout!(
1450 Duration::from_secs(1),
1451 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1452 );
1453
1454 assert_eq!(diffs.len(), 1);
1456 assert_matches!(&diffs[0], VectorDiff::Set { index, value });
1457 assert_eq!(*index, 0);
1458 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1459 }
1460}