matrix_sdk_common/
executor.rs

1// Copyright 2021 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Abstraction over an executor so we can spawn tasks under Wasm the same way
16//! we do usually.
17//!
18//! On non Wasm platforms, this re-exports parts of tokio directly.  For Wasm,
19//! we provide a single-threaded solution that matches the interface that tokio
20//! provides as a drop in replacement.
21
22use std::{
23    future::Future,
24    pin::Pin,
25    task::{Context, Poll},
26};
27
28#[cfg(not(target_family = "wasm"))]
29mod sys {
30    pub use tokio::{
31        runtime::{Handle, Runtime},
32        task::{AbortHandle, JoinError, JoinHandle, spawn},
33    };
34}
35
36#[cfg(target_family = "wasm")]
37mod sys {
38    use std::{
39        future::Future,
40        pin::Pin,
41        task::{Context, Poll},
42    };
43
44    pub use futures_util::future::AbortHandle;
45    use futures_util::{
46        FutureExt,
47        future::{Abortable, RemoteHandle},
48    };
49
50    /// A Wasm specific version of `tokio::task::JoinError` designed to work
51    /// in the single-threaded environment available in Wasm environments.
52    #[derive(Debug)]
53    pub enum JoinError {
54        Cancelled,
55        Panic,
56    }
57
58    impl JoinError {
59        /// Returns true if the error was caused by the task being cancelled.
60        ///
61        /// See [the module level docs] for more information on cancellation.
62        ///
63        /// [the module level docs]: crate::task#cancellation
64        pub fn is_cancelled(&self) -> bool {
65            matches!(self, JoinError::Cancelled)
66        }
67    }
68
69    impl std::fmt::Display for JoinError {
70        fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71            match &self {
72                JoinError::Cancelled => write!(fmt, "task was cancelled"),
73                JoinError::Panic => write!(fmt, "task panicked"),
74            }
75        }
76    }
77
78    /// A Wasm specific version of `tokio::task::JoinHandle` that
79    /// holds handles to locally executing futures.
80    #[derive(Debug)]
81    pub struct JoinHandle<T> {
82        remote_handle: Option<RemoteHandle<T>>,
83        abort_handle: AbortHandle,
84    }
85
86    impl<T> JoinHandle<T> {
87        /// Aborts the spawned future, preventing it from being polled again.
88        pub fn abort(&self) {
89            self.abort_handle.abort();
90        }
91
92        /// Returns the handle to the `AbortHandle` that can be used to
93        /// abort the spawned future.
94        pub fn abort_handle(&self) -> AbortHandle {
95            self.abort_handle.clone()
96        }
97
98        /// Returns true if the spawned future has been aborted.
99        pub fn is_finished(&self) -> bool {
100            self.abort_handle.is_aborted()
101        }
102    }
103
104    impl<T> Drop for JoinHandle<T> {
105        fn drop(&mut self) {
106            // don't abort the spawned future
107            if let Some(h) = self.remote_handle.take() {
108                h.forget();
109            }
110        }
111    }
112
113    impl<T: 'static> Future for JoinHandle<T> {
114        type Output = Result<T, JoinError>;
115
116        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
117            if self.abort_handle.is_aborted() {
118                // The future has been aborted. It is not possible to poll it again.
119                Poll::Ready(Err(JoinError::Cancelled))
120            } else if let Some(handle) = self.remote_handle.as_mut() {
121                Pin::new(handle).poll(cx).map(Ok)
122            } else {
123                Poll::Ready(Err(JoinError::Panic))
124            }
125        }
126    }
127
128    /// A Wasm specific version of `tokio::task::spawn` that utilizes
129    /// wasm_bindgen_futures to spawn futures on the local executor.
130    pub fn spawn<F, T>(future: F) -> JoinHandle<T>
131    where
132        F: Future<Output = T> + 'static,
133    {
134        let (future, remote_handle) = future.remote_handle();
135        let (abort_handle, abort_registration) = AbortHandle::new_pair();
136        let future = Abortable::new(future, abort_registration);
137
138        wasm_bindgen_futures::spawn_local(async {
139            // Poll the future, and ignore the result (either it's `Ok(())`, or it's
140            // `Err(Aborted)`).
141            let _ = future.await;
142        });
143
144        JoinHandle { remote_handle: Some(remote_handle), abort_handle }
145    }
146}
147
148pub use sys::*;
149
150/// A type ensuring a task is aborted on drop.
151#[derive(Debug)]
152pub struct AbortOnDrop<T>(JoinHandle<T>);
153
154impl<T> AbortOnDrop<T> {
155    pub fn new(join_handle: JoinHandle<T>) -> Self {
156        Self(join_handle)
157    }
158}
159
160impl<T> Drop for AbortOnDrop<T> {
161    fn drop(&mut self) {
162        self.0.abort();
163    }
164}
165
166impl<T: 'static> Future for AbortOnDrop<T> {
167    type Output = Result<T, JoinError>;
168
169    fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
170        Pin::new(&mut self.0).poll(context)
171    }
172}
173
174/// Trait to create an [`AbortOnDrop`] from a [`JoinHandle`].
175pub trait JoinHandleExt<T> {
176    fn abort_on_drop(self) -> AbortOnDrop<T>;
177}
178
179impl<T> JoinHandleExt<T> for JoinHandle<T> {
180    fn abort_on_drop(self) -> AbortOnDrop<T> {
181        AbortOnDrop::new(self)
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use assert_matches::assert_matches;
188    use matrix_sdk_test_macros::async_test;
189
190    use super::spawn;
191
192    #[async_test]
193    async fn test_spawn() {
194        let future = async { 42 };
195        let join_handle = spawn(future);
196
197        assert_matches!(join_handle.await, Ok(42));
198    }
199
200    #[async_test]
201    async fn test_abort() {
202        let future = async { 42 };
203        let join_handle = spawn(future);
204
205        join_handle.abort();
206
207        assert!(join_handle.await.is_err());
208    }
209}