1use std::{
16 io::{self, Write},
17 net::{Ipv4Addr, Ipv6Addr},
18 path::{Path, PathBuf},
19 sync::Arc,
20};
21
22use anyhow::bail;
23use futures_util::StreamExt;
24use matrix_sdk::{
25 authentication::oauth::{
26 error::OAuthClientRegistrationError,
27 registration::{ApplicationType, ClientMetadata, Localized, OAuthGrantType},
28 AccountManagementActionFull, ClientId, OAuthAuthorizationData, OAuthError,
29 OAuthRegistrationStore, OAuthSession, UrlOrQuery, UserSession,
30 },
31 config::SyncSettings,
32 encryption::{recovery::RecoveryState, CrossSigningResetAuthType},
33 room::Room,
34 ruma::{
35 events::room::message::{MessageType, OriginalSyncRoomMessageEvent},
36 serde::Raw,
37 },
38 utils::local_server::{LocalServerBuilder, LocalServerRedirectHandle, QueryString},
39 Client, ClientBuildError, Result, RoomState,
40};
41use matrix_sdk_ui::sync_service::SyncService;
42use rand::{distributions::Alphanumeric, thread_rng, Rng};
43use serde::{Deserialize, Serialize};
44use tokio::{fs, io::AsyncBufReadExt as _};
45use url::Url;
46
47#[tokio::main]
63async fn main() -> anyhow::Result<()> {
64 tracing_subscriber::fmt::init();
65
66 let data_dir =
68 dirs::data_dir().expect("no data_dir directory found").join("matrix_sdk/oauth_cli");
69 let session_file = data_dir.join("session.json");
71
72 let cli = if session_file.exists() {
73 OAuthCli::from_stored_session(session_file).await?
74 } else {
75 OAuthCli::new(&data_dir, session_file).await?
76 };
77
78 cli.run().await
79}
80
81fn help() {
83 println!("Usage: [command] [args…]\n");
84 println!("Commands:");
85 println!(" whoami Get information about this session");
86 println!(" account Get the URL to manage this account");
87 println!(" watch [sliding?] Watch new incoming messages until an error occurs");
88 println!(" refresh Refresh the access token");
89 println!(" recover Recover the E2EE secrets from secret storage");
90 println!(" logout Log out of this account");
91 println!(" exit Exit this program");
92 println!(" help Show this message\n");
93}
94
95#[derive(Debug, Serialize, Deserialize)]
97struct ClientSession {
98 homeserver: String,
100
101 db_path: PathBuf,
103
104 passphrase: String,
106}
107
108#[derive(Debug, Serialize, Deserialize)]
110struct StoredSession {
111 client_session: ClientSession,
113
114 user_session: UserSession,
116
117 client_id: ClientId,
119}
120
121#[derive(Clone, Debug)]
123struct OAuthCli {
124 client: Client,
126
127 restored: bool,
129
130 session_file: PathBuf,
132}
133
134impl OAuthCli {
135 async fn new(data_dir: &Path, session_file: PathBuf) -> anyhow::Result<Self> {
137 println!("No previous session found, logging in…");
138
139 let (client, client_session) = build_client(data_dir).await?;
140 let cli = Self { client, restored: false, session_file };
141
142 if let Err(error) = cli.register_and_login(data_dir).await {
143 if error.downcast_ref::<OAuthError>().is_some_and(|error| {
144 matches!(
145 error,
146 OAuthError::ClientRegistration(OAuthClientRegistrationError::NotSupported)
147 )
148 }) {
149 bail!(
152 "This server doesn't support dynamic registration.\n\
153 Please select another homeserver."
154 );
155 } else {
156 return Err(error);
157 }
158 }
159
160 let full_session =
164 cli.client.oauth().full_session().expect("A logged-in client should have a session");
165
166 let serialized_session = serde_json::to_string(&StoredSession {
167 client_session,
168 user_session: full_session.user,
169 client_id: full_session.client_id,
170 })?;
171 fs::write(&cli.session_file, serialized_session).await?;
172
173 println!("Session persisted in {}", cli.session_file.to_string_lossy());
174
175 cli.setup_background_save();
176
177 Ok(cli)
178 }
179
180 async fn register_and_login(&self, data_dir: &Path) -> anyhow::Result<()> {
183 let oauth = self.client.oauth();
184
185 loop {
187 let registration_file =
193 data_dir.parent().expect("directory has a parent").join("oauth_registrations");
194 let registration_store =
195 OAuthRegistrationStore::new(registration_file, client_metadata()).await?;
196
197 let (redirect_uri, server_handle) = LocalServerBuilder::new().spawn().await?;
201
202 let OAuthAuthorizationData { url, .. } =
203 oauth.login(registration_store.into(), redirect_uri, None).build().await?;
204
205 let query_string =
206 use_auth_url(&url, server_handle).await.map(|query| query.0).unwrap_or_default();
207
208 match oauth.finish_login(UrlOrQuery::Query(query_string)).await {
209 Ok(()) => {
210 let user_id = self.client.user_id().expect("Got a user ID");
211 println!("Logged in as {user_id}");
212 break;
213 }
214 Err(err) => {
215 println!("Error: failed to login: {err}");
216 println!("Please try again.\n");
217 continue;
218 }
219 }
220 }
221
222 Ok(())
223 }
224
225 async fn from_stored_session(session_file: PathBuf) -> anyhow::Result<Self> {
227 println!("Previous session found in '{}'", session_file.to_string_lossy());
228
229 let serialized_session = fs::read_to_string(&session_file).await?;
231 let StoredSession { client_session, user_session, client_id } =
232 serde_json::from_str(&serialized_session)?;
233
234 let client = Client::builder()
236 .homeserver_url(client_session.homeserver)
237 .handle_refresh_tokens()
238 .sqlite_store(client_session.db_path, Some(&client_session.passphrase))
239 .build()
240 .await?;
241
242 println!("Restoring session for {}…", user_session.meta.user_id);
243
244 let session = OAuthSession { client_id, user: user_session };
245 client.restore_session(session).await?;
247
248 let this = Self { client, restored: true, session_file };
249
250 this.setup_background_save();
251
252 Ok(this)
253 }
254
255 async fn run(&self) -> anyhow::Result<()> {
257 help();
258
259 loop {
260 let mut input = String::new();
261
262 print!("\nEnter command: ");
263 io::stdout().flush().expect("Unable to write to stdout");
264
265 io::stdin().read_line(&mut input).expect("Unable to read user input");
266
267 let mut args = input.trim().split_ascii_whitespace();
268 let cmd = args.next();
269
270 match cmd {
271 Some("whoami") => {
272 self.whoami();
273 }
274 Some("account") => {
275 self.account(None).await;
276 }
277 Some("profile") => {
278 self.account(Some(AccountManagementActionFull::Profile)).await;
279 }
280 Some("sessions") => {
281 self.account(Some(AccountManagementActionFull::SessionsList)).await;
282 }
283 Some("watch") => match args.next() {
284 Some(sub) => {
285 if sub == "sliding" {
286 self.watch_sliding_sync().await?;
287 } else {
288 println!("unknown subcommand for watch: available is 'sliding'");
289 }
290 }
291 None => self.watch().await?,
292 },
293 Some("refresh") => {
294 self.refresh_token().await?;
295 }
296 Some("recover") => {
297 self.recover().await?;
298 }
299 Some("reset-cross-signing") => {
300 self.reset_cross_signing().await?;
301 }
302 Some("logout") => {
303 self.logout().await?;
304 break;
305 }
306 Some("exit") => {
307 break;
308 }
309 Some("help") => {
310 help();
311 }
312 Some(cmd) => {
313 println!("Error: unknown command '{cmd}'\n");
314 help();
315 }
316 None => {
317 println!("Error: no command\n");
318 help()
319 }
320 };
321 }
322
323 Ok(())
324 }
325
326 async fn recover(&self) -> anyhow::Result<()> {
327 let recovery = self.client.encryption().recovery();
328
329 println!("Please enter your recovery key:");
330
331 let mut input = String::new();
332 io::stdin().read_line(&mut input).expect("error: unable to read user input");
333
334 let input = input.trim();
335
336 recovery.recover(input).await?;
337
338 match recovery.state() {
339 RecoveryState::Enabled => println!("Successfully recovered all the E2EE secrets."),
340 RecoveryState::Disabled => println!("Error recovering, recovery is disabled."),
341 RecoveryState::Incomplete => println!("Couldn't recover all E2EE secrets."),
342 _ => unreachable!("We should know our recovery state by now"),
343 }
344
345 Ok(())
346 }
347
348 async fn reset_cross_signing(&self) -> Result<()> {
349 let encryption = self.client.encryption();
350
351 if let Some(handle) = encryption.reset_cross_signing().await? {
352 match handle.auth_type() {
353 CrossSigningResetAuthType::Uiaa(_) => {
354 unimplemented!(
355 "This should never happen, this is after all the OAuth 2.0 example."
356 )
357 }
358 CrossSigningResetAuthType::OAuth(o) => {
359 println!(
360 "To reset your end-to-end encryption cross-signing identity, \
361 you first need to approve it at {}",
362 o.approval_url
363 );
364 handle.auth(None).await?;
365 }
366 }
367 }
368
369 print!("Successfully reset cross-signing");
370
371 Ok(())
372 }
373
374 fn whoami(&self) {
376 let client = &self.client;
377 let oauth = client.oauth();
378
379 let user_id = client.user_id().expect("A logged in client has a user ID");
380 let device_id = client.device_id().expect("A logged in client has a device ID");
381 let homeserver = client.homeserver();
382 let issuer = oauth.issuer().expect("A logged in OAuth 2.0 client has an issuer");
383
384 println!("\nUser ID: {user_id}");
385 println!("Device ID: {device_id}");
386 println!("Homeserver URL: {homeserver}");
387 println!("OAuth 2.0 authorization server: {issuer}");
388 }
389
390 async fn account(&self, action: Option<AccountManagementActionFull>) {
392 let mut url_builder = match self.client.oauth().fetch_account_management_url().await {
393 Ok(Some(url_builder)) => url_builder,
394 _ => {
395 println!("\nThis homeserver does not provide the URL to manage your account");
396 return;
397 }
398 };
399
400 if let Some(action) = action {
401 url_builder = url_builder.action(action);
402 }
403
404 let url = url_builder.build();
405 println!("\nTo manage your account, visit: {url}");
406 }
407
408 async fn watch(&self) -> anyhow::Result<()> {
410 let client = &self.client;
411
412 if !self.restored {
416 client.sync_once(SyncSettings::default()).await.unwrap();
417 }
418
419 let handle = client.add_event_handler(on_room_message);
421
422 let mut sync_stream = Box::pin(client.sync_stream(SyncSettings::default()).await);
424 while let Some(res) = sync_stream.next().await {
425 if let Err(err) = res {
426 client.remove_event_handler(handle);
427 return Err(err.into());
428 }
429 }
430
431 Ok(())
432 }
433
434 async fn watch_sliding_sync(&self) -> anyhow::Result<()> {
437 let sync_service = Arc::new(SyncService::builder(self.client.clone()).build().await?);
438
439 sync_service.start().await;
440
441 println!("press enter to exit the sync loop");
442
443 let mut sync_service_state = sync_service.state();
444
445 let sync_service_clone = sync_service.clone();
446 let task = tokio::spawn(async move {
447 let mut num_errors = 0;
456 let mut num_running = 0;
457
458 let mut _unused = String::new();
459 let mut stdin = tokio::io::BufReader::new(tokio::io::stdin());
460
461 loop {
462 tokio::select! {
465 res = sync_service_state.next() => {
466 if let Some(state) = res {
467 match state {
468 matrix_sdk_ui::sync_service::State::Idle
469 | matrix_sdk_ui::sync_service::State::Terminated => {
470 num_errors = 0;
471 num_running = 0;
472 }
473
474 matrix_sdk_ui::sync_service::State::Running => {
475 num_running += 1;
476 if num_running > 1 {
477 num_errors = 0;
478 }
479 }
480
481 matrix_sdk_ui::sync_service::State::Error | matrix_sdk_ui::sync_service::State::Offline => {
482 num_errors += 1;
483 num_running = 0;
484
485 if num_errors == 5 {
486 println!("ran into 5 errors in a row, terminating");
487 break;
488 }
489
490 sync_service_clone.start().await;
491 }
492 }
493 println!("New sync service state update: {state:?}");
494 } else {
495 break;
496 }
497 }
498
499 _ = stdin.read_line(&mut _unused) => {
500 println!("Stopping loop because of user request");
501 sync_service.stop().await;
502
503 break;
504 }
505 }
506 }
507 });
508
509 println!("waiting for sync service to stop...");
510 task.await.unwrap();
511
512 println!("done!");
513 Ok(())
514 }
515
516 fn setup_background_save(&self) {
521 let this = self.clone();
522 tokio::spawn(async move {
523 while let Ok(update) = this.client.subscribe_to_session_changes().recv().await {
524 match update {
525 matrix_sdk::SessionChange::UnknownToken { soft_logout } => {
526 println!("Received an unknown token error; soft logout? {soft_logout:?}");
527 }
528 matrix_sdk::SessionChange::TokensRefreshed => {
529 if let Err(err) = this.update_stored_session().await {
531 println!("Unable to store a session in the background: {err}");
532 }
533 }
534 }
535 }
536 });
537 }
538
539 async fn update_stored_session(&self) -> anyhow::Result<()> {
544 println!("Updating the stored session...");
545
546 let serialized_session = fs::read_to_string(&self.session_file).await?;
547 let mut session = serde_json::from_str::<StoredSession>(&serialized_session)?;
548
549 let user_session: UserSession =
550 self.client.oauth().user_session().expect("A logged in client has a session");
551 session.user_session = user_session;
552
553 let serialized_session = serde_json::to_string(&session)?;
554 fs::write(&self.session_file, serialized_session).await?;
555
556 println!("Updating the stored session: done!");
557 Ok(())
558 }
559
560 async fn refresh_token(&self) -> anyhow::Result<()> {
562 self.client.oauth().refresh_access_token().await?;
563
564 println!("\nToken refreshed successfully");
568
569 Ok(())
570 }
571
572 async fn logout(&self) -> anyhow::Result<()> {
574 self.client.oauth().logout().await?;
576
577 let data_dir = self.session_file.parent().expect("The file has a parent directory");
579 fs::remove_dir_all(data_dir).await?;
580
581 println!("\nLogged out successfully");
582 println!("\nExiting…");
583
584 Ok(())
585 }
586}
587
588async fn build_client(data_dir: &Path) -> anyhow::Result<(Client, ClientSession)> {
592 let db_path = data_dir.join("db");
593
594 let mut rng = thread_rng();
596 let passphrase: String =
597 (&mut rng).sample_iter(Alphanumeric).take(32).map(char::from).collect();
598
599 loop {
601 let mut homeserver = String::new();
602
603 print!("\nHomeserver: ");
604 io::stdout().flush().expect("Unable to write to stdout");
605 io::stdin().read_line(&mut homeserver).expect("Unable to read user input");
606
607 let homeserver = homeserver.trim();
608
609 println!("\nChecking homeserver…");
610
611 match Client::builder()
612 .server_name_or_homeserver_url(homeserver)
614 .handle_refresh_tokens()
616 .sqlite_store(&db_path, Some(&passphrase))
620 .build()
621 .await
622 {
623 Ok(client) => {
624 match client.oauth().server_metadata().await {
626 Ok(server_metadata) => {
627 println!(
628 "Found OAuth 2.0 server metadata with issuer: {}",
629 server_metadata.issuer
630 );
631
632 let homeserver = client.homeserver().to_string();
633 return Ok((client, ClientSession { homeserver, db_path, passphrase }));
634 }
635 Err(error) => {
636 if error.is_not_supported() {
637 println!(
638 "This homeserver doesn't advertise OAuth 2.0 server metadata."
639 );
640 } else {
641 println!("Error fetching the OAuth 2.0 server metadata: {error:?}");
642 }
643 fs::remove_dir_all(data_dir).await?;
645 }
646 }
647 }
648 Err(error) => match &error {
649 ClientBuildError::AutoDiscovery(_)
650 | ClientBuildError::Url(_)
651 | ClientBuildError::Http(_) => {
652 println!("Error checking the homeserver: {error}");
653 println!("Please try again\n");
654 fs::remove_dir_all(data_dir).await?;
656 }
657 ClientBuildError::InvalidServerName => {
658 println!("Error: not a valid server name");
659 println!("Please try again\n");
660 }
661 _ => {
662 return Err(error.into());
664 }
665 },
666 }
667 }
668}
669
670fn client_metadata() -> Raw<ClientMetadata> {
672 let ipv4_localhost_uri = Url::parse(&format!("http://{}/", Ipv4Addr::LOCALHOST))
676 .expect("Couldn't parse IPv4 redirect URI");
677 let ipv6_localhost_uri = Url::parse(&format!("http://[{}]/", Ipv6Addr::LOCALHOST))
678 .expect("Couldn't parse IPv6 redirect URI");
679 let client_uri = Localized::new(
680 Url::parse("https://github.com/matrix-org/matrix-rust-sdk")
681 .expect("Couldn't parse client URI"),
682 None,
683 );
684
685 let metadata = ClientMetadata {
686 client_name: Some(Localized::new("matrix-rust-sdk-oauth-cli".to_owned(), [])),
691 policy_uri: Some(client_uri.clone()),
692 tos_uri: Some(client_uri.clone()),
693 ..ClientMetadata::new(
694 ApplicationType::Native,
697 vec![OAuthGrantType::AuthorizationCode {
699 redirect_uris: vec![ipv4_localhost_uri, ipv6_localhost_uri],
700 }],
701 client_uri,
702 )
703 };
704
705 Raw::new(&metadata).expect("Couldn't serialize client metadata")
706}
707
708async fn use_auth_url(url: &Url, server_handle: LocalServerRedirectHandle) -> Option<QueryString> {
712 println!("\nPlease authenticate yourself at: {url}\n");
713 println!("Then proceed to the authorization.\n");
714
715 server_handle.await
716}
717
718async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) {
720 if room.state() != RoomState::Joined {
722 return;
723 }
724 let MessageType::Text(text_content) = &event.content.msgtype else { return };
725
726 let room_name = match room.display_name().await {
727 Ok(room_name) => room_name.to_string(),
728 Err(error) => {
729 println!("Error getting room display name: {error}");
730 room.room_id().to_string()
732 }
733 };
734
735 println!("[{room_name}] {}: {}", event.sender, text_content.body)
736}