use std::collections::BTreeMap;
use matrix_sdk_base::{StateStore, StoreError};
use matrix_sdk_common::timer;
use ruma::{OwnedRoomId, UserId};
use tracing::{trace, warn};
use super::{
FrozenSlidingSync, FrozenSlidingSyncList, SlidingSync, SlidingSyncList,
SlidingSyncPositionMarkers, SlidingSyncRoom,
};
#[cfg(feature = "e2e-encryption")]
use crate::sliding_sync::FrozenSlidingSyncPos;
use crate::{sliding_sync::SlidingSyncListCachePolicy, Client, Result};
pub(super) fn format_storage_key_prefix(id: &str, user_id: &UserId) -> String {
format!("sliding_sync_store::{}::{}", id, user_id)
}
fn format_storage_key_for_sliding_sync(storage_key: &str) -> String {
format!("{storage_key}::instance")
}
fn format_storage_key_for_sliding_sync_list(storage_key: &str, list_name: &str) -> String {
format!("{storage_key}::list::{list_name}")
}
async fn invalidate_cached_list(
storage: &dyn StateStore<Error = StoreError>,
storage_key: &str,
list_name: &str,
) {
let storage_key_for_list = format_storage_key_for_sliding_sync_list(storage_key, list_name);
let _ = storage.remove_custom_value(storage_key_for_list.as_bytes()).await;
}
async fn clean_storage(
client: &Client,
storage_key: &str,
lists: &BTreeMap<String, SlidingSyncList>,
) {
let storage = client.store();
for list_name in lists.keys() {
invalidate_cached_list(storage, storage_key, list_name).await;
}
let instance_storage_key = format_storage_key_for_sliding_sync(storage_key);
let _ = storage.remove_custom_value(instance_storage_key.as_bytes()).await;
#[cfg(feature = "e2e-encryption")]
if let Some(olm_machine) = &*client.olm_machine().await {
let _ = olm_machine
.store()
.set_custom_value(&instance_storage_key, "".as_bytes().to_vec())
.await;
}
}
pub(super) async fn store_sliding_sync_state(
sliding_sync: &SlidingSync,
position: &SlidingSyncPositionMarkers,
) -> Result<()> {
let storage_key = &sliding_sync.inner.storage_key;
let instance_storage_key = format_storage_key_for_sliding_sync(storage_key);
trace!(storage_key, "Saving a `SlidingSync` to the state store");
let storage = sliding_sync.inner.client.store();
storage
.set_custom_value(
instance_storage_key.as_bytes(),
serde_json::to_vec(&FrozenSlidingSync::new(&*sliding_sync.inner.rooms.read().await))?,
)
.await?;
#[cfg(feature = "e2e-encryption")]
{
if let Some(olm_machine) = &*sliding_sync.inner.client.olm_machine().await {
let pos_blob = serde_json::to_vec(&FrozenSlidingSyncPos { pos: position.pos.clone() })?;
olm_machine.store().set_custom_value(&instance_storage_key, pos_blob).await?;
}
}
let frozen_lists = {
sliding_sync
.inner
.lists
.read()
.await
.iter()
.filter(|(_, list)| matches!(list.cache_policy(), SlidingSyncListCachePolicy::Enabled))
.map(|(list_name, list)| {
Ok((
format_storage_key_for_sliding_sync_list(storage_key, list_name),
serde_json::to_vec(&FrozenSlidingSyncList::freeze(list))?,
))
})
.collect::<Result<Vec<_>, crate::Error>>()?
};
for (storage_key_for_list, frozen_list) in frozen_lists {
trace!(storage_key_for_list, "Saving a `SlidingSyncList`");
storage.set_custom_value(storage_key_for_list.as_bytes(), frozen_list).await?;
}
Ok(())
}
pub(super) async fn restore_sliding_sync_list(
storage: &dyn StateStore<Error = StoreError>,
storage_key: &str,
list_name: &str,
) -> Result<Option<FrozenSlidingSyncList>> {
let _timer = timer!(format!("loading list from DB {list_name}"));
let storage_key_for_list = format_storage_key_for_sliding_sync_list(storage_key, list_name);
match storage
.get_custom_value(storage_key_for_list.as_bytes())
.await?
.map(|custom_value| serde_json::from_slice::<FrozenSlidingSyncList>(&custom_value))
{
Some(Ok(frozen_list)) => {
trace!(list_name, "successfully read the list from cache");
return Ok(Some(frozen_list));
}
Some(Err(_)) => {
warn!(
list_name,
"failed to deserialize the list from the cache, it is obsolete; removing the cache entry!"
);
invalidate_cached_list(storage, storage_key, list_name).await;
}
None => {
trace!(list_name, "failed to find the list in the cache");
}
}
Ok(None)
}
#[derive(Default)]
pub(super) struct RestoredFields {
pub to_device_token: Option<String>,
pub pos: Option<String>,
pub rooms: BTreeMap<OwnedRoomId, SlidingSyncRoom>,
}
pub(super) async fn restore_sliding_sync_state(
client: &Client,
storage_key: &str,
lists: &BTreeMap<String, SlidingSyncList>,
) -> Result<Option<RestoredFields>> {
let _timer = timer!(format!("loading sliding sync {storage_key} state from DB"));
let mut restored_fields = RestoredFields::default();
#[cfg(feature = "e2e-encryption")]
if let Some(olm_machine) = &*client.olm_machine().await {
match olm_machine.store().next_batch_token().await? {
Some(token) => {
restored_fields.to_device_token = Some(token);
}
None => trace!("No `SlidingSync` in the crypto-store cache"),
}
}
let storage = client.store();
let instance_storage_key = format_storage_key_for_sliding_sync(storage_key);
match storage
.get_custom_value(instance_storage_key.as_bytes())
.await?
.map(|custom_value| serde_json::from_slice::<FrozenSlidingSync>(&custom_value))
{
Some(Ok(FrozenSlidingSync { to_device_since, rooms: frozen_rooms })) => {
trace!("Successfully read the `SlidingSync` from the cache");
if restored_fields.to_device_token.is_none() {
restored_fields.to_device_token = to_device_since;
}
#[cfg(feature = "e2e-encryption")]
{
if let Some(olm_machine) = &*client.olm_machine().await {
if let Ok(Some(blob)) =
olm_machine.store().get_custom_value(&instance_storage_key).await
{
if let Ok(frozen_pos) =
serde_json::from_slice::<FrozenSlidingSyncPos>(&blob)
{
trace!("Successfully read the `Sliding Sync` pos from the crypto store cache");
restored_fields.pos = frozen_pos.pos;
}
}
}
}
restored_fields.rooms = frozen_rooms
.into_iter()
.map(|frozen_room| {
(
frozen_room.room_id.clone(),
SlidingSyncRoom::from_frozen(frozen_room, client.clone()),
)
})
.collect();
}
Some(Err(_)) => {
warn!(
"failed to deserialize `SlidingSync` from the cache, it is obsolete; removing the cache entry!"
);
clean_storage(client, storage_key, lists).await;
return Ok(None);
}
None => {
trace!("No Sliding Sync object in the cache");
}
}
Ok(Some(restored_fields))
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, RwLock};
use assert_matches::assert_matches;
use matrix_sdk_test::async_test;
use ruma::owned_room_id;
use super::{
super::FrozenSlidingSyncRoom, clean_storage, format_storage_key_for_sliding_sync,
format_storage_key_for_sliding_sync_list, format_storage_key_prefix,
restore_sliding_sync_state, store_sliding_sync_state, SlidingSyncList,
};
use crate::{test_utils::logged_in_client, Result, SlidingSyncRoom};
#[allow(clippy::await_holding_lock)]
#[async_test]
async fn test_sliding_sync_can_be_stored_and_restored() -> Result<()> {
let client = logged_in_client(Some("https://foo.bar".to_owned())).await;
let store = client.store();
assert!(store
.get_custom_value(format_storage_key_for_sliding_sync("hello").as_bytes())
.await?
.is_none());
assert!(store
.get_custom_value(
format_storage_key_for_sliding_sync_list("hello", "list_foo").as_bytes()
)
.await?
.is_none());
assert!(store
.get_custom_value(
format_storage_key_for_sliding_sync_list("hello", "list_bar").as_bytes()
)
.await?
.is_none());
let room_id1 = owned_room_id!("!r1:matrix.org");
let room_id2 = owned_room_id!("!r2:matrix.org");
let storage_key = {
let sync_id = "test-sync-id";
let storage_key = format_storage_key_prefix(sync_id, client.user_id().unwrap());
let sliding_sync = client
.sliding_sync(sync_id)?
.add_cached_list(SlidingSyncList::builder("list_foo"))
.await?
.add_list(SlidingSyncList::builder("list_bar"))
.build()
.await?;
{
let lists = sliding_sync.inner.lists.write().await;
let list_foo = lists.get("list_foo").unwrap();
list_foo.set_maximum_number_of_rooms(Some(42));
let list_bar = lists.get("list_bar").unwrap();
list_bar.set_maximum_number_of_rooms(Some(1337));
}
{
let mut rooms = sliding_sync.inner.rooms.write().await;
rooms.insert(
room_id1.clone(),
SlidingSyncRoom::new(client.clone(), room_id1.clone(), None, Vec::new()),
);
rooms.insert(
room_id2.clone(),
SlidingSyncRoom::new(client.clone(), room_id2.clone(), None, Vec::new()),
);
}
let position_guard = sliding_sync.inner.position.lock().await;
assert!(sliding_sync.cache_to_storage(&position_guard).await.is_ok());
storage_key
};
assert!(store
.get_custom_value(format_storage_key_for_sliding_sync(&storage_key).as_bytes())
.await?
.is_some());
assert!(store
.get_custom_value(
format_storage_key_for_sliding_sync_list(&storage_key, "list_foo").as_bytes()
)
.await?
.is_some());
assert!(store
.get_custom_value(
format_storage_key_for_sliding_sync_list(&storage_key, "list_bar").as_bytes()
)
.await?
.is_none());
let storage_key = {
let sync_id = "test-sync-id";
let storage_key = format_storage_key_prefix(sync_id, client.user_id().unwrap());
let max_number_of_room_stream = Arc::new(RwLock::new(None));
let cloned_stream = max_number_of_room_stream.clone();
let sliding_sync = client
.sliding_sync(sync_id)?
.add_cached_list(SlidingSyncList::builder("list_foo").once_built(move |list| {
assert_eq!(list.maximum_number_of_rooms(), None);
let mut stream = cloned_stream.write().unwrap();
*stream = Some(list.maximum_number_of_rooms_stream());
list
}))
.await?
.add_list(SlidingSyncList::builder("list_bar"))
.build()
.await?;
{
let lists = sliding_sync.inner.lists.read().await;
let list_foo = lists.get("list_foo").unwrap();
assert_eq!(list_foo.maximum_number_of_rooms(), Some(42));
let list_bar = lists.get("list_bar").unwrap();
assert_eq!(list_bar.maximum_number_of_rooms(), None);
}
{
let rooms = sliding_sync.inner.rooms.read().await;
assert!(rooms.contains_key(&room_id1));
assert!(rooms.contains_key(&room_id2));
}
{
let mut stream =
max_number_of_room_stream.write().unwrap().take().expect("stream must be set");
let initial_max_number_of_rooms =
stream.next().await.expect("stream must have emitted something");
assert_eq!(initial_max_number_of_rooms, Some(42));
}
let lists = sliding_sync.inner.lists.read().await;
clean_storage(&client, &storage_key, &lists).await;
storage_key
};
assert!(store
.get_custom_value(format_storage_key_for_sliding_sync(&storage_key).as_bytes())
.await?
.is_none());
assert!(store
.get_custom_value(
format_storage_key_for_sliding_sync_list(&storage_key, "list_foo").as_bytes()
)
.await?
.is_none());
assert!(store
.get_custom_value(
format_storage_key_for_sliding_sync_list(&storage_key, "list_bar").as_bytes()
)
.await?
.is_none());
Ok(())
}
#[cfg(feature = "e2e-encryption")]
#[async_test]
async fn test_sliding_sync_high_level_cache_and_restore() -> Result<()> {
use imbl::Vector;
use ruma::owned_room_id;
use crate::sliding_sync::FrozenSlidingSync;
let client = logged_in_client(Some("https://foo.bar".to_owned())).await;
let sync_id = "test-sync-id";
let storage_key_prefix = format_storage_key_prefix(sync_id, client.user_id().unwrap());
let full_storage_key = format_storage_key_for_sliding_sync(&storage_key_prefix);
let sliding_sync = client.sliding_sync(sync_id)?.build().await?;
if let Some(olm_machine) = &*client.base_client().olm_machine().await {
let store = olm_machine.store();
assert!(store.next_batch_token().await?.is_none());
}
let state_store = client.store();
assert!(state_store.get_custom_value(full_storage_key.as_bytes()).await?.is_none());
let pos = "pos".to_owned();
{
let mut position_guard = sliding_sync.inner.position.lock().await;
position_guard.pos = Some(pos.clone());
store_sliding_sync_state(&sliding_sync, &position_guard).await?;
}
let state_store = client.store();
assert_matches!(
state_store.get_custom_value(full_storage_key.as_bytes()).await?,
Some(bytes) => {
let deserialized: FrozenSlidingSync = serde_json::from_slice(&bytes)?;
assert!(deserialized.to_device_since.is_none());
}
);
drop(sliding_sync);
let restored_fields = restore_sliding_sync_state(&client, &storage_key_prefix, &[].into())
.await?
.expect("must have restored sliding sync fields");
assert_eq!(restored_fields.pos.unwrap(), pos);
{
let olm_machine = client.base_client().olm_machine().await;
let olm_machine = olm_machine.as_ref().unwrap();
assert!(olm_machine.store().next_batch_token().await?.is_none());
}
let to_device_token = "to_device_token".to_owned();
let state_store = client.store();
state_store
.set_custom_value(
full_storage_key.as_bytes(),
serde_json::to_vec(&FrozenSlidingSync {
to_device_since: Some(to_device_token.clone()),
rooms: vec![FrozenSlidingSyncRoom {
room_id: owned_room_id!("!r0:matrix.org"),
prev_batch: Some("t0ken".to_owned()),
timeline_queue: Vector::new(),
}],
})?,
)
.await?;
let restored_fields = restore_sliding_sync_state(&client, &storage_key_prefix, &[].into())
.await?
.expect("must have restored fields");
assert_eq!(restored_fields.to_device_token.unwrap(), to_device_token);
assert_eq!(restored_fields.pos.unwrap(), pos);
assert_eq!(restored_fields.rooms.len(), 1);
Ok(())
}
}