1use std::{
21 collections::{BTreeMap, BTreeSet, HashMap, HashSet},
22 fmt::Display,
23 ops::Deref,
24 sync::{
25 atomic::{AtomicBool, Ordering},
26 Arc, Weak,
27 },
28};
29
30use matrix_sdk_common::locks::RwLock as StdRwLock;
31use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, UserId};
32use serde::{Deserialize, Serialize};
33use tokio::sync::{Mutex, MutexGuard, OwnedRwLockReadGuard, RwLock};
34use tracing::{field::display, instrument, trace, Span};
35
36use super::{CryptoStoreError, CryptoStoreWrapper};
37use crate::{identities::DeviceData, olm::Session, Account};
38
39#[derive(Debug, Default, Clone)]
41pub struct SessionStore {
42 #[allow(clippy::type_complexity)]
43 pub(crate) entries: Arc<RwLock<BTreeMap<String, Arc<Mutex<Vec<Session>>>>>>,
44}
45
46impl SessionStore {
47 pub fn new() -> Self {
49 Self::default()
50 }
51
52 pub async fn clear(&self) {
56 self.entries.write().await.clear()
57 }
58
59 pub async fn add(&self, session: Session) -> bool {
64 let sessions_lock =
65 self.entries.write().await.entry(session.sender_key.to_base64()).or_default().clone();
66
67 let mut sessions = sessions_lock.lock().await;
68
69 if !sessions.contains(&session) {
70 sessions.push(session);
71 true
72 } else {
73 false
74 }
75 }
76
77 pub async fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Session>>>> {
79 self.entries.read().await.get(sender_key).cloned()
80 }
81
82 pub async fn set_for_sender(&self, sender_key: &str, sessions: Vec<Session>) {
84 self.entries.write().await.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
85 }
86}
87
88#[derive(Debug, Default)]
90pub struct DeviceStore {
91 entries: StdRwLock<BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, DeviceData>>>,
92}
93
94impl DeviceStore {
95 pub fn new() -> Self {
97 Self::default()
98 }
99
100 pub fn add(&self, device: DeviceData) -> bool {
104 let user_id = device.user_id();
105 self.entries
106 .write()
107 .entry(user_id.to_owned())
108 .or_default()
109 .insert(device.device_id().into(), device)
110 .is_none()
111 }
112
113 pub fn get(&self, user_id: &UserId, device_id: &DeviceId) -> Option<DeviceData> {
115 Some(self.entries.read().get(user_id)?.get(device_id)?.clone())
116 }
117
118 pub fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Option<DeviceData> {
123 self.entries.write().get_mut(user_id)?.remove(device_id)
124 }
125
126 pub fn user_devices(&self, user_id: &UserId) -> HashMap<OwnedDeviceId, DeviceData> {
128 self.entries
129 .write()
130 .entry(user_id.to_owned())
131 .or_default()
132 .iter()
133 .map(|(key, value)| (key.to_owned(), value.clone()))
134 .collect()
135 }
136}
137
138#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
148#[serde(transparent)]
149pub struct SequenceNumber(i64);
150
151impl Display for SequenceNumber {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 self.0.fmt(f)
154 }
155}
156
157impl PartialOrd for SequenceNumber {
158 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
159 Some(self.cmp(other))
160 }
161}
162
163impl Ord for SequenceNumber {
164 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
165 self.0.wrapping_sub(other.0).cmp(&0)
166 }
167}
168
169impl SequenceNumber {
170 pub(crate) fn increment(&mut self) {
171 self.0 = self.0.wrapping_add(1)
172 }
173
174 fn previous(&self) -> Self {
175 Self(self.0.wrapping_sub(1))
176 }
177}
178
179#[derive(Debug)]
181pub(super) struct KeysQueryWaiter {
182 user: OwnedUserId,
184
185 sequence_number: SequenceNumber,
189
190 pub(super) completed: AtomicBool,
194}
195
196#[derive(Debug, Default)]
204pub(super) struct UsersForKeyQuery {
205 next_sequence_number: SequenceNumber,
207
208 user_map: HashMap<OwnedUserId, SequenceNumber>,
211
212 tasks_awaiting_key_query: Vec<Weak<KeysQueryWaiter>>,
217}
218
219impl UsersForKeyQuery {
220 pub(super) fn insert_user(&mut self, user: &UserId) {
222 let sequence_number = self.next_sequence_number;
223
224 trace!(?user, %sequence_number, "Flagging user for key query");
225
226 self.user_map.insert(user.to_owned(), sequence_number);
227 self.next_sequence_number.increment();
228 }
229
230 #[instrument(level = "trace", skip(self), fields(invalidation_sequence))]
238 pub(super) fn maybe_remove_user(
239 &mut self,
240 user: &UserId,
241 query_sequence: SequenceNumber,
242 ) -> bool {
243 let last_invalidation = self.user_map.get(user).copied();
244
245 self.tasks_awaiting_key_query.retain(|waiter| {
249 let Some(waiter) = waiter.upgrade() else {
250 trace!("removing expired waiting task");
254
255 return false;
256 };
257
258 if waiter.user == user && waiter.sequence_number <= query_sequence {
259 trace!(
260 ?user,
261 %query_sequence,
262 waiter_sequence = %waiter.sequence_number,
263 "Removing completed waiting task"
264 );
265
266 waiter.completed.store(true, Ordering::Relaxed);
267
268 false
269 } else {
270 trace!(
271 ?user,
272 %query_sequence,
273 waiter_user = ?waiter.user,
274 waiter_sequence= %waiter.sequence_number,
275 "Retaining still-waiting task"
276 );
277
278 true
279 }
280 });
281
282 if let Some(last_invalidation) = last_invalidation {
283 Span::current().record("invalidation_sequence", display(last_invalidation));
284
285 if last_invalidation > query_sequence {
286 trace!("User invalidated since this query started: still not up-to-date");
287 false
288 } else {
289 trace!("User now up-to-date");
290 self.user_map.remove(user);
291 true
292 }
293 } else {
294 trace!("User already up-to-date, nothing to do");
295 true
296 }
297 }
298
299 pub(super) fn users_for_key_query(&self) -> (HashSet<OwnedUserId>, SequenceNumber) {
302 let sequence_number = self.next_sequence_number.previous();
304 (self.user_map.keys().cloned().collect(), sequence_number)
305 }
306
307 pub(super) fn maybe_register_waiting_task(
314 &mut self,
315 user: &UserId,
316 ) -> Option<Arc<KeysQueryWaiter>> {
317 self.user_map.get(user).map(|&sequence_number| {
318 trace!(?user, %sequence_number, "Registering new waiting task");
319
320 let waiter = Arc::new(KeysQueryWaiter {
321 sequence_number,
322 user: user.to_owned(),
323 completed: AtomicBool::new(false),
324 });
325
326 self.tasks_awaiting_key_query.push(Arc::downgrade(&waiter));
327
328 waiter
329 })
330 }
331}
332
333#[derive(Debug)]
334pub(crate) struct StoreCache {
335 pub(super) store: Arc<CryptoStoreWrapper>,
336 pub(super) tracked_users: StdRwLock<BTreeSet<OwnedUserId>>,
337 pub(super) loaded_tracked_users: RwLock<bool>,
338 pub(super) account: Mutex<Option<Account>>,
339}
340
341impl StoreCache {
342 pub(crate) fn store_wrapper(&self) -> &CryptoStoreWrapper {
343 self.store.as_ref()
344 }
345
346 pub(super) async fn account(&self) -> super::Result<impl Deref<Target = Account> + '_> {
358 let mut guard = self.account.lock().await;
359 if guard.is_some() {
360 Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
361 } else {
362 match self.store.load_account().await? {
363 Some(account) => {
364 *guard = Some(account);
365 Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
366 }
367 None => Err(CryptoStoreError::AccountUnset),
368 }
369 }
370 }
371}
372
373pub(crate) struct StoreCacheGuard {
379 pub(super) cache: OwnedRwLockReadGuard<StoreCache>,
380 }
382
383impl StoreCacheGuard {
384 pub async fn account(&self) -> super::Result<impl Deref<Target = Account> + '_> {
392 self.cache.account().await
393 }
394}
395
396impl Deref for StoreCacheGuard {
397 type Target = StoreCache;
398
399 fn deref(&self) -> &Self::Target {
400 &self.cache
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use matrix_sdk_test::async_test;
407 use proptest::prelude::*;
408
409 use super::{DeviceStore, SequenceNumber, SessionStore};
410 use crate::{
411 identities::device::testing::get_device, olm::tests::get_account_and_session_test_helper,
412 };
413
414 #[async_test]
415 async fn test_session_store() {
416 let (_, session) = get_account_and_session_test_helper();
417
418 let store = SessionStore::new();
419
420 assert!(store.add(session.clone()).await);
421 assert!(!store.add(session.clone()).await);
422
423 let sessions = store.get(&session.sender_key.to_base64()).await.unwrap();
424 let sessions = sessions.lock().await;
425
426 let loaded_session = &sessions[0];
427
428 assert_eq!(&session, loaded_session);
429 }
430
431 #[async_test]
432 async fn test_session_store_bulk_storing() {
433 let (_, session) = get_account_and_session_test_helper();
434
435 let store = SessionStore::new();
436 store.set_for_sender(&session.sender_key.to_base64(), vec![session.clone()]).await;
437
438 let sessions = store.get(&session.sender_key.to_base64()).await.unwrap();
439 let sessions = sessions.lock().await;
440
441 let loaded_session = &sessions[0];
442
443 assert_eq!(&session, loaded_session);
444 }
445
446 #[async_test]
447 async fn test_device_store() {
448 let device = get_device();
449 let store = DeviceStore::new();
450
451 assert!(store.add(device.clone()));
452 assert!(!store.add(device.clone()));
453
454 let loaded_device = store.get(device.user_id(), device.device_id()).unwrap();
455
456 assert_eq!(device, loaded_device);
457
458 let user_devices = store.user_devices(device.user_id());
459
460 assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
461 assert_eq!(user_devices.values().next().unwrap(), &device);
462
463 let loaded_device = user_devices.get(device.device_id()).unwrap();
464
465 assert_eq!(&device, loaded_device);
466
467 store.remove(device.user_id(), device.device_id());
468
469 let loaded_device = store.get(device.user_id(), device.device_id());
470 assert!(loaded_device.is_none());
471 }
472
473 #[test]
474 fn sequence_at_boundary() {
475 let first = SequenceNumber(i64::MAX);
476 let second = SequenceNumber(first.0.wrapping_add(1));
477 let third = SequenceNumber(first.0.wrapping_sub(1));
478
479 assert!(second > first);
480 assert!(first < second);
481 assert!(third < first);
482 assert!(first > third);
483 assert!(second > third);
484 assert!(third < second);
485 }
486
487 proptest! {
488 #[test]
489 fn partial_eq_sequence_number(sequence in i64::MIN..i64::MAX) {
490 let first = SequenceNumber(sequence);
491 let second = SequenceNumber(first.0.wrapping_add(1));
492 let third = SequenceNumber(first.0.wrapping_sub(1));
493
494 assert!(second > first);
495 assert!(first < second);
496 assert!(third < first);
497 assert!(first > third);
498 assert!(second > third);
499 assert!(third < second);
500 }
501 }
502}