matrix_sdk/authentication/oauth/
cross_process.rs1use std::sync::Arc;
2
3#[cfg(feature = "e2e-encryption")]
4use matrix_sdk_base::crypto::{
5 CryptoStoreError,
6 store::{LockableCryptoStore, Store},
7};
8use matrix_sdk_common::cross_process_lock::{
9 CrossProcessLock, CrossProcessLockError, CrossProcessLockGuard,
10};
11use sha2::{Digest as _, Sha256};
12use thiserror::Error;
13use tokio::sync::{Mutex, OwnedMutexGuard};
14use tracing::trace;
15
16use crate::SessionTokens;
17
18const OIDC_SESSION_HASH_KEY: &str = "oidc_session_hash";
21
22#[derive(Clone, PartialEq, Eq)]
24struct SessionHash(Vec<u8>);
25
26impl SessionHash {
27 fn to_hex(&self) -> String {
28 const CHARS: &[char; 16] =
29 &['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'];
30 let mut res = String::with_capacity(2 * self.0.len() + 2);
31 if !self.0.is_empty() {
32 res.push('0');
33 res.push('x');
34 }
35 for &c in &self.0 {
36 res.push(CHARS[(c >> 4) as usize]);
40 res.push(CHARS[(c & 0b1111) as usize]);
41 }
42 res
43 }
44}
45
46impl std::fmt::Debug for SessionHash {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_tuple("SessionHash").field(&self.to_hex()).finish()
49 }
50}
51
52fn compute_session_hash(tokens: &SessionTokens) -> SessionHash {
54 let mut hash = Sha256::new().chain_update(tokens.access_token.as_bytes());
55 if let Some(refresh_token) = &tokens.refresh_token {
56 hash = hash.chain_update(refresh_token.as_bytes());
57 }
58 SessionHash(hash.finalize().to_vec())
59}
60
61#[derive(Clone)]
62pub(super) struct CrossProcessRefreshManager {
63 store: Store,
64 store_lock: CrossProcessLock<LockableCryptoStore>,
65 known_session_hash: Arc<Mutex<Option<SessionHash>>>,
66}
67
68impl CrossProcessRefreshManager {
69 pub fn new(store: Store, lock: CrossProcessLock<LockableCryptoStore>) -> Self {
71 Self { store, store_lock: lock, known_session_hash: Arc::new(Mutex::new(None)) }
72 }
73
74 pub async fn spin_lock(
80 &self,
81 ) -> Result<CrossProcessRefreshLockGuard, CrossProcessRefreshLockError> {
82 trace!("Waiting for intra-process lock...");
85 let prev_hash = self.known_session_hash.clone().lock_owned().await;
86
87 trace!("Waiting for inter-process lock...");
90 let store_guard = self
91 .store_lock
92 .spin_lock(Some(60000))
93 .await
94 .map_err(|err| {
95 CrossProcessRefreshLockError::LockError(CrossProcessLockError::TryLock(Box::new(
96 err,
97 )))
98 })?
99 .map_err(|err| CrossProcessRefreshLockError::LockError(err.into()))?;
100
101 let current_db_session_bytes = self.store.get_custom_value(OIDC_SESSION_HASH_KEY).await?;
103
104 let db_hash = current_db_session_bytes.map(SessionHash);
105
106 let hash_mismatch = match (&db_hash, &*prev_hash) {
107 (None, _) => false,
108 (Some(_), None) => true,
109 (Some(db), Some(known)) => db != known,
110 };
111
112 trace!(hash_mismatch, ?prev_hash, ?db_hash);
113
114 let guard = CrossProcessRefreshLockGuard {
115 hash_guard: prev_hash,
116 _store_guard: store_guard.into_guard(),
117 hash_mismatch,
118 db_hash,
119 store: self.store.clone(),
120 };
121
122 Ok(guard)
123 }
124
125 pub async fn restore_session(&self, tokens: &SessionTokens) {
126 let prev_tokens_hash = compute_session_hash(tokens);
127 *self.known_session_hash.lock().await = Some(prev_tokens_hash);
128 }
129
130 pub async fn on_logout(&self) -> Result<(), CrossProcessRefreshLockError> {
131 self.store
132 .remove_custom_value(OIDC_SESSION_HASH_KEY)
133 .await
134 .map_err(CrossProcessRefreshLockError::StoreError)?;
135 *self.known_session_hash.lock().await = None;
136 Ok(())
137 }
138}
139
140pub(super) struct CrossProcessRefreshLockGuard {
141 hash_guard: OwnedMutexGuard<Option<SessionHash>>,
144
145 _store_guard: CrossProcessLockGuard,
147
148 store: Store,
151
152 pub hash_mismatch: bool,
161
162 db_hash: Option<SessionHash>,
166}
167
168impl CrossProcessRefreshLockGuard {
169 fn save_in_memory(&mut self, hash: SessionHash) {
171 *self.hash_guard = Some(hash);
172 }
173
174 async fn save_in_database(
176 &self,
177 hash: &SessionHash,
178 ) -> Result<(), CrossProcessRefreshLockError> {
179 self.store.set_custom_value(OIDC_SESSION_HASH_KEY, hash.0.clone()).await?;
180 Ok(())
181 }
182
183 pub async fn save_in_memory_and_db(
187 &mut self,
188 tokens: &SessionTokens,
189 ) -> Result<(), CrossProcessRefreshLockError> {
190 let hash = compute_session_hash(tokens);
191 self.save_in_database(&hash).await?;
192 self.save_in_memory(hash);
193 Ok(())
194 }
195
196 pub async fn handle_mismatch(
199 &mut self,
200 trusted_tokens: &SessionTokens,
201 ) -> Result<(), CrossProcessRefreshLockError> {
202 let new_hash = compute_session_hash(trusted_tokens);
203 trace!("Trusted OAuth 2.0 tokens have hash {new_hash:?}; db had {:?}", self.db_hash);
204
205 if let Some(db_hash) = &self.db_hash
206 && new_hash != *db_hash
207 {
208 tracing::error!("error: DB and trusted disagree. Overriding in DB.");
212 self.save_in_database(&new_hash).await?;
213 }
214
215 self.save_in_memory(new_hash);
216 Ok(())
217 }
218}
219
220#[derive(Debug, Error)]
223pub enum CrossProcessRefreshLockError {
224 #[error(transparent)]
226 StoreError(#[from] CryptoStoreError),
227
228 #[error(transparent)]
230 LockError(#[from] CrossProcessLockError),
231
232 #[error("the previous stored hash isn't a valid integer")]
234 InvalidPreviousHash,
235
236 #[error("the cross-process lock hasn't been set up with `enable_cross_process_refresh_lock")]
238 MissingLock,
239
240 #[error(
242 "reload session callback must be set with Client::set_session_callbacks() \
243 for the cross-process lock to work"
244 )]
245 MissingReloadSession,
246
247 #[error(
249 "the cross-process lock has been set up twice with `enable_cross_process_refresh_lock`"
250 )]
251 DuplicatedLock,
252}
253
254#[cfg(all(test, feature = "e2e-encryption", feature = "sqlite", not(target_family = "wasm")))]
255mod tests {
256
257 use anyhow::Context as _;
258 use futures_util::future::join_all;
259 use matrix_sdk_base::{SessionMeta, store::RoomLoadSettings};
260 use matrix_sdk_test::async_test;
261 use ruma::{owned_device_id, owned_user_id};
262
263 use super::compute_session_hash;
264 use crate::{
265 Error,
266 authentication::oauth::cross_process::SessionHash,
267 test_utils::{
268 client::{
269 MockClientBuilder, mock_prev_session_tokens_with_refresh,
270 mock_session_tokens_with_refresh, oauth::mock_session,
271 },
272 mocks::MatrixMockServer,
273 },
274 };
275
276 #[async_test]
277 async fn test_restore_session_lock() -> Result<(), Error> {
278 let tmp_dir = tempfile::tempdir()?;
281 let client = MockClientBuilder::new(None)
282 .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
283 .unlogged()
284 .build()
285 .await;
286
287 let tokens = mock_session_tokens_with_refresh();
288
289 client.oauth().enable_cross_process_refresh_lock("test".to_owned()).await?;
290
291 client.set_session_callbacks(
292 Box::new({
293 let tokens = tokens.clone();
295 move |_| Ok(tokens.clone())
296 }),
297 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
298 )?;
299
300 let session_hash = compute_session_hash(&tokens);
301 client
302 .oauth()
303 .restore_session(mock_session(tokens.clone()), RoomLoadSettings::default())
304 .await?;
305
306 assert_eq!(client.session_tokens().unwrap(), tokens);
307
308 let oauth = client.oauth();
309 let xp_manager = oauth.ctx().cross_process_token_refresh_manager.get().unwrap();
310
311 {
312 let known_session = xp_manager.known_session_hash.lock().await;
313 assert_eq!(known_session.as_ref().unwrap(), &session_hash);
314 }
315
316 {
317 let lock = xp_manager.spin_lock().await.unwrap();
318 assert!(!lock.hash_mismatch);
319 assert_eq!(lock.db_hash.unwrap(), session_hash);
320 }
321
322 Ok(())
323 }
324
325 #[async_test]
326 async fn test_finish_login() -> anyhow::Result<()> {
327 let server = MatrixMockServer::new().await;
328 server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
329
330 let tmp_dir = tempfile::tempdir()?;
331 let client = server
332 .client_builder()
333 .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
334 .registered_with_oauth()
335 .build()
336 .await;
337 let oauth = client.oauth();
338
339 oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
341
342 let session_tokens = mock_session_tokens_with_refresh();
344 client.auth_ctx().set_session_tokens(session_tokens.clone());
345
346 oauth.load_session(owned_device_id!("D3V1C31D")).await?;
348
349 let session_meta = client.session_meta().context("should have session meta now")?;
350 assert_eq!(
351 *session_meta,
352 SessionMeta {
353 user_id: owned_user_id!("@joe:example.org"),
354 device_id: owned_device_id!("D3V1C31D")
355 }
356 );
357
358 {
359 let xp_manager =
362 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
363 let guard = xp_manager.spin_lock().await?;
364 let actual_hash = compute_session_hash(&session_tokens);
365 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
366 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
367 assert!(!guard.hash_mismatch);
368 }
369
370 Ok(())
371 }
372
373 #[async_test]
374 async fn test_refresh_access_token_twice() -> anyhow::Result<()> {
375 let server = MatrixMockServer::new().await;
379
380 let oauth_server = server.oauth();
381 oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
382 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
383
384 let tmp_dir = tempfile::tempdir()?;
385 let client = server
386 .client_builder()
387 .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
388 .unlogged()
389 .build()
390 .await;
391 let oauth = client.oauth();
392
393 let next_tokens = mock_session_tokens_with_refresh();
394
395 oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
397
398 oauth
400 .restore_session(
401 mock_session(mock_prev_session_tokens_with_refresh()),
402 RoomLoadSettings::default(),
403 )
404 .await?;
405
406 for result in join_all([oauth.refresh_access_token(), oauth.refresh_access_token()]).await {
408 result?;
409 }
410
411 {
412 let xp_manager =
415 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
416 let guard = xp_manager.spin_lock().await?;
417 let actual_hash = compute_session_hash(&next_tokens);
418 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
419 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
420 assert!(!guard.hash_mismatch);
421 }
422
423 Ok(())
424 }
425
426 #[async_test]
427 async fn test_cross_process_concurrent_refresh() -> anyhow::Result<()> {
428 let server = MatrixMockServer::new().await;
429
430 let oauth_server = server.oauth();
431 oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
432 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
433
434 let prev_tokens = mock_prev_session_tokens_with_refresh();
435 let next_tokens = mock_session_tokens_with_refresh();
436
437 let tmp_dir = tempfile::tempdir()?;
439 let client = server
440 .client_builder()
441 .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
442 .unlogged()
443 .build()
444 .await;
445
446 let oauth = client.oauth();
447 oauth.enable_cross_process_refresh_lock("client1".to_owned()).await?;
448
449 oauth
450 .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
451 .await?;
452
453 let unrestored_client = server
456 .client_builder()
457 .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
458 .unlogged()
459 .build()
460 .await;
461 let unrestored_oauth = unrestored_client.oauth();
462 unrestored_oauth.enable_cross_process_refresh_lock("unrestored_client".to_owned()).await?;
463
464 {
465 let client3 = server
468 .client_builder()
469 .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
470 .unlogged()
471 .build()
472 .await;
473
474 let oauth3 = client3.oauth();
475 oauth3.enable_cross_process_refresh_lock("client3".to_owned()).await?;
476 oauth3
477 .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
478 .await?;
479
480 oauth3.refresh_access_token().await?;
483
484 assert_eq!(client3.session_tokens(), Some(next_tokens.clone()));
485
486 let xp_manager =
489 oauth3.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
490 let guard = xp_manager.spin_lock().await?;
491 let actual_hash = compute_session_hash(&next_tokens);
492 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
493 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
494 assert!(!guard.hash_mismatch);
495 }
496
497 {
498 let oauth = unrestored_oauth;
500
501 unrestored_client.set_session_callbacks(
502 Box::new({
503 let tokens = next_tokens.clone();
505 move |_| Ok(tokens.clone())
506 }),
507 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
508 )?;
509
510 oauth
511 .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
512 .await?;
513
514 let xp_manager =
516 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
517 let guard = xp_manager.spin_lock().await?;
518 let next_hash = compute_session_hash(&next_tokens);
519 assert_eq!(guard.db_hash.as_ref(), Some(&next_hash));
520 assert_eq!(guard.hash_guard.as_ref(), Some(&next_hash));
521 assert!(!guard.hash_mismatch);
522
523 drop(oauth);
524 drop(unrestored_client);
525 }
526
527 {
528 let xp_manager =
531 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
532 let guard = xp_manager.spin_lock().await?;
533 let previous_hash = compute_session_hash(&prev_tokens);
534 let next_hash = compute_session_hash(&next_tokens);
535 assert_eq!(guard.db_hash, Some(next_hash));
536 assert_eq!(guard.hash_guard.as_ref(), Some(&previous_hash));
537 assert!(guard.hash_mismatch);
538 }
539
540 client.set_session_callbacks(
541 Box::new({
542 let tokens = next_tokens.clone();
544 move |_| Ok(tokens.clone())
545 }),
546 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
547 )?;
548
549 oauth.refresh_access_token().await?;
550
551 {
552 let xp_manager =
554 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
555 let guard = xp_manager.spin_lock().await?;
556 let actual_hash = compute_session_hash(&next_tokens);
557 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
558 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
559 assert!(!guard.hash_mismatch);
560 }
561
562 Ok(())
563 }
564
565 #[async_test]
566 async fn test_logout() -> anyhow::Result<()> {
567 let server = MatrixMockServer::new().await;
568
569 let oauth_server = server.oauth();
570 oauth_server
571 .mock_server_metadata()
572 .ok_https()
573 .expect(1..)
574 .named("server_metadata")
575 .mount()
576 .await;
577 oauth_server.mock_revocation().ok().expect(1).named("revocation").mount().await;
578
579 let tmp_dir = tempfile::tempdir()?;
580 let client = server
581 .client_builder()
582 .on_builder(|builder| builder.sqlite_store(&tmp_dir, None))
583 .unlogged()
584 .build()
585 .await;
586 let oauth = client.oauth().insecure_rewrite_https_to_http();
587
588 oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
590
591 let tokens = mock_session_tokens_with_refresh();
593 oauth.restore_session(mock_session(tokens.clone()), RoomLoadSettings::default()).await?;
594
595 oauth.logout().await.unwrap();
596
597 {
598 let xp_manager =
601 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
602 let guard = xp_manager.spin_lock().await?;
603 assert!(guard.db_hash.is_none());
604 assert!(guard.hash_guard.is_none());
605 assert!(!guard.hash_mismatch);
606 }
607
608 Ok(())
609 }
610
611 #[test]
612 fn test_session_hash_to_hex() {
613 let hash = SessionHash(vec![]);
614 assert_eq!(hash.to_hex(), "");
615
616 let hash = SessionHash(vec![0x13, 0x37, 0x42, 0xde, 0xad, 0xca, 0xfe]);
617 assert_eq!(hash.to_hex(), "0x133742deadcafe");
618 }
619}