matrix_sdk_crypto/session_manager/
sessions.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{BTreeMap, BTreeSet},
17    sync::Arc,
18    time::Duration,
19};
20
21use matrix_sdk_common::{failures_cache::FailuresCache, locks::RwLock as StdRwLock};
22use ruma::{
23    DeviceId, OneTimeKeyAlgorithm, OwnedDeviceId, OwnedOneTimeKeyId, OwnedServerName,
24    OwnedTransactionId, OwnedUserId, SecondsSinceUnixEpoch, ServerName, TransactionId, UserId,
25    api::client::keys::claim_keys::v3::{
26        Request as KeysClaimRequest, Response as KeysClaimResponse,
27    },
28    assign,
29    events::dummy::ToDeviceDummyEventContent,
30};
31use tracing::{debug, error, info, instrument, warn};
32use vodozemac::Curve25519PublicKey;
33
34use crate::{
35    DeviceData,
36    error::OlmResult,
37    gossiping::GossipMachine,
38    store::{Result as StoreResult, Store, types::Changes},
39    types::{
40        EventEncryptionAlgorithm,
41        events::EventType,
42        requests::{OutgoingRequest, ToDeviceRequest},
43    },
44};
45
46#[derive(Debug, Clone)]
47pub(crate) struct SessionManager {
48    store: Store,
49
50    /// If there is an active /keys/claim request, its details.
51    ///
52    /// This is used when processing the response, so that we can spot missing
53    /// users/devices.
54    ///
55    /// According to the doc on [`crate::OlmMachine::get_missing_sessions`],
56    /// there should only be one such request active at a time, so we only need
57    /// to keep a record of the most recent.
58    current_key_claim_request: Arc<StdRwLock<Option<(OwnedTransactionId, KeysClaimRequest)>>>,
59
60    /// A map of user/devices that we need to automatically claim keys for.
61    /// Submodules can insert user/device pairs into this map and the
62    /// user/device paris will be added to the list of users when
63    /// [`get_missing_sessions`](#method.get_missing_sessions) is called.
64    users_for_key_claim: Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>>>>,
65    wedged_devices: Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>>>>,
66    key_request_machine: GossipMachine,
67    outgoing_to_device_requests: Arc<StdRwLock<BTreeMap<OwnedTransactionId, OutgoingRequest>>>,
68
69    /// Servers that have previously appeared in the `failures` section of a
70    /// `/keys/claim` response.
71    ///
72    /// See also [`crate::identities::IdentityManager::failures`].
73    failures: FailuresCache<OwnedServerName>,
74
75    failed_devices: Arc<StdRwLock<BTreeMap<OwnedUserId, FailuresCache<OwnedDeviceId>>>>,
76}
77
78impl SessionManager {
79    const KEY_CLAIM_TIMEOUT: Duration = Duration::from_secs(10);
80    const UNWEDGING_INTERVAL: Duration = Duration::from_secs(60 * 60);
81
82    pub fn new(
83        users_for_key_claim: Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>>>>,
84        key_request_machine: GossipMachine,
85        store: Store,
86    ) -> Self {
87        Self {
88            store,
89            current_key_claim_request: Default::default(),
90            key_request_machine,
91            users_for_key_claim,
92            wedged_devices: Default::default(),
93            outgoing_to_device_requests: Default::default(),
94            failures: Default::default(),
95            failed_devices: Default::default(),
96        }
97    }
98
99    /// Mark the outgoing request as sent.
100    pub fn mark_outgoing_request_as_sent(&self, id: &TransactionId) {
101        self.outgoing_to_device_requests.write().remove(id);
102    }
103
104    pub async fn mark_device_as_wedged(
105        &self,
106        sender: &UserId,
107        curve_key: Curve25519PublicKey,
108    ) -> OlmResult<()> {
109        if let Some(device) = self.store.get_device_from_curve_key(sender, curve_key).await?
110            && let Some(session) = device.get_most_recent_session().await?
111        {
112            info!(sender_key = ?curve_key, "Marking session to be unwedged");
113
114            let creation_time = Duration::from_secs(session.creation_time.get().into());
115            let now = Duration::from_secs(SecondsSinceUnixEpoch::now().get().into());
116
117            let should_unwedge = now
118                .checked_sub(creation_time)
119                .map(|elapsed| elapsed > Self::UNWEDGING_INTERVAL)
120                .unwrap_or(true);
121
122            if should_unwedge {
123                self.users_for_key_claim
124                    .write()
125                    .entry(device.user_id().to_owned())
126                    .or_default()
127                    .insert(device.device_id().into());
128                self.wedged_devices
129                    .write()
130                    .entry(device.user_id().to_owned())
131                    .or_default()
132                    .insert(device.device_id().into());
133            }
134        }
135
136        Ok(())
137    }
138
139    #[allow(dead_code)]
140    pub fn is_device_wedged(&self, device: &DeviceData) -> bool {
141        self.wedged_devices
142            .read()
143            .get(device.user_id())
144            .is_some_and(|d| d.contains(device.device_id()))
145    }
146
147    /// Check if the session was created to unwedge a Device.
148    ///
149    /// If the device was wedged this will queue up a dummy to-device message.
150    async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> {
151        if self.wedged_devices.write().get_mut(user_id).is_some_and(|d| d.remove(device_id))
152            && let Some(device) = self.store.get_device(user_id, device_id).await?
153        {
154            let (_, content) = device.encrypt("m.dummy", ToDeviceDummyEventContent::new()).await?;
155
156            let event_type = content.event_type().to_owned();
157
158            let request = ToDeviceRequest::new(
159                device.user_id(),
160                device.device_id().to_owned(),
161                &event_type,
162                content.cast(),
163            );
164
165            let request = OutgoingRequest {
166                request_id: request.txn_id.clone(),
167                request: Arc::new(request.into()),
168            };
169
170            self.outgoing_to_device_requests.write().insert(request.request_id.clone(), request);
171        }
172
173        Ok(())
174    }
175
176    /// Get a key claiming request for the user/device pairs that we are
177    /// missing Olm sessions for.
178    ///
179    /// Returns None if no key claiming request needs to be sent out.
180    ///
181    /// Sessions need to be established between devices so group sessions for a
182    /// room can be shared with them.
183    ///
184    /// This should be called every time a group session needs to be shared as
185    /// well as between sync calls. After a sync some devices may request room
186    /// keys without us having a valid Olm session with them, making it
187    /// impossible to server the room key request, thus it's necessary to check
188    /// for missing sessions between sync as well.
189    ///
190    /// **Note**: Care should be taken that only one such request at a time is
191    /// in flight, e.g. using a lock.
192    ///
193    /// The response of a successful key claiming requests needs to be passed to
194    /// the `OlmMachine` with the [`receive_keys_claim_response`].
195    ///
196    /// # Arguments
197    ///
198    /// `users` - The list of users that we should check if we lack a session
199    /// with one of their devices. This can be an empty iterator when calling
200    /// this method between sync requests.
201    ///
202    /// [`receive_keys_claim_response`]: #method.receive_keys_claim_response
203    pub async fn get_missing_sessions(
204        &self,
205        users: impl Iterator<Item = &UserId>,
206    ) -> StoreResult<Option<(OwnedTransactionId, KeysClaimRequest)>> {
207        let mut missing_session_devices_by_user: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new();
208        let mut timed_out_devices_by_user: BTreeMap<_, BTreeSet<_>> = BTreeMap::new();
209
210        let unfailed_users = users.filter(|u| !self.failures.contains(u.server_name()));
211
212        // Get the current list of devices for each user.
213        let devices_by_user = Box::pin(
214            self.key_request_machine
215                .identity_manager()
216                .get_user_devices_for_encryption(unfailed_users),
217        )
218        .await?;
219
220        #[derive(Debug, Default)]
221        struct UserFailedDeviceInfo {
222            non_olm_devices: BTreeMap<OwnedDeviceId, Vec<EventEncryptionAlgorithm>>,
223            bad_key_devices: BTreeSet<OwnedDeviceId>,
224        }
225
226        let mut failed_devices_by_user: BTreeMap<_, UserFailedDeviceInfo> = BTreeMap::new();
227
228        for (user_id, user_devices) in devices_by_user {
229            for (device_id, device) in user_devices {
230                if !device.supports_olm() {
231                    failed_devices_by_user
232                        .entry(user_id.clone())
233                        .or_default()
234                        .non_olm_devices
235                        .insert(device_id, Vec::from(device.algorithms()));
236                } else if let Some(sender_key) = device.curve25519_key() {
237                    let sessions = self.store.get_sessions(&sender_key.to_base64()).await?;
238
239                    let is_missing = if let Some(sessions) = sessions {
240                        sessions.lock().await.is_empty()
241                    } else {
242                        true
243                    };
244
245                    let is_timed_out = self.is_user_timed_out(&user_id, &device_id);
246
247                    if is_missing && is_timed_out {
248                        timed_out_devices_by_user
249                            .entry(user_id.to_owned())
250                            .or_default()
251                            .insert(device_id);
252                    } else if is_missing && !is_timed_out {
253                        missing_session_devices_by_user
254                            .entry(user_id.to_owned())
255                            .or_default()
256                            .insert(device_id, OneTimeKeyAlgorithm::SignedCurve25519);
257                    }
258                } else {
259                    failed_devices_by_user
260                        .entry(user_id.clone())
261                        .or_default()
262                        .bad_key_devices
263                        .insert(device_id);
264                }
265            }
266        }
267
268        // Add the list of sessions that for some reason automatically need to
269        // create an Olm session.
270        for (user, device_ids) in self.users_for_key_claim.read().iter() {
271            missing_session_devices_by_user.entry(user.to_owned()).or_default().extend(
272                device_ids
273                    .iter()
274                    .map(|device_id| (device_id.clone(), OneTimeKeyAlgorithm::SignedCurve25519)),
275            );
276        }
277
278        if tracing::level_enabled!(tracing::Level::DEBUG) {
279            // Reformat the map to skip the encryption algorithm, which isn't very useful.
280            let missing_session_devices_by_user = missing_session_devices_by_user
281                .iter()
282                .map(|(user_id, devices)| (user_id, devices.keys().collect::<BTreeSet<_>>()))
283                .collect::<BTreeMap<_, _>>();
284            debug!(
285                ?missing_session_devices_by_user,
286                ?timed_out_devices_by_user,
287                "Collected user/device pairs that are missing an Olm session"
288            );
289        }
290
291        if !failed_devices_by_user.is_empty() {
292            warn!(
293                ?failed_devices_by_user,
294                "Can't establish an Olm session with some devices due to missing Olm support or bad keys",
295            );
296        }
297
298        let result = if missing_session_devices_by_user.is_empty() {
299            None
300        } else {
301            Some((
302                TransactionId::new(),
303                assign!(KeysClaimRequest::new(missing_session_devices_by_user), {
304                    timeout: Some(Self::KEY_CLAIM_TIMEOUT),
305                }),
306            ))
307        };
308
309        // stash the details of the request so that we can refer to it when handling the
310        // response
311        *(self.current_key_claim_request.write()) = result.clone();
312        Ok(result)
313    }
314
315    fn is_user_timed_out(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
316        self.failed_devices.read().get(user_id).is_some_and(|d| d.contains(device_id))
317    }
318
319    /// This method will try to figure out for which devices a one-time key was
320    /// requested but is not present in the response.
321    ///
322    /// As per [spec], if a user/device pair does not have any one-time keys on
323    /// the homeserver, the server will just omit the user/device pair from
324    /// the response:
325    ///
326    /// > If the homeserver could be reached, but the user or device was
327    /// > unknown, no failure is recorded. Instead, the corresponding user
328    /// > or device is missing from the one_time_keys result.
329    ///
330    /// The user/device pairs which are missing from the response are going to
331    /// be put in the failures cache so we don't retry to claim a one-time
332    /// key right away next time the user tries to send a message.
333    ///
334    /// [spec]: https://spec.matrix.org/unstable/client-server-api/#post_matrixclientv3keysclaim
335    fn handle_otk_exhaustion_failure(
336        &self,
337        request_id: &TransactionId,
338        failed_servers: &BTreeSet<OwnedServerName>,
339        one_time_keys: &BTreeMap<
340            &OwnedUserId,
341            BTreeMap<&OwnedDeviceId, BTreeSet<&OwnedOneTimeKeyId>>,
342        >,
343    ) {
344        // First check that the response is for the request we were expecting.
345        let request = {
346            let mut guard = self.current_key_claim_request.write();
347            let expected_request_id = guard.as_ref().map(|e| e.0.as_ref());
348
349            if Some(request_id) == expected_request_id {
350                // We have a confirmed match. Clear the expectation, but hang onto the details
351                // of the request.
352                guard.take().map(|(_, request)| request)
353            } else {
354                warn!(
355                    ?request_id,
356                    ?expected_request_id,
357                    "Received a `/keys/claim` response for the wrong request"
358                );
359                None
360            }
361        };
362
363        // If we were able to pair this response with a request, look for devices that
364        // were present in the request but did not elicit a successful response.
365        if let Some(request) = request {
366            let devices_in_response: BTreeSet<_> = one_time_keys
367                .iter()
368                .flat_map(|(user_id, device_key_map)| {
369                    device_key_map
370                        .keys()
371                        .map(|device_id| (*user_id, *device_id))
372                        .collect::<BTreeSet<_>>()
373                })
374                .collect();
375
376            let devices_in_request: BTreeSet<(_, _)> = request
377                .one_time_keys
378                .iter()
379                .flat_map(|(user_id, device_key_map)| {
380                    device_key_map
381                        .keys()
382                        .map(|device_id| (user_id, device_id))
383                        .collect::<BTreeSet<_>>()
384                })
385                .collect();
386
387            let missing_devices: BTreeSet<_> = devices_in_request
388                .difference(&devices_in_response)
389                .filter(|(user_id, _)| {
390                    // Skip over users whose homeservers were in the "failed servers" list: we don't
391                    // want to mark individual devices as broken *as well as* the server.
392                    !failed_servers.contains(user_id.server_name())
393                })
394                .collect();
395
396            if !missing_devices.is_empty() {
397                let mut missing_devices_by_user: BTreeMap<_, BTreeSet<_>> = BTreeMap::new();
398
399                for &(user_id, device_id) in missing_devices {
400                    missing_devices_by_user.entry(user_id).or_default().insert(device_id.clone());
401                }
402
403                warn!(
404                    ?missing_devices_by_user,
405                    "Tried to create new Olm sessions, but the signed one-time key was missing for some devices",
406                );
407
408                let mut failed_devices_lock = self.failed_devices.write();
409
410                for (user_id, device_set) in missing_devices_by_user {
411                    failed_devices_lock.entry(user_id.clone()).or_default().extend(device_set);
412                }
413            }
414        }
415    }
416
417    /// Receive a successful key claim response and create new Olm sessions with
418    /// the claimed keys.
419    ///
420    /// # Arguments
421    ///
422    /// * `request_id` - The unique id of the request that was sent out. This is
423    ///   needed to couple the response with the sent out request.
424    ///
425    /// * `response` - The response containing the claimed one-time keys.
426    #[instrument(skip(self, response))]
427    pub async fn receive_keys_claim_response(
428        &self,
429        request_id: &TransactionId,
430        response: &KeysClaimResponse,
431    ) -> OlmResult<()> {
432        // Collect the (user_id, device_id, device_key_id) triple for logging reasons.
433        let one_time_keys: BTreeMap<_, BTreeMap<_, BTreeSet<_>>> = response
434            .one_time_keys
435            .iter()
436            .map(|(user_id, device_map)| {
437                (
438                    user_id,
439                    device_map
440                        .iter()
441                        .map(|(device_id, key_map)| {
442                            (device_id, key_map.keys().collect::<BTreeSet<_>>())
443                        })
444                        .collect::<BTreeMap<_, _>>(),
445                )
446            })
447            .collect();
448
449        debug!(?request_id, ?one_time_keys, failures = ?response.failures, "Received a `/keys/claim` response");
450
451        // Collect all the servers in the `failures` field of the response.
452        let failed_servers: BTreeSet<_> = response
453            .failures
454            .keys()
455            .filter_map(|s| ServerName::parse(s).ok())
456            .filter(|s| s != self.store.static_account().user_id.server_name())
457            .collect();
458        let successful_servers = response.one_time_keys.keys().map(|u| u.server_name());
459
460        // Add the user/device pairs that don't have any one-time keys to the failures
461        // cache.
462        self.handle_otk_exhaustion_failure(request_id, &failed_servers, &one_time_keys);
463        // Add the failed servers to the failures cache.
464        self.failures.extend(failed_servers);
465        // Remove the servers we successfully contacted from the failures cache.
466        self.failures.remove(successful_servers);
467
468        // Finally, create some 1-to-1 sessions.
469        self.create_sessions(response).await
470    }
471
472    /// Create new Olm sessions for the requested devices.
473    ///
474    /// # Arguments
475    ///
476    ///  * `device_map` - a map from (user ID, device ID) pairs to key object,
477    ///    for each device we should create a session for.
478    pub(crate) async fn create_sessions(&self, response: &KeysClaimResponse) -> OlmResult<()> {
479        struct SessionInfo {
480            session_id: String,
481            algorithm: EventEncryptionAlgorithm,
482            fallback_key_used: bool,
483        }
484
485        #[cfg(not(tarpaulin_include))]
486        impl std::fmt::Debug for SessionInfo {
487            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488                write!(
489                    f,
490                    "session_id: {}, algorithm: {}, fallback_key_used: {}",
491                    self.session_id, self.algorithm, self.fallback_key_used
492                )
493            }
494        }
495
496        let mut changes = Changes::default();
497        let mut new_sessions: BTreeMap<&UserId, BTreeMap<&DeviceId, SessionInfo>> = BTreeMap::new();
498        let mut store_transaction = self.store.transaction().await;
499
500        for (user_id, user_devices) in &response.one_time_keys {
501            for (device_id, key_map) in user_devices {
502                let device = match self.store.get_device_data(user_id, device_id).await {
503                    Ok(Some(d)) => d,
504                    Ok(None) => {
505                        warn!(
506                            ?user_id,
507                            ?device_id,
508                            "Tried to create an Olm session but the device is unknown",
509                        );
510                        continue;
511                    }
512                    Err(e) => {
513                        warn!(
514                            ?user_id, ?device_id, error = ?e,
515                            "Tried to create an Olm session, but we can't \
516                            fetch the device from the store",
517                        );
518                        continue;
519                    }
520                };
521
522                let account = store_transaction.account().await?;
523                let device_keys = self.store.get_own_device().await?.as_device_keys().clone();
524                let session = match account.create_outbound_session(&device, key_map, device_keys) {
525                    Ok(s) => s,
526                    Err(e) => {
527                        warn!(
528                            ?user_id, ?device_id, error = ?e,
529                            "Error creating Olm session"
530                        );
531
532                        self.failed_devices
533                            .write()
534                            .entry(user_id.to_owned())
535                            .or_default()
536                            .insert(device_id.to_owned());
537
538                        continue;
539                    }
540                };
541
542                self.key_request_machine.retry_keyshare(user_id, device_id);
543
544                if let Err(e) = self.check_if_unwedged(user_id, device_id).await {
545                    error!(?user_id, ?device_id, "Error while treating an unwedged device: {e:?}");
546                }
547
548                let session_info = SessionInfo {
549                    session_id: session.session_id().to_owned(),
550                    algorithm: session.algorithm().await,
551                    fallback_key_used: session.created_using_fallback_key,
552                };
553
554                changes.sessions.push(session);
555                new_sessions.entry(user_id).or_default().insert(device_id, session_info);
556            }
557        }
558
559        store_transaction.commit().await?;
560        self.store.save_changes(changes).await?;
561        info!(sessions = ?new_sessions, "Established new Olm sessions");
562
563        for (user, device_map) in new_sessions {
564            if let Some(user_cache) = self.failed_devices.read().get(user) {
565                user_cache.remove(device_map.into_keys());
566            }
567        }
568
569        let store_cache = self.store.cache().await?;
570        match self.key_request_machine.collect_incoming_key_requests(&store_cache).await {
571            Ok(sessions) => {
572                let changes = Changes { sessions, ..Default::default() };
573                self.store.save_changes(changes).await?
574            }
575            // We don't propagate the error here since the next sync will retry
576            // this.
577            Err(e) => {
578                warn!(error = ?e, "Error while trying to collect the incoming secret requests")
579            }
580        }
581
582        Ok(())
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use std::{collections::BTreeMap, iter, ops::Deref, sync::Arc, time::Duration};
589
590    use matrix_sdk_common::{executor::spawn, locks::RwLock as StdRwLock};
591    use matrix_sdk_test::{async_test, ruma_response_from_json};
592    use ruma::{
593        DeviceId, OwnedUserId, UserId,
594        api::client::keys::claim_keys::v3::Response as KeyClaimResponse, device_id,
595        owned_server_name, user_id,
596    };
597    use serde_json::json;
598    use tokio::sync::Mutex;
599    use tracing::info;
600
601    use super::SessionManager;
602    use crate::{
603        gossiping::GossipMachine,
604        identities::{DeviceData, IdentityManager},
605        olm::{Account, PrivateCrossSigningIdentity},
606        session_manager::GroupSessionCache,
607        store::{
608            CryptoStoreWrapper, MemoryStore, Store,
609            types::{Changes, DeviceChanges, PendingChanges},
610        },
611        verification::VerificationMachine,
612    };
613
614    fn user_id() -> &'static UserId {
615        user_id!("@example:localhost")
616    }
617
618    fn device_id() -> &'static DeviceId {
619        device_id!("DEVICEID")
620    }
621
622    fn bob_account() -> Account {
623        Account::with_device_id(user_id!("@bob:localhost"), device_id!("BOBDEVICE"))
624    }
625
626    fn keys_claim_with_failure() -> KeyClaimResponse {
627        let response = json!({
628            "one_time_keys": {},
629            "failures": {
630                "example.org": {
631                    "errcode": "M_RESOURCE_LIMIT_EXCEEDED",
632                    "error": "Not yet ready to retry",
633                }
634            }
635        });
636        ruma_response_from_json(&response)
637    }
638
639    fn keys_claim_without_failure() -> KeyClaimResponse {
640        let response = json!({
641            "one_time_keys": {
642                "@alice:example.org": {},
643            },
644            "failures": {},
645        });
646        ruma_response_from_json(&response)
647    }
648
649    async fn session_manager_test_helper() -> (SessionManager, IdentityManager) {
650        let user_id = user_id();
651        let device_id = device_id();
652
653        let account = Account::with_device_id(user_id, device_id);
654        let store = Arc::new(CryptoStoreWrapper::new(user_id, device_id, MemoryStore::new()));
655        let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id)));
656        let verification = VerificationMachine::new(
657            account.static_data().clone(),
658            identity.clone(),
659            store.clone(),
660        );
661
662        let store = Store::new(account.static_data().clone(), identity, store, verification);
663        let device = DeviceData::from_account(&account);
664        store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
665        store
666            .save_changes(Changes {
667                devices: DeviceChanges { new: vec![device], ..Default::default() },
668                ..Default::default()
669            })
670            .await
671            .unwrap();
672
673        let session_cache = GroupSessionCache::new(store.clone());
674        let identity_manager = IdentityManager::new(store.clone());
675
676        let users_for_key_claim = Arc::new(StdRwLock::new(BTreeMap::new()));
677        let key_request = GossipMachine::new(
678            store.clone(),
679            identity_manager.clone(),
680            session_cache,
681            users_for_key_claim.clone(),
682        );
683
684        (SessionManager::new(users_for_key_claim, key_request, store), identity_manager)
685    }
686
687    #[async_test]
688    async fn test_session_creation() {
689        let (manager, _identity_manager) = session_manager_test_helper().await;
690        let mut bob = bob_account();
691
692        let bob_device = DeviceData::from_account(&bob);
693
694        manager.store.save_device_data(&[bob_device]).await.unwrap();
695
696        let (txn_id, request) =
697            manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().unwrap();
698
699        assert!(request.one_time_keys.contains_key(bob.user_id()));
700
701        bob.generate_one_time_keys(1);
702        let one_time = bob.signed_one_time_keys();
703        assert!(!one_time.is_empty());
704        bob.mark_keys_as_published();
705
706        let mut one_time_keys = BTreeMap::new();
707        one_time_keys
708            .entry(bob.user_id().to_owned())
709            .or_insert_with(BTreeMap::new)
710            .insert(bob.device_id().to_owned(), one_time);
711
712        let response = KeyClaimResponse::new(one_time_keys);
713
714        manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
715
716        assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
717    }
718
719    #[async_test]
720    async fn test_session_creation_waits_for_keys_query() {
721        let (manager, identity_manager) = session_manager_test_helper().await;
722
723        // start a `/keys/query` request. At this point, we are only interested in our
724        // own devices.
725        let (key_query_txn_id, key_query_request) =
726            identity_manager.users_for_key_query().await.unwrap().pop_first().unwrap();
727        info!("Initial key query: {:?}", key_query_request);
728
729        // now bob turns up, and we start tracking his devices...
730        let bob = bob_account();
731        let bob_device = DeviceData::from_account(&bob);
732        {
733            let cache = manager.store.cache().await.unwrap();
734            identity_manager
735                .key_query_manager
736                .synced(&cache)
737                .await
738                .unwrap()
739                .update_tracked_users(iter::once(bob.user_id()))
740                .await
741                .unwrap();
742        }
743
744        // ... and start off an attempt to get the missing sessions. This should block
745        // for now.
746        let missing_sessions_task = {
747            let manager = manager.clone();
748            let bob_user_id = bob.user_id().to_owned();
749
750            #[allow(unknown_lints, clippy::redundant_async_block)] // false positive
751            spawn(
752                async move { manager.get_missing_sessions(iter::once(bob_user_id.deref())).await },
753            )
754        };
755
756        // the initial `/keys/query` completes, and we start another
757        let response_json =
758            json!({ "device_keys": { manager.store.static_account().user_id.to_owned(): {}}});
759        let response = ruma_response_from_json(&response_json);
760        identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap();
761
762        let (key_query_txn_id, key_query_request) =
763            identity_manager.users_for_key_query().await.unwrap().pop_first().unwrap();
764        info!("Second key query: {:?}", key_query_request);
765
766        // that second request completes with info on bob's device
767        let response_json = json!({ "device_keys": { bob.user_id(): {
768            bob_device.device_id(): bob_device.as_device_keys()
769        }}});
770        let response = ruma_response_from_json(&response_json);
771        identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap();
772
773        // the missing_sessions_task should now finally complete, with a claim
774        // including bob's device
775        let (_, keys_claim_request) = missing_sessions_task.await.unwrap().unwrap().unwrap();
776        info!("Key claim request: {:?}", keys_claim_request.one_time_keys);
777        let bob_key_claims = keys_claim_request.one_time_keys.get(bob.user_id()).unwrap();
778        assert!(bob_key_claims.contains_key(bob_device.device_id()));
779    }
780
781    #[async_test]
782    async fn test_session_creation_does_not_wait_for_keys_query_on_failed_server() {
783        let (manager, identity_manager) = session_manager_test_helper().await;
784
785        // We start tracking Bob's devices.
786        let other_user_id = OwnedUserId::try_from("@bob:example.com").unwrap();
787        {
788            let cache = manager.store.cache().await.unwrap();
789            identity_manager
790                .key_query_manager
791                .synced(&cache)
792                .await
793                .unwrap()
794                .update_tracked_users(iter::once(other_user_id.as_ref()))
795                .await
796                .unwrap();
797        }
798
799        // Do a `/keys/query` request, in which Bob's server is a failure.
800        let (key_query_txn_id, _key_query_request) =
801            identity_manager.users_for_key_query().await.unwrap().pop_first().unwrap();
802        let response = ruma_response_from_json(
803            &json!({ "device_keys": {}, "failures": { other_user_id.server_name(): "unreachable" }}),
804        );
805        identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap();
806
807        // Now, an attempt to get the missing sessions should now *not* block. We use a
808        // timeout so that we can detect the call blocking.
809        let result = tokio::time::timeout(
810            Duration::from_millis(10),
811            manager.get_missing_sessions(iter::once(other_user_id.as_ref())),
812        )
813        .await
814        .expect("get_missing_sessions blocked rather than completing quickly")
815        .expect("get_missing_sessions returned an error");
816
817        assert!(result.is_none(), "get_missing_sessions returned Some(...)");
818    }
819
820    // This test doesn't run on macos because we're modifying the session
821    // creation time so we can get around the UNWEDGING_INTERVAL.
822    #[async_test]
823    #[cfg(target_os = "linux")]
824    async fn test_session_unwedging() {
825        use ruma::{SecondsSinceUnixEpoch, time::SystemTime};
826
827        let (manager, _identity_manager) = session_manager_test_helper().await;
828        let mut bob = bob_account();
829
830        let (_, mut session) = manager
831            .store
832            .with_transaction(|mut tr| async {
833                let manager_account = tr.account().await.unwrap();
834                let res = bob.create_session_for_test_helper(manager_account).await;
835                Ok((tr, res))
836            })
837            .await
838            .unwrap();
839
840        let bob_device = DeviceData::from_account(&bob);
841        let time = SystemTime::now() - Duration::from_secs(3601);
842        session.creation_time = SecondsSinceUnixEpoch::from_system_time(time).unwrap();
843
844        let devices = std::slice::from_ref(&bob_device);
845        manager.store.save_device_data(devices).await.unwrap();
846        manager.store.save_sessions(&[session]).await.unwrap();
847
848        assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
849
850        let curve_key = bob_device.curve25519_key().unwrap();
851
852        assert!(!manager.users_for_key_claim.read().contains_key(bob.user_id()));
853        assert!(!manager.is_device_wedged(&bob_device));
854        manager.mark_device_as_wedged(bob_device.user_id(), curve_key).await.unwrap();
855        assert!(manager.is_device_wedged(&bob_device));
856        assert!(manager.users_for_key_claim.read().contains_key(bob.user_id()));
857
858        let (txn_id, request) =
859            manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().unwrap();
860
861        assert!(request.one_time_keys.contains_key(bob.user_id()));
862
863        bob.generate_one_time_keys(1);
864        let one_time = bob.signed_one_time_keys();
865        assert!(!one_time.is_empty());
866        bob.mark_keys_as_published();
867
868        let mut one_time_keys = BTreeMap::new();
869        one_time_keys
870            .entry(bob.user_id().to_owned())
871            .or_insert_with(BTreeMap::new)
872            .insert(bob.device_id().to_owned(), one_time);
873
874        let response = KeyClaimResponse::new(one_time_keys);
875
876        assert!(manager.outgoing_to_device_requests.read().is_empty());
877
878        manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
879
880        assert!(!manager.is_device_wedged(&bob_device));
881        assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
882        assert!(!manager.outgoing_to_device_requests.read().is_empty())
883    }
884
885    #[async_test]
886    async fn test_failure_handling() {
887        let alice = user_id!("@alice:example.org");
888        let alice_account = Account::with_device_id(alice, "DEVICEID".into());
889        let alice_device = DeviceData::from_account(&alice_account);
890
891        let (manager, _identity_manager) = session_manager_test_helper().await;
892
893        manager.store.save_device_data(&[alice_device]).await.unwrap();
894
895        let (txn_id, users_for_key_claim) =
896            manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
897        assert!(users_for_key_claim.one_time_keys.contains_key(alice));
898
899        manager.receive_keys_claim_response(&txn_id, &keys_claim_with_failure()).await.unwrap();
900        assert!(manager.get_missing_sessions(iter::once(alice)).await.unwrap().is_none());
901
902        // expire the failure
903        manager.failures.expire(&owned_server_name!("example.org"));
904
905        let (txn_id, users_for_key_claim) =
906            manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
907        assert!(users_for_key_claim.one_time_keys.contains_key(alice));
908
909        manager.receive_keys_claim_response(&txn_id, &keys_claim_without_failure()).await.unwrap();
910    }
911
912    #[async_test]
913    async fn test_failed_devices_handling() {
914        // Alice is missing altogether
915        test_invalid_claim_response(json!({
916            "one_time_keys": {},
917            "failures": {},
918        }))
919        .await;
920
921        // Alice is present but with no devices
922        test_invalid_claim_response(json!({
923            "one_time_keys": {
924                "@alice:example.org": {}
925            },
926            "failures": {},
927        }))
928        .await;
929
930        // Alice's device is present but with no keys
931        test_invalid_claim_response(json!({
932            "one_time_keys": {
933                "@alice:example.org": {
934                    "DEVICEID": {}
935                }
936            },
937            "failures": {},
938        }))
939        .await;
940
941        // Alice's device is present with a bad signature
942        test_invalid_claim_response(json!({
943            "one_time_keys": {
944                "@alice:example.org": {
945                    "DEVICEID": {
946                        "signed_curve25519:AAAAAA": {
947                            "fallback": true,
948                            "key": "1sra5GVo1ONz478aQybxSEeHTSo2xq0Z+Q3Yzqvp3A4",
949                            "signatures": {
950                                "@example:morpheus.localhost": {
951                                    "ed25519:YAFLBLXAUK": "Zwk90fJhZWOYGNOgtOswZ6RSOGeTjTi/h2dMpyB0CR6EVtvTra0WJtp32ntifrxtwD710y2F3pe5Oyrm7jngCQ"
952                                }
953                            }
954                        }
955                    }
956                }
957            },
958            "failures": {},
959        })).await;
960    }
961
962    /// Helper for failed_devices_handling.
963    ///
964    /// Takes an invalid /keys/claim response for Alice's device DEVICEID and
965    /// checks that it is handled correctly. (The device should be marked as
966    /// 'failed'; and once that
967    async fn test_invalid_claim_response(response_json: serde_json::Value) {
968        let response = ruma_response_from_json(&response_json);
969
970        let alice = user_id!("@alice:example.org");
971        let mut alice_account = Account::with_device_id(alice, "DEVICEID".into());
972        let alice_device = DeviceData::from_account(&alice_account);
973
974        let (manager, _identity_manager) = session_manager_test_helper().await;
975        manager.store.save_device_data(&[alice_device]).await.unwrap();
976
977        // Since we don't have a session with Alice yet, the machine will try to claim
978        // some keys for alice.
979        let (txn_id, users_for_key_claim) =
980            manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
981        assert!(users_for_key_claim.one_time_keys.contains_key(alice));
982
983        // We receive a response with an invalid one-time key, this will mark Alice as
984        // timed out.
985        manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
986        // Since alice is timed out, we won't claim keys for her.
987        assert!(manager.get_missing_sessions(iter::once(alice)).await.unwrap().is_none());
988
989        alice_account.generate_one_time_keys(1);
990        let one_time = alice_account.signed_one_time_keys();
991        assert!(!one_time.is_empty());
992
993        let mut one_time_keys = BTreeMap::new();
994        one_time_keys
995            .entry(alice.to_owned())
996            .or_insert_with(BTreeMap::new)
997            .insert(alice_account.device_id().to_owned(), one_time);
998
999        // Now we expire Alice's timeout, and receive a valid one-time key for her.
1000        manager
1001            .failed_devices
1002            .write()
1003            .get(alice)
1004            .unwrap()
1005            .expire(&alice_account.device_id().to_owned());
1006        let (txn_id, users_for_key_claim) =
1007            manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
1008        assert!(users_for_key_claim.one_time_keys.contains_key(alice));
1009
1010        let response = KeyClaimResponse::new(one_time_keys);
1011        manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
1012
1013        // Alice isn't timed out anymore.
1014        assert!(
1015            manager
1016                .failed_devices
1017                .read()
1018                .get(alice)
1019                .unwrap()
1020                .failure_count(alice_account.device_id())
1021                .is_none()
1022        );
1023    }
1024}