1mod share_strategy;
16
17use std::{
18 collections::{BTreeMap, BTreeSet},
19 fmt::Debug,
20 sync::Arc,
21};
22
23use futures_util::future::join_all;
24use itertools::Itertools;
25use matrix_sdk_common::{
26 deserialized_responses::WithheldCode, executor::spawn, locks::RwLock as StdRwLock,
27};
28use ruma::{
29 events::{AnyMessageLikeEventContent, ToDeviceEventType},
30 serde::Raw,
31 to_device::DeviceIdOrAllDevices,
32 OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId,
33};
34pub(crate) use share_strategy::CollectRecipientsResult;
35pub use share_strategy::CollectStrategy;
36use tracing::{debug, error, info, instrument, trace};
37
38use crate::{
39 error::{EventError, MegolmResult, OlmResult},
40 identities::device::MaybeEncryptedRoomKey,
41 olm::{
42 InboundGroupSession, OutboundGroupSession, SenderData, SenderDataFinder, Session,
43 ShareInfo, ShareState,
44 },
45 store::{Changes, CryptoStoreWrapper, Result as StoreResult, Store},
46 types::{events::room::encrypted::RoomEncryptedEventContent, requests::ToDeviceRequest},
47 Device, DeviceData, EncryptionSettings, OlmError,
48};
49
50#[derive(Clone, Debug)]
51pub(crate) struct GroupSessionCache {
52 store: Store,
53 sessions: Arc<StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>>,
54 sessions_being_shared: Arc<StdRwLock<BTreeMap<OwnedTransactionId, OutboundGroupSession>>>,
57}
58
59impl GroupSessionCache {
60 pub(crate) fn new(store: Store) -> Self {
61 Self { store, sessions: Default::default(), sessions_being_shared: Default::default() }
62 }
63
64 pub(crate) fn insert(&self, session: OutboundGroupSession) {
65 self.sessions.write().insert(session.room_id().to_owned(), session);
66 }
67
68 pub async fn get_or_load(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
75 if let Some(s) = self.sessions.read().get(room_id) {
78 return Some(s.clone());
79 }
80
81 match self.store.get_outbound_group_session(room_id).await {
82 Ok(Some(s)) => {
83 {
84 let mut sessions_being_shared = self.sessions_being_shared.write();
85 for request_id in s.pending_request_ids() {
86 sessions_being_shared.insert(request_id, s.clone());
87 }
88 }
89
90 self.sessions.write().insert(room_id.to_owned(), s.clone());
91
92 Some(s)
93 }
94 Ok(None) => None,
95 Err(e) => {
96 error!("Couldn't restore an outbound group session: {e:?}");
97 None
98 }
99 }
100 }
101
102 fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
109 self.sessions.read().get(room_id).cloned()
110 }
111
112 fn has_session_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
114 self.sessions.read().values().any(|s| s.is_withheld_to(device, code))
115 }
116
117 fn remove_from_being_shared(&self, id: &TransactionId) -> Option<OutboundGroupSession> {
118 self.sessions_being_shared.write().remove(id)
119 }
120
121 fn mark_as_being_shared(&self, id: OwnedTransactionId, session: OutboundGroupSession) {
122 self.sessions_being_shared.write().insert(id, session);
123 }
124}
125
126#[derive(Debug, Clone)]
127pub(crate) struct GroupSessionManager {
128 store: Store,
132 sessions: GroupSessionCache,
134}
135
136impl GroupSessionManager {
137 const MAX_TO_DEVICE_MESSAGES: usize = 250;
138
139 pub fn new(store: Store) -> Self {
140 Self { store: store.clone(), sessions: GroupSessionCache::new(store) }
141 }
142
143 pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
144 if let Some(s) = self.sessions.get(room_id) {
145 s.invalidate_session();
146
147 let mut changes = Changes::default();
148 changes.outbound_group_sessions.push(s.clone());
149 self.store.save_changes(changes).await?;
150
151 Ok(true)
152 } else {
153 Ok(false)
154 }
155 }
156
157 pub async fn mark_request_as_sent(&self, request_id: &TransactionId) -> StoreResult<()> {
158 let Some(session) = self.sessions.remove_from_being_shared(request_id) else {
159 return Ok(());
160 };
161
162 let no_olm = session.mark_request_as_sent(request_id);
163
164 let mut changes = Changes::default();
165
166 for (user_id, devices) in &no_olm {
167 for device_id in devices {
168 let device = self.store.get_device(user_id, device_id).await;
169
170 if let Ok(Some(device)) = device {
171 device.mark_withheld_code_as_sent();
172 changes.devices.changed.push(device.inner.clone());
173 } else {
174 error!(
175 ?request_id,
176 "Marking to-device no olm as sent but device not found, might \
177 have been deleted?"
178 );
179 }
180 }
181 }
182
183 changes.outbound_group_sessions.push(session.clone());
184 self.store.save_changes(changes).await
185 }
186
187 #[cfg(test)]
188 pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
189 self.sessions.get(room_id)
190 }
191
192 pub async fn encrypt(
193 &self,
194 room_id: &RoomId,
195 event_type: &str,
196 content: &Raw<AnyMessageLikeEventContent>,
197 ) -> MegolmResult<Raw<RoomEncryptedEventContent>> {
198 let session =
199 self.sessions.get_or_load(room_id).await.expect("Session wasn't created nor shared");
200
201 assert!(!session.expired(), "Session expired");
202
203 let content = session.encrypt(event_type, content).await;
204
205 let mut changes = Changes::default();
206 changes.outbound_group_sessions.push(session);
207 self.store.save_changes(changes).await?;
208
209 Ok(content)
210 }
211
212 pub async fn create_outbound_group_session(
216 &self,
217 room_id: &RoomId,
218 settings: EncryptionSettings,
219 own_sender_data: SenderData,
220 ) -> OlmResult<(OutboundGroupSession, InboundGroupSession)> {
221 let (outbound, inbound) = self
222 .store
223 .static_account()
224 .create_group_session_pair(room_id, settings, own_sender_data)
225 .await
226 .map_err(|_| EventError::UnsupportedAlgorithm)?;
227
228 self.sessions.insert(outbound.clone());
229 Ok((outbound, inbound))
230 }
231
232 pub async fn get_or_create_outbound_session(
233 &self,
234 room_id: &RoomId,
235 settings: EncryptionSettings,
236 own_sender_data: SenderData,
237 ) -> OlmResult<(OutboundGroupSession, Option<InboundGroupSession>)> {
238 let outbound_session = self.sessions.get_or_load(room_id).await;
239
240 if let Some(s) = outbound_session {
243 if s.expired() || s.invalidated() {
244 self.create_outbound_group_session(room_id, settings, own_sender_data)
245 .await
246 .map(|(o, i)| (o, i.into()))
247 } else {
248 Ok((s, None))
249 }
250 } else {
251 self.create_outbound_group_session(room_id, settings, own_sender_data)
252 .await
253 .map(|(o, i)| (o, i.into()))
254 }
255 }
256
257 async fn encrypt_session_for(
260 store: Arc<CryptoStoreWrapper>,
261 group_session: OutboundGroupSession,
262 devices: Vec<DeviceData>,
263 ) -> OlmResult<(
264 OwnedTransactionId,
265 ToDeviceRequest,
266 BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>,
267 Vec<Session>,
268 Vec<(DeviceData, WithheldCode)>,
269 )> {
270 pub struct DeviceResult {
272 device: DeviceData,
273 maybe_encrypted_room_key: MaybeEncryptedRoomKey,
274 }
275
276 let mut messages = BTreeMap::new();
277 let mut changed_sessions = Vec::new();
278 let mut share_infos = BTreeMap::new();
279 let mut withheld_devices = Vec::new();
280
281 let encrypt = |store: Arc<CryptoStoreWrapper>,
284 device: DeviceData,
285 session: OutboundGroupSession| async move {
286 let encryption_result = device.maybe_encrypt_room_key(store.as_ref(), session).await?;
287
288 Ok::<_, OlmError>(DeviceResult { device, maybe_encrypted_room_key: encryption_result })
289 };
290
291 let tasks: Vec<_> = devices
292 .iter()
293 .map(|d| spawn(encrypt(store.clone(), d.clone(), group_session.clone())))
294 .collect();
295
296 let results = join_all(tasks).await;
297
298 for result in results {
299 let result = result.expect("Encryption task panicked")?;
300
301 match result.maybe_encrypted_room_key {
302 MaybeEncryptedRoomKey::Encrypted { used_session, share_info, message } => {
303 changed_sessions.push(used_session);
304
305 let user_id = result.device.user_id().to_owned();
306 let device_id = result.device.device_id().to_owned();
307
308 messages
309 .entry(user_id.to_owned())
310 .or_insert_with(BTreeMap::new)
311 .insert(DeviceIdOrAllDevices::DeviceId(device_id.to_owned()), message);
312
313 share_infos
314 .entry(user_id)
315 .or_insert_with(BTreeMap::new)
316 .insert(device_id, share_info);
317 }
318 MaybeEncryptedRoomKey::Withheld { code } => {
319 withheld_devices.push((result.device, code));
320 }
321 }
322 }
323
324 let txn_id = TransactionId::new();
325 let request = ToDeviceRequest {
326 event_type: ToDeviceEventType::RoomEncrypted,
327 txn_id: txn_id.to_owned(),
328 messages,
329 };
330
331 Ok((txn_id, request, share_infos, changed_sessions, withheld_devices))
332 }
333
334 #[instrument(skip_all)]
341 pub async fn collect_session_recipients(
342 &self,
343 users: impl Iterator<Item = &UserId>,
344 settings: &EncryptionSettings,
345 outbound: &OutboundGroupSession,
346 ) -> OlmResult<CollectRecipientsResult> {
347 share_strategy::collect_session_recipients(&self.store, users, settings, outbound).await
348 }
349
350 async fn encrypt_request(
351 store: Arc<CryptoStoreWrapper>,
352 chunk: Vec<DeviceData>,
353 outbound: OutboundGroupSession,
354 sessions: GroupSessionCache,
355 ) -> OlmResult<(Vec<Session>, Vec<(DeviceData, WithheldCode)>)> {
356 let (id, request, share_infos, used_sessions, no_olm) =
357 Self::encrypt_session_for(store, outbound.clone(), chunk).await?;
358
359 if !request.messages.is_empty() {
360 trace!(
361 recipient_count = request.message_count(),
362 transaction_id = ?id,
363 "Created a to-device request carrying a room_key"
364 );
365
366 outbound.add_request(id.clone(), request.into(), share_infos);
367 sessions.mark_as_being_shared(id, outbound.clone());
368 }
369
370 Ok((used_sessions, no_olm))
371 }
372
373 pub(crate) fn session_cache(&self) -> GroupSessionCache {
374 self.sessions.clone()
375 }
376
377 async fn maybe_rotate_group_session(
378 &self,
379 should_rotate: bool,
380 room_id: &RoomId,
381 outbound: OutboundGroupSession,
382 encryption_settings: EncryptionSettings,
383 changes: &mut Changes,
384 own_device: Option<Device>,
385 ) -> OlmResult<OutboundGroupSession> {
386 Ok(if should_rotate {
387 let old_session_id = outbound.session_id();
388
389 let (outbound, mut inbound) = self
390 .create_outbound_group_session(room_id, encryption_settings, SenderData::unknown())
391 .await?;
392
393 let own_sender_data = if let Some(device) = own_device {
397 SenderDataFinder::find_using_device_data(
398 &self.store,
399 device.inner.clone(),
400 &inbound,
401 )
402 .await?
403 } else {
404 error!("Unable to find our own device!");
405 SenderData::unknown()
406 };
407 inbound.sender_data = own_sender_data;
408
409 changes.outbound_group_sessions.push(outbound.clone());
410 changes.inbound_group_sessions.push(inbound);
411
412 debug!(
413 old_session_id = old_session_id,
414 session_id = outbound.session_id(),
415 "A user or device has left the room since we last sent a \
416 message, or the encryption settings have changed. Rotating the \
417 room key.",
418 );
419
420 outbound
421 } else {
422 outbound
423 })
424 }
425
426 async fn encrypt_for_devices(
427 &self,
428 recipient_devices: Vec<DeviceData>,
429 group_session: &OutboundGroupSession,
430 changes: &mut Changes,
431 ) -> OlmResult<Vec<(DeviceData, WithheldCode)>> {
432 if !recipient_devices.is_empty() {
434 #[allow(unknown_lints, clippy::unwrap_or_default)] let recipients = recipient_devices.iter().fold(BTreeMap::new(), |mut acc, d| {
436 acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id());
437 acc
438 });
439
440 changes.outbound_group_sessions = vec![group_session.clone()];
443
444 let message_index = group_session.message_index().await;
445
446 info!(
447 ?recipients,
448 message_index,
449 room_id = ?group_session.room_id(),
450 session_id = group_session.session_id(),
451 "Trying to encrypt a room key",
452 );
453 }
454
455 let tasks: Vec<_> = recipient_devices
460 .chunks(Self::MAX_TO_DEVICE_MESSAGES)
461 .map(|chunk| {
462 spawn(Self::encrypt_request(
463 self.store.crypto_store(),
464 chunk.to_vec(),
465 group_session.clone(),
466 self.sessions.clone(),
467 ))
468 })
469 .collect();
470
471 let mut withheld_devices = Vec::new();
472
473 for result in join_all(tasks).await {
478 let result = result.expect("Encryption task panicked");
479
480 let (used_sessions, failed_no_olm) = result?;
481
482 changes.sessions.extend(used_sessions);
483 withheld_devices.extend(failed_no_olm);
484 }
485
486 Ok(withheld_devices)
487 }
488
489 fn is_withheld_to(
490 &self,
491 group_session: &OutboundGroupSession,
492 device: &DeviceData,
493 code: &WithheldCode,
494 ) -> bool {
495 if code == &WithheldCode::NoOlm {
513 device.was_withheld_code_sent() || self.sessions.has_session_withheld_to(device, code)
514 } else {
515 group_session.is_withheld_to(device, code)
516 }
517 }
518
519 fn handle_withheld_devices(
520 &self,
521 group_session: &OutboundGroupSession,
522 withheld_devices: Vec<(DeviceData, WithheldCode)>,
523 ) -> OlmResult<()> {
524 let to_content = |code| {
526 let content = group_session.withheld_code(code);
527 Raw::new(&content).expect("We can always serialize a withheld content info").cast()
528 };
529
530 let chunk_to_request = |chunk| {
533 let mut messages = BTreeMap::new();
534 let mut share_infos = BTreeMap::new();
535
536 for (device, code) in chunk {
537 let device: DeviceData = device;
538 let code: WithheldCode = code;
539
540 let user_id = device.user_id().to_owned();
541 let device_id = device.device_id().to_owned();
542
543 let share_info = ShareInfo::new_withheld(code.to_owned());
544 let content = to_content(code);
545
546 messages
547 .entry(user_id.to_owned())
548 .or_insert_with(BTreeMap::new)
549 .insert(DeviceIdOrAllDevices::DeviceId(device_id.to_owned()), content);
550
551 share_infos
552 .entry(user_id)
553 .or_insert_with(BTreeMap::new)
554 .insert(device_id, share_info);
555 }
556
557 let txn_id = TransactionId::new();
558
559 let request = ToDeviceRequest {
560 event_type: ToDeviceEventType::from("m.room_key.withheld"),
561 txn_id,
562 messages,
563 };
564
565 (request, share_infos)
566 };
567
568 let result: Vec<_> = withheld_devices
569 .into_iter()
570 .filter(|(device, code)| !self.is_withheld_to(group_session, device, code))
571 .chunks(Self::MAX_TO_DEVICE_MESSAGES)
572 .into_iter()
573 .map(chunk_to_request)
574 .collect();
575
576 for (request, share_info) in result {
577 if !request.messages.is_empty() {
578 let txn_id = request.txn_id.to_owned();
579 group_session.add_request(txn_id.to_owned(), request.into(), share_info);
580
581 self.sessions.mark_as_being_shared(txn_id, group_session.clone());
582 }
583 }
584
585 Ok(())
586 }
587
588 fn log_room_key_sharing_result(requests: &[Arc<ToDeviceRequest>]) {
589 for request in requests {
590 let message_list = Self::to_device_request_to_log_list(request);
591 info!(
592 request_id = ?request.txn_id,
593 ?message_list,
594 "Created batch of to-device messages of type {}",
595 request.event_type
596 );
597 }
598 }
599
600 fn to_device_request_to_log_list(
604 request: &Arc<ToDeviceRequest>,
605 ) -> Vec<(String, String, String)> {
606 #[derive(serde::Deserialize)]
607 struct ContentStub<'a> {
608 #[serde(borrow, default, rename = "org.matrix.msgid")]
609 message_id: Option<&'a str>,
610 }
611
612 let mut result: Vec<(String, String, String)> = Vec::new();
613
614 for (user_id, device_map) in &request.messages {
615 for (device, content) in device_map {
616 let message_id: Option<&str> = content
617 .deserialize_as::<ContentStub<'_>>()
618 .expect("We should be able to deserialize the content we generated")
619 .message_id;
620
621 result.push((
622 message_id.unwrap_or("<undefined>").to_owned(),
623 user_id.to_string(),
624 device.to_string(),
625 ));
626 }
627 }
628 result
629 }
630
631 #[instrument(skip(self, users, encryption_settings), fields(session_id))]
642 pub async fn share_room_key(
643 &self,
644 room_id: &RoomId,
645 users: impl Iterator<Item = &UserId>,
646 encryption_settings: impl Into<EncryptionSettings>,
647 ) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
648 trace!("Checking if a room key needs to be shared");
649
650 let account = self.store.static_account();
651 let device = self.store.get_device(account.user_id(), account.device_id()).await?;
652
653 let encryption_settings = encryption_settings.into();
654 let mut changes = Changes::default();
655
656 let (outbound, inbound) = self
658 .get_or_create_outbound_session(
659 room_id,
660 encryption_settings.clone(),
661 SenderData::unknown(),
662 )
663 .await?;
664 tracing::Span::current().record("session_id", outbound.session_id());
665
666 if let Some(mut inbound) = inbound {
669 let own_sender_data = if let Some(device) = &device {
673 SenderDataFinder::find_using_device_data(
674 &self.store,
675 device.inner.clone(),
676 &inbound,
677 )
678 .await?
679 } else {
680 error!("Unable to find our own device!");
681 SenderData::unknown()
682 };
683 inbound.sender_data = own_sender_data;
684
685 changes.outbound_group_sessions.push(outbound.clone());
686 changes.inbound_group_sessions.push(inbound);
687 }
688
689 let CollectRecipientsResult { should_rotate, devices, mut withheld_devices } =
693 self.collect_session_recipients(users, &encryption_settings, &outbound).await?;
694
695 let outbound = self
696 .maybe_rotate_group_session(
697 should_rotate,
698 room_id,
699 outbound,
700 encryption_settings,
701 &mut changes,
702 device,
703 )
704 .await?;
705
706 let devices: Vec<_> = devices
709 .into_iter()
710 .flat_map(|(_, d)| {
711 d.into_iter().filter(|d| match outbound.is_shared_with(d) {
712 ShareState::NotShared => true,
713 ShareState::Shared { message_index: _, olm_wedging_index } => {
714 olm_wedging_index < d.olm_wedging_index
721 }
722 _ => false,
723 })
724 })
725 .collect();
726
727 let unable_to_encrypt_devices =
733 self.encrypt_for_devices(devices, &outbound, &mut changes).await?;
734
735 withheld_devices.extend(unable_to_encrypt_devices);
737
738 self.handle_withheld_devices(&outbound, withheld_devices)?;
741
742 let requests = outbound.pending_requests();
746
747 if requests.is_empty() {
748 if !outbound.shared() {
749 debug!("The room key doesn't need to be shared with anyone. Marking as shared.");
750
751 outbound.mark_as_shared();
752 changes.outbound_group_sessions.push(outbound.clone());
753 }
754 } else {
755 Self::log_room_key_sharing_result(&requests)
756 }
757
758 if !changes.is_empty() {
760 let session_count = changes.sessions.len();
761
762 self.store.save_changes(changes).await?;
763
764 trace!(
765 session_count = session_count,
766 "Stored the changed sessions after encrypting an room key"
767 );
768 }
769
770 Ok(requests)
771 }
772}
773
774#[cfg(test)]
775mod tests {
776 use std::{
777 collections::{BTreeMap, BTreeSet},
778 iter,
779 ops::Deref,
780 sync::Arc,
781 };
782
783 use assert_matches2::assert_let;
784 use matrix_sdk_common::deserialized_responses::WithheldCode;
785 use matrix_sdk_test::{async_test, ruma_response_from_json};
786 use ruma::{
787 api::client::{
788 keys::{claim_keys, get_keys, upload_keys},
789 to_device::send_event_to_device::v3::Response as ToDeviceResponse,
790 },
791 device_id,
792 events::room::history_visibility::HistoryVisibility,
793 room_id,
794 to_device::DeviceIdOrAllDevices,
795 user_id, DeviceId, OneTimeKeyAlgorithm, TransactionId, UInt, UserId,
796 };
797 use serde_json::{json, Value};
798
799 use crate::{
800 identities::DeviceData,
801 machine::EncryptionSyncChanges,
802 olm::{Account, SenderData},
803 session_manager::{group_sessions::CollectRecipientsResult, CollectStrategy},
804 types::{
805 events::{
806 room::encrypted::EncryptedToDeviceEvent,
807 room_key_withheld::RoomKeyWithheldContent::{self, MegolmV1AesSha2},
808 },
809 requests::ToDeviceRequest,
810 DeviceKeys, EventEncryptionAlgorithm,
811 },
812 EncryptionSettings, LocalTrust, OlmMachine,
813 };
814
815 fn alice_id() -> &'static UserId {
816 user_id!("@alice:example.org")
817 }
818
819 fn alice_device_id() -> &'static DeviceId {
820 device_id!("JLAFKJWSCS")
821 }
822
823 fn keys_query_response() -> get_keys::v3::Response {
825 let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_query.json");
826 let data: Value = serde_json::from_slice(data).unwrap();
827 ruma_response_from_json(&data)
828 }
829
830 fn bob_keys_query_response() -> get_keys::v3::Response {
831 let data = json!({
832 "device_keys": {
833 "@bob:localhost": {
834 "BOBDEVICE": {
835 "user_id": "@bob:localhost",
836 "device_id": "BOBDEVICE",
837 "algorithms": [
838 "m.olm.v1.curve25519-aes-sha2",
839 "m.megolm.v1.aes-sha2",
840 "m.megolm.v2.aes-sha2"
841 ],
842 "keys": {
843 "curve25519:BOBDEVICE": "QzXDFZj0Pt5xG4r11XGSrqE4mnFOTgRM5pz7n3tzohU",
844 "ed25519:BOBDEVICE": "T7QMEXcEo/NfiC/8doVHT+2XnMm0pDpRa27bmE8PlPI"
845 },
846 "signatures": {
847 "@bob:localhost": {
848 "ed25519:BOBDEVICE": "1Ee9J02KoVf4DKhT+LkurpZJEygiznqpgkT4lqvMTLtZyzShsVTnwmoMPttuGcJkLp9lMK1egveNYCEaYP80Cw"
849 }
850 }
851 }
852 }
853 }
854 });
855 ruma_response_from_json(&data)
856 }
857
858 fn bob_one_time_key() -> claim_keys::v3::Response {
861 let data = json!({
862 "failures": {},
863 "one_time_keys":{
864 "@bob:localhost":{
865 "BOBDEVICE":{
866 "signed_curve25519:AAAAAAAAAAA": {
867 "key":"bm1olfbksjC5SwKxCLLK4XaINCA0FwR/155J85gIpCk",
868 "signatures":{
869 "@bob:localhost":{
870 "ed25519:BOBDEVICE":"BKyS/+EV76zdZkWgny2D0svZ0ycS3etfyHCrsDgm7MYe166HqQmSoX29HsjGLvE/5F+Sg2zW7RJileUvquPwDA"
871 }
872 }
873 }
874 }
875 }
876 }
877 });
878 ruma_response_from_json(&data)
879 }
880
881 fn keys_claim_response() -> claim_keys::v3::Response {
884 let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_claim.json");
885 let data: Value = serde_json::from_slice(data).unwrap();
886 ruma_response_from_json(&data)
887 }
888
889 async fn machine_with_user_test_helper(user_id: &UserId, device_id: &DeviceId) -> OlmMachine {
890 let keys_query = keys_query_response();
891 let txn_id = TransactionId::new();
892
893 let machine = OlmMachine::new(user_id, device_id).await;
894
895 machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
897 let (txn_id, _keys_claim_request) = machine
898 .get_missing_sessions(iter::once(user_id!("@example:localhost")))
899 .await
900 .unwrap()
901 .unwrap();
902 let keys_claim = keys_claim_response();
903 machine.mark_request_as_sent(&txn_id, &keys_claim).await.unwrap();
904
905 machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
907 let (txn_id, _keys_claim_request) = machine
908 .get_missing_sessions(iter::once(user_id!("@bob:localhost")))
909 .await
910 .unwrap()
911 .unwrap();
912 machine.mark_request_as_sent(&txn_id, &bob_one_time_key()).await.unwrap();
913
914 machine
915 }
916
917 async fn machine() -> OlmMachine {
918 machine_with_user_test_helper(alice_id(), alice_device_id()).await
919 }
920
921 async fn machine_with_shared_room_key_test_helper() -> OlmMachine {
922 let machine = machine().await;
923 let room_id = room_id!("!test:localhost");
924 let keys_claim = keys_claim_response();
925
926 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
927 let requests =
928 machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
929
930 let outbound =
931 machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
932
933 assert!(!outbound.pending_requests().is_empty());
934 assert!(!outbound.shared());
935
936 let response = ToDeviceResponse::new();
937 for request in requests {
938 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
939 }
940
941 assert!(outbound.shared());
942 assert!(outbound.pending_requests().is_empty());
943
944 machine
945 }
946
947 #[async_test]
948 async fn test_sharing() {
949 let machine = machine().await;
950 let room_id = room_id!("!test:localhost");
951 let keys_claim = keys_claim_response();
952
953 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
954
955 let requests =
956 machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
957
958 let event_count: usize = requests
959 .iter()
960 .filter(|r| r.event_type == "m.room.encrypted".into())
961 .map(|r| r.message_count())
962 .sum();
963
964 assert_eq!(event_count, 148);
968
969 let withheld_count: usize = requests
970 .iter()
971 .filter(|r| r.event_type == "m.room_key.withheld".into())
972 .map(|r| r.message_count())
973 .sum();
974 assert_eq!(withheld_count, 2);
975 }
976
977 fn count_withheld_from(requests: &[Arc<ToDeviceRequest>], code: WithheldCode) -> usize {
978 requests
979 .iter()
980 .filter(|r| r.event_type == "m.room_key.withheld".into())
981 .map(|r| {
982 let mut count = 0;
983 for message in r.messages.values() {
985 message.iter().for_each(|(_, content)| {
986 let withheld: RoomKeyWithheldContent =
987 content.deserialize_as::<RoomKeyWithheldContent>().unwrap();
988
989 if let MegolmV1AesSha2(content) = withheld {
990 if content.withheld_code() == code {
991 count += 1;
992 }
993 }
994 })
995 }
996 count
997 })
998 .sum()
999 }
1000
1001 #[async_test]
1002 async fn test_no_olm_sent_once() {
1003 let machine = machine().await;
1004 let keys_claim = keys_claim_response();
1005
1006 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1007
1008 let first_room_id = room_id!("!test:localhost");
1009
1010 let requests = machine
1011 .share_room_key(first_room_id, users.to_owned(), EncryptionSettings::default())
1012 .await
1013 .unwrap();
1014
1015 let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1017 assert_eq!(withheld_count, 2);
1018
1019 let new_requests = machine
1022 .share_room_key(first_room_id, users, EncryptionSettings::default())
1023 .await
1024 .unwrap();
1025 let withheld_count: usize = count_withheld_from(&new_requests, WithheldCode::NoOlm);
1026 assert_eq!(withheld_count, 2);
1028
1029 let response = ToDeviceResponse::new();
1030 for request in requests {
1031 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1032 }
1033
1034 let second_room_id = room_id!("!other:localhost");
1037 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1038 let requests = machine
1039 .share_room_key(second_room_id, users, EncryptionSettings::default())
1040 .await
1041 .unwrap();
1042
1043 let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1044 assert_eq!(withheld_count, 0);
1045
1046 }
1049
1050 #[async_test]
1051 async fn test_ratcheted_sharing() {
1052 let machine = machine_with_shared_room_key_test_helper().await;
1053
1054 let room_id = room_id!("!test:localhost");
1055 let late_joiner = user_id!("@bob:localhost");
1056 let keys_claim = keys_claim_response();
1057
1058 let mut users: BTreeSet<_> = keys_claim.one_time_keys.keys().map(Deref::deref).collect();
1059 users.insert(late_joiner);
1060
1061 let requests = machine
1062 .share_room_key(room_id, users.into_iter(), EncryptionSettings::default())
1063 .await
1064 .unwrap();
1065
1066 let event_count: usize = requests
1067 .iter()
1068 .filter(|r| r.event_type == "m.room.encrypted".into())
1069 .map(|r| r.message_count())
1070 .sum();
1071 let outbound =
1072 machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1073
1074 assert_eq!(event_count, 1);
1075 assert!(!outbound.pending_requests().is_empty());
1076 }
1077
1078 #[async_test]
1079 async fn test_changing_encryption_settings() {
1080 let machine = machine_with_shared_room_key_test_helper().await;
1081 let room_id = room_id!("!test:localhost");
1082 let keys_claim = keys_claim_response();
1083
1084 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1085 let outbound =
1086 machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1087
1088 let CollectRecipientsResult { should_rotate, .. } = machine
1089 .inner
1090 .group_session_manager
1091 .collect_session_recipients(users.clone(), &EncryptionSettings::default(), &outbound)
1092 .await
1093 .unwrap();
1094
1095 assert!(!should_rotate);
1096
1097 let settings = EncryptionSettings {
1098 history_visibility: HistoryVisibility::Invited,
1099 ..Default::default()
1100 };
1101
1102 let CollectRecipientsResult { should_rotate, .. } = machine
1103 .inner
1104 .group_session_manager
1105 .collect_session_recipients(users.clone(), &settings, &outbound)
1106 .await
1107 .unwrap();
1108
1109 assert!(should_rotate);
1110
1111 let settings = EncryptionSettings {
1112 algorithm: EventEncryptionAlgorithm::from("m.megolm.v2.aes-sha2"),
1113 ..Default::default()
1114 };
1115
1116 let CollectRecipientsResult { should_rotate, .. } = machine
1117 .inner
1118 .group_session_manager
1119 .collect_session_recipients(users, &settings, &outbound)
1120 .await
1121 .unwrap();
1122
1123 assert!(should_rotate);
1124 }
1125
1126 #[async_test]
1127 async fn test_key_recipient_collecting() {
1128 let user_id = user_id!("@example:localhost");
1131 let device_id = device_id!("TESTDEVICE");
1132 let room_id = room_id!("!test:localhost");
1133
1134 let machine = machine_with_user_test_helper(user_id, device_id).await;
1135
1136 let (outbound, _) = machine
1137 .inner
1138 .group_session_manager
1139 .get_or_create_outbound_session(
1140 room_id,
1141 EncryptionSettings::default(),
1142 SenderData::unknown(),
1143 )
1144 .await
1145 .expect("We should be able to create a new session");
1146 let history_visibility = HistoryVisibility::Joined;
1147 let settings = EncryptionSettings { history_visibility, ..Default::default() };
1148
1149 let users = [user_id].into_iter();
1150
1151 let CollectRecipientsResult { devices: recipients, .. } = machine
1152 .inner
1153 .group_session_manager
1154 .collect_session_recipients(users, &settings, &outbound)
1155 .await
1156 .expect("We should be able to collect the session recipients");
1157
1158 assert!(!recipients[user_id].is_empty());
1159
1160 assert!(!recipients[user_id]
1162 .iter()
1163 .any(|d| d.user_id() == user_id && d.device_id() == device_id));
1164
1165 let settings = EncryptionSettings {
1166 sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1167 ..Default::default()
1168 };
1169 let users = [user_id].into_iter();
1170
1171 let CollectRecipientsResult { devices: recipients, .. } = machine
1172 .inner
1173 .group_session_manager
1174 .collect_session_recipients(users, &settings, &outbound)
1175 .await
1176 .expect("We should be able to collect the session recipients");
1177
1178 assert!(recipients[user_id].is_empty());
1179
1180 let device_id = "AFGUOBTZWM".into();
1181 let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1182 device.set_local_trust(LocalTrust::Verified).await.unwrap();
1183 let users = [user_id].into_iter();
1184
1185 let CollectRecipientsResult { devices: recipients, withheld_devices: withheld, .. } =
1186 machine
1187 .inner
1188 .group_session_manager
1189 .collect_session_recipients(users, &settings, &outbound)
1190 .await
1191 .expect("We should be able to collect the session recipients");
1192
1193 assert!(recipients[user_id]
1194 .iter()
1195 .any(|d| d.user_id() == user_id && d.device_id() == device_id));
1196
1197 let devices = machine.get_user_devices(user_id, None).await.unwrap();
1198 devices
1199 .devices()
1200 .filter(|d| d.device_id() != device_id!("TESTDEVICE"))
1202 .for_each(|d| {
1203 if d.is_blacklisted() {
1204 assert!(withheld.iter().any(|(dev, w)| {
1205 dev.device_id() == d.device_id() && w == &WithheldCode::Blacklisted
1206 }));
1207 } else if !d.is_verified() {
1208 assert!(withheld.iter().any(|(dev, w)| {
1210 dev.device_id() == d.device_id() && w == &WithheldCode::Unverified
1211 }));
1212 }
1213 });
1214
1215 assert_eq!(149, withheld.len());
1216 }
1217
1218 #[async_test]
1219 async fn test_sharing_withheld_only_trusted() {
1220 let machine = machine().await;
1221 let room_id = room_id!("!test:localhost");
1222 let keys_claim = keys_claim_response();
1223
1224 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1225 let settings = EncryptionSettings {
1226 sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1227 ..Default::default()
1228 };
1229
1230 let user_id = user_id!("@example:localhost");
1232 let device_id = "MWFXPINOAO".into();
1233 let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1234 device.set_local_trust(LocalTrust::Verified).await.unwrap();
1235 machine
1236 .get_device(user_id, "MWVTUXDNNM".into(), None)
1237 .await
1238 .unwrap()
1239 .unwrap()
1240 .set_local_trust(LocalTrust::BlackListed)
1241 .await
1242 .unwrap();
1243
1244 let requests = machine.share_room_key(room_id, users, settings).await.unwrap();
1245
1246 let room_key_count =
1248 requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).count();
1249
1250 assert_eq!(1, room_key_count);
1251
1252 let withheld_count =
1253 requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1254 assert_eq!(1, withheld_count);
1256
1257 let event_count: usize = requests
1258 .iter()
1259 .filter(|r| r.event_type == "m.room_key.withheld".into())
1260 .map(|r| r.message_count())
1261 .sum();
1262
1263 assert_eq!(event_count, 149);
1265
1266 let has_blacklist =
1268 requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).any(|r| {
1269 let device_key = DeviceIdOrAllDevices::from(device_id!("MWVTUXDNNM").to_owned());
1270 let content = &r.messages[user_id][&device_key];
1271 let withheld: RoomKeyWithheldContent =
1272 content.deserialize_as::<RoomKeyWithheldContent>().unwrap();
1273 if let MegolmV1AesSha2(content) = withheld {
1274 content.withheld_code() == WithheldCode::Blacklisted
1275 } else {
1276 false
1277 }
1278 });
1279
1280 assert!(has_blacklist);
1281 }
1282
1283 #[async_test]
1284 async fn test_no_olm_withheld_only_sent_once() {
1285 let keys_query = keys_query_response();
1286 let txn_id = TransactionId::new();
1287
1288 let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1289
1290 machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1291 machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1292
1293 let first_room = room_id!("!test:localhost");
1294 let second_room = room_id!("!test2:localhost");
1295 let bob_id = user_id!("@bob:localhost");
1296
1297 let settings = EncryptionSettings::default();
1298 let users = [bob_id];
1299
1300 let requests = machine
1301 .share_room_key(first_room, users.into_iter(), settings.to_owned())
1302 .await
1303 .unwrap();
1304
1305 let withheld_count =
1307 requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1308
1309 assert_eq!(withheld_count, 1);
1310 assert_eq!(requests.len(), 1);
1311
1312 let second_requests =
1315 machine.share_room_key(second_room, users.into_iter(), settings).await.unwrap();
1316
1317 let withheld_count =
1318 second_requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1319
1320 assert_eq!(withheld_count, 0);
1321 assert_eq!(second_requests.len(), 0);
1322
1323 let response = ToDeviceResponse::new();
1324
1325 let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1326
1327 assert!(!device.was_withheld_code_sent());
1330
1331 for request in requests {
1332 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1333 }
1334
1335 let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1336
1337 assert!(device.was_withheld_code_sent());
1338 }
1339
1340 #[async_test]
1341 async fn test_resend_session_after_unwedging() {
1342 let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1343 assert_let!(Ok(Some((txn_id, device_keys_request))) = machine.upload_device_keys().await);
1344 let device_keys_response = upload_keys::v3::Response::new(BTreeMap::from([(
1345 OneTimeKeyAlgorithm::SignedCurve25519,
1346 UInt::new(device_keys_request.one_time_keys.len() as u64).unwrap(),
1347 )]));
1348 machine.mark_request_as_sent(&txn_id, &device_keys_response).await.unwrap();
1349
1350 let room_id = room_id!("!test:localhost");
1351
1352 let bob_id = user_id!("@bob:localhost");
1353 let bob_account = Account::new(bob_id);
1354 let keys_query_data = json!({
1355 "device_keys": {
1356 "@bob:localhost": {
1357 bob_account.device_id.clone(): bob_account.device_keys()
1358 }
1359 }
1360 });
1361 let keys_query: get_keys::v3::Response = ruma_response_from_json(&keys_query_data);
1362 let txn_id = TransactionId::new();
1363 machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1364
1365 let alice_device_keys =
1366 device_keys_request.device_keys.unwrap().deserialize_as::<DeviceKeys>().unwrap();
1367 let mut alice_otks = device_keys_request.one_time_keys.iter();
1368 let alice_device = DeviceData::new(alice_device_keys, LocalTrust::Unset);
1369
1370 {
1371 let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1373 let mut session = bob_account
1374 .create_outbound_session(
1375 &alice_device,
1376 &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1377 bob_account.device_keys(),
1378 )
1379 .unwrap();
1380 let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1381
1382 let to_device =
1383 EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1384
1385 let sync_changes = EncryptionSyncChanges {
1387 to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1388 changed_devices: &Default::default(),
1389 one_time_keys_counts: &Default::default(),
1390 unused_fallback_keys: None,
1391 next_batch_token: None,
1392 };
1393 let (decrypted, _) = machine.receive_sync_changes(sync_changes).await.unwrap();
1394
1395 assert_eq!(1, decrypted.len());
1396 }
1397
1398 {
1400 let requests = machine
1401 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1402 .await
1403 .unwrap();
1404
1405 let event_count: usize = requests
1407 .iter()
1408 .filter(|r| r.event_type == "m.room.encrypted".into())
1409 .map(|r| r.message_count())
1410 .sum();
1411 assert_eq!(event_count, 1);
1412
1413 let response = ToDeviceResponse::new();
1414 for request in requests {
1415 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1416 }
1417 }
1418
1419 {
1422 let requests = machine
1423 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1424 .await
1425 .unwrap();
1426
1427 let event_count: usize = requests
1428 .iter()
1429 .filter(|r| r.event_type == "m.room.encrypted".into())
1430 .map(|r| r.message_count())
1431 .sum();
1432 assert_eq!(event_count, 0);
1433 }
1434
1435 {
1437 let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1438 let mut session = bob_account
1439 .create_outbound_session(
1440 &alice_device,
1441 &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1442 bob_account.device_keys(),
1443 )
1444 .unwrap();
1445 let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1446
1447 let to_device =
1448 EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1449
1450 let sync_changes = EncryptionSyncChanges {
1452 to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1453 changed_devices: &Default::default(),
1454 one_time_keys_counts: &Default::default(),
1455 unused_fallback_keys: None,
1456 next_batch_token: None,
1457 };
1458 let (decrypted, _) = machine.receive_sync_changes(sync_changes).await.unwrap();
1459
1460 assert_eq!(1, decrypted.len());
1461 }
1462
1463 {
1465 let requests = machine
1466 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1467 .await
1468 .unwrap();
1469
1470 let event_count: usize = requests
1471 .iter()
1472 .filter(|r| r.event_type == "m.room.encrypted".into())
1473 .map(|r| r.message_count())
1474 .sum();
1475 assert_eq!(event_count, 1);
1476
1477 let response = ToDeviceResponse::new();
1478 for request in requests {
1479 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1480 }
1481 }
1482
1483 {
1486 let requests = machine
1487 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1488 .await
1489 .unwrap();
1490
1491 let event_count: usize = requests
1492 .iter()
1493 .filter(|r| r.event_type == "m.room.encrypted".into())
1494 .map(|r| r.message_count())
1495 .sum();
1496 assert_eq!(event_count, 0);
1497 }
1498 }
1499}