1use std::{
19 collections::BTreeMap,
20 future::Future,
21 sync::{Arc, Mutex, atomic::Ordering},
22};
23
24use assert_matches2::assert_let;
25use matrix_sdk_base::crypto::types::events::room::encrypted::EncryptedToDeviceEvent;
26use matrix_sdk_test::test_json;
27use ruma::{
28 CrossSigningKeyId, DeviceId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OwnedDeviceId,
29 OwnedOneTimeKeyId, OwnedUserId, UserId,
30 api::client::{
31 keys::upload_signatures::v3::SignedKeys, to_device::send_event_to_device::v3::Messages,
32 },
33 encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
34 events::AnyToDeviceEvent,
35 owned_device_id, owned_user_id,
36 serde::Raw,
37 to_device::DeviceIdOrAllDevices,
38};
39use serde_json::json;
40use tracing::Instrument;
41use wiremock::{
42 Mock, MockGuard, Request, ResponseTemplate,
43 matchers::{method, path_regex},
44};
45
46use crate::{
47 Client,
48 test_utils::{
49 client::MockClientBuilder,
50 mocks::{Keys, MatrixMockServer},
51 },
52};
53
54pub type PendingToDeviceMessages =
57 BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, Vec<Raw<AnyToDeviceEvent>>>>;
58
59impl MatrixMockServer {
81 pub fn client_builder_for_crypto_end_to_end(
86 &self,
87 user_id: &UserId,
88 device_id: &DeviceId,
89 ) -> MockClientBuilder {
90 let next = self.token_counter.fetch_add(1, Ordering::Relaxed);
92 let access_token = format!("TOKEN_{next}");
93
94 {
95 let mut mappings = self.token_to_user_id_map.lock().unwrap();
96 let auth_string = format!("Bearer {access_token}");
97 mappings.insert(auth_string, user_id.to_owned());
98 }
99
100 MockClientBuilder::new(Some(&self.server.uri())).logged_in_with_token(
101 access_token,
102 user_id.to_owned(),
103 device_id.to_owned(),
104 )
105 }
106
107 pub fn exhaust_one_time_keys(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) {
109 let mut keys = self.keys.lock().unwrap();
110 let known_otks = &mut keys.one_time_keys;
111 known_otks.entry(user_id).or_default().entry(device_id).or_default().clear();
112 }
113
114 pub async fn exchange_e2ee_identities(&self, alice: &Client, bob: &Client) {
117 let alice_user_id = alice.user_id().expect("Alice should have a user ID configured");
118 let bob_user_id = bob.user_id().expect("Bob should have a user ID configured");
119
120 let alice_span = tracing::info_span!("alice", user_id=%alice_user_id);
121 let bob_span = tracing::info_span!("bob", user_id=%bob_user_id);
122
123 alice.update_tracked_users_for_testing([bob_user_id]).instrument(alice_span.clone()).await;
125
126 bob.update_tracked_users_for_testing([alice_user_id]).instrument(bob_span.clone()).await;
130
131 self.mock_sync().ok_and_run(alice, |_x| {}).instrument(alice_span.clone()).await;
133 self.mock_sync().ok_and_run(bob, |_x| {}).instrument(bob_span).await;
134
135 self.mock_sync().ok_and_run(alice, |_x| {}).instrument(alice_span).await;
138 }
139
140 pub async fn set_up_alice_and_bob_for_encryption(&self) -> (Client, Client) {
143 let alice_user_id = owned_user_id!("@alice:example.org");
144 let alice_device_id = owned_device_id!("4L1C3");
145
146 let alice = self
147 .client_builder_for_crypto_end_to_end(&alice_user_id, &alice_device_id)
148 .build()
149 .await;
150
151 let bob_user_id = owned_user_id!("@bob:example.org");
152 let bob_device_id = owned_device_id!("B0B0B0B0B");
153 let bob =
154 self.client_builder_for_crypto_end_to_end(&bob_user_id, &bob_device_id).build().await;
155
156 self.exchange_e2ee_identities(&alice, &bob).await;
157
158 (alice, bob)
159 }
160
161 pub async fn set_up_carl_for_encryption(&self, alice: &Client, bob: &Client) -> Client {
163 let carl_user_id = owned_user_id!("@carlg:example.org");
164 let carl_device_id = owned_device_id!("CARL_DEVICE");
165
166 let carl =
167 self.client_builder_for_crypto_end_to_end(&carl_user_id, &carl_device_id).build().await;
168
169 self.mock_sync().ok_and_run(&carl, |_| {}).await;
171
172 alice.update_tracked_users_for_testing([carl.user_id().unwrap()]).await;
174
175 bob.update_tracked_users_for_testing([carl.user_id().unwrap()]).await;
177
178 {
180 self.mock_sync().ok_and_run(alice, |_| {}).await;
181 self.mock_sync().ok_and_run(bob, |_| {}).await;
182 }
183
184 carl.update_tracked_users_for_testing([alice.user_id().unwrap(), bob.user_id().unwrap()])
186 .await;
187
188 self.mock_sync().ok_and_run(alice, |_| {}).await;
190
191 carl
192 }
193
194 pub async fn set_up_new_device_for_encryption(
210 &self,
211 existing_client: &Client,
212 device_id: &DeviceId,
213 clients_to_update: Vec<&Client>,
214 ) -> Client {
215 let user_id = existing_client.user_id().unwrap().to_owned();
216 let new_device_id = device_id.to_owned();
217
218 let new_client =
219 self.client_builder_for_crypto_end_to_end(&user_id, &new_device_id).build().await;
220
221 self.mock_sync().ok_and_run(&new_client, |_| {}).await;
223
224 self.mock_sync()
226 .ok_and_run(existing_client, |builder| {
227 builder.add_change_device(&user_id);
228 })
229 .await;
230
231 for client_to_update in clients_to_update {
232 self.mock_sync()
233 .ok_and_run(client_to_update, |builder| {
234 builder.add_change_device(&user_id);
235 })
236 .await;
237 }
238
239 new_client
240 }
241
242 pub async fn mock_crypto_endpoints_preset(&self) {
245 let keys = &self.keys;
246 let token_map = &self.token_to_user_id_map;
247
248 Mock::given(method("POST"))
249 .and(path_regex(r"^/_matrix/client/.*/keys/query"))
250 .respond_with(mock_keys_query(keys.clone()))
251 .mount(&self.server)
252 .await;
253
254 Mock::given(method("POST"))
255 .and(path_regex(r"^/_matrix/client/.*/keys/upload"))
256 .respond_with(mock_keys_upload(keys.clone(), token_map.clone()))
257 .mount(&self.server)
258 .await;
259
260 Mock::given(method("POST"))
261 .and(path_regex(r"^/_matrix/client/.*/keys/device_signing/upload"))
262 .respond_with(mock_keys_device_signing_upload(keys.clone()))
263 .mount(&self.server)
264 .await;
265
266 Mock::given(method("POST"))
267 .and(path_regex(r"^/_matrix/client/.*/keys/signatures/upload"))
268 .respond_with(mock_keys_signature_upload(keys.clone()))
269 .mount(&self.server)
270 .await;
271
272 Mock::given(method("POST"))
273 .and(path_regex(r"^/_matrix/client/.*/keys/claim"))
274 .respond_with(mock_keys_claimed_request(keys.clone()))
275 .mount(&self.server)
276 .await;
277 }
278
279 pub async fn mock_capture_put_to_device(
344 &self,
345 sender_user_id: &UserId,
346 ) -> (MockGuard, impl Future<Output = Raw<EncryptedToDeviceEvent>> + use<>) {
347 let (tx, rx) = tokio::sync::oneshot::channel();
348 let tx = Arc::new(Mutex::new(Some(tx)));
349
350 let sender = sender_user_id.to_owned();
351 let guard = Mock::given(method("PUT"))
352 .and(path_regex(r"^/_matrix/client/.*/sendToDevice/m.room.encrypted/.*"))
353 .respond_with(move |req: &Request| {
354 #[derive(Debug, serde::Deserialize)]
355 struct Parameters {
356 messages: Messages,
357 }
358
359 let params: Parameters = req.body_json().unwrap();
360
361 let (_, device_to_content) = params.messages.first_key_value().unwrap();
362 let content = device_to_content.first_key_value().unwrap().1;
363
364 let event = json!({
365 "origin_server_ts": MilliSecondsSinceUnixEpoch::now(),
366 "sender": sender,
367 "type": "m.room.encrypted",
368 "content": content,
369 });
370 let event: Raw<EncryptedToDeviceEvent> = serde_json::from_value(event).unwrap();
371
372 if let Ok(mut guard) = tx.lock()
373 && let Some(tx) = guard.take()
374 {
375 let _ = tx.send(event);
376 }
377
378 ResponseTemplate::new(200).set_body_json(&*test_json::EMPTY)
379 })
380 .expect(1)
382 .named("send_to_device")
383 .mount_as_scoped(self.server())
384 .await;
385
386 let future =
387 async move { rx.await.expect("Failed to receive captured value - sender was dropped") };
388
389 (guard, future)
390 }
391
392 pub async fn mock_capture_put_to_device_then_sync_back<'a>(
410 &'a self,
411 sender_user_id: &UserId,
412 recipient: &'a Client,
413 ) -> impl Future<Output = Raw<EncryptedToDeviceEvent>> + 'a {
414 let (guard, sent_event) = self.mock_capture_put_to_device(sender_user_id).await;
415
416 async {
417 let sent_event = sent_event.await;
418 drop(guard);
419 self.mock_sync()
420 .ok_and_run(recipient, |sync_builder| {
421 sync_builder.add_to_device_event(sent_event.deserialize_as().unwrap());
422 })
423 .await;
424
425 sent_event
426 }
427 }
428
429 pub async fn capture_put_to_device_traffic(
433 &self,
434 sender_user_id: &UserId,
435 to_device_queue: Arc<Mutex<PendingToDeviceMessages>>,
436 ) -> MockGuard {
437 let sender = sender_user_id.to_owned();
438
439 Mock::given(method("PUT"))
440 .and(path_regex(r"^/_matrix/client/.*/sendToDevice/([^/]+)/.*"))
441 .respond_with(move |req: &Request| {
442 #[derive(Debug, serde::Deserialize)]
443 struct Parameters {
444 messages: Messages,
445 }
446
447 let params: Parameters = req.body_json().unwrap();
448 let messages = params.messages;
449
450 let event_type = req
452 .url
453 .path_segments()
454 .and_then(|segments| segments.rev().nth(1))
455 .expect("Event type should be captured in the path");
456
457 let mut to_device_queue = to_device_queue.lock().unwrap();
458 for (user_id, device_map) in messages.iter() {
459 for (device_id, content) in device_map.iter() {
460 assert_let!(DeviceIdOrAllDevices::DeviceId(device_id) = device_id);
461
462 let event = json!({
463 "origin_server_ts": MilliSecondsSinceUnixEpoch::now(),
464 "sender": sender,
465 "type": event_type.to_owned(),
466 "content": content,
467 });
468
469 to_device_queue
470 .entry(user_id.to_owned())
471 .or_default()
472 .entry(device_id.to_owned())
473 .or_default()
474 .push(serde_json::from_value(event).unwrap());
475 }
476 }
477
478 ResponseTemplate::new(200).set_body_json(&*test_json::EMPTY)
479 })
480 .mount_as_scoped(self.server())
481 .await
482 }
483
484 pub async fn sync_back_pending_to_device_messages(
490 &self,
491 to_device_queue: Arc<Mutex<PendingToDeviceMessages>>,
492 recipient: &Client,
493 ) {
494 let messages_to_sync = {
495 let to_device_queue = to_device_queue.lock().unwrap();
496 let pending_messages = to_device_queue
497 .get(&recipient.user_id().unwrap().to_owned())
498 .and_then(|treemap| treemap.get(&recipient.device_id().unwrap().to_owned()));
499
500 pending_messages.cloned().unwrap_or_default()
501 };
502
503 for message in messages_to_sync {
504 self.mock_sync()
505 .ok_and_run(recipient, |sync_builder| {
506 sync_builder.add_to_device_event(message.deserialize_as().unwrap());
507 })
508 .await;
509 }
510 }
511}
512
513fn mock_keys_query(keys: Arc<Mutex<Keys>>) -> impl Fn(&Request) -> ResponseTemplate {
518 move |req| {
519 #[derive(Debug, serde::Deserialize)]
520 struct Parameters {
521 device_keys: BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>,
522 }
523
524 let params: Parameters = req.body_json().unwrap();
525
526 let keys = keys.lock().unwrap();
527 let mut device_keys = keys.device.clone();
528 if !params.device_keys.is_empty() {
529 device_keys.retain(|user, key_map| {
530 if let Some(devices) = params.device_keys.get(user) {
531 if !devices.is_empty() {
532 key_map.retain(|key_id, _json| {
533 devices.iter().any(|device_id| &device_id.to_string() == key_id)
534 });
535 }
536 true
537 } else {
538 false
539 }
540 })
541 }
542
543 let master_keys = keys.master.clone();
544 let self_signing_keys = keys.self_signing.clone();
545 let user_signing_keys = keys.user_signing.clone();
546
547 ResponseTemplate::new(200).set_body_json(json!({
548 "device_keys": device_keys,
549 "master_keys": master_keys,
550 "self_signing_keys": self_signing_keys,
551 "user_signing_keys": user_signing_keys,
552 }))
553 }
554}
555
556fn mock_keys_upload(
562 keys: Arc<Mutex<Keys>>,
563 token_to_user_id_map: Arc<Mutex<BTreeMap<String, OwnedUserId>>>,
564) -> impl Fn(&Request) -> ResponseTemplate {
565 move |req: &Request| {
566 #[derive(Debug, serde::Deserialize)]
567 struct Parameters {
568 device_keys: Option<Raw<DeviceKeys>>,
569 one_time_keys: Option<BTreeMap<OwnedOneTimeKeyId, Raw<OneTimeKey>>>,
570 }
571 let bearer_token = req
572 .headers
573 .get(http::header::AUTHORIZATION)
574 .and_then(|header| header.to_str().ok())
575 .expect("This call should be authenticated");
576
577 let params: Parameters = req.body_json().unwrap();
578
579 let tokens = token_to_user_id_map.lock().unwrap();
580 let user_id = tokens.get(bearer_token)
582 .expect("Expect this token to be known, ensure you use `MatrixKeysServer::client_builder_for_crypto_end_to_end`")
583 .to_owned();
584
585 if let Some(new_device_keys) = params.device_keys {
586 let new_device_keys = new_device_keys.deserialize().unwrap();
587
588 let key_id = new_device_keys.device_id.to_string();
589 let mut keys = keys.lock().unwrap();
591 let devices = keys.device.entry(new_device_keys.user_id.clone()).or_default();
592
593 if let Some(device_keys) = devices.get_mut(&key_id) {
595 let mut existing = device_keys.deserialize().unwrap();
596
597 for (uid, sigs) in existing.signatures.iter_mut() {
599 if let Some(new_sigs) = new_device_keys.signatures.get(uid) {
600 sigs.extend(new_sigs.clone());
601 }
602 }
603 for (uid, sigs) in new_device_keys.signatures.iter() {
604 if !existing.signatures.contains_key(uid) {
605 existing.signatures.insert(uid.clone(), sigs.clone());
606 }
607 }
608
609 *device_keys = Raw::new(&existing).unwrap();
610 } else {
611 devices.insert(key_id, Raw::new(&new_device_keys).unwrap());
612 }
613 }
614
615 let mut keys = keys.lock().unwrap();
616
617 if let Some(otks) = params.one_time_keys {
618 for (key_id, raw_otk) in otks {
622 let otk = raw_otk.deserialize().unwrap();
623 match otk {
624 OneTimeKey::SignedKey(signed_key) => {
625 let device_id = signed_key
626 .signatures
627 .first_key_value()
628 .unwrap()
629 .1
630 .keys()
631 .next()
632 .unwrap()
633 .key_name()
634 .to_owned();
635
636 keys.one_time_keys
637 .entry(user_id.clone())
638 .or_default()
639 .entry(device_id)
640 .or_default()
641 .insert(key_id, raw_otk);
642 }
643 OneTimeKey::Key(_) => {
644 }
646 _ => {}
647 }
648 }
649 }
650
651 let otk_count = keys.one_time_keys.get(&user_id).map(|m| m.len()).unwrap_or(0);
652 ResponseTemplate::new(200).set_body_json(json!({
653 "one_time_key_counts": {
654 "signed_curve25519": otk_count,
655 }
656 }))
657 }
658}
659
660fn mock_keys_device_signing_upload(
668 keys: Arc<Mutex<Keys>>,
669) -> impl Fn(&Request) -> ResponseTemplate {
670 move |req: &Request| {
671 #[derive(Debug, serde::Deserialize)]
673 struct Parameters {
674 master_key: Option<Raw<CrossSigningKey>>,
675 self_signing_key: Option<Raw<CrossSigningKey>>,
676 user_signing_key: Option<Raw<CrossSigningKey>>,
677 }
678
679 let params: Parameters = req.body_json().unwrap();
680 assert!(params.master_key.is_some());
681 assert!(params.self_signing_key.is_some());
682 assert!(params.user_signing_key.is_some());
683
684 let mut keys = keys.lock().unwrap();
685
686 if let Some(key) = params.master_key {
687 let deserialized = key.deserialize().unwrap();
688 let user_id = deserialized.user_id;
689 keys.master.insert(user_id, key);
690 }
691
692 if let Some(key) = params.self_signing_key {
693 let deserialized = key.deserialize().unwrap();
694 let user_id = deserialized.user_id;
695 keys.self_signing.insert(user_id, key);
696 }
697
698 if let Some(key) = params.user_signing_key {
699 let deserialized = key.deserialize().unwrap();
700 let user_id = deserialized.user_id;
701 keys.user_signing.insert(user_id, key);
702 }
703
704 ResponseTemplate::new(200).set_body_json(json!({}))
705 }
706}
707
708fn mock_keys_signature_upload(keys: Arc<Mutex<Keys>>) -> impl Fn(&Request) -> ResponseTemplate {
712 move |req: &Request| {
713 #[derive(Debug, serde::Deserialize)]
714 #[serde(transparent)]
715 struct Parameters(BTreeMap<OwnedUserId, SignedKeys>);
716
717 let params: Parameters = req.body_json().unwrap();
718
719 let mut keys = keys.lock().unwrap();
720
721 for (user, signed_keys) in params.0 {
722 for (key_id, raw_key) in signed_keys.iter() {
723 if let Some(existing_master_key) = keys.master.get_mut(&user) {
725 let mut existing = existing_master_key.deserialize().unwrap();
726
727 let target = CrossSigningKeyId::from_parts(
728 ruma::SigningKeyAlgorithm::Ed25519,
729 key_id.try_into().unwrap(),
730 );
731
732 if existing.keys.contains_key(&target) {
733 let param: CrossSigningKey = serde_json::from_str(raw_key.get()).unwrap();
734
735 for (uid, sigs) in existing.signatures.iter_mut() {
736 if let Some(new_sigs) = param.signatures.get(uid) {
737 sigs.extend(new_sigs.clone());
738 }
739 }
740 for (uid, sigs) in param.signatures.iter() {
741 if !existing.signatures.contains_key(uid) {
742 existing.signatures.insert(uid.clone(), sigs.clone());
743 }
744 }
745
746 *existing_master_key = Raw::new(&existing).unwrap();
748 continue;
749 }
750 }
751
752 let known_devices = keys.device.entry(user.clone()).or_default();
756 let device_keys = known_devices
757 .get_mut(key_id)
758 .expect("trying to add a signature for a missing key");
759
760 let param: DeviceKeys = serde_json::from_str(raw_key.get()).unwrap();
761
762 let mut existing: DeviceKeys = device_keys.deserialize().unwrap();
763
764 for (uid, sigs) in existing.signatures.iter_mut() {
765 if let Some(new_sigs) = param.signatures.get(uid) {
766 sigs.extend(new_sigs.clone());
767 }
768 }
769 for (uid, sigs) in param.signatures.iter() {
770 if !existing.signatures.contains_key(uid) {
771 existing.signatures.insert(uid.clone(), sigs.clone());
772 }
773 }
774
775 *device_keys = Raw::new(&existing).unwrap();
776 }
777 }
778
779 ResponseTemplate::new(200).set_body_json(json!({
780 "failures": {}
781 }))
782 }
783}
784
785fn mock_keys_claimed_request(keys: Arc<Mutex<Keys>>) -> impl Fn(&Request) -> ResponseTemplate {
786 move |req: &Request| {
787 #[derive(Debug, serde::Deserialize)]
789 struct Parameters {
790 one_time_keys: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, OneTimeKeyAlgorithm>>,
791 }
792
793 let params: Parameters = req.body_json().unwrap();
794
795 let mut keys = keys.lock().unwrap();
796 let known_otks = &mut keys.one_time_keys;
797
798 let mut found_one_time_keys: BTreeMap<
799 OwnedUserId,
800 BTreeMap<OwnedDeviceId, BTreeMap<OwnedOneTimeKeyId, Raw<OneTimeKey>>>,
801 > = BTreeMap::new();
802
803 for (user, requested_one_time_keys) in params.one_time_keys {
804 for device_id in requested_one_time_keys.keys() {
805 let device_id = device_id.clone();
806 let found_key = known_otks
807 .entry(user.clone())
808 .or_default()
809 .entry(device_id.clone())
810 .or_default()
811 .pop_first();
812 if let Some((id, raw_otk)) = found_key {
813 found_one_time_keys
814 .entry(user.clone())
815 .or_default()
816 .entry(device_id.clone())
817 .or_default()
818 .insert(id, raw_otk.clone());
819 }
820 }
821 }
822
823 ResponseTemplate::new(200).set_body_json(json!({
824 "one_time_keys" : found_one_time_keys
825 }))
826 }
827}