use std::sync::Arc;
use matrix_sdk_base::crypto::{
store::{LockableCryptoStore, Store},
CryptoStoreError,
};
use matrix_sdk_common::store_locks::{
CrossProcessStoreLock, CrossProcessStoreLockGuard, LockStoreError,
};
use sha2::{Digest as _, Sha256};
use thiserror::Error;
use tokio::sync::{Mutex, OwnedMutexGuard};
use tracing::trace;
use super::OidcSessionTokens;
const OIDC_SESSION_HASH_KEY: &str = "oidc_session_hash";
#[derive(Clone, PartialEq, Eq)]
struct SessionHash(Vec<u8>);
impl SessionHash {
fn to_hex(&self) -> String {
const CHARS: &[char; 16] =
&['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'];
let mut res = String::with_capacity(2 * self.0.len() + 2);
if !self.0.is_empty() {
res.push('0');
res.push('x');
}
for &c in &self.0 {
res.push(CHARS[(c >> 4) as usize]);
res.push(CHARS[(c & 0b1111) as usize]);
}
res
}
}
impl std::fmt::Debug for SessionHash {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("SessionHash").field(&self.to_hex()).finish()
}
}
fn compute_session_hash(tokens: &OidcSessionTokens) -> SessionHash {
let mut hash = Sha256::new().chain_update(tokens.access_token.as_bytes());
if let Some(refresh_token) = &tokens.refresh_token {
hash = hash.chain_update(refresh_token.as_bytes());
}
SessionHash(hash.finalize().to_vec())
}
#[derive(Clone)]
pub(super) struct CrossProcessRefreshManager {
store: Store,
store_lock: CrossProcessStoreLock<LockableCryptoStore>,
known_session_hash: Arc<Mutex<Option<SessionHash>>>,
}
impl CrossProcessRefreshManager {
pub fn new(store: Store, lock: CrossProcessStoreLock<LockableCryptoStore>) -> Self {
Self { store, store_lock: lock, known_session_hash: Arc::new(Mutex::new(None)) }
}
pub async fn spin_lock(
&self,
) -> Result<CrossProcessRefreshLockGuard, CrossProcessRefreshLockError> {
trace!("Waiting for intra-process lock...");
let prev_hash = self.known_session_hash.clone().lock_owned().await;
trace!("Waiting for inter-process lock...");
let store_guard = self.store_lock.spin_lock(Some(60000)).await?;
let current_db_session_bytes = self.store.get_custom_value(OIDC_SESSION_HASH_KEY).await?;
let db_hash = current_db_session_bytes.map(SessionHash);
let hash_mismatch = match (&db_hash, &*prev_hash) {
(None, _) => false,
(Some(_), None) => true,
(Some(db), Some(known)) => db != known,
};
trace!(
"Hash mismatch? {:?} (prev. known={:?}, db={:?})",
hash_mismatch,
*prev_hash,
db_hash
);
let guard = CrossProcessRefreshLockGuard {
hash_guard: prev_hash,
_store_guard: store_guard,
hash_mismatch,
db_hash,
store: self.store.clone(),
};
Ok(guard)
}
pub async fn restore_session(&self, tokens: &OidcSessionTokens) {
let prev_tokens_hash = compute_session_hash(tokens);
*self.known_session_hash.lock().await = Some(prev_tokens_hash);
}
pub async fn on_logout(&self) -> Result<(), CrossProcessRefreshLockError> {
self.store
.remove_custom_value(OIDC_SESSION_HASH_KEY)
.await
.map_err(CrossProcessRefreshLockError::StoreError)?;
*self.known_session_hash.lock().await = None;
Ok(())
}
}
pub(super) struct CrossProcessRefreshLockGuard {
hash_guard: OwnedMutexGuard<Option<SessionHash>>,
_store_guard: CrossProcessStoreLockGuard,
store: Store,
pub hash_mismatch: bool,
db_hash: Option<SessionHash>,
}
impl CrossProcessRefreshLockGuard {
fn save_in_memory(&mut self, hash: SessionHash) {
*self.hash_guard = Some(hash);
}
async fn save_in_database(
&self,
hash: &SessionHash,
) -> Result<(), CrossProcessRefreshLockError> {
self.store.set_custom_value(OIDC_SESSION_HASH_KEY, hash.0.clone()).await?;
Ok(())
}
pub async fn save_in_memory_and_db(
&mut self,
tokens: &OidcSessionTokens,
) -> Result<(), CrossProcessRefreshLockError> {
let hash = compute_session_hash(tokens);
self.save_in_database(&hash).await?;
self.save_in_memory(hash);
Ok(())
}
pub async fn handle_mismatch(
&mut self,
trusted_tokens: &OidcSessionTokens,
) -> Result<(), CrossProcessRefreshLockError> {
let new_hash = compute_session_hash(trusted_tokens);
trace!("Trusted OIDC tokens have hash {new_hash:?}; db had {:?}", self.db_hash);
if let Some(db_hash) = &self.db_hash {
if new_hash != *db_hash {
tracing::error!("error: DB and trusted disagree. Overriding in DB.");
self.save_in_database(&new_hash).await?;
}
}
self.save_in_memory(new_hash);
Ok(())
}
}
#[derive(Debug, Error)]
pub enum CrossProcessRefreshLockError {
#[error(transparent)]
StoreError(#[from] CryptoStoreError),
#[error(transparent)]
LockError(#[from] LockStoreError),
#[error("the previous stored hash isn't a valid integer")]
InvalidPreviousHash,
#[error("the cross-process lock hasn't been set up with `enable_cross_process_refresh_lock")]
MissingLock,
#[error("reload session callback must be set with Client::set_session_callbacks() for the cross-process lock to work")]
MissingReloadSession,
#[error("session tokens returned by the reload_session callback were not for OIDC")]
InvalidSessionTokens,
#[error(
"the cross-process lock has been set up twice with `enable_cross_process_refresh_lock`"
)]
DuplicatedLock,
}
#[cfg(all(test, feature = "e2e-encryption"))]
mod tests {
use std::sync::Arc;
use anyhow::Context as _;
use futures_util::future::join_all;
use matrix_sdk_base::SessionMeta;
use matrix_sdk_test::async_test;
use ruma::{owned_device_id, owned_user_id};
use wiremock::{
matchers::{method, path},
Mock, MockServer, ResponseTemplate,
};
use super::compute_session_hash;
use crate::{
oidc::{
backend::mock::{MockImpl, ISSUER_URL},
cross_process::SessionHash,
tests,
tests::mock_registered_client_data,
Oidc, OidcSessionTokens,
},
test_utils::test_client_builder,
Error,
};
#[async_test]
async fn test_restore_session_lock() -> Result<(), Error> {
let tmp_dir = tempfile::tempdir()?;
let client = test_client_builder(Some("https://example.org".to_owned()))
.sqlite_store(&tmp_dir, None)
.build()
.await
.unwrap();
let tokens = OidcSessionTokens {
access_token: "prev-access-token".to_owned(),
refresh_token: Some("prev-refresh-token".to_owned()),
latest_id_token: None,
};
client.oidc().enable_cross_process_refresh_lock("test".to_owned()).await?;
client.set_session_callbacks(
Box::new({
let tokens = tokens.clone();
move |_| Ok(crate::authentication::SessionTokens::Oidc(tokens.clone()))
}),
Box::new(|_| panic!("save_session_callback shouldn't be called here")),
)?;
let session_hash = compute_session_hash(&tokens);
client.oidc().restore_session(tests::mock_session(tokens.clone())).await?;
assert_eq!(client.oidc().session_tokens().unwrap(), tokens);
let oidc = client.oidc();
let xp_manager = oidc.ctx().cross_process_token_refresh_manager.get().unwrap();
{
let known_session = xp_manager.known_session_hash.lock().await;
assert_eq!(known_session.as_ref().unwrap(), &session_hash);
}
{
let lock = xp_manager.spin_lock().await.unwrap();
assert!(!lock.hash_mismatch);
assert_eq!(lock.db_hash.unwrap(), session_hash);
}
Ok(())
}
#[async_test]
async fn test_finish_login() -> anyhow::Result<()> {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/_matrix/client/r0/account/whoami"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"user_id": "@joe:example.org",
"device_id": "D3V1C31D",
})))
.expect(1)
.named("`GET /whoami` good token")
.mount(&server)
.await;
let tmp_dir = tempfile::tempdir()?;
let client =
test_client_builder(Some(server.uri())).sqlite_store(&tmp_dir, None).build().await?;
let oidc = Oidc { client: client.clone(), backend: Arc::new(MockImpl::new()) };
let (client_credentials, client_metadata) = mock_registered_client_data();
oidc.restore_registered_client(ISSUER_URL.to_owned(), client_metadata, client_credentials);
oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?;
let session_tokens = OidcSessionTokens {
access_token: "access".to_owned(),
refresh_token: Some("refresh".to_owned()),
latest_id_token: None,
};
oidc.set_session_tokens(session_tokens.clone());
oidc.finish_login().await?;
let session_meta = client.session_meta().context("should have session meta now")?;
assert_eq!(
*session_meta,
SessionMeta {
user_id: owned_user_id!("@joe:example.org"),
device_id: owned_device_id!("D3V1C31D")
}
);
{
let xp_manager =
oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
let guard = xp_manager.spin_lock().await?;
let actual_hash = compute_session_hash(&session_tokens);
assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
assert!(!guard.hash_mismatch);
}
Ok(())
}
#[async_test]
async fn test_refresh_access_token_twice() -> anyhow::Result<()> {
let tmp_dir = tempfile::tempdir()?;
let client = test_client_builder(Some("https://example.org".to_owned()))
.sqlite_store(&tmp_dir, None)
.build()
.await?;
let prev_tokens = OidcSessionTokens {
access_token: "prev-access-token".to_owned(),
refresh_token: Some("prev-refresh-token".to_owned()),
latest_id_token: None,
};
let next_tokens = OidcSessionTokens {
access_token: "next-access-token".to_owned(),
refresh_token: Some("next-refresh-token".to_owned()),
latest_id_token: None,
};
let backend = Arc::new(
MockImpl::new()
.next_session_tokens(next_tokens.clone())
.expected_refresh_token(prev_tokens.refresh_token.clone().unwrap()),
);
let oidc = Oidc { client: client.clone(), backend: backend.clone() };
oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?;
oidc.restore_session(tests::mock_session(prev_tokens.clone())).await?;
for result in join_all([oidc.refresh_access_token(), oidc.refresh_access_token()]).await {
result?;
}
assert_eq!(*backend.num_refreshes.lock().unwrap(), 1);
{
let xp_manager =
oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
let guard = xp_manager.spin_lock().await?;
let actual_hash = compute_session_hash(&next_tokens);
assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
assert!(!guard.hash_mismatch);
}
Ok(())
}
#[async_test]
async fn test_cross_process_concurrent_refresh() -> anyhow::Result<()> {
let prev_tokens = OidcSessionTokens {
access_token: "prev-access-token".to_owned(),
refresh_token: Some("prev-refresh-token".to_owned()),
latest_id_token: None,
};
let next_tokens = OidcSessionTokens {
access_token: "next-access-token".to_owned(),
refresh_token: Some("next-refresh-token".to_owned()),
latest_id_token: None,
};
let backend = Arc::new(
MockImpl::new()
.next_session_tokens(next_tokens.clone())
.expected_refresh_token(prev_tokens.refresh_token.clone().unwrap()),
);
let tmp_dir = tempfile::tempdir()?;
let client = test_client_builder(Some("https://example.org".to_owned()))
.sqlite_store(&tmp_dir, None)
.build()
.await?;
let oidc = Oidc { client: client.clone(), backend: backend.clone() };
oidc.enable_cross_process_refresh_lock("client1".to_owned()).await?;
oidc.restore_session(tests::mock_session(prev_tokens.clone())).await?;
let unrestored_client = test_client_builder(Some("https://example.org".to_owned()))
.sqlite_store(&tmp_dir, None)
.build()
.await?;
let unrestored_oidc = Oidc { client: unrestored_client.clone(), backend: backend.clone() };
unrestored_oidc.enable_cross_process_refresh_lock("unrestored_client".to_owned()).await?;
{
let client3 = test_client_builder(Some("https://example.org".to_owned()))
.sqlite_store(&tmp_dir, None)
.build()
.await?;
let oidc3 = Oidc { client: client3.clone(), backend: backend.clone() };
oidc3.enable_cross_process_refresh_lock("client3".to_owned()).await?;
oidc3.restore_session(tests::mock_session(prev_tokens.clone())).await?;
oidc3.refresh_access_token().await?;
assert_eq!(oidc3.session_tokens(), Some(next_tokens.clone()));
let xp_manager =
oidc3.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
let guard = xp_manager.spin_lock().await?;
let actual_hash = compute_session_hash(&next_tokens);
assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
assert!(!guard.hash_mismatch);
}
{
let oidc = unrestored_oidc;
unrestored_client.set_session_callbacks(
Box::new({
let tokens = next_tokens.clone();
move |_| Ok(crate::authentication::SessionTokens::Oidc(tokens.clone()))
}),
Box::new(|_| panic!("save_session_callback shouldn't be called here")),
)?;
oidc.restore_session(tests::mock_session(prev_tokens.clone())).await?;
let xp_manager =
oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
let guard = xp_manager.spin_lock().await?;
let next_hash = compute_session_hash(&next_tokens);
assert_eq!(guard.db_hash.as_ref(), Some(&next_hash));
assert_eq!(guard.hash_guard.as_ref(), Some(&next_hash));
assert!(!guard.hash_mismatch);
drop(oidc);
drop(unrestored_client);
}
{
let xp_manager =
oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
let guard = xp_manager.spin_lock().await?;
let previous_hash = compute_session_hash(&prev_tokens);
let next_hash = compute_session_hash(&next_tokens);
assert_eq!(guard.db_hash, Some(next_hash));
assert_eq!(guard.hash_guard.as_ref(), Some(&previous_hash));
assert!(guard.hash_mismatch);
}
client.set_session_callbacks(
Box::new({
let tokens = next_tokens.clone();
move |_| Ok(crate::authentication::SessionTokens::Oidc(tokens.clone()))
}),
Box::new(|_| panic!("save_session_callback shouldn't be called here")),
)?;
oidc.refresh_access_token().await?;
{
let xp_manager =
oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
let guard = xp_manager.spin_lock().await?;
let actual_hash = compute_session_hash(&next_tokens);
assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
assert!(!guard.hash_mismatch);
}
assert_eq!(*backend.num_refreshes.lock().unwrap(), 1);
Ok(())
}
#[async_test]
async fn test_logout() -> anyhow::Result<()> {
let tmp_dir = tempfile::tempdir()?;
let client = test_client_builder(Some("https://example.org".to_owned()))
.sqlite_store(&tmp_dir, None)
.build()
.await?;
let tokens = OidcSessionTokens {
access_token: "prev-access-token".to_owned(),
refresh_token: Some("prev-refresh-token".to_owned()),
latest_id_token: None,
};
let backend = Arc::new(MockImpl::new());
let oidc = Oidc { client: client.clone(), backend: backend.clone() };
oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?;
oidc.restore_session(tests::mock_session(tokens.clone())).await?;
let end_session_builder = oidc.logout().await?;
assert!(end_session_builder.is_none());
{
let revoked = backend.revoked_tokens.lock().unwrap();
assert_eq!(revoked.len(), 2);
assert_eq!(
*revoked,
vec![tokens.access_token.clone(), tokens.refresh_token.clone().unwrap(),]
);
}
{
let xp_manager =
oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
let guard = xp_manager.spin_lock().await?;
assert!(guard.db_hash.is_none());
assert!(guard.hash_guard.is_none());
assert!(!guard.hash_mismatch);
}
Ok(())
}
#[test]
fn test_session_hash_to_hex() {
let hash = SessionHash(vec![]);
assert_eq!(hash.to_hex(), "");
let hash = SessionHash(vec![0x13, 0x37, 0x42, 0xde, 0xad, 0xca, 0xfe]);
assert_eq!(hash.to_hex(), "0x133742deadcafe");
}
}