Skip to main content

example_oauth_cli/
main.rs

1// Copyright 2023 Kévin Commaille.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    io::{self, Write},
17    net::{Ipv4Addr, Ipv6Addr},
18    path::{Path, PathBuf},
19    sync::Arc,
20};
21
22use anyhow::bail;
23use futures_util::StreamExt;
24use matrix_sdk::{
25    Client, ClientBuildError, Result, RoomState,
26    authentication::oauth::{
27        ClientId, OAuthAuthorizationData, OAuthError, OAuthSession, UserSession,
28        error::OAuthClientRegistrationError,
29        registration::{ApplicationType, ClientMetadata, Localized, OAuthGrantType},
30    },
31    config::SyncSettings,
32    encryption::{CrossSigningResetAuthType, recovery::RecoveryState},
33    room::Room,
34    ruma::{
35        api::client::{
36            discovery::get_authorization_server_metadata::v1::AccountManagementActionData, uiaa,
37        },
38        events::room::message::{MessageType, OriginalSyncRoomMessageEvent},
39        serde::Raw,
40    },
41    utils::local_server::{LocalServerBuilder, LocalServerRedirectHandle, QueryString},
42};
43use matrix_sdk_ui::sync_service::SyncService;
44use rand::{RngExt, distr::Alphanumeric, rng};
45use serde::{Deserialize, Serialize};
46use tokio::{fs, io::AsyncBufReadExt as _};
47use url::Url;
48
49/// A command-line tool to demonstrate the steps requiring an interaction with
50/// an OAuth 2.0 authorization server for a Matrix client, using the
51/// Authorization Code flow.
52///
53/// You can test this against one of the servers from the OIDC playground:
54/// <https://github.com/element-hq/oidc-playground>.
55///
56/// To use this, just run `cargo run -p example-oauth-cli`, and everything
57/// is interactive after that. You might want to set the `RUST_LOG` environment
58/// variable to `warn` to reduce the noise in the logs. The program exits
59/// whenever an unexpected error occurs.
60///
61/// To reset the login, simply use the `logout` command or delete the folder
62/// containing the session file, the location is shown in the logs. Note that
63/// the database must be deleted too as it can't be reused.
64#[tokio::main]
65async fn main() -> anyhow::Result<()> {
66    tracing_subscriber::fmt::init();
67
68    // The folder containing this example's data.
69    let data_dir =
70        dirs::data_dir().expect("no data_dir directory found").join("matrix_sdk/oauth_cli");
71    // The file where the session is persisted.
72    let session_file = data_dir.join("session.json");
73
74    let cli = if session_file.exists() {
75        OAuthCli::from_stored_session(session_file).await?
76    } else {
77        OAuthCli::new(&data_dir, session_file).await?
78    };
79
80    cli.run().await
81}
82
83/// The available commands once the client is logged in.
84fn help() {
85    println!("Usage: [command] [args…]\n");
86    println!("Commands:");
87    println!("  whoami                 Get information about this session");
88    println!("  account                Get the URL to manage this account");
89    println!("  watch [sliding?]       Watch new incoming messages until an error occurs");
90    println!("  refresh                Refresh the access token");
91    println!("  recover                Recover the E2EE secrets from secret storage");
92    println!("  logout                 Log out of this account");
93    println!("  exit                   Exit this program");
94    println!("  help                   Show this message\n");
95}
96
97/// The data needed to re-build a client.
98#[derive(Debug, Serialize, Deserialize)]
99struct ClientSession {
100    /// The URL of the homeserver of the user.
101    homeserver: String,
102
103    /// The path of the database.
104    db_path: PathBuf,
105
106    /// The passphrase of the database.
107    passphrase: String,
108}
109
110/// The full session to persist.
111#[derive(Debug, Serialize, Deserialize)]
112struct StoredSession {
113    /// The data to re-build the client.
114    client_session: ClientSession,
115
116    /// The OAuth 2.0 user session.
117    user_session: UserSession,
118
119    /// The OAuth 2.0 client ID.
120    client_id: ClientId,
121}
122
123/// An OAuth 2.0 CLI.
124#[derive(Clone, Debug)]
125struct OAuthCli {
126    /// The Matrix client.
127    client: Client,
128
129    /// Whether this is a restored client.
130    restored: bool,
131
132    /// The path to the file storing the session.
133    session_file: PathBuf,
134}
135
136impl OAuthCli {
137    /// Create a new session by logging in.
138    async fn new(data_dir: &Path, session_file: PathBuf) -> anyhow::Result<Self> {
139        println!("No previous session found, logging in…");
140
141        let (client, client_session) = build_client(data_dir).await?;
142        let cli = Self { client, restored: false, session_file };
143
144        if let Err(error) = cli.register_and_login().await {
145            if let Some(error) = error.downcast_ref::<OAuthError>()
146                && let OAuthError::ClientRegistration(OAuthClientRegistrationError::NotSupported) =
147                    error
148            {
149                // This would require to register with the authorization server manually, which
150                // we don't support here.
151                bail!(
152                    "This server doesn't support dynamic registration.\n\
153                     Please select another homeserver."
154                );
155            } else {
156                return Err(error);
157            }
158        }
159
160        // Persist the session to reuse it later.
161        // This is not very secure, for simplicity. If the system provides a way of
162        // storing secrets securely, it should be used instead.
163        let full_session =
164            cli.client.oauth().full_session().expect("A logged-in client should have a session");
165
166        let serialized_session = serde_json::to_string(&StoredSession {
167            client_session,
168            user_session: full_session.user,
169            client_id: full_session.client_id,
170        })?;
171        fs::write(&cli.session_file, serialized_session).await?;
172
173        println!("Session persisted in {}", cli.session_file.to_string_lossy());
174
175        cli.setup_background_save();
176
177        Ok(cli)
178    }
179
180    /// Register the client and log in the user via the OAuth 2.0 Authorization
181    /// Code flow.
182    async fn register_and_login(&self) -> anyhow::Result<()> {
183        let oauth = self.client.oauth();
184
185        // We create a loop here so the user can retry if an error happens.
186        loop {
187            // Here we spawn a server to listen on the loopback interface. Another option
188            // would be to register a custom URI scheme with the system and handle
189            // the redirect when the custom URI scheme is opened.
190            let (redirect_uri, server_handle) = LocalServerBuilder::new().spawn().await?;
191
192            let OAuthAuthorizationData { url, .. } = oauth
193                .login(redirect_uri, None, Some(client_metadata().into()), None)
194                .build()
195                .await?;
196
197            let Some(query_string) = use_auth_url(&url, server_handle).await else {
198                println!("Error: failed to login: missing query string on the redirect URL");
199                println!("Please try again.\n");
200                continue;
201            };
202
203            match oauth.finish_login(query_string.into()).await {
204                Ok(()) => {
205                    let user_id = self.client.user_id().expect("Got a user ID");
206                    println!("Logged in as {user_id}");
207                    break;
208                }
209                Err(err) => {
210                    println!("Error: failed to login: {err}");
211                    println!("Please try again.\n");
212                    continue;
213                }
214            }
215        }
216
217        Ok(())
218    }
219
220    /// Restore a previous session from a file.
221    async fn from_stored_session(session_file: PathBuf) -> anyhow::Result<Self> {
222        println!("Previous session found in '{}'", session_file.to_string_lossy());
223
224        // The session was serialized as JSON in a file.
225        let serialized_session = fs::read_to_string(&session_file).await?;
226        let StoredSession { client_session, user_session, client_id } =
227            serde_json::from_str(&serialized_session)?;
228
229        // Build the client with the previous settings from the session.
230        let client = Client::builder()
231            .homeserver_url(client_session.homeserver)
232            .handle_refresh_tokens()
233            .sqlite_store(client_session.db_path, Some(&client_session.passphrase))
234            .build()
235            .await?;
236
237        println!("Restoring session for {}…", user_session.meta.user_id);
238
239        let session = OAuthSession { client_id, user: user_session };
240        // Restore the Matrix user session.
241        client.restore_session(session).await?;
242
243        let this = Self { client, restored: true, session_file };
244
245        this.setup_background_save();
246
247        Ok(this)
248    }
249
250    /// Run the main program.
251    async fn run(&self) -> anyhow::Result<()> {
252        help();
253
254        loop {
255            let mut input = String::new();
256
257            print!("\nEnter command: ");
258            io::stdout().flush().expect("Unable to write to stdout");
259
260            io::stdin().read_line(&mut input).expect("Unable to read user input");
261
262            let mut args = input.trim().split_ascii_whitespace();
263            let cmd = args.next();
264
265            match cmd {
266                Some("whoami") => {
267                    self.whoami();
268                }
269                Some("account") => {
270                    self.account(None).await;
271                }
272                Some("profile") => {
273                    self.account(Some(AccountManagementActionData::Profile)).await;
274                }
275                Some("devices") => {
276                    self.account(Some(AccountManagementActionData::DevicesList)).await;
277                }
278                Some("watch") => match args.next() {
279                    Some(sub) => {
280                        if sub == "sliding" {
281                            self.watch_sliding_sync().await?;
282                        } else {
283                            println!("unknown subcommand for watch: available is 'sliding'");
284                        }
285                    }
286                    None => self.watch().await?,
287                },
288                Some("refresh") => {
289                    self.refresh_token().await?;
290                }
291                Some("recover") => {
292                    self.recover().await?;
293                }
294                Some("reset-cross-signing") => {
295                    self.reset_cross_signing().await?;
296                }
297                Some("logout") => {
298                    self.logout().await?;
299                    break;
300                }
301                Some("exit") => {
302                    break;
303                }
304                Some("help") => {
305                    help();
306                }
307                Some(cmd) => {
308                    println!("Error: unknown command '{cmd}'\n");
309                    help();
310                }
311                None => {
312                    println!("Error: no command\n");
313                    help()
314                }
315            }
316        }
317
318        Ok(())
319    }
320
321    async fn recover(&self) -> anyhow::Result<()> {
322        let recovery = self.client.encryption().recovery();
323
324        println!("Please enter your recovery key:");
325
326        let mut input = String::new();
327        io::stdin().read_line(&mut input).expect("error: unable to read user input");
328
329        let input = input.trim();
330
331        recovery.recover(input).await?;
332
333        match recovery.state() {
334            RecoveryState::Enabled => println!("Successfully recovered all the E2EE secrets."),
335            RecoveryState::Disabled => println!("Error recovering, recovery is disabled."),
336            RecoveryState::Incomplete => println!("Couldn't recover all E2EE secrets."),
337            _ => unreachable!("We should know our recovery state by now"),
338        }
339
340        Ok(())
341    }
342
343    async fn reset_cross_signing(&self) -> Result<()> {
344        let encryption = self.client.encryption();
345
346        if let Some(handle) = encryption.reset_cross_signing().await? {
347            match handle.auth_type() {
348                CrossSigningResetAuthType::Uiaa(_) => {
349                    unimplemented!(
350                        "This should never happen, this is after all the OAuth 2.0 example."
351                    )
352                }
353                CrossSigningResetAuthType::OAuth(o) => {
354                    println!(
355                        "To reset your end-to-end encryption cross-signing identity, \
356                        you first need to approve it at {}",
357                        o.approval_url
358                    );
359
360                    let mut oauth_data = uiaa::OAuth::new();
361                    oauth_data.session = o.session.clone();
362                    handle.auth(Some(uiaa::AuthData::OAuth(oauth_data))).await?;
363                }
364            }
365        }
366
367        print!("Successfully reset cross-signing");
368
369        Ok(())
370    }
371
372    /// Get information about this session.
373    fn whoami(&self) {
374        let client = &self.client;
375
376        let user_id = client.user_id().expect("A logged in client has a user ID");
377        let device_id = client.device_id().expect("A logged in client has a device ID");
378        let homeserver = client.homeserver();
379
380        println!("\nUser ID: {user_id}");
381        println!("Device ID: {device_id}");
382        println!("Homeserver URL: {homeserver}");
383    }
384
385    /// Get the account management URL.
386    async fn account(&self, action: Option<AccountManagementActionData<'_>>) {
387        let Ok(server_metadata) = self.client.oauth().cached_server_metadata().await else {
388            println!("\nCould not retrieve the server metadata");
389            return;
390        };
391
392        let url = if let Some(action) = action {
393            server_metadata.account_management_url_with_action(action)
394        } else {
395            server_metadata.account_management_uri
396        };
397
398        let Some(url) = url else {
399            println!("\nThis homeserver does not provide the URL to manage your account");
400            return;
401        };
402
403        println!("\nTo manage your account, visit: {url}");
404    }
405
406    /// Watch incoming messages.
407    async fn watch(&self) -> anyhow::Result<()> {
408        let client = &self.client;
409
410        // If this is a new client, ignore previous messages to not fill the logs.
411        // Note that this might not work as intended, the initial sync might have failed
412        // in a previous session.
413        if !self.restored {
414            client.sync_once(SyncSettings::default()).await.unwrap();
415        }
416
417        // Listen to room messages.
418        let handle = client.add_event_handler(on_room_message);
419
420        // Sync.
421        let mut sync_stream = Box::pin(client.sync_stream(SyncSettings::default()).await);
422        while let Some(res) = sync_stream.next().await {
423            if let Err(err) = res {
424                client.remove_event_handler(handle);
425                return Err(err.into());
426            }
427        }
428
429        Ok(())
430    }
431
432    /// This watches for incoming responses using the high-level sliding sync
433    /// helpers (`SyncService`).
434    async fn watch_sliding_sync(&self) -> anyhow::Result<()> {
435        let sync_service = Arc::new(SyncService::builder(self.client.clone()).build().await?);
436
437        sync_service.start().await;
438
439        println!("press enter to exit the sync loop");
440
441        let mut sync_service_state = sync_service.state();
442
443        let sync_service_clone = sync_service.clone();
444        let task = tokio::spawn(async move {
445            // Only fail after getting 5 errors in a row. When we're in an always-refail
446            // scenario, we move from the Error to the Running state for a bit
447            // until we fail again, so we need to track both failure state and
448            // running state, hence `num_errors` and `num_running`:
449            // - if we failed and num_running was 1, then this is a failure following a
450            //   failure.
451            // - otherwise, we recovered from the failure and we can plain continue.
452
453            let mut num_errors = 0;
454            let mut num_running = 0;
455
456            let mut _unused = String::new();
457            let mut stdin = tokio::io::BufReader::new(tokio::io::stdin());
458
459            loop {
460                // Concurrently wait for an update from the sync service OR for the user to
461                // press enter and leave early.
462                tokio::select! {
463                    res = sync_service_state.next() => {
464                        if let Some(state) = res {
465                            match state {
466                                matrix_sdk_ui::sync_service::State::Idle
467                                | matrix_sdk_ui::sync_service::State::Terminated => {
468                                    num_errors = 0;
469                                    num_running = 0;
470                                }
471
472                                matrix_sdk_ui::sync_service::State::Running => {
473                                    num_running += 1;
474                                    if num_running > 1 {
475                                        num_errors = 0;
476                                    }
477                                }
478
479                                matrix_sdk_ui::sync_service::State::Error(_) | matrix_sdk_ui::sync_service::State::Offline => {
480                                    num_errors += 1;
481                                    num_running = 0;
482
483                                    if num_errors == 5 {
484                                        println!("ran into 5 errors in a row, terminating");
485                                        break;
486                                    }
487
488                                    sync_service_clone.start().await;
489                                }
490                            }
491                            println!("New sync service state update: {state:?}");
492                        } else {
493                            break;
494                        }
495                    }
496
497                    _ = stdin.read_line(&mut _unused) => {
498                        println!("Stopping loop because of user request");
499                        sync_service.stop().await;
500
501                        break;
502                    }
503                }
504            }
505        });
506
507        println!("waiting for sync service to stop...");
508        task.await.unwrap();
509
510        println!("done!");
511        Ok(())
512    }
513
514    /// Sets up this client so that it automatically saves the session onto disk
515    /// whenever there are new tokens that have been received.
516    ///
517    /// This should always be set up whenever automatic refresh is happening.
518    fn setup_background_save(&self) {
519        let this = self.clone();
520        tokio::spawn(async move {
521            while let Ok(update) = this.client.subscribe_to_session_changes().recv().await {
522                match update {
523                    matrix_sdk::SessionChange::UnknownToken(unknown_token) => {
524                        println!(
525                            "Received an unknown token error; soft logout? {:?}",
526                            unknown_token.soft_logout
527                        );
528                    }
529                    matrix_sdk::SessionChange::TokensRefreshed => {
530                        // The tokens have been refreshed, persist them to disk.
531                        if let Err(err) = this.update_stored_session().await {
532                            println!("Unable to store a session in the background: {err}");
533                        }
534                    }
535                }
536            }
537        });
538    }
539
540    /// Update the session stored on the system.
541    ///
542    /// This should be called everytime the access token (and possibly refresh
543    /// token) has changed.
544    async fn update_stored_session(&self) -> anyhow::Result<()> {
545        println!("Updating the stored session...");
546
547        let serialized_session = fs::read_to_string(&self.session_file).await?;
548        let mut session = serde_json::from_str::<StoredSession>(&serialized_session)?;
549
550        let user_session: UserSession =
551            self.client.oauth().user_session().expect("A logged in client has a session");
552        session.user_session = user_session;
553
554        let serialized_session = serde_json::to_string(&session)?;
555        fs::write(&self.session_file, serialized_session).await?;
556
557        println!("Updating the stored session: done!");
558        Ok(())
559    }
560
561    /// Refresh the access token.
562    async fn refresh_token(&self) -> anyhow::Result<()> {
563        self.client.oauth().refresh_access_token().await?;
564
565        // The session will automatically be refreshed because of the task persisting
566        // the full session upon refresh in `setup_background_save`.
567
568        println!("\nToken refreshed successfully");
569
570        Ok(())
571    }
572
573    /// Log out from this session.
574    async fn logout(&self) -> anyhow::Result<()> {
575        // Log out via OAuth 2.0.
576        self.client.logout().await?;
577
578        // Delete the stored session and database.
579        let data_dir = self.session_file.parent().expect("The file has a parent directory");
580        fs::remove_dir_all(data_dir).await?;
581
582        println!("\nLogged out successfully");
583        println!("\nExiting…");
584
585        Ok(())
586    }
587}
588
589/// Build a new client.
590///
591/// Returns the client and the data required to restore the client.
592async fn build_client(data_dir: &Path) -> anyhow::Result<(Client, ClientSession)> {
593    let db_path = data_dir.join("db");
594
595    // Generate a random passphrase.
596    let mut rng = rng();
597    let passphrase: String =
598        (&mut rng).sample_iter(Alphanumeric).take(32).map(char::from).collect();
599
600    // We create a loop here so the user can retry if an error happens.
601    loop {
602        let mut homeserver = String::new();
603
604        print!("\nHomeserver: ");
605        io::stdout().flush().expect("Unable to write to stdout");
606        io::stdin().read_line(&mut homeserver).expect("Unable to read user input");
607
608        let homeserver = homeserver.trim();
609
610        println!("\nChecking homeserver…");
611
612        match Client::builder()
613            // Try autodiscovery or test the URL.
614            .server_name_or_homeserver_url(homeserver)
615            // Make sure to automatically refresh tokens if needs be.
616            .handle_refresh_tokens()
617            // We use the sqlite store, which is available by default. This is the crucial part to
618            // persist the encryption setup.
619            // Note that other store backends are available and you can even implement your own.
620            .sqlite_store(&db_path, Some(&passphrase))
621            .build()
622            .await
623        {
624            Ok(client) => {
625                // Check if the homeserver advertises OAuth 2.0 server metadata.
626                match client.oauth().server_metadata().await {
627                    Ok(server_metadata) => {
628                        println!(
629                            "Found OAuth 2.0 server metadata with issuer: {}",
630                            server_metadata.issuer
631                        );
632
633                        let homeserver = client.homeserver().to_string();
634                        return Ok((client, ClientSession { homeserver, db_path, passphrase }));
635                    }
636                    Err(error) => {
637                        if error.is_not_supported() {
638                            println!(
639                                "This homeserver doesn't advertise OAuth 2.0 server metadata."
640                            );
641                        } else {
642                            println!("Error fetching the OAuth 2.0 server metadata: {error:?}");
643                        }
644                        // The client already initialized the store so we need to remove it.
645                        fs::remove_dir_all(data_dir).await?;
646                    }
647                }
648            }
649            Err(error) => match &error {
650                ClientBuildError::AutoDiscovery(_)
651                | ClientBuildError::Url(_)
652                | ClientBuildError::Http(_) => {
653                    println!("Error checking the homeserver: {error}");
654                    println!("Please try again\n");
655                    // The client already initialized the store so we need to remove it.
656                    fs::remove_dir_all(data_dir).await?;
657                }
658                ClientBuildError::InvalidServerName => {
659                    println!("Error: not a valid server name");
660                    println!("Please try again\n");
661                }
662                _ => {
663                    // Forward other errors, it's unlikely we can retry with a different outcome.
664                    return Err(error.into());
665                }
666            },
667        }
668    }
669}
670
671/// Generate the OAuth 2.0 client metadata.
672fn client_metadata() -> Raw<ClientMetadata> {
673    // Native clients should be able to register the IPv4 and IPv6 loopback
674    // interfaces and then point to any port when needing a redirect URI. An
675    // alternative is to use a custom URI scheme registered with the OS.
676    let ipv4_localhost_uri = Url::parse(&format!("http://{}/", Ipv4Addr::LOCALHOST))
677        .expect("Couldn't parse IPv4 redirect URI");
678    let ipv6_localhost_uri = Url::parse(&format!("http://[{}]/", Ipv6Addr::LOCALHOST))
679        .expect("Couldn't parse IPv6 redirect URI");
680    let client_uri = Localized::new(
681        Url::parse("https://github.com/matrix-org/matrix-rust-sdk")
682            .expect("Couldn't parse client URI"),
683        None,
684    );
685
686    let metadata = ClientMetadata {
687        // The following fields should be displayed in the OAuth 2.0 authorization server's
688        // web UI as part of the process to get the user's consent. It means that these
689        // should contain real data so the user can make sure that they allow the proper
690        // application. We are cheating here because this is an example.
691        client_name: Some(Localized::new("matrix-rust-sdk-oauth-cli".to_owned(), [])),
692        policy_uri: Some(client_uri.clone()),
693        tos_uri: Some(client_uri.clone()),
694        ..ClientMetadata::new(
695            // This is a native application (in contrast to a web application, that runs in a
696            // browser).
697            ApplicationType::Native,
698            // We are going to use the Authorization Code flow.
699            vec![OAuthGrantType::AuthorizationCode {
700                redirect_uris: vec![ipv4_localhost_uri, ipv6_localhost_uri],
701            }],
702            client_uri,
703        )
704    };
705
706    Raw::new(&metadata).expect("Couldn't serialize client metadata")
707}
708
709/// Open the authorization URL and wait for it to be complete.
710///
711/// Returns the code to obtain the access token.
712async fn use_auth_url(url: &Url, server_handle: LocalServerRedirectHandle) -> Option<QueryString> {
713    println!("\nPlease authenticate yourself at: {url}\n");
714    println!("Then proceed to the authorization.\n");
715
716    server_handle.await
717}
718
719/// Handle room messages.
720async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) {
721    // We only want to log text messages in joined rooms.
722    if room.state() != RoomState::Joined {
723        return;
724    }
725    let MessageType::Text(text_content) = &event.content.msgtype else { return };
726
727    let room_name = match room.display_name().await {
728        Ok(room_name) => room_name.to_string(),
729        Err(error) => {
730            println!("Error getting room display name: {error}");
731            // Let's fallback to the room ID.
732            room.room_id().to_string()
733        }
734    };
735
736    println!("[{room_name}] {}: {}", event.sender, text_content.body)
737}