1use std::{
38 convert::Infallible,
39 fmt,
40 future::IntoFuture,
41 io,
42 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
43 ops::{Deref, Range},
44 sync::Arc,
45};
46
47use axum::{body::Body, response::IntoResponse, routing::any_service};
48use http::{header, HeaderValue, Method, Request, StatusCode};
49use matrix_sdk_base::{boxed_into_future, locks::Mutex};
50use matrix_sdk_common::executor::spawn;
51use rand::{thread_rng, Rng};
52use tokio::{net::TcpListener, sync::oneshot};
53use tower::service_fn;
54use url::Url;
55
56const DEFAULT_PORT_RANGE: Range<u16> = 20000..30000;
58const DEFAULT_BIND_TRIES: u8 = 10;
61
62#[derive(Debug, Default, Clone)]
75pub struct LocalServerBuilder {
76 ip_address: Option<LocalServerIpAddress>,
77 port_range: Option<Range<u16>>,
78 bind_tries: Option<u8>,
79 response: Option<LocalServerResponse>,
80}
81
82impl LocalServerBuilder {
83 pub fn new() -> Self {
85 Self::default()
86 }
87
88 pub fn ip_address(mut self, ip_address: LocalServerIpAddress) -> Self {
92 self.ip_address = Some(ip_address);
93 self
94 }
95
96 pub fn port_range(mut self, range: Range<u16>) -> Self {
104 self.port_range = Some(range);
105 self
106 }
107
108 pub fn bind_tries(mut self, tries: u8) -> Self {
116 self.bind_tries = Some(tries);
117 self
118 }
119
120 pub fn response(mut self, response: LocalServerResponse) -> Self {
125 self.response = Some(response);
126 self
127 }
128
129 pub async fn spawn(self) -> Result<(Url, LocalServerRedirectHandle), io::Error> {
136 let Self { ip_address, port_range, bind_tries, response } = self;
137
138 let listener = {
140 let ip_addresses = ip_address.unwrap_or_default().ip_addresses();
141 let port_range = port_range.unwrap_or(DEFAULT_PORT_RANGE);
142 let bind_tries = bind_tries.unwrap_or(DEFAULT_BIND_TRIES);
143 let mut n = 0u8;
144
145 loop {
146 let port = thread_rng().gen_range(port_range.clone());
147 let socket_addresses =
148 ip_addresses.iter().map(|ip| SocketAddr::new(*ip, port)).collect::<Vec<_>>();
149
150 match TcpListener::bind(socket_addresses.as_slice()).await {
151 Ok(l) => {
152 break l;
153 }
154 Err(_) if n < bind_tries => {
155 n += 1;
156 }
157 Err(e) => {
158 return Err(e);
159 }
160 }
161 }
162 };
163
164 let socket_address =
165 listener.local_addr().expect("bound TCP listener should have an address");
166 let uri = Url::parse(&format!("http://{socket_address}/"))
167 .expect("socket address should parse as a URI host");
168
169 let (shutdown_signal_sender, shutdown_signal_receiver) = oneshot::channel::<()>();
171 let (data_sender, data_receiver) = oneshot::channel::<Option<QueryString>>();
173 let data_sender_mutex = Arc::new(Mutex::new(Some(data_sender)));
174
175 let router = any_service(service_fn(move |request: Request<_>| {
177 let data_sender_mutex = data_sender_mutex.clone();
178 let response = response.clone();
179
180 async move {
181 if request.method() != Method::HEAD && request.method() != Method::GET {
183 return Ok::<_, Infallible>(StatusCode::METHOD_NOT_ALLOWED.into_response());
184 }
185
186 if let Some(data_sender) = data_sender_mutex.lock().take() {
189 let _ =
190 data_sender.send(request.uri().query().map(|s| QueryString(s.to_owned())));
191 }
192
193 Ok(response.unwrap_or_default().into_response())
194 }
195 }));
196
197 let server = axum::serve(listener, router)
198 .with_graceful_shutdown(async {
199 shutdown_signal_receiver.await.ok();
200 })
201 .into_future();
202
203 spawn(server);
205
206 Ok((
207 uri,
208 LocalServerRedirectHandle {
209 data_receiver: Some(data_receiver),
210 shutdown_signal_sender: Arc::new(Mutex::new(Some(shutdown_signal_sender))),
211 },
212 ))
213 }
214}
215
216#[allow(missing_debug_implementations)]
228pub struct LocalServerRedirectHandle {
229 data_receiver: Option<oneshot::Receiver<Option<QueryString>>>,
231
232 shutdown_signal_sender: Arc<Mutex<Option<oneshot::Sender<()>>>>,
234}
235
236impl LocalServerRedirectHandle {
237 pub fn shutdown_handle(&self) -> LocalServerShutdownHandle {
239 LocalServerShutdownHandle(self.shutdown_signal_sender.clone())
240 }
241}
242
243impl Drop for LocalServerRedirectHandle {
244 fn drop(&mut self) {
245 if let Some(sender) = self.shutdown_signal_sender.lock().take() {
246 let _ = sender.send(());
247 }
248 }
249}
250
251impl IntoFuture for LocalServerRedirectHandle {
252 type Output = Option<QueryString>;
253 boxed_into_future!();
254
255 fn into_future(self) -> Self::IntoFuture {
256 Box::pin(async move {
257 let mut this = self;
258
259 let data_receiver =
260 this.data_receiver.take().expect("data receiver is set during construction");
261 data_receiver.await.ok().flatten()
262 })
263 }
264}
265
266#[cfg(not(tarpaulin_include))]
267impl fmt::Debug for LocalServerRedirectHandle {
268 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269 f.debug_struct("LocalServerRedirectHandle").finish_non_exhaustive()
270 }
271}
272
273#[derive(Clone)]
280#[allow(missing_debug_implementations)]
281pub struct LocalServerShutdownHandle(Arc<Mutex<Option<oneshot::Sender<()>>>>);
282
283impl LocalServerShutdownHandle {
284 pub fn shutdown(self) {
288 if let Some(sender) = self.0.lock().take() {
289 let _ = sender.send(());
290 }
291 }
292}
293
294#[cfg(not(tarpaulin_include))]
295impl fmt::Debug for LocalServerShutdownHandle {
296 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
297 f.debug_struct("LocalServerShutdownHandle").finish_non_exhaustive()
298 }
299}
300
301#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
303pub enum LocalServerIpAddress {
304 Localhostv4,
306
307 Localhostv6,
309
310 #[default]
314 LocalhostAny,
315
316 Custom(IpAddr),
318}
319
320impl LocalServerIpAddress {
321 fn ip_addresses(self) -> Vec<IpAddr> {
323 match self {
324 Self::Localhostv4 => vec![Ipv4Addr::LOCALHOST.into()],
325 Self::Localhostv6 => vec![Ipv6Addr::LOCALHOST.into()],
326 Self::LocalhostAny => vec![Ipv4Addr::LOCALHOST.into(), Ipv6Addr::LOCALHOST.into()],
327 Self::Custom(ip) => vec![ip],
328 }
329 }
330}
331
332#[derive(Debug, Clone, PartialEq, Eq)]
335pub enum LocalServerResponse {
336 PlainText(String),
338
339 Html(String),
341}
342
343impl LocalServerResponse {
344 fn into_response(self) -> http::Response<Body> {
346 let (content_type, body) = match self {
347 Self::PlainText(body) => {
348 (HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()), body)
349 }
350 Self::Html(body) => (HeaderValue::from_static(mime::TEXT_HTML_UTF_8.as_ref()), body),
351 };
352
353 let mut response = Body::from(body).into_response();
354 response.headers_mut().insert(header::CONTENT_TYPE, content_type);
355
356 response
357 }
358}
359
360impl Default for LocalServerResponse {
361 fn default() -> Self {
362 LocalServerResponse::PlainText(
363 "The authorization step is complete. You can close this page.".to_owned(),
364 )
365 }
366}
367
368#[derive(Debug, Clone, PartialEq, Eq)]
372pub struct QueryString(pub String);
373
374impl AsRef<str> for QueryString {
375 fn as_ref(&self) -> &str {
376 &self.0
377 }
378}
379
380impl Deref for QueryString {
381 type Target = str;
382
383 fn deref(&self) -> &Self::Target {
384 &self.0
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use std::net::Ipv4Addr;
391
392 use http::header;
393 use matrix_sdk_test::async_test;
394
395 use crate::{
396 assert_let_timeout,
397 utils::local_server::{LocalServerBuilder, LocalServerIpAddress, LocalServerResponse},
398 };
399
400 #[async_test]
401 async fn test_local_server_builder_no_query() {
402 let (uri, server_handle) = LocalServerBuilder::new().spawn().await.unwrap();
403
404 let http_client = reqwest::Client::new();
405 http_client.get(uri.as_str()).send().await.unwrap();
406
407 assert_let_timeout!(None = server_handle);
408 }
409
410 #[async_test]
411 async fn test_local_server_builder_with_query() {
412 let (mut uri, server_handle) = LocalServerBuilder::new().spawn().await.unwrap();
413 uri.set_query(Some("foo=bar"));
414
415 let http_client = reqwest::Client::new();
416 http_client.get(uri.as_str()).send().await.unwrap();
417
418 assert_let_timeout!(Some(query) = server_handle);
419 assert_eq!(query.0, "foo=bar");
420 }
421
422 #[async_test]
423 async fn test_local_server_builder_with_ipv4_and_port() {
424 let (mut uri, server_handle) = LocalServerBuilder::new()
425 .ip_address(LocalServerIpAddress::Localhostv4)
426 .port_range(3000..3001)
427 .bind_tries(1)
428 .spawn()
429 .await
430 .unwrap();
431 uri.set_query(Some("foo=bar"));
432
433 assert_eq!(uri.host_str(), Some("127.0.0.1"));
434 assert_eq!(uri.port(), Some(3000));
435
436 let http_client = reqwest::Client::new();
437 http_client.get(uri.as_str()).send().await.unwrap();
438
439 assert_let_timeout!(Some(query) = server_handle);
440 assert_eq!(query.0, "foo=bar");
441 }
442
443 #[async_test]
444 async fn test_local_server_builder_with_ipv6_and_port() {
445 let (mut uri, server_handle) = LocalServerBuilder::new()
446 .ip_address(LocalServerIpAddress::Localhostv6)
447 .port_range(10000..10001)
448 .bind_tries(1)
449 .spawn()
450 .await
451 .unwrap();
452 uri.set_query(Some("foo=bar"));
453
454 assert_eq!(uri.host_str(), Some("[::1]"));
455 assert_eq!(uri.port(), Some(10000));
456
457 let http_client = reqwest::Client::new();
458 http_client.get(uri.as_str()).send().await.unwrap();
459
460 assert_let_timeout!(Some(query) = server_handle);
461 assert_eq!(query.0, "foo=bar");
462 }
463
464 #[async_test]
465 async fn test_local_server_builder_with_custom_ip_and_port() {
466 let (mut uri, server_handle) = LocalServerBuilder::new()
467 .ip_address(LocalServerIpAddress::Custom(Ipv4Addr::new(127, 0, 0, 1).into()))
468 .port_range(10040..10041)
469 .bind_tries(1)
470 .spawn()
471 .await
472 .unwrap();
473 uri.set_query(Some("foo=bar"));
474
475 assert_eq!(uri.host_str(), Some("127.0.0.1"));
476 assert_eq!(uri.port(), Some(10040));
477
478 let http_client = reqwest::Client::new();
479 http_client.get(uri.as_str()).send().await.unwrap();
480
481 assert_let_timeout!(Some(query) = server_handle);
482 assert_eq!(query.0, "foo=bar");
483 }
484
485 #[async_test]
486 async fn test_local_server_builder_with_custom_plain_text_response() {
487 let text = "Hello world!";
488 let (mut uri, server_handle) = LocalServerBuilder::new()
489 .response(LocalServerResponse::PlainText(text.to_owned()))
490 .spawn()
491 .await
492 .unwrap();
493 uri.set_query(Some("foo=bar"));
494
495 let http_client = reqwest::Client::new();
496 let response = http_client.get(uri.as_str()).send().await.unwrap();
497
498 let content_type = response.headers().get(header::CONTENT_TYPE).unwrap();
499 assert_eq!(content_type, "text/plain; charset=utf-8");
500 assert_eq!(response.text().await.unwrap(), text);
501
502 assert_let_timeout!(Some(query) = server_handle);
503 assert_eq!(query.0, "foo=bar");
504 }
505
506 #[async_test]
507 async fn test_local_server_builder_with_custom_html_response() {
508 let html = "<html><body><h1>Hello world!</h1></body></html>";
509 let (mut uri, server_handle) = LocalServerBuilder::new()
510 .response(LocalServerResponse::Html(html.to_owned()))
511 .spawn()
512 .await
513 .unwrap();
514 uri.set_query(Some("foo=bar"));
515
516 let http_client = reqwest::Client::new();
517 let response = http_client.get(uri.as_str()).send().await.unwrap();
518
519 let content_type = response.headers().get(header::CONTENT_TYPE).unwrap();
520 assert_eq!(content_type, "text/html; charset=utf-8");
521 assert_eq!(response.text().await.unwrap(), html);
522
523 assert_let_timeout!(Some(query) = server_handle);
524 assert_eq!(query.0, "foo=bar");
525 }
526
527 #[async_test]
528 async fn test_local_server_builder_early_shutdown() {
529 let (mut uri, server_handle) = LocalServerBuilder::new().spawn().await.unwrap();
530 uri.set_query(Some("foo=bar"));
531
532 server_handle.shutdown_handle().shutdown();
533
534 let http_client = reqwest::Client::new();
535 http_client.get(uri.as_str()).send().await.unwrap_err();
536
537 assert_let_timeout!(None = server_handle);
538 }
539}