1use std::{
16 any::type_name,
17 borrow::Cow,
18 fmt::Debug,
19 num::NonZeroUsize,
20 sync::{
21 Arc,
22 atomic::{AtomicU64, Ordering},
23 },
24 time::Duration,
25};
26
27use bytes::{Bytes, BytesMut};
28use bytesize::ByteSize;
29use eyeball::SharedObservable;
30use http::Method;
31use matrix_sdk_base::SendOutsideWasm;
32use ruma::api::{
33 OutgoingRequest, SupportedVersions,
34 auth_scheme::{AuthScheme, SendAccessToken},
35 error::{FromHttpResponseError, IntoHttpError},
36 path_builder,
37};
38use tokio::sync::{Semaphore, SemaphorePermit};
39use tracing::{debug, field::debug, instrument, trace};
40
41use crate::{HttpResult, config::RequestConfig, error::HttpError};
42
43#[cfg(not(target_family = "wasm"))]
44mod native;
45#[cfg(target_family = "wasm")]
46mod wasm;
47
48#[cfg(not(target_family = "wasm"))]
49pub(crate) use native::HttpSettings;
50
51pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
52
53#[derive(Clone, Debug)]
54struct MaybeSemaphore(Arc<Option<Semaphore>>);
55
56#[allow(dead_code)] struct MaybeSemaphorePermit<'a>(Option<SemaphorePermit<'a>>);
58
59impl MaybeSemaphore {
60 fn new(max: Option<NonZeroUsize>) -> Self {
61 let inner = max.map(|i| Semaphore::new(i.into()));
62 MaybeSemaphore(Arc::new(inner))
63 }
64
65 async fn acquire(&self) -> MaybeSemaphorePermit<'_> {
66 match self.0.as_ref() {
67 Some(inner) => {
68 MaybeSemaphorePermit(inner.acquire().await.ok())
71 }
72 None => MaybeSemaphorePermit(None),
73 }
74 }
75}
76
77#[derive(Clone, Debug)]
78pub(crate) struct HttpClient {
79 pub(crate) inner: reqwest::Client,
80 pub(crate) request_config: RequestConfig,
81 concurrent_request_semaphore: MaybeSemaphore,
82 next_request_id: Arc<AtomicU64>,
83}
84
85impl HttpClient {
86 pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self {
87 HttpClient {
88 inner,
89 request_config,
90 concurrent_request_semaphore: MaybeSemaphore::new(
91 request_config.max_concurrent_requests,
92 ),
93 next_request_id: AtomicU64::new(0).into(),
94 }
95 }
96
97 fn get_request_id(&self) -> String {
98 let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst);
99 format!("REQ-{request_id}")
100 }
101
102 fn serialize_request<R>(
103 &self,
104 request: R,
105 config: RequestConfig,
106 homeserver: String,
107 access_token: Option<&str>,
108 path_builder_input: <R::PathBuilder as path_builder::PathBuilder>::Input<'_>,
109 ) -> Result<http::Request<Bytes>, IntoHttpError>
110 where
111 R: OutgoingRequest + Debug,
112 for<'a> R::Authentication: AuthScheme<Input<'a> = SendAccessToken<'a>>,
113 {
114 trace!(request_type = type_name::<R>(), "Serializing request");
115
116 let send_access_token = match access_token {
117 Some(access_token) => match (config.force_auth, config.skip_auth) {
118 (true, true) | (true, false) => SendAccessToken::Always(access_token),
119 (false, true) => SendAccessToken::None,
120 (false, false) => SendAccessToken::IfRequired(access_token),
121 },
122 None => SendAccessToken::None,
123 };
124
125 let request = request
126 .try_into_http_request::<BytesMut>(&homeserver, send_access_token, path_builder_input)?
127 .map(|body| body.freeze());
128
129 Ok(request)
130 }
131
132 #[allow(clippy::too_many_arguments)]
133 #[instrument(
134 skip(self, request, config, homeserver, access_token, path_builder_input, send_progress),
135 fields(
136 uri,
137 method,
138 request_id,
139 request_size,
140 request_duration,
141 status,
142 response_size,
143 sentry_event_id
144 )
145 )]
146 pub async fn send<R>(
147 &self,
148 request: R,
149 config: Option<RequestConfig>,
150 homeserver: String,
151 access_token: Option<&str>,
152 path_builder_input: <R::PathBuilder as path_builder::PathBuilder>::Input<'_>,
153 send_progress: SharedObservable<TransmissionProgress>,
154 ) -> Result<R::IncomingResponse, HttpError>
155 where
156 R: OutgoingRequest + Debug,
157 for<'a> R::Authentication: AuthScheme<Input<'a> = SendAccessToken<'a>>,
158 HttpError: From<FromHttpResponseError<R::EndpointError>>,
159 {
160 let config = match config {
161 Some(config) => config,
162 None => self.request_config,
163 };
164
165 let request = {
168 let request_id = self.get_request_id();
169 let span = tracing::Span::current();
170
171 span.record("config", debug(config)).record("request_id", request_id);
174
175 let request = self
176 .serialize_request(request, config, homeserver, access_token, path_builder_input)
177 .map_err(HttpError::IntoHttp)?;
178
179 let method = request.method();
180
181 let mut uri_parts = request.uri().clone().into_parts();
182 if let Some(path_and_query) = &mut uri_parts.path_and_query {
183 *path_and_query =
184 path_and_query.path().try_into().expect("path is valid PathAndQuery");
185 }
186 let uri = http::Uri::from_parts(uri_parts).expect("created from valid URI");
187
188 span.record("method", debug(method)).record("uri", uri.to_string());
189
190 if [Method::POST, Method::PUT, Method::PATCH].contains(method) {
193 let request_size = request.body().len().try_into().unwrap_or(u64::MAX);
194 span.record(
195 "request_size",
196 ByteSize(request_size).display().si_short().to_string(),
197 );
198 }
199
200 request
201 };
202
203 let _handle = self.concurrent_request_semaphore.acquire().await;
205
206 match Box::pin(self.send_request::<R>(request, config, send_progress)).await {
209 Ok(response) => {
210 debug!("Got response");
211 Ok(response)
212 }
213 Err(e) => {
214 debug!("Error while sending request: {e:?}");
215 Err(e)
216 }
217 }
218 }
219}
220
221#[derive(Clone, Copy, Debug, Default)]
223pub struct TransmissionProgress {
224 pub current: usize,
226 pub total: usize,
228}
229
230async fn response_to_http_response(
231 mut response: reqwest::Response,
232) -> Result<http::Response<Bytes>, reqwest::Error> {
233 let status = response.status();
234
235 let mut http_builder = http::Response::builder().status(status);
236 let headers = http_builder.headers_mut().expect("Can't get the response builder headers");
237
238 for (k, v) in response.headers_mut().drain() {
239 if let Some(key) = k {
240 headers.insert(key, v);
241 }
242 }
243
244 let body = response.bytes().await?;
245
246 Ok(http_builder.body(body).expect("Can't construct a response using the given body"))
247}
248
249pub trait SupportedPathBuilder: path_builder::PathBuilder {
255 fn get_path_builder_input(
256 client: &crate::Client,
257 skip_auth: bool,
258 ) -> impl Future<Output = HttpResult<Self::Input<'static>>> + SendOutsideWasm;
259}
260
261impl SupportedPathBuilder for path_builder::VersionHistory {
262 async fn get_path_builder_input(
263 client: &crate::Client,
264 skip_auth: bool,
265 ) -> HttpResult<Cow<'static, SupportedVersions>> {
266 if skip_auth {
267 let cached_versions = client.get_cached_versions().await;
268
269 let versions = if let Some(versions) = cached_versions {
270 versions
271 } else {
272 let request_config = RequestConfig::default().retry_limit(5).skip_auth();
275 let response = client.fetch_server_versions(Some(request_config)).await?;
276
277 response.as_supported_versions()
278 };
279
280 Ok(Cow::Owned(versions))
281 } else {
282 client.supported_versions().await.map(Cow::Owned)
283 }
284 }
285}
286
287impl SupportedPathBuilder for path_builder::SinglePath {
288 async fn get_path_builder_input(_client: &crate::Client, _skip_auth: bool) -> HttpResult<()> {
289 Ok(())
290 }
291}
292
293#[cfg(all(test, not(target_family = "wasm")))]
294mod tests {
295 use std::{
296 num::NonZeroUsize,
297 sync::{
298 Arc,
299 atomic::{AtomicU8, Ordering},
300 },
301 time::Duration,
302 };
303
304 use matrix_sdk_common::executor::spawn;
305 use matrix_sdk_test::{async_test, test_json};
306 use wiremock::{
307 Mock, Request, ResponseTemplate,
308 matchers::{method, path},
309 };
310
311 use crate::{
312 http_client::RequestConfig,
313 test_utils::{set_client_session, test_client_builder_with_server},
314 };
315
316 #[async_test]
317 async fn test_ensure_concurrent_request_limit_is_observed() {
318 let (client_builder, server) = test_client_builder_with_server().await;
319 let client = client_builder
320 .request_config(RequestConfig::default().max_concurrent_requests(NonZeroUsize::new(5)))
321 .build()
322 .await
323 .unwrap();
324
325 set_client_session(&client).await;
326
327 let counter = Arc::new(AtomicU8::new(0));
328 let inner_counter = counter.clone();
329
330 Mock::given(method("GET"))
331 .and(path("/_matrix/client/versions"))
332 .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
333 .mount(&server)
334 .await;
335
336 Mock::given(method("GET"))
337 .and(path("_matrix/client/r0/account/whoami"))
338 .respond_with(move |_req: &Request| {
339 inner_counter.fetch_add(1, Ordering::SeqCst);
340 ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
342 })
343 .mount(&server)
344 .await;
345
346 let bg_task = spawn(async move {
347 futures_util::future::join_all((0..10).map(|_| client.whoami())).await
348 });
349
350 tokio::time::sleep(Duration::from_millis(300)).await;
352
353 assert_eq!(
354 counter.load(Ordering::SeqCst),
355 5,
356 "More requests passed than the limit we configured"
357 );
358 bg_task.abort();
359 }
360
361 #[async_test]
362 async fn test_ensure_no_max_concurrent_request_does_not_limit() {
363 let (client_builder, server) = test_client_builder_with_server().await;
364 let client = client_builder
365 .request_config(RequestConfig::default().max_concurrent_requests(None))
366 .build()
367 .await
368 .unwrap();
369
370 set_client_session(&client).await;
371
372 let counter = Arc::new(AtomicU8::new(0));
373 let inner_counter = counter.clone();
374
375 Mock::given(method("GET"))
376 .and(path("/_matrix/client/versions"))
377 .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
378 .mount(&server)
379 .await;
380
381 Mock::given(method("GET"))
382 .and(path("_matrix/client/r0/account/whoami"))
383 .respond_with(move |_req: &Request| {
384 inner_counter.fetch_add(1, Ordering::SeqCst);
385 ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
386 })
387 .mount(&server)
388 .await;
389
390 let bg_task = spawn(async move {
391 futures_util::future::join_all((0..254).map(|_| client.whoami())).await
392 });
393
394 tokio::time::sleep(Duration::from_secs(1)).await;
396
397 assert_eq!(counter.load(Ordering::SeqCst), 254, "Not all requests passed through");
398 bg_task.abort();
399 }
400}