matrix_sdk/http_client/
mod.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    any::type_name,
17    fmt::Debug,
18    num::NonZeroUsize,
19    sync::{
20        atomic::{AtomicU64, Ordering},
21        Arc,
22    },
23    time::Duration,
24};
25
26use bytes::{Bytes, BytesMut};
27use bytesize::ByteSize;
28use eyeball::SharedObservable;
29use http::Method;
30use ruma::api::{
31    error::{FromHttpResponseError, IntoHttpError},
32    AuthScheme, MatrixVersion, OutgoingRequest, SendAccessToken,
33};
34use tokio::sync::{Semaphore, SemaphorePermit};
35use tracing::{debug, field::debug, instrument, trace};
36
37use crate::{config::RequestConfig, error::HttpError};
38
39#[cfg(not(target_arch = "wasm32"))]
40mod native;
41#[cfg(target_arch = "wasm32")]
42mod wasm;
43
44#[cfg(not(target_arch = "wasm32"))]
45pub(crate) use native::HttpSettings;
46
47pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
48
49#[derive(Clone, Debug)]
50struct MaybeSemaphore(Arc<Option<Semaphore>>);
51
52#[allow(dead_code)] // false-positive lint: we never use it but only hold it for the drop
53struct MaybeSemaphorePermit<'a>(Option<SemaphorePermit<'a>>);
54
55impl MaybeSemaphore {
56    fn new(max: Option<NonZeroUsize>) -> Self {
57        let inner = max.map(|i| Semaphore::new(i.into()));
58        MaybeSemaphore(Arc::new(inner))
59    }
60
61    async fn acquire(&self) -> MaybeSemaphorePermit<'_> {
62        match self.0.as_ref() {
63            Some(inner) => {
64                // This can only ever error if the semaphore was closed,
65                // which we never do, so we can safely ignore any error case
66                MaybeSemaphorePermit(inner.acquire().await.ok())
67            }
68            None => MaybeSemaphorePermit(None),
69        }
70    }
71}
72
73#[derive(Clone, Debug)]
74pub(crate) struct HttpClient {
75    pub(crate) inner: reqwest::Client,
76    pub(crate) request_config: RequestConfig,
77    concurrent_request_semaphore: MaybeSemaphore,
78    next_request_id: Arc<AtomicU64>,
79}
80
81impl HttpClient {
82    pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self {
83        HttpClient {
84            inner,
85            request_config,
86            concurrent_request_semaphore: MaybeSemaphore::new(
87                request_config.max_concurrent_requests,
88            ),
89            next_request_id: AtomicU64::new(0).into(),
90        }
91    }
92
93    fn get_request_id(&self) -> String {
94        let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst);
95        format!("REQ-{request_id}")
96    }
97
98    fn serialize_request<R>(
99        &self,
100        request: R,
101        config: RequestConfig,
102        homeserver: String,
103        access_token: Option<&str>,
104        server_versions: &[MatrixVersion],
105    ) -> Result<http::Request<Bytes>, IntoHttpError>
106    where
107        R: OutgoingRequest + Debug,
108    {
109        trace!(request_type = type_name::<R>(), "Serializing request");
110
111        let server_versions = if config.force_matrix_version.is_some() {
112            config.force_matrix_version.as_slice()
113        } else {
114            server_versions
115        };
116
117        let send_access_token = match access_token {
118            Some(access_token) => {
119                if config.force_auth {
120                    SendAccessToken::Always(access_token)
121                } else {
122                    SendAccessToken::IfRequired(access_token)
123                }
124            }
125            None => SendAccessToken::None,
126        };
127
128        let request = request
129            .try_into_http_request::<BytesMut>(&homeserver, send_access_token, server_versions)?
130            .map(|body| body.freeze());
131
132        Ok(request)
133    }
134
135    #[allow(clippy::too_many_arguments)]
136    #[instrument(
137        skip(self, request, config, homeserver, access_token, send_progress),
138        fields(
139            config,
140            uri,
141            method,
142            request_size,
143            request_body,
144            request_id,
145            status,
146            response_size,
147            sentry_event_id,
148        )
149    )]
150    pub async fn send<R>(
151        &self,
152        request: R,
153        config: Option<RequestConfig>,
154        homeserver: String,
155        access_token: Option<&str>,
156        server_versions: &[MatrixVersion],
157        send_progress: SharedObservable<TransmissionProgress>,
158    ) -> Result<R::IncomingResponse, HttpError>
159    where
160        R: OutgoingRequest + Debug,
161        HttpError: From<FromHttpResponseError<R::EndpointError>>,
162    {
163        let config = match config {
164            Some(config) => config,
165            None => self.request_config,
166        };
167
168        // Keep some local variables in a separate scope so the compiler doesn't include
169        // them in the future type. https://github.com/rust-lang/rust/issues/57478
170        let request = {
171            let request_id = self.get_request_id();
172            let span = tracing::Span::current();
173
174            // At this point in the code, the config isn't behind an Option anymore, that's
175            // why we record it here, instead of in the #[instrument] macro.
176            span.record("config", debug(config)).record("request_id", request_id);
177
178            let auth_scheme = R::METADATA.authentication;
179            match auth_scheme {
180                AuthScheme::AccessToken
181                | AuthScheme::AccessTokenOptional
182                | AuthScheme::AppserviceToken
183                | AuthScheme::None => {}
184                AuthScheme::ServerSignatures => {
185                    return Err(HttpError::NotClientRequest);
186                }
187            }
188
189            let request = self
190                .serialize_request(request, config, homeserver, access_token, server_versions)
191                .map_err(HttpError::IntoHttp)?;
192
193            let method = request.method();
194
195            let mut uri_parts = request.uri().clone().into_parts();
196            if let Some(path_and_query) = &mut uri_parts.path_and_query {
197                *path_and_query =
198                    path_and_query.path().try_into().expect("path is valid PathAndQuery");
199            }
200            let uri = http::Uri::from_parts(uri_parts).expect("created from valid URI");
201
202            span.record("method", debug(method)).record("uri", uri.to_string());
203
204            // POST, PUT, PATCH are the only methods that are reasonably used
205            // in conjunction with request bodies
206            if [Method::POST, Method::PUT, Method::PATCH].contains(method) {
207                let request_size = request.body().len().try_into().unwrap_or(u64::MAX);
208                span.record("request_size", ByteSize(request_size).to_string_as(true));
209            }
210
211            request
212        };
213
214        // will be automatically dropped at the end of this function
215        let _handle = self.concurrent_request_semaphore.acquire().await;
216
217        // There's a bunch of state in send_request, factor out a pinned inner
218        // future to reduce this size of futures that await this function.
219        match Box::pin(self.send_request::<R>(request, config, send_progress)).await {
220            Ok(response) => {
221                debug!("Got response");
222                Ok(response)
223            }
224            Err(e) => {
225                debug!("Error while sending request: {e:?}");
226                Err(e)
227            }
228        }
229    }
230}
231
232/// Progress of sending or receiving a payload.
233#[derive(Clone, Copy, Debug, Default)]
234pub struct TransmissionProgress {
235    /// How many bytes were already transferred.
236    pub current: usize,
237    /// How many bytes there are in total.
238    pub total: usize,
239}
240
241async fn response_to_http_response(
242    mut response: reqwest::Response,
243) -> Result<http::Response<Bytes>, reqwest::Error> {
244    let status = response.status();
245
246    let mut http_builder = http::Response::builder().status(status);
247    let headers = http_builder.headers_mut().expect("Can't get the response builder headers");
248
249    for (k, v) in response.headers_mut().drain() {
250        if let Some(key) = k {
251            headers.insert(key, v);
252        }
253    }
254
255    let body = response.bytes().await?;
256
257    Ok(http_builder.body(body).expect("Can't construct a response using the given body"))
258}
259
260#[cfg(feature = "experimental-oidc")]
261impl tower::Service<http::Request<Bytes>> for HttpClient {
262    type Response = http::Response<Bytes>;
263    type Error = tower::BoxError;
264    type Future = matrix_sdk_base::BoxFuture<'static, Result<Self::Response, Self::Error>>;
265
266    fn poll_ready(
267        &mut self,
268        _cx: &mut std::task::Context<'_>,
269    ) -> std::task::Poll<Result<(), Self::Error>> {
270        std::task::Poll::Ready(Ok(()))
271    }
272
273    fn call(&mut self, req: http::Request<Bytes>) -> Self::Future {
274        let inner = self.inner.clone();
275
276        let fut = async move {
277            native::send_request(&inner, &req, DEFAULT_REQUEST_TIMEOUT, Default::default())
278                .await
279                .map_err(Into::into)
280        };
281        Box::pin(fut)
282    }
283}
284
285#[cfg(all(test, not(target_arch = "wasm32")))]
286mod tests {
287    use std::{
288        num::NonZeroUsize,
289        sync::{
290            atomic::{AtomicU8, Ordering},
291            Arc,
292        },
293        time::Duration,
294    };
295
296    use matrix_sdk_test::{async_test, test_json};
297    use wiremock::{
298        matchers::{method, path},
299        Mock, Request, ResponseTemplate,
300    };
301
302    use crate::{
303        http_client::RequestConfig,
304        test_utils::{set_client_session, test_client_builder_with_server},
305    };
306
307    #[async_test]
308    async fn test_ensure_concurrent_request_limit_is_observed() {
309        let (client_builder, server) = test_client_builder_with_server().await;
310        let client = client_builder
311            .request_config(RequestConfig::default().max_concurrent_requests(NonZeroUsize::new(5)))
312            .build()
313            .await
314            .unwrap();
315
316        set_client_session(&client).await;
317
318        let counter = Arc::new(AtomicU8::new(0));
319        let inner_counter = counter.clone();
320
321        Mock::given(method("GET"))
322            .and(path("/_matrix/client/versions"))
323            .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
324            .mount(&server)
325            .await;
326
327        Mock::given(method("GET"))
328            .and(path("_matrix/client/r0/account/whoami"))
329            .respond_with(move |_req: &Request| {
330                inner_counter.fetch_add(1, Ordering::SeqCst);
331                // we stall the requests
332                ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
333            })
334            .mount(&server)
335            .await;
336
337        let bg_task = tokio::spawn(async move {
338            futures_util::future::join_all((0..10).map(|_| client.whoami())).await
339        });
340
341        // give it some time to issue the requests
342        tokio::time::sleep(Duration::from_millis(300)).await;
343
344        assert_eq!(
345            counter.load(Ordering::SeqCst),
346            5,
347            "More requests passed than the limit we configured"
348        );
349        bg_task.abort();
350    }
351
352    #[async_test]
353    async fn test_ensure_no_max_concurrent_request_does_not_limit() {
354        let (client_builder, server) = test_client_builder_with_server().await;
355        let client = client_builder
356            .request_config(RequestConfig::default().max_concurrent_requests(None))
357            .build()
358            .await
359            .unwrap();
360
361        set_client_session(&client).await;
362
363        let counter = Arc::new(AtomicU8::new(0));
364        let inner_counter = counter.clone();
365
366        Mock::given(method("GET"))
367            .and(path("/_matrix/client/versions"))
368            .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
369            .mount(&server)
370            .await;
371
372        Mock::given(method("GET"))
373            .and(path("_matrix/client/r0/account/whoami"))
374            .respond_with(move |_req: &Request| {
375                inner_counter.fetch_add(1, Ordering::SeqCst);
376                ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
377            })
378            .mount(&server)
379            .await;
380
381        let bg_task = tokio::spawn(async move {
382            futures_util::future::join_all((0..254).map(|_| client.whoami())).await
383        });
384
385        // give it some time to issue the requests
386        tokio::time::sleep(Duration::from_secs(1)).await;
387
388        assert_eq!(counter.load(Ordering::SeqCst), 254, "Not all requests passed through");
389        bg_task.abort();
390    }
391}