matrix_sdk/authentication/oauth/
cross_process.rs

1use std::sync::Arc;
2
3#[cfg(feature = "e2e-encryption")]
4use matrix_sdk_base::crypto::{
5    store::{LockableCryptoStore, Store},
6    CryptoStoreError,
7};
8use matrix_sdk_common::store_locks::{
9    CrossProcessStoreLock, CrossProcessStoreLockGuard, LockStoreError,
10};
11use sha2::{Digest as _, Sha256};
12use thiserror::Error;
13use tokio::sync::{Mutex, OwnedMutexGuard};
14use tracing::trace;
15
16use crate::SessionTokens;
17
18/// Key in the database for the custom value holding the current session tokens
19/// hash.
20const OIDC_SESSION_HASH_KEY: &str = "oidc_session_hash";
21
22/// Newtype to identify that a value is a session tokens' hash.
23#[derive(Clone, PartialEq, Eq)]
24struct SessionHash(Vec<u8>);
25
26impl SessionHash {
27    fn to_hex(&self) -> String {
28        const CHARS: &[char; 16] =
29            &['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'];
30        let mut res = String::with_capacity(2 * self.0.len() + 2);
31        if !self.0.is_empty() {
32            res.push('0');
33            res.push('x');
34        }
35        for &c in &self.0 {
36            // We don't really care about little vs big endianness, since we only need a
37            // stable format, so we pick one: little endian (print high bits
38            // first).
39            res.push(CHARS[(c >> 4) as usize]);
40            res.push(CHARS[(c & 0b1111) as usize]);
41        }
42        res
43    }
44}
45
46impl std::fmt::Debug for SessionHash {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_tuple("SessionHash").field(&self.to_hex()).finish()
49    }
50}
51
52/// Compute a hash uniquely identifying the OAuth 2.0 session tokens.
53fn compute_session_hash(tokens: &SessionTokens) -> SessionHash {
54    let mut hash = Sha256::new().chain_update(tokens.access_token.as_bytes());
55    if let Some(refresh_token) = &tokens.refresh_token {
56        hash = hash.chain_update(refresh_token.as_bytes());
57    }
58    SessionHash(hash.finalize().to_vec())
59}
60
61#[derive(Clone)]
62pub(super) struct CrossProcessRefreshManager {
63    store: Store,
64    store_lock: CrossProcessStoreLock<LockableCryptoStore>,
65    known_session_hash: Arc<Mutex<Option<SessionHash>>>,
66}
67
68impl CrossProcessRefreshManager {
69    /// Create a new `CrossProcessRefreshManager`.
70    pub fn new(store: Store, lock: CrossProcessStoreLock<LockableCryptoStore>) -> Self {
71        Self { store, store_lock: lock, known_session_hash: Arc::new(Mutex::new(None)) }
72    }
73
74    /// Wait for up to 60 seconds to get a cross-process store lock, then either
75    /// timeout (as an error) or return a lock guard.
76    ///
77    /// The guard also contains information useful to react upon another
78    /// background refresh having happened in the database already.
79    pub async fn spin_lock(
80        &self,
81    ) -> Result<CrossProcessRefreshLockGuard, CrossProcessRefreshLockError> {
82        // Acquire the intra-process mutex, to avoid multiple requests across threads in
83        // the current process.
84        trace!("Waiting for intra-process lock...");
85        let prev_hash = self.known_session_hash.clone().lock_owned().await;
86
87        // Acquire the cross-process mutex, to avoid multiple requests across different
88        // processus.
89        trace!("Waiting for inter-process lock...");
90        let store_guard = self.store_lock.spin_lock(Some(60000)).await?;
91
92        // Read the previous session hash in the database.
93        let current_db_session_bytes = self.store.get_custom_value(OIDC_SESSION_HASH_KEY).await?;
94
95        let db_hash = current_db_session_bytes.map(SessionHash);
96
97        let hash_mismatch = match (&db_hash, &*prev_hash) {
98            (None, _) => false,
99            (Some(_), None) => true,
100            (Some(db), Some(known)) => db != known,
101        };
102
103        trace!(hash_mismatch, ?prev_hash, ?db_hash);
104
105        let guard = CrossProcessRefreshLockGuard {
106            hash_guard: prev_hash,
107            _store_guard: store_guard,
108            hash_mismatch,
109            db_hash,
110            store: self.store.clone(),
111        };
112
113        Ok(guard)
114    }
115
116    pub async fn restore_session(&self, tokens: &SessionTokens) {
117        let prev_tokens_hash = compute_session_hash(tokens);
118        *self.known_session_hash.lock().await = Some(prev_tokens_hash);
119    }
120
121    pub async fn on_logout(&self) -> Result<(), CrossProcessRefreshLockError> {
122        self.store
123            .remove_custom_value(OIDC_SESSION_HASH_KEY)
124            .await
125            .map_err(CrossProcessRefreshLockError::StoreError)?;
126        *self.known_session_hash.lock().await = None;
127        Ok(())
128    }
129}
130
131pub(super) struct CrossProcessRefreshLockGuard {
132    /// The hash for the latest session, either the one we knew, or the latest
133    /// one read from the database, if it was more up to date.
134    hash_guard: OwnedMutexGuard<Option<SessionHash>>,
135
136    /// Cross-process lock being hold.
137    _store_guard: CrossProcessStoreLockGuard,
138
139    /// Reference to the underlying store, for storing the hash of the latest
140    /// known session (as a custom value).
141    store: Store,
142
143    /// Do the in-memory hash and database hash mismatch?
144    ///
145    /// If so, this indicates that another process may have refreshed the token
146    /// in the background.
147    ///
148    /// We don't consider it a mismatch if there was no previous value in the
149    /// database. We do consider it a mismatch if there was no in-memory
150    /// value known, but one was known in the database.
151    pub hash_mismatch: bool,
152
153    /// Session hash previously stored in the DB.
154    ///
155    /// Used for debugging and testing purposes.
156    db_hash: Option<SessionHash>,
157}
158
159impl CrossProcessRefreshLockGuard {
160    /// Updates the `SessionTokens` hash in-memory only.
161    fn save_in_memory(&mut self, hash: SessionHash) {
162        *self.hash_guard = Some(hash);
163    }
164
165    /// Updates the `SessionTokens` hash in the database only.
166    async fn save_in_database(
167        &self,
168        hash: &SessionHash,
169    ) -> Result<(), CrossProcessRefreshLockError> {
170        self.store.set_custom_value(OIDC_SESSION_HASH_KEY, hash.0.clone()).await?;
171        Ok(())
172    }
173
174    /// Updates the `SessionTokens` hash in both memory and database.
175    ///
176    /// Must be called after a successful refresh.
177    pub async fn save_in_memory_and_db(
178        &mut self,
179        tokens: &SessionTokens,
180    ) -> Result<(), CrossProcessRefreshLockError> {
181        let hash = compute_session_hash(tokens);
182        self.save_in_database(&hash).await?;
183        self.save_in_memory(hash);
184        Ok(())
185    }
186
187    /// Handle a mismatch by making sure values in the database and memory match
188    /// tokens we trust.
189    pub async fn handle_mismatch(
190        &mut self,
191        trusted_tokens: &SessionTokens,
192    ) -> Result<(), CrossProcessRefreshLockError> {
193        let new_hash = compute_session_hash(trusted_tokens);
194        trace!("Trusted OAuth 2.0 tokens have hash {new_hash:?}; db had {:?}", self.db_hash);
195
196        if let Some(db_hash) = &self.db_hash {
197            if new_hash != *db_hash {
198                // That should never happen, unless we got into an impossible situation!
199                // In this case, we assume the value returned by the callback is always
200                // correct, so override that in the database too.
201                tracing::error!("error: DB and trusted disagree. Overriding in DB.");
202                self.save_in_database(&new_hash).await?;
203            }
204        }
205
206        self.save_in_memory(new_hash);
207        Ok(())
208    }
209}
210
211/// An error that happened when interacting with the cross-process store lock
212/// during a token refresh.
213#[derive(Debug, Error)]
214pub enum CrossProcessRefreshLockError {
215    /// Underlying error caused by the store.
216    #[error(transparent)]
217    StoreError(#[from] CryptoStoreError),
218
219    /// The locking itself failed.
220    #[error(transparent)]
221    LockError(#[from] LockStoreError),
222
223    /// The previous hash isn't valid.
224    #[error("the previous stored hash isn't a valid integer")]
225    InvalidPreviousHash,
226
227    /// The lock hasn't been set up.
228    #[error("the cross-process lock hasn't been set up with `enable_cross_process_refresh_lock")]
229    MissingLock,
230
231    /// Cross-process lock was set, but without session callbacks.
232    #[error(
233        "reload session callback must be set with Client::set_session_callbacks() \
234         for the cross-process lock to work"
235    )]
236    MissingReloadSession,
237
238    /// The store has been created twice.
239    #[error(
240        "the cross-process lock has been set up twice with `enable_cross_process_refresh_lock`"
241    )]
242    DuplicatedLock,
243}
244
245#[cfg(all(test, feature = "e2e-encryption", feature = "sqlite", not(target_family = "wasm")))]
246mod tests {
247
248    use anyhow::Context as _;
249    use futures_util::future::join_all;
250    use matrix_sdk_base::{store::RoomLoadSettings, SessionMeta};
251    use matrix_sdk_test::async_test;
252    use ruma::{owned_device_id, owned_user_id};
253
254    use super::compute_session_hash;
255    use crate::{
256        authentication::oauth::cross_process::SessionHash,
257        test_utils::{
258            client::{
259                mock_prev_session_tokens_with_refresh, mock_session_tokens_with_refresh,
260                oauth::mock_session, MockClientBuilder,
261            },
262            mocks::MatrixMockServer,
263        },
264        Error,
265    };
266
267    #[async_test]
268    async fn test_restore_session_lock() -> Result<(), Error> {
269        // Create a client that will use sqlite databases.
270
271        let tmp_dir = tempfile::tempdir()?;
272        let client = MockClientBuilder::new(None)
273            .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
274            .unlogged()
275            .build()
276            .await;
277
278        let tokens = mock_session_tokens_with_refresh();
279
280        client.oauth().enable_cross_process_refresh_lock("test".to_owned()).await?;
281
282        client.set_session_callbacks(
283            Box::new({
284                // This is only called because of extra checks in the code.
285                let tokens = tokens.clone();
286                move |_| Ok(tokens.clone())
287            }),
288            Box::new(|_| panic!("save_session_callback shouldn't be called here")),
289        )?;
290
291        let session_hash = compute_session_hash(&tokens);
292        client
293            .oauth()
294            .restore_session(mock_session(tokens.clone()), RoomLoadSettings::default())
295            .await?;
296
297        assert_eq!(client.session_tokens().unwrap(), tokens);
298
299        let oauth = client.oauth();
300        let xp_manager = oauth.ctx().cross_process_token_refresh_manager.get().unwrap();
301
302        {
303            let known_session = xp_manager.known_session_hash.lock().await;
304            assert_eq!(known_session.as_ref().unwrap(), &session_hash);
305        }
306
307        {
308            let lock = xp_manager.spin_lock().await.unwrap();
309            assert!(!lock.hash_mismatch);
310            assert_eq!(lock.db_hash.unwrap(), session_hash);
311        }
312
313        Ok(())
314    }
315
316    #[async_test]
317    async fn test_finish_login() -> anyhow::Result<()> {
318        let server = MatrixMockServer::new().await;
319        server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
320
321        let tmp_dir = tempfile::tempdir()?;
322        let client = server
323            .client_builder()
324            .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
325            .registered_with_oauth()
326            .build()
327            .await;
328        let oauth = client.oauth();
329
330        // Enable cross-process lock.
331        oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
332
333        // Simulate we've done finalize_authorization / restore_session before.
334        let session_tokens = mock_session_tokens_with_refresh();
335        client.auth_ctx().set_session_tokens(session_tokens.clone());
336
337        // Now, finishing logging will get the user ID.
338        oauth.load_session(owned_device_id!("D3V1C31D")).await?;
339
340        let session_meta = client.session_meta().context("should have session meta now")?;
341        assert_eq!(
342            *session_meta,
343            SessionMeta {
344                user_id: owned_user_id!("@joe:example.org"),
345                device_id: owned_device_id!("D3V1C31D")
346            }
347        );
348
349        {
350            // The cross process lock has been correctly updated, and the next attempt to
351            // take it won't result in a mismatch.
352            let xp_manager =
353                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
354            let guard = xp_manager.spin_lock().await?;
355            let actual_hash = compute_session_hash(&session_tokens);
356            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
357            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
358            assert!(!guard.hash_mismatch);
359        }
360
361        Ok(())
362    }
363
364    #[async_test]
365    async fn test_refresh_access_token_twice() -> anyhow::Result<()> {
366        // This tests that refresh token works, and that it doesn't cause multiple token
367        // refreshes whenever one spawns two refreshes around the same time.
368
369        let server = MatrixMockServer::new().await;
370
371        let oauth_server = server.oauth();
372        oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
373        oauth_server.mock_token().ok().expect(1).named("token").mount().await;
374
375        let tmp_dir = tempfile::tempdir()?;
376        let client = server
377            .client_builder()
378            .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
379            .unlogged()
380            .build()
381            .await;
382        let oauth = client.oauth();
383
384        let next_tokens = mock_session_tokens_with_refresh();
385
386        // Enable cross-process lock.
387        oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
388
389        // Restore the session.
390        oauth
391            .restore_session(
392                mock_session(mock_prev_session_tokens_with_refresh()),
393                RoomLoadSettings::default(),
394            )
395            .await?;
396
397        // Immediately try to refresh the access token twice in parallel.
398        for result in join_all([oauth.refresh_access_token(), oauth.refresh_access_token()]).await {
399            result?;
400        }
401
402        {
403            // The cross process lock has been correctly updated, and the next attempt to
404            // take it won't result in a mismatch.
405            let xp_manager =
406                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
407            let guard = xp_manager.spin_lock().await?;
408            let actual_hash = compute_session_hash(&next_tokens);
409            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
410            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
411            assert!(!guard.hash_mismatch);
412        }
413
414        Ok(())
415    }
416
417    #[async_test]
418    async fn test_cross_process_concurrent_refresh() -> anyhow::Result<()> {
419        let server = MatrixMockServer::new().await;
420
421        let oauth_server = server.oauth();
422        oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
423        oauth_server.mock_token().ok().expect(1).named("token").mount().await;
424
425        let prev_tokens = mock_prev_session_tokens_with_refresh();
426        let next_tokens = mock_session_tokens_with_refresh();
427
428        // Create the first client.
429        let tmp_dir = tempfile::tempdir()?;
430        let client = server
431            .client_builder()
432            .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
433            .unlogged()
434            .build()
435            .await;
436
437        let oauth = client.oauth();
438        oauth.enable_cross_process_refresh_lock("client1".to_owned()).await?;
439
440        oauth
441            .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
442            .await?;
443
444        // Create a second client, without restoring it, to test that a token update
445        // before restoration doesn't cause new issues.
446        let unrestored_client = server
447            .client_builder()
448            .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
449            .unlogged()
450            .build()
451            .await;
452        let unrestored_oauth = unrestored_client.oauth();
453        unrestored_oauth.enable_cross_process_refresh_lock("unrestored_client".to_owned()).await?;
454
455        {
456            // Create a third client that will run a refresh while the others two are doing
457            // nothing.
458            let client3 = server
459                .client_builder()
460                .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
461                .unlogged()
462                .build()
463                .await;
464
465            let oauth3 = client3.oauth();
466            oauth3.enable_cross_process_refresh_lock("client3".to_owned()).await?;
467            oauth3
468                .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
469                .await?;
470
471            // Run a refresh in the second client; this will invalidate the tokens from the
472            // first token.
473            oauth3.refresh_access_token().await?;
474
475            assert_eq!(client3.session_tokens(), Some(next_tokens.clone()));
476
477            // Reading from the cross-process lock for the second client only shows the new
478            // tokens.
479            let xp_manager =
480                oauth3.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
481            let guard = xp_manager.spin_lock().await?;
482            let actual_hash = compute_session_hash(&next_tokens);
483            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
484            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
485            assert!(!guard.hash_mismatch);
486        }
487
488        {
489            // Restoring the client that was not restored yet will work Just Fine.
490            let oauth = unrestored_oauth;
491
492            unrestored_client.set_session_callbacks(
493                Box::new({
494                    // This is only called because of extra checks in the code.
495                    let tokens = next_tokens.clone();
496                    move |_| Ok(tokens.clone())
497                }),
498                Box::new(|_| panic!("save_session_callback shouldn't be called here")),
499            )?;
500
501            oauth
502                .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
503                .await?;
504
505            // And this client is now aware of the latest tokens.
506            let xp_manager =
507                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
508            let guard = xp_manager.spin_lock().await?;
509            let next_hash = compute_session_hash(&next_tokens);
510            assert_eq!(guard.db_hash.as_ref(), Some(&next_hash));
511            assert_eq!(guard.hash_guard.as_ref(), Some(&next_hash));
512            assert!(!guard.hash_mismatch);
513
514            drop(oauth);
515            drop(unrestored_client);
516        }
517
518        {
519            // The cross process lock has been correctly updated, and the next attempt to
520            // take it will result in a mismatch.
521            let xp_manager =
522                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
523            let guard = xp_manager.spin_lock().await?;
524            let previous_hash = compute_session_hash(&prev_tokens);
525            let next_hash = compute_session_hash(&next_tokens);
526            assert_eq!(guard.db_hash, Some(next_hash));
527            assert_eq!(guard.hash_guard.as_ref(), Some(&previous_hash));
528            assert!(guard.hash_mismatch);
529        }
530
531        client.set_session_callbacks(
532            Box::new({
533                // This is only called because of extra checks in the code.
534                let tokens = next_tokens.clone();
535                move |_| Ok(tokens.clone())
536            }),
537            Box::new(|_| panic!("save_session_callback shouldn't be called here")),
538        )?;
539
540        oauth.refresh_access_token().await?;
541
542        {
543            // The next attempt to take the lock isn't a mismatch.
544            let xp_manager =
545                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
546            let guard = xp_manager.spin_lock().await?;
547            let actual_hash = compute_session_hash(&next_tokens);
548            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
549            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
550            assert!(!guard.hash_mismatch);
551        }
552
553        Ok(())
554    }
555
556    #[async_test]
557    async fn test_logout() -> anyhow::Result<()> {
558        let server = MatrixMockServer::new().await;
559
560        let oauth_server = server.oauth();
561        oauth_server
562            .mock_server_metadata()
563            .ok_https()
564            .expect(1..)
565            .named("server_metadata")
566            .mount()
567            .await;
568        oauth_server.mock_revocation().ok().expect(1).named("revocation").mount().await;
569
570        let tmp_dir = tempfile::tempdir()?;
571        let client = server
572            .client_builder()
573            .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
574            .unlogged()
575            .build()
576            .await;
577        let oauth = client.oauth().insecure_rewrite_https_to_http();
578
579        // Enable cross-process lock.
580        oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
581
582        // Restore the session.
583        let tokens = mock_session_tokens_with_refresh();
584        oauth.restore_session(mock_session(tokens.clone()), RoomLoadSettings::default()).await?;
585
586        oauth.logout().await.unwrap();
587
588        {
589            // The cross process lock has been correctly updated, and all the hashes are
590            // empty after a logout.
591            let xp_manager =
592                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
593            let guard = xp_manager.spin_lock().await?;
594            assert!(guard.db_hash.is_none());
595            assert!(guard.hash_guard.is_none());
596            assert!(!guard.hash_mismatch);
597        }
598
599        Ok(())
600    }
601
602    #[test]
603    fn test_session_hash_to_hex() {
604        let hash = SessionHash(vec![]);
605        assert_eq!(hash.to_hex(), "");
606
607        let hash = SessionHash(vec![0x13, 0x37, 0x42, 0xde, 0xad, 0xca, 0xfe]);
608        assert_eq!(hash.to_hex(), "0x133742deadcafe");
609    }
610}