matrix_sdk/test_utils/mocks/
encryption.rs

1// Copyright 2024 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Helpers to mock a server that supports the main crypto API and have a client
16//! automatically connected to that server, for the purpose of integration
17//! tests.
18use std::{
19    collections::BTreeMap,
20    future::Future,
21    sync::{atomic::Ordering, Arc, Mutex},
22};
23
24use assert_matches2::assert_let;
25use matrix_sdk_test::test_json;
26use ruma::{
27    api::client::{
28        keys::upload_signatures::v3::SignedKeys, to_device::send_event_to_device::v3::Messages,
29    },
30    encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
31    events::AnyToDeviceEvent,
32    owned_device_id, owned_user_id,
33    serde::Raw,
34    to_device::DeviceIdOrAllDevices,
35    CrossSigningKeyId, DeviceId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OwnedDeviceId,
36    OwnedOneTimeKeyId, OwnedUserId, UserId,
37};
38use serde_json::json;
39use wiremock::{
40    matchers::{method, path_regex},
41    Mock, MockGuard, Request, ResponseTemplate,
42};
43
44use crate::{
45    crypto::types::events::room::encrypted::EncryptedToDeviceEvent,
46    test_utils::{
47        client::MockClientBuilder,
48        mocks::{Keys, MatrixMockServer},
49    },
50    Client,
51};
52
53/// Stores pending to-device messages for each user and device.
54/// To be used with [`MatrixMockServer::capture_put_to_device_traffic`].
55pub type PendingToDeviceMessages =
56    BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, Vec<Raw<AnyToDeviceEvent>>>>;
57
58/// Extends the `MatrixMockServer` with useful methods to help mocking
59/// matrix crypto API and perform integration test with encryption.
60///
61/// It implements mock endpoints for the `keys/upload`, will store the uploaded
62/// devices and serves them back for incoming `keys/query`. It is also storing
63/// and claiming one-time-keys, allowing to set up working olm sessions.
64///
65/// Adds some helpers like `exhaust_one_time_keys` that allows to simulate a
66/// client running out of otks. More can be added if needed later.
67///
68/// It works like this:
69/// * Start by creating the mock server like this [`MatrixMockServer::new`].
70/// * Then mock the crypto API endpoints
71///   [`MatrixMockServer::mock_crypto_endpoints_preset`].
72/// * Create your test client using
73///   [`MatrixMockServer::client_builder_for_crypto_end_to_end`], this is
74///   important as it will set up an access token that will allow to know what
75///   client is doing what request.
76///
77/// The [`MatrixMockServer::set_up_alice_and_bob_for_encryption`] will set up
78/// two olm machines aware of each other and ready to communicate.
79impl MatrixMockServer {
80    /// Creates a new [`MockClientBuilder`] configured to use this server and
81    /// suitable for usage of the crypto API end points.
82    /// Will create a specific access token and some mapping to the associated
83    /// user_id.
84    pub fn client_builder_for_crypto_end_to_end(
85        &self,
86        user_id: &UserId,
87        device_id: &DeviceId,
88    ) -> MockClientBuilder {
89        // Create an access token and store the token to user_id mapping
90        let next = self.token_counter.fetch_add(1, Ordering::Relaxed);
91        let access_token = format!("TOKEN_{next}");
92
93        {
94            let mut mappings = self.token_to_user_id_map.lock().unwrap();
95            let auth_string = format!("Bearer {access_token}");
96            mappings.insert(auth_string, user_id.to_owned());
97        }
98
99        MockClientBuilder::new(Some(&self.server.uri())).logged_in_with_token(
100            access_token,
101            user_id.to_owned(),
102            device_id.to_owned(),
103        )
104    }
105
106    /// Makes the server forget about all the one-time-keys for that device.
107    pub fn exhaust_one_time_keys(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) {
108        let mut keys = self.keys.lock().unwrap();
109        let known_otks = &mut keys.one_time_keys;
110        known_otks.entry(user_id).or_default().entry(device_id).or_default().clear();
111    }
112
113    /// Ensure that the given clients are aware of each others public
114    /// identities.
115    pub async fn exchange_e2ee_identities(&self, alice: &Client, bob: &Client) {
116        let alice_user_id = alice.user_id().expect("Alice should have a user ID configured");
117        let bob_user_id = bob.user_id().expect("Bob should have a user ID configured");
118
119        // Have Alice track Bob, so she queries his keys later.
120        alice.update_tracked_users_for_testing([bob_user_id]).await;
121
122        // let bob be aware of Alice keys in order to be able to decrypt custom
123        // to-device (the device keys check are deferred for `m.room.key` so this is not
124        // needed for sending room messages for example).
125        bob.update_tracked_users_for_testing([alice_user_id]).await;
126
127        // Have Alice and Bob upload their signed device keys.
128        self.mock_sync().ok_and_run(alice, |_x| {}).await;
129        self.mock_sync().ok_and_run(bob, |_x| {}).await;
130
131        // Run a sync so we do send outgoing requests, including the /keys/query for
132        // getting bob's identity.
133        self.mock_sync().ok_and_run(alice, |_x| {}).await;
134    }
135
136    /// Utility to properly setup two clients. These two clients will know about
137    /// each others (alice will have downloaded bob device keys).
138    pub async fn set_up_alice_and_bob_for_encryption(&self) -> (Client, Client) {
139        let alice_user_id = owned_user_id!("@alice:example.org");
140        let alice_device_id = owned_device_id!("4L1C3");
141
142        let alice = self
143            .client_builder_for_crypto_end_to_end(&alice_user_id, &alice_device_id)
144            .build()
145            .await;
146
147        let bob_user_id = owned_user_id!("@bob:example.org");
148        let bob_device_id = owned_device_id!("B0B0B0B0B");
149        let bob =
150            self.client_builder_for_crypto_end_to_end(&bob_user_id, &bob_device_id).build().await;
151
152        self.exchange_e2ee_identities(&alice, &bob).await;
153
154        (alice, bob)
155    }
156
157    /// Creates a third client for e2e tests.
158    pub async fn set_up_carl_for_encryption(&self, alice: &Client, bob: &Client) -> Client {
159        let carl_user_id = owned_user_id!("@carlg:example.org");
160        let carl_device_id = owned_device_id!("CARL_DEVICE");
161
162        let carl =
163            self.client_builder_for_crypto_end_to_end(&carl_user_id, &carl_device_id).build().await;
164
165        // Let carl upload it's device keys.
166        self.mock_sync().ok_and_run(&carl, |_| {}).await;
167
168        // Have Alice track Carl, so she queries his keys later.
169        alice.update_tracked_users_for_testing([carl.user_id().unwrap()]).await;
170
171        // Have Bob track Carl, so she queries his keys later.
172        bob.update_tracked_users_for_testing([carl.user_id().unwrap()]).await;
173
174        // Have Alice and Bob upload their signed device keys, and download Carl's keys.
175        {
176            self.mock_sync().ok_and_run(alice, |_| {}).await;
177            self.mock_sync().ok_and_run(bob, |_| {}).await;
178        }
179
180        // Let carl be aware of Alice and Bob keys.
181        carl.update_tracked_users_for_testing([alice.user_id().unwrap(), bob.user_id().unwrap()])
182            .await;
183
184        // A last sync for carl to get the keys.
185        self.mock_sync().ok_and_run(alice, |_| {}).await;
186
187        carl
188    }
189
190    /// Creates a new device and returns a new client for it.
191    /// The new and old clients will be aware of each other.
192    ///
193    /// # Arguments
194    ///
195    /// * `existing_client` - The original client for which a new device will be
196    ///   created
197    /// * `device_id` - The device ID to use for the new client
198    /// * `clients_to_update` - A vector of client references that should be
199    ///   notified about the new device. These clients will receive a device
200    ///   list change notification during their next sync.
201    ///
202    /// # Returns
203    ///
204    /// Returns the newly created client instance configured for the new device.
205    pub async fn set_up_new_device_for_encryption(
206        &self,
207        existing_client: &Client,
208        device_id: &DeviceId,
209        clients_to_update: Vec<&Client>,
210    ) -> Client {
211        let user_id = existing_client.user_id().unwrap().to_owned();
212        let new_device_id = device_id.to_owned();
213
214        let new_client =
215            self.client_builder_for_crypto_end_to_end(&user_id, &new_device_id).build().await;
216
217        // sync the keys
218        self.mock_sync().ok_and_run(&new_client, |_| {}).await;
219
220        // Notify existing device of a change
221        self.mock_sync()
222            .ok_and_run(existing_client, |builder| {
223                builder.add_change_device(&user_id);
224            })
225            .await;
226
227        for client_to_update in clients_to_update {
228            self.mock_sync()
229                .ok_and_run(client_to_update, |builder| {
230                    builder.add_change_device(&user_id);
231                })
232                .await;
233        }
234
235        new_client
236    }
237
238    /// Mock up the various crypto API so that it can serve back keys when
239    /// needed
240    pub async fn mock_crypto_endpoints_preset(&self) {
241        let keys = &self.keys;
242        let token_map = &self.token_to_user_id_map;
243
244        Mock::given(method("POST"))
245            .and(path_regex(r"^/_matrix/client/.*/keys/query"))
246            .respond_with(mock_keys_query(keys.clone()))
247            .mount(&self.server)
248            .await;
249
250        Mock::given(method("POST"))
251            .and(path_regex(r"^/_matrix/client/.*/keys/upload"))
252            .respond_with(mock_keys_upload(keys.clone(), token_map.clone()))
253            .mount(&self.server)
254            .await;
255
256        Mock::given(method("POST"))
257            .and(path_regex(r"^/_matrix/client/.*/keys/device_signing/upload"))
258            .respond_with(mock_keys_device_signing_upload(keys.clone()))
259            .mount(&self.server)
260            .await;
261
262        Mock::given(method("POST"))
263            .and(path_regex(r"^/_matrix/client/.*/keys/signatures/upload"))
264            .respond_with(mock_keys_signature_upload(keys.clone()))
265            .mount(&self.server)
266            .await;
267
268        Mock::given(method("POST"))
269            .and(path_regex(r"^/_matrix/client/.*/keys/claim"))
270            .respond_with(mock_keys_claimed_request(keys.clone()))
271            .mount(&self.server)
272            .await;
273    }
274
275    /// Creates a response handler for mocking encrypted to-device message
276    /// requests.
277    ///
278    /// This function creates a response handler that captures encrypted
279    /// to-device messages sent via the `/sendToDevice` endpoint.
280    ///
281    /// # Arguments
282    ///
283    /// * `sender` - The user ID of the message sender
284    ///
285    /// # Returns
286    ///
287    /// Returns a tuple containing:
288    /// - A `MockGuard` the end-point mock is scoped to this guard
289    /// - A `Future` that resolves to a `Raw<EncryptedToDeviceEvent>>`
290    ///   containing the captured encrypted to-device message.
291    ///
292    /// # Examples
293    ///
294    /// ```rust
295    /// # use ruma::{ device_id,  user_id, serde::Raw};
296    /// # use serde_json::json;
297    ///
298    /// # use matrix_sdk_test::async_test;
299    /// # use matrix_sdk::test_utils::mocks::MatrixMockServer;
300    /// #
301    /// #[async_test]
302    /// async fn test_mock_capture_put_to_device() {
303    ///     let server = MatrixMockServer::new().await;
304    ///     server.mock_crypto_endpoints_preset().await;
305    ///
306    ///     let (alice, bob) = server.set_up_alice_and_bob_for_encryption().await;
307    ///     let bob_user_id = bob.user_id().unwrap();
308    ///     let bob_device_id = bob.device_id().unwrap();
309    ///
310    ///     // From the point of view of Alice, Bob now has a device.
311    ///     let alice_bob_device = alice
312    ///         .encryption()
313    ///         .get_device(bob_user_id, bob_device_id)
314    ///         .await
315    ///         .unwrap()
316    ///         .expect("alice sees bob's device");
317    ///
318    ///     let content_raw = Raw::new(&json!({ /*...*/ })).unwrap().cast();
319    ///
320    ///     // Set up the mock to capture encrypted to-device messages
321    ///     let (guard, captured) =
322    ///         server.mock_capture_put_to_device(alice.user_id().unwrap()).await;
323    ///
324    ///     alice
325    ///         .encryption()
326    ///         .encrypt_and_send_raw_to_device(
327    ///             vec![&alice_bob_device],
328    ///             "call.keys",
329    ///             content_raw,
330    ///         )
331    ///         .await
332    ///         .unwrap();
333    ///
334    ///     // this is the captured event as sent by alice!
335    ///     let sent_event = captured.await;
336    ///     drop(guard);
337    /// }
338    /// ```
339    pub async fn mock_capture_put_to_device(
340        &self,
341        sender_user_id: &UserId,
342    ) -> (MockGuard, impl Future<Output = Raw<EncryptedToDeviceEvent>>) {
343        let (tx, rx) = tokio::sync::oneshot::channel();
344        let tx = Arc::new(Mutex::new(Some(tx)));
345
346        let sender = sender_user_id.to_owned();
347        let guard = Mock::given(method("PUT"))
348            .and(path_regex(r"^/_matrix/client/.*/sendToDevice/m.room.encrypted/.*"))
349            .respond_with(move |req: &Request| {
350                #[derive(Debug, serde::Deserialize)]
351                struct Parameters {
352                    messages: Messages,
353                }
354
355                let params: Parameters = req.body_json().unwrap();
356
357                let (_, device_to_content) = params.messages.first_key_value().unwrap();
358                let content = device_to_content.first_key_value().unwrap().1;
359
360                let event = json!({
361                    "origin_server_ts": MilliSecondsSinceUnixEpoch::now(),
362                    "sender": sender,
363                    "type": "m.room.encrypted",
364                    "content": content,
365                });
366                let event: Raw<EncryptedToDeviceEvent> = serde_json::from_value(event).unwrap();
367
368                if let Ok(mut guard) = tx.lock() {
369                    if let Some(tx) = guard.take() {
370                        let _ = tx.send(event);
371                    }
372                }
373
374                ResponseTemplate::new(200).set_body_json(&*test_json::EMPTY)
375            })
376            // Should be called once
377            .expect(1)
378            .named("send_to_device")
379            .mount_as_scoped(self.server())
380            .await;
381
382        let future =
383            async move { rx.await.expect("Failed to receive captured value - sender was dropped") };
384
385        (guard, future)
386    }
387
388    /// Captures a to-device message when it is sent to the mock server and then
389    /// injects it into the recipient's sync response.
390    ///
391    /// This is a utility function that combines capturing an encrypted
392    /// to-device message and delivering it to the recipient through a sync
393    /// response. It's useful for testing end-to-end encryption scenarios
394    /// where you need to verify message delivery and processing.
395    ///
396    /// # Arguments
397    ///
398    /// * `sender_user_id` - The user ID of the message sender
399    /// * `recipient` - The client that will receive the message through sync
400    ///
401    /// # Returns
402    ///
403    /// Returns a `Future` that will resolve when the captured event has been
404    /// fed back down the recipient sync.
405    pub async fn mock_capture_put_to_device_then_sync_back<'a>(
406        &'a self,
407        sender_user_id: &UserId,
408        recipient: &'a Client,
409    ) -> impl Future<Output = Raw<EncryptedToDeviceEvent>> + 'a {
410        let (guard, sent_event) = self.mock_capture_put_to_device(sender_user_id).await;
411
412        async {
413            let sent_event = sent_event.await;
414            drop(guard);
415            self.mock_sync()
416                .ok_and_run(recipient, |sync_builder| {
417                    sync_builder.add_to_device_event(sent_event.deserialize_as().unwrap());
418                })
419                .await;
420
421            sent_event
422        }
423    }
424
425    /// Utility to capture all the `/toDevice` upload traffic and store it in
426    /// a queue to be later used with
427    /// [`MatrixMockServer::sync_back_pending_to_device_messages`].
428    pub async fn capture_put_to_device_traffic(
429        &self,
430        sender_user_id: &UserId,
431        to_device_queue: Arc<Mutex<PendingToDeviceMessages>>,
432    ) -> MockGuard {
433        let sender = sender_user_id.to_owned();
434
435        Mock::given(method("PUT"))
436            .and(path_regex(r"^/_matrix/client/.*/sendToDevice/([^/]+)/.*"))
437            .respond_with(move |req: &Request| {
438                #[derive(Debug, serde::Deserialize)]
439                struct Parameters {
440                    messages: Messages,
441                }
442
443                let params: Parameters = req.body_json().unwrap();
444                let messages = params.messages;
445
446                // Access the captured groups from the path
447                let event_type = req
448                    .url
449                    .path_segments()
450                    .and_then(|segments| segments.rev().nth(1))
451                    .expect("Event type should be captured in the path");
452
453                let mut to_device_queue = to_device_queue.lock().unwrap();
454                for (user_id, device_map) in messages.iter() {
455                    for (device_id, content) in device_map.iter() {
456                        assert_let!(DeviceIdOrAllDevices::DeviceId(device_id) = device_id);
457
458                        let event = json!({
459                            "origin_server_ts": MilliSecondsSinceUnixEpoch::now(),
460                            "sender": sender,
461                            "type": event_type.to_owned(),
462                            "content": content,
463                        });
464
465                        to_device_queue
466                            .entry(user_id.to_owned())
467                            .or_default()
468                            .entry(device_id.to_owned())
469                            .or_default()
470                            .push(serde_json::from_value(event).unwrap());
471                    }
472                }
473
474                ResponseTemplate::new(200).set_body_json(&*test_json::EMPTY)
475            })
476            .mount_as_scoped(self.server())
477            .await
478    }
479
480    /// Sync the pending to-device messages for this client.
481    ///
482    /// To be used in connection with
483    /// [`MatrixMockServer::capture_put_to_device_traffic`] that is
484    /// capturing the traffic.
485    pub async fn sync_back_pending_to_device_messages(
486        &self,
487        to_device_queue: Arc<Mutex<PendingToDeviceMessages>>,
488        recipient: &Client,
489    ) {
490        let messages_to_sync = {
491            let to_device_queue = to_device_queue.lock().unwrap();
492            let pending_messages = to_device_queue
493                .get(&recipient.user_id().unwrap().to_owned())
494                .and_then(|treemap| treemap.get(&recipient.device_id().unwrap().to_owned()));
495
496            pending_messages.cloned().unwrap_or_default()
497        };
498
499        for message in messages_to_sync {
500            self.mock_sync()
501                .ok_and_run(recipient, |sync_builder| {
502                    sync_builder.add_to_device_event(message.deserialize_as().unwrap());
503                })
504                .await;
505        }
506    }
507}
508
509/// Intercepts a `/keys/query` request and mock its results as returned by an
510/// actual homeserver.
511///
512/// Supports filtering by user id, or no filters at all.
513fn mock_keys_query(keys: Arc<Mutex<Keys>>) -> impl Fn(&Request) -> ResponseTemplate {
514    move |req| {
515        #[derive(Debug, serde::Deserialize)]
516        struct Parameters {
517            device_keys: BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>,
518        }
519
520        let params: Parameters = req.body_json().unwrap();
521
522        let keys = keys.lock().unwrap();
523        let mut device_keys = keys.device.clone();
524        if !params.device_keys.is_empty() {
525            device_keys.retain(|user, key_map| {
526                if let Some(devices) = params.device_keys.get(user) {
527                    if !devices.is_empty() {
528                        key_map.retain(|key_id, _json| {
529                            devices.iter().any(|device_id| &device_id.to_string() == key_id)
530                        });
531                    }
532                    true
533                } else {
534                    false
535                }
536            })
537        }
538
539        let master_keys = keys.master.clone();
540        let self_signing_keys = keys.self_signing.clone();
541        let user_signing_keys = keys.user_signing.clone();
542
543        ResponseTemplate::new(200).set_body_json(json!({
544            "device_keys": device_keys,
545            "master_keys": master_keys,
546            "self_signing_keys": self_signing_keys,
547            "user_signing_keys": user_signing_keys,
548        }))
549    }
550}
551
552/// Intercepts a `/keys/upload` query and mocks the behavior it would have on a
553/// real homeserver.
554///
555/// Inserts all the `DeviceKeys` into `Keys::device_keys`, or if already present
556/// in this mapping, only merge the signatures.
557fn mock_keys_upload(
558    keys: Arc<Mutex<Keys>>,
559    token_to_user_id_map: Arc<Mutex<BTreeMap<String, OwnedUserId>>>,
560) -> impl Fn(&Request) -> ResponseTemplate {
561    move |req: &Request| {
562        #[derive(Debug, serde::Deserialize)]
563        struct Parameters {
564            device_keys: Option<Raw<DeviceKeys>>,
565            one_time_keys: Option<BTreeMap<OwnedOneTimeKeyId, Raw<OneTimeKey>>>,
566        }
567        let bearer_token = req
568            .headers
569            .get(http::header::AUTHORIZATION)
570            .and_then(|header| header.to_str().ok())
571            .expect("This call should be authenticated");
572
573        let params: Parameters = req.body_json().unwrap();
574
575        let tokens = token_to_user_id_map.lock().unwrap();
576        // Get the user
577        let user_id = tokens.get(bearer_token)
578            .expect("Expect this token to be known, ensure you use `MatrixKeysServer::client_builder_for_crypto_end_to_end`")
579            .to_owned();
580
581        if let Some(new_device_keys) = params.device_keys {
582            let new_device_keys = new_device_keys.deserialize().unwrap();
583
584            let key_id = new_device_keys.device_id.to_string();
585            // if known_devices.contains(&key_id) {
586            let mut keys = keys.lock().unwrap();
587            let devices = keys.device.entry(new_device_keys.user_id.clone()).or_default();
588
589            // Either merge signatures if an entry is already present, or insert a new one.
590            if let Some(device_keys) = devices.get_mut(&key_id) {
591                let mut existing = device_keys.deserialize().unwrap();
592
593                // Merge signatures.
594                for (uid, sigs) in existing.signatures.iter_mut() {
595                    if let Some(new_sigs) = new_device_keys.signatures.get(uid) {
596                        sigs.extend(new_sigs.clone());
597                    }
598                }
599                for (uid, sigs) in new_device_keys.signatures.iter() {
600                    if !existing.signatures.contains_key(uid) {
601                        existing.signatures.insert(uid.clone(), sigs.clone());
602                    }
603                }
604
605                *device_keys = Raw::new(&existing).unwrap();
606            } else {
607                devices.insert(key_id, Raw::new(&new_device_keys).unwrap());
608            }
609        }
610
611        let mut keys = keys.lock().unwrap();
612
613        if let Some(otks) = params.one_time_keys {
614            // We need a trick to find out what userId|device this OTK is for.
615            // This is not part of the payload, a real server uses the access token(?)
616            // Let's look at the signatures to find out
617            for (key_id, raw_otk) in otks {
618                let otk = raw_otk.deserialize().unwrap();
619                match otk {
620                    OneTimeKey::SignedKey(signed_key) => {
621                        let device_id = signed_key
622                            .signatures
623                            .first_key_value()
624                            .unwrap()
625                            .1
626                            .keys()
627                            .next()
628                            .unwrap()
629                            .key_name()
630                            .to_owned();
631
632                        keys.one_time_keys
633                            .entry(user_id.clone())
634                            .or_default()
635                            .entry(device_id)
636                            .or_default()
637                            .insert(key_id, raw_otk);
638                    }
639                    OneTimeKey::Key(_) => {
640                        // Ignore this old algorithm,
641                    }
642                    _ => {}
643                }
644            }
645        }
646
647        let otk_count = keys.one_time_keys.get(&user_id).map(|m| m.len()).unwrap_or(0);
648        ResponseTemplate::new(200).set_body_json(json!({
649            "one_time_key_counts": {
650                "signed_curve25519": otk_count,
651            }
652        }))
653    }
654}
655
656/// Mocks a `/keys/device_signing/upload` request for bootstrapping
657/// cross-signing.
658///
659/// Assumes (and asserts) all keys are updated at the same time.
660///
661/// Saves all the different cross-signing keys into their respective fields of
662/// `Keys`.
663fn mock_keys_device_signing_upload(
664    keys: Arc<Mutex<Keys>>,
665) -> impl Fn(&Request) -> ResponseTemplate {
666    move |req: &Request| {
667        // Accept all cross-signing setups by default.
668        #[derive(Debug, serde::Deserialize)]
669        struct Parameters {
670            master_key: Option<Raw<CrossSigningKey>>,
671            self_signing_key: Option<Raw<CrossSigningKey>>,
672            user_signing_key: Option<Raw<CrossSigningKey>>,
673        }
674
675        let params: Parameters = req.body_json().unwrap();
676        assert!(params.master_key.is_some());
677        assert!(params.self_signing_key.is_some());
678        assert!(params.user_signing_key.is_some());
679
680        let mut keys = keys.lock().unwrap();
681
682        if let Some(key) = params.master_key {
683            let deserialized = key.deserialize().unwrap();
684            let user_id = deserialized.user_id;
685            keys.master.insert(user_id, key);
686        }
687
688        if let Some(key) = params.self_signing_key {
689            let deserialized = key.deserialize().unwrap();
690            let user_id = deserialized.user_id;
691            keys.self_signing.insert(user_id, key);
692        }
693
694        if let Some(key) = params.user_signing_key {
695            let deserialized = key.deserialize().unwrap();
696            let user_id = deserialized.user_id;
697            keys.user_signing.insert(user_id, key);
698        }
699
700        ResponseTemplate::new(200).set_body_json(json!({}))
701    }
702}
703
704/// Mocks a `/keys/signatures/upload` request.
705///
706/// Supports merging signatures for master keys or devices keys.
707fn mock_keys_signature_upload(keys: Arc<Mutex<Keys>>) -> impl Fn(&Request) -> ResponseTemplate {
708    move |req: &Request| {
709        #[derive(Debug, serde::Deserialize)]
710        #[serde(transparent)]
711        struct Parameters(BTreeMap<OwnedUserId, SignedKeys>);
712
713        let params: Parameters = req.body_json().unwrap();
714
715        let mut keys = keys.lock().unwrap();
716
717        for (user, signed_keys) in params.0 {
718            for (key_id, raw_key) in signed_keys.iter() {
719                // Try to find a field in keys.master.
720                if let Some(existing_master_key) = keys.master.get_mut(&user) {
721                    let mut existing = existing_master_key.deserialize().unwrap();
722
723                    let target = CrossSigningKeyId::from_parts(
724                        ruma::SigningKeyAlgorithm::Ed25519,
725                        key_id.try_into().unwrap(),
726                    );
727
728                    if existing.keys.contains_key(&target) {
729                        let param: CrossSigningKey = serde_json::from_str(raw_key.get()).unwrap();
730
731                        for (uid, sigs) in existing.signatures.iter_mut() {
732                            if let Some(new_sigs) = param.signatures.get(uid) {
733                                sigs.extend(new_sigs.clone());
734                            }
735                        }
736                        for (uid, sigs) in param.signatures.iter() {
737                            if !existing.signatures.contains_key(uid) {
738                                existing.signatures.insert(uid.clone(), sigs.clone());
739                            }
740                        }
741
742                        // Update in map.
743                        *existing_master_key = Raw::new(&existing).unwrap();
744                        continue;
745                    }
746                }
747
748                // Otherwise, try to find a field in keys.device.
749                // Either merge signatures if an entry is already present, or insert a new
750                // entry.
751                let known_devices = keys.device.entry(user.clone()).or_default();
752                let device_keys = known_devices
753                    .get_mut(key_id)
754                    .expect("trying to add a signature for a missing key");
755
756                let param: DeviceKeys = serde_json::from_str(raw_key.get()).unwrap();
757
758                let mut existing: DeviceKeys = device_keys.deserialize().unwrap();
759
760                for (uid, sigs) in existing.signatures.iter_mut() {
761                    if let Some(new_sigs) = param.signatures.get(uid) {
762                        sigs.extend(new_sigs.clone());
763                    }
764                }
765                for (uid, sigs) in param.signatures.iter() {
766                    if !existing.signatures.contains_key(uid) {
767                        existing.signatures.insert(uid.clone(), sigs.clone());
768                    }
769                }
770
771                *device_keys = Raw::new(&existing).unwrap();
772            }
773        }
774
775        ResponseTemplate::new(200).set_body_json(json!({
776            "failures": {}
777        }))
778    }
779}
780
781fn mock_keys_claimed_request(keys: Arc<Mutex<Keys>>) -> impl Fn(&Request) -> ResponseTemplate {
782    move |req: &Request| {
783        // Accept all cross-signing setups by default.
784        #[derive(Debug, serde::Deserialize)]
785        struct Parameters {
786            one_time_keys: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, OneTimeKeyAlgorithm>>,
787        }
788
789        let params: Parameters = req.body_json().unwrap();
790
791        let mut keys = keys.lock().unwrap();
792        let known_otks = &mut keys.one_time_keys;
793
794        let mut found_one_time_keys: BTreeMap<
795            OwnedUserId,
796            BTreeMap<OwnedDeviceId, BTreeMap<OwnedOneTimeKeyId, Raw<OneTimeKey>>>,
797        > = BTreeMap::new();
798
799        for (user, requested_one_time_keys) in params.one_time_keys {
800            for device_id in requested_one_time_keys.keys() {
801                let device_id = device_id.clone();
802                let found_key = known_otks
803                    .entry(user.clone())
804                    .or_default()
805                    .entry(device_id.clone())
806                    .or_default()
807                    .pop_first();
808                if let Some((id, raw_otk)) = found_key {
809                    found_one_time_keys
810                        .entry(user.clone())
811                        .or_default()
812                        .entry(device_id.clone())
813                        .or_default()
814                        .insert(id, raw_otk.clone());
815                }
816            }
817        }
818
819        ResponseTemplate::new(200).set_body_json(json!({
820            "one_time_keys" : found_one_time_keys
821        }))
822    }
823}