example_oauth_cli/
main.rs

1// Copyright 2023 Kévin Commaille.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    io::{self, Write},
17    net::{Ipv4Addr, Ipv6Addr},
18    path::{Path, PathBuf},
19    sync::Arc,
20};
21
22use anyhow::bail;
23use futures_util::StreamExt;
24use matrix_sdk::{
25    authentication::oauth::{
26        error::OAuthClientRegistrationError,
27        registration::{ApplicationType, ClientMetadata, Localized, OAuthGrantType},
28        AccountManagementActionFull, ClientId, OAuthAuthorizationData, OAuthError,
29        OAuthRegistrationStore, OAuthSession, UrlOrQuery, UserSession,
30    },
31    config::SyncSettings,
32    encryption::{recovery::RecoveryState, CrossSigningResetAuthType},
33    room::Room,
34    ruma::{
35        events::room::message::{MessageType, OriginalSyncRoomMessageEvent},
36        serde::Raw,
37    },
38    utils::local_server::{LocalServerBuilder, LocalServerRedirectHandle, QueryString},
39    Client, ClientBuildError, Result, RoomState,
40};
41use matrix_sdk_ui::sync_service::SyncService;
42use rand::{distributions::Alphanumeric, thread_rng, Rng};
43use serde::{Deserialize, Serialize};
44use tokio::{fs, io::AsyncBufReadExt as _};
45use url::Url;
46
47/// A command-line tool to demonstrate the steps requiring an interaction with
48/// an OAuth 2.0 authorization server for a Matrix client, using the
49/// Authorization Code flow.
50///
51/// You can test this against one of the servers from the OIDC playground:
52/// <https://github.com/element-hq/oidc-playground>.
53///
54/// To use this, just run `cargo run -p example-oauth-cli`, and everything
55/// is interactive after that. You might want to set the `RUST_LOG` environment
56/// variable to `warn` to reduce the noise in the logs. The program exits
57/// whenever an unexpected error occurs.
58///
59/// To reset the login, simply use the `logout` command or delete the folder
60/// containing the session file, the location is shown in the logs. Note that
61/// the database must be deleted too as it can't be reused.
62#[tokio::main]
63async fn main() -> anyhow::Result<()> {
64    tracing_subscriber::fmt::init();
65
66    // The folder containing this example's data.
67    let data_dir =
68        dirs::data_dir().expect("no data_dir directory found").join("matrix_sdk/oauth_cli");
69    // The file where the session is persisted.
70    let session_file = data_dir.join("session.json");
71
72    let cli = if session_file.exists() {
73        OAuthCli::from_stored_session(session_file).await?
74    } else {
75        OAuthCli::new(&data_dir, session_file).await?
76    };
77
78    cli.run().await
79}
80
81/// The available commands once the client is logged in.
82fn help() {
83    println!("Usage: [command] [args…]\n");
84    println!("Commands:");
85    println!("  whoami                 Get information about this session");
86    println!("  account                Get the URL to manage this account");
87    println!("  watch [sliding?]       Watch new incoming messages until an error occurs");
88    println!("  refresh                Refresh the access token");
89    println!("  recover                Recover the E2EE secrets from secret storage");
90    println!("  logout                 Log out of this account");
91    println!("  exit                   Exit this program");
92    println!("  help                   Show this message\n");
93}
94
95/// The data needed to re-build a client.
96#[derive(Debug, Serialize, Deserialize)]
97struct ClientSession {
98    /// The URL of the homeserver of the user.
99    homeserver: String,
100
101    /// The path of the database.
102    db_path: PathBuf,
103
104    /// The passphrase of the database.
105    passphrase: String,
106}
107
108/// The full session to persist.
109#[derive(Debug, Serialize, Deserialize)]
110struct StoredSession {
111    /// The data to re-build the client.
112    client_session: ClientSession,
113
114    /// The OAuth 2.0 user session.
115    user_session: UserSession,
116
117    /// The OAuth 2.0 client ID.
118    client_id: ClientId,
119}
120
121/// An OAuth 2.0 CLI.
122#[derive(Clone, Debug)]
123struct OAuthCli {
124    /// The Matrix client.
125    client: Client,
126
127    /// Whether this is a restored client.
128    restored: bool,
129
130    /// The path to the file storing the session.
131    session_file: PathBuf,
132}
133
134impl OAuthCli {
135    /// Create a new session by logging in.
136    async fn new(data_dir: &Path, session_file: PathBuf) -> anyhow::Result<Self> {
137        println!("No previous session found, logging in…");
138
139        let (client, client_session) = build_client(data_dir).await?;
140        let cli = Self { client, restored: false, session_file };
141
142        if let Err(error) = cli.register_and_login(data_dir).await {
143            if error.downcast_ref::<OAuthError>().is_some_and(|error| {
144                matches!(
145                    error,
146                    OAuthError::ClientRegistration(OAuthClientRegistrationError::NotSupported)
147                )
148            }) {
149                // This would require to register with the authorization server manually, which
150                // we don't support here.
151                bail!(
152                    "This server doesn't support dynamic registration.\n\
153                     Please select another homeserver."
154                );
155            } else {
156                return Err(error);
157            }
158        }
159
160        // Persist the session to reuse it later.
161        // This is not very secure, for simplicity. If the system provides a way of
162        // storing secrets securely, it should be used instead.
163        let full_session =
164            cli.client.oauth().full_session().expect("A logged-in client should have a session");
165
166        let serialized_session = serde_json::to_string(&StoredSession {
167            client_session,
168            user_session: full_session.user,
169            client_id: full_session.client_id,
170        })?;
171        fs::write(&cli.session_file, serialized_session).await?;
172
173        println!("Session persisted in {}", cli.session_file.to_string_lossy());
174
175        cli.setup_background_save();
176
177        Ok(cli)
178    }
179
180    /// Register the client and log in the user via the OAuth 2.0 Authorization
181    /// Code flow.
182    async fn register_and_login(&self, data_dir: &Path) -> anyhow::Result<()> {
183        let oauth = self.client.oauth();
184
185        // We create a loop here so the user can retry if an error happens.
186        loop {
187            // The registrations store allows to store the client ID separately
188            // and reuse it between sessions, so we don't need to
189            // re-register each time.
190            // We put it in the parent directory because we remove the data directory on
191            // logout.
192            let registration_file =
193                data_dir.parent().expect("directory has a parent").join("oauth_registrations");
194            let registration_store =
195                OAuthRegistrationStore::new(registration_file, client_metadata()).await?;
196
197            // Here we spawn a server to listen on the loopback interface. Another option
198            // would be to register a custom URI scheme with the system and handle
199            // the redirect when the custom URI scheme is opened.
200            let (redirect_uri, server_handle) = LocalServerBuilder::new().spawn().await?;
201
202            let OAuthAuthorizationData { url, .. } =
203                oauth.login(registration_store.into(), redirect_uri, None).build().await?;
204
205            let query_string =
206                use_auth_url(&url, server_handle).await.map(|query| query.0).unwrap_or_default();
207
208            match oauth.finish_login(UrlOrQuery::Query(query_string)).await {
209                Ok(()) => {
210                    let user_id = self.client.user_id().expect("Got a user ID");
211                    println!("Logged in as {user_id}");
212                    break;
213                }
214                Err(err) => {
215                    println!("Error: failed to login: {err}");
216                    println!("Please try again.\n");
217                    continue;
218                }
219            }
220        }
221
222        Ok(())
223    }
224
225    /// Restore a previous session from a file.
226    async fn from_stored_session(session_file: PathBuf) -> anyhow::Result<Self> {
227        println!("Previous session found in '{}'", session_file.to_string_lossy());
228
229        // The session was serialized as JSON in a file.
230        let serialized_session = fs::read_to_string(&session_file).await?;
231        let StoredSession { client_session, user_session, client_id } =
232            serde_json::from_str(&serialized_session)?;
233
234        // Build the client with the previous settings from the session.
235        let client = Client::builder()
236            .homeserver_url(client_session.homeserver)
237            .handle_refresh_tokens()
238            .sqlite_store(client_session.db_path, Some(&client_session.passphrase))
239            .build()
240            .await?;
241
242        println!("Restoring session for {}…", user_session.meta.user_id);
243
244        let session = OAuthSession { client_id, user: user_session };
245        // Restore the Matrix user session.
246        client.restore_session(session).await?;
247
248        let this = Self { client, restored: true, session_file };
249
250        this.setup_background_save();
251
252        Ok(this)
253    }
254
255    /// Run the main program.
256    async fn run(&self) -> anyhow::Result<()> {
257        help();
258
259        loop {
260            let mut input = String::new();
261
262            print!("\nEnter command: ");
263            io::stdout().flush().expect("Unable to write to stdout");
264
265            io::stdin().read_line(&mut input).expect("Unable to read user input");
266
267            let mut args = input.trim().split_ascii_whitespace();
268            let cmd = args.next();
269
270            match cmd {
271                Some("whoami") => {
272                    self.whoami();
273                }
274                Some("account") => {
275                    self.account(None).await;
276                }
277                Some("profile") => {
278                    self.account(Some(AccountManagementActionFull::Profile)).await;
279                }
280                Some("sessions") => {
281                    self.account(Some(AccountManagementActionFull::SessionsList)).await;
282                }
283                Some("watch") => match args.next() {
284                    Some(sub) => {
285                        if sub == "sliding" {
286                            self.watch_sliding_sync().await?;
287                        } else {
288                            println!("unknown subcommand for watch: available is 'sliding'");
289                        }
290                    }
291                    None => self.watch().await?,
292                },
293                Some("refresh") => {
294                    self.refresh_token().await?;
295                }
296                Some("recover") => {
297                    self.recover().await?;
298                }
299                Some("reset-cross-signing") => {
300                    self.reset_cross_signing().await?;
301                }
302                Some("logout") => {
303                    self.logout().await?;
304                    break;
305                }
306                Some("exit") => {
307                    break;
308                }
309                Some("help") => {
310                    help();
311                }
312                Some(cmd) => {
313                    println!("Error: unknown command '{cmd}'\n");
314                    help();
315                }
316                None => {
317                    println!("Error: no command\n");
318                    help()
319                }
320            };
321        }
322
323        Ok(())
324    }
325
326    async fn recover(&self) -> anyhow::Result<()> {
327        let recovery = self.client.encryption().recovery();
328
329        println!("Please enter your recovery key:");
330
331        let mut input = String::new();
332        io::stdin().read_line(&mut input).expect("error: unable to read user input");
333
334        let input = input.trim();
335
336        recovery.recover(input).await?;
337
338        match recovery.state() {
339            RecoveryState::Enabled => println!("Successfully recovered all the E2EE secrets."),
340            RecoveryState::Disabled => println!("Error recovering, recovery is disabled."),
341            RecoveryState::Incomplete => println!("Couldn't recover all E2EE secrets."),
342            _ => unreachable!("We should know our recovery state by now"),
343        }
344
345        Ok(())
346    }
347
348    async fn reset_cross_signing(&self) -> Result<()> {
349        let encryption = self.client.encryption();
350
351        if let Some(handle) = encryption.reset_cross_signing().await? {
352            match handle.auth_type() {
353                CrossSigningResetAuthType::Uiaa(_) => {
354                    unimplemented!(
355                        "This should never happen, this is after all the OAuth 2.0 example."
356                    )
357                }
358                CrossSigningResetAuthType::OAuth(o) => {
359                    println!(
360                        "To reset your end-to-end encryption cross-signing identity, \
361                        you first need to approve it at {}",
362                        o.approval_url
363                    );
364                    handle.auth(None).await?;
365                }
366            }
367        }
368
369        print!("Successfully reset cross-signing");
370
371        Ok(())
372    }
373
374    /// Get information about this session.
375    fn whoami(&self) {
376        let client = &self.client;
377        let oauth = client.oauth();
378
379        let user_id = client.user_id().expect("A logged in client has a user ID");
380        let device_id = client.device_id().expect("A logged in client has a device ID");
381        let homeserver = client.homeserver();
382        let issuer = oauth.issuer().expect("A logged in OAuth 2.0 client has an issuer");
383
384        println!("\nUser ID: {user_id}");
385        println!("Device ID: {device_id}");
386        println!("Homeserver URL: {homeserver}");
387        println!("OAuth 2.0 authorization server: {issuer}");
388    }
389
390    /// Get the account management URL.
391    async fn account(&self, action: Option<AccountManagementActionFull>) {
392        let mut url_builder = match self.client.oauth().fetch_account_management_url().await {
393            Ok(Some(url_builder)) => url_builder,
394            _ => {
395                println!("\nThis homeserver does not provide the URL to manage your account");
396                return;
397            }
398        };
399
400        if let Some(action) = action {
401            url_builder = url_builder.action(action);
402        }
403
404        let url = url_builder.build();
405        println!("\nTo manage your account, visit: {url}");
406    }
407
408    /// Watch incoming messages.
409    async fn watch(&self) -> anyhow::Result<()> {
410        let client = &self.client;
411
412        // If this is a new client, ignore previous messages to not fill the logs.
413        // Note that this might not work as intended, the initial sync might have failed
414        // in a previous session.
415        if !self.restored {
416            client.sync_once(SyncSettings::default()).await.unwrap();
417        }
418
419        // Listen to room messages.
420        let handle = client.add_event_handler(on_room_message);
421
422        // Sync.
423        let mut sync_stream = Box::pin(client.sync_stream(SyncSettings::default()).await);
424        while let Some(res) = sync_stream.next().await {
425            if let Err(err) = res {
426                client.remove_event_handler(handle);
427                return Err(err.into());
428            }
429        }
430
431        Ok(())
432    }
433
434    /// This watches for incoming responses using the high-level sliding sync
435    /// helpers (`SyncService`).
436    async fn watch_sliding_sync(&self) -> anyhow::Result<()> {
437        let sync_service = Arc::new(SyncService::builder(self.client.clone()).build().await?);
438
439        sync_service.start().await;
440
441        println!("press enter to exit the sync loop");
442
443        let mut sync_service_state = sync_service.state();
444
445        let sync_service_clone = sync_service.clone();
446        let task = tokio::spawn(async move {
447            // Only fail after getting 5 errors in a row. When we're in an always-refail
448            // scenario, we move from the Error to the Running state for a bit
449            // until we fail again, so we need to track both failure state and
450            // running state, hence `num_errors` and `num_running`:
451            // - if we failed and num_running was 1, then this is a failure following a
452            //   failure.
453            // - otherwise, we recovered from the failure and we can plain continue.
454
455            let mut num_errors = 0;
456            let mut num_running = 0;
457
458            let mut _unused = String::new();
459            let mut stdin = tokio::io::BufReader::new(tokio::io::stdin());
460
461            loop {
462                // Concurrently wait for an update from the sync service OR for the user to
463                // press enter and leave early.
464                tokio::select! {
465                    res = sync_service_state.next() => {
466                        if let Some(state) = res {
467                            match state {
468                                matrix_sdk_ui::sync_service::State::Idle
469                                | matrix_sdk_ui::sync_service::State::Terminated => {
470                                    num_errors = 0;
471                                    num_running = 0;
472                                }
473
474                                matrix_sdk_ui::sync_service::State::Running => {
475                                    num_running += 1;
476                                    if num_running > 1 {
477                                        num_errors = 0;
478                                    }
479                                }
480
481                                matrix_sdk_ui::sync_service::State::Error | matrix_sdk_ui::sync_service::State::Offline => {
482                                    num_errors += 1;
483                                    num_running = 0;
484
485                                    if num_errors == 5 {
486                                        println!("ran into 5 errors in a row, terminating");
487                                        break;
488                                    }
489
490                                    sync_service_clone.start().await;
491                                }
492                            }
493                            println!("New sync service state update: {state:?}");
494                        } else {
495                            break;
496                        }
497                    }
498
499                    _ = stdin.read_line(&mut _unused) => {
500                        println!("Stopping loop because of user request");
501                        sync_service.stop().await;
502
503                        break;
504                    }
505                }
506            }
507        });
508
509        println!("waiting for sync service to stop...");
510        task.await.unwrap();
511
512        println!("done!");
513        Ok(())
514    }
515
516    /// Sets up this client so that it automatically saves the session onto disk
517    /// whenever there are new tokens that have been received.
518    ///
519    /// This should always be set up whenever automatic refresh is happening.
520    fn setup_background_save(&self) {
521        let this = self.clone();
522        tokio::spawn(async move {
523            while let Ok(update) = this.client.subscribe_to_session_changes().recv().await {
524                match update {
525                    matrix_sdk::SessionChange::UnknownToken { soft_logout } => {
526                        println!("Received an unknown token error; soft logout? {soft_logout:?}");
527                    }
528                    matrix_sdk::SessionChange::TokensRefreshed => {
529                        // The tokens have been refreshed, persist them to disk.
530                        if let Err(err) = this.update_stored_session().await {
531                            println!("Unable to store a session in the background: {err}");
532                        }
533                    }
534                }
535            }
536        });
537    }
538
539    /// Update the session stored on the system.
540    ///
541    /// This should be called everytime the access token (and possibly refresh
542    /// token) has changed.
543    async fn update_stored_session(&self) -> anyhow::Result<()> {
544        println!("Updating the stored session...");
545
546        let serialized_session = fs::read_to_string(&self.session_file).await?;
547        let mut session = serde_json::from_str::<StoredSession>(&serialized_session)?;
548
549        let user_session: UserSession =
550            self.client.oauth().user_session().expect("A logged in client has a session");
551        session.user_session = user_session;
552
553        let serialized_session = serde_json::to_string(&session)?;
554        fs::write(&self.session_file, serialized_session).await?;
555
556        println!("Updating the stored session: done!");
557        Ok(())
558    }
559
560    /// Refresh the access token.
561    async fn refresh_token(&self) -> anyhow::Result<()> {
562        self.client.oauth().refresh_access_token().await?;
563
564        // The session will automatically be refreshed because of the task persisting
565        // the full session upon refresh in `setup_background_save`.
566
567        println!("\nToken refreshed successfully");
568
569        Ok(())
570    }
571
572    /// Log out from this session.
573    async fn logout(&self) -> anyhow::Result<()> {
574        // Log out via OAuth 2.0.
575        self.client.oauth().logout().await?;
576
577        // Delete the stored session and database.
578        let data_dir = self.session_file.parent().expect("The file has a parent directory");
579        fs::remove_dir_all(data_dir).await?;
580
581        println!("\nLogged out successfully");
582        println!("\nExiting…");
583
584        Ok(())
585    }
586}
587
588/// Build a new client.
589///
590/// Returns the client and the data required to restore the client.
591async fn build_client(data_dir: &Path) -> anyhow::Result<(Client, ClientSession)> {
592    let db_path = data_dir.join("db");
593
594    // Generate a random passphrase.
595    let mut rng = thread_rng();
596    let passphrase: String =
597        (&mut rng).sample_iter(Alphanumeric).take(32).map(char::from).collect();
598
599    // We create a loop here so the user can retry if an error happens.
600    loop {
601        let mut homeserver = String::new();
602
603        print!("\nHomeserver: ");
604        io::stdout().flush().expect("Unable to write to stdout");
605        io::stdin().read_line(&mut homeserver).expect("Unable to read user input");
606
607        let homeserver = homeserver.trim();
608
609        println!("\nChecking homeserver…");
610
611        match Client::builder()
612            // Try autodiscovery or test the URL.
613            .server_name_or_homeserver_url(homeserver)
614            // Make sure to automatically refresh tokens if needs be.
615            .handle_refresh_tokens()
616            // We use the sqlite store, which is available by default. This is the crucial part to
617            // persist the encryption setup.
618            // Note that other store backends are available and you can even implement your own.
619            .sqlite_store(&db_path, Some(&passphrase))
620            .build()
621            .await
622        {
623            Ok(client) => {
624                // Check if the homeserver advertises OAuth 2.0 server metadata.
625                match client.oauth().server_metadata().await {
626                    Ok(server_metadata) => {
627                        println!(
628                            "Found OAuth 2.0 server metadata with issuer: {}",
629                            server_metadata.issuer
630                        );
631
632                        let homeserver = client.homeserver().to_string();
633                        return Ok((client, ClientSession { homeserver, db_path, passphrase }));
634                    }
635                    Err(error) => {
636                        if error.is_not_supported() {
637                            println!(
638                                "This homeserver doesn't advertise OAuth 2.0 server metadata."
639                            );
640                        } else {
641                            println!("Error fetching the OAuth 2.0 server metadata: {error:?}");
642                        }
643                        // The client already initialized the store so we need to remove it.
644                        fs::remove_dir_all(data_dir).await?;
645                    }
646                }
647            }
648            Err(error) => match &error {
649                ClientBuildError::AutoDiscovery(_)
650                | ClientBuildError::Url(_)
651                | ClientBuildError::Http(_) => {
652                    println!("Error checking the homeserver: {error}");
653                    println!("Please try again\n");
654                    // The client already initialized the store so we need to remove it.
655                    fs::remove_dir_all(data_dir).await?;
656                }
657                ClientBuildError::InvalidServerName => {
658                    println!("Error: not a valid server name");
659                    println!("Please try again\n");
660                }
661                _ => {
662                    // Forward other errors, it's unlikely we can retry with a different outcome.
663                    return Err(error.into());
664                }
665            },
666        }
667    }
668}
669
670/// Generate the OAuth 2.0 client metadata.
671fn client_metadata() -> Raw<ClientMetadata> {
672    // Native clients should be able to register the IPv4 and IPv6 loopback
673    // interfaces and then point to any port when needing a redirect URI. An
674    // alternative is to use a custom URI scheme registered with the OS.
675    let ipv4_localhost_uri = Url::parse(&format!("http://{}/", Ipv4Addr::LOCALHOST))
676        .expect("Couldn't parse IPv4 redirect URI");
677    let ipv6_localhost_uri = Url::parse(&format!("http://[{}]/", Ipv6Addr::LOCALHOST))
678        .expect("Couldn't parse IPv6 redirect URI");
679    let client_uri = Localized::new(
680        Url::parse("https://github.com/matrix-org/matrix-rust-sdk")
681            .expect("Couldn't parse client URI"),
682        None,
683    );
684
685    let metadata = ClientMetadata {
686        // The following fields should be displayed in the OAuth 2.0 authorization server's
687        // web UI as part of the process to get the user's consent. It means that these
688        // should contain real data so the user can make sure that they allow the proper
689        // application. We are cheating here because this is an example.
690        client_name: Some(Localized::new("matrix-rust-sdk-oauth-cli".to_owned(), [])),
691        policy_uri: Some(client_uri.clone()),
692        tos_uri: Some(client_uri.clone()),
693        ..ClientMetadata::new(
694            // This is a native application (in contrast to a web application, that runs in a
695            // browser).
696            ApplicationType::Native,
697            // We are going to use the Authorization Code flow.
698            vec![OAuthGrantType::AuthorizationCode {
699                redirect_uris: vec![ipv4_localhost_uri, ipv6_localhost_uri],
700            }],
701            client_uri,
702        )
703    };
704
705    Raw::new(&metadata).expect("Couldn't serialize client metadata")
706}
707
708/// Open the authorization URL and wait for it to be complete.
709///
710/// Returns the code to obtain the access token.
711async fn use_auth_url(url: &Url, server_handle: LocalServerRedirectHandle) -> Option<QueryString> {
712    println!("\nPlease authenticate yourself at: {url}\n");
713    println!("Then proceed to the authorization.\n");
714
715    server_handle.await
716}
717
718/// Handle room messages.
719async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) {
720    // We only want to log text messages in joined rooms.
721    if room.state() != RoomState::Joined {
722        return;
723    }
724    let MessageType::Text(text_content) = &event.content.msgtype else { return };
725
726    let room_name = match room.display_name().await {
727        Ok(room_name) => room_name.to_string(),
728        Err(error) => {
729            println!("Error getting room display name: {error}");
730            // Let's fallback to the room ID.
731            room.room_id().to_string()
732        }
733    };
734
735    println!("[{room_name}] {}: {}", event.sender, text_content.body)
736}