matrix_sdk_ui/timeline/controller/
decryption_retry_task.rs

1// Copyright 2025 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{collections::BTreeSet, sync::Arc};
16
17use imbl::Vector;
18use itertools::{Either, Itertools as _};
19use matrix_sdk::{
20    deserialized_responses::TimelineEventKind as SdkTimelineEventKind, executor::JoinHandle,
21};
22use tokio::sync::{
23    mpsc::{self, Receiver, Sender},
24    RwLock,
25};
26use tracing::{debug, error, field, info, info_span, Instrument as _};
27
28use crate::timeline::{
29    controller::{TimelineSettings, TimelineState},
30    event_item::EventTimelineItemKind,
31    traits::{Decryptor, RoomDataProvider},
32    EncryptedMessage, EventTimelineItem, TimelineItem, TimelineItemContent, TimelineItemKind,
33};
34
35/// Holds a long-running task that is used to retry decryption of items in the
36/// timeline when new information about a session is received.
37///
38/// Creating an instance with [`DecryptionRetryTask::new`] creates the async
39/// task, and a channel that is used to communicate with it.
40///
41/// The underlying async task will stop soon after the [`DecryptionRetryTask`]
42/// is dropped, because it waits for the channel to close, which happens when we
43/// drop the sending side.
44#[derive(Clone, Debug)]
45pub struct DecryptionRetryTask<D: Decryptor> {
46    /// The sending side of the channel that we have open to the long-running
47    /// async task. Every time we want to retry decrypting some events, we
48    /// send a [`DecryptionRetryRequest`] along this channel. Users of this
49    /// struct call [`DecryptionRetryTask::decrypt`] to do this.
50    sender: Sender<DecryptionRetryRequest<D>>,
51
52    /// The join handle of the task. We don't actually use this, since the task
53    /// will end soon after we are dropped, because when `sender` is dropped the
54    /// task will see that the channel closed, but we hold on to the handle to
55    /// indicate that we own the task.
56    _task_handle: Arc<JoinHandle<()>>,
57}
58
59/// How many concurrent retry requests we will queue before blocking when
60/// attempting to queue another. We don't normally expect more than one or two
61/// will be queued at a time, so blocking should be a rare occurrence.
62const CHANNEL_BUFFER_SIZE: usize = 100;
63
64impl<D: Decryptor> DecryptionRetryTask<D> {
65    pub(crate) fn new<P: RoomDataProvider>(
66        state: Arc<RwLock<TimelineState>>,
67        room_data_provider: P,
68    ) -> Self {
69        // We will send decryption requests down this channel to the long-running task
70        let (sender, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
71
72        // Spawn the long-running task, providing the receiver so we can listen for
73        // decryption requests
74        let handle =
75            matrix_sdk::executor::spawn(decryption_task(state, room_data_provider, receiver));
76
77        // Keep hold of the sender so we can send off decryption requests to the task.
78        Self { sender, _task_handle: Arc::new(handle) }
79    }
80
81    /// Use the supplied decryptor to attempt redecryption of the events
82    /// associated with the supplied session IDs.
83    pub(crate) async fn decrypt(
84        &self,
85        decryptor: D,
86        session_ids: Option<BTreeSet<String>>,
87        settings: TimelineSettings,
88    ) {
89        let res =
90            self.sender.send(DecryptionRetryRequest { decryptor, session_ids, settings }).await;
91
92        if let Err(error) = res {
93            error!("Failed to send decryption retry request: {}", error);
94        }
95    }
96}
97
98/// The information sent across the channel to the long-running task requesting
99/// that the supplied set of sessions be retried.
100struct DecryptionRetryRequest<D: Decryptor> {
101    decryptor: D,
102    session_ids: Option<BTreeSet<String>>,
103    settings: TimelineSettings,
104}
105
106/// Long-running task that waits for decryption requests to come through the
107/// supplied channel `receiver` and act on them. Stops when the channel is
108/// closed, i.e. when the sender side is dropped.
109async fn decryption_task<D: Decryptor>(
110    state: Arc<RwLock<TimelineState>>,
111    room_data_provider: impl RoomDataProvider,
112    mut receiver: Receiver<DecryptionRetryRequest<D>>,
113) {
114    debug!("Decryption task starting.");
115
116    while let Some(request) = receiver.recv().await {
117        let should_retry = |session_id: &str| {
118            if let Some(session_ids) = &request.session_ids {
119                session_ids.contains(session_id)
120            } else {
121                true
122            }
123        };
124
125        // Find the indices of events that are in the supplied sessions, distinguishing
126        // between UTDs which we need to decrypt, and already-decrypted events where we
127        // only need to re-fetch encryption info.
128        let mut state = state.write().await;
129        let (retry_decryption_indices, retry_info_indices) =
130            compute_event_indices_to_retry_decryption(&state.items, should_retry);
131
132        // Retry fetching encryption info for events that are already decrypted
133        if !retry_info_indices.is_empty() {
134            debug!("Retrying fetching encryption info");
135            retry_fetch_encryption_info(&mut state, retry_info_indices, &room_data_provider).await;
136        }
137
138        // Retry decrypting any unable-to-decrypt messages
139        if !retry_decryption_indices.is_empty() {
140            debug!("Retrying decryption");
141            decrypt_by_index(
142                &mut state,
143                &request.settings,
144                &room_data_provider,
145                request.decryptor,
146                should_retry,
147                retry_decryption_indices,
148            )
149            .await
150        }
151    }
152
153    debug!("Decryption task stopping.");
154}
155
156/// Decide which events should be retried, either for re-decryption, or, if they
157/// are already decrypted, for re-checking their encryption info.
158///
159/// Returns a tuple `(retry_decryption_indices, retry_info_indices)` where
160/// `retry_decryption_indices` is a list of the indices of UTDs to try
161/// decrypting, and retry_info_indices is a list of the indices of
162/// already-decrypted events whose encryption info we can re-fetch.
163fn compute_event_indices_to_retry_decryption(
164    items: &Vector<Arc<TimelineItem>>,
165    should_retry: impl Fn(&str) -> bool,
166) -> (Vec<usize>, Vec<usize>) {
167    use Either::{Left, Right};
168
169    // We retry an event if its session ID should be retried
170    let should_retry_event = |event: &EventTimelineItem| {
171        let session_id =
172            if let TimelineItemContent::UnableToDecrypt(encrypted_message) = event.content() {
173                // UTDs carry their session ID inside the content
174                encrypted_message.session_id()
175            } else {
176                // Non-UTDs only have a session ID if they are remote and have it in the
177                // EncryptionInfo
178                event
179                    .as_remote()
180                    .and_then(|remote| remote.encryption_info.as_ref()?.session_id.as_ref())
181                    .map(String::as_str)
182            };
183
184        if let Some(session_id) = session_id {
185            // Should we retry this session ID?
186            should_retry(session_id)
187        } else {
188            // No session ID: don't retry this event
189            false
190        }
191    };
192
193    items
194        .iter()
195        .enumerate()
196        .filter_map(|(idx, item)| {
197            item.as_event().filter(|e| should_retry_event(e)).map(|event| (idx, event))
198        })
199        // Break the result into 2 lists: (utds, decrypted)
200        .partition_map(
201            |(idx, event)| {
202                if event.content().is_unable_to_decrypt() {
203                    Left(idx)
204                } else {
205                    Right(idx)
206                }
207            },
208        )
209}
210
211/// Try to fetch [`EncryptionInfo`] for the events with the supplied
212/// indices, and update them where we succeed.
213pub(super) async fn retry_fetch_encryption_info<P: RoomDataProvider>(
214    state: &mut TimelineState,
215    retry_indices: Vec<usize>,
216    room_data_provider: &P,
217) {
218    for idx in retry_indices {
219        let old_item = state.items.get(idx);
220        if let Some(new_item) = make_replacement_for(room_data_provider, old_item).await {
221            state.items.replace(idx, new_item);
222        }
223    }
224}
225
226/// Create a replacement TimelineItem for the supplied one, with new
227/// [`EncryptionInfo`] from the supplied `room_data_provider`. Returns None if
228/// the supplied item is not a remote event, or if it doesn't have a session ID.
229async fn make_replacement_for<P: RoomDataProvider>(
230    room_data_provider: &P,
231    item: Option<&Arc<TimelineItem>>,
232) -> Option<Arc<TimelineItem>> {
233    let item = item?;
234    let event = item.as_event()?;
235    let remote = event.as_remote()?;
236    let session_id = remote.encryption_info.as_ref()?.session_id.as_deref()?;
237
238    let new_encryption_info =
239        room_data_provider.get_encryption_info(session_id, &event.sender).await;
240    let mut new_remote = remote.clone();
241    new_remote.encryption_info = new_encryption_info;
242    let new_item = item.with_kind(TimelineItemKind::Event(
243        event.with_kind(EventTimelineItemKind::Remote(new_remote)),
244    ));
245
246    Some(new_item)
247}
248
249/// Attempt decryption of the events encrypted with the session IDs in the
250/// supplied decryption `request`.
251async fn decrypt_by_index<D: Decryptor>(
252    state: &mut TimelineState,
253    settings: &TimelineSettings,
254    room_data_provider: &impl RoomDataProvider,
255    decryptor: D,
256    should_retry: impl Fn(&str) -> bool,
257    retry_indices: Vec<usize>,
258) {
259    let push_rules_context = room_data_provider.push_rules_and_context().await;
260    let unable_to_decrypt_hook = state.meta.unable_to_decrypt_hook.clone();
261
262    let retry_one = |item: Arc<TimelineItem>| {
263        let decryptor = decryptor.clone();
264        let should_retry = &should_retry;
265        let unable_to_decrypt_hook = unable_to_decrypt_hook.clone();
266        async move {
267            let event_item = item.as_event()?;
268
269            let session_id = match event_item.content().as_unable_to_decrypt()? {
270                EncryptedMessage::MegolmV1AesSha2 { session_id, .. }
271                    if should_retry(session_id) =>
272                {
273                    session_id
274                }
275                EncryptedMessage::MegolmV1AesSha2 { .. }
276                | EncryptedMessage::OlmV1Curve25519AesSha2 { .. }
277                | EncryptedMessage::Unknown => return None,
278            };
279
280            tracing::Span::current().record("session_id", session_id);
281
282            let Some(remote_event) = event_item.as_remote() else {
283                error!("Key for unable-to-decrypt timeline item is not an event ID");
284                return None;
285            };
286
287            tracing::Span::current().record("event_id", field::debug(&remote_event.event_id));
288
289            let Some(original_json) = &remote_event.original_json else {
290                error!("UTD item must contain original JSON");
291                return None;
292            };
293
294            match decryptor.decrypt_event_impl(original_json).await {
295                Ok(event) => {
296                    if let SdkTimelineEventKind::UnableToDecrypt { utd_info, .. } = event.kind {
297                        info!(
298                            "Failed to decrypt event after receiving room key: {:?}",
299                            utd_info.reason
300                        );
301                        None
302                    } else {
303                        // Notify observers that we managed to eventually decrypt an event.
304                        if let Some(hook) = unable_to_decrypt_hook {
305                            hook.on_late_decrypt(&remote_event.event_id).await;
306                        }
307
308                        Some(event)
309                    }
310                }
311                Err(e) => {
312                    info!("Failed to decrypt event after receiving room key: {e}");
313                    None
314                }
315            }
316        }
317        .instrument(info_span!(
318            "retry_one",
319            session_id = field::Empty,
320            event_id = field::Empty
321        ))
322    };
323
324    state
325        .retry_event_decryption(
326            retry_one,
327            retry_indices,
328            push_rules_context,
329            room_data_provider,
330            settings,
331        )
332        .await;
333}
334
335#[cfg(test)]
336mod tests {
337    use std::{collections::BTreeMap, sync::Arc, time::SystemTime};
338
339    use imbl::{vector, Vector};
340    use matrix_sdk::{
341        crypto::types::events::UtdCause,
342        deserialized_responses::{AlgorithmInfo, EncryptionInfo, VerificationState},
343    };
344    use ruma::{
345        events::room::{
346            encrypted::{
347                EncryptedEventScheme, MegolmV1AesSha2Content, MegolmV1AesSha2ContentInit,
348                RoomEncryptedEventContent,
349            },
350            message::RoomMessageEventContent,
351        },
352        owned_device_id, owned_event_id, owned_user_id, MilliSecondsSinceUnixEpoch,
353        OwnedTransactionId,
354    };
355
356    use crate::timeline::{
357        controller::decryption_retry_task::compute_event_indices_to_retry_decryption,
358        event_item::{
359            EventTimelineItemKind, LocalEventTimelineItem, RemoteEventOrigin,
360            RemoteEventTimelineItem,
361        },
362        EventSendState, EventTimelineItem, ReactionsByKeyBySender, TimelineDetails, TimelineItem,
363        TimelineItemContent, TimelineItemKind, TimelineUniqueId, VirtualTimelineItem,
364    };
365
366    #[test]
367    fn test_non_events_are_not_retried() {
368        // Given a timeline with only non-events
369        let timeline = vector![TimelineItem::read_marker(), date_divider()];
370        // When we ask what to retry
371        let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
372        // Then we retry nothing
373        assert!(answer.0.is_empty());
374        assert!(answer.1.is_empty());
375    }
376
377    #[test]
378    fn test_non_remote_events_are_not_retried() {
379        // Given a timeline with only local events
380        let timeline = vector![local_event()];
381        // When we ask what to retry
382        let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
383        // Then we retry nothing
384        assert!(answer.0.is_empty());
385        assert!(answer.1.is_empty());
386    }
387
388    #[test]
389    fn test_utds_are_retried() {
390        // Given a timeline with a UTD
391        let timeline = vector![utd_event("session1")];
392        // When we ask what to retry
393        let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
394        // Then we retry decrypting it, and don't refetch any encryption info
395        assert_eq!(answer.0, vec![0]);
396        assert!(answer.1.is_empty());
397    }
398
399    #[test]
400    fn test_remote_decrypted_info_is_refetched() {
401        // Given a timeline with a decrypted event
402        let timeline = vector![decrypted_event("session1")];
403        // When we ask what to retry
404        let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
405        // Then we don't need to decrypt anything, but we do refetch the encryption info
406        assert!(answer.0.is_empty());
407        assert_eq!(answer.1, vec![0]);
408    }
409
410    #[test]
411    fn test_only_required_sessions_are_retried() {
412        // Given we want to retry everything in session1 only
413
414        fn retry(s: &str) -> bool {
415            s == "session1"
416        }
417
418        // And we have a timeline containing non-events, local events, UTDs and
419        // decrypted events
420        let timeline = vector![
421            TimelineItem::read_marker(),
422            utd_event("session1"),
423            utd_event("session1"),
424            date_divider(),
425            utd_event("session2"),
426            decrypted_event("session1"),
427            decrypted_event("session1"),
428            decrypted_event("session2"),
429            local_event(),
430        ];
431
432        // When we ask what to retry
433        let answer = compute_event_indices_to_retry_decryption(&timeline, retry);
434
435        // Then we re-decrypt the UTDs, and refetch the decrypted events' info
436        assert_eq!(answer.0, vec![1, 2]);
437        assert_eq!(answer.1, vec![5, 6]);
438    }
439
440    fn always_retry(_: &str) -> bool {
441        true
442    }
443
444    fn date_divider() -> Arc<TimelineItem> {
445        TimelineItem::new(
446            TimelineItemKind::Virtual(VirtualTimelineItem::DateDivider(timestamp())),
447            TimelineUniqueId("datething".to_owned()),
448        )
449    }
450
451    fn local_event() -> Arc<TimelineItem> {
452        let event_kind = EventTimelineItemKind::Local(LocalEventTimelineItem {
453            send_state: EventSendState::NotSentYet,
454            transaction_id: OwnedTransactionId::from("trans"),
455            send_handle: None,
456        });
457
458        TimelineItem::new(
459            TimelineItemKind::Event(EventTimelineItem::new(
460                owned_user_id!("@u:s.to"),
461                TimelineDetails::Pending,
462                timestamp(),
463                TimelineItemContent::RedactedMessage,
464                event_kind,
465                true,
466            )),
467            TimelineUniqueId("local".to_owned()),
468        )
469    }
470
471    fn utd_event(session_id: &str) -> Arc<TimelineItem> {
472        let event_kind = EventTimelineItemKind::Remote(RemoteEventTimelineItem {
473            event_id: owned_event_id!("$local"),
474            transaction_id: None,
475            read_receipts: Default::default(),
476            is_own: false,
477            is_highlighted: false,
478            encryption_info: None,
479            original_json: None,
480            latest_edit_json: None,
481            origin: RemoteEventOrigin::Sync,
482        });
483
484        TimelineItem::new(
485            TimelineItemKind::Event(EventTimelineItem::new(
486                owned_user_id!("@u:s.to"),
487                TimelineDetails::Pending,
488                timestamp(),
489                TimelineItemContent::unable_to_decrypt(
490                    RoomEncryptedEventContent::new(
491                        EncryptedEventScheme::MegolmV1AesSha2(MegolmV1AesSha2Content::from(
492                            MegolmV1AesSha2ContentInit {
493                                ciphertext: "cyf".to_owned(),
494                                sender_key: "sendk".to_owned(),
495                                device_id: owned_device_id!("DEV"),
496                                session_id: session_id.to_owned(),
497                            },
498                        )),
499                        None,
500                    ),
501                    UtdCause::Unknown,
502                ),
503                event_kind,
504                true,
505            )),
506            TimelineUniqueId("local".to_owned()),
507        )
508    }
509
510    fn decrypted_event(session_id: &str) -> Arc<TimelineItem> {
511        let event_kind = EventTimelineItemKind::Remote(RemoteEventTimelineItem {
512            event_id: owned_event_id!("$local"),
513            transaction_id: None,
514            read_receipts: Default::default(),
515            is_own: false,
516            is_highlighted: false,
517            encryption_info: Some(EncryptionInfo {
518                sender: owned_user_id!("@u:s.co"),
519                sender_device: None,
520                algorithm_info: AlgorithmInfo::MegolmV1AesSha2 {
521                    curve25519_key: "".to_owned(),
522                    sender_claimed_keys: BTreeMap::new(),
523                },
524                verification_state: VerificationState::Verified,
525                session_id: Some(session_id.to_owned()),
526            }),
527            original_json: None,
528            latest_edit_json: None,
529            origin: RemoteEventOrigin::Sync,
530        });
531
532        TimelineItem::new(
533            TimelineItemKind::Event(EventTimelineItem::new(
534                owned_user_id!("@u:s.to"),
535                TimelineDetails::Pending,
536                timestamp(),
537                TimelineItemContent::message(
538                    RoomMessageEventContent::text_plain("hi"),
539                    None,
540                    &Vector::new(),
541                    ReactionsByKeyBySender::default(),
542                ),
543                event_kind,
544                true,
545            )),
546            TimelineUniqueId("local".to_owned()),
547        )
548    }
549
550    fn timestamp() -> MilliSecondsSinceUnixEpoch {
551        MilliSecondsSinceUnixEpoch::from_system_time(SystemTime::UNIX_EPOCH).unwrap()
552    }
553}