matrix_sdk_crypto/session_manager/group_sessions/
mod.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod share_strategy;
16
17use std::{
18    collections::{BTreeMap, BTreeSet},
19    fmt::Debug,
20    iter,
21    iter::zip,
22    sync::Arc,
23};
24
25use futures_util::future::join_all;
26use itertools::Itertools;
27use matrix_sdk_common::{
28    deserialized_responses::WithheldCode, executor::spawn, locks::RwLock as StdRwLock,
29};
30#[cfg(feature = "experimental-encrypted-state-events")]
31use ruma::events::AnyStateEventContent;
32use ruma::{
33    DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId,
34    UserId,
35    events::{AnyMessageLikeEventContent, AnyToDeviceEventContent, ToDeviceEventType},
36    serde::Raw,
37    to_device::DeviceIdOrAllDevices,
38};
39use serde::Serialize;
40pub use share_strategy::CollectStrategy;
41#[cfg(feature = "experimental-send-custom-to-device")]
42pub(crate) use share_strategy::split_devices_for_share_strategy;
43pub(crate) use share_strategy::{
44    CollectRecipientsResult, withheld_code_for_device_for_share_strategy,
45};
46use tracing::{Instrument, debug, error, info, instrument, trace, warn};
47
48#[cfg(feature = "experimental-encrypted-state-events")]
49use crate::types::events::room::encrypted::RoomEncryptedEventContent;
50use crate::{
51    Device, DeviceData, EncryptionSettings, OlmError,
52    error::{EventError, MegolmResult, OlmResult},
53    identities::device::MaybeEncryptedRoomKey,
54    olm::{
55        InboundGroupSession, OutboundGroupSession, OutboundGroupSessionEncryptionResult,
56        SenderData, SenderDataFinder, Session, ShareInfo, ShareState,
57    },
58    store::{CryptoStoreWrapper, Result as StoreResult, Store, types::Changes},
59    types::{
60        events::{
61            EventType, room::encrypted::ToDeviceEncryptedEventContent,
62            room_key_bundle::RoomKeyBundleContent,
63        },
64        requests::ToDeviceRequest,
65    },
66};
67
68#[derive(Clone, Debug)]
69pub(crate) struct GroupSessionCache {
70    store: Store,
71    sessions: Arc<StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>>,
72    /// A map from the request id to the group session that the request belongs
73    /// to. Used to mark requests belonging to the session as shared.
74    sessions_being_shared: Arc<StdRwLock<BTreeMap<OwnedTransactionId, OutboundGroupSession>>>,
75}
76
77impl GroupSessionCache {
78    pub(crate) fn new(store: Store) -> Self {
79        Self { store, sessions: Default::default(), sessions_being_shared: Default::default() }
80    }
81
82    pub(crate) fn insert(&self, session: OutboundGroupSession) {
83        self.sessions.write().insert(session.room_id().to_owned(), session);
84    }
85
86    /// Either get a session for the given room from the cache or load it from
87    /// the store.
88    ///
89    /// # Arguments
90    ///
91    /// * `room_id` - The id of the room this session is used for.
92    pub async fn get_or_load(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
93        // Get the cached session, if there isn't one load one from the store
94        // and put it in the cache.
95        if let Some(s) = self.sessions.read().get(room_id) {
96            return Some(s.clone());
97        }
98
99        match self.store.get_outbound_group_session(room_id).await {
100            Ok(Some(s)) => {
101                {
102                    let mut sessions_being_shared = self.sessions_being_shared.write();
103                    for request_id in s.pending_request_ids() {
104                        sessions_being_shared.insert(request_id, s.clone());
105                    }
106                }
107
108                self.sessions.write().insert(room_id.to_owned(), s.clone());
109
110                Some(s)
111            }
112            Ok(None) => None,
113            Err(e) => {
114                error!("Couldn't restore an outbound group session: {e:?}");
115                None
116            }
117        }
118    }
119
120    /// Get an outbound group session for a room, if one exists.
121    ///
122    /// # Arguments
123    ///
124    /// * `room_id` - The id of the room for which we should get the outbound
125    ///   group session.
126    fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
127        self.sessions.read().get(room_id).cloned()
128    }
129
130    /// Returns whether any session is withheld with the given device and code.
131    fn has_session_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
132        self.sessions.read().values().any(|s| s.sharing_view().is_withheld_to(device, code))
133    }
134
135    fn remove_from_being_shared(&self, id: &TransactionId) -> Option<OutboundGroupSession> {
136        self.sessions_being_shared.write().remove(id)
137    }
138
139    fn mark_as_being_shared(&self, id: OwnedTransactionId, session: OutboundGroupSession) {
140        self.sessions_being_shared.write().insert(id, session);
141    }
142}
143
144#[derive(Debug, Clone)]
145pub(crate) struct GroupSessionManager {
146    /// Store for the encryption keys.
147    /// Persists all the encryption keys so a client can resume the session
148    /// without the need to create new keys.
149    store: Store,
150    /// The currently active outbound group sessions.
151    sessions: GroupSessionCache,
152}
153
154impl GroupSessionManager {
155    const MAX_TO_DEVICE_MESSAGES: usize = 250;
156
157    pub fn new(store: Store) -> Self {
158        Self { store: store.clone(), sessions: GroupSessionCache::new(store) }
159    }
160
161    pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
162        if let Some(s) = self.sessions.get(room_id) {
163            s.invalidate_session();
164
165            let mut changes = Changes::default();
166            changes.outbound_group_sessions.push(s.clone());
167            self.store.save_changes(changes).await?;
168
169            Ok(true)
170        } else {
171            Ok(false)
172        }
173    }
174
175    pub async fn mark_request_as_sent(&self, request_id: &TransactionId) -> StoreResult<()> {
176        let Some(session) = self.sessions.remove_from_being_shared(request_id) else {
177            return Ok(());
178        };
179
180        let no_olm = session.mark_request_as_sent(request_id);
181
182        let mut changes = Changes::default();
183
184        for (user_id, devices) in &no_olm {
185            for device_id in devices {
186                let device = self.store.get_device(user_id, device_id).await;
187
188                if let Ok(Some(device)) = device {
189                    device.mark_withheld_code_as_sent();
190                    changes.devices.changed.push(device.inner.clone());
191                } else {
192                    error!(
193                        ?request_id,
194                        "Marking to-device no olm as sent but device not found, might \
195                            have been deleted?"
196                    );
197                }
198            }
199        }
200
201        changes.outbound_group_sessions.push(session.clone());
202        self.store.save_changes(changes).await
203    }
204
205    #[cfg(test)]
206    pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
207        self.sessions.get(room_id)
208    }
209
210    pub async fn encrypt(
211        &self,
212        room_id: &RoomId,
213        event_type: &str,
214        content: &Raw<AnyMessageLikeEventContent>,
215    ) -> MegolmResult<OutboundGroupSessionEncryptionResult> {
216        let session =
217            self.sessions.get_or_load(room_id).await.expect("Session wasn't created nor shared");
218
219        assert!(!session.expired(), "Session expired");
220
221        let result = session.encrypt(event_type, content).await;
222
223        let mut changes = Changes::default();
224        changes.outbound_group_sessions.push(session);
225        self.store.save_changes(changes).await?;
226
227        Ok(result)
228    }
229
230    /// Encrypts a state event for the given room using its outbound group
231    /// session.
232    ///
233    /// # Arguments
234    ///
235    /// * `room_id` - The ID of the room where the state event will be sent.
236    /// * `event_type` - The type of the state event to encrypt.
237    /// * `state_key` - The state key associated with the event.
238    /// * `content` - The raw content of the state event to encrypt.
239    ///
240    /// # Returns
241    ///
242    /// Returns the raw encrypted state event content.
243    ///
244    /// # Errors
245    ///
246    /// Returns an error if saving changes to the store fails.
247    ///
248    /// # Panics
249    ///
250    /// Panics if no session exists for the given room ID, or the session
251    /// has expired.
252    #[cfg(feature = "experimental-encrypted-state-events")]
253    pub async fn encrypt_state(
254        &self,
255        room_id: &RoomId,
256        event_type: &str,
257        state_key: &str,
258        content: &Raw<AnyStateEventContent>,
259    ) -> MegolmResult<Raw<RoomEncryptedEventContent>> {
260        let session =
261            self.sessions.get_or_load(room_id).await.expect("Session wasn't created nor shared");
262
263        assert!(!session.expired(), "Session expired");
264
265        let content = session.encrypt_state(event_type, state_key, content).await;
266
267        let mut changes = Changes::default();
268        changes.outbound_group_sessions.push(session);
269        self.store.save_changes(changes).await?;
270
271        Ok(content)
272    }
273
274    /// Create a new outbound group session.
275    ///
276    /// This also creates a matching inbound group session.
277    pub async fn create_outbound_group_session(
278        &self,
279        room_id: &RoomId,
280        settings: EncryptionSettings,
281        own_sender_data: SenderData,
282    ) -> OlmResult<(OutboundGroupSession, InboundGroupSession)> {
283        let (outbound, inbound) = self
284            .store
285            .static_account()
286            .create_group_session_pair(room_id, settings, own_sender_data)
287            .await
288            .map_err(|_| EventError::UnsupportedAlgorithm)?;
289
290        self.sessions.insert(outbound.clone());
291        Ok((outbound, inbound))
292    }
293
294    pub async fn get_or_create_outbound_session(
295        &self,
296        room_id: &RoomId,
297        settings: EncryptionSettings,
298        own_sender_data: SenderData,
299    ) -> OlmResult<(OutboundGroupSession, Option<InboundGroupSession>)> {
300        let outbound_session = self.sessions.get_or_load(room_id).await;
301
302        // If there is no session or the session has expired or is invalid,
303        // create a new one.
304        if let Some(s) = outbound_session {
305            if s.expired() || s.invalidated() {
306                self.create_outbound_group_session(room_id, settings, own_sender_data)
307                    .await
308                    .map(|(o, i)| (o, i.into()))
309            } else {
310                Ok((s, None))
311            }
312        } else {
313            self.create_outbound_group_session(room_id, settings, own_sender_data)
314                .await
315                .map(|(o, i)| (o, i.into()))
316        }
317    }
318
319    /// Encrypt the given group session key for the given devices and create
320    /// to-device requests that sends the encrypted content to them.
321    ///
322    /// See also [`encrypt_content_for_devices`] which is similar
323    /// but is not specific to group sessions, and does not return the
324    /// [`ShareInfo`] data.
325    async fn encrypt_session_for(
326        store: Arc<CryptoStoreWrapper>,
327        group_session: OutboundGroupSession,
328        devices: Vec<DeviceData>,
329    ) -> OlmResult<(
330        EncryptForDevicesResult,
331        BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>,
332    )> {
333        // Use a named type instead of a tuple with rather long type name
334        pub struct DeviceResult {
335            device: DeviceData,
336            maybe_encrypted_room_key: MaybeEncryptedRoomKey,
337        }
338
339        let mut result_builder = EncryptForDevicesResultBuilder::default();
340        let mut share_infos = BTreeMap::new();
341
342        // XXX is there a way to do this that doesn't involve cloning the
343        // `Arc<CryptoStoreWrapper>` for each device?
344        let encrypt = |store: Arc<CryptoStoreWrapper>,
345                       device: DeviceData,
346                       session: OutboundGroupSession| async move {
347            let encryption_result = device.maybe_encrypt_room_key(store.as_ref(), session).await?;
348
349            Ok::<_, OlmError>(DeviceResult { device, maybe_encrypted_room_key: encryption_result })
350        };
351
352        let tasks: Vec<_> = devices
353            .iter()
354            .map(|d| spawn(encrypt(store.clone(), d.clone(), group_session.clone())))
355            .collect();
356
357        let results = join_all(tasks).await;
358
359        for result in results {
360            let result = result.expect("Encryption task panicked")?;
361
362            match result.maybe_encrypted_room_key {
363                MaybeEncryptedRoomKey::Encrypted { used_session, share_info, message } => {
364                    result_builder.on_successful_encryption(&result.device, *used_session, message);
365
366                    let user_id = result.device.user_id().to_owned();
367                    let device_id = result.device.device_id().to_owned();
368                    share_infos
369                        .entry(user_id)
370                        .or_insert_with(BTreeMap::new)
371                        .insert(device_id, *share_info);
372                }
373                MaybeEncryptedRoomKey::MissingSession => {
374                    result_builder.on_missing_session(result.device);
375                }
376            }
377        }
378
379        Ok((result_builder.into_result(), share_infos))
380    }
381
382    /// Given a list of user and an outbound session, return the list of users
383    /// and their devices that this session should be shared with.
384    ///
385    /// Returns information indicating whether the session needs to be rotated
386    /// and the list of users/devices that should receive or not the session
387    /// (with withheld reason).
388    #[instrument(skip_all)]
389    pub async fn collect_session_recipients(
390        &self,
391        users: impl Iterator<Item = &UserId>,
392        settings: &EncryptionSettings,
393        outbound: &OutboundGroupSession,
394    ) -> OlmResult<CollectRecipientsResult> {
395        share_strategy::collect_session_recipients(&self.store, users, settings, outbound).await
396    }
397
398    async fn encrypt_request(
399        store: Arc<CryptoStoreWrapper>,
400        chunk: Vec<DeviceData>,
401        outbound: OutboundGroupSession,
402        sessions: GroupSessionCache,
403    ) -> OlmResult<(Vec<Session>, Vec<(DeviceData, WithheldCode)>)> {
404        let (result, share_infos) =
405            Self::encrypt_session_for(store, outbound.clone(), chunk).await?;
406
407        if let Some(request) = result.to_device_request {
408            let id = request.txn_id.clone();
409            outbound.add_request(id.clone(), request.into(), share_infos);
410            sessions.mark_as_being_shared(id, outbound.clone());
411        }
412
413        Ok((result.updated_olm_sessions, result.no_olm_devices))
414    }
415
416    pub(crate) fn session_cache(&self) -> GroupSessionCache {
417        self.sessions.clone()
418    }
419
420    async fn maybe_rotate_group_session(
421        &self,
422        should_rotate: bool,
423        room_id: &RoomId,
424        outbound: OutboundGroupSession,
425        encryption_settings: EncryptionSettings,
426        changes: &mut Changes,
427        own_device: Option<Device>,
428    ) -> OlmResult<OutboundGroupSession> {
429        Ok(if should_rotate {
430            let old_session_id = outbound.session_id();
431
432            let (outbound, mut inbound) = self
433                .create_outbound_group_session(room_id, encryption_settings, SenderData::unknown())
434                .await?;
435
436            // Use our own device info to populate the SenderData that validates the
437            // InboundGroupSession that we create as a pair to the OutboundGroupSession we
438            // are sending out.
439            let own_sender_data = if let Some(device) = own_device {
440                SenderDataFinder::find_using_device_data(
441                    &self.store,
442                    device.inner.clone(),
443                    &inbound,
444                )
445                .await?
446            } else {
447                error!("Unable to find our own device!");
448                SenderData::unknown()
449            };
450            inbound.sender_data = own_sender_data;
451
452            changes.outbound_group_sessions.push(outbound.clone());
453            changes.inbound_group_sessions.push(inbound);
454
455            debug!(
456                old_session_id = old_session_id,
457                session_id = outbound.session_id(),
458                "A user or device has left the room since we last sent a \
459                message, or the encryption settings have changed. Rotating the \
460                room key.",
461            );
462
463            outbound
464        } else {
465            outbound
466        })
467    }
468
469    async fn encrypt_for_devices(
470        &self,
471        recipient_devices: Vec<DeviceData>,
472        group_session: &OutboundGroupSession,
473        changes: &mut Changes,
474    ) -> OlmResult<Vec<(DeviceData, WithheldCode)>> {
475        // If we have some recipients, log them here.
476        if !recipient_devices.is_empty() {
477            let recipients = recipient_list_to_users_and_devices(&recipient_devices);
478
479            // If there are new recipients we need to persist the outbound group
480            // session as the to-device requests are persisted with the session.
481            changes.outbound_group_sessions = vec![group_session.clone()];
482
483            let message_index = group_session.message_index().await;
484
485            info!(
486                ?recipients,
487                message_index,
488                room_id = ?group_session.room_id(),
489                session_id = group_session.session_id(),
490                "Trying to encrypt a room key",
491            );
492        }
493
494        // Chunk the recipients out so each to-device request will contain a
495        // limited amount of to-device messages.
496        //
497        // Create concurrent tasks for each chunk of recipients.
498        let tasks: Vec<_> = recipient_devices
499            .chunks(Self::MAX_TO_DEVICE_MESSAGES)
500            .map(|chunk| {
501                spawn(Self::encrypt_request(
502                    self.store.crypto_store(),
503                    chunk.to_vec(),
504                    group_session.clone(),
505                    self.sessions.clone(),
506                ))
507            })
508            .collect();
509
510        let mut withheld_devices = Vec::new();
511
512        // Wait for all the tasks to finish up and queue up the Olm session that
513        // was used to encrypt the room key to be persisted again. This is
514        // needed because each encryption step will mutate the Olm session,
515        // ratcheting its state forward.
516        for result in join_all(tasks).await {
517            let result = result.expect("Encryption task panicked");
518
519            let (used_sessions, failed_no_olm) = result?;
520
521            changes.sessions.extend(used_sessions);
522            withheld_devices.extend(failed_no_olm);
523        }
524
525        Ok(withheld_devices)
526    }
527
528    fn is_withheld_to(
529        &self,
530        group_session: &OutboundGroupSession,
531        device: &DeviceData,
532        code: &WithheldCode,
533    ) -> bool {
534        // The `m.no_olm` withheld code is special because it is supposed to be sent
535        // only once for a given device. The `Device` remembers the flag if we
536        // already sent a `m.no_olm` to this particular device so let's check
537        // that first.
538        //
539        // Keep in mind that any outbound group session might want to send this code to
540        // the device. So we need to check if any of our outbound group sessions
541        // is attempting to send the code to the device.
542        //
543        // This still has a slight race where some other thread might remove the
544        // outbound group session while a third is marking the device as having
545        // received the code.
546        //
547        // Since nothing terrible happens if we do end up sending the withheld code
548        // twice, and removing the race requires us to lock the store because the
549        // `OutboundGroupSession` and the `Device` both interact with the flag we'll
550        // leave it be.
551        if code == &WithheldCode::NoOlm {
552            device.was_withheld_code_sent() || self.sessions.has_session_withheld_to(device, code)
553        } else {
554            group_session.sharing_view().is_withheld_to(device, code)
555        }
556    }
557
558    fn handle_withheld_devices(
559        &self,
560        group_session: &OutboundGroupSession,
561        withheld_devices: Vec<(DeviceData, WithheldCode)>,
562    ) -> OlmResult<()> {
563        // Convert a withheld code for the group session into a to-device event content.
564        let to_content = |code| {
565            let content = group_session.withheld_code(code);
566            Raw::new(&content).expect("We can always serialize a withheld content info").cast()
567        };
568
569        // Helper to convert a chunk of device and withheld code pairs into a to-device
570        // request and it's accompanying share info.
571        let chunk_to_request = |chunk| {
572            let mut messages = BTreeMap::new();
573            let mut share_infos = BTreeMap::new();
574
575            for (device, code) in chunk {
576                let device: DeviceData = device;
577                let code: WithheldCode = code;
578
579                let user_id = device.user_id().to_owned();
580                let device_id = device.device_id().to_owned();
581
582                let share_info = ShareInfo::new_withheld(code.to_owned());
583                let content = to_content(code);
584
585                messages
586                    .entry(user_id.to_owned())
587                    .or_insert_with(BTreeMap::new)
588                    .insert(DeviceIdOrAllDevices::DeviceId(device_id.to_owned()), content);
589
590                share_infos
591                    .entry(user_id)
592                    .or_insert_with(BTreeMap::new)
593                    .insert(device_id, share_info);
594            }
595
596            let txn_id = TransactionId::new();
597
598            let request = ToDeviceRequest {
599                event_type: ToDeviceEventType::from("m.room_key.withheld"),
600                txn_id,
601                messages,
602            };
603
604            (request, share_infos)
605        };
606
607        let result: Vec<_> = withheld_devices
608            .into_iter()
609            .filter(|(device, code)| !self.is_withheld_to(group_session, device, code))
610            .chunks(Self::MAX_TO_DEVICE_MESSAGES)
611            .into_iter()
612            .map(chunk_to_request)
613            .collect();
614
615        for (request, share_info) in result {
616            if !request.messages.is_empty() {
617                let txn_id = request.txn_id.to_owned();
618                group_session.add_request(txn_id.to_owned(), request.into(), share_info);
619
620                self.sessions.mark_as_being_shared(txn_id, group_session.clone());
621            }
622        }
623
624        Ok(())
625    }
626
627    fn log_room_key_sharing_result(requests: &[Arc<ToDeviceRequest>]) {
628        for request in requests {
629            let message_list = Self::to_device_request_to_log_list(request);
630            info!(
631                request_id = ?request.txn_id,
632                ?message_list,
633                "Created batch of to-device messages of type {}",
634                request.event_type
635            );
636        }
637    }
638
639    /// Given a to-device request, build a recipient map suitable for logging.
640    ///
641    /// Returns a list of triples of (message_id, user id, device_id).
642    fn to_device_request_to_log_list(
643        request: &Arc<ToDeviceRequest>,
644    ) -> Vec<(String, String, String)> {
645        #[derive(serde::Deserialize)]
646        struct ContentStub<'a> {
647            #[serde(borrow, default, rename = "org.matrix.msgid")]
648            message_id: Option<&'a str>,
649        }
650
651        let mut result: Vec<(String, String, String)> = Vec::new();
652
653        for (user_id, device_map) in &request.messages {
654            for (device, content) in device_map {
655                let message_id: Option<&str> = content
656                    .deserialize_as_unchecked::<ContentStub<'_>>()
657                    .expect("We should be able to deserialize the content we generated")
658                    .message_id;
659
660                result.push((
661                    message_id.unwrap_or("<undefined>").to_owned(),
662                    user_id.to_string(),
663                    device.to_string(),
664                ));
665            }
666        }
667        result
668    }
669
670    /// Get to-device requests to share a room key with users in a room.
671    ///
672    /// # Arguments
673    ///
674    /// `room_id` - The room id of the room where the room key will be used.
675    ///
676    /// `users` - The list of users that should receive the room key.
677    ///
678    /// `encryption_settings` - The settings that should be used for
679    /// the room key.
680    #[instrument(skip(self, users, encryption_settings), fields(session_id))]
681    pub async fn share_room_key(
682        &self,
683        room_id: &RoomId,
684        users: impl Iterator<Item = &UserId>,
685        encryption_settings: impl Into<EncryptionSettings>,
686    ) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
687        trace!("Checking if a room key needs to be shared");
688
689        let account = self.store.static_account();
690        let device = self.store.get_device(account.user_id(), account.device_id()).await?;
691
692        let encryption_settings = encryption_settings.into();
693        let mut changes = Changes::default();
694
695        // Try to get an existing session or create a new one.
696        let (outbound, inbound) = self
697            .get_or_create_outbound_session(
698                room_id,
699                encryption_settings.clone(),
700                SenderData::unknown(),
701            )
702            .await?;
703        tracing::Span::current().record("session_id", outbound.session_id());
704
705        // Having an inbound group session here means that we created a new
706        // group session pair, which we then need to store.
707        if let Some(mut inbound) = inbound {
708            // Use our own device info to populate the SenderData that validates the
709            // InboundGroupSession that we create as a pair to the OutboundGroupSession we
710            // are sending out.
711            let own_sender_data = if let Some(device) = &device {
712                SenderDataFinder::find_using_device_data(
713                    &self.store,
714                    device.inner.clone(),
715                    &inbound,
716                )
717                .await?
718            } else {
719                error!("Unable to find our own device!");
720                SenderData::unknown()
721            };
722            inbound.sender_data = own_sender_data;
723
724            changes.outbound_group_sessions.push(outbound.clone());
725            changes.inbound_group_sessions.push(inbound);
726        }
727
728        // Collect the recipient devices and check if either the settings
729        // or the recipient list changed in a way that requires the
730        // session to be rotated.
731        let CollectRecipientsResult { should_rotate, devices, mut withheld_devices } =
732            self.collect_session_recipients(users, &encryption_settings, &outbound).await?;
733
734        let outbound = self
735            .maybe_rotate_group_session(
736                should_rotate,
737                room_id,
738                outbound,
739                encryption_settings,
740                &mut changes,
741                device,
742            )
743            .await?;
744
745        // Filter out the devices that already received this room key or have a
746        // to-device message already queued up.
747        let devices: Vec<_> = devices
748            .into_iter()
749            .flat_map(|(_, d)| {
750                d.into_iter().filter(|d| match outbound.sharing_view().get_share_state(d) {
751                    ShareState::NotShared => true,
752                    ShareState::Shared { message_index: _, olm_wedging_index } => {
753                        // If the recipient device's Olm wedging index is higher
754                        // than the value that we stored with the session, that
755                        // means that they tried to unwedge the session since we
756                        // last shared the room key.  So we re-share it with
757                        // them in case they weren't able to decrypt the room
758                        // key the last time we shared it.
759                        olm_wedging_index < d.olm_wedging_index
760                    }
761                    _ => false,
762                })
763            })
764            .collect();
765
766        // The `encrypt_for_devices()` method adds the to-device requests that will send
767        // out the room key to the `OutboundGroupSession`. It doesn't do that
768        // for the m.room_key_withheld events since we might have more of those
769        // coming from the `collect_session_recipients()` method. Instead they get
770        // returned by the method.
771        let unable_to_encrypt_devices =
772            self.encrypt_for_devices(devices, &outbound, &mut changes).await?;
773
774        // Merge the withheld recipients.
775        withheld_devices.extend(unable_to_encrypt_devices);
776
777        // Now handle and add the withheld recipients to the resulting requests to the
778        // `OutboundGroupSession`.
779        self.handle_withheld_devices(&outbound, withheld_devices)?;
780
781        // The to-device requests get added to the outbound group session, this
782        // way we're making sure that they are persisted and scoped to the
783        // session.
784        let requests = outbound.pending_requests();
785
786        if requests.is_empty() {
787            if !outbound.shared() {
788                debug!("The room key doesn't need to be shared with anyone. Marking as shared.");
789
790                outbound.mark_as_shared();
791                changes.outbound_group_sessions.push(outbound.clone());
792            }
793        } else {
794            Self::log_room_key_sharing_result(&requests)
795        }
796
797        // Persist any changes we might have collected.
798        if !changes.is_empty() {
799            let session_count = changes.sessions.len();
800
801            self.store.save_changes(changes).await?;
802
803            trace!(
804                session_count = session_count,
805                "Stored the changed sessions after encrypting an room key"
806            );
807        }
808
809        Ok(requests)
810    }
811
812    /// Collect the devices belonging to the given user, and send the details of
813    /// a room key bundle to those devices.
814    ///
815    /// Returns a list of to-device requests which must be sent.
816    ///
817    /// For security reasons, only "safe" [`CollectStrategy`]s are supported, in
818    /// which the recipient must have signed their
819    /// devices. [`CollectStrategy::AllDevices`] and
820    /// [`CollectStrategy::ErrorOnVerifiedUserProblem`] are "unsafe" in this
821    /// respect,and are treated the same as
822    /// [`CollectStrategy::IdentityBasedStrategy`].
823    #[instrument(skip(self, bundle_data))]
824    pub async fn share_room_key_bundle_data(
825        &self,
826        user_id: &UserId,
827        collect_strategy: &CollectStrategy,
828        bundle_data: RoomKeyBundleContent,
829    ) -> OlmResult<Vec<ToDeviceRequest>> {
830        // Only allow conservative sharing strategies
831        let collect_strategy = match collect_strategy {
832            CollectStrategy::AllDevices | CollectStrategy::ErrorOnVerifiedUserProblem => {
833                warn!(
834                    "Ignoring request to use unsafe sharing strategy {collect_strategy:?} \
835                     for room key history sharing",
836                );
837                &CollectStrategy::IdentityBasedStrategy
838            }
839            CollectStrategy::IdentityBasedStrategy | CollectStrategy::OnlyTrustedDevices => {
840                collect_strategy
841            }
842        };
843
844        let mut changes = Changes::default();
845
846        let CollectRecipientsResult { devices, .. } =
847            share_strategy::collect_recipients_for_share_strategy(
848                &self.store,
849                iter::once(user_id),
850                collect_strategy,
851                None,
852            )
853            .await?;
854
855        let devices = devices.into_values().flatten().collect();
856        let event_type = bundle_data.event_type().to_owned();
857        let (requests, _) = self
858            .encrypt_content_for_devices(devices, &event_type, bundle_data, &mut changes)
859            .await?;
860
861        // TODO: figure out what to do with withheld devices
862
863        // Persist any changes we might have collected.
864        if !changes.is_empty() {
865            let session_count = changes.sessions.len();
866
867            self.store.save_changes(changes).await?;
868
869            trace!(
870                session_count = session_count,
871                "Stored the changed sessions after encrypting an room key"
872            );
873        }
874
875        Ok(requests)
876    }
877
878    /// Encrypt the given content for the given devices and build to-device
879    /// requests to send the encrypted content to them.
880    ///
881    /// Returns a tuple containing (1) the list of to-device requests, and (2)
882    /// the list of devices that we could not find an olm session for (so
883    /// need a withheld message).
884    pub(crate) async fn encrypt_content_for_devices(
885        &self,
886        recipient_devices: Vec<DeviceData>,
887        event_type: &str,
888        content: impl Serialize + Clone + Send + 'static,
889        changes: &mut Changes,
890    ) -> OlmResult<(Vec<ToDeviceRequest>, Vec<(DeviceData, WithheldCode)>)> {
891        let recipients = recipient_list_to_users_and_devices(&recipient_devices);
892        info!(?recipients, "Encrypting content of type {}", event_type);
893
894        // Chunk the recipients out so each to-device request will contain a
895        // limited amount of to-device messages.
896        //
897        // Create concurrent tasks for each chunk of recipients.
898        let tasks: Vec<_> = recipient_devices
899            .chunks(Self::MAX_TO_DEVICE_MESSAGES)
900            .map(|chunk| {
901                spawn(
902                    encrypt_content_for_devices(
903                        self.store.crypto_store(),
904                        event_type.to_owned(),
905                        content.clone(),
906                        chunk.to_vec(),
907                    )
908                    .in_current_span(),
909                )
910            })
911            .collect();
912
913        let mut no_olm_devices = Vec::new();
914        let mut to_device_requests = Vec::new();
915
916        // Wait for all the tasks to finish up and queue up the Olm session that
917        // was used to encrypt the room key to be persisted again. This is
918        // needed because each encryption step will mutate the Olm session,
919        // ratcheting its state forward.
920        for result in join_all(tasks).await {
921            let result = result.expect("Encryption task panicked")?;
922            if let Some(request) = result.to_device_request {
923                to_device_requests.push(request);
924            }
925            changes.sessions.extend(result.updated_olm_sessions);
926            no_olm_devices.extend(result.no_olm_devices);
927        }
928
929        Ok((to_device_requests, no_olm_devices))
930    }
931}
932
933/// Helper for [`GroupSessionManager::encrypt_content_for_devices`].
934///
935/// Encrypt the given content for the given devices and build a to-device
936/// request to send the encrypted content to them.
937///
938/// See also [`GroupSessionManager::encrypt_session_for`], which is similar
939/// but applies specifically to `m.room_key` messages that hold a megolm
940/// session key.
941async fn encrypt_content_for_devices(
942    store: Arc<CryptoStoreWrapper>,
943    event_type: String,
944    content: impl Serialize + Clone + Send + 'static,
945    devices: Vec<DeviceData>,
946) -> OlmResult<EncryptForDevicesResult> {
947    let mut result_builder = EncryptForDevicesResultBuilder::default();
948
949    async fn encrypt(
950        store: Arc<CryptoStoreWrapper>,
951        device: DeviceData,
952        event_type: String,
953        bundle_data: impl Serialize,
954    ) -> OlmResult<(Session, Raw<ToDeviceEncryptedEventContent>)> {
955        device.encrypt(store.as_ref(), &event_type, bundle_data).await
956    }
957
958    let tasks = devices.iter().map(|device| {
959        spawn(
960            encrypt(store.clone(), device.clone(), event_type.clone(), content.clone())
961                .in_current_span(),
962        )
963    });
964
965    let results = join_all(tasks).await;
966
967    for (device, result) in zip(devices, results) {
968        let encryption_result = result.expect("Encryption task panicked");
969
970        match encryption_result {
971            Ok((used_session, message)) => {
972                result_builder.on_successful_encryption(&device, used_session, message.cast());
973            }
974            Err(OlmError::MissingSession) => {
975                // There is no established Olm session for this device
976                result_builder.on_missing_session(device);
977            }
978            Err(e) => return Err(e),
979        }
980    }
981
982    Ok(result_builder.into_result())
983}
984
985/// Result of [`GroupSessionManager::encrypt_session_for`] and
986/// [`encrypt_content_for_devices`].
987#[derive(Debug)]
988struct EncryptForDevicesResult {
989    /// The request to send the to-device messages containing the encrypted
990    /// payload, if any devices were found.
991    to_device_request: Option<ToDeviceRequest>,
992
993    /// The devices which lack an Olm session and therefore need a withheld code
994    no_olm_devices: Vec<(DeviceData, WithheldCode)>,
995
996    /// The Olm sessions which were used to encrypt the requests and now need
997    /// persisting to the store.
998    updated_olm_sessions: Vec<Session>,
999}
1000
1001/// A helper for building [`EncryptForDevicesResult`]
1002#[derive(Debug, Default)]
1003struct EncryptForDevicesResultBuilder {
1004    /// The payloads of the to-device messages
1005    messages: BTreeMap<OwnedUserId, BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>>>,
1006
1007    /// The devices which lack an Olm session and therefore need a withheld code
1008    no_olm_devices: Vec<(DeviceData, WithheldCode)>,
1009
1010    /// The Olm sessions which were used to encrypt the requests and now need
1011    /// persisting to the store.
1012    updated_olm_sessions: Vec<Session>,
1013}
1014
1015impl EncryptForDevicesResultBuilder {
1016    /// Record a successful encryption. The encrypted message is added to the
1017    /// list to be sent, and the olm session is added to the list of those
1018    /// that have been modified.
1019    pub fn on_successful_encryption(
1020        &mut self,
1021        device: &DeviceData,
1022        used_session: Session,
1023        message: Raw<AnyToDeviceEventContent>,
1024    ) {
1025        self.updated_olm_sessions.push(used_session);
1026
1027        self.messages
1028            .entry(device.user_id().to_owned())
1029            .or_default()
1030            .insert(DeviceIdOrAllDevices::DeviceId(device.device_id().to_owned()), message);
1031    }
1032
1033    /// Record a device which didn't have an active Olm session.
1034    pub fn on_missing_session(&mut self, device: DeviceData) {
1035        self.no_olm_devices.push((device, WithheldCode::NoOlm));
1036    }
1037
1038    /// Transform the accumulated results into an [`EncryptForDevicesResult`],
1039    /// wrapping the messages, if any, into a `ToDeviceRequest`.
1040    pub fn into_result(self) -> EncryptForDevicesResult {
1041        let EncryptForDevicesResultBuilder { updated_olm_sessions, no_olm_devices, messages } =
1042            self;
1043
1044        let mut encrypt_for_devices_result = EncryptForDevicesResult {
1045            to_device_request: None,
1046            updated_olm_sessions,
1047            no_olm_devices,
1048        };
1049
1050        if !messages.is_empty() {
1051            let request = ToDeviceRequest {
1052                event_type: ToDeviceEventType::RoomEncrypted,
1053                txn_id: TransactionId::new(),
1054                messages,
1055            };
1056            trace!(
1057                recipient_count = request.message_count(),
1058                transaction_id = ?request.txn_id,
1059                "Created a to-device request carrying room keys",
1060            );
1061            encrypt_for_devices_result.to_device_request = Some(request);
1062        }
1063
1064        encrypt_for_devices_result
1065    }
1066}
1067
1068fn recipient_list_to_users_and_devices(
1069    recipient_devices: &[DeviceData],
1070) -> BTreeMap<&UserId, BTreeSet<&DeviceId>> {
1071    #[allow(unknown_lints, clippy::unwrap_or_default)] // false positive
1072    recipient_devices.iter().fold(BTreeMap::new(), |mut acc, d| {
1073        acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id());
1074        acc
1075    })
1076}
1077
1078#[cfg(test)]
1079mod tests {
1080    use std::{
1081        collections::{BTreeMap, BTreeSet},
1082        iter,
1083        ops::Deref,
1084        sync::Arc,
1085    };
1086
1087    use assert_matches2::assert_let;
1088    use matrix_sdk_common::deserialized_responses::{ProcessedToDeviceEvent, WithheldCode};
1089    use matrix_sdk_test::{async_test, ruma_response_from_json};
1090    use ruma::{
1091        DeviceId, OneTimeKeyAlgorithm, OwnedMxcUri, TransactionId, UInt, UserId,
1092        api::client::{
1093            keys::{claim_keys, get_keys, upload_keys},
1094            to_device::send_event_to_device::v3::Response as ToDeviceResponse,
1095        },
1096        device_id,
1097        events::room::{
1098            EncryptedFileInit, JsonWebKey, JsonWebKeyInit, history_visibility::HistoryVisibility,
1099        },
1100        owned_room_id, room_id,
1101        serde::Base64,
1102        to_device::DeviceIdOrAllDevices,
1103        user_id,
1104    };
1105    use serde_json::{Value, json};
1106
1107    use crate::{
1108        DecryptionSettings, EncryptionSettings, LocalTrust, OlmMachine, TrustRequirement,
1109        identities::DeviceData,
1110        machine::{
1111            EncryptionSyncChanges, test_helpers::get_machine_pair_with_setup_sessions_test_helper,
1112        },
1113        olm::{Account, SenderData},
1114        session_manager::{CollectStrategy, group_sessions::CollectRecipientsResult},
1115        types::{
1116            DeviceKeys, EventEncryptionAlgorithm,
1117            events::{
1118                room::encrypted::EncryptedToDeviceEvent,
1119                room_key_bundle::RoomKeyBundleContent,
1120                room_key_withheld::RoomKeyWithheldContent::{self, MegolmV1AesSha2},
1121            },
1122            requests::ToDeviceRequest,
1123        },
1124    };
1125
1126    fn alice_id() -> &'static UserId {
1127        user_id!("@alice:example.org")
1128    }
1129
1130    fn alice_device_id() -> &'static DeviceId {
1131        device_id!("JLAFKJWSCS")
1132    }
1133
1134    /// Returns a /keys/query response for user "@example:localhost"
1135    fn keys_query_response() -> get_keys::v3::Response {
1136        let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_query.json");
1137        let data: Value = serde_json::from_slice(data).unwrap();
1138        ruma_response_from_json(&data)
1139    }
1140
1141    fn bob_keys_query_response() -> get_keys::v3::Response {
1142        let data = json!({
1143            "device_keys": {
1144                "@bob:localhost": {
1145                    "BOBDEVICE": {
1146                        "user_id": "@bob:localhost",
1147                        "device_id": "BOBDEVICE",
1148                        "algorithms": [
1149                            "m.olm.v1.curve25519-aes-sha2",
1150                            "m.megolm.v1.aes-sha2",
1151                            "m.megolm.v2.aes-sha2"
1152                        ],
1153                        "keys": {
1154                            "curve25519:BOBDEVICE": "QzXDFZj0Pt5xG4r11XGSrqE4mnFOTgRM5pz7n3tzohU",
1155                            "ed25519:BOBDEVICE": "T7QMEXcEo/NfiC/8doVHT+2XnMm0pDpRa27bmE8PlPI"
1156                        },
1157                        "signatures": {
1158                            "@bob:localhost": {
1159                                "ed25519:BOBDEVICE": "1Ee9J02KoVf4DKhT+LkurpZJEygiznqpgkT4lqvMTLtZyzShsVTnwmoMPttuGcJkLp9lMK1egveNYCEaYP80Cw"
1160                            }
1161                        }
1162                    }
1163                }
1164            }
1165        });
1166        ruma_response_from_json(&data)
1167    }
1168
1169    /// Returns a keys claim response for device `BOBDEVICE` of user
1170    /// `@bob:localhost`.
1171    fn bob_one_time_key() -> claim_keys::v3::Response {
1172        let data = json!({
1173            "failures": {},
1174            "one_time_keys":{
1175                "@bob:localhost":{
1176                    "BOBDEVICE":{
1177                      "signed_curve25519:AAAAAAAAAAA": {
1178                          "key":"bm1olfbksjC5SwKxCLLK4XaINCA0FwR/155J85gIpCk",
1179                          "signatures":{
1180                              "@bob:localhost":{
1181                                  "ed25519:BOBDEVICE":"BKyS/+EV76zdZkWgny2D0svZ0ycS3etfyHCrsDgm7MYe166HqQmSoX29HsjGLvE/5F+Sg2zW7RJileUvquPwDA"
1182                              }
1183                          }
1184                      }
1185                    }
1186                }
1187            }
1188        });
1189        ruma_response_from_json(&data)
1190    }
1191
1192    /// Returns a key claim response for device `NMMBNBUSNR` of user
1193    /// `@example2:localhost`
1194    fn keys_claim_response() -> claim_keys::v3::Response {
1195        let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_claim.json");
1196        let data: Value = serde_json::from_slice(data).unwrap();
1197        ruma_response_from_json(&data)
1198    }
1199
1200    async fn machine_with_user_test_helper(user_id: &UserId, device_id: &DeviceId) -> OlmMachine {
1201        let keys_query = keys_query_response();
1202        let txn_id = TransactionId::new();
1203
1204        let machine = OlmMachine::new(user_id, device_id).await;
1205
1206        // complete a /keys/query and /keys/claim for @example:localhost
1207        machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1208        let (txn_id, _keys_claim_request) = machine
1209            .get_missing_sessions(iter::once(user_id!("@example:localhost")))
1210            .await
1211            .unwrap()
1212            .unwrap();
1213        let keys_claim = keys_claim_response();
1214        machine.mark_request_as_sent(&txn_id, &keys_claim).await.unwrap();
1215
1216        // complete a /keys/query and /keys/claim for @bob:localhost
1217        machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1218        let (txn_id, _keys_claim_request) = machine
1219            .get_missing_sessions(iter::once(user_id!("@bob:localhost")))
1220            .await
1221            .unwrap()
1222            .unwrap();
1223        machine.mark_request_as_sent(&txn_id, &bob_one_time_key()).await.unwrap();
1224
1225        machine
1226    }
1227
1228    async fn machine() -> OlmMachine {
1229        machine_with_user_test_helper(alice_id(), alice_device_id()).await
1230    }
1231
1232    async fn machine_with_shared_room_key_test_helper() -> OlmMachine {
1233        let machine = machine().await;
1234        let room_id = room_id!("!test:localhost");
1235        let keys_claim = keys_claim_response();
1236
1237        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1238        let requests =
1239            machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
1240
1241        let outbound =
1242            machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1243
1244        assert!(!outbound.pending_requests().is_empty());
1245        assert!(!outbound.shared());
1246
1247        let response = ToDeviceResponse::new();
1248        for request in requests {
1249            machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1250        }
1251
1252        assert!(outbound.shared());
1253        assert!(outbound.pending_requests().is_empty());
1254
1255        machine
1256    }
1257
1258    #[async_test]
1259    async fn test_sharing() {
1260        let machine = machine().await;
1261        let room_id = room_id!("!test:localhost");
1262        let keys_claim = keys_claim_response();
1263
1264        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1265
1266        let requests =
1267            machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
1268
1269        let event_count: usize = requests
1270            .iter()
1271            .filter(|r| r.event_type == "m.room.encrypted".into())
1272            .map(|r| r.message_count())
1273            .sum();
1274
1275        // The keys claim response has a couple of one-time keys with invalid
1276        // signatures, thus only 148 sessions are actually created, we check
1277        // that all 148 valid sessions get an room key.
1278        assert_eq!(event_count, 148);
1279
1280        let withheld_count: usize = requests
1281            .iter()
1282            .filter(|r| r.event_type == "m.room_key.withheld".into())
1283            .map(|r| r.message_count())
1284            .sum();
1285        assert_eq!(withheld_count, 2);
1286    }
1287
1288    fn count_withheld_from(requests: &[Arc<ToDeviceRequest>], code: WithheldCode) -> usize {
1289        requests
1290            .iter()
1291            .filter(|r| r.event_type == "m.room_key.withheld".into())
1292            .map(|r| {
1293                let mut count = 0;
1294                // count targets
1295                for message in r.messages.values() {
1296                    message.iter().for_each(|(_, content)| {
1297                        let withheld: RoomKeyWithheldContent =
1298                            content.deserialize_as_unchecked::<RoomKeyWithheldContent>().unwrap();
1299
1300                        if let MegolmV1AesSha2(content) = withheld
1301                            && content.withheld_code() == code
1302                        {
1303                            count += 1;
1304                        }
1305                    })
1306                }
1307                count
1308            })
1309            .sum()
1310    }
1311
1312    #[async_test]
1313    async fn test_no_olm_sent_once() {
1314        let machine = machine().await;
1315        let keys_claim = keys_claim_response();
1316
1317        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1318
1319        let first_room_id = room_id!("!test:localhost");
1320
1321        let requests = machine
1322            .share_room_key(first_room_id, users.to_owned(), EncryptionSettings::default())
1323            .await
1324            .unwrap();
1325
1326        // there will be two no_olm
1327        let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1328        assert_eq!(withheld_count, 2);
1329
1330        // Re-sharing same session while request has not been sent should not produces
1331        // withheld
1332        let new_requests = machine
1333            .share_room_key(first_room_id, users, EncryptionSettings::default())
1334            .await
1335            .unwrap();
1336        let withheld_count: usize = count_withheld_from(&new_requests, WithheldCode::NoOlm);
1337        // No additional request was added, still the 2 already pending
1338        assert_eq!(withheld_count, 2);
1339
1340        let response = ToDeviceResponse::new();
1341        for request in requests {
1342            machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1343        }
1344
1345        // The fact that an olm was sent should be remembered even if sharing another
1346        // session in an other room.
1347        let second_room_id = room_id!("!other:localhost");
1348        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1349        let requests = machine
1350            .share_room_key(second_room_id, users, EncryptionSettings::default())
1351            .await
1352            .unwrap();
1353
1354        let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1355        assert_eq!(withheld_count, 0);
1356
1357        // Help how do I simulate the creation of a new session for the device
1358        // with no session now?
1359    }
1360
1361    #[async_test]
1362    async fn test_ratcheted_sharing() {
1363        let machine = machine_with_shared_room_key_test_helper().await;
1364
1365        let room_id = room_id!("!test:localhost");
1366        let late_joiner = user_id!("@bob:localhost");
1367        let keys_claim = keys_claim_response();
1368
1369        let mut users: BTreeSet<_> = keys_claim.one_time_keys.keys().map(Deref::deref).collect();
1370        users.insert(late_joiner);
1371
1372        let requests = machine
1373            .share_room_key(room_id, users.into_iter(), EncryptionSettings::default())
1374            .await
1375            .unwrap();
1376
1377        let event_count: usize = requests
1378            .iter()
1379            .filter(|r| r.event_type == "m.room.encrypted".into())
1380            .map(|r| r.message_count())
1381            .sum();
1382        let outbound =
1383            machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1384
1385        assert_eq!(event_count, 1);
1386        assert!(!outbound.pending_requests().is_empty());
1387    }
1388
1389    #[async_test]
1390    async fn test_changing_encryption_settings() {
1391        let machine = machine_with_shared_room_key_test_helper().await;
1392        let room_id = room_id!("!test:localhost");
1393        let keys_claim = keys_claim_response();
1394
1395        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1396        let outbound =
1397            machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1398
1399        let CollectRecipientsResult { should_rotate, .. } = machine
1400            .inner
1401            .group_session_manager
1402            .collect_session_recipients(users.clone(), &EncryptionSettings::default(), &outbound)
1403            .await
1404            .unwrap();
1405
1406        assert!(!should_rotate);
1407
1408        let settings = EncryptionSettings {
1409            history_visibility: HistoryVisibility::Invited,
1410            ..Default::default()
1411        };
1412
1413        let CollectRecipientsResult { should_rotate, .. } = machine
1414            .inner
1415            .group_session_manager
1416            .collect_session_recipients(users.clone(), &settings, &outbound)
1417            .await
1418            .unwrap();
1419
1420        assert!(should_rotate);
1421
1422        let settings = EncryptionSettings {
1423            algorithm: EventEncryptionAlgorithm::from("m.megolm.v2.aes-sha2"),
1424            ..Default::default()
1425        };
1426
1427        let CollectRecipientsResult { should_rotate, .. } = machine
1428            .inner
1429            .group_session_manager
1430            .collect_session_recipients(users, &settings, &outbound)
1431            .await
1432            .unwrap();
1433
1434        assert!(should_rotate);
1435    }
1436
1437    #[async_test]
1438    async fn test_key_recipient_collecting() {
1439        // The user id comes from the fact that the keys_query.json file uses
1440        // this one.
1441        let user_id = user_id!("@example:localhost");
1442        let device_id = device_id!("TESTDEVICE");
1443        let room_id = room_id!("!test:localhost");
1444
1445        let machine = machine_with_user_test_helper(user_id, device_id).await;
1446
1447        let (outbound, _) = machine
1448            .inner
1449            .group_session_manager
1450            .get_or_create_outbound_session(
1451                room_id,
1452                EncryptionSettings::default(),
1453                SenderData::unknown(),
1454            )
1455            .await
1456            .expect("We should be able to create a new session");
1457        let history_visibility = HistoryVisibility::Joined;
1458        let settings = EncryptionSettings { history_visibility, ..Default::default() };
1459
1460        let users = [user_id].into_iter();
1461
1462        let CollectRecipientsResult { devices: recipients, .. } = machine
1463            .inner
1464            .group_session_manager
1465            .collect_session_recipients(users, &settings, &outbound)
1466            .await
1467            .expect("We should be able to collect the session recipients");
1468
1469        assert!(!recipients[user_id].is_empty());
1470
1471        // Make sure that our own device isn't part of the recipients.
1472        assert!(
1473            !recipients[user_id]
1474                .iter()
1475                .any(|d| d.user_id() == user_id && d.device_id() == device_id)
1476        );
1477
1478        let settings = EncryptionSettings {
1479            sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1480            ..Default::default()
1481        };
1482        let users = [user_id].into_iter();
1483
1484        let CollectRecipientsResult { devices: recipients, .. } = machine
1485            .inner
1486            .group_session_manager
1487            .collect_session_recipients(users, &settings, &outbound)
1488            .await
1489            .expect("We should be able to collect the session recipients");
1490
1491        assert!(recipients[user_id].is_empty());
1492
1493        let device_id = "AFGUOBTZWM".into();
1494        let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1495        device.set_local_trust(LocalTrust::Verified).await.unwrap();
1496        let users = [user_id].into_iter();
1497
1498        let CollectRecipientsResult { devices: recipients, withheld_devices: withheld, .. } =
1499            machine
1500                .inner
1501                .group_session_manager
1502                .collect_session_recipients(users, &settings, &outbound)
1503                .await
1504                .expect("We should be able to collect the session recipients");
1505
1506        assert!(
1507            recipients[user_id]
1508                .iter()
1509                .any(|d| d.user_id() == user_id && d.device_id() == device_id)
1510        );
1511
1512        let devices = machine.get_user_devices(user_id, None).await.unwrap();
1513        devices
1514            .devices()
1515            // Ignore our own device
1516            .filter(|d| d.device_id() != device_id!("TESTDEVICE"))
1517            .for_each(|d| {
1518                if d.is_blacklisted() {
1519                    assert!(withheld.iter().any(|(dev, w)| {
1520                        dev.device_id() == d.device_id() && w == &WithheldCode::Blacklisted
1521                    }));
1522                } else if !d.is_verified() {
1523                    // the device should then be in the list of withhelds
1524                    assert!(withheld.iter().any(|(dev, w)| {
1525                        dev.device_id() == d.device_id() && w == &WithheldCode::Unverified
1526                    }));
1527                }
1528            });
1529
1530        assert_eq!(149, withheld.len());
1531    }
1532
1533    #[async_test]
1534    async fn test_sharing_withheld_only_trusted() {
1535        let machine = machine().await;
1536        let room_id = room_id!("!test:localhost");
1537        let keys_claim = keys_claim_response();
1538
1539        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1540        let settings = EncryptionSettings {
1541            sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1542            ..Default::default()
1543        };
1544
1545        // Trust only one
1546        let user_id = user_id!("@example:localhost");
1547        let device_id = "MWFXPINOAO".into();
1548        let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1549        device.set_local_trust(LocalTrust::Verified).await.unwrap();
1550        machine
1551            .get_device(user_id, "MWVTUXDNNM".into(), None)
1552            .await
1553            .unwrap()
1554            .unwrap()
1555            .set_local_trust(LocalTrust::BlackListed)
1556            .await
1557            .unwrap();
1558
1559        let requests = machine.share_room_key(room_id, users, settings).await.unwrap();
1560
1561        // One room key should be sent
1562        let room_key_count =
1563            requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).count();
1564
1565        assert_eq!(1, room_key_count);
1566
1567        let withheld_count =
1568            requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1569        // Can be send in one batch
1570        assert_eq!(1, withheld_count);
1571
1572        let event_count: usize = requests
1573            .iter()
1574            .filter(|r| r.event_type == "m.room_key.withheld".into())
1575            .map(|r| r.message_count())
1576            .sum();
1577
1578        // withhelds are sent in clear so all device should be counted (even if no OTK)
1579        assert_eq!(event_count, 149);
1580
1581        // One should be blacklisted
1582        let has_blacklist =
1583            requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).any(|r| {
1584                let device_key = DeviceIdOrAllDevices::from(device_id!("MWVTUXDNNM").to_owned());
1585                let content = &r.messages[user_id][&device_key];
1586                let withheld: RoomKeyWithheldContent =
1587                    content.deserialize_as_unchecked::<RoomKeyWithheldContent>().unwrap();
1588                if let MegolmV1AesSha2(content) = withheld {
1589                    content.withheld_code() == WithheldCode::Blacklisted
1590                } else {
1591                    false
1592                }
1593            });
1594
1595        assert!(has_blacklist);
1596    }
1597
1598    #[async_test]
1599    async fn test_no_olm_withheld_only_sent_once() {
1600        let keys_query = keys_query_response();
1601        let txn_id = TransactionId::new();
1602
1603        let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1604
1605        machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1606        machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1607
1608        let first_room = room_id!("!test:localhost");
1609        let second_room = room_id!("!test2:localhost");
1610        let bob_id = user_id!("@bob:localhost");
1611
1612        let settings = EncryptionSettings::default();
1613        let users = [bob_id];
1614
1615        let requests = machine
1616            .share_room_key(first_room, users.into_iter(), settings.to_owned())
1617            .await
1618            .unwrap();
1619
1620        // One withheld request should be sent.
1621        let withheld_count =
1622            requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1623
1624        assert_eq!(withheld_count, 1);
1625        assert_eq!(requests.len(), 1);
1626
1627        // On the second room key share attempt we're not sending another `m.no_olm`
1628        // code since the first one is taking care of this.
1629        let second_requests =
1630            machine.share_room_key(second_room, users.into_iter(), settings).await.unwrap();
1631
1632        let withheld_count =
1633            second_requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1634
1635        assert_eq!(withheld_count, 0);
1636        assert_eq!(second_requests.len(), 0);
1637
1638        let response = ToDeviceResponse::new();
1639
1640        let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1641
1642        // The device should be marked as having the `m.no_olm` code received only after
1643        // the request has been marked as sent.
1644        assert!(!device.was_withheld_code_sent());
1645
1646        for request in requests {
1647            machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1648        }
1649
1650        let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1651
1652        assert!(device.was_withheld_code_sent());
1653    }
1654
1655    #[async_test]
1656    async fn test_resend_session_after_unwedging() {
1657        let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1658        assert_let!(Ok(Some((txn_id, device_keys_request))) = machine.upload_device_keys().await);
1659        let device_keys_response = upload_keys::v3::Response::new(BTreeMap::from([(
1660            OneTimeKeyAlgorithm::SignedCurve25519,
1661            UInt::new(device_keys_request.one_time_keys.len() as u64).unwrap(),
1662        )]));
1663        machine.mark_request_as_sent(&txn_id, &device_keys_response).await.unwrap();
1664
1665        let room_id = room_id!("!test:localhost");
1666
1667        let bob_id = user_id!("@bob:localhost");
1668        let bob_account = Account::new(bob_id);
1669        let keys_query_data = json!({
1670            "device_keys": {
1671                "@bob:localhost": {
1672                    bob_account.device_id.clone(): bob_account.device_keys()
1673                }
1674            }
1675        });
1676        let keys_query: get_keys::v3::Response = ruma_response_from_json(&keys_query_data);
1677        let txn_id = TransactionId::new();
1678        machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1679
1680        let alice_device_keys =
1681            device_keys_request.device_keys.unwrap().deserialize_as::<DeviceKeys>().unwrap();
1682        let mut alice_otks = device_keys_request.one_time_keys.iter();
1683        let alice_device = DeviceData::new(alice_device_keys, LocalTrust::Unset);
1684
1685        {
1686            // Bob creates an Olm session with Alice and encrypts a message to her
1687            let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1688            let mut session = bob_account
1689                .create_outbound_session(
1690                    &alice_device,
1691                    &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1692                    bob_account.device_keys(),
1693                )
1694                .unwrap();
1695            let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1696
1697            let to_device =
1698                EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1699
1700            // Alice decrypts the message
1701            let sync_changes = EncryptionSyncChanges {
1702                to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1703                changed_devices: &Default::default(),
1704                one_time_keys_counts: &Default::default(),
1705                unused_fallback_keys: None,
1706                next_batch_token: None,
1707            };
1708
1709            let decryption_settings =
1710                DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1711
1712            let (decrypted, _) =
1713                machine.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1714
1715            assert_eq!(1, decrypted.len());
1716        }
1717
1718        // Alice shares the room key with Bob
1719        {
1720            let requests = machine
1721                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1722                .await
1723                .unwrap();
1724
1725            // We should have had one to-device event
1726            let event_count: usize = requests
1727                .iter()
1728                .filter(|r| r.event_type == "m.room.encrypted".into())
1729                .map(|r| r.message_count())
1730                .sum();
1731            assert_eq!(event_count, 1);
1732
1733            let response = ToDeviceResponse::new();
1734            for request in requests {
1735                machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1736            }
1737        }
1738
1739        // When Alice shares the room key again, there shouldn't be any
1740        // to-device events, since we already shared with Bob
1741        {
1742            let requests = machine
1743                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1744                .await
1745                .unwrap();
1746
1747            let event_count: usize = requests
1748                .iter()
1749                .filter(|r| r.event_type == "m.room.encrypted".into())
1750                .map(|r| r.message_count())
1751                .sum();
1752            assert_eq!(event_count, 0);
1753        }
1754
1755        // Pretend that Bob wasn't able to decrypt, so he tries to unwedge
1756        {
1757            let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1758            let mut session = bob_account
1759                .create_outbound_session(
1760                    &alice_device,
1761                    &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1762                    bob_account.device_keys(),
1763                )
1764                .unwrap();
1765            let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1766
1767            let to_device =
1768                EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1769
1770            // Alice decrypts the unwedge message
1771            let sync_changes = EncryptionSyncChanges {
1772                to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1773                changed_devices: &Default::default(),
1774                one_time_keys_counts: &Default::default(),
1775                unused_fallback_keys: None,
1776                next_batch_token: None,
1777            };
1778
1779            let decryption_settings =
1780                DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1781
1782            let (decrypted, _) =
1783                machine.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1784
1785            assert_eq!(1, decrypted.len());
1786        }
1787
1788        // When Alice shares the room key again, it should be re-shared with Bob
1789        {
1790            let requests = machine
1791                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1792                .await
1793                .unwrap();
1794
1795            let event_count: usize = requests
1796                .iter()
1797                .filter(|r| r.event_type == "m.room.encrypted".into())
1798                .map(|r| r.message_count())
1799                .sum();
1800            assert_eq!(event_count, 1);
1801
1802            let response = ToDeviceResponse::new();
1803            for request in requests {
1804                machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1805            }
1806        }
1807
1808        // When Alice shares the room key yet again, there shouldn't be any
1809        // to-device events
1810        {
1811            let requests = machine
1812                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1813                .await
1814                .unwrap();
1815
1816            let event_count: usize = requests
1817                .iter()
1818                .filter(|r| r.event_type == "m.room.encrypted".into())
1819                .map(|r| r.message_count())
1820                .sum();
1821            assert_eq!(event_count, 0);
1822        }
1823    }
1824
1825    #[async_test]
1826    async fn test_room_key_bundle_sharing() {
1827        let (alice, bob) = get_machine_pair_with_setup_sessions_test_helper(
1828            user_id!("@alice:localhost"),
1829            user_id!("@bob:localhost"),
1830            false,
1831        )
1832        .await;
1833
1834        // Alice trusts Bob's device
1835        let device = alice.get_device(bob.user_id(), bob.device_id(), None).await.unwrap().unwrap();
1836        device.set_local_trust(LocalTrust::Verified).await.unwrap();
1837
1838        let content = RoomKeyBundleContent {
1839            room_id: owned_room_id!("!room:id"),
1840            file: (EncryptedFileInit {
1841                url: OwnedMxcUri::from("test"),
1842                key: JsonWebKey::from(JsonWebKeyInit {
1843                    kty: "oct".to_owned(),
1844                    key_ops: vec!["encrypt".to_owned(), "decrypt".to_owned()],
1845                    alg: "A256CTR".to_owned(),
1846                    #[allow(clippy::unnecessary_to_owned)]
1847                    k: Base64::new(vec![0u8; 0]),
1848                    ext: true,
1849                }),
1850                iv: Base64::new(vec![0u8; 0]),
1851                hashes: Default::default(),
1852                v: "".to_owned(),
1853            })
1854            .into(),
1855        };
1856
1857        let requests = alice
1858            .share_room_key_bundle_data(
1859                bob.user_id(),
1860                &CollectStrategy::OnlyTrustedDevices,
1861                content,
1862            )
1863            .await
1864            .unwrap();
1865
1866        // There should be exactly one message
1867        let requests: Vec<_> =
1868            requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).collect();
1869        let message_count: usize = requests.iter().map(|r| r.message_count()).sum();
1870        assert_eq!(message_count, 1);
1871
1872        // Bob decrypts the message
1873        let bob_message = requests[0]
1874            .messages
1875            .get(bob.user_id())
1876            .unwrap()
1877            .get(&(bob.device_id().to_owned().into()))
1878            .unwrap();
1879        let to_device = EncryptedToDeviceEvent::new(
1880            alice.user_id().to_owned(),
1881            bob_message.deserialize_as_unchecked().unwrap(),
1882        );
1883
1884        let sync_changes = EncryptionSyncChanges {
1885            to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1886            changed_devices: &Default::default(),
1887            one_time_keys_counts: &Default::default(),
1888            unused_fallback_keys: None,
1889            next_batch_token: None,
1890        };
1891
1892        let decryption_settings =
1893            DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1894
1895        let (decrypted, _) =
1896            bob.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1897        assert_eq!(1, decrypted.len());
1898        use crate::types::events::EventType;
1899        assert_let!(
1900            ProcessedToDeviceEvent::Decrypted { raw, .. } = decrypted.first().unwrap().clone()
1901        );
1902        assert_eq!(
1903            raw.get_field::<String>("type").unwrap().unwrap(),
1904            RoomKeyBundleContent::EVENT_TYPE,
1905        );
1906    }
1907}