matrix_sdk_common/
failures_cache.rs1use std::{borrow::Borrow, collections::HashMap, hash::Hash, sync::Arc, time::Duration};
19
20use ruma::time::Instant;
21
22use super::locks::RwLock;
23
24const MAX_DELAY: u64 = 15 * 60;
25const MULTIPLIER: u64 = 15;
26
27#[derive(Clone, Debug)]
32pub struct FailuresCache<T: Eq + Hash> {
33 inner: Arc<InnerCache<T>>,
34}
35
36#[derive(Debug)]
37struct InnerCache<T: Eq + Hash> {
38 max_delay: Duration,
39 backoff_multiplier: u64,
40 items: RwLock<HashMap<T, FailuresItem>>,
41}
42
43impl<T: Eq + Hash> Default for InnerCache<T> {
44 fn default() -> Self {
45 Self {
46 max_delay: Duration::from_secs(MAX_DELAY),
47 backoff_multiplier: MULTIPLIER,
48 items: Default::default(),
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy)]
54struct FailuresItem {
55 insertion_time: Instant,
56 duration: Duration,
57
58 failure_count: u8,
62}
63
64impl FailuresItem {
65 fn expired(&self) -> bool {
67 self.insertion_time.elapsed() >= self.duration
68 }
69
70 fn expire(&mut self) {
75 self.duration = Duration::from_secs(0);
76 }
77}
78
79impl<T> FailuresCache<T>
80where
81 T: Eq + Hash,
82{
83 pub fn new() -> Self {
84 Self { inner: Default::default() }
85 }
86
87 pub fn with_settings(max_delay: Duration, multiplier: u8) -> Self {
88 Self {
89 inner: InnerCache {
90 max_delay,
91 backoff_multiplier: multiplier.into(),
92 items: Default::default(),
93 }
94 .into(),
95 }
96 }
97
98 pub fn contains<Q>(&self, key: &Q) -> bool
100 where
101 T: Borrow<Q>,
102 Q: Hash + Eq + ?Sized,
103 {
104 let lock = self.inner.items.read();
105
106 let contains = if let Some(item) = lock.get(key) { !item.expired() } else { false };
107
108 contains
109 }
110
111 pub fn failure_count<Q>(&self, key: &Q) -> Option<u8>
122 where
123 T: Borrow<Q>,
124 Q: Hash + Eq + ?Sized,
125 {
126 let lock = self.inner.items.read();
127 lock.get(key).map(|i| i.failure_count)
128 }
129
130 fn calculate_delay(&self, failure_count: u8) -> Duration {
137 let exponential_backoff = 2u64.saturating_pow(failure_count.into());
138 let delay = exponential_backoff.saturating_mul(self.inner.backoff_multiplier);
139
140 Duration::from_secs(delay).clamp(Duration::from_secs(1), self.inner.max_delay)
141 }
142
143 pub fn insert(&self, item: T) {
145 self.extend([item]);
146 }
147
148 pub fn extend(&self, iterator: impl IntoIterator<Item = T>) {
154 let mut lock = self.inner.items.write();
155
156 let now = Instant::now();
157
158 for key in iterator {
159 let failure_count = if let Some(value) = lock.get(&key) {
160 value.failure_count.saturating_add(1)
161 } else {
162 0
163 };
164
165 let delay = self.calculate_delay(failure_count);
166
167 let item = FailuresItem { insertion_time: now, duration: delay, failure_count };
168
169 lock.insert(key, item);
170 }
171 }
172
173 pub fn remove<'a, I, Q>(&'a self, iterator: I)
175 where
176 I: Iterator<Item = &'a Q>,
177 T: Borrow<Q>,
178 Q: Hash + Eq + 'a + ?Sized,
179 {
180 let mut lock = self.inner.items.write();
181
182 for item in iterator {
183 lock.remove(item);
184 }
185 }
186
187 #[doc(hidden)]
192 pub fn expire(&self, item: &T) {
193 let mut lock = self.inner.items.write();
194 lock.get_mut(item).map(FailuresItem::expire);
195 }
196}
197
198impl<T: Eq + Hash> Default for FailuresCache<T> {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use std::time::Duration;
207
208 use proptest::prelude::*;
209
210 use super::FailuresCache;
211
212 #[test]
213 fn failures_cache() {
214 let cache = FailuresCache::new();
215
216 assert!(!cache.contains(&1));
217 cache.extend([1u8].iter());
218 assert!(cache.contains(&1));
219
220 cache.inner.items.write().get_mut(&1).unwrap().duration = Duration::from_secs(0);
221 assert!(!cache.contains(&1));
222
223 cache.remove([1u8].iter());
224 assert!(cache.inner.items.read().get(&1).is_none())
225 }
226
227 #[test]
228 fn failures_cache_timeout() {
229 let cache: FailuresCache<u8> = FailuresCache::new();
230
231 assert_eq!(cache.calculate_delay(0).as_secs(), 15);
232 assert_eq!(cache.calculate_delay(1).as_secs(), 30);
233 assert_eq!(cache.calculate_delay(2).as_secs(), 60);
234 assert_eq!(cache.calculate_delay(3).as_secs(), 120);
235 assert_eq!(cache.calculate_delay(4).as_secs(), 240);
236 assert_eq!(cache.calculate_delay(5).as_secs(), 480);
237 assert_eq!(cache.calculate_delay(6).as_secs(), 900);
238 assert_eq!(cache.calculate_delay(7).as_secs(), 900);
239 }
240
241 proptest! {
242 #[test]
243 fn failures_cache_proptest_timeout(count in 0..10u8) {
244 let cache: FailuresCache<u8> = FailuresCache::new();
245 let delay = cache.calculate_delay(count).as_secs();
246
247 assert!(delay <= 900);
248 assert!(delay >= 15);
249 }
250 }
251}