matrix_sdk/authentication/oauth/
cross_process.rs1use std::sync::Arc;
2
3#[cfg(feature = "e2e-encryption")]
4use matrix_sdk_base::crypto::{
5 store::{LockableCryptoStore, Store},
6 CryptoStoreError,
7};
8use matrix_sdk_common::store_locks::{
9 CrossProcessStoreLock, CrossProcessStoreLockGuard, LockStoreError,
10};
11use sha2::{Digest as _, Sha256};
12use thiserror::Error;
13use tokio::sync::{Mutex, OwnedMutexGuard};
14use tracing::trace;
15
16use crate::SessionTokens;
17
18const OIDC_SESSION_HASH_KEY: &str = "oidc_session_hash";
21
22#[derive(Clone, PartialEq, Eq)]
24struct SessionHash(Vec<u8>);
25
26impl SessionHash {
27 fn to_hex(&self) -> String {
28 const CHARS: &[char; 16] =
29 &['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'];
30 let mut res = String::with_capacity(2 * self.0.len() + 2);
31 if !self.0.is_empty() {
32 res.push('0');
33 res.push('x');
34 }
35 for &c in &self.0 {
36 res.push(CHARS[(c >> 4) as usize]);
40 res.push(CHARS[(c & 0b1111) as usize]);
41 }
42 res
43 }
44}
45
46impl std::fmt::Debug for SessionHash {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_tuple("SessionHash").field(&self.to_hex()).finish()
49 }
50}
51
52fn compute_session_hash(tokens: &SessionTokens) -> SessionHash {
54 let mut hash = Sha256::new().chain_update(tokens.access_token.as_bytes());
55 if let Some(refresh_token) = &tokens.refresh_token {
56 hash = hash.chain_update(refresh_token.as_bytes());
57 }
58 SessionHash(hash.finalize().to_vec())
59}
60
61#[derive(Clone)]
62pub(super) struct CrossProcessRefreshManager {
63 store: Store,
64 store_lock: CrossProcessStoreLock<LockableCryptoStore>,
65 known_session_hash: Arc<Mutex<Option<SessionHash>>>,
66}
67
68impl CrossProcessRefreshManager {
69 pub fn new(store: Store, lock: CrossProcessStoreLock<LockableCryptoStore>) -> Self {
71 Self { store, store_lock: lock, known_session_hash: Arc::new(Mutex::new(None)) }
72 }
73
74 pub async fn spin_lock(
80 &self,
81 ) -> Result<CrossProcessRefreshLockGuard, CrossProcessRefreshLockError> {
82 trace!("Waiting for intra-process lock...");
85 let prev_hash = self.known_session_hash.clone().lock_owned().await;
86
87 trace!("Waiting for inter-process lock...");
90 let store_guard = self.store_lock.spin_lock(Some(60000)).await?;
91
92 let current_db_session_bytes = self.store.get_custom_value(OIDC_SESSION_HASH_KEY).await?;
94
95 let db_hash = current_db_session_bytes.map(SessionHash);
96
97 let hash_mismatch = match (&db_hash, &*prev_hash) {
98 (None, _) => false,
99 (Some(_), None) => true,
100 (Some(db), Some(known)) => db != known,
101 };
102
103 trace!(hash_mismatch, ?prev_hash, ?db_hash);
104
105 let guard = CrossProcessRefreshLockGuard {
106 hash_guard: prev_hash,
107 _store_guard: store_guard,
108 hash_mismatch,
109 db_hash,
110 store: self.store.clone(),
111 };
112
113 Ok(guard)
114 }
115
116 pub async fn restore_session(&self, tokens: &SessionTokens) {
117 let prev_tokens_hash = compute_session_hash(tokens);
118 *self.known_session_hash.lock().await = Some(prev_tokens_hash);
119 }
120
121 pub async fn on_logout(&self) -> Result<(), CrossProcessRefreshLockError> {
122 self.store
123 .remove_custom_value(OIDC_SESSION_HASH_KEY)
124 .await
125 .map_err(CrossProcessRefreshLockError::StoreError)?;
126 *self.known_session_hash.lock().await = None;
127 Ok(())
128 }
129}
130
131pub(super) struct CrossProcessRefreshLockGuard {
132 hash_guard: OwnedMutexGuard<Option<SessionHash>>,
135
136 _store_guard: CrossProcessStoreLockGuard,
138
139 store: Store,
142
143 pub hash_mismatch: bool,
152
153 db_hash: Option<SessionHash>,
157}
158
159impl CrossProcessRefreshLockGuard {
160 fn save_in_memory(&mut self, hash: SessionHash) {
162 *self.hash_guard = Some(hash);
163 }
164
165 async fn save_in_database(
167 &self,
168 hash: &SessionHash,
169 ) -> Result<(), CrossProcessRefreshLockError> {
170 self.store.set_custom_value(OIDC_SESSION_HASH_KEY, hash.0.clone()).await?;
171 Ok(())
172 }
173
174 pub async fn save_in_memory_and_db(
178 &mut self,
179 tokens: &SessionTokens,
180 ) -> Result<(), CrossProcessRefreshLockError> {
181 let hash = compute_session_hash(tokens);
182 self.save_in_database(&hash).await?;
183 self.save_in_memory(hash);
184 Ok(())
185 }
186
187 pub async fn handle_mismatch(
190 &mut self,
191 trusted_tokens: &SessionTokens,
192 ) -> Result<(), CrossProcessRefreshLockError> {
193 let new_hash = compute_session_hash(trusted_tokens);
194 trace!("Trusted OAuth 2.0 tokens have hash {new_hash:?}; db had {:?}", self.db_hash);
195
196 if let Some(db_hash) = &self.db_hash {
197 if new_hash != *db_hash {
198 tracing::error!("error: DB and trusted disagree. Overriding in DB.");
202 self.save_in_database(&new_hash).await?;
203 }
204 }
205
206 self.save_in_memory(new_hash);
207 Ok(())
208 }
209}
210
211#[derive(Debug, Error)]
214pub enum CrossProcessRefreshLockError {
215 #[error(transparent)]
217 StoreError(#[from] CryptoStoreError),
218
219 #[error(transparent)]
221 LockError(#[from] LockStoreError),
222
223 #[error("the previous stored hash isn't a valid integer")]
225 InvalidPreviousHash,
226
227 #[error("the cross-process lock hasn't been set up with `enable_cross_process_refresh_lock")]
229 MissingLock,
230
231 #[error(
233 "reload session callback must be set with Client::set_session_callbacks() \
234 for the cross-process lock to work"
235 )]
236 MissingReloadSession,
237
238 #[error(
240 "the cross-process lock has been set up twice with `enable_cross_process_refresh_lock`"
241 )]
242 DuplicatedLock,
243}
244
245#[cfg(all(test, feature = "e2e-encryption", feature = "sqlite", not(target_family = "wasm")))]
246mod tests {
247
248 use anyhow::Context as _;
249 use futures_util::future::join_all;
250 use matrix_sdk_base::{store::RoomLoadSettings, SessionMeta};
251 use matrix_sdk_test::async_test;
252 use ruma::{owned_device_id, owned_user_id};
253
254 use super::compute_session_hash;
255 use crate::{
256 authentication::oauth::cross_process::SessionHash,
257 test_utils::{
258 client::{
259 mock_prev_session_tokens_with_refresh, mock_session_tokens_with_refresh,
260 oauth::mock_session, MockClientBuilder,
261 },
262 mocks::MatrixMockServer,
263 },
264 Error,
265 };
266
267 #[async_test]
268 async fn test_restore_session_lock() -> Result<(), Error> {
269 let tmp_dir = tempfile::tempdir()?;
272 let client = MockClientBuilder::new(None)
273 .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
274 .unlogged()
275 .build()
276 .await;
277
278 let tokens = mock_session_tokens_with_refresh();
279
280 client.oauth().enable_cross_process_refresh_lock("test".to_owned()).await?;
281
282 client.set_session_callbacks(
283 Box::new({
284 let tokens = tokens.clone();
286 move |_| Ok(tokens.clone())
287 }),
288 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
289 )?;
290
291 let session_hash = compute_session_hash(&tokens);
292 client
293 .oauth()
294 .restore_session(mock_session(tokens.clone()), RoomLoadSettings::default())
295 .await?;
296
297 assert_eq!(client.session_tokens().unwrap(), tokens);
298
299 let oauth = client.oauth();
300 let xp_manager = oauth.ctx().cross_process_token_refresh_manager.get().unwrap();
301
302 {
303 let known_session = xp_manager.known_session_hash.lock().await;
304 assert_eq!(known_session.as_ref().unwrap(), &session_hash);
305 }
306
307 {
308 let lock = xp_manager.spin_lock().await.unwrap();
309 assert!(!lock.hash_mismatch);
310 assert_eq!(lock.db_hash.unwrap(), session_hash);
311 }
312
313 Ok(())
314 }
315
316 #[async_test]
317 async fn test_finish_login() -> anyhow::Result<()> {
318 let server = MatrixMockServer::new().await;
319 server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
320
321 let tmp_dir = tempfile::tempdir()?;
322 let client = server
323 .client_builder()
324 .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
325 .registered_with_oauth()
326 .build()
327 .await;
328 let oauth = client.oauth();
329
330 oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
332
333 let session_tokens = mock_session_tokens_with_refresh();
335 client.auth_ctx().set_session_tokens(session_tokens.clone());
336
337 oauth.load_session(owned_device_id!("D3V1C31D")).await?;
339
340 let session_meta = client.session_meta().context("should have session meta now")?;
341 assert_eq!(
342 *session_meta,
343 SessionMeta {
344 user_id: owned_user_id!("@joe:example.org"),
345 device_id: owned_device_id!("D3V1C31D")
346 }
347 );
348
349 {
350 let xp_manager =
353 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
354 let guard = xp_manager.spin_lock().await?;
355 let actual_hash = compute_session_hash(&session_tokens);
356 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
357 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
358 assert!(!guard.hash_mismatch);
359 }
360
361 Ok(())
362 }
363
364 #[async_test]
365 async fn test_refresh_access_token_twice() -> anyhow::Result<()> {
366 let server = MatrixMockServer::new().await;
370
371 let oauth_server = server.oauth();
372 oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
373 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
374
375 let tmp_dir = tempfile::tempdir()?;
376 let client = server
377 .client_builder()
378 .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
379 .unlogged()
380 .build()
381 .await;
382 let oauth = client.oauth();
383
384 let next_tokens = mock_session_tokens_with_refresh();
385
386 oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
388
389 oauth
391 .restore_session(
392 mock_session(mock_prev_session_tokens_with_refresh()),
393 RoomLoadSettings::default(),
394 )
395 .await?;
396
397 for result in join_all([oauth.refresh_access_token(), oauth.refresh_access_token()]).await {
399 result?;
400 }
401
402 {
403 let xp_manager =
406 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
407 let guard = xp_manager.spin_lock().await?;
408 let actual_hash = compute_session_hash(&next_tokens);
409 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
410 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
411 assert!(!guard.hash_mismatch);
412 }
413
414 Ok(())
415 }
416
417 #[async_test]
418 async fn test_cross_process_concurrent_refresh() -> anyhow::Result<()> {
419 let server = MatrixMockServer::new().await;
420
421 let oauth_server = server.oauth();
422 oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
423 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
424
425 let prev_tokens = mock_prev_session_tokens_with_refresh();
426 let next_tokens = mock_session_tokens_with_refresh();
427
428 let tmp_dir = tempfile::tempdir()?;
430 let client = server
431 .client_builder()
432 .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
433 .unlogged()
434 .build()
435 .await;
436
437 let oauth = client.oauth();
438 oauth.enable_cross_process_refresh_lock("client1".to_owned()).await?;
439
440 oauth
441 .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
442 .await?;
443
444 let unrestored_client = server
447 .client_builder()
448 .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
449 .unlogged()
450 .build()
451 .await;
452 let unrestored_oauth = unrestored_client.oauth();
453 unrestored_oauth.enable_cross_process_refresh_lock("unrestored_client".to_owned()).await?;
454
455 {
456 let client3 = server
459 .client_builder()
460 .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
461 .unlogged()
462 .build()
463 .await;
464
465 let oauth3 = client3.oauth();
466 oauth3.enable_cross_process_refresh_lock("client3".to_owned()).await?;
467 oauth3
468 .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
469 .await?;
470
471 oauth3.refresh_access_token().await?;
474
475 assert_eq!(client3.session_tokens(), Some(next_tokens.clone()));
476
477 let xp_manager =
480 oauth3.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
481 let guard = xp_manager.spin_lock().await?;
482 let actual_hash = compute_session_hash(&next_tokens);
483 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
484 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
485 assert!(!guard.hash_mismatch);
486 }
487
488 {
489 let oauth = unrestored_oauth;
491
492 unrestored_client.set_session_callbacks(
493 Box::new({
494 let tokens = next_tokens.clone();
496 move |_| Ok(tokens.clone())
497 }),
498 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
499 )?;
500
501 oauth
502 .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
503 .await?;
504
505 let xp_manager =
507 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
508 let guard = xp_manager.spin_lock().await?;
509 let next_hash = compute_session_hash(&next_tokens);
510 assert_eq!(guard.db_hash.as_ref(), Some(&next_hash));
511 assert_eq!(guard.hash_guard.as_ref(), Some(&next_hash));
512 assert!(!guard.hash_mismatch);
513
514 drop(oauth);
515 drop(unrestored_client);
516 }
517
518 {
519 let xp_manager =
522 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
523 let guard = xp_manager.spin_lock().await?;
524 let previous_hash = compute_session_hash(&prev_tokens);
525 let next_hash = compute_session_hash(&next_tokens);
526 assert_eq!(guard.db_hash, Some(next_hash));
527 assert_eq!(guard.hash_guard.as_ref(), Some(&previous_hash));
528 assert!(guard.hash_mismatch);
529 }
530
531 client.set_session_callbacks(
532 Box::new({
533 let tokens = next_tokens.clone();
535 move |_| Ok(tokens.clone())
536 }),
537 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
538 )?;
539
540 oauth.refresh_access_token().await?;
541
542 {
543 let xp_manager =
545 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
546 let guard = xp_manager.spin_lock().await?;
547 let actual_hash = compute_session_hash(&next_tokens);
548 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
549 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
550 assert!(!guard.hash_mismatch);
551 }
552
553 Ok(())
554 }
555
556 #[async_test]
557 async fn test_logout() -> anyhow::Result<()> {
558 let server = MatrixMockServer::new().await;
559
560 let oauth_server = server.oauth();
561 oauth_server
562 .mock_server_metadata()
563 .ok_https()
564 .expect(1..)
565 .named("server_metadata")
566 .mount()
567 .await;
568 oauth_server.mock_revocation().ok().expect(1).named("revocation").mount().await;
569
570 let tmp_dir = tempfile::tempdir()?;
571 let client = server
572 .client_builder()
573 .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
574 .unlogged()
575 .build()
576 .await;
577 let oauth = client.oauth().insecure_rewrite_https_to_http();
578
579 oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
581
582 let tokens = mock_session_tokens_with_refresh();
584 oauth.restore_session(mock_session(tokens.clone()), RoomLoadSettings::default()).await?;
585
586 oauth.logout().await.unwrap();
587
588 {
589 let xp_manager =
592 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
593 let guard = xp_manager.spin_lock().await?;
594 assert!(guard.db_hash.is_none());
595 assert!(guard.hash_guard.is_none());
596 assert!(!guard.hash_mismatch);
597 }
598
599 Ok(())
600 }
601
602 #[test]
603 fn test_session_hash_to_hex() {
604 let hash = SessionHash(vec![]);
605 assert_eq!(hash.to_hex(), "");
606
607 let hash = SessionHash(vec![0x13, 0x37, 0x42, 0xde, 0xad, 0xca, 0xfe]);
608 assert_eq!(hash.to_hex(), "0x133742deadcafe");
609 }
610}