matrix_sdk/authentication/oauth/
http_client.rs
1use matrix_sdk_base::BoxFuture;
18use oauth2::{
19 AsyncHttpClient, ErrorResponse, HttpClientError, HttpRequest, HttpResponse, RequestTokenError,
20};
21
22#[derive(Debug, Clone)]
24pub(super) struct OAuthHttpClient {
25 pub(super) inner: reqwest::Client,
26 #[cfg(test)]
31 pub(super) insecure_rewrite_https_to_http: bool,
32}
33
34impl<'c> AsyncHttpClient<'c> for OAuthHttpClient {
35 type Error = HttpClientError<reqwest::Error>;
36
37 type Future = BoxFuture<'c, Result<HttpResponse, Self::Error>>;
38
39 fn call(&'c self, request: HttpRequest) -> Self::Future {
40 Box::pin(async move {
41 #[cfg(test)]
42 let request = if self.insecure_rewrite_https_to_http
43 && request.uri().scheme().is_some_and(|scheme| *scheme == http::uri::Scheme::HTTPS)
44 {
45 let mut request = request;
46
47 let mut uri_parts = request.uri().clone().into_parts();
48 uri_parts.scheme = Some(http::uri::Scheme::HTTP);
49 *request.uri_mut() = http::uri::Uri::from_parts(uri_parts)
50 .expect("reconstructing URI from parts should work");
51
52 request
53 } else {
54 request
55 };
56
57 let response = self.inner.call(request).await?;
58
59 Ok(response)
60 })
61 }
62}
63
64pub(super) fn check_http_response_status_code<T: ErrorResponse + 'static>(
66 http_response: &HttpResponse,
67) -> Result<(), RequestTokenError<HttpClientError<reqwest::Error>, T>> {
68 if http_response.status().as_u16() < 400 {
69 return Ok(());
70 }
71
72 let reason = http_response.body().as_slice();
73 let error = if reason.is_empty() {
74 RequestTokenError::Other("server returned an empty error response".to_owned())
75 } else {
76 match serde_json::from_slice(reason) {
77 Ok(error) => RequestTokenError::ServerResponse(error),
78 Err(error) => RequestTokenError::Other(error.to_string()),
79 }
80 };
81
82 Err(error)
83}
84
85pub(super) fn check_http_response_json_content_type<T: ErrorResponse + 'static>(
87 http_response: &HttpResponse,
88) -> Result<(), RequestTokenError<HttpClientError<reqwest::Error>, T>> {
89 let Some(content_type) = http_response.headers().get(http::header::CONTENT_TYPE) else {
90 return Ok(());
91 };
92
93 if content_type
94 .to_str()
95 .is_ok_and(|ct| ct.to_lowercase().starts_with(mime::APPLICATION_JSON.essence_str()))
98 {
99 Ok(())
100 } else {
101 Err(RequestTokenError::Other(format!(
102 "unexpected response Content-Type: {content_type:?}, should be `{}`",
103 mime::APPLICATION_JSON
104 )))
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use assert_matches2::assert_matches;
111 use oauth2::{basic::BasicErrorResponse, RequestTokenError};
112
113 use super::{check_http_response_json_content_type, check_http_response_status_code};
114
115 #[test]
116 fn test_check_http_response_status_code() {
117 let response = http::Response::builder().status(200).body(Vec::<u8>::new()).unwrap();
119 assert_matches!(check_http_response_status_code::<BasicErrorResponse>(&response), Ok(()));
120
121 let response = http::Response::builder().status(404).body(Vec::<u8>::new()).unwrap();
123 assert_matches!(
124 check_http_response_status_code::<BasicErrorResponse>(&response),
125 Err(RequestTokenError::Other(_))
126 );
127
128 let response =
130 http::Response::builder().status(404).body(b"invalid error format".to_vec()).unwrap();
131 assert_matches!(
132 check_http_response_status_code::<BasicErrorResponse>(&response),
133 Err(RequestTokenError::Other(_))
134 );
135
136 let response = http::Response::builder()
138 .status(404)
139 .body(br#"{"error": "invalid_request"}"#.to_vec())
140 .unwrap();
141 assert_matches!(
142 check_http_response_status_code::<BasicErrorResponse>(&response),
143 Err(RequestTokenError::ServerResponse(_))
144 );
145 }
146
147 #[test]
148 fn test_check_http_response_json_content_type() {
149 let response = http::Response::builder()
151 .status(200)
152 .header(http::header::CONTENT_TYPE, "application/json")
153 .body(b"{}".to_vec())
154 .unwrap();
155 assert_matches!(
156 check_http_response_json_content_type::<BasicErrorResponse>(&response),
157 Ok(())
158 );
159
160 let response = http::Response::builder()
162 .status(200)
163 .header(http::header::CONTENT_TYPE, "application/json; charset=utf-8")
164 .body(b"{}".to_vec())
165 .unwrap();
166 assert_matches!(
167 check_http_response_json_content_type::<BasicErrorResponse>(&response),
168 Ok(())
169 );
170
171 let response = http::Response::builder().status(200).body(b"{}".to_vec()).unwrap();
173 assert_matches!(
174 check_http_response_json_content_type::<BasicErrorResponse>(&response),
175 Ok(())
176 );
177
178 let response = http::Response::builder()
180 .status(200)
181 .header(http::header::CONTENT_TYPE, "text/html")
182 .body(b"<html><body><h1>HTML!</h1></body></html>".to_vec())
183 .unwrap();
184 assert_matches!(
185 check_http_response_json_content_type::<BasicErrorResponse>(&response),
186 Err(RequestTokenError::Other(_))
187 );
188 }
189}