matrix_sdk/authentication/oidc/
registrations.rs1use std::{
24 collections::HashMap,
25 fs,
26 fs::File,
27 io::{BufReader, BufWriter},
28 path::{Path, PathBuf},
29};
30
31use mas_oidc_client::types::registration::{
32 ClientMetadata, ClientMetadataVerificationError, VerifiedClientMetadata,
33};
34pub use oauth2::ClientId;
35use serde::{Deserialize, Serialize};
36use url::Url;
37
38#[derive(Debug, thiserror::Error)]
40pub enum OidcRegistrationsError {
41 #[error("Failed to use the supplied registrations file path.")]
43 InvalidFilePath,
44 #[error("Failed to save the registration data {0}.")]
46 SaveFailure(#[source] Box<dyn std::error::Error + Send + Sync>),
47}
48
49#[derive(Debug)]
51pub struct OidcRegistrations {
52 file_path: PathBuf,
54 pub(super) verified_metadata: VerifiedClientMetadata,
57 static_registrations: HashMap<Url, ClientId>,
60}
61
62#[derive(Debug, Serialize)]
64struct FrozenRegistrationData {
65 metadata: VerifiedClientMetadata,
67 dynamic_registrations: HashMap<Url, ClientId>,
70}
71
72#[derive(Debug, Deserialize)]
75struct UnvalidatedRegistrationData {
76 metadata: ClientMetadata,
78 dynamic_registrations: HashMap<Url, ClientId>,
81}
82
83impl UnvalidatedRegistrationData {
84 fn validate(&self) -> Result<FrozenRegistrationData, ClientMetadataVerificationError> {
86 let verified_metadata = match self.metadata.clone().validate() {
87 Ok(metadata) => metadata,
88 Err(e) => {
89 tracing::warn!("Failed to validate stored metadata.");
90 return Err(e);
91 }
92 };
93
94 Ok(FrozenRegistrationData {
95 metadata: verified_metadata,
96 dynamic_registrations: self.dynamic_registrations.clone(),
97 })
98 }
99}
100
101impl OidcRegistrations {
103 pub fn new(
118 registrations_file: &Path,
119 metadata: VerifiedClientMetadata,
120 static_registrations: HashMap<Url, ClientId>,
121 ) -> Result<Self, OidcRegistrationsError> {
122 let parent = registrations_file.parent().ok_or(OidcRegistrationsError::InvalidFilePath)?;
123 fs::create_dir_all(parent).map_err(|_| OidcRegistrationsError::InvalidFilePath)?;
124
125 Ok(OidcRegistrations {
126 file_path: registrations_file.to_owned(),
127 verified_metadata: metadata,
128 static_registrations,
129 })
130 }
131
132 pub fn client_id(&self, issuer: &Url) -> Option<ClientId> {
135 let mut data = self.read_or_generate_registration_data();
136 data.dynamic_registrations.extend(self.static_registrations.clone());
137 data.dynamic_registrations.get(issuer).cloned()
138 }
139
140 pub fn set_and_write_client_id(
143 &self,
144 client_id: ClientId,
145 issuer: Url,
146 ) -> Result<(), OidcRegistrationsError> {
147 let mut data = self.read_or_generate_registration_data();
148 data.dynamic_registrations.insert(issuer, client_id);
149
150 let writer = BufWriter::new(
151 File::create(&self.file_path)
152 .map_err(|e| OidcRegistrationsError::SaveFailure(Box::new(e)))?,
153 );
154 serde_json::to_writer(writer, &data)
155 .map_err(|e| OidcRegistrationsError::SaveFailure(Box::new(e)))
156 }
157
158 fn read_or_generate_registration_data(&self) -> FrozenRegistrationData {
160 let try_read_previous = || {
161 let reader = BufReader::new(
162 File::open(&self.file_path)
163 .map_err(|error| {
164 tracing::warn!("Failed to load registrations file: {error}");
165 })
166 .ok()?,
167 );
168
169 let registration_data: UnvalidatedRegistrationData = serde_json::from_reader(reader)
170 .map_err(|error| {
171 tracing::warn!("Failed to deserialize registrations file: {error}");
172 })
173 .ok()?;
174
175 let registration_data = registration_data
176 .validate()
177 .map_err(|error| {
178 tracing::warn!("Failed to validate registration data: {error}");
179 })
180 .ok()?;
181
182 if registration_data.metadata != self.verified_metadata {
183 tracing::warn!("Metadata mismatch, ignoring any stored registrations.");
184 return None;
185 }
186
187 Some(registration_data)
188 };
189
190 try_read_previous().unwrap_or_else(|| {
191 tracing::warn!("Generating new registration data");
192 FrozenRegistrationData {
193 metadata: self.verified_metadata.clone(),
194 dynamic_registrations: Default::default(),
195 }
196 })
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use mas_oidc_client::types::registration::Localized;
203 use tempfile::tempdir;
204
205 use super::*;
206
207 #[test]
208 fn test_oidc_registrations() {
209 let dir = tempdir().unwrap();
211 let registrations_file = dir.path().join("oidc").join("registrations.json");
212
213 let static_url = Url::parse("https://example.com").unwrap();
214 let static_id = ClientId::new("static_client_id".to_owned());
215 let dynamic_url = Url::parse("https://example.org").unwrap();
216 let dynamic_id = ClientId::new("dynamic_client_id".to_owned());
217
218 let mut static_registrations = HashMap::new();
219 static_registrations.insert(static_url.clone(), static_id.clone());
220
221 let oidc_metadata = mock_metadata("Example".to_owned());
222
223 let registrations =
224 OidcRegistrations::new(®istrations_file, oidc_metadata, static_registrations)
225 .unwrap();
226
227 assert_eq!(registrations.client_id(&static_url), Some(static_id.clone()));
228 assert_eq!(registrations.client_id(&dynamic_url), None);
229
230 registrations.set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone()).unwrap();
232
233 assert_eq!(registrations.client_id(&static_url), Some(static_id));
236 assert_eq!(registrations.client_id(&dynamic_url), Some(dynamic_id));
237 }
238
239 #[test]
240 fn test_change_of_metadata() {
241 let dir = tempdir().unwrap();
243 let registrations_file = dir.path().join("oidc").join("registrations.json");
244
245 let static_url = Url::parse("https://example.com").unwrap();
246 let static_id = ClientId::new("static_client_id".to_owned());
247 let dynamic_url = Url::parse("https://example.org").unwrap();
248 let dynamic_id = ClientId::new("dynamic_client_id".to_owned());
249
250 let oidc_metadata = mock_metadata("Example".to_owned());
251
252 let mut static_registrations = HashMap::new();
253 static_registrations.insert(static_url.clone(), static_id.clone());
254
255 let registrations = OidcRegistrations::new(
256 ®istrations_file,
257 oidc_metadata,
258 static_registrations.clone(),
259 )
260 .unwrap();
261 registrations.set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone()).unwrap();
262
263 assert_eq!(registrations.client_id(&static_url), Some(static_id.clone()));
264 assert_eq!(registrations.client_id(&dynamic_url), Some(dynamic_id));
265
266 let new_oidc_metadata = mock_metadata("New App".to_owned());
268
269 let registrations =
270 OidcRegistrations::new(®istrations_file, new_oidc_metadata, static_registrations)
271 .unwrap();
272
273 assert_eq!(registrations.client_id(&dynamic_url), None);
275 assert_eq!(registrations.client_id(&static_url), Some(static_id));
276 }
277
278 fn mock_metadata(client_name: String) -> VerifiedClientMetadata {
279 let callback_url = Url::parse("https://example.org/login/callback").unwrap();
280 let client_name = Some(Localized::new(client_name, None));
281
282 ClientMetadata {
283 redirect_uris: Some(vec![callback_url]),
284 client_name,
285 ..Default::default()
286 }
287 .validate()
288 .unwrap()
289 }
290}