1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
1415//! Facilities to deduplicate similar queries running at the same time.
16//!
17//! See [`DeduplicatingHandler`].
1819use std::{collections::BTreeMap, sync::Arc};
2021use futures_core::Future;
22use matrix_sdk_common::SendOutsideWasm;
23use tokio::sync::Mutex;
2425use crate::{Error, Result};
2627/// State machine for the state of a query deduplicated by the
28/// [`DeduplicatingHandler`].
29enum QueryState {
30/// The query hasn't completed. This doesn't mean it hasn't *started* yet,
31 /// but rather that it couldn't get to completion: some intermediate
32 /// steps might have run.
33Cancelled,
34/// The query has completed with an `Ok` result.
35Success,
36/// The query has completed with an `Err` result.
37Failure,
38}
3940type DeduplicatedRequestMap<Key> = Mutex<BTreeMap<Key, Arc<Mutex<QueryState>>>>;
4142/// Handler that properly deduplicates function calls given a key uniquely
43/// identifying the call kind, and will properly report error upwards in case
44/// the concurrent call failed.
45///
46/// This is handy for deduplicating per-room requests, but can also be used in
47/// other contexts.
48pub(crate) struct DeduplicatingHandler<Key> {
49/// Map of outstanding function calls, grouped by key.
50inflight: DeduplicatedRequestMap<Key>,
51}
5253impl<Key> Default for DeduplicatingHandler<Key> {
54fn default() -> Self {
55Self { inflight: Default::default() }
56 }
57}
5859impl<Key: Clone + Ord + std::hash::Hash> DeduplicatingHandler<Key> {
60/// Runs the given code if and only if there wasn't a similar query running
61 /// for the same key.
62 ///
63 /// Note: the `code` may be run multiple times, if the first query to run it
64 /// has been aborted by the caller (i.e. the future has been dropped).
65 /// As a consequence, it's important that the `code` future be
66 /// idempotent.
67 ///
68 /// See also [`DeduplicatingHandler`] for more details.
69pub async fn run<'a, F: Future<Output = Result<()>> + SendOutsideWasm + 'a>(
70&self,
71 key: Key,
72 code: F,
73 ) -> Result<()> {
74let mut map = self.inflight.lock().await;
7576if let Some(request_mutex) = map.get(&key).cloned() {
77// If a request is already going on, await the release of the lock.
78drop(map);
7980let mut request_guard = request_mutex.lock().await;
8182return match *request_guard {
83 QueryState::Success => {
84// The query completed with a success: forward this success.
85Ok(())
86 }
8788 QueryState::Failure => {
89// The query completed with an error, but we don't know what it is; report
90 // there was an error.
91Err(Error::ConcurrentRequestFailed)
92 }
9394 QueryState::Cancelled => {
95// If we could take a hold onto the mutex without it being in the success or
96 // failure state, then the query hasn't completed (e.g. it could have been
97 // cancelled). Repeat it.
98 //
99 // Note: there might be other waiters for the deduplicated result; they will
100 // still be waiting for the mutex above, since the mutex is obtained for at
101 // most one holder at the same time.
102self.run_code(key, code, &mut request_guard).await
103}
104 };
105 }
106107// Let's assume the cancelled state, if we succeed or fail we'll modify the
108 // result.
109let request_mutex = Arc::new(Mutex::new(QueryState::Cancelled));
110111 map.insert(key.clone(), request_mutex.clone());
112113let mut request_guard = request_mutex.lock().await;
114 drop(map);
115116self.run_code(key, code, &mut request_guard).await
117}
118119async fn run_code<'a, F: Future<Output = Result<()>> + SendOutsideWasm + 'a>(
120&self,
121 key: Key,
122 code: F,
123 result: &mut QueryState,
124 ) -> Result<()> {
125match code.await {
126Ok(()) => {
127// Mark the request as completed.
128*result = QueryState::Success;
129130self.inflight.lock().await.remove(&key);
131132Ok(())
133 }
134135Err(err) => {
136// Propagate the error state to other callers.
137*result = QueryState::Failure;
138139// Remove the request from the in-flights set.
140self.inflight.lock().await.remove(&key);
141142// Bubble up the error.
143Err(err)
144 }
145 }
146 }
147}
148149// Sorry wasm32, you don't have tokio::join :(
150#[cfg(all(test, not(target_arch = "wasm32")))]
151mod tests {
152use std::sync::Arc;
153154use matrix_sdk_test::async_test;
155use tokio::{join, spawn, sync::Mutex, task::yield_now};
156157use crate::deduplicating_handler::DeduplicatingHandler;
158159#[async_test]
160async fn test_deduplicating_handler_same_key() -> anyhow::Result<()> {
161let num_calls = Arc::new(Mutex::new(0));
162163let inner = || {
164let num_calls_cloned = num_calls.clone();
165async move {
166 yield_now().await;
167*num_calls_cloned.lock().await += 1;
168 yield_now().await;
169Ok(())
170 }
171 };
172173let handler = DeduplicatingHandler::default();
174175let (first, second) = join!(handler.run(0, inner()), handler.run(0, inner()));
176177assert!(first.is_ok());
178assert!(second.is_ok());
179assert_eq!(*num_calls.lock().await, 1);
180181Ok(())
182 }
183184#[async_test]
185async fn test_deduplicating_handler_different_keys() -> anyhow::Result<()> {
186let num_calls = Arc::new(Mutex::new(0));
187188let inner = || {
189let num_calls_cloned = num_calls.clone();
190async move {
191 yield_now().await;
192*num_calls_cloned.lock().await += 1;
193 yield_now().await;
194Ok(())
195 }
196 };
197198let handler = DeduplicatingHandler::default();
199200let (first, second) = join!(handler.run(0, inner()), handler.run(1, inner()));
201202assert!(first.is_ok());
203assert!(second.is_ok());
204assert_eq!(*num_calls.lock().await, 2);
205206Ok(())
207 }
208209#[async_test]
210async fn test_deduplicating_handler_failure() -> anyhow::Result<()> {
211let num_calls = Arc::new(Mutex::new(0));
212213let inner = || {
214let num_calls_cloned = num_calls.clone();
215async move {
216 yield_now().await;
217*num_calls_cloned.lock().await += 1;
218 yield_now().await;
219Err(crate::Error::AuthenticationRequired)
220 }
221 };
222223let handler = DeduplicatingHandler::default();
224225let (first, second) = join!(handler.run(0, inner()), handler.run(0, inner()));
226227assert!(first.is_err());
228assert!(second.is_err());
229assert_eq!(*num_calls.lock().await, 1);
230231// Then we can still do subsequent requests that may succeed (or fail), for the
232 // same key.
233let inner = || {
234let num_calls_cloned = num_calls.clone();
235async move {
236*num_calls_cloned.lock().await += 1;
237Ok(())
238 }
239 };
240241*num_calls.lock().await = 0;
242 handler.run(0, inner()).await?;
243assert_eq!(*num_calls.lock().await, 1);
244245Ok(())
246 }
247248#[async_test]
249async fn test_cancelling_deduplicated_query() -> anyhow::Result<()> {
250// A mutex used to prevent progress in the `inner` function.
251let allow_progress = Arc::new(Mutex::new(()));
252253// Number of calls up to the `allow_progress` lock taking.
254let num_before = Arc::new(Mutex::new(0));
255// Number of calls after the `allow_progress` lock taking.
256let num_after = Arc::new(Mutex::new(0));
257258let inner = || {
259let num_before = num_before.clone();
260let num_after = num_after.clone();
261let allow_progress = allow_progress.clone();
262263async move {
264*num_before.lock().await += 1;
265let _ = allow_progress.lock().await;
266*num_after.lock().await += 1;
267Ok(())
268 }
269 };
270271let handler = Arc::new(DeduplicatingHandler::default());
272273// First, take the lock so that the `inner` can't complete.
274let progress_guard = allow_progress.lock().await;
275276// Then, spawn deduplicated tasks.
277let first = spawn({
278let handler = handler.clone();
279let query = inner();
280async move { handler.run(0, query).await }
281 });
282283let second = spawn({
284let handler = handler.clone();
285let query = inner();
286async move { handler.run(0, query).await }
287 });
288289// At this point, only the "before" count has been incremented, and only once
290 // (per the deduplication contract).
291yield_now().await;
292293assert_eq!(*num_before.lock().await, 1);
294assert_eq!(*num_after.lock().await, 0);
295296// Cancel the first task.
297first.abort();
298assert!(first.await.unwrap_err().is_cancelled());
299300// The second task restarts the whole query from the beginning.
301yield_now().await;
302303assert_eq!(*num_before.lock().await, 2);
304assert_eq!(*num_after.lock().await, 0);
305306// Release the progress lock; the second query can now finish.
307drop(progress_guard);
308309assert!(second.await.unwrap().is_ok());
310311// We should've reached completion once.
312assert_eq!(*num_before.lock().await, 2);
313assert_eq!(*num_after.lock().await, 1);
314315Ok(())
316 }
317}