1use std::{collections::BTreeSet, sync::Arc};
16
17use imbl::Vector;
18use itertools::{Either, Itertools as _};
19use matrix_sdk::{
20 deserialized_responses::TimelineEventKind as SdkTimelineEventKind, executor::JoinHandle,
21};
22use tokio::sync::{
23 mpsc::{self, Receiver, Sender},
24 RwLock,
25};
26use tracing::{debug, error, field, info, info_span, Instrument as _};
27
28use crate::timeline::{
29 controller::{TimelineSettings, TimelineState},
30 event_item::EventTimelineItemKind,
31 traits::{Decryptor, RoomDataProvider},
32 EncryptedMessage, EventTimelineItem, TimelineItem, TimelineItemContent, TimelineItemKind,
33};
34
35#[derive(Clone, Debug)]
45pub struct DecryptionRetryTask<D: Decryptor> {
46 sender: Sender<DecryptionRetryRequest<D>>,
51
52 _task_handle: Arc<JoinHandle<()>>,
57}
58
59const CHANNEL_BUFFER_SIZE: usize = 100;
63
64impl<D: Decryptor> DecryptionRetryTask<D> {
65 pub(crate) fn new<P: RoomDataProvider>(
66 state: Arc<RwLock<TimelineState>>,
67 room_data_provider: P,
68 ) -> Self {
69 let (sender, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
71
72 let handle =
75 matrix_sdk::executor::spawn(decryption_task(state, room_data_provider, receiver));
76
77 Self { sender, _task_handle: Arc::new(handle) }
79 }
80
81 pub(crate) async fn decrypt(
84 &self,
85 decryptor: D,
86 session_ids: Option<BTreeSet<String>>,
87 settings: TimelineSettings,
88 ) {
89 let res =
90 self.sender.send(DecryptionRetryRequest { decryptor, session_ids, settings }).await;
91
92 if let Err(error) = res {
93 error!("Failed to send decryption retry request: {}", error);
94 }
95 }
96}
97
98struct DecryptionRetryRequest<D: Decryptor> {
101 decryptor: D,
102 session_ids: Option<BTreeSet<String>>,
103 settings: TimelineSettings,
104}
105
106async fn decryption_task<D: Decryptor>(
110 state: Arc<RwLock<TimelineState>>,
111 room_data_provider: impl RoomDataProvider,
112 mut receiver: Receiver<DecryptionRetryRequest<D>>,
113) {
114 debug!("Decryption task starting.");
115
116 while let Some(request) = receiver.recv().await {
117 let should_retry = |session_id: &str| {
118 if let Some(session_ids) = &request.session_ids {
119 session_ids.contains(session_id)
120 } else {
121 true
122 }
123 };
124
125 let mut state = state.write().await;
129 let (retry_decryption_indices, retry_info_indices) =
130 compute_event_indices_to_retry_decryption(&state.items, should_retry);
131
132 if !retry_info_indices.is_empty() {
134 debug!("Retrying fetching encryption info");
135 retry_fetch_encryption_info(&mut state, retry_info_indices, &room_data_provider).await;
136 }
137
138 if !retry_decryption_indices.is_empty() {
140 debug!("Retrying decryption");
141 decrypt_by_index(
142 &mut state,
143 &request.settings,
144 &room_data_provider,
145 request.decryptor,
146 should_retry,
147 retry_decryption_indices,
148 )
149 .await
150 }
151 }
152
153 debug!("Decryption task stopping.");
154}
155
156fn compute_event_indices_to_retry_decryption(
164 items: &Vector<Arc<TimelineItem>>,
165 should_retry: impl Fn(&str) -> bool,
166) -> (Vec<usize>, Vec<usize>) {
167 use Either::{Left, Right};
168
169 let should_retry_event = |event: &EventTimelineItem| {
171 let session_id =
172 if let TimelineItemContent::UnableToDecrypt(encrypted_message) = event.content() {
173 encrypted_message.session_id()
175 } else {
176 event
179 .as_remote()
180 .and_then(|remote| remote.encryption_info.as_ref()?.session_id.as_ref())
181 .map(String::as_str)
182 };
183
184 if let Some(session_id) = session_id {
185 should_retry(session_id)
187 } else {
188 false
190 }
191 };
192
193 items
194 .iter()
195 .enumerate()
196 .filter_map(|(idx, item)| {
197 item.as_event().filter(|e| should_retry_event(e)).map(|event| (idx, event))
198 })
199 .partition_map(
201 |(idx, event)| {
202 if event.content().is_unable_to_decrypt() {
203 Left(idx)
204 } else {
205 Right(idx)
206 }
207 },
208 )
209}
210
211pub(super) async fn retry_fetch_encryption_info<P: RoomDataProvider>(
214 state: &mut TimelineState,
215 retry_indices: Vec<usize>,
216 room_data_provider: &P,
217) {
218 for idx in retry_indices {
219 let old_item = state.items.get(idx);
220 if let Some(new_item) = make_replacement_for(room_data_provider, old_item).await {
221 state.items.replace(idx, new_item);
222 }
223 }
224}
225
226async fn make_replacement_for<P: RoomDataProvider>(
230 room_data_provider: &P,
231 item: Option<&Arc<TimelineItem>>,
232) -> Option<Arc<TimelineItem>> {
233 let item = item?;
234 let event = item.as_event()?;
235 let remote = event.as_remote()?;
236 let session_id = remote.encryption_info.as_ref()?.session_id.as_deref()?;
237
238 let new_encryption_info =
239 room_data_provider.get_encryption_info(session_id, &event.sender).await;
240 let mut new_remote = remote.clone();
241 new_remote.encryption_info = new_encryption_info;
242 let new_item = item.with_kind(TimelineItemKind::Event(
243 event.with_kind(EventTimelineItemKind::Remote(new_remote)),
244 ));
245
246 Some(new_item)
247}
248
249async fn decrypt_by_index<D: Decryptor>(
252 state: &mut TimelineState,
253 settings: &TimelineSettings,
254 room_data_provider: &impl RoomDataProvider,
255 decryptor: D,
256 should_retry: impl Fn(&str) -> bool,
257 retry_indices: Vec<usize>,
258) {
259 let push_rules_context = room_data_provider.push_rules_and_context().await;
260 let unable_to_decrypt_hook = state.meta.unable_to_decrypt_hook.clone();
261
262 let retry_one = |item: Arc<TimelineItem>| {
263 let decryptor = decryptor.clone();
264 let should_retry = &should_retry;
265 let unable_to_decrypt_hook = unable_to_decrypt_hook.clone();
266 async move {
267 let event_item = item.as_event()?;
268
269 let session_id = match event_item.content().as_unable_to_decrypt()? {
270 EncryptedMessage::MegolmV1AesSha2 { session_id, .. }
271 if should_retry(session_id) =>
272 {
273 session_id
274 }
275 EncryptedMessage::MegolmV1AesSha2 { .. }
276 | EncryptedMessage::OlmV1Curve25519AesSha2 { .. }
277 | EncryptedMessage::Unknown => return None,
278 };
279
280 tracing::Span::current().record("session_id", session_id);
281
282 let Some(remote_event) = event_item.as_remote() else {
283 error!("Key for unable-to-decrypt timeline item is not an event ID");
284 return None;
285 };
286
287 tracing::Span::current().record("event_id", field::debug(&remote_event.event_id));
288
289 let Some(original_json) = &remote_event.original_json else {
290 error!("UTD item must contain original JSON");
291 return None;
292 };
293
294 match decryptor.decrypt_event_impl(original_json).await {
295 Ok(event) => {
296 if let SdkTimelineEventKind::UnableToDecrypt { utd_info, .. } = event.kind {
297 info!(
298 "Failed to decrypt event after receiving room key: {:?}",
299 utd_info.reason
300 );
301 None
302 } else {
303 if let Some(hook) = unable_to_decrypt_hook {
305 hook.on_late_decrypt(&remote_event.event_id).await;
306 }
307
308 Some(event)
309 }
310 }
311 Err(e) => {
312 info!("Failed to decrypt event after receiving room key: {e}");
313 None
314 }
315 }
316 }
317 .instrument(info_span!(
318 "retry_one",
319 session_id = field::Empty,
320 event_id = field::Empty
321 ))
322 };
323
324 state
325 .retry_event_decryption(
326 retry_one,
327 retry_indices,
328 push_rules_context,
329 room_data_provider,
330 settings,
331 )
332 .await;
333}
334
335#[cfg(test)]
336mod tests {
337 use std::{collections::BTreeMap, sync::Arc, time::SystemTime};
338
339 use imbl::{vector, Vector};
340 use matrix_sdk::{
341 crypto::types::events::UtdCause,
342 deserialized_responses::{AlgorithmInfo, EncryptionInfo, VerificationState},
343 };
344 use ruma::{
345 events::room::{
346 encrypted::{
347 EncryptedEventScheme, MegolmV1AesSha2Content, MegolmV1AesSha2ContentInit,
348 RoomEncryptedEventContent,
349 },
350 message::RoomMessageEventContent,
351 },
352 owned_device_id, owned_event_id, owned_user_id, MilliSecondsSinceUnixEpoch,
353 OwnedTransactionId,
354 };
355
356 use crate::timeline::{
357 controller::decryption_retry_task::compute_event_indices_to_retry_decryption,
358 event_item::{
359 EventTimelineItemKind, LocalEventTimelineItem, RemoteEventOrigin,
360 RemoteEventTimelineItem,
361 },
362 EventSendState, EventTimelineItem, ReactionsByKeyBySender, TimelineDetails, TimelineItem,
363 TimelineItemContent, TimelineItemKind, TimelineUniqueId, VirtualTimelineItem,
364 };
365
366 #[test]
367 fn test_non_events_are_not_retried() {
368 let timeline = vector![TimelineItem::read_marker(), date_divider()];
370 let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
372 assert!(answer.0.is_empty());
374 assert!(answer.1.is_empty());
375 }
376
377 #[test]
378 fn test_non_remote_events_are_not_retried() {
379 let timeline = vector![local_event()];
381 let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
383 assert!(answer.0.is_empty());
385 assert!(answer.1.is_empty());
386 }
387
388 #[test]
389 fn test_utds_are_retried() {
390 let timeline = vector![utd_event("session1")];
392 let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
394 assert_eq!(answer.0, vec![0]);
396 assert!(answer.1.is_empty());
397 }
398
399 #[test]
400 fn test_remote_decrypted_info_is_refetched() {
401 let timeline = vector![decrypted_event("session1")];
403 let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
405 assert!(answer.0.is_empty());
407 assert_eq!(answer.1, vec![0]);
408 }
409
410 #[test]
411 fn test_only_required_sessions_are_retried() {
412 fn retry(s: &str) -> bool {
415 s == "session1"
416 }
417
418 let timeline = vector![
421 TimelineItem::read_marker(),
422 utd_event("session1"),
423 utd_event("session1"),
424 date_divider(),
425 utd_event("session2"),
426 decrypted_event("session1"),
427 decrypted_event("session1"),
428 decrypted_event("session2"),
429 local_event(),
430 ];
431
432 let answer = compute_event_indices_to_retry_decryption(&timeline, retry);
434
435 assert_eq!(answer.0, vec![1, 2]);
437 assert_eq!(answer.1, vec![5, 6]);
438 }
439
440 fn always_retry(_: &str) -> bool {
441 true
442 }
443
444 fn date_divider() -> Arc<TimelineItem> {
445 TimelineItem::new(
446 TimelineItemKind::Virtual(VirtualTimelineItem::DateDivider(timestamp())),
447 TimelineUniqueId("datething".to_owned()),
448 )
449 }
450
451 fn local_event() -> Arc<TimelineItem> {
452 let event_kind = EventTimelineItemKind::Local(LocalEventTimelineItem {
453 send_state: EventSendState::NotSentYet,
454 transaction_id: OwnedTransactionId::from("trans"),
455 send_handle: None,
456 });
457
458 TimelineItem::new(
459 TimelineItemKind::Event(EventTimelineItem::new(
460 owned_user_id!("@u:s.to"),
461 TimelineDetails::Pending,
462 timestamp(),
463 TimelineItemContent::RedactedMessage,
464 event_kind,
465 true,
466 )),
467 TimelineUniqueId("local".to_owned()),
468 )
469 }
470
471 fn utd_event(session_id: &str) -> Arc<TimelineItem> {
472 let event_kind = EventTimelineItemKind::Remote(RemoteEventTimelineItem {
473 event_id: owned_event_id!("$local"),
474 transaction_id: None,
475 read_receipts: Default::default(),
476 is_own: false,
477 is_highlighted: false,
478 encryption_info: None,
479 original_json: None,
480 latest_edit_json: None,
481 origin: RemoteEventOrigin::Sync,
482 });
483
484 TimelineItem::new(
485 TimelineItemKind::Event(EventTimelineItem::new(
486 owned_user_id!("@u:s.to"),
487 TimelineDetails::Pending,
488 timestamp(),
489 TimelineItemContent::unable_to_decrypt(
490 RoomEncryptedEventContent::new(
491 EncryptedEventScheme::MegolmV1AesSha2(MegolmV1AesSha2Content::from(
492 MegolmV1AesSha2ContentInit {
493 ciphertext: "cyf".to_owned(),
494 sender_key: "sendk".to_owned(),
495 device_id: owned_device_id!("DEV"),
496 session_id: session_id.to_owned(),
497 },
498 )),
499 None,
500 ),
501 UtdCause::Unknown,
502 ),
503 event_kind,
504 true,
505 )),
506 TimelineUniqueId("local".to_owned()),
507 )
508 }
509
510 fn decrypted_event(session_id: &str) -> Arc<TimelineItem> {
511 let event_kind = EventTimelineItemKind::Remote(RemoteEventTimelineItem {
512 event_id: owned_event_id!("$local"),
513 transaction_id: None,
514 read_receipts: Default::default(),
515 is_own: false,
516 is_highlighted: false,
517 encryption_info: Some(EncryptionInfo {
518 sender: owned_user_id!("@u:s.co"),
519 sender_device: None,
520 algorithm_info: AlgorithmInfo::MegolmV1AesSha2 {
521 curve25519_key: "".to_owned(),
522 sender_claimed_keys: BTreeMap::new(),
523 },
524 verification_state: VerificationState::Verified,
525 session_id: Some(session_id.to_owned()),
526 }),
527 original_json: None,
528 latest_edit_json: None,
529 origin: RemoteEventOrigin::Sync,
530 });
531
532 TimelineItem::new(
533 TimelineItemKind::Event(EventTimelineItem::new(
534 owned_user_id!("@u:s.to"),
535 TimelineDetails::Pending,
536 timestamp(),
537 TimelineItemContent::message(
538 RoomMessageEventContent::text_plain("hi"),
539 None,
540 &Vector::new(),
541 ReactionsByKeyBySender::default(),
542 ),
543 event_kind,
544 true,
545 )),
546 TimelineUniqueId("local".to_owned()),
547 )
548 }
549
550 fn timestamp() -> MilliSecondsSinceUnixEpoch {
551 MilliSecondsSinceUnixEpoch::from_system_time(SystemTime::UNIX_EPOCH).unwrap()
552 }
553}