matrix_sdk/http_client/
mod.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    any::type_name,
17    borrow::Cow,
18    fmt::Debug,
19    num::NonZeroUsize,
20    sync::{
21        Arc,
22        atomic::{AtomicU64, Ordering},
23    },
24    time::Duration,
25};
26
27use bytes::{Bytes, BytesMut};
28use bytesize::ByteSize;
29use eyeball::SharedObservable;
30use http::Method;
31use matrix_sdk_base::SendOutsideWasm;
32use ruma::api::{
33    OutgoingRequest, SupportedVersions,
34    auth_scheme::{AuthScheme, SendAccessToken},
35    error::{FromHttpResponseError, IntoHttpError},
36    path_builder,
37};
38use tokio::sync::{Semaphore, SemaphorePermit};
39use tracing::{debug, field::debug, instrument, trace};
40
41use crate::{HttpResult, config::RequestConfig, error::HttpError};
42
43#[cfg(not(target_family = "wasm"))]
44mod native;
45#[cfg(target_family = "wasm")]
46mod wasm;
47
48#[cfg(not(target_family = "wasm"))]
49pub(crate) use native::HttpSettings;
50
51pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
52
53#[derive(Clone, Debug)]
54struct MaybeSemaphore(Arc<Option<Semaphore>>);
55
56#[allow(dead_code)] // false-positive lint: we never use it but only hold it for the drop
57struct MaybeSemaphorePermit<'a>(Option<SemaphorePermit<'a>>);
58
59impl MaybeSemaphore {
60    fn new(max: Option<NonZeroUsize>) -> Self {
61        let inner = max.map(|i| Semaphore::new(i.into()));
62        MaybeSemaphore(Arc::new(inner))
63    }
64
65    async fn acquire(&self) -> MaybeSemaphorePermit<'_> {
66        match self.0.as_ref() {
67            Some(inner) => {
68                // This can only ever error if the semaphore was closed,
69                // which we never do, so we can safely ignore any error case
70                MaybeSemaphorePermit(inner.acquire().await.ok())
71            }
72            None => MaybeSemaphorePermit(None),
73        }
74    }
75}
76
77#[derive(Clone, Debug)]
78pub(crate) struct HttpClient {
79    pub(crate) inner: reqwest::Client,
80    pub(crate) request_config: RequestConfig,
81    concurrent_request_semaphore: MaybeSemaphore,
82    next_request_id: Arc<AtomicU64>,
83}
84
85impl HttpClient {
86    pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self {
87        HttpClient {
88            inner,
89            request_config,
90            concurrent_request_semaphore: MaybeSemaphore::new(
91                request_config.max_concurrent_requests,
92            ),
93            next_request_id: AtomicU64::new(0).into(),
94        }
95    }
96
97    fn get_request_id(&self) -> String {
98        let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst);
99        format!("REQ-{request_id}")
100    }
101
102    fn serialize_request<R>(
103        &self,
104        request: R,
105        config: RequestConfig,
106        homeserver: String,
107        access_token: Option<&str>,
108        path_builder_input: <R::PathBuilder as path_builder::PathBuilder>::Input<'_>,
109    ) -> Result<http::Request<Bytes>, IntoHttpError>
110    where
111        R: OutgoingRequest + Debug,
112        for<'a> R::Authentication: AuthScheme<Input<'a> = SendAccessToken<'a>>,
113    {
114        trace!(request_type = type_name::<R>(), "Serializing request");
115
116        let send_access_token = match access_token {
117            Some(access_token) => match (config.force_auth, config.skip_auth) {
118                (true, true) | (true, false) => SendAccessToken::Always(access_token),
119                (false, true) => SendAccessToken::None,
120                (false, false) => SendAccessToken::IfRequired(access_token),
121            },
122            None => SendAccessToken::None,
123        };
124
125        let request = request
126            .try_into_http_request::<BytesMut>(&homeserver, send_access_token, path_builder_input)?
127            .map(|body| body.freeze());
128
129        Ok(request)
130    }
131
132    #[allow(clippy::too_many_arguments)]
133    #[instrument(
134        skip(self, request, config, homeserver, access_token, path_builder_input, send_progress),
135        fields(
136            uri,
137            method,
138            request_id,
139            request_size,
140            request_duration,
141            status,
142            response_size,
143            sentry_event_id
144        )
145    )]
146    pub async fn send<R>(
147        &self,
148        request: R,
149        config: Option<RequestConfig>,
150        homeserver: String,
151        access_token: Option<&str>,
152        path_builder_input: <R::PathBuilder as path_builder::PathBuilder>::Input<'_>,
153        send_progress: SharedObservable<TransmissionProgress>,
154    ) -> Result<R::IncomingResponse, HttpError>
155    where
156        R: OutgoingRequest + Debug,
157        for<'a> R::Authentication: AuthScheme<Input<'a> = SendAccessToken<'a>>,
158        HttpError: From<FromHttpResponseError<R::EndpointError>>,
159    {
160        let config = match config {
161            Some(config) => config,
162            None => self.request_config,
163        };
164
165        // Keep some local variables in a separate scope so the compiler doesn't include
166        // them in the future type. https://github.com/rust-lang/rust/issues/57478
167        let request = {
168            let request_id = self.get_request_id();
169            let span = tracing::Span::current();
170
171            // At this point in the code, the config isn't behind an Option anymore, that's
172            // why we record it here, instead of in the #[instrument] macro.
173            span.record("config", debug(config)).record("request_id", request_id);
174
175            let request = self
176                .serialize_request(request, config, homeserver, access_token, path_builder_input)
177                .map_err(HttpError::IntoHttp)?;
178
179            let method = request.method();
180
181            let mut uri_parts = request.uri().clone().into_parts();
182            if let Some(path_and_query) = &mut uri_parts.path_and_query {
183                *path_and_query =
184                    path_and_query.path().try_into().expect("path is valid PathAndQuery");
185            }
186            let uri = http::Uri::from_parts(uri_parts).expect("created from valid URI");
187
188            span.record("method", debug(method)).record("uri", uri.to_string());
189
190            // POST, PUT, PATCH are the only methods that are reasonably used
191            // in conjunction with request bodies
192            if [Method::POST, Method::PUT, Method::PATCH].contains(method) {
193                let request_size = request.body().len().try_into().unwrap_or(u64::MAX);
194                span.record(
195                    "request_size",
196                    ByteSize(request_size).display().si_short().to_string(),
197                );
198            }
199
200            request
201        };
202
203        // will be automatically dropped at the end of this function
204        let _handle = self.concurrent_request_semaphore.acquire().await;
205
206        // There's a bunch of state in send_request, factor out a pinned inner
207        // future to reduce this size of futures that await this function.
208        match Box::pin(self.send_request::<R>(request, config, send_progress)).await {
209            Ok(response) => {
210                debug!("Got response");
211                Ok(response)
212            }
213            Err(e) => {
214                debug!("Error while sending request: {e:?}");
215                Err(e)
216            }
217        }
218    }
219}
220
221/// Progress of sending or receiving a payload.
222#[derive(Clone, Copy, Debug, Default)]
223pub struct TransmissionProgress {
224    /// How many bytes were already transferred.
225    pub current: usize,
226    /// How many bytes there are in total.
227    pub total: usize,
228}
229
230async fn response_to_http_response(
231    mut response: reqwest::Response,
232) -> Result<http::Response<Bytes>, reqwest::Error> {
233    let status = response.status();
234
235    let mut http_builder = http::Response::builder().status(status);
236    let headers = http_builder.headers_mut().expect("Can't get the response builder headers");
237
238    for (k, v) in response.headers_mut().drain() {
239        if let Some(key) = k {
240            headers.insert(key, v);
241        }
242    }
243
244    let body = response.bytes().await?;
245
246    Ok(http_builder.body(body).expect("Can't construct a response using the given body"))
247}
248
249/// Marker trait to identify the authentication schemes that the
250/// [`Client`](crate::Client) supports.
251///
252/// This trait can also be implemented for custom
253/// [`PathBuilder`](path_builder::PathBuilder)s if necessary.
254pub trait SupportedPathBuilder: path_builder::PathBuilder {
255    fn get_path_builder_input(
256        client: &crate::Client,
257        skip_auth: bool,
258    ) -> impl Future<Output = HttpResult<Self::Input<'static>>> + SendOutsideWasm;
259}
260
261impl SupportedPathBuilder for path_builder::VersionHistory {
262    async fn get_path_builder_input(
263        client: &crate::Client,
264        skip_auth: bool,
265    ) -> HttpResult<Cow<'static, SupportedVersions>> {
266        if skip_auth {
267            let cached_versions = client.get_cached_versions().await;
268
269            let versions = if let Some(versions) = cached_versions {
270                versions
271            } else {
272                // If we're skipping auth we might not get all the supported features, so just
273                // fetch the versions and don't cache them.
274                let request_config = RequestConfig::default().retry_limit(5).skip_auth();
275                let response = client.fetch_server_versions(Some(request_config)).await?;
276
277                response.as_supported_versions()
278            };
279
280            Ok(Cow::Owned(versions))
281        } else {
282            client.supported_versions().await.map(Cow::Owned)
283        }
284    }
285}
286
287impl SupportedPathBuilder for path_builder::SinglePath {
288    async fn get_path_builder_input(_client: &crate::Client, _skip_auth: bool) -> HttpResult<()> {
289        Ok(())
290    }
291}
292
293#[cfg(all(test, not(target_family = "wasm")))]
294mod tests {
295    use std::{
296        num::NonZeroUsize,
297        sync::{
298            Arc,
299            atomic::{AtomicU8, Ordering},
300        },
301        time::Duration,
302    };
303
304    use matrix_sdk_common::executor::spawn;
305    use matrix_sdk_test::{async_test, test_json};
306    use wiremock::{
307        Mock, Request, ResponseTemplate,
308        matchers::{method, path},
309    };
310
311    use crate::{
312        http_client::RequestConfig,
313        test_utils::{set_client_session, test_client_builder_with_server},
314    };
315
316    #[async_test]
317    async fn test_ensure_concurrent_request_limit_is_observed() {
318        let (client_builder, server) = test_client_builder_with_server().await;
319        let client = client_builder
320            .request_config(RequestConfig::default().max_concurrent_requests(NonZeroUsize::new(5)))
321            .build()
322            .await
323            .unwrap();
324
325        set_client_session(&client).await;
326
327        let counter = Arc::new(AtomicU8::new(0));
328        let inner_counter = counter.clone();
329
330        Mock::given(method("GET"))
331            .and(path("/_matrix/client/versions"))
332            .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
333            .mount(&server)
334            .await;
335
336        Mock::given(method("GET"))
337            .and(path("_matrix/client/r0/account/whoami"))
338            .respond_with(move |_req: &Request| {
339                inner_counter.fetch_add(1, Ordering::SeqCst);
340                // we stall the requests
341                ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
342            })
343            .mount(&server)
344            .await;
345
346        let bg_task = spawn(async move {
347            futures_util::future::join_all((0..10).map(|_| client.whoami())).await
348        });
349
350        // give it some time to issue the requests
351        tokio::time::sleep(Duration::from_millis(300)).await;
352
353        assert_eq!(
354            counter.load(Ordering::SeqCst),
355            5,
356            "More requests passed than the limit we configured"
357        );
358        bg_task.abort();
359    }
360
361    #[async_test]
362    async fn test_ensure_no_max_concurrent_request_does_not_limit() {
363        let (client_builder, server) = test_client_builder_with_server().await;
364        let client = client_builder
365            .request_config(RequestConfig::default().max_concurrent_requests(None))
366            .build()
367            .await
368            .unwrap();
369
370        set_client_session(&client).await;
371
372        let counter = Arc::new(AtomicU8::new(0));
373        let inner_counter = counter.clone();
374
375        Mock::given(method("GET"))
376            .and(path("/_matrix/client/versions"))
377            .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
378            .mount(&server)
379            .await;
380
381        Mock::given(method("GET"))
382            .and(path("_matrix/client/r0/account/whoami"))
383            .respond_with(move |_req: &Request| {
384                inner_counter.fetch_add(1, Ordering::SeqCst);
385                ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
386            })
387            .mount(&server)
388            .await;
389
390        let bg_task = spawn(async move {
391            futures_util::future::join_all((0..254).map(|_| client.whoami())).await
392        });
393
394        // give it some time to issue the requests
395        tokio::time::sleep(Duration::from_secs(1)).await;
396
397        assert_eq!(counter.load(Ordering::SeqCst), 254, "Not all requests passed through");
398        bg_task.abort();
399    }
400}