matrix_sdk/authentication/oauth/
registration_store.rs

1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! OAuth 2.0 client registration store.
16//!
17//! This module provides a way to persist OAuth 2.0 client registrations outside
18//! of the state store. This is useful when using a `Client` with an in-memory
19//! store or when different store paths are used for multi-account support
20//! within the same app, and those accounts need to share the same OAuth 2.0
21//! client registration.
22
23use std::{collections::HashMap, io::ErrorKind, path::PathBuf};
24
25use oauth2::ClientId;
26use ruma::serde::Raw;
27use serde::{Deserialize, Serialize};
28use tokio::fs;
29use url::Url;
30
31use super::ClientMetadata;
32
33/// Errors that can occur when using the [`OAuthRegistrationStore`].
34#[derive(Debug, thiserror::Error)]
35pub enum OAuthRegistrationStoreError {
36    /// The supplied path is not a file path.
37    #[error("supplied registrations path is not a file path")]
38    NotAFilePath,
39    /// An error occurred when reading from or writing to the file.
40    #[error(transparent)]
41    File(#[from] std::io::Error),
42    /// An error occurred when serializing the registration data.
43    #[error("failed to serialize registration data: {0}")]
44    IntoJson(serde_json::Error),
45    /// An error occurred when deserializing the registration data.
46    #[error("failed to deserialize registration data: {0}")]
47    FromJson(serde_json::Error),
48}
49
50/// An API to store and restore OAuth 2.0 client registrations.
51///
52/// This stores dynamic client registrations in a file, and accepts "static"
53/// client registrations via
54/// [`OAuthRegistrationStore::with_static_registrations()`], for servers that
55/// don't support dynamic client registration.
56///
57/// If the client metadata passed to this API changes, the previous
58/// registrations that were stored in the file are invalidated, allowing to
59/// re-register with the new metadata.
60///
61/// The purpose of storing client IDs outside of the state store or separate
62/// from the user's session is that it allows to reuse the same client ID
63/// between user sessions on the same server.
64#[derive(Debug)]
65pub struct OAuthRegistrationStore {
66    /// The path of the file where the registrations are stored.
67    pub(super) file_path: PathBuf,
68    /// The metadata used to register the client.
69    /// This is used to check if the client needs to be re-registered.
70    pub(super) metadata: Raw<ClientMetadata>,
71    /// Pre-configured registrations for use with issuers that don't support
72    /// dynamic client registration.
73    static_registrations: Option<HashMap<Url, ClientId>>,
74}
75
76/// The underlying data serialized into the registration file.
77#[derive(Debug, Serialize, Deserialize)]
78struct FrozenRegistrationData {
79    /// The metadata used to register the client.
80    metadata: Raw<ClientMetadata>,
81    /// All of the registrations this client has made as a HashMap of issuer URL
82    /// to client ID.
83    dynamic_registrations: HashMap<Url, ClientId>,
84}
85
86impl OAuthRegistrationStore {
87    /// Creates a new registration store.
88    ///
89    /// This method creates the `file`'s parent directory if it doesn't exist.
90    ///
91    /// # Arguments
92    ///
93    /// * `file` - A file path where the registrations will be stored. This
94    ///   previously took a directory and stored the registrations with the path
95    ///   `supplied_directory/oidc/registrations.json`.
96    ///
97    /// * `metadata` - The metadata used to register the client. If this changes
98    ///   compared to the value stored in the file, any stored registrations
99    ///   will be invalidated so the client can re-register with the new data.
100    pub async fn new(
101        file: PathBuf,
102        metadata: Raw<ClientMetadata>,
103    ) -> Result<Self, OAuthRegistrationStoreError> {
104        let parent = file.parent().ok_or(OAuthRegistrationStoreError::NotAFilePath)?;
105        fs::create_dir_all(parent).await?;
106
107        Ok(OAuthRegistrationStore { file_path: file, metadata, static_registrations: None })
108    }
109
110    /// Add static registrations to the store.
111    ///
112    /// Static registrations are used for servers that don't support dynamic
113    /// registration but provide a client ID out-of-band.
114    ///
115    /// These registrations are not stored in the file and must be provided each
116    /// time.
117    pub fn with_static_registrations(
118        mut self,
119        static_registrations: HashMap<Url, ClientId>,
120    ) -> Self {
121        self.static_registrations = Some(static_registrations);
122        self
123    }
124
125    /// Returns the client ID registered for a particular issuer or `None` if a
126    /// registration hasn't been made.
127    ///
128    /// # Arguments
129    ///
130    /// * `issuer` - The issuer to look up.
131    ///
132    /// # Errors
133    ///
134    /// Returns an error if the file could not be read, or if the data in the
135    /// file could not be deserialized.
136    pub async fn client_id(
137        &self,
138        issuer: &Url,
139    ) -> Result<Option<ClientId>, OAuthRegistrationStoreError> {
140        if let Some(client_id) =
141            self.static_registrations.as_ref().and_then(|registrations| registrations.get(issuer))
142        {
143            return Ok(Some(client_id.clone()));
144        }
145
146        let data = self.read_registration_data().await?;
147        Ok(data.and_then(|mut data| data.dynamic_registrations.remove(issuer)))
148    }
149
150    /// Stores a new client ID registration for a particular issuer.
151    ///
152    /// If a client ID has already been stored for the given issuer, this will
153    /// overwrite the old value.
154    ///
155    /// # Arguments
156    ///
157    /// * `client_id` - The client ID obtained after registration.
158    ///
159    /// * `issuer` - The issuer associated with the client ID.
160    ///
161    /// # Errors
162    ///
163    /// Returns an error if the file could not be read from or written to, or if
164    /// the data in the file could not be (de)serialized.
165    pub async fn set_and_write_client_id(
166        &self,
167        client_id: ClientId,
168        issuer: Url,
169    ) -> Result<(), OAuthRegistrationStoreError> {
170        let mut data = self.read_registration_data().await?.unwrap_or_else(|| {
171            tracing::info!("Generating new OAuth 2.0 client registration data");
172            FrozenRegistrationData {
173                metadata: self.metadata.clone(),
174                dynamic_registrations: Default::default(),
175            }
176        });
177        data.dynamic_registrations.insert(issuer, client_id);
178
179        let contents = serde_json::to_vec(&data).map_err(OAuthRegistrationStoreError::IntoJson)?;
180        fs::write(&self.file_path, contents).await?;
181
182        Ok(())
183    }
184
185    /// The persisted registration data.
186    ///
187    /// # Errors
188    ///
189    /// Returns an error if the file could not be read, or if the data in the
190    /// file could not be deserialized.
191    async fn read_registration_data(
192        &self,
193    ) -> Result<Option<FrozenRegistrationData>, OAuthRegistrationStoreError> {
194        let contents = match fs::read(&self.file_path).await {
195            Ok(contents) => contents,
196            Err(error) => {
197                if error.kind() == ErrorKind::NotFound {
198                    // The file doesn't exist so there is no data.
199                    return Ok(None);
200                }
201
202                // Forward the error.
203                return Err(error.into());
204            }
205        };
206
207        let registration_data: FrozenRegistrationData =
208            serde_json::from_slice(&contents).map_err(OAuthRegistrationStoreError::FromJson)?;
209
210        if registration_data.metadata.json().get() != self.metadata.json().get() {
211            tracing::info!("Metadata mismatch, ignoring any stored registrations.");
212            Ok(None)
213        } else {
214            Ok(Some(registration_data))
215        }
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use matrix_sdk_test::async_test;
222    use tempfile::tempdir;
223
224    use super::*;
225    use crate::authentication::oauth::registration::{ApplicationType, Localized, OAuthGrantType};
226
227    #[async_test]
228    async fn test_oauth_registration_store() {
229        // Given a fresh registration store with a single static registration.
230        let dir = tempdir().unwrap();
231        let registrations_file = dir.path().join("oauth").join("registrations.json");
232
233        let static_url = Url::parse("https://example.com").unwrap();
234        let static_id = ClientId::new("static_client_id".to_owned());
235        let dynamic_url = Url::parse("https://example.org").unwrap();
236        let dynamic_id = ClientId::new("dynamic_client_id".to_owned());
237
238        let mut static_registrations = HashMap::new();
239        static_registrations.insert(static_url.clone(), static_id.clone());
240
241        let oidc_metadata = mock_metadata("Example".to_owned());
242
243        let registrations = OAuthRegistrationStore::new(registrations_file, oidc_metadata)
244            .await
245            .unwrap()
246            .with_static_registrations(static_registrations);
247
248        assert_eq!(registrations.client_id(&static_url).await.unwrap(), Some(static_id.clone()));
249        assert_eq!(registrations.client_id(&dynamic_url).await.unwrap(), None);
250
251        // When a dynamic registration is added.
252        registrations
253            .set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone())
254            .await
255            .unwrap();
256
257        // Then the dynamic registration should be stored and the static registration
258        // should be unaffected.
259        assert_eq!(registrations.client_id(&static_url).await.unwrap(), Some(static_id));
260        assert_eq!(registrations.client_id(&dynamic_url).await.unwrap(), Some(dynamic_id));
261    }
262
263    #[async_test]
264    async fn test_change_of_metadata() {
265        // Given a single registration with an example app name.
266        let dir = tempdir().unwrap();
267        let registrations_file = dir.path().join("oidc").join("registrations.json");
268
269        let static_url = Url::parse("https://example.com").unwrap();
270        let static_id = ClientId::new("static_client_id".to_owned());
271        let dynamic_url = Url::parse("https://example.org").unwrap();
272        let dynamic_id = ClientId::new("dynamic_client_id".to_owned());
273
274        let oidc_metadata = mock_metadata("Example".to_owned());
275
276        let mut static_registrations = HashMap::new();
277        static_registrations.insert(static_url.clone(), static_id.clone());
278
279        let registrations = OAuthRegistrationStore::new(registrations_file.clone(), oidc_metadata)
280            .await
281            .unwrap()
282            .with_static_registrations(static_registrations.clone());
283        registrations
284            .set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone())
285            .await
286            .unwrap();
287
288        assert_eq!(registrations.client_id(&static_url).await.unwrap(), Some(static_id.clone()));
289        assert_eq!(registrations.client_id(&dynamic_url).await.unwrap(), Some(dynamic_id));
290
291        // When the app name changes.
292        let new_oidc_metadata = mock_metadata("New App".to_owned());
293
294        let registrations = OAuthRegistrationStore::new(registrations_file, new_oidc_metadata)
295            .await
296            .unwrap()
297            .with_static_registrations(static_registrations);
298
299        // Then the dynamic registrations are cleared.
300        assert_eq!(registrations.client_id(&dynamic_url).await.unwrap(), None);
301        assert_eq!(registrations.client_id(&static_url).await.unwrap(), Some(static_id));
302    }
303
304    fn mock_metadata(client_name: String) -> Raw<ClientMetadata> {
305        let callback_url = Url::parse("https://example.org/login/callback").unwrap();
306        let client_uri = Url::parse("https://example.org/").unwrap();
307
308        let mut metadata = ClientMetadata::new(
309            ApplicationType::Web,
310            vec![OAuthGrantType::AuthorizationCode { redirect_uris: vec![callback_url] }],
311            Localized::new(client_uri, None),
312        );
313        metadata.client_name = Some(Localized::new(client_name, None));
314
315        Raw::new(&metadata).unwrap()
316    }
317}