1use std::future::IntoFuture;
16
17use eyeball::SharedObservable;
18use futures_core::Stream;
19use matrix_sdk_base::{
20 SessionMeta, boxed_into_future,
21 crypto::types::qr_login::{QrCodeData, QrCodeMode},
22 store::RoomLoadSettings,
23};
24use oauth2::{DeviceCodeErrorResponseType, StandardDeviceAuthorizationResponse};
25use ruma::{
26 OwnedDeviceId,
27 api::client::discovery::get_authorization_server_metadata::v1::AuthorizationServerMetadata,
28};
29use tracing::trace;
30use vodozemac::{Curve25519PublicKey, ecies::CheckCode};
31
32use super::{
33 DeviceAuthorizationOAuthError, QRCodeLoginError, SecureChannelError,
34 messages::{LoginFailureReason, QrAuthMessage},
35 secure_channel::{EstablishedSecureChannel, SecureChannel},
36};
37use crate::{
38 Client,
39 authentication::oauth::{
40 ClientRegistrationData, OAuth, OAuthError,
41 qrcode::{CheckCodeSender, GeneratedQrProgress, LoginProtocolType},
42 },
43};
44
45async fn send_unexpected_message_error(
46 channel: &mut EstablishedSecureChannel,
47) -> Result<(), SecureChannelError> {
48 channel
49 .send_json(QrAuthMessage::LoginFailure {
50 reason: LoginFailureReason::UnexpectedMessageReceived,
51 homeserver: None,
52 })
53 .await
54}
55
56async fn finish_login<Q>(
57 client: &Client,
58 mut channel: EstablishedSecureChannel,
59 registration_data: Option<&ClientRegistrationData>,
60 state: SharedObservable<LoginProgress<Q>>,
61) -> Result<(), QRCodeLoginError> {
62 let oauth = client.oauth();
63
64 trace!("Registering the client with the OAuth 2.0 authorization server.");
66 let server_metadata = register_client(&oauth, registration_data).await?;
67
68 let account = vodozemac::olm::Account::new();
71 let public_key = account.identity_keys().curve25519;
72 let device_id = public_key;
73
74 trace!("Requesting device authorization.");
77 let auth_grant_response =
78 request_device_authorization(&oauth, &server_metadata, device_id).await?;
79
80 trace!("Letting the existing device know about the device authorization grant.");
83 let message =
84 QrAuthMessage::authorization_grant_login_protocol((&auth_grant_response).into(), device_id);
85 channel.send_json(&message).await?;
86
87 match channel.receive_json().await? {
89 QrAuthMessage::LoginProtocolAccepted => (),
90 QrAuthMessage::LoginFailure { reason, homeserver } => {
91 return Err(QRCodeLoginError::LoginFailure { reason, homeserver });
92 }
93 message => {
94 send_unexpected_message_error(&mut channel).await?;
95
96 return Err(QRCodeLoginError::UnexpectedMessage {
97 expected: "m.login.protocol_accepted",
98 received: message,
99 });
100 }
101 }
102
103 let user_code = auth_grant_response.user_code();
107 state.set(LoginProgress::WaitingForToken { user_code: user_code.secret().to_owned() });
108
109 trace!("Waiting for the OAuth 2.0 authorization server to give us the access token.");
112 if let Err(e) = wait_for_tokens(&oauth, &server_metadata, &auth_grant_response).await {
113 if let Some(e) = e.as_request_token_error() {
116 match e {
117 DeviceCodeErrorResponseType::AccessDenied => {
118 channel.send_json(QrAuthMessage::LoginDeclined).await?;
119 }
120 DeviceCodeErrorResponseType::ExpiredToken => {
121 channel
122 .send_json(QrAuthMessage::LoginFailure {
123 reason: LoginFailureReason::AuthorizationExpired,
124 homeserver: None,
125 })
126 .await?;
127 }
128 _ => (),
129 }
130 }
131
132 return Err(e.into());
133 }
134
135 trace!("Discovering our own user id.");
141 let whoami_response = client.whoami().await.map_err(QRCodeLoginError::UserIdDiscovery)?;
142 client
143 .base_client()
144 .activate(
145 SessionMeta {
146 user_id: whoami_response.user_id,
147 device_id: OwnedDeviceId::from(device_id.to_base64()),
148 },
149 RoomLoadSettings::default(),
150 Some(account),
151 )
152 .await
153 .map_err(|error| QRCodeLoginError::SessionTokens(error.into()))?;
154
155 client.oauth().enable_cross_process_lock().await?;
156
157 state.set(LoginProgress::SyncingSecrets);
158
159 trace!("Telling the existing device that we successfully logged in.");
161 let message = QrAuthMessage::LoginSuccess;
162 channel.send_json(&message).await?;
163
164 trace!("Waiting for the secrets bundle.");
167 let bundle = match channel.receive_json().await? {
168 QrAuthMessage::LoginSecrets(bundle) => bundle,
169 QrAuthMessage::LoginFailure { reason, homeserver } => {
170 return Err(QRCodeLoginError::LoginFailure { reason, homeserver });
171 }
172 message => {
173 send_unexpected_message_error(&mut channel).await?;
174
175 return Err(QRCodeLoginError::UnexpectedMessage {
176 expected: "m.login.secrets",
177 received: message,
178 });
179 }
180 };
181
182 client.encryption().import_secrets_bundle(&bundle).await?;
185
186 client
189 .encryption()
190 .ensure_device_keys_upload()
191 .await
192 .map_err(QRCodeLoginError::DeviceKeyUpload)?;
193
194 client.encryption().spawn_initialization_task(None).await;
199 client.encryption().wait_for_e2ee_initialization_tasks().await;
200
201 trace!("successfully logged in and enabled E2EE.");
202
203 state.set(LoginProgress::Done);
205
206 Ok(())
208}
209
210async fn register_client(
214 oauth: &OAuth,
215 registration_data: Option<&ClientRegistrationData>,
216) -> Result<AuthorizationServerMetadata, DeviceAuthorizationOAuthError> {
217 let server_metadata = oauth.server_metadata().await.map_err(OAuthError::from)?;
218 oauth.use_registration_data(&server_metadata, registration_data).await?;
219
220 Ok(server_metadata)
221}
222
223async fn request_device_authorization(
224 oauth: &OAuth,
225 server_metadata: &AuthorizationServerMetadata,
226 device_id: Curve25519PublicKey,
227) -> Result<StandardDeviceAuthorizationResponse, DeviceAuthorizationOAuthError> {
228 let response = oauth
229 .request_device_authorization(server_metadata, Some(device_id.to_base64().into()))
230 .await?;
231 Ok(response)
232}
233
234async fn wait_for_tokens(
235 oauth: &OAuth,
236 server_metadata: &AuthorizationServerMetadata,
237 auth_response: &StandardDeviceAuthorizationResponse,
238) -> Result<(), DeviceAuthorizationOAuthError> {
239 oauth.exchange_device_code(server_metadata, auth_response).await?;
240 Ok(())
241}
242
243#[derive(Clone, Debug, Default)]
245pub enum LoginProgress<Q> {
246 #[default]
248 Starting,
249 EstablishingSecureChannel(Q),
252 WaitingForToken {
256 user_code: String,
260 },
261 SyncingSecrets,
263 Done,
265}
266
267#[derive(Clone, Debug)]
274pub struct QrProgress {
275 pub check_code: CheckCode,
277}
278
279#[derive(Debug)]
282pub struct LoginWithQrCode<'a> {
283 client: &'a Client,
284 registration_data: Option<&'a ClientRegistrationData>,
285 qr_code_data: &'a QrCodeData,
286 state: SharedObservable<LoginProgress<QrProgress>>,
287}
288
289impl LoginWithQrCode<'_> {
290 pub fn subscribe_to_progress(&self) -> impl Stream<Item = LoginProgress<QrProgress>> + use<> {
296 self.state.subscribe()
297 }
298}
299
300impl<'a> IntoFuture for LoginWithQrCode<'a> {
301 type Output = Result<(), QRCodeLoginError>;
302 boxed_into_future!(extra_bounds: 'a);
303
304 fn into_future(self) -> Self::IntoFuture {
305 Box::pin(async move {
306 let channel = self.establish_secure_channel().await?;
315
316 trace!("Established the secure channel.");
317
318 let check_code = channel.check_code().to_owned();
322 self.state.set(LoginProgress::EstablishingSecureChannel(QrProgress { check_code }));
323
324 finish_login(self.client, channel, self.registration_data, self.state).await
331 })
332 }
333}
334
335impl<'a> LoginWithQrCode<'a> {
336 pub(crate) fn new(
337 client: &'a Client,
338 qr_code_data: &'a QrCodeData,
339 registration_data: Option<&'a ClientRegistrationData>,
340 ) -> LoginWithQrCode<'a> {
341 LoginWithQrCode { client, registration_data, qr_code_data, state: Default::default() }
342 }
343
344 async fn establish_secure_channel(
345 &self,
346 ) -> Result<EstablishedSecureChannel, SecureChannelError> {
347 let http_client = self.client.inner.http_client.inner.clone();
348
349 let channel = EstablishedSecureChannel::from_qr_code(
350 http_client,
351 self.qr_code_data,
352 QrCodeMode::Login,
353 )
354 .await?;
355
356 Ok(channel)
357 }
358}
359
360#[derive(Debug)]
363pub struct LoginWithGeneratedQrCode<'a> {
364 client: &'a Client,
365 registration_data: Option<&'a ClientRegistrationData>,
366 state: SharedObservable<LoginProgress<GeneratedQrProgress>>,
367}
368
369impl LoginWithGeneratedQrCode<'_> {
370 pub fn subscribe_to_progress(
375 &self,
376 ) -> impl Stream<Item = LoginProgress<GeneratedQrProgress>> + use<> {
377 self.state.subscribe()
378 }
379}
380
381impl<'a> IntoFuture for LoginWithGeneratedQrCode<'a> {
382 type Output = Result<(), QRCodeLoginError>;
383 boxed_into_future!(extra_bounds: 'a);
384
385 fn into_future(self) -> Self::IntoFuture {
386 Box::pin(async move {
387 let mut channel = self.establish_secure_channel().await?;
390
391 trace!("Established the secure channel.");
392
393 let message = channel.receive_json().await?;
397
398 let homeserver = match message {
401 QrAuthMessage::LoginProtocols { protocols, homeserver } => {
402 if !protocols.contains(&LoginProtocolType::DeviceAuthorizationGrant) {
403 channel
404 .send_json(QrAuthMessage::LoginFailure {
405 reason: LoginFailureReason::UnsupportedProtocol,
406 homeserver: None,
407 })
408 .await?;
409
410 return Err(QRCodeLoginError::LoginFailure {
411 reason: LoginFailureReason::UnsupportedProtocol,
412 homeserver: None,
413 });
414 }
415
416 homeserver
417 }
418 _ => {
419 send_unexpected_message_error(&mut channel).await?;
420
421 return Err(QRCodeLoginError::UnexpectedMessage {
422 expected: "m.login.protocols",
423 received: message,
424 });
425 }
426 };
427
428 if self.client.homeserver() != homeserver {
431 self.client
432 .switch_homeserver_and_re_resolve_well_known(homeserver)
433 .await
434 .map_err(QRCodeLoginError::ServerReset)?;
435 }
436
437 finish_login(self.client, channel, self.registration_data, self.state).await
440 })
441 }
442}
443
444impl<'a> LoginWithGeneratedQrCode<'a> {
445 pub(crate) fn new(
446 client: &'a Client,
447 registration_data: Option<&'a ClientRegistrationData>,
448 ) -> Self {
449 Self { client, registration_data, state: Default::default() }
450 }
451
452 async fn establish_secure_channel(
453 &self,
454 ) -> Result<EstablishedSecureChannel, SecureChannelError> {
455 let http_client = self.client.inner.http_client.clone();
456
457 let secure_channel = SecureChannel::login(http_client, &self.client.homeserver()).await?;
461
462 let qr_code_data = secure_channel.qr_code_data().clone();
466 trace!("Generated QR code.");
467 self.state.set(LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrReady(
468 qr_code_data,
469 )));
470
471 let channel = secure_channel.connect().await?;
475
476 trace!("Waiting for checkcode.");
481 let (tx, rx) = tokio::sync::oneshot::channel();
482 self.state.set(LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrScanned(
483 CheckCodeSender::new(tx),
484 )));
485
486 let check_code = rx.await.map_err(|_| SecureChannelError::CannotReceiveCheckCode)?;
490 trace!("Received check code.");
491 channel.confirm(check_code)
492 }
493}
494
495#[cfg(all(test, not(target_family = "wasm")))]
496mod test {
497 use assert_matches2::{assert_let, assert_matches};
498 use futures_util::{StreamExt, join};
499 use matrix_sdk_base::crypto::types::{SecretsBundle, qr_login::QrCodeModeData};
500 use matrix_sdk_common::executor::spawn;
501 use matrix_sdk_test::async_test;
502 use serde_json::json;
503
504 use super::*;
505 use crate::{
506 authentication::oauth::qrcode::{
507 messages::LoginProtocolType,
508 secure_channel::{SecureChannel, test::MockedRendezvousServer},
509 },
510 config::RequestConfig,
511 http_client::HttpClient,
512 test_utils::{client::oauth::mock_client_metadata, mocks::MatrixMockServer},
513 };
514
515 enum AliceBehaviour {
516 HappyPath,
517 DeclinedProtocol,
518 UnexpectedMessage,
519 UnexpectedMessageInsteadOfSecrets,
520 RefuseSecrets,
521 }
522
523 enum TokenResponse {
525 Ok,
526 AccessDenied,
527 ExpiredToken,
528 }
529
530 fn secrets_bundle() -> SecretsBundle {
531 let json = json!({
532 "cross_signing": {
533 "master_key": "rTtSv67XGS6k/rg6/yTG/m573cyFTPFRqluFhQY+hSw",
534 "self_signing_key": "4jbPt7jh5D2iyM4U+3IDa+WthgJB87IQN1ATdkau+xk",
535 "user_signing_key": "YkFKtkjcsTxF6UAzIIG/l6Nog/G2RigCRfWj3cjNWeM",
536 },
537 });
538
539 serde_json::from_value(json).expect("We should be able to deserialize a secrets bundle")
540 }
541
542 async fn grant_login(
547 alice: SecureChannel,
548 check_code_receiver: tokio::sync::oneshot::Receiver<CheckCode>,
549 behavior: AliceBehaviour,
550 ) {
551 let alice = alice.connect().await.expect("Alice should be able to connect the channel");
552
553 let check_code =
554 check_code_receiver.await.expect("We should receive the check code from bob");
555
556 let mut alice = alice
557 .confirm(check_code.to_digit())
558 .expect("Alice should be able to confirm the secure channel");
559
560 let message = alice
561 .receive_json()
562 .await
563 .expect("Alice should be able to receive the initial message from Bob");
564
565 assert_let!(QrAuthMessage::LoginProtocol { protocol, .. } = message);
566 assert_eq!(protocol, LoginProtocolType::DeviceAuthorizationGrant);
567
568 let message = match behavior {
569 AliceBehaviour::DeclinedProtocol => QrAuthMessage::LoginFailure {
570 reason: LoginFailureReason::UnsupportedProtocol,
571 homeserver: None,
572 },
573 AliceBehaviour::UnexpectedMessage => QrAuthMessage::LoginDeclined,
574 _ => QrAuthMessage::LoginProtocolAccepted,
575 };
576
577 alice.send_json(message).await.unwrap();
578
579 let message: QrAuthMessage = alice.receive_json().await.unwrap();
580 assert_let!(QrAuthMessage::LoginSuccess = message);
581
582 let message = match behavior {
583 AliceBehaviour::UnexpectedMessageInsteadOfSecrets => QrAuthMessage::LoginDeclined,
584 AliceBehaviour::RefuseSecrets => QrAuthMessage::LoginFailure {
585 reason: LoginFailureReason::DeviceNotFound,
586 homeserver: None,
587 },
588 _ => QrAuthMessage::LoginSecrets(secrets_bundle()),
589 };
590
591 alice.send_json(message).await.unwrap();
592 }
593
594 #[async_test]
595 async fn test_qr_login() {
596 let server = MatrixMockServer::new().await;
597 let rendezvous_server = MockedRendezvousServer::new(server.server(), "abcdEFG12345").await;
598 let (sender, receiver) = tokio::sync::oneshot::channel();
599
600 let oauth_server = server.oauth();
601 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
602 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
603 oauth_server
604 .mock_device_authorization()
605 .ok()
606 .expect(1)
607 .named("device_authorization")
608 .mount()
609 .await;
610 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
611
612 server.mock_versions().ok().expect(1..).named("versions").mount().await;
613 server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
614 server.mock_upload_keys().ok().expect(1).named("upload_keys").mount().await;
615 server.mock_query_keys().ok().expect(1).named("query_keys").mount().await;
616
617 let client = HttpClient::new(reqwest::Client::new(), Default::default());
618 let alice = SecureChannel::reciprocate(client, &rendezvous_server.homeserver_url)
619 .await
620 .expect("Alice should be able to create a secure channel.");
621
622 assert_let!(QrCodeModeData::Reciprocate { server_name } = &alice.qr_code_data().mode_data);
623
624 let bob = Client::builder()
625 .server_name_or_homeserver_url(server_name)
626 .request_config(RequestConfig::new().disable_retry())
627 .build()
628 .await
629 .expect("We should be able to build the Client object from the URL in the QR code");
630
631 let qr_code = alice.qr_code_data().clone();
632
633 let oauth = bob.oauth();
634 let registration_data = mock_client_metadata().into();
635 let login_bob = oauth.login_with_qr_code(Some(®istration_data)).scan(&qr_code);
636 let mut updates = login_bob.subscribe_to_progress();
637
638 let updates_task = spawn(async move {
639 let mut sender = Some(sender);
640
641 while let Some(update) = updates.next().await {
642 match update {
643 LoginProgress::EstablishingSecureChannel(QrProgress { check_code }) => {
644 sender
645 .take()
646 .expect("The establishing secure channel update should be received only once")
647 .send(check_code)
648 .expect("Bob should be able to send the check code to Alice");
649 }
650 LoginProgress::Done => break,
651 _ => (),
652 }
653 }
654 });
655 let alice_task =
656 spawn(async { grant_login(alice, receiver, AliceBehaviour::HappyPath).await });
657
658 join!(
659 async {
660 login_bob.await.expect("Bob should be able to login");
661 },
662 async {
663 alice_task.await.expect("Alice should have completed it's task successfully");
664 },
665 async { updates_task.await.unwrap() }
666 );
667
668 assert!(bob.encryption().cross_signing_status().await.unwrap().is_complete());
669 let own_identity =
670 bob.encryption().get_user_identity(bob.user_id().unwrap()).await.unwrap().unwrap();
671
672 assert!(own_identity.is_verified());
673 }
674
675 async fn grant_login_with_generated_qr(
676 alice: &Client,
677 qr_receiver: tokio::sync::oneshot::Receiver<QrCodeData>,
678 cctx_receiver: tokio::sync::oneshot::Receiver<CheckCodeSender>,
679 behavior: AliceBehaviour,
680 ) {
681 let qr_code_data = qr_receiver.await.expect("Alice should receive the QR code");
682
683 let mut channel = EstablishedSecureChannel::from_qr_code(
684 alice.inner.http_client.inner.clone(),
685 &qr_code_data,
686 QrCodeMode::Reciprocate,
687 )
688 .await
689 .expect("Alice should be able to establish the secure channel");
690
691 trace!("Established the secure channel.");
692
693 let check_code = channel.check_code().to_digit();
696
697 let check_code_sender =
698 cctx_receiver.await.expect("Alice should receive the CheckCodeSender");
699
700 check_code_sender
701 .send(check_code)
702 .await
703 .expect("Alice should be able to send the check code to Bob");
704
705 let message = QrAuthMessage::LoginProtocols {
707 protocols: vec![LoginProtocolType::DeviceAuthorizationGrant],
708 homeserver: alice.homeserver(),
709 };
710 channel
711 .send_json(message)
712 .await
713 .expect("Alice should be able to send the `m.login.protocols` message to Bob");
714
715 let message: QrAuthMessage = channel
717 .receive_json()
718 .await
719 .expect("Alice should be able to receive the `m.login.protocol` message from Bob");
720 assert_let!(QrAuthMessage::LoginProtocol { protocol, .. } = message);
721 assert_eq!(protocol, LoginProtocolType::DeviceAuthorizationGrant);
722
723 let message = match behavior {
725 AliceBehaviour::DeclinedProtocol => QrAuthMessage::LoginFailure {
726 reason: LoginFailureReason::UnsupportedProtocol,
727 homeserver: None,
728 },
729 AliceBehaviour::UnexpectedMessage => QrAuthMessage::LoginDeclined,
730 _ => QrAuthMessage::LoginProtocolAccepted,
731 };
732 channel
733 .send_json(message)
734 .await
735 .expect("Alice should be able to send the `m.login.protocol_accepted` message to Bob");
736
737 let message: QrAuthMessage = channel
738 .receive_json()
739 .await
740 .expect("Alice should be able to receive the `m.login.success` message from Bob");
741 assert_let!(QrAuthMessage::LoginSuccess = message);
742
743 let message = match behavior {
745 AliceBehaviour::UnexpectedMessageInsteadOfSecrets => QrAuthMessage::LoginDeclined,
746 AliceBehaviour::RefuseSecrets => QrAuthMessage::LoginFailure {
747 reason: LoginFailureReason::DeviceNotFound,
748 homeserver: None,
749 },
750 _ => QrAuthMessage::LoginSecrets(secrets_bundle()),
751 };
752 channel
753 .send_json(message)
754 .await
755 .expect("Alice should be able to send the `m.login.secrets` message to Bob");
756 }
757
758 #[async_test]
759 async fn test_generated_qr_login() {
760 let server = MatrixMockServer::new().await;
761 let rendezvous_server = MockedRendezvousServer::new(server.server(), "abcdEFG12345").await;
762 let (qr_sender, qr_receiver) = tokio::sync::oneshot::channel();
763 let (cctx_sender, cctx_receiver) = tokio::sync::oneshot::channel();
764
765 let oauth_server = server.oauth();
766 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
767 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
768 oauth_server
769 .mock_device_authorization()
770 .ok()
771 .expect(1)
772 .named("device_authorization")
773 .mount()
774 .await;
775 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
776
777 server.mock_versions().ok().expect(1..).named("versions").mount().await;
778 server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
779 server.mock_upload_keys().ok().expect(1).named("upload_keys").mount().await;
780 server.mock_query_keys().ok().expect(1).named("query_keys").mount().await;
781
782 let homeserver_url = rendezvous_server.homeserver_url.clone();
783
784 let alice = server.client_builder().logged_in_with_oauth().build().await;
787 assert!(alice.session_meta().is_some(), "Alice should be logged in");
788
789 let bob = Client::builder()
791 .server_name_or_homeserver_url(&homeserver_url)
792 .request_config(RequestConfig::new().disable_retry())
793 .build()
794 .await
795 .expect("Should be able to create a client for Bob");
796
797 let secure_channel = SecureChannel::login(bob.inner.http_client.clone(), &homeserver_url)
798 .await
799 .expect("Bob should be able to create a secure channel");
800
801 assert_eq!(QrCodeModeData::Login, secure_channel.qr_code_data().mode_data);
802
803 let registration_data = mock_client_metadata().into();
804 let bob_oauth = bob.oauth();
805 let bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate();
806 let mut bob_updates = bob_login.subscribe_to_progress();
807
808 let updates_task = spawn(async move {
809 let mut qr_sender = Some(qr_sender);
810 let mut cctx_sender = Some(cctx_sender);
811
812 while let Some(update) = bob_updates.next().await {
813 match update {
814 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrReady(qr)) => {
815 qr_sender
816 .take()
817 .expect("The establishing secure channel update with a qr code should be received only once")
818 .send(qr)
819 .expect("Bob should be able to send the qr code code to Alice");
820 }
821 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrScanned(
822 cctx,
823 )) => {
824 cctx_sender
825 .take()
826 .expect("The establishing secure channel update with a CheckCodeSender should be received only once")
827 .send(cctx)
828 .expect("Bob should be able to send the qr code code to Alice");
829 }
830 LoginProgress::Done => break,
831 _ => (),
832 }
833 }
834 });
835
836 let alice_task = spawn(async move {
837 grant_login_with_generated_qr(
838 &alice,
839 qr_receiver,
840 cctx_receiver,
841 AliceBehaviour::HappyPath,
842 )
843 .await
844 });
845
846 join!(
847 async { bob_login.await.expect("Bob should be able to login") },
848 async { alice_task.await.expect("Alice should have completed it's task successfully") },
849 async { updates_task.await.unwrap() }
850 );
851
852 assert!(bob.encryption().cross_signing_status().await.unwrap().is_complete());
853 let own_identity =
854 bob.encryption().get_user_identity(bob.user_id().unwrap()).await.unwrap().unwrap();
855
856 assert!(own_identity.is_verified());
857 }
858
859 #[async_test]
860 async fn test_generated_qr_login_with_homeserver_swap() {
861 let server = MatrixMockServer::new().await;
862 let rendezvous_server = MockedRendezvousServer::new(server.server(), "abcdEFG12345").await;
863 let (qr_sender, qr_receiver) = tokio::sync::oneshot::channel();
864 let (cctx_sender, cctx_receiver) = tokio::sync::oneshot::channel();
865
866 let login_server = MatrixMockServer::new().await;
867 let oauth_server = login_server.oauth();
868 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
869 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
870 oauth_server
871 .mock_device_authorization()
872 .ok()
873 .expect(1)
874 .named("device_authorization")
875 .mount()
876 .await;
877 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
878
879 server.mock_versions().ok().expect(1..).named("versions").mount().await;
880
881 login_server.mock_well_known().ok().expect(1).named("well_known").mount().await;
882 login_server.mock_versions().ok().expect(1..).named("versions").mount().await;
883 login_server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
884 login_server.mock_upload_keys().ok().expect(1).named("upload_keys").mount().await;
885 login_server.mock_query_keys().ok().expect(1).named("query_keys").mount().await;
886
887 let homeserver_url = rendezvous_server.homeserver_url.clone();
888
889 let alice = login_server.client_builder().logged_in_with_oauth().build().await;
892 assert!(alice.session_meta().is_some(), "Alice should be logged in");
893
894 let bob = Client::builder()
896 .server_name_or_homeserver_url(&homeserver_url)
897 .request_config(RequestConfig::new().disable_retry())
898 .build()
899 .await
900 .expect("Should be able to create a client for Bob");
901
902 let secure_channel = SecureChannel::login(bob.inner.http_client.clone(), &homeserver_url)
903 .await
904 .expect("Bob should be able to create a secure channel");
905
906 assert_eq!(QrCodeModeData::Login, secure_channel.qr_code_data().mode_data);
907
908 let registration_data = mock_client_metadata().into();
909 let bob_oauth = bob.oauth();
910 let bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate();
911 let mut bob_updates = bob_login.subscribe_to_progress();
912
913 let updates_task = spawn(async move {
914 let mut qr_sender = Some(qr_sender);
915 let mut cctx_sender = Some(cctx_sender);
916
917 while let Some(update) = bob_updates.next().await {
918 match update {
919 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrReady(qr)) => {
920 qr_sender
921 .take()
922 .expect("The establishing secure channel update with a qr code should be received only once")
923 .send(qr)
924 .expect("Bob should be able to send the qr code code to Alice");
925 }
926 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrScanned(
927 cctx,
928 )) => {
929 cctx_sender
930 .take()
931 .expect("The establishing secure channel update with a CheckCodeSender should be received only once")
932 .send(cctx)
933 .expect("Bob should be able to send the qr code code to Alice");
934 }
935 LoginProgress::Done => break,
936 _ => (),
937 }
938 }
939 });
940
941 let alice_task = spawn(async move {
942 grant_login_with_generated_qr(
943 &alice,
944 qr_receiver,
945 cctx_receiver,
946 AliceBehaviour::HappyPath,
947 )
948 .await
949 });
950
951 join!(
952 async { bob_login.await.expect("Bob should be able to login") },
953 async { alice_task.await.expect("Alice should have completed it's task successfully") },
954 async { updates_task.await.unwrap() }
955 );
956
957 assert!(bob.encryption().cross_signing_status().await.unwrap().is_complete());
958 let own_identity =
959 bob.encryption().get_user_identity(bob.user_id().unwrap()).await.unwrap().unwrap();
960
961 assert!(own_identity.is_verified());
962 }
963
964 async fn test_failure(
965 token_response: TokenResponse,
966 alice_behavior: AliceBehaviour,
967 ) -> Result<(), QRCodeLoginError> {
968 let server = MatrixMockServer::new().await;
969 let rendezvous_server = MockedRendezvousServer::new(server.server(), "abcdEFG12345").await;
970 let (sender, receiver) = tokio::sync::oneshot::channel();
971
972 let oauth_server = server.oauth();
973 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
974 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
975 oauth_server
976 .mock_device_authorization()
977 .ok()
978 .expect(1)
979 .named("device_authorization")
980 .mount()
981 .await;
982
983 let token_mock = oauth_server.mock_token();
984 let token_mock = match token_response {
985 TokenResponse::Ok => token_mock.ok(),
986 TokenResponse::AccessDenied => token_mock.access_denied(),
987 TokenResponse::ExpiredToken => token_mock.expired_token(),
988 };
989 token_mock.named("token").mount().await;
990
991 server.mock_versions().ok().named("versions").mount().await;
992 server.mock_who_am_i().ok().named("whoami").mount().await;
993
994 let client = HttpClient::new(reqwest::Client::new(), Default::default());
995 let alice = SecureChannel::reciprocate(client, &rendezvous_server.homeserver_url)
996 .await
997 .expect("Alice should be able to create a secure channel.");
998
999 assert_let!(QrCodeModeData::Reciprocate { server_name } = &alice.qr_code_data().mode_data);
1000
1001 let bob = Client::builder()
1002 .server_name_or_homeserver_url(server_name)
1003 .request_config(RequestConfig::new().disable_retry())
1004 .build()
1005 .await
1006 .expect("We should be able to build the Client object from the URL in the QR code");
1007
1008 let qr_code = alice.qr_code_data().clone();
1009
1010 let oauth = bob.oauth();
1011 let registration_data = mock_client_metadata().into();
1012 let login_bob = oauth.login_with_qr_code(Some(®istration_data)).scan(&qr_code);
1013 let mut updates = login_bob.subscribe_to_progress();
1014
1015 let _updates_task = spawn(async move {
1016 let mut sender = Some(sender);
1017
1018 while let Some(update) = updates.next().await {
1019 match update {
1020 LoginProgress::EstablishingSecureChannel(QrProgress { check_code }) => {
1021 sender
1022 .take()
1023 .expect("The establishing secure channel update should be received only once")
1024 .send(check_code)
1025 .expect("Bob should be able to send the check code to Alice");
1026 }
1027 LoginProgress::Done => break,
1028 _ => (),
1029 }
1030 }
1031 });
1032 let _alice_task = spawn(async move { grant_login(alice, receiver, alice_behavior).await });
1033 login_bob.await
1034 }
1035
1036 async fn test_generated_failure(
1037 token_response: TokenResponse,
1038 alice_behavior: AliceBehaviour,
1039 ) -> Result<(), QRCodeLoginError> {
1040 let server = MatrixMockServer::new().await;
1041 let rendezvous_server = MockedRendezvousServer::new(server.server(), "abcdEFG12345").await;
1042 let (qr_sender, qr_receiver) = tokio::sync::oneshot::channel();
1043 let (cctx_sender, cctx_receiver) = tokio::sync::oneshot::channel();
1044
1045 let oauth_server = server.oauth();
1046 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
1047 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
1048 oauth_server
1049 .mock_device_authorization()
1050 .ok()
1051 .expect(1)
1052 .named("device_authorization")
1053 .mount()
1054 .await;
1055
1056 let token_mock = oauth_server.mock_token();
1057 let token_mock = match token_response {
1058 TokenResponse::Ok => token_mock.ok(),
1059 TokenResponse::AccessDenied => token_mock.access_denied(),
1060 TokenResponse::ExpiredToken => token_mock.expired_token(),
1061 };
1062 token_mock.named("token").mount().await;
1063
1064 server.mock_versions().ok().named("versions").mount().await;
1065 server.mock_who_am_i().ok().named("whoami").mount().await;
1066
1067 let homeserver_url = rendezvous_server.homeserver_url.clone();
1068
1069 let alice = server.client_builder().logged_in_with_oauth().build().await;
1072 assert!(alice.session_meta().is_some(), "Alice should be logged in");
1073
1074 let bob = Client::builder()
1076 .server_name_or_homeserver_url(&homeserver_url)
1077 .request_config(RequestConfig::new().disable_retry())
1078 .build()
1079 .await
1080 .expect("Should be able to create a client for Bob");
1081
1082 let secure_channel = SecureChannel::login(bob.inner.http_client.clone(), &homeserver_url)
1083 .await
1084 .expect("Bob should be able to create a secure channel");
1085
1086 assert_eq!(QrCodeModeData::Login, secure_channel.qr_code_data().mode_data);
1087
1088 let registration_data = mock_client_metadata().into();
1089 let bob_oauth = bob.oauth();
1090 let bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate();
1091 let mut bob_updates = bob_login.subscribe_to_progress();
1092
1093 let _updates_task = spawn(async move {
1094 let mut qr_sender = Some(qr_sender);
1095 let mut cctx_sender = Some(cctx_sender);
1096
1097 while let Some(update) = bob_updates.next().await {
1098 match update {
1099 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrReady(qr)) => {
1100 qr_sender
1101 .take()
1102 .expect("The establishing secure channel update with a qr code should be received only once")
1103 .send(qr)
1104 .expect("Bob should be able to send the qr code code to Alice");
1105 }
1106 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrScanned(
1107 cctx,
1108 )) => {
1109 cctx_sender
1110 .take()
1111 .expect("The establishing secure channel update with a CheckCodeSender should be received only once")
1112 .send(cctx)
1113 .expect("Bob should be able to send the qr code code to Alice");
1114 }
1115 LoginProgress::Done => break,
1116 _ => (),
1117 }
1118 }
1119 });
1120
1121 let _alice_task = spawn(async move {
1122 grant_login_with_generated_qr(&alice, qr_receiver, cctx_receiver, alice_behavior).await
1123 });
1124 bob_login.await
1125 }
1126
1127 #[async_test]
1128 async fn test_qr_login_refused_access_token() {
1129 let result = test_failure(TokenResponse::AccessDenied, AliceBehaviour::HappyPath).await;
1130
1131 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
1132 assert_eq!(
1133 e.as_request_token_error(),
1134 Some(&DeviceCodeErrorResponseType::AccessDenied),
1135 "The server should have told us that access has been denied."
1136 );
1137 }
1138
1139 #[async_test]
1140 async fn test_generated_qr_login_refused_access_token() {
1141 let result =
1142 test_generated_failure(TokenResponse::AccessDenied, AliceBehaviour::HappyPath).await;
1143
1144 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
1145 assert_eq!(
1146 e.as_request_token_error(),
1147 Some(&DeviceCodeErrorResponseType::AccessDenied),
1148 "The server should have told us that access has been denied."
1149 );
1150 }
1151
1152 #[async_test]
1153 async fn test_qr_login_expired_token() {
1154 let result = test_failure(TokenResponse::ExpiredToken, AliceBehaviour::HappyPath).await;
1155
1156 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
1157 assert_eq!(
1158 e.as_request_token_error(),
1159 Some(&DeviceCodeErrorResponseType::ExpiredToken),
1160 "The server should have told us that access has been denied."
1161 );
1162 }
1163
1164 #[async_test]
1165 async fn test_generated_qr_login_expired_token() {
1166 let result =
1167 test_generated_failure(TokenResponse::ExpiredToken, AliceBehaviour::HappyPath).await;
1168
1169 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
1170 assert_eq!(
1171 e.as_request_token_error(),
1172 Some(&DeviceCodeErrorResponseType::ExpiredToken),
1173 "The server should have told us that access has been denied."
1174 );
1175 }
1176
1177 #[async_test]
1178 async fn test_qr_login_declined_protocol() {
1179 let result = test_failure(TokenResponse::Ok, AliceBehaviour::DeclinedProtocol).await;
1180
1181 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
1182 assert_eq!(
1183 reason,
1184 LoginFailureReason::UnsupportedProtocol,
1185 "Alice should have told us that the protocol is unsupported."
1186 );
1187 }
1188
1189 #[async_test]
1190 async fn test_generated_qr_login_declined_protocol() {
1191 let result =
1192 test_generated_failure(TokenResponse::Ok, AliceBehaviour::DeclinedProtocol).await;
1193
1194 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
1195 assert_eq!(
1196 reason,
1197 LoginFailureReason::UnsupportedProtocol,
1198 "Alice should have told us that the protocol is unsupported."
1199 );
1200 }
1201
1202 #[async_test]
1203 async fn test_qr_login_unexpected_message() {
1204 let result = test_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessage).await;
1205
1206 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
1207 assert_eq!(expected, "m.login.protocol_accepted");
1208 }
1209
1210 #[async_test]
1211 async fn test_generated_qr_login_unexpected_message() {
1212 let result =
1213 test_generated_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessage).await;
1214
1215 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
1216 assert_eq!(expected, "m.login.protocol_accepted");
1217 }
1218
1219 #[async_test]
1220 async fn test_qr_login_unexpected_message_instead_of_secrets() {
1221 let result =
1222 test_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessageInsteadOfSecrets)
1223 .await;
1224
1225 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
1226 assert_eq!(expected, "m.login.secrets");
1227 }
1228
1229 #[async_test]
1230 async fn test_generated_qr_login_unexpected_message_instead_of_secrets() {
1231 let result = test_generated_failure(
1232 TokenResponse::Ok,
1233 AliceBehaviour::UnexpectedMessageInsteadOfSecrets,
1234 )
1235 .await;
1236
1237 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
1238 assert_eq!(expected, "m.login.secrets");
1239 }
1240
1241 #[async_test]
1242 async fn test_qr_login_refuse_secrets() {
1243 let result = test_failure(TokenResponse::Ok, AliceBehaviour::RefuseSecrets).await;
1244
1245 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
1246 assert_eq!(reason, LoginFailureReason::DeviceNotFound);
1247 }
1248
1249 #[async_test]
1250 async fn test_generated_qr_login_refuse_secrets() {
1251 let result = test_generated_failure(TokenResponse::Ok, AliceBehaviour::RefuseSecrets).await;
1252
1253 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
1254 assert_eq!(reason, LoginFailureReason::DeviceNotFound);
1255 }
1256
1257 #[async_test]
1258 async fn test_device_authorization_endpoint_missing() {
1259 let server = MatrixMockServer::new().await;
1260 let rendezvous_server = MockedRendezvousServer::new(server.server(), "abcdEFG12345").await;
1261 let (sender, receiver) = tokio::sync::oneshot::channel();
1262
1263 let oauth_server = server.oauth();
1264 oauth_server
1265 .mock_server_metadata()
1266 .ok_without_device_authorization()
1267 .expect(1)
1268 .named("server_metadata")
1269 .mount()
1270 .await;
1271 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
1272
1273 server.mock_versions().ok().named("versions").mount().await;
1274 server.mock_who_am_i().ok().named("whoami").mount().await;
1275
1276 let client = HttpClient::new(reqwest::Client::new(), Default::default());
1277 let alice = SecureChannel::reciprocate(client, &rendezvous_server.homeserver_url)
1278 .await
1279 .expect("Alice should be able to create a secure channel.");
1280
1281 assert_let!(QrCodeModeData::Reciprocate { server_name } = &alice.qr_code_data().mode_data);
1282
1283 let bob = Client::builder()
1284 .server_name_or_homeserver_url(server_name)
1285 .request_config(RequestConfig::new().disable_retry())
1286 .build()
1287 .await
1288 .expect("We should be able to build the Client object from the URL in the QR code");
1289
1290 let qr_code = alice.qr_code_data().clone();
1291
1292 let oauth = bob.oauth();
1293 let registration_data = mock_client_metadata().into();
1294 let login_bob = oauth.login_with_qr_code(Some(®istration_data)).scan(&qr_code);
1295 let mut updates = login_bob.subscribe_to_progress();
1296
1297 let _updates_task = spawn(async move {
1298 let mut sender = Some(sender);
1299
1300 while let Some(update) = updates.next().await {
1301 match update {
1302 LoginProgress::EstablishingSecureChannel(QrProgress { check_code }) => {
1303 sender
1304 .take()
1305 .expect("The establishing secure channel update should be received only once")
1306 .send(check_code)
1307 .expect("Bob should be able to send the check code to Alice");
1308 }
1309 LoginProgress::Done => break,
1310 _ => (),
1311 }
1312 }
1313 });
1314 let _alice_task =
1315 spawn(async move { grant_login(alice, receiver, AliceBehaviour::HappyPath).await });
1316 let error = login_bob.await.unwrap_err();
1317
1318 assert_matches!(
1319 error,
1320 QRCodeLoginError::OAuth(DeviceAuthorizationOAuthError::NoDeviceAuthorizationEndpoint)
1321 );
1322 }
1323}