Skip to main content

example_oauth_cli/
main.rs

1// Copyright 2023 Kévin Commaille.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    io::{self, Write},
17    net::{Ipv4Addr, Ipv6Addr},
18    path::{Path, PathBuf},
19    sync::Arc,
20};
21
22use anyhow::bail;
23use futures_util::StreamExt;
24use matrix_sdk::{
25    Client, ClientBuildError, Result, RoomState,
26    authentication::oauth::{
27        ClientId, OAuthAuthorizationData, OAuthError, OAuthSession, UserSession,
28        error::OAuthClientRegistrationError,
29        registration::{ApplicationType, ClientMetadata, Localized, OAuthGrantType},
30    },
31    config::SyncSettings,
32    encryption::{CrossSigningResetAuthType, recovery::RecoveryState},
33    room::Room,
34    ruma::{
35        api::client::discovery::get_authorization_server_metadata::v1::AccountManagementActionData,
36        events::room::message::{MessageType, OriginalSyncRoomMessageEvent},
37        serde::Raw,
38    },
39    utils::local_server::{LocalServerBuilder, LocalServerRedirectHandle, QueryString},
40};
41use matrix_sdk_ui::sync_service::SyncService;
42use rand::{Rng, distributions::Alphanumeric, thread_rng};
43use serde::{Deserialize, Serialize};
44use tokio::{fs, io::AsyncBufReadExt as _};
45use url::Url;
46
47/// A command-line tool to demonstrate the steps requiring an interaction with
48/// an OAuth 2.0 authorization server for a Matrix client, using the
49/// Authorization Code flow.
50///
51/// You can test this against one of the servers from the OIDC playground:
52/// <https://github.com/element-hq/oidc-playground>.
53///
54/// To use this, just run `cargo run -p example-oauth-cli`, and everything
55/// is interactive after that. You might want to set the `RUST_LOG` environment
56/// variable to `warn` to reduce the noise in the logs. The program exits
57/// whenever an unexpected error occurs.
58///
59/// To reset the login, simply use the `logout` command or delete the folder
60/// containing the session file, the location is shown in the logs. Note that
61/// the database must be deleted too as it can't be reused.
62#[tokio::main]
63async fn main() -> anyhow::Result<()> {
64    tracing_subscriber::fmt::init();
65
66    // The folder containing this example's data.
67    let data_dir =
68        dirs::data_dir().expect("no data_dir directory found").join("matrix_sdk/oauth_cli");
69    // The file where the session is persisted.
70    let session_file = data_dir.join("session.json");
71
72    let cli = if session_file.exists() {
73        OAuthCli::from_stored_session(session_file).await?
74    } else {
75        OAuthCli::new(&data_dir, session_file).await?
76    };
77
78    cli.run().await
79}
80
81/// The available commands once the client is logged in.
82fn help() {
83    println!("Usage: [command] [args…]\n");
84    println!("Commands:");
85    println!("  whoami                 Get information about this session");
86    println!("  account                Get the URL to manage this account");
87    println!("  watch [sliding?]       Watch new incoming messages until an error occurs");
88    println!("  refresh                Refresh the access token");
89    println!("  recover                Recover the E2EE secrets from secret storage");
90    println!("  logout                 Log out of this account");
91    println!("  exit                   Exit this program");
92    println!("  help                   Show this message\n");
93}
94
95/// The data needed to re-build a client.
96#[derive(Debug, Serialize, Deserialize)]
97struct ClientSession {
98    /// The URL of the homeserver of the user.
99    homeserver: String,
100
101    /// The path of the database.
102    db_path: PathBuf,
103
104    /// The passphrase of the database.
105    passphrase: String,
106}
107
108/// The full session to persist.
109#[derive(Debug, Serialize, Deserialize)]
110struct StoredSession {
111    /// The data to re-build the client.
112    client_session: ClientSession,
113
114    /// The OAuth 2.0 user session.
115    user_session: UserSession,
116
117    /// The OAuth 2.0 client ID.
118    client_id: ClientId,
119}
120
121/// An OAuth 2.0 CLI.
122#[derive(Clone, Debug)]
123struct OAuthCli {
124    /// The Matrix client.
125    client: Client,
126
127    /// Whether this is a restored client.
128    restored: bool,
129
130    /// The path to the file storing the session.
131    session_file: PathBuf,
132}
133
134impl OAuthCli {
135    /// Create a new session by logging in.
136    async fn new(data_dir: &Path, session_file: PathBuf) -> anyhow::Result<Self> {
137        println!("No previous session found, logging in…");
138
139        let (client, client_session) = build_client(data_dir).await?;
140        let cli = Self { client, restored: false, session_file };
141
142        if let Err(error) = cli.register_and_login().await {
143            if let Some(error) = error.downcast_ref::<OAuthError>()
144                && let OAuthError::ClientRegistration(OAuthClientRegistrationError::NotSupported) =
145                    error
146            {
147                // This would require to register with the authorization server manually, which
148                // we don't support here.
149                bail!(
150                    "This server doesn't support dynamic registration.\n\
151                     Please select another homeserver."
152                );
153            } else {
154                return Err(error);
155            }
156        }
157
158        // Persist the session to reuse it later.
159        // This is not very secure, for simplicity. If the system provides a way of
160        // storing secrets securely, it should be used instead.
161        let full_session =
162            cli.client.oauth().full_session().expect("A logged-in client should have a session");
163
164        let serialized_session = serde_json::to_string(&StoredSession {
165            client_session,
166            user_session: full_session.user,
167            client_id: full_session.client_id,
168        })?;
169        fs::write(&cli.session_file, serialized_session).await?;
170
171        println!("Session persisted in {}", cli.session_file.to_string_lossy());
172
173        cli.setup_background_save();
174
175        Ok(cli)
176    }
177
178    /// Register the client and log in the user via the OAuth 2.0 Authorization
179    /// Code flow.
180    async fn register_and_login(&self) -> anyhow::Result<()> {
181        let oauth = self.client.oauth();
182
183        // We create a loop here so the user can retry if an error happens.
184        loop {
185            // Here we spawn a server to listen on the loopback interface. Another option
186            // would be to register a custom URI scheme with the system and handle
187            // the redirect when the custom URI scheme is opened.
188            let (redirect_uri, server_handle) = LocalServerBuilder::new().spawn().await?;
189
190            let OAuthAuthorizationData { url, .. } = oauth
191                .login(redirect_uri, None, Some(client_metadata().into()), None)
192                .build()
193                .await?;
194
195            let Some(query_string) = use_auth_url(&url, server_handle).await else {
196                println!("Error: failed to login: missing query string on the redirect URL");
197                println!("Please try again.\n");
198                continue;
199            };
200
201            match oauth.finish_login(query_string.into()).await {
202                Ok(()) => {
203                    let user_id = self.client.user_id().expect("Got a user ID");
204                    println!("Logged in as {user_id}");
205                    break;
206                }
207                Err(err) => {
208                    println!("Error: failed to login: {err}");
209                    println!("Please try again.\n");
210                    continue;
211                }
212            }
213        }
214
215        Ok(())
216    }
217
218    /// Restore a previous session from a file.
219    async fn from_stored_session(session_file: PathBuf) -> anyhow::Result<Self> {
220        println!("Previous session found in '{}'", session_file.to_string_lossy());
221
222        // The session was serialized as JSON in a file.
223        let serialized_session = fs::read_to_string(&session_file).await?;
224        let StoredSession { client_session, user_session, client_id } =
225            serde_json::from_str(&serialized_session)?;
226
227        // Build the client with the previous settings from the session.
228        let client = Client::builder()
229            .homeserver_url(client_session.homeserver)
230            .handle_refresh_tokens()
231            .sqlite_store(client_session.db_path, Some(&client_session.passphrase))
232            .build()
233            .await?;
234
235        println!("Restoring session for {}…", user_session.meta.user_id);
236
237        let session = OAuthSession { client_id, user: user_session };
238        // Restore the Matrix user session.
239        client.restore_session(session).await?;
240
241        let this = Self { client, restored: true, session_file };
242
243        this.setup_background_save();
244
245        Ok(this)
246    }
247
248    /// Run the main program.
249    async fn run(&self) -> anyhow::Result<()> {
250        help();
251
252        loop {
253            let mut input = String::new();
254
255            print!("\nEnter command: ");
256            io::stdout().flush().expect("Unable to write to stdout");
257
258            io::stdin().read_line(&mut input).expect("Unable to read user input");
259
260            let mut args = input.trim().split_ascii_whitespace();
261            let cmd = args.next();
262
263            match cmd {
264                Some("whoami") => {
265                    self.whoami();
266                }
267                Some("account") => {
268                    self.account(None).await;
269                }
270                Some("profile") => {
271                    self.account(Some(AccountManagementActionData::Profile)).await;
272                }
273                Some("devices") => {
274                    self.account(Some(AccountManagementActionData::DevicesList)).await;
275                }
276                Some("watch") => match args.next() {
277                    Some(sub) => {
278                        if sub == "sliding" {
279                            self.watch_sliding_sync().await?;
280                        } else {
281                            println!("unknown subcommand for watch: available is 'sliding'");
282                        }
283                    }
284                    None => self.watch().await?,
285                },
286                Some("refresh") => {
287                    self.refresh_token().await?;
288                }
289                Some("recover") => {
290                    self.recover().await?;
291                }
292                Some("reset-cross-signing") => {
293                    self.reset_cross_signing().await?;
294                }
295                Some("logout") => {
296                    self.logout().await?;
297                    break;
298                }
299                Some("exit") => {
300                    break;
301                }
302                Some("help") => {
303                    help();
304                }
305                Some(cmd) => {
306                    println!("Error: unknown command '{cmd}'\n");
307                    help();
308                }
309                None => {
310                    println!("Error: no command\n");
311                    help()
312                }
313            }
314        }
315
316        Ok(())
317    }
318
319    async fn recover(&self) -> anyhow::Result<()> {
320        let recovery = self.client.encryption().recovery();
321
322        println!("Please enter your recovery key:");
323
324        let mut input = String::new();
325        io::stdin().read_line(&mut input).expect("error: unable to read user input");
326
327        let input = input.trim();
328
329        recovery.recover(input).await?;
330
331        match recovery.state() {
332            RecoveryState::Enabled => println!("Successfully recovered all the E2EE secrets."),
333            RecoveryState::Disabled => println!("Error recovering, recovery is disabled."),
334            RecoveryState::Incomplete => println!("Couldn't recover all E2EE secrets."),
335            _ => unreachable!("We should know our recovery state by now"),
336        }
337
338        Ok(())
339    }
340
341    async fn reset_cross_signing(&self) -> Result<()> {
342        let encryption = self.client.encryption();
343
344        if let Some(handle) = encryption.reset_cross_signing().await? {
345            match handle.auth_type() {
346                CrossSigningResetAuthType::Uiaa(_) => {
347                    unimplemented!(
348                        "This should never happen, this is after all the OAuth 2.0 example."
349                    )
350                }
351                CrossSigningResetAuthType::OAuth(o) => {
352                    println!(
353                        "To reset your end-to-end encryption cross-signing identity, \
354                        you first need to approve it at {}",
355                        o.approval_url
356                    );
357                    handle.auth(None).await?;
358                }
359            }
360        }
361
362        print!("Successfully reset cross-signing");
363
364        Ok(())
365    }
366
367    /// Get information about this session.
368    fn whoami(&self) {
369        let client = &self.client;
370
371        let user_id = client.user_id().expect("A logged in client has a user ID");
372        let device_id = client.device_id().expect("A logged in client has a device ID");
373        let homeserver = client.homeserver();
374
375        println!("\nUser ID: {user_id}");
376        println!("Device ID: {device_id}");
377        println!("Homeserver URL: {homeserver}");
378    }
379
380    /// Get the account management URL.
381    async fn account(&self, action: Option<AccountManagementActionData<'_>>) {
382        let Ok(server_metadata) = self.client.oauth().cached_server_metadata().await else {
383            println!("\nCould not retrieve the server metadata");
384            return;
385        };
386
387        let url = if let Some(action) = action {
388            server_metadata.account_management_url_with_action(action)
389        } else {
390            server_metadata.account_management_uri
391        };
392
393        let Some(url) = url else {
394            println!("\nThis homeserver does not provide the URL to manage your account");
395            return;
396        };
397
398        println!("\nTo manage your account, visit: {url}");
399    }
400
401    /// Watch incoming messages.
402    async fn watch(&self) -> anyhow::Result<()> {
403        let client = &self.client;
404
405        // If this is a new client, ignore previous messages to not fill the logs.
406        // Note that this might not work as intended, the initial sync might have failed
407        // in a previous session.
408        if !self.restored {
409            client.sync_once(SyncSettings::default()).await.unwrap();
410        }
411
412        // Listen to room messages.
413        let handle = client.add_event_handler(on_room_message);
414
415        // Sync.
416        let mut sync_stream = Box::pin(client.sync_stream(SyncSettings::default()).await);
417        while let Some(res) = sync_stream.next().await {
418            if let Err(err) = res {
419                client.remove_event_handler(handle);
420                return Err(err.into());
421            }
422        }
423
424        Ok(())
425    }
426
427    /// This watches for incoming responses using the high-level sliding sync
428    /// helpers (`SyncService`).
429    async fn watch_sliding_sync(&self) -> anyhow::Result<()> {
430        let sync_service = Arc::new(SyncService::builder(self.client.clone()).build().await?);
431
432        sync_service.start().await;
433
434        println!("press enter to exit the sync loop");
435
436        let mut sync_service_state = sync_service.state();
437
438        let sync_service_clone = sync_service.clone();
439        let task = tokio::spawn(async move {
440            // Only fail after getting 5 errors in a row. When we're in an always-refail
441            // scenario, we move from the Error to the Running state for a bit
442            // until we fail again, so we need to track both failure state and
443            // running state, hence `num_errors` and `num_running`:
444            // - if we failed and num_running was 1, then this is a failure following a
445            //   failure.
446            // - otherwise, we recovered from the failure and we can plain continue.
447
448            let mut num_errors = 0;
449            let mut num_running = 0;
450
451            let mut _unused = String::new();
452            let mut stdin = tokio::io::BufReader::new(tokio::io::stdin());
453
454            loop {
455                // Concurrently wait for an update from the sync service OR for the user to
456                // press enter and leave early.
457                tokio::select! {
458                    res = sync_service_state.next() => {
459                        if let Some(state) = res {
460                            match state {
461                                matrix_sdk_ui::sync_service::State::Idle
462                                | matrix_sdk_ui::sync_service::State::Terminated => {
463                                    num_errors = 0;
464                                    num_running = 0;
465                                }
466
467                                matrix_sdk_ui::sync_service::State::Running => {
468                                    num_running += 1;
469                                    if num_running > 1 {
470                                        num_errors = 0;
471                                    }
472                                }
473
474                                matrix_sdk_ui::sync_service::State::Error(_) | matrix_sdk_ui::sync_service::State::Offline => {
475                                    num_errors += 1;
476                                    num_running = 0;
477
478                                    if num_errors == 5 {
479                                        println!("ran into 5 errors in a row, terminating");
480                                        break;
481                                    }
482
483                                    sync_service_clone.start().await;
484                                }
485                            }
486                            println!("New sync service state update: {state:?}");
487                        } else {
488                            break;
489                        }
490                    }
491
492                    _ = stdin.read_line(&mut _unused) => {
493                        println!("Stopping loop because of user request");
494                        sync_service.stop().await;
495
496                        break;
497                    }
498                }
499            }
500        });
501
502        println!("waiting for sync service to stop...");
503        task.await.unwrap();
504
505        println!("done!");
506        Ok(())
507    }
508
509    /// Sets up this client so that it automatically saves the session onto disk
510    /// whenever there are new tokens that have been received.
511    ///
512    /// This should always be set up whenever automatic refresh is happening.
513    fn setup_background_save(&self) {
514        let this = self.clone();
515        tokio::spawn(async move {
516            while let Ok(update) = this.client.subscribe_to_session_changes().recv().await {
517                match update {
518                    matrix_sdk::SessionChange::UnknownToken(unknown_token) => {
519                        println!(
520                            "Received an unknown token error; soft logout? {:?}",
521                            unknown_token.soft_logout
522                        );
523                    }
524                    matrix_sdk::SessionChange::TokensRefreshed => {
525                        // The tokens have been refreshed, persist them to disk.
526                        if let Err(err) = this.update_stored_session().await {
527                            println!("Unable to store a session in the background: {err}");
528                        }
529                    }
530                }
531            }
532        });
533    }
534
535    /// Update the session stored on the system.
536    ///
537    /// This should be called everytime the access token (and possibly refresh
538    /// token) has changed.
539    async fn update_stored_session(&self) -> anyhow::Result<()> {
540        println!("Updating the stored session...");
541
542        let serialized_session = fs::read_to_string(&self.session_file).await?;
543        let mut session = serde_json::from_str::<StoredSession>(&serialized_session)?;
544
545        let user_session: UserSession =
546            self.client.oauth().user_session().expect("A logged in client has a session");
547        session.user_session = user_session;
548
549        let serialized_session = serde_json::to_string(&session)?;
550        fs::write(&self.session_file, serialized_session).await?;
551
552        println!("Updating the stored session: done!");
553        Ok(())
554    }
555
556    /// Refresh the access token.
557    async fn refresh_token(&self) -> anyhow::Result<()> {
558        self.client.oauth().refresh_access_token().await?;
559
560        // The session will automatically be refreshed because of the task persisting
561        // the full session upon refresh in `setup_background_save`.
562
563        println!("\nToken refreshed successfully");
564
565        Ok(())
566    }
567
568    /// Log out from this session.
569    async fn logout(&self) -> anyhow::Result<()> {
570        // Log out via OAuth 2.0.
571        self.client.logout().await?;
572
573        // Delete the stored session and database.
574        let data_dir = self.session_file.parent().expect("The file has a parent directory");
575        fs::remove_dir_all(data_dir).await?;
576
577        println!("\nLogged out successfully");
578        println!("\nExiting…");
579
580        Ok(())
581    }
582}
583
584/// Build a new client.
585///
586/// Returns the client and the data required to restore the client.
587async fn build_client(data_dir: &Path) -> anyhow::Result<(Client, ClientSession)> {
588    let db_path = data_dir.join("db");
589
590    // Generate a random passphrase.
591    let mut rng = thread_rng();
592    let passphrase: String =
593        (&mut rng).sample_iter(Alphanumeric).take(32).map(char::from).collect();
594
595    // We create a loop here so the user can retry if an error happens.
596    loop {
597        let mut homeserver = String::new();
598
599        print!("\nHomeserver: ");
600        io::stdout().flush().expect("Unable to write to stdout");
601        io::stdin().read_line(&mut homeserver).expect("Unable to read user input");
602
603        let homeserver = homeserver.trim();
604
605        println!("\nChecking homeserver…");
606
607        match Client::builder()
608            // Try autodiscovery or test the URL.
609            .server_name_or_homeserver_url(homeserver)
610            // Make sure to automatically refresh tokens if needs be.
611            .handle_refresh_tokens()
612            // We use the sqlite store, which is available by default. This is the crucial part to
613            // persist the encryption setup.
614            // Note that other store backends are available and you can even implement your own.
615            .sqlite_store(&db_path, Some(&passphrase))
616            .build()
617            .await
618        {
619            Ok(client) => {
620                // Check if the homeserver advertises OAuth 2.0 server metadata.
621                match client.oauth().server_metadata().await {
622                    Ok(server_metadata) => {
623                        println!(
624                            "Found OAuth 2.0 server metadata with issuer: {}",
625                            server_metadata.issuer
626                        );
627
628                        let homeserver = client.homeserver().to_string();
629                        return Ok((client, ClientSession { homeserver, db_path, passphrase }));
630                    }
631                    Err(error) => {
632                        if error.is_not_supported() {
633                            println!(
634                                "This homeserver doesn't advertise OAuth 2.0 server metadata."
635                            );
636                        } else {
637                            println!("Error fetching the OAuth 2.0 server metadata: {error:?}");
638                        }
639                        // The client already initialized the store so we need to remove it.
640                        fs::remove_dir_all(data_dir).await?;
641                    }
642                }
643            }
644            Err(error) => match &error {
645                ClientBuildError::AutoDiscovery(_)
646                | ClientBuildError::Url(_)
647                | ClientBuildError::Http(_) => {
648                    println!("Error checking the homeserver: {error}");
649                    println!("Please try again\n");
650                    // The client already initialized the store so we need to remove it.
651                    fs::remove_dir_all(data_dir).await?;
652                }
653                ClientBuildError::InvalidServerName => {
654                    println!("Error: not a valid server name");
655                    println!("Please try again\n");
656                }
657                _ => {
658                    // Forward other errors, it's unlikely we can retry with a different outcome.
659                    return Err(error.into());
660                }
661            },
662        }
663    }
664}
665
666/// Generate the OAuth 2.0 client metadata.
667fn client_metadata() -> Raw<ClientMetadata> {
668    // Native clients should be able to register the IPv4 and IPv6 loopback
669    // interfaces and then point to any port when needing a redirect URI. An
670    // alternative is to use a custom URI scheme registered with the OS.
671    let ipv4_localhost_uri = Url::parse(&format!("http://{}/", Ipv4Addr::LOCALHOST))
672        .expect("Couldn't parse IPv4 redirect URI");
673    let ipv6_localhost_uri = Url::parse(&format!("http://[{}]/", Ipv6Addr::LOCALHOST))
674        .expect("Couldn't parse IPv6 redirect URI");
675    let client_uri = Localized::new(
676        Url::parse("https://github.com/matrix-org/matrix-rust-sdk")
677            .expect("Couldn't parse client URI"),
678        None,
679    );
680
681    let metadata = ClientMetadata {
682        // The following fields should be displayed in the OAuth 2.0 authorization server's
683        // web UI as part of the process to get the user's consent. It means that these
684        // should contain real data so the user can make sure that they allow the proper
685        // application. We are cheating here because this is an example.
686        client_name: Some(Localized::new("matrix-rust-sdk-oauth-cli".to_owned(), [])),
687        policy_uri: Some(client_uri.clone()),
688        tos_uri: Some(client_uri.clone()),
689        ..ClientMetadata::new(
690            // This is a native application (in contrast to a web application, that runs in a
691            // browser).
692            ApplicationType::Native,
693            // We are going to use the Authorization Code flow.
694            vec![OAuthGrantType::AuthorizationCode {
695                redirect_uris: vec![ipv4_localhost_uri, ipv6_localhost_uri],
696            }],
697            client_uri,
698        )
699    };
700
701    Raw::new(&metadata).expect("Couldn't serialize client metadata")
702}
703
704/// Open the authorization URL and wait for it to be complete.
705///
706/// Returns the code to obtain the access token.
707async fn use_auth_url(url: &Url, server_handle: LocalServerRedirectHandle) -> Option<QueryString> {
708    println!("\nPlease authenticate yourself at: {url}\n");
709    println!("Then proceed to the authorization.\n");
710
711    server_handle.await
712}
713
714/// Handle room messages.
715async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) {
716    // We only want to log text messages in joined rooms.
717    if room.state() != RoomState::Joined {
718        return;
719    }
720    let MessageType::Text(text_content) = &event.content.msgtype else { return };
721
722    let room_name = match room.display_name().await {
723        Ok(room_name) => room_name.to_string(),
724        Err(error) => {
725            println!("Error getting room display name: {error}");
726            // Let's fallback to the room ID.
727            room.room_id().to_string()
728        }
729    };
730
731    println!("[{room_name}] {}: {}", event.sender, text_content.body)
732}