matrix_sdk/client/builder/
homeserver_config.rsuse ruma::{
api::{
client::discovery::{discover_homeserver, get_supported_versions},
MatrixVersion,
},
OwnedServerName, ServerName,
};
use tracing::debug;
use url::Url;
use crate::{
config::RequestConfig, http_client::HttpClient, sanitize_server_name, ClientBuildError,
HttpError,
};
#[derive(Clone, Debug)]
pub(super) enum HomeserverConfig {
HomeserverUrl(String),
ServerName { server: OwnedServerName, protocol: UrlScheme },
ServerNameOrHomeserverUrl(String),
}
#[derive(Clone, Copy, Debug)]
pub(super) enum UrlScheme {
Http,
Https,
}
pub(super) struct HomeserverDiscoveryResult {
pub server: Option<Url>,
pub homeserver: Url,
pub well_known: Option<discover_homeserver::Response>,
pub supported_versions: Option<get_supported_versions::Response>,
}
impl HomeserverConfig {
pub async fn discover(
&self,
http_client: &HttpClient,
) -> Result<HomeserverDiscoveryResult, ClientBuildError> {
Ok(match self {
Self::HomeserverUrl(url) => {
let homeserver = Url::parse(url)?;
HomeserverDiscoveryResult {
server: None, homeserver,
well_known: None,
supported_versions: None,
}
}
Self::ServerName { server, protocol } => {
let (server, well_known) =
discover_homeserver(server, protocol, http_client).await?;
HomeserverDiscoveryResult {
server: Some(server),
homeserver: Url::parse(&well_known.homeserver.base_url)?,
well_known: Some(well_known),
supported_versions: None,
}
}
Self::ServerNameOrHomeserverUrl(server_name_or_url) => {
let (server, homeserver, well_known, supported_versions) =
discover_homeserver_from_server_name_or_url(
server_name_or_url.to_owned(),
http_client,
)
.await?;
HomeserverDiscoveryResult { server, homeserver, well_known, supported_versions }
}
})
}
}
async fn discover_homeserver_from_server_name_or_url(
mut server_name_or_url: String,
http_client: &HttpClient,
) -> Result<
(
Option<Url>,
Url,
Option<discover_homeserver::Response>,
Option<get_supported_versions::Response>,
),
ClientBuildError,
> {
let mut discovery_error: Option<ClientBuildError> = None;
let sanitize_result = sanitize_server_name(&server_name_or_url);
if let Ok(server_name) = sanitize_result.as_ref() {
let protocol = if server_name_or_url.starts_with("http://") {
UrlScheme::Http
} else {
UrlScheme::Https
};
match discover_homeserver(server_name, &protocol, http_client).await {
Ok((server, well_known)) => {
return Ok((
Some(server),
Url::parse(&well_known.homeserver.base_url)?,
Some(well_known),
None,
));
}
Err(e) => {
debug!(error = %e, "Well-known discovery failed.");
discovery_error = Some(e);
server_name_or_url = match protocol {
UrlScheme::Http => format!("http://{server_name}"),
UrlScheme::Https => format!("https://{server_name}"),
}
}
}
}
if let Ok(homeserver_url) = Url::parse(&server_name_or_url) {
match get_supported_versions(&homeserver_url, http_client).await {
Ok(response) => {
return Ok((None, homeserver_url, None, Some(response)));
}
Err(e) => {
debug!(error = %e, "Checking supported versions failed.");
}
}
}
Err(discovery_error.unwrap_or(ClientBuildError::InvalidServerName))
}
async fn discover_homeserver(
server_name: &ServerName,
protocol: &UrlScheme,
http_client: &HttpClient,
) -> Result<(Url, discover_homeserver::Response), ClientBuildError> {
debug!("Trying to discover the homeserver");
let server = Url::parse(&match protocol {
UrlScheme::Http => format!("http://{server_name}"),
UrlScheme::Https => format!("https://{server_name}"),
})?;
let well_known = http_client
.send(
discover_homeserver::Request::new(),
Some(RequestConfig::short_retry()),
server.to_string(),
None,
&[MatrixVersion::V1_0],
Default::default(),
)
.await
.map_err(|e| match e {
HttpError::Api(err) => ClientBuildError::AutoDiscovery(err),
err => ClientBuildError::Http(err),
})?;
debug!(homeserver_url = well_known.homeserver.base_url, "Discovered the homeserver");
Ok((server, well_known))
}
pub(super) async fn get_supported_versions(
homeserver_url: &Url,
http_client: &HttpClient,
) -> Result<get_supported_versions::Response, HttpError> {
http_client
.send(
get_supported_versions::Request::new(),
Some(RequestConfig::short_retry()),
homeserver_url.to_string(),
None,
&[MatrixVersion::V1_0],
Default::default(),
)
.await
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use matrix_sdk_test::async_test;
use ruma::OwnedServerName;
use serde_json::json;
use wiremock::{
matchers::{method, path},
Mock, MockServer, ResponseTemplate,
};
use super::*;
use crate::http_client::HttpSettings;
#[async_test]
async fn test_url() {
let http_client =
HttpClient::new(HttpSettings::default().make_client().unwrap(), Default::default());
let result = HomeserverConfig::HomeserverUrl("https://matrix-client.matrix.org".to_owned())
.discover(&http_client)
.await
.unwrap();
assert_eq!(result.server, None);
assert_eq!(result.homeserver, Url::parse("https://matrix-client.matrix.org").unwrap());
assert!(result.well_known.is_none());
assert!(result.supported_versions.is_none());
}
#[async_test]
async fn test_server_name() {
let http_client =
HttpClient::new(HttpSettings::default().make_client().unwrap(), Default::default());
let server = MockServer::start().await;
let homeserver = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/.well-known/matrix/client"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"m.homeserver": {
"base_url": homeserver.uri(),
},
})))
.mount(&server)
.await;
let result = HomeserverConfig::ServerName {
server: OwnedServerName::try_from(server.address().to_string()).unwrap(),
protocol: UrlScheme::Http,
}
.discover(&http_client)
.await
.unwrap();
assert_eq!(result.server, Some(Url::parse(&server.uri()).unwrap()));
assert_eq!(result.homeserver, Url::parse(&homeserver.uri()).unwrap());
assert!(result.well_known.is_some());
assert!(result.supported_versions.is_none());
}
#[async_test]
async fn test_server_name_or_url_with_name() {
let http_client =
HttpClient::new(HttpSettings::default().make_client().unwrap(), Default::default());
let server = MockServer::start().await;
let homeserver = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/.well-known/matrix/client"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"m.homeserver": {
"base_url": homeserver.uri(),
},
})))
.mount(&server)
.await;
let result = HomeserverConfig::ServerNameOrHomeserverUrl(server.uri().to_string())
.discover(&http_client)
.await
.unwrap();
assert_eq!(result.server, Some(Url::parse(&server.uri()).unwrap()));
assert_eq!(result.homeserver, Url::parse(&homeserver.uri()).unwrap());
assert!(result.well_known.is_some());
assert!(result.supported_versions.is_none());
}
#[async_test]
async fn test_server_name_or_url_with_url() {
let http_client =
HttpClient::new(HttpSettings::default().make_client().unwrap(), Default::default());
let homeserver = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/_matrix/client/versions"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"versions": [],
})))
.mount(&homeserver)
.await;
let result = HomeserverConfig::ServerNameOrHomeserverUrl(homeserver.uri().to_string())
.discover(&http_client)
.await
.unwrap();
assert!(result.server.is_none());
assert_eq!(result.homeserver, Url::parse(&homeserver.uri()).unwrap());
assert!(result.well_known.is_none());
assert!(result.supported_versions.is_some());
}
}