matrix_sdk/
deduplicating_handler.rs

1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Facilities to deduplicate similar queries running at the same time.
16//!
17//! See [`DeduplicatingHandler`].
18
19use std::{collections::BTreeMap, sync::Arc};
20
21use futures_core::Future;
22use matrix_sdk_common::SendOutsideWasm;
23use tokio::sync::Mutex;
24
25use crate::{Error, Result};
26
27/// State machine for the state of a query deduplicated by the
28/// [`DeduplicatingHandler`].
29enum QueryState {
30    /// The query hasn't completed. This doesn't mean it hasn't *started* yet,
31    /// but rather that it couldn't get to completion: some intermediate
32    /// steps might have run.
33    Cancelled,
34    /// The query has completed with an `Ok` result.
35    Success,
36    /// The query has completed with an `Err` result.
37    Failure,
38}
39
40type DeduplicatedRequestMap<Key> = Mutex<BTreeMap<Key, Arc<Mutex<QueryState>>>>;
41
42/// Handler that properly deduplicates function calls given a key uniquely
43/// identifying the call kind, and will properly report error upwards in case
44/// the concurrent call failed.
45///
46/// This is handy for deduplicating per-room requests, but can also be used in
47/// other contexts.
48pub(crate) struct DeduplicatingHandler<Key> {
49    /// Map of outstanding function calls, grouped by key.
50    inflight: DeduplicatedRequestMap<Key>,
51}
52
53impl<Key> Default for DeduplicatingHandler<Key> {
54    fn default() -> Self {
55        Self { inflight: Default::default() }
56    }
57}
58
59impl<Key: Clone + Ord + std::hash::Hash> DeduplicatingHandler<Key> {
60    /// Runs the given code if and only if there wasn't a similar query running
61    /// for the same key.
62    ///
63    /// Note: the `code` may be run multiple times, if the first query to run it
64    /// has been aborted by the caller (i.e. the future has been dropped).
65    /// As a consequence, it's important that the `code` future be
66    /// idempotent.
67    ///
68    /// See also [`DeduplicatingHandler`] for more details.
69    pub async fn run<'a, F: Future<Output = Result<()>> + SendOutsideWasm + 'a>(
70        &self,
71        key: Key,
72        code: F,
73    ) -> Result<()> {
74        let mut map = self.inflight.lock().await;
75
76        if let Some(request_mutex) = map.get(&key).cloned() {
77            // If a request is already going on, await the release of the lock.
78            drop(map);
79
80            let mut request_guard = request_mutex.lock().await;
81
82            return match *request_guard {
83                QueryState::Success => {
84                    // The query completed with a success: forward this success.
85                    Ok(())
86                }
87
88                QueryState::Failure => {
89                    // The query completed with an error, but we don't know what it is; report
90                    // there was an error.
91                    Err(Error::ConcurrentRequestFailed)
92                }
93
94                QueryState::Cancelled => {
95                    // If we could take a hold onto the mutex without it being in the success or
96                    // failure state, then the query hasn't completed (e.g. it could have been
97                    // cancelled). Repeat it.
98                    //
99                    // Note: there might be other waiters for the deduplicated result; they will
100                    // still be waiting for the mutex above, since the mutex is obtained for at
101                    // most one holder at the same time.
102                    self.run_code(key, code, &mut request_guard).await
103                }
104            };
105        }
106
107        // Let's assume the cancelled state, if we succeed or fail we'll modify the
108        // result.
109        let request_mutex = Arc::new(Mutex::new(QueryState::Cancelled));
110
111        map.insert(key.clone(), request_mutex.clone());
112
113        let mut request_guard = request_mutex.lock().await;
114        drop(map);
115
116        self.run_code(key, code, &mut request_guard).await
117    }
118
119    async fn run_code<'a, F: Future<Output = Result<()>> + SendOutsideWasm + 'a>(
120        &self,
121        key: Key,
122        code: F,
123        result: &mut QueryState,
124    ) -> Result<()> {
125        match code.await {
126            Ok(()) => {
127                // Mark the request as completed.
128                *result = QueryState::Success;
129
130                self.inflight.lock().await.remove(&key);
131
132                Ok(())
133            }
134
135            Err(err) => {
136                // Propagate the error state to other callers.
137                *result = QueryState::Failure;
138
139                // Remove the request from the in-flights set.
140                self.inflight.lock().await.remove(&key);
141
142                // Bubble up the error.
143                Err(err)
144            }
145        }
146    }
147}
148
149// Sorry wasm32, you don't have tokio::join :(
150#[cfg(all(test, not(target_arch = "wasm32")))]
151mod tests {
152    use std::sync::Arc;
153
154    use matrix_sdk_test::async_test;
155    use tokio::{join, spawn, sync::Mutex, task::yield_now};
156
157    use crate::deduplicating_handler::DeduplicatingHandler;
158
159    #[async_test]
160    async fn test_deduplicating_handler_same_key() -> anyhow::Result<()> {
161        let num_calls = Arc::new(Mutex::new(0));
162
163        let inner = || {
164            let num_calls_cloned = num_calls.clone();
165            async move {
166                yield_now().await;
167                *num_calls_cloned.lock().await += 1;
168                yield_now().await;
169                Ok(())
170            }
171        };
172
173        let handler = DeduplicatingHandler::default();
174
175        let (first, second) = join!(handler.run(0, inner()), handler.run(0, inner()));
176
177        assert!(first.is_ok());
178        assert!(second.is_ok());
179        assert_eq!(*num_calls.lock().await, 1);
180
181        Ok(())
182    }
183
184    #[async_test]
185    async fn test_deduplicating_handler_different_keys() -> anyhow::Result<()> {
186        let num_calls = Arc::new(Mutex::new(0));
187
188        let inner = || {
189            let num_calls_cloned = num_calls.clone();
190            async move {
191                yield_now().await;
192                *num_calls_cloned.lock().await += 1;
193                yield_now().await;
194                Ok(())
195            }
196        };
197
198        let handler = DeduplicatingHandler::default();
199
200        let (first, second) = join!(handler.run(0, inner()), handler.run(1, inner()));
201
202        assert!(first.is_ok());
203        assert!(second.is_ok());
204        assert_eq!(*num_calls.lock().await, 2);
205
206        Ok(())
207    }
208
209    #[async_test]
210    async fn test_deduplicating_handler_failure() -> anyhow::Result<()> {
211        let num_calls = Arc::new(Mutex::new(0));
212
213        let inner = || {
214            let num_calls_cloned = num_calls.clone();
215            async move {
216                yield_now().await;
217                *num_calls_cloned.lock().await += 1;
218                yield_now().await;
219                Err(crate::Error::AuthenticationRequired)
220            }
221        };
222
223        let handler = DeduplicatingHandler::default();
224
225        let (first, second) = join!(handler.run(0, inner()), handler.run(0, inner()));
226
227        assert!(first.is_err());
228        assert!(second.is_err());
229        assert_eq!(*num_calls.lock().await, 1);
230
231        // Then we can still do subsequent requests that may succeed (or fail), for the
232        // same key.
233        let inner = || {
234            let num_calls_cloned = num_calls.clone();
235            async move {
236                *num_calls_cloned.lock().await += 1;
237                Ok(())
238            }
239        };
240
241        *num_calls.lock().await = 0;
242        handler.run(0, inner()).await?;
243        assert_eq!(*num_calls.lock().await, 1);
244
245        Ok(())
246    }
247
248    #[async_test]
249    async fn test_cancelling_deduplicated_query() -> anyhow::Result<()> {
250        // A mutex used to prevent progress in the `inner` function.
251        let allow_progress = Arc::new(Mutex::new(()));
252
253        // Number of calls up to the `allow_progress` lock taking.
254        let num_before = Arc::new(Mutex::new(0));
255        // Number of calls after the `allow_progress` lock taking.
256        let num_after = Arc::new(Mutex::new(0));
257
258        let inner = || {
259            let num_before = num_before.clone();
260            let num_after = num_after.clone();
261            let allow_progress = allow_progress.clone();
262
263            async move {
264                *num_before.lock().await += 1;
265                let _ = allow_progress.lock().await;
266                *num_after.lock().await += 1;
267                Ok(())
268            }
269        };
270
271        let handler = Arc::new(DeduplicatingHandler::default());
272
273        // First, take the lock so that the `inner` can't complete.
274        let progress_guard = allow_progress.lock().await;
275
276        // Then, spawn deduplicated tasks.
277        let first = spawn({
278            let handler = handler.clone();
279            let query = inner();
280            async move { handler.run(0, query).await }
281        });
282
283        let second = spawn({
284            let handler = handler.clone();
285            let query = inner();
286            async move { handler.run(0, query).await }
287        });
288
289        // At this point, only the "before" count has been incremented, and only once
290        // (per the deduplication contract).
291        yield_now().await;
292
293        assert_eq!(*num_before.lock().await, 1);
294        assert_eq!(*num_after.lock().await, 0);
295
296        // Cancel the first task.
297        first.abort();
298        assert!(first.await.unwrap_err().is_cancelled());
299
300        // The second task restarts the whole query from the beginning.
301        yield_now().await;
302
303        assert_eq!(*num_before.lock().await, 2);
304        assert_eq!(*num_after.lock().await, 0);
305
306        // Release the progress lock; the second query can now finish.
307        drop(progress_guard);
308
309        assert!(second.await.unwrap().is_ok());
310
311        // We should've reached completion once.
312        assert_eq!(*num_before.lock().await, 2);
313        assert_eq!(*num_after.lock().await, 1);
314
315        Ok(())
316    }
317}