use axum::{
extract::{Path, Query, State},
response::{IntoResponse, Redirect},
};
use hyper::StatusCode;
use mas_axum_utils::{
cookies::CookieJar, http_client_factory::HttpClientFactory, sentry::SentryEventID,
};
use mas_data_model::UpstreamOAuthProvider;
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
use mas_router::UrlBuilder;
use mas_storage::{
upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository},
BoxClock, BoxRepository, BoxRng,
};
use thiserror::Error;
use ulid::Ulid;
use super::{cache::LazyProviderInfos, UpstreamSessionsCookie};
use crate::{
impl_from_error_for_route, upstream_oauth2::cache::MetadataCache,
views::shared::OptionalPostAuthAction,
};
#[derive(Debug, Error)]
pub(crate) enum RouteError {
#[error("Provider not found")]
ProviderNotFound,
#[error(transparent)]
Internal(Box<dyn std::error::Error>),
}
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let response = match self {
Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
};
(SentryEventID::from(event_id), response).into_response()
}
}
#[tracing::instrument(
name = "handlers.upstream_oauth2.authorize.get",
fields(upstream_oauth_provider.id = %provider_id),
skip_all,
err,
)]
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(http_client_factory): State<HttpClientFactory>,
State(metadata_cache): State<MetadataCache>,
mut repo: BoxRepository,
State(url_builder): State<UrlBuilder>,
cookie_jar: CookieJar,
Path(provider_id): Path<Ulid>,
Query(query): Query<OptionalPostAuthAction>,
) -> Result<impl IntoResponse, RouteError> {
let provider = repo
.upstream_oauth_provider()
.lookup(provider_id)
.await?
.filter(UpstreamOAuthProvider::enabled)
.ok_or(RouteError::ProviderNotFound)?;
let http_service = http_client_factory.http_service("upstream_oauth2.authorize");
let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &http_service);
lazy_metadata.maybe_discover().await?;
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
let data = AuthorizationRequestData::new(
provider.client_id.clone(),
provider.scope.clone(),
redirect_uri,
);
let data = if let Some(methods) = lazy_metadata.pkce_methods().await? {
data.with_code_challenge_methods_supported(methods)
} else {
data
};
let (mut url, data) = mas_oidc_client::requests::authorization_code::build_authorization_url(
lazy_metadata.authorization_endpoint().await?.clone(),
data,
&mut rng,
)?;
{
let mut params = url.query_pairs_mut();
for (key, value) in &provider.additional_authorization_parameters {
params.append_pair(key, value);
}
}
let session = repo
.upstream_oauth_session()
.add(
&mut rng,
&clock,
&provider,
data.state.clone(),
data.code_challenge_verifier,
data.nonce,
)
.await?;
let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar)
.add(session.id, provider.id, data.state, query.post_auth_action)
.save(cookie_jar, &clock);
repo.save().await?;
Ok((cookie_jar, Redirect::temporary(url.as_str())))
}