use axum::{extract::State, response::IntoResponse, Json};
use hyper::StatusCode;
use mas_axum_utils::{
client_authorization::{ClientAuthorization, CredentialsVerificationError},
http_client_factory::HttpClientFactory,
sentry::SentryEventID,
};
use mas_data_model::TokenType;
use mas_iana::oauth::OAuthTokenTypeHint;
use mas_keystore::Encrypter;
use mas_storage::{
job::{JobRepositoryExt, SyncDevicesJob},
BoxClock, BoxRepository, RepositoryAccess,
};
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
requests::RevocationRequest,
};
use thiserror::Error;
use crate::{impl_from_error_for_route, BoundActivityTracker};
#[derive(Debug, Error)]
pub(crate) enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("bad request")]
BadRequest,
#[error("client not found")]
ClientNotFound,
#[error("client not allowed")]
ClientNotAllowed,
#[error("could not verify client credentials")]
ClientCredentialsVerification(#[from] CredentialsVerificationError),
#[error("client is unauthorized")]
UnauthorizedClient,
#[error("unsupported token type")]
UnsupportedTokenType,
#[error("unknown token")]
UnknownToken,
}
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let response = match self {
Self::Internal(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ClientError::from(ClientErrorCode::ServerError)),
)
.into_response(),
Self::BadRequest => (
StatusCode::BAD_REQUEST,
Json(ClientError::from(ClientErrorCode::InvalidRequest)),
)
.into_response(),
Self::ClientNotFound | Self::ClientCredentialsVerification(_) => (
StatusCode::UNAUTHORIZED,
Json(ClientError::from(ClientErrorCode::InvalidClient)),
)
.into_response(),
Self::ClientNotAllowed | Self::UnauthorizedClient => (
StatusCode::UNAUTHORIZED,
Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
)
.into_response(),
Self::UnsupportedTokenType => (
StatusCode::BAD_REQUEST,
Json(ClientError::from(ClientErrorCode::UnsupportedTokenType)),
)
.into_response(),
Self::UnknownToken => StatusCode::OK.into_response(),
};
(SentryEventID::from(event_id), response).into_response()
}
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl From<mas_data_model::TokenFormatError> for RouteError {
fn from(_e: mas_data_model::TokenFormatError) -> Self {
Self::UnknownToken
}
}
#[tracing::instrument(
name = "handlers.oauth2.revoke.post",
fields(client.id = client_authorization.client_id()),
skip_all,
err,
)]
pub(crate) async fn post(
clock: BoxClock,
State(http_client_factory): State<HttpClientFactory>,
mut repo: BoxRepository,
activity_tracker: BoundActivityTracker,
State(encrypter): State<Encrypter>,
client_authorization: ClientAuthorization<RevocationRequest>,
) -> Result<impl IntoResponse, RouteError> {
let client = client_authorization
.credentials
.fetch(&mut repo)
.await?
.ok_or(RouteError::ClientNotFound)?;
let method = client
.token_endpoint_auth_method
.as_ref()
.ok_or(RouteError::ClientNotAllowed)?;
client_authorization
.credentials
.verify(&http_client_factory, &encrypter, method, &client)
.await?;
let Some(form) = client_authorization.form else {
return Err(RouteError::BadRequest);
};
let token_type = TokenType::check(&form.token)?;
let session_id = match (form.token_type_hint, token_type) {
(Some(OAuthTokenTypeHint::AccessToken) | None, TokenType::AccessToken) => {
let access_token = repo
.oauth2_access_token()
.find_by_token(&form.token)
.await?
.ok_or(RouteError::UnknownToken)?;
if !access_token.is_valid(clock.now()) {
return Err(RouteError::UnknownToken);
}
access_token.session_id
}
(Some(OAuthTokenTypeHint::RefreshToken) | None, TokenType::RefreshToken) => {
let refresh_token = repo
.oauth2_refresh_token()
.find_by_token(&form.token)
.await?
.ok_or(RouteError::UnknownToken)?;
if !refresh_token.is_valid() {
return Err(RouteError::UnknownToken);
}
refresh_token.session_id
}
(Some(OAuthTokenTypeHint::AccessToken | OAuthTokenTypeHint::RefreshToken) | None, _) => {
return Err(RouteError::UnknownToken)
}
(Some(_), _) => return Err(RouteError::UnsupportedTokenType),
};
let session = repo
.oauth2_session()
.lookup(session_id)
.await?
.ok_or(RouteError::UnknownToken)?;
if !session.is_valid() {
return Err(RouteError::UnknownToken);
}
if client.id != session.client_id {
return Err(RouteError::UnauthorizedClient);
}
activity_tracker
.record_oauth2_session(&clock, &session)
.await;
if let Some(user_id) = session.user_id {
let user = repo
.user()
.lookup(user_id)
.await?
.ok_or(RouteError::UnknownToken)?;
repo.job().schedule_job(SyncDevicesJob::new(&user)).await?;
}
repo.oauth2_session().finish(&clock, session).await?;
repo.save().await?;
Ok(())
}
#[cfg(test)]
mod tests {
use chrono::Duration;
use hyper::Request;
use mas_data_model::{AccessToken, RefreshToken};
use mas_router::SimpleRoute;
use mas_storage::RepositoryAccess;
use oauth2_types::{
registration::ClientRegistrationResponse,
requests::AccessTokenResponse,
scope::{Scope, OPENID},
};
use sqlx::PgPool;
use super::*;
use crate::{
oauth2::generate_token_pair,
test_utils::{setup, RequestBuilderExt, ResponseExt, TestState},
};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_revoke_access_token(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"client_uri": "https://example.com/",
"redirect_uris": ["https://example.com/callback"],
"contacts": ["contact@example.com"],
"token_endpoint_auth_method": "client_secret_post",
"response_types": ["code"],
"grant_types": ["authorization_code", "refresh_token"],
}));
let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);
let client_registration: ClientRegistrationResponse = response.json();
let client_id = client_registration.client_id;
let client_secret = client_registration.client_secret.unwrap();
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.add(&mut state.rng(), &state.clock, "alice".to_owned())
.await
.unwrap();
let browser_session = repo
.browser_session()
.add(&mut state.rng(), &state.clock, &user, None)
.await
.unwrap();
let client = repo
.oauth2_client()
.find_by_client_id(&client_id)
.await
.unwrap()
.unwrap();
let session = repo
.oauth2_session()
.add_from_browser_session(
&mut state.rng(),
&state.clock,
&client,
&browser_session,
Scope::from_iter([OPENID]),
)
.await
.unwrap();
let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
generate_token_pair(
&mut state.rng(),
&state.clock,
&mut repo,
&session,
Duration::microseconds(5 * 60 * 1000 * 1000),
)
.await
.unwrap();
repo.save().await.unwrap();
assert!(state.is_access_token_valid(&access_token).await);
let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
"token": access_token,
"token_type_hint": "access_token",
"client_id": client_id,
"client_secret": client_secret,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
assert!(!state.is_access_token_valid(&access_token).await);
let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
"token": access_token,
"token_type_hint": "access_token",
"client_id": client_id,
"client_secret": client_secret,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client_id,
"client_secret": client_secret,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let mut repo = state.repository().await.unwrap();
let session = repo
.oauth2_session()
.add_from_browser_session(
&mut state.rng(),
&state.clock,
&client,
&browser_session,
Scope::from_iter([OPENID]),
)
.await
.unwrap();
let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
generate_token_pair(
&mut state.rng(),
&state.clock,
&mut repo,
&session,
Duration::microseconds(5 * 60 * 1000 * 1000),
)
.await
.unwrap();
repo.save().await.unwrap();
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client_id,
"client_secret": client_secret,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let old_access_token = access_token;
let old_refresh_token = refresh_token;
let AccessTokenResponse {
access_token,
refresh_token,
..
} = response.json();
assert!(state.is_access_token_valid(&access_token).await);
assert!(!state.is_access_token_valid(&old_access_token).await);
let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
"token": old_access_token,
"token_type_hint": "access_token",
"client_id": client_id,
"client_secret": client_secret,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
assert!(state.is_access_token_valid(&access_token).await);
let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
"token": old_refresh_token,
"token_type_hint": "refresh_token",
"client_id": client_id,
"client_secret": client_secret,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
assert!(state.is_access_token_valid(&access_token).await);
let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
"token": refresh_token,
"token_type_hint": "refresh_token",
"client_id": client_id,
"client_secret": client_secret,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
assert!(!state.is_access_token_valid(&access_token).await);
}
}