1use std::{
16 fmt::Debug,
17 mem,
18 sync::atomic::{AtomicU64, Ordering},
19 time::Duration,
20};
21
22use backoff::{future::retry, Error as RetryError, ExponentialBackoff};
23use bytes::Bytes;
24use bytesize::ByteSize;
25use eyeball::SharedObservable;
26use http::header::CONTENT_LENGTH;
27use reqwest::{tls, Certificate};
28use ruma::api::{error::FromHttpResponseError, IncomingResponse, OutgoingRequest};
29use tracing::{debug, info, warn};
30
31use super::{response_to_http_response, HttpClient, TransmissionProgress, DEFAULT_REQUEST_TIMEOUT};
32use crate::{
33 config::RequestConfig,
34 error::{HttpError, RetryKind},
35};
36
37impl HttpClient {
38 pub(super) async fn send_request<R>(
39 &self,
40 request: http::Request<Bytes>,
41 config: RequestConfig,
42 send_progress: SharedObservable<TransmissionProgress>,
43 ) -> Result<R::IncomingResponse, HttpError>
44 where
45 R: OutgoingRequest + Debug,
46 HttpError: From<FromHttpResponseError<R::EndpointError>>,
47 {
48 let backoff =
49 ExponentialBackoff { max_elapsed_time: config.retry_timeout, ..Default::default() };
50 let retry_count = AtomicU64::new(1);
51
52 let send_request = || {
53 let send_progress = send_progress.clone();
54 async {
55 debug!(num_attempt = retry_count.load(Ordering::SeqCst), "Sending request");
56
57 let stop = if let Some(retry_limit) = config.retry_limit {
58 retry_count.fetch_add(1, Ordering::Relaxed) >= retry_limit
59 } else {
60 false
61 };
62
63 let error_type = |err: HttpError| {
65 if stop {
66 RetryError::Permanent(err)
67 } else {
68 let has_retry_limit = config.retry_limit.is_some();
69 match err.retry_kind() {
70 RetryKind::Transient { retry_after } => {
71 RetryError::Transient { err, retry_after }
72 }
73 RetryKind::Permanent => RetryError::Permanent(err),
74 RetryKind::NetworkFailure => {
75 if has_retry_limit {
80 RetryError::Transient { err, retry_after: None }
81 } else {
82 RetryError::Permanent(err)
83 }
84 }
85 }
86 }
87 };
88
89 let response = send_request(&self.inner, &request, config.timeout, send_progress)
90 .await
91 .map_err(error_type)?;
92
93 let status_code = response.status();
94 let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX));
95 tracing::Span::current()
96 .record("status", status_code.as_u16())
97 .record("response_size", response_size.to_string_as(true));
98
99 for (header_name, header_value) in response.headers() {
102 let header_name = header_name.as_str().to_lowercase();
103
104 if header_name == "x-sentry-event-id" {
107 tracing::Span::current()
108 .record("sentry_event_id", header_value.to_str().unwrap_or("<???>"));
109 }
110 }
111
112 R::IncomingResponse::try_from_http_response(response)
113 .map_err(|e| error_type(HttpError::from(e)))
114 }
115 };
116
117 retry::<_, HttpError, _, _, _>(backoff, send_request).await
118 }
119}
120
121#[cfg(not(target_arch = "wasm32"))]
122#[derive(Clone, Debug)]
123pub(crate) struct HttpSettings {
124 pub(crate) disable_ssl_verification: bool,
125 pub(crate) proxy: Option<String>,
126 pub(crate) user_agent: Option<String>,
127 pub(crate) timeout: Duration,
128 pub(crate) additional_root_certificates: Vec<Certificate>,
129 pub(crate) disable_built_in_root_certificates: bool,
130}
131
132#[cfg(not(target_arch = "wasm32"))]
133impl Default for HttpSettings {
134 fn default() -> Self {
135 Self {
136 disable_ssl_verification: false,
137 proxy: None,
138 user_agent: None,
139 timeout: DEFAULT_REQUEST_TIMEOUT,
140 additional_root_certificates: Default::default(),
141 disable_built_in_root_certificates: false,
142 }
143 }
144}
145
146#[cfg(not(target_arch = "wasm32"))]
147impl HttpSettings {
148 pub(crate) fn make_client(&self) -> Result<reqwest::Client, HttpError> {
150 let user_agent = self.user_agent.clone().unwrap_or_else(|| "matrix-rust-sdk".to_owned());
151 let mut http_client = reqwest::Client::builder()
152 .user_agent(user_agent)
153 .timeout(self.timeout)
154 .min_tls_version(tls::Version::TLS_1_2);
157
158 if self.disable_ssl_verification {
159 warn!("SSL verification disabled in the HTTP client!");
160 http_client = http_client.danger_accept_invalid_certs(true)
161 }
162
163 if !self.additional_root_certificates.is_empty() {
164 info!(
165 "Adding {} additional root certificates to the HTTP client",
166 self.additional_root_certificates.len()
167 );
168
169 for cert in &self.additional_root_certificates {
170 http_client = http_client.add_root_certificate(cert.clone());
171 }
172 }
173
174 if self.disable_built_in_root_certificates {
175 info!("Built-in root certificates disabled in the HTTP client.");
176 http_client = http_client.tls_built_in_root_certs(false);
177 }
178
179 if let Some(p) = &self.proxy {
180 info!(proxy_url = p, "Setting the proxy for the HTTP client");
181 http_client = http_client.proxy(reqwest::Proxy::all(p.as_str())?);
182 }
183
184 Ok(http_client.build()?)
185 }
186}
187
188pub(super) async fn send_request(
189 client: &reqwest::Client,
190 request: &http::Request<Bytes>,
191 timeout: Duration,
192 send_progress: SharedObservable<TransmissionProgress>,
193) -> Result<http::Response<Bytes>, HttpError> {
194 use std::convert::Infallible;
195
196 use futures_util::stream;
197
198 let request = clone_request(request);
199 let request = {
200 let mut request = if send_progress.subscriber_count() != 0 {
201 let content_length = request.body().len();
202 send_progress.update(|p| p.total += content_length);
203
204 tokio::task::yield_now().await;
208
209 let mut req = reqwest::Request::try_from(request.map(|body| {
210 let chunks = stream::iter(BytesChunks::new(body, 8192).map(
211 move |chunk| -> Result<_, Infallible> {
212 send_progress.update(|p| p.current += chunk.len());
213 Ok(chunk)
214 },
215 ));
216 reqwest::Body::wrap_stream(chunks)
217 }))?;
218
219 req.headers_mut().insert(CONTENT_LENGTH, content_length.into());
223
224 req
225 } else {
226 reqwest::Request::try_from(request)?
227 };
228
229 *request.timeout_mut() = Some(timeout);
230 request
231 };
232
233 let response = client.execute(request).await?;
234 Ok(response_to_http_response(response).await?)
235}
236
237fn clone_request(request: &http::Request<Bytes>) -> http::Request<Bytes> {
240 let mut builder = http::Request::builder()
241 .version(request.version())
242 .method(request.method())
243 .uri(request.uri());
244 *builder.headers_mut().unwrap() = request.headers().clone();
245 builder.body(request.body().clone()).unwrap()
246}
247
248struct BytesChunks {
249 bytes: Bytes,
250 size: usize,
251}
252
253impl BytesChunks {
254 fn new(bytes: Bytes, size: usize) -> Self {
255 assert_ne!(size, 0);
256 Self { bytes, size }
257 }
258}
259
260impl Iterator for BytesChunks {
261 type Item = Bytes;
262
263 fn next(&mut self) -> Option<Self::Item> {
264 if self.bytes.is_empty() {
265 None
266 } else if self.bytes.len() < self.size {
267 Some(mem::take(&mut self.bytes))
268 } else {
269 Some(self.bytes.split_to(self.size))
270 }
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use bytes::Bytes;
277
278 use super::BytesChunks;
279
280 #[test]
281 fn test_bytes_chunks() {
282 let bytes = Bytes::new();
283 assert!(BytesChunks::new(bytes, 1).collect::<Vec<_>>().is_empty());
284
285 let bytes = Bytes::from_iter([1, 2]);
286 assert_eq!(BytesChunks::new(bytes, 2).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
287
288 let bytes = Bytes::from_iter([1, 2]);
289 assert_eq!(BytesChunks::new(bytes, 3).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
290
291 let bytes = Bytes::from_iter([1, 2, 3]);
292 assert_eq!(
293 BytesChunks::new(bytes, 1).collect::<Vec<_>>(),
294 [Bytes::from_iter([1]), Bytes::from_iter([2]), Bytes::from_iter([3])]
295 );
296
297 let bytes = Bytes::from_iter([1, 2, 3]);
298 assert_eq!(
299 BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
300 [Bytes::from_iter([1, 2]), Bytes::from_iter([3])]
301 );
302
303 let bytes = Bytes::from_iter([1, 2, 3, 4]);
304 assert_eq!(
305 BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
306 [Bytes::from_iter([1, 2]), Bytes::from_iter([3, 4])]
307 );
308 }
309}