1use std::collections::BTreeMap;
8
9use matrix_sdk_base::{StateStore, StoreError};
10use matrix_sdk_common::timer;
11use ruma::{OwnedRoomId, UserId};
12use tracing::{trace, warn};
13
14use super::{
15 FrozenSlidingSync, FrozenSlidingSyncList, SlidingSync, SlidingSyncList,
16 SlidingSyncPositionMarkers, SlidingSyncRoom,
17};
18#[cfg(feature = "e2e-encryption")]
19use crate::sliding_sync::FrozenSlidingSyncPos;
20use crate::{sliding_sync::SlidingSyncListCachePolicy, Client, Result};
21
22pub(super) fn format_storage_key_prefix(id: &str, user_id: &UserId) -> String {
25 format!("sliding_sync_store::{}::{}", id, user_id)
26}
27
28fn format_storage_key_for_sliding_sync(storage_key: &str) -> String {
31 format!("{storage_key}::instance")
32}
33
34fn format_storage_key_for_sliding_sync_list(storage_key: &str, list_name: &str) -> String {
37 format!("{storage_key}::list::{list_name}")
38}
39
40async fn invalidate_cached_list(
43 storage: &dyn StateStore<Error = StoreError>,
44 storage_key: &str,
45 list_name: &str,
46) {
47 let storage_key_for_list = format_storage_key_for_sliding_sync_list(storage_key, list_name);
48 let _ = storage.remove_custom_value(storage_key_for_list.as_bytes()).await;
49}
50
51async fn clean_storage(
54 client: &Client,
55 storage_key: &str,
56 lists: &BTreeMap<String, SlidingSyncList>,
57) {
58 let storage = client.store();
59 for list_name in lists.keys() {
60 invalidate_cached_list(storage, storage_key, list_name).await;
61 }
62 let instance_storage_key = format_storage_key_for_sliding_sync(storage_key);
63 let _ = storage.remove_custom_value(instance_storage_key.as_bytes()).await;
64
65 #[cfg(feature = "e2e-encryption")]
66 if let Some(olm_machine) = &*client.olm_machine().await {
67 let _ = olm_machine
69 .store()
70 .set_custom_value(&instance_storage_key, "".as_bytes().to_vec())
71 .await;
72 }
73}
74
75pub(super) async fn store_sliding_sync_state(
77 sliding_sync: &SlidingSync,
78 _position: &SlidingSyncPositionMarkers,
79) -> Result<()> {
80 let storage_key = &sliding_sync.inner.storage_key;
81 let instance_storage_key = format_storage_key_for_sliding_sync(storage_key);
82
83 trace!(storage_key, "Saving a `SlidingSync` to the state store");
84 let storage = sliding_sync.inner.client.store();
85
86 storage
89 .set_custom_value(
90 instance_storage_key.as_bytes(),
91 serde_json::to_vec(&FrozenSlidingSync::new(&*sliding_sync.inner.rooms.read().await))?,
92 )
93 .await?;
94
95 #[cfg(feature = "e2e-encryption")]
96 {
97 let position = _position;
98
99 if let Some(olm_machine) = &*sliding_sync.inner.client.olm_machine().await {
104 let pos_blob = serde_json::to_vec(&FrozenSlidingSyncPos { pos: position.pos.clone() })?;
105 olm_machine.store().set_custom_value(&instance_storage_key, pos_blob).await?;
106 }
107 }
108
109 let frozen_lists = {
111 sliding_sync
112 .inner
113 .lists
114 .read()
115 .await
116 .iter()
117 .filter(|(_, list)| matches!(list.cache_policy(), SlidingSyncListCachePolicy::Enabled))
118 .map(|(list_name, list)| {
119 Ok((
120 format_storage_key_for_sliding_sync_list(storage_key, list_name),
121 serde_json::to_vec(&FrozenSlidingSyncList::freeze(list))?,
122 ))
123 })
124 .collect::<Result<Vec<_>, crate::Error>>()?
125 };
126
127 for (storage_key_for_list, frozen_list) in frozen_lists {
128 trace!(storage_key_for_list, "Saving a `SlidingSyncList`");
129
130 storage.set_custom_value(storage_key_for_list.as_bytes(), frozen_list).await?;
131 }
132
133 Ok(())
134}
135
136pub(super) async fn restore_sliding_sync_list(
140 storage: &dyn StateStore<Error = StoreError>,
141 storage_key: &str,
142 list_name: &str,
143) -> Result<Option<FrozenSlidingSyncList>> {
144 let _timer = timer!(format!("loading list from DB {list_name}"));
145
146 let storage_key_for_list = format_storage_key_for_sliding_sync_list(storage_key, list_name);
147
148 match storage
149 .get_custom_value(storage_key_for_list.as_bytes())
150 .await?
151 .map(|custom_value| serde_json::from_slice::<FrozenSlidingSyncList>(&custom_value))
152 {
153 Some(Ok(frozen_list)) => {
154 trace!(list_name, "successfully read the list from cache");
156 return Ok(Some(frozen_list));
157 }
158
159 Some(Err(_)) => {
160 warn!(
166 list_name,
167 "failed to deserialize the list from the cache, it is obsolete; removing the cache entry!"
168 );
169 invalidate_cached_list(storage, storage_key, list_name).await;
171 }
172
173 None => {
174 trace!(list_name, "failed to find the list in the cache");
177 }
178 }
179
180 Ok(None)
181}
182
183#[derive(Default)]
185pub(super) struct RestoredFields {
186 pub to_device_token: Option<String>,
187 pub pos: Option<String>,
188 pub rooms: BTreeMap<OwnedRoomId, SlidingSyncRoom>,
189}
190
191pub(super) async fn restore_sliding_sync_state(
196 client: &Client,
197 storage_key: &str,
198 lists: &BTreeMap<String, SlidingSyncList>,
199) -> Result<Option<RestoredFields>> {
200 let _timer = timer!(format!("loading sliding sync {storage_key} state from DB"));
201
202 let mut restored_fields = RestoredFields::default();
203
204 #[cfg(feature = "e2e-encryption")]
205 if let Some(olm_machine) = &*client.olm_machine().await {
206 match olm_machine.store().next_batch_token().await? {
207 Some(token) => {
208 restored_fields.to_device_token = Some(token);
209 }
210 None => trace!("No `SlidingSync` in the crypto-store cache"),
211 }
212 }
213
214 let storage = client.store();
215 let instance_storage_key = format_storage_key_for_sliding_sync(storage_key);
216
217 match storage
219 .get_custom_value(instance_storage_key.as_bytes())
220 .await?
221 .map(|custom_value| serde_json::from_slice::<FrozenSlidingSync>(&custom_value))
222 {
223 Some(Ok(FrozenSlidingSync { to_device_since, rooms: frozen_rooms })) => {
225 trace!("Successfully read the `SlidingSync` from the cache");
226 if restored_fields.to_device_token.is_none() {
229 restored_fields.to_device_token = to_device_since;
230 }
231
232 #[cfg(feature = "e2e-encryption")]
233 {
234 if let Some(olm_machine) = &*client.olm_machine().await {
235 if let Ok(Some(blob)) =
236 olm_machine.store().get_custom_value(&instance_storage_key).await
237 {
238 if let Ok(frozen_pos) =
239 serde_json::from_slice::<FrozenSlidingSyncPos>(&blob)
240 {
241 trace!("Successfully read the `Sliding Sync` pos from the crypto store cache");
242 restored_fields.pos = frozen_pos.pos;
243 }
244 }
245 }
246 }
247
248 restored_fields.rooms = frozen_rooms
249 .into_iter()
250 .map(|frozen_room| {
251 (frozen_room.room_id.clone(), SlidingSyncRoom::from_frozen(frozen_room))
252 })
253 .collect();
254 }
255
256 Some(Err(_)) => {
263 warn!(
264 "failed to deserialize `SlidingSync` from the cache, it is obsolete; removing the cache entry!"
265 );
266
267 clean_storage(client, storage_key, lists).await;
269
270 return Ok(None);
271 }
272
273 None => {
274 trace!("No Sliding Sync object in the cache");
275 }
276 }
277
278 Ok(Some(restored_fields))
279}
280
281#[cfg(test)]
282mod tests {
283 use std::sync::{Arc, RwLock};
284
285 use assert_matches::assert_matches;
286 use matrix_sdk_test::async_test;
287 use ruma::owned_room_id;
288
289 use super::{
290 super::FrozenSlidingSyncRoom, clean_storage, format_storage_key_for_sliding_sync,
291 format_storage_key_for_sliding_sync_list, format_storage_key_prefix,
292 restore_sliding_sync_state, store_sliding_sync_state, SlidingSyncList,
293 };
294 use crate::{test_utils::logged_in_client, Result, SlidingSyncRoom};
295
296 #[allow(clippy::await_holding_lock)]
297 #[async_test]
298 async fn test_sliding_sync_can_be_stored_and_restored() -> Result<()> {
299 let client = logged_in_client(Some("https://foo.bar".to_owned())).await;
300
301 let store = client.store();
302
303 assert!(store
305 .get_custom_value(format_storage_key_for_sliding_sync("hello").as_bytes())
306 .await?
307 .is_none());
308
309 assert!(store
310 .get_custom_value(
311 format_storage_key_for_sliding_sync_list("hello", "list_foo").as_bytes()
312 )
313 .await?
314 .is_none());
315
316 assert!(store
317 .get_custom_value(
318 format_storage_key_for_sliding_sync_list("hello", "list_bar").as_bytes()
319 )
320 .await?
321 .is_none());
322
323 let room_id1 = owned_room_id!("!r1:matrix.org");
324 let room_id2 = owned_room_id!("!r2:matrix.org");
325
326 let storage_key = {
328 let sync_id = "test-sync-id";
329 let storage_key = format_storage_key_prefix(sync_id, client.user_id().unwrap());
330 let sliding_sync = client
331 .sliding_sync(sync_id)?
332 .add_cached_list(SlidingSyncList::builder("list_foo"))
333 .await?
334 .add_list(SlidingSyncList::builder("list_bar"))
335 .build()
336 .await?;
337
338 {
340 let lists = sliding_sync.inner.lists.write().await;
341
342 let list_foo = lists.get("list_foo").unwrap();
343 list_foo.set_maximum_number_of_rooms(Some(42));
344
345 let list_bar = lists.get("list_bar").unwrap();
346 list_bar.set_maximum_number_of_rooms(Some(1337));
347 }
348
349 {
351 let mut rooms = sliding_sync.inner.rooms.write().await;
352
353 rooms.insert(
354 room_id1.clone(),
355 SlidingSyncRoom::new(room_id1.clone(), None, Vec::new()),
356 );
357 rooms.insert(
358 room_id2.clone(),
359 SlidingSyncRoom::new(room_id2.clone(), None, Vec::new()),
360 );
361 }
362
363 let position_guard = sliding_sync.inner.position.lock().await;
364 assert!(sliding_sync.cache_to_storage(&position_guard).await.is_ok());
365
366 storage_key
367 };
368
369 assert!(store
371 .get_custom_value(format_storage_key_for_sliding_sync(&storage_key).as_bytes())
372 .await?
373 .is_some());
374
375 assert!(store
376 .get_custom_value(
377 format_storage_key_for_sliding_sync_list(&storage_key, "list_foo").as_bytes()
378 )
379 .await?
380 .is_some());
381
382 assert!(store
384 .get_custom_value(
385 format_storage_key_for_sliding_sync_list(&storage_key, "list_bar").as_bytes()
386 )
387 .await?
388 .is_none());
389
390 let storage_key = {
392 let sync_id = "test-sync-id";
393 let storage_key = format_storage_key_prefix(sync_id, client.user_id().unwrap());
394 let max_number_of_room_stream = Arc::new(RwLock::new(None));
395 let cloned_stream = max_number_of_room_stream.clone();
396 let sliding_sync = client
397 .sliding_sync(sync_id)?
398 .add_cached_list(SlidingSyncList::builder("list_foo").once_built(move |list| {
399 assert_eq!(list.maximum_number_of_rooms(), None);
401
402 let mut stream = cloned_stream.write().unwrap();
403 *stream = Some(list.maximum_number_of_rooms_stream());
404 list
405 }))
406 .await?
407 .add_list(SlidingSyncList::builder("list_bar"))
408 .build()
409 .await?;
410
411 {
413 let lists = sliding_sync.inner.lists.read().await;
414
415 let list_foo = lists.get("list_foo").unwrap();
417 assert_eq!(list_foo.maximum_number_of_rooms(), Some(42));
418
419 let list_bar = lists.get("list_bar").unwrap();
421 assert_eq!(list_bar.maximum_number_of_rooms(), None);
422 }
423
424 {
426 let rooms = sliding_sync.inner.rooms.read().await;
427
428 assert!(rooms.contains_key(&room_id1));
430 assert!(rooms.contains_key(&room_id2));
431 }
432
433 {
436 let mut stream =
437 max_number_of_room_stream.write().unwrap().take().expect("stream must be set");
438 let initial_max_number_of_rooms =
439 stream.next().await.expect("stream must have emitted something");
440 assert_eq!(initial_max_number_of_rooms, Some(42));
441 }
442
443 let lists = sliding_sync.inner.lists.read().await;
445 clean_storage(&client, &storage_key, &lists).await;
446 storage_key
447 };
448
449 assert!(store
451 .get_custom_value(format_storage_key_for_sliding_sync(&storage_key).as_bytes())
452 .await?
453 .is_none());
454
455 assert!(store
456 .get_custom_value(
457 format_storage_key_for_sliding_sync_list(&storage_key, "list_foo").as_bytes()
458 )
459 .await?
460 .is_none());
461
462 assert!(store
463 .get_custom_value(
464 format_storage_key_for_sliding_sync_list(&storage_key, "list_bar").as_bytes()
465 )
466 .await?
467 .is_none());
468
469 Ok(())
470 }
471
472 #[cfg(feature = "e2e-encryption")]
473 #[async_test]
474 async fn test_sliding_sync_high_level_cache_and_restore() -> Result<()> {
475 use imbl::Vector;
476 use ruma::owned_room_id;
477
478 use crate::sliding_sync::FrozenSlidingSync;
479
480 let client = logged_in_client(Some("https://foo.bar".to_owned())).await;
481
482 let sync_id = "test-sync-id";
483 let storage_key_prefix = format_storage_key_prefix(sync_id, client.user_id().unwrap());
484 let full_storage_key = format_storage_key_for_sliding_sync(&storage_key_prefix);
485 let sliding_sync = client.sliding_sync(sync_id)?.build().await?;
486
487 if let Some(olm_machine) = &*client.base_client().olm_machine().await {
489 let store = olm_machine.store();
490 assert!(store.next_batch_token().await?.is_none());
491 }
492
493 let state_store = client.store();
494 assert!(state_store.get_custom_value(full_storage_key.as_bytes()).await?.is_none());
495
496 let pos = "pos".to_owned();
498 {
499 let mut position_guard = sliding_sync.inner.position.lock().await;
500 position_guard.pos = Some(pos.clone());
501
502 store_sliding_sync_state(&sliding_sync, &position_guard).await?;
504 }
505
506 let state_store = client.store();
509 assert_matches!(
510 state_store.get_custom_value(full_storage_key.as_bytes()).await?,
511 Some(bytes) => {
512 let deserialized: FrozenSlidingSync = serde_json::from_slice(&bytes)?;
513 assert!(deserialized.to_device_since.is_none());
514 }
515 );
516
517 drop(sliding_sync);
519
520 let restored_fields = restore_sliding_sync_state(&client, &storage_key_prefix, &[].into())
521 .await?
522 .expect("must have restored sliding sync fields");
523
524 assert_eq!(restored_fields.pos.unwrap(), pos);
526
527 {
532 let olm_machine = client.base_client().olm_machine().await;
533 let olm_machine = olm_machine.as_ref().unwrap();
534 assert!(olm_machine.store().next_batch_token().await?.is_none());
535 }
536
537 let to_device_token = "to_device_token".to_owned();
538
539 let state_store = client.store();
541 state_store
542 .set_custom_value(
543 full_storage_key.as_bytes(),
544 serde_json::to_vec(&FrozenSlidingSync {
545 to_device_since: Some(to_device_token.clone()),
546 rooms: vec![FrozenSlidingSyncRoom {
547 room_id: owned_room_id!("!r0:matrix.org"),
548 prev_batch: Some("t0ken".to_owned()),
549 timeline_queue: Vector::new(),
550 }],
551 })?,
552 )
553 .await?;
554
555 let restored_fields = restore_sliding_sync_state(&client, &storage_key_prefix, &[].into())
556 .await?
557 .expect("must have restored fields");
558
559 assert_eq!(restored_fields.to_device_token.unwrap(), to_device_token);
562 assert_eq!(restored_fields.pos.unwrap(), pos);
563 assert_eq!(restored_fields.rooms.len(), 1);
564
565 Ok(())
566 }
567}