1use std::{
16 any::type_name,
17 borrow::Cow,
18 fmt::Debug,
19 num::NonZeroUsize,
20 sync::{
21 Arc,
22 atomic::{AtomicU64, Ordering},
23 },
24 time::Duration,
25};
26
27use bytes::{Bytes, BytesMut};
28use bytesize::ByteSize;
29use eyeball::SharedObservable;
30use http::Method;
31use matrix_sdk_base::SendOutsideWasm;
32use ruma::api::{
33 OutgoingRequest, SupportedVersions,
34 auth_scheme::{self, AuthScheme, SendAccessToken},
35 error::{FromHttpResponseError, IntoHttpError},
36 path_builder,
37};
38use tokio::sync::{Semaphore, SemaphorePermit};
39use tracing::{debug, error, field::debug, instrument, trace};
40
41use crate::{HttpResult, client::caches::CachedValue, config::RequestConfig, error::HttpError};
42
43#[cfg(not(target_family = "wasm"))]
44mod native;
45#[cfg(target_family = "wasm")]
46mod wasm;
47
48#[cfg(not(target_family = "wasm"))]
49pub(crate) use native::HttpSettings;
50
51pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
52
53#[derive(Clone, Debug)]
54struct MaybeSemaphore(Arc<Option<Semaphore>>);
55
56#[allow(dead_code)] struct MaybeSemaphorePermit<'a>(Option<SemaphorePermit<'a>>);
58
59impl MaybeSemaphore {
60 fn new(max: Option<NonZeroUsize>) -> Self {
61 let inner = max.map(|i| Semaphore::new(i.into()));
62 MaybeSemaphore(Arc::new(inner))
63 }
64
65 async fn acquire(&self) -> MaybeSemaphorePermit<'_> {
66 match self.0.as_ref() {
67 Some(inner) => {
68 MaybeSemaphorePermit(inner.acquire().await.ok())
71 }
72 None => MaybeSemaphorePermit(None),
73 }
74 }
75}
76
77#[derive(Clone, Debug)]
78pub(crate) struct HttpClient {
79 pub(crate) inner: reqwest::Client,
80 pub(crate) request_config: RequestConfig,
81 concurrent_request_semaphore: MaybeSemaphore,
82 next_request_id: Arc<AtomicU64>,
83}
84
85impl HttpClient {
86 pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self {
87 HttpClient {
88 inner,
89 request_config,
90 concurrent_request_semaphore: MaybeSemaphore::new(
91 request_config.max_concurrent_requests,
92 ),
93 next_request_id: AtomicU64::new(0).into(),
94 }
95 }
96
97 fn get_request_id(&self) -> String {
98 let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst);
99 format!("REQ-{request_id}")
100 }
101
102 fn serialize_request<R>(
103 &self,
104 request: R,
105 config: RequestConfig,
106 homeserver: String,
107 access_token: Option<&str>,
108 path_builder_input: <R::PathBuilder as path_builder::PathBuilder>::Input<'_>,
109 ) -> Result<http::Request<Bytes>, IntoHttpError>
110 where
111 R: OutgoingRequest + Debug,
112 R::Authentication: SupportedAuthScheme,
113 {
114 trace!(request_type = type_name::<R>(), "Serializing request");
115
116 let send_access_token = match access_token {
117 Some(access_token) => match (config.force_auth, config.skip_auth) {
118 (true, true) | (true, false) => SendAccessToken::Always(access_token),
119 (false, true) => SendAccessToken::None,
120 (false, false) => SendAccessToken::IfRequired(access_token),
121 },
122 None => SendAccessToken::None,
123 };
124 let authentication_input = R::Authentication::authentication_input(send_access_token);
125
126 let request = request
127 .try_into_http_request::<BytesMut>(
128 &homeserver,
129 authentication_input,
130 path_builder_input,
131 )?
132 .map(|body| body.freeze());
133
134 Ok(request)
135 }
136
137 #[allow(clippy::too_many_arguments)]
138 #[instrument(
139 skip(self, request, config, homeserver, access_token, path_builder_input, send_progress),
140 fields(
141 uri,
142 method,
143 request_id,
144 request_size,
145 request_duration,
146 status,
147 response_size,
148 sentry_event_id
149 )
150 )]
151 pub async fn send<R>(
152 &self,
153 request: R,
154 config: Option<RequestConfig>,
155 homeserver: String,
156 access_token: Option<&str>,
157 path_builder_input: <R::PathBuilder as path_builder::PathBuilder>::Input<'_>,
158 send_progress: SharedObservable<TransmissionProgress>,
159 ) -> Result<R::IncomingResponse, HttpError>
160 where
161 R: OutgoingRequest + Debug,
162 R::Authentication: SupportedAuthScheme,
163 HttpError: From<FromHttpResponseError<R::EndpointError>>,
164 {
165 let config = match config {
166 Some(config) => config,
167 None => self.request_config,
168 };
169
170 let request = {
173 let request_id = self.get_request_id();
174 let span = tracing::Span::current();
175
176 span.record("config", debug(config)).record("request_id", request_id);
179
180 let request = self
181 .serialize_request(request, config, homeserver, access_token, path_builder_input)
182 .map_err(HttpError::IntoHttp)?;
183
184 let method = request.method();
185
186 let mut uri_parts = request.uri().clone().into_parts();
187
188 if let Some(path_and_query) = &mut uri_parts.path_and_query {
191 *path_and_query =
192 path_and_query.path().try_into().expect("path is valid PathAndQuery");
193 }
194
195 let uri = http::Uri::from_parts(uri_parts).expect("created from valid URI");
196
197 span.record("method", debug(method)).record("uri", uri.to_string());
198
199 if [Method::POST, Method::PUT, Method::PATCH].contains(method) {
202 let request_size = request.body().len().try_into().unwrap_or(u64::MAX);
203 span.record(
204 "request_size",
205 ByteSize(request_size).display().si_short().to_string(),
206 );
207 }
208
209 request
210 };
211
212 let _handle = self.concurrent_request_semaphore.acquire().await;
214
215 match Box::pin(self.send_request::<R>(request, config, send_progress)).await {
218 Ok(response) => {
219 debug!("Got response");
220 Ok(response)
221 }
222 Err(e) => {
223 error!("Error while sending request: {e:?}");
224 Err(e)
225 }
226 }
227 }
228}
229
230#[derive(Clone, Copy, Debug, Default)]
232pub struct TransmissionProgress {
233 pub current: usize,
235 pub total: usize,
237}
238
239async fn response_to_http_response(
240 mut response: reqwest::Response,
241) -> Result<http::Response<Bytes>, reqwest::Error> {
242 let status = response.status();
243
244 let mut http_builder = http::Response::builder().status(status);
245 let headers = http_builder.headers_mut().expect("Can't get the response builder headers");
246
247 for (k, v) in response.headers_mut().drain() {
248 if let Some(key) = k {
249 headers.insert(key, v);
250 }
251 }
252
253 let body = response.bytes().await?;
254
255 Ok(http_builder.body(body).expect("Can't construct a response using the given body"))
256}
257
258pub trait SupportedAuthScheme: AuthScheme {
264 fn authentication_input(access_token: SendAccessToken<'_>) -> Self::Input<'_>;
265}
266
267impl SupportedAuthScheme for auth_scheme::NoAccessToken {
268 fn authentication_input(access_token: SendAccessToken<'_>) -> Self::Input<'_> {
269 access_token
270 }
271}
272
273impl SupportedAuthScheme for auth_scheme::AccessToken {
274 fn authentication_input(access_token: SendAccessToken<'_>) -> Self::Input<'_> {
275 access_token
276 }
277}
278
279impl SupportedAuthScheme for auth_scheme::AccessTokenOptional {
280 fn authentication_input(access_token: SendAccessToken<'_>) -> Self::Input<'_> {
281 access_token
282 }
283}
284
285impl SupportedAuthScheme for auth_scheme::AppserviceToken {
286 fn authentication_input(access_token: SendAccessToken<'_>) -> Self::Input<'_> {
287 access_token
288 }
289}
290
291impl SupportedAuthScheme for auth_scheme::AppserviceTokenOptional {
292 fn authentication_input(access_token: SendAccessToken<'_>) -> Self::Input<'_> {
293 access_token
294 }
295}
296
297impl SupportedAuthScheme for auth_scheme::NoAuthentication {
298 fn authentication_input(_access_token: SendAccessToken<'_>) -> Self::Input<'_> {}
299}
300
301pub trait SupportedPathBuilder: path_builder::PathBuilder {
307 fn get_path_builder_input(
308 client: &crate::Client,
309 skip_auth: bool,
310 ) -> impl Future<Output = HttpResult<Self::Input<'static>>> + SendOutsideWasm;
311}
312
313impl SupportedPathBuilder for path_builder::VersionHistory {
314 async fn get_path_builder_input(
315 client: &crate::Client,
316 skip_auth: bool,
317 ) -> HttpResult<Cow<'static, SupportedVersions>> {
318 if !client.auth_ctx().has_valid_access_token() {
323 if let Ok(CachedValue::Cached(versions)) =
327 client.inner.caches.supported_versions.try_read().as_deref()
328 {
329 return Ok(Cow::Owned(versions.clone()));
330 }
331
332 let response = client.fetch_server_versions_inner(true, None).await?;
335
336 Ok(Cow::Owned(response.as_supported_versions()))
337 } else if skip_auth {
338 let cached_versions = client.get_cached_supported_versions().await;
339
340 let versions = if let Some(versions) = cached_versions {
341 versions
342 } else {
343 let request_config = RequestConfig::default().retry_limit(5).skip_auth();
346 let response =
347 client.fetch_server_versions_inner(true, Some(request_config)).await?;
348
349 response.as_supported_versions()
350 };
351
352 Ok(Cow::Owned(versions))
353 } else {
354 client.supported_versions_inner(true).await.map(Cow::Owned)
355 }
356 }
357}
358
359impl SupportedPathBuilder for path_builder::SinglePath {
360 async fn get_path_builder_input(_client: &crate::Client, _skip_auth: bool) -> HttpResult<()> {
361 Ok(())
362 }
363}
364
365#[cfg(feature = "rustls-tls")]
366pub mod rustls {
367 use rustls::crypto::CryptoProvider;
371
372 pub fn default_crypto_provider() -> CryptoProvider {
378 rustls::crypto::ring::default_provider()
379 }
380
381 pub fn install_default_crypto_provider_if_none_installed() {
390 if default_crypto_provider().install_default().is_ok() {
391 #[cfg(not(test))]
394 tracing::info!("No rustls crypto provider set, setting default provider to ring.");
395 }
396 }
397
398 #[cfg(test)]
401 macro_rules! install_default_crypto_provider_for_tests {
402 () => {
403 #[cfg(not(target_family = "wasm"))]
404 #[ctor::ctor]
405 fn install_default_crypto_provider_for_tests() {
406 $crate::http_client::rustls::install_default_crypto_provider_if_none_installed();
407 }
408 };
409 }
410
411 #[cfg(test)]
412 pub(crate) use install_default_crypto_provider_for_tests;
413}
414
415#[cfg(all(test, not(target_family = "wasm")))]
416mod tests {
417 use std::{
418 num::NonZeroUsize,
419 sync::{
420 Arc,
421 atomic::{AtomicU8, Ordering},
422 },
423 time::Duration,
424 };
425
426 use matrix_sdk_common::executor::spawn;
427 use matrix_sdk_test::{async_test, test_json};
428 use wiremock::{
429 Mock, Request, ResponseTemplate,
430 matchers::{method, path},
431 };
432
433 use crate::{
434 http_client::RequestConfig,
435 test_utils::{set_client_session, test_client_builder_with_server},
436 };
437
438 #[async_test]
439 async fn test_ensure_concurrent_request_limit_is_observed() {
440 let (client_builder, server) = test_client_builder_with_server().await;
441 let client = client_builder
442 .request_config(RequestConfig::default().max_concurrent_requests(NonZeroUsize::new(5)))
443 .build()
444 .await
445 .unwrap();
446
447 set_client_session(&client).await;
448
449 let counter = Arc::new(AtomicU8::new(0));
450 let inner_counter = counter.clone();
451
452 Mock::given(method("GET"))
453 .and(path("/_matrix/client/versions"))
454 .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
455 .mount(&server)
456 .await;
457
458 Mock::given(method("GET"))
459 .and(path("_matrix/client/r0/account/whoami"))
460 .respond_with(move |_req: &Request| {
461 inner_counter.fetch_add(1, Ordering::SeqCst);
462 ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
464 })
465 .mount(&server)
466 .await;
467
468 let bg_task = spawn(async move {
469 futures_util::future::join_all((0..10).map(|_| client.whoami())).await
470 });
471
472 tokio::time::sleep(Duration::from_millis(300)).await;
474
475 assert_eq!(
476 counter.load(Ordering::SeqCst),
477 5,
478 "More requests passed than the limit we configured"
479 );
480 bg_task.abort();
481 }
482
483 #[async_test]
484 async fn test_ensure_no_max_concurrent_request_does_not_limit() {
485 let (client_builder, server) = test_client_builder_with_server().await;
486 let client = client_builder
487 .request_config(RequestConfig::default().max_concurrent_requests(None))
488 .build()
489 .await
490 .unwrap();
491
492 set_client_session(&client).await;
493
494 let counter = Arc::new(AtomicU8::new(0));
495 let inner_counter = counter.clone();
496
497 Mock::given(method("GET"))
498 .and(path("/_matrix/client/versions"))
499 .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
500 .mount(&server)
501 .await;
502
503 Mock::given(method("GET"))
504 .and(path("_matrix/client/r0/account/whoami"))
505 .respond_with(move |_req: &Request| {
506 inner_counter.fetch_add(1, Ordering::SeqCst);
507 ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
508 })
509 .mount(&server)
510 .await;
511
512 let bg_task = spawn(async move {
513 futures_util::future::join_all((0..254).map(|_| client.whoami())).await
514 });
515
516 tokio::time::sleep(Duration::from_secs(1)).await;
518
519 assert_eq!(counter.load(Ordering::SeqCst), 254, "Not all requests passed through");
520 bg_task.abort();
521 }
522}