example_oauth_cli/
main.rs

1// Copyright 2023 Kévin Commaille.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    io::{self, Write},
17    net::{Ipv4Addr, Ipv6Addr},
18    path::{Path, PathBuf},
19    sync::Arc,
20};
21
22use anyhow::bail;
23use futures_util::StreamExt;
24use matrix_sdk::{
25    Client, ClientBuildError, Result, RoomState,
26    authentication::oauth::{
27        AccountManagementActionFull, ClientId, OAuthAuthorizationData, OAuthError, OAuthSession,
28        UrlOrQuery, UserSession,
29        error::OAuthClientRegistrationError,
30        registration::{ApplicationType, ClientMetadata, Localized, OAuthGrantType},
31    },
32    config::SyncSettings,
33    encryption::{CrossSigningResetAuthType, recovery::RecoveryState},
34    room::Room,
35    ruma::{
36        events::room::message::{MessageType, OriginalSyncRoomMessageEvent},
37        serde::Raw,
38    },
39    utils::local_server::{LocalServerBuilder, LocalServerRedirectHandle, QueryString},
40};
41use matrix_sdk_ui::sync_service::SyncService;
42use rand::{Rng, distributions::Alphanumeric, thread_rng};
43use serde::{Deserialize, Serialize};
44use tokio::{fs, io::AsyncBufReadExt as _};
45use url::Url;
46
47/// A command-line tool to demonstrate the steps requiring an interaction with
48/// an OAuth 2.0 authorization server for a Matrix client, using the
49/// Authorization Code flow.
50///
51/// You can test this against one of the servers from the OIDC playground:
52/// <https://github.com/element-hq/oidc-playground>.
53///
54/// To use this, just run `cargo run -p example-oauth-cli`, and everything
55/// is interactive after that. You might want to set the `RUST_LOG` environment
56/// variable to `warn` to reduce the noise in the logs. The program exits
57/// whenever an unexpected error occurs.
58///
59/// To reset the login, simply use the `logout` command or delete the folder
60/// containing the session file, the location is shown in the logs. Note that
61/// the database must be deleted too as it can't be reused.
62#[tokio::main]
63async fn main() -> anyhow::Result<()> {
64    tracing_subscriber::fmt::init();
65
66    // The folder containing this example's data.
67    let data_dir =
68        dirs::data_dir().expect("no data_dir directory found").join("matrix_sdk/oauth_cli");
69    // The file where the session is persisted.
70    let session_file = data_dir.join("session.json");
71
72    let cli = if session_file.exists() {
73        OAuthCli::from_stored_session(session_file).await?
74    } else {
75        OAuthCli::new(&data_dir, session_file).await?
76    };
77
78    cli.run().await
79}
80
81/// The available commands once the client is logged in.
82fn help() {
83    println!("Usage: [command] [args…]\n");
84    println!("Commands:");
85    println!("  whoami                 Get information about this session");
86    println!("  account                Get the URL to manage this account");
87    println!("  watch [sliding?]       Watch new incoming messages until an error occurs");
88    println!("  refresh                Refresh the access token");
89    println!("  recover                Recover the E2EE secrets from secret storage");
90    println!("  logout                 Log out of this account");
91    println!("  exit                   Exit this program");
92    println!("  help                   Show this message\n");
93}
94
95/// The data needed to re-build a client.
96#[derive(Debug, Serialize, Deserialize)]
97struct ClientSession {
98    /// The URL of the homeserver of the user.
99    homeserver: String,
100
101    /// The path of the database.
102    db_path: PathBuf,
103
104    /// The passphrase of the database.
105    passphrase: String,
106}
107
108/// The full session to persist.
109#[derive(Debug, Serialize, Deserialize)]
110struct StoredSession {
111    /// The data to re-build the client.
112    client_session: ClientSession,
113
114    /// The OAuth 2.0 user session.
115    user_session: UserSession,
116
117    /// The OAuth 2.0 client ID.
118    client_id: ClientId,
119}
120
121/// An OAuth 2.0 CLI.
122#[derive(Clone, Debug)]
123struct OAuthCli {
124    /// The Matrix client.
125    client: Client,
126
127    /// Whether this is a restored client.
128    restored: bool,
129
130    /// The path to the file storing the session.
131    session_file: PathBuf,
132}
133
134impl OAuthCli {
135    /// Create a new session by logging in.
136    async fn new(data_dir: &Path, session_file: PathBuf) -> anyhow::Result<Self> {
137        println!("No previous session found, logging in…");
138
139        let (client, client_session) = build_client(data_dir).await?;
140        let cli = Self { client, restored: false, session_file };
141
142        if let Err(error) = cli.register_and_login().await {
143            if let Some(error) = error.downcast_ref::<OAuthError>()
144                && let OAuthError::ClientRegistration(OAuthClientRegistrationError::NotSupported) =
145                    error
146            {
147                // This would require to register with the authorization server manually, which
148                // we don't support here.
149                bail!(
150                    "This server doesn't support dynamic registration.\n\
151                     Please select another homeserver."
152                );
153            } else {
154                return Err(error);
155            }
156        }
157
158        // Persist the session to reuse it later.
159        // This is not very secure, for simplicity. If the system provides a way of
160        // storing secrets securely, it should be used instead.
161        let full_session =
162            cli.client.oauth().full_session().expect("A logged-in client should have a session");
163
164        let serialized_session = serde_json::to_string(&StoredSession {
165            client_session,
166            user_session: full_session.user,
167            client_id: full_session.client_id,
168        })?;
169        fs::write(&cli.session_file, serialized_session).await?;
170
171        println!("Session persisted in {}", cli.session_file.to_string_lossy());
172
173        cli.setup_background_save();
174
175        Ok(cli)
176    }
177
178    /// Register the client and log in the user via the OAuth 2.0 Authorization
179    /// Code flow.
180    async fn register_and_login(&self) -> anyhow::Result<()> {
181        let oauth = self.client.oauth();
182
183        // We create a loop here so the user can retry if an error happens.
184        loop {
185            // Here we spawn a server to listen on the loopback interface. Another option
186            // would be to register a custom URI scheme with the system and handle
187            // the redirect when the custom URI scheme is opened.
188            let (redirect_uri, server_handle) = LocalServerBuilder::new().spawn().await?;
189
190            let OAuthAuthorizationData { url, .. } = oauth
191                .login(redirect_uri, None, Some(client_metadata().into()), None)
192                .build()
193                .await?;
194
195            let query_string =
196                use_auth_url(&url, server_handle).await.map(|query| query.0).unwrap_or_default();
197
198            match oauth.finish_login(UrlOrQuery::Query(query_string)).await {
199                Ok(()) => {
200                    let user_id = self.client.user_id().expect("Got a user ID");
201                    println!("Logged in as {user_id}");
202                    break;
203                }
204                Err(err) => {
205                    println!("Error: failed to login: {err}");
206                    println!("Please try again.\n");
207                    continue;
208                }
209            }
210        }
211
212        Ok(())
213    }
214
215    /// Restore a previous session from a file.
216    async fn from_stored_session(session_file: PathBuf) -> anyhow::Result<Self> {
217        println!("Previous session found in '{}'", session_file.to_string_lossy());
218
219        // The session was serialized as JSON in a file.
220        let serialized_session = fs::read_to_string(&session_file).await?;
221        let StoredSession { client_session, user_session, client_id } =
222            serde_json::from_str(&serialized_session)?;
223
224        // Build the client with the previous settings from the session.
225        let client = Client::builder()
226            .homeserver_url(client_session.homeserver)
227            .handle_refresh_tokens()
228            .sqlite_store(client_session.db_path, Some(&client_session.passphrase))
229            .build()
230            .await?;
231
232        println!("Restoring session for {}…", user_session.meta.user_id);
233
234        let session = OAuthSession { client_id, user: user_session };
235        // Restore the Matrix user session.
236        client.restore_session(session).await?;
237
238        let this = Self { client, restored: true, session_file };
239
240        this.setup_background_save();
241
242        Ok(this)
243    }
244
245    /// Run the main program.
246    async fn run(&self) -> anyhow::Result<()> {
247        help();
248
249        loop {
250            let mut input = String::new();
251
252            print!("\nEnter command: ");
253            io::stdout().flush().expect("Unable to write to stdout");
254
255            io::stdin().read_line(&mut input).expect("Unable to read user input");
256
257            let mut args = input.trim().split_ascii_whitespace();
258            let cmd = args.next();
259
260            match cmd {
261                Some("whoami") => {
262                    self.whoami();
263                }
264                Some("account") => {
265                    self.account(None).await;
266                }
267                Some("profile") => {
268                    self.account(Some(AccountManagementActionFull::Profile)).await;
269                }
270                Some("sessions") => {
271                    self.account(Some(AccountManagementActionFull::SessionsList)).await;
272                }
273                Some("watch") => match args.next() {
274                    Some(sub) => {
275                        if sub == "sliding" {
276                            self.watch_sliding_sync().await?;
277                        } else {
278                            println!("unknown subcommand for watch: available is 'sliding'");
279                        }
280                    }
281                    None => self.watch().await?,
282                },
283                Some("refresh") => {
284                    self.refresh_token().await?;
285                }
286                Some("recover") => {
287                    self.recover().await?;
288                }
289                Some("reset-cross-signing") => {
290                    self.reset_cross_signing().await?;
291                }
292                Some("logout") => {
293                    self.logout().await?;
294                    break;
295                }
296                Some("exit") => {
297                    break;
298                }
299                Some("help") => {
300                    help();
301                }
302                Some(cmd) => {
303                    println!("Error: unknown command '{cmd}'\n");
304                    help();
305                }
306                None => {
307                    println!("Error: no command\n");
308                    help()
309                }
310            }
311        }
312
313        Ok(())
314    }
315
316    async fn recover(&self) -> anyhow::Result<()> {
317        let recovery = self.client.encryption().recovery();
318
319        println!("Please enter your recovery key:");
320
321        let mut input = String::new();
322        io::stdin().read_line(&mut input).expect("error: unable to read user input");
323
324        let input = input.trim();
325
326        recovery.recover(input).await?;
327
328        match recovery.state() {
329            RecoveryState::Enabled => println!("Successfully recovered all the E2EE secrets."),
330            RecoveryState::Disabled => println!("Error recovering, recovery is disabled."),
331            RecoveryState::Incomplete => println!("Couldn't recover all E2EE secrets."),
332            _ => unreachable!("We should know our recovery state by now"),
333        }
334
335        Ok(())
336    }
337
338    async fn reset_cross_signing(&self) -> Result<()> {
339        let encryption = self.client.encryption();
340
341        if let Some(handle) = encryption.reset_cross_signing().await? {
342            match handle.auth_type() {
343                CrossSigningResetAuthType::Uiaa(_) => {
344                    unimplemented!(
345                        "This should never happen, this is after all the OAuth 2.0 example."
346                    )
347                }
348                CrossSigningResetAuthType::OAuth(o) => {
349                    println!(
350                        "To reset your end-to-end encryption cross-signing identity, \
351                        you first need to approve it at {}",
352                        o.approval_url
353                    );
354                    handle.auth(None).await?;
355                }
356            }
357        }
358
359        print!("Successfully reset cross-signing");
360
361        Ok(())
362    }
363
364    /// Get information about this session.
365    fn whoami(&self) {
366        let client = &self.client;
367
368        let user_id = client.user_id().expect("A logged in client has a user ID");
369        let device_id = client.device_id().expect("A logged in client has a device ID");
370        let homeserver = client.homeserver();
371
372        println!("\nUser ID: {user_id}");
373        println!("Device ID: {device_id}");
374        println!("Homeserver URL: {homeserver}");
375    }
376
377    /// Get the account management URL.
378    async fn account(&self, action: Option<AccountManagementActionFull>) {
379        let Ok(Some(mut url_builder)) = self.client.oauth().fetch_account_management_url().await
380        else {
381            println!("\nThis homeserver does not provide the URL to manage your account");
382            return;
383        };
384
385        if let Some(action) = action {
386            url_builder = url_builder.action(action);
387        }
388
389        let url = url_builder.build();
390        println!("\nTo manage your account, visit: {url}");
391    }
392
393    /// Watch incoming messages.
394    async fn watch(&self) -> anyhow::Result<()> {
395        let client = &self.client;
396
397        // If this is a new client, ignore previous messages to not fill the logs.
398        // Note that this might not work as intended, the initial sync might have failed
399        // in a previous session.
400        if !self.restored {
401            client.sync_once(SyncSettings::default()).await.unwrap();
402        }
403
404        // Listen to room messages.
405        let handle = client.add_event_handler(on_room_message);
406
407        // Sync.
408        let mut sync_stream = Box::pin(client.sync_stream(SyncSettings::default()).await);
409        while let Some(res) = sync_stream.next().await {
410            if let Err(err) = res {
411                client.remove_event_handler(handle);
412                return Err(err.into());
413            }
414        }
415
416        Ok(())
417    }
418
419    /// This watches for incoming responses using the high-level sliding sync
420    /// helpers (`SyncService`).
421    async fn watch_sliding_sync(&self) -> anyhow::Result<()> {
422        let sync_service = Arc::new(SyncService::builder(self.client.clone()).build().await?);
423
424        sync_service.start().await;
425
426        println!("press enter to exit the sync loop");
427
428        let mut sync_service_state = sync_service.state();
429
430        let sync_service_clone = sync_service.clone();
431        let task = tokio::spawn(async move {
432            // Only fail after getting 5 errors in a row. When we're in an always-refail
433            // scenario, we move from the Error to the Running state for a bit
434            // until we fail again, so we need to track both failure state and
435            // running state, hence `num_errors` and `num_running`:
436            // - if we failed and num_running was 1, then this is a failure following a
437            //   failure.
438            // - otherwise, we recovered from the failure and we can plain continue.
439
440            let mut num_errors = 0;
441            let mut num_running = 0;
442
443            let mut _unused = String::new();
444            let mut stdin = tokio::io::BufReader::new(tokio::io::stdin());
445
446            loop {
447                // Concurrently wait for an update from the sync service OR for the user to
448                // press enter and leave early.
449                tokio::select! {
450                    res = sync_service_state.next() => {
451                        if let Some(state) = res {
452                            match state {
453                                matrix_sdk_ui::sync_service::State::Idle
454                                | matrix_sdk_ui::sync_service::State::Terminated => {
455                                    num_errors = 0;
456                                    num_running = 0;
457                                }
458
459                                matrix_sdk_ui::sync_service::State::Running => {
460                                    num_running += 1;
461                                    if num_running > 1 {
462                                        num_errors = 0;
463                                    }
464                                }
465
466                                matrix_sdk_ui::sync_service::State::Error | matrix_sdk_ui::sync_service::State::Offline => {
467                                    num_errors += 1;
468                                    num_running = 0;
469
470                                    if num_errors == 5 {
471                                        println!("ran into 5 errors in a row, terminating");
472                                        break;
473                                    }
474
475                                    sync_service_clone.start().await;
476                                }
477                            }
478                            println!("New sync service state update: {state:?}");
479                        } else {
480                            break;
481                        }
482                    }
483
484                    _ = stdin.read_line(&mut _unused) => {
485                        println!("Stopping loop because of user request");
486                        sync_service.stop().await;
487
488                        break;
489                    }
490                }
491            }
492        });
493
494        println!("waiting for sync service to stop...");
495        task.await.unwrap();
496
497        println!("done!");
498        Ok(())
499    }
500
501    /// Sets up this client so that it automatically saves the session onto disk
502    /// whenever there are new tokens that have been received.
503    ///
504    /// This should always be set up whenever automatic refresh is happening.
505    fn setup_background_save(&self) {
506        let this = self.clone();
507        tokio::spawn(async move {
508            while let Ok(update) = this.client.subscribe_to_session_changes().recv().await {
509                match update {
510                    matrix_sdk::SessionChange::UnknownToken { soft_logout } => {
511                        println!("Received an unknown token error; soft logout? {soft_logout:?}");
512                    }
513                    matrix_sdk::SessionChange::TokensRefreshed => {
514                        // The tokens have been refreshed, persist them to disk.
515                        if let Err(err) = this.update_stored_session().await {
516                            println!("Unable to store a session in the background: {err}");
517                        }
518                    }
519                }
520            }
521        });
522    }
523
524    /// Update the session stored on the system.
525    ///
526    /// This should be called everytime the access token (and possibly refresh
527    /// token) has changed.
528    async fn update_stored_session(&self) -> anyhow::Result<()> {
529        println!("Updating the stored session...");
530
531        let serialized_session = fs::read_to_string(&self.session_file).await?;
532        let mut session = serde_json::from_str::<StoredSession>(&serialized_session)?;
533
534        let user_session: UserSession =
535            self.client.oauth().user_session().expect("A logged in client has a session");
536        session.user_session = user_session;
537
538        let serialized_session = serde_json::to_string(&session)?;
539        fs::write(&self.session_file, serialized_session).await?;
540
541        println!("Updating the stored session: done!");
542        Ok(())
543    }
544
545    /// Refresh the access token.
546    async fn refresh_token(&self) -> anyhow::Result<()> {
547        self.client.oauth().refresh_access_token().await?;
548
549        // The session will automatically be refreshed because of the task persisting
550        // the full session upon refresh in `setup_background_save`.
551
552        println!("\nToken refreshed successfully");
553
554        Ok(())
555    }
556
557    /// Log out from this session.
558    async fn logout(&self) -> anyhow::Result<()> {
559        // Log out via OAuth 2.0.
560        self.client.logout().await?;
561
562        // Delete the stored session and database.
563        let data_dir = self.session_file.parent().expect("The file has a parent directory");
564        fs::remove_dir_all(data_dir).await?;
565
566        println!("\nLogged out successfully");
567        println!("\nExiting…");
568
569        Ok(())
570    }
571}
572
573/// Build a new client.
574///
575/// Returns the client and the data required to restore the client.
576async fn build_client(data_dir: &Path) -> anyhow::Result<(Client, ClientSession)> {
577    let db_path = data_dir.join("db");
578
579    // Generate a random passphrase.
580    let mut rng = thread_rng();
581    let passphrase: String =
582        (&mut rng).sample_iter(Alphanumeric).take(32).map(char::from).collect();
583
584    // We create a loop here so the user can retry if an error happens.
585    loop {
586        let mut homeserver = String::new();
587
588        print!("\nHomeserver: ");
589        io::stdout().flush().expect("Unable to write to stdout");
590        io::stdin().read_line(&mut homeserver).expect("Unable to read user input");
591
592        let homeserver = homeserver.trim();
593
594        println!("\nChecking homeserver…");
595
596        match Client::builder()
597            // Try autodiscovery or test the URL.
598            .server_name_or_homeserver_url(homeserver)
599            // Make sure to automatically refresh tokens if needs be.
600            .handle_refresh_tokens()
601            // We use the sqlite store, which is available by default. This is the crucial part to
602            // persist the encryption setup.
603            // Note that other store backends are available and you can even implement your own.
604            .sqlite_store(&db_path, Some(&passphrase))
605            .build()
606            .await
607        {
608            Ok(client) => {
609                // Check if the homeserver advertises OAuth 2.0 server metadata.
610                match client.oauth().server_metadata().await {
611                    Ok(server_metadata) => {
612                        println!(
613                            "Found OAuth 2.0 server metadata with issuer: {}",
614                            server_metadata.issuer
615                        );
616
617                        let homeserver = client.homeserver().to_string();
618                        return Ok((client, ClientSession { homeserver, db_path, passphrase }));
619                    }
620                    Err(error) => {
621                        if error.is_not_supported() {
622                            println!(
623                                "This homeserver doesn't advertise OAuth 2.0 server metadata."
624                            );
625                        } else {
626                            println!("Error fetching the OAuth 2.0 server metadata: {error:?}");
627                        }
628                        // The client already initialized the store so we need to remove it.
629                        fs::remove_dir_all(data_dir).await?;
630                    }
631                }
632            }
633            Err(error) => match &error {
634                ClientBuildError::AutoDiscovery(_)
635                | ClientBuildError::Url(_)
636                | ClientBuildError::Http(_) => {
637                    println!("Error checking the homeserver: {error}");
638                    println!("Please try again\n");
639                    // The client already initialized the store so we need to remove it.
640                    fs::remove_dir_all(data_dir).await?;
641                }
642                ClientBuildError::InvalidServerName => {
643                    println!("Error: not a valid server name");
644                    println!("Please try again\n");
645                }
646                _ => {
647                    // Forward other errors, it's unlikely we can retry with a different outcome.
648                    return Err(error.into());
649                }
650            },
651        }
652    }
653}
654
655/// Generate the OAuth 2.0 client metadata.
656fn client_metadata() -> Raw<ClientMetadata> {
657    // Native clients should be able to register the IPv4 and IPv6 loopback
658    // interfaces and then point to any port when needing a redirect URI. An
659    // alternative is to use a custom URI scheme registered with the OS.
660    let ipv4_localhost_uri = Url::parse(&format!("http://{}/", Ipv4Addr::LOCALHOST))
661        .expect("Couldn't parse IPv4 redirect URI");
662    let ipv6_localhost_uri = Url::parse(&format!("http://[{}]/", Ipv6Addr::LOCALHOST))
663        .expect("Couldn't parse IPv6 redirect URI");
664    let client_uri = Localized::new(
665        Url::parse("https://github.com/matrix-org/matrix-rust-sdk")
666            .expect("Couldn't parse client URI"),
667        None,
668    );
669
670    let metadata = ClientMetadata {
671        // The following fields should be displayed in the OAuth 2.0 authorization server's
672        // web UI as part of the process to get the user's consent. It means that these
673        // should contain real data so the user can make sure that they allow the proper
674        // application. We are cheating here because this is an example.
675        client_name: Some(Localized::new("matrix-rust-sdk-oauth-cli".to_owned(), [])),
676        policy_uri: Some(client_uri.clone()),
677        tos_uri: Some(client_uri.clone()),
678        ..ClientMetadata::new(
679            // This is a native application (in contrast to a web application, that runs in a
680            // browser).
681            ApplicationType::Native,
682            // We are going to use the Authorization Code flow.
683            vec![OAuthGrantType::AuthorizationCode {
684                redirect_uris: vec![ipv4_localhost_uri, ipv6_localhost_uri],
685            }],
686            client_uri,
687        )
688    };
689
690    Raw::new(&metadata).expect("Couldn't serialize client metadata")
691}
692
693/// Open the authorization URL and wait for it to be complete.
694///
695/// Returns the code to obtain the access token.
696async fn use_auth_url(url: &Url, server_handle: LocalServerRedirectHandle) -> Option<QueryString> {
697    println!("\nPlease authenticate yourself at: {url}\n");
698    println!("Then proceed to the authorization.\n");
699
700    server_handle.await
701}
702
703/// Handle room messages.
704async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) {
705    // We only want to log text messages in joined rooms.
706    if room.state() != RoomState::Joined {
707        return;
708    }
709    let MessageType::Text(text_content) = &event.content.msgtype else { return };
710
711    let room_name = match room.display_name().await {
712        Ok(room_name) => room_name.to_string(),
713        Err(error) => {
714            println!("Error getting room display name: {error}");
715            // Let's fallback to the room ID.
716            room.room_id().to_string()
717        }
718    };
719
720    println!("[{room_name}] {}: {}", event.sender, text_content.body)
721}