matrix_sdk_crypto/store/
caches.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Collection of small in-memory stores that can be used to cache Olm objects.
16//!
17//! Note: You'll only be interested in these if you are implementing a custom
18//! `CryptoStore`.
19
20use std::{
21    collections::{BTreeMap, HashMap, HashSet},
22    fmt::Display,
23    sync::{
24        atomic::{AtomicBool, Ordering},
25        Arc, Weak,
26    },
27};
28
29use matrix_sdk_common::locks::RwLock as StdRwLock;
30use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, UserId};
31use serde::{Deserialize, Serialize};
32use tokio::sync::{Mutex, RwLock};
33use tracing::{field::display, instrument, trace, Span};
34
35use crate::{identities::DeviceData, olm::Session};
36
37/// In-memory store for Olm Sessions.
38#[derive(Debug, Default, Clone)]
39pub struct SessionStore {
40    #[allow(clippy::type_complexity)]
41    pub(crate) entries: Arc<RwLock<BTreeMap<String, Arc<Mutex<Vec<Session>>>>>>,
42}
43
44impl SessionStore {
45    /// Create a new empty Session store.
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    /// Clear all entries in the session store.
51    ///
52    /// This is intended to be used when regenerating olm machines.
53    pub async fn clear(&self) {
54        self.entries.write().await.clear()
55    }
56
57    /// Add a session to the store.
58    ///
59    /// Returns true if the session was added, false if the session was
60    /// already in the store.
61    pub async fn add(&self, session: Session) -> bool {
62        let sessions_lock =
63            self.entries.write().await.entry(session.sender_key.to_base64()).or_default().clone();
64
65        let mut sessions = sessions_lock.lock().await;
66
67        if !sessions.contains(&session) {
68            sessions.push(session);
69            true
70        } else {
71            false
72        }
73    }
74
75    /// Get all the sessions that belong to the given sender key.
76    pub async fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Session>>>> {
77        self.entries.read().await.get(sender_key).cloned()
78    }
79
80    /// Add a list of sessions belonging to the sender key.
81    pub async fn set_for_sender(&self, sender_key: &str, sessions: Vec<Session>) {
82        self.entries.write().await.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
83    }
84}
85
86/// In-memory store holding the devices of users.
87#[derive(Debug, Default)]
88pub struct DeviceStore {
89    entries: StdRwLock<BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, DeviceData>>>,
90}
91
92impl DeviceStore {
93    /// Create a new empty device store.
94    pub fn new() -> Self {
95        Self::default()
96    }
97
98    /// Add a device to the store.
99    ///
100    /// Returns true if the device was already in the store, false otherwise.
101    pub fn add(&self, device: DeviceData) -> bool {
102        let user_id = device.user_id();
103        self.entries
104            .write()
105            .entry(user_id.to_owned())
106            .or_default()
107            .insert(device.device_id().into(), device)
108            .is_none()
109    }
110
111    /// Get the device with the given device_id and belonging to the given user.
112    pub fn get(&self, user_id: &UserId, device_id: &DeviceId) -> Option<DeviceData> {
113        Some(self.entries.read().get(user_id)?.get(device_id)?.clone())
114    }
115
116    /// Remove the device with the given device_id and belonging to the given
117    /// user.
118    ///
119    /// Returns the device if it was removed, None if it wasn't in the store.
120    pub fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Option<DeviceData> {
121        self.entries.write().get_mut(user_id)?.remove(device_id)
122    }
123
124    /// Get a read-only view over all devices of the given user.
125    pub fn user_devices(&self, user_id: &UserId) -> HashMap<OwnedDeviceId, DeviceData> {
126        self.entries
127            .write()
128            .entry(user_id.to_owned())
129            .or_default()
130            .iter()
131            .map(|(key, value)| (key.to_owned(), value.clone()))
132            .collect()
133    }
134}
135
136/// A numeric type that can represent an infinite ordered sequence.
137///
138/// It uses wrapping arithmetic to make sure we never run out of numbers. (2**64
139/// should be enough for anyone, but it's easy enough just to make it wrap.)
140//
141/// Internally it uses a *signed* counter so that we can compare values via a
142/// subtraction. For example, suppose we've just overflowed from i64::MAX to
143/// i64::MIN. (i64::MAX.wrapping_sub(i64::MIN)) is -1, which tells us that
144/// i64::MAX comes before i64::MIN in the sequence.
145#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
146#[serde(transparent)]
147pub struct SequenceNumber(i64);
148
149impl Display for SequenceNumber {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        self.0.fmt(f)
152    }
153}
154
155impl PartialOrd for SequenceNumber {
156    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
157        Some(self.0.wrapping_sub(other.0).cmp(&0))
158    }
159}
160
161impl Ord for SequenceNumber {
162    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
163        self.0.wrapping_sub(other.0).cmp(&0)
164    }
165}
166
167impl SequenceNumber {
168    pub(crate) fn increment(&mut self) {
169        self.0 = self.0.wrapping_add(1)
170    }
171
172    fn previous(&self) -> Self {
173        Self(self.0.wrapping_sub(1))
174    }
175}
176
177/// Information on a task which is waiting for a `/keys/query` to complete.
178#[derive(Debug)]
179pub(super) struct KeysQueryWaiter {
180    /// The user that we are waiting for
181    user: OwnedUserId,
182
183    /// The sequence number of the last invalidation of the users's device list
184    /// when we started waiting (ie, any `/keys/query` result with the same or
185    /// greater sequence number will satisfy this waiter)
186    sequence_number: SequenceNumber,
187
188    /// Whether the `/keys/query` has completed.
189    ///
190    /// This is only modified whilst holding the mutex on `users_for_key_query`.
191    pub(super) completed: AtomicBool,
192}
193
194/// Record of the users that are waiting for a /keys/query.
195///
196/// To avoid races, we maintain a sequence number which is updated each time we
197/// receive an invalidation notification. We also record the sequence number at
198/// which each user was last invalidated. Then, we attach the current sequence
199/// number to each `/keys/query` request, and when we get the response we can
200/// tell if any users have been invalidated more recently than that request.
201#[derive(Debug, Default)]
202pub(super) struct UsersForKeyQuery {
203    /// The sequence number we will assign to the next addition to user_map
204    next_sequence_number: SequenceNumber,
205
206    /// The users pending a lookup, together with the sequence number at which
207    /// they were added to the list
208    user_map: HashMap<OwnedUserId, SequenceNumber>,
209
210    /// A list of tasks waiting for key queries to complete.
211    ///
212    /// We expect this list to remain fairly short, so don't bother partitioning
213    /// by user.
214    tasks_awaiting_key_query: Vec<Weak<KeysQueryWaiter>>,
215}
216
217impl UsersForKeyQuery {
218    /// Record a new user that requires a key query
219    pub(super) fn insert_user(&mut self, user: &UserId) {
220        let sequence_number = self.next_sequence_number;
221
222        trace!(?user, %sequence_number, "Flagging user for key query");
223
224        self.user_map.insert(user.to_owned(), sequence_number);
225        self.next_sequence_number.increment();
226    }
227
228    /// Record that a user has received an update with the given sequence
229    /// number.
230    ///
231    /// If the sequence number is newer than the oldest invalidation for this
232    /// user, it is removed from the list of those needing an update.
233    ///
234    /// Returns true if the user is now up-to-date, else false
235    #[instrument(level = "trace", skip(self), fields(invalidation_sequence))]
236    pub(super) fn maybe_remove_user(
237        &mut self,
238        user: &UserId,
239        query_sequence: SequenceNumber,
240    ) -> bool {
241        let last_invalidation = self.user_map.get(user).copied();
242
243        // If there were any jobs waiting for this key query to complete, we can flag
244        // them as completed and remove them from our list. We also clear out any tasks
245        // that have been cancelled.
246        self.tasks_awaiting_key_query.retain(|waiter| {
247            let Some(waiter) = waiter.upgrade() else {
248                // the TaskAwaitingKeyQuery has been dropped, so it probably timed out and the
249                // caller went away. We can remove it from our list whether or not it's for this
250                // user.
251                trace!("removing expired waiting task");
252
253                return false;
254            };
255
256            if waiter.user == user && waiter.sequence_number <= query_sequence {
257                trace!(
258                    ?user,
259                    %query_sequence,
260                    waiter_sequence = %waiter.sequence_number,
261                    "Removing completed waiting task"
262                );
263
264                waiter.completed.store(true, Ordering::Relaxed);
265
266                false
267            } else {
268                trace!(
269                    ?user,
270                    %query_sequence,
271                    waiter_user = ?waiter.user,
272                    waiter_sequence= %waiter.sequence_number,
273                    "Retaining still-waiting task"
274                );
275
276                true
277            }
278        });
279
280        if let Some(last_invalidation) = last_invalidation {
281            Span::current().record("invalidation_sequence", display(last_invalidation));
282
283            if last_invalidation > query_sequence {
284                trace!("User invalidated since this query started: still not up-to-date");
285                false
286            } else {
287                trace!("User now up-to-date");
288                self.user_map.remove(user);
289                true
290            }
291        } else {
292            trace!("User already up-to-date, nothing to do");
293            true
294        }
295    }
296
297    /// Fetch the list of users waiting for a key query, and the current
298    /// sequence number
299    pub(super) fn users_for_key_query(&self) -> (HashSet<OwnedUserId>, SequenceNumber) {
300        // we return the sequence number of the last invalidation
301        let sequence_number = self.next_sequence_number.previous();
302        (self.user_map.keys().cloned().collect(), sequence_number)
303    }
304
305    /// Check if a key query is pending for a user, and register for a wakeup if
306    /// so.
307    ///
308    /// If no key query is currently pending, returns `None`. Otherwise, returns
309    /// (an `Arc` to) a `KeysQueryWaiter`, whose `completed` flag will
310    /// be set once the lookup completes.
311    pub(super) fn maybe_register_waiting_task(
312        &mut self,
313        user: &UserId,
314    ) -> Option<Arc<KeysQueryWaiter>> {
315        self.user_map.get(user).map(|&sequence_number| {
316            trace!(?user, %sequence_number, "Registering new waiting task");
317
318            let waiter = Arc::new(KeysQueryWaiter {
319                sequence_number,
320                user: user.to_owned(),
321                completed: AtomicBool::new(false),
322            });
323
324            self.tasks_awaiting_key_query.push(Arc::downgrade(&waiter));
325
326            waiter
327        })
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use matrix_sdk_test::async_test;
334    use proptest::prelude::*;
335
336    use super::{DeviceStore, SequenceNumber, SessionStore};
337    use crate::{
338        identities::device::testing::get_device, olm::tests::get_account_and_session_test_helper,
339    };
340
341    #[async_test]
342    async fn test_session_store() {
343        let (_, session) = get_account_and_session_test_helper();
344
345        let store = SessionStore::new();
346
347        assert!(store.add(session.clone()).await);
348        assert!(!store.add(session.clone()).await);
349
350        let sessions = store.get(&session.sender_key.to_base64()).await.unwrap();
351        let sessions = sessions.lock().await;
352
353        let loaded_session = &sessions[0];
354
355        assert_eq!(&session, loaded_session);
356    }
357
358    #[async_test]
359    async fn test_session_store_bulk_storing() {
360        let (_, session) = get_account_and_session_test_helper();
361
362        let store = SessionStore::new();
363        store.set_for_sender(&session.sender_key.to_base64(), vec![session.clone()]).await;
364
365        let sessions = store.get(&session.sender_key.to_base64()).await.unwrap();
366        let sessions = sessions.lock().await;
367
368        let loaded_session = &sessions[0];
369
370        assert_eq!(&session, loaded_session);
371    }
372
373    #[async_test]
374    async fn test_device_store() {
375        let device = get_device();
376        let store = DeviceStore::new();
377
378        assert!(store.add(device.clone()));
379        assert!(!store.add(device.clone()));
380
381        let loaded_device = store.get(device.user_id(), device.device_id()).unwrap();
382
383        assert_eq!(device, loaded_device);
384
385        let user_devices = store.user_devices(device.user_id());
386
387        assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
388        assert_eq!(user_devices.values().next().unwrap(), &device);
389
390        let loaded_device = user_devices.get(device.device_id()).unwrap();
391
392        assert_eq!(&device, loaded_device);
393
394        store.remove(device.user_id(), device.device_id());
395
396        let loaded_device = store.get(device.user_id(), device.device_id());
397        assert!(loaded_device.is_none());
398    }
399
400    #[test]
401    fn sequence_at_boundary() {
402        let first = SequenceNumber(i64::MAX);
403        let second = SequenceNumber(first.0.wrapping_add(1));
404        let third = SequenceNumber(first.0.wrapping_sub(1));
405
406        assert!(second > first);
407        assert!(first < second);
408        assert!(third < first);
409        assert!(first > third);
410        assert!(second > third);
411        assert!(third < second);
412    }
413
414    proptest! {
415        #[test]
416        fn partial_eq_sequence_number(sequence in i64::MIN..i64::MAX) {
417            let first = SequenceNumber(sequence);
418            let second = SequenceNumber(first.0.wrapping_add(1));
419            let third = SequenceNumber(first.0.wrapping_sub(1));
420
421            assert!(second > first);
422            assert!(first < second);
423            assert!(third < first);
424            assert!(first > third);
425            assert!(second > third);
426            assert!(third < second);
427        }
428    }
429}