1use std::{future::IntoFuture, sync::Arc};
16
17use eyeball::SharedObservable;
18use futures_core::Stream;
19use matrix_sdk_base::{
20 SessionMeta, boxed_into_future,
21 crypto::types::qr_login::{QrCodeData, QrCodeMode},
22 store::RoomLoadSettings,
23};
24use oauth2::{DeviceCodeErrorResponseType, StandardDeviceAuthorizationResponse};
25use ruma::{
26 OwnedDeviceId,
27 api::client::discovery::get_authorization_server_metadata::v1::AuthorizationServerMetadata,
28};
29use tokio::sync::Mutex;
30use tracing::trace;
31use vodozemac::{Curve25519PublicKey, ecies::CheckCode};
32
33use super::{
34 DeviceAuthorizationOAuthError, QRCodeLoginError, SecureChannelError,
35 messages::{LoginFailureReason, QrAuthMessage},
36 secure_channel::{EstablishedSecureChannel, SecureChannel},
37};
38use crate::{
39 Client,
40 authentication::oauth::{ClientRegistrationData, OAuth, OAuthError, qrcode::LoginProtocolType},
41};
42
43async fn send_unexpected_message_error(
44 channel: &mut EstablishedSecureChannel,
45) -> Result<(), SecureChannelError> {
46 channel
47 .send_json(QrAuthMessage::LoginFailure {
48 reason: LoginFailureReason::UnexpectedMessageReceived,
49 homeserver: None,
50 })
51 .await
52}
53
54async fn finish_login<Q>(
55 client: &Client,
56 mut channel: EstablishedSecureChannel,
57 registration_data: Option<&ClientRegistrationData>,
58 state: SharedObservable<LoginProgress<Q>>,
59) -> Result<(), QRCodeLoginError> {
60 let oauth = client.oauth();
61
62 trace!("Registering the client with the OAuth 2.0 authorization server.");
64 let server_metadata = register_client(&oauth, registration_data).await?;
65
66 let account = vodozemac::olm::Account::new();
69 let public_key = account.identity_keys().curve25519;
70 let device_id = public_key;
71
72 trace!("Requesting device authorization.");
75 let auth_grant_response =
76 request_device_authorization(&oauth, &server_metadata, device_id).await?;
77
78 trace!("Letting the existing device know about the device authorization grant.");
81 let message =
82 QrAuthMessage::authorization_grant_login_protocol((&auth_grant_response).into(), device_id);
83 channel.send_json(&message).await?;
84
85 match channel.receive_json().await? {
87 QrAuthMessage::LoginProtocolAccepted => (),
88 QrAuthMessage::LoginFailure { reason, homeserver } => {
89 return Err(QRCodeLoginError::LoginFailure { reason, homeserver });
90 }
91 message => {
92 send_unexpected_message_error(&mut channel).await?;
93
94 return Err(QRCodeLoginError::UnexpectedMessage {
95 expected: "m.login.protocol_accepted",
96 received: message,
97 });
98 }
99 }
100
101 let user_code = auth_grant_response.user_code();
105 state.set(LoginProgress::WaitingForToken { user_code: user_code.secret().to_owned() });
106
107 trace!("Waiting for the OAuth 2.0 authorization server to give us the access token.");
110 if let Err(e) = wait_for_tokens(&oauth, &server_metadata, &auth_grant_response).await {
111 if let Some(e) = e.as_request_token_error() {
114 match e {
115 DeviceCodeErrorResponseType::AccessDenied => {
116 channel.send_json(QrAuthMessage::LoginDeclined).await?;
117 }
118 DeviceCodeErrorResponseType::ExpiredToken => {
119 channel
120 .send_json(QrAuthMessage::LoginFailure {
121 reason: LoginFailureReason::AuthorizationExpired,
122 homeserver: None,
123 })
124 .await?;
125 }
126 _ => (),
127 }
128 }
129
130 return Err(e.into());
131 }
132
133 trace!("Discovering our own user id.");
139 let whoami_response = client.whoami().await.map_err(QRCodeLoginError::UserIdDiscovery)?;
140 client
141 .base_client()
142 .activate(
143 SessionMeta {
144 user_id: whoami_response.user_id,
145 device_id: OwnedDeviceId::from(device_id.to_base64()),
146 },
147 RoomLoadSettings::default(),
148 Some(account),
149 )
150 .await
151 .map_err(|error| QRCodeLoginError::SessionTokens(error.into()))?;
152
153 client.oauth().enable_cross_process_lock().await?;
154
155 state.set(LoginProgress::SyncingSecrets);
156
157 trace!("Telling the existing device that we successfully logged in.");
159 let message = QrAuthMessage::LoginSuccess;
160 channel.send_json(&message).await?;
161
162 trace!("Waiting for the secrets bundle.");
165 let bundle = match channel.receive_json().await? {
166 QrAuthMessage::LoginSecrets(bundle) => bundle,
167 QrAuthMessage::LoginFailure { reason, homeserver } => {
168 return Err(QRCodeLoginError::LoginFailure { reason, homeserver });
169 }
170 message => {
171 send_unexpected_message_error(&mut channel).await?;
172
173 return Err(QRCodeLoginError::UnexpectedMessage {
174 expected: "m.login.secrets",
175 received: message,
176 });
177 }
178 };
179
180 client.encryption().import_secrets_bundle(&bundle).await?;
183
184 client
187 .encryption()
188 .ensure_device_keys_upload()
189 .await
190 .map_err(QRCodeLoginError::DeviceKeyUpload)?;
191
192 client.encryption().spawn_initialization_task(None).await;
197 client.encryption().wait_for_e2ee_initialization_tasks().await;
198
199 trace!("successfully logged in and enabled E2EE.");
200
201 state.set(LoginProgress::Done);
203
204 Ok(())
206}
207
208async fn register_client(
212 oauth: &OAuth,
213 registration_data: Option<&ClientRegistrationData>,
214) -> Result<AuthorizationServerMetadata, DeviceAuthorizationOAuthError> {
215 let server_metadata = oauth.server_metadata().await.map_err(OAuthError::from)?;
216 oauth.use_registration_data(&server_metadata, registration_data).await?;
217
218 Ok(server_metadata)
219}
220
221async fn request_device_authorization(
222 oauth: &OAuth,
223 server_metadata: &AuthorizationServerMetadata,
224 device_id: Curve25519PublicKey,
225) -> Result<StandardDeviceAuthorizationResponse, DeviceAuthorizationOAuthError> {
226 let response = oauth
227 .request_device_authorization(server_metadata, Some(device_id.to_base64().into()))
228 .await?;
229 Ok(response)
230}
231
232async fn wait_for_tokens(
233 oauth: &OAuth,
234 server_metadata: &AuthorizationServerMetadata,
235 auth_response: &StandardDeviceAuthorizationResponse,
236) -> Result<(), DeviceAuthorizationOAuthError> {
237 oauth.exchange_device_code(server_metadata, auth_response).await?;
238 Ok(())
239}
240
241#[derive(Clone, Debug, Default)]
243pub enum LoginProgress<Q> {
244 #[default]
246 Starting,
247 EstablishingSecureChannel(Q),
250 WaitingForToken {
254 user_code: String,
258 },
259 SyncingSecrets,
261 Done,
263}
264
265#[derive(Clone, Debug)]
272pub struct QrProgress {
273 pub check_code: CheckCode,
275}
276
277#[derive(Clone, Debug)]
284pub enum GeneratedQrProgress {
285 QrReady(QrCodeData),
287 QrScanned(CheckCodeSender),
290}
291
292#[derive(Clone, Debug)]
295pub struct CheckCodeSender {
296 inner: Arc<Mutex<Option<tokio::sync::oneshot::Sender<u8>>>>,
297}
298
299impl CheckCodeSender {
300 fn new(tx: tokio::sync::oneshot::Sender<u8>) -> Self {
301 Self { inner: Arc::new(Mutex::new(Some(tx))) }
302 }
303
304 pub async fn send(&self, check_code: u8) -> Result<(), CheckCodeSenderError> {
312 match self.inner.lock().await.take() {
313 Some(tx) => tx.send(check_code).map_err(|_| CheckCodeSenderError::CannotSend),
314 None => Err(CheckCodeSenderError::AlreadySent),
315 }
316 }
317}
318
319#[derive(Debug, thiserror::Error)]
321pub enum CheckCodeSenderError {
322 #[error("check code already sent.")]
324 AlreadySent,
325 #[error("check code cannot be sent.")]
327 CannotSend,
328}
329
330#[derive(Debug)]
333pub struct LoginWithQrCode<'a> {
334 client: &'a Client,
335 registration_data: Option<&'a ClientRegistrationData>,
336 qr_code_data: &'a QrCodeData,
337 state: SharedObservable<LoginProgress<QrProgress>>,
338}
339
340impl LoginWithQrCode<'_> {
341 pub fn subscribe_to_progress(&self) -> impl Stream<Item = LoginProgress<QrProgress>> + use<> {
347 self.state.subscribe()
348 }
349}
350
351impl<'a> IntoFuture for LoginWithQrCode<'a> {
352 type Output = Result<(), QRCodeLoginError>;
353 boxed_into_future!(extra_bounds: 'a);
354
355 fn into_future(self) -> Self::IntoFuture {
356 Box::pin(async move {
357 let channel = self.establish_secure_channel().await?;
366
367 trace!("Established the secure channel.");
368
369 let check_code = channel.check_code().to_owned();
373 self.state.set(LoginProgress::EstablishingSecureChannel(QrProgress { check_code }));
374
375 finish_login(self.client, channel, self.registration_data, self.state).await
382 })
383 }
384}
385
386impl<'a> LoginWithQrCode<'a> {
387 pub(crate) fn new(
388 client: &'a Client,
389 qr_code_data: &'a QrCodeData,
390 registration_data: Option<&'a ClientRegistrationData>,
391 ) -> LoginWithQrCode<'a> {
392 LoginWithQrCode { client, registration_data, qr_code_data, state: Default::default() }
393 }
394
395 async fn establish_secure_channel(
396 &self,
397 ) -> Result<EstablishedSecureChannel, SecureChannelError> {
398 let http_client = self.client.inner.http_client.inner.clone();
399
400 let channel = EstablishedSecureChannel::from_qr_code(
401 http_client,
402 self.qr_code_data,
403 QrCodeMode::Login,
404 )
405 .await?;
406
407 Ok(channel)
408 }
409}
410
411#[derive(Debug)]
414pub struct LoginWithGeneratedQrCode<'a> {
415 client: &'a Client,
416 registration_data: Option<&'a ClientRegistrationData>,
417 state: SharedObservable<LoginProgress<GeneratedQrProgress>>,
418}
419
420impl LoginWithGeneratedQrCode<'_> {
421 pub fn subscribe_to_progress(
426 &self,
427 ) -> impl Stream<Item = LoginProgress<GeneratedQrProgress>> + use<> {
428 self.state.subscribe()
429 }
430}
431
432impl<'a> IntoFuture for LoginWithGeneratedQrCode<'a> {
433 type Output = Result<(), QRCodeLoginError>;
434 boxed_into_future!(extra_bounds: 'a);
435
436 fn into_future(self) -> Self::IntoFuture {
437 Box::pin(async move {
438 let mut channel = self.establish_secure_channel().await?;
441
442 trace!("Established the secure channel.");
443
444 let message = channel.receive_json().await?;
448
449 let homeserver = match message {
452 QrAuthMessage::LoginProtocols { protocols, homeserver } => {
453 if !protocols.contains(&LoginProtocolType::DeviceAuthorizationGrant) {
454 channel
455 .send_json(QrAuthMessage::LoginFailure {
456 reason: LoginFailureReason::UnsupportedProtocol,
457 homeserver: None,
458 })
459 .await?;
460
461 return Err(QRCodeLoginError::LoginFailure {
462 reason: LoginFailureReason::UnsupportedProtocol,
463 homeserver: None,
464 });
465 }
466
467 homeserver
468 }
469 _ => {
470 send_unexpected_message_error(&mut channel).await?;
471
472 return Err(QRCodeLoginError::UnexpectedMessage {
473 expected: "m.login.protocols",
474 received: message,
475 });
476 }
477 };
478
479 if self.client.homeserver() != homeserver {
482 self.client
483 .switch_homeserver_and_re_resolve_well_known(homeserver)
484 .await
485 .map_err(QRCodeLoginError::ServerReset)?;
486 }
487
488 finish_login(self.client, channel, self.registration_data, self.state).await
491 })
492 }
493}
494
495impl<'a> LoginWithGeneratedQrCode<'a> {
496 pub(crate) fn new(
497 client: &'a Client,
498 registration_data: Option<&'a ClientRegistrationData>,
499 ) -> Self {
500 Self { client, registration_data, state: Default::default() }
501 }
502
503 async fn establish_secure_channel(
504 &self,
505 ) -> Result<EstablishedSecureChannel, SecureChannelError> {
506 let http_client = self.client.inner.http_client.clone();
507
508 let secure_channel = SecureChannel::login(http_client, &self.client.homeserver()).await?;
512
513 let qr_code_data = secure_channel.qr_code_data().clone();
517 trace!("Generated QR code.");
518 self.state.set(LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrReady(
519 qr_code_data,
520 )));
521
522 let channel = secure_channel.connect().await?;
526
527 trace!("Waiting for checkcode.");
532 let (tx, rx) = tokio::sync::oneshot::channel();
533 self.state.set(LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrScanned(
534 CheckCodeSender::new(tx),
535 )));
536
537 let check_code = rx.await.map_err(|_| SecureChannelError::CannotReceiveCheckCode)?;
541 trace!("Received check code.");
542 channel.confirm(check_code)
543 }
544}
545
546#[cfg(all(test, not(target_family = "wasm")))]
547mod test {
548 use assert_matches2::{assert_let, assert_matches};
549 use futures_util::{StreamExt, join};
550 use matrix_sdk_base::crypto::types::{SecretsBundle, qr_login::QrCodeModeData};
551 use matrix_sdk_common::executor::spawn;
552 use matrix_sdk_test::async_test;
553 use serde_json::json;
554
555 use super::*;
556 use crate::{
557 authentication::oauth::qrcode::{
558 messages::LoginProtocolType,
559 secure_channel::{SecureChannel, test::MockedRendezvousServer},
560 },
561 config::RequestConfig,
562 http_client::HttpClient,
563 test_utils::{client::oauth::mock_client_metadata, mocks::MatrixMockServer},
564 };
565
566 enum AliceBehaviour {
567 HappyPath,
568 DeclinedProtocol,
569 UnexpectedMessage,
570 UnexpectedMessageInsteadOfSecrets,
571 RefuseSecrets,
572 }
573
574 enum TokenResponse {
576 Ok,
577 AccessDenied,
578 ExpiredToken,
579 }
580
581 fn secrets_bundle() -> SecretsBundle {
582 let json = json!({
583 "cross_signing": {
584 "master_key": "rTtSv67XGS6k/rg6/yTG/m573cyFTPFRqluFhQY+hSw",
585 "self_signing_key": "4jbPt7jh5D2iyM4U+3IDa+WthgJB87IQN1ATdkau+xk",
586 "user_signing_key": "YkFKtkjcsTxF6UAzIIG/l6Nog/G2RigCRfWj3cjNWeM",
587 },
588 });
589
590 serde_json::from_value(json).expect("We should be able to deserialize a secrets bundle")
591 }
592
593 async fn grant_login(
598 alice: SecureChannel,
599 check_code_receiver: tokio::sync::oneshot::Receiver<CheckCode>,
600 behavior: AliceBehaviour,
601 ) {
602 let alice = alice.connect().await.expect("Alice should be able to connect the channel");
603
604 let check_code =
605 check_code_receiver.await.expect("We should receive the check code from bob");
606
607 let mut alice = alice
608 .confirm(check_code.to_digit())
609 .expect("Alice should be able to confirm the secure channel");
610
611 let message = alice
612 .receive_json()
613 .await
614 .expect("Alice should be able to receive the initial message from Bob");
615
616 assert_let!(QrAuthMessage::LoginProtocol { protocol, .. } = message);
617 assert_eq!(protocol, LoginProtocolType::DeviceAuthorizationGrant);
618
619 let message = match behavior {
620 AliceBehaviour::DeclinedProtocol => QrAuthMessage::LoginFailure {
621 reason: LoginFailureReason::UnsupportedProtocol,
622 homeserver: None,
623 },
624 AliceBehaviour::UnexpectedMessage => QrAuthMessage::LoginDeclined,
625 _ => QrAuthMessage::LoginProtocolAccepted,
626 };
627
628 alice.send_json(message).await.unwrap();
629
630 let message: QrAuthMessage = alice.receive_json().await.unwrap();
631 assert_let!(QrAuthMessage::LoginSuccess = message);
632
633 let message = match behavior {
634 AliceBehaviour::UnexpectedMessageInsteadOfSecrets => QrAuthMessage::LoginDeclined,
635 AliceBehaviour::RefuseSecrets => QrAuthMessage::LoginFailure {
636 reason: LoginFailureReason::DeviceNotFound,
637 homeserver: None,
638 },
639 _ => QrAuthMessage::LoginSecrets(secrets_bundle()),
640 };
641
642 alice.send_json(message).await.unwrap();
643 }
644
645 #[async_test]
646 async fn test_qr_login() {
647 let server = MatrixMockServer::new().await;
648 let rendezvous_server = MockedRendezvousServer::new(server.server(), "abcdEFG12345").await;
649 let (sender, receiver) = tokio::sync::oneshot::channel();
650
651 let oauth_server = server.oauth();
652 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
653 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
654 oauth_server
655 .mock_device_authorization()
656 .ok()
657 .expect(1)
658 .named("device_authorization")
659 .mount()
660 .await;
661 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
662
663 server.mock_versions().ok().expect(1..).named("versions").mount().await;
664 server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
665 server.mock_upload_keys().ok().expect(1).named("upload_keys").mount().await;
666 server.mock_query_keys().ok().expect(1).named("query_keys").mount().await;
667
668 let client = HttpClient::new(reqwest::Client::new(), Default::default());
669 let alice = SecureChannel::new(client, &rendezvous_server.homeserver_url)
670 .await
671 .expect("Alice should be able to create a secure channel.");
672
673 assert_let!(QrCodeModeData::Reciprocate { server_name } = &alice.qr_code_data().mode_data);
674
675 let bob = Client::builder()
676 .server_name_or_homeserver_url(server_name)
677 .request_config(RequestConfig::new().disable_retry())
678 .build()
679 .await
680 .expect("We should be able to build the Client object from the URL in the QR code");
681
682 let qr_code = alice.qr_code_data().clone();
683
684 let oauth = bob.oauth();
685 let registration_data = mock_client_metadata().into();
686 let login_bob = oauth.login_with_qr_code(Some(®istration_data)).scan(&qr_code);
687 let mut updates = login_bob.subscribe_to_progress();
688
689 let updates_task = spawn(async move {
690 let mut sender = Some(sender);
691
692 while let Some(update) = updates.next().await {
693 match update {
694 LoginProgress::EstablishingSecureChannel(QrProgress { check_code }) => {
695 sender
696 .take()
697 .expect("The establishing secure channel update should be received only once")
698 .send(check_code)
699 .expect("Bob should be able to send the check code to Alice");
700 }
701 LoginProgress::Done => break,
702 _ => (),
703 }
704 }
705 });
706 let alice_task =
707 spawn(async { grant_login(alice, receiver, AliceBehaviour::HappyPath).await });
708
709 join!(
710 async {
711 login_bob.await.expect("Bob should be able to login");
712 },
713 async {
714 alice_task.await.expect("Alice should have completed it's task successfully");
715 },
716 async { updates_task.await.unwrap() }
717 );
718
719 assert!(bob.encryption().cross_signing_status().await.unwrap().is_complete());
720 let own_identity =
721 bob.encryption().get_user_identity(bob.user_id().unwrap()).await.unwrap().unwrap();
722
723 assert!(own_identity.is_verified());
724 }
725
726 async fn grant_login_with_generated_qr(
727 alice: &Client,
728 qr_receiver: tokio::sync::oneshot::Receiver<QrCodeData>,
729 cctx_receiver: tokio::sync::oneshot::Receiver<CheckCodeSender>,
730 behavior: AliceBehaviour,
731 ) {
732 let qr_code_data = qr_receiver.await.expect("Alice should receive the QR code");
733
734 let mut channel = EstablishedSecureChannel::from_qr_code(
735 alice.inner.http_client.inner.clone(),
736 &qr_code_data,
737 QrCodeMode::Reciprocate,
738 )
739 .await
740 .expect("Alice should be able to establish the secure channel");
741
742 trace!("Established the secure channel.");
743
744 let check_code = channel.check_code().to_digit();
747
748 let check_code_sender =
749 cctx_receiver.await.expect("Alice should receive the CheckCodeSender");
750
751 check_code_sender
752 .send(check_code)
753 .await
754 .expect("Alice should be able to send the check code to Bob");
755
756 let message = QrAuthMessage::LoginProtocols {
758 protocols: vec![LoginProtocolType::DeviceAuthorizationGrant],
759 homeserver: alice.homeserver(),
760 };
761 channel
762 .send_json(message)
763 .await
764 .expect("Alice should be able to send the `m.login.protocols` message to Bob");
765
766 let message: QrAuthMessage = channel
768 .receive_json()
769 .await
770 .expect("Alice should be able to receive the `m.login.protocol` message from Bob");
771 assert_let!(QrAuthMessage::LoginProtocol { protocol, .. } = message);
772 assert_eq!(protocol, LoginProtocolType::DeviceAuthorizationGrant);
773
774 let message = match behavior {
776 AliceBehaviour::DeclinedProtocol => QrAuthMessage::LoginFailure {
777 reason: LoginFailureReason::UnsupportedProtocol,
778 homeserver: None,
779 },
780 AliceBehaviour::UnexpectedMessage => QrAuthMessage::LoginDeclined,
781 _ => QrAuthMessage::LoginProtocolAccepted,
782 };
783 channel
784 .send_json(message)
785 .await
786 .expect("Alice should be able to send the `m.login.protocol_accepted` message to Bob");
787
788 let message: QrAuthMessage = channel
789 .receive_json()
790 .await
791 .expect("Alice should be able to receive the `m.login.success` message from Bob");
792 assert_let!(QrAuthMessage::LoginSuccess = message);
793
794 let message = match behavior {
796 AliceBehaviour::UnexpectedMessageInsteadOfSecrets => QrAuthMessage::LoginDeclined,
797 AliceBehaviour::RefuseSecrets => QrAuthMessage::LoginFailure {
798 reason: LoginFailureReason::DeviceNotFound,
799 homeserver: None,
800 },
801 _ => QrAuthMessage::LoginSecrets(secrets_bundle()),
802 };
803 channel
804 .send_json(message)
805 .await
806 .expect("Alice should be able to send the `m.login.secrets` message to Bob");
807 }
808
809 #[async_test]
810 async fn test_generated_qr_login() {
811 let server = MatrixMockServer::new().await;
812 let rendezvous_server = MockedRendezvousServer::new(server.server(), "abcdEFG12345").await;
813 let (qr_sender, qr_receiver) = tokio::sync::oneshot::channel();
814 let (cctx_sender, cctx_receiver) = tokio::sync::oneshot::channel();
815
816 let oauth_server = server.oauth();
817 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
818 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
819 oauth_server
820 .mock_device_authorization()
821 .ok()
822 .expect(1)
823 .named("device_authorization")
824 .mount()
825 .await;
826 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
827
828 server.mock_versions().ok().expect(1..).named("versions").mount().await;
829 server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
830 server.mock_upload_keys().ok().expect(1).named("upload_keys").mount().await;
831 server.mock_query_keys().ok().expect(1).named("query_keys").mount().await;
832
833 let homeserver_url = rendezvous_server.homeserver_url.clone();
834
835 let alice = server.client_builder().logged_in_with_oauth().build().await;
838 assert!(alice.session_meta().is_some(), "Alice should be logged in");
839
840 let bob = Client::builder()
842 .server_name_or_homeserver_url(&homeserver_url)
843 .request_config(RequestConfig::new().disable_retry())
844 .build()
845 .await
846 .expect("Should be able to create a client for Bob");
847
848 let secure_channel = SecureChannel::login(bob.inner.http_client.clone(), &homeserver_url)
849 .await
850 .expect("Bob should be able to create a secure channel");
851
852 assert_eq!(QrCodeModeData::Login, secure_channel.qr_code_data().mode_data);
853
854 let registration_data = mock_client_metadata().into();
855 let bob_oauth = bob.oauth();
856 let bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate();
857 let mut bob_updates = bob_login.subscribe_to_progress();
858
859 let updates_task = spawn(async move {
860 let mut qr_sender = Some(qr_sender);
861 let mut cctx_sender = Some(cctx_sender);
862
863 while let Some(update) = bob_updates.next().await {
864 match update {
865 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrReady(qr)) => {
866 qr_sender
867 .take()
868 .expect("The establishing secure channel update with a qr code should be received only once")
869 .send(qr)
870 .expect("Bob should be able to send the qr code code to Alice");
871 }
872 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrScanned(
873 cctx,
874 )) => {
875 cctx_sender
876 .take()
877 .expect("The establishing secure channel update with a CheckCodeSender should be received only once")
878 .send(cctx)
879 .expect("Bob should be able to send the qr code code to Alice");
880 }
881 LoginProgress::Done => break,
882 _ => (),
883 }
884 }
885 });
886
887 let alice_task = spawn(async move {
888 grant_login_with_generated_qr(
889 &alice,
890 qr_receiver,
891 cctx_receiver,
892 AliceBehaviour::HappyPath,
893 )
894 .await
895 });
896
897 join!(
898 async { bob_login.await.expect("Bob should be able to login") },
899 async { alice_task.await.expect("Alice should have completed it's task successfully") },
900 async { updates_task.await.unwrap() }
901 );
902
903 assert!(bob.encryption().cross_signing_status().await.unwrap().is_complete());
904 let own_identity =
905 bob.encryption().get_user_identity(bob.user_id().unwrap()).await.unwrap().unwrap();
906
907 assert!(own_identity.is_verified());
908 }
909
910 #[async_test]
911 async fn test_generated_qr_login_with_homeserver_swap() {
912 let server = MatrixMockServer::new().await;
913 let rendezvous_server = MockedRendezvousServer::new(server.server(), "abcdEFG12345").await;
914 let (qr_sender, qr_receiver) = tokio::sync::oneshot::channel();
915 let (cctx_sender, cctx_receiver) = tokio::sync::oneshot::channel();
916
917 let login_server = MatrixMockServer::new().await;
918 let oauth_server = login_server.oauth();
919 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
920 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
921 oauth_server
922 .mock_device_authorization()
923 .ok()
924 .expect(1)
925 .named("device_authorization")
926 .mount()
927 .await;
928 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
929
930 server.mock_versions().ok().expect(1..).named("versions").mount().await;
931
932 login_server.mock_well_known().ok().expect(1).named("well_known").mount().await;
933 login_server.mock_versions().ok().expect(1..).named("versions").mount().await;
934 login_server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
935 login_server.mock_upload_keys().ok().expect(1).named("upload_keys").mount().await;
936 login_server.mock_query_keys().ok().expect(1).named("query_keys").mount().await;
937
938 let homeserver_url = rendezvous_server.homeserver_url.clone();
939
940 let alice = login_server.client_builder().logged_in_with_oauth().build().await;
943 assert!(alice.session_meta().is_some(), "Alice should be logged in");
944
945 let bob = Client::builder()
947 .server_name_or_homeserver_url(&homeserver_url)
948 .request_config(RequestConfig::new().disable_retry())
949 .build()
950 .await
951 .expect("Should be able to create a client for Bob");
952
953 let secure_channel = SecureChannel::login(bob.inner.http_client.clone(), &homeserver_url)
954 .await
955 .expect("Bob should be able to create a secure channel");
956
957 assert_eq!(QrCodeModeData::Login, secure_channel.qr_code_data().mode_data);
958
959 let registration_data = mock_client_metadata().into();
960 let bob_oauth = bob.oauth();
961 let bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate();
962 let mut bob_updates = bob_login.subscribe_to_progress();
963
964 let updates_task = spawn(async move {
965 let mut qr_sender = Some(qr_sender);
966 let mut cctx_sender = Some(cctx_sender);
967
968 while let Some(update) = bob_updates.next().await {
969 match update {
970 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrReady(qr)) => {
971 qr_sender
972 .take()
973 .expect("The establishing secure channel update with a qr code should be received only once")
974 .send(qr)
975 .expect("Bob should be able to send the qr code code to Alice");
976 }
977 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrScanned(
978 cctx,
979 )) => {
980 cctx_sender
981 .take()
982 .expect("The establishing secure channel update with a CheckCodeSender should be received only once")
983 .send(cctx)
984 .expect("Bob should be able to send the qr code code to Alice");
985 }
986 LoginProgress::Done => break,
987 _ => (),
988 }
989 }
990 });
991
992 let alice_task = spawn(async move {
993 grant_login_with_generated_qr(
994 &alice,
995 qr_receiver,
996 cctx_receiver,
997 AliceBehaviour::HappyPath,
998 )
999 .await
1000 });
1001
1002 join!(
1003 async { bob_login.await.expect("Bob should be able to login") },
1004 async { alice_task.await.expect("Alice should have completed it's task successfully") },
1005 async { updates_task.await.unwrap() }
1006 );
1007
1008 assert!(bob.encryption().cross_signing_status().await.unwrap().is_complete());
1009 let own_identity =
1010 bob.encryption().get_user_identity(bob.user_id().unwrap()).await.unwrap().unwrap();
1011
1012 assert!(own_identity.is_verified());
1013 }
1014
1015 async fn test_failure(
1016 token_response: TokenResponse,
1017 alice_behavior: AliceBehaviour,
1018 ) -> Result<(), QRCodeLoginError> {
1019 let server = MatrixMockServer::new().await;
1020 let rendezvous_server = MockedRendezvousServer::new(server.server(), "abcdEFG12345").await;
1021 let (sender, receiver) = tokio::sync::oneshot::channel();
1022
1023 let oauth_server = server.oauth();
1024 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
1025 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
1026 oauth_server
1027 .mock_device_authorization()
1028 .ok()
1029 .expect(1)
1030 .named("device_authorization")
1031 .mount()
1032 .await;
1033
1034 let token_mock = oauth_server.mock_token();
1035 let token_mock = match token_response {
1036 TokenResponse::Ok => token_mock.ok(),
1037 TokenResponse::AccessDenied => token_mock.access_denied(),
1038 TokenResponse::ExpiredToken => token_mock.expired_token(),
1039 };
1040 token_mock.named("token").mount().await;
1041
1042 server.mock_versions().ok().named("versions").mount().await;
1043 server.mock_who_am_i().ok().named("whoami").mount().await;
1044
1045 let client = HttpClient::new(reqwest::Client::new(), Default::default());
1046 let alice = SecureChannel::new(client, &rendezvous_server.homeserver_url)
1047 .await
1048 .expect("Alice should be able to create a secure channel.");
1049
1050 assert_let!(QrCodeModeData::Reciprocate { server_name } = &alice.qr_code_data().mode_data);
1051
1052 let bob = Client::builder()
1053 .server_name_or_homeserver_url(server_name)
1054 .request_config(RequestConfig::new().disable_retry())
1055 .build()
1056 .await
1057 .expect("We should be able to build the Client object from the URL in the QR code");
1058
1059 let qr_code = alice.qr_code_data().clone();
1060
1061 let oauth = bob.oauth();
1062 let registration_data = mock_client_metadata().into();
1063 let login_bob = oauth.login_with_qr_code(Some(®istration_data)).scan(&qr_code);
1064 let mut updates = login_bob.subscribe_to_progress();
1065
1066 let _updates_task = spawn(async move {
1067 let mut sender = Some(sender);
1068
1069 while let Some(update) = updates.next().await {
1070 match update {
1071 LoginProgress::EstablishingSecureChannel(QrProgress { check_code }) => {
1072 sender
1073 .take()
1074 .expect("The establishing secure channel update should be received only once")
1075 .send(check_code)
1076 .expect("Bob should be able to send the check code to Alice");
1077 }
1078 LoginProgress::Done => break,
1079 _ => (),
1080 }
1081 }
1082 });
1083 let _alice_task = spawn(async move { grant_login(alice, receiver, alice_behavior).await });
1084 login_bob.await
1085 }
1086
1087 async fn test_generated_failure(
1088 token_response: TokenResponse,
1089 alice_behavior: AliceBehaviour,
1090 ) -> Result<(), QRCodeLoginError> {
1091 let server = MatrixMockServer::new().await;
1092 let rendezvous_server = MockedRendezvousServer::new(server.server(), "abcdEFG12345").await;
1093 let (qr_sender, qr_receiver) = tokio::sync::oneshot::channel();
1094 let (cctx_sender, cctx_receiver) = tokio::sync::oneshot::channel();
1095
1096 let oauth_server = server.oauth();
1097 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
1098 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
1099 oauth_server
1100 .mock_device_authorization()
1101 .ok()
1102 .expect(1)
1103 .named("device_authorization")
1104 .mount()
1105 .await;
1106
1107 let token_mock = oauth_server.mock_token();
1108 let token_mock = match token_response {
1109 TokenResponse::Ok => token_mock.ok(),
1110 TokenResponse::AccessDenied => token_mock.access_denied(),
1111 TokenResponse::ExpiredToken => token_mock.expired_token(),
1112 };
1113 token_mock.named("token").mount().await;
1114
1115 server.mock_versions().ok().named("versions").mount().await;
1116 server.mock_who_am_i().ok().named("whoami").mount().await;
1117
1118 let homeserver_url = rendezvous_server.homeserver_url.clone();
1119
1120 let alice = server.client_builder().logged_in_with_oauth().build().await;
1123 assert!(alice.session_meta().is_some(), "Alice should be logged in");
1124
1125 let bob = Client::builder()
1127 .server_name_or_homeserver_url(&homeserver_url)
1128 .request_config(RequestConfig::new().disable_retry())
1129 .build()
1130 .await
1131 .expect("Should be able to create a client for Bob");
1132
1133 let secure_channel = SecureChannel::login(bob.inner.http_client.clone(), &homeserver_url)
1134 .await
1135 .expect("Bob should be able to create a secure channel");
1136
1137 assert_eq!(QrCodeModeData::Login, secure_channel.qr_code_data().mode_data);
1138
1139 let registration_data = mock_client_metadata().into();
1140 let bob_oauth = bob.oauth();
1141 let bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate();
1142 let mut bob_updates = bob_login.subscribe_to_progress();
1143
1144 let _updates_task = spawn(async move {
1145 let mut qr_sender = Some(qr_sender);
1146 let mut cctx_sender = Some(cctx_sender);
1147
1148 while let Some(update) = bob_updates.next().await {
1149 match update {
1150 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrReady(qr)) => {
1151 qr_sender
1152 .take()
1153 .expect("The establishing secure channel update with a qr code should be received only once")
1154 .send(qr)
1155 .expect("Bob should be able to send the qr code code to Alice");
1156 }
1157 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrScanned(
1158 cctx,
1159 )) => {
1160 cctx_sender
1161 .take()
1162 .expect("The establishing secure channel update with a CheckCodeSender should be received only once")
1163 .send(cctx)
1164 .expect("Bob should be able to send the qr code code to Alice");
1165 }
1166 LoginProgress::Done => break,
1167 _ => (),
1168 }
1169 }
1170 });
1171
1172 let _alice_task = spawn(async move {
1173 grant_login_with_generated_qr(&alice, qr_receiver, cctx_receiver, alice_behavior).await
1174 });
1175 bob_login.await
1176 }
1177
1178 #[async_test]
1179 async fn test_qr_login_refused_access_token() {
1180 let result = test_failure(TokenResponse::AccessDenied, AliceBehaviour::HappyPath).await;
1181
1182 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
1183 assert_eq!(
1184 e.as_request_token_error(),
1185 Some(&DeviceCodeErrorResponseType::AccessDenied),
1186 "The server should have told us that access has been denied."
1187 );
1188 }
1189
1190 #[async_test]
1191 async fn test_generated_qr_login_refused_access_token() {
1192 let result =
1193 test_generated_failure(TokenResponse::AccessDenied, AliceBehaviour::HappyPath).await;
1194
1195 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
1196 assert_eq!(
1197 e.as_request_token_error(),
1198 Some(&DeviceCodeErrorResponseType::AccessDenied),
1199 "The server should have told us that access has been denied."
1200 );
1201 }
1202
1203 #[async_test]
1204 async fn test_qr_login_expired_token() {
1205 let result = test_failure(TokenResponse::ExpiredToken, AliceBehaviour::HappyPath).await;
1206
1207 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
1208 assert_eq!(
1209 e.as_request_token_error(),
1210 Some(&DeviceCodeErrorResponseType::ExpiredToken),
1211 "The server should have told us that access has been denied."
1212 );
1213 }
1214
1215 #[async_test]
1216 async fn test_generated_qr_login_expired_token() {
1217 let result =
1218 test_generated_failure(TokenResponse::ExpiredToken, AliceBehaviour::HappyPath).await;
1219
1220 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
1221 assert_eq!(
1222 e.as_request_token_error(),
1223 Some(&DeviceCodeErrorResponseType::ExpiredToken),
1224 "The server should have told us that access has been denied."
1225 );
1226 }
1227
1228 #[async_test]
1229 async fn test_qr_login_declined_protocol() {
1230 let result = test_failure(TokenResponse::Ok, AliceBehaviour::DeclinedProtocol).await;
1231
1232 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
1233 assert_eq!(
1234 reason,
1235 LoginFailureReason::UnsupportedProtocol,
1236 "Alice should have told us that the protocol is unsupported."
1237 );
1238 }
1239
1240 #[async_test]
1241 async fn test_generated_qr_login_declined_protocol() {
1242 let result =
1243 test_generated_failure(TokenResponse::Ok, AliceBehaviour::DeclinedProtocol).await;
1244
1245 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
1246 assert_eq!(
1247 reason,
1248 LoginFailureReason::UnsupportedProtocol,
1249 "Alice should have told us that the protocol is unsupported."
1250 );
1251 }
1252
1253 #[async_test]
1254 async fn test_qr_login_unexpected_message() {
1255 let result = test_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessage).await;
1256
1257 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
1258 assert_eq!(expected, "m.login.protocol_accepted");
1259 }
1260
1261 #[async_test]
1262 async fn test_generated_qr_login_unexpected_message() {
1263 let result =
1264 test_generated_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessage).await;
1265
1266 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
1267 assert_eq!(expected, "m.login.protocol_accepted");
1268 }
1269
1270 #[async_test]
1271 async fn test_qr_login_unexpected_message_instead_of_secrets() {
1272 let result =
1273 test_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessageInsteadOfSecrets)
1274 .await;
1275
1276 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
1277 assert_eq!(expected, "m.login.secrets");
1278 }
1279
1280 #[async_test]
1281 async fn test_generated_qr_login_unexpected_message_instead_of_secrets() {
1282 let result = test_generated_failure(
1283 TokenResponse::Ok,
1284 AliceBehaviour::UnexpectedMessageInsteadOfSecrets,
1285 )
1286 .await;
1287
1288 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
1289 assert_eq!(expected, "m.login.secrets");
1290 }
1291
1292 #[async_test]
1293 async fn test_qr_login_refuse_secrets() {
1294 let result = test_failure(TokenResponse::Ok, AliceBehaviour::RefuseSecrets).await;
1295
1296 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
1297 assert_eq!(reason, LoginFailureReason::DeviceNotFound);
1298 }
1299
1300 #[async_test]
1301 async fn test_generated_qr_login_refuse_secrets() {
1302 let result = test_generated_failure(TokenResponse::Ok, AliceBehaviour::RefuseSecrets).await;
1303
1304 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
1305 assert_eq!(reason, LoginFailureReason::DeviceNotFound);
1306 }
1307
1308 #[async_test]
1309 async fn test_device_authorization_endpoint_missing() {
1310 let server = MatrixMockServer::new().await;
1311 let rendezvous_server = MockedRendezvousServer::new(server.server(), "abcdEFG12345").await;
1312 let (sender, receiver) = tokio::sync::oneshot::channel();
1313
1314 let oauth_server = server.oauth();
1315 oauth_server
1316 .mock_server_metadata()
1317 .ok_without_device_authorization()
1318 .expect(1)
1319 .named("server_metadata")
1320 .mount()
1321 .await;
1322 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
1323
1324 server.mock_versions().ok().named("versions").mount().await;
1325 server.mock_who_am_i().ok().named("whoami").mount().await;
1326
1327 let client = HttpClient::new(reqwest::Client::new(), Default::default());
1328 let alice = SecureChannel::new(client, &rendezvous_server.homeserver_url)
1329 .await
1330 .expect("Alice should be able to create a secure channel.");
1331
1332 assert_let!(QrCodeModeData::Reciprocate { server_name } = &alice.qr_code_data().mode_data);
1333
1334 let bob = Client::builder()
1335 .server_name_or_homeserver_url(server_name)
1336 .request_config(RequestConfig::new().disable_retry())
1337 .build()
1338 .await
1339 .expect("We should be able to build the Client object from the URL in the QR code");
1340
1341 let qr_code = alice.qr_code_data().clone();
1342
1343 let oauth = bob.oauth();
1344 let registration_data = mock_client_metadata().into();
1345 let login_bob = oauth.login_with_qr_code(Some(®istration_data)).scan(&qr_code);
1346 let mut updates = login_bob.subscribe_to_progress();
1347
1348 let _updates_task = spawn(async move {
1349 let mut sender = Some(sender);
1350
1351 while let Some(update) = updates.next().await {
1352 match update {
1353 LoginProgress::EstablishingSecureChannel(QrProgress { check_code }) => {
1354 sender
1355 .take()
1356 .expect("The establishing secure channel update should be received only once")
1357 .send(check_code)
1358 .expect("Bob should be able to send the check code to Alice");
1359 }
1360 LoginProgress::Done => break,
1361 _ => (),
1362 }
1363 }
1364 });
1365 let _alice_task =
1366 spawn(async move { grant_login(alice, receiver, AliceBehaviour::HappyPath).await });
1367 let error = login_bob.await.unwrap_err();
1368
1369 assert_matches!(
1370 error,
1371 QRCodeLoginError::OAuth(DeviceAuthorizationOAuthError::NoDeviceAuthorizationEndpoint)
1372 );
1373 }
1374}