matrix_sdk/encryption/
tasks.rs

1// Copyright 2023-2024 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{collections::BTreeMap, sync::Arc, time::Duration};
16
17use matrix_sdk_common::failures_cache::FailuresCache;
18use ruma::{
19    events::room::encrypted::{EncryptedEventScheme, OriginalSyncRoomEncryptedEvent},
20    serde::Raw,
21    OwnedEventId, OwnedRoomId,
22};
23use tokio::sync::{
24    mpsc::{self, UnboundedReceiver},
25    Mutex,
26};
27use tracing::{debug, trace, warn};
28
29use crate::{
30    client::WeakClient,
31    encryption::backups::UploadState,
32    executor::{spawn, JoinHandle},
33    Client,
34};
35
36/// A cache of room keys we already downloaded.
37type DownloadCache = FailuresCache<RoomKeyInfo>;
38
39#[derive(Default)]
40pub(crate) struct ClientTasks {
41    #[cfg(feature = "e2e-encryption")]
42    pub(crate) upload_room_keys: Option<BackupUploadingTask>,
43    #[cfg(feature = "e2e-encryption")]
44    pub(crate) download_room_keys: Option<BackupDownloadTask>,
45    #[cfg(feature = "e2e-encryption")]
46    pub(crate) update_recovery_state_after_backup: Option<JoinHandle<()>>,
47    pub(crate) setup_e2ee: Option<JoinHandle<()>>,
48}
49
50#[cfg(feature = "e2e-encryption")]
51pub(crate) struct BackupUploadingTask {
52    sender: mpsc::UnboundedSender<()>,
53    #[allow(dead_code)]
54    join_handle: JoinHandle<()>,
55}
56
57#[cfg(feature = "e2e-encryption")]
58impl Drop for BackupUploadingTask {
59    fn drop(&mut self) {
60        #[cfg(not(target_arch = "wasm32"))]
61        self.join_handle.abort();
62    }
63}
64
65#[cfg(feature = "e2e-encryption")]
66impl BackupUploadingTask {
67    pub(crate) fn new(client: WeakClient) -> Self {
68        let (sender, receiver) = mpsc::unbounded_channel();
69
70        let join_handle = spawn(async move {
71            Self::listen(client, receiver).await;
72        });
73
74        Self { sender, join_handle }
75    }
76
77    pub(crate) fn trigger_upload(&self) {
78        let _ = self.sender.send(());
79    }
80
81    pub(crate) async fn listen(client: WeakClient, mut receiver: UnboundedReceiver<()>) {
82        while receiver.recv().await.is_some() {
83            if let Some(client) = client.get() {
84                let upload_progress = &client.inner.e2ee.backup_state.upload_progress;
85
86                if let Err(e) = client.encryption().backups().backup_room_keys().await {
87                    upload_progress.set(UploadState::Error);
88                    warn!("Error backing up room keys {e:?}");
89                    // Note: it's expected we're not `continue`ing here, because
90                    // *every* single state update
91                    // is propagated to the caller.
92                }
93
94                upload_progress.set(UploadState::Idle);
95            } else {
96                trace!("Client got dropped, shutting down the task");
97                break;
98            }
99        }
100    }
101}
102
103/// Information about a request for a backup download for an undecryptable
104/// event.
105#[derive(Debug)]
106struct RoomKeyDownloadRequest {
107    /// The room in which the event was sent.
108    room_id: OwnedRoomId,
109
110    /// The ID of the event we could not decrypt.
111    event_id: OwnedEventId,
112
113    /// The event we could not decrypt.
114    event: Raw<OriginalSyncRoomEncryptedEvent>,
115
116    /// The unique ID of the room key that the event was encrypted with.
117    megolm_session_id: String,
118}
119
120impl RoomKeyDownloadRequest {
121    pub fn to_room_key_info(&self) -> RoomKeyInfo {
122        (self.room_id.clone(), self.megolm_session_id.clone())
123    }
124}
125
126pub type RoomKeyInfo = (OwnedRoomId, String);
127
128pub(crate) struct BackupDownloadTask {
129    sender: mpsc::UnboundedSender<RoomKeyDownloadRequest>,
130    #[allow(dead_code)]
131    join_handle: JoinHandle<()>,
132}
133
134#[cfg(feature = "e2e-encryption")]
135impl Drop for BackupDownloadTask {
136    fn drop(&mut self) {
137        #[cfg(not(target_arch = "wasm32"))]
138        self.join_handle.abort();
139    }
140}
141
142impl BackupDownloadTask {
143    #[cfg(not(test))]
144    const DOWNLOAD_DELAY_MILLIS: u64 = 100;
145
146    pub(crate) fn new(client: WeakClient) -> Self {
147        let (sender, receiver) = mpsc::unbounded_channel();
148
149        let join_handle = spawn(async move {
150            Self::listen(client, receiver).await;
151        });
152
153        Self { sender, join_handle }
154    }
155
156    /// Trigger a backup download for the keys for the given event.
157    ///
158    /// Does nothing unless the event is encrypted using `m.megolm.v1.aes-sha2`.
159    /// Otherwise, tells the listener task to set off a task to do a backup
160    /// download, unless there is one already running.
161    pub(crate) fn trigger_download_for_utd_event(
162        &self,
163        room_id: OwnedRoomId,
164        event: Raw<OriginalSyncRoomEncryptedEvent>,
165    ) {
166        if let Ok(deserialized_event) = event.deserialize() {
167            if let EncryptedEventScheme::MegolmV1AesSha2(c) = deserialized_event.content.scheme {
168                let _ = self.sender.send(RoomKeyDownloadRequest {
169                    room_id,
170                    event_id: deserialized_event.event_id,
171                    event,
172                    megolm_session_id: c.session_id,
173                });
174            }
175        }
176    }
177
178    /// Listen for incoming [`RoomKeyDownloadRequest`]s and process them.
179    ///
180    /// This will keep running until either the request channel is closed, or
181    /// all other references to `Client` are dropped.
182    ///
183    /// # Arguments
184    ///
185    /// * `receiver` - The source of incoming [`RoomKeyDownloadRequest`]s.
186    async fn listen(client: WeakClient, mut receiver: UnboundedReceiver<RoomKeyDownloadRequest>) {
187        let state = Arc::new(Mutex::new(BackupDownloadTaskListenerState::new(client)));
188
189        while let Some(room_key_download_request) = receiver.recv().await {
190            let mut state_guard = state.lock().await;
191
192            if state_guard.client.strong_count() == 0 {
193                trace!("Client got dropped, shutting down the task");
194                break;
195            }
196
197            // Check that we don't already have a task to process this event, and fire one
198            // off else if not.
199            let event_id = &room_key_download_request.event_id;
200            if !state_guard.active_tasks.contains_key(event_id) {
201                let event_id = event_id.to_owned();
202                let task =
203                    spawn(Self::handle_download_request(state.clone(), room_key_download_request));
204                state_guard.active_tasks.insert(event_id, task);
205            }
206        }
207    }
208
209    /// Handle a request to download a room key for a given event.
210    ///
211    /// Sleeps for a while to see if the key turns up; then checks if we still
212    /// want to do a download, and does the download if so.
213    async fn handle_download_request(
214        state: Arc<Mutex<BackupDownloadTaskListenerState>>,
215        download_request: RoomKeyDownloadRequest,
216    ) {
217        // Wait a bit, perhaps the room key will arrive in the meantime.
218        #[cfg(not(test))]
219        crate::sleep::sleep(Duration::from_millis(Self::DOWNLOAD_DELAY_MILLIS)).await;
220
221        // Now take the lock, and check that we still want to do a download. If we do,
222        // keep hold of a strong reference to the `Client`.
223        let client = {
224            let mut state = state.lock().await;
225
226            let Some(client) = state.client.get() else {
227                // The client was dropped while we were sleeping. We should just bail out;
228                // the main BackupDownloadTask loop will bail out too.
229                return;
230            };
231
232            // Check that we still want to do a download.
233            if !state.should_download(&client, &download_request).await {
234                // We decided against doing a download. Mark the job done for this event before
235                // dropping the lock.
236                state.active_tasks.remove(&download_request.event_id);
237                return;
238            }
239
240            // Before we drop the lock, indicate to other tasks that may be considering this
241            // room key, that we're going to go ahead and do a download.
242            state.downloaded_room_keys.insert(download_request.to_room_key_info());
243
244            client
245        };
246
247        // Do the download without holding the lock.
248        let result = client
249            .encryption()
250            .backups()
251            .download_room_key(&download_request.room_id, &download_request.megolm_session_id)
252            .await;
253
254        // Then take the lock again to update the state.
255        {
256            let mut state = state.lock().await;
257            let room_key_info = download_request.to_room_key_info();
258
259            match result {
260                Ok(true) => {
261                    // We successfully downloaded the room key. We can clear any record of previous
262                    // backoffs from the failures cache, because we won't be needing them again.
263                    state.failures_cache.remove(std::iter::once(&room_key_info))
264                }
265                Ok(false) => {
266                    // We did not find a valid backup decryption key or backup version, we did not
267                    // even attempt to download the room key.
268                    state.downloaded_room_keys.remove(std::iter::once(&room_key_info));
269                }
270                Err(_) => {
271                    // We were unable to download the room key. Update the failure cache so that we
272                    // back off from more requests, and also remove the entry from the list of
273                    // room keys that we are downloading.
274                    state.downloaded_room_keys.remove(std::iter::once(&room_key_info));
275                    state.failures_cache.insert(room_key_info);
276                }
277            }
278
279            state.active_tasks.remove(&download_request.event_id);
280        }
281    }
282}
283
284/// The state for an active [`BackupDownloadTask`].
285struct BackupDownloadTaskListenerState {
286    /// Reference to the `Client`, which will be used to fire off the download
287    /// requests.
288    client: WeakClient,
289
290    /// A record of backup download attempts that have recently failed.
291    failures_cache: FailuresCache<RoomKeyInfo>,
292
293    /// Map from event ID to download task
294    active_tasks: BTreeMap<OwnedEventId, JoinHandle<()>>,
295
296    /// A list of room keys that we have already downloaded, or are about to
297    /// download.
298    ///
299    /// The idea here is that once we've (successfully) downloaded a room key
300    /// from the backup, there's not much point trying again even if we get
301    /// another UTD event that uses the same room key.
302    downloaded_room_keys: DownloadCache,
303}
304
305impl BackupDownloadTaskListenerState {
306    /// Prepare a new `BackupDownloadTaskListenerState`.
307    ///
308    /// # Arguments
309    ///
310    /// * `client` - A reference to the `Client`, which is used to fire off the
311    ///   backup download request.
312    pub fn new(client: WeakClient) -> Self {
313        Self {
314            client,
315            failures_cache: FailuresCache::with_settings(Duration::from_secs(60 * 60 * 24), 60),
316            active_tasks: Default::default(),
317            downloaded_room_keys: DownloadCache::with_settings(
318                Duration::from_secs(60 * 60 * 24),
319                60,
320            ),
321        }
322    }
323
324    /// Check if we should set off a download for the given request.
325    ///
326    /// Checks if:
327    ///  * we already have the key,
328    ///  * we have already downloaded this room key, or are about to do so, or
329    ///  * we've backed off from trying to download this room key.
330    ///
331    /// If any of the above are true, returns `false`. Otherwise, returns
332    /// `true`.
333    pub async fn should_download(
334        &self,
335        client: &Client,
336        download_request: &RoomKeyDownloadRequest,
337    ) -> bool {
338        // Check that the Client has an OlmMachine
339        let machine_guard = client.olm_machine().await;
340        let Some(machine) = machine_guard.as_ref() else {
341            return false;
342        };
343
344        // If backups aren't enabled, there's no point in trying to download a room key.
345        if !client.encryption().backups().are_enabled().await {
346            debug!(
347                ?download_request,
348                "Not performing backup download because backups are not enabled"
349            );
350
351            return false;
352        }
353
354        // Check if the keys for this message have arrived in the meantime.
355        // If we get a StoreError doing the lookup, we assume the keys haven't arrived
356        // (though if the store is returning errors, probably something else is
357        // going to go wrong very soon).
358        if machine
359            .is_room_key_available(download_request.event.cast_ref(), &download_request.room_id)
360            .await
361            .unwrap_or(false)
362        {
363            debug!(?download_request, "Not performing backup download because key became available while we were sleeping");
364            return false;
365        }
366
367        // Check if we already downloaded this room key, or another task is in the
368        // process of doing so.
369        let room_key_info = download_request.to_room_key_info();
370        if self.downloaded_room_keys.contains(&room_key_info) {
371            debug!(
372                ?download_request,
373                "Not performing backup download because this room key has already been downloaded recently"
374            );
375            return false;
376        };
377
378        // Check if we're backing off from attempts to download this room key
379        if self.failures_cache.contains(&room_key_info) {
380            debug!(
381                ?download_request,
382                "Not performing backup download because this room key failed to download recently"
383            );
384            return false;
385        }
386
387        debug!(?download_request, "Performing backup download");
388        true
389    }
390}
391
392#[cfg(all(test, not(target_arch = "wasm32")))]
393mod test {
394    use matrix_sdk_test::async_test;
395    use ruma::{event_id, room_id};
396    use serde_json::json;
397    use wiremock::MockServer;
398
399    use super::*;
400    use crate::test_utils::logged_in_client;
401
402    // Test that, if backups are not enabled, we don't incorrectly mark a room key
403    // as downloaded.
404    #[async_test]
405    async fn test_disabled_backup_does_not_mark_room_key_as_downloaded() {
406        let room_id = room_id!("!DovneieKSTkdHKpIXy:morpheus.localhost");
407        let event_id = event_id!("$JbFHtZpEJiH8uaajZjPLz0QUZc1xtBR9rPGBOjF6WFM");
408        let session_id = "session_id";
409
410        let server = MockServer::start().await;
411        let client = logged_in_client(Some(server.uri())).await;
412        let weak_client = WeakClient::from_client(&client);
413
414        let event_content = json!({
415            "event_id": event_id,
416            "origin_server_ts": 1698579035927u64,
417            "sender": "@example2:morpheus.localhost",
418            "type": "m.room.encrypted",
419            "content": {
420                "algorithm": "m.megolm.v1.aes-sha2",
421                "ciphertext": "AwgAEpABhetEzzZzyYrxtEVUtlJnZtJcURBlQUQJ9irVeklCTs06LwgTMQj61PMUS4Vy\
422                               YOX+PD67+hhU40/8olOww+Ud0m2afjMjC3wFX+4fFfSkoWPVHEmRVucfcdSF1RSB4EmK\
423                               PIP4eo1X6x8kCIMewBvxl2sI9j4VNvDvAN7M3zkLJfFLOFHbBviI4FN7hSFHFeM739Zg\
424                               iwxEs3hIkUXEiAfrobzaMEM/zY7SDrTdyffZndgJo7CZOVhoV6vuaOhmAy4X2t4UnbuV\
425                               JGJjKfV57NAhp8W+9oT7ugwO",
426                "device_id": "KIUVQQSDTM",
427                "sender_key": "LvryVyoCjdONdBCi2vvoSbI34yTOx7YrCFACUEKoXnc",
428                "session_id": "64H7XKokIx0ASkYDHZKlT5zd/Zccz/cQspPNdvnNULA"
429            }
430        });
431
432        let event: Raw<OriginalSyncRoomEncryptedEvent> =
433            serde_json::from_value(event_content).expect("");
434
435        let state = Arc::new(Mutex::new(BackupDownloadTaskListenerState::new(weak_client)));
436        let download_request = RoomKeyDownloadRequest {
437            room_id: room_id.into(),
438            megolm_session_id: session_id.to_owned(),
439            event,
440            event_id: event_id.into(),
441        };
442
443        assert!(
444            !client.encryption().backups().are_enabled().await,
445            "Backups should not be enabled."
446        );
447
448        BackupDownloadTask::handle_download_request(state.clone(), download_request).await;
449
450        {
451            let state = state.lock().await;
452            assert!(
453                !state.downloaded_room_keys.contains(&(room_id.to_owned(), session_id.to_owned())),
454                "Backups are not enabled, we should not mark any room keys as downloaded."
455            )
456        }
457    }
458}