matrix_sdk/authentication/oidc/
cross_process.rs1use std::sync::Arc;
2
3#[cfg(feature = "e2e-encryption")]
4use matrix_sdk_base::crypto::{
5 store::{LockableCryptoStore, Store},
6 CryptoStoreError,
7};
8use matrix_sdk_common::store_locks::{
9 CrossProcessStoreLock, CrossProcessStoreLockGuard, LockStoreError,
10};
11use sha2::{Digest as _, Sha256};
12use thiserror::Error;
13use tokio::sync::{Mutex, OwnedMutexGuard};
14use tracing::trace;
15
16use crate::SessionTokens;
17
18const OIDC_SESSION_HASH_KEY: &str = "oidc_session_hash";
21
22#[derive(Clone, PartialEq, Eq)]
24struct SessionHash(Vec<u8>);
25
26impl SessionHash {
27 fn to_hex(&self) -> String {
28 const CHARS: &[char; 16] =
29 &['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'];
30 let mut res = String::with_capacity(2 * self.0.len() + 2);
31 if !self.0.is_empty() {
32 res.push('0');
33 res.push('x');
34 }
35 for &c in &self.0 {
36 res.push(CHARS[(c >> 4) as usize]);
40 res.push(CHARS[(c & 0b1111) as usize]);
41 }
42 res
43 }
44}
45
46impl std::fmt::Debug for SessionHash {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_tuple("SessionHash").field(&self.to_hex()).finish()
49 }
50}
51
52fn compute_session_hash(tokens: &SessionTokens) -> SessionHash {
54 let mut hash = Sha256::new().chain_update(tokens.access_token.as_bytes());
55 if let Some(refresh_token) = &tokens.refresh_token {
56 hash = hash.chain_update(refresh_token.as_bytes());
57 }
58 SessionHash(hash.finalize().to_vec())
59}
60
61#[derive(Clone)]
62pub(super) struct CrossProcessRefreshManager {
63 store: Store,
64 store_lock: CrossProcessStoreLock<LockableCryptoStore>,
65 known_session_hash: Arc<Mutex<Option<SessionHash>>>,
66}
67
68impl CrossProcessRefreshManager {
69 pub fn new(store: Store, lock: CrossProcessStoreLock<LockableCryptoStore>) -> Self {
71 Self { store, store_lock: lock, known_session_hash: Arc::new(Mutex::new(None)) }
72 }
73
74 pub async fn spin_lock(
80 &self,
81 ) -> Result<CrossProcessRefreshLockGuard, CrossProcessRefreshLockError> {
82 trace!("Waiting for intra-process lock...");
85 let prev_hash = self.known_session_hash.clone().lock_owned().await;
86
87 trace!("Waiting for inter-process lock...");
90 let store_guard = self.store_lock.spin_lock(Some(60000)).await?;
91
92 let current_db_session_bytes = self.store.get_custom_value(OIDC_SESSION_HASH_KEY).await?;
94
95 let db_hash = current_db_session_bytes.map(SessionHash);
96
97 let hash_mismatch = match (&db_hash, &*prev_hash) {
98 (None, _) => false,
99 (Some(_), None) => true,
100 (Some(db), Some(known)) => db != known,
101 };
102
103 trace!(
104 "Hash mismatch? {:?} (prev. known={:?}, db={:?})",
105 hash_mismatch,
106 *prev_hash,
107 db_hash
108 );
109
110 let guard = CrossProcessRefreshLockGuard {
111 hash_guard: prev_hash,
112 _store_guard: store_guard,
113 hash_mismatch,
114 db_hash,
115 store: self.store.clone(),
116 };
117
118 Ok(guard)
119 }
120
121 pub async fn restore_session(&self, tokens: &SessionTokens) {
122 let prev_tokens_hash = compute_session_hash(tokens);
123 *self.known_session_hash.lock().await = Some(prev_tokens_hash);
124 }
125
126 pub async fn on_logout(&self) -> Result<(), CrossProcessRefreshLockError> {
127 self.store
128 .remove_custom_value(OIDC_SESSION_HASH_KEY)
129 .await
130 .map_err(CrossProcessRefreshLockError::StoreError)?;
131 *self.known_session_hash.lock().await = None;
132 Ok(())
133 }
134}
135
136pub(super) struct CrossProcessRefreshLockGuard {
137 hash_guard: OwnedMutexGuard<Option<SessionHash>>,
140
141 _store_guard: CrossProcessStoreLockGuard,
143
144 store: Store,
147
148 pub hash_mismatch: bool,
157
158 db_hash: Option<SessionHash>,
162}
163
164impl CrossProcessRefreshLockGuard {
165 fn save_in_memory(&mut self, hash: SessionHash) {
167 *self.hash_guard = Some(hash);
168 }
169
170 async fn save_in_database(
172 &self,
173 hash: &SessionHash,
174 ) -> Result<(), CrossProcessRefreshLockError> {
175 self.store.set_custom_value(OIDC_SESSION_HASH_KEY, hash.0.clone()).await?;
176 Ok(())
177 }
178
179 pub async fn save_in_memory_and_db(
183 &mut self,
184 tokens: &SessionTokens,
185 ) -> Result<(), CrossProcessRefreshLockError> {
186 let hash = compute_session_hash(tokens);
187 self.save_in_database(&hash).await?;
188 self.save_in_memory(hash);
189 Ok(())
190 }
191
192 pub async fn handle_mismatch(
195 &mut self,
196 trusted_tokens: &SessionTokens,
197 ) -> Result<(), CrossProcessRefreshLockError> {
198 let new_hash = compute_session_hash(trusted_tokens);
199 trace!("Trusted OIDC tokens have hash {new_hash:?}; db had {:?}", self.db_hash);
200
201 if let Some(db_hash) = &self.db_hash {
202 if new_hash != *db_hash {
203 tracing::error!("error: DB and trusted disagree. Overriding in DB.");
207 self.save_in_database(&new_hash).await?;
208 }
209 }
210
211 self.save_in_memory(new_hash);
212 Ok(())
213 }
214}
215
216#[derive(Debug, Error)]
219pub enum CrossProcessRefreshLockError {
220 #[error(transparent)]
222 StoreError(#[from] CryptoStoreError),
223
224 #[error(transparent)]
226 LockError(#[from] LockStoreError),
227
228 #[error("the previous stored hash isn't a valid integer")]
230 InvalidPreviousHash,
231
232 #[error("the cross-process lock hasn't been set up with `enable_cross_process_refresh_lock")]
234 MissingLock,
235
236 #[error("reload session callback must be set with Client::set_session_callbacks() for the cross-process lock to work")]
238 MissingReloadSession,
239
240 #[error(
242 "the cross-process lock has been set up twice with `enable_cross_process_refresh_lock`"
243 )]
244 DuplicatedLock,
245}
246
247#[cfg(all(test, feature = "e2e-encryption"))]
248mod tests {
249
250 use anyhow::Context as _;
251 use futures_util::future::join_all;
252 use matrix_sdk_base::SessionMeta;
253 use matrix_sdk_test::async_test;
254 use ruma::{owned_device_id, owned_user_id};
255
256 use super::compute_session_hash;
257 use crate::{
258 authentication::oidc::cross_process::SessionHash,
259 test_utils::{
260 client::{
261 mock_prev_session_tokens_with_refresh, mock_session_tokens_with_refresh,
262 oauth::mock_session, MockClientBuilder,
263 },
264 mocks::MatrixMockServer,
265 },
266 Error,
267 };
268
269 #[async_test]
270 async fn test_restore_session_lock() -> Result<(), Error> {
271 let tmp_dir = tempfile::tempdir()?;
274 let client = MockClientBuilder::new("https://example.org".to_owned())
275 .sqlite_store(&tmp_dir)
276 .unlogged()
277 .build()
278 .await;
279
280 let tokens = mock_session_tokens_with_refresh();
281
282 client.oidc().enable_cross_process_refresh_lock("test".to_owned()).await?;
283
284 client.set_session_callbacks(
285 Box::new({
286 let tokens = tokens.clone();
288 move |_| Ok(tokens.clone())
289 }),
290 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
291 )?;
292
293 let session_hash = compute_session_hash(&tokens);
294 client
295 .oidc()
296 .restore_session(mock_session(
297 tokens.clone(),
298 "https://oidc.example.com/issuer".to_owned(),
299 ))
300 .await?;
301
302 assert_eq!(client.session_tokens().unwrap(), tokens);
303
304 let oidc = client.oidc();
305 let xp_manager = oidc.ctx().cross_process_token_refresh_manager.get().unwrap();
306
307 {
308 let known_session = xp_manager.known_session_hash.lock().await;
309 assert_eq!(known_session.as_ref().unwrap(), &session_hash);
310 }
311
312 {
313 let lock = xp_manager.spin_lock().await.unwrap();
314 assert!(!lock.hash_mismatch);
315 assert_eq!(lock.db_hash.unwrap(), session_hash);
316 }
317
318 Ok(())
319 }
320
321 #[async_test]
322 async fn test_finish_login() -> anyhow::Result<()> {
323 let server = MatrixMockServer::new().await;
324 server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
325
326 let tmp_dir = tempfile::tempdir()?;
327 let client = server
328 .client_builder()
329 .sqlite_store(&tmp_dir)
330 .registered_with_oauth(server.server().uri())
331 .build()
332 .await;
333 let oidc = client.oidc();
334
335 oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?;
337
338 let session_tokens = mock_session_tokens_with_refresh();
340 client.auth_ctx().set_session_tokens(session_tokens.clone());
341
342 oidc.finish_login().await?;
344
345 let session_meta = client.session_meta().context("should have session meta now")?;
346 assert_eq!(
347 *session_meta,
348 SessionMeta {
349 user_id: owned_user_id!("@joe:example.org"),
350 device_id: owned_device_id!("D3V1C31D")
351 }
352 );
353
354 {
355 let xp_manager =
358 oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
359 let guard = xp_manager.spin_lock().await?;
360 let actual_hash = compute_session_hash(&session_tokens);
361 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
362 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
363 assert!(!guard.hash_mismatch);
364 }
365
366 Ok(())
367 }
368
369 #[async_test]
370 async fn test_refresh_access_token_twice() -> anyhow::Result<()> {
371 let server = MatrixMockServer::new().await;
375
376 let oauth_server = server.oauth();
377 oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
378 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
379
380 let tmp_dir = tempfile::tempdir()?;
381 let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
382 let oidc = client.oidc();
383
384 let next_tokens = mock_session_tokens_with_refresh();
385
386 oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?;
388
389 oidc.restore_session(mock_session(
391 mock_prev_session_tokens_with_refresh(),
392 server.server().uri(),
393 ))
394 .await?;
395
396 for result in join_all([oidc.refresh_access_token(), oidc.refresh_access_token()]).await {
398 result?;
399 }
400
401 {
402 let xp_manager =
405 oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
406 let guard = xp_manager.spin_lock().await?;
407 let actual_hash = compute_session_hash(&next_tokens);
408 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
409 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
410 assert!(!guard.hash_mismatch);
411 }
412
413 Ok(())
414 }
415
416 #[async_test]
417 async fn test_cross_process_concurrent_refresh() -> anyhow::Result<()> {
418 let server = MatrixMockServer::new().await;
419 let issuer = server.server().uri();
420
421 let oauth_server = server.oauth();
422 oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
423 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
424
425 let prev_tokens = mock_prev_session_tokens_with_refresh();
426 let next_tokens = mock_session_tokens_with_refresh();
427
428 let tmp_dir = tempfile::tempdir()?;
430 let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
431
432 let oidc = client.oidc();
433 oidc.enable_cross_process_refresh_lock("client1".to_owned()).await?;
434
435 oidc.restore_session(mock_session(prev_tokens.clone(), issuer.clone())).await?;
436
437 let unrestored_client =
440 server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
441 let unrestored_oidc = unrestored_client.oidc();
442 unrestored_oidc.enable_cross_process_refresh_lock("unrestored_client".to_owned()).await?;
443
444 {
445 let client3 = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
448
449 let oidc3 = client3.oidc();
450 oidc3.enable_cross_process_refresh_lock("client3".to_owned()).await?;
451 oidc3.restore_session(mock_session(prev_tokens.clone(), issuer.clone())).await?;
452
453 oidc3.refresh_access_token().await?;
456
457 assert_eq!(client3.session_tokens(), Some(next_tokens.clone()));
458
459 let xp_manager =
462 oidc3.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
463 let guard = xp_manager.spin_lock().await?;
464 let actual_hash = compute_session_hash(&next_tokens);
465 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
466 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
467 assert!(!guard.hash_mismatch);
468 }
469
470 {
471 let oidc = unrestored_oidc;
473
474 unrestored_client.set_session_callbacks(
475 Box::new({
476 let tokens = next_tokens.clone();
478 move |_| Ok(tokens.clone())
479 }),
480 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
481 )?;
482
483 oidc.restore_session(mock_session(prev_tokens.clone(), issuer)).await?;
484
485 let xp_manager =
487 oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
488 let guard = xp_manager.spin_lock().await?;
489 let next_hash = compute_session_hash(&next_tokens);
490 assert_eq!(guard.db_hash.as_ref(), Some(&next_hash));
491 assert_eq!(guard.hash_guard.as_ref(), Some(&next_hash));
492 assert!(!guard.hash_mismatch);
493
494 drop(oidc);
495 drop(unrestored_client);
496 }
497
498 {
499 let xp_manager =
502 oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
503 let guard = xp_manager.spin_lock().await?;
504 let previous_hash = compute_session_hash(&prev_tokens);
505 let next_hash = compute_session_hash(&next_tokens);
506 assert_eq!(guard.db_hash, Some(next_hash));
507 assert_eq!(guard.hash_guard.as_ref(), Some(&previous_hash));
508 assert!(guard.hash_mismatch);
509 }
510
511 client.set_session_callbacks(
512 Box::new({
513 let tokens = next_tokens.clone();
515 move |_| Ok(tokens.clone())
516 }),
517 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
518 )?;
519
520 oidc.refresh_access_token().await?;
521
522 {
523 let xp_manager =
525 oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
526 let guard = xp_manager.spin_lock().await?;
527 let actual_hash = compute_session_hash(&next_tokens);
528 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
529 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
530 assert!(!guard.hash_mismatch);
531 }
532
533 Ok(())
534 }
535
536 #[async_test]
537 async fn test_logout() -> anyhow::Result<()> {
538 let server = MatrixMockServer::new().await;
539
540 let oauth_server = server.oauth();
541 oauth_server
542 .mock_server_metadata()
543 .ok_https()
544 .expect(1..)
545 .named("server_metadata")
546 .mount()
547 .await;
548 oauth_server.mock_revocation().ok().expect(1).named("revocation").mount().await;
549
550 let tmp_dir = tempfile::tempdir()?;
551 let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
552 let oidc = client.oidc().insecure_rewrite_https_to_http();
553
554 oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?;
556
557 let tokens = mock_session_tokens_with_refresh();
559 oidc.restore_session(mock_session(tokens.clone(), server.server().uri())).await?;
560
561 oidc.logout().await?;
562
563 {
564 let xp_manager =
567 oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
568 let guard = xp_manager.spin_lock().await?;
569 assert!(guard.db_hash.is_none());
570 assert!(guard.hash_guard.is_none());
571 assert!(!guard.hash_mismatch);
572 }
573
574 Ok(())
575 }
576
577 #[test]
578 fn test_session_hash_to_hex() {
579 let hash = SessionHash(vec![]);
580 assert_eq!(hash.to_hex(), "");
581
582 let hash = SessionHash(vec![0x13, 0x37, 0x42, 0xde, 0xad, 0xca, 0xfe]);
583 assert_eq!(hash.to_hex(), "0x133742deadcafe");
584 }
585}