1use std::{
38 convert::Infallible,
39 fmt,
40 future::IntoFuture,
41 io,
42 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
43 ops::{Deref, Range},
44 sync::Arc,
45};
46
47use axum::{body::Body, response::IntoResponse, routing::any_service};
48use http::{header, HeaderValue, Method, Request, StatusCode};
49use matrix_sdk_base::{boxed_into_future, locks::Mutex};
50use rand::{thread_rng, Rng};
51use tokio::{net::TcpListener, sync::oneshot};
52use tower::service_fn;
53use url::Url;
54
55const DEFAULT_PORT_RANGE: Range<u16> = 20000..30000;
57const DEFAULT_BIND_TRIES: u8 = 10;
60
61#[derive(Debug, Default, Clone)]
74pub struct LocalServerBuilder {
75 ip_address: Option<LocalServerIpAddress>,
76 port_range: Option<Range<u16>>,
77 bind_tries: Option<u8>,
78 response: Option<LocalServerResponse>,
79}
80
81impl LocalServerBuilder {
82 pub fn new() -> Self {
84 Self::default()
85 }
86
87 pub fn ip_address(mut self, ip_address: LocalServerIpAddress) -> Self {
91 self.ip_address = Some(ip_address);
92 self
93 }
94
95 pub fn port_range(mut self, range: Range<u16>) -> Self {
103 self.port_range = Some(range);
104 self
105 }
106
107 pub fn bind_tries(mut self, tries: u8) -> Self {
115 self.bind_tries = Some(tries);
116 self
117 }
118
119 pub fn response(mut self, response: LocalServerResponse) -> Self {
124 self.response = Some(response);
125 self
126 }
127
128 pub async fn spawn(self) -> Result<(Url, LocalServerRedirectHandle), io::Error> {
135 let Self { ip_address, port_range, bind_tries, response } = self;
136
137 let listener = {
139 let ip_addresses = ip_address.unwrap_or_default().ip_addresses();
140 let port_range = port_range.unwrap_or(DEFAULT_PORT_RANGE);
141 let bind_tries = bind_tries.unwrap_or(DEFAULT_BIND_TRIES);
142 let mut n = 0u8;
143
144 loop {
145 let port = thread_rng().gen_range(port_range.clone());
146 let socket_addresses =
147 ip_addresses.iter().map(|ip| SocketAddr::new(*ip, port)).collect::<Vec<_>>();
148
149 match TcpListener::bind(socket_addresses.as_slice()).await {
150 Ok(l) => {
151 break l;
152 }
153 Err(_) if n < bind_tries => {
154 n += 1;
155 }
156 Err(e) => {
157 return Err(e);
158 }
159 }
160 }
161 };
162
163 let socket_address =
164 listener.local_addr().expect("bound TCP listener should have an address");
165 let uri = Url::parse(&format!("http://{socket_address}/"))
166 .expect("socket address should parse as a URI host");
167
168 let (shutdown_signal_sender, shutdown_signal_receiver) = oneshot::channel::<()>();
170 let (data_sender, data_receiver) = oneshot::channel::<Option<QueryString>>();
172 let data_sender_mutex = Arc::new(Mutex::new(Some(data_sender)));
173
174 let router = any_service(service_fn(move |request: Request<_>| {
176 let data_sender_mutex = data_sender_mutex.clone();
177 let response = response.clone();
178
179 async move {
180 if request.method() != Method::HEAD && request.method() != Method::GET {
182 return Ok::<_, Infallible>(StatusCode::METHOD_NOT_ALLOWED.into_response());
183 }
184
185 if let Some(data_sender) = data_sender_mutex.lock().take() {
188 let _ =
189 data_sender.send(request.uri().query().map(|s| QueryString(s.to_owned())));
190 }
191
192 Ok(response.unwrap_or_default().into_response())
193 }
194 }));
195
196 let server = axum::serve(listener, router)
197 .with_graceful_shutdown(async {
198 shutdown_signal_receiver.await.ok();
199 })
200 .into_future();
201
202 tokio::spawn(server);
204
205 Ok((
206 uri,
207 LocalServerRedirectHandle {
208 data_receiver: Some(data_receiver),
209 shutdown_signal_sender: Arc::new(Mutex::new(Some(shutdown_signal_sender))),
210 },
211 ))
212 }
213}
214
215#[allow(missing_debug_implementations)]
227pub struct LocalServerRedirectHandle {
228 data_receiver: Option<oneshot::Receiver<Option<QueryString>>>,
230
231 shutdown_signal_sender: Arc<Mutex<Option<oneshot::Sender<()>>>>,
233}
234
235impl LocalServerRedirectHandle {
236 pub fn shutdown_handle(&self) -> LocalServerShutdownHandle {
238 LocalServerShutdownHandle(self.shutdown_signal_sender.clone())
239 }
240}
241
242impl Drop for LocalServerRedirectHandle {
243 fn drop(&mut self) {
244 if let Some(sender) = self.shutdown_signal_sender.lock().take() {
245 let _ = sender.send(());
246 }
247 }
248}
249
250impl IntoFuture for LocalServerRedirectHandle {
251 type Output = Option<QueryString>;
252 boxed_into_future!();
253
254 fn into_future(self) -> Self::IntoFuture {
255 Box::pin(async move {
256 let mut this = self;
257
258 let data_receiver =
259 this.data_receiver.take().expect("data receiver is set during construction");
260 data_receiver.await.ok().flatten()
261 })
262 }
263}
264
265#[cfg(not(tarpaulin_include))]
266impl fmt::Debug for LocalServerRedirectHandle {
267 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
268 f.debug_struct("LocalServerRedirectHandle").finish_non_exhaustive()
269 }
270}
271
272#[derive(Clone)]
279#[allow(missing_debug_implementations)]
280pub struct LocalServerShutdownHandle(Arc<Mutex<Option<oneshot::Sender<()>>>>);
281
282impl LocalServerShutdownHandle {
283 pub fn shutdown(self) {
287 if let Some(sender) = self.0.lock().take() {
288 let _ = sender.send(());
289 }
290 }
291}
292
293#[cfg(not(tarpaulin_include))]
294impl fmt::Debug for LocalServerShutdownHandle {
295 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296 f.debug_struct("LocalServerShutdownHandle").finish_non_exhaustive()
297 }
298}
299
300#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
302pub enum LocalServerIpAddress {
303 Localhostv4,
305
306 Localhostv6,
308
309 #[default]
313 LocalhostAny,
314
315 Custom(IpAddr),
317}
318
319impl LocalServerIpAddress {
320 fn ip_addresses(self) -> Vec<IpAddr> {
322 match self {
323 Self::Localhostv4 => vec![Ipv4Addr::LOCALHOST.into()],
324 Self::Localhostv6 => vec![Ipv6Addr::LOCALHOST.into()],
325 Self::LocalhostAny => vec![Ipv4Addr::LOCALHOST.into(), Ipv6Addr::LOCALHOST.into()],
326 Self::Custom(ip) => vec![ip],
327 }
328 }
329}
330
331#[derive(Debug, Clone, PartialEq, Eq)]
334pub enum LocalServerResponse {
335 PlainText(String),
337
338 Html(String),
340}
341
342impl LocalServerResponse {
343 fn into_response(self) -> http::Response<Body> {
345 let (content_type, body) = match self {
346 Self::PlainText(body) => {
347 (HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()), body)
348 }
349 Self::Html(body) => (HeaderValue::from_static(mime::TEXT_HTML_UTF_8.as_ref()), body),
350 };
351
352 let mut response = Body::from(body).into_response();
353 response.headers_mut().insert(header::CONTENT_TYPE, content_type);
354
355 response
356 }
357}
358
359impl Default for LocalServerResponse {
360 fn default() -> Self {
361 LocalServerResponse::PlainText(
362 "The authorization step is complete. You can close this page.".to_owned(),
363 )
364 }
365}
366
367#[derive(Debug, Clone, PartialEq, Eq)]
371pub struct QueryString(pub String);
372
373impl AsRef<str> for QueryString {
374 fn as_ref(&self) -> &str {
375 &self.0
376 }
377}
378
379impl Deref for QueryString {
380 type Target = str;
381
382 fn deref(&self) -> &Self::Target {
383 &self.0
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use std::net::Ipv4Addr;
390
391 use http::header;
392 use matrix_sdk_test::async_test;
393
394 use crate::{
395 assert_let_timeout,
396 utils::local_server::{LocalServerBuilder, LocalServerIpAddress, LocalServerResponse},
397 };
398
399 #[async_test]
400 async fn test_local_server_builder_no_query() {
401 let (uri, server_handle) = LocalServerBuilder::new().spawn().await.unwrap();
402
403 let http_client = reqwest::Client::new();
404 http_client.get(uri.as_str()).send().await.unwrap();
405
406 assert_let_timeout!(None = server_handle);
407 }
408
409 #[async_test]
410 async fn test_local_server_builder_with_query() {
411 let (mut uri, server_handle) = LocalServerBuilder::new().spawn().await.unwrap();
412 uri.set_query(Some("foo=bar"));
413
414 let http_client = reqwest::Client::new();
415 http_client.get(uri.as_str()).send().await.unwrap();
416
417 assert_let_timeout!(Some(query) = server_handle);
418 assert_eq!(query.0, "foo=bar");
419 }
420
421 #[async_test]
422 async fn test_local_server_builder_with_ipv4_and_port() {
423 let (mut uri, server_handle) = LocalServerBuilder::new()
424 .ip_address(LocalServerIpAddress::Localhostv4)
425 .port_range(3000..3001)
426 .bind_tries(1)
427 .spawn()
428 .await
429 .unwrap();
430 uri.set_query(Some("foo=bar"));
431
432 assert_eq!(uri.host_str(), Some("127.0.0.1"));
433 assert_eq!(uri.port(), Some(3000));
434
435 let http_client = reqwest::Client::new();
436 http_client.get(uri.as_str()).send().await.unwrap();
437
438 assert_let_timeout!(Some(query) = server_handle);
439 assert_eq!(query.0, "foo=bar");
440 }
441
442 #[async_test]
443 async fn test_local_server_builder_with_ipv6_and_port() {
444 let (mut uri, server_handle) = LocalServerBuilder::new()
445 .ip_address(LocalServerIpAddress::Localhostv6)
446 .port_range(10000..10001)
447 .bind_tries(1)
448 .spawn()
449 .await
450 .unwrap();
451 uri.set_query(Some("foo=bar"));
452
453 assert_eq!(uri.host_str(), Some("[::1]"));
454 assert_eq!(uri.port(), Some(10000));
455
456 let http_client = reqwest::Client::new();
457 http_client.get(uri.as_str()).send().await.unwrap();
458
459 assert_let_timeout!(Some(query) = server_handle);
460 assert_eq!(query.0, "foo=bar");
461 }
462
463 #[async_test]
464 async fn test_local_server_builder_with_custom_ip_and_port() {
465 let (mut uri, server_handle) = LocalServerBuilder::new()
466 .ip_address(LocalServerIpAddress::Custom(Ipv4Addr::new(127, 0, 0, 1).into()))
467 .port_range(10040..10041)
468 .bind_tries(1)
469 .spawn()
470 .await
471 .unwrap();
472 uri.set_query(Some("foo=bar"));
473
474 assert_eq!(uri.host_str(), Some("127.0.0.1"));
475 assert_eq!(uri.port(), Some(10040));
476
477 let http_client = reqwest::Client::new();
478 http_client.get(uri.as_str()).send().await.unwrap();
479
480 assert_let_timeout!(Some(query) = server_handle);
481 assert_eq!(query.0, "foo=bar");
482 }
483
484 #[async_test]
485 async fn test_local_server_builder_with_custom_plain_text_response() {
486 let text = "Hello world!";
487 let (mut uri, server_handle) = LocalServerBuilder::new()
488 .response(LocalServerResponse::PlainText(text.to_owned()))
489 .spawn()
490 .await
491 .unwrap();
492 uri.set_query(Some("foo=bar"));
493
494 let http_client = reqwest::Client::new();
495 let response = http_client.get(uri.as_str()).send().await.unwrap();
496
497 let content_type = response.headers().get(header::CONTENT_TYPE).unwrap();
498 assert_eq!(content_type, "text/plain; charset=utf-8");
499 assert_eq!(response.text().await.unwrap(), text);
500
501 assert_let_timeout!(Some(query) = server_handle);
502 assert_eq!(query.0, "foo=bar");
503 }
504
505 #[async_test]
506 async fn test_local_server_builder_with_custom_html_response() {
507 let html = "<html><body><h1>Hello world!</h1></body></html>";
508 let (mut uri, server_handle) = LocalServerBuilder::new()
509 .response(LocalServerResponse::Html(html.to_owned()))
510 .spawn()
511 .await
512 .unwrap();
513 uri.set_query(Some("foo=bar"));
514
515 let http_client = reqwest::Client::new();
516 let response = http_client.get(uri.as_str()).send().await.unwrap();
517
518 let content_type = response.headers().get(header::CONTENT_TYPE).unwrap();
519 assert_eq!(content_type, "text/html; charset=utf-8");
520 assert_eq!(response.text().await.unwrap(), html);
521
522 assert_let_timeout!(Some(query) = server_handle);
523 assert_eq!(query.0, "foo=bar");
524 }
525
526 #[async_test]
527 async fn test_local_server_builder_early_shutdown() {
528 let (mut uri, server_handle) = LocalServerBuilder::new().spawn().await.unwrap();
529 uri.set_query(Some("foo=bar"));
530
531 server_handle.shutdown_handle().shutdown();
532
533 let http_client = reqwest::Client::new();
534 http_client.get(uri.as_str()).send().await.unwrap_err();
535
536 assert_let_timeout!(None = server_handle);
537 }
538}