matrix_sdk/authentication/oauth/
http_client.rs

1// Copyright 2025 Kévin Commaille
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! HTTP client and helpers for making OAuth 2.0 requests.
16
17use matrix_sdk_base::BoxFuture;
18use oauth2::{
19    AsyncHttpClient, ErrorResponse, HttpClientError, HttpRequest, HttpResponse, RequestTokenError,
20};
21
22/// An HTTP client for making OAuth 2.0 requests.
23#[derive(Debug, Clone)]
24pub(super) struct OAuthHttpClient {
25    pub(super) inner: reqwest::Client,
26    /// Rewrite HTTPS requests to use HTTP instead.
27    ///
28    /// This is a workaround to bypass some checks that require an HTTPS URL,
29    /// but we can only mock HTTP URLs.
30    #[cfg(test)]
31    pub(super) insecure_rewrite_https_to_http: bool,
32}
33
34impl<'c> AsyncHttpClient<'c> for OAuthHttpClient {
35    type Error = HttpClientError<reqwest::Error>;
36
37    type Future = BoxFuture<'c, Result<HttpResponse, Self::Error>>;
38
39    fn call(&'c self, request: HttpRequest) -> Self::Future {
40        Box::pin(async move {
41            #[cfg(test)]
42            let request = if self.insecure_rewrite_https_to_http
43                && request.uri().scheme().is_some_and(|scheme| *scheme == http::uri::Scheme::HTTPS)
44            {
45                let mut request = request;
46
47                let mut uri_parts = request.uri().clone().into_parts();
48                uri_parts.scheme = Some(http::uri::Scheme::HTTP);
49                *request.uri_mut() = http::uri::Uri::from_parts(uri_parts)
50                    .expect("reconstructing URI from parts should work");
51
52                request
53            } else {
54                request
55            };
56
57            let response = self.inner.call(request).await?;
58
59            Ok(response)
60        })
61    }
62}
63
64/// Check the status code of the given HTTP response to identify errors.
65pub(super) fn check_http_response_status_code<T: ErrorResponse + 'static>(
66    http_response: &HttpResponse,
67) -> Result<(), RequestTokenError<HttpClientError<reqwest::Error>, T>> {
68    if http_response.status().as_u16() < 400 {
69        return Ok(());
70    }
71
72    let reason = http_response.body().as_slice();
73    let error = if reason.is_empty() {
74        RequestTokenError::Other("server returned an empty error response".to_owned())
75    } else {
76        match serde_json::from_slice(reason) {
77            Ok(error) => RequestTokenError::ServerResponse(error),
78            Err(error) => RequestTokenError::Other(error.to_string()),
79        }
80    };
81
82    Err(error)
83}
84
85/// Check that the server returned a response with a JSON `Content-Type`.
86pub(super) fn check_http_response_json_content_type<T: ErrorResponse + 'static>(
87    http_response: &HttpResponse,
88) -> Result<(), RequestTokenError<HttpClientError<reqwest::Error>, T>> {
89    let Some(content_type) = http_response.headers().get(http::header::CONTENT_TYPE) else {
90        return Ok(());
91    };
92
93    if content_type
94        .to_str()
95        // Check only the beginning of the content type, because there might be extra
96        // parameters, like a charset.
97        .is_ok_and(|ct| ct.to_lowercase().starts_with(mime::APPLICATION_JSON.essence_str()))
98    {
99        Ok(())
100    } else {
101        Err(RequestTokenError::Other(format!(
102            "unexpected response Content-Type: {content_type:?}, should be `{}`",
103            mime::APPLICATION_JSON
104        )))
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use assert_matches2::assert_matches;
111    use oauth2::{basic::BasicErrorResponse, RequestTokenError};
112
113    use super::{check_http_response_json_content_type, check_http_response_status_code};
114
115    #[test]
116    fn test_check_http_response_status_code() {
117        // OK
118        let response = http::Response::builder().status(200).body(Vec::<u8>::new()).unwrap();
119        assert_matches!(check_http_response_status_code::<BasicErrorResponse>(&response), Ok(()));
120
121        // Error without body.
122        let response = http::Response::builder().status(404).body(Vec::<u8>::new()).unwrap();
123        assert_matches!(
124            check_http_response_status_code::<BasicErrorResponse>(&response),
125            Err(RequestTokenError::Other(_))
126        );
127
128        // Error with invalid body.
129        let response =
130            http::Response::builder().status(404).body(b"invalid error format".to_vec()).unwrap();
131        assert_matches!(
132            check_http_response_status_code::<BasicErrorResponse>(&response),
133            Err(RequestTokenError::Other(_))
134        );
135
136        // Error with valid body.
137        let response = http::Response::builder()
138            .status(404)
139            .body(br#"{"error": "invalid_request"}"#.to_vec())
140            .unwrap();
141        assert_matches!(
142            check_http_response_status_code::<BasicErrorResponse>(&response),
143            Err(RequestTokenError::ServerResponse(_))
144        );
145    }
146
147    #[test]
148    fn test_check_http_response_json_content_type() {
149        // Valid content type.
150        let response = http::Response::builder()
151            .status(200)
152            .header(http::header::CONTENT_TYPE, "application/json")
153            .body(b"{}".to_vec())
154            .unwrap();
155        assert_matches!(
156            check_http_response_json_content_type::<BasicErrorResponse>(&response),
157            Ok(())
158        );
159
160        // Valid content type with charset.
161        let response = http::Response::builder()
162            .status(200)
163            .header(http::header::CONTENT_TYPE, "application/json; charset=utf-8")
164            .body(b"{}".to_vec())
165            .unwrap();
166        assert_matches!(
167            check_http_response_json_content_type::<BasicErrorResponse>(&response),
168            Ok(())
169        );
170
171        // Without content type.
172        let response = http::Response::builder().status(200).body(b"{}".to_vec()).unwrap();
173        assert_matches!(
174            check_http_response_json_content_type::<BasicErrorResponse>(&response),
175            Ok(())
176        );
177
178        // Wrong content type.
179        let response = http::Response::builder()
180            .status(200)
181            .header(http::header::CONTENT_TYPE, "text/html")
182            .body(b"<html><body><h1>HTML!</h1></body></html>".to_vec())
183            .unwrap();
184        assert_matches!(
185            check_http_response_json_content_type::<BasicErrorResponse>(&response),
186            Err(RequestTokenError::Other(_))
187        );
188    }
189}