1use std::{
16 io::{self, Write},
17 net::{Ipv4Addr, Ipv6Addr},
18 path::{Path, PathBuf},
19 sync::Arc,
20};
21
22use anyhow::bail;
23use futures_util::StreamExt;
24use matrix_sdk::{
25 Client, ClientBuildError, Result, RoomState,
26 authentication::oauth::{
27 ClientId, OAuthAuthorizationData, OAuthError, OAuthSession, UserSession,
28 error::OAuthClientRegistrationError,
29 registration::{ApplicationType, ClientMetadata, Localized, OAuthGrantType},
30 },
31 config::SyncSettings,
32 encryption::{CrossSigningResetAuthType, recovery::RecoveryState},
33 room::Room,
34 ruma::{
35 api::client::{
36 discovery::get_authorization_server_metadata::v1::AccountManagementActionData, uiaa,
37 },
38 events::room::message::{MessageType, OriginalSyncRoomMessageEvent},
39 serde::Raw,
40 },
41 utils::local_server::{LocalServerBuilder, LocalServerRedirectHandle, QueryString},
42};
43use matrix_sdk_ui::sync_service::SyncService;
44use rand::{RngExt, distr::Alphanumeric, rng};
45use serde::{Deserialize, Serialize};
46use tokio::{fs, io::AsyncBufReadExt as _};
47use url::Url;
48
49#[tokio::main]
65async fn main() -> anyhow::Result<()> {
66 tracing_subscriber::fmt::init();
67
68 let data_dir =
70 dirs::data_dir().expect("no data_dir directory found").join("matrix_sdk/oauth_cli");
71 let session_file = data_dir.join("session.json");
73
74 let cli = if session_file.exists() {
75 OAuthCli::from_stored_session(session_file).await?
76 } else {
77 OAuthCli::new(&data_dir, session_file).await?
78 };
79
80 cli.run().await
81}
82
83fn help() {
85 println!("Usage: [command] [args…]\n");
86 println!("Commands:");
87 println!(" whoami Get information about this session");
88 println!(" account Get the URL to manage this account");
89 println!(" watch [sliding?] Watch new incoming messages until an error occurs");
90 println!(" refresh Refresh the access token");
91 println!(" recover Recover the E2EE secrets from secret storage");
92 println!(" logout Log out of this account");
93 println!(" exit Exit this program");
94 println!(" help Show this message\n");
95}
96
97#[derive(Debug, Serialize, Deserialize)]
99struct ClientSession {
100 homeserver: String,
102
103 db_path: PathBuf,
105
106 passphrase: String,
108}
109
110#[derive(Debug, Serialize, Deserialize)]
112struct StoredSession {
113 client_session: ClientSession,
115
116 user_session: UserSession,
118
119 client_id: ClientId,
121}
122
123#[derive(Clone, Debug)]
125struct OAuthCli {
126 client: Client,
128
129 restored: bool,
131
132 session_file: PathBuf,
134}
135
136impl OAuthCli {
137 async fn new(data_dir: &Path, session_file: PathBuf) -> anyhow::Result<Self> {
139 println!("No previous session found, logging in…");
140
141 let (client, client_session) = build_client(data_dir).await?;
142 let cli = Self { client, restored: false, session_file };
143
144 if let Err(error) = cli.register_and_login().await {
145 if let Some(error) = error.downcast_ref::<OAuthError>()
146 && let OAuthError::ClientRegistration(OAuthClientRegistrationError::NotSupported) =
147 error
148 {
149 bail!(
152 "This server doesn't support dynamic registration.\n\
153 Please select another homeserver."
154 );
155 } else {
156 return Err(error);
157 }
158 }
159
160 let full_session =
164 cli.client.oauth().full_session().expect("A logged-in client should have a session");
165
166 let serialized_session = serde_json::to_string(&StoredSession {
167 client_session,
168 user_session: full_session.user,
169 client_id: full_session.client_id,
170 })?;
171 fs::write(&cli.session_file, serialized_session).await?;
172
173 println!("Session persisted in {}", cli.session_file.to_string_lossy());
174
175 cli.setup_background_save();
176
177 Ok(cli)
178 }
179
180 async fn register_and_login(&self) -> anyhow::Result<()> {
183 let oauth = self.client.oauth();
184
185 loop {
187 let (redirect_uri, server_handle) = LocalServerBuilder::new().spawn().await?;
191
192 let OAuthAuthorizationData { url, .. } = oauth
193 .login(redirect_uri, None, Some(client_metadata().into()), None)
194 .build()
195 .await?;
196
197 let Some(query_string) = use_auth_url(&url, server_handle).await else {
198 println!("Error: failed to login: missing query string on the redirect URL");
199 println!("Please try again.\n");
200 continue;
201 };
202
203 match oauth.finish_login(query_string.into()).await {
204 Ok(()) => {
205 let user_id = self.client.user_id().expect("Got a user ID");
206 println!("Logged in as {user_id}");
207 break;
208 }
209 Err(err) => {
210 println!("Error: failed to login: {err}");
211 println!("Please try again.\n");
212 continue;
213 }
214 }
215 }
216
217 Ok(())
218 }
219
220 async fn from_stored_session(session_file: PathBuf) -> anyhow::Result<Self> {
222 println!("Previous session found in '{}'", session_file.to_string_lossy());
223
224 let serialized_session = fs::read_to_string(&session_file).await?;
226 let StoredSession { client_session, user_session, client_id } =
227 serde_json::from_str(&serialized_session)?;
228
229 let client = Client::builder()
231 .homeserver_url(client_session.homeserver)
232 .handle_refresh_tokens()
233 .sqlite_store(client_session.db_path, Some(&client_session.passphrase))
234 .build()
235 .await?;
236
237 println!("Restoring session for {}…", user_session.meta.user_id);
238
239 let session = OAuthSession { client_id, user: user_session };
240 client.restore_session(session).await?;
242
243 let this = Self { client, restored: true, session_file };
244
245 this.setup_background_save();
246
247 Ok(this)
248 }
249
250 async fn run(&self) -> anyhow::Result<()> {
252 help();
253
254 loop {
255 let mut input = String::new();
256
257 print!("\nEnter command: ");
258 io::stdout().flush().expect("Unable to write to stdout");
259
260 io::stdin().read_line(&mut input).expect("Unable to read user input");
261
262 let mut args = input.trim().split_ascii_whitespace();
263 let cmd = args.next();
264
265 match cmd {
266 Some("whoami") => {
267 self.whoami();
268 }
269 Some("account") => {
270 self.account(None).await;
271 }
272 Some("profile") => {
273 self.account(Some(AccountManagementActionData::Profile)).await;
274 }
275 Some("devices") => {
276 self.account(Some(AccountManagementActionData::DevicesList)).await;
277 }
278 Some("watch") => match args.next() {
279 Some(sub) => {
280 if sub == "sliding" {
281 self.watch_sliding_sync().await?;
282 } else {
283 println!("unknown subcommand for watch: available is 'sliding'");
284 }
285 }
286 None => self.watch().await?,
287 },
288 Some("refresh") => {
289 self.refresh_token().await?;
290 }
291 Some("recover") => {
292 self.recover().await?;
293 }
294 Some("reset-cross-signing") => {
295 self.reset_cross_signing().await?;
296 }
297 Some("logout") => {
298 self.logout().await?;
299 break;
300 }
301 Some("exit") => {
302 break;
303 }
304 Some("help") => {
305 help();
306 }
307 Some(cmd) => {
308 println!("Error: unknown command '{cmd}'\n");
309 help();
310 }
311 None => {
312 println!("Error: no command\n");
313 help()
314 }
315 }
316 }
317
318 Ok(())
319 }
320
321 async fn recover(&self) -> anyhow::Result<()> {
322 let recovery = self.client.encryption().recovery();
323
324 println!("Please enter your recovery key:");
325
326 let mut input = String::new();
327 io::stdin().read_line(&mut input).expect("error: unable to read user input");
328
329 let input = input.trim();
330
331 recovery.recover(input).await?;
332
333 match recovery.state() {
334 RecoveryState::Enabled => println!("Successfully recovered all the E2EE secrets."),
335 RecoveryState::Disabled => println!("Error recovering, recovery is disabled."),
336 RecoveryState::Incomplete => println!("Couldn't recover all E2EE secrets."),
337 _ => unreachable!("We should know our recovery state by now"),
338 }
339
340 Ok(())
341 }
342
343 async fn reset_cross_signing(&self) -> Result<()> {
344 let encryption = self.client.encryption();
345
346 if let Some(handle) = encryption.reset_cross_signing().await? {
347 match handle.auth_type() {
348 CrossSigningResetAuthType::Uiaa(_) => {
349 unimplemented!(
350 "This should never happen, this is after all the OAuth 2.0 example."
351 )
352 }
353 CrossSigningResetAuthType::OAuth(o) => {
354 println!(
355 "To reset your end-to-end encryption cross-signing identity, \
356 you first need to approve it at {}",
357 o.approval_url
358 );
359
360 let mut oauth_data = uiaa::OAuth::new();
361 oauth_data.session = o.session.clone();
362 handle.auth(Some(uiaa::AuthData::OAuth(oauth_data))).await?;
363 }
364 }
365 }
366
367 print!("Successfully reset cross-signing");
368
369 Ok(())
370 }
371
372 fn whoami(&self) {
374 let client = &self.client;
375
376 let user_id = client.user_id().expect("A logged in client has a user ID");
377 let device_id = client.device_id().expect("A logged in client has a device ID");
378 let homeserver = client.homeserver();
379
380 println!("\nUser ID: {user_id}");
381 println!("Device ID: {device_id}");
382 println!("Homeserver URL: {homeserver}");
383 }
384
385 async fn account(&self, action: Option<AccountManagementActionData<'_>>) {
387 let Ok(server_metadata) = self.client.oauth().cached_server_metadata().await else {
388 println!("\nCould not retrieve the server metadata");
389 return;
390 };
391
392 let url = if let Some(action) = action {
393 server_metadata.account_management_url_with_action(action)
394 } else {
395 server_metadata.account_management_uri
396 };
397
398 let Some(url) = url else {
399 println!("\nThis homeserver does not provide the URL to manage your account");
400 return;
401 };
402
403 println!("\nTo manage your account, visit: {url}");
404 }
405
406 async fn watch(&self) -> anyhow::Result<()> {
408 let client = &self.client;
409
410 if !self.restored {
414 client.sync_once(SyncSettings::default()).await.unwrap();
415 }
416
417 let handle = client.add_event_handler(on_room_message);
419
420 let mut sync_stream = Box::pin(client.sync_stream(SyncSettings::default()).await);
422 while let Some(res) = sync_stream.next().await {
423 if let Err(err) = res {
424 client.remove_event_handler(handle);
425 return Err(err.into());
426 }
427 }
428
429 Ok(())
430 }
431
432 async fn watch_sliding_sync(&self) -> anyhow::Result<()> {
435 let sync_service = Arc::new(SyncService::builder(self.client.clone()).build().await?);
436
437 sync_service.start().await;
438
439 println!("press enter to exit the sync loop");
440
441 let mut sync_service_state = sync_service.state();
442
443 let sync_service_clone = sync_service.clone();
444 let task = tokio::spawn(async move {
445 let mut num_errors = 0;
454 let mut num_running = 0;
455
456 let mut _unused = String::new();
457 let mut stdin = tokio::io::BufReader::new(tokio::io::stdin());
458
459 loop {
460 tokio::select! {
463 res = sync_service_state.next() => {
464 if let Some(state) = res {
465 match state {
466 matrix_sdk_ui::sync_service::State::Idle
467 | matrix_sdk_ui::sync_service::State::Terminated => {
468 num_errors = 0;
469 num_running = 0;
470 }
471
472 matrix_sdk_ui::sync_service::State::Running => {
473 num_running += 1;
474 if num_running > 1 {
475 num_errors = 0;
476 }
477 }
478
479 matrix_sdk_ui::sync_service::State::Error(_) | matrix_sdk_ui::sync_service::State::Offline => {
480 num_errors += 1;
481 num_running = 0;
482
483 if num_errors == 5 {
484 println!("ran into 5 errors in a row, terminating");
485 break;
486 }
487
488 sync_service_clone.start().await;
489 }
490 }
491 println!("New sync service state update: {state:?}");
492 } else {
493 break;
494 }
495 }
496
497 _ = stdin.read_line(&mut _unused) => {
498 println!("Stopping loop because of user request");
499 sync_service.stop().await;
500
501 break;
502 }
503 }
504 }
505 });
506
507 println!("waiting for sync service to stop...");
508 task.await.unwrap();
509
510 println!("done!");
511 Ok(())
512 }
513
514 fn setup_background_save(&self) {
519 let this = self.clone();
520 tokio::spawn(async move {
521 while let Ok(update) = this.client.subscribe_to_session_changes().recv().await {
522 match update {
523 matrix_sdk::SessionChange::UnknownToken(unknown_token) => {
524 println!(
525 "Received an unknown token error; soft logout? {:?}",
526 unknown_token.soft_logout
527 );
528 }
529 matrix_sdk::SessionChange::TokensRefreshed => {
530 if let Err(err) = this.update_stored_session().await {
532 println!("Unable to store a session in the background: {err}");
533 }
534 }
535 }
536 }
537 });
538 }
539
540 async fn update_stored_session(&self) -> anyhow::Result<()> {
545 println!("Updating the stored session...");
546
547 let serialized_session = fs::read_to_string(&self.session_file).await?;
548 let mut session = serde_json::from_str::<StoredSession>(&serialized_session)?;
549
550 let user_session: UserSession =
551 self.client.oauth().user_session().expect("A logged in client has a session");
552 session.user_session = user_session;
553
554 let serialized_session = serde_json::to_string(&session)?;
555 fs::write(&self.session_file, serialized_session).await?;
556
557 println!("Updating the stored session: done!");
558 Ok(())
559 }
560
561 async fn refresh_token(&self) -> anyhow::Result<()> {
563 self.client.oauth().refresh_access_token().await?;
564
565 println!("\nToken refreshed successfully");
569
570 Ok(())
571 }
572
573 async fn logout(&self) -> anyhow::Result<()> {
575 self.client.logout().await?;
577
578 let data_dir = self.session_file.parent().expect("The file has a parent directory");
580 fs::remove_dir_all(data_dir).await?;
581
582 println!("\nLogged out successfully");
583 println!("\nExiting…");
584
585 Ok(())
586 }
587}
588
589async fn build_client(data_dir: &Path) -> anyhow::Result<(Client, ClientSession)> {
593 let db_path = data_dir.join("db");
594
595 let mut rng = rng();
597 let passphrase: String =
598 (&mut rng).sample_iter(Alphanumeric).take(32).map(char::from).collect();
599
600 loop {
602 let mut homeserver = String::new();
603
604 print!("\nHomeserver: ");
605 io::stdout().flush().expect("Unable to write to stdout");
606 io::stdin().read_line(&mut homeserver).expect("Unable to read user input");
607
608 let homeserver = homeserver.trim();
609
610 println!("\nChecking homeserver…");
611
612 match Client::builder()
613 .server_name_or_homeserver_url(homeserver)
615 .handle_refresh_tokens()
617 .sqlite_store(&db_path, Some(&passphrase))
621 .build()
622 .await
623 {
624 Ok(client) => {
625 match client.oauth().server_metadata().await {
627 Ok(server_metadata) => {
628 println!(
629 "Found OAuth 2.0 server metadata with issuer: {}",
630 server_metadata.issuer
631 );
632
633 let homeserver = client.homeserver().to_string();
634 return Ok((client, ClientSession { homeserver, db_path, passphrase }));
635 }
636 Err(error) => {
637 if error.is_not_supported() {
638 println!(
639 "This homeserver doesn't advertise OAuth 2.0 server metadata."
640 );
641 } else {
642 println!("Error fetching the OAuth 2.0 server metadata: {error:?}");
643 }
644 fs::remove_dir_all(data_dir).await?;
646 }
647 }
648 }
649 Err(error) => match &error {
650 ClientBuildError::AutoDiscovery(_)
651 | ClientBuildError::Url(_)
652 | ClientBuildError::Http(_) => {
653 println!("Error checking the homeserver: {error}");
654 println!("Please try again\n");
655 fs::remove_dir_all(data_dir).await?;
657 }
658 ClientBuildError::InvalidServerName => {
659 println!("Error: not a valid server name");
660 println!("Please try again\n");
661 }
662 _ => {
663 return Err(error.into());
665 }
666 },
667 }
668 }
669}
670
671fn client_metadata() -> Raw<ClientMetadata> {
673 let ipv4_localhost_uri = Url::parse(&format!("http://{}/", Ipv4Addr::LOCALHOST))
677 .expect("Couldn't parse IPv4 redirect URI");
678 let ipv6_localhost_uri = Url::parse(&format!("http://[{}]/", Ipv6Addr::LOCALHOST))
679 .expect("Couldn't parse IPv6 redirect URI");
680 let client_uri = Localized::new(
681 Url::parse("https://github.com/matrix-org/matrix-rust-sdk")
682 .expect("Couldn't parse client URI"),
683 None,
684 );
685
686 let metadata = ClientMetadata {
687 client_name: Some(Localized::new("matrix-rust-sdk-oauth-cli".to_owned(), [])),
692 policy_uri: Some(client_uri.clone()),
693 tos_uri: Some(client_uri.clone()),
694 ..ClientMetadata::new(
695 ApplicationType::Native,
698 vec![OAuthGrantType::AuthorizationCode {
700 redirect_uris: vec![ipv4_localhost_uri, ipv6_localhost_uri],
701 }],
702 client_uri,
703 )
704 };
705
706 Raw::new(&metadata).expect("Couldn't serialize client metadata")
707}
708
709async fn use_auth_url(url: &Url, server_handle: LocalServerRedirectHandle) -> Option<QueryString> {
713 println!("\nPlease authenticate yourself at: {url}\n");
714 println!("Then proceed to the authorization.\n");
715
716 server_handle.await
717}
718
719async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) {
721 if room.state() != RoomState::Joined {
723 return;
724 }
725 let MessageType::Text(text_content) = &event.content.msgtype else { return };
726
727 let room_name = match room.display_name().await {
728 Ok(room_name) => room_name.to_string(),
729 Err(error) => {
730 println!("Error getting room display name: {error}");
731 room.room_id().to_string()
733 }
734 };
735
736 println!("[{room_name}] {}: {}", event.sender, text_content.body)
737}