1use std::future::IntoFuture;
16
17use eyeball::SharedObservable;
18use futures_core::Stream;
19use matrix_sdk_base::{
20 SessionMeta, boxed_into_future,
21 crypto::types::qr_login::{QrCodeData, QrCodeIntent},
22 store::RoomLoadSettings,
23};
24use oauth2::{DeviceCodeErrorResponseType, StandardDeviceAuthorizationResponse};
25use ruma::{
26 OwnedDeviceId,
27 api::client::discovery::get_authorization_server_metadata::v1::AuthorizationServerMetadata,
28};
29use tracing::trace;
30use vodozemac::Curve25519PublicKey;
31#[cfg(doc)]
32use vodozemac::ecies::CheckCode;
33
34use super::{
35 DeviceAuthorizationOAuthError, QRCodeLoginError, SecureChannelError,
36 messages::{LoginFailureReason, QrAuthMessage},
37 secure_channel::{EstablishedSecureChannel, SecureChannel},
38};
39use crate::{
40 Client,
41 authentication::oauth::{
42 ClientRegistrationData, OAuth, OAuthError,
43 qrcode::{CheckCodeSender, GeneratedQrProgress, LoginProtocolType, QrProgress},
44 },
45};
46
47async fn send_unexpected_message_error(
48 channel: &mut EstablishedSecureChannel,
49) -> Result<(), SecureChannelError> {
50 channel
51 .send_json(QrAuthMessage::LoginFailure {
52 reason: LoginFailureReason::UnexpectedMessageReceived,
53 homeserver: None,
54 })
55 .await
56}
57
58async fn finish_login<Q>(
59 client: &Client,
60 mut channel: EstablishedSecureChannel,
61 registration_data: Option<&ClientRegistrationData>,
62 state: SharedObservable<LoginProgress<Q>>,
63) -> Result<(), QRCodeLoginError> {
64 let oauth = client.oauth();
65
66 trace!("Registering the client with the OAuth 2.0 authorization server.");
68 let server_metadata = register_client(&oauth, registration_data).await?;
69
70 let account = vodozemac::olm::Account::new();
73 let public_key = account.identity_keys().curve25519;
74 let device_id = public_key;
75
76 trace!("Requesting device authorization.");
79 let auth_grant_response =
80 request_device_authorization(&oauth, &server_metadata, device_id).await?;
81
82 trace!("Letting the existing device know about the device authorization grant.");
85 let message =
86 QrAuthMessage::authorization_grant_login_protocol((&auth_grant_response).into(), device_id);
87 channel.send_json(&message).await?;
88
89 match channel.receive_json().await? {
91 QrAuthMessage::LoginProtocolAccepted => (),
92 QrAuthMessage::LoginFailure { reason, homeserver } => {
93 return Err(QRCodeLoginError::LoginFailure { reason, homeserver });
94 }
95 message => {
96 send_unexpected_message_error(&mut channel).await?;
97
98 return Err(QRCodeLoginError::UnexpectedMessage {
99 expected: "m.login.protocol_accepted",
100 received: message,
101 });
102 }
103 }
104
105 let user_code = auth_grant_response.user_code();
109 state.set(LoginProgress::WaitingForToken { user_code: user_code.secret().to_owned() });
110
111 trace!("Waiting for the OAuth 2.0 authorization server to give us the access token.");
114 if let Err(e) = wait_for_tokens(&oauth, &server_metadata, &auth_grant_response).await {
115 if let Some(e) = e.as_request_token_error() {
118 match e {
119 DeviceCodeErrorResponseType::AccessDenied => {
120 channel.send_json(QrAuthMessage::LoginDeclined).await?;
121 }
122 DeviceCodeErrorResponseType::ExpiredToken => {
123 channel
124 .send_json(QrAuthMessage::LoginFailure {
125 reason: LoginFailureReason::AuthorizationExpired,
126 homeserver: None,
127 })
128 .await?;
129 }
130 _ => (),
131 }
132 }
133
134 return Err(e.into());
135 }
136
137 trace!("Discovering our own user id.");
143 let whoami_response = client.whoami().await.map_err(QRCodeLoginError::UserIdDiscovery)?;
144 client
145 .base_client()
146 .activate(
147 SessionMeta {
148 user_id: whoami_response.user_id,
149 device_id: OwnedDeviceId::from(device_id.to_base64()),
150 },
151 RoomLoadSettings::default(),
152 Some(account),
153 )
154 .await
155 .map_err(|error| QRCodeLoginError::SessionTokens(error.into()))?;
156
157 client.oauth().enable_cross_process_lock().await?;
158
159 state.set(LoginProgress::SyncingSecrets);
160
161 trace!("Telling the existing device that we successfully logged in.");
163 let message = QrAuthMessage::LoginSuccess;
164 channel.send_json(&message).await?;
165
166 trace!("Waiting for the secrets bundle.");
169 let bundle = match channel.receive_json().await? {
170 QrAuthMessage::LoginSecrets(bundle) => bundle,
171 QrAuthMessage::LoginFailure { reason, homeserver } => {
172 return Err(QRCodeLoginError::LoginFailure { reason, homeserver });
173 }
174 message => {
175 send_unexpected_message_error(&mut channel).await?;
176
177 return Err(QRCodeLoginError::UnexpectedMessage {
178 expected: "m.login.secrets",
179 received: message,
180 });
181 }
182 };
183
184 client.encryption().import_secrets_bundle_impl(&bundle).await?;
187
188 client
191 .encryption()
192 .ensure_device_keys_upload()
193 .await
194 .map_err(QRCodeLoginError::DeviceKeyUpload)?;
195
196 client.encryption().spawn_initialization_task(None).await;
201 client.encryption().wait_for_e2ee_initialization_tasks().await;
202
203 trace!("successfully logged in and enabled E2EE.");
204
205 state.set(LoginProgress::Done);
207
208 Ok(())
210}
211
212async fn register_client(
216 oauth: &OAuth,
217 registration_data: Option<&ClientRegistrationData>,
218) -> Result<AuthorizationServerMetadata, DeviceAuthorizationOAuthError> {
219 let server_metadata = oauth.server_metadata().await.map_err(OAuthError::from)?;
220 oauth.use_registration_data(&server_metadata, registration_data).await?;
221
222 Ok(server_metadata)
223}
224
225async fn request_device_authorization(
226 oauth: &OAuth,
227 server_metadata: &AuthorizationServerMetadata,
228 device_id: Curve25519PublicKey,
229) -> Result<StandardDeviceAuthorizationResponse, DeviceAuthorizationOAuthError> {
230 let response = oauth
231 .request_device_authorization(server_metadata, Some(device_id.to_base64().into()))
232 .await?;
233 Ok(response)
234}
235
236async fn wait_for_tokens(
237 oauth: &OAuth,
238 server_metadata: &AuthorizationServerMetadata,
239 auth_response: &StandardDeviceAuthorizationResponse,
240) -> Result<(), DeviceAuthorizationOAuthError> {
241 oauth.exchange_device_code(server_metadata, auth_response).await?;
242 Ok(())
243}
244
245#[derive(Clone, Debug, Default)]
247pub enum LoginProgress<Q> {
248 #[default]
250 Starting,
251 EstablishingSecureChannel(Q),
254 WaitingForToken {
258 user_code: String,
262 },
263 SyncingSecrets,
265 Done,
267}
268
269#[derive(Debug)]
272pub struct LoginWithQrCode<'a> {
273 client: &'a Client,
274 registration_data: Option<&'a ClientRegistrationData>,
275 qr_code_data: &'a QrCodeData,
276 state: SharedObservable<LoginProgress<QrProgress>>,
277}
278
279impl LoginWithQrCode<'_> {
280 pub fn subscribe_to_progress(&self) -> impl Stream<Item = LoginProgress<QrProgress>> + use<> {
286 self.state.subscribe()
287 }
288}
289
290impl<'a> IntoFuture for LoginWithQrCode<'a> {
291 type Output = Result<(), QRCodeLoginError>;
292 boxed_into_future!(extra_bounds: 'a);
293
294 fn into_future(self) -> Self::IntoFuture {
295 Box::pin(async move {
296 let channel = self.establish_secure_channel().await?;
305
306 trace!("Established the secure channel.");
307
308 let check_code = channel.check_code().to_owned();
312 self.state.set(LoginProgress::EstablishingSecureChannel(QrProgress { check_code }));
313
314 finish_login(self.client, channel, self.registration_data, self.state).await
321 })
322 }
323}
324
325impl<'a> LoginWithQrCode<'a> {
326 pub(crate) fn new(
327 client: &'a Client,
328 qr_code_data: &'a QrCodeData,
329 registration_data: Option<&'a ClientRegistrationData>,
330 ) -> LoginWithQrCode<'a> {
331 LoginWithQrCode { client, registration_data, qr_code_data, state: Default::default() }
332 }
333
334 async fn establish_secure_channel(
335 &self,
336 ) -> Result<EstablishedSecureChannel, SecureChannelError> {
337 let http_client = self.client.inner.http_client.inner.clone();
338
339 let channel = EstablishedSecureChannel::from_qr_code(
340 http_client,
341 self.qr_code_data,
342 QrCodeIntent::Login,
343 )
344 .await?;
345
346 Ok(channel)
347 }
348}
349
350#[derive(Debug)]
353pub struct LoginWithGeneratedQrCode<'a> {
354 client: &'a Client,
355 registration_data: Option<&'a ClientRegistrationData>,
356 state: SharedObservable<LoginProgress<GeneratedQrProgress>>,
357}
358
359impl LoginWithGeneratedQrCode<'_> {
360 pub fn subscribe_to_progress(
365 &self,
366 ) -> impl Stream<Item = LoginProgress<GeneratedQrProgress>> + use<> {
367 self.state.subscribe()
368 }
369}
370
371impl<'a> IntoFuture for LoginWithGeneratedQrCode<'a> {
372 type Output = Result<(), QRCodeLoginError>;
373 boxed_into_future!(extra_bounds: 'a);
374
375 fn into_future(self) -> Self::IntoFuture {
376 Box::pin(async move {
377 let mut channel = self.establish_secure_channel().await?;
380
381 trace!("Established the secure channel.");
382
383 let message = channel.receive_json().await?;
387
388 let homeserver = match message {
391 QrAuthMessage::LoginProtocols { protocols, homeserver } => {
392 if !protocols.contains(&LoginProtocolType::DeviceAuthorizationGrant) {
393 channel
394 .send_json(QrAuthMessage::LoginFailure {
395 reason: LoginFailureReason::UnsupportedProtocol,
396 homeserver: None,
397 })
398 .await?;
399
400 return Err(QRCodeLoginError::LoginFailure {
401 reason: LoginFailureReason::UnsupportedProtocol,
402 homeserver: None,
403 });
404 }
405
406 homeserver
407 }
408 _ => {
409 send_unexpected_message_error(&mut channel).await?;
410
411 return Err(QRCodeLoginError::UnexpectedMessage {
412 expected: "m.login.protocols",
413 received: message,
414 });
415 }
416 };
417
418 if self.client.homeserver() != homeserver {
421 self.client
422 .switch_homeserver_and_re_resolve_well_known(homeserver)
423 .await
424 .map_err(QRCodeLoginError::ServerReset)?;
425 }
426
427 finish_login(self.client, channel, self.registration_data, self.state).await
430 })
431 }
432}
433
434impl<'a> LoginWithGeneratedQrCode<'a> {
435 pub(crate) fn new(
436 client: &'a Client,
437 registration_data: Option<&'a ClientRegistrationData>,
438 ) -> Self {
439 Self { client, registration_data, state: Default::default() }
440 }
441
442 async fn establish_secure_channel(
443 &self,
444 ) -> Result<EstablishedSecureChannel, SecureChannelError> {
445 let http_client = self.client.inner.http_client.clone();
446
447 let secure_channel = SecureChannel::login(http_client, &self.client.homeserver()).await?;
451
452 let qr_code_data = secure_channel.qr_code_data().clone();
456 trace!("Generated QR code.");
457 self.state.set(LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrReady(
458 qr_code_data,
459 )));
460
461 let channel = secure_channel.connect().await?;
465
466 trace!("Waiting for checkcode.");
471 let (tx, rx) = tokio::sync::oneshot::channel();
472 self.state.set(LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrScanned(
473 CheckCodeSender::new(tx),
474 )));
475
476 let check_code = rx.await.map_err(|_| SecureChannelError::CannotReceiveCheckCode)?;
480 trace!("Received check code.");
481 channel.confirm(check_code)
482 }
483}
484
485#[cfg(all(test, not(target_family = "wasm")))]
486mod test {
487 use std::time::Duration;
488
489 use assert_matches2::{assert_let, assert_matches};
490 use futures_util::StreamExt;
491 use matrix_sdk_base::crypto::types::{
492 SecretsBundle,
493 qr_login::{Msc4108IntentData, QrCodeIntentData},
494 };
495 use matrix_sdk_common::executor::spawn;
496 use matrix_sdk_test::async_test;
497 use serde_json::json;
498 use vodozemac::ecies::CheckCode;
499
500 use super::*;
501 use crate::{
502 authentication::oauth::qrcode::{
503 messages::LoginProtocolType,
504 secure_channel::{SecureChannel, test::MockedRendezvousServer},
505 },
506 config::RequestConfig,
507 http_client::HttpClient,
508 test_utils::{client::oauth::mock_client_metadata, mocks::MatrixMockServer},
509 };
510
511 enum AliceBehaviour {
512 HappyPath,
513 DeclinedProtocol,
514 UnexpectedMessage,
515 UnexpectedMessageInsteadOfSecrets,
516 RefuseSecrets,
517 LetSessionExpire,
518 }
519
520 enum TokenResponse {
522 Ok,
523 AccessDenied,
524 ExpiredToken,
525 }
526
527 fn secrets_bundle() -> SecretsBundle {
528 let json = json!({
529 "cross_signing": {
530 "master_key": "rTtSv67XGS6k/rg6/yTG/m573cyFTPFRqluFhQY+hSw",
531 "self_signing_key": "4jbPt7jh5D2iyM4U+3IDa+WthgJB87IQN1ATdkau+xk",
532 "user_signing_key": "YkFKtkjcsTxF6UAzIIG/l6Nog/G2RigCRfWj3cjNWeM",
533 },
534 });
535
536 serde_json::from_value(json).expect("We should be able to deserialize a secrets bundle")
537 }
538
539 async fn grant_login(
542 alice: SecureChannel,
543 check_code_receiver: tokio::sync::oneshot::Receiver<CheckCode>,
544 behavior: AliceBehaviour,
545 ) {
546 let alice = alice.connect().await.expect("Alice should be able to connect the channel");
547
548 let check_code =
549 check_code_receiver.await.expect("We should receive the check code from bob");
550
551 let mut alice = alice
552 .confirm(check_code.to_digit())
553 .expect("Alice should be able to confirm the secure channel");
554
555 let message = alice
556 .receive_json()
557 .await
558 .expect("Alice should be able to receive the initial message from Bob");
559
560 assert_let!(QrAuthMessage::LoginProtocol { protocol, .. } = message);
561 assert_eq!(protocol, LoginProtocolType::DeviceAuthorizationGrant);
562
563 let message = match behavior {
564 AliceBehaviour::DeclinedProtocol => QrAuthMessage::LoginFailure {
565 reason: LoginFailureReason::UnsupportedProtocol,
566 homeserver: None,
567 },
568 AliceBehaviour::UnexpectedMessage => QrAuthMessage::LoginDeclined,
569 _ => QrAuthMessage::LoginProtocolAccepted,
570 };
571
572 alice.send_json(message).await.unwrap();
573
574 let message: QrAuthMessage = alice.receive_json().await.unwrap();
575 assert_let!(QrAuthMessage::LoginSuccess = message);
576
577 let message = match behavior {
578 AliceBehaviour::UnexpectedMessageInsteadOfSecrets => QrAuthMessage::LoginDeclined,
579 AliceBehaviour::RefuseSecrets => QrAuthMessage::LoginFailure {
580 reason: LoginFailureReason::DeviceNotFound,
581 homeserver: None,
582 },
583 _ => QrAuthMessage::LoginSecrets(secrets_bundle()),
584 };
585
586 alice.send_json(message).await.unwrap();
587 }
588
589 #[async_test]
590 async fn test_qr_login() {
591 let server = MatrixMockServer::new().await;
592 let rendezvous_server =
593 MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await;
594 let (sender, receiver) = tokio::sync::oneshot::channel();
595
596 let oauth_server = server.oauth();
597 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
598 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
599 oauth_server
600 .mock_device_authorization()
601 .ok()
602 .expect(1)
603 .named("device_authorization")
604 .mount()
605 .await;
606 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
607
608 server.mock_versions().ok().expect(1..).named("versions").mount().await;
609 server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
610 server.mock_upload_keys().ok().expect(1).named("upload_keys").mount().await;
611 server.mock_query_keys().ok().expect(1).named("query_keys").mount().await;
612
613 let client = HttpClient::new(reqwest::Client::new(), Default::default());
614 let alice = SecureChannel::reciprocate(client, &rendezvous_server.homeserver_url)
615 .await
616 .expect("Alice should be able to create a secure channel.");
617
618 assert_let!(
619 QrCodeIntentData::Msc4108 {
620 data: Msc4108IntentData::Reciprocate { server_name },
621 ..
622 } = &alice.qr_code_data().intent_data()
623 );
624
625 let bob = Client::builder()
626 .server_name_or_homeserver_url(server_name)
627 .request_config(RequestConfig::new().disable_retry())
628 .build()
629 .await
630 .expect("We should be able to build the Client object from the URL in the QR code");
631
632 let qr_code = alice.qr_code_data().clone();
633
634 let oauth = bob.oauth();
635 let registration_data = mock_client_metadata().into();
636 let login_bob = oauth.login_with_qr_code(Some(®istration_data)).scan(&qr_code);
637 let mut updates = login_bob.subscribe_to_progress();
638
639 let updates_task = spawn(async move {
640 let mut sender = Some(sender);
641
642 while let Some(update) = updates.next().await {
643 match update {
644 LoginProgress::EstablishingSecureChannel(QrProgress { check_code }) => {
645 sender
646 .take()
647 .expect("The establishing secure channel update should be received only once")
648 .send(check_code)
649 .expect("Bob should be able to send the check code to Alice");
650 }
651 LoginProgress::Done => break,
652 _ => (),
653 }
654 }
655 });
656 let alice_task =
657 spawn(async { grant_login(alice, receiver, AliceBehaviour::HappyPath).await });
658
659 login_bob.await.expect("Bob should be able to login");
661 alice_task.await.expect("Alice should have completed it's task successfully");
662 updates_task.await.unwrap();
663
664 assert!(bob.encryption().cross_signing_status().await.unwrap().is_complete());
665 let own_identity =
666 bob.encryption().get_user_identity(bob.user_id().unwrap()).await.unwrap().unwrap();
667
668 assert!(own_identity.is_verified());
669 }
670
671 async fn grant_login_with_generated_qr(
672 alice: &Client,
673 qr_receiver: tokio::sync::oneshot::Receiver<QrCodeData>,
674 cctx_receiver: tokio::sync::oneshot::Receiver<CheckCodeSender>,
675 behavior: AliceBehaviour,
676 ) {
677 let qr_code_data = qr_receiver.await.expect("Alice should receive the QR code");
678
679 let mut channel = EstablishedSecureChannel::from_qr_code(
680 alice.inner.http_client.inner.clone(),
681 &qr_code_data,
682 QrCodeIntent::Reciprocate,
683 )
684 .await
685 .expect("Alice should be able to establish the secure channel");
686
687 trace!("Established the secure channel.");
688
689 let check_code = channel.check_code().to_digit();
692
693 let check_code_sender =
694 cctx_receiver.await.expect("Alice should receive the CheckCodeSender");
695
696 check_code_sender
697 .send(check_code)
698 .await
699 .expect("Alice should be able to send the check code to Bob");
700
701 let message = QrAuthMessage::LoginProtocols {
703 protocols: vec![LoginProtocolType::DeviceAuthorizationGrant],
704 homeserver: alice.homeserver(),
705 };
706 channel
707 .send_json(message)
708 .await
709 .expect("Alice should be able to send the `m.login.protocols` message to Bob");
710
711 let message: QrAuthMessage = channel
713 .receive_json()
714 .await
715 .expect("Alice should be able to receive the `m.login.protocol` message from Bob");
716 assert_let!(QrAuthMessage::LoginProtocol { protocol, .. } = message);
717 assert_eq!(protocol, LoginProtocolType::DeviceAuthorizationGrant);
718
719 let message = match behavior {
721 AliceBehaviour::DeclinedProtocol => QrAuthMessage::LoginFailure {
722 reason: LoginFailureReason::UnsupportedProtocol,
723 homeserver: None,
724 },
725 AliceBehaviour::UnexpectedMessage => QrAuthMessage::LoginDeclined,
726 _ => QrAuthMessage::LoginProtocolAccepted,
727 };
728 channel
729 .send_json(message)
730 .await
731 .expect("Alice should be able to send the `m.login.protocol_accepted` message to Bob");
732
733 let message: QrAuthMessage = channel
734 .receive_json()
735 .await
736 .expect("Alice should be able to receive the `m.login.success` message from Bob");
737 assert_let!(QrAuthMessage::LoginSuccess = message);
738
739 let message = match behavior {
741 AliceBehaviour::UnexpectedMessageInsteadOfSecrets => QrAuthMessage::LoginDeclined,
742 AliceBehaviour::RefuseSecrets => QrAuthMessage::LoginFailure {
743 reason: LoginFailureReason::DeviceNotFound,
744 homeserver: None,
745 },
746 _ => QrAuthMessage::LoginSecrets(secrets_bundle()),
747 };
748 channel
749 .send_json(message)
750 .await
751 .expect("Alice should be able to send the `m.login.secrets` message to Bob");
752 }
753
754 #[async_test]
755 async fn test_generated_qr_login() {
756 let server = MatrixMockServer::new().await;
757 let rendezvous_server =
758 MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await;
759 let (qr_sender, qr_receiver) = tokio::sync::oneshot::channel();
760 let (cctx_sender, cctx_receiver) = tokio::sync::oneshot::channel();
761
762 let oauth_server = server.oauth();
763 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
764 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
765 oauth_server
766 .mock_device_authorization()
767 .ok()
768 .expect(1)
769 .named("device_authorization")
770 .mount()
771 .await;
772 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
773
774 server.mock_versions().ok().expect(1..).named("versions").mount().await;
775 server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
776 server.mock_upload_keys().ok().expect(1).named("upload_keys").mount().await;
777 server.mock_query_keys().ok().expect(1).named("query_keys").mount().await;
778
779 let homeserver_url = rendezvous_server.homeserver_url.clone();
780
781 let alice = server.client_builder().logged_in_with_oauth().build().await;
784 assert!(alice.session_meta().is_some(), "Alice should be logged in");
785
786 let bob = Client::builder()
788 .server_name_or_homeserver_url(&homeserver_url)
789 .request_config(RequestConfig::new().disable_retry())
790 .build()
791 .await
792 .expect("Should be able to create a client for Bob");
793
794 let secure_channel = SecureChannel::login(bob.inner.http_client.clone(), &homeserver_url)
795 .await
796 .expect("Bob should be able to create a secure channel");
797
798 assert_matches!(
799 secure_channel.qr_code_data().intent_data(),
800 QrCodeIntentData::Msc4108 { data: Msc4108IntentData::Login, .. }
801 );
802
803 let registration_data = mock_client_metadata().into();
804 let bob_oauth = bob.oauth();
805 let bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate();
806 let mut bob_updates = bob_login.subscribe_to_progress();
807
808 let updates_task = spawn(async move {
809 let mut qr_sender = Some(qr_sender);
810 let mut cctx_sender = Some(cctx_sender);
811
812 while let Some(update) = bob_updates.next().await {
813 match update {
814 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrReady(qr)) => {
815 qr_sender
816 .take()
817 .expect("The establishing secure channel update with a qr code should be received only once")
818 .send(qr)
819 .expect("Bob should be able to send the qr code code to Alice");
820 }
821 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrScanned(
822 cctx,
823 )) => {
824 cctx_sender
825 .take()
826 .expect("The establishing secure channel update with a CheckCodeSender should be received only once")
827 .send(cctx)
828 .expect("Bob should be able to send the qr code code to Alice");
829 }
830 LoginProgress::Done => break,
831 _ => (),
832 }
833 }
834 });
835
836 let alice_task = spawn(async move {
837 grant_login_with_generated_qr(
838 &alice,
839 qr_receiver,
840 cctx_receiver,
841 AliceBehaviour::HappyPath,
842 )
843 .await
844 });
845
846 bob_login.await.expect("Bob should be able to login");
848 alice_task.await.expect("Alice should have completed it's task successfully");
849 updates_task.await.unwrap();
850
851 assert!(bob.encryption().cross_signing_status().await.unwrap().is_complete());
852 let own_identity =
853 bob.encryption().get_user_identity(bob.user_id().unwrap()).await.unwrap().unwrap();
854
855 assert!(own_identity.is_verified());
856 }
857
858 #[async_test]
859 async fn test_generated_qr_login_with_homeserver_swap() {
860 let initial_server = MatrixMockServer::new().await;
861 let rendezvous_server =
862 MockedRendezvousServer::new(initial_server.server(), "abcdEFG12345", Duration::MAX)
863 .await;
864 let (qr_sender, qr_receiver) = tokio::sync::oneshot::channel();
865 let (cctx_sender, cctx_receiver) = tokio::sync::oneshot::channel();
866
867 let login_server = MatrixMockServer::new().await;
868 let oauth_server = login_server.oauth();
869 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
870 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
871 oauth_server
872 .mock_device_authorization()
873 .ok()
874 .expect(1)
875 .named("device_authorization")
876 .mount()
877 .await;
878 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
879
880 initial_server.mock_versions().ok().expect(1..).named("versions").mount().await;
881
882 login_server.mock_well_known().ok().expect(1).named("well_known").mount().await;
883 login_server.mock_versions().ok().expect(1..).named("versions").mount().await;
884 login_server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
885 login_server.mock_upload_keys().ok().expect(1).named("upload_keys").mount().await;
886 login_server.mock_query_keys().ok().expect(1).named("query_keys").mount().await;
887
888 let rendezvous_homeserver_url = rendezvous_server.homeserver_url.clone();
889
890 let alice = login_server.client_builder().logged_in_with_oauth().build().await;
893 assert!(alice.session_meta().is_some(), "Alice should be logged in");
894
895 let bob = Client::builder()
897 .server_name_or_homeserver_url(&rendezvous_homeserver_url)
898 .request_config(RequestConfig::new().disable_retry())
899 .build()
900 .await
901 .expect("Should be able to create a client for Bob");
902
903 let secure_channel =
904 SecureChannel::login(bob.inner.http_client.clone(), &rendezvous_homeserver_url)
905 .await
906 .expect("Bob should be able to create a secure channel");
907
908 assert_matches!(
909 secure_channel.qr_code_data().intent_data(),
910 QrCodeIntentData::Msc4108 { data: Msc4108IntentData::Login, .. }
911 );
912
913 let initial_server_url = initial_server.server().uri().parse().unwrap();
915 assert_eq!(bob.homeserver(), initial_server_url);
916 let login_server_url = login_server.server().uri().parse().unwrap();
917 assert_eq!(alice.homeserver(), login_server_url);
918 assert_ne!(initial_server_url, login_server_url);
919
920 let registration_data = mock_client_metadata().into();
921 let bob_oauth = bob.oauth();
922 let bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate();
923 let mut bob_updates = bob_login.subscribe_to_progress();
924
925 let updates_task = spawn(async move {
926 let mut qr_sender = Some(qr_sender);
927 let mut cctx_sender = Some(cctx_sender);
928
929 while let Some(update) = bob_updates.next().await {
930 match update {
931 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrReady(qr)) => {
932 qr_sender
933 .take()
934 .expect("The establishing secure channel update with a qr code should be received only once")
935 .send(qr)
936 .expect("Bob should be able to send the qr code code to Alice");
937 }
938 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrScanned(
939 cctx,
940 )) => {
941 cctx_sender
942 .take()
943 .expect("The establishing secure channel update with a CheckCodeSender should be received only once")
944 .send(cctx)
945 .expect("Bob should be able to send the qr code code to Alice");
946 }
947 LoginProgress::Done => break,
948 _ => (),
949 }
950 }
951 });
952
953 let alice_task = spawn(async move {
954 grant_login_with_generated_qr(
955 &alice,
956 qr_receiver,
957 cctx_receiver,
958 AliceBehaviour::HappyPath,
959 )
960 .await
961 });
962
963 bob_login.await.expect("Bob should be able to login");
965 alice_task.await.expect("Alice should have completed it's task successfully");
966 updates_task.await.unwrap();
967
968 assert!(bob.encryption().cross_signing_status().await.unwrap().is_complete());
969 let own_identity =
970 bob.encryption().get_user_identity(bob.user_id().unwrap()).await.unwrap().unwrap();
971
972 assert!(own_identity.is_verified());
973
974 assert_eq!(bob.homeserver(), login_server_url);
976 }
977
978 async fn test_failure(
979 token_response: TokenResponse,
980 alice_behavior: AliceBehaviour,
981 ) -> Result<(), QRCodeLoginError> {
982 let server = MatrixMockServer::new().await;
983 let expiration = match alice_behavior {
984 AliceBehaviour::LetSessionExpire => Duration::from_secs(2),
985 _ => Duration::MAX,
986 };
987 let rendezvous_server =
988 MockedRendezvousServer::new(server.server(), "abcdEFG12345", expiration).await;
989 let (sender, receiver) = tokio::sync::oneshot::channel();
990
991 let oauth_server = server.oauth();
992 let expected_calls = match alice_behavior {
993 AliceBehaviour::LetSessionExpire => 0,
994 _ => 1,
995 };
996 oauth_server
997 .mock_server_metadata()
998 .ok()
999 .expect(expected_calls)
1000 .named("server_metadata")
1001 .mount()
1002 .await;
1003 oauth_server
1004 .mock_registration()
1005 .ok()
1006 .expect(expected_calls)
1007 .named("registration")
1008 .mount()
1009 .await;
1010 oauth_server
1011 .mock_device_authorization()
1012 .ok()
1013 .expect(expected_calls)
1014 .named("device_authorization")
1015 .mount()
1016 .await;
1017
1018 let token_mock = oauth_server.mock_token();
1019 let token_mock = match token_response {
1020 TokenResponse::Ok => token_mock.ok(),
1021 TokenResponse::AccessDenied => token_mock.access_denied(),
1022 TokenResponse::ExpiredToken => token_mock.expired_token(),
1023 };
1024 token_mock.named("token").mount().await;
1025
1026 server.mock_versions().ok().named("versions").mount().await;
1027 server.mock_who_am_i().ok().named("whoami").mount().await;
1028
1029 let client = HttpClient::new(reqwest::Client::new(), Default::default());
1030 let alice = SecureChannel::reciprocate(client, &rendezvous_server.homeserver_url)
1031 .await
1032 .expect("Alice should be able to create a secure channel.");
1033
1034 assert_let!(
1035 QrCodeIntentData::Msc4108 {
1036 data: Msc4108IntentData::Reciprocate { server_name },
1037 ..
1038 } = &alice.qr_code_data().intent_data()
1039 );
1040
1041 let bob = Client::builder()
1042 .server_name_or_homeserver_url(server_name)
1043 .request_config(RequestConfig::new().disable_retry())
1044 .build()
1045 .await
1046 .expect("We should be able to build the Client object from the URL in the QR code");
1047
1048 let qr_code = alice.qr_code_data().clone();
1049
1050 let oauth = bob.oauth();
1051 let registration_data = mock_client_metadata().into();
1052 let login_bob = oauth.login_with_qr_code(Some(®istration_data)).scan(&qr_code);
1053 let mut updates = login_bob.subscribe_to_progress();
1054
1055 let _updates_task = spawn(async move {
1056 let mut sender = Some(sender);
1057
1058 while let Some(update) = updates.next().await {
1059 match update {
1060 LoginProgress::EstablishingSecureChannel(QrProgress { check_code }) => {
1061 sender
1062 .take()
1063 .expect("The establishing secure channel update should be received only once")
1064 .send(check_code)
1065 .expect("Bob should be able to send the check code to Alice");
1066 }
1067 LoginProgress::Done => break,
1068 _ => (),
1069 }
1070 }
1071 });
1072
1073 if !matches!(alice_behavior, AliceBehaviour::LetSessionExpire) {
1074 let _alice_task =
1075 spawn(async move { grant_login(alice, receiver, alice_behavior).await });
1076 }
1077
1078 login_bob.await
1079 }
1080
1081 async fn test_generated_failure(
1082 token_response: TokenResponse,
1083 alice_behavior: AliceBehaviour,
1084 ) -> Result<(), QRCodeLoginError> {
1085 let server = MatrixMockServer::new().await;
1086 let expiration = match alice_behavior {
1087 AliceBehaviour::LetSessionExpire => Duration::from_secs(2),
1088 _ => Duration::MAX,
1089 };
1090 let rendezvous_server =
1091 MockedRendezvousServer::new(server.server(), "abcdEFG12345", expiration).await;
1092
1093 let (qr_sender, qr_receiver) = tokio::sync::oneshot::channel();
1094 let (cctx_sender, cctx_receiver) = tokio::sync::oneshot::channel();
1095
1096 let oauth_server = server.oauth();
1097 let expected_calls = match alice_behavior {
1098 AliceBehaviour::LetSessionExpire => 0,
1099 _ => 1,
1100 };
1101 oauth_server
1102 .mock_server_metadata()
1103 .ok()
1104 .expect(expected_calls)
1105 .named("server_metadata")
1106 .mount()
1107 .await;
1108 oauth_server
1109 .mock_registration()
1110 .ok()
1111 .expect(expected_calls)
1112 .named("registration")
1113 .mount()
1114 .await;
1115 oauth_server
1116 .mock_device_authorization()
1117 .ok()
1118 .expect(expected_calls)
1119 .named("device_authorization")
1120 .mount()
1121 .await;
1122
1123 let token_mock = oauth_server.mock_token();
1124 let token_mock = match token_response {
1125 TokenResponse::Ok => token_mock.ok(),
1126 TokenResponse::AccessDenied => token_mock.access_denied(),
1127 TokenResponse::ExpiredToken => token_mock.expired_token(),
1128 };
1129 token_mock.named("token").mount().await;
1130
1131 server.mock_versions().ok().named("versions").mount().await;
1132 server.mock_who_am_i().ok().named("whoami").mount().await;
1133
1134 let homeserver_url = rendezvous_server.homeserver_url.clone();
1135
1136 let alice = server.client_builder().logged_in_with_oauth().build().await;
1139 assert!(alice.session_meta().is_some(), "Alice should be logged in");
1140
1141 let bob = Client::builder()
1143 .server_name_or_homeserver_url(&homeserver_url)
1144 .request_config(RequestConfig::new().disable_retry())
1145 .build()
1146 .await
1147 .expect("Should be able to create a client for Bob");
1148
1149 let secure_channel = SecureChannel::login(bob.inner.http_client.clone(), &homeserver_url)
1150 .await
1151 .expect("Bob should be able to create a secure channel");
1152
1153 assert_matches!(
1154 secure_channel.qr_code_data().intent_data(),
1155 QrCodeIntentData::Msc4108 { data: Msc4108IntentData::Login, .. }
1156 );
1157
1158 let registration_data = mock_client_metadata().into();
1159 let bob_oauth = bob.oauth();
1160 let bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate();
1161 let mut bob_updates = bob_login.subscribe_to_progress();
1162
1163 let _updates_task = spawn(async move {
1164 let mut qr_sender = Some(qr_sender);
1165 let mut cctx_sender = Some(cctx_sender);
1166
1167 while let Some(update) = bob_updates.next().await {
1168 match update {
1169 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrReady(qr)) => {
1170 qr_sender
1171 .take()
1172 .expect("The establishing secure channel update with a qr code should be received only once")
1173 .send(qr)
1174 .expect("Bob should be able to send the qr code code to Alice");
1175 }
1176 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrScanned(
1177 cctx,
1178 )) => {
1179 cctx_sender
1180 .take()
1181 .expect("The establishing secure channel update with a CheckCodeSender should be received only once")
1182 .send(cctx)
1183 .expect("Bob should be able to send the qr code code to Alice");
1184 }
1185 LoginProgress::Done => break,
1186 _ => (),
1187 }
1188 }
1189 });
1190
1191 if !matches!(alice_behavior, AliceBehaviour::LetSessionExpire) {
1192 let _alice_task = spawn(async move {
1193 grant_login_with_generated_qr(&alice, qr_receiver, cctx_receiver, alice_behavior)
1194 .await
1195 });
1196 }
1197
1198 bob_login.await
1199 }
1200
1201 #[async_test]
1202 async fn test_qr_login_refused_access_token() {
1203 let result = test_failure(TokenResponse::AccessDenied, AliceBehaviour::HappyPath).await;
1204
1205 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
1206 assert_eq!(
1207 e.as_request_token_error(),
1208 Some(&DeviceCodeErrorResponseType::AccessDenied),
1209 "The server should have told us that access has been denied."
1210 );
1211 }
1212
1213 #[async_test]
1214 async fn test_generated_qr_login_refused_access_token() {
1215 let result =
1216 test_generated_failure(TokenResponse::AccessDenied, AliceBehaviour::HappyPath).await;
1217
1218 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
1219 assert_eq!(
1220 e.as_request_token_error(),
1221 Some(&DeviceCodeErrorResponseType::AccessDenied),
1222 "The server should have told us that access has been denied."
1223 );
1224 }
1225
1226 #[async_test]
1227 async fn test_qr_login_expired_token() {
1228 let result = test_failure(TokenResponse::ExpiredToken, AliceBehaviour::HappyPath).await;
1229
1230 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
1231 assert_eq!(
1232 e.as_request_token_error(),
1233 Some(&DeviceCodeErrorResponseType::ExpiredToken),
1234 "The server should have told us that access has been denied."
1235 );
1236 }
1237
1238 #[async_test]
1239 async fn test_generated_qr_login_expired_token() {
1240 let result =
1241 test_generated_failure(TokenResponse::ExpiredToken, AliceBehaviour::HappyPath).await;
1242
1243 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
1244 assert_eq!(
1245 e.as_request_token_error(),
1246 Some(&DeviceCodeErrorResponseType::ExpiredToken),
1247 "The server should have told us that access has been denied."
1248 );
1249 }
1250
1251 #[async_test]
1252 async fn test_qr_login_declined_protocol() {
1253 let result = test_failure(TokenResponse::Ok, AliceBehaviour::DeclinedProtocol).await;
1254
1255 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
1256 assert_eq!(
1257 reason,
1258 LoginFailureReason::UnsupportedProtocol,
1259 "Alice should have told us that the protocol is unsupported."
1260 );
1261 }
1262
1263 #[async_test]
1264 async fn test_generated_qr_login_declined_protocol() {
1265 let result =
1266 test_generated_failure(TokenResponse::Ok, AliceBehaviour::DeclinedProtocol).await;
1267
1268 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
1269 assert_eq!(
1270 reason,
1271 LoginFailureReason::UnsupportedProtocol,
1272 "Alice should have told us that the protocol is unsupported."
1273 );
1274 }
1275
1276 #[async_test]
1277 async fn test_qr_login_unexpected_message() {
1278 let result = test_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessage).await;
1279
1280 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
1281 assert_eq!(expected, "m.login.protocol_accepted");
1282 }
1283
1284 #[async_test]
1285 async fn test_generated_qr_login_unexpected_message() {
1286 let result =
1287 test_generated_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessage).await;
1288
1289 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
1290 assert_eq!(expected, "m.login.protocol_accepted");
1291 }
1292
1293 #[async_test]
1294 async fn test_qr_login_unexpected_message_instead_of_secrets() {
1295 let result =
1296 test_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessageInsteadOfSecrets)
1297 .await;
1298
1299 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
1300 assert_eq!(expected, "m.login.secrets");
1301 }
1302
1303 #[async_test]
1304 async fn test_generated_qr_login_unexpected_message_instead_of_secrets() {
1305 let result = test_generated_failure(
1306 TokenResponse::Ok,
1307 AliceBehaviour::UnexpectedMessageInsteadOfSecrets,
1308 )
1309 .await;
1310
1311 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
1312 assert_eq!(expected, "m.login.secrets");
1313 }
1314
1315 #[async_test]
1316 async fn test_qr_login_refuse_secrets() {
1317 let result = test_failure(TokenResponse::Ok, AliceBehaviour::RefuseSecrets).await;
1318
1319 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
1320 assert_eq!(reason, LoginFailureReason::DeviceNotFound);
1321 }
1322
1323 #[async_test]
1324 async fn test_generated_qr_login_refuse_secrets() {
1325 let result = test_generated_failure(TokenResponse::Ok, AliceBehaviour::RefuseSecrets).await;
1326
1327 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
1328 assert_eq!(reason, LoginFailureReason::DeviceNotFound);
1329 }
1330
1331 #[async_test]
1332 async fn test_qr_login_session_expired() {
1333 let result = test_failure(TokenResponse::Ok, AliceBehaviour::LetSessionExpire).await;
1334
1335 assert_matches!(result, Err(QRCodeLoginError::NotFound));
1336 }
1337
1338 #[async_test]
1339 async fn test_generated_qr_login_session_expired() {
1340 let result =
1341 test_generated_failure(TokenResponse::Ok, AliceBehaviour::LetSessionExpire).await;
1342
1343 assert_matches!(result, Err(QRCodeLoginError::NotFound));
1344 }
1345
1346 #[async_test]
1347 async fn test_device_authorization_endpoint_missing() {
1348 let server = MatrixMockServer::new().await;
1349 let rendezvous_server =
1350 MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await;
1351 let (sender, receiver) = tokio::sync::oneshot::channel();
1352
1353 let oauth_server = server.oauth();
1354 oauth_server
1355 .mock_server_metadata()
1356 .ok_without_device_authorization()
1357 .expect(1)
1358 .named("server_metadata")
1359 .mount()
1360 .await;
1361 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
1362
1363 server.mock_versions().ok().named("versions").mount().await;
1364 server.mock_who_am_i().ok().named("whoami").mount().await;
1365
1366 let client = HttpClient::new(reqwest::Client::new(), Default::default());
1367 let alice = SecureChannel::reciprocate(client, &rendezvous_server.homeserver_url)
1368 .await
1369 .expect("Alice should be able to create a secure channel.");
1370
1371 assert_let!(
1372 QrCodeIntentData::Msc4108 {
1373 data: Msc4108IntentData::Reciprocate { server_name },
1374 ..
1375 } = &alice.qr_code_data().intent_data()
1376 );
1377
1378 let bob = Client::builder()
1379 .server_name_or_homeserver_url(server_name)
1380 .request_config(RequestConfig::new().disable_retry())
1381 .build()
1382 .await
1383 .expect("We should be able to build the Client object from the URL in the QR code");
1384
1385 let qr_code = alice.qr_code_data().clone();
1386
1387 let oauth = bob.oauth();
1388 let registration_data = mock_client_metadata().into();
1389 let login_bob = oauth.login_with_qr_code(Some(®istration_data)).scan(&qr_code);
1390 let mut updates = login_bob.subscribe_to_progress();
1391
1392 let _updates_task = spawn(async move {
1393 let mut sender = Some(sender);
1394
1395 while let Some(update) = updates.next().await {
1396 match update {
1397 LoginProgress::EstablishingSecureChannel(QrProgress { check_code }) => {
1398 sender
1399 .take()
1400 .expect("The establishing secure channel update should be received only once")
1401 .send(check_code)
1402 .expect("Bob should be able to send the check code to Alice");
1403 }
1404 LoginProgress::Done => break,
1405 _ => (),
1406 }
1407 }
1408 });
1409 let _alice_task =
1410 spawn(async move { grant_login(alice, receiver, AliceBehaviour::HappyPath).await });
1411 let error = login_bob.await.unwrap_err();
1412
1413 assert_matches!(
1414 error,
1415 QRCodeLoginError::OAuth(DeviceAuthorizationOAuthError::NoDeviceAuthorizationEndpoint)
1416 );
1417 }
1418}