matrix_sdk/authentication/oauth/
cross_process.rs

1use std::sync::Arc;
2
3#[cfg(feature = "e2e-encryption")]
4use matrix_sdk_base::crypto::{
5    CryptoStoreError,
6    store::{LockableCryptoStore, Store},
7};
8use matrix_sdk_common::cross_process_lock::{
9    CrossProcessLock, CrossProcessLockError, CrossProcessLockGuard,
10};
11use sha2::{Digest as _, Sha256};
12use thiserror::Error;
13use tokio::sync::{Mutex, OwnedMutexGuard};
14use tracing::trace;
15
16use crate::SessionTokens;
17
18/// Key in the database for the custom value holding the current session tokens
19/// hash.
20const OIDC_SESSION_HASH_KEY: &str = "oidc_session_hash";
21
22/// Newtype to identify that a value is a session tokens' hash.
23#[derive(Clone, PartialEq, Eq)]
24struct SessionHash(Vec<u8>);
25
26impl SessionHash {
27    fn to_hex(&self) -> String {
28        const CHARS: &[char; 16] =
29            &['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'];
30        let mut res = String::with_capacity(2 * self.0.len() + 2);
31        if !self.0.is_empty() {
32            res.push('0');
33            res.push('x');
34        }
35        for &c in &self.0 {
36            // We don't really care about little vs big endianness, since we only need a
37            // stable format, so we pick one: little endian (print high bits
38            // first).
39            res.push(CHARS[(c >> 4) as usize]);
40            res.push(CHARS[(c & 0b1111) as usize]);
41        }
42        res
43    }
44}
45
46impl std::fmt::Debug for SessionHash {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_tuple("SessionHash").field(&self.to_hex()).finish()
49    }
50}
51
52/// Compute a hash uniquely identifying the OAuth 2.0 session tokens.
53fn compute_session_hash(tokens: &SessionTokens) -> SessionHash {
54    let mut hash = Sha256::new().chain_update(tokens.access_token.as_bytes());
55    if let Some(refresh_token) = &tokens.refresh_token {
56        hash = hash.chain_update(refresh_token.as_bytes());
57    }
58    SessionHash(hash.finalize().to_vec())
59}
60
61#[derive(Clone)]
62pub(super) struct CrossProcessRefreshManager {
63    store: Store,
64    store_lock: CrossProcessLock<LockableCryptoStore>,
65    known_session_hash: Arc<Mutex<Option<SessionHash>>>,
66}
67
68impl CrossProcessRefreshManager {
69    /// Create a new `CrossProcessRefreshManager`.
70    pub fn new(store: Store, lock: CrossProcessLock<LockableCryptoStore>) -> Self {
71        Self { store, store_lock: lock, known_session_hash: Arc::new(Mutex::new(None)) }
72    }
73
74    /// Wait for up to 60 seconds to get a cross-process store lock, then either
75    /// timeout (as an error) or return a lock guard.
76    ///
77    /// The guard also contains information useful to react upon another
78    /// background refresh having happened in the database already.
79    pub async fn spin_lock(
80        &self,
81    ) -> Result<CrossProcessRefreshLockGuard, CrossProcessRefreshLockError> {
82        // Acquire the intra-process mutex, to avoid multiple requests across threads in
83        // the current process.
84        trace!("Waiting for intra-process lock...");
85        let prev_hash = self.known_session_hash.clone().lock_owned().await;
86
87        // Acquire the cross-process mutex, to avoid multiple requests across different
88        // processus.
89        trace!("Waiting for inter-process lock...");
90        let store_guard = self
91            .store_lock
92            .spin_lock(Some(60000))
93            .await
94            .map_err(|err| {
95                CrossProcessRefreshLockError::LockError(CrossProcessLockError::TryLock(Box::new(
96                    err,
97                )))
98            })?
99            .map_err(|err| CrossProcessRefreshLockError::LockError(err.into()))?;
100
101        // Read the previous session hash in the database.
102        let current_db_session_bytes = self.store.get_custom_value(OIDC_SESSION_HASH_KEY).await?;
103
104        let db_hash = current_db_session_bytes.map(SessionHash);
105
106        let hash_mismatch = match (&db_hash, &*prev_hash) {
107            (None, _) => false,
108            (Some(_), None) => true,
109            (Some(db), Some(known)) => db != known,
110        };
111
112        trace!(hash_mismatch, ?prev_hash, ?db_hash);
113
114        let guard = CrossProcessRefreshLockGuard {
115            hash_guard: prev_hash,
116            _store_guard: store_guard.into_guard(),
117            hash_mismatch,
118            db_hash,
119            store: self.store.clone(),
120        };
121
122        Ok(guard)
123    }
124
125    pub async fn restore_session(&self, tokens: &SessionTokens) {
126        let prev_tokens_hash = compute_session_hash(tokens);
127        *self.known_session_hash.lock().await = Some(prev_tokens_hash);
128    }
129
130    pub async fn on_logout(&self) -> Result<(), CrossProcessRefreshLockError> {
131        self.store
132            .remove_custom_value(OIDC_SESSION_HASH_KEY)
133            .await
134            .map_err(CrossProcessRefreshLockError::StoreError)?;
135        *self.known_session_hash.lock().await = None;
136        Ok(())
137    }
138}
139
140pub(super) struct CrossProcessRefreshLockGuard {
141    /// The hash for the latest session, either the one we knew, or the latest
142    /// one read from the database, if it was more up to date.
143    hash_guard: OwnedMutexGuard<Option<SessionHash>>,
144
145    /// Cross-process lock being hold.
146    _store_guard: CrossProcessLockGuard,
147
148    /// Reference to the underlying store, for storing the hash of the latest
149    /// known session (as a custom value).
150    store: Store,
151
152    /// Do the in-memory hash and database hash mismatch?
153    ///
154    /// If so, this indicates that another process may have refreshed the token
155    /// in the background.
156    ///
157    /// We don't consider it a mismatch if there was no previous value in the
158    /// database. We do consider it a mismatch if there was no in-memory
159    /// value known, but one was known in the database.
160    pub hash_mismatch: bool,
161
162    /// Session hash previously stored in the DB.
163    ///
164    /// Used for debugging and testing purposes.
165    db_hash: Option<SessionHash>,
166}
167
168impl CrossProcessRefreshLockGuard {
169    /// Updates the `SessionTokens` hash in-memory only.
170    fn save_in_memory(&mut self, hash: SessionHash) {
171        *self.hash_guard = Some(hash);
172    }
173
174    /// Updates the `SessionTokens` hash in the database only.
175    async fn save_in_database(
176        &self,
177        hash: &SessionHash,
178    ) -> Result<(), CrossProcessRefreshLockError> {
179        self.store.set_custom_value(OIDC_SESSION_HASH_KEY, hash.0.clone()).await?;
180        Ok(())
181    }
182
183    /// Updates the `SessionTokens` hash in both memory and database.
184    ///
185    /// Must be called after a successful refresh.
186    pub async fn save_in_memory_and_db(
187        &mut self,
188        tokens: &SessionTokens,
189    ) -> Result<(), CrossProcessRefreshLockError> {
190        let hash = compute_session_hash(tokens);
191        self.save_in_database(&hash).await?;
192        self.save_in_memory(hash);
193        Ok(())
194    }
195
196    /// Handle a mismatch by making sure values in the database and memory match
197    /// tokens we trust.
198    pub async fn handle_mismatch(
199        &mut self,
200        trusted_tokens: &SessionTokens,
201    ) -> Result<(), CrossProcessRefreshLockError> {
202        let new_hash = compute_session_hash(trusted_tokens);
203        trace!("Trusted OAuth 2.0 tokens have hash {new_hash:?}; db had {:?}", self.db_hash);
204
205        if let Some(db_hash) = &self.db_hash
206            && new_hash != *db_hash
207        {
208            // That should never happen, unless we got into an impossible situation!
209            // In this case, we assume the value returned by the callback is always
210            // correct, so override that in the database too.
211            tracing::error!("error: DB and trusted disagree. Overriding in DB.");
212            self.save_in_database(&new_hash).await?;
213        }
214
215        self.save_in_memory(new_hash);
216        Ok(())
217    }
218}
219
220/// An error that happened when interacting with the cross-process store lock
221/// during a token refresh.
222#[derive(Debug, Error)]
223pub enum CrossProcessRefreshLockError {
224    /// Underlying error caused by the store.
225    #[error(transparent)]
226    StoreError(#[from] CryptoStoreError),
227
228    /// The locking itself failed.
229    #[error(transparent)]
230    LockError(#[from] CrossProcessLockError),
231
232    /// The previous hash isn't valid.
233    #[error("the previous stored hash isn't a valid integer")]
234    InvalidPreviousHash,
235
236    /// The lock hasn't been set up.
237    #[error("the cross-process lock hasn't been set up with `enable_cross_process_refresh_lock")]
238    MissingLock,
239
240    /// Cross-process lock was set, but without session callbacks.
241    #[error(
242        "reload session callback must be set with Client::set_session_callbacks() \
243         for the cross-process lock to work"
244    )]
245    MissingReloadSession,
246
247    /// The store has been created twice.
248    #[error(
249        "the cross-process lock has been set up twice with `enable_cross_process_refresh_lock`"
250    )]
251    DuplicatedLock,
252}
253
254#[cfg(all(test, feature = "e2e-encryption", feature = "sqlite", not(target_family = "wasm")))]
255mod tests {
256
257    use anyhow::Context as _;
258    use futures_util::future::join_all;
259    use matrix_sdk_base::{SessionMeta, store::RoomLoadSettings};
260    use matrix_sdk_test::async_test;
261    use ruma::{owned_device_id, owned_user_id};
262
263    use super::compute_session_hash;
264    use crate::{
265        Error,
266        authentication::oauth::cross_process::SessionHash,
267        test_utils::{
268            client::{
269                MockClientBuilder, mock_prev_session_tokens_with_refresh,
270                mock_session_tokens_with_refresh, oauth::mock_session,
271            },
272            mocks::MatrixMockServer,
273        },
274    };
275
276    #[async_test]
277    async fn test_restore_session_lock() -> Result<(), Error> {
278        // Create a client that will use sqlite databases.
279
280        let tmp_dir = tempfile::tempdir()?;
281        let client = MockClientBuilder::new(None)
282            .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
283            .unlogged()
284            .build()
285            .await;
286
287        let tokens = mock_session_tokens_with_refresh();
288
289        client.oauth().enable_cross_process_refresh_lock("test".to_owned()).await?;
290
291        client.set_session_callbacks(
292            Box::new({
293                // This is only called because of extra checks in the code.
294                let tokens = tokens.clone();
295                move |_| Ok(tokens.clone())
296            }),
297            Box::new(|_| panic!("save_session_callback shouldn't be called here")),
298        )?;
299
300        let session_hash = compute_session_hash(&tokens);
301        client
302            .oauth()
303            .restore_session(mock_session(tokens.clone()), RoomLoadSettings::default())
304            .await?;
305
306        assert_eq!(client.session_tokens().unwrap(), tokens);
307
308        let oauth = client.oauth();
309        let xp_manager = oauth.ctx().cross_process_token_refresh_manager.get().unwrap();
310
311        {
312            let known_session = xp_manager.known_session_hash.lock().await;
313            assert_eq!(known_session.as_ref().unwrap(), &session_hash);
314        }
315
316        {
317            let lock = xp_manager.spin_lock().await.unwrap();
318            assert!(!lock.hash_mismatch);
319            assert_eq!(lock.db_hash.unwrap(), session_hash);
320        }
321
322        Ok(())
323    }
324
325    #[async_test]
326    async fn test_finish_login() -> anyhow::Result<()> {
327        let server = MatrixMockServer::new().await;
328        server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
329
330        let tmp_dir = tempfile::tempdir()?;
331        let client = server
332            .client_builder()
333            .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
334            .registered_with_oauth()
335            .build()
336            .await;
337        let oauth = client.oauth();
338
339        // Enable cross-process lock.
340        oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
341
342        // Simulate we've done finalize_authorization / restore_session before.
343        let session_tokens = mock_session_tokens_with_refresh();
344        client.auth_ctx().set_session_tokens(session_tokens.clone());
345
346        // Now, finishing logging will get the user ID.
347        oauth.load_session(owned_device_id!("D3V1C31D")).await?;
348
349        let session_meta = client.session_meta().context("should have session meta now")?;
350        assert_eq!(
351            *session_meta,
352            SessionMeta {
353                user_id: owned_user_id!("@joe:example.org"),
354                device_id: owned_device_id!("D3V1C31D")
355            }
356        );
357
358        {
359            // The cross process lock has been correctly updated, and the next attempt to
360            // take it won't result in a mismatch.
361            let xp_manager =
362                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
363            let guard = xp_manager.spin_lock().await?;
364            let actual_hash = compute_session_hash(&session_tokens);
365            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
366            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
367            assert!(!guard.hash_mismatch);
368        }
369
370        Ok(())
371    }
372
373    #[async_test]
374    async fn test_refresh_access_token_twice() -> anyhow::Result<()> {
375        // This tests that refresh token works, and that it doesn't cause multiple token
376        // refreshes whenever one spawns two refreshes around the same time.
377
378        let server = MatrixMockServer::new().await;
379
380        let oauth_server = server.oauth();
381        oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
382        oauth_server.mock_token().ok().expect(1).named("token").mount().await;
383
384        let tmp_dir = tempfile::tempdir()?;
385        let client = server
386            .client_builder()
387            .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
388            .unlogged()
389            .build()
390            .await;
391        let oauth = client.oauth();
392
393        let next_tokens = mock_session_tokens_with_refresh();
394
395        // Enable cross-process lock.
396        oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
397
398        // Restore the session.
399        oauth
400            .restore_session(
401                mock_session(mock_prev_session_tokens_with_refresh()),
402                RoomLoadSettings::default(),
403            )
404            .await?;
405
406        // Immediately try to refresh the access token twice in parallel.
407        for result in join_all([oauth.refresh_access_token(), oauth.refresh_access_token()]).await {
408            result?;
409        }
410
411        {
412            // The cross process lock has been correctly updated, and the next attempt to
413            // take it won't result in a mismatch.
414            let xp_manager =
415                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
416            let guard = xp_manager.spin_lock().await?;
417            let actual_hash = compute_session_hash(&next_tokens);
418            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
419            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
420            assert!(!guard.hash_mismatch);
421        }
422
423        Ok(())
424    }
425
426    #[async_test]
427    async fn test_cross_process_concurrent_refresh() -> anyhow::Result<()> {
428        let server = MatrixMockServer::new().await;
429
430        let oauth_server = server.oauth();
431        oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
432        oauth_server.mock_token().ok().expect(1).named("token").mount().await;
433
434        let prev_tokens = mock_prev_session_tokens_with_refresh();
435        let next_tokens = mock_session_tokens_with_refresh();
436
437        // Create the first client.
438        let tmp_dir = tempfile::tempdir()?;
439        let client = server
440            .client_builder()
441            .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
442            .unlogged()
443            .build()
444            .await;
445
446        let oauth = client.oauth();
447        oauth.enable_cross_process_refresh_lock("client1".to_owned()).await?;
448
449        oauth
450            .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
451            .await?;
452
453        // Create a second client, without restoring it, to test that a token update
454        // before restoration doesn't cause new issues.
455        let unrestored_client = server
456            .client_builder()
457            .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
458            .unlogged()
459            .build()
460            .await;
461        let unrestored_oauth = unrestored_client.oauth();
462        unrestored_oauth.enable_cross_process_refresh_lock("unrestored_client".to_owned()).await?;
463
464        {
465            // Create a third client that will run a refresh while the others two are doing
466            // nothing.
467            let client3 = server
468                .client_builder()
469                .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
470                .unlogged()
471                .build()
472                .await;
473
474            let oauth3 = client3.oauth();
475            oauth3.enable_cross_process_refresh_lock("client3".to_owned()).await?;
476            oauth3
477                .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
478                .await?;
479
480            // Run a refresh in the second client; this will invalidate the tokens from the
481            // first token.
482            oauth3.refresh_access_token().await?;
483
484            assert_eq!(client3.session_tokens(), Some(next_tokens.clone()));
485
486            // Reading from the cross-process lock for the second client only shows the new
487            // tokens.
488            let xp_manager =
489                oauth3.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
490            let guard = xp_manager.spin_lock().await?;
491            let actual_hash = compute_session_hash(&next_tokens);
492            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
493            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
494            assert!(!guard.hash_mismatch);
495        }
496
497        {
498            // Restoring the client that was not restored yet will work Just Fine.
499            let oauth = unrestored_oauth;
500
501            unrestored_client.set_session_callbacks(
502                Box::new({
503                    // This is only called because of extra checks in the code.
504                    let tokens = next_tokens.clone();
505                    move |_| Ok(tokens.clone())
506                }),
507                Box::new(|_| panic!("save_session_callback shouldn't be called here")),
508            )?;
509
510            oauth
511                .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
512                .await?;
513
514            // And this client is now aware of the latest tokens.
515            let xp_manager =
516                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
517            let guard = xp_manager.spin_lock().await?;
518            let next_hash = compute_session_hash(&next_tokens);
519            assert_eq!(guard.db_hash.as_ref(), Some(&next_hash));
520            assert_eq!(guard.hash_guard.as_ref(), Some(&next_hash));
521            assert!(!guard.hash_mismatch);
522
523            drop(oauth);
524            drop(unrestored_client);
525        }
526
527        {
528            // The cross process lock has been correctly updated, and the next attempt to
529            // take it will result in a mismatch.
530            let xp_manager =
531                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
532            let guard = xp_manager.spin_lock().await?;
533            let previous_hash = compute_session_hash(&prev_tokens);
534            let next_hash = compute_session_hash(&next_tokens);
535            assert_eq!(guard.db_hash, Some(next_hash));
536            assert_eq!(guard.hash_guard.as_ref(), Some(&previous_hash));
537            assert!(guard.hash_mismatch);
538        }
539
540        client.set_session_callbacks(
541            Box::new({
542                // This is only called because of extra checks in the code.
543                let tokens = next_tokens.clone();
544                move |_| Ok(tokens.clone())
545            }),
546            Box::new(|_| panic!("save_session_callback shouldn't be called here")),
547        )?;
548
549        oauth.refresh_access_token().await?;
550
551        {
552            // The next attempt to take the lock isn't a mismatch.
553            let xp_manager =
554                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
555            let guard = xp_manager.spin_lock().await?;
556            let actual_hash = compute_session_hash(&next_tokens);
557            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
558            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
559            assert!(!guard.hash_mismatch);
560        }
561
562        Ok(())
563    }
564
565    #[async_test]
566    async fn test_logout() -> anyhow::Result<()> {
567        let server = MatrixMockServer::new().await;
568
569        let oauth_server = server.oauth();
570        oauth_server
571            .mock_server_metadata()
572            .ok_https()
573            .expect(1..)
574            .named("server_metadata")
575            .mount()
576            .await;
577        oauth_server.mock_revocation().ok().expect(1).named("revocation").mount().await;
578
579        let tmp_dir = tempfile::tempdir()?;
580        let client = server
581            .client_builder()
582            .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
583            .unlogged()
584            .build()
585            .await;
586        let oauth = client.oauth().insecure_rewrite_https_to_http();
587
588        // Enable cross-process lock.
589        oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
590
591        // Restore the session.
592        let tokens = mock_session_tokens_with_refresh();
593        oauth.restore_session(mock_session(tokens.clone()), RoomLoadSettings::default()).await?;
594
595        oauth.logout().await.unwrap();
596
597        {
598            // The cross process lock has been correctly updated, and all the hashes are
599            // empty after a logout.
600            let xp_manager =
601                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
602            let guard = xp_manager.spin_lock().await?;
603            assert!(guard.db_hash.is_none());
604            assert!(guard.hash_guard.is_none());
605            assert!(!guard.hash_mismatch);
606        }
607
608        Ok(())
609    }
610
611    #[test]
612    fn test_session_hash_to_hex() {
613        let hash = SessionHash(vec![]);
614        assert_eq!(hash.to_hex(), "");
615
616        let hash = SessionHash(vec![0x13, 0x37, 0x42, 0xde, 0xad, 0xca, 0xfe]);
617        assert_eq!(hash.to_hex(), "0x133742deadcafe");
618    }
619}