matrix_sdk_common/
executor.rs1use std::{
23 future::Future,
24 pin::Pin,
25 task::{Context, Poll},
26};
27
28#[cfg(not(target_family = "wasm"))]
29mod sys {
30 pub use tokio::{
31 runtime::{Handle, Runtime},
32 task::{AbortHandle, JoinError, JoinHandle, spawn},
33 };
34}
35
36#[cfg(target_family = "wasm")]
37mod sys {
38 use std::{
39 future::Future,
40 pin::Pin,
41 task::{Context, Poll},
42 };
43
44 pub use futures_util::future::AbortHandle;
45 use futures_util::{
46 FutureExt,
47 future::{Abortable, RemoteHandle},
48 };
49
50 #[derive(Debug)]
53 pub enum JoinError {
54 Cancelled,
55 Panic,
56 }
57
58 impl JoinError {
59 pub fn is_cancelled(&self) -> bool {
65 matches!(self, JoinError::Cancelled)
66 }
67 }
68
69 impl std::fmt::Display for JoinError {
70 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 match &self {
72 JoinError::Cancelled => write!(fmt, "task was cancelled"),
73 JoinError::Panic => write!(fmt, "task panicked"),
74 }
75 }
76 }
77
78 #[derive(Debug)]
81 pub struct JoinHandle<T> {
82 remote_handle: Option<RemoteHandle<T>>,
83 abort_handle: AbortHandle,
84 }
85
86 impl<T> JoinHandle<T> {
87 pub fn abort(&self) {
89 self.abort_handle.abort();
90 }
91
92 pub fn abort_handle(&self) -> AbortHandle {
95 self.abort_handle.clone()
96 }
97
98 pub fn is_finished(&self) -> bool {
100 self.abort_handle.is_aborted()
101 }
102 }
103
104 impl<T> Drop for JoinHandle<T> {
105 fn drop(&mut self) {
106 if let Some(h) = self.remote_handle.take() {
108 h.forget();
109 }
110 }
111 }
112
113 impl<T: 'static> Future for JoinHandle<T> {
114 type Output = Result<T, JoinError>;
115
116 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
117 if self.abort_handle.is_aborted() {
118 Poll::Ready(Err(JoinError::Cancelled))
120 } else if let Some(handle) = self.remote_handle.as_mut() {
121 Pin::new(handle).poll(cx).map(Ok)
122 } else {
123 Poll::Ready(Err(JoinError::Panic))
124 }
125 }
126 }
127
128 pub fn spawn<F, T>(future: F) -> JoinHandle<T>
131 where
132 F: Future<Output = T> + 'static,
133 {
134 let (future, remote_handle) = future.remote_handle();
135 let (abort_handle, abort_registration) = AbortHandle::new_pair();
136 let future = Abortable::new(future, abort_registration);
137
138 wasm_bindgen_futures::spawn_local(async {
139 let _ = future.await;
142 });
143
144 JoinHandle { remote_handle: Some(remote_handle), abort_handle }
145 }
146}
147
148pub use sys::*;
149
150#[derive(Debug)]
152pub struct AbortOnDrop<T>(JoinHandle<T>);
153
154impl<T> AbortOnDrop<T> {
155 pub fn new(join_handle: JoinHandle<T>) -> Self {
156 Self(join_handle)
157 }
158}
159
160impl<T> Drop for AbortOnDrop<T> {
161 fn drop(&mut self) {
162 self.0.abort();
163 }
164}
165
166impl<T: 'static> Future for AbortOnDrop<T> {
167 type Output = Result<T, JoinError>;
168
169 fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
170 Pin::new(&mut self.0).poll(context)
171 }
172}
173
174pub trait JoinHandleExt<T> {
176 fn abort_on_drop(self) -> AbortOnDrop<T>;
177}
178
179impl<T> JoinHandleExt<T> for JoinHandle<T> {
180 fn abort_on_drop(self) -> AbortOnDrop<T> {
181 AbortOnDrop::new(self)
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use assert_matches::assert_matches;
188 use matrix_sdk_test_macros::async_test;
189
190 use super::spawn;
191
192 #[async_test]
193 async fn test_spawn() {
194 let future = async { 42 };
195 let join_handle = spawn(future);
196
197 assert_matches!(join_handle.await, Ok(42));
198 }
199
200 #[async_test]
201 async fn test_abort() {
202 let future = async { 42 };
203 let join_handle = spawn(future);
204
205 join_handle.abort();
206
207 assert!(join_handle.await.is_err());
208 }
209}