matrix_sdk/
deduplicating_handler.rs1use std::{collections::BTreeMap, sync::Arc};
20
21use futures_core::Future;
22use matrix_sdk_common::SendOutsideWasm;
23use tokio::sync::Mutex;
24
25use crate::{Error, Result};
26
27enum QueryState {
30 Cancelled,
34 Success,
36 Failure,
38}
39
40type DeduplicatedRequestMap<Key> = Mutex<BTreeMap<Key, Arc<Mutex<QueryState>>>>;
41
42pub(crate) struct DeduplicatingHandler<Key> {
49 inflight: DeduplicatedRequestMap<Key>,
51}
52
53impl<Key> Default for DeduplicatingHandler<Key> {
54 fn default() -> Self {
55 Self { inflight: Default::default() }
56 }
57}
58
59impl<Key: Clone + Ord + std::hash::Hash> DeduplicatingHandler<Key> {
60 pub async fn run<'a, F: Future<Output = Result<()>> + SendOutsideWasm + 'a>(
70 &self,
71 key: Key,
72 code: F,
73 ) -> Result<()> {
74 let mut map = self.inflight.lock().await;
75
76 if let Some(request_mutex) = map.get(&key).cloned() {
77 drop(map);
79
80 let mut request_guard = request_mutex.lock().await;
81
82 return match *request_guard {
83 QueryState::Success => {
84 Ok(())
86 }
87
88 QueryState::Failure => {
89 Err(Error::ConcurrentRequestFailed)
92 }
93
94 QueryState::Cancelled => {
95 self.run_code(key, code, &mut request_guard).await
103 }
104 };
105 }
106
107 let request_mutex = Arc::new(Mutex::new(QueryState::Cancelled));
110
111 map.insert(key.clone(), request_mutex.clone());
112
113 let mut request_guard = request_mutex.lock().await;
114 drop(map);
115
116 self.run_code(key, code, &mut request_guard).await
117 }
118
119 async fn run_code<'a, F: Future<Output = Result<()>> + SendOutsideWasm + 'a>(
120 &self,
121 key: Key,
122 code: F,
123 result: &mut QueryState,
124 ) -> Result<()> {
125 match code.await {
126 Ok(()) => {
127 *result = QueryState::Success;
129
130 self.inflight.lock().await.remove(&key);
131
132 Ok(())
133 }
134
135 Err(err) => {
136 *result = QueryState::Failure;
138
139 self.inflight.lock().await.remove(&key);
141
142 Err(err)
144 }
145 }
146 }
147}
148
149#[cfg(all(test, not(target_arch = "wasm32")))]
151mod tests {
152 use std::sync::Arc;
153
154 use matrix_sdk_test::async_test;
155 use tokio::{join, spawn, sync::Mutex, task::yield_now};
156
157 use crate::deduplicating_handler::DeduplicatingHandler;
158
159 #[async_test]
160 async fn test_deduplicating_handler_same_key() -> anyhow::Result<()> {
161 let num_calls = Arc::new(Mutex::new(0));
162
163 let inner = || {
164 let num_calls_cloned = num_calls.clone();
165 async move {
166 yield_now().await;
167 *num_calls_cloned.lock().await += 1;
168 yield_now().await;
169 Ok(())
170 }
171 };
172
173 let handler = DeduplicatingHandler::default();
174
175 let (first, second) = join!(handler.run(0, inner()), handler.run(0, inner()));
176
177 assert!(first.is_ok());
178 assert!(second.is_ok());
179 assert_eq!(*num_calls.lock().await, 1);
180
181 Ok(())
182 }
183
184 #[async_test]
185 async fn test_deduplicating_handler_different_keys() -> anyhow::Result<()> {
186 let num_calls = Arc::new(Mutex::new(0));
187
188 let inner = || {
189 let num_calls_cloned = num_calls.clone();
190 async move {
191 yield_now().await;
192 *num_calls_cloned.lock().await += 1;
193 yield_now().await;
194 Ok(())
195 }
196 };
197
198 let handler = DeduplicatingHandler::default();
199
200 let (first, second) = join!(handler.run(0, inner()), handler.run(1, inner()));
201
202 assert!(first.is_ok());
203 assert!(second.is_ok());
204 assert_eq!(*num_calls.lock().await, 2);
205
206 Ok(())
207 }
208
209 #[async_test]
210 async fn test_deduplicating_handler_failure() -> anyhow::Result<()> {
211 let num_calls = Arc::new(Mutex::new(0));
212
213 let inner = || {
214 let num_calls_cloned = num_calls.clone();
215 async move {
216 yield_now().await;
217 *num_calls_cloned.lock().await += 1;
218 yield_now().await;
219 Err(crate::Error::AuthenticationRequired)
220 }
221 };
222
223 let handler = DeduplicatingHandler::default();
224
225 let (first, second) = join!(handler.run(0, inner()), handler.run(0, inner()));
226
227 assert!(first.is_err());
228 assert!(second.is_err());
229 assert_eq!(*num_calls.lock().await, 1);
230
231 let inner = || {
234 let num_calls_cloned = num_calls.clone();
235 async move {
236 *num_calls_cloned.lock().await += 1;
237 Ok(())
238 }
239 };
240
241 *num_calls.lock().await = 0;
242 handler.run(0, inner()).await?;
243 assert_eq!(*num_calls.lock().await, 1);
244
245 Ok(())
246 }
247
248 #[async_test]
249 async fn test_cancelling_deduplicated_query() -> anyhow::Result<()> {
250 let allow_progress = Arc::new(Mutex::new(()));
252
253 let num_before = Arc::new(Mutex::new(0));
255 let num_after = Arc::new(Mutex::new(0));
257
258 let inner = || {
259 let num_before = num_before.clone();
260 let num_after = num_after.clone();
261 let allow_progress = allow_progress.clone();
262
263 async move {
264 *num_before.lock().await += 1;
265 let _ = allow_progress.lock().await;
266 *num_after.lock().await += 1;
267 Ok(())
268 }
269 };
270
271 let handler = Arc::new(DeduplicatingHandler::default());
272
273 let progress_guard = allow_progress.lock().await;
275
276 let first = spawn({
278 let handler = handler.clone();
279 let query = inner();
280 async move { handler.run(0, query).await }
281 });
282
283 let second = spawn({
284 let handler = handler.clone();
285 let query = inner();
286 async move { handler.run(0, query).await }
287 });
288
289 yield_now().await;
292
293 assert_eq!(*num_before.lock().await, 1);
294 assert_eq!(*num_after.lock().await, 0);
295
296 first.abort();
298 assert!(first.await.unwrap_err().is_cancelled());
299
300 yield_now().await;
302
303 assert_eq!(*num_before.lock().await, 2);
304 assert_eq!(*num_after.lock().await, 0);
305
306 drop(progress_guard);
308
309 assert!(second.await.unwrap().is_ok());
310
311 assert_eq!(*num_before.lock().await, 2);
313 assert_eq!(*num_after.lock().await, 1);
314
315 Ok(())
316 }
317}