matrix_sdk_common/
executor.rs1#[cfg(target_arch = "wasm32")]
19use std::{
20 future::Future,
21 pin::Pin,
22 task::{Context, Poll},
23};
24
25#[cfg(target_arch = "wasm32")]
26pub use futures_util::future::Aborted as JoinError;
27#[cfg(target_arch = "wasm32")]
28use futures_util::{
29 future::{AbortHandle, Abortable, RemoteHandle},
30 FutureExt,
31};
32#[cfg(not(target_arch = "wasm32"))]
33pub use tokio::task::{spawn, JoinError, JoinHandle};
34
35#[cfg(target_arch = "wasm32")]
36pub fn spawn<F, T>(future: F) -> JoinHandle<T>
37where
38 F: Future<Output = T> + 'static,
39{
40 let (future, remote_handle) = future.remote_handle();
41 let (abort_handle, abort_registration) = AbortHandle::new_pair();
42 let future = Abortable::new(future, abort_registration);
43
44 wasm_bindgen_futures::spawn_local(async {
45 let _ = future.await;
48 });
49
50 JoinHandle { remote_handle, abort_handle }
51}
52
53#[cfg(target_arch = "wasm32")]
54#[derive(Debug)]
55pub struct JoinHandle<T> {
56 remote_handle: RemoteHandle<T>,
57 abort_handle: AbortHandle,
58}
59
60#[cfg(target_arch = "wasm32")]
61impl<T> JoinHandle<T> {
62 pub fn abort(&self) {
63 self.abort_handle.abort();
64 }
65
66 pub fn is_finished(&self) -> bool {
67 self.abort_handle.is_aborted()
68 }
69}
70
71#[cfg(target_arch = "wasm32")]
72impl<T: 'static> Future for JoinHandle<T> {
73 type Output = Result<T, JoinError>;
74
75 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
76 if self.abort_handle.is_aborted() {
77 Poll::Ready(Err(JoinError))
79 } else {
80 Pin::new(&mut self.remote_handle).poll(cx).map(Ok)
81 }
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use assert_matches::assert_matches;
88 use matrix_sdk_test_macros::async_test;
89
90 use super::spawn;
91
92 #[async_test]
93 async fn test_spawn() {
94 let future = async { 42 };
95 let join_handle = spawn(future);
96
97 assert_matches!(join_handle.await, Ok(42));
98 }
99
100 #[async_test]
101 async fn test_abort() {
102 let future = async { 42 };
103 let join_handle = spawn(future);
104
105 join_handle.abort();
106
107 assert!(join_handle.await.is_err());
108 }
109}