matrix_sdk/utils/
local_server.rs

1// Copyright 2025 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! A server that binds to a random port on localhost and waits for a `GET` HTTP
16//! request.
17//!
18//! # Example
19//!
20//! ```no_run
21//! # async {
22//! use matrix_sdk::utils::local_server::LocalServerBuilder;
23//! # let open_uri = |uri: url::Url| {};
24//! # let parse_query_string = |query: &str| {};
25//!
26//! let (uri, server_handle) = LocalServerBuilder::new().spawn().await?;
27//!
28//! open_uri(uri);
29//!
30//! if let Some(query_string) = server_handle.await {
31//!     parse_query_string(&query_string);
32//! }
33//!
34//! # anyhow::Ok(()) };
35//! ```
36
37use std::{
38    convert::Infallible,
39    fmt,
40    future::IntoFuture,
41    io,
42    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
43    ops::{Deref, Range},
44    sync::Arc,
45};
46
47use axum::{body::Body, response::IntoResponse, routing::any_service};
48use http::{header, HeaderValue, Method, Request, StatusCode};
49use matrix_sdk_base::{boxed_into_future, locks::Mutex};
50use matrix_sdk_common::executor::spawn;
51use rand::{thread_rng, Rng};
52use tokio::{net::TcpListener, sync::oneshot};
53use tower::service_fn;
54use url::Url;
55
56/// The default range of ports the server will try to bind to randomly.
57const DEFAULT_PORT_RANGE: Range<u16> = 20000..30000;
58/// The default number of times the server will try to bind to a random
59/// port.
60const DEFAULT_BIND_TRIES: u8 = 10;
61
62/// Builder for a server that binds on a random port on localhost and waits for
63/// a `GET` HTTP request.
64///
65/// The server is spawned when calling [`LocalServerBuilder::spawn()`].
66///
67/// The query string of the URI where the end-user is redirected is available by
68/// `.await`ing the [`LocalServerRedirectHandle`], in case it should receive
69/// parameters.
70///
71/// The server is shutdown when [`LocalServerRedirectHandle`] is dropped. It can
72/// also be shutdown manually with a [`LocalServerShutdownHandle`] obtained from
73/// [`LocalServerRedirectHandle::shutdown_handle()`].
74#[derive(Debug, Default, Clone)]
75pub struct LocalServerBuilder {
76    ip_address: Option<LocalServerIpAddress>,
77    port_range: Option<Range<u16>>,
78    bind_tries: Option<u8>,
79    response: Option<LocalServerResponse>,
80}
81
82impl LocalServerBuilder {
83    /// Construct a [`LocalServerBuilder`] using the default settings.
84    pub fn new() -> Self {
85        Self::default()
86    }
87
88    /// Set the IP address that the server should to bind to.
89    ///
90    /// Defaults to [`LocalServerIpAddress::LocalhostAny`].
91    pub fn ip_address(mut self, ip_address: LocalServerIpAddress) -> Self {
92        self.ip_address = Some(ip_address);
93        self
94    }
95
96    /// The default range of ports the server will try to bind to randomly.
97    ///
98    /// Care should be taken not to bind to a [port blocked by browsers].
99    ///
100    /// Defaults to ports in the `20000..30000` range.
101    ///
102    /// [port blocked by browsers]: https://fetch.spec.whatwg.org/#port-blocking
103    pub fn port_range(mut self, range: Range<u16>) -> Self {
104        self.port_range = Some(range);
105        self
106    }
107
108    /// The number of times the server will try to bind to a random port on
109    /// localhost.
110    ///
111    /// Since random ports might already be taken, this setting allows to try to
112    /// bind to several random ports before giving up.
113    ///
114    /// Defaults to `10`.
115    pub fn bind_tries(mut self, tries: u8) -> Self {
116        self.bind_tries = Some(tries);
117        self
118    }
119
120    /// Set the content of the page that the end user will see when they a
121    /// redirected to the server's URI.
122    ///
123    /// Defaults to a plain text page with a generic message.
124    pub fn response(mut self, response: LocalServerResponse) -> Self {
125        self.response = Some(response);
126        self
127    }
128
129    /// Spawn the server.
130    ///
131    /// Returns the [`Url`] where the server is listening, and a
132    /// [`LocalServerRedirectHandle`] to `await` the redirect or to shutdown
133    /// the server. Returns an error if the server could not be bound to a port
134    /// on localhost.
135    pub async fn spawn(self) -> Result<(Url, LocalServerRedirectHandle), io::Error> {
136        let Self { ip_address, port_range, bind_tries, response } = self;
137
138        // Bind a TCP listener to a random port.
139        let listener = {
140            let ip_addresses = ip_address.unwrap_or_default().ip_addresses();
141            let port_range = port_range.unwrap_or(DEFAULT_PORT_RANGE);
142            let bind_tries = bind_tries.unwrap_or(DEFAULT_BIND_TRIES);
143            let mut n = 0u8;
144
145            loop {
146                let port = thread_rng().gen_range(port_range.clone());
147                let socket_addresses =
148                    ip_addresses.iter().map(|ip| SocketAddr::new(*ip, port)).collect::<Vec<_>>();
149
150                match TcpListener::bind(socket_addresses.as_slice()).await {
151                    Ok(l) => {
152                        break l;
153                    }
154                    Err(_) if n < bind_tries => {
155                        n += 1;
156                    }
157                    Err(e) => {
158                        return Err(e);
159                    }
160                }
161            }
162        };
163
164        let socket_address =
165            listener.local_addr().expect("bound TCP listener should have an address");
166        let uri = Url::parse(&format!("http://{socket_address}/"))
167            .expect("socket address should parse as a URI host");
168
169        // The channel used to shutdown the server when we are done with it.
170        let (shutdown_signal_sender, shutdown_signal_receiver) = oneshot::channel::<()>();
171        // The channel used to transmit the data received a the redirect URL.
172        let (data_sender, data_receiver) = oneshot::channel::<Option<QueryString>>();
173        let data_sender_mutex = Arc::new(Mutex::new(Some(data_sender)));
174
175        // Set up the server.
176        let router = any_service(service_fn(move |request: Request<_>| {
177            let data_sender_mutex = data_sender_mutex.clone();
178            let response = response.clone();
179
180            async move {
181                // Reject methods others than HEAD or GET.
182                if request.method() != Method::HEAD && request.method() != Method::GET {
183                    return Ok::<_, Infallible>(StatusCode::METHOD_NOT_ALLOWED.into_response());
184                }
185
186                // We only need to get the first response so we consume the transmitter the
187                // first time.
188                if let Some(data_sender) = data_sender_mutex.lock().take() {
189                    let _ =
190                        data_sender.send(request.uri().query().map(|s| QueryString(s.to_owned())));
191                }
192
193                Ok(response.unwrap_or_default().into_response())
194            }
195        }));
196
197        let server = axum::serve(listener, router)
198            .with_graceful_shutdown(async {
199                shutdown_signal_receiver.await.ok();
200            })
201            .into_future();
202
203        // Spawn the server.
204        spawn(server);
205
206        Ok((
207            uri,
208            LocalServerRedirectHandle {
209                data_receiver: Some(data_receiver),
210                shutdown_signal_sender: Arc::new(Mutex::new(Some(shutdown_signal_sender))),
211            },
212        ))
213    }
214}
215
216/// A handle to wait for the end-user to be redirected to a server spawned by
217/// [`LocalServerBuilder`].
218///
219/// Constructed with [`LocalServerBuilder::spawn()`].
220///
221/// `await`ing this type returns the query string of the URI where the end-user
222/// is redirected.
223///
224/// The server is shutdown when this handle is dropped. It can also be shutdown
225/// manually with a [`LocalServerShutdownHandle`] obtained from
226/// [`LocalServerRedirectHandle::shutdown_handle()`].
227#[allow(missing_debug_implementations)]
228pub struct LocalServerRedirectHandle {
229    /// The receiver to receive the query string.
230    data_receiver: Option<oneshot::Receiver<Option<QueryString>>>,
231
232    /// The sender used to send the signal to shutdown the server.
233    shutdown_signal_sender: Arc<Mutex<Option<oneshot::Sender<()>>>>,
234}
235
236impl LocalServerRedirectHandle {
237    /// Get a [`LocalServerShutdownHandle`].
238    pub fn shutdown_handle(&self) -> LocalServerShutdownHandle {
239        LocalServerShutdownHandle(self.shutdown_signal_sender.clone())
240    }
241}
242
243impl Drop for LocalServerRedirectHandle {
244    fn drop(&mut self) {
245        if let Some(sender) = self.shutdown_signal_sender.lock().take() {
246            let _ = sender.send(());
247        }
248    }
249}
250
251impl IntoFuture for LocalServerRedirectHandle {
252    type Output = Option<QueryString>;
253    boxed_into_future!();
254
255    fn into_future(self) -> Self::IntoFuture {
256        Box::pin(async move {
257            let mut this = self;
258
259            let data_receiver =
260                this.data_receiver.take().expect("data receiver is set during construction");
261            data_receiver.await.ok().flatten()
262        })
263    }
264}
265
266#[cfg(not(tarpaulin_include))]
267impl fmt::Debug for LocalServerRedirectHandle {
268    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269        f.debug_struct("LocalServerRedirectHandle").finish_non_exhaustive()
270    }
271}
272
273/// A handle to shutdown a server spawned by [`LocalServerBuilder`].
274///
275/// Constructed with [`LocalServerRedirectHandle::shutdown_handle()`].
276///
277/// Calling [`LocalServerShutdownHandle::shutdown()`] will shutdown the
278/// server before the end-user is redirected to it.
279#[derive(Clone)]
280#[allow(missing_debug_implementations)]
281pub struct LocalServerShutdownHandle(Arc<Mutex<Option<oneshot::Sender<()>>>>);
282
283impl LocalServerShutdownHandle {
284    /// Shutdown the local redirect server.
285    ///
286    /// This is a noop if the server was already shutdown.
287    pub fn shutdown(self) {
288        if let Some(sender) = self.0.lock().take() {
289            let _ = sender.send(());
290        }
291    }
292}
293
294#[cfg(not(tarpaulin_include))]
295impl fmt::Debug for LocalServerShutdownHandle {
296    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
297        f.debug_struct("LocalServerShutdownHandle").finish_non_exhaustive()
298    }
299}
300
301/// The IP address that we want the [`LocalServerBuilder`] to bind to.
302#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
303pub enum LocalServerIpAddress {
304    /// Bind to the localhost IPv4 address, `127.0.0.1`.
305    Localhostv4,
306
307    /// Bind to the localhost IPv6 address, `::1`,
308    Localhostv6,
309
310    /// Bind to localhost on an IPv6 or an IPv4 address.
311    ///
312    /// This is the default value.
313    #[default]
314    LocalhostAny,
315
316    /// Bind to a custom IP Address.
317    Custom(IpAddr),
318}
319
320impl LocalServerIpAddress {
321    /// Get the addresses to bind to.
322    fn ip_addresses(self) -> Vec<IpAddr> {
323        match self {
324            Self::Localhostv4 => vec![Ipv4Addr::LOCALHOST.into()],
325            Self::Localhostv6 => vec![Ipv6Addr::LOCALHOST.into()],
326            Self::LocalhostAny => vec![Ipv4Addr::LOCALHOST.into(), Ipv6Addr::LOCALHOST.into()],
327            Self::Custom(ip) => vec![ip],
328        }
329    }
330}
331
332/// The content that the end user will see when they a redirected to the
333/// local server's URI.
334#[derive(Debug, Clone, PartialEq, Eq)]
335pub enum LocalServerResponse {
336    /// A plain text body.
337    PlainText(String),
338
339    /// An HTML body.
340    Html(String),
341}
342
343impl LocalServerResponse {
344    /// Convert this body into an HTTP response.
345    fn into_response(self) -> http::Response<Body> {
346        let (content_type, body) = match self {
347            Self::PlainText(body) => {
348                (HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()), body)
349            }
350            Self::Html(body) => (HeaderValue::from_static(mime::TEXT_HTML_UTF_8.as_ref()), body),
351        };
352
353        let mut response = Body::from(body).into_response();
354        response.headers_mut().insert(header::CONTENT_TYPE, content_type);
355
356        response
357    }
358}
359
360impl Default for LocalServerResponse {
361    fn default() -> Self {
362        LocalServerResponse::PlainText(
363            "The authorization step is complete. You can close this page.".to_owned(),
364        )
365    }
366}
367
368/// A query string from a URI.
369///
370/// This is just a wrapper to have a strong type around a `String`.
371#[derive(Debug, Clone, PartialEq, Eq)]
372pub struct QueryString(pub String);
373
374impl AsRef<str> for QueryString {
375    fn as_ref(&self) -> &str {
376        &self.0
377    }
378}
379
380impl Deref for QueryString {
381    type Target = str;
382
383    fn deref(&self) -> &Self::Target {
384        &self.0
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use std::net::Ipv4Addr;
391
392    use http::header;
393    use matrix_sdk_test::async_test;
394
395    use crate::{
396        assert_let_timeout,
397        utils::local_server::{LocalServerBuilder, LocalServerIpAddress, LocalServerResponse},
398    };
399
400    #[async_test]
401    async fn test_local_server_builder_no_query() {
402        let (uri, server_handle) = LocalServerBuilder::new().spawn().await.unwrap();
403
404        let http_client = reqwest::Client::new();
405        http_client.get(uri.as_str()).send().await.unwrap();
406
407        assert_let_timeout!(None = server_handle);
408    }
409
410    #[async_test]
411    async fn test_local_server_builder_with_query() {
412        let (mut uri, server_handle) = LocalServerBuilder::new().spawn().await.unwrap();
413        uri.set_query(Some("foo=bar"));
414
415        let http_client = reqwest::Client::new();
416        http_client.get(uri.as_str()).send().await.unwrap();
417
418        assert_let_timeout!(Some(query) = server_handle);
419        assert_eq!(query.0, "foo=bar");
420    }
421
422    #[async_test]
423    async fn test_local_server_builder_with_ipv4_and_port() {
424        let (mut uri, server_handle) = LocalServerBuilder::new()
425            .ip_address(LocalServerIpAddress::Localhostv4)
426            .port_range(3000..3001)
427            .bind_tries(1)
428            .spawn()
429            .await
430            .unwrap();
431        uri.set_query(Some("foo=bar"));
432
433        assert_eq!(uri.host_str(), Some("127.0.0.1"));
434        assert_eq!(uri.port(), Some(3000));
435
436        let http_client = reqwest::Client::new();
437        http_client.get(uri.as_str()).send().await.unwrap();
438
439        assert_let_timeout!(Some(query) = server_handle);
440        assert_eq!(query.0, "foo=bar");
441    }
442
443    #[async_test]
444    async fn test_local_server_builder_with_ipv6_and_port() {
445        let (mut uri, server_handle) = LocalServerBuilder::new()
446            .ip_address(LocalServerIpAddress::Localhostv6)
447            .port_range(10000..10001)
448            .bind_tries(1)
449            .spawn()
450            .await
451            .unwrap();
452        uri.set_query(Some("foo=bar"));
453
454        assert_eq!(uri.host_str(), Some("[::1]"));
455        assert_eq!(uri.port(), Some(10000));
456
457        let http_client = reqwest::Client::new();
458        http_client.get(uri.as_str()).send().await.unwrap();
459
460        assert_let_timeout!(Some(query) = server_handle);
461        assert_eq!(query.0, "foo=bar");
462    }
463
464    #[async_test]
465    async fn test_local_server_builder_with_custom_ip_and_port() {
466        let (mut uri, server_handle) = LocalServerBuilder::new()
467            .ip_address(LocalServerIpAddress::Custom(Ipv4Addr::new(127, 0, 0, 1).into()))
468            .port_range(10040..10041)
469            .bind_tries(1)
470            .spawn()
471            .await
472            .unwrap();
473        uri.set_query(Some("foo=bar"));
474
475        assert_eq!(uri.host_str(), Some("127.0.0.1"));
476        assert_eq!(uri.port(), Some(10040));
477
478        let http_client = reqwest::Client::new();
479        http_client.get(uri.as_str()).send().await.unwrap();
480
481        assert_let_timeout!(Some(query) = server_handle);
482        assert_eq!(query.0, "foo=bar");
483    }
484
485    #[async_test]
486    async fn test_local_server_builder_with_custom_plain_text_response() {
487        let text = "Hello world!";
488        let (mut uri, server_handle) = LocalServerBuilder::new()
489            .response(LocalServerResponse::PlainText(text.to_owned()))
490            .spawn()
491            .await
492            .unwrap();
493        uri.set_query(Some("foo=bar"));
494
495        let http_client = reqwest::Client::new();
496        let response = http_client.get(uri.as_str()).send().await.unwrap();
497
498        let content_type = response.headers().get(header::CONTENT_TYPE).unwrap();
499        assert_eq!(content_type, "text/plain; charset=utf-8");
500        assert_eq!(response.text().await.unwrap(), text);
501
502        assert_let_timeout!(Some(query) = server_handle);
503        assert_eq!(query.0, "foo=bar");
504    }
505
506    #[async_test]
507    async fn test_local_server_builder_with_custom_html_response() {
508        let html = "<html><body><h1>Hello world!</h1></body></html>";
509        let (mut uri, server_handle) = LocalServerBuilder::new()
510            .response(LocalServerResponse::Html(html.to_owned()))
511            .spawn()
512            .await
513            .unwrap();
514        uri.set_query(Some("foo=bar"));
515
516        let http_client = reqwest::Client::new();
517        let response = http_client.get(uri.as_str()).send().await.unwrap();
518
519        let content_type = response.headers().get(header::CONTENT_TYPE).unwrap();
520        assert_eq!(content_type, "text/html; charset=utf-8");
521        assert_eq!(response.text().await.unwrap(), html);
522
523        assert_let_timeout!(Some(query) = server_handle);
524        assert_eq!(query.0, "foo=bar");
525    }
526
527    #[async_test]
528    async fn test_local_server_builder_early_shutdown() {
529        let (mut uri, server_handle) = LocalServerBuilder::new().spawn().await.unwrap();
530        uri.set_query(Some("foo=bar"));
531
532        server_handle.shutdown_handle().shutdown();
533
534        let http_client = reqwest::Client::new();
535        http_client.get(uri.as_str()).send().await.unwrap_err();
536
537        assert_let_timeout!(None = server_handle);
538    }
539}