1use std::{
21 collections::{BTreeMap, HashMap, HashSet},
22 fmt::Display,
23 sync::{
24 atomic::{AtomicBool, Ordering},
25 Arc, Weak,
26 },
27};
28
29use matrix_sdk_common::locks::RwLock as StdRwLock;
30use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, UserId};
31use serde::{Deserialize, Serialize};
32use tokio::sync::{Mutex, RwLock};
33use tracing::{field::display, instrument, trace, Span};
34
35use crate::{identities::DeviceData, olm::Session};
36
37#[derive(Debug, Default, Clone)]
39pub struct SessionStore {
40 #[allow(clippy::type_complexity)]
41 pub(crate) entries: Arc<RwLock<BTreeMap<String, Arc<Mutex<Vec<Session>>>>>>,
42}
43
44impl SessionStore {
45 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub async fn clear(&self) {
54 self.entries.write().await.clear()
55 }
56
57 pub async fn add(&self, session: Session) -> bool {
62 let sessions_lock =
63 self.entries.write().await.entry(session.sender_key.to_base64()).or_default().clone();
64
65 let mut sessions = sessions_lock.lock().await;
66
67 if !sessions.contains(&session) {
68 sessions.push(session);
69 true
70 } else {
71 false
72 }
73 }
74
75 pub async fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Session>>>> {
77 self.entries.read().await.get(sender_key).cloned()
78 }
79
80 pub async fn set_for_sender(&self, sender_key: &str, sessions: Vec<Session>) {
82 self.entries.write().await.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
83 }
84}
85
86#[derive(Debug, Default)]
88pub struct DeviceStore {
89 entries: StdRwLock<BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, DeviceData>>>,
90}
91
92impl DeviceStore {
93 pub fn new() -> Self {
95 Self::default()
96 }
97
98 pub fn add(&self, device: DeviceData) -> bool {
102 let user_id = device.user_id();
103 self.entries
104 .write()
105 .entry(user_id.to_owned())
106 .or_default()
107 .insert(device.device_id().into(), device)
108 .is_none()
109 }
110
111 pub fn get(&self, user_id: &UserId, device_id: &DeviceId) -> Option<DeviceData> {
113 Some(self.entries.read().get(user_id)?.get(device_id)?.clone())
114 }
115
116 pub fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Option<DeviceData> {
121 self.entries.write().get_mut(user_id)?.remove(device_id)
122 }
123
124 pub fn user_devices(&self, user_id: &UserId) -> HashMap<OwnedDeviceId, DeviceData> {
126 self.entries
127 .write()
128 .entry(user_id.to_owned())
129 .or_default()
130 .iter()
131 .map(|(key, value)| (key.to_owned(), value.clone()))
132 .collect()
133 }
134}
135
136#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
146#[serde(transparent)]
147pub struct SequenceNumber(i64);
148
149impl Display for SequenceNumber {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 self.0.fmt(f)
152 }
153}
154
155impl PartialOrd for SequenceNumber {
156 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
157 Some(self.0.wrapping_sub(other.0).cmp(&0))
158 }
159}
160
161impl Ord for SequenceNumber {
162 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
163 self.0.wrapping_sub(other.0).cmp(&0)
164 }
165}
166
167impl SequenceNumber {
168 pub(crate) fn increment(&mut self) {
169 self.0 = self.0.wrapping_add(1)
170 }
171
172 fn previous(&self) -> Self {
173 Self(self.0.wrapping_sub(1))
174 }
175}
176
177#[derive(Debug)]
179pub(super) struct KeysQueryWaiter {
180 user: OwnedUserId,
182
183 sequence_number: SequenceNumber,
187
188 pub(super) completed: AtomicBool,
192}
193
194#[derive(Debug, Default)]
202pub(super) struct UsersForKeyQuery {
203 next_sequence_number: SequenceNumber,
205
206 user_map: HashMap<OwnedUserId, SequenceNumber>,
209
210 tasks_awaiting_key_query: Vec<Weak<KeysQueryWaiter>>,
215}
216
217impl UsersForKeyQuery {
218 pub(super) fn insert_user(&mut self, user: &UserId) {
220 let sequence_number = self.next_sequence_number;
221
222 trace!(?user, %sequence_number, "Flagging user for key query");
223
224 self.user_map.insert(user.to_owned(), sequence_number);
225 self.next_sequence_number.increment();
226 }
227
228 #[instrument(level = "trace", skip(self), fields(invalidation_sequence))]
236 pub(super) fn maybe_remove_user(
237 &mut self,
238 user: &UserId,
239 query_sequence: SequenceNumber,
240 ) -> bool {
241 let last_invalidation = self.user_map.get(user).copied();
242
243 self.tasks_awaiting_key_query.retain(|waiter| {
247 let Some(waiter) = waiter.upgrade() else {
248 trace!("removing expired waiting task");
252
253 return false;
254 };
255
256 if waiter.user == user && waiter.sequence_number <= query_sequence {
257 trace!(
258 ?user,
259 %query_sequence,
260 waiter_sequence = %waiter.sequence_number,
261 "Removing completed waiting task"
262 );
263
264 waiter.completed.store(true, Ordering::Relaxed);
265
266 false
267 } else {
268 trace!(
269 ?user,
270 %query_sequence,
271 waiter_user = ?waiter.user,
272 waiter_sequence= %waiter.sequence_number,
273 "Retaining still-waiting task"
274 );
275
276 true
277 }
278 });
279
280 if let Some(last_invalidation) = last_invalidation {
281 Span::current().record("invalidation_sequence", display(last_invalidation));
282
283 if last_invalidation > query_sequence {
284 trace!("User invalidated since this query started: still not up-to-date");
285 false
286 } else {
287 trace!("User now up-to-date");
288 self.user_map.remove(user);
289 true
290 }
291 } else {
292 trace!("User already up-to-date, nothing to do");
293 true
294 }
295 }
296
297 pub(super) fn users_for_key_query(&self) -> (HashSet<OwnedUserId>, SequenceNumber) {
300 let sequence_number = self.next_sequence_number.previous();
302 (self.user_map.keys().cloned().collect(), sequence_number)
303 }
304
305 pub(super) fn maybe_register_waiting_task(
312 &mut self,
313 user: &UserId,
314 ) -> Option<Arc<KeysQueryWaiter>> {
315 self.user_map.get(user).map(|&sequence_number| {
316 trace!(?user, %sequence_number, "Registering new waiting task");
317
318 let waiter = Arc::new(KeysQueryWaiter {
319 sequence_number,
320 user: user.to_owned(),
321 completed: AtomicBool::new(false),
322 });
323
324 self.tasks_awaiting_key_query.push(Arc::downgrade(&waiter));
325
326 waiter
327 })
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use matrix_sdk_test::async_test;
334 use proptest::prelude::*;
335
336 use super::{DeviceStore, SequenceNumber, SessionStore};
337 use crate::{
338 identities::device::testing::get_device, olm::tests::get_account_and_session_test_helper,
339 };
340
341 #[async_test]
342 async fn test_session_store() {
343 let (_, session) = get_account_and_session_test_helper();
344
345 let store = SessionStore::new();
346
347 assert!(store.add(session.clone()).await);
348 assert!(!store.add(session.clone()).await);
349
350 let sessions = store.get(&session.sender_key.to_base64()).await.unwrap();
351 let sessions = sessions.lock().await;
352
353 let loaded_session = &sessions[0];
354
355 assert_eq!(&session, loaded_session);
356 }
357
358 #[async_test]
359 async fn test_session_store_bulk_storing() {
360 let (_, session) = get_account_and_session_test_helper();
361
362 let store = SessionStore::new();
363 store.set_for_sender(&session.sender_key.to_base64(), vec![session.clone()]).await;
364
365 let sessions = store.get(&session.sender_key.to_base64()).await.unwrap();
366 let sessions = sessions.lock().await;
367
368 let loaded_session = &sessions[0];
369
370 assert_eq!(&session, loaded_session);
371 }
372
373 #[async_test]
374 async fn test_device_store() {
375 let device = get_device();
376 let store = DeviceStore::new();
377
378 assert!(store.add(device.clone()));
379 assert!(!store.add(device.clone()));
380
381 let loaded_device = store.get(device.user_id(), device.device_id()).unwrap();
382
383 assert_eq!(device, loaded_device);
384
385 let user_devices = store.user_devices(device.user_id());
386
387 assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
388 assert_eq!(user_devices.values().next().unwrap(), &device);
389
390 let loaded_device = user_devices.get(device.device_id()).unwrap();
391
392 assert_eq!(&device, loaded_device);
393
394 store.remove(device.user_id(), device.device_id());
395
396 let loaded_device = store.get(device.user_id(), device.device_id());
397 assert!(loaded_device.is_none());
398 }
399
400 #[test]
401 fn sequence_at_boundary() {
402 let first = SequenceNumber(i64::MAX);
403 let second = SequenceNumber(first.0.wrapping_add(1));
404 let third = SequenceNumber(first.0.wrapping_sub(1));
405
406 assert!(second > first);
407 assert!(first < second);
408 assert!(third < first);
409 assert!(first > third);
410 assert!(second > third);
411 assert!(third < second);
412 }
413
414 proptest! {
415 #[test]
416 fn partial_eq_sequence_number(sequence in i64::MIN..i64::MAX) {
417 let first = SequenceNumber(sequence);
418 let second = SequenceNumber(first.0.wrapping_add(1));
419 let third = SequenceNumber(first.0.wrapping_sub(1));
420
421 assert!(second > first);
422 assert!(first < second);
423 assert!(third < first);
424 assert!(first > third);
425 assert!(second > third);
426 assert!(third < second);
427 }
428 }
429}