use std::{
collections::HashMap,
fs,
fs::File,
io::{BufReader, BufWriter},
path::{Path, PathBuf},
};
use mas_oidc_client::types::registration::{
ClientMetadata, ClientMetadataVerificationError, VerifiedClientMetadata,
};
use serde::{Deserialize, Serialize};
use url::Url;
#[derive(Debug, thiserror::Error)]
pub enum OidcRegistrationsError {
#[error("Failed to use the supplied registrations file path.")]
InvalidFilePath,
#[error("Failed to save the registration data {0}.")]
SaveFailure(#[source] Box<dyn std::error::Error + Send + Sync>),
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct ClientId(pub String);
#[derive(Debug)]
pub struct OidcRegistrations {
file_path: PathBuf,
verified_metadata: VerifiedClientMetadata,
static_registrations: HashMap<Url, ClientId>,
}
#[derive(Debug, Serialize)]
struct FrozenRegistrationData {
metadata: VerifiedClientMetadata,
dynamic_registrations: HashMap<Url, ClientId>,
}
#[derive(Debug, Deserialize)]
struct UnvalidatedRegistrationData {
metadata: ClientMetadata,
dynamic_registrations: HashMap<Url, ClientId>,
}
impl UnvalidatedRegistrationData {
fn validate(&self) -> Result<FrozenRegistrationData, ClientMetadataVerificationError> {
let verified_metadata = match self.metadata.clone().validate() {
Ok(metadata) => metadata,
Err(e) => {
tracing::warn!("Failed to validate stored metadata.");
return Err(e);
}
};
Ok(FrozenRegistrationData {
metadata: verified_metadata,
dynamic_registrations: self.dynamic_registrations.clone(),
})
}
}
impl OidcRegistrations {
pub fn new(
registrations_file: &Path,
metadata: VerifiedClientMetadata,
static_registrations: HashMap<Url, ClientId>,
) -> Result<Self, OidcRegistrationsError> {
let parent = registrations_file.parent().ok_or(OidcRegistrationsError::InvalidFilePath)?;
fs::create_dir_all(parent).map_err(|_| OidcRegistrationsError::InvalidFilePath)?;
Ok(OidcRegistrations {
file_path: registrations_file.to_owned(),
verified_metadata: metadata,
static_registrations,
})
}
pub fn client_id(&self, issuer: &Url) -> Option<ClientId> {
let mut data = self.read_or_generate_registration_data();
data.dynamic_registrations.extend(self.static_registrations.clone());
data.dynamic_registrations.get(issuer).cloned()
}
pub fn set_and_write_client_id(
&self,
client_id: ClientId,
issuer: Url,
) -> Result<(), OidcRegistrationsError> {
let mut data = self.read_or_generate_registration_data();
data.dynamic_registrations.insert(issuer, client_id);
let writer = BufWriter::new(
File::create(&self.file_path)
.map_err(|e| OidcRegistrationsError::SaveFailure(Box::new(e)))?,
);
serde_json::to_writer(writer, &data)
.map_err(|e| OidcRegistrationsError::SaveFailure(Box::new(e)))
}
fn read_or_generate_registration_data(&self) -> FrozenRegistrationData {
let try_read_previous = || {
let reader = BufReader::new(
File::open(&self.file_path)
.map_err(|error| {
tracing::warn!("Failed to load registrations file: {error}");
})
.ok()?,
);
let registration_data: UnvalidatedRegistrationData = serde_json::from_reader(reader)
.map_err(|error| {
tracing::warn!("Failed to deserialize registrations file: {error}");
})
.ok()?;
let registration_data = registration_data
.validate()
.map_err(|error| {
tracing::warn!("Failed to validate registration data: {error}");
})
.ok()?;
if registration_data.metadata != self.verified_metadata {
tracing::warn!("Metadata mismatch, ignoring any stored registrations.");
return None;
}
Some(registration_data)
};
try_read_previous().unwrap_or_else(|| {
tracing::warn!("Generating new registration data");
FrozenRegistrationData {
metadata: self.verified_metadata.clone(),
dynamic_registrations: Default::default(),
}
})
}
}
#[cfg(test)]
mod tests {
use mas_oidc_client::types::registration::Localized;
use tempfile::tempdir;
use super::*;
#[test]
fn test_oidc_registrations() {
let dir = tempdir().unwrap();
let registrations_file = dir.path().join("oidc").join("registrations.json");
let static_url = Url::parse("https://example.com").unwrap();
let static_id = ClientId("static_client_id".to_owned());
let dynamic_url = Url::parse("https://example.org").unwrap();
let dynamic_id = ClientId("dynamic_client_id".to_owned());
let mut static_registrations = HashMap::new();
static_registrations.insert(static_url.clone(), static_id.clone());
let oidc_metadata = mock_metadata("Example".to_owned());
let registrations =
OidcRegistrations::new(®istrations_file, oidc_metadata, static_registrations)
.unwrap();
assert_eq!(registrations.client_id(&static_url), Some(static_id.clone()));
assert_eq!(registrations.client_id(&dynamic_url), None);
registrations.set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone()).unwrap();
assert_eq!(registrations.client_id(&static_url), Some(static_id));
assert_eq!(registrations.client_id(&dynamic_url), Some(dynamic_id));
}
#[test]
fn test_change_of_metadata() {
let dir = tempdir().unwrap();
let registrations_file = dir.path().join("oidc").join("registrations.json");
let static_url = Url::parse("https://example.com").unwrap();
let static_id = ClientId("static_client_id".to_owned());
let dynamic_url = Url::parse("https://example.org").unwrap();
let dynamic_id = ClientId("dynamic_client_id".to_owned());
let oidc_metadata = mock_metadata("Example".to_owned());
let mut static_registrations = HashMap::new();
static_registrations.insert(static_url.clone(), static_id.clone());
let registrations = OidcRegistrations::new(
®istrations_file,
oidc_metadata,
static_registrations.clone(),
)
.unwrap();
registrations.set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone()).unwrap();
assert_eq!(registrations.client_id(&static_url), Some(static_id.clone()));
assert_eq!(registrations.client_id(&dynamic_url), Some(dynamic_id));
let new_oidc_metadata = mock_metadata("New App".to_owned());
let registrations =
OidcRegistrations::new(®istrations_file, new_oidc_metadata, static_registrations)
.unwrap();
assert_eq!(registrations.client_id(&dynamic_url), None);
assert_eq!(registrations.client_id(&static_url), Some(static_id));
}
fn mock_metadata(client_name: String) -> VerifiedClientMetadata {
let callback_url = Url::parse("https://example.org/login/callback").unwrap();
let client_name = Some(Localized::new(client_name, None));
ClientMetadata {
redirect_uris: Some(vec![callback_url]),
client_name,
..Default::default()
}
.validate()
.unwrap()
}
}