1use std::{
16 collections::{BTreeMap, BTreeSet},
17 sync::Arc,
18 time::Duration,
19};
20
21use matrix_sdk_common::{failures_cache::FailuresCache, locks::RwLock as StdRwLock};
22use ruma::{
23 DeviceId, OneTimeKeyAlgorithm, OwnedDeviceId, OwnedOneTimeKeyId, OwnedServerName,
24 OwnedTransactionId, OwnedUserId, SecondsSinceUnixEpoch, ServerName, TransactionId, UserId,
25 api::client::keys::claim_keys::v3::{
26 Request as KeysClaimRequest, Response as KeysClaimResponse,
27 },
28 assign,
29 events::dummy::ToDeviceDummyEventContent,
30};
31use tracing::{debug, error, info, instrument, warn};
32use vodozemac::Curve25519PublicKey;
33
34use crate::{
35 DeviceData,
36 error::OlmResult,
37 gossiping::GossipMachine,
38 store::{Result as StoreResult, Store, types::Changes},
39 types::{
40 EventEncryptionAlgorithm,
41 events::EventType,
42 requests::{OutgoingRequest, ToDeviceRequest},
43 },
44};
45
46#[derive(Debug, Clone)]
47pub(crate) struct SessionManager {
48 store: Store,
49
50 current_key_claim_request: Arc<StdRwLock<Option<(OwnedTransactionId, KeysClaimRequest)>>>,
59
60 users_for_key_claim: Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>>>>,
65 wedged_devices: Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>>>>,
66 key_request_machine: GossipMachine,
67 outgoing_to_device_requests: Arc<StdRwLock<BTreeMap<OwnedTransactionId, OutgoingRequest>>>,
68
69 failures: FailuresCache<OwnedServerName>,
74
75 failed_devices: Arc<StdRwLock<BTreeMap<OwnedUserId, FailuresCache<OwnedDeviceId>>>>,
76}
77
78impl SessionManager {
79 const KEY_CLAIM_TIMEOUT: Duration = Duration::from_secs(10);
80 const UNWEDGING_INTERVAL: Duration = Duration::from_secs(60 * 60);
81
82 pub fn new(
83 users_for_key_claim: Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>>>>,
84 key_request_machine: GossipMachine,
85 store: Store,
86 ) -> Self {
87 Self {
88 store,
89 current_key_claim_request: Default::default(),
90 key_request_machine,
91 users_for_key_claim,
92 wedged_devices: Default::default(),
93 outgoing_to_device_requests: Default::default(),
94 failures: Default::default(),
95 failed_devices: Default::default(),
96 }
97 }
98
99 pub fn mark_outgoing_request_as_sent(&self, id: &TransactionId) {
101 self.outgoing_to_device_requests.write().remove(id);
102 }
103
104 pub async fn mark_device_as_wedged(
105 &self,
106 sender: &UserId,
107 curve_key: Curve25519PublicKey,
108 ) -> OlmResult<()> {
109 if let Some(device) = self.store.get_device_from_curve_key(sender, curve_key).await?
110 && let Some(session) = device.get_most_recent_session().await?
111 {
112 info!(sender_key = ?curve_key, "Marking session to be unwedged");
113
114 let creation_time = Duration::from_secs(session.creation_time.get().into());
115 let now = Duration::from_secs(SecondsSinceUnixEpoch::now().get().into());
116
117 let should_unwedge = now
118 .checked_sub(creation_time)
119 .map(|elapsed| elapsed > Self::UNWEDGING_INTERVAL)
120 .unwrap_or(true);
121
122 if should_unwedge {
123 self.users_for_key_claim
124 .write()
125 .entry(device.user_id().to_owned())
126 .or_default()
127 .insert(device.device_id().into());
128 self.wedged_devices
129 .write()
130 .entry(device.user_id().to_owned())
131 .or_default()
132 .insert(device.device_id().into());
133 }
134 }
135
136 Ok(())
137 }
138
139 #[allow(dead_code)]
140 pub fn is_device_wedged(&self, device: &DeviceData) -> bool {
141 self.wedged_devices
142 .read()
143 .get(device.user_id())
144 .is_some_and(|d| d.contains(device.device_id()))
145 }
146
147 async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> {
151 if self.wedged_devices.write().get_mut(user_id).is_some_and(|d| d.remove(device_id))
152 && let Some(device) = self.store.get_device(user_id, device_id).await?
153 {
154 let (_, content) = device.encrypt("m.dummy", ToDeviceDummyEventContent::new()).await?;
155
156 let event_type = content.event_type().to_owned();
157
158 let request = ToDeviceRequest::new(
159 device.user_id(),
160 device.device_id().to_owned(),
161 &event_type,
162 content.cast(),
163 );
164
165 let request = OutgoingRequest {
166 request_id: request.txn_id.clone(),
167 request: Arc::new(request.into()),
168 };
169
170 self.outgoing_to_device_requests.write().insert(request.request_id.clone(), request);
171 }
172
173 Ok(())
174 }
175
176 pub async fn get_missing_sessions(
204 &self,
205 users: impl Iterator<Item = &UserId>,
206 ) -> StoreResult<Option<(OwnedTransactionId, KeysClaimRequest)>> {
207 let mut missing_session_devices_by_user: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new();
208 let mut timed_out_devices_by_user: BTreeMap<_, BTreeSet<_>> = BTreeMap::new();
209
210 let unfailed_users = users.filter(|u| !self.failures.contains(u.server_name()));
211
212 let devices_by_user = Box::pin(
214 self.key_request_machine
215 .identity_manager()
216 .get_user_devices_for_encryption(unfailed_users),
217 )
218 .await?;
219
220 #[derive(Debug, Default)]
221 struct UserFailedDeviceInfo {
222 non_olm_devices: BTreeMap<OwnedDeviceId, Vec<EventEncryptionAlgorithm>>,
223 bad_key_devices: BTreeSet<OwnedDeviceId>,
224 }
225
226 let mut failed_devices_by_user: BTreeMap<_, UserFailedDeviceInfo> = BTreeMap::new();
227
228 for (user_id, user_devices) in devices_by_user {
229 for (device_id, device) in user_devices {
230 if !device.supports_olm() {
231 failed_devices_by_user
232 .entry(user_id.clone())
233 .or_default()
234 .non_olm_devices
235 .insert(device_id, Vec::from(device.algorithms()));
236 } else if let Some(sender_key) = device.curve25519_key() {
237 let sessions = self.store.get_sessions(&sender_key.to_base64()).await?;
238
239 let is_missing = if let Some(sessions) = sessions {
240 sessions.lock().await.is_empty()
241 } else {
242 true
243 };
244
245 let is_timed_out = self.is_user_timed_out(&user_id, &device_id);
246
247 if is_missing && is_timed_out {
248 timed_out_devices_by_user
249 .entry(user_id.to_owned())
250 .or_default()
251 .insert(device_id);
252 } else if is_missing && !is_timed_out {
253 missing_session_devices_by_user
254 .entry(user_id.to_owned())
255 .or_default()
256 .insert(device_id, OneTimeKeyAlgorithm::SignedCurve25519);
257 }
258 } else {
259 failed_devices_by_user
260 .entry(user_id.clone())
261 .or_default()
262 .bad_key_devices
263 .insert(device_id);
264 }
265 }
266 }
267
268 for (user, device_ids) in self.users_for_key_claim.read().iter() {
271 missing_session_devices_by_user.entry(user.to_owned()).or_default().extend(
272 device_ids
273 .iter()
274 .map(|device_id| (device_id.clone(), OneTimeKeyAlgorithm::SignedCurve25519)),
275 );
276 }
277
278 if tracing::level_enabled!(tracing::Level::DEBUG) {
279 let missing_session_devices_by_user = missing_session_devices_by_user
281 .iter()
282 .map(|(user_id, devices)| (user_id, devices.keys().collect::<BTreeSet<_>>()))
283 .collect::<BTreeMap<_, _>>();
284 debug!(
285 ?missing_session_devices_by_user,
286 ?timed_out_devices_by_user,
287 "Collected user/device pairs that are missing an Olm session"
288 );
289 }
290
291 if !failed_devices_by_user.is_empty() {
292 warn!(
293 ?failed_devices_by_user,
294 "Can't establish an Olm session with some devices due to missing Olm support or bad keys",
295 );
296 }
297
298 let result = if missing_session_devices_by_user.is_empty() {
299 None
300 } else {
301 Some((
302 TransactionId::new(),
303 assign!(KeysClaimRequest::new(missing_session_devices_by_user), {
304 timeout: Some(Self::KEY_CLAIM_TIMEOUT),
305 }),
306 ))
307 };
308
309 *(self.current_key_claim_request.write()) = result.clone();
312 Ok(result)
313 }
314
315 fn is_user_timed_out(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
316 self.failed_devices.read().get(user_id).is_some_and(|d| d.contains(device_id))
317 }
318
319 fn handle_otk_exhaustion_failure(
336 &self,
337 request_id: &TransactionId,
338 failed_servers: &BTreeSet<OwnedServerName>,
339 one_time_keys: &BTreeMap<
340 &OwnedUserId,
341 BTreeMap<&OwnedDeviceId, BTreeSet<&OwnedOneTimeKeyId>>,
342 >,
343 ) {
344 let request = {
346 let mut guard = self.current_key_claim_request.write();
347 let expected_request_id = guard.as_ref().map(|e| e.0.as_ref());
348
349 if Some(request_id) == expected_request_id {
350 guard.take().map(|(_, request)| request)
353 } else {
354 warn!(
355 ?request_id,
356 ?expected_request_id,
357 "Received a `/keys/claim` response for the wrong request"
358 );
359 None
360 }
361 };
362
363 if let Some(request) = request {
366 let devices_in_response: BTreeSet<_> = one_time_keys
367 .iter()
368 .flat_map(|(user_id, device_key_map)| {
369 device_key_map
370 .keys()
371 .map(|device_id| (*user_id, *device_id))
372 .collect::<BTreeSet<_>>()
373 })
374 .collect();
375
376 let devices_in_request: BTreeSet<(_, _)> = request
377 .one_time_keys
378 .iter()
379 .flat_map(|(user_id, device_key_map)| {
380 device_key_map
381 .keys()
382 .map(|device_id| (user_id, device_id))
383 .collect::<BTreeSet<_>>()
384 })
385 .collect();
386
387 let missing_devices: BTreeSet<_> = devices_in_request
388 .difference(&devices_in_response)
389 .filter(|(user_id, _)| {
390 !failed_servers.contains(user_id.server_name())
393 })
394 .collect();
395
396 if !missing_devices.is_empty() {
397 let mut missing_devices_by_user: BTreeMap<_, BTreeSet<_>> = BTreeMap::new();
398
399 for &(user_id, device_id) in missing_devices {
400 missing_devices_by_user.entry(user_id).or_default().insert(device_id.clone());
401 }
402
403 warn!(
404 ?missing_devices_by_user,
405 "Tried to create new Olm sessions, but the signed one-time key was missing for some devices",
406 );
407
408 let mut failed_devices_lock = self.failed_devices.write();
409
410 for (user_id, device_set) in missing_devices_by_user {
411 failed_devices_lock.entry(user_id.clone()).or_default().extend(device_set);
412 }
413 }
414 }
415 }
416
417 #[instrument(skip(self, response))]
427 pub async fn receive_keys_claim_response(
428 &self,
429 request_id: &TransactionId,
430 response: &KeysClaimResponse,
431 ) -> OlmResult<()> {
432 let one_time_keys: BTreeMap<_, BTreeMap<_, BTreeSet<_>>> = response
434 .one_time_keys
435 .iter()
436 .map(|(user_id, device_map)| {
437 (
438 user_id,
439 device_map
440 .iter()
441 .map(|(device_id, key_map)| {
442 (device_id, key_map.keys().collect::<BTreeSet<_>>())
443 })
444 .collect::<BTreeMap<_, _>>(),
445 )
446 })
447 .collect();
448
449 debug!(?request_id, ?one_time_keys, failures = ?response.failures, "Received a `/keys/claim` response");
450
451 let failed_servers: BTreeSet<_> = response
453 .failures
454 .keys()
455 .filter_map(|s| ServerName::parse(s).ok())
456 .filter(|s| s != self.store.static_account().user_id.server_name())
457 .collect();
458 let successful_servers = response.one_time_keys.keys().map(|u| u.server_name());
459
460 self.handle_otk_exhaustion_failure(request_id, &failed_servers, &one_time_keys);
463 self.failures.extend(failed_servers);
465 self.failures.remove(successful_servers);
467
468 self.create_sessions(response).await
470 }
471
472 pub(crate) async fn create_sessions(&self, response: &KeysClaimResponse) -> OlmResult<()> {
479 struct SessionInfo {
480 session_id: String,
481 algorithm: EventEncryptionAlgorithm,
482 fallback_key_used: bool,
483 }
484
485 #[cfg(not(tarpaulin_include))]
486 impl std::fmt::Debug for SessionInfo {
487 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488 write!(
489 f,
490 "session_id: {}, algorithm: {}, fallback_key_used: {}",
491 self.session_id, self.algorithm, self.fallback_key_used
492 )
493 }
494 }
495
496 let mut changes = Changes::default();
497 let mut new_sessions: BTreeMap<&UserId, BTreeMap<&DeviceId, SessionInfo>> = BTreeMap::new();
498 let mut store_transaction = self.store.transaction().await;
499
500 for (user_id, user_devices) in &response.one_time_keys {
501 for (device_id, key_map) in user_devices {
502 let device = match self.store.get_device_data(user_id, device_id).await {
503 Ok(Some(d)) => d,
504 Ok(None) => {
505 warn!(
506 ?user_id,
507 ?device_id,
508 "Tried to create an Olm session but the device is unknown",
509 );
510 continue;
511 }
512 Err(e) => {
513 warn!(
514 ?user_id, ?device_id, error = ?e,
515 "Tried to create an Olm session, but we can't \
516 fetch the device from the store",
517 );
518 continue;
519 }
520 };
521
522 let account = store_transaction.account().await?;
523 let device_keys = self.store.get_own_device().await?.as_device_keys().clone();
524 let session = match account.create_outbound_session(&device, key_map, device_keys) {
525 Ok(s) => s,
526 Err(e) => {
527 warn!(
528 ?user_id, ?device_id, error = ?e,
529 "Error creating Olm session"
530 );
531
532 self.failed_devices
533 .write()
534 .entry(user_id.to_owned())
535 .or_default()
536 .insert(device_id.to_owned());
537
538 continue;
539 }
540 };
541
542 self.key_request_machine.retry_keyshare(user_id, device_id);
543
544 if let Err(e) = self.check_if_unwedged(user_id, device_id).await {
545 error!(?user_id, ?device_id, "Error while treating an unwedged device: {e:?}");
546 }
547
548 let session_info = SessionInfo {
549 session_id: session.session_id().to_owned(),
550 algorithm: session.algorithm().await,
551 fallback_key_used: session.created_using_fallback_key,
552 };
553
554 changes.sessions.push(session);
555 new_sessions.entry(user_id).or_default().insert(device_id, session_info);
556 }
557 }
558
559 store_transaction.commit().await?;
560 self.store.save_changes(changes).await?;
561 info!(sessions = ?new_sessions, "Established new Olm sessions");
562
563 for (user, device_map) in new_sessions {
564 if let Some(user_cache) = self.failed_devices.read().get(user) {
565 user_cache.remove(device_map.into_keys());
566 }
567 }
568
569 let store_cache = self.store.cache().await?;
570 match self.key_request_machine.collect_incoming_key_requests(&store_cache).await {
571 Ok(sessions) => {
572 let changes = Changes { sessions, ..Default::default() };
573 self.store.save_changes(changes).await?
574 }
575 Err(e) => {
578 warn!(error = ?e, "Error while trying to collect the incoming secret requests")
579 }
580 }
581
582 Ok(())
583 }
584}
585
586#[cfg(test)]
587mod tests {
588 use std::{collections::BTreeMap, iter, ops::Deref, sync::Arc, time::Duration};
589
590 use matrix_sdk_common::{executor::spawn, locks::RwLock as StdRwLock};
591 use matrix_sdk_test::{async_test, ruma_response_from_json};
592 use ruma::{
593 DeviceId, OwnedUserId, UserId,
594 api::client::keys::claim_keys::v3::Response as KeyClaimResponse, device_id,
595 owned_server_name, user_id,
596 };
597 use serde_json::json;
598 use tokio::sync::Mutex;
599 use tracing::info;
600
601 use super::SessionManager;
602 use crate::{
603 gossiping::GossipMachine,
604 identities::{DeviceData, IdentityManager},
605 olm::{Account, PrivateCrossSigningIdentity},
606 session_manager::GroupSessionCache,
607 store::{
608 CryptoStoreWrapper, MemoryStore, Store,
609 types::{Changes, DeviceChanges, PendingChanges},
610 },
611 verification::VerificationMachine,
612 };
613
614 fn user_id() -> &'static UserId {
615 user_id!("@example:localhost")
616 }
617
618 fn device_id() -> &'static DeviceId {
619 device_id!("DEVICEID")
620 }
621
622 fn bob_account() -> Account {
623 Account::with_device_id(user_id!("@bob:localhost"), device_id!("BOBDEVICE"))
624 }
625
626 fn keys_claim_with_failure() -> KeyClaimResponse {
627 let response = json!({
628 "one_time_keys": {},
629 "failures": {
630 "example.org": {
631 "errcode": "M_RESOURCE_LIMIT_EXCEEDED",
632 "error": "Not yet ready to retry",
633 }
634 }
635 });
636 ruma_response_from_json(&response)
637 }
638
639 fn keys_claim_without_failure() -> KeyClaimResponse {
640 let response = json!({
641 "one_time_keys": {
642 "@alice:example.org": {},
643 },
644 "failures": {},
645 });
646 ruma_response_from_json(&response)
647 }
648
649 async fn session_manager_test_helper() -> (SessionManager, IdentityManager) {
650 let user_id = user_id();
651 let device_id = device_id();
652
653 let account = Account::with_device_id(user_id, device_id);
654 let store = Arc::new(CryptoStoreWrapper::new(user_id, device_id, MemoryStore::new()));
655 let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id)));
656 let verification = VerificationMachine::new(
657 account.static_data().clone(),
658 identity.clone(),
659 store.clone(),
660 );
661
662 let store = Store::new(account.static_data().clone(), identity, store, verification);
663 let device = DeviceData::from_account(&account);
664 store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
665 store
666 .save_changes(Changes {
667 devices: DeviceChanges { new: vec![device], ..Default::default() },
668 ..Default::default()
669 })
670 .await
671 .unwrap();
672
673 let session_cache = GroupSessionCache::new(store.clone());
674 let identity_manager = IdentityManager::new(store.clone());
675
676 let users_for_key_claim = Arc::new(StdRwLock::new(BTreeMap::new()));
677 let key_request = GossipMachine::new(
678 store.clone(),
679 identity_manager.clone(),
680 session_cache,
681 users_for_key_claim.clone(),
682 );
683
684 (SessionManager::new(users_for_key_claim, key_request, store), identity_manager)
685 }
686
687 #[async_test]
688 async fn test_session_creation() {
689 let (manager, _identity_manager) = session_manager_test_helper().await;
690 let mut bob = bob_account();
691
692 let bob_device = DeviceData::from_account(&bob);
693
694 manager.store.save_device_data(&[bob_device]).await.unwrap();
695
696 let (txn_id, request) =
697 manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().unwrap();
698
699 assert!(request.one_time_keys.contains_key(bob.user_id()));
700
701 bob.generate_one_time_keys(1);
702 let one_time = bob.signed_one_time_keys();
703 assert!(!one_time.is_empty());
704 bob.mark_keys_as_published();
705
706 let mut one_time_keys = BTreeMap::new();
707 one_time_keys
708 .entry(bob.user_id().to_owned())
709 .or_insert_with(BTreeMap::new)
710 .insert(bob.device_id().to_owned(), one_time);
711
712 let response = KeyClaimResponse::new(one_time_keys);
713
714 manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
715
716 assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
717 }
718
719 #[async_test]
720 async fn test_session_creation_waits_for_keys_query() {
721 let (manager, identity_manager) = session_manager_test_helper().await;
722
723 let (key_query_txn_id, key_query_request) =
726 identity_manager.users_for_key_query().await.unwrap().pop_first().unwrap();
727 info!("Initial key query: {:?}", key_query_request);
728
729 let bob = bob_account();
731 let bob_device = DeviceData::from_account(&bob);
732 {
733 let cache = manager.store.cache().await.unwrap();
734 identity_manager
735 .key_query_manager
736 .synced(&cache)
737 .await
738 .unwrap()
739 .update_tracked_users(iter::once(bob.user_id()))
740 .await
741 .unwrap();
742 }
743
744 let missing_sessions_task = {
747 let manager = manager.clone();
748 let bob_user_id = bob.user_id().to_owned();
749
750 #[allow(unknown_lints, clippy::redundant_async_block)] spawn(
752 async move { manager.get_missing_sessions(iter::once(bob_user_id.deref())).await },
753 )
754 };
755
756 let response_json =
758 json!({ "device_keys": { manager.store.static_account().user_id.to_owned(): {}}});
759 let response = ruma_response_from_json(&response_json);
760 identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap();
761
762 let (key_query_txn_id, key_query_request) =
763 identity_manager.users_for_key_query().await.unwrap().pop_first().unwrap();
764 info!("Second key query: {:?}", key_query_request);
765
766 let response_json = json!({ "device_keys": { bob.user_id(): {
768 bob_device.device_id(): bob_device.as_device_keys()
769 }}});
770 let response = ruma_response_from_json(&response_json);
771 identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap();
772
773 let (_, keys_claim_request) = missing_sessions_task.await.unwrap().unwrap().unwrap();
776 info!("Key claim request: {:?}", keys_claim_request.one_time_keys);
777 let bob_key_claims = keys_claim_request.one_time_keys.get(bob.user_id()).unwrap();
778 assert!(bob_key_claims.contains_key(bob_device.device_id()));
779 }
780
781 #[async_test]
782 async fn test_session_creation_does_not_wait_for_keys_query_on_failed_server() {
783 let (manager, identity_manager) = session_manager_test_helper().await;
784
785 let other_user_id = OwnedUserId::try_from("@bob:example.com").unwrap();
787 {
788 let cache = manager.store.cache().await.unwrap();
789 identity_manager
790 .key_query_manager
791 .synced(&cache)
792 .await
793 .unwrap()
794 .update_tracked_users(iter::once(other_user_id.as_ref()))
795 .await
796 .unwrap();
797 }
798
799 let (key_query_txn_id, _key_query_request) =
801 identity_manager.users_for_key_query().await.unwrap().pop_first().unwrap();
802 let response = ruma_response_from_json(
803 &json!({ "device_keys": {}, "failures": { other_user_id.server_name(): "unreachable" }}),
804 );
805 identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap();
806
807 let result = tokio::time::timeout(
810 Duration::from_millis(10),
811 manager.get_missing_sessions(iter::once(other_user_id.as_ref())),
812 )
813 .await
814 .expect("get_missing_sessions blocked rather than completing quickly")
815 .expect("get_missing_sessions returned an error");
816
817 assert!(result.is_none(), "get_missing_sessions returned Some(...)");
818 }
819
820 #[async_test]
823 #[cfg(target_os = "linux")]
824 async fn test_session_unwedging() {
825 use ruma::{SecondsSinceUnixEpoch, time::SystemTime};
826
827 let (manager, _identity_manager) = session_manager_test_helper().await;
828 let mut bob = bob_account();
829
830 let (_, mut session) = manager
831 .store
832 .with_transaction(|mut tr| async {
833 let manager_account = tr.account().await.unwrap();
834 let res = bob.create_session_for_test_helper(manager_account).await;
835 Ok((tr, res))
836 })
837 .await
838 .unwrap();
839
840 let bob_device = DeviceData::from_account(&bob);
841 let time = SystemTime::now() - Duration::from_secs(3601);
842 session.creation_time = SecondsSinceUnixEpoch::from_system_time(time).unwrap();
843
844 let devices = std::slice::from_ref(&bob_device);
845 manager.store.save_device_data(devices).await.unwrap();
846 manager.store.save_sessions(&[session]).await.unwrap();
847
848 assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
849
850 let curve_key = bob_device.curve25519_key().unwrap();
851
852 assert!(!manager.users_for_key_claim.read().contains_key(bob.user_id()));
853 assert!(!manager.is_device_wedged(&bob_device));
854 manager.mark_device_as_wedged(bob_device.user_id(), curve_key).await.unwrap();
855 assert!(manager.is_device_wedged(&bob_device));
856 assert!(manager.users_for_key_claim.read().contains_key(bob.user_id()));
857
858 let (txn_id, request) =
859 manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().unwrap();
860
861 assert!(request.one_time_keys.contains_key(bob.user_id()));
862
863 bob.generate_one_time_keys(1);
864 let one_time = bob.signed_one_time_keys();
865 assert!(!one_time.is_empty());
866 bob.mark_keys_as_published();
867
868 let mut one_time_keys = BTreeMap::new();
869 one_time_keys
870 .entry(bob.user_id().to_owned())
871 .or_insert_with(BTreeMap::new)
872 .insert(bob.device_id().to_owned(), one_time);
873
874 let response = KeyClaimResponse::new(one_time_keys);
875
876 assert!(manager.outgoing_to_device_requests.read().is_empty());
877
878 manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
879
880 assert!(!manager.is_device_wedged(&bob_device));
881 assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
882 assert!(!manager.outgoing_to_device_requests.read().is_empty())
883 }
884
885 #[async_test]
886 async fn test_failure_handling() {
887 let alice = user_id!("@alice:example.org");
888 let alice_account = Account::with_device_id(alice, "DEVICEID".into());
889 let alice_device = DeviceData::from_account(&alice_account);
890
891 let (manager, _identity_manager) = session_manager_test_helper().await;
892
893 manager.store.save_device_data(&[alice_device]).await.unwrap();
894
895 let (txn_id, users_for_key_claim) =
896 manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
897 assert!(users_for_key_claim.one_time_keys.contains_key(alice));
898
899 manager.receive_keys_claim_response(&txn_id, &keys_claim_with_failure()).await.unwrap();
900 assert!(manager.get_missing_sessions(iter::once(alice)).await.unwrap().is_none());
901
902 manager.failures.expire(&owned_server_name!("example.org"));
904
905 let (txn_id, users_for_key_claim) =
906 manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
907 assert!(users_for_key_claim.one_time_keys.contains_key(alice));
908
909 manager.receive_keys_claim_response(&txn_id, &keys_claim_without_failure()).await.unwrap();
910 }
911
912 #[async_test]
913 async fn test_failed_devices_handling() {
914 test_invalid_claim_response(json!({
916 "one_time_keys": {},
917 "failures": {},
918 }))
919 .await;
920
921 test_invalid_claim_response(json!({
923 "one_time_keys": {
924 "@alice:example.org": {}
925 },
926 "failures": {},
927 }))
928 .await;
929
930 test_invalid_claim_response(json!({
932 "one_time_keys": {
933 "@alice:example.org": {
934 "DEVICEID": {}
935 }
936 },
937 "failures": {},
938 }))
939 .await;
940
941 test_invalid_claim_response(json!({
943 "one_time_keys": {
944 "@alice:example.org": {
945 "DEVICEID": {
946 "signed_curve25519:AAAAAA": {
947 "fallback": true,
948 "key": "1sra5GVo1ONz478aQybxSEeHTSo2xq0Z+Q3Yzqvp3A4",
949 "signatures": {
950 "@example:morpheus.localhost": {
951 "ed25519:YAFLBLXAUK": "Zwk90fJhZWOYGNOgtOswZ6RSOGeTjTi/h2dMpyB0CR6EVtvTra0WJtp32ntifrxtwD710y2F3pe5Oyrm7jngCQ"
952 }
953 }
954 }
955 }
956 }
957 },
958 "failures": {},
959 })).await;
960 }
961
962 async fn test_invalid_claim_response(response_json: serde_json::Value) {
968 let response = ruma_response_from_json(&response_json);
969
970 let alice = user_id!("@alice:example.org");
971 let mut alice_account = Account::with_device_id(alice, "DEVICEID".into());
972 let alice_device = DeviceData::from_account(&alice_account);
973
974 let (manager, _identity_manager) = session_manager_test_helper().await;
975 manager.store.save_device_data(&[alice_device]).await.unwrap();
976
977 let (txn_id, users_for_key_claim) =
980 manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
981 assert!(users_for_key_claim.one_time_keys.contains_key(alice));
982
983 manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
986 assert!(manager.get_missing_sessions(iter::once(alice)).await.unwrap().is_none());
988
989 alice_account.generate_one_time_keys(1);
990 let one_time = alice_account.signed_one_time_keys();
991 assert!(!one_time.is_empty());
992
993 let mut one_time_keys = BTreeMap::new();
994 one_time_keys
995 .entry(alice.to_owned())
996 .or_insert_with(BTreeMap::new)
997 .insert(alice_account.device_id().to_owned(), one_time);
998
999 manager
1001 .failed_devices
1002 .write()
1003 .get(alice)
1004 .unwrap()
1005 .expire(&alice_account.device_id().to_owned());
1006 let (txn_id, users_for_key_claim) =
1007 manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
1008 assert!(users_for_key_claim.one_time_keys.contains_key(alice));
1009
1010 let response = KeyClaimResponse::new(one_time_keys);
1011 manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
1012
1013 assert!(
1015 manager
1016 .failed_devices
1017 .read()
1018 .get(alice)
1019 .unwrap()
1020 .failure_count(alice_account.device_id())
1021 .is_none()
1022 );
1023 }
1024}