use std::{collections::HashMap, sync::Arc};
use mas_data_model::{
UpstreamOAuthProvider, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderPkceMode,
};
use mas_http::HttpService;
use mas_iana::oauth::PkceCodeChallengeMethod;
use mas_oidc_client::error::DiscoveryError;
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess};
use oauth2_types::oidc::VerifiedProviderMetadata;
use tokio::sync::RwLock;
use url::Url;
pub struct LazyProviderInfos<'a> {
cache: &'a MetadataCache,
provider: &'a UpstreamOAuthProvider,
http_service: &'a HttpService,
loaded_metadata: Option<Arc<VerifiedProviderMetadata>>,
}
impl<'a> LazyProviderInfos<'a> {
pub fn new(
cache: &'a MetadataCache,
provider: &'a UpstreamOAuthProvider,
http_service: &'a HttpService,
) -> Self {
Self {
cache,
provider,
http_service,
loaded_metadata: None,
}
}
pub async fn maybe_discover<'b>(
&'b mut self,
) -> Result<Option<&'b VerifiedProviderMetadata>, DiscoveryError> {
match self.load().await {
Ok(metadata) => Ok(Some(metadata)),
Err(DiscoveryError::Disabled) => Ok(None),
Err(e) => Err(e),
}
}
async fn load<'b>(&'b mut self) -> Result<&'b VerifiedProviderMetadata, DiscoveryError> {
if self.loaded_metadata.is_none() {
let verify = match self.provider.discovery_mode {
UpstreamOAuthProviderDiscoveryMode::Oidc => true,
UpstreamOAuthProviderDiscoveryMode::Insecure => false,
UpstreamOAuthProviderDiscoveryMode::Disabled => {
return Err(DiscoveryError::Disabled)
}
};
let metadata = self
.cache
.get(self.http_service, &self.provider.issuer, verify)
.await?;
self.loaded_metadata = Some(metadata);
}
Ok(self.loaded_metadata.as_ref().unwrap())
}
pub async fn jwks_uri(&mut self) -> Result<&Url, DiscoveryError> {
if let Some(jwks_uri) = &self.provider.jwks_uri_override {
return Ok(jwks_uri);
}
Ok(self.load().await?.jwks_uri())
}
pub async fn authorization_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
if let Some(authorization_endpoint) = &self.provider.authorization_endpoint_override {
return Ok(authorization_endpoint);
}
Ok(self.load().await?.authorization_endpoint())
}
pub async fn token_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
if let Some(token_endpoint) = &self.provider.token_endpoint_override {
return Ok(token_endpoint);
}
Ok(self.load().await?.token_endpoint())
}
pub async fn pkce_methods(
&mut self,
) -> Result<Option<Vec<PkceCodeChallengeMethod>>, DiscoveryError> {
let methods = match self.provider.pkce_mode {
UpstreamOAuthProviderPkceMode::Auto => self
.maybe_discover()
.await?
.and_then(|metadata| metadata.code_challenge_methods_supported.clone()),
UpstreamOAuthProviderPkceMode::S256 => Some(vec![PkceCodeChallengeMethod::S256]),
UpstreamOAuthProviderPkceMode::Disabled => None,
};
Ok(methods)
}
}
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Default)]
pub struct MetadataCache {
cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
insecure_cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
}
impl MetadataCache {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all, err)]
pub async fn warm_up_and_run<R: RepositoryAccess>(
&self,
http_service: HttpService,
interval: std::time::Duration,
repository: &mut R,
) -> Result<tokio::task::JoinHandle<()>, R::Error> {
let providers = repository.upstream_oauth_provider().all_enabled().await?;
for provider in providers {
let verify = match provider.discovery_mode {
UpstreamOAuthProviderDiscoveryMode::Oidc => true,
UpstreamOAuthProviderDiscoveryMode::Insecure => false,
UpstreamOAuthProviderDiscoveryMode::Disabled => continue,
};
if let Err(e) = self.fetch(&http_service, &provider.issuer, verify).await {
tracing::error!(issuer = %provider.issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
}
}
let cache = self.clone();
Ok(tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
cache.refresh_all(&http_service).await;
}
}))
}
#[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all, err)]
async fn fetch(
&self,
http_service: &HttpService,
issuer: &str,
verify: bool,
) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
if verify {
let metadata =
mas_oidc_client::requests::discovery::discover(http_service, issuer).await?;
let metadata = Arc::new(metadata);
self.cache
.write()
.await
.insert(issuer.to_owned(), metadata.clone());
Ok(metadata)
} else {
let metadata =
mas_oidc_client::requests::discovery::insecure_discover(http_service, issuer)
.await?;
let metadata = Arc::new(metadata);
self.insecure_cache
.write()
.await
.insert(issuer.to_owned(), metadata.clone());
Ok(metadata)
}
}
#[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all, err)]
pub async fn get(
&self,
http_service: &HttpService,
issuer: &str,
verify: bool,
) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
let cache = if verify {
self.cache.read().await
} else {
self.insecure_cache.read().await
};
if let Some(metadata) = cache.get(issuer) {
return Ok(Arc::clone(metadata));
}
drop(cache);
let metadata = self.fetch(http_service, issuer, verify).await?;
Ok(metadata)
}
#[tracing::instrument(name = "metadata_cache.refresh_all", skip_all)]
async fn refresh_all(&self, http_service: &HttpService) {
let keys: Vec<String> = {
let cache = self.cache.read().await;
cache.keys().cloned().collect()
};
for issuer in keys {
if let Err(e) = self.fetch(http_service, &issuer, true).await {
tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
}
}
let keys: Vec<String> = {
let cache = self.insecure_cache.read().await;
cache.keys().cloned().collect()
};
for issuer in keys {
if let Err(e) = self.fetch(http_service, &issuer, false).await {
tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
}
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::too_many_lines)]
use std::sync::atomic::{AtomicUsize, Ordering};
use hyper::{body::Bytes, Request, Response, StatusCode};
use mas_data_model::UpstreamOAuthProviderClaimsImports;
use mas_http::BoxCloneSyncService;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_storage::{clock::MockClock, Clock};
use oauth2_types::scope::{Scope, OPENID};
use tower::BoxError;
use ulid::Ulid;
use super::*;
use crate::test_utils::setup;
#[tokio::test]
async fn test_metadata_cache() {
setup();
let calls = Arc::new(AtomicUsize::new(0));
let closure_calls = Arc::clone(&calls);
let handler = move |req: Request<Bytes>| {
let calls = Arc::clone(&closure_calls);
async move {
calls.fetch_add(1, Ordering::SeqCst);
let body = match req.uri().authority().unwrap().as_str() {
"valid.example.com" => Bytes::from_static(
br#"{
"issuer": "https://valid.example.com/",
"authorization_endpoint": "https://valid.example.com/authorize",
"token_endpoint": "https://valid.example.com/token",
"jwks_uri": "https://valid.example.com/jwks",
"response_types_supported": [
"code"
],
"grant_types_supported": [
"authorization_code"
],
"subject_types_supported": [
"public"
],
"id_token_signing_alg_values_supported": [
"RS256"
],
"scopes_supported": [
"openid",
"profile",
"email"
]
}"#,
),
"insecure.example.com" => Bytes::from_static(
br#"{
"issuer": "http://insecure.example.com/",
"authorization_endpoint": "http://insecure.example.com/authorize",
"token_endpoint": "http://insecure.example.com/token",
"jwks_uri": "http://insecure.example.com/jwks",
"response_types_supported": [
"code"
],
"grant_types_supported": [
"authorization_code"
],
"subject_types_supported": [
"public"
],
"id_token_signing_alg_values_supported": [
"RS256"
],
"scopes_supported": [
"openid",
"profile",
"email"
]
}"#,
),
_ => Bytes::default(),
};
let mut response = Response::new(body);
*response.status_mut() = StatusCode::OK;
Ok::<_, BoxError>(response)
}
};
let service = BoxCloneSyncService::new(tower::service_fn(handler));
let cache = MetadataCache::new();
cache
.get(&service, "https://inexistant.example.com/", true)
.await
.unwrap_err();
assert_eq!(calls.load(Ordering::SeqCst), 1);
cache
.get(&service, "https://valid.example.com/", true)
.await
.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 2);
cache
.get(&service, "https://valid.example.com/", true)
.await
.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 2);
cache
.get(&service, "http://insecure.example.com/", false)
.await
.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 3);
cache
.get(&service, "http://insecure.example.com/", false)
.await
.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 3);
cache
.get(&service, "http://insecure.example.com/", true)
.await
.unwrap_err();
assert_eq!(calls.load(Ordering::SeqCst), 4);
cache.refresh_all(&service).await;
assert_eq!(calls.load(Ordering::SeqCst), 6);
}
#[tokio::test]
async fn test_lazy_provider_infos() {
setup();
let calls = Arc::new(AtomicUsize::new(0));
let closure_calls = Arc::clone(&calls);
let handler = move |req: Request<Bytes>| {
let calls = Arc::clone(&closure_calls);
async move {
calls.fetch_add(1, Ordering::SeqCst);
let body = match req.uri().authority().unwrap().as_str() {
"valid.example.com" => Bytes::from_static(
br#"{
"issuer": "https://valid.example.com/",
"authorization_endpoint": "https://valid.example.com/authorize",
"token_endpoint": "https://valid.example.com/token",
"jwks_uri": "https://valid.example.com/jwks",
"response_types_supported": [
"code"
],
"grant_types_supported": [
"authorization_code"
],
"subject_types_supported": [
"public"
],
"id_token_signing_alg_values_supported": [
"RS256"
],
"scopes_supported": [
"openid",
"profile",
"email"
]
}"#,
),
"insecure.example.com" => Bytes::from_static(
br#"{
"issuer": "http://insecure.example.com/",
"authorization_endpoint": "http://insecure.example.com/authorize",
"token_endpoint": "http://insecure.example.com/token",
"jwks_uri": "http://insecure.example.com/jwks",
"response_types_supported": [
"code"
],
"grant_types_supported": [
"authorization_code"
],
"subject_types_supported": [
"public"
],
"id_token_signing_alg_values_supported": [
"RS256"
],
"scopes_supported": [
"openid",
"profile",
"email"
]
}"#,
),
_ => Bytes::default(),
};
let mut response = Response::new(body);
*response.status_mut() = StatusCode::OK;
Ok::<_, BoxError>(response)
}
};
let clock = MockClock::default();
let service = BoxCloneSyncService::new(tower::service_fn(handler));
let provider = UpstreamOAuthProvider {
id: Ulid::nil(),
issuer: "https://valid.example.com/".to_owned(),
human_name: Some("Example Ltd.".to_owned()),
brand_name: None,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
jwks_uri_override: None,
authorization_endpoint_override: None,
token_endpoint_override: None,
scope: Scope::from_iter([OPENID]),
client_id: "client_id".to_owned(),
encrypted_client_secret: None,
token_endpoint_signing_alg: None,
token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
created_at: clock.now(),
disabled_at: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
additional_authorization_parameters: Vec::new(),
};
{
let cache = MetadataCache::new();
let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service);
assert_eq!(calls.load(Ordering::SeqCst), 0);
lazy_metadata.maybe_discover().await.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 1);
assert_eq!(
lazy_metadata
.authorization_endpoint()
.await
.unwrap()
.as_str(),
"https://valid.example.com/authorize"
);
}
{
let provider = UpstreamOAuthProvider {
jwks_uri_override: Some("https://valid.example.com/jwks_override".parse().unwrap()),
authorization_endpoint_override: Some(
"https://valid.example.com/authorize_override"
.parse()
.unwrap(),
),
token_endpoint_override: Some(
"https://valid.example.com/token_override".parse().unwrap(),
),
..provider.clone()
};
let cache = MetadataCache::new();
let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service);
assert_eq!(
lazy_metadata.jwks_uri().await.unwrap().as_str(),
"https://valid.example.com/jwks_override"
);
assert_eq!(
lazy_metadata
.authorization_endpoint()
.await
.unwrap()
.as_str(),
"https://valid.example.com/authorize_override"
);
assert_eq!(
lazy_metadata.token_endpoint().await.unwrap().as_str(),
"https://valid.example.com/token_override"
);
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
{
let provider = UpstreamOAuthProvider {
issuer: "http://insecure.example.com/".to_owned(),
..provider.clone()
};
let cache = MetadataCache::new();
let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service);
lazy_metadata.authorization_endpoint().await.unwrap_err();
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
{
let provider = UpstreamOAuthProvider {
issuer: "http://insecure.example.com/".to_owned(),
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
..provider.clone()
};
let cache = MetadataCache::new();
let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service);
assert_eq!(
lazy_metadata
.authorization_endpoint()
.await
.unwrap()
.as_str(),
"http://insecure.example.com/authorize"
);
assert_eq!(calls.load(Ordering::SeqCst), 3);
}
{
let provider = UpstreamOAuthProvider {
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Disabled,
authorization_endpoint_override: Some(
Url::parse("https://valid.example.com/authorize_override").unwrap(),
),
token_endpoint_override: None,
..provider.clone()
};
let cache = MetadataCache::new();
let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service);
assert!(lazy_metadata.maybe_discover().await.unwrap().is_none());
assert_eq!(
lazy_metadata
.authorization_endpoint()
.await
.unwrap()
.as_str(),
"https://valid.example.com/authorize_override"
);
assert!(matches!(
lazy_metadata.token_endpoint().await,
Err(DiscoveryError::Disabled),
));
assert_eq!(calls.load(Ordering::SeqCst), 3);
}
}
}