use axum::{extract::State, response::IntoResponse, Json};
use hyper::StatusCode;
use mas_axum_utils::sentry::SentryEventID;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_keystore::Encrypter;
use mas_policy::{Policy, Violation};
use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRepository, BoxRng};
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
registration::{
ClientMetadata, ClientMetadataVerificationError, ClientRegistrationResponse, Localized,
},
};
use psl::Psl;
use rand::distributions::{Alphanumeric, DistString};
use thiserror::Error;
use tracing::info;
use url::Url;
use crate::impl_from_error_for_route;
#[derive(Debug, Error)]
pub(crate) enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync>),
#[error(transparent)]
JsonExtract(#[from] axum::extract::rejection::JsonRejection),
#[error("invalid client metadata")]
InvalidClientMetadata(#[from] ClientMetadataVerificationError),
#[error("{0} is a public suffix, not a valid domain")]
UrlIsPublicSuffix(&'static str),
#[error("denied by the policy: {0:?}")]
PolicyDenied(Vec<Violation>),
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::EvaluationError);
impl_from_error_for_route!(mas_keystore::aead::Error);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let response = match self {
Self::Internal(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ClientError::from(ClientErrorCode::ServerError)),
)
.into_response(),
Self::JsonExtract(axum::extract::rejection::JsonRejection::JsonDataError(e)) => (
StatusCode::BAD_REQUEST,
Json(
ClientError::from(ClientErrorCode::InvalidClientMetadata)
.with_description(e.to_string()),
),
)
.into_response(),
Self::JsonExtract(_) => (
StatusCode::BAD_REQUEST,
Json(ClientError::from(ClientErrorCode::InvalidRequest)),
)
.into_response(),
Self::InvalidClientMetadata(
ClientMetadataVerificationError::MissingRedirectUris
| ClientMetadataVerificationError::RedirectUriWithFragment(_),
) => (
StatusCode::BAD_REQUEST,
Json(ClientError::from(ClientErrorCode::InvalidRedirectUri)),
)
.into_response(),
Self::InvalidClientMetadata(e) => (
StatusCode::BAD_REQUEST,
Json(
ClientError::from(ClientErrorCode::InvalidClientMetadata)
.with_description(e.to_string()),
),
)
.into_response(),
Self::UrlIsPublicSuffix("redirect_uri") => (
StatusCode::BAD_REQUEST,
Json(
ClientError::from(ClientErrorCode::InvalidRedirectUri)
.with_description("redirect_uri is not using a valid domain".to_owned()),
),
)
.into_response(),
Self::UrlIsPublicSuffix(field) => (
StatusCode::BAD_REQUEST,
Json(
ClientError::from(ClientErrorCode::InvalidClientMetadata)
.with_description(format!("{field} is not using a valid domain")),
),
)
.into_response(),
Self::PolicyDenied(violations) => {
let code = if violations.iter().any(|v| v.msg.contains("redirect_uri")) {
ClientErrorCode::InvalidRedirectUri
} else {
ClientErrorCode::InvalidClientMetadata
};
let collected = &violations
.iter()
.map(|v| v.msg.clone())
.collect::<Vec<String>>();
let joined = collected.join("; ");
(
StatusCode::BAD_REQUEST,
Json(ClientError::from(code).with_description(joined)),
)
.into_response()
}
};
(SentryEventID::from(event_id), response).into_response()
}
}
fn host_is_public_suffix(url: &Url) -> bool {
let host = url.host_str().unwrap_or_default().as_bytes();
let Some(suffix) = psl::List.suffix(host) else {
return false;
};
if !suffix.is_known() {
return false;
}
if host.len() <= suffix.as_bytes().len() + 1 {
return true;
}
false
}
fn localised_url_has_public_suffix(url: &Localized<Url>) -> bool {
url.iter().any(|(_lang, url)| host_is_public_suffix(url))
}
#[tracing::instrument(name = "handlers.oauth2.registration.post", skip_all, err)]
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
mut policy: Policy,
State(encrypter): State<Encrypter>,
body: Result<Json<ClientMetadata>, axum::extract::rejection::JsonRejection>,
) -> Result<impl IntoResponse, RouteError> {
let Json(body) = body?;
info!(?body, "Client registration");
let metadata = body.validate()?;
if let Some(client_uri) = &metadata.client_uri {
if localised_url_has_public_suffix(client_uri) {
return Err(RouteError::UrlIsPublicSuffix("client_uri"));
}
}
if let Some(logo_uri) = &metadata.logo_uri {
if localised_url_has_public_suffix(logo_uri) {
return Err(RouteError::UrlIsPublicSuffix("logo_uri"));
}
}
if let Some(policy_uri) = &metadata.policy_uri {
if localised_url_has_public_suffix(policy_uri) {
return Err(RouteError::UrlIsPublicSuffix("policy_uri"));
}
}
if let Some(tos_uri) = &metadata.tos_uri {
if localised_url_has_public_suffix(tos_uri) {
return Err(RouteError::UrlIsPublicSuffix("tos_uri"));
}
}
if let Some(initiate_login_uri) = &metadata.initiate_login_uri {
if host_is_public_suffix(initiate_login_uri) {
return Err(RouteError::UrlIsPublicSuffix("initiate_login_uri"));
}
}
for redirect_uri in metadata.redirect_uris() {
if host_is_public_suffix(redirect_uri) {
return Err(RouteError::UrlIsPublicSuffix("redirect_uri"));
}
}
let res = policy.evaluate_client_registration(&metadata).await?;
if !res.valid() {
return Err(RouteError::PolicyDenied(res.violations));
}
let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method {
Some(
OAuthClientAuthenticationMethod::ClientSecretJwt
| OAuthClientAuthenticationMethod::ClientSecretPost
| OAuthClientAuthenticationMethod::ClientSecretBasic,
) => {
let client_secret = Alphanumeric.sample_string(&mut rng, 20);
let encrypted_client_secret = encrypter.encrypt_to_string(client_secret.as_bytes())?;
(Some(client_secret), Some(encrypted_client_secret))
}
_ => (None, None),
};
let client = repo
.oauth2_client()
.add(
&mut rng,
&clock,
metadata.redirect_uris().to_vec(),
encrypted_client_secret,
metadata.application_type.clone(),
metadata.grant_types().to_vec(),
metadata.contacts.clone().unwrap_or_default(),
metadata
.client_name
.clone()
.map(Localized::to_non_localized),
metadata.logo_uri.clone().map(Localized::to_non_localized),
metadata.client_uri.clone().map(Localized::to_non_localized),
metadata.policy_uri.clone().map(Localized::to_non_localized),
metadata.tos_uri.clone().map(Localized::to_non_localized),
metadata.jwks_uri.clone(),
metadata.jwks.clone(),
metadata.id_token_signed_response_alg.clone(),
metadata.userinfo_signed_response_alg.clone(),
metadata.token_endpoint_auth_method.clone(),
metadata.token_endpoint_auth_signing_alg.clone(),
metadata.initiate_login_uri.clone(),
)
.await?;
repo.save().await?;
let response = ClientRegistrationResponse {
client_id: client.client_id,
client_secret,
client_id_issued_at: Some(client.id.datetime().into()),
client_secret_expires_at: None,
};
Ok((StatusCode::CREATED, Json(response)))
}
#[cfg(test)]
mod tests {
use hyper::{Request, StatusCode};
use mas_router::SimpleRoute;
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
registration::ClientRegistrationResponse,
};
use sqlx::PgPool;
use url::Url;
use crate::{
oauth2::registration::host_is_public_suffix,
test_utils::{setup, RequestBuilderExt, ResponseExt, TestState},
};
#[test]
fn test_public_suffix_list() {
fn url_is_public_suffix(url: &str) -> bool {
host_is_public_suffix(&Url::parse(url).unwrap())
}
assert!(url_is_public_suffix("https://.com"));
assert!(url_is_public_suffix("https://.com."));
assert!(url_is_public_suffix("https://co.uk"));
assert!(url_is_public_suffix("https://github.io"));
assert!(!url_is_public_suffix("https://example.com"));
assert!(!url_is_public_suffix("https://example.com."));
assert!(!url_is_public_suffix("https://x.com"));
assert!(!url_is_public_suffix("https://x.com."));
assert!(!url_is_public_suffix("https://matrix-org.github.io"));
assert!(!url_is_public_suffix("http://localhost"));
assert!(!url_is_public_suffix("org.matrix:/callback"));
assert!(!url_is_public_suffix("http://somerandominternaldomain"));
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_registration_error(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let request = Request::post(mas_router::OAuth2RegistrationEndpoint::PATH)
.body("this is not a json".to_owned())
.unwrap();
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let response: ClientError = response.json();
assert_eq!(response.error, ClientErrorCode::InvalidRequest);
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"client_uri": "this is not a uri",
}));
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let response: ClientError = response.json();
assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"application_type": "web",
"contacts": ["hello@example.com"],
"client_uri": "https://example.com/",
"redirect_uris": ["http://this-is-insecure.com/"],
}));
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let response: ClientError = response.json();
assert_eq!(response.error, ClientErrorCode::InvalidRedirectUri);
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"contacts": ["hello@example.com"],
"client_uri": "https://example.com/",
"redirect_uris": ["https://example.com/"],
"response_types": ["id_token"],
"grant_types": ["authorization_code"],
}));
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let response: ClientError = response.json();
assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"contacts": ["hello@example.com"],
"client_uri": "https://github.io/",
"redirect_uris": ["https://github.io/"],
"response_types": ["code"],
"grant_types": ["authorization_code"],
"token_endpoint_auth_method": "client_secret_basic",
}));
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let response: ClientError = response.json();
assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
assert_eq!(
response.error_description.unwrap(),
"client_uri is not using a valid domain"
);
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"contacts": ["hello@example.com"],
"client_uri": "https://example.com/",
"client_uri#fr-FR": "https://github.io/",
"redirect_uris": ["https://example.com/"],
"response_types": ["code"],
"grant_types": ["authorization_code"],
"token_endpoint_auth_method": "client_secret_basic",
}));
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let response: ClientError = response.json();
assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
assert_eq!(
response.error_description.unwrap(),
"client_uri is not using a valid domain"
);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_registration(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"contacts": ["hello@example.com"],
"client_uri": "https://example.com/",
"redirect_uris": ["https://example.com/"],
"response_types": ["code"],
"grant_types": ["authorization_code"],
"token_endpoint_auth_method": "none",
}));
let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);
let response: ClientRegistrationResponse = response.json();
assert!(response.client_secret.is_none());
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"contacts": ["hello@example.com"],
"client_uri": "https://example.com/",
"redirect_uris": ["https://example.com/"],
"response_types": ["code"],
"grant_types": ["authorization_code"],
"token_endpoint_auth_method": "client_secret_basic",
}));
let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);
let response: ClientRegistrationResponse = response.json();
assert!(response.client_secret.is_some());
}
}