matrix_sdk/utils/
local_server.rs

1// Copyright 2025 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! A server that binds to a random port on localhost and waits for a `GET` HTTP
16//! request.
17//!
18//! # Example
19//!
20//! ```no_run
21//! # async {
22//! use matrix_sdk::utils::local_server::LocalServerBuilder;
23//! # let open_uri = |uri: url::Url| {};
24//! # let parse_query_string = |query: &str| {};
25//!
26//! let (uri, server_handle) = LocalServerBuilder::new().spawn().await?;
27//!
28//! open_uri(uri);
29//!
30//! if let Some(query_string) = server_handle.await {
31//!     parse_query_string(&query_string);
32//! }
33//!
34//! # anyhow::Ok(()) };
35//! ```
36
37use std::{
38    convert::Infallible,
39    fmt,
40    future::IntoFuture,
41    io,
42    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
43    ops::{Deref, Range},
44    sync::Arc,
45};
46
47use axum::{body::Body, response::IntoResponse, routing::any_service};
48use http::{header, HeaderValue, Method, Request, StatusCode};
49use matrix_sdk_base::{boxed_into_future, locks::Mutex};
50use rand::{thread_rng, Rng};
51use tokio::{net::TcpListener, sync::oneshot};
52use tower::service_fn;
53use url::Url;
54
55/// The default range of ports the server will try to bind to randomly.
56const DEFAULT_PORT_RANGE: Range<u16> = 20000..30000;
57/// The default number of times the server will try to bind to a random
58/// port.
59const DEFAULT_BIND_TRIES: u8 = 10;
60
61/// Builder for a server that binds on a random port on localhost and waits for
62/// a `GET` HTTP request.
63///
64/// The server is spawned when calling [`LocalServerBuilder::spawn()`].
65///
66/// The query string of the URI where the end-user is redirected is available by
67/// `.await`ing the [`LocalServerRedirectHandle`], in case it should receive
68/// parameters.
69///
70/// The server is shutdown when [`LocalServerRedirectHandle`] is dropped. It can
71/// also be shutdown manually with a [`LocalServerShutdownHandle`] obtained from
72/// [`LocalServerRedirectHandle::shutdown_handle()`].
73#[derive(Debug, Default, Clone)]
74pub struct LocalServerBuilder {
75    ip_address: Option<LocalServerIpAddress>,
76    port_range: Option<Range<u16>>,
77    bind_tries: Option<u8>,
78    response: Option<LocalServerResponse>,
79}
80
81impl LocalServerBuilder {
82    /// Construct a [`LocalServerBuilder`] using the default settings.
83    pub fn new() -> Self {
84        Self::default()
85    }
86
87    /// Set the IP address that the server should to bind to.
88    ///
89    /// Defaults to [`LocalServerIpAddress::LocalhostAny`].
90    pub fn ip_address(mut self, ip_address: LocalServerIpAddress) -> Self {
91        self.ip_address = Some(ip_address);
92        self
93    }
94
95    /// The default range of ports the server will try to bind to randomly.
96    ///
97    /// Care should be taken not to bind to a [port blocked by browsers].
98    ///
99    /// Defaults to ports in the `20000..30000` range.
100    ///
101    /// [port blocked by browsers]: https://fetch.spec.whatwg.org/#port-blocking
102    pub fn port_range(mut self, range: Range<u16>) -> Self {
103        self.port_range = Some(range);
104        self
105    }
106
107    /// The number of times the server will try to bind to a random port on
108    /// localhost.
109    ///
110    /// Since random ports might already be taken, this setting allows to try to
111    /// bind to several random ports before giving up.
112    ///
113    /// Defaults to `10`.
114    pub fn bind_tries(mut self, tries: u8) -> Self {
115        self.bind_tries = Some(tries);
116        self
117    }
118
119    /// Set the content of the page that the end user will see when they a
120    /// redirected to the server's URI.
121    ///
122    /// Defaults to a plain text page with a generic message.
123    pub fn response(mut self, response: LocalServerResponse) -> Self {
124        self.response = Some(response);
125        self
126    }
127
128    /// Spawn the server.
129    ///
130    /// Returns the [`Url`] where the server is listening, and a
131    /// [`LocalServerRedirectHandle`] to `await` the redirect or to shutdown
132    /// the server. Returns an error if the server could not be bound to a port
133    /// on localhost.
134    pub async fn spawn(self) -> Result<(Url, LocalServerRedirectHandle), io::Error> {
135        let Self { ip_address, port_range, bind_tries, response } = self;
136
137        // Bind a TCP listener to a random port.
138        let listener = {
139            let ip_addresses = ip_address.unwrap_or_default().ip_addresses();
140            let port_range = port_range.unwrap_or(DEFAULT_PORT_RANGE);
141            let bind_tries = bind_tries.unwrap_or(DEFAULT_BIND_TRIES);
142            let mut n = 0u8;
143
144            loop {
145                let port = thread_rng().gen_range(port_range.clone());
146                let socket_addresses =
147                    ip_addresses.iter().map(|ip| SocketAddr::new(*ip, port)).collect::<Vec<_>>();
148
149                match TcpListener::bind(socket_addresses.as_slice()).await {
150                    Ok(l) => {
151                        break l;
152                    }
153                    Err(_) if n < bind_tries => {
154                        n += 1;
155                    }
156                    Err(e) => {
157                        return Err(e);
158                    }
159                }
160            }
161        };
162
163        let socket_address =
164            listener.local_addr().expect("bound TCP listener should have an address");
165        let uri = Url::parse(&format!("http://{socket_address}/"))
166            .expect("socket address should parse as a URI host");
167
168        // The channel used to shutdown the server when we are done with it.
169        let (shutdown_signal_sender, shutdown_signal_receiver) = oneshot::channel::<()>();
170        // The channel used to transmit the data received a the redirect URL.
171        let (data_sender, data_receiver) = oneshot::channel::<Option<QueryString>>();
172        let data_sender_mutex = Arc::new(Mutex::new(Some(data_sender)));
173
174        // Set up the server.
175        let router = any_service(service_fn(move |request: Request<_>| {
176            let data_sender_mutex = data_sender_mutex.clone();
177            let response = response.clone();
178
179            async move {
180                // Reject methods others than HEAD or GET.
181                if request.method() != Method::HEAD && request.method() != Method::GET {
182                    return Ok::<_, Infallible>(StatusCode::METHOD_NOT_ALLOWED.into_response());
183                }
184
185                // We only need to get the first response so we consume the transmitter the
186                // first time.
187                if let Some(data_sender) = data_sender_mutex.lock().take() {
188                    let _ =
189                        data_sender.send(request.uri().query().map(|s| QueryString(s.to_owned())));
190                }
191
192                Ok(response.unwrap_or_default().into_response())
193            }
194        }));
195
196        let server = axum::serve(listener, router)
197            .with_graceful_shutdown(async {
198                shutdown_signal_receiver.await.ok();
199            })
200            .into_future();
201
202        // Spawn the server.
203        tokio::spawn(server);
204
205        Ok((
206            uri,
207            LocalServerRedirectHandle {
208                data_receiver: Some(data_receiver),
209                shutdown_signal_sender: Arc::new(Mutex::new(Some(shutdown_signal_sender))),
210            },
211        ))
212    }
213}
214
215/// A handle to wait for the end-user to be redirected to a server spawned by
216/// [`LocalServerBuilder`].
217///
218/// Constructed with [`LocalServerBuilder::spawn()`].
219///
220/// `await`ing this type returns the query string of the URI where the end-user
221/// is redirected.
222///
223/// The server is shutdown when this handle is dropped. It can also be shutdown
224/// manually with a [`LocalServerShutdownHandle`] obtained from
225/// [`LocalServerRedirectHandle::shutdown_handle()`].
226#[allow(missing_debug_implementations)]
227pub struct LocalServerRedirectHandle {
228    /// The receiver to receive the query string.
229    data_receiver: Option<oneshot::Receiver<Option<QueryString>>>,
230
231    /// The sender used to send the signal to shutdown the server.
232    shutdown_signal_sender: Arc<Mutex<Option<oneshot::Sender<()>>>>,
233}
234
235impl LocalServerRedirectHandle {
236    /// Get a [`LocalServerShutdownHandle`].
237    pub fn shutdown_handle(&self) -> LocalServerShutdownHandle {
238        LocalServerShutdownHandle(self.shutdown_signal_sender.clone())
239    }
240}
241
242impl Drop for LocalServerRedirectHandle {
243    fn drop(&mut self) {
244        if let Some(sender) = self.shutdown_signal_sender.lock().take() {
245            let _ = sender.send(());
246        }
247    }
248}
249
250impl IntoFuture for LocalServerRedirectHandle {
251    type Output = Option<QueryString>;
252    boxed_into_future!();
253
254    fn into_future(self) -> Self::IntoFuture {
255        Box::pin(async move {
256            let mut this = self;
257
258            let data_receiver =
259                this.data_receiver.take().expect("data receiver is set during construction");
260            data_receiver.await.ok().flatten()
261        })
262    }
263}
264
265#[cfg(not(tarpaulin_include))]
266impl fmt::Debug for LocalServerRedirectHandle {
267    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
268        f.debug_struct("LocalServerRedirectHandle").finish_non_exhaustive()
269    }
270}
271
272/// A handle to shutdown a server spawned by [`LocalServerBuilder`].
273///
274/// Constructed with [`LocalServerRedirectHandle::shutdown_handle()`].
275///
276/// Calling [`LocalServerShutdownHandle::shutdown()`] will shutdown the
277/// server before the end-user is redirected to it.
278#[derive(Clone)]
279#[allow(missing_debug_implementations)]
280pub struct LocalServerShutdownHandle(Arc<Mutex<Option<oneshot::Sender<()>>>>);
281
282impl LocalServerShutdownHandle {
283    /// Shutdown the local redirect server.
284    ///
285    /// This is a noop if the server was already shutdown.
286    pub fn shutdown(self) {
287        if let Some(sender) = self.0.lock().take() {
288            let _ = sender.send(());
289        }
290    }
291}
292
293#[cfg(not(tarpaulin_include))]
294impl fmt::Debug for LocalServerShutdownHandle {
295    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296        f.debug_struct("LocalServerShutdownHandle").finish_non_exhaustive()
297    }
298}
299
300/// The IP address that we want the [`LocalServerBuilder`] to bind to.
301#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
302pub enum LocalServerIpAddress {
303    /// Bind to the localhost IPv4 address, `127.0.0.1`.
304    Localhostv4,
305
306    /// Bind to the localhost IPv6 address, `::1`,
307    Localhostv6,
308
309    /// Bind to localhost on an IPv6 or an IPv4 address.
310    ///
311    /// This is the default value.
312    #[default]
313    LocalhostAny,
314
315    /// Bind to a custom IP Address.
316    Custom(IpAddr),
317}
318
319impl LocalServerIpAddress {
320    /// Get the addresses to bind to.
321    fn ip_addresses(self) -> Vec<IpAddr> {
322        match self {
323            Self::Localhostv4 => vec![Ipv4Addr::LOCALHOST.into()],
324            Self::Localhostv6 => vec![Ipv6Addr::LOCALHOST.into()],
325            Self::LocalhostAny => vec![Ipv4Addr::LOCALHOST.into(), Ipv6Addr::LOCALHOST.into()],
326            Self::Custom(ip) => vec![ip],
327        }
328    }
329}
330
331/// The content that the end user will see when they a redirected to the
332/// local server's URI.
333#[derive(Debug, Clone, PartialEq, Eq)]
334pub enum LocalServerResponse {
335    /// A plain text body.
336    PlainText(String),
337
338    /// An HTML body.
339    Html(String),
340}
341
342impl LocalServerResponse {
343    /// Convert this body into an HTTP response.
344    fn into_response(self) -> http::Response<Body> {
345        let (content_type, body) = match self {
346            Self::PlainText(body) => {
347                (HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()), body)
348            }
349            Self::Html(body) => (HeaderValue::from_static(mime::TEXT_HTML_UTF_8.as_ref()), body),
350        };
351
352        let mut response = Body::from(body).into_response();
353        response.headers_mut().insert(header::CONTENT_TYPE, content_type);
354
355        response
356    }
357}
358
359impl Default for LocalServerResponse {
360    fn default() -> Self {
361        LocalServerResponse::PlainText(
362            "The authorization step is complete. You can close this page.".to_owned(),
363        )
364    }
365}
366
367/// A query string from a URI.
368///
369/// This is just a wrapper to have a strong type around a `String`.
370#[derive(Debug, Clone, PartialEq, Eq)]
371pub struct QueryString(pub String);
372
373impl AsRef<str> for QueryString {
374    fn as_ref(&self) -> &str {
375        &self.0
376    }
377}
378
379impl Deref for QueryString {
380    type Target = str;
381
382    fn deref(&self) -> &Self::Target {
383        &self.0
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use std::net::Ipv4Addr;
390
391    use http::header;
392    use matrix_sdk_test::async_test;
393
394    use crate::{
395        assert_let_timeout,
396        utils::local_server::{LocalServerBuilder, LocalServerIpAddress, LocalServerResponse},
397    };
398
399    #[async_test]
400    async fn test_local_server_builder_no_query() {
401        let (uri, server_handle) = LocalServerBuilder::new().spawn().await.unwrap();
402
403        let http_client = reqwest::Client::new();
404        http_client.get(uri.as_str()).send().await.unwrap();
405
406        assert_let_timeout!(None = server_handle);
407    }
408
409    #[async_test]
410    async fn test_local_server_builder_with_query() {
411        let (mut uri, server_handle) = LocalServerBuilder::new().spawn().await.unwrap();
412        uri.set_query(Some("foo=bar"));
413
414        let http_client = reqwest::Client::new();
415        http_client.get(uri.as_str()).send().await.unwrap();
416
417        assert_let_timeout!(Some(query) = server_handle);
418        assert_eq!(query.0, "foo=bar");
419    }
420
421    #[async_test]
422    async fn test_local_server_builder_with_ipv4_and_port() {
423        let (mut uri, server_handle) = LocalServerBuilder::new()
424            .ip_address(LocalServerIpAddress::Localhostv4)
425            .port_range(3000..3001)
426            .bind_tries(1)
427            .spawn()
428            .await
429            .unwrap();
430        uri.set_query(Some("foo=bar"));
431
432        assert_eq!(uri.host_str(), Some("127.0.0.1"));
433        assert_eq!(uri.port(), Some(3000));
434
435        let http_client = reqwest::Client::new();
436        http_client.get(uri.as_str()).send().await.unwrap();
437
438        assert_let_timeout!(Some(query) = server_handle);
439        assert_eq!(query.0, "foo=bar");
440    }
441
442    #[async_test]
443    async fn test_local_server_builder_with_ipv6_and_port() {
444        let (mut uri, server_handle) = LocalServerBuilder::new()
445            .ip_address(LocalServerIpAddress::Localhostv6)
446            .port_range(10000..10001)
447            .bind_tries(1)
448            .spawn()
449            .await
450            .unwrap();
451        uri.set_query(Some("foo=bar"));
452
453        assert_eq!(uri.host_str(), Some("[::1]"));
454        assert_eq!(uri.port(), Some(10000));
455
456        let http_client = reqwest::Client::new();
457        http_client.get(uri.as_str()).send().await.unwrap();
458
459        assert_let_timeout!(Some(query) = server_handle);
460        assert_eq!(query.0, "foo=bar");
461    }
462
463    #[async_test]
464    async fn test_local_server_builder_with_custom_ip_and_port() {
465        let (mut uri, server_handle) = LocalServerBuilder::new()
466            .ip_address(LocalServerIpAddress::Custom(Ipv4Addr::new(127, 0, 0, 1).into()))
467            .port_range(10040..10041)
468            .bind_tries(1)
469            .spawn()
470            .await
471            .unwrap();
472        uri.set_query(Some("foo=bar"));
473
474        assert_eq!(uri.host_str(), Some("127.0.0.1"));
475        assert_eq!(uri.port(), Some(10040));
476
477        let http_client = reqwest::Client::new();
478        http_client.get(uri.as_str()).send().await.unwrap();
479
480        assert_let_timeout!(Some(query) = server_handle);
481        assert_eq!(query.0, "foo=bar");
482    }
483
484    #[async_test]
485    async fn test_local_server_builder_with_custom_plain_text_response() {
486        let text = "Hello world!";
487        let (mut uri, server_handle) = LocalServerBuilder::new()
488            .response(LocalServerResponse::PlainText(text.to_owned()))
489            .spawn()
490            .await
491            .unwrap();
492        uri.set_query(Some("foo=bar"));
493
494        let http_client = reqwest::Client::new();
495        let response = http_client.get(uri.as_str()).send().await.unwrap();
496
497        let content_type = response.headers().get(header::CONTENT_TYPE).unwrap();
498        assert_eq!(content_type, "text/plain; charset=utf-8");
499        assert_eq!(response.text().await.unwrap(), text);
500
501        assert_let_timeout!(Some(query) = server_handle);
502        assert_eq!(query.0, "foo=bar");
503    }
504
505    #[async_test]
506    async fn test_local_server_builder_with_custom_html_response() {
507        let html = "<html><body><h1>Hello world!</h1></body></html>";
508        let (mut uri, server_handle) = LocalServerBuilder::new()
509            .response(LocalServerResponse::Html(html.to_owned()))
510            .spawn()
511            .await
512            .unwrap();
513        uri.set_query(Some("foo=bar"));
514
515        let http_client = reqwest::Client::new();
516        let response = http_client.get(uri.as_str()).send().await.unwrap();
517
518        let content_type = response.headers().get(header::CONTENT_TYPE).unwrap();
519        assert_eq!(content_type, "text/html; charset=utf-8");
520        assert_eq!(response.text().await.unwrap(), html);
521
522        assert_let_timeout!(Some(query) = server_handle);
523        assert_eq!(query.0, "foo=bar");
524    }
525
526    #[async_test]
527    async fn test_local_server_builder_early_shutdown() {
528        let (mut uri, server_handle) = LocalServerBuilder::new().spawn().await.unwrap();
529        uri.set_query(Some("foo=bar"));
530
531        server_handle.shutdown_handle().shutdown();
532
533        let http_client = reqwest::Client::new();
534        http_client.get(uri.as_str()).send().await.unwrap_err();
535
536        assert_let_timeout!(None = server_handle);
537    }
538}