1use std::{
16 any::type_name,
17 fmt::Debug,
18 num::NonZeroUsize,
19 sync::{
20 atomic::{AtomicU64, Ordering},
21 Arc,
22 },
23 time::Duration,
24};
25
26use bytes::{Bytes, BytesMut};
27use bytesize::ByteSize;
28use eyeball::SharedObservable;
29use http::Method;
30use ruma::api::{
31 error::{FromHttpResponseError, IntoHttpError},
32 AuthScheme, MatrixVersion, OutgoingRequest, SendAccessToken,
33};
34use tokio::sync::{Semaphore, SemaphorePermit};
35use tracing::{debug, field::debug, instrument, trace};
36
37use crate::{config::RequestConfig, error::HttpError};
38
39#[cfg(not(target_arch = "wasm32"))]
40mod native;
41#[cfg(target_arch = "wasm32")]
42mod wasm;
43
44#[cfg(not(target_arch = "wasm32"))]
45pub(crate) use native::HttpSettings;
46
47pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
48
49#[derive(Clone, Debug)]
50struct MaybeSemaphore(Arc<Option<Semaphore>>);
51
52#[allow(dead_code)] struct MaybeSemaphorePermit<'a>(Option<SemaphorePermit<'a>>);
54
55impl MaybeSemaphore {
56 fn new(max: Option<NonZeroUsize>) -> Self {
57 let inner = max.map(|i| Semaphore::new(i.into()));
58 MaybeSemaphore(Arc::new(inner))
59 }
60
61 async fn acquire(&self) -> MaybeSemaphorePermit<'_> {
62 match self.0.as_ref() {
63 Some(inner) => {
64 MaybeSemaphorePermit(inner.acquire().await.ok())
67 }
68 None => MaybeSemaphorePermit(None),
69 }
70 }
71}
72
73#[derive(Clone, Debug)]
74pub(crate) struct HttpClient {
75 pub(crate) inner: reqwest::Client,
76 pub(crate) request_config: RequestConfig,
77 concurrent_request_semaphore: MaybeSemaphore,
78 next_request_id: Arc<AtomicU64>,
79}
80
81impl HttpClient {
82 pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self {
83 HttpClient {
84 inner,
85 request_config,
86 concurrent_request_semaphore: MaybeSemaphore::new(
87 request_config.max_concurrent_requests,
88 ),
89 next_request_id: AtomicU64::new(0).into(),
90 }
91 }
92
93 fn get_request_id(&self) -> String {
94 let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst);
95 format!("REQ-{request_id}")
96 }
97
98 fn serialize_request<R>(
99 &self,
100 request: R,
101 config: RequestConfig,
102 homeserver: String,
103 access_token: Option<&str>,
104 server_versions: &[MatrixVersion],
105 ) -> Result<http::Request<Bytes>, IntoHttpError>
106 where
107 R: OutgoingRequest + Debug,
108 {
109 trace!(request_type = type_name::<R>(), "Serializing request");
110
111 let server_versions = if config.force_matrix_version.is_some() {
112 config.force_matrix_version.as_slice()
113 } else {
114 server_versions
115 };
116
117 let send_access_token = match access_token {
118 Some(access_token) => {
119 if config.force_auth {
120 SendAccessToken::Always(access_token)
121 } else {
122 SendAccessToken::IfRequired(access_token)
123 }
124 }
125 None => SendAccessToken::None,
126 };
127
128 let request = request
129 .try_into_http_request::<BytesMut>(&homeserver, send_access_token, server_versions)?
130 .map(|body| body.freeze());
131
132 Ok(request)
133 }
134
135 #[allow(clippy::too_many_arguments)]
136 #[instrument(
137 skip(self, request, config, homeserver, access_token, send_progress),
138 fields(
139 config,
140 uri,
141 method,
142 request_size,
143 request_body,
144 request_id,
145 status,
146 response_size,
147 sentry_event_id,
148 )
149 )]
150 pub async fn send<R>(
151 &self,
152 request: R,
153 config: Option<RequestConfig>,
154 homeserver: String,
155 access_token: Option<&str>,
156 server_versions: &[MatrixVersion],
157 send_progress: SharedObservable<TransmissionProgress>,
158 ) -> Result<R::IncomingResponse, HttpError>
159 where
160 R: OutgoingRequest + Debug,
161 HttpError: From<FromHttpResponseError<R::EndpointError>>,
162 {
163 let config = match config {
164 Some(config) => config,
165 None => self.request_config,
166 };
167
168 let request = {
171 let request_id = self.get_request_id();
172 let span = tracing::Span::current();
173
174 span.record("config", debug(config)).record("request_id", request_id);
177
178 let auth_scheme = R::METADATA.authentication;
179 match auth_scheme {
180 AuthScheme::AccessToken
181 | AuthScheme::AccessTokenOptional
182 | AuthScheme::AppserviceToken
183 | AuthScheme::None => {}
184 AuthScheme::ServerSignatures => {
185 return Err(HttpError::NotClientRequest);
186 }
187 }
188
189 let request = self
190 .serialize_request(request, config, homeserver, access_token, server_versions)
191 .map_err(HttpError::IntoHttp)?;
192
193 let method = request.method();
194
195 let mut uri_parts = request.uri().clone().into_parts();
196 if let Some(path_and_query) = &mut uri_parts.path_and_query {
197 *path_and_query =
198 path_and_query.path().try_into().expect("path is valid PathAndQuery");
199 }
200 let uri = http::Uri::from_parts(uri_parts).expect("created from valid URI");
201
202 span.record("method", debug(method)).record("uri", uri.to_string());
203
204 if [Method::POST, Method::PUT, Method::PATCH].contains(method) {
207 let request_size = request.body().len().try_into().unwrap_or(u64::MAX);
208 span.record("request_size", ByteSize(request_size).to_string_as(true));
209 }
210
211 request
212 };
213
214 let _handle = self.concurrent_request_semaphore.acquire().await;
216
217 match Box::pin(self.send_request::<R>(request, config, send_progress)).await {
220 Ok(response) => {
221 debug!("Got response");
222 Ok(response)
223 }
224 Err(e) => {
225 debug!("Error while sending request: {e:?}");
226 Err(e)
227 }
228 }
229 }
230}
231
232#[derive(Clone, Copy, Debug, Default)]
234pub struct TransmissionProgress {
235 pub current: usize,
237 pub total: usize,
239}
240
241async fn response_to_http_response(
242 mut response: reqwest::Response,
243) -> Result<http::Response<Bytes>, reqwest::Error> {
244 let status = response.status();
245
246 let mut http_builder = http::Response::builder().status(status);
247 let headers = http_builder.headers_mut().expect("Can't get the response builder headers");
248
249 for (k, v) in response.headers_mut().drain() {
250 if let Some(key) = k {
251 headers.insert(key, v);
252 }
253 }
254
255 let body = response.bytes().await?;
256
257 Ok(http_builder.body(body).expect("Can't construct a response using the given body"))
258}
259
260#[cfg(feature = "experimental-oidc")]
261impl tower::Service<http::Request<Bytes>> for HttpClient {
262 type Response = http::Response<Bytes>;
263 type Error = tower::BoxError;
264 type Future = matrix_sdk_base::BoxFuture<'static, Result<Self::Response, Self::Error>>;
265
266 fn poll_ready(
267 &mut self,
268 _cx: &mut std::task::Context<'_>,
269 ) -> std::task::Poll<Result<(), Self::Error>> {
270 std::task::Poll::Ready(Ok(()))
271 }
272
273 fn call(&mut self, req: http::Request<Bytes>) -> Self::Future {
274 let inner = self.inner.clone();
275
276 let fut = async move {
277 native::send_request(&inner, &req, DEFAULT_REQUEST_TIMEOUT, Default::default())
278 .await
279 .map_err(Into::into)
280 };
281 Box::pin(fut)
282 }
283}
284
285#[cfg(all(test, not(target_arch = "wasm32")))]
286mod tests {
287 use std::{
288 num::NonZeroUsize,
289 sync::{
290 atomic::{AtomicU8, Ordering},
291 Arc,
292 },
293 time::Duration,
294 };
295
296 use matrix_sdk_test::{async_test, test_json};
297 use wiremock::{
298 matchers::{method, path},
299 Mock, Request, ResponseTemplate,
300 };
301
302 use crate::{
303 http_client::RequestConfig,
304 test_utils::{set_client_session, test_client_builder_with_server},
305 };
306
307 #[async_test]
308 async fn test_ensure_concurrent_request_limit_is_observed() {
309 let (client_builder, server) = test_client_builder_with_server().await;
310 let client = client_builder
311 .request_config(RequestConfig::default().max_concurrent_requests(NonZeroUsize::new(5)))
312 .build()
313 .await
314 .unwrap();
315
316 set_client_session(&client).await;
317
318 let counter = Arc::new(AtomicU8::new(0));
319 let inner_counter = counter.clone();
320
321 Mock::given(method("GET"))
322 .and(path("/_matrix/client/versions"))
323 .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
324 .mount(&server)
325 .await;
326
327 Mock::given(method("GET"))
328 .and(path("_matrix/client/r0/account/whoami"))
329 .respond_with(move |_req: &Request| {
330 inner_counter.fetch_add(1, Ordering::SeqCst);
331 ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
333 })
334 .mount(&server)
335 .await;
336
337 let bg_task = tokio::spawn(async move {
338 futures_util::future::join_all((0..10).map(|_| client.whoami())).await
339 });
340
341 tokio::time::sleep(Duration::from_millis(300)).await;
343
344 assert_eq!(
345 counter.load(Ordering::SeqCst),
346 5,
347 "More requests passed than the limit we configured"
348 );
349 bg_task.abort();
350 }
351
352 #[async_test]
353 async fn test_ensure_no_max_concurrent_request_does_not_limit() {
354 let (client_builder, server) = test_client_builder_with_server().await;
355 let client = client_builder
356 .request_config(RequestConfig::default().max_concurrent_requests(None))
357 .build()
358 .await
359 .unwrap();
360
361 set_client_session(&client).await;
362
363 let counter = Arc::new(AtomicU8::new(0));
364 let inner_counter = counter.clone();
365
366 Mock::given(method("GET"))
367 .and(path("/_matrix/client/versions"))
368 .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
369 .mount(&server)
370 .await;
371
372 Mock::given(method("GET"))
373 .and(path("_matrix/client/r0/account/whoami"))
374 .respond_with(move |_req: &Request| {
375 inner_counter.fetch_add(1, Ordering::SeqCst);
376 ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
377 })
378 .mount(&server)
379 .await;
380
381 let bg_task = tokio::spawn(async move {
382 futures_util::future::join_all((0..254).map(|_| client.whoami())).await
383 });
384
385 tokio::time::sleep(Duration::from_secs(1)).await;
387
388 assert_eq!(counter.load(Ordering::SeqCst), 254, "Not all requests passed through");
389 bg_task.abort();
390 }
391}