use std::convert::Infallible;
use aide::OperationIo;
use axum::{
extract::FromRequestParts,
response::{IntoResponse, Response},
Json,
};
use axum_extra::TypedHeader;
use headers::{authorization::Bearer, Authorization};
use hyper::StatusCode;
use mas_data_model::{Session, User};
use mas_storage::{BoxClock, BoxRepository, RepositoryError};
use ulid::Ulid;
use super::response::ErrorResponse;
use crate::BoundActivityTracker;
#[derive(Debug, thiserror::Error)]
pub enum Rejection {
#[error("Missing authorization header")]
MissingAuthorizationHeader,
#[error("Invalid authorization header")]
InvalidAuthorizationHeader,
#[error("Couldn't load the database repository")]
RepositorySetup(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("Invalid repository operation")]
Repository(#[from] RepositoryError),
#[error("Unknown access token")]
UnknownAccessToken,
#[error("Access token expired")]
TokenExpired,
#[error("Access token revoked")]
SessionRevoked,
#[error("User locked")]
UserLocked,
#[error("Failed to load session {0}")]
LoadSession(Ulid),
#[error("Failed to load user {0}")]
LoadUser(Ulid),
#[error("Missing urn:mas:admin scope")]
MissingScope,
}
impl Rejection {
fn status_code(&self) -> StatusCode {
match self {
Self::InvalidAuthorizationHeader | Self::MissingAuthorizationHeader => {
StatusCode::BAD_REQUEST
}
Self::UnknownAccessToken
| Self::TokenExpired
| Self::SessionRevoked
| Self::UserLocked
| Self::MissingScope => StatusCode::UNAUTHORIZED,
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
impl IntoResponse for Rejection {
fn into_response(self) -> Response {
let response = ErrorResponse::from_error(&self);
let status = self.status_code();
(status, Json(response)).into_response()
}
}
#[non_exhaustive]
#[derive(OperationIo)]
#[aide(input)]
pub struct CallContext {
pub repo: BoxRepository,
pub clock: BoxClock,
pub user: Option<User>,
pub session: Session,
}
#[async_trait::async_trait]
impl<S> FromRequestParts<S> for CallContext
where
S: Send + Sync,
BoundActivityTracker: FromRequestParts<S, Rejection = Infallible>,
BoxRepository: FromRequestParts<S>,
BoxClock: FromRequestParts<S, Rejection = Infallible>,
<BoxRepository as FromRequestParts<S>>::Rejection:
Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
type Rejection = Rejection;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
let activity_tracker = BoundActivityTracker::from_request_parts(parts, state).await;
let activity_tracker = match activity_tracker {
Ok(t) => t,
Err(e) => match e {},
};
let clock = BoxClock::from_request_parts(parts, state).await;
let clock = match clock {
Ok(c) => c,
Err(e) => match e {},
};
let mut repo = BoxRepository::from_request_parts(parts, state)
.await
.map_err(Into::into)
.map_err(Rejection::RepositorySetup)?;
let token = TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
.await
.map_err(|e| {
if e.is_missing() {
Rejection::MissingAuthorizationHeader
} else {
Rejection::InvalidAuthorizationHeader
}
})?;
let token = token.token();
let token = repo
.oauth2_access_token()
.find_by_token(token)
.await?
.ok_or(Rejection::UnknownAccessToken)?;
let session = repo
.oauth2_session()
.lookup(token.session_id)
.await?
.ok_or_else(|| Rejection::LoadSession(token.session_id))?;
activity_tracker
.record_oauth2_session(&clock, &session)
.await;
let user = if let Some(user_id) = session.user_id {
let user = repo
.user()
.lookup(user_id)
.await?
.ok_or_else(|| Rejection::LoadUser(user_id))?;
Some(user)
} else {
None
};
if let Some(user) = &user {
if !user.is_valid() {
return Err(Rejection::UserLocked);
}
}
if !session.is_valid() {
return Err(Rejection::SessionRevoked);
}
if !token.is_valid(clock.now()) {
return Err(Rejection::TokenExpired);
}
if !session.scope.contains("urn:mas:admin") {
return Err(Rejection::MissingScope);
}
Ok(Self {
repo,
clock,
user,
session,
})
}
}