matrix_sdk/event_cache/
paginator.rs

1// Copyright 2024 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! The paginator is a stateful helper object that handles reaching an event,
16//! either from a cache or network, and surrounding events ("context"). Then, it
17//! makes it possible to paginate forward or backward, from that event, until
18//! one end of the timeline (front or back) is reached.
19
20use std::{future::Future, sync::Mutex};
21
22use eyeball::{SharedObservable, Subscriber};
23use matrix_sdk_base::{deserialized_responses::TimelineEvent, SendOutsideWasm, SyncOutsideWasm};
24use ruma::{api::Direction, EventId, OwnedEventId, UInt};
25
26use super::pagination::PaginationToken;
27use crate::{
28    room::{EventWithContextResponse, Messages, MessagesOptions, WeakRoom},
29    Room,
30};
31
32/// Current state of a [`Paginator`].
33#[derive(Debug, PartialEq, Copy, Clone)]
34#[cfg_attr(feature = "uniffi", derive(uniffi::Enum))]
35pub enum PaginatorState {
36    /// The initial state of the paginator.
37    Initial,
38
39    /// The paginator is fetching the target initial event.
40    FetchingTargetEvent,
41
42    /// The target initial event could be found, zero or more paginations have
43    /// happened since then, and the paginator is at rest now.
44    Idle,
45
46    /// The paginator is… paginating one direction or another.
47    Paginating,
48}
49
50/// An error that happened when using a [`Paginator`].
51#[derive(Debug, thiserror::Error)]
52pub enum PaginatorError {
53    /// The target event could not be found.
54    #[error("target event with id {0} could not be found")]
55    EventNotFound(OwnedEventId),
56
57    /// We're trying to manipulate the paginator in the wrong state.
58    #[error("expected paginator state {expected:?}, observed {actual:?}")]
59    InvalidPreviousState {
60        /// The state we were expecting to see.
61        expected: PaginatorState,
62        /// The actual state when doing the check.
63        actual: PaginatorState,
64    },
65
66    /// There was another SDK error while paginating.
67    #[error("an error happened while paginating: {0}")]
68    SdkError(#[from] Box<crate::Error>),
69}
70
71/// Paginations tokens used for backward and forward pagination.
72#[derive(Debug)]
73struct PaginationTokens {
74    /// Pagination token used for backward pagination.
75    previous: PaginationToken,
76    /// Pagination token used for forward pagination.
77    next: PaginationToken,
78}
79
80/// A stateful object to reach to an event, and then paginate backward and
81/// forward from it.
82///
83/// See also the module-level documentation.
84pub struct Paginator<PR: PaginableRoom> {
85    /// The room in which we're going to run the pagination.
86    room: PR,
87
88    /// Current state of the paginator.
89    state: SharedObservable<PaginatorState>,
90
91    /// Pagination tokens used for subsequent requests.
92    ///
93    /// This mutex is always short-lived, so it's sync.
94    tokens: Mutex<PaginationTokens>,
95}
96
97#[cfg(not(tarpaulin_include))]
98impl<PR: PaginableRoom> std::fmt::Debug for Paginator<PR> {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        // Don't include the room in the debug output.
101        f.debug_struct("Paginator")
102            .field("state", &self.state.get())
103            .field("tokens", &self.tokens)
104            .finish_non_exhaustive()
105    }
106}
107
108/// The result of a single pagination, be it from
109/// [`Paginator::paginate_backward`] or [`Paginator::paginate_forward`].
110#[derive(Debug)]
111pub struct PaginationResult {
112    /// Events returned during this pagination.
113    ///
114    /// If this is the result of a backward pagination, then the events are in
115    /// reverse topological order.
116    ///
117    /// If this is the result of a forward pagination, then the events are in
118    /// topological order.
119    pub events: Vec<TimelineEvent>,
120
121    /// Did we hit *an* end of the timeline?
122    ///
123    /// If this is the result of a backward pagination, this means we hit the
124    /// *start* of the timeline.
125    ///
126    /// If this is the result of a forward pagination, this means we hit the
127    /// *end* of the timeline.
128    pub hit_end_of_timeline: bool,
129}
130
131/// The result of an initial [`Paginator::start_from`] query.
132#[derive(Debug)]
133pub struct StartFromResult {
134    /// All the events returned during this pagination, in topological ordering.
135    pub events: Vec<TimelineEvent>,
136
137    /// Whether the /context query returned a previous batch token.
138    pub has_prev: bool,
139
140    /// Whether the /context query returned a next batch token.
141    pub has_next: bool,
142}
143
144/// Reset the state to a given target on drop.
145struct ResetStateGuard {
146    target: Option<PaginatorState>,
147    state: SharedObservable<PaginatorState>,
148}
149
150impl ResetStateGuard {
151    /// Create a new reset state guard.
152    fn new(state: SharedObservable<PaginatorState>, target: PaginatorState) -> Self {
153        Self { target: Some(target), state }
154    }
155
156    /// Render the guard effectless, and consume it.
157    fn disarm(mut self) {
158        self.target = None;
159    }
160}
161
162impl Drop for ResetStateGuard {
163    fn drop(&mut self) {
164        if let Some(target) = self.target.take() {
165            self.state.set_if_not_eq(target);
166        }
167    }
168}
169
170impl<PR: PaginableRoom> Paginator<PR> {
171    /// Create a new [`Paginator`], given a room implementation.
172    pub fn new(room: PR) -> Self {
173        Self {
174            room,
175            state: SharedObservable::new(PaginatorState::Initial),
176            tokens: Mutex::new(PaginationTokens { previous: None.into(), next: None.into() }),
177        }
178    }
179
180    /// Check if the current state of the paginator matches the expected one.
181    fn check_state(&self, expected: PaginatorState) -> Result<(), PaginatorError> {
182        let actual = self.state.get();
183        if actual != expected {
184            Err(PaginatorError::InvalidPreviousState { expected, actual })
185        } else {
186            Ok(())
187        }
188    }
189
190    /// Returns a subscriber to the internal [`PaginatorState`] machine.
191    pub fn state(&self) -> Subscriber<PaginatorState> {
192        self.state.subscribe()
193    }
194
195    /// Prepares the paginator to be in the idle state, ready for backwards- and
196    /// forwards- pagination.
197    ///
198    /// Will return an `InvalidPreviousState` error if the paginator is busy
199    /// (running /context or /messages).
200    pub(super) fn set_idle_state(
201        &self,
202        next_state: PaginatorState,
203        prev_batch_token: Option<String>,
204        next_batch_token: Option<String>,
205    ) -> Result<(), PaginatorError> {
206        let prev_state = self.state.get();
207
208        match next_state {
209            PaginatorState::Initial | PaginatorState::Idle => {}
210            PaginatorState::FetchingTargetEvent | PaginatorState::Paginating => {
211                panic!("internal error: set_idle_state only accept Initial|Idle next states");
212            }
213        }
214
215        match prev_state {
216            PaginatorState::Initial | PaginatorState::Idle => {}
217            PaginatorState::FetchingTargetEvent | PaginatorState::Paginating => {
218                // The paginator was busy. Don't interrupt it.
219                return Err(PaginatorError::InvalidPreviousState {
220                    // Technically it's initial OR idle, but we don't really care here.
221                    expected: PaginatorState::Idle,
222                    actual: prev_state,
223                });
224            }
225        }
226
227        self.state.set_if_not_eq(next_state);
228
229        {
230            let mut tokens = self.tokens.lock().unwrap();
231            tokens.previous = prev_batch_token.into();
232            tokens.next = next_batch_token.into();
233        }
234
235        Ok(())
236    }
237
238    /// Returns the current previous batch token, as stored in this paginator.
239    pub(super) fn prev_batch_token(&self) -> Option<String> {
240        match &self.tokens.lock().unwrap().previous {
241            PaginationToken::HitEnd | PaginationToken::None => None,
242            PaginationToken::HasMore(token) => Some(token.clone()),
243        }
244    }
245
246    /// Starts the pagination from the initial event, requesting `num_events`
247    /// additional context events.
248    ///
249    /// Only works for fresh [`Paginator`] objects, which are in the
250    /// [`PaginatorState::Initial`] state.
251    pub async fn start_from(
252        &self,
253        event_id: &EventId,
254        num_events: UInt,
255    ) -> Result<StartFromResult, PaginatorError> {
256        self.check_state(PaginatorState::Initial)?;
257
258        // Note: it's possible two callers have checked the state and both figured it's
259        // initial. This check below makes sure there's at most one which can set the
260        // state to FetchingTargetEvent, preventing a race condition.
261        if self.state.set_if_not_eq(PaginatorState::FetchingTargetEvent).is_none() {
262            return Err(PaginatorError::InvalidPreviousState {
263                expected: PaginatorState::Initial,
264                actual: PaginatorState::FetchingTargetEvent,
265            });
266        }
267
268        let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Initial);
269
270        // TODO: do we want to lazy load members?
271        let lazy_load_members = true;
272
273        let response =
274            self.room.event_with_context(event_id, lazy_load_members, num_events).await?;
275
276        // NOTE: it's super important to not have any `await` after this point, since we
277        // don't want the task to be interrupted anymore, or the internal state
278        // may become incorrect.
279
280        let has_prev = response.prev_batch_token.is_some();
281        let has_next = response.next_batch_token.is_some();
282
283        {
284            let mut tokens = self.tokens.lock().unwrap();
285            tokens.previous = match response.prev_batch_token {
286                Some(token) => PaginationToken::HasMore(token),
287                None => PaginationToken::HitEnd,
288            };
289            tokens.next = match response.next_batch_token {
290                Some(token) => PaginationToken::HasMore(token),
291                None => PaginationToken::HitEnd,
292            };
293        }
294
295        // Forget the reset state guard, so its Drop method is not called.
296        reset_state_guard.disarm();
297        // And set the final state.
298        self.state.set(PaginatorState::Idle);
299
300        // Consolidate the events into a linear timeline, topologically ordered.
301        // - the events before are returned in the reverse topological order: invert
302        //   them.
303        // - insert the target event, if set.
304        // - the events after are returned in the correct topological order.
305
306        let events = response
307            .events_before
308            .into_iter()
309            .rev()
310            .chain(response.event)
311            .chain(response.events_after)
312            .collect();
313
314        Ok(StartFromResult { events, has_prev, has_next })
315    }
316
317    /// Runs a backward pagination (requesting `num_events` to the server), from
318    /// the current state of the object.
319    ///
320    /// Will return immediately if we have already hit the start of the
321    /// timeline.
322    ///
323    /// May return an error if it's already paginating, or if the call to
324    /// /messages failed.
325    pub async fn paginate_backward(
326        &self,
327        num_events: UInt,
328    ) -> Result<PaginationResult, PaginatorError> {
329        self.paginate(Direction::Backward, num_events).await
330    }
331
332    /// Returns whether we've hit the start of the timeline.
333    ///
334    /// This is true if, and only if, we didn't have a previous-batch token and
335    /// running backwards pagination would be useless.
336    pub fn hit_timeline_start(&self) -> bool {
337        matches!(self.tokens.lock().unwrap().previous, PaginationToken::HitEnd)
338    }
339
340    /// Returns whether we've hit the end of the timeline.
341    ///
342    /// This is true if, and only if, we didn't have a next-batch token and
343    /// running forwards pagination would be useless.
344    pub fn hit_timeline_end(&self) -> bool {
345        matches!(self.tokens.lock().unwrap().next, PaginationToken::HitEnd)
346    }
347
348    /// Runs a forward pagination (requesting `num_events` to the server), from
349    /// the current state of the object.
350    ///
351    /// Will return immediately if we have already hit the end of the timeline.
352    ///
353    /// May return an error if it's already paginating, or if the call to
354    /// /messages failed.
355    pub async fn paginate_forward(
356        &self,
357        num_events: UInt,
358    ) -> Result<PaginationResult, PaginatorError> {
359        self.paginate(Direction::Forward, num_events).await
360    }
361
362    /// Paginate in the given direction, requesting `num_events` events to the
363    /// server, using the `token_lock` to read from and write the pagination
364    /// token.
365    async fn paginate(
366        &self,
367        dir: Direction,
368        num_events: UInt,
369    ) -> Result<PaginationResult, PaginatorError> {
370        self.check_state(PaginatorState::Idle)?;
371
372        let token = {
373            let tokens = self.tokens.lock().unwrap();
374
375            let token = match dir {
376                Direction::Backward => &tokens.previous,
377                Direction::Forward => &tokens.next,
378            };
379
380            match token {
381                PaginationToken::None => None,
382                PaginationToken::HasMore(val) => Some(val.clone()),
383                PaginationToken::HitEnd => {
384                    return Ok(PaginationResult { events: Vec::new(), hit_end_of_timeline: true });
385                }
386            }
387        };
388
389        // Note: it's possible two callers have checked the state and both figured it's
390        // idle. This check below makes sure there's at most one which can set the
391        // state to paginating, preventing a race condition.
392        if self.state.set_if_not_eq(PaginatorState::Paginating).is_none() {
393            return Err(PaginatorError::InvalidPreviousState {
394                expected: PaginatorState::Idle,
395                actual: PaginatorState::Paginating,
396            });
397        }
398
399        let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Idle);
400
401        let mut options = MessagesOptions::new(dir).from(token.as_deref());
402        options.limit = num_events;
403
404        // In case of error, the state is reset to idle automatically thanks to
405        // reset_state_guard.
406        let response = self.room.messages(options).await?;
407
408        // NOTE: it's super important to not have any `await` after this point, since we
409        // don't want the task to be interrupted anymore, or the internal state
410        // may be incorrect.
411
412        let hit_end_of_timeline = response.end.is_none();
413
414        {
415            let mut tokens = self.tokens.lock().unwrap();
416
417            let token = match dir {
418                Direction::Backward => &mut tokens.previous,
419                Direction::Forward => &mut tokens.next,
420            };
421
422            *token = match response.end {
423                Some(val) => PaginationToken::HasMore(val),
424                None => PaginationToken::HitEnd,
425            };
426        }
427
428        // TODO: what to do with state events?
429
430        // Forget the reset state guard, so its Drop method is not called.
431        reset_state_guard.disarm();
432        // And set the final state.
433        self.state.set(PaginatorState::Idle);
434
435        Ok(PaginationResult { events: response.chunk, hit_end_of_timeline })
436    }
437}
438
439/// A room that can be paginated.
440///
441/// Not [`crate::Room`] because we may want to paginate rooms we don't belong
442/// to.
443pub trait PaginableRoom: SendOutsideWasm + SyncOutsideWasm {
444    /// Runs a /context query for the given room.
445    ///
446    /// ## Parameters
447    ///
448    /// - `event_id` is the identifier of the target event.
449    /// - `lazy_load_members` controls whether room membership events are lazily
450    ///   loaded as context state events.
451    /// - `num_events` is the number of events (including the fetched event) to
452    ///   return as context.
453    ///
454    /// ## Returns
455    ///
456    /// Must return [`PaginatorError::EventNotFound`] whenever the target event
457    /// could not be found, instead of causing an http `Err` result.
458    fn event_with_context(
459        &self,
460        event_id: &EventId,
461        lazy_load_members: bool,
462        num_events: UInt,
463    ) -> impl Future<Output = Result<EventWithContextResponse, PaginatorError>> + SendOutsideWasm;
464
465    /// Runs a /messages query for the given room.
466    fn messages(
467        &self,
468        opts: MessagesOptions,
469    ) -> impl Future<Output = Result<Messages, PaginatorError>> + SendOutsideWasm;
470}
471
472impl PaginableRoom for Room {
473    async fn event_with_context(
474        &self,
475        event_id: &EventId,
476        lazy_load_members: bool,
477        num_events: UInt,
478    ) -> Result<EventWithContextResponse, PaginatorError> {
479        let response =
480            match self.event_with_context(event_id, lazy_load_members, num_events, None).await {
481                Ok(result) => result,
482
483                Err(err) => {
484                    // If the error was a 404, then the event wasn't found on the server;
485                    // special case this to make it easy to react to
486                    // such an error.
487                    if let Some(error) = err.as_client_api_error() {
488                        if error.status_code == 404 {
489                            // Event not found
490                            return Err(PaginatorError::EventNotFound(event_id.to_owned()));
491                        }
492                    }
493
494                    // Otherwise, just return a wrapped error.
495                    return Err(PaginatorError::SdkError(Box::new(err)));
496                }
497            };
498
499        Ok(response)
500    }
501
502    async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
503        self.messages(opts).await.map_err(|err| PaginatorError::SdkError(Box::new(err)))
504    }
505}
506
507impl PaginableRoom for WeakRoom {
508    async fn event_with_context(
509        &self,
510        event_id: &EventId,
511        lazy_load_members: bool,
512        num_events: UInt,
513    ) -> Result<EventWithContextResponse, PaginatorError> {
514        let Some(room) = self.get() else {
515            // Client is shutting down, return a default response.
516            return Ok(EventWithContextResponse::default());
517        };
518
519        PaginableRoom::event_with_context(&room, event_id, lazy_load_members, num_events).await
520    }
521
522    async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
523        let Some(room) = self.get() else {
524            // Client is shutting down, return a default response.
525            return Ok(Messages::default());
526        };
527
528        PaginableRoom::messages(&room, opts).await
529    }
530}
531
532#[cfg(all(not(target_arch = "wasm32"), test))]
533mod tests {
534    use std::sync::Arc;
535
536    use assert_matches2::assert_let;
537    use futures_core::Future;
538    use futures_util::FutureExt as _;
539    use matrix_sdk_base::deserialized_responses::TimelineEvent;
540    use matrix_sdk_test::{async_test, event_factory::EventFactory};
541    use once_cell::sync::Lazy;
542    use ruma::{api::Direction, event_id, room_id, uint, user_id, EventId, RoomId, UInt, UserId};
543    use tokio::{
544        spawn,
545        sync::{Mutex, Notify},
546        task::AbortHandle,
547    };
548
549    use super::{PaginableRoom, PaginatorError, PaginatorState};
550    use crate::{
551        event_cache::paginator::Paginator,
552        room::{EventWithContextResponse, Messages, MessagesOptions},
553        test_utils::assert_event_matches_msg,
554    };
555
556    #[derive(Clone)]
557    struct TestRoom {
558        event_factory: Arc<EventFactory>,
559        wait_for_ready: bool,
560
561        target_event_text: Arc<Mutex<String>>,
562        next_events: Arc<Mutex<Vec<TimelineEvent>>>,
563        prev_events: Arc<Mutex<Vec<TimelineEvent>>>,
564        prev_batch_token: Arc<Mutex<Option<String>>>,
565        next_batch_token: Arc<Mutex<Option<String>>>,
566
567        room_ready: Arc<Notify>,
568    }
569
570    impl TestRoom {
571        fn new(wait_for_ready: bool, room_id: &RoomId, sender: &UserId) -> Self {
572            let event_factory = Arc::new(EventFactory::default().sender(sender).room(room_id));
573
574            Self {
575                event_factory,
576                wait_for_ready,
577
578                room_ready: Default::default(),
579                target_event_text: Default::default(),
580                next_events: Default::default(),
581                prev_events: Default::default(),
582                prev_batch_token: Default::default(),
583                next_batch_token: Default::default(),
584            }
585        }
586
587        /// Unblocks the next request.
588        fn mark_ready(&self) {
589            self.room_ready.notify_one();
590        }
591    }
592
593    static ROOM_ID: Lazy<&RoomId> = Lazy::new(|| room_id!("!dune:herbert.org"));
594    static USER_ID: Lazy<&UserId> = Lazy::new(|| user_id!("@paul:atreid.es"));
595
596    impl PaginableRoom for TestRoom {
597        async fn event_with_context(
598            &self,
599            event_id: &EventId,
600            _lazy_load_members: bool,
601            num_events: UInt,
602        ) -> Result<EventWithContextResponse, PaginatorError> {
603            // Wait for the room to be marked as ready first.
604            if self.wait_for_ready {
605                self.room_ready.notified().await;
606            }
607
608            let event = self
609                .event_factory
610                .text_msg(self.target_event_text.lock().await.clone())
611                .event_id(event_id)
612                .into_event();
613
614            // Properly simulate `num_events`: take either the closest num_events events
615            // before, or use all of the before events and then consume after events.
616            let mut num_events = u64::from(num_events) as usize;
617
618            let prev_events = self.prev_events.lock().await;
619
620            let events_before = if prev_events.is_empty() {
621                Vec::new()
622            } else {
623                let len = prev_events.len();
624                let take_before = num_events.min(len);
625                // Subtract is safe because take_before <= num_events.
626                num_events -= take_before;
627                // Subtract is safe because take_before <= len
628                prev_events[len - take_before..len].to_vec()
629            };
630
631            let events_after = self.next_events.lock().await;
632            let events_after = if events_after.is_empty() {
633                Vec::new()
634            } else {
635                events_after[0..num_events.min(events_after.len())].to_vec()
636            };
637
638            Ok(EventWithContextResponse {
639                event: Some(event),
640                events_before,
641                events_after,
642                prev_batch_token: self.prev_batch_token.lock().await.clone(),
643                next_batch_token: self.next_batch_token.lock().await.clone(),
644                state: Vec::new(),
645            })
646        }
647
648        async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
649            if self.wait_for_ready {
650                self.room_ready.notified().await;
651            }
652
653            let limit = u64::from(opts.limit) as usize;
654
655            let (end, events) = match opts.dir {
656                Direction::Backward => {
657                    let events = self.prev_events.lock().await;
658                    let events = if events.is_empty() {
659                        Vec::new()
660                    } else {
661                        let len = events.len();
662                        let take_before = limit.min(len);
663                        // Subtract is safe because take_before <= len
664                        events[len - take_before..len].to_vec()
665                    };
666                    (self.prev_batch_token.lock().await.clone(), events)
667                }
668
669                Direction::Forward => {
670                    let events = self.next_events.lock().await;
671                    let events = if events.is_empty() {
672                        Vec::new()
673                    } else {
674                        events[0..limit.min(events.len())].to_vec()
675                    };
676                    (self.next_batch_token.lock().await.clone(), events)
677                }
678            };
679
680            Ok(Messages { start: opts.from.unwrap(), end, chunk: events, state: Vec::new() })
681        }
682    }
683
684    async fn assert_invalid_state<T: std::fmt::Debug>(
685        task: impl Future<Output = Result<T, PaginatorError>>,
686        expected: PaginatorState,
687        actual: PaginatorState,
688    ) {
689        assert_let!(
690            Err(PaginatorError::InvalidPreviousState {
691                expected: real_expected,
692                actual: real_actual
693            }) = task.await
694        );
695        assert_eq!(real_expected, expected);
696        assert_eq!(real_actual, actual);
697    }
698
699    #[async_test]
700    async fn test_start_from() {
701        // Prepare test data.
702        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
703
704        let event_id = event_id!("$yoyoyo");
705        let event_factory = &room.event_factory;
706
707        *room.target_event_text.lock().await = "fetch_from".to_owned();
708        *room.prev_events.lock().await = (0..10)
709            .rev()
710            .map(|i| event_factory.text_msg(format!("before-{i}")).into_event())
711            .collect();
712        *room.next_events.lock().await =
713            (0..10).map(|i| event_factory.text_msg(format!("after-{i}")).into_event()).collect();
714
715        // When I call `Paginator::start_from`, it works,
716        let paginator = Arc::new(Paginator::new(room.clone()));
717        let context =
718            paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
719
720        assert!(!context.has_prev);
721        assert!(!context.has_next);
722
723        // And I get the events I expected.
724
725        // 10 events before, the target event, 10 events after.
726        assert_eq!(context.events.len(), 21);
727
728        for i in 0..10 {
729            assert_event_matches_msg(&context.events[i], &format!("before-{i}"));
730        }
731
732        assert_event_matches_msg(&context.events[10], "fetch_from");
733        assert_eq!(context.events[10].raw().deserialize().unwrap().event_id(), event_id);
734
735        for i in 0..10 {
736            assert_event_matches_msg(&context.events[i + 11], &format!("after-{i}"));
737        }
738    }
739
740    #[async_test]
741    async fn test_start_from_with_num_events() {
742        // Prepare test data.
743        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
744
745        let event_id = event_id!("$yoyoyo");
746        let event_factory = &room.event_factory;
747
748        *room.target_event_text.lock().await = "fetch_from".to_owned();
749        *room.prev_events.lock().await =
750            (0..100).rev().map(|i| event_factory.text_msg(format!("ev{i}")).into_event()).collect();
751
752        // When I call `Paginator::start_from`, it works,
753        let paginator = Arc::new(Paginator::new(room.clone()));
754        let context =
755            paginator.start_from(event_id, uint!(10)).await.expect("start_from should work");
756
757        // Then I only get 10 events + the target event, even if there was more than 10
758        // events in the room.
759        assert_eq!(context.events.len(), 11);
760
761        for i in 0..10 {
762            assert_event_matches_msg(&context.events[i], &format!("ev{i}"));
763        }
764        assert_event_matches_msg(&context.events[10], "fetch_from");
765    }
766
767    #[async_test]
768    async fn test_paginate_backward() {
769        // Prepare test data.
770        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
771
772        let event_id = event_id!("$yoyoyo");
773        let event_factory = &room.event_factory;
774
775        *room.target_event_text.lock().await = "initial".to_owned();
776        *room.prev_batch_token.lock().await = Some("prev".to_owned());
777
778        // When I call `Paginator::start_from`, it works,
779        let paginator = Arc::new(Paginator::new(room.clone()));
780
781        assert!(!paginator.hit_timeline_start(), "we must have a prev-batch token");
782        assert!(
783            !paginator.hit_timeline_end(),
784            "we don't know about the status of the next-batch token"
785        );
786
787        let context =
788            paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
789
790        // And I get the events I expected.
791        assert_eq!(context.events.len(), 1);
792        assert_event_matches_msg(&context.events[0], "initial");
793        assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
794
795        // There's a previous batch, but no next batch.
796        assert!(context.has_prev);
797        assert!(!context.has_next);
798
799        assert!(!paginator.hit_timeline_start());
800        assert!(paginator.hit_timeline_end());
801
802        // Preparing data for the next back-pagination.
803        *room.prev_events.lock().await = vec![event_factory.text_msg("previous").into_event()];
804        *room.prev_batch_token.lock().await = Some("prev2".to_owned());
805
806        // When I backpaginate, I get the events I expect.
807        let prev =
808            paginator.paginate_backward(uint!(100)).await.expect("paginate backward should work");
809        assert!(!prev.hit_end_of_timeline);
810        assert!(!paginator.hit_timeline_start());
811        assert_eq!(prev.events.len(), 1);
812        assert_event_matches_msg(&prev.events[0], "previous");
813
814        // And I can backpaginate again, because there's a prev batch token
815        // still.
816        *room.prev_events.lock().await = vec![event_factory.text_msg("oldest").into_event()];
817        *room.prev_batch_token.lock().await = None;
818
819        let prev = paginator
820            .paginate_backward(uint!(100))
821            .await
822            .expect("paginate backward the second time should work");
823        assert!(prev.hit_end_of_timeline);
824        assert!(paginator.hit_timeline_start());
825        assert_eq!(prev.events.len(), 1);
826        assert_event_matches_msg(&prev.events[0], "oldest");
827
828        // I've hit the start of the timeline, but back-paginating again will
829        // return immediately.
830        let prev = paginator
831            .paginate_backward(uint!(100))
832            .await
833            .expect("paginate backward the third time should work");
834        assert!(prev.hit_end_of_timeline);
835        assert!(paginator.hit_timeline_start());
836        assert!(prev.events.is_empty());
837    }
838
839    #[async_test]
840    async fn test_paginate_backward_with_limit() {
841        // Prepare test data.
842        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
843
844        let event_id = event_id!("$yoyoyo");
845        let event_factory = &room.event_factory;
846
847        *room.target_event_text.lock().await = "initial".to_owned();
848        *room.prev_batch_token.lock().await = Some("prev".to_owned());
849
850        // When I call `Paginator::start_from`, it works,
851        let paginator = Arc::new(Paginator::new(room.clone()));
852        let context =
853            paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
854
855        // And I get the events I expected.
856        assert_eq!(context.events.len(), 1);
857        assert_event_matches_msg(&context.events[0], "initial");
858        assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
859
860        // There's a previous batch.
861        assert!(context.has_prev);
862        assert!(!context.has_next);
863
864        // Preparing data for the next back-pagination.
865        *room.prev_events.lock().await = (0..100)
866            .rev()
867            .map(|i| event_factory.text_msg(format!("prev{i}")).into_event())
868            .collect();
869        *room.prev_batch_token.lock().await = None;
870
871        // When I backpaginate and request 100 events, I get only 10 events.
872        let prev =
873            paginator.paginate_backward(uint!(10)).await.expect("paginate backward should work");
874        assert!(prev.hit_end_of_timeline);
875        assert_eq!(prev.events.len(), 10);
876        for i in 0..10 {
877            assert_event_matches_msg(&prev.events[i], &format!("prev{}", 9 - i));
878        }
879    }
880
881    #[async_test]
882    async fn test_paginate_forward() {
883        // Prepare test data.
884        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
885
886        let event_id = event_id!("$yoyoyo");
887        let event_factory = &room.event_factory;
888
889        *room.target_event_text.lock().await = "initial".to_owned();
890        *room.next_batch_token.lock().await = Some("next".to_owned());
891
892        // When I call `Paginator::start_from`, it works,
893        let paginator = Arc::new(Paginator::new(room.clone()));
894        assert!(!paginator.hit_timeline_end(), "we must have a next-batch token");
895        assert!(
896            !paginator.hit_timeline_start(),
897            "we don't know about the status of the prev-batch token"
898        );
899
900        let context =
901            paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
902
903        // And I get the events I expected.
904        assert_eq!(context.events.len(), 1);
905        assert_event_matches_msg(&context.events[0], "initial");
906        assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
907
908        // There's a next batch, but no previous batch (i.e. we've hit the start of the
909        // timeline).
910        assert!(!context.has_prev);
911        assert!(context.has_next);
912
913        assert!(paginator.hit_timeline_start());
914        assert!(!paginator.hit_timeline_end());
915
916        // Preparing data for the next forward-pagination.
917        *room.next_events.lock().await = vec![event_factory.text_msg("next").into_event()];
918        *room.next_batch_token.lock().await = Some("next2".to_owned());
919
920        // When I forward-paginate, I get the events I expect.
921        let next =
922            paginator.paginate_forward(uint!(100)).await.expect("paginate forward should work");
923        assert!(!next.hit_end_of_timeline);
924        assert_eq!(next.events.len(), 1);
925        assert_event_matches_msg(&next.events[0], "next");
926        assert!(!paginator.hit_timeline_end());
927
928        // And I can forward-paginate again, because there's a prev batch token
929        // still.
930        *room.next_events.lock().await = vec![event_factory.text_msg("latest").into_event()];
931        *room.next_batch_token.lock().await = None;
932
933        let next = paginator
934            .paginate_forward(uint!(100))
935            .await
936            .expect("paginate forward the second time should work");
937        assert!(next.hit_end_of_timeline);
938        assert_eq!(next.events.len(), 1);
939        assert_event_matches_msg(&next.events[0], "latest");
940        assert!(paginator.hit_timeline_end());
941
942        // I've hit the start of the timeline, but back-paginating again will
943        // return immediately.
944        let next = paginator
945            .paginate_forward(uint!(100))
946            .await
947            .expect("paginate forward the third time should work");
948        assert!(next.hit_end_of_timeline);
949        assert!(next.events.is_empty());
950        assert!(paginator.hit_timeline_end());
951    }
952
953    #[async_test]
954    async fn test_state() {
955        let room = TestRoom::new(true, *ROOM_ID, *USER_ID);
956
957        *room.prev_batch_token.lock().await = Some("prev".to_owned());
958        *room.next_batch_token.lock().await = Some("next".to_owned());
959
960        let paginator = Arc::new(Paginator::new(room.clone()));
961
962        let event_id = event_id!("$yoyoyo");
963
964        let mut state = paginator.state();
965
966        assert_eq!(state.get(), PaginatorState::Initial);
967        assert!(state.next().now_or_never().is_none());
968
969        // Attempting to run pagination must fail and not change the state.
970        assert_invalid_state(
971            paginator.paginate_backward(uint!(100)),
972            PaginatorState::Idle,
973            PaginatorState::Initial,
974        )
975        .await;
976
977        assert!(state.next().now_or_never().is_none());
978
979        // Running the initial query must work.
980        let p = paginator.clone();
981        let join_handle = spawn(async move { p.start_from(event_id, uint!(100)).await });
982
983        assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
984        assert!(state.next().now_or_never().is_none());
985
986        // The query is pending. Running other operations must fail.
987        assert_invalid_state(
988            paginator.start_from(event_id, uint!(100)),
989            PaginatorState::Initial,
990            PaginatorState::FetchingTargetEvent,
991        )
992        .await;
993
994        assert_invalid_state(
995            paginator.paginate_backward(uint!(100)),
996            PaginatorState::Idle,
997            PaginatorState::FetchingTargetEvent,
998        )
999        .await;
1000
1001        assert!(state.next().now_or_never().is_none());
1002
1003        // Mark the dummy room as ready. The query may now terminate.
1004        room.mark_ready();
1005
1006        // After fetching the initial event data, the paginator switches to `Idle`.
1007        assert_eq!(state.next().await, Some(PaginatorState::Idle));
1008
1009        join_handle.await.expect("joined failed").expect("/context failed");
1010
1011        assert!(state.next().now_or_never().is_none());
1012
1013        let p = paginator.clone();
1014        let join_handle = spawn(async move { p.paginate_backward(uint!(100)).await });
1015
1016        assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1017
1018        // The query is pending. Running other operations must fail.
1019        assert_invalid_state(
1020            paginator.start_from(event_id, uint!(100)),
1021            PaginatorState::Initial,
1022            PaginatorState::Paginating,
1023        )
1024        .await;
1025
1026        assert_invalid_state(
1027            paginator.paginate_backward(uint!(100)),
1028            PaginatorState::Idle,
1029            PaginatorState::Paginating,
1030        )
1031        .await;
1032
1033        assert_invalid_state(
1034            paginator.paginate_forward(uint!(100)),
1035            PaginatorState::Idle,
1036            PaginatorState::Paginating,
1037        )
1038        .await;
1039
1040        assert!(state.next().now_or_never().is_none());
1041
1042        room.mark_ready();
1043
1044        assert_eq!(state.next().await, Some(PaginatorState::Idle));
1045
1046        join_handle.await.expect("joined failed").expect("/messages failed");
1047
1048        assert!(state.next().now_or_never().is_none());
1049    }
1050
1051    mod aborts {
1052        use super::*;
1053
1054        #[derive(Clone, Default)]
1055        struct AbortingRoom {
1056            abort_handle: Arc<Mutex<Option<AbortHandle>>>,
1057            room_ready: Arc<Notify>,
1058        }
1059
1060        impl AbortingRoom {
1061            async fn wait_abort_and_yield(&self) -> ! {
1062                // Wait for the controller to tell us we're ready.
1063                self.room_ready.notified().await;
1064
1065                // Abort the given handle.
1066                let mut guard = self.abort_handle.lock().await;
1067                let handle = guard.take().expect("only call me when i'm initialized");
1068                handle.abort();
1069
1070                // Enter an endless loop of yielding.
1071                loop {
1072                    tokio::task::yield_now().await;
1073                }
1074            }
1075        }
1076
1077        impl PaginableRoom for AbortingRoom {
1078            async fn event_with_context(
1079                &self,
1080                _event_id: &EventId,
1081                _lazy_load_members: bool,
1082                _num_events: UInt,
1083            ) -> Result<EventWithContextResponse, PaginatorError> {
1084                self.wait_abort_and_yield().await
1085            }
1086
1087            async fn messages(&self, _opts: MessagesOptions) -> Result<Messages, PaginatorError> {
1088                self.wait_abort_and_yield().await
1089            }
1090        }
1091
1092        #[async_test]
1093        async fn test_abort_while_starting_from() {
1094            let room = AbortingRoom::default();
1095
1096            let paginator = Arc::new(Paginator::new(room.clone()));
1097
1098            let mut state = paginator.state();
1099
1100            assert_eq!(state.get(), PaginatorState::Initial);
1101            assert!(state.next().now_or_never().is_none());
1102
1103            // When I try to start the initial query…
1104            let p = paginator.clone();
1105            let join_handle = spawn(async move {
1106                let _ = p.start_from(event_id!("$yoyoyo"), uint!(100)).await;
1107            });
1108
1109            *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1110
1111            assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
1112            assert!(state.next().now_or_never().is_none());
1113
1114            room.room_ready.notify_one();
1115
1116            // But it's aborted when awaiting the task.
1117            let join_result = join_handle.await;
1118            assert!(join_result.unwrap_err().is_cancelled());
1119
1120            // Then the state is reset to initial.
1121            assert_eq!(state.next().await, Some(PaginatorState::Initial));
1122            assert!(state.next().now_or_never().is_none());
1123        }
1124
1125        #[async_test]
1126        async fn test_abort_while_paginating() {
1127            let room = AbortingRoom::default();
1128
1129            // Assuming a paginator ready to back- or forward- paginate,
1130            let paginator = Paginator::new(room.clone());
1131            paginator
1132                .set_idle_state(
1133                    PaginatorState::Idle,
1134                    Some("prev".to_owned()),
1135                    Some("next".to_owned()),
1136                )
1137                .unwrap();
1138
1139            let paginator = Arc::new(paginator);
1140
1141            let mut state = paginator.state();
1142
1143            assert_eq!(state.get(), PaginatorState::Idle);
1144            assert!(state.next().now_or_never().is_none());
1145
1146            // When I try to back-paginate…
1147            let p = paginator.clone();
1148            let join_handle = spawn(async move {
1149                let _ = p.paginate_backward(uint!(100)).await;
1150            });
1151
1152            *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1153
1154            assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1155            assert!(state.next().now_or_never().is_none());
1156
1157            room.room_ready.notify_one();
1158
1159            // But it's aborted when awaiting the task.
1160            let join_result = join_handle.await;
1161            assert!(join_result.unwrap_err().is_cancelled());
1162
1163            // Then the state is reset to idle.
1164            assert_eq!(state.next().await, Some(PaginatorState::Idle));
1165            assert!(state.next().now_or_never().is_none());
1166
1167            // And ditto for forward pagination.
1168            let p = paginator.clone();
1169            let join_handle = spawn(async move {
1170                let _ = p.paginate_forward(uint!(100)).await;
1171            });
1172
1173            *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1174
1175            assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1176            assert!(state.next().now_or_never().is_none());
1177
1178            room.room_ready.notify_one();
1179
1180            let join_result = join_handle.await;
1181            assert!(join_result.unwrap_err().is_cancelled());
1182
1183            assert_eq!(state.next().await, Some(PaginatorState::Idle));
1184            assert!(state.next().now_or_never().is_none());
1185        }
1186    }
1187}