1use std::{
16 any::type_name,
17 fmt::Debug,
18 num::NonZeroUsize,
19 sync::{
20 Arc,
21 atomic::{AtomicU64, Ordering},
22 },
23 time::Duration,
24};
25
26use bytes::{Bytes, BytesMut};
27use bytesize::ByteSize;
28use eyeball::SharedObservable;
29use http::Method;
30use ruma::api::{
31 OutgoingRequest, SendAccessToken, SupportedVersions, auth_scheme,
32 error::{FromHttpResponseError, IntoHttpError},
33};
34use tokio::sync::{Semaphore, SemaphorePermit};
35use tracing::{debug, field::debug, instrument, trace};
36
37use crate::{config::RequestConfig, error::HttpError};
38
39#[cfg(not(target_family = "wasm"))]
40mod native;
41#[cfg(target_family = "wasm")]
42mod wasm;
43
44#[cfg(not(target_family = "wasm"))]
45pub(crate) use native::HttpSettings;
46
47pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
48
49#[derive(Clone, Debug)]
50struct MaybeSemaphore(Arc<Option<Semaphore>>);
51
52#[allow(dead_code)] struct MaybeSemaphorePermit<'a>(Option<SemaphorePermit<'a>>);
54
55impl MaybeSemaphore {
56 fn new(max: Option<NonZeroUsize>) -> Self {
57 let inner = max.map(|i| Semaphore::new(i.into()));
58 MaybeSemaphore(Arc::new(inner))
59 }
60
61 async fn acquire(&self) -> MaybeSemaphorePermit<'_> {
62 match self.0.as_ref() {
63 Some(inner) => {
64 MaybeSemaphorePermit(inner.acquire().await.ok())
67 }
68 None => MaybeSemaphorePermit(None),
69 }
70 }
71}
72
73#[derive(Clone, Debug)]
74pub(crate) struct HttpClient {
75 pub(crate) inner: reqwest::Client,
76 pub(crate) request_config: RequestConfig,
77 concurrent_request_semaphore: MaybeSemaphore,
78 next_request_id: Arc<AtomicU64>,
79}
80
81impl HttpClient {
82 pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self {
83 HttpClient {
84 inner,
85 request_config,
86 concurrent_request_semaphore: MaybeSemaphore::new(
87 request_config.max_concurrent_requests,
88 ),
89 next_request_id: AtomicU64::new(0).into(),
90 }
91 }
92
93 fn get_request_id(&self) -> String {
94 let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst);
95 format!("REQ-{request_id}")
96 }
97
98 fn serialize_request<R>(
99 &self,
100 request: R,
101 config: RequestConfig,
102 homeserver: String,
103 access_token: Option<&str>,
104 supported_versions: &SupportedVersions,
105 ) -> Result<http::Request<Bytes>, IntoHttpError>
106 where
107 R: OutgoingRequest + Debug,
108 {
109 trace!(request_type = type_name::<R>(), "Serializing request");
110
111 let send_access_token = match access_token {
112 Some(access_token) => {
113 if config.force_auth {
114 SendAccessToken::Always(access_token)
115 } else {
116 SendAccessToken::IfRequired(access_token)
117 }
118 }
119 None => SendAccessToken::None,
120 };
121
122 let request = request
123 .try_into_http_request::<BytesMut>(&homeserver, send_access_token, supported_versions)?
124 .map(|body| body.freeze());
125
126 Ok(request)
127 }
128
129 #[allow(clippy::too_many_arguments)]
130 #[instrument(
131 skip(self, request, config, homeserver, access_token, supported_versions, send_progress),
132 fields(
133 uri,
134 method,
135 request_id,
136 request_size,
137 request_duration,
138 status,
139 response_size,
140 sentry_event_id
141 )
142 )]
143 pub async fn send<R>(
144 &self,
145 request: R,
146 config: Option<RequestConfig>,
147 homeserver: String,
148 access_token: Option<&str>,
149 supported_versions: &SupportedVersions,
150 send_progress: SharedObservable<TransmissionProgress>,
151 ) -> Result<R::IncomingResponse, HttpError>
152 where
153 R: OutgoingRequest + Debug,
154 R::Authentication: SupportedAuthScheme,
155 HttpError: From<FromHttpResponseError<R::EndpointError>>,
156 {
157 let config = match config {
158 Some(config) => config,
159 None => self.request_config,
160 };
161
162 let request = {
165 let request_id = self.get_request_id();
166 let span = tracing::Span::current();
167
168 span.record("config", debug(config)).record("request_id", request_id);
171
172 let request = self
173 .serialize_request(request, config, homeserver, access_token, supported_versions)
174 .map_err(HttpError::IntoHttp)?;
175
176 let method = request.method();
177
178 let mut uri_parts = request.uri().clone().into_parts();
179 if let Some(path_and_query) = &mut uri_parts.path_and_query {
180 *path_and_query =
181 path_and_query.path().try_into().expect("path is valid PathAndQuery");
182 }
183 let uri = http::Uri::from_parts(uri_parts).expect("created from valid URI");
184
185 span.record("method", debug(method)).record("uri", uri.to_string());
186
187 if [Method::POST, Method::PUT, Method::PATCH].contains(method) {
190 let request_size = request.body().len().try_into().unwrap_or(u64::MAX);
191 span.record(
192 "request_size",
193 ByteSize(request_size).display().si_short().to_string(),
194 );
195 }
196
197 request
198 };
199
200 let _handle = self.concurrent_request_semaphore.acquire().await;
202
203 match Box::pin(self.send_request::<R>(request, config, send_progress)).await {
206 Ok(response) => {
207 debug!("Got response");
208 Ok(response)
209 }
210 Err(e) => {
211 debug!("Error while sending request: {e:?}");
212 Err(e)
213 }
214 }
215 }
216}
217
218#[derive(Clone, Copy, Debug, Default)]
220pub struct TransmissionProgress {
221 pub current: usize,
223 pub total: usize,
225}
226
227async fn response_to_http_response(
228 mut response: reqwest::Response,
229) -> Result<http::Response<Bytes>, reqwest::Error> {
230 let status = response.status();
231
232 let mut http_builder = http::Response::builder().status(status);
233 let headers = http_builder.headers_mut().expect("Can't get the response builder headers");
234
235 for (k, v) in response.headers_mut().drain() {
236 if let Some(key) = k {
237 headers.insert(key, v);
238 }
239 }
240
241 let body = response.bytes().await?;
242
243 Ok(http_builder.body(body).expect("Can't construct a response using the given body"))
244}
245
246pub trait SupportedAuthScheme: auth_scheme::AuthScheme {}
252
253impl SupportedAuthScheme for auth_scheme::NoAuthentication {}
254
255impl SupportedAuthScheme for auth_scheme::AccessToken {}
256
257impl SupportedAuthScheme for auth_scheme::AccessTokenOptional {}
258
259impl SupportedAuthScheme for auth_scheme::AppserviceToken {}
260
261impl SupportedAuthScheme for auth_scheme::AppserviceTokenOptional {}
262
263#[cfg(all(test, not(target_family = "wasm")))]
264mod tests {
265 use std::{
266 num::NonZeroUsize,
267 sync::{
268 Arc,
269 atomic::{AtomicU8, Ordering},
270 },
271 time::Duration,
272 };
273
274 use matrix_sdk_common::executor::spawn;
275 use matrix_sdk_test::{async_test, test_json};
276 use wiremock::{
277 Mock, Request, ResponseTemplate,
278 matchers::{method, path},
279 };
280
281 use crate::{
282 http_client::RequestConfig,
283 test_utils::{set_client_session, test_client_builder_with_server},
284 };
285
286 #[async_test]
287 async fn test_ensure_concurrent_request_limit_is_observed() {
288 let (client_builder, server) = test_client_builder_with_server().await;
289 let client = client_builder
290 .request_config(RequestConfig::default().max_concurrent_requests(NonZeroUsize::new(5)))
291 .build()
292 .await
293 .unwrap();
294
295 set_client_session(&client).await;
296
297 let counter = Arc::new(AtomicU8::new(0));
298 let inner_counter = counter.clone();
299
300 Mock::given(method("GET"))
301 .and(path("/_matrix/client/versions"))
302 .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
303 .mount(&server)
304 .await;
305
306 Mock::given(method("GET"))
307 .and(path("_matrix/client/r0/account/whoami"))
308 .respond_with(move |_req: &Request| {
309 inner_counter.fetch_add(1, Ordering::SeqCst);
310 ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
312 })
313 .mount(&server)
314 .await;
315
316 let bg_task = spawn(async move {
317 futures_util::future::join_all((0..10).map(|_| client.whoami())).await
318 });
319
320 tokio::time::sleep(Duration::from_millis(300)).await;
322
323 assert_eq!(
324 counter.load(Ordering::SeqCst),
325 5,
326 "More requests passed than the limit we configured"
327 );
328 bg_task.abort();
329 }
330
331 #[async_test]
332 async fn test_ensure_no_max_concurrent_request_does_not_limit() {
333 let (client_builder, server) = test_client_builder_with_server().await;
334 let client = client_builder
335 .request_config(RequestConfig::default().max_concurrent_requests(None))
336 .build()
337 .await
338 .unwrap();
339
340 set_client_session(&client).await;
341
342 let counter = Arc::new(AtomicU8::new(0));
343 let inner_counter = counter.clone();
344
345 Mock::given(method("GET"))
346 .and(path("/_matrix/client/versions"))
347 .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
348 .mount(&server)
349 .await;
350
351 Mock::given(method("GET"))
352 .and(path("_matrix/client/r0/account/whoami"))
353 .respond_with(move |_req: &Request| {
354 inner_counter.fetch_add(1, Ordering::SeqCst);
355 ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
356 })
357 .mount(&server)
358 .await;
359
360 let bg_task = spawn(async move {
361 futures_util::future::join_all((0..254).map(|_| client.whoami())).await
362 });
363
364 tokio::time::sleep(Duration::from_secs(1)).await;
366
367 assert_eq!(counter.load(Ordering::SeqCst), 254, "Not all requests passed through");
368 bg_task.abort();
369 }
370}