1mod share_strategy;
16
17use std::{
18 collections::{BTreeMap, BTreeSet},
19 fmt::Debug,
20 iter,
21 iter::zip,
22 sync::Arc,
23};
24
25use futures_util::future::join_all;
26use itertools::Itertools;
27use matrix_sdk_common::{
28 deserialized_responses::WithheldCode, executor::spawn, locks::RwLock as StdRwLock,
29};
30#[cfg(feature = "experimental-encrypted-state-events")]
31use ruma::events::AnyStateEventContent;
32use ruma::{
33 DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId,
34 UserId,
35 events::{AnyMessageLikeEventContent, AnyToDeviceEventContent, ToDeviceEventType},
36 serde::Raw,
37 to_device::DeviceIdOrAllDevices,
38};
39use serde::Serialize;
40pub use share_strategy::CollectStrategy;
41#[cfg(feature = "experimental-send-custom-to-device")]
42pub(crate) use share_strategy::split_devices_for_share_strategy;
43pub(crate) use share_strategy::{
44 CollectRecipientsResult, withheld_code_for_device_for_share_strategy,
45};
46use tracing::{Instrument, debug, error, info, instrument, trace, warn};
47
48#[cfg(feature = "experimental-encrypted-state-events")]
49use crate::types::events::room::encrypted::RoomEncryptedEventContent;
50use crate::{
51 Device, DeviceData, EncryptionSettings, OlmError,
52 error::{EventError, MegolmResult, OlmResult},
53 identities::device::MaybeEncryptedRoomKey,
54 olm::{
55 InboundGroupSession, OutboundGroupSession, OutboundGroupSessionEncryptionResult,
56 SenderData, SenderDataFinder, Session, ShareInfo, ShareState,
57 },
58 store::{CryptoStoreWrapper, Result as StoreResult, Store, types::Changes},
59 types::{
60 events::{
61 EventType, room::encrypted::ToDeviceEncryptedEventContent,
62 room_key_bundle::RoomKeyBundleContent,
63 },
64 requests::ToDeviceRequest,
65 },
66};
67
68#[derive(Clone, Debug)]
69pub(crate) struct GroupSessionCache {
70 store: Store,
71 sessions: Arc<StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>>,
72 sessions_being_shared: Arc<StdRwLock<BTreeMap<OwnedTransactionId, OutboundGroupSession>>>,
75}
76
77impl GroupSessionCache {
78 pub(crate) fn new(store: Store) -> Self {
79 Self { store, sessions: Default::default(), sessions_being_shared: Default::default() }
80 }
81
82 pub(crate) fn insert(&self, session: OutboundGroupSession) {
83 self.sessions.write().insert(session.room_id().to_owned(), session);
84 }
85
86 pub async fn get_or_load(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
93 if let Some(s) = self.sessions.read().get(room_id) {
96 return Some(s.clone());
97 }
98
99 match self.store.get_outbound_group_session(room_id).await {
100 Ok(Some(s)) => {
101 {
102 let mut sessions_being_shared = self.sessions_being_shared.write();
103 for request_id in s.pending_request_ids() {
104 sessions_being_shared.insert(request_id, s.clone());
105 }
106 }
107
108 self.sessions.write().insert(room_id.to_owned(), s.clone());
109
110 Some(s)
111 }
112 Ok(None) => None,
113 Err(e) => {
114 error!("Couldn't restore an outbound group session: {e:?}");
115 None
116 }
117 }
118 }
119
120 fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
127 self.sessions.read().get(room_id).cloned()
128 }
129
130 fn has_session_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
132 self.sessions.read().values().any(|s| s.sharing_view().is_withheld_to(device, code))
133 }
134
135 fn remove_from_being_shared(&self, id: &TransactionId) -> Option<OutboundGroupSession> {
136 self.sessions_being_shared.write().remove(id)
137 }
138
139 fn mark_as_being_shared(&self, id: OwnedTransactionId, session: OutboundGroupSession) {
140 self.sessions_being_shared.write().insert(id, session);
141 }
142}
143
144#[derive(Debug, Clone)]
145pub(crate) struct GroupSessionManager {
146 store: Store,
150 sessions: GroupSessionCache,
152}
153
154impl GroupSessionManager {
155 const MAX_TO_DEVICE_MESSAGES: usize = 250;
156
157 pub fn new(store: Store) -> Self {
158 Self { store: store.clone(), sessions: GroupSessionCache::new(store) }
159 }
160
161 pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
162 if let Some(s) = self.sessions.get(room_id) {
163 s.invalidate_session();
164
165 let mut changes = Changes::default();
166 changes.outbound_group_sessions.push(s.clone());
167 self.store.save_changes(changes).await?;
168
169 Ok(true)
170 } else {
171 Ok(false)
172 }
173 }
174
175 pub async fn mark_request_as_sent(&self, request_id: &TransactionId) -> StoreResult<()> {
176 let Some(session) = self.sessions.remove_from_being_shared(request_id) else {
177 return Ok(());
178 };
179
180 let no_olm = session.mark_request_as_sent(request_id);
181
182 let mut changes = Changes::default();
183
184 for (user_id, devices) in &no_olm {
185 for device_id in devices {
186 let device = self.store.get_device(user_id, device_id).await;
187
188 if let Ok(Some(device)) = device {
189 device.mark_withheld_code_as_sent();
190 changes.devices.changed.push(device.inner.clone());
191 } else {
192 error!(
193 ?request_id,
194 "Marking to-device no olm as sent but device not found, might \
195 have been deleted?"
196 );
197 }
198 }
199 }
200
201 changes.outbound_group_sessions.push(session.clone());
202 self.store.save_changes(changes).await
203 }
204
205 #[cfg(test)]
206 pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
207 self.sessions.get(room_id)
208 }
209
210 pub async fn encrypt(
211 &self,
212 room_id: &RoomId,
213 event_type: &str,
214 content: &Raw<AnyMessageLikeEventContent>,
215 ) -> MegolmResult<OutboundGroupSessionEncryptionResult> {
216 let session =
217 self.sessions.get_or_load(room_id).await.expect("Session wasn't created nor shared");
218
219 assert!(!session.expired(), "Session expired");
220
221 let result = session.encrypt(event_type, content).await;
222
223 let mut changes = Changes::default();
224 changes.outbound_group_sessions.push(session);
225 self.store.save_changes(changes).await?;
226
227 Ok(result)
228 }
229
230 #[cfg(feature = "experimental-encrypted-state-events")]
253 pub async fn encrypt_state(
254 &self,
255 room_id: &RoomId,
256 event_type: &str,
257 state_key: &str,
258 content: &Raw<AnyStateEventContent>,
259 ) -> MegolmResult<Raw<RoomEncryptedEventContent>> {
260 let session =
261 self.sessions.get_or_load(room_id).await.expect("Session wasn't created nor shared");
262
263 assert!(!session.expired(), "Session expired");
264
265 let content = session.encrypt_state(event_type, state_key, content).await;
266
267 let mut changes = Changes::default();
268 changes.outbound_group_sessions.push(session);
269 self.store.save_changes(changes).await?;
270
271 Ok(content)
272 }
273
274 pub async fn create_outbound_group_session(
278 &self,
279 room_id: &RoomId,
280 settings: EncryptionSettings,
281 own_sender_data: SenderData,
282 ) -> OlmResult<(OutboundGroupSession, InboundGroupSession)> {
283 let (outbound, inbound) = self
284 .store
285 .static_account()
286 .create_group_session_pair(room_id, settings, own_sender_data)
287 .await
288 .map_err(|_| EventError::UnsupportedAlgorithm)?;
289
290 self.sessions.insert(outbound.clone());
291 Ok((outbound, inbound))
292 }
293
294 pub async fn get_or_create_outbound_session(
295 &self,
296 room_id: &RoomId,
297 settings: EncryptionSettings,
298 own_sender_data: SenderData,
299 ) -> OlmResult<(OutboundGroupSession, Option<InboundGroupSession>)> {
300 let outbound_session = self.sessions.get_or_load(room_id).await;
301
302 if let Some(s) = outbound_session {
305 if s.expired() || s.invalidated() {
306 self.create_outbound_group_session(room_id, settings, own_sender_data)
307 .await
308 .map(|(o, i)| (o, i.into()))
309 } else {
310 Ok((s, None))
311 }
312 } else {
313 self.create_outbound_group_session(room_id, settings, own_sender_data)
314 .await
315 .map(|(o, i)| (o, i.into()))
316 }
317 }
318
319 async fn encrypt_session_for(
326 store: Arc<CryptoStoreWrapper>,
327 group_session: OutboundGroupSession,
328 devices: Vec<DeviceData>,
329 ) -> OlmResult<(
330 EncryptForDevicesResult,
331 BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>,
332 )> {
333 pub struct DeviceResult {
335 device: DeviceData,
336 maybe_encrypted_room_key: MaybeEncryptedRoomKey,
337 }
338
339 let mut result_builder = EncryptForDevicesResultBuilder::default();
340 let mut share_infos = BTreeMap::new();
341
342 let encrypt = |store: Arc<CryptoStoreWrapper>,
345 device: DeviceData,
346 session: OutboundGroupSession| async move {
347 let encryption_result = device.maybe_encrypt_room_key(store.as_ref(), session).await?;
348
349 Ok::<_, OlmError>(DeviceResult { device, maybe_encrypted_room_key: encryption_result })
350 };
351
352 let tasks: Vec<_> = devices
353 .iter()
354 .map(|d| spawn(encrypt(store.clone(), d.clone(), group_session.clone())))
355 .collect();
356
357 let results = join_all(tasks).await;
358
359 for result in results {
360 let result = result.expect("Encryption task panicked")?;
361
362 match result.maybe_encrypted_room_key {
363 MaybeEncryptedRoomKey::Encrypted { used_session, share_info, message } => {
364 result_builder.on_successful_encryption(&result.device, *used_session, message);
365
366 let user_id = result.device.user_id().to_owned();
367 let device_id = result.device.device_id().to_owned();
368 share_infos
369 .entry(user_id)
370 .or_insert_with(BTreeMap::new)
371 .insert(device_id, *share_info);
372 }
373 MaybeEncryptedRoomKey::MissingSession => {
374 result_builder.on_missing_session(result.device);
375 }
376 }
377 }
378
379 Ok((result_builder.into_result(), share_infos))
380 }
381
382 #[instrument(skip_all)]
389 pub async fn collect_session_recipients(
390 &self,
391 users: impl Iterator<Item = &UserId>,
392 settings: &EncryptionSettings,
393 outbound: &OutboundGroupSession,
394 ) -> OlmResult<CollectRecipientsResult> {
395 share_strategy::collect_session_recipients(&self.store, users, settings, outbound).await
396 }
397
398 async fn encrypt_request(
399 store: Arc<CryptoStoreWrapper>,
400 chunk: Vec<DeviceData>,
401 outbound: OutboundGroupSession,
402 sessions: GroupSessionCache,
403 ) -> OlmResult<(Vec<Session>, Vec<(DeviceData, WithheldCode)>)> {
404 let (result, share_infos) =
405 Self::encrypt_session_for(store, outbound.clone(), chunk).await?;
406
407 if let Some(request) = result.to_device_request {
408 let id = request.txn_id.clone();
409 outbound.add_request(id.clone(), request.into(), share_infos);
410 sessions.mark_as_being_shared(id, outbound.clone());
411 }
412
413 Ok((result.updated_olm_sessions, result.no_olm_devices))
414 }
415
416 pub(crate) fn session_cache(&self) -> GroupSessionCache {
417 self.sessions.clone()
418 }
419
420 async fn maybe_rotate_group_session(
421 &self,
422 should_rotate: bool,
423 room_id: &RoomId,
424 outbound: OutboundGroupSession,
425 encryption_settings: EncryptionSettings,
426 changes: &mut Changes,
427 own_device: Option<Device>,
428 ) -> OlmResult<OutboundGroupSession> {
429 Ok(if should_rotate {
430 let old_session_id = outbound.session_id();
431
432 let (outbound, mut inbound) = self
433 .create_outbound_group_session(room_id, encryption_settings, SenderData::unknown())
434 .await?;
435
436 let own_sender_data = if let Some(device) = own_device {
440 SenderDataFinder::find_using_device_data(
441 &self.store,
442 device.inner.clone(),
443 &inbound,
444 )
445 .await?
446 } else {
447 error!("Unable to find our own device!");
448 SenderData::unknown()
449 };
450 inbound.sender_data = own_sender_data;
451
452 changes.outbound_group_sessions.push(outbound.clone());
453 changes.inbound_group_sessions.push(inbound);
454
455 debug!(
456 old_session_id = old_session_id,
457 session_id = outbound.session_id(),
458 "A user or device has left the room since we last sent a \
459 message, or the encryption settings have changed. Rotating the \
460 room key.",
461 );
462
463 outbound
464 } else {
465 outbound
466 })
467 }
468
469 async fn encrypt_for_devices(
470 &self,
471 recipient_devices: Vec<DeviceData>,
472 group_session: &OutboundGroupSession,
473 changes: &mut Changes,
474 ) -> OlmResult<Vec<(DeviceData, WithheldCode)>> {
475 if !recipient_devices.is_empty() {
477 let recipients = recipient_list_to_users_and_devices(&recipient_devices);
478
479 changes.outbound_group_sessions = vec![group_session.clone()];
482
483 let message_index = group_session.message_index().await;
484
485 info!(
486 ?recipients,
487 message_index,
488 room_id = ?group_session.room_id(),
489 session_id = group_session.session_id(),
490 "Trying to encrypt a room key",
491 );
492 }
493
494 let tasks: Vec<_> = recipient_devices
499 .chunks(Self::MAX_TO_DEVICE_MESSAGES)
500 .map(|chunk| {
501 spawn(Self::encrypt_request(
502 self.store.crypto_store(),
503 chunk.to_vec(),
504 group_session.clone(),
505 self.sessions.clone(),
506 ))
507 })
508 .collect();
509
510 let mut withheld_devices = Vec::new();
511
512 for result in join_all(tasks).await {
517 let result = result.expect("Encryption task panicked");
518
519 let (used_sessions, failed_no_olm) = result?;
520
521 changes.sessions.extend(used_sessions);
522 withheld_devices.extend(failed_no_olm);
523 }
524
525 Ok(withheld_devices)
526 }
527
528 fn is_withheld_to(
529 &self,
530 group_session: &OutboundGroupSession,
531 device: &DeviceData,
532 code: &WithheldCode,
533 ) -> bool {
534 if code == &WithheldCode::NoOlm {
552 device.was_withheld_code_sent() || self.sessions.has_session_withheld_to(device, code)
553 } else {
554 group_session.sharing_view().is_withheld_to(device, code)
555 }
556 }
557
558 fn handle_withheld_devices(
559 &self,
560 group_session: &OutboundGroupSession,
561 withheld_devices: Vec<(DeviceData, WithheldCode)>,
562 ) -> OlmResult<()> {
563 let to_content = |code| {
565 let content = group_session.withheld_code(code);
566 Raw::new(&content).expect("We can always serialize a withheld content info").cast()
567 };
568
569 let chunk_to_request = |chunk| {
572 let mut messages = BTreeMap::new();
573 let mut share_infos = BTreeMap::new();
574
575 for (device, code) in chunk {
576 let device: DeviceData = device;
577 let code: WithheldCode = code;
578
579 let user_id = device.user_id().to_owned();
580 let device_id = device.device_id().to_owned();
581
582 let share_info = ShareInfo::new_withheld(code.to_owned());
583 let content = to_content(code);
584
585 messages
586 .entry(user_id.to_owned())
587 .or_insert_with(BTreeMap::new)
588 .insert(DeviceIdOrAllDevices::DeviceId(device_id.to_owned()), content);
589
590 share_infos
591 .entry(user_id)
592 .or_insert_with(BTreeMap::new)
593 .insert(device_id, share_info);
594 }
595
596 let txn_id = TransactionId::new();
597
598 let request = ToDeviceRequest {
599 event_type: ToDeviceEventType::from("m.room_key.withheld"),
600 txn_id,
601 messages,
602 };
603
604 (request, share_infos)
605 };
606
607 let result: Vec<_> = withheld_devices
608 .into_iter()
609 .filter(|(device, code)| !self.is_withheld_to(group_session, device, code))
610 .chunks(Self::MAX_TO_DEVICE_MESSAGES)
611 .into_iter()
612 .map(chunk_to_request)
613 .collect();
614
615 for (request, share_info) in result {
616 if !request.messages.is_empty() {
617 let txn_id = request.txn_id.to_owned();
618 group_session.add_request(txn_id.to_owned(), request.into(), share_info);
619
620 self.sessions.mark_as_being_shared(txn_id, group_session.clone());
621 }
622 }
623
624 Ok(())
625 }
626
627 fn log_room_key_sharing_result(requests: &[Arc<ToDeviceRequest>]) {
628 for request in requests {
629 let message_list = Self::to_device_request_to_log_list(request);
630 info!(
631 request_id = ?request.txn_id,
632 ?message_list,
633 "Created batch of to-device messages of type {}",
634 request.event_type
635 );
636 }
637 }
638
639 fn to_device_request_to_log_list(
643 request: &Arc<ToDeviceRequest>,
644 ) -> Vec<(String, String, String)> {
645 #[derive(serde::Deserialize)]
646 struct ContentStub<'a> {
647 #[serde(borrow, default, rename = "org.matrix.msgid")]
648 message_id: Option<&'a str>,
649 }
650
651 let mut result: Vec<(String, String, String)> = Vec::new();
652
653 for (user_id, device_map) in &request.messages {
654 for (device, content) in device_map {
655 let message_id: Option<&str> = content
656 .deserialize_as_unchecked::<ContentStub<'_>>()
657 .expect("We should be able to deserialize the content we generated")
658 .message_id;
659
660 result.push((
661 message_id.unwrap_or("<undefined>").to_owned(),
662 user_id.to_string(),
663 device.to_string(),
664 ));
665 }
666 }
667 result
668 }
669
670 #[instrument(skip(self, users, encryption_settings), fields(session_id))]
681 pub async fn share_room_key(
682 &self,
683 room_id: &RoomId,
684 users: impl Iterator<Item = &UserId>,
685 encryption_settings: impl Into<EncryptionSettings>,
686 ) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
687 trace!("Checking if a room key needs to be shared");
688
689 let account = self.store.static_account();
690 let device = self.store.get_device(account.user_id(), account.device_id()).await?;
691
692 let encryption_settings = encryption_settings.into();
693 let mut changes = Changes::default();
694
695 let (outbound, inbound) = self
697 .get_or_create_outbound_session(
698 room_id,
699 encryption_settings.clone(),
700 SenderData::unknown(),
701 )
702 .await?;
703 tracing::Span::current().record("session_id", outbound.session_id());
704
705 if let Some(mut inbound) = inbound {
708 let own_sender_data = if let Some(device) = &device {
712 SenderDataFinder::find_using_device_data(
713 &self.store,
714 device.inner.clone(),
715 &inbound,
716 )
717 .await?
718 } else {
719 error!("Unable to find our own device!");
720 SenderData::unknown()
721 };
722 inbound.sender_data = own_sender_data;
723
724 changes.outbound_group_sessions.push(outbound.clone());
725 changes.inbound_group_sessions.push(inbound);
726 }
727
728 let CollectRecipientsResult { should_rotate, devices, mut withheld_devices } =
732 self.collect_session_recipients(users, &encryption_settings, &outbound).await?;
733
734 let outbound = self
735 .maybe_rotate_group_session(
736 should_rotate,
737 room_id,
738 outbound,
739 encryption_settings,
740 &mut changes,
741 device,
742 )
743 .await?;
744
745 let devices: Vec<_> = devices
748 .into_iter()
749 .flat_map(|(_, d)| {
750 d.into_iter().filter(|d| match outbound.sharing_view().get_share_state(d) {
751 ShareState::NotShared => true,
752 ShareState::Shared { message_index: _, olm_wedging_index } => {
753 olm_wedging_index < d.olm_wedging_index
760 }
761 _ => false,
762 })
763 })
764 .collect();
765
766 let unable_to_encrypt_devices =
772 self.encrypt_for_devices(devices, &outbound, &mut changes).await?;
773
774 withheld_devices.extend(unable_to_encrypt_devices);
776
777 self.handle_withheld_devices(&outbound, withheld_devices)?;
780
781 let requests = outbound.pending_requests();
785
786 if requests.is_empty() {
787 if !outbound.shared() {
788 debug!("The room key doesn't need to be shared with anyone. Marking as shared.");
789
790 outbound.mark_as_shared();
791 changes.outbound_group_sessions.push(outbound.clone());
792 }
793 } else {
794 Self::log_room_key_sharing_result(&requests)
795 }
796
797 if !changes.is_empty() {
799 let session_count = changes.sessions.len();
800
801 self.store.save_changes(changes).await?;
802
803 trace!(
804 session_count = session_count,
805 "Stored the changed sessions after encrypting an room key"
806 );
807 }
808
809 Ok(requests)
810 }
811
812 #[instrument(skip(self, bundle_data))]
824 pub async fn share_room_key_bundle_data(
825 &self,
826 user_id: &UserId,
827 collect_strategy: &CollectStrategy,
828 bundle_data: RoomKeyBundleContent,
829 ) -> OlmResult<Vec<ToDeviceRequest>> {
830 let collect_strategy = match collect_strategy {
832 CollectStrategy::AllDevices | CollectStrategy::ErrorOnVerifiedUserProblem => {
833 warn!(
834 "Ignoring request to use unsafe sharing strategy {collect_strategy:?} \
835 for room key history sharing",
836 );
837 &CollectStrategy::IdentityBasedStrategy
838 }
839 CollectStrategy::IdentityBasedStrategy | CollectStrategy::OnlyTrustedDevices => {
840 collect_strategy
841 }
842 };
843
844 let mut changes = Changes::default();
845
846 let CollectRecipientsResult { devices, .. } =
847 share_strategy::collect_recipients_for_share_strategy(
848 &self.store,
849 iter::once(user_id),
850 collect_strategy,
851 None,
852 )
853 .await?;
854
855 let devices = devices.into_values().flatten().collect();
856 let event_type = bundle_data.event_type().to_owned();
857 let (requests, _) = self
858 .encrypt_content_for_devices(devices, &event_type, bundle_data, &mut changes)
859 .await?;
860
861 if !changes.is_empty() {
865 let session_count = changes.sessions.len();
866
867 self.store.save_changes(changes).await?;
868
869 trace!(
870 session_count = session_count,
871 "Stored the changed sessions after encrypting an room key"
872 );
873 }
874
875 Ok(requests)
876 }
877
878 pub(crate) async fn encrypt_content_for_devices(
885 &self,
886 recipient_devices: Vec<DeviceData>,
887 event_type: &str,
888 content: impl Serialize + Clone + Send + 'static,
889 changes: &mut Changes,
890 ) -> OlmResult<(Vec<ToDeviceRequest>, Vec<(DeviceData, WithheldCode)>)> {
891 let recipients = recipient_list_to_users_and_devices(&recipient_devices);
892 info!(?recipients, "Encrypting content of type {}", event_type);
893
894 let tasks: Vec<_> = recipient_devices
899 .chunks(Self::MAX_TO_DEVICE_MESSAGES)
900 .map(|chunk| {
901 spawn(
902 encrypt_content_for_devices(
903 self.store.crypto_store(),
904 event_type.to_owned(),
905 content.clone(),
906 chunk.to_vec(),
907 )
908 .in_current_span(),
909 )
910 })
911 .collect();
912
913 let mut no_olm_devices = Vec::new();
914 let mut to_device_requests = Vec::new();
915
916 for result in join_all(tasks).await {
921 let result = result.expect("Encryption task panicked")?;
922 if let Some(request) = result.to_device_request {
923 to_device_requests.push(request);
924 }
925 changes.sessions.extend(result.updated_olm_sessions);
926 no_olm_devices.extend(result.no_olm_devices);
927 }
928
929 Ok((to_device_requests, no_olm_devices))
930 }
931}
932
933async fn encrypt_content_for_devices(
942 store: Arc<CryptoStoreWrapper>,
943 event_type: String,
944 content: impl Serialize + Clone + Send + 'static,
945 devices: Vec<DeviceData>,
946) -> OlmResult<EncryptForDevicesResult> {
947 let mut result_builder = EncryptForDevicesResultBuilder::default();
948
949 async fn encrypt(
950 store: Arc<CryptoStoreWrapper>,
951 device: DeviceData,
952 event_type: String,
953 bundle_data: impl Serialize,
954 ) -> OlmResult<(Session, Raw<ToDeviceEncryptedEventContent>)> {
955 device.encrypt(store.as_ref(), &event_type, bundle_data).await
956 }
957
958 let tasks = devices.iter().map(|device| {
959 spawn(
960 encrypt(store.clone(), device.clone(), event_type.clone(), content.clone())
961 .in_current_span(),
962 )
963 });
964
965 let results = join_all(tasks).await;
966
967 for (device, result) in zip(devices, results) {
968 let encryption_result = result.expect("Encryption task panicked");
969
970 match encryption_result {
971 Ok((used_session, message)) => {
972 result_builder.on_successful_encryption(&device, used_session, message.cast());
973 }
974 Err(OlmError::MissingSession) => {
975 result_builder.on_missing_session(device);
977 }
978 Err(e) => return Err(e),
979 }
980 }
981
982 Ok(result_builder.into_result())
983}
984
985#[derive(Debug)]
988struct EncryptForDevicesResult {
989 to_device_request: Option<ToDeviceRequest>,
992
993 no_olm_devices: Vec<(DeviceData, WithheldCode)>,
995
996 updated_olm_sessions: Vec<Session>,
999}
1000
1001#[derive(Debug, Default)]
1003struct EncryptForDevicesResultBuilder {
1004 messages: BTreeMap<OwnedUserId, BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>>>,
1006
1007 no_olm_devices: Vec<(DeviceData, WithheldCode)>,
1009
1010 updated_olm_sessions: Vec<Session>,
1013}
1014
1015impl EncryptForDevicesResultBuilder {
1016 pub fn on_successful_encryption(
1020 &mut self,
1021 device: &DeviceData,
1022 used_session: Session,
1023 message: Raw<AnyToDeviceEventContent>,
1024 ) {
1025 self.updated_olm_sessions.push(used_session);
1026
1027 self.messages
1028 .entry(device.user_id().to_owned())
1029 .or_default()
1030 .insert(DeviceIdOrAllDevices::DeviceId(device.device_id().to_owned()), message);
1031 }
1032
1033 pub fn on_missing_session(&mut self, device: DeviceData) {
1035 self.no_olm_devices.push((device, WithheldCode::NoOlm));
1036 }
1037
1038 pub fn into_result(self) -> EncryptForDevicesResult {
1041 let EncryptForDevicesResultBuilder { updated_olm_sessions, no_olm_devices, messages } =
1042 self;
1043
1044 let mut encrypt_for_devices_result = EncryptForDevicesResult {
1045 to_device_request: None,
1046 updated_olm_sessions,
1047 no_olm_devices,
1048 };
1049
1050 if !messages.is_empty() {
1051 let request = ToDeviceRequest {
1052 event_type: ToDeviceEventType::RoomEncrypted,
1053 txn_id: TransactionId::new(),
1054 messages,
1055 };
1056 trace!(
1057 recipient_count = request.message_count(),
1058 transaction_id = ?request.txn_id,
1059 "Created a to-device request carrying room keys",
1060 );
1061 encrypt_for_devices_result.to_device_request = Some(request);
1062 }
1063
1064 encrypt_for_devices_result
1065 }
1066}
1067
1068fn recipient_list_to_users_and_devices(
1069 recipient_devices: &[DeviceData],
1070) -> BTreeMap<&UserId, BTreeSet<&DeviceId>> {
1071 #[allow(unknown_lints, clippy::unwrap_or_default)] recipient_devices.iter().fold(BTreeMap::new(), |mut acc, d| {
1073 acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id());
1074 acc
1075 })
1076}
1077
1078#[cfg(test)]
1079mod tests {
1080 use std::{
1081 collections::{BTreeMap, BTreeSet},
1082 iter,
1083 ops::Deref,
1084 sync::Arc,
1085 };
1086
1087 use assert_matches2::assert_let;
1088 use matrix_sdk_common::deserialized_responses::{ProcessedToDeviceEvent, WithheldCode};
1089 use matrix_sdk_test::{async_test, ruma_response_from_json};
1090 use ruma::{
1091 DeviceId, OneTimeKeyAlgorithm, OwnedMxcUri, TransactionId, UInt, UserId,
1092 api::client::{
1093 keys::{claim_keys, get_keys, upload_keys},
1094 to_device::send_event_to_device::v3::Response as ToDeviceResponse,
1095 },
1096 device_id,
1097 events::room::{
1098 EncryptedFileInit, JsonWebKey, JsonWebKeyInit, history_visibility::HistoryVisibility,
1099 },
1100 owned_room_id, room_id,
1101 serde::Base64,
1102 to_device::DeviceIdOrAllDevices,
1103 user_id,
1104 };
1105 use serde_json::{Value, json};
1106
1107 use crate::{
1108 DecryptionSettings, EncryptionSettings, LocalTrust, OlmMachine, TrustRequirement,
1109 identities::DeviceData,
1110 machine::{
1111 EncryptionSyncChanges, test_helpers::get_machine_pair_with_setup_sessions_test_helper,
1112 },
1113 olm::{Account, SenderData},
1114 session_manager::{CollectStrategy, group_sessions::CollectRecipientsResult},
1115 types::{
1116 DeviceKeys, EventEncryptionAlgorithm,
1117 events::{
1118 room::encrypted::EncryptedToDeviceEvent,
1119 room_key_bundle::RoomKeyBundleContent,
1120 room_key_withheld::RoomKeyWithheldContent::{self, MegolmV1AesSha2},
1121 },
1122 requests::ToDeviceRequest,
1123 },
1124 };
1125
1126 fn alice_id() -> &'static UserId {
1127 user_id!("@alice:example.org")
1128 }
1129
1130 fn alice_device_id() -> &'static DeviceId {
1131 device_id!("JLAFKJWSCS")
1132 }
1133
1134 fn keys_query_response() -> get_keys::v3::Response {
1136 let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_query.json");
1137 let data: Value = serde_json::from_slice(data).unwrap();
1138 ruma_response_from_json(&data)
1139 }
1140
1141 fn bob_keys_query_response() -> get_keys::v3::Response {
1142 let data = json!({
1143 "device_keys": {
1144 "@bob:localhost": {
1145 "BOBDEVICE": {
1146 "user_id": "@bob:localhost",
1147 "device_id": "BOBDEVICE",
1148 "algorithms": [
1149 "m.olm.v1.curve25519-aes-sha2",
1150 "m.megolm.v1.aes-sha2",
1151 "m.megolm.v2.aes-sha2"
1152 ],
1153 "keys": {
1154 "curve25519:BOBDEVICE": "QzXDFZj0Pt5xG4r11XGSrqE4mnFOTgRM5pz7n3tzohU",
1155 "ed25519:BOBDEVICE": "T7QMEXcEo/NfiC/8doVHT+2XnMm0pDpRa27bmE8PlPI"
1156 },
1157 "signatures": {
1158 "@bob:localhost": {
1159 "ed25519:BOBDEVICE": "1Ee9J02KoVf4DKhT+LkurpZJEygiznqpgkT4lqvMTLtZyzShsVTnwmoMPttuGcJkLp9lMK1egveNYCEaYP80Cw"
1160 }
1161 }
1162 }
1163 }
1164 }
1165 });
1166 ruma_response_from_json(&data)
1167 }
1168
1169 fn bob_one_time_key() -> claim_keys::v3::Response {
1172 let data = json!({
1173 "failures": {},
1174 "one_time_keys":{
1175 "@bob:localhost":{
1176 "BOBDEVICE":{
1177 "signed_curve25519:AAAAAAAAAAA": {
1178 "key":"bm1olfbksjC5SwKxCLLK4XaINCA0FwR/155J85gIpCk",
1179 "signatures":{
1180 "@bob:localhost":{
1181 "ed25519:BOBDEVICE":"BKyS/+EV76zdZkWgny2D0svZ0ycS3etfyHCrsDgm7MYe166HqQmSoX29HsjGLvE/5F+Sg2zW7RJileUvquPwDA"
1182 }
1183 }
1184 }
1185 }
1186 }
1187 }
1188 });
1189 ruma_response_from_json(&data)
1190 }
1191
1192 fn keys_claim_response() -> claim_keys::v3::Response {
1195 let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_claim.json");
1196 let data: Value = serde_json::from_slice(data).unwrap();
1197 ruma_response_from_json(&data)
1198 }
1199
1200 async fn machine_with_user_test_helper(user_id: &UserId, device_id: &DeviceId) -> OlmMachine {
1201 let keys_query = keys_query_response();
1202 let txn_id = TransactionId::new();
1203
1204 let machine = OlmMachine::new(user_id, device_id).await;
1205
1206 machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1208 let (txn_id, _keys_claim_request) = machine
1209 .get_missing_sessions(iter::once(user_id!("@example:localhost")))
1210 .await
1211 .unwrap()
1212 .unwrap();
1213 let keys_claim = keys_claim_response();
1214 machine.mark_request_as_sent(&txn_id, &keys_claim).await.unwrap();
1215
1216 machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1218 let (txn_id, _keys_claim_request) = machine
1219 .get_missing_sessions(iter::once(user_id!("@bob:localhost")))
1220 .await
1221 .unwrap()
1222 .unwrap();
1223 machine.mark_request_as_sent(&txn_id, &bob_one_time_key()).await.unwrap();
1224
1225 machine
1226 }
1227
1228 async fn machine() -> OlmMachine {
1229 machine_with_user_test_helper(alice_id(), alice_device_id()).await
1230 }
1231
1232 async fn machine_with_shared_room_key_test_helper() -> OlmMachine {
1233 let machine = machine().await;
1234 let room_id = room_id!("!test:localhost");
1235 let keys_claim = keys_claim_response();
1236
1237 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1238 let requests =
1239 machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
1240
1241 let outbound =
1242 machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1243
1244 assert!(!outbound.pending_requests().is_empty());
1245 assert!(!outbound.shared());
1246
1247 let response = ToDeviceResponse::new();
1248 for request in requests {
1249 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1250 }
1251
1252 assert!(outbound.shared());
1253 assert!(outbound.pending_requests().is_empty());
1254
1255 machine
1256 }
1257
1258 #[async_test]
1259 async fn test_sharing() {
1260 let machine = machine().await;
1261 let room_id = room_id!("!test:localhost");
1262 let keys_claim = keys_claim_response();
1263
1264 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1265
1266 let requests =
1267 machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
1268
1269 let event_count: usize = requests
1270 .iter()
1271 .filter(|r| r.event_type == "m.room.encrypted".into())
1272 .map(|r| r.message_count())
1273 .sum();
1274
1275 assert_eq!(event_count, 148);
1279
1280 let withheld_count: usize = requests
1281 .iter()
1282 .filter(|r| r.event_type == "m.room_key.withheld".into())
1283 .map(|r| r.message_count())
1284 .sum();
1285 assert_eq!(withheld_count, 2);
1286 }
1287
1288 fn count_withheld_from(requests: &[Arc<ToDeviceRequest>], code: WithheldCode) -> usize {
1289 requests
1290 .iter()
1291 .filter(|r| r.event_type == "m.room_key.withheld".into())
1292 .map(|r| {
1293 let mut count = 0;
1294 for message in r.messages.values() {
1296 message.iter().for_each(|(_, content)| {
1297 let withheld: RoomKeyWithheldContent =
1298 content.deserialize_as_unchecked::<RoomKeyWithheldContent>().unwrap();
1299
1300 if let MegolmV1AesSha2(content) = withheld
1301 && content.withheld_code() == code
1302 {
1303 count += 1;
1304 }
1305 })
1306 }
1307 count
1308 })
1309 .sum()
1310 }
1311
1312 #[async_test]
1313 async fn test_no_olm_sent_once() {
1314 let machine = machine().await;
1315 let keys_claim = keys_claim_response();
1316
1317 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1318
1319 let first_room_id = room_id!("!test:localhost");
1320
1321 let requests = machine
1322 .share_room_key(first_room_id, users.to_owned(), EncryptionSettings::default())
1323 .await
1324 .unwrap();
1325
1326 let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1328 assert_eq!(withheld_count, 2);
1329
1330 let new_requests = machine
1333 .share_room_key(first_room_id, users, EncryptionSettings::default())
1334 .await
1335 .unwrap();
1336 let withheld_count: usize = count_withheld_from(&new_requests, WithheldCode::NoOlm);
1337 assert_eq!(withheld_count, 2);
1339
1340 let response = ToDeviceResponse::new();
1341 for request in requests {
1342 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1343 }
1344
1345 let second_room_id = room_id!("!other:localhost");
1348 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1349 let requests = machine
1350 .share_room_key(second_room_id, users, EncryptionSettings::default())
1351 .await
1352 .unwrap();
1353
1354 let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1355 assert_eq!(withheld_count, 0);
1356
1357 }
1360
1361 #[async_test]
1362 async fn test_ratcheted_sharing() {
1363 let machine = machine_with_shared_room_key_test_helper().await;
1364
1365 let room_id = room_id!("!test:localhost");
1366 let late_joiner = user_id!("@bob:localhost");
1367 let keys_claim = keys_claim_response();
1368
1369 let mut users: BTreeSet<_> = keys_claim.one_time_keys.keys().map(Deref::deref).collect();
1370 users.insert(late_joiner);
1371
1372 let requests = machine
1373 .share_room_key(room_id, users.into_iter(), EncryptionSettings::default())
1374 .await
1375 .unwrap();
1376
1377 let event_count: usize = requests
1378 .iter()
1379 .filter(|r| r.event_type == "m.room.encrypted".into())
1380 .map(|r| r.message_count())
1381 .sum();
1382 let outbound =
1383 machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1384
1385 assert_eq!(event_count, 1);
1386 assert!(!outbound.pending_requests().is_empty());
1387 }
1388
1389 #[async_test]
1390 async fn test_changing_encryption_settings() {
1391 let machine = machine_with_shared_room_key_test_helper().await;
1392 let room_id = room_id!("!test:localhost");
1393 let keys_claim = keys_claim_response();
1394
1395 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1396 let outbound =
1397 machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1398
1399 let CollectRecipientsResult { should_rotate, .. } = machine
1400 .inner
1401 .group_session_manager
1402 .collect_session_recipients(users.clone(), &EncryptionSettings::default(), &outbound)
1403 .await
1404 .unwrap();
1405
1406 assert!(!should_rotate);
1407
1408 let settings = EncryptionSettings {
1409 history_visibility: HistoryVisibility::Invited,
1410 ..Default::default()
1411 };
1412
1413 let CollectRecipientsResult { should_rotate, .. } = machine
1414 .inner
1415 .group_session_manager
1416 .collect_session_recipients(users.clone(), &settings, &outbound)
1417 .await
1418 .unwrap();
1419
1420 assert!(should_rotate);
1421
1422 let settings = EncryptionSettings {
1423 algorithm: EventEncryptionAlgorithm::from("m.megolm.v2.aes-sha2"),
1424 ..Default::default()
1425 };
1426
1427 let CollectRecipientsResult { should_rotate, .. } = machine
1428 .inner
1429 .group_session_manager
1430 .collect_session_recipients(users, &settings, &outbound)
1431 .await
1432 .unwrap();
1433
1434 assert!(should_rotate);
1435 }
1436
1437 #[async_test]
1438 async fn test_key_recipient_collecting() {
1439 let user_id = user_id!("@example:localhost");
1442 let device_id = device_id!("TESTDEVICE");
1443 let room_id = room_id!("!test:localhost");
1444
1445 let machine = machine_with_user_test_helper(user_id, device_id).await;
1446
1447 let (outbound, _) = machine
1448 .inner
1449 .group_session_manager
1450 .get_or_create_outbound_session(
1451 room_id,
1452 EncryptionSettings::default(),
1453 SenderData::unknown(),
1454 )
1455 .await
1456 .expect("We should be able to create a new session");
1457 let history_visibility = HistoryVisibility::Joined;
1458 let settings = EncryptionSettings { history_visibility, ..Default::default() };
1459
1460 let users = [user_id].into_iter();
1461
1462 let CollectRecipientsResult { devices: recipients, .. } = machine
1463 .inner
1464 .group_session_manager
1465 .collect_session_recipients(users, &settings, &outbound)
1466 .await
1467 .expect("We should be able to collect the session recipients");
1468
1469 assert!(!recipients[user_id].is_empty());
1470
1471 assert!(
1473 !recipients[user_id]
1474 .iter()
1475 .any(|d| d.user_id() == user_id && d.device_id() == device_id)
1476 );
1477
1478 let settings = EncryptionSettings {
1479 sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1480 ..Default::default()
1481 };
1482 let users = [user_id].into_iter();
1483
1484 let CollectRecipientsResult { devices: recipients, .. } = machine
1485 .inner
1486 .group_session_manager
1487 .collect_session_recipients(users, &settings, &outbound)
1488 .await
1489 .expect("We should be able to collect the session recipients");
1490
1491 assert!(recipients[user_id].is_empty());
1492
1493 let device_id = "AFGUOBTZWM".into();
1494 let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1495 device.set_local_trust(LocalTrust::Verified).await.unwrap();
1496 let users = [user_id].into_iter();
1497
1498 let CollectRecipientsResult { devices: recipients, withheld_devices: withheld, .. } =
1499 machine
1500 .inner
1501 .group_session_manager
1502 .collect_session_recipients(users, &settings, &outbound)
1503 .await
1504 .expect("We should be able to collect the session recipients");
1505
1506 assert!(
1507 recipients[user_id]
1508 .iter()
1509 .any(|d| d.user_id() == user_id && d.device_id() == device_id)
1510 );
1511
1512 let devices = machine.get_user_devices(user_id, None).await.unwrap();
1513 devices
1514 .devices()
1515 .filter(|d| d.device_id() != device_id!("TESTDEVICE"))
1517 .for_each(|d| {
1518 if d.is_blacklisted() {
1519 assert!(withheld.iter().any(|(dev, w)| {
1520 dev.device_id() == d.device_id() && w == &WithheldCode::Blacklisted
1521 }));
1522 } else if !d.is_verified() {
1523 assert!(withheld.iter().any(|(dev, w)| {
1525 dev.device_id() == d.device_id() && w == &WithheldCode::Unverified
1526 }));
1527 }
1528 });
1529
1530 assert_eq!(149, withheld.len());
1531 }
1532
1533 #[async_test]
1534 async fn test_sharing_withheld_only_trusted() {
1535 let machine = machine().await;
1536 let room_id = room_id!("!test:localhost");
1537 let keys_claim = keys_claim_response();
1538
1539 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1540 let settings = EncryptionSettings {
1541 sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1542 ..Default::default()
1543 };
1544
1545 let user_id = user_id!("@example:localhost");
1547 let device_id = "MWFXPINOAO".into();
1548 let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1549 device.set_local_trust(LocalTrust::Verified).await.unwrap();
1550 machine
1551 .get_device(user_id, "MWVTUXDNNM".into(), None)
1552 .await
1553 .unwrap()
1554 .unwrap()
1555 .set_local_trust(LocalTrust::BlackListed)
1556 .await
1557 .unwrap();
1558
1559 let requests = machine.share_room_key(room_id, users, settings).await.unwrap();
1560
1561 let room_key_count =
1563 requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).count();
1564
1565 assert_eq!(1, room_key_count);
1566
1567 let withheld_count =
1568 requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1569 assert_eq!(1, withheld_count);
1571
1572 let event_count: usize = requests
1573 .iter()
1574 .filter(|r| r.event_type == "m.room_key.withheld".into())
1575 .map(|r| r.message_count())
1576 .sum();
1577
1578 assert_eq!(event_count, 149);
1580
1581 let has_blacklist =
1583 requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).any(|r| {
1584 let device_key = DeviceIdOrAllDevices::from(device_id!("MWVTUXDNNM").to_owned());
1585 let content = &r.messages[user_id][&device_key];
1586 let withheld: RoomKeyWithheldContent =
1587 content.deserialize_as_unchecked::<RoomKeyWithheldContent>().unwrap();
1588 if let MegolmV1AesSha2(content) = withheld {
1589 content.withheld_code() == WithheldCode::Blacklisted
1590 } else {
1591 false
1592 }
1593 });
1594
1595 assert!(has_blacklist);
1596 }
1597
1598 #[async_test]
1599 async fn test_no_olm_withheld_only_sent_once() {
1600 let keys_query = keys_query_response();
1601 let txn_id = TransactionId::new();
1602
1603 let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1604
1605 machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1606 machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1607
1608 let first_room = room_id!("!test:localhost");
1609 let second_room = room_id!("!test2:localhost");
1610 let bob_id = user_id!("@bob:localhost");
1611
1612 let settings = EncryptionSettings::default();
1613 let users = [bob_id];
1614
1615 let requests = machine
1616 .share_room_key(first_room, users.into_iter(), settings.to_owned())
1617 .await
1618 .unwrap();
1619
1620 let withheld_count =
1622 requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1623
1624 assert_eq!(withheld_count, 1);
1625 assert_eq!(requests.len(), 1);
1626
1627 let second_requests =
1630 machine.share_room_key(second_room, users.into_iter(), settings).await.unwrap();
1631
1632 let withheld_count =
1633 second_requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1634
1635 assert_eq!(withheld_count, 0);
1636 assert_eq!(second_requests.len(), 0);
1637
1638 let response = ToDeviceResponse::new();
1639
1640 let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1641
1642 assert!(!device.was_withheld_code_sent());
1645
1646 for request in requests {
1647 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1648 }
1649
1650 let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1651
1652 assert!(device.was_withheld_code_sent());
1653 }
1654
1655 #[async_test]
1656 async fn test_resend_session_after_unwedging() {
1657 let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1658 assert_let!(Ok(Some((txn_id, device_keys_request))) = machine.upload_device_keys().await);
1659 let device_keys_response = upload_keys::v3::Response::new(BTreeMap::from([(
1660 OneTimeKeyAlgorithm::SignedCurve25519,
1661 UInt::new(device_keys_request.one_time_keys.len() as u64).unwrap(),
1662 )]));
1663 machine.mark_request_as_sent(&txn_id, &device_keys_response).await.unwrap();
1664
1665 let room_id = room_id!("!test:localhost");
1666
1667 let bob_id = user_id!("@bob:localhost");
1668 let bob_account = Account::new(bob_id);
1669 let keys_query_data = json!({
1670 "device_keys": {
1671 "@bob:localhost": {
1672 bob_account.device_id.clone(): bob_account.device_keys()
1673 }
1674 }
1675 });
1676 let keys_query: get_keys::v3::Response = ruma_response_from_json(&keys_query_data);
1677 let txn_id = TransactionId::new();
1678 machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1679
1680 let alice_device_keys =
1681 device_keys_request.device_keys.unwrap().deserialize_as::<DeviceKeys>().unwrap();
1682 let mut alice_otks = device_keys_request.one_time_keys.iter();
1683 let alice_device = DeviceData::new(alice_device_keys, LocalTrust::Unset);
1684
1685 {
1686 let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1688 let mut session = bob_account
1689 .create_outbound_session(
1690 &alice_device,
1691 &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1692 bob_account.device_keys(),
1693 )
1694 .unwrap();
1695 let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1696
1697 let to_device =
1698 EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1699
1700 let sync_changes = EncryptionSyncChanges {
1702 to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1703 changed_devices: &Default::default(),
1704 one_time_keys_counts: &Default::default(),
1705 unused_fallback_keys: None,
1706 next_batch_token: None,
1707 };
1708
1709 let decryption_settings =
1710 DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1711
1712 let (decrypted, _) =
1713 machine.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1714
1715 assert_eq!(1, decrypted.len());
1716 }
1717
1718 {
1720 let requests = machine
1721 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1722 .await
1723 .unwrap();
1724
1725 let event_count: usize = requests
1727 .iter()
1728 .filter(|r| r.event_type == "m.room.encrypted".into())
1729 .map(|r| r.message_count())
1730 .sum();
1731 assert_eq!(event_count, 1);
1732
1733 let response = ToDeviceResponse::new();
1734 for request in requests {
1735 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1736 }
1737 }
1738
1739 {
1742 let requests = machine
1743 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1744 .await
1745 .unwrap();
1746
1747 let event_count: usize = requests
1748 .iter()
1749 .filter(|r| r.event_type == "m.room.encrypted".into())
1750 .map(|r| r.message_count())
1751 .sum();
1752 assert_eq!(event_count, 0);
1753 }
1754
1755 {
1757 let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1758 let mut session = bob_account
1759 .create_outbound_session(
1760 &alice_device,
1761 &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1762 bob_account.device_keys(),
1763 )
1764 .unwrap();
1765 let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1766
1767 let to_device =
1768 EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1769
1770 let sync_changes = EncryptionSyncChanges {
1772 to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1773 changed_devices: &Default::default(),
1774 one_time_keys_counts: &Default::default(),
1775 unused_fallback_keys: None,
1776 next_batch_token: None,
1777 };
1778
1779 let decryption_settings =
1780 DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1781
1782 let (decrypted, _) =
1783 machine.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1784
1785 assert_eq!(1, decrypted.len());
1786 }
1787
1788 {
1790 let requests = machine
1791 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1792 .await
1793 .unwrap();
1794
1795 let event_count: usize = requests
1796 .iter()
1797 .filter(|r| r.event_type == "m.room.encrypted".into())
1798 .map(|r| r.message_count())
1799 .sum();
1800 assert_eq!(event_count, 1);
1801
1802 let response = ToDeviceResponse::new();
1803 for request in requests {
1804 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1805 }
1806 }
1807
1808 {
1811 let requests = machine
1812 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1813 .await
1814 .unwrap();
1815
1816 let event_count: usize = requests
1817 .iter()
1818 .filter(|r| r.event_type == "m.room.encrypted".into())
1819 .map(|r| r.message_count())
1820 .sum();
1821 assert_eq!(event_count, 0);
1822 }
1823 }
1824
1825 #[async_test]
1826 async fn test_room_key_bundle_sharing() {
1827 let (alice, bob) = get_machine_pair_with_setup_sessions_test_helper(
1828 user_id!("@alice:localhost"),
1829 user_id!("@bob:localhost"),
1830 false,
1831 )
1832 .await;
1833
1834 let device = alice.get_device(bob.user_id(), bob.device_id(), None).await.unwrap().unwrap();
1836 device.set_local_trust(LocalTrust::Verified).await.unwrap();
1837
1838 let content = RoomKeyBundleContent {
1839 room_id: owned_room_id!("!room:id"),
1840 file: (EncryptedFileInit {
1841 url: OwnedMxcUri::from("test"),
1842 key: JsonWebKey::from(JsonWebKeyInit {
1843 kty: "oct".to_owned(),
1844 key_ops: vec!["encrypt".to_owned(), "decrypt".to_owned()],
1845 alg: "A256CTR".to_owned(),
1846 #[allow(clippy::unnecessary_to_owned)]
1847 k: Base64::new(vec![0u8; 0]),
1848 ext: true,
1849 }),
1850 iv: Base64::new(vec![0u8; 0]),
1851 hashes: Default::default(),
1852 v: "".to_owned(),
1853 })
1854 .into(),
1855 };
1856
1857 let requests = alice
1858 .share_room_key_bundle_data(
1859 bob.user_id(),
1860 &CollectStrategy::OnlyTrustedDevices,
1861 content,
1862 )
1863 .await
1864 .unwrap();
1865
1866 let requests: Vec<_> =
1868 requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).collect();
1869 let message_count: usize = requests.iter().map(|r| r.message_count()).sum();
1870 assert_eq!(message_count, 1);
1871
1872 let bob_message = requests[0]
1874 .messages
1875 .get(bob.user_id())
1876 .unwrap()
1877 .get(&(bob.device_id().to_owned().into()))
1878 .unwrap();
1879 let to_device = EncryptedToDeviceEvent::new(
1880 alice.user_id().to_owned(),
1881 bob_message.deserialize_as_unchecked().unwrap(),
1882 );
1883
1884 let sync_changes = EncryptionSyncChanges {
1885 to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1886 changed_devices: &Default::default(),
1887 one_time_keys_counts: &Default::default(),
1888 unused_fallback_keys: None,
1889 next_batch_token: None,
1890 };
1891
1892 let decryption_settings =
1893 DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1894
1895 let (decrypted, _) =
1896 bob.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1897 assert_eq!(1, decrypted.len());
1898 use crate::types::events::EventType;
1899 assert_let!(
1900 ProcessedToDeviceEvent::Decrypted { raw, .. } = decrypted.first().unwrap().clone()
1901 );
1902 assert_eq!(
1903 raw.get_field::<String>("type").unwrap().unwrap(),
1904 RoomKeyBundleContent::EVENT_TYPE,
1905 );
1906 }
1907}