matrix_sdk_common/
failures_cache.rs1use std::{borrow::Borrow, collections::HashMap, hash::Hash, sync::Arc, time::Duration};
19
20use ruma::time::Instant;
21
22use super::locks::RwLock;
23
24const MAX_DELAY: u64 = 15 * 60;
25const MULTIPLIER: u64 = 15;
26
27#[derive(Clone, Debug)]
32pub struct FailuresCache<T: Eq + Hash> {
33 inner: Arc<InnerCache<T>>,
34}
35
36#[derive(Debug)]
37struct InnerCache<T: Eq + Hash> {
38 max_delay: Duration,
39 backoff_multiplier: u64,
40 items: RwLock<HashMap<T, FailuresItem>>,
41}
42
43impl<T: Eq + Hash> Default for InnerCache<T> {
44 fn default() -> Self {
45 Self {
46 max_delay: Duration::from_secs(MAX_DELAY),
47 backoff_multiplier: MULTIPLIER,
48 items: Default::default(),
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy)]
54struct FailuresItem {
55 insertion_time: Instant,
56 duration: Duration,
57
58 failure_count: u8,
62}
63
64impl FailuresItem {
65 fn expired(&self) -> bool {
67 self.insertion_time.elapsed() >= self.duration
68 }
69
70 fn expire(&mut self) {
75 self.duration = Duration::from_secs(0);
76 }
77}
78
79impl<T> FailuresCache<T>
80where
81 T: Eq + Hash,
82{
83 pub fn new() -> Self {
84 Self { inner: Default::default() }
85 }
86
87 pub fn with_settings(max_delay: Duration, multiplier: u8) -> Self {
88 Self {
89 inner: InnerCache {
90 max_delay,
91 backoff_multiplier: multiplier.into(),
92 items: Default::default(),
93 }
94 .into(),
95 }
96 }
97
98 pub fn contains<Q>(&self, key: &Q) -> bool
100 where
101 T: Borrow<Q>,
102 Q: Hash + Eq + ?Sized,
103 {
104 self.inner.items.read().get(key).is_some_and(|item| !item.expired())
105 }
106
107 pub fn failure_count<Q>(&self, key: &Q) -> Option<u8>
118 where
119 T: Borrow<Q>,
120 Q: Hash + Eq + ?Sized,
121 {
122 self.inner.items.read().get(key).map(|i| i.failure_count)
123 }
124
125 fn calculate_delay(&self, failure_count: u8) -> Duration {
132 let exponential_backoff = 2u64.saturating_pow(failure_count.into());
133 let delay = exponential_backoff.saturating_mul(self.inner.backoff_multiplier);
134
135 Duration::from_secs(delay).clamp(Duration::from_secs(1), self.inner.max_delay)
136 }
137
138 pub fn insert(&self, item: T) {
140 self.extend([item]);
141 }
142
143 pub fn extend(&self, iterator: impl IntoIterator<Item = T>) {
149 let mut lock = self.inner.items.write();
150
151 let now = Instant::now();
152
153 for key in iterator {
154 let failure_count = if let Some(value) = lock.get(&key) {
155 value.failure_count.saturating_add(1)
156 } else {
157 0
158 };
159
160 let delay = self.calculate_delay(failure_count);
161
162 let item = FailuresItem { insertion_time: now, duration: delay, failure_count };
163
164 lock.insert(key, item);
165 }
166 }
167
168 pub fn remove<'a, I, Q>(&'a self, iterator: I)
170 where
171 I: Iterator<Item = &'a Q>,
172 T: Borrow<Q>,
173 Q: Hash + Eq + 'a + ?Sized,
174 {
175 let mut lock = self.inner.items.write();
176
177 for item in iterator {
178 lock.remove(item);
179 }
180 }
181
182 #[doc(hidden)]
187 pub fn expire(&self, item: &T) {
188 self.inner.items.write().get_mut(item).map(FailuresItem::expire);
189 }
190}
191
192impl<T: Eq + Hash> Default for FailuresCache<T> {
193 fn default() -> Self {
194 Self::new()
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use std::time::Duration;
201
202 use proptest::prelude::*;
203
204 use super::FailuresCache;
205
206 #[test]
207 fn failures_cache() {
208 let cache = FailuresCache::new();
209
210 assert!(!cache.contains(&1));
211 cache.extend([1u8].iter());
212 assert!(cache.contains(&1));
213
214 cache.inner.items.write().get_mut(&1).unwrap().duration = Duration::from_secs(0);
215 assert!(!cache.contains(&1));
216
217 cache.remove([1u8].iter());
218 assert!(cache.inner.items.read().get(&1).is_none())
219 }
220
221 #[test]
222 fn failures_cache_timeout() {
223 let cache: FailuresCache<u8> = FailuresCache::new();
224
225 assert_eq!(cache.calculate_delay(0).as_secs(), 15);
226 assert_eq!(cache.calculate_delay(1).as_secs(), 30);
227 assert_eq!(cache.calculate_delay(2).as_secs(), 60);
228 assert_eq!(cache.calculate_delay(3).as_secs(), 120);
229 assert_eq!(cache.calculate_delay(4).as_secs(), 240);
230 assert_eq!(cache.calculate_delay(5).as_secs(), 480);
231 assert_eq!(cache.calculate_delay(6).as_secs(), 900);
232 assert_eq!(cache.calculate_delay(7).as_secs(), 900);
233 }
234
235 proptest! {
236 #[test]
237 fn failures_cache_proptest_timeout(count in 0..10u8) {
238 let cache: FailuresCache<u8> = FailuresCache::new();
239 let delay = cache.calculate_delay(count).as_secs();
240
241 assert!(delay <= 900);
242 assert!(delay >= 15);
243 }
244 }
245}