1use std::sync::Arc;
23#[cfg(feature = "e2e-encryption")]
4use matrix_sdk_base::crypto::{
5 store::{LockableCryptoStore, Store},
6 CryptoStoreError,
7};
8use matrix_sdk_common::store_locks::{
9 CrossProcessStoreLock, CrossProcessStoreLockGuard, LockStoreError,
10};
11use sha2::{Digest as _, Sha256};
12use thiserror::Error;
13use tokio::sync::{Mutex, OwnedMutexGuard};
14use tracing::trace;
1516use crate::SessionTokens;
1718/// Key in the database for the custom value holding the current session tokens
19/// hash.
20const OIDC_SESSION_HASH_KEY: &str = "oidc_session_hash";
2122/// Newtype to identify that a value is a session tokens' hash.
23#[derive(Clone, PartialEq, Eq)]
24struct SessionHash(Vec<u8>);
2526impl SessionHash {
27fn to_hex(&self) -> String {
28const CHARS: &[char; 16] =
29&['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'];
30let mut res = String::with_capacity(2 * self.0.len() + 2);
31if !self.0.is_empty() {
32 res.push('0');
33 res.push('x');
34 }
35for &c in &self.0 {
36// We don't really care about little vs big endianness, since we only need a
37 // stable format, so we pick one: little endian (print high bits
38 // first).
39res.push(CHARS[(c >> 4) as usize]);
40 res.push(CHARS[(c & 0b1111) as usize]);
41 }
42 res
43 }
44}
4546impl std::fmt::Debug for SessionHash {
47fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_tuple("SessionHash").field(&self.to_hex()).finish()
49 }
50}
5152/// Compute a hash uniquely identifying the OAuth 2.0 session tokens.
53fn compute_session_hash(tokens: &SessionTokens) -> SessionHash {
54let mut hash = Sha256::new().chain_update(tokens.access_token.as_bytes());
55if let Some(refresh_token) = &tokens.refresh_token {
56 hash = hash.chain_update(refresh_token.as_bytes());
57 }
58 SessionHash(hash.finalize().to_vec())
59}
6061#[derive(Clone)]
62pub(super) struct CrossProcessRefreshManager {
63 store: Store,
64 store_lock: CrossProcessStoreLock<LockableCryptoStore>,
65 known_session_hash: Arc<Mutex<Option<SessionHash>>>,
66}
6768impl CrossProcessRefreshManager {
69/// Create a new `CrossProcessRefreshManager`.
70pub fn new(store: Store, lock: CrossProcessStoreLock<LockableCryptoStore>) -> Self {
71Self { store, store_lock: lock, known_session_hash: Arc::new(Mutex::new(None)) }
72 }
7374/// Wait for up to 60 seconds to get a cross-process store lock, then either
75 /// timeout (as an error) or return a lock guard.
76 ///
77 /// The guard also contains information useful to react upon another
78 /// background refresh having happened in the database already.
79pub async fn spin_lock(
80&self,
81 ) -> Result<CrossProcessRefreshLockGuard, CrossProcessRefreshLockError> {
82// Acquire the intra-process mutex, to avoid multiple requests across threads in
83 // the current process.
84trace!("Waiting for intra-process lock...");
85let prev_hash = self.known_session_hash.clone().lock_owned().await;
8687// Acquire the cross-process mutex, to avoid multiple requests across different
88 // processus.
89trace!("Waiting for inter-process lock...");
90let store_guard = self.store_lock.spin_lock(Some(60000)).await?;
9192// Read the previous session hash in the database.
93let current_db_session_bytes = self.store.get_custom_value(OIDC_SESSION_HASH_KEY).await?;
9495let db_hash = current_db_session_bytes.map(SessionHash);
9697let hash_mismatch = match (&db_hash, &*prev_hash) {
98 (None, _) => false,
99 (Some(_), None) => true,
100 (Some(db), Some(known)) => db != known,
101 };
102103trace!(
104"Hash mismatch? {:?} (prev. known={:?}, db={:?})",
105 hash_mismatch,
106*prev_hash,
107 db_hash
108 );
109110let guard = CrossProcessRefreshLockGuard {
111 hash_guard: prev_hash,
112 _store_guard: store_guard,
113 hash_mismatch,
114 db_hash,
115 store: self.store.clone(),
116 };
117118Ok(guard)
119 }
120121pub async fn restore_session(&self, tokens: &SessionTokens) {
122let prev_tokens_hash = compute_session_hash(tokens);
123*self.known_session_hash.lock().await = Some(prev_tokens_hash);
124 }
125126pub async fn on_logout(&self) -> Result<(), CrossProcessRefreshLockError> {
127self.store
128 .remove_custom_value(OIDC_SESSION_HASH_KEY)
129 .await
130.map_err(CrossProcessRefreshLockError::StoreError)?;
131*self.known_session_hash.lock().await = None;
132Ok(())
133 }
134}
135136pub(super) struct CrossProcessRefreshLockGuard {
137/// The hash for the latest session, either the one we knew, or the latest
138 /// one read from the database, if it was more up to date.
139hash_guard: OwnedMutexGuard<Option<SessionHash>>,
140141/// Cross-process lock being hold.
142_store_guard: CrossProcessStoreLockGuard,
143144/// Reference to the underlying store, for storing the hash of the latest
145 /// known session (as a custom value).
146store: Store,
147148/// Do the in-memory hash and database hash mismatch?
149 ///
150 /// If so, this indicates that another process may have refreshed the token
151 /// in the background.
152 ///
153 /// We don't consider it a mismatch if there was no previous value in the
154 /// database. We do consider it a mismatch if there was no in-memory
155 /// value known, but one was known in the database.
156pub hash_mismatch: bool,
157158/// Session hash previously stored in the DB.
159 ///
160 /// Used for debugging and testing purposes.
161db_hash: Option<SessionHash>,
162}
163164impl CrossProcessRefreshLockGuard {
165/// Updates the `SessionTokens` hash in-memory only.
166fn save_in_memory(&mut self, hash: SessionHash) {
167*self.hash_guard = Some(hash);
168 }
169170/// Updates the `SessionTokens` hash in the database only.
171async fn save_in_database(
172&self,
173 hash: &SessionHash,
174 ) -> Result<(), CrossProcessRefreshLockError> {
175self.store.set_custom_value(OIDC_SESSION_HASH_KEY, hash.0.clone()).await?;
176Ok(())
177 }
178179/// Updates the `SessionTokens` hash in both memory and database.
180 ///
181 /// Must be called after a successful refresh.
182pub async fn save_in_memory_and_db(
183&mut self,
184 tokens: &SessionTokens,
185 ) -> Result<(), CrossProcessRefreshLockError> {
186let hash = compute_session_hash(tokens);
187self.save_in_database(&hash).await?;
188self.save_in_memory(hash);
189Ok(())
190 }
191192/// Handle a mismatch by making sure values in the database and memory match
193 /// tokens we trust.
194pub async fn handle_mismatch(
195&mut self,
196 trusted_tokens: &SessionTokens,
197 ) -> Result<(), CrossProcessRefreshLockError> {
198let new_hash = compute_session_hash(trusted_tokens);
199trace!("Trusted OAuth 2.0 tokens have hash {new_hash:?}; db had {:?}", self.db_hash);
200201if let Some(db_hash) = &self.db_hash {
202if new_hash != *db_hash {
203// That should never happen, unless we got into an impossible situation!
204 // In this case, we assume the value returned by the callback is always
205 // correct, so override that in the database too.
206tracing::error!("error: DB and trusted disagree. Overriding in DB.");
207self.save_in_database(&new_hash).await?;
208 }
209 }
210211self.save_in_memory(new_hash);
212Ok(())
213 }
214}
215216/// An error that happened when interacting with the cross-process store lock
217/// during a token refresh.
218#[derive(Debug, Error)]
219pub enum CrossProcessRefreshLockError {
220/// Underlying error caused by the store.
221#[error(transparent)]
222StoreError(#[from] CryptoStoreError),
223224/// The locking itself failed.
225#[error(transparent)]
226LockError(#[from] LockStoreError),
227228/// The previous hash isn't valid.
229#[error("the previous stored hash isn't a valid integer")]
230InvalidPreviousHash,
231232/// The lock hasn't been set up.
233#[error("the cross-process lock hasn't been set up with `enable_cross_process_refresh_lock")]
234MissingLock,
235236/// Cross-process lock was set, but without session callbacks.
237#[error("reload session callback must be set with Client::set_session_callbacks() for the cross-process lock to work")]
238MissingReloadSession,
239240/// The store has been created twice.
241#[error(
242"the cross-process lock has been set up twice with `enable_cross_process_refresh_lock`"
243)]
244DuplicatedLock,
245}
246247#[cfg(all(test, feature = "e2e-encryption", feature = "sqlite", not(target_arch = "wasm32")))]
248mod tests {
249250use anyhow::Context as _;
251use futures_util::future::join_all;
252use matrix_sdk_base::{store::RoomLoadSettings, SessionMeta};
253use matrix_sdk_test::async_test;
254use ruma::{owned_device_id, owned_user_id};
255256use super::compute_session_hash;
257use crate::{
258 authentication::oauth::cross_process::SessionHash,
259 test_utils::{
260 client::{
261 mock_prev_session_tokens_with_refresh, mock_session_tokens_with_refresh,
262 oauth::mock_session, MockClientBuilder,
263 },
264 mocks::MatrixMockServer,
265 },
266 Error,
267 };
268269#[async_test]
270async fn test_restore_session_lock() -> Result<(), Error> {
271// Create a client that will use sqlite databases.
272273let tmp_dir = tempfile::tempdir()?;
274let client = MockClientBuilder::new("https://example.org".to_owned())
275 .sqlite_store(&tmp_dir)
276 .unlogged()
277 .build()
278 .await;
279280let tokens = mock_session_tokens_with_refresh();
281282 client.oauth().enable_cross_process_refresh_lock("test".to_owned()).await?;
283284 client.set_session_callbacks(
285 Box::new({
286// This is only called because of extra checks in the code.
287let tokens = tokens.clone();
288move |_| Ok(tokens.clone())
289 }),
290 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
291 )?;
292293let session_hash = compute_session_hash(&tokens);
294 client
295 .oauth()
296 .restore_session(
297 mock_session(tokens.clone(), "https://oauth.example.com/issuer"),
298 RoomLoadSettings::default(),
299 )
300 .await?;
301302assert_eq!(client.session_tokens().unwrap(), tokens);
303304let oauth = client.oauth();
305let xp_manager = oauth.ctx().cross_process_token_refresh_manager.get().unwrap();
306307 {
308let known_session = xp_manager.known_session_hash.lock().await;
309assert_eq!(known_session.as_ref().unwrap(), &session_hash);
310 }
311312 {
313let lock = xp_manager.spin_lock().await.unwrap();
314assert!(!lock.hash_mismatch);
315assert_eq!(lock.db_hash.unwrap(), session_hash);
316 }
317318Ok(())
319 }
320321#[async_test]
322async fn test_finish_login() -> anyhow::Result<()> {
323let server = MatrixMockServer::new().await;
324 server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
325326let tmp_dir = tempfile::tempdir()?;
327let client = server
328 .client_builder()
329 .sqlite_store(&tmp_dir)
330 .registered_with_oauth(server.server().uri())
331 .build()
332 .await;
333let oauth = client.oauth();
334335// Enable cross-process lock.
336oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
337338// Simulate we've done finalize_authorization / restore_session before.
339let session_tokens = mock_session_tokens_with_refresh();
340 client.auth_ctx().set_session_tokens(session_tokens.clone());
341342// Now, finishing logging will get the user ID.
343oauth.load_session(owned_device_id!("D3V1C31D")).await?;
344345let session_meta = client.session_meta().context("should have session meta now")?;
346assert_eq!(
347*session_meta,
348 SessionMeta {
349 user_id: owned_user_id!("@joe:example.org"),
350 device_id: owned_device_id!("D3V1C31D")
351 }
352 );
353354 {
355// The cross process lock has been correctly updated, and the next attempt to
356 // take it won't result in a mismatch.
357let xp_manager =
358 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
359let guard = xp_manager.spin_lock().await?;
360let actual_hash = compute_session_hash(&session_tokens);
361assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
362assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
363assert!(!guard.hash_mismatch);
364 }
365366Ok(())
367 }
368369#[async_test]
370async fn test_refresh_access_token_twice() -> anyhow::Result<()> {
371// This tests that refresh token works, and that it doesn't cause multiple token
372 // refreshes whenever one spawns two refreshes around the same time.
373374let server = MatrixMockServer::new().await;
375376let oauth_server = server.oauth();
377 oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
378 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
379380let tmp_dir = tempfile::tempdir()?;
381let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
382let oauth = client.oauth();
383384let next_tokens = mock_session_tokens_with_refresh();
385386// Enable cross-process lock.
387oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
388389// Restore the session.
390oauth
391 .restore_session(
392 mock_session(mock_prev_session_tokens_with_refresh(), server.server().uri()),
393 RoomLoadSettings::default(),
394 )
395 .await?;
396397// Immediately try to refresh the access token twice in parallel.
398for result in join_all([oauth.refresh_access_token(), oauth.refresh_access_token()]).await {
399 result?;
400 }
401402 {
403// The cross process lock has been correctly updated, and the next attempt to
404 // take it won't result in a mismatch.
405let xp_manager =
406 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
407let guard = xp_manager.spin_lock().await?;
408let actual_hash = compute_session_hash(&next_tokens);
409assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
410assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
411assert!(!guard.hash_mismatch);
412 }
413414Ok(())
415 }
416417#[async_test]
418async fn test_cross_process_concurrent_refresh() -> anyhow::Result<()> {
419let server = MatrixMockServer::new().await;
420let issuer = server.server().uri();
421422let oauth_server = server.oauth();
423 oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
424 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
425426let prev_tokens = mock_prev_session_tokens_with_refresh();
427let next_tokens = mock_session_tokens_with_refresh();
428429// Create the first client.
430let tmp_dir = tempfile::tempdir()?;
431let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
432433let oauth = client.oauth();
434 oauth.enable_cross_process_refresh_lock("client1".to_owned()).await?;
435436 oauth
437 .restore_session(
438 mock_session(prev_tokens.clone(), issuer.clone()),
439 RoomLoadSettings::default(),
440 )
441 .await?;
442443// Create a second client, without restoring it, to test that a token update
444 // before restoration doesn't cause new issues.
445let unrestored_client =
446 server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
447let unrestored_oauth = unrestored_client.oauth();
448 unrestored_oauth.enable_cross_process_refresh_lock("unrestored_client".to_owned()).await?;
449450 {
451// Create a third client that will run a refresh while the others two are doing
452 // nothing.
453let client3 = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
454455let oauth3 = client3.oauth();
456 oauth3.enable_cross_process_refresh_lock("client3".to_owned()).await?;
457 oauth3
458 .restore_session(
459 mock_session(prev_tokens.clone(), issuer.clone()),
460 RoomLoadSettings::default(),
461 )
462 .await?;
463464// Run a refresh in the second client; this will invalidate the tokens from the
465 // first token.
466oauth3.refresh_access_token().await?;
467468assert_eq!(client3.session_tokens(), Some(next_tokens.clone()));
469470// Reading from the cross-process lock for the second client only shows the new
471 // tokens.
472let xp_manager =
473 oauth3.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
474let guard = xp_manager.spin_lock().await?;
475let actual_hash = compute_session_hash(&next_tokens);
476assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
477assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
478assert!(!guard.hash_mismatch);
479 }
480481 {
482// Restoring the client that was not restored yet will work Just Fine.
483let oauth = unrestored_oauth;
484485 unrestored_client.set_session_callbacks(
486 Box::new({
487// This is only called because of extra checks in the code.
488let tokens = next_tokens.clone();
489move |_| Ok(tokens.clone())
490 }),
491 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
492 )?;
493494 oauth
495 .restore_session(
496 mock_session(prev_tokens.clone(), issuer),
497 RoomLoadSettings::default(),
498 )
499 .await?;
500501// And this client is now aware of the latest tokens.
502let xp_manager =
503 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
504let guard = xp_manager.spin_lock().await?;
505let next_hash = compute_session_hash(&next_tokens);
506assert_eq!(guard.db_hash.as_ref(), Some(&next_hash));
507assert_eq!(guard.hash_guard.as_ref(), Some(&next_hash));
508assert!(!guard.hash_mismatch);
509510 drop(oauth);
511 drop(unrestored_client);
512 }
513514 {
515// The cross process lock has been correctly updated, and the next attempt to
516 // take it will result in a mismatch.
517let xp_manager =
518 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
519let guard = xp_manager.spin_lock().await?;
520let previous_hash = compute_session_hash(&prev_tokens);
521let next_hash = compute_session_hash(&next_tokens);
522assert_eq!(guard.db_hash, Some(next_hash));
523assert_eq!(guard.hash_guard.as_ref(), Some(&previous_hash));
524assert!(guard.hash_mismatch);
525 }
526527 client.set_session_callbacks(
528 Box::new({
529// This is only called because of extra checks in the code.
530let tokens = next_tokens.clone();
531move |_| Ok(tokens.clone())
532 }),
533 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
534 )?;
535536 oauth.refresh_access_token().await?;
537538 {
539// The next attempt to take the lock isn't a mismatch.
540let xp_manager =
541 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
542let guard = xp_manager.spin_lock().await?;
543let actual_hash = compute_session_hash(&next_tokens);
544assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
545assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
546assert!(!guard.hash_mismatch);
547 }
548549Ok(())
550 }
551552#[async_test]
553async fn test_logout() -> anyhow::Result<()> {
554let server = MatrixMockServer::new().await;
555556let oauth_server = server.oauth();
557 oauth_server
558 .mock_server_metadata()
559 .ok_https()
560 .expect(1..)
561 .named("server_metadata")
562 .mount()
563 .await;
564 oauth_server.mock_revocation().ok().expect(1).named("revocation").mount().await;
565566let tmp_dir = tempfile::tempdir()?;
567let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
568let oauth = client.oauth().insecure_rewrite_https_to_http();
569570// Enable cross-process lock.
571oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
572573// Restore the session.
574let tokens = mock_session_tokens_with_refresh();
575 oauth
576 .restore_session(
577 mock_session(tokens.clone(), server.server().uri()),
578 RoomLoadSettings::default(),
579 )
580 .await?;
581582 oauth.logout().await.unwrap();
583584 {
585// The cross process lock has been correctly updated, and all the hashes are
586 // empty after a logout.
587let xp_manager =
588 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
589let guard = xp_manager.spin_lock().await?;
590assert!(guard.db_hash.is_none());
591assert!(guard.hash_guard.is_none());
592assert!(!guard.hash_mismatch);
593 }
594595Ok(())
596 }
597598#[test]
599fn test_session_hash_to_hex() {
600let hash = SessionHash(vec![]);
601assert_eq!(hash.to_hex(), "");
602603let hash = SessionHash(vec![0x13, 0x37, 0x42, 0xde, 0xad, 0xca, 0xfe]);
604assert_eq!(hash.to_hex(), "0x133742deadcafe");
605 }
606}