1use std::{
16 convert::Infallible,
17 future::IntoFuture,
18 io::{self, Write},
19 ops::Range,
20 path::{Path, PathBuf},
21 sync::{Arc, Mutex},
22};
23
24use anyhow::{anyhow, bail};
25use axum::{
26 http::{Method, Request, StatusCode},
27 response::IntoResponse,
28 routing::any_service,
29};
30use futures_util::StreamExt;
31use matrix_sdk::{
32 authentication::oidc::{
33 registrations::ClientId,
34 requests::account_management::AccountManagementActionFull,
35 types::{
36 iana::oauth::OAuthClientAuthenticationMethod,
37 oidc::ApplicationType,
38 registration::{ClientMetadata, Localized, VerifiedClientMetadata},
39 requests::GrantType,
40 },
41 AuthorizationCode, AuthorizationResponse, OidcAuthorizationData, OidcSession, UserSession,
42 },
43 config::SyncSettings,
44 encryption::{recovery::RecoveryState, CrossSigningResetAuthType},
45 room::Room,
46 ruma::events::room::message::{MessageType, OriginalSyncRoomMessageEvent},
47 Client, ClientBuildError, Result, RoomState,
48};
49use matrix_sdk_ui::sync_service::SyncService;
50use rand::{distributions::Alphanumeric, thread_rng, Rng};
51use serde::{Deserialize, Serialize};
52use tokio::{fs, io::AsyncBufReadExt as _, net::TcpListener, sync::oneshot};
53use tower::service_fn;
54use url::Url;
55
56#[tokio::main]
72async fn main() -> anyhow::Result<()> {
73 tracing_subscriber::fmt::init();
74
75 let data_dir =
77 dirs::data_dir().expect("no data_dir directory found").join("matrix_sdk/oidc_cli");
78 let session_file = data_dir.join("session.json");
80
81 let cli = if session_file.exists() {
82 OidcCli::from_stored_session(session_file).await?
83 } else {
84 OidcCli::new(&data_dir, session_file).await?
85 };
86
87 cli.run().await
88}
89
90fn help() {
92 println!("Usage: [command] [args…]\n");
93 println!("Commands:");
94 println!(" whoami Get information about this session");
95 println!(" account Get the URL to manage this account");
96 println!(" watch [sliding?] Watch new incoming messages until an error occurs");
97 println!(" refresh Refresh the access token");
98 println!(" recover Recover the E2EE secrets from secret storage");
99 println!(" logout Log out of this account");
100 println!(" exit Exit this program");
101 println!(" help Show this message\n");
102}
103
104#[derive(Debug, Serialize, Deserialize)]
106struct ClientSession {
107 homeserver: String,
109
110 db_path: PathBuf,
112
113 passphrase: String,
115}
116
117#[derive(Debug, Serialize, Deserialize)]
119struct Credentials {
120 client_id: String,
122}
123
124#[derive(Debug, Serialize, Deserialize)]
126struct StoredSession {
127 client_session: ClientSession,
129
130 user_session: UserSession,
132
133 client_credentials: Credentials,
135}
136
137#[derive(Clone, Debug)]
139struct OidcCli {
140 client: Client,
142
143 restored: bool,
145
146 session_file: PathBuf,
148}
149
150impl OidcCli {
151 async fn new(data_dir: &Path, session_file: PathBuf) -> anyhow::Result<Self> {
153 println!("No previous session found, logging in…");
154
155 let (client, client_session) = build_client(data_dir).await?;
156 let cli = Self { client, restored: false, session_file };
157
158 let client_id = cli.register_client().await?;
159 cli.login().await?;
160
161 let user_session =
166 cli.client.oidc().user_session().expect("A logged-in client should have a session");
167
168 let client_credentials = Credentials { client_id };
174
175 let serialized_session = serde_json::to_string(&StoredSession {
176 client_session,
177 user_session,
178 client_credentials,
179 })?;
180 fs::write(&cli.session_file, serialized_session).await?;
181
182 println!("Session persisted in {}", cli.session_file.to_string_lossy());
183
184 cli.setup_background_save();
185
186 Ok(cli)
187 }
188
189 async fn register_client(&self) -> anyhow::Result<String> {
193 let oidc = self.client.oidc();
194
195 let provider_metadata = oidc.provider_metadata().await?;
196
197 if provider_metadata.registration_endpoint.is_none() {
198 bail!(
201 "This provider doesn't support dynamic registration.\n\
202 Please select another homeserver."
203 );
204 }
205
206 let metadata = client_metadata();
207
208 let res = oidc.register_client(metadata.clone(), None).await?;
214
215 println!("\nRegistered successfully");
216
217 Ok(res.client_id)
218 }
219
220 async fn login(&self) -> anyhow::Result<()> {
222 let oidc = self.client.oidc();
223
224 loop {
226 let (redirect_uri, data_rx, signal_tx) = spawn_local_server().await?;
230
231 let OidcAuthorizationData { url, state } =
232 oidc.login(redirect_uri, None)?.build().await?;
233
234 let authorization_code = match use_auth_url(&url, &state, data_rx, signal_tx).await {
235 Ok(code) => code,
236 Err(err) => {
237 oidc.abort_authorization(&state).await;
238 return Err(err);
239 }
240 };
241
242 let res = oidc.finish_authorization(authorization_code).await;
243
244 if let Err(err) = res {
245 println!("Error: failed to login: {err}");
246 println!("Please try again.\n");
247 continue;
248 }
249
250 match oidc.finish_login().await {
251 Ok(()) => {
252 let user_id = self.client.user_id().expect("Got a user ID");
253 println!("Logged in as {user_id}");
254 break;
255 }
256 Err(err) => {
257 println!("Error: failed to finish login: {err}");
258 println!("Please try again.\n");
259 continue;
260 }
261 }
262 }
263
264 Ok(())
265 }
266
267 async fn from_stored_session(session_file: PathBuf) -> anyhow::Result<Self> {
269 println!("Previous session found in '{}'", session_file.to_string_lossy());
270
271 let serialized_session = fs::read_to_string(&session_file).await?;
273 let StoredSession { client_session, user_session, client_credentials } =
274 serde_json::from_str(&serialized_session)?;
275
276 let client = Client::builder()
278 .homeserver_url(client_session.homeserver)
279 .handle_refresh_tokens()
280 .sqlite_store(client_session.db_path, Some(&client_session.passphrase))
281 .build()
282 .await?;
283
284 println!("Restoring session for {}…", user_session.meta.user_id);
285
286 let session = OidcSession {
287 client_id: ClientId(client_credentials.client_id),
288 metadata: client_metadata(),
289 user: user_session,
290 };
291 client.restore_session(session).await?;
293
294 let this = Self { client, restored: true, session_file };
295
296 this.setup_background_save();
297
298 Ok(this)
299 }
300
301 async fn run(&self) -> anyhow::Result<()> {
303 help();
304
305 loop {
306 let mut input = String::new();
307
308 print!("\nEnter command: ");
309 io::stdout().flush().expect("Unable to write to stdout");
310
311 io::stdin().read_line(&mut input).expect("Unable to read user input");
312
313 let mut args = input.trim().split_ascii_whitespace();
314 let cmd = args.next();
315
316 match cmd {
317 Some("whoami") => {
318 self.whoami();
319 }
320 Some("account") => {
321 self.account(None).await;
322 }
323 Some("profile") => {
324 self.account(Some(AccountManagementActionFull::Profile)).await;
325 }
326 Some("sessions") => {
327 self.account(Some(AccountManagementActionFull::SessionsList)).await;
328 }
329 Some("watch") => match args.next() {
330 Some(sub) => {
331 if sub == "sliding" {
332 self.watch_sliding_sync().await?;
333 } else {
334 println!("unknown subcommand for watch: available is 'sliding'");
335 }
336 }
337 None => self.watch().await?,
338 },
339 Some("refresh") => {
340 self.refresh_token().await?;
341 }
342 Some("recover") => {
343 self.recover().await?;
344 }
345 Some("reset-cross-signing") => {
346 self.reset_cross_signing().await?;
347 }
348 Some("logout") => {
349 self.logout().await?;
350 break;
351 }
352 Some("exit") => {
353 break;
354 }
355 Some("help") => {
356 help();
357 }
358 Some(cmd) => {
359 println!("Error: unknown command '{cmd}'\n");
360 help();
361 }
362 None => {
363 println!("Error: no command\n");
364 help()
365 }
366 };
367 }
368
369 Ok(())
370 }
371
372 async fn recover(&self) -> anyhow::Result<()> {
373 let recovery = self.client.encryption().recovery();
374
375 println!("Please enter your recovery key:");
376
377 let mut input = String::new();
378 io::stdin().read_line(&mut input).expect("error: unable to read user input");
379
380 let input = input.trim();
381
382 recovery.recover(input).await?;
383
384 match recovery.state() {
385 RecoveryState::Enabled => println!("Successfully recovered all the E2EE secrets."),
386 RecoveryState::Disabled => println!("Error recovering, recovery is disabled."),
387 RecoveryState::Incomplete => println!("Couldn't recover all E2EE secrets."),
388 _ => unreachable!("We should know our recovery state by now"),
389 }
390
391 Ok(())
392 }
393
394 async fn reset_cross_signing(&self) -> Result<()> {
395 let encryption = self.client.encryption();
396
397 if let Some(handle) = encryption.reset_cross_signing().await? {
398 match handle.auth_type() {
399 CrossSigningResetAuthType::Uiaa(_) => {
400 unimplemented!("This should never happen, this is after all the OIDC example.")
401 }
402 CrossSigningResetAuthType::Oidc(o) => {
403 println!(
404 "To reset your end-to-end encryption cross-signing identity, \
405 you first need to approve it at {}",
406 o.approval_url
407 );
408 handle.auth(None).await?;
409 }
410 }
411 }
412
413 print!("Successfully reset cross-signing");
414
415 Ok(())
416 }
417
418 fn whoami(&self) {
420 let client = &self.client;
421 let oidc = client.oidc();
422
423 let user_id = client.user_id().expect("A logged in client has a user ID");
424 let device_id = client.device_id().expect("A logged in client has a device ID");
425 let homeserver = client.homeserver();
426 let issuer = oidc.issuer().expect("A logged in OIDC client has an issuer");
427
428 println!("\nUser ID: {user_id}");
429 println!("Device ID: {device_id}");
430 println!("Homeserver URL: {homeserver}");
431 println!("OpenID Connect provider: {issuer}");
432 }
433
434 async fn account(&self, action: Option<AccountManagementActionFull>) {
436 match self.client.oidc().fetch_account_management_url(action).await {
437 Ok(Some(url)) => {
438 println!("\nTo manage your account, visit: {url}");
439 }
440 _ => {
441 println!("\nThis homeserver does not provide the URL to manage your account")
442 }
443 }
444 }
445
446 async fn watch(&self) -> anyhow::Result<()> {
448 let client = &self.client;
449
450 if !self.restored {
454 client.sync_once(SyncSettings::default()).await.unwrap();
455 }
456
457 let handle = client.add_event_handler(on_room_message);
459
460 let mut sync_stream = Box::pin(client.sync_stream(SyncSettings::default()).await);
462 while let Some(res) = sync_stream.next().await {
463 if let Err(err) = res {
464 client.remove_event_handler(handle);
465 return Err(err.into());
466 }
467 }
468
469 Ok(())
470 }
471
472 async fn watch_sliding_sync(&self) -> anyhow::Result<()> {
475 let sync_service = Arc::new(SyncService::builder(self.client.clone()).build().await?);
476
477 sync_service.start().await;
478
479 println!("press enter to exit the sync loop");
480
481 let mut sync_service_state = sync_service.state();
482
483 let sync_service_clone = sync_service.clone();
484 let task = tokio::spawn(async move {
485 let mut num_errors = 0;
494 let mut num_running = 0;
495
496 let mut _unused = String::new();
497 let mut stdin = tokio::io::BufReader::new(tokio::io::stdin());
498
499 loop {
500 tokio::select! {
503 res = sync_service_state.next() => {
504 if let Some(state) = res {
505 match state {
506 matrix_sdk_ui::sync_service::State::Idle
507 | matrix_sdk_ui::sync_service::State::Terminated => {
508 num_errors = 0;
509 num_running = 0;
510 }
511
512 matrix_sdk_ui::sync_service::State::Running => {
513 num_running += 1;
514 if num_running > 1 {
515 num_errors = 0;
516 }
517 }
518
519 matrix_sdk_ui::sync_service::State::Error | matrix_sdk_ui::sync_service::State::Offline => {
520 num_errors += 1;
521 num_running = 0;
522
523 if num_errors == 5 {
524 println!("ran into 5 errors in a row, terminating");
525 break;
526 }
527
528 sync_service_clone.start().await;
529 }
530 }
531 println!("New sync service state update: {state:?}");
532 } else {
533 break;
534 }
535 }
536
537 _ = stdin.read_line(&mut _unused) => {
538 println!("Stopping loop because of user request");
539 sync_service.stop().await;
540
541 break;
542 }
543 }
544 }
545 });
546
547 println!("waiting for sync service to stop...");
548 task.await.unwrap();
549
550 println!("done!");
551 Ok(())
552 }
553
554 fn setup_background_save(&self) {
559 let this = self.clone();
560 tokio::spawn(async move {
561 while let Ok(update) = this.client.subscribe_to_session_changes().recv().await {
562 match update {
563 matrix_sdk::SessionChange::UnknownToken { soft_logout } => {
564 println!("Received an unknown token error; soft logout? {soft_logout:?}");
565 }
566 matrix_sdk::SessionChange::TokensRefreshed => {
567 if let Err(err) = this.update_stored_session().await {
569 println!("Unable to store a session in the background: {err}");
570 }
571 }
572 }
573 }
574 });
575 }
576
577 async fn update_stored_session(&self) -> anyhow::Result<()> {
582 println!("Updating the stored session...");
583
584 let serialized_session = fs::read_to_string(&self.session_file).await?;
585 let mut session = serde_json::from_str::<StoredSession>(&serialized_session)?;
586
587 let user_session =
588 self.client.oidc().user_session().expect("A logged in client has a session");
589 session.user_session = user_session;
590
591 let serialized_session = serde_json::to_string(&session)?;
592 fs::write(&self.session_file, serialized_session).await?;
593
594 println!("Updating the stored session: done!");
595 Ok(())
596 }
597
598 async fn refresh_token(&self) -> anyhow::Result<()> {
600 self.client.oidc().refresh_access_token().await?;
601
602 println!("\nToken refreshed successfully");
606
607 Ok(())
608 }
609
610 async fn logout(&self) -> anyhow::Result<()> {
612 let url_builder = self.client.oidc().logout().await?;
614
615 let data_dir = self.session_file.parent().expect("The file has a parent directory");
617 fs::remove_dir_all(data_dir).await?;
618
619 println!("\nLogged out successfully");
620
621 if let Some(url_builder) = url_builder {
622 let data = url_builder.build()?;
623 println!(
624 "\nTo log out from your account in the provider's interface, visit: {}",
625 data.url
626 );
627 }
628
629 println!("\nExiting…");
630
631 Ok(())
632 }
633}
634
635async fn build_client(data_dir: &Path) -> anyhow::Result<(Client, ClientSession)> {
640 let db_path = data_dir.join("db");
641
642 let mut rng = thread_rng();
644 let passphrase: String =
645 (&mut rng).sample_iter(Alphanumeric).take(32).map(char::from).collect();
646
647 loop {
649 let mut homeserver = String::new();
650
651 print!("\nHomeserver: ");
652 io::stdout().flush().expect("Unable to write to stdout");
653 io::stdin().read_line(&mut homeserver).expect("Unable to read user input");
654
655 let homeserver = homeserver.trim();
656
657 println!("\nChecking homeserver…");
658
659 match Client::builder()
660 .server_name_or_homeserver_url(homeserver)
662 .handle_refresh_tokens()
664 .sqlite_store(&db_path, Some(&passphrase))
668 .build()
669 .await
670 {
671 Ok(client) => {
672 match client.oidc().provider_metadata().await {
674 Ok(server_metadata) => {
675 println!(
676 "Found OAuth 2.0 server metadata with issuer: {}",
677 server_metadata.issuer()
678 );
679
680 let homeserver = client.homeserver().to_string();
681 return Ok((client, ClientSession { homeserver, db_path, passphrase }));
682 }
683 Err(error) => {
684 if error.is_not_supported() {
685 println!(
686 "This homeserver doesn't advertise OAuth 2.0 server metadata."
687 );
688 } else {
689 println!("Error fetching the OAuth 2.0 server metadata: {error:?}");
690 }
691 fs::remove_dir_all(data_dir).await?;
693 }
694 }
695 }
696 Err(error) => match &error {
697 ClientBuildError::AutoDiscovery(_)
698 | ClientBuildError::Url(_)
699 | ClientBuildError::Http(_) => {
700 println!("Error checking the homeserver: {error}");
701 println!("Please try again\n");
702 fs::remove_dir_all(data_dir).await?;
704 }
705 ClientBuildError::InvalidServerName => {
706 println!("Error: not a valid server name");
707 println!("Please try again\n");
708 }
709 _ => {
710 return Err(error.into());
712 }
713 },
714 }
715 }
716}
717
718fn client_metadata() -> VerifiedClientMetadata {
725 let redirect_uri = Url::parse("http://127.0.0.1").expect("Couldn't parse redirect URI");
726 let client_uri = Url::parse("https://github.com/matrix-org/matrix-rust-sdk")
727 .expect("Couldn't parse client URI");
728
729 ClientMetadata {
730 application_type: Some(ApplicationType::Native),
732 redirect_uris: Some(vec![redirect_uri]),
736 grant_types: Some(vec![GrantType::RefreshToken, GrantType::AuthorizationCode]),
739 token_endpoint_auth_method: Some(OAuthClientAuthenticationMethod::None),
742 client_name: Some(Localized::new("matrix-rust-sdk-oidc-cli".to_owned(), [])),
747 contacts: Some(vec!["root@127.0.0.1".to_owned()]),
748 client_uri: Some(Localized::new(client_uri.clone(), [])),
749 policy_uri: Some(Localized::new(client_uri.clone(), [])),
750 tos_uri: Some(Localized::new(client_uri, [])),
751 ..Default::default()
752 }
753 .validate()
754 .unwrap()
755}
756
757async fn use_auth_url(
761 url: &Url,
762 state: &str,
763 data_rx: oneshot::Receiver<String>,
764 signal_tx: oneshot::Sender<()>,
765) -> anyhow::Result<AuthorizationCode> {
766 println!("\nPlease authenticate yourself at: {url}\n");
767 println!("Then proceed to the authorization.\n");
768
769 let response_query = data_rx.await?;
770 signal_tx.send(()).expect("Receiver is still alive");
771
772 let code = match AuthorizationResponse::parse_query(&response_query)? {
773 AuthorizationResponse::Success(code) => code,
774 AuthorizationResponse::Error(err) => {
775 let err = err.error;
776 return Err(anyhow!("{}: {:?}", err.error, err.error_description));
777 }
778 };
779
780 if code.state != state {
785 bail!("State strings don't match")
786 }
787
788 Ok(code)
789}
790
791async fn spawn_local_server(
797) -> anyhow::Result<(Url, oneshot::Receiver<String>, oneshot::Sender<()>)> {
798 const SSO_SERVER_BIND_RANGE: Range<u16> = 20000..30000;
803 const SSO_SERVER_BIND_TRIES: u8 = 10;
805
806 let (signal_tx, signal_rx) = oneshot::channel::<()>();
808 let (data_tx, data_rx) = oneshot::channel::<String>();
810 let data_tx_mutex = Arc::new(Mutex::new(Some(data_tx)));
811
812 let mut redirect_url = Url::parse("http://127.0.0.1:0/")
814 .expect("Couldn't parse good known loopback interface URL");
815
816 let listener = {
818 let host = redirect_url.host_str().expect("The redirect URL doesn't have a host");
819 let mut n = 0u8;
820
821 loop {
822 let port = thread_rng().gen_range(SSO_SERVER_BIND_RANGE);
823 match TcpListener::bind((host, port)).await {
824 Ok(l) => {
825 redirect_url
826 .set_port(Some(port))
827 .expect("Could not set new port on redirect URL");
828 break l;
829 }
830 Err(_) if n < SSO_SERVER_BIND_TRIES => {
831 n += 1;
832 }
833 Err(e) => {
834 return Err(e.into());
835 }
836 }
837 }
838 };
839
840 let router = any_service(service_fn(move |request: Request<_>| {
842 let data_tx_mutex = data_tx_mutex.clone();
843 async move {
844 if request.method() != Method::HEAD && request.method() != Method::GET {
846 return Ok::<_, Infallible>(StatusCode::METHOD_NOT_ALLOWED.into_response());
847 }
848
849 if let Some(data_tx) = data_tx_mutex.lock().unwrap().take() {
852 let query_string = request.uri().query().unwrap_or_default();
853
854 data_tx.send(query_string.to_owned()).expect("The receiver is still alive");
855 }
856
857 Ok("The authorization step is complete. You can close this page and go back to the oidc-cli.".into_response())
858 }
859 }));
860
861 let server = axum::serve(listener, router)
862 .with_graceful_shutdown(async {
863 signal_rx.await.ok();
864 })
865 .into_future();
866
867 tokio::spawn(server);
868
869 Ok((redirect_url, data_rx, signal_tx))
870}
871
872async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) {
874 if room.state() != RoomState::Joined {
876 return;
877 }
878 let MessageType::Text(text_content) = &event.content.msgtype else { return };
879
880 let room_name = match room.display_name().await {
881 Ok(room_name) => room_name.to_string(),
882 Err(error) => {
883 println!("Error getting room display name: {error}");
884 room.room_id().to_string()
886 }
887 };
888
889 println!("[{room_name}] {}: {}", event.sender, text_content.body)
890}