matrix_sdk_crypto/session_manager/
sessions.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{BTreeMap, BTreeSet},
17    sync::Arc,
18    time::Duration,
19};
20
21use matrix_sdk_common::{failures_cache::FailuresCache, locks::RwLock as StdRwLock};
22use ruma::{
23    api::client::keys::claim_keys::v3::{
24        Request as KeysClaimRequest, Response as KeysClaimResponse,
25    },
26    assign,
27    events::dummy::ToDeviceDummyEventContent,
28    DeviceId, OneTimeKeyAlgorithm, OwnedDeviceId, OwnedOneTimeKeyId, OwnedServerName,
29    OwnedTransactionId, OwnedUserId, SecondsSinceUnixEpoch, ServerName, TransactionId, UserId,
30};
31use tracing::{debug, error, info, instrument, warn};
32use vodozemac::Curve25519PublicKey;
33
34use crate::{
35    error::OlmResult,
36    gossiping::GossipMachine,
37    store::{Changes, Result as StoreResult, Store},
38    types::{
39        events::EventType,
40        requests::{OutgoingRequest, ToDeviceRequest},
41        EventEncryptionAlgorithm,
42    },
43    DeviceData,
44};
45
46#[derive(Debug, Clone)]
47pub(crate) struct SessionManager {
48    store: Store,
49
50    /// If there is an active /keys/claim request, its details.
51    ///
52    /// This is used when processing the response, so that we can spot missing
53    /// users/devices.
54    ///
55    /// According to the doc on [`crate::OlmMachine::get_missing_sessions`],
56    /// there should only be one such request active at a time, so we only need
57    /// to keep a record of the most recent.
58    current_key_claim_request: Arc<StdRwLock<Option<(OwnedTransactionId, KeysClaimRequest)>>>,
59
60    /// A map of user/devices that we need to automatically claim keys for.
61    /// Submodules can insert user/device pairs into this map and the
62    /// user/device paris will be added to the list of users when
63    /// [`get_missing_sessions`](#method.get_missing_sessions) is called.
64    users_for_key_claim: Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>>>>,
65    wedged_devices: Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>>>>,
66    key_request_machine: GossipMachine,
67    outgoing_to_device_requests: Arc<StdRwLock<BTreeMap<OwnedTransactionId, OutgoingRequest>>>,
68
69    /// Servers that have previously appeared in the `failures` section of a
70    /// `/keys/claim` response.
71    ///
72    /// See also [`crate::identities::IdentityManager::failures`].
73    failures: FailuresCache<OwnedServerName>,
74
75    failed_devices: Arc<StdRwLock<BTreeMap<OwnedUserId, FailuresCache<OwnedDeviceId>>>>,
76}
77
78impl SessionManager {
79    const KEY_CLAIM_TIMEOUT: Duration = Duration::from_secs(10);
80    const UNWEDGING_INTERVAL: Duration = Duration::from_secs(60 * 60);
81
82    pub fn new(
83        users_for_key_claim: Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>>>>,
84        key_request_machine: GossipMachine,
85        store: Store,
86    ) -> Self {
87        Self {
88            store,
89            current_key_claim_request: Default::default(),
90            key_request_machine,
91            users_for_key_claim,
92            wedged_devices: Default::default(),
93            outgoing_to_device_requests: Default::default(),
94            failures: Default::default(),
95            failed_devices: Default::default(),
96        }
97    }
98
99    /// Mark the outgoing request as sent.
100    pub fn mark_outgoing_request_as_sent(&self, id: &TransactionId) {
101        self.outgoing_to_device_requests.write().remove(id);
102    }
103
104    pub async fn mark_device_as_wedged(
105        &self,
106        sender: &UserId,
107        curve_key: Curve25519PublicKey,
108    ) -> OlmResult<()> {
109        if let Some(device) = self.store.get_device_from_curve_key(sender, curve_key).await? {
110            if let Some(session) = device.get_most_recent_session().await? {
111                info!(sender_key = ?curve_key, "Marking session to be unwedged");
112
113                let creation_time = Duration::from_secs(session.creation_time.get().into());
114                let now = Duration::from_secs(SecondsSinceUnixEpoch::now().get().into());
115
116                let should_unwedge = now
117                    .checked_sub(creation_time)
118                    .map(|elapsed| elapsed > Self::UNWEDGING_INTERVAL)
119                    .unwrap_or(true);
120
121                if should_unwedge {
122                    self.users_for_key_claim
123                        .write()
124                        .entry(device.user_id().to_owned())
125                        .or_default()
126                        .insert(device.device_id().into());
127                    self.wedged_devices
128                        .write()
129                        .entry(device.user_id().to_owned())
130                        .or_default()
131                        .insert(device.device_id().into());
132                }
133            }
134        }
135
136        Ok(())
137    }
138
139    #[allow(dead_code)]
140    pub fn is_device_wedged(&self, device: &DeviceData) -> bool {
141        self.wedged_devices
142            .read()
143            .get(device.user_id())
144            .is_some_and(|d| d.contains(device.device_id()))
145    }
146
147    /// Check if the session was created to unwedge a Device.
148    ///
149    /// If the device was wedged this will queue up a dummy to-device message.
150    async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> {
151        if self.wedged_devices.write().get_mut(user_id).is_some_and(|d| d.remove(device_id)) {
152            if let Some(device) = self.store.get_device(user_id, device_id).await? {
153                let (_, content) =
154                    device.encrypt("m.dummy", ToDeviceDummyEventContent::new()).await?;
155
156                let event_type = content.event_type().to_owned();
157
158                let request = ToDeviceRequest::new(
159                    device.user_id(),
160                    device.device_id().to_owned(),
161                    &event_type,
162                    content.cast(),
163                );
164
165                let request = OutgoingRequest {
166                    request_id: request.txn_id.clone(),
167                    request: Arc::new(request.into()),
168                };
169
170                self.outgoing_to_device_requests
171                    .write()
172                    .insert(request.request_id.clone(), request);
173            }
174        }
175
176        Ok(())
177    }
178
179    /// Get a key claiming request for the user/device pairs that we are
180    /// missing Olm sessions for.
181    ///
182    /// Returns None if no key claiming request needs to be sent out.
183    ///
184    /// Sessions need to be established between devices so group sessions for a
185    /// room can be shared with them.
186    ///
187    /// This should be called every time a group session needs to be shared as
188    /// well as between sync calls. After a sync some devices may request room
189    /// keys without us having a valid Olm session with them, making it
190    /// impossible to server the room key request, thus it's necessary to check
191    /// for missing sessions between sync as well.
192    ///
193    /// **Note**: Care should be taken that only one such request at a time is
194    /// in flight, e.g. using a lock.
195    ///
196    /// The response of a successful key claiming requests needs to be passed to
197    /// the `OlmMachine` with the [`receive_keys_claim_response`].
198    ///
199    /// # Arguments
200    ///
201    /// `users` - The list of users that we should check if we lack a session
202    /// with one of their devices. This can be an empty iterator when calling
203    /// this method between sync requests.
204    ///
205    /// [`receive_keys_claim_response`]: #method.receive_keys_claim_response
206    pub async fn get_missing_sessions(
207        &self,
208        users: impl Iterator<Item = &UserId>,
209    ) -> StoreResult<Option<(OwnedTransactionId, KeysClaimRequest)>> {
210        let mut missing_session_devices_by_user: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new();
211        let mut timed_out_devices_by_user: BTreeMap<_, BTreeSet<_>> = BTreeMap::new();
212
213        let unfailed_users = users.filter(|u| !self.failures.contains(u.server_name()));
214
215        // Get the current list of devices for each user.
216        let devices_by_user = Box::pin(
217            self.key_request_machine
218                .identity_manager()
219                .get_user_devices_for_encryption(unfailed_users),
220        )
221        .await?;
222
223        #[derive(Debug, Default)]
224        struct UserFailedDeviceInfo {
225            non_olm_devices: BTreeMap<OwnedDeviceId, Vec<EventEncryptionAlgorithm>>,
226            bad_key_devices: BTreeSet<OwnedDeviceId>,
227        }
228
229        let mut failed_devices_by_user: BTreeMap<_, UserFailedDeviceInfo> = BTreeMap::new();
230
231        for (user_id, user_devices) in devices_by_user {
232            for (device_id, device) in user_devices {
233                if !device.supports_olm() {
234                    failed_devices_by_user
235                        .entry(user_id.clone())
236                        .or_default()
237                        .non_olm_devices
238                        .insert(device_id, Vec::from(device.algorithms()));
239                } else if let Some(sender_key) = device.curve25519_key() {
240                    let sessions = self.store.get_sessions(&sender_key.to_base64()).await?;
241
242                    let is_missing = if let Some(sessions) = sessions {
243                        sessions.lock().await.is_empty()
244                    } else {
245                        true
246                    };
247
248                    let is_timed_out = self.is_user_timed_out(&user_id, &device_id);
249
250                    if is_missing && is_timed_out {
251                        timed_out_devices_by_user
252                            .entry(user_id.to_owned())
253                            .or_default()
254                            .insert(device_id);
255                    } else if is_missing && !is_timed_out {
256                        missing_session_devices_by_user
257                            .entry(user_id.to_owned())
258                            .or_default()
259                            .insert(device_id, OneTimeKeyAlgorithm::SignedCurve25519);
260                    }
261                } else {
262                    failed_devices_by_user
263                        .entry(user_id.clone())
264                        .or_default()
265                        .bad_key_devices
266                        .insert(device_id);
267                }
268            }
269        }
270
271        // Add the list of sessions that for some reason automatically need to
272        // create an Olm session.
273        for (user, device_ids) in self.users_for_key_claim.read().iter() {
274            missing_session_devices_by_user.entry(user.to_owned()).or_default().extend(
275                device_ids
276                    .iter()
277                    .map(|device_id| (device_id.clone(), OneTimeKeyAlgorithm::SignedCurve25519)),
278            );
279        }
280
281        if tracing::level_enabled!(tracing::Level::DEBUG) {
282            // Reformat the map to skip the encryption algorithm, which isn't very useful.
283            let missing_session_devices_by_user = missing_session_devices_by_user
284                .iter()
285                .map(|(user_id, devices)| (user_id, devices.keys().collect::<BTreeSet<_>>()))
286                .collect::<BTreeMap<_, _>>();
287            debug!(
288                ?missing_session_devices_by_user,
289                ?timed_out_devices_by_user,
290                "Collected user/device pairs that are missing an Olm session"
291            );
292        }
293
294        if !failed_devices_by_user.is_empty() {
295            warn!(
296                ?failed_devices_by_user,
297                "Can't establish an Olm session with some devices due to missing Olm support or bad keys",
298            );
299        }
300
301        let result = if missing_session_devices_by_user.is_empty() {
302            None
303        } else {
304            Some((
305                TransactionId::new(),
306                assign!(KeysClaimRequest::new(missing_session_devices_by_user), {
307                    timeout: Some(Self::KEY_CLAIM_TIMEOUT),
308                }),
309            ))
310        };
311
312        // stash the details of the request so that we can refer to it when handling the
313        // response
314        *(self.current_key_claim_request.write()) = result.clone();
315        Ok(result)
316    }
317
318    fn is_user_timed_out(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
319        self.failed_devices.read().get(user_id).is_some_and(|d| d.contains(device_id))
320    }
321
322    /// This method will try to figure out for which devices a one-time key was
323    /// requested but is not present in the response.
324    ///
325    /// As per [spec], if a user/device pair does not have any one-time keys on
326    /// the homeserver, the server will just omit the user/device pair from
327    /// the response:
328    ///
329    /// > If the homeserver could be reached, but the user or device was
330    /// > unknown, no failure is recorded. Instead, the corresponding user
331    /// > or device is missing from the one_time_keys result.
332    ///
333    /// The user/device pairs which are missing from the response are going to
334    /// be put in the failures cache so we don't retry to claim a one-time
335    /// key right away next time the user tries to send a message.
336    ///
337    /// [spec]: https://spec.matrix.org/unstable/client-server-api/#post_matrixclientv3keysclaim
338    fn handle_otk_exhaustion_failure(
339        &self,
340        request_id: &TransactionId,
341        failed_servers: &BTreeSet<OwnedServerName>,
342        one_time_keys: &BTreeMap<
343            &OwnedUserId,
344            BTreeMap<&OwnedDeviceId, BTreeSet<&OwnedOneTimeKeyId>>,
345        >,
346    ) {
347        // First check that the response is for the request we were expecting.
348        let request = {
349            let mut guard = self.current_key_claim_request.write();
350            let expected_request_id = guard.as_ref().map(|e| e.0.as_ref());
351
352            if Some(request_id) == expected_request_id {
353                // We have a confirmed match. Clear the expectation, but hang onto the details
354                // of the request.
355                guard.take().map(|(_, request)| request)
356            } else {
357                warn!(
358                    ?request_id,
359                    ?expected_request_id,
360                    "Received a `/keys/claim` response for the wrong request"
361                );
362                None
363            }
364        };
365
366        // If we were able to pair this response with a request, look for devices that
367        // were present in the request but did not elicit a successful response.
368        if let Some(request) = request {
369            let devices_in_response: BTreeSet<_> = one_time_keys
370                .iter()
371                .flat_map(|(user_id, device_key_map)| {
372                    device_key_map
373                        .keys()
374                        .map(|device_id| (*user_id, *device_id))
375                        .collect::<BTreeSet<_>>()
376                })
377                .collect();
378
379            let devices_in_request: BTreeSet<(_, _)> = request
380                .one_time_keys
381                .iter()
382                .flat_map(|(user_id, device_key_map)| {
383                    device_key_map
384                        .keys()
385                        .map(|device_id| (user_id, device_id))
386                        .collect::<BTreeSet<_>>()
387                })
388                .collect();
389
390            let missing_devices: BTreeSet<_> = devices_in_request
391                .difference(&devices_in_response)
392                .filter(|(user_id, _)| {
393                    // Skip over users whose homeservers were in the "failed servers" list: we don't
394                    // want to mark individual devices as broken *as well as* the server.
395                    !failed_servers.contains(user_id.server_name())
396                })
397                .collect();
398
399            if !missing_devices.is_empty() {
400                let mut missing_devices_by_user: BTreeMap<_, BTreeSet<_>> = BTreeMap::new();
401
402                for &(user_id, device_id) in missing_devices {
403                    missing_devices_by_user.entry(user_id).or_default().insert(device_id.clone());
404                }
405
406                warn!(
407                    ?missing_devices_by_user,
408                    "Tried to create new Olm sessions, but the signed one-time key was missing for some devices",
409                );
410
411                let mut failed_devices_lock = self.failed_devices.write();
412
413                for (user_id, device_set) in missing_devices_by_user {
414                    failed_devices_lock.entry(user_id.clone()).or_default().extend(device_set);
415                }
416            }
417        };
418    }
419
420    /// Receive a successful key claim response and create new Olm sessions with
421    /// the claimed keys.
422    ///
423    /// # Arguments
424    ///
425    /// * `request_id` - The unique id of the request that was sent out. This is
426    ///   needed to couple the response with the sent out request.
427    ///
428    /// * `response` - The response containing the claimed one-time keys.
429    #[instrument(skip(self, response))]
430    pub async fn receive_keys_claim_response(
431        &self,
432        request_id: &TransactionId,
433        response: &KeysClaimResponse,
434    ) -> OlmResult<()> {
435        // Collect the (user_id, device_id, device_key_id) triple for logging reasons.
436        let one_time_keys: BTreeMap<_, BTreeMap<_, BTreeSet<_>>> = response
437            .one_time_keys
438            .iter()
439            .map(|(user_id, device_map)| {
440                (
441                    user_id,
442                    device_map
443                        .iter()
444                        .map(|(device_id, key_map)| {
445                            (device_id, key_map.keys().collect::<BTreeSet<_>>())
446                        })
447                        .collect::<BTreeMap<_, _>>(),
448                )
449            })
450            .collect();
451
452        debug!(?request_id, ?one_time_keys, failures = ?response.failures, "Received a `/keys/claim` response");
453
454        // Collect all the servers in the `failures` field of the response.
455        let failed_servers: BTreeSet<_> = response
456            .failures
457            .keys()
458            .filter_map(|s| ServerName::parse(s).ok())
459            .filter(|s| s != self.store.static_account().user_id.server_name())
460            .collect();
461        let successful_servers = response.one_time_keys.keys().map(|u| u.server_name());
462
463        // Add the user/device pairs that don't have any one-time keys to the failures
464        // cache.
465        self.handle_otk_exhaustion_failure(request_id, &failed_servers, &one_time_keys);
466        // Add the failed servers to the failures cache.
467        self.failures.extend(failed_servers);
468        // Remove the servers we successfully contacted from the failures cache.
469        self.failures.remove(successful_servers);
470
471        // Finally, create some 1-to-1 sessions.
472        self.create_sessions(response).await
473    }
474
475    /// Create new Olm sessions for the requested devices.
476    ///
477    /// # Arguments
478    ///
479    ///  * `device_map` - a map from (user ID, device ID) pairs to key object,
480    ///    for each device we should create a session for.
481    pub(crate) async fn create_sessions(&self, response: &KeysClaimResponse) -> OlmResult<()> {
482        struct SessionInfo {
483            session_id: String,
484            algorithm: EventEncryptionAlgorithm,
485            fallback_key_used: bool,
486        }
487
488        #[cfg(not(tarpaulin_include))]
489        impl std::fmt::Debug for SessionInfo {
490            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
491                write!(
492                    f,
493                    "session_id: {}, algorithm: {}, fallback_key_used: {}",
494                    self.session_id, self.algorithm, self.fallback_key_used
495                )
496            }
497        }
498
499        let mut changes = Changes::default();
500        let mut new_sessions: BTreeMap<&UserId, BTreeMap<&DeviceId, SessionInfo>> = BTreeMap::new();
501        let mut store_transaction = self.store.transaction().await;
502
503        for (user_id, user_devices) in &response.one_time_keys {
504            for (device_id, key_map) in user_devices {
505                let device = match self.store.get_device_data(user_id, device_id).await {
506                    Ok(Some(d)) => d,
507                    Ok(None) => {
508                        warn!(
509                            ?user_id,
510                            ?device_id,
511                            "Tried to create an Olm session but the device is unknown",
512                        );
513                        continue;
514                    }
515                    Err(e) => {
516                        warn!(
517                            ?user_id, ?device_id, error = ?e,
518                            "Tried to create an Olm session, but we can't \
519                            fetch the device from the store",
520                        );
521                        continue;
522                    }
523                };
524
525                let account = store_transaction.account().await?;
526                let device_keys = self.store.get_own_device().await?.as_device_keys().clone();
527                let session = match account.create_outbound_session(&device, key_map, device_keys) {
528                    Ok(s) => s,
529                    Err(e) => {
530                        warn!(
531                            ?user_id, ?device_id, error = ?e,
532                            "Error creating Olm session"
533                        );
534
535                        self.failed_devices
536                            .write()
537                            .entry(user_id.to_owned())
538                            .or_default()
539                            .insert(device_id.to_owned());
540
541                        continue;
542                    }
543                };
544
545                self.key_request_machine.retry_keyshare(user_id, device_id);
546
547                if let Err(e) = self.check_if_unwedged(user_id, device_id).await {
548                    error!(?user_id, ?device_id, "Error while treating an unwedged device: {e:?}");
549                }
550
551                let session_info = SessionInfo {
552                    session_id: session.session_id().to_owned(),
553                    algorithm: session.algorithm().await,
554                    fallback_key_used: session.created_using_fallback_key,
555                };
556
557                changes.sessions.push(session);
558                new_sessions.entry(user_id).or_default().insert(device_id, session_info);
559            }
560        }
561
562        store_transaction.commit().await?;
563        self.store.save_changes(changes).await?;
564        info!(sessions = ?new_sessions, "Established new Olm sessions");
565
566        for (user, device_map) in new_sessions {
567            if let Some(user_cache) = self.failed_devices.read().get(user) {
568                user_cache.remove(device_map.into_keys());
569            }
570        }
571
572        let store_cache = self.store.cache().await?;
573        match self.key_request_machine.collect_incoming_key_requests(&store_cache).await {
574            Ok(sessions) => {
575                let changes = Changes { sessions, ..Default::default() };
576                self.store.save_changes(changes).await?
577            }
578            // We don't propagate the error here since the next sync will retry
579            // this.
580            Err(e) => {
581                warn!(error = ?e, "Error while trying to collect the incoming secret requests")
582            }
583        }
584
585        Ok(())
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use std::{collections::BTreeMap, iter, ops::Deref, sync::Arc, time::Duration};
592
593    use matrix_sdk_common::locks::RwLock as StdRwLock;
594    use matrix_sdk_test::{async_test, ruma_response_from_json};
595    use ruma::{
596        api::client::keys::claim_keys::v3::Response as KeyClaimResponse, device_id,
597        owned_server_name, user_id, DeviceId, OwnedUserId, UserId,
598    };
599    use serde_json::json;
600    use tokio::sync::Mutex;
601    use tracing::info;
602
603    use super::SessionManager;
604    use crate::{
605        gossiping::GossipMachine,
606        identities::{DeviceData, IdentityManager},
607        olm::{Account, PrivateCrossSigningIdentity},
608        session_manager::GroupSessionCache,
609        store::{Changes, CryptoStoreWrapper, DeviceChanges, MemoryStore, PendingChanges, Store},
610        verification::VerificationMachine,
611    };
612
613    fn user_id() -> &'static UserId {
614        user_id!("@example:localhost")
615    }
616
617    fn device_id() -> &'static DeviceId {
618        device_id!("DEVICEID")
619    }
620
621    fn bob_account() -> Account {
622        Account::with_device_id(user_id!("@bob:localhost"), device_id!("BOBDEVICE"))
623    }
624
625    fn keys_claim_with_failure() -> KeyClaimResponse {
626        let response = json!({
627            "one_time_keys": {},
628            "failures": {
629                "example.org": {
630                    "errcode": "M_RESOURCE_LIMIT_EXCEEDED",
631                    "error": "Not yet ready to retry",
632                }
633            }
634        });
635        ruma_response_from_json(&response)
636    }
637
638    fn keys_claim_without_failure() -> KeyClaimResponse {
639        let response = json!({
640            "one_time_keys": {
641                "@alice:example.org": {},
642            },
643            "failures": {},
644        });
645        ruma_response_from_json(&response)
646    }
647
648    async fn session_manager_test_helper() -> (SessionManager, IdentityManager) {
649        let user_id = user_id();
650        let device_id = device_id();
651
652        let account = Account::with_device_id(user_id, device_id);
653        let store = Arc::new(CryptoStoreWrapper::new(user_id, device_id, MemoryStore::new()));
654        let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id)));
655        let verification = VerificationMachine::new(
656            account.static_data().clone(),
657            identity.clone(),
658            store.clone(),
659        );
660
661        let store = Store::new(account.static_data().clone(), identity, store, verification);
662        let device = DeviceData::from_account(&account);
663        store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
664        store
665            .save_changes(Changes {
666                devices: DeviceChanges { new: vec![device], ..Default::default() },
667                ..Default::default()
668            })
669            .await
670            .unwrap();
671
672        let session_cache = GroupSessionCache::new(store.clone());
673        let identity_manager = IdentityManager::new(store.clone());
674
675        let users_for_key_claim = Arc::new(StdRwLock::new(BTreeMap::new()));
676        let key_request = GossipMachine::new(
677            store.clone(),
678            identity_manager.clone(),
679            session_cache,
680            users_for_key_claim.clone(),
681        );
682
683        (SessionManager::new(users_for_key_claim, key_request, store), identity_manager)
684    }
685
686    #[async_test]
687    async fn test_session_creation() {
688        let (manager, _identity_manager) = session_manager_test_helper().await;
689        let mut bob = bob_account();
690
691        let bob_device = DeviceData::from_account(&bob);
692
693        manager.store.save_device_data(&[bob_device]).await.unwrap();
694
695        let (txn_id, request) =
696            manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().unwrap();
697
698        assert!(request.one_time_keys.contains_key(bob.user_id()));
699
700        bob.generate_one_time_keys(1);
701        let one_time = bob.signed_one_time_keys();
702        assert!(!one_time.is_empty());
703        bob.mark_keys_as_published();
704
705        let mut one_time_keys = BTreeMap::new();
706        one_time_keys
707            .entry(bob.user_id().to_owned())
708            .or_insert_with(BTreeMap::new)
709            .insert(bob.device_id().to_owned(), one_time);
710
711        let response = KeyClaimResponse::new(one_time_keys);
712
713        manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
714
715        assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
716    }
717
718    #[async_test]
719    async fn test_session_creation_waits_for_keys_query() {
720        let (manager, identity_manager) = session_manager_test_helper().await;
721
722        // start a `/keys/query` request. At this point, we are only interested in our
723        // own devices.
724        let (key_query_txn_id, key_query_request) =
725            identity_manager.users_for_key_query().await.unwrap().pop_first().unwrap();
726        info!("Initial key query: {:?}", key_query_request);
727
728        // now bob turns up, and we start tracking his devices...
729        let bob = bob_account();
730        let bob_device = DeviceData::from_account(&bob);
731        {
732            let cache = manager.store.cache().await.unwrap();
733            identity_manager
734                .key_query_manager
735                .synced(&cache)
736                .await
737                .unwrap()
738                .update_tracked_users(iter::once(bob.user_id()))
739                .await
740                .unwrap();
741        }
742
743        // ... and start off an attempt to get the missing sessions. This should block
744        // for now.
745        let missing_sessions_task = {
746            let manager = manager.clone();
747            let bob_user_id = bob.user_id().to_owned();
748
749            #[allow(unknown_lints, clippy::redundant_async_block)] // false positive
750            tokio::spawn(async move {
751                manager.get_missing_sessions(iter::once(bob_user_id.deref())).await
752            })
753        };
754
755        // the initial `/keys/query` completes, and we start another
756        let response_json =
757            json!({ "device_keys": { manager.store.static_account().user_id.to_owned(): {}}});
758        let response = ruma_response_from_json(&response_json);
759        identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap();
760
761        let (key_query_txn_id, key_query_request) =
762            identity_manager.users_for_key_query().await.unwrap().pop_first().unwrap();
763        info!("Second key query: {:?}", key_query_request);
764
765        // that second request completes with info on bob's device
766        let response_json = json!({ "device_keys": { bob.user_id(): {
767            bob_device.device_id(): bob_device.as_device_keys()
768        }}});
769        let response = ruma_response_from_json(&response_json);
770        identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap();
771
772        // the missing_sessions_task should now finally complete, with a claim
773        // including bob's device
774        let (_, keys_claim_request) = missing_sessions_task.await.unwrap().unwrap().unwrap();
775        info!("Key claim request: {:?}", keys_claim_request.one_time_keys);
776        let bob_key_claims = keys_claim_request.one_time_keys.get(bob.user_id()).unwrap();
777        assert!(bob_key_claims.contains_key(bob_device.device_id()));
778    }
779
780    #[async_test]
781    async fn test_session_creation_does_not_wait_for_keys_query_on_failed_server() {
782        let (manager, identity_manager) = session_manager_test_helper().await;
783
784        // We start tracking Bob's devices.
785        let other_user_id = OwnedUserId::try_from("@bob:example.com").unwrap();
786        {
787            let cache = manager.store.cache().await.unwrap();
788            identity_manager
789                .key_query_manager
790                .synced(&cache)
791                .await
792                .unwrap()
793                .update_tracked_users(iter::once(other_user_id.as_ref()))
794                .await
795                .unwrap();
796        }
797
798        // Do a `/keys/query` request, in which Bob's server is a failure.
799        let (key_query_txn_id, _key_query_request) =
800            identity_manager.users_for_key_query().await.unwrap().pop_first().unwrap();
801        let response = ruma_response_from_json(
802            &json!({ "device_keys": {}, "failures": { other_user_id.server_name(): "unreachable" }}),
803        );
804        identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap();
805
806        // Now, an attempt to get the missing sessions should now *not* block. We use a
807        // timeout so that we can detect the call blocking.
808        let result = tokio::time::timeout(
809            Duration::from_millis(10),
810            manager.get_missing_sessions(iter::once(other_user_id.as_ref())),
811        )
812        .await
813        .expect("get_missing_sessions blocked rather than completing quickly")
814        .expect("get_missing_sessions returned an error");
815
816        assert!(result.is_none(), "get_missing_sessions returned Some(...)");
817    }
818
819    // This test doesn't run on macos because we're modifying the session
820    // creation time so we can get around the UNWEDGING_INTERVAL.
821    #[async_test]
822    #[cfg(target_os = "linux")]
823    async fn test_session_unwedging() {
824        use ruma::{time::SystemTime, SecondsSinceUnixEpoch};
825
826        let (manager, _identity_manager) = session_manager_test_helper().await;
827        let mut bob = bob_account();
828
829        let (_, mut session) = manager
830            .store
831            .with_transaction(|mut tr| async {
832                let manager_account = tr.account().await.unwrap();
833                let res = bob.create_session_for_test_helper(manager_account).await;
834                Ok((tr, res))
835            })
836            .await
837            .unwrap();
838
839        let bob_device = DeviceData::from_account(&bob);
840        let time = SystemTime::now() - Duration::from_secs(3601);
841        session.creation_time = SecondsSinceUnixEpoch::from_system_time(time).unwrap();
842
843        manager.store.save_device_data(&[bob_device.clone()]).await.unwrap();
844        manager.store.save_sessions(&[session]).await.unwrap();
845
846        assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
847
848        let curve_key = bob_device.curve25519_key().unwrap();
849
850        assert!(!manager.users_for_key_claim.read().contains_key(bob.user_id()));
851        assert!(!manager.is_device_wedged(&bob_device));
852        manager.mark_device_as_wedged(bob_device.user_id(), curve_key).await.unwrap();
853        assert!(manager.is_device_wedged(&bob_device));
854        assert!(manager.users_for_key_claim.read().contains_key(bob.user_id()));
855
856        let (txn_id, request) =
857            manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().unwrap();
858
859        assert!(request.one_time_keys.contains_key(bob.user_id()));
860
861        bob.generate_one_time_keys(1);
862        let one_time = bob.signed_one_time_keys();
863        assert!(!one_time.is_empty());
864        bob.mark_keys_as_published();
865
866        let mut one_time_keys = BTreeMap::new();
867        one_time_keys
868            .entry(bob.user_id().to_owned())
869            .or_insert_with(BTreeMap::new)
870            .insert(bob.device_id().to_owned(), one_time);
871
872        let response = KeyClaimResponse::new(one_time_keys);
873
874        assert!(manager.outgoing_to_device_requests.read().is_empty());
875
876        manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
877
878        assert!(!manager.is_device_wedged(&bob_device));
879        assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
880        assert!(!manager.outgoing_to_device_requests.read().is_empty())
881    }
882
883    #[async_test]
884    async fn test_failure_handling() {
885        let alice = user_id!("@alice:example.org");
886        let alice_account = Account::with_device_id(alice, "DEVICEID".into());
887        let alice_device = DeviceData::from_account(&alice_account);
888
889        let (manager, _identity_manager) = session_manager_test_helper().await;
890
891        manager.store.save_device_data(&[alice_device]).await.unwrap();
892
893        let (txn_id, users_for_key_claim) =
894            manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
895        assert!(users_for_key_claim.one_time_keys.contains_key(alice));
896
897        manager.receive_keys_claim_response(&txn_id, &keys_claim_with_failure()).await.unwrap();
898        assert!(manager.get_missing_sessions(iter::once(alice)).await.unwrap().is_none());
899
900        // expire the failure
901        manager.failures.expire(&owned_server_name!("example.org"));
902
903        let (txn_id, users_for_key_claim) =
904            manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
905        assert!(users_for_key_claim.one_time_keys.contains_key(alice));
906
907        manager.receive_keys_claim_response(&txn_id, &keys_claim_without_failure()).await.unwrap();
908    }
909
910    #[async_test]
911    async fn test_failed_devices_handling() {
912        // Alice is missing altogether
913        test_invalid_claim_response(json!({
914            "one_time_keys": {},
915            "failures": {},
916        }))
917        .await;
918
919        // Alice is present but with no devices
920        test_invalid_claim_response(json!({
921            "one_time_keys": {
922                "@alice:example.org": {}
923            },
924            "failures": {},
925        }))
926        .await;
927
928        // Alice's device is present but with no keys
929        test_invalid_claim_response(json!({
930            "one_time_keys": {
931                "@alice:example.org": {
932                    "DEVICEID": {}
933                }
934            },
935            "failures": {},
936        }))
937        .await;
938
939        // Alice's device is present with a bad signature
940        test_invalid_claim_response(json!({
941            "one_time_keys": {
942                "@alice:example.org": {
943                    "DEVICEID": {
944                        "signed_curve25519:AAAAAA": {
945                            "fallback": true,
946                            "key": "1sra5GVo1ONz478aQybxSEeHTSo2xq0Z+Q3Yzqvp3A4",
947                            "signatures": {
948                                "@example:morpheus.localhost": {
949                                    "ed25519:YAFLBLXAUK": "Zwk90fJhZWOYGNOgtOswZ6RSOGeTjTi/h2dMpyB0CR6EVtvTra0WJtp32ntifrxtwD710y2F3pe5Oyrm7jngCQ"
950                                }
951                            }
952                        }
953                    }
954                }
955            },
956            "failures": {},
957        })).await;
958    }
959
960    /// Helper for failed_devices_handling.
961    ///
962    /// Takes an invalid /keys/claim response for Alice's device DEVICEID and
963    /// checks that it is handled correctly. (The device should be marked as
964    /// 'failed'; and once that
965    async fn test_invalid_claim_response(response_json: serde_json::Value) {
966        let response = ruma_response_from_json(&response_json);
967
968        let alice = user_id!("@alice:example.org");
969        let mut alice_account = Account::with_device_id(alice, "DEVICEID".into());
970        let alice_device = DeviceData::from_account(&alice_account);
971
972        let (manager, _identity_manager) = session_manager_test_helper().await;
973        manager.store.save_device_data(&[alice_device]).await.unwrap();
974
975        // Since we don't have a session with Alice yet, the machine will try to claim
976        // some keys for alice.
977        let (txn_id, users_for_key_claim) =
978            manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
979        assert!(users_for_key_claim.one_time_keys.contains_key(alice));
980
981        // We receive a response with an invalid one-time key, this will mark Alice as
982        // timed out.
983        manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
984        // Since alice is timed out, we won't claim keys for her.
985        assert!(manager.get_missing_sessions(iter::once(alice)).await.unwrap().is_none());
986
987        alice_account.generate_one_time_keys(1);
988        let one_time = alice_account.signed_one_time_keys();
989        assert!(!one_time.is_empty());
990
991        let mut one_time_keys = BTreeMap::new();
992        one_time_keys
993            .entry(alice.to_owned())
994            .or_insert_with(BTreeMap::new)
995            .insert(alice_account.device_id().to_owned(), one_time);
996
997        // Now we expire Alice's timeout, and receive a valid one-time key for her.
998        manager
999            .failed_devices
1000            .write()
1001            .get(alice)
1002            .unwrap()
1003            .expire(&alice_account.device_id().to_owned());
1004        let (txn_id, users_for_key_claim) =
1005            manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
1006        assert!(users_for_key_claim.one_time_keys.contains_key(alice));
1007
1008        let response = KeyClaimResponse::new(one_time_keys);
1009        manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
1010
1011        // Alice isn't timed out anymore.
1012        assert!(manager
1013            .failed_devices
1014            .read()
1015            .get(alice)
1016            .unwrap()
1017            .failure_count(alice_account.device_id())
1018            .is_none());
1019    }
1020}