1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
1415//! A TTL cache which can be used to time out repeated operations that might
16//! experience intermittent failures.
1718use std::{borrow::Borrow, collections::HashMap, hash::Hash, sync::Arc, time::Duration};
1920use ruma::time::Instant;
2122use super::locks::RwLock;
2324const MAX_DELAY: u64 = 15 * 60;
25const MULTIPLIER: u64 = 15;
2627/// A TTL cache where items get inactive instead of discarded.
28///
29/// The items need to be explicitly removed from the cache. This allows us to
30/// implement exponential backoff based TTL.
31#[derive(Clone, Debug)]
32pub struct FailuresCache<T: Eq + Hash> {
33 inner: Arc<InnerCache<T>>,
34}
3536#[derive(Debug)]
37struct InnerCache<T: Eq + Hash> {
38 max_delay: Duration,
39 backoff_multiplier: u64,
40 items: RwLock<HashMap<T, FailuresItem>>,
41}
4243impl<T: Eq + Hash> Default for InnerCache<T> {
44fn default() -> Self {
45Self {
46 max_delay: Duration::from_secs(MAX_DELAY),
47 backoff_multiplier: MULTIPLIER,
48 items: Default::default(),
49 }
50 }
51}
5253#[derive(Debug, Clone, Copy)]
54struct FailuresItem {
55 insertion_time: Instant,
56 duration: Duration,
5758/// Number of times that this item has failed after it was first added to
59 /// the cache. (In other words, one less than the total number of
60 /// failures.)
61failure_count: u8,
62}
6364impl FailuresItem {
65/// Has the item expired.
66fn expired(&self) -> bool {
67self.insertion_time.elapsed() >= self.duration
68 }
6970/// Force the expiry of this item.
71 ///
72 /// This doesn't reset the failure count, but does mark the item as ready
73 /// for immediate retry.
74fn expire(&mut self) {
75self.duration = Duration::from_secs(0);
76 }
77}
7879impl<T> FailuresCache<T>
80where
81T: Eq + Hash,
82{
83pub fn new() -> Self {
84Self { inner: Default::default() }
85 }
8687pub fn with_settings(max_delay: Duration, multiplier: u8) -> Self {
88Self {
89 inner: InnerCache {
90 max_delay,
91 backoff_multiplier: multiplier.into(),
92 items: Default::default(),
93 }
94 .into(),
95 }
96 }
9798/// Is the given key non-expired and part of the cache.
99pub fn contains<Q>(&self, key: &Q) -> bool
100where
101T: Borrow<Q>,
102 Q: Hash + Eq + ?Sized,
103 {
104let lock = self.inner.items.read();
105106let contains = if let Some(item) = lock.get(key) { !item.expired() } else { false };
107108 contains
109 }
110111/// Get the failure count for a given key.
112 ///
113 /// # Returns
114 ///
115 /// * `None` if this key is not in the failure cache. (It has never failed,
116 /// or it has been [`FailuresCache::remove()`]d since the last failure.)
117 ///
118 /// * `Some(u8)`: the number of times it has failed since it was first
119 /// added to the failure cache. (In other words, one less than the total
120 /// number of failures.)
121pub fn failure_count<Q>(&self, key: &Q) -> Option<u8>
122where
123T: Borrow<Q>,
124 Q: Hash + Eq + ?Sized,
125 {
126let lock = self.inner.items.read();
127 lock.get(key).map(|i| i.failure_count)
128 }
129130/// This will calculate a duration that determines how long an item is
131 /// considered to be valid while being in the cache.
132 ///
133 /// The returned duration will follow this sequence if the default
134 /// multiplier and `max_delay` values are used, values are in minutes:
135 /// [0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 15.0]
136fn calculate_delay(&self, failure_count: u8) -> Duration {
137let exponential_backoff = 2u64.saturating_pow(failure_count.into());
138let delay = exponential_backoff.saturating_mul(self.inner.backoff_multiplier);
139140 Duration::from_secs(delay).clamp(Duration::from_secs(1), self.inner.max_delay)
141 }
142143/// Add a single item to the cache.
144pub fn insert(&self, item: T) {
145self.extend([item]);
146 }
147148/// Extend the cache with the given iterator of items.
149 ///
150 /// Items that are already part of the cache, whether they are expired or
151 /// not, will have their TTL extended using an exponential backoff
152 /// algorithm.
153pub fn extend(&self, iterator: impl IntoIterator<Item = T>) {
154let mut lock = self.inner.items.write();
155156let now = Instant::now();
157158for key in iterator {
159let failure_count = if let Some(value) = lock.get(&key) {
160 value.failure_count.saturating_add(1)
161 } else {
1620
163};
164165let delay = self.calculate_delay(failure_count);
166167let item = FailuresItem { insertion_time: now, duration: delay, failure_count };
168169 lock.insert(key, item);
170 }
171 }
172173/// Remove the items contained in the iterator from the cache.
174pub fn remove<'a, I, Q>(&'a self, iterator: I)
175where
176I: Iterator<Item = &'a Q>,
177 T: Borrow<Q>,
178 Q: Hash + Eq + 'a + ?Sized,
179 {
180let mut lock = self.inner.items.write();
181182for item in iterator {
183 lock.remove(item);
184 }
185 }
186187/// Force the expiry of the given item, if it is present in the cache.
188 ///
189 /// This doesn't reset the failure count, but does mark the item as ready
190 /// for immediate retry.
191#[doc(hidden)]
192pub fn expire(&self, item: &T) {
193let mut lock = self.inner.items.write();
194 lock.get_mut(item).map(FailuresItem::expire);
195 }
196}
197198impl<T: Eq + Hash> Default for FailuresCache<T> {
199fn default() -> Self {
200Self::new()
201 }
202}
203204#[cfg(test)]
205mod tests {
206use std::time::Duration;
207208use proptest::prelude::*;
209210use super::FailuresCache;
211212#[test]
213fn failures_cache() {
214let cache = FailuresCache::new();
215216assert!(!cache.contains(&1));
217 cache.extend([1u8].iter());
218assert!(cache.contains(&1));
219220 cache.inner.items.write().get_mut(&1).unwrap().duration = Duration::from_secs(0);
221assert!(!cache.contains(&1));
222223 cache.remove([1u8].iter());
224assert!(cache.inner.items.read().get(&1).is_none())
225 }
226227#[test]
228fn failures_cache_timeout() {
229let cache: FailuresCache<u8> = FailuresCache::new();
230231assert_eq!(cache.calculate_delay(0).as_secs(), 15);
232assert_eq!(cache.calculate_delay(1).as_secs(), 30);
233assert_eq!(cache.calculate_delay(2).as_secs(), 60);
234assert_eq!(cache.calculate_delay(3).as_secs(), 120);
235assert_eq!(cache.calculate_delay(4).as_secs(), 240);
236assert_eq!(cache.calculate_delay(5).as_secs(), 480);
237assert_eq!(cache.calculate_delay(6).as_secs(), 900);
238assert_eq!(cache.calculate_delay(7).as_secs(), 900);
239 }
240241proptest! {
242#[test]
243fn failures_cache_proptest_timeout(count in 0..10u8) {
244let cache: FailuresCache<u8> = FailuresCache::new();
245let delay = cache.calculate_delay(count).as_secs();
246247assert!(delay <= 900);
248assert!(delay >= 15);
249 }
250 }
251}