matrix_sdk/authentication/oidc/
registrations.rs

1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! OpenID Connect client registration management.
16//!
17//! This module provides a way to persist OIDC client registrations outside of
18//! the state store. This is useful when using a `Client` with an in-memory
19//! store or when different store paths are used for multi-account support
20//! within the same app, and those accounts need to share the same OIDC client
21//! registration.
22
23use std::{
24    collections::HashMap,
25    fs,
26    fs::File,
27    io::{BufReader, BufWriter},
28    path::{Path, PathBuf},
29};
30
31use mas_oidc_client::types::registration::{
32    ClientMetadata, ClientMetadataVerificationError, VerifiedClientMetadata,
33};
34pub use oauth2::ClientId;
35use serde::{Deserialize, Serialize};
36use url::Url;
37
38/// Errors related to persisting OIDC registrations.
39#[derive(Debug, thiserror::Error)]
40pub enum OidcRegistrationsError {
41    /// The supplied registrations file path is invalid.
42    #[error("Failed to use the supplied registrations file path.")]
43    InvalidFilePath,
44    /// An error occurred whilst saving the registration data.
45    #[error("Failed to save the registration data {0}.")]
46    SaveFailure(#[source] Box<dyn std::error::Error + Send + Sync>),
47}
48
49/// The data needed to restore an OpenID Connect session.
50#[derive(Debug)]
51pub struct OidcRegistrations {
52    /// The path of the file where the registrations are stored.
53    file_path: PathBuf,
54    /// The hash for the metadata used to register the client.
55    /// This is used to check if the client needs to be re-registered.
56    pub(super) verified_metadata: VerifiedClientMetadata,
57    /// Pre-configured registrations for use with issuers that don't support
58    /// dynamic client registration.
59    static_registrations: HashMap<Url, ClientId>,
60}
61
62/// The underlying data serialized into the registration file.
63#[derive(Debug, Serialize)]
64struct FrozenRegistrationData {
65    /// The hash for the metadata used to register the client.
66    metadata: VerifiedClientMetadata,
67    /// All of the registrations this client has made as a HashMap of issuer URL
68    /// (as a string) to client ID (as a string).
69    dynamic_registrations: HashMap<Url, ClientId>,
70}
71
72/// The deserialize data from the registration file. This data needs to be
73/// validated before it can be used.
74#[derive(Debug, Deserialize)]
75struct UnvalidatedRegistrationData {
76    /// The hash for the metadata used to register the client.
77    metadata: ClientMetadata,
78    /// All of the registrations this client has made as a HashMap of issuer URL
79    /// (as a string) to client ID (as a string).
80    dynamic_registrations: HashMap<Url, ClientId>,
81}
82
83impl UnvalidatedRegistrationData {
84    /// Validates the registration data, returning a `FrozenRegistrationData`.
85    fn validate(&self) -> Result<FrozenRegistrationData, ClientMetadataVerificationError> {
86        let verified_metadata = match self.metadata.clone().validate() {
87            Ok(metadata) => metadata,
88            Err(e) => {
89                tracing::warn!("Failed to validate stored metadata.");
90                return Err(e);
91            }
92        };
93
94        Ok(FrozenRegistrationData {
95            metadata: verified_metadata,
96            dynamic_registrations: self.dynamic_registrations.clone(),
97        })
98    }
99}
100
101/// Manages the storage of OIDC registrations.
102impl OidcRegistrations {
103    /// Creates a new registration store.
104    ///
105    /// # Arguments
106    ///
107    /// * `registrations_file` - A file path where the registrations will be
108    ///   stored. This previously took a directory and stored the registrations
109    ///   with the path `supplied_directory/oidc/registrations.json`.
110    ///
111    /// * `metadata` - The metadata used to register the client. If this
112    ///   changes, any stored registrations will be lost so the client can
113    ///   re-register with the new data.
114    ///
115    /// * `static_registrations` - Pre-configured registrations for use with
116    ///   issuers that don't support dynamic client registration.
117    pub fn new(
118        registrations_file: &Path,
119        metadata: VerifiedClientMetadata,
120        static_registrations: HashMap<Url, ClientId>,
121    ) -> Result<Self, OidcRegistrationsError> {
122        let parent = registrations_file.parent().ok_or(OidcRegistrationsError::InvalidFilePath)?;
123        fs::create_dir_all(parent).map_err(|_| OidcRegistrationsError::InvalidFilePath)?;
124
125        Ok(OidcRegistrations {
126            file_path: registrations_file.to_owned(),
127            verified_metadata: metadata,
128            static_registrations,
129        })
130    }
131
132    /// Returns the client ID registered for a particular issuer or None if a
133    /// registration hasn't been made.
134    pub fn client_id(&self, issuer: &Url) -> Option<ClientId> {
135        let mut data = self.read_or_generate_registration_data();
136        data.dynamic_registrations.extend(self.static_registrations.clone());
137        data.dynamic_registrations.get(issuer).cloned()
138    }
139
140    /// Stores a new client ID registration for a particular issuer. If a client
141    /// ID has already been stored, this will overwrite the old value.
142    pub fn set_and_write_client_id(
143        &self,
144        client_id: ClientId,
145        issuer: Url,
146    ) -> Result<(), OidcRegistrationsError> {
147        let mut data = self.read_or_generate_registration_data();
148        data.dynamic_registrations.insert(issuer, client_id);
149
150        let writer = BufWriter::new(
151            File::create(&self.file_path)
152                .map_err(|e| OidcRegistrationsError::SaveFailure(Box::new(e)))?,
153        );
154        serde_json::to_writer(writer, &data)
155            .map_err(|e| OidcRegistrationsError::SaveFailure(Box::new(e)))
156    }
157
158    /// Returns the underlying registration data, or generates a new one.
159    fn read_or_generate_registration_data(&self) -> FrozenRegistrationData {
160        let try_read_previous = || {
161            let reader = BufReader::new(
162                File::open(&self.file_path)
163                    .map_err(|error| {
164                        tracing::warn!("Failed to load registrations file: {error}");
165                    })
166                    .ok()?,
167            );
168
169            let registration_data: UnvalidatedRegistrationData = serde_json::from_reader(reader)
170                .map_err(|error| {
171                    tracing::warn!("Failed to deserialize registrations file: {error}");
172                })
173                .ok()?;
174
175            let registration_data = registration_data
176                .validate()
177                .map_err(|error| {
178                    tracing::warn!("Failed to validate registration data: {error}");
179                })
180                .ok()?;
181
182            if registration_data.metadata != self.verified_metadata {
183                tracing::warn!("Metadata mismatch, ignoring any stored registrations.");
184                return None;
185            }
186
187            Some(registration_data)
188        };
189
190        try_read_previous().unwrap_or_else(|| {
191            tracing::warn!("Generating new registration data");
192            FrozenRegistrationData {
193                metadata: self.verified_metadata.clone(),
194                dynamic_registrations: Default::default(),
195            }
196        })
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use mas_oidc_client::types::registration::Localized;
203    use tempfile::tempdir;
204
205    use super::*;
206
207    #[test]
208    fn test_oidc_registrations() {
209        // Given a fresh registration store with a single static registration.
210        let dir = tempdir().unwrap();
211        let registrations_file = dir.path().join("oidc").join("registrations.json");
212
213        let static_url = Url::parse("https://example.com").unwrap();
214        let static_id = ClientId::new("static_client_id".to_owned());
215        let dynamic_url = Url::parse("https://example.org").unwrap();
216        let dynamic_id = ClientId::new("dynamic_client_id".to_owned());
217
218        let mut static_registrations = HashMap::new();
219        static_registrations.insert(static_url.clone(), static_id.clone());
220
221        let oidc_metadata = mock_metadata("Example".to_owned());
222
223        let registrations =
224            OidcRegistrations::new(&registrations_file, oidc_metadata, static_registrations)
225                .unwrap();
226
227        assert_eq!(registrations.client_id(&static_url), Some(static_id.clone()));
228        assert_eq!(registrations.client_id(&dynamic_url), None);
229
230        // When a dynamic registration is added.
231        registrations.set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone()).unwrap();
232
233        // Then the dynamic registration should be stored and the static registration
234        // should be unaffected.
235        assert_eq!(registrations.client_id(&static_url), Some(static_id));
236        assert_eq!(registrations.client_id(&dynamic_url), Some(dynamic_id));
237    }
238
239    #[test]
240    fn test_change_of_metadata() {
241        // Given a single registration with an example app name.
242        let dir = tempdir().unwrap();
243        let registrations_file = dir.path().join("oidc").join("registrations.json");
244
245        let static_url = Url::parse("https://example.com").unwrap();
246        let static_id = ClientId::new("static_client_id".to_owned());
247        let dynamic_url = Url::parse("https://example.org").unwrap();
248        let dynamic_id = ClientId::new("dynamic_client_id".to_owned());
249
250        let oidc_metadata = mock_metadata("Example".to_owned());
251
252        let mut static_registrations = HashMap::new();
253        static_registrations.insert(static_url.clone(), static_id.clone());
254
255        let registrations = OidcRegistrations::new(
256            &registrations_file,
257            oidc_metadata,
258            static_registrations.clone(),
259        )
260        .unwrap();
261        registrations.set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone()).unwrap();
262
263        assert_eq!(registrations.client_id(&static_url), Some(static_id.clone()));
264        assert_eq!(registrations.client_id(&dynamic_url), Some(dynamic_id));
265
266        // When the app name changes.
267        let new_oidc_metadata = mock_metadata("New App".to_owned());
268
269        let registrations =
270            OidcRegistrations::new(&registrations_file, new_oidc_metadata, static_registrations)
271                .unwrap();
272
273        // Then the dynamic registrations are cleared.
274        assert_eq!(registrations.client_id(&dynamic_url), None);
275        assert_eq!(registrations.client_id(&static_url), Some(static_id));
276    }
277
278    fn mock_metadata(client_name: String) -> VerifiedClientMetadata {
279        let callback_url = Url::parse("https://example.org/login/callback").unwrap();
280        let client_name = Some(Localized::new(client_name, None));
281
282        ClientMetadata {
283            redirect_uris: Some(vec![callback_url]),
284            client_name,
285            ..Default::default()
286        }
287        .validate()
288        .unwrap()
289    }
290}