matrix_sdk_crypto/store/
caches.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Collection of small in-memory stores that can be used to cache Olm objects.
16//!
17//! Note: You'll only be interested in these if you are implementing a custom
18//! `CryptoStore`.
19
20use std::{
21    collections::{BTreeMap, BTreeSet, HashMap, HashSet},
22    fmt::Display,
23    ops::Deref,
24    sync::{
25        atomic::{AtomicBool, Ordering},
26        Arc, Weak,
27    },
28};
29
30use matrix_sdk_common::locks::RwLock as StdRwLock;
31use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, UserId};
32use serde::{Deserialize, Serialize};
33use tokio::sync::{Mutex, MutexGuard, OwnedRwLockReadGuard, RwLock};
34use tracing::{field::display, instrument, trace, Span};
35
36use super::{CryptoStoreError, CryptoStoreWrapper};
37use crate::{identities::DeviceData, olm::Session, Account};
38
39/// In-memory store for Olm Sessions.
40#[derive(Debug, Default, Clone)]
41pub struct SessionStore {
42    #[allow(clippy::type_complexity)]
43    pub(crate) entries: Arc<RwLock<BTreeMap<String, Arc<Mutex<Vec<Session>>>>>>,
44}
45
46impl SessionStore {
47    /// Create a new empty Session store.
48    pub fn new() -> Self {
49        Self::default()
50    }
51
52    /// Clear all entries in the session store.
53    ///
54    /// This is intended to be used when regenerating olm machines.
55    pub async fn clear(&self) {
56        self.entries.write().await.clear()
57    }
58
59    /// Add a session to the store.
60    ///
61    /// Returns true if the session was added, false if the session was
62    /// already in the store.
63    pub async fn add(&self, session: Session) -> bool {
64        let sessions_lock =
65            self.entries.write().await.entry(session.sender_key.to_base64()).or_default().clone();
66
67        let mut sessions = sessions_lock.lock().await;
68
69        if !sessions.contains(&session) {
70            sessions.push(session);
71            true
72        } else {
73            false
74        }
75    }
76
77    /// Get all the sessions that belong to the given sender key.
78    pub async fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Session>>>> {
79        self.entries.read().await.get(sender_key).cloned()
80    }
81
82    /// Add a list of sessions belonging to the sender key.
83    pub async fn set_for_sender(&self, sender_key: &str, sessions: Vec<Session>) {
84        self.entries.write().await.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
85    }
86}
87
88/// In-memory store holding the devices of users.
89#[derive(Debug, Default)]
90pub struct DeviceStore {
91    entries: StdRwLock<BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, DeviceData>>>,
92}
93
94impl DeviceStore {
95    /// Create a new empty device store.
96    pub fn new() -> Self {
97        Self::default()
98    }
99
100    /// Add a device to the store.
101    ///
102    /// Returns true if the device was already in the store, false otherwise.
103    pub fn add(&self, device: DeviceData) -> bool {
104        let user_id = device.user_id();
105        self.entries
106            .write()
107            .entry(user_id.to_owned())
108            .or_default()
109            .insert(device.device_id().into(), device)
110            .is_none()
111    }
112
113    /// Get the device with the given device_id and belonging to the given user.
114    pub fn get(&self, user_id: &UserId, device_id: &DeviceId) -> Option<DeviceData> {
115        Some(self.entries.read().get(user_id)?.get(device_id)?.clone())
116    }
117
118    /// Remove the device with the given device_id and belonging to the given
119    /// user.
120    ///
121    /// Returns the device if it was removed, None if it wasn't in the store.
122    pub fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Option<DeviceData> {
123        self.entries.write().get_mut(user_id)?.remove(device_id)
124    }
125
126    /// Get a read-only view over all devices of the given user.
127    pub fn user_devices(&self, user_id: &UserId) -> HashMap<OwnedDeviceId, DeviceData> {
128        self.entries
129            .write()
130            .entry(user_id.to_owned())
131            .or_default()
132            .iter()
133            .map(|(key, value)| (key.to_owned(), value.clone()))
134            .collect()
135    }
136}
137
138/// A numeric type that can represent an infinite ordered sequence.
139///
140/// It uses wrapping arithmetic to make sure we never run out of numbers. (2**64
141/// should be enough for anyone, but it's easy enough just to make it wrap.)
142//
143/// Internally it uses a *signed* counter so that we can compare values via a
144/// subtraction. For example, suppose we've just overflowed from i64::MAX to
145/// i64::MIN. (i64::MAX.wrapping_sub(i64::MIN)) is -1, which tells us that
146/// i64::MAX comes before i64::MIN in the sequence.
147#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
148#[serde(transparent)]
149pub struct SequenceNumber(i64);
150
151impl Display for SequenceNumber {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        self.0.fmt(f)
154    }
155}
156
157impl PartialOrd for SequenceNumber {
158    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
159        Some(self.cmp(other))
160    }
161}
162
163impl Ord for SequenceNumber {
164    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
165        self.0.wrapping_sub(other.0).cmp(&0)
166    }
167}
168
169impl SequenceNumber {
170    pub(crate) fn increment(&mut self) {
171        self.0 = self.0.wrapping_add(1)
172    }
173
174    fn previous(&self) -> Self {
175        Self(self.0.wrapping_sub(1))
176    }
177}
178
179/// Information on a task which is waiting for a `/keys/query` to complete.
180#[derive(Debug)]
181pub(super) struct KeysQueryWaiter {
182    /// The user that we are waiting for
183    user: OwnedUserId,
184
185    /// The sequence number of the last invalidation of the users's device list
186    /// when we started waiting (ie, any `/keys/query` result with the same or
187    /// greater sequence number will satisfy this waiter)
188    sequence_number: SequenceNumber,
189
190    /// Whether the `/keys/query` has completed.
191    ///
192    /// This is only modified whilst holding the mutex on `users_for_key_query`.
193    pub(super) completed: AtomicBool,
194}
195
196/// Record of the users that are waiting for a /keys/query.
197///
198/// To avoid races, we maintain a sequence number which is updated each time we
199/// receive an invalidation notification. We also record the sequence number at
200/// which each user was last invalidated. Then, we attach the current sequence
201/// number to each `/keys/query` request, and when we get the response we can
202/// tell if any users have been invalidated more recently than that request.
203#[derive(Debug, Default)]
204pub(super) struct UsersForKeyQuery {
205    /// The sequence number we will assign to the next addition to user_map
206    next_sequence_number: SequenceNumber,
207
208    /// The users pending a lookup, together with the sequence number at which
209    /// they were added to the list
210    user_map: HashMap<OwnedUserId, SequenceNumber>,
211
212    /// A list of tasks waiting for key queries to complete.
213    ///
214    /// We expect this list to remain fairly short, so don't bother partitioning
215    /// by user.
216    tasks_awaiting_key_query: Vec<Weak<KeysQueryWaiter>>,
217}
218
219impl UsersForKeyQuery {
220    /// Record a new user that requires a key query
221    pub(super) fn insert_user(&mut self, user: &UserId) {
222        let sequence_number = self.next_sequence_number;
223
224        trace!(?user, %sequence_number, "Flagging user for key query");
225
226        self.user_map.insert(user.to_owned(), sequence_number);
227        self.next_sequence_number.increment();
228    }
229
230    /// Record that a user has received an update with the given sequence
231    /// number.
232    ///
233    /// If the sequence number is newer than the oldest invalidation for this
234    /// user, it is removed from the list of those needing an update.
235    ///
236    /// Returns true if the user is now up-to-date, else false
237    #[instrument(level = "trace", skip(self), fields(invalidation_sequence))]
238    pub(super) fn maybe_remove_user(
239        &mut self,
240        user: &UserId,
241        query_sequence: SequenceNumber,
242    ) -> bool {
243        let last_invalidation = self.user_map.get(user).copied();
244
245        // If there were any jobs waiting for this key query to complete, we can flag
246        // them as completed and remove them from our list. We also clear out any tasks
247        // that have been cancelled.
248        self.tasks_awaiting_key_query.retain(|waiter| {
249            let Some(waiter) = waiter.upgrade() else {
250                // the TaskAwaitingKeyQuery has been dropped, so it probably timed out and the
251                // caller went away. We can remove it from our list whether or not it's for this
252                // user.
253                trace!("removing expired waiting task");
254
255                return false;
256            };
257
258            if waiter.user == user && waiter.sequence_number <= query_sequence {
259                trace!(
260                    ?user,
261                    %query_sequence,
262                    waiter_sequence = %waiter.sequence_number,
263                    "Removing completed waiting task"
264                );
265
266                waiter.completed.store(true, Ordering::Relaxed);
267
268                false
269            } else {
270                trace!(
271                    ?user,
272                    %query_sequence,
273                    waiter_user = ?waiter.user,
274                    waiter_sequence= %waiter.sequence_number,
275                    "Retaining still-waiting task"
276                );
277
278                true
279            }
280        });
281
282        if let Some(last_invalidation) = last_invalidation {
283            Span::current().record("invalidation_sequence", display(last_invalidation));
284
285            if last_invalidation > query_sequence {
286                trace!("User invalidated since this query started: still not up-to-date");
287                false
288            } else {
289                trace!("User now up-to-date");
290                self.user_map.remove(user);
291                true
292            }
293        } else {
294            trace!("User already up-to-date, nothing to do");
295            true
296        }
297    }
298
299    /// Fetch the list of users waiting for a key query, and the current
300    /// sequence number
301    pub(super) fn users_for_key_query(&self) -> (HashSet<OwnedUserId>, SequenceNumber) {
302        // we return the sequence number of the last invalidation
303        let sequence_number = self.next_sequence_number.previous();
304        (self.user_map.keys().cloned().collect(), sequence_number)
305    }
306
307    /// Check if a key query is pending for a user, and register for a wakeup if
308    /// so.
309    ///
310    /// If no key query is currently pending, returns `None`. Otherwise, returns
311    /// (an `Arc` to) a `KeysQueryWaiter`, whose `completed` flag will
312    /// be set once the lookup completes.
313    pub(super) fn maybe_register_waiting_task(
314        &mut self,
315        user: &UserId,
316    ) -> Option<Arc<KeysQueryWaiter>> {
317        self.user_map.get(user).map(|&sequence_number| {
318            trace!(?user, %sequence_number, "Registering new waiting task");
319
320            let waiter = Arc::new(KeysQueryWaiter {
321                sequence_number,
322                user: user.to_owned(),
323                completed: AtomicBool::new(false),
324            });
325
326            self.tasks_awaiting_key_query.push(Arc::downgrade(&waiter));
327
328            waiter
329        })
330    }
331}
332
333#[derive(Debug)]
334pub(crate) struct StoreCache {
335    pub(super) store: Arc<CryptoStoreWrapper>,
336    pub(super) tracked_users: StdRwLock<BTreeSet<OwnedUserId>>,
337    pub(super) loaded_tracked_users: RwLock<bool>,
338    pub(super) account: Mutex<Option<Account>>,
339}
340
341impl StoreCache {
342    pub(crate) fn store_wrapper(&self) -> &CryptoStoreWrapper {
343        self.store.as_ref()
344    }
345
346    /// Returns a reference to the `Account`.
347    ///
348    /// Either load the account from the cache, or the store if missing from
349    /// the cache.
350    ///
351    /// Note there should always be an account stored at least in the store, so
352    /// this doesn't return an `Option`.
353    ///
354    /// Note: this method should remain private, otherwise it's possible to ask
355    /// for a `StoreTransaction`, then get the `StoreTransaction::cache()`
356    /// and thus have two different live copies of the `Account` at once.
357    pub(super) async fn account(&self) -> super::Result<impl Deref<Target = Account> + '_> {
358        let mut guard = self.account.lock().await;
359        if guard.is_some() {
360            Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
361        } else {
362            match self.store.load_account().await? {
363                Some(account) => {
364                    *guard = Some(account);
365                    Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
366                }
367                None => Err(CryptoStoreError::AccountUnset),
368            }
369        }
370    }
371}
372
373/// Read-only store cache guard.
374///
375/// This type should hold all the methods that are available when the cache is
376/// borrowed in read-only mode, while all the write operations on those fields
377/// should happen as part of a `StoreTransaction`.
378pub(crate) struct StoreCacheGuard {
379    pub(super) cache: OwnedRwLockReadGuard<StoreCache>,
380    // TODO: (bnjbvr, #2624) add cross-process lock guard here.
381}
382
383impl StoreCacheGuard {
384    /// Returns a reference to the `Account`.
385    ///
386    /// Either load the account from the cache, or the store if missing from
387    /// the cache.
388    ///
389    /// Note there should always be an account stored at least in the store, so
390    /// this doesn't return an `Option`.
391    pub async fn account(&self) -> super::Result<impl Deref<Target = Account> + '_> {
392        self.cache.account().await
393    }
394}
395
396impl Deref for StoreCacheGuard {
397    type Target = StoreCache;
398
399    fn deref(&self) -> &Self::Target {
400        &self.cache
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use matrix_sdk_test::async_test;
407    use proptest::prelude::*;
408
409    use super::{DeviceStore, SequenceNumber, SessionStore};
410    use crate::{
411        identities::device::testing::get_device, olm::tests::get_account_and_session_test_helper,
412    };
413
414    #[async_test]
415    async fn test_session_store() {
416        let (_, session) = get_account_and_session_test_helper();
417
418        let store = SessionStore::new();
419
420        assert!(store.add(session.clone()).await);
421        assert!(!store.add(session.clone()).await);
422
423        let sessions = store.get(&session.sender_key.to_base64()).await.unwrap();
424        let sessions = sessions.lock().await;
425
426        let loaded_session = &sessions[0];
427
428        assert_eq!(&session, loaded_session);
429    }
430
431    #[async_test]
432    async fn test_session_store_bulk_storing() {
433        let (_, session) = get_account_and_session_test_helper();
434
435        let store = SessionStore::new();
436        store.set_for_sender(&session.sender_key.to_base64(), vec![session.clone()]).await;
437
438        let sessions = store.get(&session.sender_key.to_base64()).await.unwrap();
439        let sessions = sessions.lock().await;
440
441        let loaded_session = &sessions[0];
442
443        assert_eq!(&session, loaded_session);
444    }
445
446    #[async_test]
447    async fn test_device_store() {
448        let device = get_device();
449        let store = DeviceStore::new();
450
451        assert!(store.add(device.clone()));
452        assert!(!store.add(device.clone()));
453
454        let loaded_device = store.get(device.user_id(), device.device_id()).unwrap();
455
456        assert_eq!(device, loaded_device);
457
458        let user_devices = store.user_devices(device.user_id());
459
460        assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
461        assert_eq!(user_devices.values().next().unwrap(), &device);
462
463        let loaded_device = user_devices.get(device.device_id()).unwrap();
464
465        assert_eq!(&device, loaded_device);
466
467        store.remove(device.user_id(), device.device_id());
468
469        let loaded_device = store.get(device.user_id(), device.device_id());
470        assert!(loaded_device.is_none());
471    }
472
473    #[test]
474    fn sequence_at_boundary() {
475        let first = SequenceNumber(i64::MAX);
476        let second = SequenceNumber(first.0.wrapping_add(1));
477        let third = SequenceNumber(first.0.wrapping_sub(1));
478
479        assert!(second > first);
480        assert!(first < second);
481        assert!(third < first);
482        assert!(first > third);
483        assert!(second > third);
484        assert!(third < second);
485    }
486
487    proptest! {
488        #[test]
489        fn partial_eq_sequence_number(sequence in i64::MIN..i64::MAX) {
490            let first = SequenceNumber(sequence);
491            let second = SequenceNumber(first.0.wrapping_add(1));
492            let third = SequenceNumber(first.0.wrapping_sub(1));
493
494            assert!(second > first);
495            assert!(first < second);
496            assert!(third < first);
497            assert!(first > third);
498            assert!(second > third);
499            assert!(third < second);
500        }
501    }
502}