matrix_sdk/test_utils/mocks/
encryption.rs1use std::{
19 collections::BTreeMap,
20 future::Future,
21 sync::{atomic::Ordering, Arc, Mutex},
22};
23
24use assert_matches2::assert_let;
25use matrix_sdk_test::test_json;
26use ruma::{
27 api::client::{
28 keys::upload_signatures::v3::SignedKeys, to_device::send_event_to_device::v3::Messages,
29 },
30 encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
31 events::AnyToDeviceEvent,
32 owned_device_id, owned_user_id,
33 serde::Raw,
34 to_device::DeviceIdOrAllDevices,
35 CrossSigningKeyId, DeviceId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OwnedDeviceId,
36 OwnedOneTimeKeyId, OwnedUserId, UserId,
37};
38use serde_json::json;
39use wiremock::{
40 matchers::{method, path_regex},
41 Mock, MockGuard, Request, ResponseTemplate,
42};
43
44use crate::{
45 crypto::types::events::room::encrypted::EncryptedToDeviceEvent,
46 test_utils::{
47 client::MockClientBuilder,
48 mocks::{Keys, MatrixMockServer},
49 },
50 Client,
51};
52
53pub type PendingToDeviceMessages =
56 BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, Vec<Raw<AnyToDeviceEvent>>>>;
57
58impl MatrixMockServer {
80 pub fn client_builder_for_crypto_end_to_end(
85 &self,
86 user_id: &UserId,
87 device_id: &DeviceId,
88 ) -> MockClientBuilder {
89 let next = self.token_counter.fetch_add(1, Ordering::Relaxed);
91 let access_token = format!("TOKEN_{next}");
92
93 {
94 let mut mappings = self.token_to_user_id_map.lock().unwrap();
95 let auth_string = format!("Bearer {access_token}");
96 mappings.insert(auth_string, user_id.to_owned());
97 }
98
99 MockClientBuilder::new(Some(&self.server.uri())).logged_in_with_token(
100 access_token,
101 user_id.to_owned(),
102 device_id.to_owned(),
103 )
104 }
105
106 pub fn exhaust_one_time_keys(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) {
108 let mut keys = self.keys.lock().unwrap();
109 let known_otks = &mut keys.one_time_keys;
110 known_otks.entry(user_id).or_default().entry(device_id).or_default().clear();
111 }
112
113 pub async fn exchange_e2ee_identities(&self, alice: &Client, bob: &Client) {
116 let alice_user_id = alice.user_id().expect("Alice should have a user ID configured");
117 let bob_user_id = bob.user_id().expect("Bob should have a user ID configured");
118
119 alice.update_tracked_users_for_testing([bob_user_id]).await;
121
122 bob.update_tracked_users_for_testing([alice_user_id]).await;
126
127 self.mock_sync().ok_and_run(alice, |_x| {}).await;
129 self.mock_sync().ok_and_run(bob, |_x| {}).await;
130
131 self.mock_sync().ok_and_run(alice, |_x| {}).await;
134 }
135
136 pub async fn set_up_alice_and_bob_for_encryption(&self) -> (Client, Client) {
139 let alice_user_id = owned_user_id!("@alice:example.org");
140 let alice_device_id = owned_device_id!("4L1C3");
141
142 let alice = self
143 .client_builder_for_crypto_end_to_end(&alice_user_id, &alice_device_id)
144 .build()
145 .await;
146
147 let bob_user_id = owned_user_id!("@bob:example.org");
148 let bob_device_id = owned_device_id!("B0B0B0B0B");
149 let bob =
150 self.client_builder_for_crypto_end_to_end(&bob_user_id, &bob_device_id).build().await;
151
152 self.exchange_e2ee_identities(&alice, &bob).await;
153
154 (alice, bob)
155 }
156
157 pub async fn set_up_carl_for_encryption(&self, alice: &Client, bob: &Client) -> Client {
159 let carl_user_id = owned_user_id!("@carlg:example.org");
160 let carl_device_id = owned_device_id!("CARL_DEVICE");
161
162 let carl =
163 self.client_builder_for_crypto_end_to_end(&carl_user_id, &carl_device_id).build().await;
164
165 self.mock_sync().ok_and_run(&carl, |_| {}).await;
167
168 alice.update_tracked_users_for_testing([carl.user_id().unwrap()]).await;
170
171 bob.update_tracked_users_for_testing([carl.user_id().unwrap()]).await;
173
174 {
176 self.mock_sync().ok_and_run(alice, |_| {}).await;
177 self.mock_sync().ok_and_run(bob, |_| {}).await;
178 }
179
180 carl.update_tracked_users_for_testing([alice.user_id().unwrap(), bob.user_id().unwrap()])
182 .await;
183
184 self.mock_sync().ok_and_run(alice, |_| {}).await;
186
187 carl
188 }
189
190 pub async fn set_up_new_device_for_encryption(
206 &self,
207 existing_client: &Client,
208 device_id: &DeviceId,
209 clients_to_update: Vec<&Client>,
210 ) -> Client {
211 let user_id = existing_client.user_id().unwrap().to_owned();
212 let new_device_id = device_id.to_owned();
213
214 let new_client =
215 self.client_builder_for_crypto_end_to_end(&user_id, &new_device_id).build().await;
216
217 self.mock_sync().ok_and_run(&new_client, |_| {}).await;
219
220 self.mock_sync()
222 .ok_and_run(existing_client, |builder| {
223 builder.add_change_device(&user_id);
224 })
225 .await;
226
227 for client_to_update in clients_to_update {
228 self.mock_sync()
229 .ok_and_run(client_to_update, |builder| {
230 builder.add_change_device(&user_id);
231 })
232 .await;
233 }
234
235 new_client
236 }
237
238 pub async fn mock_crypto_endpoints_preset(&self) {
241 let keys = &self.keys;
242 let token_map = &self.token_to_user_id_map;
243
244 Mock::given(method("POST"))
245 .and(path_regex(r"^/_matrix/client/.*/keys/query"))
246 .respond_with(mock_keys_query(keys.clone()))
247 .mount(&self.server)
248 .await;
249
250 Mock::given(method("POST"))
251 .and(path_regex(r"^/_matrix/client/.*/keys/upload"))
252 .respond_with(mock_keys_upload(keys.clone(), token_map.clone()))
253 .mount(&self.server)
254 .await;
255
256 Mock::given(method("POST"))
257 .and(path_regex(r"^/_matrix/client/.*/keys/device_signing/upload"))
258 .respond_with(mock_keys_device_signing_upload(keys.clone()))
259 .mount(&self.server)
260 .await;
261
262 Mock::given(method("POST"))
263 .and(path_regex(r"^/_matrix/client/.*/keys/signatures/upload"))
264 .respond_with(mock_keys_signature_upload(keys.clone()))
265 .mount(&self.server)
266 .await;
267
268 Mock::given(method("POST"))
269 .and(path_regex(r"^/_matrix/client/.*/keys/claim"))
270 .respond_with(mock_keys_claimed_request(keys.clone()))
271 .mount(&self.server)
272 .await;
273 }
274
275 pub async fn mock_capture_put_to_device(
340 &self,
341 sender_user_id: &UserId,
342 ) -> (MockGuard, impl Future<Output = Raw<EncryptedToDeviceEvent>>) {
343 let (tx, rx) = tokio::sync::oneshot::channel();
344 let tx = Arc::new(Mutex::new(Some(tx)));
345
346 let sender = sender_user_id.to_owned();
347 let guard = Mock::given(method("PUT"))
348 .and(path_regex(r"^/_matrix/client/.*/sendToDevice/m.room.encrypted/.*"))
349 .respond_with(move |req: &Request| {
350 #[derive(Debug, serde::Deserialize)]
351 struct Parameters {
352 messages: Messages,
353 }
354
355 let params: Parameters = req.body_json().unwrap();
356
357 let (_, device_to_content) = params.messages.first_key_value().unwrap();
358 let content = device_to_content.first_key_value().unwrap().1;
359
360 let event = json!({
361 "origin_server_ts": MilliSecondsSinceUnixEpoch::now(),
362 "sender": sender,
363 "type": "m.room.encrypted",
364 "content": content,
365 });
366 let event: Raw<EncryptedToDeviceEvent> = serde_json::from_value(event).unwrap();
367
368 if let Ok(mut guard) = tx.lock() {
369 if let Some(tx) = guard.take() {
370 let _ = tx.send(event);
371 }
372 }
373
374 ResponseTemplate::new(200).set_body_json(&*test_json::EMPTY)
375 })
376 .expect(1)
378 .named("send_to_device")
379 .mount_as_scoped(self.server())
380 .await;
381
382 let future =
383 async move { rx.await.expect("Failed to receive captured value - sender was dropped") };
384
385 (guard, future)
386 }
387
388 pub async fn mock_capture_put_to_device_then_sync_back<'a>(
406 &'a self,
407 sender_user_id: &UserId,
408 recipient: &'a Client,
409 ) -> impl Future<Output = Raw<EncryptedToDeviceEvent>> + 'a {
410 let (guard, sent_event) = self.mock_capture_put_to_device(sender_user_id).await;
411
412 async {
413 let sent_event = sent_event.await;
414 drop(guard);
415 self.mock_sync()
416 .ok_and_run(recipient, |sync_builder| {
417 sync_builder.add_to_device_event(sent_event.deserialize_as().unwrap());
418 })
419 .await;
420
421 sent_event
422 }
423 }
424
425 pub async fn capture_put_to_device_traffic(
429 &self,
430 sender_user_id: &UserId,
431 to_device_queue: Arc<Mutex<PendingToDeviceMessages>>,
432 ) -> MockGuard {
433 let sender = sender_user_id.to_owned();
434
435 Mock::given(method("PUT"))
436 .and(path_regex(r"^/_matrix/client/.*/sendToDevice/([^/]+)/.*"))
437 .respond_with(move |req: &Request| {
438 #[derive(Debug, serde::Deserialize)]
439 struct Parameters {
440 messages: Messages,
441 }
442
443 let params: Parameters = req.body_json().unwrap();
444 let messages = params.messages;
445
446 let event_type = req
448 .url
449 .path_segments()
450 .and_then(|segments| segments.rev().nth(1))
451 .expect("Event type should be captured in the path");
452
453 let mut to_device_queue = to_device_queue.lock().unwrap();
454 for (user_id, device_map) in messages.iter() {
455 for (device_id, content) in device_map.iter() {
456 assert_let!(DeviceIdOrAllDevices::DeviceId(device_id) = device_id);
457
458 let event = json!({
459 "origin_server_ts": MilliSecondsSinceUnixEpoch::now(),
460 "sender": sender,
461 "type": event_type.to_owned(),
462 "content": content,
463 });
464
465 to_device_queue
466 .entry(user_id.to_owned())
467 .or_default()
468 .entry(device_id.to_owned())
469 .or_default()
470 .push(serde_json::from_value(event).unwrap());
471 }
472 }
473
474 ResponseTemplate::new(200).set_body_json(&*test_json::EMPTY)
475 })
476 .mount_as_scoped(self.server())
477 .await
478 }
479
480 pub async fn sync_back_pending_to_device_messages(
486 &self,
487 to_device_queue: Arc<Mutex<PendingToDeviceMessages>>,
488 recipient: &Client,
489 ) {
490 let messages_to_sync = {
491 let to_device_queue = to_device_queue.lock().unwrap();
492 let pending_messages = to_device_queue
493 .get(&recipient.user_id().unwrap().to_owned())
494 .and_then(|treemap| treemap.get(&recipient.device_id().unwrap().to_owned()));
495
496 pending_messages.cloned().unwrap_or_default()
497 };
498
499 for message in messages_to_sync {
500 self.mock_sync()
501 .ok_and_run(recipient, |sync_builder| {
502 sync_builder.add_to_device_event(message.deserialize_as().unwrap());
503 })
504 .await;
505 }
506 }
507}
508
509fn mock_keys_query(keys: Arc<Mutex<Keys>>) -> impl Fn(&Request) -> ResponseTemplate {
514 move |req| {
515 #[derive(Debug, serde::Deserialize)]
516 struct Parameters {
517 device_keys: BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>,
518 }
519
520 let params: Parameters = req.body_json().unwrap();
521
522 let keys = keys.lock().unwrap();
523 let mut device_keys = keys.device.clone();
524 if !params.device_keys.is_empty() {
525 device_keys.retain(|user, key_map| {
526 if let Some(devices) = params.device_keys.get(user) {
527 if !devices.is_empty() {
528 key_map.retain(|key_id, _json| {
529 devices.iter().any(|device_id| &device_id.to_string() == key_id)
530 });
531 }
532 true
533 } else {
534 false
535 }
536 })
537 }
538
539 let master_keys = keys.master.clone();
540 let self_signing_keys = keys.self_signing.clone();
541 let user_signing_keys = keys.user_signing.clone();
542
543 ResponseTemplate::new(200).set_body_json(json!({
544 "device_keys": device_keys,
545 "master_keys": master_keys,
546 "self_signing_keys": self_signing_keys,
547 "user_signing_keys": user_signing_keys,
548 }))
549 }
550}
551
552fn mock_keys_upload(
558 keys: Arc<Mutex<Keys>>,
559 token_to_user_id_map: Arc<Mutex<BTreeMap<String, OwnedUserId>>>,
560) -> impl Fn(&Request) -> ResponseTemplate {
561 move |req: &Request| {
562 #[derive(Debug, serde::Deserialize)]
563 struct Parameters {
564 device_keys: Option<Raw<DeviceKeys>>,
565 one_time_keys: Option<BTreeMap<OwnedOneTimeKeyId, Raw<OneTimeKey>>>,
566 }
567 let bearer_token = req
568 .headers
569 .get(http::header::AUTHORIZATION)
570 .and_then(|header| header.to_str().ok())
571 .expect("This call should be authenticated");
572
573 let params: Parameters = req.body_json().unwrap();
574
575 let tokens = token_to_user_id_map.lock().unwrap();
576 let user_id = tokens.get(bearer_token)
578 .expect("Expect this token to be known, ensure you use `MatrixKeysServer::client_builder_for_crypto_end_to_end`")
579 .to_owned();
580
581 if let Some(new_device_keys) = params.device_keys {
582 let new_device_keys = new_device_keys.deserialize().unwrap();
583
584 let key_id = new_device_keys.device_id.to_string();
585 let mut keys = keys.lock().unwrap();
587 let devices = keys.device.entry(new_device_keys.user_id.clone()).or_default();
588
589 if let Some(device_keys) = devices.get_mut(&key_id) {
591 let mut existing = device_keys.deserialize().unwrap();
592
593 for (uid, sigs) in existing.signatures.iter_mut() {
595 if let Some(new_sigs) = new_device_keys.signatures.get(uid) {
596 sigs.extend(new_sigs.clone());
597 }
598 }
599 for (uid, sigs) in new_device_keys.signatures.iter() {
600 if !existing.signatures.contains_key(uid) {
601 existing.signatures.insert(uid.clone(), sigs.clone());
602 }
603 }
604
605 *device_keys = Raw::new(&existing).unwrap();
606 } else {
607 devices.insert(key_id, Raw::new(&new_device_keys).unwrap());
608 }
609 }
610
611 let mut keys = keys.lock().unwrap();
612
613 if let Some(otks) = params.one_time_keys {
614 for (key_id, raw_otk) in otks {
618 let otk = raw_otk.deserialize().unwrap();
619 match otk {
620 OneTimeKey::SignedKey(signed_key) => {
621 let device_id = signed_key
622 .signatures
623 .first_key_value()
624 .unwrap()
625 .1
626 .keys()
627 .next()
628 .unwrap()
629 .key_name()
630 .to_owned();
631
632 keys.one_time_keys
633 .entry(user_id.clone())
634 .or_default()
635 .entry(device_id)
636 .or_default()
637 .insert(key_id, raw_otk);
638 }
639 OneTimeKey::Key(_) => {
640 }
642 _ => {}
643 }
644 }
645 }
646
647 let otk_count = keys.one_time_keys.get(&user_id).map(|m| m.len()).unwrap_or(0);
648 ResponseTemplate::new(200).set_body_json(json!({
649 "one_time_key_counts": {
650 "signed_curve25519": otk_count,
651 }
652 }))
653 }
654}
655
656fn mock_keys_device_signing_upload(
664 keys: Arc<Mutex<Keys>>,
665) -> impl Fn(&Request) -> ResponseTemplate {
666 move |req: &Request| {
667 #[derive(Debug, serde::Deserialize)]
669 struct Parameters {
670 master_key: Option<Raw<CrossSigningKey>>,
671 self_signing_key: Option<Raw<CrossSigningKey>>,
672 user_signing_key: Option<Raw<CrossSigningKey>>,
673 }
674
675 let params: Parameters = req.body_json().unwrap();
676 assert!(params.master_key.is_some());
677 assert!(params.self_signing_key.is_some());
678 assert!(params.user_signing_key.is_some());
679
680 let mut keys = keys.lock().unwrap();
681
682 if let Some(key) = params.master_key {
683 let deserialized = key.deserialize().unwrap();
684 let user_id = deserialized.user_id;
685 keys.master.insert(user_id, key);
686 }
687
688 if let Some(key) = params.self_signing_key {
689 let deserialized = key.deserialize().unwrap();
690 let user_id = deserialized.user_id;
691 keys.self_signing.insert(user_id, key);
692 }
693
694 if let Some(key) = params.user_signing_key {
695 let deserialized = key.deserialize().unwrap();
696 let user_id = deserialized.user_id;
697 keys.user_signing.insert(user_id, key);
698 }
699
700 ResponseTemplate::new(200).set_body_json(json!({}))
701 }
702}
703
704fn mock_keys_signature_upload(keys: Arc<Mutex<Keys>>) -> impl Fn(&Request) -> ResponseTemplate {
708 move |req: &Request| {
709 #[derive(Debug, serde::Deserialize)]
710 #[serde(transparent)]
711 struct Parameters(BTreeMap<OwnedUserId, SignedKeys>);
712
713 let params: Parameters = req.body_json().unwrap();
714
715 let mut keys = keys.lock().unwrap();
716
717 for (user, signed_keys) in params.0 {
718 for (key_id, raw_key) in signed_keys.iter() {
719 if let Some(existing_master_key) = keys.master.get_mut(&user) {
721 let mut existing = existing_master_key.deserialize().unwrap();
722
723 let target = CrossSigningKeyId::from_parts(
724 ruma::SigningKeyAlgorithm::Ed25519,
725 key_id.try_into().unwrap(),
726 );
727
728 if existing.keys.contains_key(&target) {
729 let param: CrossSigningKey = serde_json::from_str(raw_key.get()).unwrap();
730
731 for (uid, sigs) in existing.signatures.iter_mut() {
732 if let Some(new_sigs) = param.signatures.get(uid) {
733 sigs.extend(new_sigs.clone());
734 }
735 }
736 for (uid, sigs) in param.signatures.iter() {
737 if !existing.signatures.contains_key(uid) {
738 existing.signatures.insert(uid.clone(), sigs.clone());
739 }
740 }
741
742 *existing_master_key = Raw::new(&existing).unwrap();
744 continue;
745 }
746 }
747
748 let known_devices = keys.device.entry(user.clone()).or_default();
752 let device_keys = known_devices
753 .get_mut(key_id)
754 .expect("trying to add a signature for a missing key");
755
756 let param: DeviceKeys = serde_json::from_str(raw_key.get()).unwrap();
757
758 let mut existing: DeviceKeys = device_keys.deserialize().unwrap();
759
760 for (uid, sigs) in existing.signatures.iter_mut() {
761 if let Some(new_sigs) = param.signatures.get(uid) {
762 sigs.extend(new_sigs.clone());
763 }
764 }
765 for (uid, sigs) in param.signatures.iter() {
766 if !existing.signatures.contains_key(uid) {
767 existing.signatures.insert(uid.clone(), sigs.clone());
768 }
769 }
770
771 *device_keys = Raw::new(&existing).unwrap();
772 }
773 }
774
775 ResponseTemplate::new(200).set_body_json(json!({
776 "failures": {}
777 }))
778 }
779}
780
781fn mock_keys_claimed_request(keys: Arc<Mutex<Keys>>) -> impl Fn(&Request) -> ResponseTemplate {
782 move |req: &Request| {
783 #[derive(Debug, serde::Deserialize)]
785 struct Parameters {
786 one_time_keys: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, OneTimeKeyAlgorithm>>,
787 }
788
789 let params: Parameters = req.body_json().unwrap();
790
791 let mut keys = keys.lock().unwrap();
792 let known_otks = &mut keys.one_time_keys;
793
794 let mut found_one_time_keys: BTreeMap<
795 OwnedUserId,
796 BTreeMap<OwnedDeviceId, BTreeMap<OwnedOneTimeKeyId, Raw<OneTimeKey>>>,
797 > = BTreeMap::new();
798
799 for (user, requested_one_time_keys) in params.one_time_keys {
800 for device_id in requested_one_time_keys.keys() {
801 let device_id = device_id.clone();
802 let found_key = known_otks
803 .entry(user.clone())
804 .or_default()
805 .entry(device_id.clone())
806 .or_default()
807 .pop_first();
808 if let Some((id, raw_otk)) = found_key {
809 found_one_time_keys
810 .entry(user.clone())
811 .or_default()
812 .entry(device_id.clone())
813 .or_default()
814 .insert(id, raw_otk.clone());
815 }
816 }
817 }
818
819 ResponseTemplate::new(200).set_body_json(json!({
820 "one_time_keys" : found_one_time_keys
821 }))
822 }
823}