matrix_sdk/client/builder/
homeserver_config.rs

1// Copyright 2024 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use ruma::{
16    api::{
17        client::discovery::{discover_homeserver, get_supported_versions},
18        MatrixVersion,
19    },
20    OwnedServerName, ServerName,
21};
22use tracing::debug;
23use url::Url;
24
25use crate::{
26    config::RequestConfig, http_client::HttpClient, sanitize_server_name, ClientBuildError,
27    HttpError,
28};
29
30/// Configuration for the homeserver.
31#[derive(Clone, Debug)]
32pub(super) enum HomeserverConfig {
33    /// A homeserver name URL, including the protocol.
34    HomeserverUrl(String),
35
36    /// A server name, with the protocol put apart.
37    ServerName { server: OwnedServerName, protocol: UrlScheme },
38
39    /// A server name with or without the protocol (it will fallback to `https`
40    /// if absent), or a homeserver URL.
41    ServerNameOrHomeserverUrl(String),
42}
43
44/// A simple helper to represent `http` or `https` in a URL.
45#[derive(Clone, Copy, Debug)]
46pub(super) enum UrlScheme {
47    Http,
48    Https,
49}
50
51/// The `Ok` result for `HomeserverConfig::discover`.
52pub(super) struct HomeserverDiscoveryResult {
53    pub server: Option<Url>,
54    pub homeserver: Url,
55    pub supported_versions: Option<get_supported_versions::Response>,
56}
57
58impl HomeserverConfig {
59    pub async fn discover(
60        &self,
61        http_client: &HttpClient,
62    ) -> Result<HomeserverDiscoveryResult, ClientBuildError> {
63        Ok(match self {
64            Self::HomeserverUrl(url) => {
65                let homeserver = Url::parse(url)?;
66
67                HomeserverDiscoveryResult {
68                    server: None, // We can't know the `server` if we only have a `homeserver`.
69                    homeserver,
70                    supported_versions: None,
71                }
72            }
73
74            Self::ServerName { server, protocol } => {
75                let (server, well_known) =
76                    discover_homeserver(server, protocol, http_client).await?;
77
78                HomeserverDiscoveryResult {
79                    server: Some(server),
80                    homeserver: Url::parse(&well_known.homeserver.base_url)?,
81                    supported_versions: None,
82                }
83            }
84
85            Self::ServerNameOrHomeserverUrl(server_name_or_url) => {
86                let (server, homeserver, supported_versions) =
87                    discover_homeserver_from_server_name_or_url(
88                        server_name_or_url.to_owned(),
89                        http_client,
90                    )
91                    .await?;
92
93                HomeserverDiscoveryResult { server, homeserver, supported_versions }
94            }
95        })
96    }
97}
98
99/// Discovers a homeserver from a server name or a URL.
100///
101/// Tries well-known discovery and checking if the URL points to a homeserver.
102async fn discover_homeserver_from_server_name_or_url(
103    mut server_name_or_url: String,
104    http_client: &HttpClient,
105) -> Result<(Option<Url>, Url, Option<get_supported_versions::Response>), ClientBuildError> {
106    let mut discovery_error: Option<ClientBuildError> = None;
107
108    // Attempt discovery as a server name first.
109    let sanitize_result = sanitize_server_name(&server_name_or_url);
110
111    if let Ok(server_name) = sanitize_result.as_ref() {
112        let protocol = if server_name_or_url.starts_with("http://") {
113            UrlScheme::Http
114        } else {
115            UrlScheme::Https
116        };
117
118        match discover_homeserver(server_name, &protocol, http_client).await {
119            Ok((server, well_known)) => {
120                return Ok((Some(server), Url::parse(&well_known.homeserver.base_url)?, None));
121            }
122            Err(e) => {
123                debug!(error = %e, "Well-known discovery failed.");
124                discovery_error = Some(e);
125
126                // Check if the server name points to a homeserver.
127                server_name_or_url = match protocol {
128                    UrlScheme::Http => format!("http://{server_name}"),
129                    UrlScheme::Https => format!("https://{server_name}"),
130                }
131            }
132        }
133    }
134
135    // When discovery fails, or the input isn't a valid server name, fallback to
136    // trying a homeserver URL.
137    if let Ok(homeserver_url) = Url::parse(&server_name_or_url) {
138        // Make sure the URL is definitely for a homeserver.
139        match get_supported_versions(&homeserver_url, http_client).await {
140            Ok(response) => {
141                return Ok((None, homeserver_url, Some(response)));
142            }
143            Err(e) => {
144                debug!(error = %e, "Checking supported versions failed.");
145            }
146        }
147    }
148
149    Err(discovery_error.unwrap_or(ClientBuildError::InvalidServerName))
150}
151
152/// Discovers a homeserver by looking up the well-known at the supplied server
153/// name.
154async fn discover_homeserver(
155    server_name: &ServerName,
156    protocol: &UrlScheme,
157    http_client: &HttpClient,
158) -> Result<(Url, discover_homeserver::Response), ClientBuildError> {
159    debug!("Trying to discover the homeserver");
160
161    let server = Url::parse(&match protocol {
162        UrlScheme::Http => format!("http://{server_name}"),
163        UrlScheme::Https => format!("https://{server_name}"),
164    })?;
165
166    let well_known = http_client
167        .send(
168            discover_homeserver::Request::new(),
169            Some(RequestConfig::short_retry()),
170            server.to_string(),
171            None,
172            &[MatrixVersion::V1_0],
173            Default::default(),
174        )
175        .await
176        .map_err(|e| match e {
177            HttpError::Api(err) => ClientBuildError::AutoDiscovery(err),
178            err => ClientBuildError::Http(err),
179        })?;
180
181    debug!(homeserver_url = well_known.homeserver.base_url, "Discovered the homeserver");
182
183    Ok((server, well_known))
184}
185
186pub(super) async fn get_supported_versions(
187    homeserver_url: &Url,
188    http_client: &HttpClient,
189) -> Result<get_supported_versions::Response, HttpError> {
190    http_client
191        .send(
192            get_supported_versions::Request::new(),
193            Some(RequestConfig::short_retry()),
194            homeserver_url.to_string(),
195            None,
196            &[MatrixVersion::V1_0],
197            Default::default(),
198        )
199        .await
200}
201
202#[cfg(all(test, not(target_arch = "wasm32")))]
203mod tests {
204    use matrix_sdk_test::async_test;
205    use ruma::OwnedServerName;
206    use serde_json::json;
207    use wiremock::{
208        matchers::{method, path},
209        Mock, MockServer, ResponseTemplate,
210    };
211
212    use super::*;
213    use crate::http_client::HttpSettings;
214
215    #[async_test]
216    async fn test_url() {
217        let http_client =
218            HttpClient::new(HttpSettings::default().make_client().unwrap(), Default::default());
219
220        let result = HomeserverConfig::HomeserverUrl("https://matrix-client.matrix.org".to_owned())
221            .discover(&http_client)
222            .await
223            .unwrap();
224
225        assert_eq!(result.server, None);
226        assert_eq!(result.homeserver, Url::parse("https://matrix-client.matrix.org").unwrap());
227        assert!(result.supported_versions.is_none());
228    }
229
230    #[async_test]
231    async fn test_server_name() {
232        let http_client =
233            HttpClient::new(HttpSettings::default().make_client().unwrap(), Default::default());
234
235        let server = MockServer::start().await;
236        let homeserver = MockServer::start().await;
237
238        Mock::given(method("GET"))
239            .and(path("/.well-known/matrix/client"))
240            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
241                "m.homeserver": {
242                    "base_url": homeserver.uri(),
243                },
244            })))
245            .mount(&server)
246            .await;
247
248        let result = HomeserverConfig::ServerName {
249            server: OwnedServerName::try_from(server.address().to_string()).unwrap(),
250            protocol: UrlScheme::Http,
251        }
252        .discover(&http_client)
253        .await
254        .unwrap();
255
256        assert_eq!(result.server, Some(Url::parse(&server.uri()).unwrap()));
257        assert_eq!(result.homeserver, Url::parse(&homeserver.uri()).unwrap());
258        assert!(result.supported_versions.is_none());
259    }
260
261    #[async_test]
262    async fn test_server_name_or_url_with_name() {
263        let http_client =
264            HttpClient::new(HttpSettings::default().make_client().unwrap(), Default::default());
265
266        let server = MockServer::start().await;
267        let homeserver = MockServer::start().await;
268
269        Mock::given(method("GET"))
270            .and(path("/.well-known/matrix/client"))
271            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
272                "m.homeserver": {
273                    "base_url": homeserver.uri(),
274                },
275            })))
276            .mount(&server)
277            .await;
278
279        let result = HomeserverConfig::ServerNameOrHomeserverUrl(server.uri().to_string())
280            .discover(&http_client)
281            .await
282            .unwrap();
283
284        assert_eq!(result.server, Some(Url::parse(&server.uri()).unwrap()));
285        assert_eq!(result.homeserver, Url::parse(&homeserver.uri()).unwrap());
286        assert!(result.supported_versions.is_none());
287    }
288
289    #[async_test]
290    async fn test_server_name_or_url_with_url() {
291        let http_client =
292            HttpClient::new(HttpSettings::default().make_client().unwrap(), Default::default());
293
294        let homeserver = MockServer::start().await;
295
296        Mock::given(method("GET"))
297            .and(path("/_matrix/client/versions"))
298            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
299                "versions": [],
300            })))
301            .mount(&homeserver)
302            .await;
303
304        let result = HomeserverConfig::ServerNameOrHomeserverUrl(homeserver.uri().to_string())
305            .discover(&http_client)
306            .await
307            .unwrap();
308
309        assert!(result.server.is_none());
310        assert_eq!(result.homeserver, Url::parse(&homeserver.uri()).unwrap());
311        assert!(result.supported_versions.is_some());
312    }
313}