matrix_sdk/test_utils/mocks/
encryption.rs

1// Copyright 2024 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Helpers to mock a server that supports the main crypto API and have a client
16//! automatically connected to that server, for the purpose of integration
17//! tests.
18use std::{
19    collections::BTreeMap,
20    future::Future,
21    sync::{Arc, Mutex, atomic::Ordering},
22};
23
24use assert_matches2::assert_let;
25use matrix_sdk_base::crypto::types::events::room::encrypted::EncryptedToDeviceEvent;
26use matrix_sdk_test::test_json;
27use ruma::{
28    CrossSigningKeyId, DeviceId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OwnedDeviceId,
29    OwnedOneTimeKeyId, OwnedUserId, UserId,
30    api::client::{
31        keys::upload_signatures::v3::SignedKeys, to_device::send_event_to_device::v3::Messages,
32    },
33    encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
34    events::AnyToDeviceEvent,
35    owned_device_id, owned_user_id,
36    serde::Raw,
37    to_device::DeviceIdOrAllDevices,
38};
39use serde_json::json;
40use tracing::Instrument;
41use wiremock::{
42    Mock, MockGuard, Request, ResponseTemplate,
43    matchers::{method, path_regex},
44};
45
46use crate::{
47    Client,
48    test_utils::{
49        client::MockClientBuilder,
50        mocks::{Keys, MatrixMockServer},
51    },
52};
53
54/// Stores pending to-device messages for each user and device.
55/// To be used with [`MatrixMockServer::capture_put_to_device_traffic`].
56pub type PendingToDeviceMessages =
57    BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, Vec<Raw<AnyToDeviceEvent>>>>;
58
59/// Extends the `MatrixMockServer` with useful methods to help mocking
60/// matrix crypto API and perform integration test with encryption.
61///
62/// It implements mock endpoints for the `keys/upload`, will store the uploaded
63/// devices and serves them back for incoming `keys/query`. It is also storing
64/// and claiming one-time-keys, allowing to set up working olm sessions.
65///
66/// Adds some helpers like `exhaust_one_time_keys` that allows to simulate a
67/// client running out of otks. More can be added if needed later.
68///
69/// It works like this:
70/// * Start by creating the mock server like this [`MatrixMockServer::new`].
71/// * Then mock the crypto API endpoints
72///   [`MatrixMockServer::mock_crypto_endpoints_preset`].
73/// * Create your test client using
74///   [`MatrixMockServer::client_builder_for_crypto_end_to_end`], this is
75///   important as it will set up an access token that will allow to know what
76///   client is doing what request.
77///
78/// The [`MatrixMockServer::set_up_alice_and_bob_for_encryption`] will set up
79/// two olm machines aware of each other and ready to communicate.
80impl MatrixMockServer {
81    /// Creates a new [`MockClientBuilder`] configured to use this server and
82    /// suitable for usage of the crypto API end points.
83    /// Will create a specific access token and some mapping to the associated
84    /// user_id.
85    pub fn client_builder_for_crypto_end_to_end(
86        &self,
87        user_id: &UserId,
88        device_id: &DeviceId,
89    ) -> MockClientBuilder {
90        // Create an access token and store the token to user_id mapping
91        let next = self.token_counter.fetch_add(1, Ordering::Relaxed);
92        let access_token = format!("TOKEN_{next}");
93
94        {
95            let mut mappings = self.token_to_user_id_map.lock().unwrap();
96            let auth_string = format!("Bearer {access_token}");
97            mappings.insert(auth_string, user_id.to_owned());
98        }
99
100        MockClientBuilder::new(Some(&self.server.uri())).logged_in_with_token(
101            access_token,
102            user_id.to_owned(),
103            device_id.to_owned(),
104        )
105    }
106
107    /// Makes the server forget about all the one-time-keys for that device.
108    pub fn exhaust_one_time_keys(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) {
109        let mut keys = self.keys.lock().unwrap();
110        let known_otks = &mut keys.one_time_keys;
111        known_otks.entry(user_id).or_default().entry(device_id).or_default().clear();
112    }
113
114    /// Ensure that the given clients are aware of each others public
115    /// identities.
116    pub async fn exchange_e2ee_identities(&self, alice: &Client, bob: &Client) {
117        let alice_user_id = alice.user_id().expect("Alice should have a user ID configured");
118        let bob_user_id = bob.user_id().expect("Bob should have a user ID configured");
119
120        let alice_span = tracing::info_span!("alice", user_id=%alice_user_id);
121        let bob_span = tracing::info_span!("bob", user_id=%bob_user_id);
122
123        // Have Alice track Bob, so she queries his keys later.
124        alice.update_tracked_users_for_testing([bob_user_id]).instrument(alice_span.clone()).await;
125
126        // let bob be aware of Alice keys in order to be able to decrypt custom
127        // to-device (the device keys check are deferred for `m.room.key` so this is not
128        // needed for sending room messages for example).
129        bob.update_tracked_users_for_testing([alice_user_id]).instrument(bob_span.clone()).await;
130
131        // Have Alice and Bob upload their signed device keys.
132        self.mock_sync().ok_and_run(alice, |_x| {}).instrument(alice_span.clone()).await;
133        self.mock_sync().ok_and_run(bob, |_x| {}).instrument(bob_span).await;
134
135        // Run a sync so we do send outgoing requests, including the /keys/query for
136        // getting bob's identity.
137        self.mock_sync().ok_and_run(alice, |_x| {}).instrument(alice_span).await;
138    }
139
140    /// Utility to properly setup two clients. These two clients will know about
141    /// each others (alice will have downloaded bob device keys).
142    pub async fn set_up_alice_and_bob_for_encryption(&self) -> (Client, Client) {
143        let alice_user_id = owned_user_id!("@alice:example.org");
144        let alice_device_id = owned_device_id!("4L1C3");
145
146        let alice = self
147            .client_builder_for_crypto_end_to_end(&alice_user_id, &alice_device_id)
148            .build()
149            .await;
150
151        let bob_user_id = owned_user_id!("@bob:example.org");
152        let bob_device_id = owned_device_id!("B0B0B0B0B");
153        let bob =
154            self.client_builder_for_crypto_end_to_end(&bob_user_id, &bob_device_id).build().await;
155
156        self.exchange_e2ee_identities(&alice, &bob).await;
157
158        (alice, bob)
159    }
160
161    /// Creates a third client for e2e tests.
162    pub async fn set_up_carl_for_encryption(&self, alice: &Client, bob: &Client) -> Client {
163        let carl_user_id = owned_user_id!("@carlg:example.org");
164        let carl_device_id = owned_device_id!("CARL_DEVICE");
165
166        let carl =
167            self.client_builder_for_crypto_end_to_end(&carl_user_id, &carl_device_id).build().await;
168
169        // Let carl upload it's device keys.
170        self.mock_sync().ok_and_run(&carl, |_| {}).await;
171
172        // Have Alice track Carl, so she queries his keys later.
173        alice.update_tracked_users_for_testing([carl.user_id().unwrap()]).await;
174
175        // Have Bob track Carl, so she queries his keys later.
176        bob.update_tracked_users_for_testing([carl.user_id().unwrap()]).await;
177
178        // Have Alice and Bob upload their signed device keys, and download Carl's keys.
179        {
180            self.mock_sync().ok_and_run(alice, |_| {}).await;
181            self.mock_sync().ok_and_run(bob, |_| {}).await;
182        }
183
184        // Let carl be aware of Alice and Bob keys.
185        carl.update_tracked_users_for_testing([alice.user_id().unwrap(), bob.user_id().unwrap()])
186            .await;
187
188        // A last sync for carl to get the keys.
189        self.mock_sync().ok_and_run(alice, |_| {}).await;
190
191        carl
192    }
193
194    /// Creates a new device and returns a new client for it.
195    /// The new and old clients will be aware of each other.
196    ///
197    /// # Arguments
198    ///
199    /// * `existing_client` - The original client for which a new device will be
200    ///   created
201    /// * `device_id` - The device ID to use for the new client
202    /// * `clients_to_update` - A vector of client references that should be
203    ///   notified about the new device. These clients will receive a device
204    ///   list change notification during their next sync.
205    ///
206    /// # Returns
207    ///
208    /// Returns the newly created client instance configured for the new device.
209    pub async fn set_up_new_device_for_encryption(
210        &self,
211        existing_client: &Client,
212        device_id: &DeviceId,
213        clients_to_update: Vec<&Client>,
214    ) -> Client {
215        let user_id = existing_client.user_id().unwrap().to_owned();
216        let new_device_id = device_id.to_owned();
217
218        let new_client =
219            self.client_builder_for_crypto_end_to_end(&user_id, &new_device_id).build().await;
220
221        // sync the keys
222        self.mock_sync().ok_and_run(&new_client, |_| {}).await;
223
224        // Notify existing device of a change
225        self.mock_sync()
226            .ok_and_run(existing_client, |builder| {
227                builder.add_change_device(&user_id);
228            })
229            .await;
230
231        for client_to_update in clients_to_update {
232            self.mock_sync()
233                .ok_and_run(client_to_update, |builder| {
234                    builder.add_change_device(&user_id);
235                })
236                .await;
237        }
238
239        new_client
240    }
241
242    /// Mock up the various crypto API so that it can serve back keys when
243    /// needed
244    pub async fn mock_crypto_endpoints_preset(&self) {
245        let keys = &self.keys;
246        let token_map = &self.token_to_user_id_map;
247
248        Mock::given(method("POST"))
249            .and(path_regex(r"^/_matrix/client/.*/keys/query"))
250            .respond_with(mock_keys_query(keys.clone()))
251            .mount(&self.server)
252            .await;
253
254        Mock::given(method("POST"))
255            .and(path_regex(r"^/_matrix/client/.*/keys/upload"))
256            .respond_with(mock_keys_upload(keys.clone(), token_map.clone()))
257            .mount(&self.server)
258            .await;
259
260        Mock::given(method("POST"))
261            .and(path_regex(r"^/_matrix/client/.*/keys/device_signing/upload"))
262            .respond_with(mock_keys_device_signing_upload(keys.clone()))
263            .mount(&self.server)
264            .await;
265
266        Mock::given(method("POST"))
267            .and(path_regex(r"^/_matrix/client/.*/keys/signatures/upload"))
268            .respond_with(mock_keys_signature_upload(keys.clone()))
269            .mount(&self.server)
270            .await;
271
272        Mock::given(method("POST"))
273            .and(path_regex(r"^/_matrix/client/.*/keys/claim"))
274            .respond_with(mock_keys_claimed_request(keys.clone()))
275            .mount(&self.server)
276            .await;
277    }
278
279    /// Creates a response handler for mocking encrypted to-device message
280    /// requests.
281    ///
282    /// This function creates a response handler that captures encrypted
283    /// to-device messages sent via the `/sendToDevice` endpoint.
284    ///
285    /// # Arguments
286    ///
287    /// * `sender` - The user ID of the message sender
288    ///
289    /// # Returns
290    ///
291    /// Returns a tuple containing:
292    /// - A `MockGuard` the end-point mock is scoped to this guard
293    /// - A `Future` that resolves to a `Raw<EncryptedToDeviceEvent>>`
294    ///   containing the captured encrypted to-device message.
295    ///
296    /// # Examples
297    ///
298    /// ```rust
299    /// # use ruma::{ device_id,  user_id, serde::Raw};
300    /// # use serde_json::json;
301    ///
302    /// # use matrix_sdk_test::async_test;
303    /// # use matrix_sdk::test_utils::mocks::MatrixMockServer;
304    /// #
305    /// #[async_test]
306    /// async fn test_mock_capture_put_to_device() {
307    ///     let server = MatrixMockServer::new().await;
308    ///     server.mock_crypto_endpoints_preset().await;
309    ///
310    ///     let (alice, bob) = server.set_up_alice_and_bob_for_encryption().await;
311    ///     let bob_user_id = bob.user_id().unwrap();
312    ///     let bob_device_id = bob.device_id().unwrap();
313    ///
314    ///     // From the point of view of Alice, Bob now has a device.
315    ///     let alice_bob_device = alice
316    ///         .encryption()
317    ///         .get_device(bob_user_id, bob_device_id)
318    ///         .await
319    ///         .unwrap()
320    ///         .expect("alice sees bob's device");
321    ///
322    ///     let content_raw = Raw::new(&json!({ /*...*/ })).unwrap().cast();
323    ///
324    ///     // Set up the mock to capture encrypted to-device messages
325    ///     let (guard, captured) =
326    ///         server.mock_capture_put_to_device(alice.user_id().unwrap()).await;
327    ///
328    ///     alice
329    ///         .encryption()
330    ///         .encrypt_and_send_raw_to_device(
331    ///             vec![&alice_bob_device],
332    ///             "call.keys",
333    ///             content_raw,
334    ///         )
335    ///         .await
336    ///         .unwrap();
337    ///
338    ///     // this is the captured event as sent by alice!
339    ///     let sent_event = captured.await;
340    ///     drop(guard);
341    /// }
342    /// ```
343    pub async fn mock_capture_put_to_device(
344        &self,
345        sender_user_id: &UserId,
346    ) -> (MockGuard, impl Future<Output = Raw<EncryptedToDeviceEvent>> + use<>) {
347        let (tx, rx) = tokio::sync::oneshot::channel();
348        let tx = Arc::new(Mutex::new(Some(tx)));
349
350        let sender = sender_user_id.to_owned();
351        let guard = Mock::given(method("PUT"))
352            .and(path_regex(r"^/_matrix/client/.*/sendToDevice/m.room.encrypted/.*"))
353            .respond_with(move |req: &Request| {
354                #[derive(Debug, serde::Deserialize)]
355                struct Parameters {
356                    messages: Messages,
357                }
358
359                let params: Parameters = req.body_json().unwrap();
360
361                let (_, device_to_content) = params.messages.first_key_value().unwrap();
362                let content = device_to_content.first_key_value().unwrap().1;
363
364                let event = json!({
365                    "origin_server_ts": MilliSecondsSinceUnixEpoch::now(),
366                    "sender": sender,
367                    "type": "m.room.encrypted",
368                    "content": content,
369                });
370                let event: Raw<EncryptedToDeviceEvent> = serde_json::from_value(event).unwrap();
371
372                if let Ok(mut guard) = tx.lock()
373                    && let Some(tx) = guard.take()
374                {
375                    let _ = tx.send(event);
376                }
377
378                ResponseTemplate::new(200).set_body_json(&*test_json::EMPTY)
379            })
380            // Should be called once
381            .expect(1)
382            .named("send_to_device")
383            .mount_as_scoped(self.server())
384            .await;
385
386        let future =
387            async move { rx.await.expect("Failed to receive captured value - sender was dropped") };
388
389        (guard, future)
390    }
391
392    /// Captures a to-device message when it is sent to the mock server and then
393    /// injects it into the recipient's sync response.
394    ///
395    /// This is a utility function that combines capturing an encrypted
396    /// to-device message and delivering it to the recipient through a sync
397    /// response. It's useful for testing end-to-end encryption scenarios
398    /// where you need to verify message delivery and processing.
399    ///
400    /// # Arguments
401    ///
402    /// * `sender_user_id` - The user ID of the message sender
403    /// * `recipient` - The client that will receive the message through sync
404    ///
405    /// # Returns
406    ///
407    /// Returns a `Future` that will resolve when the captured event has been
408    /// fed back down the recipient sync.
409    pub async fn mock_capture_put_to_device_then_sync_back<'a>(
410        &'a self,
411        sender_user_id: &UserId,
412        recipient: &'a Client,
413    ) -> impl Future<Output = Raw<EncryptedToDeviceEvent>> + 'a {
414        let (guard, sent_event) = self.mock_capture_put_to_device(sender_user_id).await;
415
416        async {
417            let sent_event = sent_event.await;
418            drop(guard);
419            self.mock_sync()
420                .ok_and_run(recipient, |sync_builder| {
421                    sync_builder.add_to_device_event(sent_event.deserialize_as().unwrap());
422                })
423                .await;
424
425            sent_event
426        }
427    }
428
429    /// Utility to capture all the `/toDevice` upload traffic and store it in
430    /// a queue to be later used with
431    /// [`MatrixMockServer::sync_back_pending_to_device_messages`].
432    pub async fn capture_put_to_device_traffic(
433        &self,
434        sender_user_id: &UserId,
435        to_device_queue: Arc<Mutex<PendingToDeviceMessages>>,
436    ) -> MockGuard {
437        let sender = sender_user_id.to_owned();
438
439        Mock::given(method("PUT"))
440            .and(path_regex(r"^/_matrix/client/.*/sendToDevice/([^/]+)/.*"))
441            .respond_with(move |req: &Request| {
442                #[derive(Debug, serde::Deserialize)]
443                struct Parameters {
444                    messages: Messages,
445                }
446
447                let params: Parameters = req.body_json().unwrap();
448                let messages = params.messages;
449
450                // Access the captured groups from the path
451                let event_type = req
452                    .url
453                    .path_segments()
454                    .and_then(|segments| segments.rev().nth(1))
455                    .expect("Event type should be captured in the path");
456
457                let mut to_device_queue = to_device_queue.lock().unwrap();
458                for (user_id, device_map) in messages.iter() {
459                    for (device_id, content) in device_map.iter() {
460                        assert_let!(DeviceIdOrAllDevices::DeviceId(device_id) = device_id);
461
462                        let event = json!({
463                            "origin_server_ts": MilliSecondsSinceUnixEpoch::now(),
464                            "sender": sender,
465                            "type": event_type.to_owned(),
466                            "content": content,
467                        });
468
469                        to_device_queue
470                            .entry(user_id.to_owned())
471                            .or_default()
472                            .entry(device_id.to_owned())
473                            .or_default()
474                            .push(serde_json::from_value(event).unwrap());
475                    }
476                }
477
478                ResponseTemplate::new(200).set_body_json(&*test_json::EMPTY)
479            })
480            .mount_as_scoped(self.server())
481            .await
482    }
483
484    /// Sync the pending to-device messages for this client.
485    ///
486    /// To be used in connection with
487    /// [`MatrixMockServer::capture_put_to_device_traffic`] that is
488    /// capturing the traffic.
489    pub async fn sync_back_pending_to_device_messages(
490        &self,
491        to_device_queue: Arc<Mutex<PendingToDeviceMessages>>,
492        recipient: &Client,
493    ) {
494        let messages_to_sync = {
495            let to_device_queue = to_device_queue.lock().unwrap();
496            let pending_messages = to_device_queue
497                .get(&recipient.user_id().unwrap().to_owned())
498                .and_then(|treemap| treemap.get(&recipient.device_id().unwrap().to_owned()));
499
500            pending_messages.cloned().unwrap_or_default()
501        };
502
503        for message in messages_to_sync {
504            self.mock_sync()
505                .ok_and_run(recipient, |sync_builder| {
506                    sync_builder.add_to_device_event(message.deserialize_as().unwrap());
507                })
508                .await;
509        }
510    }
511}
512
513/// Intercepts a `/keys/query` request and mock its results as returned by an
514/// actual homeserver.
515///
516/// Supports filtering by user id, or no filters at all.
517fn mock_keys_query(keys: Arc<Mutex<Keys>>) -> impl Fn(&Request) -> ResponseTemplate {
518    move |req| {
519        #[derive(Debug, serde::Deserialize)]
520        struct Parameters {
521            device_keys: BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>,
522        }
523
524        let params: Parameters = req.body_json().unwrap();
525
526        let keys = keys.lock().unwrap();
527        let mut device_keys = keys.device.clone();
528        if !params.device_keys.is_empty() {
529            device_keys.retain(|user, key_map| {
530                if let Some(devices) = params.device_keys.get(user) {
531                    if !devices.is_empty() {
532                        key_map.retain(|key_id, _json| {
533                            devices.iter().any(|device_id| &device_id.to_string() == key_id)
534                        });
535                    }
536                    true
537                } else {
538                    false
539                }
540            })
541        }
542
543        let master_keys = keys.master.clone();
544        let self_signing_keys = keys.self_signing.clone();
545        let user_signing_keys = keys.user_signing.clone();
546
547        ResponseTemplate::new(200).set_body_json(json!({
548            "device_keys": device_keys,
549            "master_keys": master_keys,
550            "self_signing_keys": self_signing_keys,
551            "user_signing_keys": user_signing_keys,
552        }))
553    }
554}
555
556/// Intercepts a `/keys/upload` query and mocks the behavior it would have on a
557/// real homeserver.
558///
559/// Inserts all the `DeviceKeys` into `Keys::device_keys`, or if already present
560/// in this mapping, only merge the signatures.
561fn mock_keys_upload(
562    keys: Arc<Mutex<Keys>>,
563    token_to_user_id_map: Arc<Mutex<BTreeMap<String, OwnedUserId>>>,
564) -> impl Fn(&Request) -> ResponseTemplate {
565    move |req: &Request| {
566        #[derive(Debug, serde::Deserialize)]
567        struct Parameters {
568            device_keys: Option<Raw<DeviceKeys>>,
569            one_time_keys: Option<BTreeMap<OwnedOneTimeKeyId, Raw<OneTimeKey>>>,
570        }
571        let bearer_token = req
572            .headers
573            .get(http::header::AUTHORIZATION)
574            .and_then(|header| header.to_str().ok())
575            .expect("This call should be authenticated");
576
577        let params: Parameters = req.body_json().unwrap();
578
579        let tokens = token_to_user_id_map.lock().unwrap();
580        // Get the user
581        let user_id = tokens.get(bearer_token)
582            .expect("Expect this token to be known, ensure you use `MatrixKeysServer::client_builder_for_crypto_end_to_end`")
583            .to_owned();
584
585        if let Some(new_device_keys) = params.device_keys {
586            let new_device_keys = new_device_keys.deserialize().unwrap();
587
588            let key_id = new_device_keys.device_id.to_string();
589            // if known_devices.contains(&key_id) {
590            let mut keys = keys.lock().unwrap();
591            let devices = keys.device.entry(new_device_keys.user_id.clone()).or_default();
592
593            // Either merge signatures if an entry is already present, or insert a new one.
594            if let Some(device_keys) = devices.get_mut(&key_id) {
595                let mut existing = device_keys.deserialize().unwrap();
596
597                // Merge signatures.
598                for (uid, sigs) in existing.signatures.iter_mut() {
599                    if let Some(new_sigs) = new_device_keys.signatures.get(uid) {
600                        sigs.extend(new_sigs.clone());
601                    }
602                }
603                for (uid, sigs) in new_device_keys.signatures.iter() {
604                    if !existing.signatures.contains_key(uid) {
605                        existing.signatures.insert(uid.clone(), sigs.clone());
606                    }
607                }
608
609                *device_keys = Raw::new(&existing).unwrap();
610            } else {
611                devices.insert(key_id, Raw::new(&new_device_keys).unwrap());
612            }
613        }
614
615        let mut keys = keys.lock().unwrap();
616
617        if let Some(otks) = params.one_time_keys {
618            // We need a trick to find out what userId|device this OTK is for.
619            // This is not part of the payload, a real server uses the access token(?)
620            // Let's look at the signatures to find out
621            for (key_id, raw_otk) in otks {
622                let otk = raw_otk.deserialize().unwrap();
623                match otk {
624                    OneTimeKey::SignedKey(signed_key) => {
625                        let device_id = signed_key
626                            .signatures
627                            .first_key_value()
628                            .unwrap()
629                            .1
630                            .keys()
631                            .next()
632                            .unwrap()
633                            .key_name()
634                            .to_owned();
635
636                        keys.one_time_keys
637                            .entry(user_id.clone())
638                            .or_default()
639                            .entry(device_id)
640                            .or_default()
641                            .insert(key_id, raw_otk);
642                    }
643                    OneTimeKey::Key(_) => {
644                        // Ignore this old algorithm,
645                    }
646                    _ => {}
647                }
648            }
649        }
650
651        let otk_count = keys.one_time_keys.get(&user_id).map(|m| m.len()).unwrap_or(0);
652        ResponseTemplate::new(200).set_body_json(json!({
653            "one_time_key_counts": {
654                "signed_curve25519": otk_count,
655            }
656        }))
657    }
658}
659
660/// Mocks a `/keys/device_signing/upload` request for bootstrapping
661/// cross-signing.
662///
663/// Assumes (and asserts) all keys are updated at the same time.
664///
665/// Saves all the different cross-signing keys into their respective fields of
666/// `Keys`.
667fn mock_keys_device_signing_upload(
668    keys: Arc<Mutex<Keys>>,
669) -> impl Fn(&Request) -> ResponseTemplate {
670    move |req: &Request| {
671        // Accept all cross-signing setups by default.
672        #[derive(Debug, serde::Deserialize)]
673        struct Parameters {
674            master_key: Option<Raw<CrossSigningKey>>,
675            self_signing_key: Option<Raw<CrossSigningKey>>,
676            user_signing_key: Option<Raw<CrossSigningKey>>,
677        }
678
679        let params: Parameters = req.body_json().unwrap();
680        assert!(params.master_key.is_some());
681        assert!(params.self_signing_key.is_some());
682        assert!(params.user_signing_key.is_some());
683
684        let mut keys = keys.lock().unwrap();
685
686        if let Some(key) = params.master_key {
687            let deserialized = key.deserialize().unwrap();
688            let user_id = deserialized.user_id;
689            keys.master.insert(user_id, key);
690        }
691
692        if let Some(key) = params.self_signing_key {
693            let deserialized = key.deserialize().unwrap();
694            let user_id = deserialized.user_id;
695            keys.self_signing.insert(user_id, key);
696        }
697
698        if let Some(key) = params.user_signing_key {
699            let deserialized = key.deserialize().unwrap();
700            let user_id = deserialized.user_id;
701            keys.user_signing.insert(user_id, key);
702        }
703
704        ResponseTemplate::new(200).set_body_json(json!({}))
705    }
706}
707
708/// Mocks a `/keys/signatures/upload` request.
709///
710/// Supports merging signatures for master keys or devices keys.
711fn mock_keys_signature_upload(keys: Arc<Mutex<Keys>>) -> impl Fn(&Request) -> ResponseTemplate {
712    move |req: &Request| {
713        #[derive(Debug, serde::Deserialize)]
714        #[serde(transparent)]
715        struct Parameters(BTreeMap<OwnedUserId, SignedKeys>);
716
717        let params: Parameters = req.body_json().unwrap();
718
719        let mut keys = keys.lock().unwrap();
720
721        for (user, signed_keys) in params.0 {
722            for (key_id, raw_key) in signed_keys.iter() {
723                // Try to find a field in keys.master.
724                if let Some(existing_master_key) = keys.master.get_mut(&user) {
725                    let mut existing = existing_master_key.deserialize().unwrap();
726
727                    let target = CrossSigningKeyId::from_parts(
728                        ruma::SigningKeyAlgorithm::Ed25519,
729                        key_id.try_into().unwrap(),
730                    );
731
732                    if existing.keys.contains_key(&target) {
733                        let param: CrossSigningKey = serde_json::from_str(raw_key.get()).unwrap();
734
735                        for (uid, sigs) in existing.signatures.iter_mut() {
736                            if let Some(new_sigs) = param.signatures.get(uid) {
737                                sigs.extend(new_sigs.clone());
738                            }
739                        }
740                        for (uid, sigs) in param.signatures.iter() {
741                            if !existing.signatures.contains_key(uid) {
742                                existing.signatures.insert(uid.clone(), sigs.clone());
743                            }
744                        }
745
746                        // Update in map.
747                        *existing_master_key = Raw::new(&existing).unwrap();
748                        continue;
749                    }
750                }
751
752                // Otherwise, try to find a field in keys.device.
753                // Either merge signatures if an entry is already present, or insert a new
754                // entry.
755                let known_devices = keys.device.entry(user.clone()).or_default();
756                let device_keys = known_devices
757                    .get_mut(key_id)
758                    .expect("trying to add a signature for a missing key");
759
760                let param: DeviceKeys = serde_json::from_str(raw_key.get()).unwrap();
761
762                let mut existing: DeviceKeys = device_keys.deserialize().unwrap();
763
764                for (uid, sigs) in existing.signatures.iter_mut() {
765                    if let Some(new_sigs) = param.signatures.get(uid) {
766                        sigs.extend(new_sigs.clone());
767                    }
768                }
769                for (uid, sigs) in param.signatures.iter() {
770                    if !existing.signatures.contains_key(uid) {
771                        existing.signatures.insert(uid.clone(), sigs.clone());
772                    }
773                }
774
775                *device_keys = Raw::new(&existing).unwrap();
776            }
777        }
778
779        ResponseTemplate::new(200).set_body_json(json!({
780            "failures": {}
781        }))
782    }
783}
784
785fn mock_keys_claimed_request(keys: Arc<Mutex<Keys>>) -> impl Fn(&Request) -> ResponseTemplate {
786    move |req: &Request| {
787        // Accept all cross-signing setups by default.
788        #[derive(Debug, serde::Deserialize)]
789        struct Parameters {
790            one_time_keys: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, OneTimeKeyAlgorithm>>,
791        }
792
793        let params: Parameters = req.body_json().unwrap();
794
795        let mut keys = keys.lock().unwrap();
796        let known_otks = &mut keys.one_time_keys;
797
798        let mut found_one_time_keys: BTreeMap<
799            OwnedUserId,
800            BTreeMap<OwnedDeviceId, BTreeMap<OwnedOneTimeKeyId, Raw<OneTimeKey>>>,
801        > = BTreeMap::new();
802
803        for (user, requested_one_time_keys) in params.one_time_keys {
804            for device_id in requested_one_time_keys.keys() {
805                let device_id = device_id.clone();
806                let found_key = known_otks
807                    .entry(user.clone())
808                    .or_default()
809                    .entry(device_id.clone())
810                    .or_default()
811                    .pop_first();
812                if let Some((id, raw_otk)) = found_key {
813                    found_one_time_keys
814                        .entry(user.clone())
815                        .or_default()
816                        .entry(device_id.clone())
817                        .or_default()
818                        .insert(id, raw_otk.clone());
819                }
820            }
821        }
822
823        ResponseTemplate::new(200).set_body_json(json!({
824            "one_time_keys" : found_one_time_keys
825        }))
826    }
827}