1use std::{
16 io::{self, Write},
17 net::{Ipv4Addr, Ipv6Addr},
18 path::{Path, PathBuf},
19 sync::Arc,
20};
21
22use anyhow::bail;
23use futures_util::StreamExt;
24use matrix_sdk::{
25 Client, ClientBuildError, Result, RoomState,
26 authentication::oauth::{
27 ClientId, OAuthAuthorizationData, OAuthError, OAuthSession, UserSession,
28 error::OAuthClientRegistrationError,
29 registration::{ApplicationType, ClientMetadata, Localized, OAuthGrantType},
30 },
31 config::SyncSettings,
32 encryption::{CrossSigningResetAuthType, recovery::RecoveryState},
33 room::Room,
34 ruma::{
35 api::client::discovery::get_authorization_server_metadata::v1::AccountManagementActionData,
36 events::room::message::{MessageType, OriginalSyncRoomMessageEvent},
37 serde::Raw,
38 },
39 utils::local_server::{LocalServerBuilder, LocalServerRedirectHandle, QueryString},
40};
41use matrix_sdk_ui::sync_service::SyncService;
42use rand::{Rng, distributions::Alphanumeric, thread_rng};
43use serde::{Deserialize, Serialize};
44use tokio::{fs, io::AsyncBufReadExt as _};
45use url::Url;
46
47#[tokio::main]
63async fn main() -> anyhow::Result<()> {
64 tracing_subscriber::fmt::init();
65
66 let data_dir =
68 dirs::data_dir().expect("no data_dir directory found").join("matrix_sdk/oauth_cli");
69 let session_file = data_dir.join("session.json");
71
72 let cli = if session_file.exists() {
73 OAuthCli::from_stored_session(session_file).await?
74 } else {
75 OAuthCli::new(&data_dir, session_file).await?
76 };
77
78 cli.run().await
79}
80
81fn help() {
83 println!("Usage: [command] [args…]\n");
84 println!("Commands:");
85 println!(" whoami Get information about this session");
86 println!(" account Get the URL to manage this account");
87 println!(" watch [sliding?] Watch new incoming messages until an error occurs");
88 println!(" refresh Refresh the access token");
89 println!(" recover Recover the E2EE secrets from secret storage");
90 println!(" logout Log out of this account");
91 println!(" exit Exit this program");
92 println!(" help Show this message\n");
93}
94
95#[derive(Debug, Serialize, Deserialize)]
97struct ClientSession {
98 homeserver: String,
100
101 db_path: PathBuf,
103
104 passphrase: String,
106}
107
108#[derive(Debug, Serialize, Deserialize)]
110struct StoredSession {
111 client_session: ClientSession,
113
114 user_session: UserSession,
116
117 client_id: ClientId,
119}
120
121#[derive(Clone, Debug)]
123struct OAuthCli {
124 client: Client,
126
127 restored: bool,
129
130 session_file: PathBuf,
132}
133
134impl OAuthCli {
135 async fn new(data_dir: &Path, session_file: PathBuf) -> anyhow::Result<Self> {
137 println!("No previous session found, logging in…");
138
139 let (client, client_session) = build_client(data_dir).await?;
140 let cli = Self { client, restored: false, session_file };
141
142 if let Err(error) = cli.register_and_login().await {
143 if let Some(error) = error.downcast_ref::<OAuthError>()
144 && let OAuthError::ClientRegistration(OAuthClientRegistrationError::NotSupported) =
145 error
146 {
147 bail!(
150 "This server doesn't support dynamic registration.\n\
151 Please select another homeserver."
152 );
153 } else {
154 return Err(error);
155 }
156 }
157
158 let full_session =
162 cli.client.oauth().full_session().expect("A logged-in client should have a session");
163
164 let serialized_session = serde_json::to_string(&StoredSession {
165 client_session,
166 user_session: full_session.user,
167 client_id: full_session.client_id,
168 })?;
169 fs::write(&cli.session_file, serialized_session).await?;
170
171 println!("Session persisted in {}", cli.session_file.to_string_lossy());
172
173 cli.setup_background_save();
174
175 Ok(cli)
176 }
177
178 async fn register_and_login(&self) -> anyhow::Result<()> {
181 let oauth = self.client.oauth();
182
183 loop {
185 let (redirect_uri, server_handle) = LocalServerBuilder::new().spawn().await?;
189
190 let OAuthAuthorizationData { url, .. } = oauth
191 .login(redirect_uri, None, Some(client_metadata().into()), None)
192 .build()
193 .await?;
194
195 let Some(query_string) = use_auth_url(&url, server_handle).await else {
196 println!("Error: failed to login: missing query string on the redirect URL");
197 println!("Please try again.\n");
198 continue;
199 };
200
201 match oauth.finish_login(query_string.into()).await {
202 Ok(()) => {
203 let user_id = self.client.user_id().expect("Got a user ID");
204 println!("Logged in as {user_id}");
205 break;
206 }
207 Err(err) => {
208 println!("Error: failed to login: {err}");
209 println!("Please try again.\n");
210 continue;
211 }
212 }
213 }
214
215 Ok(())
216 }
217
218 async fn from_stored_session(session_file: PathBuf) -> anyhow::Result<Self> {
220 println!("Previous session found in '{}'", session_file.to_string_lossy());
221
222 let serialized_session = fs::read_to_string(&session_file).await?;
224 let StoredSession { client_session, user_session, client_id } =
225 serde_json::from_str(&serialized_session)?;
226
227 let client = Client::builder()
229 .homeserver_url(client_session.homeserver)
230 .handle_refresh_tokens()
231 .sqlite_store(client_session.db_path, Some(&client_session.passphrase))
232 .build()
233 .await?;
234
235 println!("Restoring session for {}…", user_session.meta.user_id);
236
237 let session = OAuthSession { client_id, user: user_session };
238 client.restore_session(session).await?;
240
241 let this = Self { client, restored: true, session_file };
242
243 this.setup_background_save();
244
245 Ok(this)
246 }
247
248 async fn run(&self) -> anyhow::Result<()> {
250 help();
251
252 loop {
253 let mut input = String::new();
254
255 print!("\nEnter command: ");
256 io::stdout().flush().expect("Unable to write to stdout");
257
258 io::stdin().read_line(&mut input).expect("Unable to read user input");
259
260 let mut args = input.trim().split_ascii_whitespace();
261 let cmd = args.next();
262
263 match cmd {
264 Some("whoami") => {
265 self.whoami();
266 }
267 Some("account") => {
268 self.account(None).await;
269 }
270 Some("profile") => {
271 self.account(Some(AccountManagementActionData::Profile)).await;
272 }
273 Some("devices") => {
274 self.account(Some(AccountManagementActionData::DevicesList)).await;
275 }
276 Some("watch") => match args.next() {
277 Some(sub) => {
278 if sub == "sliding" {
279 self.watch_sliding_sync().await?;
280 } else {
281 println!("unknown subcommand for watch: available is 'sliding'");
282 }
283 }
284 None => self.watch().await?,
285 },
286 Some("refresh") => {
287 self.refresh_token().await?;
288 }
289 Some("recover") => {
290 self.recover().await?;
291 }
292 Some("reset-cross-signing") => {
293 self.reset_cross_signing().await?;
294 }
295 Some("logout") => {
296 self.logout().await?;
297 break;
298 }
299 Some("exit") => {
300 break;
301 }
302 Some("help") => {
303 help();
304 }
305 Some(cmd) => {
306 println!("Error: unknown command '{cmd}'\n");
307 help();
308 }
309 None => {
310 println!("Error: no command\n");
311 help()
312 }
313 }
314 }
315
316 Ok(())
317 }
318
319 async fn recover(&self) -> anyhow::Result<()> {
320 let recovery = self.client.encryption().recovery();
321
322 println!("Please enter your recovery key:");
323
324 let mut input = String::new();
325 io::stdin().read_line(&mut input).expect("error: unable to read user input");
326
327 let input = input.trim();
328
329 recovery.recover(input).await?;
330
331 match recovery.state() {
332 RecoveryState::Enabled => println!("Successfully recovered all the E2EE secrets."),
333 RecoveryState::Disabled => println!("Error recovering, recovery is disabled."),
334 RecoveryState::Incomplete => println!("Couldn't recover all E2EE secrets."),
335 _ => unreachable!("We should know our recovery state by now"),
336 }
337
338 Ok(())
339 }
340
341 async fn reset_cross_signing(&self) -> Result<()> {
342 let encryption = self.client.encryption();
343
344 if let Some(handle) = encryption.reset_cross_signing().await? {
345 match handle.auth_type() {
346 CrossSigningResetAuthType::Uiaa(_) => {
347 unimplemented!(
348 "This should never happen, this is after all the OAuth 2.0 example."
349 )
350 }
351 CrossSigningResetAuthType::OAuth(o) => {
352 println!(
353 "To reset your end-to-end encryption cross-signing identity, \
354 you first need to approve it at {}",
355 o.approval_url
356 );
357 handle.auth(None).await?;
358 }
359 }
360 }
361
362 print!("Successfully reset cross-signing");
363
364 Ok(())
365 }
366
367 fn whoami(&self) {
369 let client = &self.client;
370
371 let user_id = client.user_id().expect("A logged in client has a user ID");
372 let device_id = client.device_id().expect("A logged in client has a device ID");
373 let homeserver = client.homeserver();
374
375 println!("\nUser ID: {user_id}");
376 println!("Device ID: {device_id}");
377 println!("Homeserver URL: {homeserver}");
378 }
379
380 async fn account(&self, action: Option<AccountManagementActionData<'_>>) {
382 let Ok(server_metadata) = self.client.oauth().cached_server_metadata().await else {
383 println!("\nCould not retrieve the server metadata");
384 return;
385 };
386
387 let url = if let Some(action) = action {
388 server_metadata.account_management_url_with_action(action)
389 } else {
390 server_metadata.account_management_uri
391 };
392
393 let Some(url) = url else {
394 println!("\nThis homeserver does not provide the URL to manage your account");
395 return;
396 };
397
398 println!("\nTo manage your account, visit: {url}");
399 }
400
401 async fn watch(&self) -> anyhow::Result<()> {
403 let client = &self.client;
404
405 if !self.restored {
409 client.sync_once(SyncSettings::default()).await.unwrap();
410 }
411
412 let handle = client.add_event_handler(on_room_message);
414
415 let mut sync_stream = Box::pin(client.sync_stream(SyncSettings::default()).await);
417 while let Some(res) = sync_stream.next().await {
418 if let Err(err) = res {
419 client.remove_event_handler(handle);
420 return Err(err.into());
421 }
422 }
423
424 Ok(())
425 }
426
427 async fn watch_sliding_sync(&self) -> anyhow::Result<()> {
430 let sync_service = Arc::new(SyncService::builder(self.client.clone()).build().await?);
431
432 sync_service.start().await;
433
434 println!("press enter to exit the sync loop");
435
436 let mut sync_service_state = sync_service.state();
437
438 let sync_service_clone = sync_service.clone();
439 let task = tokio::spawn(async move {
440 let mut num_errors = 0;
449 let mut num_running = 0;
450
451 let mut _unused = String::new();
452 let mut stdin = tokio::io::BufReader::new(tokio::io::stdin());
453
454 loop {
455 tokio::select! {
458 res = sync_service_state.next() => {
459 if let Some(state) = res {
460 match state {
461 matrix_sdk_ui::sync_service::State::Idle
462 | matrix_sdk_ui::sync_service::State::Terminated => {
463 num_errors = 0;
464 num_running = 0;
465 }
466
467 matrix_sdk_ui::sync_service::State::Running => {
468 num_running += 1;
469 if num_running > 1 {
470 num_errors = 0;
471 }
472 }
473
474 matrix_sdk_ui::sync_service::State::Error(_) | matrix_sdk_ui::sync_service::State::Offline => {
475 num_errors += 1;
476 num_running = 0;
477
478 if num_errors == 5 {
479 println!("ran into 5 errors in a row, terminating");
480 break;
481 }
482
483 sync_service_clone.start().await;
484 }
485 }
486 println!("New sync service state update: {state:?}");
487 } else {
488 break;
489 }
490 }
491
492 _ = stdin.read_line(&mut _unused) => {
493 println!("Stopping loop because of user request");
494 sync_service.stop().await;
495
496 break;
497 }
498 }
499 }
500 });
501
502 println!("waiting for sync service to stop...");
503 task.await.unwrap();
504
505 println!("done!");
506 Ok(())
507 }
508
509 fn setup_background_save(&self) {
514 let this = self.clone();
515 tokio::spawn(async move {
516 while let Ok(update) = this.client.subscribe_to_session_changes().recv().await {
517 match update {
518 matrix_sdk::SessionChange::UnknownToken(unknown_token) => {
519 println!(
520 "Received an unknown token error; soft logout? {:?}",
521 unknown_token.soft_logout
522 );
523 }
524 matrix_sdk::SessionChange::TokensRefreshed => {
525 if let Err(err) = this.update_stored_session().await {
527 println!("Unable to store a session in the background: {err}");
528 }
529 }
530 }
531 }
532 });
533 }
534
535 async fn update_stored_session(&self) -> anyhow::Result<()> {
540 println!("Updating the stored session...");
541
542 let serialized_session = fs::read_to_string(&self.session_file).await?;
543 let mut session = serde_json::from_str::<StoredSession>(&serialized_session)?;
544
545 let user_session: UserSession =
546 self.client.oauth().user_session().expect("A logged in client has a session");
547 session.user_session = user_session;
548
549 let serialized_session = serde_json::to_string(&session)?;
550 fs::write(&self.session_file, serialized_session).await?;
551
552 println!("Updating the stored session: done!");
553 Ok(())
554 }
555
556 async fn refresh_token(&self) -> anyhow::Result<()> {
558 self.client.oauth().refresh_access_token().await?;
559
560 println!("\nToken refreshed successfully");
564
565 Ok(())
566 }
567
568 async fn logout(&self) -> anyhow::Result<()> {
570 self.client.logout().await?;
572
573 let data_dir = self.session_file.parent().expect("The file has a parent directory");
575 fs::remove_dir_all(data_dir).await?;
576
577 println!("\nLogged out successfully");
578 println!("\nExiting…");
579
580 Ok(())
581 }
582}
583
584async fn build_client(data_dir: &Path) -> anyhow::Result<(Client, ClientSession)> {
588 let db_path = data_dir.join("db");
589
590 let mut rng = thread_rng();
592 let passphrase: String =
593 (&mut rng).sample_iter(Alphanumeric).take(32).map(char::from).collect();
594
595 loop {
597 let mut homeserver = String::new();
598
599 print!("\nHomeserver: ");
600 io::stdout().flush().expect("Unable to write to stdout");
601 io::stdin().read_line(&mut homeserver).expect("Unable to read user input");
602
603 let homeserver = homeserver.trim();
604
605 println!("\nChecking homeserver…");
606
607 match Client::builder()
608 .server_name_or_homeserver_url(homeserver)
610 .handle_refresh_tokens()
612 .sqlite_store(&db_path, Some(&passphrase))
616 .build()
617 .await
618 {
619 Ok(client) => {
620 match client.oauth().server_metadata().await {
622 Ok(server_metadata) => {
623 println!(
624 "Found OAuth 2.0 server metadata with issuer: {}",
625 server_metadata.issuer
626 );
627
628 let homeserver = client.homeserver().to_string();
629 return Ok((client, ClientSession { homeserver, db_path, passphrase }));
630 }
631 Err(error) => {
632 if error.is_not_supported() {
633 println!(
634 "This homeserver doesn't advertise OAuth 2.0 server metadata."
635 );
636 } else {
637 println!("Error fetching the OAuth 2.0 server metadata: {error:?}");
638 }
639 fs::remove_dir_all(data_dir).await?;
641 }
642 }
643 }
644 Err(error) => match &error {
645 ClientBuildError::AutoDiscovery(_)
646 | ClientBuildError::Url(_)
647 | ClientBuildError::Http(_) => {
648 println!("Error checking the homeserver: {error}");
649 println!("Please try again\n");
650 fs::remove_dir_all(data_dir).await?;
652 }
653 ClientBuildError::InvalidServerName => {
654 println!("Error: not a valid server name");
655 println!("Please try again\n");
656 }
657 _ => {
658 return Err(error.into());
660 }
661 },
662 }
663 }
664}
665
666fn client_metadata() -> Raw<ClientMetadata> {
668 let ipv4_localhost_uri = Url::parse(&format!("http://{}/", Ipv4Addr::LOCALHOST))
672 .expect("Couldn't parse IPv4 redirect URI");
673 let ipv6_localhost_uri = Url::parse(&format!("http://[{}]/", Ipv6Addr::LOCALHOST))
674 .expect("Couldn't parse IPv6 redirect URI");
675 let client_uri = Localized::new(
676 Url::parse("https://github.com/matrix-org/matrix-rust-sdk")
677 .expect("Couldn't parse client URI"),
678 None,
679 );
680
681 let metadata = ClientMetadata {
682 client_name: Some(Localized::new("matrix-rust-sdk-oauth-cli".to_owned(), [])),
687 policy_uri: Some(client_uri.clone()),
688 tos_uri: Some(client_uri.clone()),
689 ..ClientMetadata::new(
690 ApplicationType::Native,
693 vec![OAuthGrantType::AuthorizationCode {
695 redirect_uris: vec![ipv4_localhost_uri, ipv6_localhost_uri],
696 }],
697 client_uri,
698 )
699 };
700
701 Raw::new(&metadata).expect("Couldn't serialize client metadata")
702}
703
704async fn use_auth_url(url: &Url, server_handle: LocalServerRedirectHandle) -> Option<QueryString> {
708 println!("\nPlease authenticate yourself at: {url}\n");
709 println!("Then proceed to the authorization.\n");
710
711 server_handle.await
712}
713
714async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) {
716 if room.state() != RoomState::Joined {
718 return;
719 }
720 let MessageType::Text(text_content) = &event.content.msgtype else { return };
721
722 let room_name = match room.display_name().await {
723 Ok(room_name) => room_name.to_string(),
724 Err(error) => {
725 println!("Error getting room display name: {error}");
726 room.room_id().to_string()
728 }
729 };
730
731 println!("[{room_name}] {}: {}", event.sender, text_content.body)
732}