1use std::{collections::BTreeSet, sync::Arc};
16
17use imbl::Vector;
18use itertools::{Either, Itertools as _};
19use matrix_sdk::{
20 deserialized_responses::TimelineEventKind as SdkTimelineEventKind, executor::JoinHandle,
21};
22use tokio::sync::{
23 mpsc::{self, Receiver, Sender},
24 RwLock,
25};
26use tracing::{debug, error, field, info, info_span, Instrument as _};
27
28use crate::timeline::{
29 controller::{TimelineSettings, TimelineState},
30 event_item::EventTimelineItemKind,
31 traits::{Decryptor, RoomDataProvider},
32 EncryptedMessage, EventTimelineItem, TimelineItem, TimelineItemKind,
33};
34
35#[derive(Clone, Debug)]
45pub struct DecryptionRetryTask<D: Decryptor> {
46 sender: Sender<DecryptionRetryRequest<D>>,
51
52 _task_handle: Arc<JoinHandle<()>>,
57}
58
59const CHANNEL_BUFFER_SIZE: usize = 100;
63
64impl<D: Decryptor> DecryptionRetryTask<D> {
65 pub(crate) fn new<P: RoomDataProvider>(
66 state: Arc<RwLock<TimelineState>>,
67 room_data_provider: P,
68 ) -> Self {
69 let (sender, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
71
72 let handle =
75 matrix_sdk::executor::spawn(decryption_task(state, room_data_provider, receiver));
76
77 Self { sender, _task_handle: Arc::new(handle) }
79 }
80
81 pub(crate) async fn decrypt(
84 &self,
85 decryptor: D,
86 session_ids: Option<BTreeSet<String>>,
87 settings: TimelineSettings,
88 ) {
89 let res =
90 self.sender.send(DecryptionRetryRequest { decryptor, session_ids, settings }).await;
91
92 if let Err(error) = res {
93 error!("Failed to send decryption retry request: {}", error);
94 }
95 }
96}
97
98struct DecryptionRetryRequest<D: Decryptor> {
101 decryptor: D,
102 session_ids: Option<BTreeSet<String>>,
103 settings: TimelineSettings,
104}
105
106async fn decryption_task<D: Decryptor>(
110 state: Arc<RwLock<TimelineState>>,
111 room_data_provider: impl RoomDataProvider,
112 mut receiver: Receiver<DecryptionRetryRequest<D>>,
113) {
114 debug!("Decryption task starting.");
115
116 while let Some(request) = receiver.recv().await {
117 let should_retry = |session_id: &str| {
118 if let Some(session_ids) = &request.session_ids {
119 session_ids.contains(session_id)
120 } else {
121 true
122 }
123 };
124
125 let mut state = state.write().await;
129 let (retry_decryption_indices, retry_info_indices) =
130 compute_event_indices_to_retry_decryption(&state.items, should_retry);
131
132 if !retry_info_indices.is_empty() {
134 debug!("Retrying fetching encryption info");
135 retry_fetch_encryption_info(&mut state, retry_info_indices, &room_data_provider).await;
136 }
137
138 if !retry_decryption_indices.is_empty() {
140 debug!("Retrying decryption");
141 decrypt_by_index(
142 &mut state,
143 &request.settings,
144 &room_data_provider,
145 request.decryptor,
146 should_retry,
147 retry_decryption_indices,
148 )
149 .await
150 }
151 }
152
153 debug!("Decryption task stopping.");
154}
155
156fn compute_event_indices_to_retry_decryption(
164 items: &Vector<Arc<TimelineItem>>,
165 should_retry: impl Fn(&str) -> bool,
166) -> (Vec<usize>, Vec<usize>) {
167 use Either::{Left, Right};
168
169 let should_retry_event = |event: &EventTimelineItem| {
171 let session_id = if let Some(encrypted_message) = event.content().as_unable_to_decrypt() {
172 encrypted_message.session_id()
174 } else {
175 event
178 .as_remote()
179 .and_then(|remote| remote.encryption_info.as_ref()?.session_id.as_ref())
180 .map(String::as_str)
181 };
182
183 if let Some(session_id) = session_id {
184 should_retry(session_id)
186 } else {
187 false
189 }
190 };
191
192 items
193 .iter()
194 .enumerate()
195 .filter_map(|(idx, item)| {
196 item.as_event().filter(|e| should_retry_event(e)).map(|event| (idx, event))
197 })
198 .partition_map(
200 |(idx, event)| {
201 if event.content().is_unable_to_decrypt() {
202 Left(idx)
203 } else {
204 Right(idx)
205 }
206 },
207 )
208}
209
210pub(super) async fn retry_fetch_encryption_info<P: RoomDataProvider>(
213 state: &mut TimelineState,
214 retry_indices: Vec<usize>,
215 room_data_provider: &P,
216) {
217 for idx in retry_indices {
218 let old_item = state.items.get(idx);
219 if let Some(new_item) = make_replacement_for(room_data_provider, old_item).await {
220 state.items.replace(idx, new_item);
221 }
222 }
223}
224
225async fn make_replacement_for<P: RoomDataProvider>(
229 room_data_provider: &P,
230 item: Option<&Arc<TimelineItem>>,
231) -> Option<Arc<TimelineItem>> {
232 let item = item?;
233 let event = item.as_event()?;
234 let remote = event.as_remote()?;
235 let session_id = remote.encryption_info.as_ref()?.session_id.as_deref()?;
236
237 let new_encryption_info =
238 room_data_provider.get_encryption_info(session_id, &event.sender).await;
239 let mut new_remote = remote.clone();
240 new_remote.encryption_info = new_encryption_info;
241 let new_item = item.with_kind(TimelineItemKind::Event(
242 event.with_kind(EventTimelineItemKind::Remote(new_remote)),
243 ));
244
245 Some(new_item)
246}
247
248async fn decrypt_by_index<D: Decryptor>(
251 state: &mut TimelineState,
252 settings: &TimelineSettings,
253 room_data_provider: &impl RoomDataProvider,
254 decryptor: D,
255 should_retry: impl Fn(&str) -> bool,
256 retry_indices: Vec<usize>,
257) {
258 let push_ctx = room_data_provider.push_context().await;
259 let push_ctx = push_ctx.as_ref();
260 let unable_to_decrypt_hook = state.meta.unable_to_decrypt_hook.clone();
261
262 let retry_one = |item: Arc<TimelineItem>| {
263 let decryptor = decryptor.clone();
264 let should_retry = &should_retry;
265 let unable_to_decrypt_hook = unable_to_decrypt_hook.clone();
266 async move {
267 let event_item = item.as_event()?;
268
269 let session_id = match event_item.content().as_unable_to_decrypt()? {
270 EncryptedMessage::MegolmV1AesSha2 { session_id, .. }
271 if should_retry(session_id) =>
272 {
273 session_id
274 }
275 EncryptedMessage::MegolmV1AesSha2 { .. }
276 | EncryptedMessage::OlmV1Curve25519AesSha2 { .. }
277 | EncryptedMessage::Unknown => return None,
278 };
279
280 tracing::Span::current().record("session_id", session_id);
281
282 let Some(remote_event) = event_item.as_remote() else {
283 error!("Key for unable-to-decrypt timeline item is not an event ID");
284 return None;
285 };
286
287 tracing::Span::current().record("event_id", field::debug(&remote_event.event_id));
288
289 let Some(original_json) = &remote_event.original_json else {
290 error!("UTD item must contain original JSON");
291 return None;
292 };
293
294 match decryptor.decrypt_event_impl(original_json, push_ctx).await {
295 Ok(event) => {
296 if let SdkTimelineEventKind::UnableToDecrypt { utd_info, .. } = event.kind {
297 info!(
298 "Failed to decrypt event after receiving room key: {:?}",
299 utd_info.reason
300 );
301 None
302 } else {
303 if let Some(hook) = unable_to_decrypt_hook {
305 hook.on_late_decrypt(&remote_event.event_id).await;
306 }
307
308 Some(event)
309 }
310 }
311 Err(e) => {
312 info!("Failed to decrypt event after receiving room key: {e}");
313 None
314 }
315 }
316 }
317 .instrument(info_span!(
318 "retry_one",
319 session_id = field::Empty,
320 event_id = field::Empty
321 ))
322 };
323
324 state.retry_event_decryption(retry_one, retry_indices, room_data_provider, settings).await;
325}
326
327#[cfg(test)]
328mod tests {
329 use std::{collections::BTreeMap, sync::Arc, time::SystemTime};
330
331 use imbl::vector;
332 use matrix_sdk::{
333 crypto::types::events::UtdCause,
334 deserialized_responses::{AlgorithmInfo, EncryptionInfo, VerificationState},
335 };
336 use ruma::{
337 events::room::{
338 encrypted::{
339 EncryptedEventScheme, MegolmV1AesSha2Content, MegolmV1AesSha2ContentInit,
340 RoomEncryptedEventContent,
341 },
342 message::RoomMessageEventContent,
343 },
344 owned_device_id, owned_event_id, owned_user_id, MilliSecondsSinceUnixEpoch,
345 OwnedTransactionId,
346 };
347
348 use crate::timeline::{
349 controller::decryption_retry_task::compute_event_indices_to_retry_decryption,
350 event_item::{
351 EventTimelineItemKind, LocalEventTimelineItem, RemoteEventOrigin,
352 RemoteEventTimelineItem,
353 },
354 EncryptedMessage, EventSendState, EventTimelineItem, MsgLikeContent,
355 ReactionsByKeyBySender, TimelineDetails, TimelineItem, TimelineItemContent,
356 TimelineItemKind, TimelineUniqueId, VirtualTimelineItem,
357 };
358
359 #[test]
360 fn test_non_events_are_not_retried() {
361 let timeline = vector![TimelineItem::read_marker(), date_divider()];
363 let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
365 assert!(answer.0.is_empty());
367 assert!(answer.1.is_empty());
368 }
369
370 #[test]
371 fn test_non_remote_events_are_not_retried() {
372 let timeline = vector![local_event()];
374 let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
376 assert!(answer.0.is_empty());
378 assert!(answer.1.is_empty());
379 }
380
381 #[test]
382 fn test_utds_are_retried() {
383 let timeline = vector![utd_event("session1")];
385 let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
387 assert_eq!(answer.0, vec![0]);
389 assert!(answer.1.is_empty());
390 }
391
392 #[test]
393 fn test_remote_decrypted_info_is_refetched() {
394 let timeline = vector![decrypted_event("session1")];
396 let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
398 assert!(answer.0.is_empty());
400 assert_eq!(answer.1, vec![0]);
401 }
402
403 #[test]
404 fn test_only_required_sessions_are_retried() {
405 fn retry(s: &str) -> bool {
408 s == "session1"
409 }
410
411 let timeline = vector![
414 TimelineItem::read_marker(),
415 utd_event("session1"),
416 utd_event("session1"),
417 date_divider(),
418 utd_event("session2"),
419 decrypted_event("session1"),
420 decrypted_event("session1"),
421 decrypted_event("session2"),
422 local_event(),
423 ];
424
425 let answer = compute_event_indices_to_retry_decryption(&timeline, retry);
427
428 assert_eq!(answer.0, vec![1, 2]);
430 assert_eq!(answer.1, vec![5, 6]);
431 }
432
433 fn always_retry(_: &str) -> bool {
434 true
435 }
436
437 fn date_divider() -> Arc<TimelineItem> {
438 TimelineItem::new(
439 TimelineItemKind::Virtual(VirtualTimelineItem::DateDivider(timestamp())),
440 TimelineUniqueId("datething".to_owned()),
441 )
442 }
443
444 fn local_event() -> Arc<TimelineItem> {
445 let event_kind = EventTimelineItemKind::Local(LocalEventTimelineItem {
446 send_state: EventSendState::NotSentYet,
447 transaction_id: OwnedTransactionId::from("trans"),
448 send_handle: None,
449 });
450
451 TimelineItem::new(
452 TimelineItemKind::Event(EventTimelineItem::new(
453 owned_user_id!("@u:s.to"),
454 TimelineDetails::Pending,
455 timestamp(),
456 TimelineItemContent::MsgLike(MsgLikeContent::redacted()),
457 event_kind,
458 true,
459 )),
460 TimelineUniqueId("local".to_owned()),
461 )
462 }
463
464 fn utd_event(session_id: &str) -> Arc<TimelineItem> {
465 let event_kind = EventTimelineItemKind::Remote(RemoteEventTimelineItem {
466 event_id: owned_event_id!("$local"),
467 transaction_id: None,
468 read_receipts: Default::default(),
469 is_own: false,
470 is_highlighted: false,
471 encryption_info: None,
472 original_json: None,
473 latest_edit_json: None,
474 origin: RemoteEventOrigin::Sync,
475 });
476
477 TimelineItem::new(
478 TimelineItemKind::Event(EventTimelineItem::new(
479 owned_user_id!("@u:s.to"),
480 TimelineDetails::Pending,
481 timestamp(),
482 TimelineItemContent::MsgLike(MsgLikeContent::unable_to_decrypt(
483 EncryptedMessage::from_content(
484 RoomEncryptedEventContent::new(
485 EncryptedEventScheme::MegolmV1AesSha2(MegolmV1AesSha2Content::from(
486 MegolmV1AesSha2ContentInit {
487 ciphertext: "cyf".to_owned(),
488 sender_key: "sendk".to_owned(),
489 device_id: owned_device_id!("DEV"),
490 session_id: session_id.to_owned(),
491 },
492 )),
493 None,
494 ),
495 UtdCause::Unknown,
496 ),
497 )),
498 event_kind,
499 true,
500 )),
501 TimelineUniqueId("local".to_owned()),
502 )
503 }
504
505 fn decrypted_event(session_id: &str) -> Arc<TimelineItem> {
506 let event_kind = EventTimelineItemKind::Remote(RemoteEventTimelineItem {
507 event_id: owned_event_id!("$local"),
508 transaction_id: None,
509 read_receipts: Default::default(),
510 is_own: false,
511 is_highlighted: false,
512 encryption_info: Some(EncryptionInfo {
513 sender: owned_user_id!("@u:s.co"),
514 sender_device: None,
515 algorithm_info: AlgorithmInfo::MegolmV1AesSha2 {
516 curve25519_key: "".to_owned(),
517 sender_claimed_keys: BTreeMap::new(),
518 },
519 verification_state: VerificationState::Verified,
520 session_id: Some(session_id.to_owned()),
521 }),
522 original_json: None,
523 latest_edit_json: None,
524 origin: RemoteEventOrigin::Sync,
525 });
526
527 TimelineItem::new(
528 TimelineItemKind::Event(EventTimelineItem::new(
529 owned_user_id!("@u:s.to"),
530 TimelineDetails::Pending,
531 timestamp(),
532 TimelineItemContent::message(
533 RoomMessageEventContent::text_plain("hi"),
534 None,
535 ReactionsByKeyBySender::default(),
536 None,
537 None,
538 None,
539 ),
540 event_kind,
541 true,
542 )),
543 TimelineUniqueId("local".to_owned()),
544 )
545 }
546
547 fn timestamp() -> MilliSecondsSinceUnixEpoch {
548 MilliSecondsSinceUnixEpoch::from_system_time(SystemTime::UNIX_EPOCH).unwrap()
549 }
550}