matrix_sdk/http_client/
native.rs

1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    fmt::Debug,
17    mem,
18    sync::atomic::{AtomicU64, Ordering},
19    time::Duration,
20};
21
22use backoff::{future::retry, Error as RetryError, ExponentialBackoff};
23use bytes::Bytes;
24use bytesize::ByteSize;
25use eyeball::SharedObservable;
26use http::header::CONTENT_LENGTH;
27use reqwest::{tls, Certificate};
28use ruma::api::{error::FromHttpResponseError, IncomingResponse, OutgoingRequest};
29use tracing::{debug, info, warn};
30
31use super::{response_to_http_response, HttpClient, TransmissionProgress, DEFAULT_REQUEST_TIMEOUT};
32use crate::{
33    config::RequestConfig,
34    error::{HttpError, RetryKind},
35};
36
37impl HttpClient {
38    pub(super) async fn send_request<R>(
39        &self,
40        request: http::Request<Bytes>,
41        config: RequestConfig,
42        send_progress: SharedObservable<TransmissionProgress>,
43    ) -> Result<R::IncomingResponse, HttpError>
44    where
45        R: OutgoingRequest + Debug,
46        HttpError: From<FromHttpResponseError<R::EndpointError>>,
47    {
48        let backoff =
49            ExponentialBackoff { max_elapsed_time: config.retry_timeout, ..Default::default() };
50        let retry_count = AtomicU64::new(1);
51
52        let send_request = || {
53            let send_progress = send_progress.clone();
54            async {
55                debug!(num_attempt = retry_count.load(Ordering::SeqCst), "Sending request");
56
57                let stop = if let Some(retry_limit) = config.retry_limit {
58                    retry_count.fetch_add(1, Ordering::Relaxed) >= retry_limit
59                } else {
60                    false
61                };
62
63                // Turn errors into permanent errors when the retry limit is reached.
64                let error_type = |err: HttpError| {
65                    if stop {
66                        RetryError::Permanent(err)
67                    } else {
68                        let has_retry_limit = config.retry_limit.is_some();
69                        match err.retry_kind() {
70                            RetryKind::Transient { retry_after } => {
71                                RetryError::Transient { err, retry_after }
72                            }
73                            RetryKind::Permanent => RetryError::Permanent(err),
74                            RetryKind::NetworkFailure => {
75                                // If we ran into a network failure, only retry if there's some
76                                // retry limit associated to this request's configuration;
77                                // otherwise, we would end up running an infinite loop of network
78                                // requests in offline mode.
79                                if has_retry_limit {
80                                    RetryError::Transient { err, retry_after: None }
81                                } else {
82                                    RetryError::Permanent(err)
83                                }
84                            }
85                        }
86                    }
87                };
88
89                let response = send_request(&self.inner, &request, config.timeout, send_progress)
90                    .await
91                    .map_err(error_type)?;
92
93                let status_code = response.status();
94                let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX));
95                tracing::Span::current()
96                    .record("status", status_code.as_u16())
97                    .record("response_size", response_size.to_string_as(true));
98
99                // Record interesting headers. If you add more headers, ensure they're not
100                // confidential.
101                for (header_name, header_value) in response.headers() {
102                    let header_name = header_name.as_str().to_lowercase();
103
104                    // Header added in case of OIDC authentication failure, so we can correlate
105                    // failures with a Sentry event emitted by the OIDC authentication server.
106                    if header_name == "x-sentry-event-id" {
107                        tracing::Span::current()
108                            .record("sentry_event_id", header_value.to_str().unwrap_or("<???>"));
109                    }
110                }
111
112                R::IncomingResponse::try_from_http_response(response)
113                    .map_err(|e| error_type(HttpError::from(e)))
114            }
115        };
116
117        retry::<_, HttpError, _, _, _>(backoff, send_request).await
118    }
119}
120
121#[cfg(not(target_arch = "wasm32"))]
122#[derive(Clone, Debug)]
123pub(crate) struct HttpSettings {
124    pub(crate) disable_ssl_verification: bool,
125    pub(crate) proxy: Option<String>,
126    pub(crate) user_agent: Option<String>,
127    pub(crate) timeout: Duration,
128    pub(crate) additional_root_certificates: Vec<Certificate>,
129    pub(crate) disable_built_in_root_certificates: bool,
130}
131
132#[cfg(not(target_arch = "wasm32"))]
133impl Default for HttpSettings {
134    fn default() -> Self {
135        Self {
136            disable_ssl_verification: false,
137            proxy: None,
138            user_agent: None,
139            timeout: DEFAULT_REQUEST_TIMEOUT,
140            additional_root_certificates: Default::default(),
141            disable_built_in_root_certificates: false,
142        }
143    }
144}
145
146#[cfg(not(target_arch = "wasm32"))]
147impl HttpSettings {
148    /// Build a client with the specified configuration.
149    pub(crate) fn make_client(&self) -> Result<reqwest::Client, HttpError> {
150        let user_agent = self.user_agent.clone().unwrap_or_else(|| "matrix-rust-sdk".to_owned());
151        let mut http_client = reqwest::Client::builder()
152            .user_agent(user_agent)
153            .timeout(self.timeout)
154            // As recommended by BCP 195.
155            // See: https://datatracker.ietf.org/doc/bcp195/
156            .min_tls_version(tls::Version::TLS_1_2);
157
158        if self.disable_ssl_verification {
159            warn!("SSL verification disabled in the HTTP client!");
160            http_client = http_client.danger_accept_invalid_certs(true)
161        }
162
163        if !self.additional_root_certificates.is_empty() {
164            info!(
165                "Adding {} additional root certificates to the HTTP client",
166                self.additional_root_certificates.len()
167            );
168
169            for cert in &self.additional_root_certificates {
170                http_client = http_client.add_root_certificate(cert.clone());
171            }
172        }
173
174        if self.disable_built_in_root_certificates {
175            info!("Built-in root certificates disabled in the HTTP client.");
176            http_client = http_client.tls_built_in_root_certs(false);
177        }
178
179        if let Some(p) = &self.proxy {
180            info!(proxy_url = p, "Setting the proxy for the HTTP client");
181            http_client = http_client.proxy(reqwest::Proxy::all(p.as_str())?);
182        }
183
184        Ok(http_client.build()?)
185    }
186}
187
188pub(super) async fn send_request(
189    client: &reqwest::Client,
190    request: &http::Request<Bytes>,
191    timeout: Duration,
192    send_progress: SharedObservable<TransmissionProgress>,
193) -> Result<http::Response<Bytes>, HttpError> {
194    use std::convert::Infallible;
195
196    use futures_util::stream;
197
198    let request = clone_request(request);
199    let request = {
200        let mut request = if send_progress.subscriber_count() != 0 {
201            let content_length = request.body().len();
202            send_progress.update(|p| p.total += content_length);
203
204            // Make sure any concurrent futures in the same task get a chance
205            // to also add to the progress total before the first chunks are
206            // pulled out of the body stream.
207            tokio::task::yield_now().await;
208
209            let mut req = reqwest::Request::try_from(request.map(|body| {
210                let chunks = stream::iter(BytesChunks::new(body, 8192).map(
211                    move |chunk| -> Result<_, Infallible> {
212                        send_progress.update(|p| p.current += chunk.len());
213                        Ok(chunk)
214                    },
215                ));
216                reqwest::Body::wrap_stream(chunks)
217            }))?;
218
219            // When streaming the request, reqwest / hyper doesn't know how
220            // large the body is, so it doesn't set the content-length header
221            // (required by some servers). Set it manually.
222            req.headers_mut().insert(CONTENT_LENGTH, content_length.into());
223
224            req
225        } else {
226            reqwest::Request::try_from(request)?
227        };
228
229        *request.timeout_mut() = Some(timeout);
230        request
231    };
232
233    let response = client.execute(request).await?;
234    Ok(response_to_http_response(response).await?)
235}
236
237// Clones all request parts except the extensions which can't be cloned.
238// See also https://github.com/hyperium/http/issues/395
239fn clone_request(request: &http::Request<Bytes>) -> http::Request<Bytes> {
240    let mut builder = http::Request::builder()
241        .version(request.version())
242        .method(request.method())
243        .uri(request.uri());
244    *builder.headers_mut().unwrap() = request.headers().clone();
245    builder.body(request.body().clone()).unwrap()
246}
247
248struct BytesChunks {
249    bytes: Bytes,
250    size: usize,
251}
252
253impl BytesChunks {
254    fn new(bytes: Bytes, size: usize) -> Self {
255        assert_ne!(size, 0);
256        Self { bytes, size }
257    }
258}
259
260impl Iterator for BytesChunks {
261    type Item = Bytes;
262
263    fn next(&mut self) -> Option<Self::Item> {
264        if self.bytes.is_empty() {
265            None
266        } else if self.bytes.len() < self.size {
267            Some(mem::take(&mut self.bytes))
268        } else {
269            Some(self.bytes.split_to(self.size))
270        }
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use bytes::Bytes;
277
278    use super::BytesChunks;
279
280    #[test]
281    fn test_bytes_chunks() {
282        let bytes = Bytes::new();
283        assert!(BytesChunks::new(bytes, 1).collect::<Vec<_>>().is_empty());
284
285        let bytes = Bytes::from_iter([1, 2]);
286        assert_eq!(BytesChunks::new(bytes, 2).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
287
288        let bytes = Bytes::from_iter([1, 2]);
289        assert_eq!(BytesChunks::new(bytes, 3).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
290
291        let bytes = Bytes::from_iter([1, 2, 3]);
292        assert_eq!(
293            BytesChunks::new(bytes, 1).collect::<Vec<_>>(),
294            [Bytes::from_iter([1]), Bytes::from_iter([2]), Bytes::from_iter([3])]
295        );
296
297        let bytes = Bytes::from_iter([1, 2, 3]);
298        assert_eq!(
299            BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
300            [Bytes::from_iter([1, 2]), Bytes::from_iter([3])]
301        );
302
303        let bytes = Bytes::from_iter([1, 2, 3, 4]);
304        assert_eq!(
305            BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
306            [Bytes::from_iter([1, 2]), Bytes::from_iter([3, 4])]
307        );
308    }
309}