matrix_sdk_crypto/session_manager/group_sessions/
mod.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod share_strategy;
16
17use std::{
18    collections::{BTreeMap, BTreeSet},
19    fmt::Debug,
20    sync::Arc,
21};
22
23use futures_util::future::join_all;
24use itertools::Itertools;
25use matrix_sdk_common::{
26    deserialized_responses::WithheldCode, executor::spawn, locks::RwLock as StdRwLock,
27};
28use ruma::{
29    events::{AnyMessageLikeEventContent, ToDeviceEventType},
30    serde::Raw,
31    to_device::DeviceIdOrAllDevices,
32    OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId,
33};
34pub(crate) use share_strategy::CollectRecipientsResult;
35pub use share_strategy::CollectStrategy;
36use tracing::{debug, error, info, instrument, trace};
37
38use crate::{
39    error::{EventError, MegolmResult, OlmResult},
40    identities::device::MaybeEncryptedRoomKey,
41    olm::{
42        InboundGroupSession, OutboundGroupSession, SenderData, SenderDataFinder, Session,
43        ShareInfo, ShareState,
44    },
45    store::{Changes, CryptoStoreWrapper, Result as StoreResult, Store},
46    types::{events::room::encrypted::RoomEncryptedEventContent, requests::ToDeviceRequest},
47    Device, DeviceData, EncryptionSettings, OlmError,
48};
49
50#[derive(Clone, Debug)]
51pub(crate) struct GroupSessionCache {
52    store: Store,
53    sessions: Arc<StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>>,
54    /// A map from the request id to the group session that the request belongs
55    /// to. Used to mark requests belonging to the session as shared.
56    sessions_being_shared: Arc<StdRwLock<BTreeMap<OwnedTransactionId, OutboundGroupSession>>>,
57}
58
59impl GroupSessionCache {
60    pub(crate) fn new(store: Store) -> Self {
61        Self { store, sessions: Default::default(), sessions_being_shared: Default::default() }
62    }
63
64    pub(crate) fn insert(&self, session: OutboundGroupSession) {
65        self.sessions.write().insert(session.room_id().to_owned(), session);
66    }
67
68    /// Either get a session for the given room from the cache or load it from
69    /// the store.
70    ///
71    /// # Arguments
72    ///
73    /// * `room_id` - The id of the room this session is used for.
74    pub async fn get_or_load(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
75        // Get the cached session, if there isn't one load one from the store
76        // and put it in the cache.
77        if let Some(s) = self.sessions.read().get(room_id) {
78            return Some(s.clone());
79        }
80
81        match self.store.get_outbound_group_session(room_id).await {
82            Ok(Some(s)) => {
83                {
84                    let mut sessions_being_shared = self.sessions_being_shared.write();
85                    for request_id in s.pending_request_ids() {
86                        sessions_being_shared.insert(request_id, s.clone());
87                    }
88                }
89
90                self.sessions.write().insert(room_id.to_owned(), s.clone());
91
92                Some(s)
93            }
94            Ok(None) => None,
95            Err(e) => {
96                error!("Couldn't restore an outbound group session: {e:?}");
97                None
98            }
99        }
100    }
101
102    /// Get an outbound group session for a room, if one exists.
103    ///
104    /// # Arguments
105    ///
106    /// * `room_id` - The id of the room for which we should get the outbound
107    ///   group session.
108    fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
109        self.sessions.read().get(room_id).cloned()
110    }
111
112    /// Returns whether any session is withheld with the given device and code.
113    fn has_session_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
114        self.sessions.read().values().any(|s| s.is_withheld_to(device, code))
115    }
116
117    fn remove_from_being_shared(&self, id: &TransactionId) -> Option<OutboundGroupSession> {
118        self.sessions_being_shared.write().remove(id)
119    }
120
121    fn mark_as_being_shared(&self, id: OwnedTransactionId, session: OutboundGroupSession) {
122        self.sessions_being_shared.write().insert(id, session);
123    }
124}
125
126#[derive(Debug, Clone)]
127pub(crate) struct GroupSessionManager {
128    /// Store for the encryption keys.
129    /// Persists all the encryption keys so a client can resume the session
130    /// without the need to create new keys.
131    store: Store,
132    /// The currently active outbound group sessions.
133    sessions: GroupSessionCache,
134}
135
136impl GroupSessionManager {
137    const MAX_TO_DEVICE_MESSAGES: usize = 250;
138
139    pub fn new(store: Store) -> Self {
140        Self { store: store.clone(), sessions: GroupSessionCache::new(store) }
141    }
142
143    pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
144        if let Some(s) = self.sessions.get(room_id) {
145            s.invalidate_session();
146
147            let mut changes = Changes::default();
148            changes.outbound_group_sessions.push(s.clone());
149            self.store.save_changes(changes).await?;
150
151            Ok(true)
152        } else {
153            Ok(false)
154        }
155    }
156
157    pub async fn mark_request_as_sent(&self, request_id: &TransactionId) -> StoreResult<()> {
158        let Some(session) = self.sessions.remove_from_being_shared(request_id) else {
159            return Ok(());
160        };
161
162        let no_olm = session.mark_request_as_sent(request_id);
163
164        let mut changes = Changes::default();
165
166        for (user_id, devices) in &no_olm {
167            for device_id in devices {
168                let device = self.store.get_device(user_id, device_id).await;
169
170                if let Ok(Some(device)) = device {
171                    device.mark_withheld_code_as_sent();
172                    changes.devices.changed.push(device.inner.clone());
173                } else {
174                    error!(
175                        ?request_id,
176                        "Marking to-device no olm as sent but device not found, might \
177                            have been deleted?"
178                    );
179                }
180            }
181        }
182
183        changes.outbound_group_sessions.push(session.clone());
184        self.store.save_changes(changes).await
185    }
186
187    #[cfg(test)]
188    pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
189        self.sessions.get(room_id)
190    }
191
192    pub async fn encrypt(
193        &self,
194        room_id: &RoomId,
195        event_type: &str,
196        content: &Raw<AnyMessageLikeEventContent>,
197    ) -> MegolmResult<Raw<RoomEncryptedEventContent>> {
198        let session =
199            self.sessions.get_or_load(room_id).await.expect("Session wasn't created nor shared");
200
201        assert!(!session.expired(), "Session expired");
202
203        let content = session.encrypt(event_type, content).await;
204
205        let mut changes = Changes::default();
206        changes.outbound_group_sessions.push(session);
207        self.store.save_changes(changes).await?;
208
209        Ok(content)
210    }
211
212    /// Create a new outbound group session.
213    ///
214    /// This also creates a matching inbound group session.
215    pub async fn create_outbound_group_session(
216        &self,
217        room_id: &RoomId,
218        settings: EncryptionSettings,
219        own_sender_data: SenderData,
220    ) -> OlmResult<(OutboundGroupSession, InboundGroupSession)> {
221        let (outbound, inbound) = self
222            .store
223            .static_account()
224            .create_group_session_pair(room_id, settings, own_sender_data)
225            .await
226            .map_err(|_| EventError::UnsupportedAlgorithm)?;
227
228        self.sessions.insert(outbound.clone());
229        Ok((outbound, inbound))
230    }
231
232    pub async fn get_or_create_outbound_session(
233        &self,
234        room_id: &RoomId,
235        settings: EncryptionSettings,
236        own_sender_data: SenderData,
237    ) -> OlmResult<(OutboundGroupSession, Option<InboundGroupSession>)> {
238        let outbound_session = self.sessions.get_or_load(room_id).await;
239
240        // If there is no session or the session has expired or is invalid,
241        // create a new one.
242        if let Some(s) = outbound_session {
243            if s.expired() || s.invalidated() {
244                self.create_outbound_group_session(room_id, settings, own_sender_data)
245                    .await
246                    .map(|(o, i)| (o, i.into()))
247            } else {
248                Ok((s, None))
249            }
250        } else {
251            self.create_outbound_group_session(room_id, settings, own_sender_data)
252                .await
253                .map(|(o, i)| (o, i.into()))
254        }
255    }
256
257    /// Encrypt the given content for the given devices and create a to-device
258    /// requests that sends the encrypted content to them.
259    async fn encrypt_session_for(
260        store: Arc<CryptoStoreWrapper>,
261        group_session: OutboundGroupSession,
262        devices: Vec<DeviceData>,
263    ) -> OlmResult<(
264        OwnedTransactionId,
265        ToDeviceRequest,
266        BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>,
267        Vec<Session>,
268        Vec<(DeviceData, WithheldCode)>,
269    )> {
270        // Use a named type instead of a tuple with rather long type name
271        pub struct DeviceResult {
272            device: DeviceData,
273            maybe_encrypted_room_key: MaybeEncryptedRoomKey,
274        }
275
276        let mut messages = BTreeMap::new();
277        let mut changed_sessions = Vec::new();
278        let mut share_infos = BTreeMap::new();
279        let mut withheld_devices = Vec::new();
280
281        // XXX is there a way to do this that doesn't involve cloning the
282        // `Arc<CryptoStoreWrapper>` for each device?
283        let encrypt = |store: Arc<CryptoStoreWrapper>,
284                       device: DeviceData,
285                       session: OutboundGroupSession| async move {
286            let encryption_result = device.maybe_encrypt_room_key(store.as_ref(), session).await?;
287
288            Ok::<_, OlmError>(DeviceResult { device, maybe_encrypted_room_key: encryption_result })
289        };
290
291        let tasks: Vec<_> = devices
292            .iter()
293            .map(|d| spawn(encrypt(store.clone(), d.clone(), group_session.clone())))
294            .collect();
295
296        let results = join_all(tasks).await;
297
298        for result in results {
299            let result = result.expect("Encryption task panicked")?;
300
301            match result.maybe_encrypted_room_key {
302                MaybeEncryptedRoomKey::Encrypted { used_session, share_info, message } => {
303                    changed_sessions.push(used_session);
304
305                    let user_id = result.device.user_id().to_owned();
306                    let device_id = result.device.device_id().to_owned();
307
308                    messages
309                        .entry(user_id.to_owned())
310                        .or_insert_with(BTreeMap::new)
311                        .insert(DeviceIdOrAllDevices::DeviceId(device_id.to_owned()), message);
312
313                    share_infos
314                        .entry(user_id)
315                        .or_insert_with(BTreeMap::new)
316                        .insert(device_id, share_info);
317                }
318                MaybeEncryptedRoomKey::Withheld { code } => {
319                    withheld_devices.push((result.device, code));
320                }
321            }
322        }
323
324        let txn_id = TransactionId::new();
325        let request = ToDeviceRequest {
326            event_type: ToDeviceEventType::RoomEncrypted,
327            txn_id: txn_id.to_owned(),
328            messages,
329        };
330
331        Ok((txn_id, request, share_infos, changed_sessions, withheld_devices))
332    }
333
334    /// Given a list of user and an outbound session, return the list of users
335    /// and their devices that this session should be shared with.
336    ///
337    /// Returns information indicating whether the session needs to be rotated
338    /// and the list of users/devices that should receive or not the session
339    /// (with withheld reason).
340    #[instrument(skip_all)]
341    pub async fn collect_session_recipients(
342        &self,
343        users: impl Iterator<Item = &UserId>,
344        settings: &EncryptionSettings,
345        outbound: &OutboundGroupSession,
346    ) -> OlmResult<CollectRecipientsResult> {
347        share_strategy::collect_session_recipients(&self.store, users, settings, outbound).await
348    }
349
350    async fn encrypt_request(
351        store: Arc<CryptoStoreWrapper>,
352        chunk: Vec<DeviceData>,
353        outbound: OutboundGroupSession,
354        sessions: GroupSessionCache,
355    ) -> OlmResult<(Vec<Session>, Vec<(DeviceData, WithheldCode)>)> {
356        let (id, request, share_infos, used_sessions, no_olm) =
357            Self::encrypt_session_for(store, outbound.clone(), chunk).await?;
358
359        if !request.messages.is_empty() {
360            trace!(
361                recipient_count = request.message_count(),
362                transaction_id = ?id,
363                "Created a to-device request carrying a room_key"
364            );
365
366            outbound.add_request(id.clone(), request.into(), share_infos);
367            sessions.mark_as_being_shared(id, outbound.clone());
368        }
369
370        Ok((used_sessions, no_olm))
371    }
372
373    pub(crate) fn session_cache(&self) -> GroupSessionCache {
374        self.sessions.clone()
375    }
376
377    async fn maybe_rotate_group_session(
378        &self,
379        should_rotate: bool,
380        room_id: &RoomId,
381        outbound: OutboundGroupSession,
382        encryption_settings: EncryptionSettings,
383        changes: &mut Changes,
384        own_device: Option<Device>,
385    ) -> OlmResult<OutboundGroupSession> {
386        Ok(if should_rotate {
387            let old_session_id = outbound.session_id();
388
389            let (outbound, mut inbound) = self
390                .create_outbound_group_session(room_id, encryption_settings, SenderData::unknown())
391                .await?;
392
393            // Use our own device info to populate the SenderData that validates the
394            // InboundGroupSession that we create as a pair to the OutboundGroupSession we
395            // are sending out.
396            let own_sender_data = if let Some(device) = own_device {
397                SenderDataFinder::find_using_device_data(
398                    &self.store,
399                    device.inner.clone(),
400                    &inbound,
401                )
402                .await?
403            } else {
404                error!("Unable to find our own device!");
405                SenderData::unknown()
406            };
407            inbound.sender_data = own_sender_data;
408
409            changes.outbound_group_sessions.push(outbound.clone());
410            changes.inbound_group_sessions.push(inbound);
411
412            debug!(
413                old_session_id = old_session_id,
414                session_id = outbound.session_id(),
415                "A user or device has left the room since we last sent a \
416                message, or the encryption settings have changed. Rotating the \
417                room key.",
418            );
419
420            outbound
421        } else {
422            outbound
423        })
424    }
425
426    async fn encrypt_for_devices(
427        &self,
428        recipient_devices: Vec<DeviceData>,
429        group_session: &OutboundGroupSession,
430        changes: &mut Changes,
431    ) -> OlmResult<Vec<(DeviceData, WithheldCode)>> {
432        // If we have some recipients, log them here.
433        if !recipient_devices.is_empty() {
434            #[allow(unknown_lints, clippy::unwrap_or_default)] // false positive
435            let recipients = recipient_devices.iter().fold(BTreeMap::new(), |mut acc, d| {
436                acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id());
437                acc
438            });
439
440            // If there are new recipients we need to persist the outbound group
441            // session as the to-device requests are persisted with the session.
442            changes.outbound_group_sessions = vec![group_session.clone()];
443
444            let message_index = group_session.message_index().await;
445
446            info!(
447                ?recipients,
448                message_index,
449                room_id = ?group_session.room_id(),
450                session_id = group_session.session_id(),
451                "Trying to encrypt a room key",
452            );
453        }
454
455        // Chunk the recipients out so each to-device request will contain a
456        // limited amount of to-device messages.
457        //
458        // Create concurrent tasks for each chunk of recipients.
459        let tasks: Vec<_> = recipient_devices
460            .chunks(Self::MAX_TO_DEVICE_MESSAGES)
461            .map(|chunk| {
462                spawn(Self::encrypt_request(
463                    self.store.crypto_store(),
464                    chunk.to_vec(),
465                    group_session.clone(),
466                    self.sessions.clone(),
467                ))
468            })
469            .collect();
470
471        let mut withheld_devices = Vec::new();
472
473        // Wait for all the tasks to finish up and queue up the Olm session that
474        // was used to encrypt the room key to be persisted again. This is
475        // needed because each encryption step will mutate the Olm session,
476        // ratcheting its state forward.
477        for result in join_all(tasks).await {
478            let result = result.expect("Encryption task panicked");
479
480            let (used_sessions, failed_no_olm) = result?;
481
482            changes.sessions.extend(used_sessions);
483            withheld_devices.extend(failed_no_olm);
484        }
485
486        Ok(withheld_devices)
487    }
488
489    fn is_withheld_to(
490        &self,
491        group_session: &OutboundGroupSession,
492        device: &DeviceData,
493        code: &WithheldCode,
494    ) -> bool {
495        // The `m.no_olm` withheld code is special because it is supposed to be sent
496        // only once for a given device. The `Device` remembers the flag if we
497        // already sent a `m.no_olm` to this particular device so let's check
498        // that first.
499        //
500        // Keep in mind that any outbound group session might want to send this code to
501        // the device. So we need to check if any of our outbound group sessions
502        // is attempting to send the code to the device.
503        //
504        // This still has a slight race where some other thread might remove the
505        // outbound group session while a third is marking the device as having
506        // received the code.
507        //
508        // Since nothing terrible happens if we do end up sending the withheld code
509        // twice, and removing the race requires us to lock the store because the
510        // `OutboundGroupSession` and the `Device` both interact with the flag we'll
511        // leave it be.
512        if code == &WithheldCode::NoOlm {
513            device.was_withheld_code_sent() || self.sessions.has_session_withheld_to(device, code)
514        } else {
515            group_session.is_withheld_to(device, code)
516        }
517    }
518
519    fn handle_withheld_devices(
520        &self,
521        group_session: &OutboundGroupSession,
522        withheld_devices: Vec<(DeviceData, WithheldCode)>,
523    ) -> OlmResult<()> {
524        // Convert a withheld code for the group session into a to-device event content.
525        let to_content = |code| {
526            let content = group_session.withheld_code(code);
527            Raw::new(&content).expect("We can always serialize a withheld content info").cast()
528        };
529
530        // Helper to convert a chunk of device and withheld code pairs into a to-device
531        // request and it's accompanying share info.
532        let chunk_to_request = |chunk| {
533            let mut messages = BTreeMap::new();
534            let mut share_infos = BTreeMap::new();
535
536            for (device, code) in chunk {
537                let device: DeviceData = device;
538                let code: WithheldCode = code;
539
540                let user_id = device.user_id().to_owned();
541                let device_id = device.device_id().to_owned();
542
543                let share_info = ShareInfo::new_withheld(code.to_owned());
544                let content = to_content(code);
545
546                messages
547                    .entry(user_id.to_owned())
548                    .or_insert_with(BTreeMap::new)
549                    .insert(DeviceIdOrAllDevices::DeviceId(device_id.to_owned()), content);
550
551                share_infos
552                    .entry(user_id)
553                    .or_insert_with(BTreeMap::new)
554                    .insert(device_id, share_info);
555            }
556
557            let txn_id = TransactionId::new();
558
559            let request = ToDeviceRequest {
560                event_type: ToDeviceEventType::from("m.room_key.withheld"),
561                txn_id,
562                messages,
563            };
564
565            (request, share_infos)
566        };
567
568        let result: Vec<_> = withheld_devices
569            .into_iter()
570            .filter(|(device, code)| !self.is_withheld_to(group_session, device, code))
571            .chunks(Self::MAX_TO_DEVICE_MESSAGES)
572            .into_iter()
573            .map(chunk_to_request)
574            .collect();
575
576        for (request, share_info) in result {
577            if !request.messages.is_empty() {
578                let txn_id = request.txn_id.to_owned();
579                group_session.add_request(txn_id.to_owned(), request.into(), share_info);
580
581                self.sessions.mark_as_being_shared(txn_id, group_session.clone());
582            }
583        }
584
585        Ok(())
586    }
587
588    fn log_room_key_sharing_result(requests: &[Arc<ToDeviceRequest>]) {
589        for request in requests {
590            let message_list = Self::to_device_request_to_log_list(request);
591            info!(
592                request_id = ?request.txn_id,
593                ?message_list,
594                "Created batch of to-device messages of type {}",
595                request.event_type
596            );
597        }
598    }
599
600    /// Given a to-device request, build a recipient map suitable for logging.
601    ///
602    /// Returns a list of triples of (message_id, user id, device_id).
603    fn to_device_request_to_log_list(
604        request: &Arc<ToDeviceRequest>,
605    ) -> Vec<(String, String, String)> {
606        #[derive(serde::Deserialize)]
607        struct ContentStub<'a> {
608            #[serde(borrow, default, rename = "org.matrix.msgid")]
609            message_id: Option<&'a str>,
610        }
611
612        let mut result: Vec<(String, String, String)> = Vec::new();
613
614        for (user_id, device_map) in &request.messages {
615            for (device, content) in device_map {
616                let message_id: Option<&str> = content
617                    .deserialize_as::<ContentStub<'_>>()
618                    .expect("We should be able to deserialize the content we generated")
619                    .message_id;
620
621                result.push((
622                    message_id.unwrap_or("<undefined>").to_owned(),
623                    user_id.to_string(),
624                    device.to_string(),
625                ));
626            }
627        }
628        result
629    }
630
631    /// Get to-device requests to share a room key with users in a room.
632    ///
633    /// # Arguments
634    ///
635    /// `room_id` - The room id of the room where the room key will be used.
636    ///
637    /// `users` - The list of users that should receive the room key.
638    ///
639    /// `encryption_settings` - The settings that should be used for
640    /// the room key.
641    #[instrument(skip(self, users, encryption_settings), fields(session_id))]
642    pub async fn share_room_key(
643        &self,
644        room_id: &RoomId,
645        users: impl Iterator<Item = &UserId>,
646        encryption_settings: impl Into<EncryptionSettings>,
647    ) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
648        trace!("Checking if a room key needs to be shared");
649
650        let account = self.store.static_account();
651        let device = self.store.get_device(account.user_id(), account.device_id()).await?;
652
653        let encryption_settings = encryption_settings.into();
654        let mut changes = Changes::default();
655
656        // Try to get an existing session or create a new one.
657        let (outbound, inbound) = self
658            .get_or_create_outbound_session(
659                room_id,
660                encryption_settings.clone(),
661                SenderData::unknown(),
662            )
663            .await?;
664        tracing::Span::current().record("session_id", outbound.session_id());
665
666        // Having an inbound group session here means that we created a new
667        // group session pair, which we then need to store.
668        if let Some(mut inbound) = inbound {
669            // Use our own device info to populate the SenderData that validates the
670            // InboundGroupSession that we create as a pair to the OutboundGroupSession we
671            // are sending out.
672            let own_sender_data = if let Some(device) = &device {
673                SenderDataFinder::find_using_device_data(
674                    &self.store,
675                    device.inner.clone(),
676                    &inbound,
677                )
678                .await?
679            } else {
680                error!("Unable to find our own device!");
681                SenderData::unknown()
682            };
683            inbound.sender_data = own_sender_data;
684
685            changes.outbound_group_sessions.push(outbound.clone());
686            changes.inbound_group_sessions.push(inbound);
687        }
688
689        // Collect the recipient devices and check if either the settings
690        // or the recipient list changed in a way that requires the
691        // session to be rotated.
692        let CollectRecipientsResult { should_rotate, devices, mut withheld_devices } =
693            self.collect_session_recipients(users, &encryption_settings, &outbound).await?;
694
695        let outbound = self
696            .maybe_rotate_group_session(
697                should_rotate,
698                room_id,
699                outbound,
700                encryption_settings,
701                &mut changes,
702                device,
703            )
704            .await?;
705
706        // Filter out the devices that already received this room key or have a
707        // to-device message already queued up.
708        let devices: Vec<_> = devices
709            .into_iter()
710            .flat_map(|(_, d)| {
711                d.into_iter().filter(|d| match outbound.is_shared_with(d) {
712                    ShareState::NotShared => true,
713                    ShareState::Shared { message_index: _, olm_wedging_index } => {
714                        // If the recipient device's Olm wedging index is higher
715                        // than the value that we stored with the session, that
716                        // means that they tried to unwedge the session since we
717                        // last shared the room key.  So we re-share it with
718                        // them in case they weren't able to decrypt the room
719                        // key the last time we shared it.
720                        olm_wedging_index < d.olm_wedging_index
721                    }
722                    _ => false,
723                })
724            })
725            .collect();
726
727        // The `encrypt_for_devices()` method adds the to-device requests that will send
728        // out the room key to the `OutboundGroupSession`. It doesn't do that
729        // for the m.room_key_withheld events since we might have more of those
730        // coming from the `collect_session_recipients()` method. Instead they get
731        // returned by the method.
732        let unable_to_encrypt_devices =
733            self.encrypt_for_devices(devices, &outbound, &mut changes).await?;
734
735        // Merge the withheld recipients.
736        withheld_devices.extend(unable_to_encrypt_devices);
737
738        // Now handle and add the withheld recipients to the resulting requests to the
739        // `OutboundGroupSession`.
740        self.handle_withheld_devices(&outbound, withheld_devices)?;
741
742        // The to-device requests get added to the outbound group session, this
743        // way we're making sure that they are persisted and scoped to the
744        // session.
745        let requests = outbound.pending_requests();
746
747        if requests.is_empty() {
748            if !outbound.shared() {
749                debug!("The room key doesn't need to be shared with anyone. Marking as shared.");
750
751                outbound.mark_as_shared();
752                changes.outbound_group_sessions.push(outbound.clone());
753            }
754        } else {
755            Self::log_room_key_sharing_result(&requests)
756        }
757
758        // Persist any changes we might have collected.
759        if !changes.is_empty() {
760            let session_count = changes.sessions.len();
761
762            self.store.save_changes(changes).await?;
763
764            trace!(
765                session_count = session_count,
766                "Stored the changed sessions after encrypting an room key"
767            );
768        }
769
770        Ok(requests)
771    }
772}
773
774#[cfg(test)]
775mod tests {
776    use std::{
777        collections::{BTreeMap, BTreeSet},
778        iter,
779        ops::Deref,
780        sync::Arc,
781    };
782
783    use assert_matches2::assert_let;
784    use matrix_sdk_common::deserialized_responses::WithheldCode;
785    use matrix_sdk_test::{async_test, ruma_response_from_json};
786    use ruma::{
787        api::client::{
788            keys::{claim_keys, get_keys, upload_keys},
789            to_device::send_event_to_device::v3::Response as ToDeviceResponse,
790        },
791        device_id,
792        events::room::history_visibility::HistoryVisibility,
793        room_id,
794        to_device::DeviceIdOrAllDevices,
795        user_id, DeviceId, OneTimeKeyAlgorithm, TransactionId, UInt, UserId,
796    };
797    use serde_json::{json, Value};
798
799    use crate::{
800        identities::DeviceData,
801        machine::EncryptionSyncChanges,
802        olm::{Account, SenderData},
803        session_manager::{group_sessions::CollectRecipientsResult, CollectStrategy},
804        types::{
805            events::{
806                room::encrypted::EncryptedToDeviceEvent,
807                room_key_withheld::RoomKeyWithheldContent::{self, MegolmV1AesSha2},
808            },
809            requests::ToDeviceRequest,
810            DeviceKeys, EventEncryptionAlgorithm,
811        },
812        EncryptionSettings, LocalTrust, OlmMachine,
813    };
814
815    fn alice_id() -> &'static UserId {
816        user_id!("@alice:example.org")
817    }
818
819    fn alice_device_id() -> &'static DeviceId {
820        device_id!("JLAFKJWSCS")
821    }
822
823    /// Returns a /keys/query response for user "@example:localhost"
824    fn keys_query_response() -> get_keys::v3::Response {
825        let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_query.json");
826        let data: Value = serde_json::from_slice(data).unwrap();
827        ruma_response_from_json(&data)
828    }
829
830    fn bob_keys_query_response() -> get_keys::v3::Response {
831        let data = json!({
832            "device_keys": {
833                "@bob:localhost": {
834                    "BOBDEVICE": {
835                        "user_id": "@bob:localhost",
836                        "device_id": "BOBDEVICE",
837                        "algorithms": [
838                            "m.olm.v1.curve25519-aes-sha2",
839                            "m.megolm.v1.aes-sha2",
840                            "m.megolm.v2.aes-sha2"
841                        ],
842                        "keys": {
843                            "curve25519:BOBDEVICE": "QzXDFZj0Pt5xG4r11XGSrqE4mnFOTgRM5pz7n3tzohU",
844                            "ed25519:BOBDEVICE": "T7QMEXcEo/NfiC/8doVHT+2XnMm0pDpRa27bmE8PlPI"
845                        },
846                        "signatures": {
847                            "@bob:localhost": {
848                                "ed25519:BOBDEVICE": "1Ee9J02KoVf4DKhT+LkurpZJEygiznqpgkT4lqvMTLtZyzShsVTnwmoMPttuGcJkLp9lMK1egveNYCEaYP80Cw"
849                            }
850                        }
851                    }
852                }
853            }
854        });
855        ruma_response_from_json(&data)
856    }
857
858    /// Returns a keys claim response for device `BOBDEVICE` of user
859    /// `@bob:localhost`.
860    fn bob_one_time_key() -> claim_keys::v3::Response {
861        let data = json!({
862            "failures": {},
863            "one_time_keys":{
864                "@bob:localhost":{
865                    "BOBDEVICE":{
866                      "signed_curve25519:AAAAAAAAAAA": {
867                          "key":"bm1olfbksjC5SwKxCLLK4XaINCA0FwR/155J85gIpCk",
868                          "signatures":{
869                              "@bob:localhost":{
870                                  "ed25519:BOBDEVICE":"BKyS/+EV76zdZkWgny2D0svZ0ycS3etfyHCrsDgm7MYe166HqQmSoX29HsjGLvE/5F+Sg2zW7RJileUvquPwDA"
871                              }
872                          }
873                      }
874                    }
875                }
876            }
877        });
878        ruma_response_from_json(&data)
879    }
880
881    /// Returns a key claim response for device `NMMBNBUSNR` of user
882    /// `@example2:localhost`
883    fn keys_claim_response() -> claim_keys::v3::Response {
884        let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_claim.json");
885        let data: Value = serde_json::from_slice(data).unwrap();
886        ruma_response_from_json(&data)
887    }
888
889    async fn machine_with_user_test_helper(user_id: &UserId, device_id: &DeviceId) -> OlmMachine {
890        let keys_query = keys_query_response();
891        let txn_id = TransactionId::new();
892
893        let machine = OlmMachine::new(user_id, device_id).await;
894
895        // complete a /keys/query and /keys/claim for @example:localhost
896        machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
897        let (txn_id, _keys_claim_request) = machine
898            .get_missing_sessions(iter::once(user_id!("@example:localhost")))
899            .await
900            .unwrap()
901            .unwrap();
902        let keys_claim = keys_claim_response();
903        machine.mark_request_as_sent(&txn_id, &keys_claim).await.unwrap();
904
905        // complete a /keys/query and /keys/claim for @bob:localhost
906        machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
907        let (txn_id, _keys_claim_request) = machine
908            .get_missing_sessions(iter::once(user_id!("@bob:localhost")))
909            .await
910            .unwrap()
911            .unwrap();
912        machine.mark_request_as_sent(&txn_id, &bob_one_time_key()).await.unwrap();
913
914        machine
915    }
916
917    async fn machine() -> OlmMachine {
918        machine_with_user_test_helper(alice_id(), alice_device_id()).await
919    }
920
921    async fn machine_with_shared_room_key_test_helper() -> OlmMachine {
922        let machine = machine().await;
923        let room_id = room_id!("!test:localhost");
924        let keys_claim = keys_claim_response();
925
926        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
927        let requests =
928            machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
929
930        let outbound =
931            machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
932
933        assert!(!outbound.pending_requests().is_empty());
934        assert!(!outbound.shared());
935
936        let response = ToDeviceResponse::new();
937        for request in requests {
938            machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
939        }
940
941        assert!(outbound.shared());
942        assert!(outbound.pending_requests().is_empty());
943
944        machine
945    }
946
947    #[async_test]
948    async fn test_sharing() {
949        let machine = machine().await;
950        let room_id = room_id!("!test:localhost");
951        let keys_claim = keys_claim_response();
952
953        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
954
955        let requests =
956            machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
957
958        let event_count: usize = requests
959            .iter()
960            .filter(|r| r.event_type == "m.room.encrypted".into())
961            .map(|r| r.message_count())
962            .sum();
963
964        // The keys claim response has a couple of one-time keys with invalid
965        // signatures, thus only 148 sessions are actually created, we check
966        // that all 148 valid sessions get an room key.
967        assert_eq!(event_count, 148);
968
969        let withheld_count: usize = requests
970            .iter()
971            .filter(|r| r.event_type == "m.room_key.withheld".into())
972            .map(|r| r.message_count())
973            .sum();
974        assert_eq!(withheld_count, 2);
975    }
976
977    fn count_withheld_from(requests: &[Arc<ToDeviceRequest>], code: WithheldCode) -> usize {
978        requests
979            .iter()
980            .filter(|r| r.event_type == "m.room_key.withheld".into())
981            .map(|r| {
982                let mut count = 0;
983                // count targets
984                for message in r.messages.values() {
985                    message.iter().for_each(|(_, content)| {
986                        let withheld: RoomKeyWithheldContent =
987                            content.deserialize_as::<RoomKeyWithheldContent>().unwrap();
988
989                        if let MegolmV1AesSha2(content) = withheld {
990                            if content.withheld_code() == code {
991                                count += 1;
992                            }
993                        }
994                    })
995                }
996                count
997            })
998            .sum()
999    }
1000
1001    #[async_test]
1002    async fn test_no_olm_sent_once() {
1003        let machine = machine().await;
1004        let keys_claim = keys_claim_response();
1005
1006        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1007
1008        let first_room_id = room_id!("!test:localhost");
1009
1010        let requests = machine
1011            .share_room_key(first_room_id, users.to_owned(), EncryptionSettings::default())
1012            .await
1013            .unwrap();
1014
1015        // there will be two no_olm
1016        let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1017        assert_eq!(withheld_count, 2);
1018
1019        // Re-sharing same session while request has not been sent should not produces
1020        // withheld
1021        let new_requests = machine
1022            .share_room_key(first_room_id, users, EncryptionSettings::default())
1023            .await
1024            .unwrap();
1025        let withheld_count: usize = count_withheld_from(&new_requests, WithheldCode::NoOlm);
1026        // No additional request was added, still the 2 already pending
1027        assert_eq!(withheld_count, 2);
1028
1029        let response = ToDeviceResponse::new();
1030        for request in requests {
1031            machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1032        }
1033
1034        // The fact that an olm was sent should be remembered even if sharing another
1035        // session in an other room.
1036        let second_room_id = room_id!("!other:localhost");
1037        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1038        let requests = machine
1039            .share_room_key(second_room_id, users, EncryptionSettings::default())
1040            .await
1041            .unwrap();
1042
1043        let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1044        assert_eq!(withheld_count, 0);
1045
1046        // Help how do I simulate the creation of a new session for the device
1047        // with no session now?
1048    }
1049
1050    #[async_test]
1051    async fn test_ratcheted_sharing() {
1052        let machine = machine_with_shared_room_key_test_helper().await;
1053
1054        let room_id = room_id!("!test:localhost");
1055        let late_joiner = user_id!("@bob:localhost");
1056        let keys_claim = keys_claim_response();
1057
1058        let mut users: BTreeSet<_> = keys_claim.one_time_keys.keys().map(Deref::deref).collect();
1059        users.insert(late_joiner);
1060
1061        let requests = machine
1062            .share_room_key(room_id, users.into_iter(), EncryptionSettings::default())
1063            .await
1064            .unwrap();
1065
1066        let event_count: usize = requests
1067            .iter()
1068            .filter(|r| r.event_type == "m.room.encrypted".into())
1069            .map(|r| r.message_count())
1070            .sum();
1071        let outbound =
1072            machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1073
1074        assert_eq!(event_count, 1);
1075        assert!(!outbound.pending_requests().is_empty());
1076    }
1077
1078    #[async_test]
1079    async fn test_changing_encryption_settings() {
1080        let machine = machine_with_shared_room_key_test_helper().await;
1081        let room_id = room_id!("!test:localhost");
1082        let keys_claim = keys_claim_response();
1083
1084        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1085        let outbound =
1086            machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1087
1088        let CollectRecipientsResult { should_rotate, .. } = machine
1089            .inner
1090            .group_session_manager
1091            .collect_session_recipients(users.clone(), &EncryptionSettings::default(), &outbound)
1092            .await
1093            .unwrap();
1094
1095        assert!(!should_rotate);
1096
1097        let settings = EncryptionSettings {
1098            history_visibility: HistoryVisibility::Invited,
1099            ..Default::default()
1100        };
1101
1102        let CollectRecipientsResult { should_rotate, .. } = machine
1103            .inner
1104            .group_session_manager
1105            .collect_session_recipients(users.clone(), &settings, &outbound)
1106            .await
1107            .unwrap();
1108
1109        assert!(should_rotate);
1110
1111        let settings = EncryptionSettings {
1112            algorithm: EventEncryptionAlgorithm::from("m.megolm.v2.aes-sha2"),
1113            ..Default::default()
1114        };
1115
1116        let CollectRecipientsResult { should_rotate, .. } = machine
1117            .inner
1118            .group_session_manager
1119            .collect_session_recipients(users, &settings, &outbound)
1120            .await
1121            .unwrap();
1122
1123        assert!(should_rotate);
1124    }
1125
1126    #[async_test]
1127    async fn test_key_recipient_collecting() {
1128        // The user id comes from the fact that the keys_query.json file uses
1129        // this one.
1130        let user_id = user_id!("@example:localhost");
1131        let device_id = device_id!("TESTDEVICE");
1132        let room_id = room_id!("!test:localhost");
1133
1134        let machine = machine_with_user_test_helper(user_id, device_id).await;
1135
1136        let (outbound, _) = machine
1137            .inner
1138            .group_session_manager
1139            .get_or_create_outbound_session(
1140                room_id,
1141                EncryptionSettings::default(),
1142                SenderData::unknown(),
1143            )
1144            .await
1145            .expect("We should be able to create a new session");
1146        let history_visibility = HistoryVisibility::Joined;
1147        let settings = EncryptionSettings { history_visibility, ..Default::default() };
1148
1149        let users = [user_id].into_iter();
1150
1151        let CollectRecipientsResult { devices: recipients, .. } = machine
1152            .inner
1153            .group_session_manager
1154            .collect_session_recipients(users, &settings, &outbound)
1155            .await
1156            .expect("We should be able to collect the session recipients");
1157
1158        assert!(!recipients[user_id].is_empty());
1159
1160        // Make sure that our own device isn't part of the recipients.
1161        assert!(!recipients[user_id]
1162            .iter()
1163            .any(|d| d.user_id() == user_id && d.device_id() == device_id));
1164
1165        let settings = EncryptionSettings {
1166            sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1167            ..Default::default()
1168        };
1169        let users = [user_id].into_iter();
1170
1171        let CollectRecipientsResult { devices: recipients, .. } = machine
1172            .inner
1173            .group_session_manager
1174            .collect_session_recipients(users, &settings, &outbound)
1175            .await
1176            .expect("We should be able to collect the session recipients");
1177
1178        assert!(recipients[user_id].is_empty());
1179
1180        let device_id = "AFGUOBTZWM".into();
1181        let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1182        device.set_local_trust(LocalTrust::Verified).await.unwrap();
1183        let users = [user_id].into_iter();
1184
1185        let CollectRecipientsResult { devices: recipients, withheld_devices: withheld, .. } =
1186            machine
1187                .inner
1188                .group_session_manager
1189                .collect_session_recipients(users, &settings, &outbound)
1190                .await
1191                .expect("We should be able to collect the session recipients");
1192
1193        assert!(recipients[user_id]
1194            .iter()
1195            .any(|d| d.user_id() == user_id && d.device_id() == device_id));
1196
1197        let devices = machine.get_user_devices(user_id, None).await.unwrap();
1198        devices
1199            .devices()
1200            // Ignore our own device
1201            .filter(|d| d.device_id() != device_id!("TESTDEVICE"))
1202            .for_each(|d| {
1203                if d.is_blacklisted() {
1204                    assert!(withheld.iter().any(|(dev, w)| {
1205                        dev.device_id() == d.device_id() && w == &WithheldCode::Blacklisted
1206                    }));
1207                } else if !d.is_verified() {
1208                    // the device should then be in the list of withhelds
1209                    assert!(withheld.iter().any(|(dev, w)| {
1210                        dev.device_id() == d.device_id() && w == &WithheldCode::Unverified
1211                    }));
1212                }
1213            });
1214
1215        assert_eq!(149, withheld.len());
1216    }
1217
1218    #[async_test]
1219    async fn test_sharing_withheld_only_trusted() {
1220        let machine = machine().await;
1221        let room_id = room_id!("!test:localhost");
1222        let keys_claim = keys_claim_response();
1223
1224        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1225        let settings = EncryptionSettings {
1226            sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1227            ..Default::default()
1228        };
1229
1230        // Trust only one
1231        let user_id = user_id!("@example:localhost");
1232        let device_id = "MWFXPINOAO".into();
1233        let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1234        device.set_local_trust(LocalTrust::Verified).await.unwrap();
1235        machine
1236            .get_device(user_id, "MWVTUXDNNM".into(), None)
1237            .await
1238            .unwrap()
1239            .unwrap()
1240            .set_local_trust(LocalTrust::BlackListed)
1241            .await
1242            .unwrap();
1243
1244        let requests = machine.share_room_key(room_id, users, settings).await.unwrap();
1245
1246        // One room key should be sent
1247        let room_key_count =
1248            requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).count();
1249
1250        assert_eq!(1, room_key_count);
1251
1252        let withheld_count =
1253            requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1254        // Can be send in one batch
1255        assert_eq!(1, withheld_count);
1256
1257        let event_count: usize = requests
1258            .iter()
1259            .filter(|r| r.event_type == "m.room_key.withheld".into())
1260            .map(|r| r.message_count())
1261            .sum();
1262
1263        // withhelds are sent in clear so all device should be counted (even if no OTK)
1264        assert_eq!(event_count, 149);
1265
1266        // One should be blacklisted
1267        let has_blacklist =
1268            requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).any(|r| {
1269                let device_key = DeviceIdOrAllDevices::from(device_id!("MWVTUXDNNM").to_owned());
1270                let content = &r.messages[user_id][&device_key];
1271                let withheld: RoomKeyWithheldContent =
1272                    content.deserialize_as::<RoomKeyWithheldContent>().unwrap();
1273                if let MegolmV1AesSha2(content) = withheld {
1274                    content.withheld_code() == WithheldCode::Blacklisted
1275                } else {
1276                    false
1277                }
1278            });
1279
1280        assert!(has_blacklist);
1281    }
1282
1283    #[async_test]
1284    async fn test_no_olm_withheld_only_sent_once() {
1285        let keys_query = keys_query_response();
1286        let txn_id = TransactionId::new();
1287
1288        let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1289
1290        machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1291        machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1292
1293        let first_room = room_id!("!test:localhost");
1294        let second_room = room_id!("!test2:localhost");
1295        let bob_id = user_id!("@bob:localhost");
1296
1297        let settings = EncryptionSettings::default();
1298        let users = [bob_id];
1299
1300        let requests = machine
1301            .share_room_key(first_room, users.into_iter(), settings.to_owned())
1302            .await
1303            .unwrap();
1304
1305        // One withheld request should be sent.
1306        let withheld_count =
1307            requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1308
1309        assert_eq!(withheld_count, 1);
1310        assert_eq!(requests.len(), 1);
1311
1312        // On the second room key share attempt we're not sending another `m.no_olm`
1313        // code since the first one is taking care of this.
1314        let second_requests =
1315            machine.share_room_key(second_room, users.into_iter(), settings).await.unwrap();
1316
1317        let withheld_count =
1318            second_requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1319
1320        assert_eq!(withheld_count, 0);
1321        assert_eq!(second_requests.len(), 0);
1322
1323        let response = ToDeviceResponse::new();
1324
1325        let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1326
1327        // The device should be marked as having the `m.no_olm` code received only after
1328        // the request has been marked as sent.
1329        assert!(!device.was_withheld_code_sent());
1330
1331        for request in requests {
1332            machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1333        }
1334
1335        let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1336
1337        assert!(device.was_withheld_code_sent());
1338    }
1339
1340    #[async_test]
1341    async fn test_resend_session_after_unwedging() {
1342        let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1343        assert_let!(Ok(Some((txn_id, device_keys_request))) = machine.upload_device_keys().await);
1344        let device_keys_response = upload_keys::v3::Response::new(BTreeMap::from([(
1345            OneTimeKeyAlgorithm::SignedCurve25519,
1346            UInt::new(device_keys_request.one_time_keys.len() as u64).unwrap(),
1347        )]));
1348        machine.mark_request_as_sent(&txn_id, &device_keys_response).await.unwrap();
1349
1350        let room_id = room_id!("!test:localhost");
1351
1352        let bob_id = user_id!("@bob:localhost");
1353        let bob_account = Account::new(bob_id);
1354        let keys_query_data = json!({
1355            "device_keys": {
1356                "@bob:localhost": {
1357                    bob_account.device_id.clone(): bob_account.device_keys()
1358                }
1359            }
1360        });
1361        let keys_query: get_keys::v3::Response = ruma_response_from_json(&keys_query_data);
1362        let txn_id = TransactionId::new();
1363        machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1364
1365        let alice_device_keys =
1366            device_keys_request.device_keys.unwrap().deserialize_as::<DeviceKeys>().unwrap();
1367        let mut alice_otks = device_keys_request.one_time_keys.iter();
1368        let alice_device = DeviceData::new(alice_device_keys, LocalTrust::Unset);
1369
1370        {
1371            // Bob creates an Olm session with Alice and encrypts a message to her
1372            let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1373            let mut session = bob_account
1374                .create_outbound_session(
1375                    &alice_device,
1376                    &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1377                    bob_account.device_keys(),
1378                )
1379                .unwrap();
1380            let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1381
1382            let to_device =
1383                EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1384
1385            // Alice decrypts the message
1386            let sync_changes = EncryptionSyncChanges {
1387                to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1388                changed_devices: &Default::default(),
1389                one_time_keys_counts: &Default::default(),
1390                unused_fallback_keys: None,
1391                next_batch_token: None,
1392            };
1393            let (decrypted, _) = machine.receive_sync_changes(sync_changes).await.unwrap();
1394
1395            assert_eq!(1, decrypted.len());
1396        }
1397
1398        // Alice shares the room key with Bob
1399        {
1400            let requests = machine
1401                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1402                .await
1403                .unwrap();
1404
1405            // We should have had one to-device event
1406            let event_count: usize = requests
1407                .iter()
1408                .filter(|r| r.event_type == "m.room.encrypted".into())
1409                .map(|r| r.message_count())
1410                .sum();
1411            assert_eq!(event_count, 1);
1412
1413            let response = ToDeviceResponse::new();
1414            for request in requests {
1415                machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1416            }
1417        }
1418
1419        // When Alice shares the room key again, there shouldn't be any
1420        // to-device events, since we already shared with Bob
1421        {
1422            let requests = machine
1423                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1424                .await
1425                .unwrap();
1426
1427            let event_count: usize = requests
1428                .iter()
1429                .filter(|r| r.event_type == "m.room.encrypted".into())
1430                .map(|r| r.message_count())
1431                .sum();
1432            assert_eq!(event_count, 0);
1433        }
1434
1435        // Pretend that Bob wasn't able to decrypt, so he tries to unwedge
1436        {
1437            let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1438            let mut session = bob_account
1439                .create_outbound_session(
1440                    &alice_device,
1441                    &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1442                    bob_account.device_keys(),
1443                )
1444                .unwrap();
1445            let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1446
1447            let to_device =
1448                EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1449
1450            // Alice decrypts the unwedge message
1451            let sync_changes = EncryptionSyncChanges {
1452                to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1453                changed_devices: &Default::default(),
1454                one_time_keys_counts: &Default::default(),
1455                unused_fallback_keys: None,
1456                next_batch_token: None,
1457            };
1458            let (decrypted, _) = machine.receive_sync_changes(sync_changes).await.unwrap();
1459
1460            assert_eq!(1, decrypted.len());
1461        }
1462
1463        // When Alice shares the room key again, it should be re-shared with Bob
1464        {
1465            let requests = machine
1466                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1467                .await
1468                .unwrap();
1469
1470            let event_count: usize = requests
1471                .iter()
1472                .filter(|r| r.event_type == "m.room.encrypted".into())
1473                .map(|r| r.message_count())
1474                .sum();
1475            assert_eq!(event_count, 1);
1476
1477            let response = ToDeviceResponse::new();
1478            for request in requests {
1479                machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1480            }
1481        }
1482
1483        // When Alice shares the room key yet again, there shouldn't be any
1484        // to-device events
1485        {
1486            let requests = machine
1487                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1488                .await
1489                .unwrap();
1490
1491            let event_count: usize = requests
1492                .iter()
1493                .filter(|r| r.event_type == "m.room.encrypted".into())
1494                .map(|r| r.message_count())
1495                .sum();
1496            assert_eq!(event_count, 0);
1497        }
1498    }
1499}