matrix_sdk_ui/timeline/controller/
decryption_retry_task.rs

1// Copyright 2025 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{collections::BTreeSet, sync::Arc};
16
17use imbl::Vector;
18use itertools::{Either, Itertools as _};
19use matrix_sdk::{
20    deserialized_responses::TimelineEventKind as SdkTimelineEventKind, executor::JoinHandle,
21};
22use tokio::sync::{
23    mpsc::{self, Receiver, Sender},
24    RwLock,
25};
26use tracing::{debug, error, field, info, info_span, Instrument as _};
27
28use crate::timeline::{
29    controller::{TimelineSettings, TimelineState},
30    event_item::EventTimelineItemKind,
31    traits::{Decryptor, RoomDataProvider},
32    EncryptedMessage, EventTimelineItem, TimelineItem, TimelineItemKind,
33};
34
35/// Holds a long-running task that is used to retry decryption of items in the
36/// timeline when new information about a session is received.
37///
38/// Creating an instance with [`DecryptionRetryTask::new`] creates the async
39/// task, and a channel that is used to communicate with it.
40///
41/// The underlying async task will stop soon after the [`DecryptionRetryTask`]
42/// is dropped, because it waits for the channel to close, which happens when we
43/// drop the sending side.
44#[derive(Clone, Debug)]
45pub struct DecryptionRetryTask<D: Decryptor> {
46    /// The sending side of the channel that we have open to the long-running
47    /// async task. Every time we want to retry decrypting some events, we
48    /// send a [`DecryptionRetryRequest`] along this channel. Users of this
49    /// struct call [`DecryptionRetryTask::decrypt`] to do this.
50    sender: Sender<DecryptionRetryRequest<D>>,
51
52    /// The join handle of the task. We don't actually use this, since the task
53    /// will end soon after we are dropped, because when `sender` is dropped the
54    /// task will see that the channel closed, but we hold on to the handle to
55    /// indicate that we own the task.
56    _task_handle: Arc<JoinHandle<()>>,
57}
58
59/// How many concurrent retry requests we will queue before blocking when
60/// attempting to queue another. We don't normally expect more than one or two
61/// will be queued at a time, so blocking should be a rare occurrence.
62const CHANNEL_BUFFER_SIZE: usize = 100;
63
64impl<D: Decryptor> DecryptionRetryTask<D> {
65    pub(crate) fn new<P: RoomDataProvider>(
66        state: Arc<RwLock<TimelineState>>,
67        room_data_provider: P,
68    ) -> Self {
69        // We will send decryption requests down this channel to the long-running task
70        let (sender, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
71
72        // Spawn the long-running task, providing the receiver so we can listen for
73        // decryption requests
74        let handle =
75            matrix_sdk::executor::spawn(decryption_task(state, room_data_provider, receiver));
76
77        // Keep hold of the sender so we can send off decryption requests to the task.
78        Self { sender, _task_handle: Arc::new(handle) }
79    }
80
81    /// Use the supplied decryptor to attempt redecryption of the events
82    /// associated with the supplied session IDs.
83    pub(crate) async fn decrypt(
84        &self,
85        decryptor: D,
86        session_ids: Option<BTreeSet<String>>,
87        settings: TimelineSettings,
88    ) {
89        let res =
90            self.sender.send(DecryptionRetryRequest { decryptor, session_ids, settings }).await;
91
92        if let Err(error) = res {
93            error!("Failed to send decryption retry request: {}", error);
94        }
95    }
96}
97
98/// The information sent across the channel to the long-running task requesting
99/// that the supplied set of sessions be retried.
100struct DecryptionRetryRequest<D: Decryptor> {
101    decryptor: D,
102    session_ids: Option<BTreeSet<String>>,
103    settings: TimelineSettings,
104}
105
106/// Long-running task that waits for decryption requests to come through the
107/// supplied channel `receiver` and act on them. Stops when the channel is
108/// closed, i.e. when the sender side is dropped.
109async fn decryption_task<D: Decryptor>(
110    state: Arc<RwLock<TimelineState>>,
111    room_data_provider: impl RoomDataProvider,
112    mut receiver: Receiver<DecryptionRetryRequest<D>>,
113) {
114    debug!("Decryption task starting.");
115
116    while let Some(request) = receiver.recv().await {
117        let should_retry = |session_id: &str| {
118            if let Some(session_ids) = &request.session_ids {
119                session_ids.contains(session_id)
120            } else {
121                true
122            }
123        };
124
125        // Find the indices of events that are in the supplied sessions, distinguishing
126        // between UTDs which we need to decrypt, and already-decrypted events where we
127        // only need to re-fetch encryption info.
128        let mut state = state.write().await;
129        let (retry_decryption_indices, retry_info_indices) =
130            compute_event_indices_to_retry_decryption(&state.items, should_retry);
131
132        // Retry fetching encryption info for events that are already decrypted
133        if !retry_info_indices.is_empty() {
134            debug!("Retrying fetching encryption info");
135            retry_fetch_encryption_info(&mut state, retry_info_indices, &room_data_provider).await;
136        }
137
138        // Retry decrypting any unable-to-decrypt messages
139        if !retry_decryption_indices.is_empty() {
140            debug!("Retrying decryption");
141            decrypt_by_index(
142                &mut state,
143                &request.settings,
144                &room_data_provider,
145                request.decryptor,
146                should_retry,
147                retry_decryption_indices,
148            )
149            .await
150        }
151    }
152
153    debug!("Decryption task stopping.");
154}
155
156/// Decide which events should be retried, either for re-decryption, or, if they
157/// are already decrypted, for re-checking their encryption info.
158///
159/// Returns a tuple `(retry_decryption_indices, retry_info_indices)` where
160/// `retry_decryption_indices` is a list of the indices of UTDs to try
161/// decrypting, and retry_info_indices is a list of the indices of
162/// already-decrypted events whose encryption info we can re-fetch.
163fn compute_event_indices_to_retry_decryption(
164    items: &Vector<Arc<TimelineItem>>,
165    should_retry: impl Fn(&str) -> bool,
166) -> (Vec<usize>, Vec<usize>) {
167    use Either::{Left, Right};
168
169    // We retry an event if its session ID should be retried
170    let should_retry_event = |event: &EventTimelineItem| {
171        let session_id = if let Some(encrypted_message) = event.content().as_unable_to_decrypt() {
172            // UTDs carry their session ID inside the content
173            encrypted_message.session_id()
174        } else {
175            // Non-UTDs only have a session ID if they are remote and have it in the
176            // EncryptionInfo
177            event
178                .as_remote()
179                .and_then(|remote| remote.encryption_info.as_ref()?.session_id.as_ref())
180                .map(String::as_str)
181        };
182
183        if let Some(session_id) = session_id {
184            // Should we retry this session ID?
185            should_retry(session_id)
186        } else {
187            // No session ID: don't retry this event
188            false
189        }
190    };
191
192    items
193        .iter()
194        .enumerate()
195        .filter_map(|(idx, item)| {
196            item.as_event().filter(|e| should_retry_event(e)).map(|event| (idx, event))
197        })
198        // Break the result into 2 lists: (utds, decrypted)
199        .partition_map(
200            |(idx, event)| {
201                if event.content().is_unable_to_decrypt() {
202                    Left(idx)
203                } else {
204                    Right(idx)
205                }
206            },
207        )
208}
209
210/// Try to fetch [`EncryptionInfo`] for the events with the supplied
211/// indices, and update them where we succeed.
212pub(super) async fn retry_fetch_encryption_info<P: RoomDataProvider>(
213    state: &mut TimelineState,
214    retry_indices: Vec<usize>,
215    room_data_provider: &P,
216) {
217    for idx in retry_indices {
218        let old_item = state.items.get(idx);
219        if let Some(new_item) = make_replacement_for(room_data_provider, old_item).await {
220            state.items.replace(idx, new_item);
221        }
222    }
223}
224
225/// Create a replacement TimelineItem for the supplied one, with new
226/// [`EncryptionInfo`] from the supplied `room_data_provider`. Returns None if
227/// the supplied item is not a remote event, or if it doesn't have a session ID.
228async fn make_replacement_for<P: RoomDataProvider>(
229    room_data_provider: &P,
230    item: Option<&Arc<TimelineItem>>,
231) -> Option<Arc<TimelineItem>> {
232    let item = item?;
233    let event = item.as_event()?;
234    let remote = event.as_remote()?;
235    let session_id = remote.encryption_info.as_ref()?.session_id.as_deref()?;
236
237    let new_encryption_info =
238        room_data_provider.get_encryption_info(session_id, &event.sender).await;
239    let mut new_remote = remote.clone();
240    new_remote.encryption_info = new_encryption_info;
241    let new_item = item.with_kind(TimelineItemKind::Event(
242        event.with_kind(EventTimelineItemKind::Remote(new_remote)),
243    ));
244
245    Some(new_item)
246}
247
248/// Attempt decryption of the events encrypted with the session IDs in the
249/// supplied decryption `request`.
250async fn decrypt_by_index<D: Decryptor>(
251    state: &mut TimelineState,
252    settings: &TimelineSettings,
253    room_data_provider: &impl RoomDataProvider,
254    decryptor: D,
255    should_retry: impl Fn(&str) -> bool,
256    retry_indices: Vec<usize>,
257) {
258    let push_ctx = room_data_provider.push_context().await;
259    let push_ctx = push_ctx.as_ref();
260    let unable_to_decrypt_hook = state.meta.unable_to_decrypt_hook.clone();
261
262    let retry_one = |item: Arc<TimelineItem>| {
263        let decryptor = decryptor.clone();
264        let should_retry = &should_retry;
265        let unable_to_decrypt_hook = unable_to_decrypt_hook.clone();
266        async move {
267            let event_item = item.as_event()?;
268
269            let session_id = match event_item.content().as_unable_to_decrypt()? {
270                EncryptedMessage::MegolmV1AesSha2 { session_id, .. }
271                    if should_retry(session_id) =>
272                {
273                    session_id
274                }
275                EncryptedMessage::MegolmV1AesSha2 { .. }
276                | EncryptedMessage::OlmV1Curve25519AesSha2 { .. }
277                | EncryptedMessage::Unknown => return None,
278            };
279
280            tracing::Span::current().record("session_id", session_id);
281
282            let Some(remote_event) = event_item.as_remote() else {
283                error!("Key for unable-to-decrypt timeline item is not an event ID");
284                return None;
285            };
286
287            tracing::Span::current().record("event_id", field::debug(&remote_event.event_id));
288
289            let Some(original_json) = &remote_event.original_json else {
290                error!("UTD item must contain original JSON");
291                return None;
292            };
293
294            match decryptor.decrypt_event_impl(original_json, push_ctx).await {
295                Ok(event) => {
296                    if let SdkTimelineEventKind::UnableToDecrypt { utd_info, .. } = event.kind {
297                        info!(
298                            "Failed to decrypt event after receiving room key: {:?}",
299                            utd_info.reason
300                        );
301                        None
302                    } else {
303                        // Notify observers that we managed to eventually decrypt an event.
304                        if let Some(hook) = unable_to_decrypt_hook {
305                            hook.on_late_decrypt(&remote_event.event_id).await;
306                        }
307
308                        Some(event)
309                    }
310                }
311                Err(e) => {
312                    info!("Failed to decrypt event after receiving room key: {e}");
313                    None
314                }
315            }
316        }
317        .instrument(info_span!(
318            "retry_one",
319            session_id = field::Empty,
320            event_id = field::Empty
321        ))
322    };
323
324    state.retry_event_decryption(retry_one, retry_indices, room_data_provider, settings).await;
325}
326
327#[cfg(test)]
328mod tests {
329    use std::{collections::BTreeMap, sync::Arc, time::SystemTime};
330
331    use imbl::vector;
332    use matrix_sdk::{
333        crypto::types::events::UtdCause,
334        deserialized_responses::{AlgorithmInfo, EncryptionInfo, VerificationState},
335    };
336    use ruma::{
337        events::room::{
338            encrypted::{
339                EncryptedEventScheme, MegolmV1AesSha2Content, MegolmV1AesSha2ContentInit,
340                RoomEncryptedEventContent,
341            },
342            message::RoomMessageEventContent,
343        },
344        owned_device_id, owned_event_id, owned_user_id, MilliSecondsSinceUnixEpoch,
345        OwnedTransactionId,
346    };
347
348    use crate::timeline::{
349        controller::decryption_retry_task::compute_event_indices_to_retry_decryption,
350        event_item::{
351            EventTimelineItemKind, LocalEventTimelineItem, RemoteEventOrigin,
352            RemoteEventTimelineItem,
353        },
354        EncryptedMessage, EventSendState, EventTimelineItem, MsgLikeContent,
355        ReactionsByKeyBySender, TimelineDetails, TimelineItem, TimelineItemContent,
356        TimelineItemKind, TimelineUniqueId, VirtualTimelineItem,
357    };
358
359    #[test]
360    fn test_non_events_are_not_retried() {
361        // Given a timeline with only non-events
362        let timeline = vector![TimelineItem::read_marker(), date_divider()];
363        // When we ask what to retry
364        let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
365        // Then we retry nothing
366        assert!(answer.0.is_empty());
367        assert!(answer.1.is_empty());
368    }
369
370    #[test]
371    fn test_non_remote_events_are_not_retried() {
372        // Given a timeline with only local events
373        let timeline = vector![local_event()];
374        // When we ask what to retry
375        let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
376        // Then we retry nothing
377        assert!(answer.0.is_empty());
378        assert!(answer.1.is_empty());
379    }
380
381    #[test]
382    fn test_utds_are_retried() {
383        // Given a timeline with a UTD
384        let timeline = vector![utd_event("session1")];
385        // When we ask what to retry
386        let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
387        // Then we retry decrypting it, and don't refetch any encryption info
388        assert_eq!(answer.0, vec![0]);
389        assert!(answer.1.is_empty());
390    }
391
392    #[test]
393    fn test_remote_decrypted_info_is_refetched() {
394        // Given a timeline with a decrypted event
395        let timeline = vector![decrypted_event("session1")];
396        // When we ask what to retry
397        let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
398        // Then we don't need to decrypt anything, but we do refetch the encryption info
399        assert!(answer.0.is_empty());
400        assert_eq!(answer.1, vec![0]);
401    }
402
403    #[test]
404    fn test_only_required_sessions_are_retried() {
405        // Given we want to retry everything in session1 only
406
407        fn retry(s: &str) -> bool {
408            s == "session1"
409        }
410
411        // And we have a timeline containing non-events, local events, UTDs and
412        // decrypted events
413        let timeline = vector![
414            TimelineItem::read_marker(),
415            utd_event("session1"),
416            utd_event("session1"),
417            date_divider(),
418            utd_event("session2"),
419            decrypted_event("session1"),
420            decrypted_event("session1"),
421            decrypted_event("session2"),
422            local_event(),
423        ];
424
425        // When we ask what to retry
426        let answer = compute_event_indices_to_retry_decryption(&timeline, retry);
427
428        // Then we re-decrypt the UTDs, and refetch the decrypted events' info
429        assert_eq!(answer.0, vec![1, 2]);
430        assert_eq!(answer.1, vec![5, 6]);
431    }
432
433    fn always_retry(_: &str) -> bool {
434        true
435    }
436
437    fn date_divider() -> Arc<TimelineItem> {
438        TimelineItem::new(
439            TimelineItemKind::Virtual(VirtualTimelineItem::DateDivider(timestamp())),
440            TimelineUniqueId("datething".to_owned()),
441        )
442    }
443
444    fn local_event() -> Arc<TimelineItem> {
445        let event_kind = EventTimelineItemKind::Local(LocalEventTimelineItem {
446            send_state: EventSendState::NotSentYet,
447            transaction_id: OwnedTransactionId::from("trans"),
448            send_handle: None,
449        });
450
451        TimelineItem::new(
452            TimelineItemKind::Event(EventTimelineItem::new(
453                owned_user_id!("@u:s.to"),
454                TimelineDetails::Pending,
455                timestamp(),
456                TimelineItemContent::MsgLike(MsgLikeContent::redacted()),
457                event_kind,
458                true,
459            )),
460            TimelineUniqueId("local".to_owned()),
461        )
462    }
463
464    fn utd_event(session_id: &str) -> Arc<TimelineItem> {
465        let event_kind = EventTimelineItemKind::Remote(RemoteEventTimelineItem {
466            event_id: owned_event_id!("$local"),
467            transaction_id: None,
468            read_receipts: Default::default(),
469            is_own: false,
470            is_highlighted: false,
471            encryption_info: None,
472            original_json: None,
473            latest_edit_json: None,
474            origin: RemoteEventOrigin::Sync,
475        });
476
477        TimelineItem::new(
478            TimelineItemKind::Event(EventTimelineItem::new(
479                owned_user_id!("@u:s.to"),
480                TimelineDetails::Pending,
481                timestamp(),
482                TimelineItemContent::MsgLike(MsgLikeContent::unable_to_decrypt(
483                    EncryptedMessage::from_content(
484                        RoomEncryptedEventContent::new(
485                            EncryptedEventScheme::MegolmV1AesSha2(MegolmV1AesSha2Content::from(
486                                MegolmV1AesSha2ContentInit {
487                                    ciphertext: "cyf".to_owned(),
488                                    sender_key: "sendk".to_owned(),
489                                    device_id: owned_device_id!("DEV"),
490                                    session_id: session_id.to_owned(),
491                                },
492                            )),
493                            None,
494                        ),
495                        UtdCause::Unknown,
496                    ),
497                )),
498                event_kind,
499                true,
500            )),
501            TimelineUniqueId("local".to_owned()),
502        )
503    }
504
505    fn decrypted_event(session_id: &str) -> Arc<TimelineItem> {
506        let event_kind = EventTimelineItemKind::Remote(RemoteEventTimelineItem {
507            event_id: owned_event_id!("$local"),
508            transaction_id: None,
509            read_receipts: Default::default(),
510            is_own: false,
511            is_highlighted: false,
512            encryption_info: Some(EncryptionInfo {
513                sender: owned_user_id!("@u:s.co"),
514                sender_device: None,
515                algorithm_info: AlgorithmInfo::MegolmV1AesSha2 {
516                    curve25519_key: "".to_owned(),
517                    sender_claimed_keys: BTreeMap::new(),
518                },
519                verification_state: VerificationState::Verified,
520                session_id: Some(session_id.to_owned()),
521            }),
522            original_json: None,
523            latest_edit_json: None,
524            origin: RemoteEventOrigin::Sync,
525        });
526
527        TimelineItem::new(
528            TimelineItemKind::Event(EventTimelineItem::new(
529                owned_user_id!("@u:s.to"),
530                TimelineDetails::Pending,
531                timestamp(),
532                TimelineItemContent::message(
533                    RoomMessageEventContent::text_plain("hi"),
534                    None,
535                    ReactionsByKeyBySender::default(),
536                    None,
537                    None,
538                    None,
539                ),
540                event_kind,
541                true,
542            )),
543            TimelineUniqueId("local".to_owned()),
544        )
545    }
546
547    fn timestamp() -> MilliSecondsSinceUnixEpoch {
548        MilliSecondsSinceUnixEpoch::from_system_time(SystemTime::UNIX_EPOCH).unwrap()
549    }
550}