#![allow(clippy::module_name_repetitions)]
use std::collections::HashMap;
use axum::response::{Html, IntoResponse, Redirect, Response};
use mas_data_model::AuthorizationGrant;
use mas_templates::{FormPostContext, Templates};
use oauth2_types::requests::ResponseMode;
use serde::Serialize;
use thiserror::Error;
use url::Url;
#[derive(Debug, Clone)]
enum CallbackDestinationMode {
Query {
existing_params: HashMap<String, String>,
},
Fragment,
FormPost,
}
#[derive(Debug, Clone)]
pub struct CallbackDestination {
mode: CallbackDestinationMode,
safe_redirect_uri: Url,
state: Option<String>,
}
#[derive(Debug, Error)]
pub enum IntoCallbackDestinationError {
#[error("Redirect URI can't have a fragment")]
RedirectUriFragmentNotAllowed,
#[error("Existing query parameters are not valid")]
RedirectUriInvalidQueryParams(#[from] serde_urlencoded::de::Error),
#[error("Requested response_mode is not supported")]
UnsupportedResponseMode,
}
#[derive(Debug, Error)]
pub enum CallbackDestinationError {
#[error("Failed to render the form_post template")]
FormPostRender(#[from] mas_templates::TemplateError),
#[error("Failed to serialize parameters query string")]
ParamsSerialization(#[from] serde_urlencoded::ser::Error),
}
impl TryFrom<&AuthorizationGrant> for CallbackDestination {
type Error = IntoCallbackDestinationError;
fn try_from(value: &AuthorizationGrant) -> Result<Self, Self::Error> {
Self::try_new(
&value.response_mode,
value.redirect_uri.clone(),
value.state.clone(),
)
}
}
impl CallbackDestination {
pub fn try_new(
mode: &ResponseMode,
mut redirect_uri: Url,
state: Option<String>,
) -> Result<Self, IntoCallbackDestinationError> {
if redirect_uri.fragment().is_some() {
return Err(IntoCallbackDestinationError::RedirectUriFragmentNotAllowed);
}
let mode = match mode {
ResponseMode::Query => {
let existing_params = redirect_uri
.query()
.map(serde_urlencoded::from_str)
.transpose()?
.unwrap_or_default();
redirect_uri.set_query(None);
CallbackDestinationMode::Query { existing_params }
}
ResponseMode::Fragment => CallbackDestinationMode::Fragment,
ResponseMode::FormPost => CallbackDestinationMode::FormPost,
_ => return Err(IntoCallbackDestinationError::UnsupportedResponseMode),
};
Ok(Self {
mode,
safe_redirect_uri: redirect_uri,
state,
})
}
pub async fn go<T: Serialize + Send + Sync>(
self,
templates: &Templates,
params: T,
) -> Result<Response, CallbackDestinationError> {
#[derive(Serialize)]
struct AllParams<'s, T> {
#[serde(flatten, skip_serializing_if = "Option::is_none")]
existing: Option<&'s HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
state: Option<String>,
#[serde(flatten)]
params: T,
}
let mut redirect_uri = self.safe_redirect_uri;
let state = self.state;
match self.mode {
CallbackDestinationMode::Query { existing_params } => {
let merged = AllParams {
existing: Some(&existing_params),
state,
params,
};
let new_qs = serde_urlencoded::to_string(merged)?;
redirect_uri.set_query(Some(&new_qs));
Ok(Redirect::to(redirect_uri.as_str()).into_response())
}
CallbackDestinationMode::Fragment => {
let merged = AllParams {
existing: None,
state,
params,
};
let new_qs = serde_urlencoded::to_string(merged)?;
redirect_uri.set_fragment(Some(&new_qs));
Ok(Redirect::to(redirect_uri.as_str()).into_response())
}
CallbackDestinationMode::FormPost => {
let merged = AllParams {
existing: None,
state,
params,
};
let ctx = FormPostContext::new(redirect_uri, merged);
let rendered = templates.render_form_post(&ctx)?;
Ok(Html(rendered).into_response())
}
}
}
}