matrix_sdk/event_handler/
mod.rs

1// Copyright 2021 Jonas Platte
2// Copyright 2022 Famedly GmbH
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Types and traits related for event handlers. For usage, see
17//! [`Client::add_event_handler`].
18//!
19//! ### How it works
20//!
21//! The `add_event_handler` method registers event handlers of different
22//! signatures by actually storing boxed closures that all have the same
23//! signature of `async (EventHandlerData) -> ()` where `EventHandlerData` is a
24//! private type that contains all of the data an event handler *might* need.
25//!
26//! The stored closure takes care of deserializing the event which the
27//! `EventHandlerData` contains as a (borrowed) [`serde_json::value::RawValue`],
28//! extracting the context arguments from other fields of `EventHandlerData` and
29//! calling / `.await`ing the event handler if the previous steps succeeded.
30//! It also logs any errors from the above chain of function calls.
31//!
32//! For more details, see the [`EventHandler`] trait.
33
34#[cfg(any(feature = "anyhow", feature = "eyre"))]
35use std::any::TypeId;
36use std::{
37    borrow::Cow,
38    fmt,
39    future::Future,
40    pin::Pin,
41    sync::{
42        atomic::{AtomicU64, Ordering::SeqCst},
43        Arc, RwLock, Weak,
44    },
45    task::{Context, Poll},
46};
47
48#[cfg(target_family = "wasm")]
49use anymap2::any::CloneAny;
50#[cfg(not(target_family = "wasm"))]
51use anymap2::any::CloneAnySendSync;
52use eyeball::{SharedObservable, Subscriber};
53use futures_core::Stream;
54use futures_util::stream::{FuturesUnordered, StreamExt};
55use matrix_sdk_base::{
56    deserialized_responses::{EncryptionInfo, TimelineEvent},
57    SendOutsideWasm, SyncOutsideWasm,
58};
59use matrix_sdk_common::deserialized_responses::ProcessedToDeviceEvent;
60use pin_project_lite::pin_project;
61use ruma::{events::AnySyncStateEvent, push::Action, serde::Raw, OwnedRoomId};
62use serde::{de::DeserializeOwned, Deserialize};
63use serde_json::value::RawValue as RawJsonValue;
64use tracing::{debug, error, field::debug, instrument, warn};
65
66use self::maps::EventHandlerMaps;
67use crate::{Client, Room};
68
69mod context;
70mod maps;
71mod static_events;
72
73pub use self::context::{Ctx, EventHandlerContext, RawEvent};
74
75#[cfg(not(target_family = "wasm"))]
76type EventHandlerFut = Pin<Box<dyn Future<Output = ()> + Send>>;
77#[cfg(target_family = "wasm")]
78type EventHandlerFut = Pin<Box<dyn Future<Output = ()>>>;
79
80#[cfg(not(target_family = "wasm"))]
81type EventHandlerFn = dyn Fn(EventHandlerData<'_>) -> EventHandlerFut + Send + Sync;
82#[cfg(target_family = "wasm")]
83type EventHandlerFn = dyn Fn(EventHandlerData<'_>) -> EventHandlerFut;
84
85#[cfg(not(target_family = "wasm"))]
86type AnyMap = anymap2::Map<dyn CloneAnySendSync + Send + Sync>;
87#[cfg(target_family = "wasm")]
88type AnyMap = anymap2::Map<dyn CloneAny>;
89
90#[derive(Default)]
91pub(crate) struct EventHandlerStore {
92    handlers: RwLock<EventHandlerMaps>,
93    context: RwLock<AnyMap>,
94    counter: AtomicU64,
95}
96
97impl EventHandlerStore {
98    pub fn add_handler(&self, handle: EventHandlerHandle, handler_fn: Box<EventHandlerFn>) {
99        self.handlers.write().unwrap().add(handle, handler_fn);
100    }
101
102    pub fn add_context<T>(&self, ctx: T)
103    where
104        T: Clone + Send + Sync + 'static,
105    {
106        self.context.write().unwrap().insert(ctx);
107    }
108
109    pub fn remove(&self, handle: EventHandlerHandle) {
110        self.handlers.write().unwrap().remove(handle);
111    }
112
113    #[cfg(test)]
114    fn len(&self) -> usize {
115        self.handlers.read().unwrap().len()
116    }
117}
118
119#[doc(hidden)]
120#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
121pub enum HandlerKind {
122    GlobalAccountData,
123    RoomAccountData,
124    EphemeralRoomData,
125    Timeline,
126    MessageLike,
127    OriginalMessageLike,
128    RedactedMessageLike,
129    State,
130    OriginalState,
131    RedactedState,
132    StrippedState,
133    ToDevice,
134    Presence,
135}
136
137impl HandlerKind {
138    fn message_like_redacted(redacted: bool) -> Self {
139        if redacted {
140            Self::RedactedMessageLike
141        } else {
142            Self::OriginalMessageLike
143        }
144    }
145
146    fn state_redacted(redacted: bool) -> Self {
147        if redacted {
148            Self::RedactedState
149        } else {
150            Self::OriginalState
151        }
152    }
153}
154
155/// A statically-known event kind/type that can be retrieved from an event sync.
156pub trait SyncEvent {
157    #[doc(hidden)]
158    const KIND: HandlerKind;
159    #[doc(hidden)]
160    const TYPE: Option<&'static str>;
161}
162
163pub(crate) struct EventHandlerWrapper {
164    handler_fn: Box<EventHandlerFn>,
165    pub handler_id: u64,
166}
167
168/// Handle to remove a registered event handler by passing it to
169/// [`Client::remove_event_handler`].
170#[derive(Clone, Debug)]
171pub struct EventHandlerHandle {
172    pub(crate) ev_kind: HandlerKind,
173    pub(crate) ev_type: Option<&'static str>,
174    pub(crate) room_id: Option<OwnedRoomId>,
175    pub(crate) handler_id: u64,
176}
177
178/// Interface for event handlers.
179///
180/// This trait is an abstraction for a certain kind of functions / closures,
181/// specifically:
182///
183/// * They must have at least one argument, which is the event itself, a type
184///   that implements [`SyncEvent`]. Any additional arguments need to implement
185///   the [`EventHandlerContext`] trait.
186/// * Their return type has to be one of: `()`, `Result<(), impl Display + Debug
187///   + 'static>` (if you are using `anyhow::Result` or `eyre::Result` you can
188///   additionally enable the `anyhow` / `eyre` feature to get the verbose
189///   `Debug` output printed on error)
190///
191/// ### How it works
192///
193/// This trait is basically a very constrained version of `Fn`: It requires at
194/// least one argument, which is represented as its own generic parameter `Ev`
195/// with the remaining parameter types being represented by the second generic
196/// parameter `Ctx`; they have to be stuffed into one generic parameter as a
197/// tuple because Rust doesn't have variadic generics.
198///
199/// `Ev` and `Ctx` are generic parameters rather than associated types because
200/// the argument list is a generic parameter for the `Fn` traits too, so a
201/// single type could implement `Fn` multiple times with different argument
202/// lists¹. Luckily, when calling [`Client::add_event_handler`] with a
203/// closure argument the trait solver takes into account that only a single one
204/// of the implementations applies (even though this could theoretically change
205/// through a dependency upgrade) and uses that rather than raising an ambiguity
206/// error. This is the same trick used by web frameworks like actix-web and
207/// axum.
208///
209/// ¹ the only thing stopping such types from existing in stable Rust is that
210/// all manual implementations of the `Fn` traits require a Nightly feature
211pub trait EventHandler<Ev, Ctx>: Clone + SendOutsideWasm + SyncOutsideWasm + 'static {
212    /// The future returned by `handle_event`.
213    #[doc(hidden)]
214    type Future: EventHandlerFuture;
215
216    /// Create a future for handling the given event.
217    ///
218    /// `data` provides additional data about the event, for example the room it
219    /// appeared in.
220    ///
221    /// Returns `None` if one of the context extractors failed.
222    #[doc(hidden)]
223    fn handle_event(self, ev: Ev, data: EventHandlerData<'_>) -> Option<Self::Future>;
224}
225
226#[doc(hidden)]
227pub trait EventHandlerFuture:
228    Future<Output = <Self as EventHandlerFuture>::Output> + SendOutsideWasm + 'static
229{
230    type Output: EventHandlerResult;
231}
232
233impl<T> EventHandlerFuture for T
234where
235    T: Future + SendOutsideWasm + 'static,
236    <T as Future>::Output: EventHandlerResult,
237{
238    type Output = <T as Future>::Output;
239}
240
241#[doc(hidden)]
242#[derive(Debug)]
243pub struct EventHandlerData<'a> {
244    client: Client,
245    room: Option<Room>,
246    raw: &'a RawJsonValue,
247    encryption_info: Option<&'a EncryptionInfo>,
248    push_actions: &'a [Action],
249    handle: EventHandlerHandle,
250}
251
252/// Return types supported for event handlers implement this trait.
253///
254/// It is not meant to be implemented outside of matrix-sdk.
255pub trait EventHandlerResult: Sized {
256    #[doc(hidden)]
257    fn print_error(&self, event_type: Option<&str>);
258}
259
260impl EventHandlerResult for () {
261    fn print_error(&self, _event_type: Option<&str>) {}
262}
263
264impl<E: fmt::Debug + fmt::Display + 'static> EventHandlerResult for Result<(), E> {
265    fn print_error(&self, event_type: Option<&str>) {
266        let msg_fragment = match event_type {
267            Some(event_type) => format!(" for `{event_type}`"),
268            None => "".to_owned(),
269        };
270
271        match self {
272            #[cfg(feature = "anyhow")]
273            Err(e) if TypeId::of::<E>() == TypeId::of::<anyhow::Error>() => {
274                error!("Event handler{msg_fragment} failed: {e:?}");
275            }
276            #[cfg(feature = "eyre")]
277            Err(e) if TypeId::of::<E>() == TypeId::of::<eyre::Report>() => {
278                error!("Event handler{msg_fragment} failed: {e:?}");
279            }
280            Err(e) => {
281                error!("Event handler{msg_fragment} failed: {e}");
282            }
283            Ok(_) => {}
284        }
285    }
286}
287
288#[derive(Deserialize)]
289struct UnsignedDetails {
290    redacted_because: Option<serde::de::IgnoredAny>,
291}
292
293/// Event handling internals.
294impl Client {
295    pub(crate) fn add_event_handler_impl<Ev, Ctx, H>(
296        &self,
297        handler: H,
298        room_id: Option<OwnedRoomId>,
299    ) -> EventHandlerHandle
300    where
301        Ev: SyncEvent + DeserializeOwned + SendOutsideWasm + 'static,
302        H: EventHandler<Ev, Ctx>,
303    {
304        let handler_fn: Box<EventHandlerFn> = Box::new(move |data| {
305            let maybe_fut = serde_json::from_str(data.raw.get())
306                .map(|ev| handler.clone().handle_event(ev, data));
307
308            Box::pin(async move {
309                match maybe_fut {
310                    Ok(Some(fut)) => {
311                        fut.await.print_error(Ev::TYPE);
312                    }
313                    Ok(None) => {
314                        error!(
315                            event_type = Ev::TYPE, event_kind = ?Ev::KIND,
316                            "Event handler has an invalid context argument",
317                        );
318                    }
319                    Err(e) => {
320                        warn!(
321                            event_type = Ev::TYPE, event_kind = ?Ev::KIND,
322                            "Failed to deserialize event, skipping event handler.\n
323                             Deserialization error: {e}",
324                        );
325                    }
326                }
327            })
328        });
329
330        let handler_id = self.inner.event_handlers.counter.fetch_add(1, SeqCst);
331        let handle =
332            EventHandlerHandle { ev_kind: Ev::KIND, ev_type: Ev::TYPE, room_id, handler_id };
333
334        self.inner.event_handlers.add_handler(handle.clone(), handler_fn);
335
336        handle
337    }
338
339    pub(crate) async fn handle_sync_events<T>(
340        &self,
341        kind: HandlerKind,
342        room: Option<&Room>,
343        events: &[Raw<T>],
344    ) -> serde_json::Result<()> {
345        #[derive(Deserialize)]
346        struct ExtractType<'a> {
347            #[serde(borrow, rename = "type")]
348            event_type: Cow<'a, str>,
349        }
350
351        for raw_event in events {
352            let event_type = raw_event.deserialize_as::<ExtractType<'_>>()?.event_type;
353            self.call_event_handlers(room, raw_event.json(), kind, &event_type, None, &[]).await;
354        }
355
356        Ok(())
357    }
358
359    pub(crate) async fn handle_sync_to_device_events(
360        &self,
361        events: &[ProcessedToDeviceEvent],
362    ) -> serde_json::Result<()> {
363        #[derive(Deserialize)]
364        struct ExtractType<'a> {
365            #[serde(borrow, rename = "type")]
366            event_type: Cow<'a, str>,
367        }
368
369        for processed_to_device in events {
370            let (raw_event, encryption_info) = match processed_to_device {
371                ProcessedToDeviceEvent::Decrypted { raw, encryption_info } => {
372                    (raw, Some(encryption_info))
373                }
374                other => (&other.to_raw(), None),
375            };
376            let event_type = raw_event.deserialize_as::<ExtractType<'_>>()?.event_type;
377            self.call_event_handlers(
378                None,
379                raw_event.json(),
380                HandlerKind::ToDevice,
381                &event_type,
382                encryption_info,
383                &[],
384            )
385            .await;
386        }
387
388        Ok(())
389    }
390
391    pub(crate) async fn handle_sync_state_events(
392        &self,
393        room: Option<&Room>,
394        state_events: &[Raw<AnySyncStateEvent>],
395    ) -> serde_json::Result<()> {
396        #[derive(Deserialize)]
397        struct StateEventDetails<'a> {
398            #[serde(borrow, rename = "type")]
399            event_type: Cow<'a, str>,
400            unsigned: Option<UnsignedDetails>,
401        }
402
403        // Event handlers for possibly-redacted state events
404        self.handle_sync_events(HandlerKind::State, room, state_events).await?;
405
406        // Event handlers specifically for redacted OR unredacted state events
407        for raw_event in state_events {
408            let StateEventDetails { event_type, unsigned } = raw_event.deserialize_as()?;
409            let redacted = unsigned.and_then(|u| u.redacted_because).is_some();
410            let handler_kind = HandlerKind::state_redacted(redacted);
411
412            self.call_event_handlers(room, raw_event.json(), handler_kind, &event_type, None, &[])
413                .await;
414        }
415
416        Ok(())
417    }
418
419    pub(crate) async fn handle_sync_timeline_events(
420        &self,
421        room: Option<&Room>,
422        timeline_events: &[TimelineEvent],
423    ) -> serde_json::Result<()> {
424        #[derive(Deserialize)]
425        struct TimelineEventDetails<'a> {
426            #[serde(borrow, rename = "type")]
427            event_type: Cow<'a, str>,
428            state_key: Option<serde::de::IgnoredAny>,
429            unsigned: Option<UnsignedDetails>,
430        }
431
432        for item in timeline_events {
433            let TimelineEventDetails { event_type, state_key, unsigned } =
434                item.raw().deserialize_as()?;
435
436            let redacted = unsigned.and_then(|u| u.redacted_because).is_some();
437            let (handler_kind_g, handler_kind_r) = match state_key {
438                Some(_) => (HandlerKind::State, HandlerKind::state_redacted(redacted)),
439                None => (HandlerKind::MessageLike, HandlerKind::message_like_redacted(redacted)),
440            };
441
442            let raw_event = item.raw().json();
443            let encryption_info = item.encryption_info().map(|i| &**i);
444            let push_actions = item.push_actions().unwrap_or(&[]);
445
446            // Event handlers for possibly-redacted timeline events
447            self.call_event_handlers(
448                room,
449                raw_event,
450                handler_kind_g,
451                &event_type,
452                encryption_info,
453                push_actions,
454            )
455            .await;
456
457            // Event handlers specifically for redacted OR unredacted timeline events
458            self.call_event_handlers(
459                room,
460                raw_event,
461                handler_kind_r,
462                &event_type,
463                encryption_info,
464                push_actions,
465            )
466            .await;
467
468            // Event handlers for `AnySyncTimelineEvent`
469            let kind = HandlerKind::Timeline;
470            self.call_event_handlers(
471                room,
472                raw_event,
473                kind,
474                &event_type,
475                encryption_info,
476                push_actions,
477            )
478            .await;
479        }
480
481        Ok(())
482    }
483
484    #[instrument(skip_all, fields(?event_kind, ?event_type, room_id))]
485    async fn call_event_handlers(
486        &self,
487        room: Option<&Room>,
488        raw: &RawJsonValue,
489        event_kind: HandlerKind,
490        event_type: &str,
491        encryption_info: Option<&EncryptionInfo>,
492        push_actions: &[Action],
493    ) {
494        let room_id = room.map(|r| r.room_id());
495        if let Some(room_id) = room_id {
496            tracing::Span::current().record("room_id", debug(room_id));
497        }
498
499        // Construct event handler futures
500        let mut futures: FuturesUnordered<_> = self
501            .inner
502            .event_handlers
503            .handlers
504            .read()
505            .unwrap()
506            .get_handlers(event_kind, event_type, room_id)
507            .map(|(handle, handler_fn)| {
508                let data = EventHandlerData {
509                    client: self.clone(),
510                    room: room.cloned(),
511                    raw,
512                    encryption_info,
513                    push_actions,
514                    handle,
515                };
516
517                (handler_fn)(data)
518            })
519            .collect();
520
521        if !futures.is_empty() {
522            debug!(amount = futures.len(), "Calling event handlers");
523
524            // Run the event handler futures with the `self.event_handlers.handlers`
525            // lock no longer being held.
526            while let Some(()) = futures.next().await {}
527        }
528    }
529}
530
531/// A guard type that removes an event handler when it drops (goes out of
532/// scope).
533///
534/// Created with [`Client::event_handler_drop_guard`].
535#[derive(Debug)]
536pub struct EventHandlerDropGuard {
537    handle: EventHandlerHandle,
538    client: Client,
539}
540
541impl EventHandlerDropGuard {
542    pub(crate) fn new(handle: EventHandlerHandle, client: Client) -> Self {
543        Self { handle, client }
544    }
545}
546
547impl Drop for EventHandlerDropGuard {
548    fn drop(&mut self) {
549        self.client.remove_event_handler(self.handle.clone());
550    }
551}
552
553macro_rules! impl_event_handler {
554    ($($ty:ident),* $(,)?) => {
555        impl<Ev, Fun, Fut, $($ty),*> EventHandler<Ev, ($($ty,)*)> for Fun
556        where
557            Ev: SyncEvent,
558            Fun: FnOnce(Ev, $($ty),*) -> Fut + Clone + SendOutsideWasm + SyncOutsideWasm + 'static,
559            Fut: EventHandlerFuture,
560            $($ty: EventHandlerContext),*
561        {
562            type Future = Fut;
563
564            fn handle_event(self, ev: Ev, _d: EventHandlerData<'_>) -> Option<Self::Future> {
565                Some((self)(ev, $($ty::from_data(&_d)?),*))
566            }
567        }
568    };
569}
570
571impl_event_handler!();
572impl_event_handler!(A);
573impl_event_handler!(A, B);
574impl_event_handler!(A, B, C);
575impl_event_handler!(A, B, C, D);
576impl_event_handler!(A, B, C, D, E);
577impl_event_handler!(A, B, C, D, E, F);
578impl_event_handler!(A, B, C, D, E, F, G);
579impl_event_handler!(A, B, C, D, E, F, G, H);
580
581/// An observer of events (may be tailored to a room).
582///
583/// Only the most recent value can be observed. Subscribers are notified when a
584/// new value is sent, but there is no guarantee that they will see all values.
585///
586/// To create such observer, use [`Client::observe_events`] or
587/// [`Client::observe_room_events`].
588#[derive(Debug)]
589pub struct ObservableEventHandler<T> {
590    /// This type is actually nothing more than a thin glue layer between the
591    /// [`EventHandler`] mechanism and the reactive programming types from
592    /// [`eyeball`]. Here, we use a [`SharedObservable`] that is updated by the
593    /// [`EventHandler`].
594    shared_observable: SharedObservable<Option<T>>,
595
596    /// This type owns the [`EventHandlerDropGuard`]. As soon as this type goes
597    /// out of scope, the event handler is unregistered/removed.
598    ///
599    /// [`EventHandlerSubscriber`] holds a weak, non-owning reference, to this
600    /// guard. It is useful to detect when to close the [`Stream`]: as soon as
601    /// this type goes out of scope, the subscriber will close itself on poll.
602    event_handler_guard: Arc<EventHandlerDropGuard>,
603}
604
605impl<T> ObservableEventHandler<T> {
606    pub(crate) fn new(
607        shared_observable: SharedObservable<Option<T>>,
608        event_handler_guard: EventHandlerDropGuard,
609    ) -> Self {
610        Self { shared_observable, event_handler_guard: Arc::new(event_handler_guard) }
611    }
612
613    /// Subscribe to this observer.
614    ///
615    /// It returns an [`EventHandlerSubscriber`], which implements [`Stream`].
616    /// See its documentation to learn more.
617    pub fn subscribe(&self) -> EventHandlerSubscriber<T> {
618        EventHandlerSubscriber::new(
619            self.shared_observable.subscribe(),
620            // The subscriber holds a weak non-owning reference to the event handler guard, so that
621            // it can detect when this observer is dropped, and can close the subscriber's stream.
622            Arc::downgrade(&self.event_handler_guard),
623        )
624    }
625}
626
627pin_project! {
628    /// The subscriber of an [`ObservableEventHandler`].
629    ///
630    /// To create such subscriber, use [`ObservableEventHandler::subscribe`].
631    ///
632    /// This type implements [`Stream`], which means it is possible to poll the
633    /// next value asynchronously. In other terms, polling this type will return
634    /// the new event as soon as they are synced. See [`Client::observe_events`]
635    /// to learn more.
636    #[derive(Debug)]
637    pub struct EventHandlerSubscriber<T> {
638        // The `Subscriber` associated to the `SharedObservable` inside
639        // `ObservableEventHandle`.
640        //
641        // Keep in mind all this API is just a thin glue layer between
642        // `EventHandle` and `SharedObservable`, that's… maagiic!
643        #[pin]
644        subscriber: Subscriber<Option<T>>,
645
646        // A weak non-owning reference to the event handler guard from
647        // `ObservableEventHandler`. When this type is polled (via its `Stream`
648        // implementation), it is possible to detect whether the observable has
649        // been dropped by upgrading this weak reference, and close the `Stream`
650        // if it needs to.
651        event_handler_guard: Weak<EventHandlerDropGuard>,
652    }
653}
654
655impl<T> EventHandlerSubscriber<T> {
656    fn new(
657        subscriber: Subscriber<Option<T>>,
658        event_handler_handle: Weak<EventHandlerDropGuard>,
659    ) -> Self {
660        Self { subscriber, event_handler_guard: event_handler_handle }
661    }
662}
663
664impl<T> Stream for EventHandlerSubscriber<T>
665where
666    T: Clone,
667{
668    type Item = T;
669
670    fn poll_next(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Option<Self::Item>> {
671        let mut this = self.project();
672
673        let Some(_) = this.event_handler_guard.upgrade() else {
674            // The `EventHandlerHandle` has been dropped via `EventHandlerDropGuard`. It
675            // means the `ObservableEventHandler` has been dropped. It's time to
676            // close this stream.
677            return Poll::Ready(None);
678        };
679
680        // First off, the subscriber is of type `Subscriber<Option<T>>` because the
681        // `SharedObservable` starts with a `None` value to indicate it has no yet
682        // received any update. We want the `Stream` to return `T`, not `Option<T>`. We
683        // then filter out all `None` value.
684        //
685        // Second, when a `None` value is met, we want to poll again (hence the `loop`).
686        // At best, there is a new value to return. At worst, the subscriber will return
687        // `Poll::Pending` and will register the wakers accordingly.
688
689        loop {
690            match this.subscriber.as_mut().poll_next(context) {
691                // Stream has been closed somehow.
692                Poll::Ready(None) => return Poll::Ready(None),
693
694                // The initial value (of the `SharedObservable` behind `self.subscriber`) has been
695                // polled. We want to filter it out.
696                Poll::Ready(Some(None)) => {
697                    // Loop over.
698                    continue;
699                }
700
701                // We have a new value!
702                Poll::Ready(Some(Some(value))) => return Poll::Ready(Some(value)),
703
704                // Classical pending.
705                Poll::Pending => return Poll::Pending,
706            }
707        }
708    }
709}
710
711#[cfg(test)]
712mod tests {
713    use matrix_sdk_test::{
714        async_test,
715        event_factory::{EventFactory, PreviousMembership},
716        InvitedRoomBuilder, JoinedRoomBuilder, DEFAULT_TEST_ROOM_ID,
717    };
718    use stream_assert::{assert_closed, assert_pending, assert_ready};
719    #[cfg(target_family = "wasm")]
720    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
721    use std::{
722        future,
723        sync::{
724            atomic::{AtomicU8, Ordering::SeqCst},
725            Arc,
726        },
727    };
728
729    use assert_matches2::assert_let;
730    use matrix_sdk_common::{deserialized_responses::EncryptionInfo, locks::Mutex};
731    use matrix_sdk_test::{StateTestEvent, StrippedStateTestEvent, SyncResponseBuilder};
732    use once_cell::sync::Lazy;
733    use ruma::{
734        event_id,
735        events::{
736            room::{
737                member::{MembershipState, OriginalSyncRoomMemberEvent, StrippedRoomMemberEvent},
738                name::OriginalSyncRoomNameEvent,
739                power_levels::OriginalSyncRoomPowerLevelsEvent,
740            },
741            typing::SyncTypingEvent,
742            AnySyncStateEvent, AnySyncTimelineEvent, AnyToDeviceEvent,
743        },
744        room_id,
745        serde::Raw,
746        user_id,
747    };
748    use serde_json::json;
749
750    use crate::{
751        event_handler::Ctx,
752        test_utils::{logged_in_client, no_retry_test_client},
753        Client, Room,
754    };
755
756    static MEMBER_EVENT: Lazy<Raw<AnySyncTimelineEvent>> = Lazy::new(|| {
757        EventFactory::new()
758            .member(user_id!("@example:localhost"))
759            .membership(MembershipState::Join)
760            .display_name("example")
761            .event_id(event_id!("$151800140517rfvjc:localhost"))
762            .previous(PreviousMembership::new(MembershipState::Invite).display_name("example"))
763            .into()
764    });
765
766    #[async_test]
767    async fn test_add_event_handler() -> crate::Result<()> {
768        let client = logged_in_client(None).await;
769
770        let member_count = Arc::new(AtomicU8::new(0));
771        let typing_count = Arc::new(AtomicU8::new(0));
772        let power_levels_count = Arc::new(AtomicU8::new(0));
773        let invited_member_count = Arc::new(AtomicU8::new(0));
774
775        client.add_event_handler({
776            let member_count = member_count.clone();
777            move |_ev: OriginalSyncRoomMemberEvent, _room: Room| async move {
778                member_count.fetch_add(1, SeqCst);
779            }
780        });
781        client.add_event_handler({
782            let typing_count = typing_count.clone();
783            move |_ev: SyncTypingEvent| async move {
784                typing_count.fetch_add(1, SeqCst);
785            }
786        });
787        client.add_event_handler({
788            let power_levels_count = power_levels_count.clone();
789            move |_ev: OriginalSyncRoomPowerLevelsEvent, _client: Client, _room: Room| async move {
790                power_levels_count.fetch_add(1, SeqCst);
791            }
792        });
793        client.add_event_handler({
794            let invited_member_count = invited_member_count.clone();
795            move |_ev: StrippedRoomMemberEvent| async move {
796                invited_member_count.fetch_add(1, SeqCst);
797            }
798        });
799
800        let f = EventFactory::new();
801        let response = SyncResponseBuilder::default()
802            .add_joined_room(
803                JoinedRoomBuilder::default()
804                    .add_timeline_event(MEMBER_EVENT.clone())
805                    .add_typing(
806                        f.typing(vec![user_id!("@alice:matrix.org"), user_id!("@bob:example.com")]),
807                    )
808                    .add_state_event(StateTestEvent::PowerLevels),
809            )
810            .add_invited_room(
811                InvitedRoomBuilder::new(room_id!("!test_invited:example.org")).add_state_event(
812                    StrippedStateTestEvent::Custom(json!({
813                        "content": {
814                            "avatar_url": "mxc://example.org/SEsfnsuifSDFSSEF",
815                            "displayname": "Alice",
816                            "membership": "invite",
817                        },
818                        "event_id": "$143273582443PhrSn:example.org",
819                        "origin_server_ts": 1432735824653u64,
820                        "room_id": "!jEsUZKDJdhlrceRyVU:example.org",
821                        "sender": "@example:example.org",
822                        "state_key": "@alice:example.org",
823                        "type": "m.room.member",
824                        "unsigned": {
825                            "age": 1234,
826                            "invite_room_state": [
827                                {
828                                    "content": {
829                                        "name": "Example Room"
830                                    },
831                                    "sender": "@bob:example.org",
832                                    "state_key": "",
833                                    "type": "m.room.name"
834                                },
835                                {
836                                    "content": {
837                                        "join_rule": "invite"
838                                    },
839                                    "sender": "@bob:example.org",
840                                    "state_key": "",
841                                    "type": "m.room.join_rules"
842                                }
843                            ]
844                        }
845                    })),
846                ),
847            )
848            .build_sync_response();
849        client.process_sync(response).await?;
850
851        assert_eq!(member_count.load(SeqCst), 1);
852        assert_eq!(typing_count.load(SeqCst), 1);
853        assert_eq!(power_levels_count.load(SeqCst), 1);
854        assert_eq!(invited_member_count.load(SeqCst), 1);
855
856        Ok(())
857    }
858
859    #[async_test]
860    #[allow(dependency_on_unit_never_type_fallback)]
861    async fn test_add_to_device_event_handler() -> crate::Result<()> {
862        let client = logged_in_client(None).await;
863
864        let captured_event: Arc<Mutex<Option<AnyToDeviceEvent>>> = Arc::new(Mutex::new(None));
865        let captured_info: Arc<Mutex<Option<EncryptionInfo>>> = Arc::new(Mutex::new(None));
866
867        client.add_event_handler({
868            let captured = captured_event.clone();
869            let captured_info = captured_info.clone();
870            move |ev: AnyToDeviceEvent, encryption_info: Option<EncryptionInfo>| {
871                let mut captured_lock = captured.lock();
872                *captured_lock = Some(ev);
873                let mut captured_info_lock = captured_info.lock();
874                *captured_info_lock = encryption_info;
875                future::ready(())
876            }
877        });
878
879        let response = SyncResponseBuilder::default()
880            .add_to_device_event(json!({
881              "sender": "@alice:example.com",
882              "type": "m.custom.to.device.type",
883              "content": {
884                "a": "test",
885              }
886            }))
887            .build_sync_response();
888        client.process_sync(response).await?;
889
890        let captured = captured_event.lock().clone();
891        assert_let!(Some(received_event) = captured);
892        assert_eq!(received_event.event_type().to_string(), "m.custom.to.device.type");
893        let info = captured_info.lock().clone();
894        assert!(info.is_none());
895        Ok(())
896    }
897
898    #[async_test]
899    #[allow(dependency_on_unit_never_type_fallback)]
900    async fn test_add_room_event_handler() -> crate::Result<()> {
901        let client = logged_in_client(None).await;
902
903        let room_id_a = room_id!("!foo:example.org");
904        let room_id_b = room_id!("!bar:matrix.org");
905
906        let member_count = Arc::new(AtomicU8::new(0));
907        let power_levels_count = Arc::new(AtomicU8::new(0));
908
909        // Room event handlers for member events in both rooms
910        client.add_room_event_handler(room_id_a, {
911            let member_count = member_count.clone();
912            move |_ev: OriginalSyncRoomMemberEvent, _room: Room| {
913                member_count.fetch_add(1, SeqCst);
914                future::ready(())
915            }
916        });
917        client.add_room_event_handler(room_id_b, {
918            let member_count = member_count.clone();
919            move |_ev: OriginalSyncRoomMemberEvent, _room: Room| {
920                member_count.fetch_add(1, SeqCst);
921                future::ready(())
922            }
923        });
924
925        // Power levels event handlers for member events in room A
926        client.add_room_event_handler(room_id_a, {
927            let power_levels_count = power_levels_count.clone();
928            move |_ev: OriginalSyncRoomPowerLevelsEvent, _client: Client, _room: Room| {
929                power_levels_count.fetch_add(1, SeqCst);
930                future::ready(())
931            }
932        });
933
934        // Room name event handler for room name events in room B
935        client.add_room_event_handler(room_id_b, move |_ev: OriginalSyncRoomNameEvent| async {
936            unreachable!("No room event in room B")
937        });
938
939        let response = SyncResponseBuilder::default()
940            .add_joined_room(
941                JoinedRoomBuilder::new(room_id_a)
942                    .add_timeline_event(MEMBER_EVENT.clone())
943                    .add_state_event(StateTestEvent::PowerLevels)
944                    .add_state_event(StateTestEvent::RoomName),
945            )
946            .add_joined_room(
947                JoinedRoomBuilder::new(room_id_b)
948                    .add_timeline_event(MEMBER_EVENT.clone())
949                    .add_state_event(StateTestEvent::PowerLevels),
950            )
951            .build_sync_response();
952        client.process_sync(response).await?;
953
954        assert_eq!(member_count.load(SeqCst), 2);
955        assert_eq!(power_levels_count.load(SeqCst), 1);
956
957        Ok(())
958    }
959
960    #[async_test]
961    #[allow(dependency_on_unit_never_type_fallback)]
962    async fn test_add_event_handler_with_tuples() -> crate::Result<()> {
963        let client = logged_in_client(None).await;
964
965        client.add_event_handler(
966            |_ev: OriginalSyncRoomMemberEvent, (_room, _client): (Room, Client)| future::ready(()),
967        );
968
969        // If it compiles, it works. No need to assert anything.
970
971        Ok(())
972    }
973
974    #[async_test]
975    #[allow(dependency_on_unit_never_type_fallback)]
976    async fn test_remove_event_handler() -> crate::Result<()> {
977        let client = logged_in_client(None).await;
978
979        let member_count = Arc::new(AtomicU8::new(0));
980
981        client.add_event_handler({
982            let member_count = member_count.clone();
983            move |_ev: OriginalSyncRoomMemberEvent| async move {
984                member_count.fetch_add(1, SeqCst);
985            }
986        });
987
988        let handle_a = client.add_event_handler(move |_ev: OriginalSyncRoomMemberEvent| async {
989            panic!("handler should have been removed");
990        });
991        let handle_b = client.add_room_event_handler(
992            #[allow(unknown_lints, clippy::explicit_auto_deref)] // lint is buggy
993            *DEFAULT_TEST_ROOM_ID,
994            move |_ev: OriginalSyncRoomMemberEvent| async {
995                panic!("handler should have been removed");
996            },
997        );
998
999        client.add_event_handler({
1000            let member_count = member_count.clone();
1001            move |_ev: OriginalSyncRoomMemberEvent| async move {
1002                member_count.fetch_add(1, SeqCst);
1003            }
1004        });
1005
1006        let response = SyncResponseBuilder::default()
1007            .add_joined_room(JoinedRoomBuilder::default().add_timeline_event(MEMBER_EVENT.clone()))
1008            .build_sync_response();
1009
1010        client.remove_event_handler(handle_a);
1011        client.remove_event_handler(handle_b);
1012
1013        client.process_sync(response).await?;
1014
1015        assert_eq!(member_count.load(SeqCst), 2);
1016
1017        Ok(())
1018    }
1019
1020    #[async_test]
1021    async fn test_event_handler_drop_guard() {
1022        let client = no_retry_test_client(None).await;
1023
1024        let handle = client.add_event_handler(|_ev: OriginalSyncRoomMemberEvent| async {});
1025        assert_eq!(client.inner.event_handlers.len(), 1);
1026
1027        {
1028            let _guard = client.event_handler_drop_guard(handle);
1029            assert_eq!(client.inner.event_handlers.len(), 1);
1030            // guard dropped here
1031        }
1032
1033        assert_eq!(client.inner.event_handlers.len(), 0);
1034    }
1035
1036    #[async_test]
1037    async fn test_use_client_in_handler() {
1038        // This used to not work because we were requiring `Send` of event
1039        // handler futures even on WASM, where practically all futures that do
1040        // I/O aren't.
1041        let client = no_retry_test_client(None).await;
1042
1043        client.add_event_handler(|_ev: OriginalSyncRoomMemberEvent, client: Client| async move {
1044            // All of Client's async methods that do network requests (and
1045            // possibly some that don't) are `!Send` on wasm. We obviously want
1046            // to be able to use them in event handlers.
1047            let _caps = client.get_capabilities().await.map_err(|e| anyhow::anyhow!("{}", e))?;
1048            anyhow::Ok(())
1049        });
1050    }
1051
1052    #[async_test]
1053    async fn test_raw_event_handler() -> crate::Result<()> {
1054        let client = logged_in_client(None).await;
1055        let counter = Arc::new(AtomicU8::new(0));
1056        client.add_event_handler_context(counter.clone());
1057        client.add_event_handler(
1058            |_ev: Raw<OriginalSyncRoomMemberEvent>, counter: Ctx<Arc<AtomicU8>>| async move {
1059                counter.fetch_add(1, SeqCst);
1060            },
1061        );
1062
1063        let response = SyncResponseBuilder::default()
1064            .add_joined_room(JoinedRoomBuilder::default().add_timeline_event(MEMBER_EVENT.clone()))
1065            .build_sync_response();
1066        client.process_sync(response).await?;
1067
1068        assert_eq!(counter.load(SeqCst), 1);
1069        Ok(())
1070    }
1071
1072    #[async_test]
1073    async fn test_enum_event_handler() -> crate::Result<()> {
1074        let client = logged_in_client(None).await;
1075        let counter = Arc::new(AtomicU8::new(0));
1076        client.add_event_handler_context(counter.clone());
1077        client.add_event_handler(
1078            |_ev: AnySyncStateEvent, counter: Ctx<Arc<AtomicU8>>| async move {
1079                counter.fetch_add(1, SeqCst);
1080            },
1081        );
1082
1083        let response = SyncResponseBuilder::default()
1084            .add_joined_room(JoinedRoomBuilder::default().add_timeline_event(MEMBER_EVENT.clone()))
1085            .build_sync_response();
1086        client.process_sync(response).await?;
1087
1088        assert_eq!(counter.load(SeqCst), 1);
1089        Ok(())
1090    }
1091
1092    #[async_test]
1093    #[allow(dependency_on_unit_never_type_fallback)]
1094    async fn test_observe_events() -> crate::Result<()> {
1095        let client = logged_in_client(None).await;
1096
1097        let room_id_0 = room_id!("!r0.matrix.org");
1098        let room_id_1 = room_id!("!r1.matrix.org");
1099
1100        let observable = client.observe_events::<OriginalSyncRoomNameEvent, Room>();
1101
1102        let mut subscriber = observable.subscribe();
1103
1104        assert_pending!(subscriber);
1105
1106        let mut response_builder = SyncResponseBuilder::new();
1107        let response = response_builder
1108            .add_joined_room(JoinedRoomBuilder::new(room_id_0).add_state_event(
1109                StateTestEvent::Custom(json!({
1110                    "content": {
1111                        "name": "Name 0"
1112                    },
1113                    "event_id": "$ev0",
1114                    "origin_server_ts": 1,
1115                    "sender": "@mnt_io:matrix.org",
1116                    "state_key": "",
1117                    "type": "m.room.name",
1118                    "unsigned": {
1119                        "age": 1,
1120                    }
1121                })),
1122            ))
1123            .build_sync_response();
1124        client.process_sync(response).await?;
1125
1126        let (room_name, room) = assert_ready!(subscriber);
1127
1128        assert_eq!(room_name.event_id.as_str(), "$ev0");
1129        assert_eq!(room.room_id(), room_id_0);
1130        assert_eq!(room.name().unwrap(), "Name 0");
1131
1132        assert_pending!(subscriber);
1133
1134        let response = response_builder
1135            .add_joined_room(JoinedRoomBuilder::new(room_id_1).add_state_event(
1136                StateTestEvent::Custom(json!({
1137                    "content": {
1138                        "name": "Name 1"
1139                    },
1140                    "event_id": "$ev1",
1141                    "origin_server_ts": 2,
1142                    "sender": "@mnt_io:matrix.org",
1143                    "state_key": "",
1144                    "type": "m.room.name",
1145                    "unsigned": {
1146                        "age": 2,
1147                    }
1148                })),
1149            ))
1150            .build_sync_response();
1151        client.process_sync(response).await?;
1152
1153        let (room_name, room) = assert_ready!(subscriber);
1154
1155        assert_eq!(room_name.event_id.as_str(), "$ev1");
1156        assert_eq!(room.room_id(), room_id_1);
1157        assert_eq!(room.name().unwrap(), "Name 1");
1158
1159        assert_pending!(subscriber);
1160
1161        drop(observable);
1162        assert_closed!(subscriber);
1163
1164        Ok(())
1165    }
1166
1167    #[async_test]
1168    #[allow(dependency_on_unit_never_type_fallback)]
1169    async fn test_observe_room_events() -> crate::Result<()> {
1170        let client = logged_in_client(None).await;
1171
1172        let room_id = room_id!("!r0.matrix.org");
1173
1174        let observable_for_room =
1175            client.observe_room_events::<OriginalSyncRoomNameEvent, (Room, Client)>(room_id);
1176
1177        let mut subscriber_for_room = observable_for_room.subscribe();
1178
1179        assert_pending!(subscriber_for_room);
1180
1181        let mut response_builder = SyncResponseBuilder::new();
1182        let response = response_builder
1183            .add_joined_room(JoinedRoomBuilder::new(room_id).add_state_event(
1184                StateTestEvent::Custom(json!({
1185                    "content": {
1186                        "name": "Name 0"
1187                    },
1188                    "event_id": "$ev0",
1189                    "origin_server_ts": 1,
1190                    "sender": "@mnt_io:matrix.org",
1191                    "state_key": "",
1192                    "type": "m.room.name",
1193                    "unsigned": {
1194                        "age": 1,
1195                    }
1196                })),
1197            ))
1198            .build_sync_response();
1199        client.process_sync(response).await?;
1200
1201        let (room_name, (room, _client)) = assert_ready!(subscriber_for_room);
1202
1203        assert_eq!(room_name.event_id.as_str(), "$ev0");
1204        assert_eq!(room.name().unwrap(), "Name 0");
1205
1206        assert_pending!(subscriber_for_room);
1207
1208        let response = response_builder
1209            .add_joined_room(JoinedRoomBuilder::new(room_id).add_state_event(
1210                StateTestEvent::Custom(json!({
1211                    "content": {
1212                        "name": "Name 1"
1213                    },
1214                    "event_id": "$ev1",
1215                    "origin_server_ts": 2,
1216                    "sender": "@mnt_io:matrix.org",
1217                    "state_key": "",
1218                    "type": "m.room.name",
1219                    "unsigned": {
1220                        "age": 2,
1221                    }
1222                })),
1223            ))
1224            .build_sync_response();
1225        client.process_sync(response).await?;
1226
1227        let (room_name, (room, _client)) = assert_ready!(subscriber_for_room);
1228
1229        assert_eq!(room_name.event_id.as_str(), "$ev1");
1230        assert_eq!(room.name().unwrap(), "Name 1");
1231
1232        assert_pending!(subscriber_for_room);
1233
1234        drop(observable_for_room);
1235        assert_closed!(subscriber_for_room);
1236
1237        Ok(())
1238    }
1239
1240    #[async_test]
1241    async fn test_observe_several_room_events() -> crate::Result<()> {
1242        let client = logged_in_client(None).await;
1243
1244        let room_id = room_id!("!r0.matrix.org");
1245
1246        let observable_for_room =
1247            client.observe_room_events::<OriginalSyncRoomNameEvent, (Room, Client)>(room_id);
1248
1249        let mut subscriber_for_room = observable_for_room.subscribe();
1250
1251        assert_pending!(subscriber_for_room);
1252
1253        let mut response_builder = SyncResponseBuilder::new();
1254        let response = response_builder
1255            .add_joined_room(
1256                JoinedRoomBuilder::new(room_id)
1257                    .add_state_event(StateTestEvent::Custom(json!({
1258                        "content": {
1259                            "name": "Name 0"
1260                        },
1261                        "event_id": "$ev0",
1262                        "origin_server_ts": 1,
1263                        "sender": "@mnt_io:matrix.org",
1264                        "state_key": "",
1265                        "type": "m.room.name",
1266                        "unsigned": {
1267                            "age": 1,
1268                        }
1269                    })))
1270                    .add_state_event(StateTestEvent::Custom(json!({
1271                        "content": {
1272                            "name": "Name 1"
1273                        },
1274                        "event_id": "$ev1",
1275                        "origin_server_ts": 2,
1276                        "sender": "@mnt_io:matrix.org",
1277                        "state_key": "",
1278                        "type": "m.room.name",
1279                        "unsigned": {
1280                            "age": 1,
1281                        }
1282                    })))
1283                    .add_state_event(StateTestEvent::Custom(json!({
1284                        "content": {
1285                            "name": "Name 2"
1286                        },
1287                        "event_id": "$ev2",
1288                        "origin_server_ts": 3,
1289                        "sender": "@mnt_io:matrix.org",
1290                        "state_key": "",
1291                        "type": "m.room.name",
1292                        "unsigned": {
1293                            "age": 1,
1294                        }
1295                    }))),
1296            )
1297            .build_sync_response();
1298        client.process_sync(response).await?;
1299
1300        let (room_name, (room, _client)) = assert_ready!(subscriber_for_room);
1301
1302        // Check we only get notified about the latest received event
1303        assert_eq!(room_name.event_id.as_str(), "$ev2");
1304        assert_eq!(room.name().unwrap(), "Name 2");
1305
1306        assert_pending!(subscriber_for_room);
1307
1308        drop(observable_for_room);
1309        assert_closed!(subscriber_for_room);
1310
1311        Ok(())
1312    }
1313}