1use std::{
16 io::{self, Write},
17 net::{Ipv4Addr, Ipv6Addr},
18 path::{Path, PathBuf},
19 sync::Arc,
20};
21
22use anyhow::bail;
23use futures_util::StreamExt;
24use matrix_sdk::{
25 Client, ClientBuildError, Result, RoomState,
26 authentication::oauth::{
27 AccountManagementActionFull, ClientId, OAuthAuthorizationData, OAuthError, OAuthSession,
28 UrlOrQuery, UserSession,
29 error::OAuthClientRegistrationError,
30 registration::{ApplicationType, ClientMetadata, Localized, OAuthGrantType},
31 },
32 config::SyncSettings,
33 encryption::{CrossSigningResetAuthType, recovery::RecoveryState},
34 room::Room,
35 ruma::{
36 events::room::message::{MessageType, OriginalSyncRoomMessageEvent},
37 serde::Raw,
38 },
39 utils::local_server::{LocalServerBuilder, LocalServerRedirectHandle, QueryString},
40};
41use matrix_sdk_ui::sync_service::SyncService;
42use rand::{Rng, distributions::Alphanumeric, thread_rng};
43use serde::{Deserialize, Serialize};
44use tokio::{fs, io::AsyncBufReadExt as _};
45use url::Url;
46
47#[tokio::main]
63async fn main() -> anyhow::Result<()> {
64 tracing_subscriber::fmt::init();
65
66 let data_dir =
68 dirs::data_dir().expect("no data_dir directory found").join("matrix_sdk/oauth_cli");
69 let session_file = data_dir.join("session.json");
71
72 let cli = if session_file.exists() {
73 OAuthCli::from_stored_session(session_file).await?
74 } else {
75 OAuthCli::new(&data_dir, session_file).await?
76 };
77
78 cli.run().await
79}
80
81fn help() {
83 println!("Usage: [command] [args…]\n");
84 println!("Commands:");
85 println!(" whoami Get information about this session");
86 println!(" account Get the URL to manage this account");
87 println!(" watch [sliding?] Watch new incoming messages until an error occurs");
88 println!(" refresh Refresh the access token");
89 println!(" recover Recover the E2EE secrets from secret storage");
90 println!(" logout Log out of this account");
91 println!(" exit Exit this program");
92 println!(" help Show this message\n");
93}
94
95#[derive(Debug, Serialize, Deserialize)]
97struct ClientSession {
98 homeserver: String,
100
101 db_path: PathBuf,
103
104 passphrase: String,
106}
107
108#[derive(Debug, Serialize, Deserialize)]
110struct StoredSession {
111 client_session: ClientSession,
113
114 user_session: UserSession,
116
117 client_id: ClientId,
119}
120
121#[derive(Clone, Debug)]
123struct OAuthCli {
124 client: Client,
126
127 restored: bool,
129
130 session_file: PathBuf,
132}
133
134impl OAuthCli {
135 async fn new(data_dir: &Path, session_file: PathBuf) -> anyhow::Result<Self> {
137 println!("No previous session found, logging in…");
138
139 let (client, client_session) = build_client(data_dir).await?;
140 let cli = Self { client, restored: false, session_file };
141
142 if let Err(error) = cli.register_and_login().await {
143 if let Some(error) = error.downcast_ref::<OAuthError>()
144 && let OAuthError::ClientRegistration(OAuthClientRegistrationError::NotSupported) =
145 error
146 {
147 bail!(
150 "This server doesn't support dynamic registration.\n\
151 Please select another homeserver."
152 );
153 } else {
154 return Err(error);
155 }
156 }
157
158 let full_session =
162 cli.client.oauth().full_session().expect("A logged-in client should have a session");
163
164 let serialized_session = serde_json::to_string(&StoredSession {
165 client_session,
166 user_session: full_session.user,
167 client_id: full_session.client_id,
168 })?;
169 fs::write(&cli.session_file, serialized_session).await?;
170
171 println!("Session persisted in {}", cli.session_file.to_string_lossy());
172
173 cli.setup_background_save();
174
175 Ok(cli)
176 }
177
178 async fn register_and_login(&self) -> anyhow::Result<()> {
181 let oauth = self.client.oauth();
182
183 loop {
185 let (redirect_uri, server_handle) = LocalServerBuilder::new().spawn().await?;
189
190 let OAuthAuthorizationData { url, .. } = oauth
191 .login(redirect_uri, None, Some(client_metadata().into()), None)
192 .build()
193 .await?;
194
195 let query_string =
196 use_auth_url(&url, server_handle).await.map(|query| query.0).unwrap_or_default();
197
198 match oauth.finish_login(UrlOrQuery::Query(query_string)).await {
199 Ok(()) => {
200 let user_id = self.client.user_id().expect("Got a user ID");
201 println!("Logged in as {user_id}");
202 break;
203 }
204 Err(err) => {
205 println!("Error: failed to login: {err}");
206 println!("Please try again.\n");
207 continue;
208 }
209 }
210 }
211
212 Ok(())
213 }
214
215 async fn from_stored_session(session_file: PathBuf) -> anyhow::Result<Self> {
217 println!("Previous session found in '{}'", session_file.to_string_lossy());
218
219 let serialized_session = fs::read_to_string(&session_file).await?;
221 let StoredSession { client_session, user_session, client_id } =
222 serde_json::from_str(&serialized_session)?;
223
224 let client = Client::builder()
226 .homeserver_url(client_session.homeserver)
227 .handle_refresh_tokens()
228 .sqlite_store(client_session.db_path, Some(&client_session.passphrase))
229 .build()
230 .await?;
231
232 println!("Restoring session for {}…", user_session.meta.user_id);
233
234 let session = OAuthSession { client_id, user: user_session };
235 client.restore_session(session).await?;
237
238 let this = Self { client, restored: true, session_file };
239
240 this.setup_background_save();
241
242 Ok(this)
243 }
244
245 async fn run(&self) -> anyhow::Result<()> {
247 help();
248
249 loop {
250 let mut input = String::new();
251
252 print!("\nEnter command: ");
253 io::stdout().flush().expect("Unable to write to stdout");
254
255 io::stdin().read_line(&mut input).expect("Unable to read user input");
256
257 let mut args = input.trim().split_ascii_whitespace();
258 let cmd = args.next();
259
260 match cmd {
261 Some("whoami") => {
262 self.whoami();
263 }
264 Some("account") => {
265 self.account(None).await;
266 }
267 Some("profile") => {
268 self.account(Some(AccountManagementActionFull::Profile)).await;
269 }
270 Some("sessions") => {
271 self.account(Some(AccountManagementActionFull::SessionsList)).await;
272 }
273 Some("watch") => match args.next() {
274 Some(sub) => {
275 if sub == "sliding" {
276 self.watch_sliding_sync().await?;
277 } else {
278 println!("unknown subcommand for watch: available is 'sliding'");
279 }
280 }
281 None => self.watch().await?,
282 },
283 Some("refresh") => {
284 self.refresh_token().await?;
285 }
286 Some("recover") => {
287 self.recover().await?;
288 }
289 Some("reset-cross-signing") => {
290 self.reset_cross_signing().await?;
291 }
292 Some("logout") => {
293 self.logout().await?;
294 break;
295 }
296 Some("exit") => {
297 break;
298 }
299 Some("help") => {
300 help();
301 }
302 Some(cmd) => {
303 println!("Error: unknown command '{cmd}'\n");
304 help();
305 }
306 None => {
307 println!("Error: no command\n");
308 help()
309 }
310 }
311 }
312
313 Ok(())
314 }
315
316 async fn recover(&self) -> anyhow::Result<()> {
317 let recovery = self.client.encryption().recovery();
318
319 println!("Please enter your recovery key:");
320
321 let mut input = String::new();
322 io::stdin().read_line(&mut input).expect("error: unable to read user input");
323
324 let input = input.trim();
325
326 recovery.recover(input).await?;
327
328 match recovery.state() {
329 RecoveryState::Enabled => println!("Successfully recovered all the E2EE secrets."),
330 RecoveryState::Disabled => println!("Error recovering, recovery is disabled."),
331 RecoveryState::Incomplete => println!("Couldn't recover all E2EE secrets."),
332 _ => unreachable!("We should know our recovery state by now"),
333 }
334
335 Ok(())
336 }
337
338 async fn reset_cross_signing(&self) -> Result<()> {
339 let encryption = self.client.encryption();
340
341 if let Some(handle) = encryption.reset_cross_signing().await? {
342 match handle.auth_type() {
343 CrossSigningResetAuthType::Uiaa(_) => {
344 unimplemented!(
345 "This should never happen, this is after all the OAuth 2.0 example."
346 )
347 }
348 CrossSigningResetAuthType::OAuth(o) => {
349 println!(
350 "To reset your end-to-end encryption cross-signing identity, \
351 you first need to approve it at {}",
352 o.approval_url
353 );
354 handle.auth(None).await?;
355 }
356 }
357 }
358
359 print!("Successfully reset cross-signing");
360
361 Ok(())
362 }
363
364 fn whoami(&self) {
366 let client = &self.client;
367
368 let user_id = client.user_id().expect("A logged in client has a user ID");
369 let device_id = client.device_id().expect("A logged in client has a device ID");
370 let homeserver = client.homeserver();
371
372 println!("\nUser ID: {user_id}");
373 println!("Device ID: {device_id}");
374 println!("Homeserver URL: {homeserver}");
375 }
376
377 async fn account(&self, action: Option<AccountManagementActionFull>) {
379 let Ok(Some(mut url_builder)) = self.client.oauth().fetch_account_management_url().await
380 else {
381 println!("\nThis homeserver does not provide the URL to manage your account");
382 return;
383 };
384
385 if let Some(action) = action {
386 url_builder = url_builder.action(action);
387 }
388
389 let url = url_builder.build();
390 println!("\nTo manage your account, visit: {url}");
391 }
392
393 async fn watch(&self) -> anyhow::Result<()> {
395 let client = &self.client;
396
397 if !self.restored {
401 client.sync_once(SyncSettings::default()).await.unwrap();
402 }
403
404 let handle = client.add_event_handler(on_room_message);
406
407 let mut sync_stream = Box::pin(client.sync_stream(SyncSettings::default()).await);
409 while let Some(res) = sync_stream.next().await {
410 if let Err(err) = res {
411 client.remove_event_handler(handle);
412 return Err(err.into());
413 }
414 }
415
416 Ok(())
417 }
418
419 async fn watch_sliding_sync(&self) -> anyhow::Result<()> {
422 let sync_service = Arc::new(SyncService::builder(self.client.clone()).build().await?);
423
424 sync_service.start().await;
425
426 println!("press enter to exit the sync loop");
427
428 let mut sync_service_state = sync_service.state();
429
430 let sync_service_clone = sync_service.clone();
431 let task = tokio::spawn(async move {
432 let mut num_errors = 0;
441 let mut num_running = 0;
442
443 let mut _unused = String::new();
444 let mut stdin = tokio::io::BufReader::new(tokio::io::stdin());
445
446 loop {
447 tokio::select! {
450 res = sync_service_state.next() => {
451 if let Some(state) = res {
452 match state {
453 matrix_sdk_ui::sync_service::State::Idle
454 | matrix_sdk_ui::sync_service::State::Terminated => {
455 num_errors = 0;
456 num_running = 0;
457 }
458
459 matrix_sdk_ui::sync_service::State::Running => {
460 num_running += 1;
461 if num_running > 1 {
462 num_errors = 0;
463 }
464 }
465
466 matrix_sdk_ui::sync_service::State::Error | matrix_sdk_ui::sync_service::State::Offline => {
467 num_errors += 1;
468 num_running = 0;
469
470 if num_errors == 5 {
471 println!("ran into 5 errors in a row, terminating");
472 break;
473 }
474
475 sync_service_clone.start().await;
476 }
477 }
478 println!("New sync service state update: {state:?}");
479 } else {
480 break;
481 }
482 }
483
484 _ = stdin.read_line(&mut _unused) => {
485 println!("Stopping loop because of user request");
486 sync_service.stop().await;
487
488 break;
489 }
490 }
491 }
492 });
493
494 println!("waiting for sync service to stop...");
495 task.await.unwrap();
496
497 println!("done!");
498 Ok(())
499 }
500
501 fn setup_background_save(&self) {
506 let this = self.clone();
507 tokio::spawn(async move {
508 while let Ok(update) = this.client.subscribe_to_session_changes().recv().await {
509 match update {
510 matrix_sdk::SessionChange::UnknownToken { soft_logout } => {
511 println!("Received an unknown token error; soft logout? {soft_logout:?}");
512 }
513 matrix_sdk::SessionChange::TokensRefreshed => {
514 if let Err(err) = this.update_stored_session().await {
516 println!("Unable to store a session in the background: {err}");
517 }
518 }
519 }
520 }
521 });
522 }
523
524 async fn update_stored_session(&self) -> anyhow::Result<()> {
529 println!("Updating the stored session...");
530
531 let serialized_session = fs::read_to_string(&self.session_file).await?;
532 let mut session = serde_json::from_str::<StoredSession>(&serialized_session)?;
533
534 let user_session: UserSession =
535 self.client.oauth().user_session().expect("A logged in client has a session");
536 session.user_session = user_session;
537
538 let serialized_session = serde_json::to_string(&session)?;
539 fs::write(&self.session_file, serialized_session).await?;
540
541 println!("Updating the stored session: done!");
542 Ok(())
543 }
544
545 async fn refresh_token(&self) -> anyhow::Result<()> {
547 self.client.oauth().refresh_access_token().await?;
548
549 println!("\nToken refreshed successfully");
553
554 Ok(())
555 }
556
557 async fn logout(&self) -> anyhow::Result<()> {
559 self.client.logout().await?;
561
562 let data_dir = self.session_file.parent().expect("The file has a parent directory");
564 fs::remove_dir_all(data_dir).await?;
565
566 println!("\nLogged out successfully");
567 println!("\nExiting…");
568
569 Ok(())
570 }
571}
572
573async fn build_client(data_dir: &Path) -> anyhow::Result<(Client, ClientSession)> {
577 let db_path = data_dir.join("db");
578
579 let mut rng = thread_rng();
581 let passphrase: String =
582 (&mut rng).sample_iter(Alphanumeric).take(32).map(char::from).collect();
583
584 loop {
586 let mut homeserver = String::new();
587
588 print!("\nHomeserver: ");
589 io::stdout().flush().expect("Unable to write to stdout");
590 io::stdin().read_line(&mut homeserver).expect("Unable to read user input");
591
592 let homeserver = homeserver.trim();
593
594 println!("\nChecking homeserver…");
595
596 match Client::builder()
597 .server_name_or_homeserver_url(homeserver)
599 .handle_refresh_tokens()
601 .sqlite_store(&db_path, Some(&passphrase))
605 .build()
606 .await
607 {
608 Ok(client) => {
609 match client.oauth().server_metadata().await {
611 Ok(server_metadata) => {
612 println!(
613 "Found OAuth 2.0 server metadata with issuer: {}",
614 server_metadata.issuer
615 );
616
617 let homeserver = client.homeserver().to_string();
618 return Ok((client, ClientSession { homeserver, db_path, passphrase }));
619 }
620 Err(error) => {
621 if error.is_not_supported() {
622 println!(
623 "This homeserver doesn't advertise OAuth 2.0 server metadata."
624 );
625 } else {
626 println!("Error fetching the OAuth 2.0 server metadata: {error:?}");
627 }
628 fs::remove_dir_all(data_dir).await?;
630 }
631 }
632 }
633 Err(error) => match &error {
634 ClientBuildError::AutoDiscovery(_)
635 | ClientBuildError::Url(_)
636 | ClientBuildError::Http(_) => {
637 println!("Error checking the homeserver: {error}");
638 println!("Please try again\n");
639 fs::remove_dir_all(data_dir).await?;
641 }
642 ClientBuildError::InvalidServerName => {
643 println!("Error: not a valid server name");
644 println!("Please try again\n");
645 }
646 _ => {
647 return Err(error.into());
649 }
650 },
651 }
652 }
653}
654
655fn client_metadata() -> Raw<ClientMetadata> {
657 let ipv4_localhost_uri = Url::parse(&format!("http://{}/", Ipv4Addr::LOCALHOST))
661 .expect("Couldn't parse IPv4 redirect URI");
662 let ipv6_localhost_uri = Url::parse(&format!("http://[{}]/", Ipv6Addr::LOCALHOST))
663 .expect("Couldn't parse IPv6 redirect URI");
664 let client_uri = Localized::new(
665 Url::parse("https://github.com/matrix-org/matrix-rust-sdk")
666 .expect("Couldn't parse client URI"),
667 None,
668 );
669
670 let metadata = ClientMetadata {
671 client_name: Some(Localized::new("matrix-rust-sdk-oauth-cli".to_owned(), [])),
676 policy_uri: Some(client_uri.clone()),
677 tos_uri: Some(client_uri.clone()),
678 ..ClientMetadata::new(
679 ApplicationType::Native,
682 vec![OAuthGrantType::AuthorizationCode {
684 redirect_uris: vec![ipv4_localhost_uri, ipv6_localhost_uri],
685 }],
686 client_uri,
687 )
688 };
689
690 Raw::new(&metadata).expect("Couldn't serialize client metadata")
691}
692
693async fn use_auth_url(url: &Url, server_handle: LocalServerRedirectHandle) -> Option<QueryString> {
697 println!("\nPlease authenticate yourself at: {url}\n");
698 println!("Then proceed to the authorization.\n");
699
700 server_handle.await
701}
702
703async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) {
705 if room.state() != RoomState::Joined {
707 return;
708 }
709 let MessageType::Text(text_content) = &event.content.msgtype else { return };
710
711 let room_name = match room.display_name().await {
712 Ok(room_name) => room_name.to_string(),
713 Err(error) => {
714 println!("Error getting room display name: {error}");
715 room.room_id().to_string()
717 }
718 };
719
720 println!("[{room_name}] {}: {}", event.sender, text_content.body)
721}