Skip to main content

matrix_sdk_common/
cross_process_lock.rs

1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! A cross-process lock implementation.
16//!
17//! This is a per-process lock that may be used only for very specific use
18//! cases, where multiple processes might concurrently write to the same
19//! database at the same time; this would invalidate store caches, so
20//! that should be done mindfully. Such a lock can be obtained multiple times by
21//! the same process, and it remains active as long as there's at least one user
22//! in a given process.
23//!
24//! The lock is implemented using time-based leases. The lock maintains the lock
25//! identifier (key), who's the current holder (value), and an expiration
26//! timestamp on the side; see also `CryptoStore::try_take_leased_lock` for more
27//! details.
28//!
29//! The lock is initially obtained for a certain period of time (namely, the
30//! duration of a lease, aka `LEASE_DURATION_MS`), and then a “heartbeat” task
31//! renews the lease to extend its duration, every so often (namely, every
32//! `EXTEND_LEASE_EVERY_MS`). Since the Tokio scheduler might be busy, the
33//! extension request should happen way more frequently than the duration of a
34//! lease, in case a deadline is missed. The current values have been chosen to
35//! reflect that, with a ratio of 1:10 as of 2023-06-23.
36//!
37//! Releasing the lock happens naturally, by not renewing a lease. It happens
38//! automatically after the duration of the last lease, at most.
39
40use std::{
41    error::Error,
42    future::Future,
43    sync::{
44        Arc,
45        atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering},
46    },
47    time::Duration,
48};
49
50use tokio::sync::Mutex;
51use tracing::{debug, error, instrument, trace, warn};
52
53use crate::{
54    SendOutsideWasm,
55    executor::{JoinHandle, spawn},
56    sleep::sleep,
57};
58
59/// A lock generation is an integer incremented each time the lock is taken by
60/// a different holder.
61///
62/// This is used to know if a lock has been dirtied.
63pub type CrossProcessLockGeneration = u64;
64
65/// Trait used to try to take a lock. Foundation of [`CrossProcessLock`].
66pub trait TryLock {
67    #[cfg(not(target_family = "wasm"))]
68    type LockError: Error + Send + Sync;
69
70    #[cfg(target_family = "wasm")]
71    type LockError: Error;
72
73    /// Try to take a leased lock.
74    ///
75    /// This attempts to take a lock for the given lease duration.
76    ///
77    /// - If we already had the lease, this will extend the lease.
78    /// - If we didn't, but the previous lease has expired, we will obtain the
79    ///   lock.
80    /// - If there was no previous lease, we will obtain the lock.
81    /// - Otherwise, we don't get the lock.
82    ///
83    /// Returns `Some(_)` to indicate the lock succeeded, `None` otherwise. The
84    /// cross-process lock generation must be compared to the generation before
85    /// the call to see if the lock has been dirtied: a different generation
86    /// means the lock has been dirtied, i.e. taken by a different holder in
87    /// the meantime.
88    fn try_lock(
89        &self,
90        lease_duration_ms: u32,
91        key: &str,
92        holder: &str,
93    ) -> impl Future<Output = Result<Option<CrossProcessLockGeneration>, Self::LockError>>
94    + SendOutsideWasm;
95}
96
97/// Small state machine to handle wait times.
98#[derive(Clone, Debug)]
99enum WaitingTime {
100    /// Some time to wait, in milliseconds.
101    Some(u32),
102    /// Stop waiting when seeing this value.
103    Stop,
104}
105
106/// A guard of a cross-process lock.
107///
108/// The lock will be automatically released a short period of time after all the
109/// guards have dropped.
110#[derive(Clone, Debug)]
111#[must_use = "If unused, the `CrossProcessLock` will unlock at the end of the lease"]
112pub struct CrossProcessLockGuard {
113    /// A clone of [`CrossProcessLock::num_holders`].
114    num_holders: Arc<AtomicU32>,
115
116    /// A clone of [`CrossProcessLock::is_dirty`].
117    is_dirty: Arc<AtomicBool>,
118}
119
120impl CrossProcessLockGuard {
121    fn new(num_holders: Arc<AtomicU32>, is_dirty: Arc<AtomicBool>) -> Self {
122        Self { num_holders, is_dirty }
123    }
124
125    /// Determine whether the cross-process lock associated to this guard is
126    /// dirty.
127    ///
128    /// See [`CrossProcessLockState::Dirty`] to learn more about the semantics
129    /// of _dirty_.
130    pub fn is_dirty(&self) -> bool {
131        self.is_dirty.load(Ordering::SeqCst)
132    }
133
134    /// Clear the dirty state from the cross-process lock associated to this
135    /// guard.
136    ///
137    /// If the cross-process lock is dirtied, it will remain dirtied until
138    /// this method is called. This allows recovering from a dirty state and
139    /// marking that it has recovered.
140    pub fn clear_dirty(&self) {
141        self.is_dirty.store(false, Ordering::SeqCst);
142    }
143}
144
145impl Drop for CrossProcessLockGuard {
146    fn drop(&mut self) {
147        self.num_holders.fetch_sub(1, Ordering::SeqCst);
148    }
149}
150
151/// A cross-process lock implementation.
152///
153/// See the doc-comment of this module for more information.
154#[derive(Clone, Debug)]
155pub struct CrossProcessLock<L>
156where
157    L: TryLock + Clone + SendOutsideWasm + 'static,
158{
159    /// The locker implementation.
160    ///
161    /// `L` is responsible for trying to take the lock, while
162    /// [`CrossProcessLock`] is responsible to make it cross-process, with the
163    /// retry mechanism, plus guard and so on.
164    locker: L,
165
166    /// Number of holders of the lock in this process.
167    ///
168    /// If greater than 0, this means we've already obtained this lock, in this
169    /// process, and the store lock mustn't be touched.
170    ///
171    /// When the number of holders is decreased to 0, then the lock must be
172    /// released in the store.
173    num_holders: Arc<AtomicU32>,
174
175    /// A mutex to control an attempt to take the lock, to avoid making it
176    /// reentrant.
177    locking_attempt: Arc<Mutex<()>>,
178
179    /// Current renew task spawned by `try_lock_once`.
180    renew_task: Arc<Mutex<Option<JoinHandle<()>>>>,
181
182    /// The key used in the key/value mapping for the lock entry.
183    lock_key: String,
184
185    /// The cross-process lock configuration.
186    config: CrossProcessLockConfig,
187
188    /// Backoff time, in milliseconds.
189    backoff: Arc<Mutex<WaitingTime>>,
190
191    /// This lock generation.
192    generation: Arc<AtomicU64>,
193
194    /// Whether the lock has been dirtied.
195    ///
196    /// See [`CrossProcessLockState::Dirty`] to learn more about the semantics
197    /// of _dirty_.
198    is_dirty: Arc<AtomicBool>,
199}
200
201/// Amount of time a lease of the lock should last, in milliseconds.
202pub const LEASE_DURATION_MS: u32 = 500;
203
204/// Period of time between two attempts to extend the lease. We'll
205/// re-request a lease for an entire duration of `LEASE_DURATION_MS`
206/// milliseconds, every `EXTEND_LEASE_EVERY_MS`, so this has to
207/// be an amount safely low compared to `LEASE_DURATION_MS`, to make sure
208/// that we can miss a deadline without compromising the lock.
209pub const EXTEND_LEASE_EVERY_MS: u64 = 50;
210
211/// Initial backoff, in milliseconds. This is the time we wait the first
212/// time, if taking the lock initially failed.
213const INITIAL_BACKOFF_MS: u32 = 10;
214
215/// Maximal backoff, in milliseconds. This is the maximum amount of time
216/// we'll wait for the lock, *between two attempts*.
217pub const MAX_BACKOFF_MS: u32 = 1000;
218
219/// Sentinel value representing the absence of a lock generation value.
220///
221/// When the lock is created, it has no generation. Once locked, it receives its
222/// first generation from [`TryLock::try_lock`]. Subsequent lockings may
223/// generate new lock generation. The generation is incremented by 1 every time.
224///
225/// The first generation is defined by [`FIRST_CROSS_PROCESS_LOCK_GENERATION`].
226pub const NO_CROSS_PROCESS_LOCK_GENERATION: CrossProcessLockGeneration = 0;
227
228/// Describe the first lock generation value (see
229/// [`CrossProcessLockGeneration`]).
230pub const FIRST_CROSS_PROCESS_LOCK_GENERATION: CrossProcessLockGeneration = 1;
231
232impl<L> CrossProcessLock<L>
233where
234    L: TryLock + Clone + SendOutsideWasm + 'static,
235{
236    /// Create a new cross-process lock.
237    ///
238    /// # Parameters
239    ///
240    /// - `lock_key`: key in the key-value store to store the lock's state.
241    /// - `config`: the cross-process lock configuration to use, if it's
242    ///   [`CrossProcessLockConfig::SingleProcess`], no actual lock will be
243    ///   taken.
244    pub fn new(locker: L, lock_key: String, config: CrossProcessLockConfig) -> Self {
245        Self {
246            locker,
247            lock_key,
248            config,
249            backoff: Arc::new(Mutex::new(WaitingTime::Some(INITIAL_BACKOFF_MS))),
250            num_holders: Arc::new(0.into()),
251            locking_attempt: Arc::new(Mutex::new(())),
252            renew_task: Default::default(),
253            generation: Arc::new(AtomicU64::new(NO_CROSS_PROCESS_LOCK_GENERATION)),
254            is_dirty: Arc::new(AtomicBool::new(false)),
255        }
256    }
257
258    /// Determine whether the cross-process lock is dirty.
259    ///
260    /// See [`CrossProcessLockState::Dirty`] to learn more about the semantics
261    /// of _dirty_.
262    pub fn is_dirty(&self) -> bool {
263        self.is_dirty.load(Ordering::SeqCst)
264    }
265
266    /// Clear the dirty state from this cross-process lock.
267    ///
268    /// If the cross-process lock is dirtied, it will remain dirtied until
269    /// this method is called. This allows recovering from a dirty state and
270    /// marking that it has recovered.
271    pub fn clear_dirty(&self) {
272        self.is_dirty.store(false, Ordering::SeqCst);
273    }
274
275    /// Try to lock once, returns whether the lock was obtained or not.
276    ///
277    /// The lock can be obtained but it can be dirty. In all cases, the renew
278    /// task will run in the background.
279    #[instrument(skip(self), fields(?self.lock_key, ?self.config))]
280    pub async fn try_lock_once(
281        &self,
282    ) -> Result<Result<CrossProcessLockState, CrossProcessLockUnobtained>, L::LockError> {
283        // If it's not `MultiProcess`, this behaves as a no-op
284        let CrossProcessLockConfig::MultiProcess { holder_name } = &self.config else {
285            let guard = CrossProcessLockGuard::new(self.num_holders.clone(), self.is_dirty.clone());
286            return Ok(Ok(CrossProcessLockState::Clean(guard)));
287        };
288
289        // Hold onto the locking attempt mutex for the entire lifetime of this
290        // function, to avoid multiple reentrant calls.
291        let mut _attempt = self.locking_attempt.lock().await;
292
293        // If another thread obtained the lock, make sure to only superficially increase
294        // the number of holders, and carry on.
295        if self.num_holders.load(Ordering::SeqCst) > 0 {
296            // Note: between the above load and the fetch_add below, another thread may
297            // decrement `num_holders`. That's fine because that means the lock
298            // was taken by at least one thread, and after this call it will be
299            // taken by at least one thread.
300            trace!("We already had the lock, incrementing holder count");
301
302            self.num_holders.fetch_add(1, Ordering::SeqCst);
303
304            return Ok(Ok(CrossProcessLockState::Clean(CrossProcessLockGuard::new(
305                self.num_holders.clone(),
306                self.is_dirty.clone(),
307            ))));
308        }
309
310        if let Some(new_generation) =
311            self.locker.try_lock(LEASE_DURATION_MS, &self.lock_key, holder_name).await?
312        {
313            match self.generation.swap(new_generation, Ordering::SeqCst) {
314                // If there was no lock generation, it means this is the first time the lock is
315                // obtained. It cannot be dirty.
316                NO_CROSS_PROCESS_LOCK_GENERATION => {
317                    trace!(?new_generation, "Setting the lock generation for the first time");
318                }
319
320                // This was NOT the same generation, the lock has been dirtied!
321                previous_generation if previous_generation != new_generation => {
322                    warn!(
323                        ?previous_generation,
324                        ?new_generation,
325                        "The lock has been obtained, but it's been dirtied!"
326                    );
327                    self.is_dirty.store(true, Ordering::SeqCst);
328                }
329
330                // This was the same generation, no problem.
331                _ => {
332                    trace!("Same lock generation; no problem");
333                }
334            }
335
336            trace!("Lock obtained!");
337        } else {
338            trace!("Couldn't obtain the lock immediately.");
339            return Ok(Err(CrossProcessLockUnobtained::Busy));
340        }
341
342        trace!("Obtained the lock, spawning the lease extension task.");
343
344        // This is the first time we've obtaind the lock. We're going to spawn the task
345        // that will renew the lease.
346
347        // Clone data to be owned by the task.
348        let this = (*self).clone();
349
350        let mut renew_task = self.renew_task.lock().await;
351
352        // Cancel the previous task, if any. That's safe to do, because:
353        // - either the task was done,
354        // - or it was still running, but taking a lock in the db has to be an atomic
355        //   operation running in a transaction.
356
357        if let Some(_prev) = renew_task.take() {
358            #[cfg(not(target_family = "wasm"))]
359            if !_prev.is_finished() {
360                trace!("aborting the previous renew task");
361                _prev.abort();
362            }
363        }
364
365        // Restart a new one.
366        *renew_task = Some(spawn(async move {
367            let CrossProcessLockConfig::MultiProcess { holder_name } = this.config else { return };
368            loop {
369                {
370                    // First, check if there are still users of this lock.
371                    //
372                    // This is not racy, because:
373                    // - the `locking_attempt` mutex makes sure we don't have unexpected
374                    // interactions with the non-atomic sequence above in `try_lock_once`
375                    // (check > 0, then add 1).
376                    // - other entities holding onto the `num_holders` atomic will only
377                    // decrease it over time.
378
379                    let _guard = this.locking_attempt.lock().await;
380
381                    // If there are no more users, we can quit.
382                    if this.num_holders.load(Ordering::SeqCst) == 0 {
383                        trace!("exiting the lease extension loop");
384
385                        // Cancel the lease with another 0ms lease.
386                        // If we don't get the lock, that's (weird but) fine.
387                        let fut = this.locker.try_lock(0, &this.lock_key, &holder_name);
388                        let _ = fut.await;
389
390                        // Exit the loop.
391                        break;
392                    }
393                }
394
395                sleep(Duration::from_millis(EXTEND_LEASE_EVERY_MS)).await;
396
397                match this.locker.try_lock(LEASE_DURATION_MS, &this.lock_key, &holder_name).await {
398                    Ok(Some(_generation)) => {
399                        // It's impossible that the generation can be
400                        // different from the previous generation.
401                        //
402                        // As long as the task runs, the lock is renewed, so
403                        // the generation remains the same. If the lock is not
404                        // taken, it's because the lease has expired, which
405                        // is represented by the
406                        // `Ok(None)` value, and the task
407                        // must stop.
408                    }
409
410                    Ok(None) => {
411                        error!("Failed to renew the lock lease: the lock could not be obtained");
412
413                        // Exit the loop.
414                        break;
415                    }
416
417                    Err(err) => {
418                        error!("Error when extending the lock lease: {err:#}");
419
420                        // Exit the loop.
421                        break;
422                    }
423                }
424            }
425        }));
426
427        self.num_holders.fetch_add(1, Ordering::SeqCst);
428
429        let guard = CrossProcessLockGuard::new(self.num_holders.clone(), self.is_dirty.clone());
430
431        Ok(Ok(if self.is_dirty() {
432            CrossProcessLockState::Dirty(guard)
433        } else {
434            CrossProcessLockState::Clean(guard)
435        }))
436    }
437
438    /// Attempt to take the lock, with exponential backoff if the lock has
439    /// already been taken before.
440    ///
441    /// The `max_backoff` parameter is the maximum time (in milliseconds) that
442    /// should be waited for, between two attempts. When that time is
443    /// reached a second time, the lock will stop attempting to get the lock
444    /// and will return a timeout error upon locking. If not provided,
445    /// will wait for [`MAX_BACKOFF_MS`].
446    #[instrument(skip(self), fields(?self.lock_key, ?self.config))]
447    pub async fn spin_lock(
448        &self,
449        max_backoff: Option<u32>,
450    ) -> Result<Result<CrossProcessLockState, CrossProcessLockUnobtained>, L::LockError> {
451        // If there is no holder, this behaves as a no-op
452        let max_backoff = max_backoff.unwrap_or(MAX_BACKOFF_MS);
453
454        // Note: reads/writes to the backoff are racy across threads in theory, but the
455        // lock in `try_lock_once` should sequentialize it all.
456
457        loop {
458            // If the cross-process lock config is not `MultiProcess`, this behaves as a
459            // no-op and we just return
460            let lock_result = self.try_lock_once().await?;
461
462            if lock_result.is_ok() {
463                if matches!(self.config, CrossProcessLockConfig::MultiProcess { .. }) {
464                    // Reset backoff before returning, for the next attempt to lock.
465                    *self.backoff.lock().await = WaitingTime::Some(INITIAL_BACKOFF_MS);
466                }
467
468                return Ok(lock_result);
469            }
470
471            // Exponential backoff! Multiply by 2 the time we've waited before, cap it to
472            // max_backoff.
473            let mut backoff = self.backoff.lock().await;
474
475            let wait = match &mut *backoff {
476                WaitingTime::Some(val) => {
477                    let wait = *val;
478                    *val = val.saturating_mul(2);
479                    if *val >= max_backoff {
480                        *backoff = WaitingTime::Stop;
481                    }
482                    wait
483                }
484                WaitingTime::Stop => {
485                    // We've reached the maximum backoff, abandon.
486                    return Ok(Err(CrossProcessLockUnobtained::TimedOut));
487                }
488            };
489
490            debug!("Waiting {wait} before re-attempting to take the lock");
491            sleep(Duration::from_millis(wait.into())).await;
492        }
493    }
494
495    /// Returns the value in the database that represents the holder's
496    /// identifier.
497    pub fn lock_holder(&self) -> Option<&str> {
498        self.config.holder_name()
499    }
500}
501
502/// Represent a successful result of a locking attempt, either by
503/// [`CrossProcessLock::try_lock_once`] or [`CrossProcessLock::spin_lock`].
504#[derive(Debug)]
505#[must_use = "If unused, the `CrossProcessLock` will unlock at the end of the lease"]
506pub enum CrossProcessLockState {
507    /// The lock has been obtained successfully, all good.
508    Clean(CrossProcessLockGuard),
509
510    /// The lock has been obtained successfully, but the lock is dirty!
511    ///
512    /// This holder has obtained this cross-process lock once, then another
513    /// holder has obtained this cross-process lock _before_ this holder
514    /// obtained it again. The lock is marked as dirty. It means the value
515    /// protected by the cross-process lock may need to be reloaded if
516    /// synchronisation is important.
517    ///
518    /// Until [`CrossProcessLock::clear_dirty`] is called,
519    /// [`CrossProcessLock::is_dirty`], [`CrossProcessLock::try_lock_once`] and
520    /// [`CrossProcessLock::spin_lock`] will report the lock as dirty. Put it
521    /// differently: dirty once, dirty forever, unless
522    /// [`CrossProcessLock::clear_dirty`] is called.
523    Dirty(CrossProcessLockGuard),
524}
525
526impl CrossProcessLockState {
527    /// Map this value into the inner [`CrossProcessLockGuard`].
528    pub fn into_guard(self) -> CrossProcessLockGuard {
529        match self {
530            Self::Clean(guard) | Self::Dirty(guard) => guard,
531        }
532    }
533
534    /// Map this [`CrossProcessLockState`] into a
535    /// [`MappedCrossProcessLockState`].
536    ///
537    /// This is helpful when one wants to create its own wrapper over
538    /// [`CrossProcessLockGuard`].
539    pub fn map<F, G>(self, mapper: F) -> MappedCrossProcessLockState<G>
540    where
541        F: FnOnce(CrossProcessLockGuard) -> G,
542    {
543        match self {
544            Self::Clean(guard) => MappedCrossProcessLockState::Clean(mapper(guard)),
545            Self::Dirty(guard) => MappedCrossProcessLockState::Dirty(mapper(guard)),
546        }
547    }
548}
549
550/// A mapped [`CrossProcessLockState`].
551///
552/// Created by [`CrossProcessLockState::map`].
553#[derive(Debug)]
554#[must_use = "If unused, the `CrossProcessLock` will unlock at the end of the lease"]
555pub enum MappedCrossProcessLockState<G> {
556    /// The equivalent of [`CrossProcessLockState::Clean`].
557    Clean(G),
558
559    /// The equivalent of [`CrossProcessLockState::Dirty`].
560    Dirty(G),
561}
562
563impl<G> MappedCrossProcessLockState<G> {
564    /// Return `Some(G)` if `Self` is [`Clean`][Self::Clean].
565    pub fn as_clean(&self) -> Option<&G> {
566        match self {
567            Self::Clean(guard) => Some(guard),
568            Self::Dirty(_) => None,
569        }
570    }
571}
572
573/// Represent an unsuccessful result of a lock attempt, either by
574/// [`CrossProcessLock::try_lock_once`] or [`CrossProcessLock::spin_lock`].
575#[derive(Debug, thiserror::Error)]
576pub enum CrossProcessLockUnobtained {
577    /// The lock couldn't be obtained immediately because it is busy, i.e. it is
578    /// held by another holder.
579    #[error(
580        "The lock couldn't be obtained immediately because it is busy, i.e. it is held by another holder"
581    )]
582    Busy,
583
584    /// The lock couldn't be obtained after several attempts: locking has timed
585    /// out.
586    #[error("The lock couldn't be obtained after several attempts: locking has timed out")]
587    TimedOut,
588}
589
590/// Union of [`CrossProcessLockUnobtained`] and [`TryLock::LockError`].
591#[derive(Debug, thiserror::Error)]
592pub enum CrossProcessLockError {
593    #[error(transparent)]
594    Unobtained(#[from] CrossProcessLockUnobtained),
595
596    #[error(transparent)]
597    #[cfg(not(target_family = "wasm"))]
598    TryLock(#[from] Box<dyn Error + Send + Sync>),
599
600    #[error(transparent)]
601    #[cfg(target_family = "wasm")]
602    TryLock(#[from] Box<dyn Error>),
603}
604
605/// The cross-process lock config to use for the various stores.
606#[derive(Clone, Debug)]
607pub enum CrossProcessLockConfig {
608    /// The stores will be used in multiple processes, the holder name for the
609    /// cross-process lock is the associated `String`.
610    MultiProcess {
611        /// The name of the holder of the cross-process lock.
612        holder_name: String,
613    },
614    /// The stores will be used in a single process, there is no need for a
615    /// cross-process lock.
616    SingleProcess,
617}
618
619impl CrossProcessLockConfig {
620    /// Helper for quickly creating a [`CrossProcessLockConfig::MultiProcess`]
621    /// variant.
622    pub fn multi_process(holder_name: impl Into<String>) -> Self {
623        Self::MultiProcess { holder_name: holder_name.into() }
624    }
625
626    /// The holder name for the cross-process lock. This is only relevant for
627    /// [`CrossProcessLockConfig::MultiProcess`] variants.
628    pub fn holder_name(&self) -> Option<&str> {
629        match self {
630            Self::MultiProcess { holder_name } => Some(holder_name),
631            Self::SingleProcess => None,
632        }
633    }
634}
635
636#[cfg(test)]
637#[cfg(not(target_family = "wasm"))] // These tests require tokio::time, which is not implemented on wasm.
638mod tests {
639    use std::{
640        collections::HashMap,
641        ops::Not,
642        sync::{Arc, RwLock, atomic},
643    };
644
645    use assert_matches::assert_matches;
646    use matrix_sdk_test_macros::async_test;
647    use tokio::{spawn, task::yield_now};
648
649    use super::{
650        CrossProcessLock, CrossProcessLockConfig, CrossProcessLockError,
651        CrossProcessLockGeneration, CrossProcessLockState, CrossProcessLockUnobtained, TryLock,
652        memory_store_helper::{Lease, try_take_leased_lock},
653    };
654
655    #[derive(Clone, Default)]
656    struct TestStore {
657        leases: Arc<RwLock<HashMap<String, Lease>>>,
658    }
659
660    impl TestStore {
661        fn try_take_leased_lock(
662            &self,
663            lease_duration_ms: u32,
664            key: &str,
665            holder: &str,
666        ) -> Option<CrossProcessLockGeneration> {
667            try_take_leased_lock(&mut self.leases.write().unwrap(), lease_duration_ms, key, holder)
668        }
669    }
670
671    #[derive(Debug, thiserror::Error)]
672    enum DummyError {}
673
674    impl From<DummyError> for CrossProcessLockError {
675        fn from(value: DummyError) -> Self {
676            Self::TryLock(Box::new(value))
677        }
678    }
679
680    impl TryLock for TestStore {
681        type LockError = DummyError;
682
683        /// Try to take a lock using the given store.
684        async fn try_lock(
685            &self,
686            lease_duration_ms: u32,
687            key: &str,
688            holder: &str,
689        ) -> Result<Option<CrossProcessLockGeneration>, Self::LockError> {
690            Ok(self.try_take_leased_lock(lease_duration_ms, key, holder))
691        }
692    }
693
694    async fn release_lock(lock: CrossProcessLockState) {
695        drop(lock);
696        yield_now().await;
697    }
698
699    type TestResult = Result<(), CrossProcessLockError>;
700
701    #[async_test]
702    async fn test_simple_lock_unlock() -> TestResult {
703        let store = TestStore::default();
704        let lock = CrossProcessLock::new(
705            store,
706            "key".to_owned(),
707            CrossProcessLockConfig::multi_process("first"),
708        );
709
710        // The lock plain works when used with a single holder.
711        let guard = lock.try_lock_once().await?.expect("lock must be obtained successfully");
712        assert_matches!(guard, CrossProcessLockState::Clean(_));
713        assert!(lock.is_dirty().not());
714        assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 1);
715
716        // Releasing works.
717        release_lock(guard).await;
718        assert!(lock.is_dirty().not());
719        assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 0);
720
721        // Spin locking on the same lock always works, assuming no concurrent access.
722        let guard = lock.spin_lock(None).await?.expect("spin lock must be obtained successfully");
723        assert!(lock.is_dirty().not());
724        assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 1);
725
726        // Releasing still works.
727        release_lock(guard).await;
728        assert!(lock.is_dirty().not());
729        assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 0);
730
731        Ok(())
732    }
733
734    #[async_test]
735    async fn test_self_recovery() -> TestResult {
736        let store = TestStore::default();
737        let lock = CrossProcessLock::new(
738            store.clone(),
739            "key".to_owned(),
740            CrossProcessLockConfig::multi_process("first"),
741        );
742
743        // When a lock is obtained…
744        let guard = lock.try_lock_once().await?.expect("lock must be obtained successfully");
745        assert_matches!(guard, CrossProcessLockState::Clean(_));
746        assert!(lock.is_dirty().not());
747        assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 1);
748
749        // But then forgotten… (note: no need to release the guard)
750        drop(lock);
751
752        // And when rematerializing the lock with the same key/value…
753        let lock = CrossProcessLock::new(
754            store.clone(),
755            "key".to_owned(),
756            CrossProcessLockConfig::multi_process("first"),
757        );
758
759        // We still got it.
760        let guard =
761            lock.try_lock_once().await?.expect("lock (again) must be obtained successfully");
762        assert_matches!(guard, CrossProcessLockState::Clean(_));
763        assert!(lock.is_dirty().not());
764        assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 1);
765
766        Ok(())
767    }
768
769    #[async_test]
770    async fn test_multiple_holders_same_process() -> TestResult {
771        let store = TestStore::default();
772        let lock = CrossProcessLock::new(
773            store,
774            "key".to_owned(),
775            CrossProcessLockConfig::multi_process("first"),
776        );
777
778        // Taking the lock twice…
779        let guard1 = lock.try_lock_once().await?.expect("lock must be obtained successfully");
780        assert_matches!(guard1, CrossProcessLockState::Clean(_));
781        let guard2 = lock.try_lock_once().await?.expect("lock must be obtained successfully");
782        assert_matches!(guard2, CrossProcessLockState::Clean(_));
783        assert!(lock.is_dirty().not());
784
785        assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 2);
786
787        // … means we can release it twice.
788        release_lock(guard1).await;
789        assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 1);
790
791        release_lock(guard2).await;
792        assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 0);
793
794        assert!(lock.is_dirty().not());
795
796        Ok(())
797    }
798
799    #[async_test]
800    async fn test_multiple_processes() -> TestResult {
801        let store = TestStore::default();
802        let lock1 = CrossProcessLock::new(
803            store.clone(),
804            "key".to_owned(),
805            CrossProcessLockConfig::multi_process("first"),
806        );
807        let lock2 = CrossProcessLock::new(
808            store,
809            "key".to_owned(),
810            CrossProcessLockConfig::multi_process("second"),
811        );
812
813        // `lock1` acquires the lock.
814        let guard1 = lock1.try_lock_once().await?.expect("lock must be obtained successfully");
815        assert_matches!(guard1, CrossProcessLockState::Clean(_));
816        assert!(lock1.is_dirty().not());
817
818        // `lock2` cannot acquire the lock.
819        let err = lock2.try_lock_once().await?.expect_err("lock must NOT be obtained");
820        assert_matches!(err, CrossProcessLockUnobtained::Busy);
821
822        // `lock2` is waiting in a task.
823        let lock2_clone = lock2.clone();
824        let task = spawn(async move { lock2_clone.spin_lock(Some(500)).await });
825
826        yield_now().await;
827
828        drop(guard1);
829
830        // Once `lock1` is released, `lock2` managed to obtain it.
831        let guard2 = task
832            .await
833            .expect("join handle is properly awaited")
834            .expect("lock is successfully attempted")
835            .expect("lock must be obtained successfully");
836        assert_matches!(guard2, CrossProcessLockState::Clean(_));
837
838        // `lock1` and `lock2` are both clean!
839        assert!(lock1.is_dirty().not());
840        assert!(lock2.is_dirty().not());
841
842        // Now if `lock1` tries to obtain the lock with a small timeout, it will fail.
843        assert_matches!(
844            lock1.spin_lock(Some(200)).await,
845            Ok(Err(CrossProcessLockUnobtained::TimedOut))
846        );
847
848        Ok(())
849    }
850
851    #[async_test]
852    async fn test_multiple_processes_up_to_dirty() -> TestResult {
853        let store = TestStore::default();
854        let lock1 = CrossProcessLock::new(
855            store.clone(),
856            "key".to_owned(),
857            CrossProcessLockConfig::multi_process("first"),
858        );
859        let lock2 = CrossProcessLock::new(
860            store,
861            "key".to_owned(),
862            CrossProcessLockConfig::multi_process("second"),
863        );
864
865        // Obtain `lock1` once.
866        {
867            let guard = lock1.try_lock_once().await?.expect("lock must be obtained successfully");
868            assert_matches!(guard, CrossProcessLockState::Clean(_));
869            assert!(lock1.is_dirty().not());
870            drop(guard);
871
872            yield_now().await;
873        }
874
875        // Obtain `lock2` once.
876        {
877            let guard = lock2.try_lock_once().await?.expect("lock must be obtained successfully");
878            assert_matches!(guard, CrossProcessLockState::Clean(_));
879            assert!(lock1.is_dirty().not());
880            drop(guard);
881
882            yield_now().await;
883        }
884
885        for _ in 0..3 {
886            // Obtain `lock1` once more. Now it's dirty because `lock2` has acquired the
887            // lock meanwhile.
888            {
889                let guard =
890                    lock1.try_lock_once().await?.expect("lock must be obtained successfully");
891                assert_matches!(guard, CrossProcessLockState::Dirty(_));
892                assert!(lock1.is_dirty());
893
894                drop(guard);
895                yield_now().await;
896            }
897
898            // Obtain `lock1` once more! It still dirty because it has not been marked as
899            // non-dirty.
900            {
901                let guard =
902                    lock1.try_lock_once().await?.expect("lock must be obtained successfully");
903                assert_matches!(guard, CrossProcessLockState::Dirty(_));
904                assert!(lock1.is_dirty());
905                lock1.clear_dirty();
906
907                drop(guard);
908                yield_now().await;
909            }
910
911            // Obtain `lock1` once more. Now it's clear!
912            {
913                let guard =
914                    lock1.try_lock_once().await?.expect("lock must be obtained successfully");
915                assert_matches!(guard, CrossProcessLockState::Clean(_));
916                assert!(lock1.is_dirty().not());
917
918                drop(guard);
919                yield_now().await;
920            }
921
922            // Same dance with `lock2`!
923            {
924                let guard =
925                    lock2.try_lock_once().await?.expect("lock must be obtained successfully");
926                assert_matches!(guard, CrossProcessLockState::Dirty(_));
927                assert!(lock2.is_dirty());
928                lock2.clear_dirty();
929
930                drop(guard);
931                yield_now().await;
932            }
933        }
934
935        Ok(())
936    }
937}
938
939/// Some code that is shared by almost all `MemoryStore` implementations out
940/// there.
941pub mod memory_store_helper {
942    use std::collections::{HashMap, hash_map::Entry};
943
944    use ruma::time::{Duration, Instant};
945
946    use super::{CrossProcessLockGeneration, FIRST_CROSS_PROCESS_LOCK_GENERATION};
947
948    #[derive(Debug)]
949    pub struct Lease {
950        holder: String,
951        expiration: Instant,
952        generation: CrossProcessLockGeneration,
953    }
954
955    pub fn try_take_leased_lock(
956        leases: &mut HashMap<String, Lease>,
957        lease_duration_ms: u32,
958        key: &str,
959        holder: &str,
960    ) -> Option<CrossProcessLockGeneration> {
961        let now = Instant::now();
962        let expiration = now + Duration::from_millis(lease_duration_ms.into());
963
964        match leases.entry(key.to_owned()) {
965            // There is an existing holder.
966            Entry::Occupied(mut entry) => {
967                let Lease {
968                    holder: current_holder,
969                    expiration: current_expiration,
970                    generation: current_generation,
971                } = entry.get_mut();
972
973                if current_holder == holder {
974                    // We had the lease before, extend it.
975                    *current_expiration = expiration;
976
977                    Some(*current_generation)
978                } else {
979                    // We didn't have it.
980                    if *current_expiration < now {
981                        // Steal it!
982                        *current_holder = holder.to_owned();
983                        *current_expiration = expiration;
984                        *current_generation += 1;
985
986                        Some(*current_generation)
987                    } else {
988                        // We tried our best.
989                        None
990                    }
991                }
992            }
993
994            // There is no holder, easy.
995            Entry::Vacant(entry) => {
996                entry.insert(Lease {
997                    holder: holder.to_owned(),
998                    expiration: Instant::now() + Duration::from_millis(lease_duration_ms.into()),
999                    generation: FIRST_CROSS_PROCESS_LOCK_GENERATION,
1000                });
1001
1002                Some(FIRST_CROSS_PROCESS_LOCK_GENERATION)
1003            }
1004        }
1005    }
1006}