matrix_sdk/authentication/oidc/
cross_process.rs

1use std::sync::Arc;
2
3#[cfg(feature = "e2e-encryption")]
4use matrix_sdk_base::crypto::{
5    store::{LockableCryptoStore, Store},
6    CryptoStoreError,
7};
8use matrix_sdk_common::store_locks::{
9    CrossProcessStoreLock, CrossProcessStoreLockGuard, LockStoreError,
10};
11use sha2::{Digest as _, Sha256};
12use thiserror::Error;
13use tokio::sync::{Mutex, OwnedMutexGuard};
14use tracing::trace;
15
16use crate::SessionTokens;
17
18/// Key in the database for the custom value holding the current session tokens
19/// hash.
20const OIDC_SESSION_HASH_KEY: &str = "oidc_session_hash";
21
22/// Newtype to identify that a value is a session tokens' hash.
23#[derive(Clone, PartialEq, Eq)]
24struct SessionHash(Vec<u8>);
25
26impl SessionHash {
27    fn to_hex(&self) -> String {
28        const CHARS: &[char; 16] =
29            &['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'];
30        let mut res = String::with_capacity(2 * self.0.len() + 2);
31        if !self.0.is_empty() {
32            res.push('0');
33            res.push('x');
34        }
35        for &c in &self.0 {
36            // We don't really care about little vs big endianness, since we only need a
37            // stable format, so we pick one: little endian (print high bits
38            // first).
39            res.push(CHARS[(c >> 4) as usize]);
40            res.push(CHARS[(c & 0b1111) as usize]);
41        }
42        res
43    }
44}
45
46impl std::fmt::Debug for SessionHash {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_tuple("SessionHash").field(&self.to_hex()).finish()
49    }
50}
51
52/// Compute a hash uniquely identifying the OIDC session tokens.
53fn compute_session_hash(tokens: &SessionTokens) -> SessionHash {
54    let mut hash = Sha256::new().chain_update(tokens.access_token.as_bytes());
55    if let Some(refresh_token) = &tokens.refresh_token {
56        hash = hash.chain_update(refresh_token.as_bytes());
57    }
58    SessionHash(hash.finalize().to_vec())
59}
60
61#[derive(Clone)]
62pub(super) struct CrossProcessRefreshManager {
63    store: Store,
64    store_lock: CrossProcessStoreLock<LockableCryptoStore>,
65    known_session_hash: Arc<Mutex<Option<SessionHash>>>,
66}
67
68impl CrossProcessRefreshManager {
69    /// Create a new `CrossProcessRefreshManager`.
70    pub fn new(store: Store, lock: CrossProcessStoreLock<LockableCryptoStore>) -> Self {
71        Self { store, store_lock: lock, known_session_hash: Arc::new(Mutex::new(None)) }
72    }
73
74    /// Wait for up to 60 seconds to get a cross-process store lock, then either
75    /// timeout (as an error) or return a lock guard.
76    ///
77    /// The guard also contains information useful to react upon another
78    /// background refresh having happened in the database already.
79    pub async fn spin_lock(
80        &self,
81    ) -> Result<CrossProcessRefreshLockGuard, CrossProcessRefreshLockError> {
82        // Acquire the intra-process mutex, to avoid multiple requests across threads in
83        // the current process.
84        trace!("Waiting for intra-process lock...");
85        let prev_hash = self.known_session_hash.clone().lock_owned().await;
86
87        // Acquire the cross-process mutex, to avoid multiple requests across different
88        // processus.
89        trace!("Waiting for inter-process lock...");
90        let store_guard = self.store_lock.spin_lock(Some(60000)).await?;
91
92        // Read the previous session hash in the database.
93        let current_db_session_bytes = self.store.get_custom_value(OIDC_SESSION_HASH_KEY).await?;
94
95        let db_hash = current_db_session_bytes.map(SessionHash);
96
97        let hash_mismatch = match (&db_hash, &*prev_hash) {
98            (None, _) => false,
99            (Some(_), None) => true,
100            (Some(db), Some(known)) => db != known,
101        };
102
103        trace!(
104            "Hash mismatch? {:?} (prev. known={:?}, db={:?})",
105            hash_mismatch,
106            *prev_hash,
107            db_hash
108        );
109
110        let guard = CrossProcessRefreshLockGuard {
111            hash_guard: prev_hash,
112            _store_guard: store_guard,
113            hash_mismatch,
114            db_hash,
115            store: self.store.clone(),
116        };
117
118        Ok(guard)
119    }
120
121    pub async fn restore_session(&self, tokens: &SessionTokens) {
122        let prev_tokens_hash = compute_session_hash(tokens);
123        *self.known_session_hash.lock().await = Some(prev_tokens_hash);
124    }
125
126    pub async fn on_logout(&self) -> Result<(), CrossProcessRefreshLockError> {
127        self.store
128            .remove_custom_value(OIDC_SESSION_HASH_KEY)
129            .await
130            .map_err(CrossProcessRefreshLockError::StoreError)?;
131        *self.known_session_hash.lock().await = None;
132        Ok(())
133    }
134}
135
136pub(super) struct CrossProcessRefreshLockGuard {
137    /// The hash for the latest session, either the one we knew, or the latest
138    /// one read from the database, if it was more up to date.
139    hash_guard: OwnedMutexGuard<Option<SessionHash>>,
140
141    /// Cross-process lock being hold.
142    _store_guard: CrossProcessStoreLockGuard,
143
144    /// Reference to the underlying store, for storing the hash of the latest
145    /// known session (as a custom value).
146    store: Store,
147
148    /// Do the in-memory hash and database hash mismatch?
149    ///
150    /// If so, this indicates that another process may have refreshed the token
151    /// in the background.
152    ///
153    /// We don't consider it a mismatch if there was no previous value in the
154    /// database. We do consider it a mismatch if there was no in-memory
155    /// value known, but one was known in the database.
156    pub hash_mismatch: bool,
157
158    /// Session hash previously stored in the DB.
159    ///
160    /// Used for debugging and testing purposes.
161    db_hash: Option<SessionHash>,
162}
163
164impl CrossProcessRefreshLockGuard {
165    /// Updates the `SessionTokens` hash in-memory only.
166    fn save_in_memory(&mut self, hash: SessionHash) {
167        *self.hash_guard = Some(hash);
168    }
169
170    /// Updates the `SessionTokens` hash in the database only.
171    async fn save_in_database(
172        &self,
173        hash: &SessionHash,
174    ) -> Result<(), CrossProcessRefreshLockError> {
175        self.store.set_custom_value(OIDC_SESSION_HASH_KEY, hash.0.clone()).await?;
176        Ok(())
177    }
178
179    /// Updates the `SessionTokens` hash in both memory and database.
180    ///
181    /// Must be called after a successful refresh.
182    pub async fn save_in_memory_and_db(
183        &mut self,
184        tokens: &SessionTokens,
185    ) -> Result<(), CrossProcessRefreshLockError> {
186        let hash = compute_session_hash(tokens);
187        self.save_in_database(&hash).await?;
188        self.save_in_memory(hash);
189        Ok(())
190    }
191
192    /// Handle a mismatch by making sure values in the database and memory match
193    /// tokens we trust.
194    pub async fn handle_mismatch(
195        &mut self,
196        trusted_tokens: &SessionTokens,
197    ) -> Result<(), CrossProcessRefreshLockError> {
198        let new_hash = compute_session_hash(trusted_tokens);
199        trace!("Trusted OIDC tokens have hash {new_hash:?}; db had {:?}", self.db_hash);
200
201        if let Some(db_hash) = &self.db_hash {
202            if new_hash != *db_hash {
203                // That should never happen, unless we got into an impossible situation!
204                // In this case, we assume the value returned by the callback is always
205                // correct, so override that in the database too.
206                tracing::error!("error: DB and trusted disagree. Overriding in DB.");
207                self.save_in_database(&new_hash).await?;
208            }
209        }
210
211        self.save_in_memory(new_hash);
212        Ok(())
213    }
214}
215
216/// An error that happened when interacting with the cross-process store lock
217/// during a token refresh.
218#[derive(Debug, Error)]
219pub enum CrossProcessRefreshLockError {
220    /// Underlying error caused by the store.
221    #[error(transparent)]
222    StoreError(#[from] CryptoStoreError),
223
224    /// The locking itself failed.
225    #[error(transparent)]
226    LockError(#[from] LockStoreError),
227
228    /// The previous hash isn't valid.
229    #[error("the previous stored hash isn't a valid integer")]
230    InvalidPreviousHash,
231
232    /// The lock hasn't been set up.
233    #[error("the cross-process lock hasn't been set up with `enable_cross_process_refresh_lock")]
234    MissingLock,
235
236    /// Cross-process lock was set, but without session callbacks.
237    #[error("reload session callback must be set with Client::set_session_callbacks() for the cross-process lock to work")]
238    MissingReloadSession,
239
240    /// The store has been created twice.
241    #[error(
242        "the cross-process lock has been set up twice with `enable_cross_process_refresh_lock`"
243    )]
244    DuplicatedLock,
245}
246
247#[cfg(all(test, feature = "e2e-encryption"))]
248mod tests {
249
250    use anyhow::Context as _;
251    use futures_util::future::join_all;
252    use matrix_sdk_base::SessionMeta;
253    use matrix_sdk_test::async_test;
254    use ruma::{owned_device_id, owned_user_id};
255
256    use super::compute_session_hash;
257    use crate::{
258        authentication::oidc::cross_process::SessionHash,
259        test_utils::{
260            client::{
261                mock_prev_session_tokens_with_refresh, mock_session_tokens_with_refresh,
262                oauth::mock_session, MockClientBuilder,
263            },
264            mocks::MatrixMockServer,
265        },
266        Error,
267    };
268
269    #[async_test]
270    async fn test_restore_session_lock() -> Result<(), Error> {
271        // Create a client that will use sqlite databases.
272
273        let tmp_dir = tempfile::tempdir()?;
274        let client = MockClientBuilder::new("https://example.org".to_owned())
275            .sqlite_store(&tmp_dir)
276            .unlogged()
277            .build()
278            .await;
279
280        let tokens = mock_session_tokens_with_refresh();
281
282        client.oidc().enable_cross_process_refresh_lock("test".to_owned()).await?;
283
284        client.set_session_callbacks(
285            Box::new({
286                // This is only called because of extra checks in the code.
287                let tokens = tokens.clone();
288                move |_| Ok(tokens.clone())
289            }),
290            Box::new(|_| panic!("save_session_callback shouldn't be called here")),
291        )?;
292
293        let session_hash = compute_session_hash(&tokens);
294        client
295            .oidc()
296            .restore_session(mock_session(
297                tokens.clone(),
298                "https://oidc.example.com/issuer".to_owned(),
299            ))
300            .await?;
301
302        assert_eq!(client.session_tokens().unwrap(), tokens);
303
304        let oidc = client.oidc();
305        let xp_manager = oidc.ctx().cross_process_token_refresh_manager.get().unwrap();
306
307        {
308            let known_session = xp_manager.known_session_hash.lock().await;
309            assert_eq!(known_session.as_ref().unwrap(), &session_hash);
310        }
311
312        {
313            let lock = xp_manager.spin_lock().await.unwrap();
314            assert!(!lock.hash_mismatch);
315            assert_eq!(lock.db_hash.unwrap(), session_hash);
316        }
317
318        Ok(())
319    }
320
321    #[async_test]
322    async fn test_finish_login() -> anyhow::Result<()> {
323        let server = MatrixMockServer::new().await;
324        server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
325
326        let tmp_dir = tempfile::tempdir()?;
327        let client = server
328            .client_builder()
329            .sqlite_store(&tmp_dir)
330            .registered_with_oauth(server.server().uri())
331            .build()
332            .await;
333        let oidc = client.oidc();
334
335        // Enable cross-process lock.
336        oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?;
337
338        // Simulate we've done finalize_authorization / restore_session before.
339        let session_tokens = mock_session_tokens_with_refresh();
340        client.auth_ctx().set_session_tokens(session_tokens.clone());
341
342        // Now, finishing logging will get the user and device ids.
343        oidc.finish_login().await?;
344
345        let session_meta = client.session_meta().context("should have session meta now")?;
346        assert_eq!(
347            *session_meta,
348            SessionMeta {
349                user_id: owned_user_id!("@joe:example.org"),
350                device_id: owned_device_id!("D3V1C31D")
351            }
352        );
353
354        {
355            // The cross process lock has been correctly updated, and the next attempt to
356            // take it won't result in a mismatch.
357            let xp_manager =
358                oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
359            let guard = xp_manager.spin_lock().await?;
360            let actual_hash = compute_session_hash(&session_tokens);
361            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
362            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
363            assert!(!guard.hash_mismatch);
364        }
365
366        Ok(())
367    }
368
369    #[async_test]
370    async fn test_refresh_access_token_twice() -> anyhow::Result<()> {
371        // This tests that refresh token works, and that it doesn't cause multiple token
372        // refreshes whenever one spawns two refreshes around the same time.
373
374        let server = MatrixMockServer::new().await;
375
376        let oauth_server = server.oauth();
377        oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
378        oauth_server.mock_token().ok().expect(1).named("token").mount().await;
379
380        let tmp_dir = tempfile::tempdir()?;
381        let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
382        let oidc = client.oidc();
383
384        let next_tokens = mock_session_tokens_with_refresh();
385
386        // Enable cross-process lock.
387        oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?;
388
389        // Restore the session.
390        oidc.restore_session(mock_session(
391            mock_prev_session_tokens_with_refresh(),
392            server.server().uri(),
393        ))
394        .await?;
395
396        // Immediately try to refresh the access token twice in parallel.
397        for result in join_all([oidc.refresh_access_token(), oidc.refresh_access_token()]).await {
398            result?;
399        }
400
401        {
402            // The cross process lock has been correctly updated, and the next attempt to
403            // take it won't result in a mismatch.
404            let xp_manager =
405                oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
406            let guard = xp_manager.spin_lock().await?;
407            let actual_hash = compute_session_hash(&next_tokens);
408            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
409            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
410            assert!(!guard.hash_mismatch);
411        }
412
413        Ok(())
414    }
415
416    #[async_test]
417    async fn test_cross_process_concurrent_refresh() -> anyhow::Result<()> {
418        let server = MatrixMockServer::new().await;
419        let issuer = server.server().uri();
420
421        let oauth_server = server.oauth();
422        oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
423        oauth_server.mock_token().ok().expect(1).named("token").mount().await;
424
425        let prev_tokens = mock_prev_session_tokens_with_refresh();
426        let next_tokens = mock_session_tokens_with_refresh();
427
428        // Create the first client.
429        let tmp_dir = tempfile::tempdir()?;
430        let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
431
432        let oidc = client.oidc();
433        oidc.enable_cross_process_refresh_lock("client1".to_owned()).await?;
434
435        oidc.restore_session(mock_session(prev_tokens.clone(), issuer.clone())).await?;
436
437        // Create a second client, without restoring it, to test that a token update
438        // before restoration doesn't cause new issues.
439        let unrestored_client =
440            server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
441        let unrestored_oidc = unrestored_client.oidc();
442        unrestored_oidc.enable_cross_process_refresh_lock("unrestored_client".to_owned()).await?;
443
444        {
445            // Create a third client that will run a refresh while the others two are doing
446            // nothing.
447            let client3 = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
448
449            let oidc3 = client3.oidc();
450            oidc3.enable_cross_process_refresh_lock("client3".to_owned()).await?;
451            oidc3.restore_session(mock_session(prev_tokens.clone(), issuer.clone())).await?;
452
453            // Run a refresh in the second client; this will invalidate the tokens from the
454            // first token.
455            oidc3.refresh_access_token().await?;
456
457            assert_eq!(client3.session_tokens(), Some(next_tokens.clone()));
458
459            // Reading from the cross-process lock for the second client only shows the new
460            // tokens.
461            let xp_manager =
462                oidc3.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
463            let guard = xp_manager.spin_lock().await?;
464            let actual_hash = compute_session_hash(&next_tokens);
465            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
466            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
467            assert!(!guard.hash_mismatch);
468        }
469
470        {
471            // Restoring the client that was not restored yet will work Just Fine.
472            let oidc = unrestored_oidc;
473
474            unrestored_client.set_session_callbacks(
475                Box::new({
476                    // This is only called because of extra checks in the code.
477                    let tokens = next_tokens.clone();
478                    move |_| Ok(tokens.clone())
479                }),
480                Box::new(|_| panic!("save_session_callback shouldn't be called here")),
481            )?;
482
483            oidc.restore_session(mock_session(prev_tokens.clone(), issuer)).await?;
484
485            // And this client is now aware of the latest tokens.
486            let xp_manager =
487                oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
488            let guard = xp_manager.spin_lock().await?;
489            let next_hash = compute_session_hash(&next_tokens);
490            assert_eq!(guard.db_hash.as_ref(), Some(&next_hash));
491            assert_eq!(guard.hash_guard.as_ref(), Some(&next_hash));
492            assert!(!guard.hash_mismatch);
493
494            drop(oidc);
495            drop(unrestored_client);
496        }
497
498        {
499            // The cross process lock has been correctly updated, and the next attempt to
500            // take it will result in a mismatch.
501            let xp_manager =
502                oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
503            let guard = xp_manager.spin_lock().await?;
504            let previous_hash = compute_session_hash(&prev_tokens);
505            let next_hash = compute_session_hash(&next_tokens);
506            assert_eq!(guard.db_hash, Some(next_hash));
507            assert_eq!(guard.hash_guard.as_ref(), Some(&previous_hash));
508            assert!(guard.hash_mismatch);
509        }
510
511        client.set_session_callbacks(
512            Box::new({
513                // This is only called because of extra checks in the code.
514                let tokens = next_tokens.clone();
515                move |_| Ok(tokens.clone())
516            }),
517            Box::new(|_| panic!("save_session_callback shouldn't be called here")),
518        )?;
519
520        oidc.refresh_access_token().await?;
521
522        {
523            // The next attempt to take the lock isn't a mismatch.
524            let xp_manager =
525                oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
526            let guard = xp_manager.spin_lock().await?;
527            let actual_hash = compute_session_hash(&next_tokens);
528            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
529            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
530            assert!(!guard.hash_mismatch);
531        }
532
533        Ok(())
534    }
535
536    #[async_test]
537    async fn test_logout() -> anyhow::Result<()> {
538        let server = MatrixMockServer::new().await;
539
540        let oauth_server = server.oauth();
541        oauth_server
542            .mock_server_metadata()
543            .ok_https()
544            .expect(1..)
545            .named("server_metadata")
546            .mount()
547            .await;
548        oauth_server.mock_revocation().ok().expect(1).named("revocation").mount().await;
549
550        let tmp_dir = tempfile::tempdir()?;
551        let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
552        let oidc = client.oidc().insecure_rewrite_https_to_http();
553
554        // Enable cross-process lock.
555        oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?;
556
557        // Restore the session.
558        let tokens = mock_session_tokens_with_refresh();
559        oidc.restore_session(mock_session(tokens.clone(), server.server().uri())).await?;
560
561        oidc.logout().await?;
562
563        {
564            // The cross process lock has been correctly updated, and all the hashes are
565            // empty after a logout.
566            let xp_manager =
567                oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
568            let guard = xp_manager.spin_lock().await?;
569            assert!(guard.db_hash.is_none());
570            assert!(guard.hash_guard.is_none());
571            assert!(!guard.hash_mismatch);
572        }
573
574        Ok(())
575    }
576
577    #[test]
578    fn test_session_hash_to_hex() {
579        let hash = SessionHash(vec![]);
580        assert_eq!(hash.to_hex(), "");
581
582        let hash = SessionHash(vec![0x13, 0x37, 0x42, 0xde, 0xad, 0xca, 0xfe]);
583        assert_eq!(hash.to_hex(), "0x133742deadcafe");
584    }
585}