Skip to main content

matrix_sdk/http_client/
mod.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    any::type_name,
17    borrow::Cow,
18    fmt::Debug,
19    num::NonZeroUsize,
20    sync::{
21        Arc,
22        atomic::{AtomicU64, Ordering},
23    },
24    time::Duration,
25};
26
27use bytes::{Bytes, BytesMut};
28use bytesize::ByteSize;
29use eyeball::SharedObservable;
30use http::Method;
31use matrix_sdk_base::SendOutsideWasm;
32use ruma::api::{
33    OutgoingRequest, SupportedVersions,
34    auth_scheme::{self, AuthScheme, SendAccessToken},
35    error::{FromHttpResponseError, IntoHttpError},
36    path_builder,
37};
38use tokio::sync::{Semaphore, SemaphorePermit};
39use tracing::{debug, error, field::debug, instrument, trace};
40
41use crate::{HttpResult, client::caches::CachedValue, config::RequestConfig, error::HttpError};
42
43#[cfg(not(target_family = "wasm"))]
44mod native;
45#[cfg(target_family = "wasm")]
46mod wasm;
47
48#[cfg(not(target_family = "wasm"))]
49pub(crate) use native::HttpSettings;
50
51pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
52
53#[derive(Clone, Debug)]
54struct MaybeSemaphore(Arc<Option<Semaphore>>);
55
56#[allow(dead_code)] // false-positive lint: we never use it but only hold it for the drop
57struct MaybeSemaphorePermit<'a>(Option<SemaphorePermit<'a>>);
58
59impl MaybeSemaphore {
60    fn new(max: Option<NonZeroUsize>) -> Self {
61        let inner = max.map(|i| Semaphore::new(i.into()));
62        MaybeSemaphore(Arc::new(inner))
63    }
64
65    async fn acquire(&self) -> MaybeSemaphorePermit<'_> {
66        match self.0.as_ref() {
67            Some(inner) => {
68                // This can only ever error if the semaphore was closed,
69                // which we never do, so we can safely ignore any error case
70                MaybeSemaphorePermit(inner.acquire().await.ok())
71            }
72            None => MaybeSemaphorePermit(None),
73        }
74    }
75}
76
77#[derive(Clone, Debug)]
78pub(crate) struct HttpClient {
79    pub(crate) inner: reqwest::Client,
80    pub(crate) request_config: RequestConfig,
81    concurrent_request_semaphore: MaybeSemaphore,
82    next_request_id: Arc<AtomicU64>,
83}
84
85impl HttpClient {
86    pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self {
87        HttpClient {
88            inner,
89            request_config,
90            concurrent_request_semaphore: MaybeSemaphore::new(
91                request_config.max_concurrent_requests,
92            ),
93            next_request_id: AtomicU64::new(0).into(),
94        }
95    }
96
97    fn get_request_id(&self) -> String {
98        let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst);
99        format!("REQ-{request_id}")
100    }
101
102    fn serialize_request<R>(
103        &self,
104        request: R,
105        config: RequestConfig,
106        homeserver: String,
107        access_token: Option<&str>,
108        path_builder_input: <R::PathBuilder as path_builder::PathBuilder>::Input<'_>,
109    ) -> Result<http::Request<Bytes>, IntoHttpError>
110    where
111        R: OutgoingRequest + Debug,
112        R::Authentication: SupportedAuthScheme,
113    {
114        trace!(request_type = type_name::<R>(), "Serializing request");
115
116        let send_access_token = match access_token {
117            Some(access_token) => match (config.force_auth, config.skip_auth) {
118                (true, true) | (true, false) => SendAccessToken::Always(access_token),
119                (false, true) => SendAccessToken::None,
120                (false, false) => SendAccessToken::IfRequired(access_token),
121            },
122            None => SendAccessToken::None,
123        };
124        let authentication_input = R::Authentication::authentication_input(send_access_token);
125
126        let request = request
127            .try_into_http_request::<BytesMut>(
128                &homeserver,
129                authentication_input,
130                path_builder_input,
131            )?
132            .map(|body| body.freeze());
133
134        Ok(request)
135    }
136
137    #[allow(clippy::too_many_arguments)]
138    #[instrument(
139        skip(self, request, config, homeserver, access_token, path_builder_input, send_progress),
140        fields(
141            uri,
142            method,
143            request_id,
144            request_size,
145            request_duration,
146            status,
147            response_size,
148            sentry_event_id
149        )
150    )]
151    pub async fn send<R>(
152        &self,
153        request: R,
154        config: Option<RequestConfig>,
155        homeserver: String,
156        access_token: Option<&str>,
157        path_builder_input: <R::PathBuilder as path_builder::PathBuilder>::Input<'_>,
158        send_progress: SharedObservable<TransmissionProgress>,
159    ) -> Result<R::IncomingResponse, HttpError>
160    where
161        R: OutgoingRequest + Debug,
162        R::Authentication: SupportedAuthScheme,
163        HttpError: From<FromHttpResponseError<R::EndpointError>>,
164    {
165        let config = match config {
166            Some(config) => config,
167            None => self.request_config,
168        };
169
170        // Keep some local variables in a separate scope so the compiler doesn't include
171        // them in the future type. https://github.com/rust-lang/rust/issues/57478
172        let request = {
173            let request_id = self.get_request_id();
174            let span = tracing::Span::current();
175
176            // At this point in the code, the config isn't behind an Option anymore, that's
177            // why we record it here, instead of in the #[instrument] macro.
178            span.record("config", debug(config)).record("request_id", request_id);
179
180            let request = self
181                .serialize_request(request, config, homeserver, access_token, path_builder_input)
182                .map_err(HttpError::IntoHttp)?;
183
184            let method = request.method();
185
186            let mut uri_parts = request.uri().clone().into_parts();
187
188            // Erase the query parameters for the sake of secrecy (in case a token is
189            // present).
190            if let Some(path_and_query) = &mut uri_parts.path_and_query {
191                *path_and_query =
192                    path_and_query.path().try_into().expect("path is valid PathAndQuery");
193            }
194
195            let uri = http::Uri::from_parts(uri_parts).expect("created from valid URI");
196
197            span.record("method", debug(method)).record("uri", uri.to_string());
198
199            // POST, PUT, PATCH are the only methods that are reasonably used
200            // in conjunction with request bodies
201            if [Method::POST, Method::PUT, Method::PATCH].contains(method) {
202                let request_size = request.body().len().try_into().unwrap_or(u64::MAX);
203                span.record(
204                    "request_size",
205                    ByteSize(request_size).display().si_short().to_string(),
206                );
207            }
208
209            request
210        };
211
212        // will be automatically dropped at the end of this function
213        let _handle = self.concurrent_request_semaphore.acquire().await;
214
215        // There's a bunch of state in send_request, factor out a pinned inner
216        // future to reduce this size of futures that await this function.
217        match Box::pin(self.send_request::<R>(request, config, send_progress)).await {
218            Ok(response) => {
219                debug!("Got response");
220                Ok(response)
221            }
222            Err(e) => {
223                error!("Error while sending request: {e:?}");
224                Err(e)
225            }
226        }
227    }
228}
229
230/// Progress of sending or receiving a payload.
231#[derive(Clone, Copy, Debug, Default)]
232pub struct TransmissionProgress {
233    /// How many bytes were already transferred.
234    pub current: usize,
235    /// How many bytes there are in total.
236    pub total: usize,
237}
238
239async fn response_to_http_response(
240    mut response: reqwest::Response,
241) -> Result<http::Response<Bytes>, reqwest::Error> {
242    let status = response.status();
243
244    let mut http_builder = http::Response::builder().status(status);
245    let headers = http_builder.headers_mut().expect("Can't get the response builder headers");
246
247    for (k, v) in response.headers_mut().drain() {
248        if let Some(key) = k {
249            headers.insert(key, v);
250        }
251    }
252
253    let body = response.bytes().await?;
254
255    Ok(http_builder.body(body).expect("Can't construct a response using the given body"))
256}
257
258/// Marker trait to identify the authentication schemes that the
259/// [`Client`](crate::Client) supports.
260///
261/// This trait can also be implemented for custom
262/// [`PathBuilder`](path_builder::PathBuilder)s if necessary.
263pub trait SupportedAuthScheme: AuthScheme {
264    fn authentication_input(access_token: SendAccessToken<'_>) -> Self::Input<'_>;
265}
266
267impl SupportedAuthScheme for auth_scheme::NoAccessToken {
268    fn authentication_input(access_token: SendAccessToken<'_>) -> Self::Input<'_> {
269        access_token
270    }
271}
272
273impl SupportedAuthScheme for auth_scheme::AccessToken {
274    fn authentication_input(access_token: SendAccessToken<'_>) -> Self::Input<'_> {
275        access_token
276    }
277}
278
279impl SupportedAuthScheme for auth_scheme::AccessTokenOptional {
280    fn authentication_input(access_token: SendAccessToken<'_>) -> Self::Input<'_> {
281        access_token
282    }
283}
284
285impl SupportedAuthScheme for auth_scheme::AppserviceToken {
286    fn authentication_input(access_token: SendAccessToken<'_>) -> Self::Input<'_> {
287        access_token
288    }
289}
290
291impl SupportedAuthScheme for auth_scheme::AppserviceTokenOptional {
292    fn authentication_input(access_token: SendAccessToken<'_>) -> Self::Input<'_> {
293        access_token
294    }
295}
296
297impl SupportedAuthScheme for auth_scheme::NoAuthentication {
298    fn authentication_input(_access_token: SendAccessToken<'_>) -> Self::Input<'_> {}
299}
300
301/// Marker trait to identify the path builders that the
302/// [`Client`](crate::Client) supports.
303///
304/// This trait can also be implemented for custom
305/// [`PathBuilder`](path_builder::PathBuilder)s if necessary.
306pub trait SupportedPathBuilder: path_builder::PathBuilder {
307    fn get_path_builder_input(
308        client: &crate::Client,
309        skip_auth: bool,
310    ) -> impl Future<Output = HttpResult<Self::Input<'static>>> + SendOutsideWasm;
311}
312
313impl SupportedPathBuilder for path_builder::VersionHistory {
314    async fn get_path_builder_input(
315        client: &crate::Client,
316        skip_auth: bool,
317    ) -> HttpResult<Cow<'static, SupportedVersions>> {
318        // We always enable "failsafe" mode for the GET /versions requests in this
319        // function. It disables trying to refresh the access token for those requests,
320        // to avoid possible deadlocks.
321
322        if !client.auth_ctx().has_valid_access_token() {
323            // Get the value in the cache without waiting. If the lock is not available, we
324            // are in the middle of refreshing the cache so waiting for it would result in a
325            // deadlock.
326            if let Ok(CachedValue::Cached(versions)) =
327                client.inner.caches.supported_versions.try_read().as_deref()
328            {
329                return Ok(Cow::Owned(versions.clone()));
330            }
331
332            // The request will skip auth so we might not get all the supported features, so
333            // just fetch the supported versions and don't cache them.
334            let response = client.fetch_server_versions_inner(true, None).await?;
335
336            Ok(Cow::Owned(response.as_supported_versions()))
337        } else if skip_auth {
338            let cached_versions = client.get_cached_supported_versions().await;
339
340            let versions = if let Some(versions) = cached_versions {
341                versions
342            } else {
343                // If we're skipping auth we might not get all the supported features, so just
344                // fetch the versions and don't cache them.
345                let request_config = RequestConfig::default().retry_limit(5).skip_auth();
346                let response =
347                    client.fetch_server_versions_inner(true, Some(request_config)).await?;
348
349                response.as_supported_versions()
350            };
351
352            Ok(Cow::Owned(versions))
353        } else {
354            client.supported_versions_inner(true).await.map(Cow::Owned)
355        }
356    }
357}
358
359impl SupportedPathBuilder for path_builder::SinglePath {
360    async fn get_path_builder_input(_client: &crate::Client, _skip_auth: bool) -> HttpResult<()> {
361        Ok(())
362    }
363}
364
365#[cfg(feature = "rustls-tls")]
366pub mod rustls {
367    //! Functions for configuring the default [`CryptoProvider`] when using
368    //! `reqwest` with `rustls`'s implementation of TLS.
369
370    use rustls::crypto::CryptoProvider;
371
372    /// The default [`CryptoProvider`] preferred by this crate.
373    ///
374    /// Typically, the default [`CryptoProvider`] for `rustls`
375    /// is `aws-lc-rs`, but due to licensing issues this crate
376    /// prefers `ring`.
377    pub fn default_crypto_provider() -> CryptoProvider {
378        rustls::crypto::ring::default_provider()
379    }
380
381    /// The `rustls-tls` flag enables the `rustls` implementation of TLS for
382    /// `reqwest`, but without a [`CryptoProvider`]. This means that no
383    /// default provider is installed, which will cause `reqwest::Client::new()`
384    /// to panic.
385    ///
386    /// This functions installs the preferred default provider given by
387    /// [`default_crypto_provider`], if no default has previously been
388    /// installed.
389    pub fn install_default_crypto_provider_if_none_installed() {
390        if default_crypto_provider().install_default().is_ok() {
391            // This log message seems to cause `nextest` to get confused,
392            // so it won't be printed when running tests.
393            #[cfg(not(test))]
394            tracing::info!("No rustls crypto provider set, setting default provider to ring.");
395        }
396    }
397
398    /// Install a default [`CryptoProvider`] for `rustls`, if one isn't already
399    /// installed. This uses [`ctor`] to run before any tests.
400    #[cfg(test)]
401    macro_rules! install_default_crypto_provider_for_tests {
402        () => {
403            #[cfg(not(target_family = "wasm"))]
404            #[ctor::ctor]
405            fn install_default_crypto_provider_for_tests() {
406                $crate::http_client::rustls::install_default_crypto_provider_if_none_installed();
407            }
408        };
409    }
410
411    #[cfg(test)]
412    pub(crate) use install_default_crypto_provider_for_tests;
413}
414
415#[cfg(all(test, not(target_family = "wasm")))]
416mod tests {
417    use std::{
418        num::NonZeroUsize,
419        sync::{
420            Arc,
421            atomic::{AtomicU8, Ordering},
422        },
423        time::Duration,
424    };
425
426    use matrix_sdk_common::executor::spawn;
427    use matrix_sdk_test::{async_test, test_json};
428    use wiremock::{
429        Mock, Request, ResponseTemplate,
430        matchers::{method, path},
431    };
432
433    use crate::{
434        http_client::RequestConfig,
435        test_utils::{set_client_session, test_client_builder_with_server},
436    };
437
438    #[async_test]
439    async fn test_ensure_concurrent_request_limit_is_observed() {
440        let (client_builder, server) = test_client_builder_with_server().await;
441        let client = client_builder
442            .request_config(RequestConfig::default().max_concurrent_requests(NonZeroUsize::new(5)))
443            .build()
444            .await
445            .unwrap();
446
447        set_client_session(&client).await;
448
449        let counter = Arc::new(AtomicU8::new(0));
450        let inner_counter = counter.clone();
451
452        Mock::given(method("GET"))
453            .and(path("/_matrix/client/versions"))
454            .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
455            .mount(&server)
456            .await;
457
458        Mock::given(method("GET"))
459            .and(path("_matrix/client/r0/account/whoami"))
460            .respond_with(move |_req: &Request| {
461                inner_counter.fetch_add(1, Ordering::SeqCst);
462                // we stall the requests
463                ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
464            })
465            .mount(&server)
466            .await;
467
468        let bg_task = spawn(async move {
469            futures_util::future::join_all((0..10).map(|_| client.whoami())).await
470        });
471
472        // give it some time to issue the requests
473        tokio::time::sleep(Duration::from_millis(300)).await;
474
475        assert_eq!(
476            counter.load(Ordering::SeqCst),
477            5,
478            "More requests passed than the limit we configured"
479        );
480        bg_task.abort();
481    }
482
483    #[async_test]
484    async fn test_ensure_no_max_concurrent_request_does_not_limit() {
485        let (client_builder, server) = test_client_builder_with_server().await;
486        let client = client_builder
487            .request_config(RequestConfig::default().max_concurrent_requests(None))
488            .build()
489            .await
490            .unwrap();
491
492        set_client_session(&client).await;
493
494        let counter = Arc::new(AtomicU8::new(0));
495        let inner_counter = counter.clone();
496
497        Mock::given(method("GET"))
498            .and(path("/_matrix/client/versions"))
499            .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
500            .mount(&server)
501            .await;
502
503        Mock::given(method("GET"))
504            .and(path("_matrix/client/r0/account/whoami"))
505            .respond_with(move |_req: &Request| {
506                inner_counter.fetch_add(1, Ordering::SeqCst);
507                ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
508            })
509            .mount(&server)
510            .await;
511
512        let bg_task = spawn(async move {
513            futures_util::future::join_all((0..254).map(|_| client.whoami())).await
514        });
515
516        // give it some time to issue the requests
517        tokio::time::sleep(Duration::from_secs(1)).await;
518
519        assert_eq!(counter.load(Ordering::SeqCst), 254, "Not all requests passed through");
520        bg_task.abort();
521    }
522}