matrix_sdk/event_cache/
paginator.rs

1// Copyright 2024 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! The paginator is a stateful helper object that handles reaching an event,
16//! either from a cache or network, and surrounding events ("context"). Then, it
17//! makes it possible to paginate forward or backward, from that event, until
18//! one end of the timeline (front or back) is reached.
19
20use std::{future::Future, sync::Mutex};
21
22use eyeball::{SharedObservable, Subscriber};
23use matrix_sdk_base::{deserialized_responses::TimelineEvent, SendOutsideWasm, SyncOutsideWasm};
24use ruma::{api::Direction, EventId, OwnedEventId, UInt};
25
26use super::pagination::PaginationToken;
27use crate::{
28    room::{EventWithContextResponse, Messages, MessagesOptions, WeakRoom},
29    Room,
30};
31
32/// Current state of a [`Paginator`].
33#[derive(Debug, PartialEq, Copy, Clone)]
34#[cfg_attr(feature = "uniffi", derive(uniffi::Enum))]
35pub enum PaginatorState {
36    /// The initial state of the paginator.
37    Initial,
38
39    /// The paginator is fetching the target initial event.
40    FetchingTargetEvent,
41
42    /// The target initial event could be found, zero or more paginations have
43    /// happened since then, and the paginator is at rest now.
44    Idle,
45
46    /// The paginator is… paginating one direction or another.
47    Paginating,
48}
49
50/// An error that happened when using a [`Paginator`].
51#[derive(Debug, thiserror::Error)]
52pub enum PaginatorError {
53    /// The target event could not be found.
54    #[error("target event with id {0} could not be found")]
55    EventNotFound(OwnedEventId),
56
57    /// We're trying to manipulate the paginator in the wrong state.
58    #[error("expected paginator state {expected:?}, observed {actual:?}")]
59    InvalidPreviousState {
60        /// The state we were expecting to see.
61        expected: PaginatorState,
62        /// The actual state when doing the check.
63        actual: PaginatorState,
64    },
65
66    /// There was another SDK error while paginating.
67    #[error("an error happened while paginating: {0}")]
68    SdkError(#[from] Box<crate::Error>),
69}
70
71/// Paginations tokens used for backward and forward pagination.
72#[derive(Debug)]
73struct PaginationTokens {
74    /// Pagination token used for backward pagination.
75    previous: PaginationToken,
76    /// Pagination token used for forward pagination.
77    next: PaginationToken,
78}
79
80/// A stateful object to reach to an event, and then paginate backward and
81/// forward from it.
82///
83/// See also the module-level documentation.
84pub struct Paginator<PR: PaginableRoom> {
85    /// The room in which we're going to run the pagination.
86    room: PR,
87
88    /// Current state of the paginator.
89    state: SharedObservable<PaginatorState>,
90
91    /// Pagination tokens used for subsequent requests.
92    ///
93    /// This mutex is always short-lived, so it's sync.
94    tokens: Mutex<PaginationTokens>,
95}
96
97#[cfg(not(tarpaulin_include))]
98impl<PR: PaginableRoom> std::fmt::Debug for Paginator<PR> {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        // Don't include the room in the debug output.
101        f.debug_struct("Paginator")
102            .field("state", &self.state.get())
103            .field("tokens", &self.tokens)
104            .finish_non_exhaustive()
105    }
106}
107
108/// The result of a single pagination, be it from
109/// [`Paginator::paginate_backward`] or [`Paginator::paginate_forward`].
110#[derive(Debug)]
111pub struct PaginationResult {
112    /// Events returned during this pagination.
113    ///
114    /// If this is the result of a backward pagination, then the events are in
115    /// reverse topological order.
116    ///
117    /// If this is the result of a forward pagination, then the events are in
118    /// topological order.
119    pub events: Vec<TimelineEvent>,
120
121    /// Did we hit *an* end of the timeline?
122    ///
123    /// If this is the result of a backward pagination, this means we hit the
124    /// *start* of the timeline.
125    ///
126    /// If this is the result of a forward pagination, this means we hit the
127    /// *end* of the timeline.
128    pub hit_end_of_timeline: bool,
129}
130
131/// The result of an initial [`Paginator::start_from`] query.
132#[derive(Debug)]
133pub struct StartFromResult {
134    /// All the events returned during this pagination, in topological ordering.
135    pub events: Vec<TimelineEvent>,
136
137    /// Whether the /context query returned a previous batch token.
138    pub has_prev: bool,
139
140    /// Whether the /context query returned a next batch token.
141    pub has_next: bool,
142}
143
144/// Reset the state to a given target on drop.
145struct ResetStateGuard {
146    target: Option<PaginatorState>,
147    state: SharedObservable<PaginatorState>,
148}
149
150impl ResetStateGuard {
151    /// Create a new reset state guard.
152    fn new(state: SharedObservable<PaginatorState>, target: PaginatorState) -> Self {
153        Self { target: Some(target), state }
154    }
155
156    /// Render the guard effectless, and consume it.
157    fn disarm(mut self) {
158        self.target = None;
159    }
160}
161
162impl Drop for ResetStateGuard {
163    fn drop(&mut self) {
164        if let Some(target) = self.target.take() {
165            self.state.set_if_not_eq(target);
166        }
167    }
168}
169
170impl<PR: PaginableRoom> Paginator<PR> {
171    /// Create a new [`Paginator`], given a room implementation.
172    pub fn new(room: PR) -> Self {
173        Self {
174            room,
175            state: SharedObservable::new(PaginatorState::Initial),
176            tokens: Mutex::new(PaginationTokens { previous: None.into(), next: None.into() }),
177        }
178    }
179
180    /// Check if the current state of the paginator matches the expected one.
181    fn check_state(&self, expected: PaginatorState) -> Result<(), PaginatorError> {
182        let actual = self.state.get();
183        if actual != expected {
184            Err(PaginatorError::InvalidPreviousState { expected, actual })
185        } else {
186            Ok(())
187        }
188    }
189
190    /// Returns a subscriber to the internal [`PaginatorState`] machine.
191    pub fn state(&self) -> Subscriber<PaginatorState> {
192        self.state.subscribe()
193    }
194
195    /// Starts the pagination from the initial event, requesting `num_events`
196    /// additional context events.
197    ///
198    /// Only works for fresh [`Paginator`] objects, which are in the
199    /// [`PaginatorState::Initial`] state.
200    pub async fn start_from(
201        &self,
202        event_id: &EventId,
203        num_events: UInt,
204    ) -> Result<StartFromResult, PaginatorError> {
205        self.check_state(PaginatorState::Initial)?;
206
207        // Note: it's possible two callers have checked the state and both figured it's
208        // initial. This check below makes sure there's at most one which can set the
209        // state to FetchingTargetEvent, preventing a race condition.
210        if self.state.set_if_not_eq(PaginatorState::FetchingTargetEvent).is_none() {
211            return Err(PaginatorError::InvalidPreviousState {
212                expected: PaginatorState::Initial,
213                actual: PaginatorState::FetchingTargetEvent,
214            });
215        }
216
217        let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Initial);
218
219        // TODO: do we want to lazy load members?
220        let lazy_load_members = true;
221
222        let response =
223            self.room.event_with_context(event_id, lazy_load_members, num_events).await?;
224
225        // NOTE: it's super important to not have any `await` after this point, since we
226        // don't want the task to be interrupted anymore, or the internal state
227        // may become incorrect.
228
229        let has_prev = response.prev_batch_token.is_some();
230        let has_next = response.next_batch_token.is_some();
231
232        {
233            let mut tokens = self.tokens.lock().unwrap();
234            tokens.previous = match response.prev_batch_token {
235                Some(token) => PaginationToken::HasMore(token),
236                None => PaginationToken::HitEnd,
237            };
238            tokens.next = match response.next_batch_token {
239                Some(token) => PaginationToken::HasMore(token),
240                None => PaginationToken::HitEnd,
241            };
242        }
243
244        // Forget the reset state guard, so its Drop method is not called.
245        reset_state_guard.disarm();
246        // And set the final state.
247        self.state.set(PaginatorState::Idle);
248
249        // Consolidate the events into a linear timeline, topologically ordered.
250        // - the events before are returned in the reverse topological order: invert
251        //   them.
252        // - insert the target event, if set.
253        // - the events after are returned in the correct topological order.
254
255        let events = response
256            .events_before
257            .into_iter()
258            .rev()
259            .chain(response.event)
260            .chain(response.events_after)
261            .collect();
262
263        Ok(StartFromResult { events, has_prev, has_next })
264    }
265
266    /// Runs a backward pagination (requesting `num_events` to the server), from
267    /// the current state of the object.
268    ///
269    /// Will return immediately if we have already hit the start of the
270    /// timeline.
271    ///
272    /// May return an error if it's already paginating, or if the call to
273    /// /messages failed.
274    pub async fn paginate_backward(
275        &self,
276        num_events: UInt,
277    ) -> Result<PaginationResult, PaginatorError> {
278        self.paginate(Direction::Backward, num_events).await
279    }
280
281    /// Returns whether we've hit the start of the timeline.
282    ///
283    /// This is true if, and only if, we didn't have a previous-batch token and
284    /// running backwards pagination would be useless.
285    pub fn hit_timeline_start(&self) -> bool {
286        matches!(self.tokens.lock().unwrap().previous, PaginationToken::HitEnd)
287    }
288
289    /// Returns whether we've hit the end of the timeline.
290    ///
291    /// This is true if, and only if, we didn't have a next-batch token and
292    /// running forwards pagination would be useless.
293    pub fn hit_timeline_end(&self) -> bool {
294        matches!(self.tokens.lock().unwrap().next, PaginationToken::HitEnd)
295    }
296
297    /// Runs a forward pagination (requesting `num_events` to the server), from
298    /// the current state of the object.
299    ///
300    /// Will return immediately if we have already hit the end of the timeline.
301    ///
302    /// May return an error if it's already paginating, or if the call to
303    /// /messages failed.
304    pub async fn paginate_forward(
305        &self,
306        num_events: UInt,
307    ) -> Result<PaginationResult, PaginatorError> {
308        self.paginate(Direction::Forward, num_events).await
309    }
310
311    /// Paginate in the given direction, requesting `num_events` events to the
312    /// server, using the `token_lock` to read from and write the pagination
313    /// token.
314    async fn paginate(
315        &self,
316        dir: Direction,
317        num_events: UInt,
318    ) -> Result<PaginationResult, PaginatorError> {
319        self.check_state(PaginatorState::Idle)?;
320
321        let token = {
322            let tokens = self.tokens.lock().unwrap();
323
324            let token = match dir {
325                Direction::Backward => &tokens.previous,
326                Direction::Forward => &tokens.next,
327            };
328
329            match token {
330                PaginationToken::None => None,
331                PaginationToken::HasMore(val) => Some(val.clone()),
332                PaginationToken::HitEnd => {
333                    return Ok(PaginationResult { events: Vec::new(), hit_end_of_timeline: true });
334                }
335            }
336        };
337
338        // Note: it's possible two callers have checked the state and both figured it's
339        // idle. This check below makes sure there's at most one which can set the
340        // state to paginating, preventing a race condition.
341        if self.state.set_if_not_eq(PaginatorState::Paginating).is_none() {
342            return Err(PaginatorError::InvalidPreviousState {
343                expected: PaginatorState::Idle,
344                actual: PaginatorState::Paginating,
345            });
346        }
347
348        let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Idle);
349
350        let mut options = MessagesOptions::new(dir).from(token.as_deref());
351        options.limit = num_events;
352
353        // In case of error, the state is reset to idle automatically thanks to
354        // reset_state_guard.
355        let response = self.room.messages(options).await?;
356
357        // NOTE: it's super important to not have any `await` after this point, since we
358        // don't want the task to be interrupted anymore, or the internal state
359        // may be incorrect.
360
361        let hit_end_of_timeline = response.end.is_none();
362
363        {
364            let mut tokens = self.tokens.lock().unwrap();
365
366            let token = match dir {
367                Direction::Backward => &mut tokens.previous,
368                Direction::Forward => &mut tokens.next,
369            };
370
371            *token = match response.end {
372                Some(val) => PaginationToken::HasMore(val),
373                None => PaginationToken::HitEnd,
374            };
375        }
376
377        // TODO: what to do with state events?
378
379        // Forget the reset state guard, so its Drop method is not called.
380        reset_state_guard.disarm();
381        // And set the final state.
382        self.state.set(PaginatorState::Idle);
383
384        Ok(PaginationResult { events: response.chunk, hit_end_of_timeline })
385    }
386}
387
388/// A room that can be paginated.
389///
390/// Not [`crate::Room`] because we may want to paginate rooms we don't belong
391/// to.
392pub trait PaginableRoom: SendOutsideWasm + SyncOutsideWasm {
393    /// Runs a /context query for the given room.
394    ///
395    /// ## Parameters
396    ///
397    /// - `event_id` is the identifier of the target event.
398    /// - `lazy_load_members` controls whether room membership events are lazily
399    ///   loaded as context state events.
400    /// - `num_events` is the number of events (including the fetched event) to
401    ///   return as context.
402    ///
403    /// ## Returns
404    ///
405    /// Must return [`PaginatorError::EventNotFound`] whenever the target event
406    /// could not be found, instead of causing an http `Err` result.
407    fn event_with_context(
408        &self,
409        event_id: &EventId,
410        lazy_load_members: bool,
411        num_events: UInt,
412    ) -> impl Future<Output = Result<EventWithContextResponse, PaginatorError>> + SendOutsideWasm;
413
414    /// Runs a /messages query for the given room.
415    fn messages(
416        &self,
417        opts: MessagesOptions,
418    ) -> impl Future<Output = Result<Messages, PaginatorError>> + SendOutsideWasm;
419}
420
421impl PaginableRoom for Room {
422    async fn event_with_context(
423        &self,
424        event_id: &EventId,
425        lazy_load_members: bool,
426        num_events: UInt,
427    ) -> Result<EventWithContextResponse, PaginatorError> {
428        let response =
429            match self.event_with_context(event_id, lazy_load_members, num_events, None).await {
430                Ok(result) => result,
431
432                Err(err) => {
433                    // If the error was a 404, then the event wasn't found on the server;
434                    // special case this to make it easy to react to
435                    // such an error.
436                    if let Some(error) = err.as_client_api_error() {
437                        if error.status_code == 404 {
438                            // Event not found
439                            return Err(PaginatorError::EventNotFound(event_id.to_owned()));
440                        }
441                    }
442
443                    // Otherwise, just return a wrapped error.
444                    return Err(PaginatorError::SdkError(Box::new(err)));
445                }
446            };
447
448        Ok(response)
449    }
450
451    async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
452        self.messages(opts).await.map_err(|err| PaginatorError::SdkError(Box::new(err)))
453    }
454}
455
456impl PaginableRoom for WeakRoom {
457    async fn event_with_context(
458        &self,
459        event_id: &EventId,
460        lazy_load_members: bool,
461        num_events: UInt,
462    ) -> Result<EventWithContextResponse, PaginatorError> {
463        let Some(room) = self.get() else {
464            // Client is shutting down, return a default response.
465            return Ok(EventWithContextResponse::default());
466        };
467
468        PaginableRoom::event_with_context(&room, event_id, lazy_load_members, num_events).await
469    }
470
471    async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
472        let Some(room) = self.get() else {
473            // Client is shutting down, return a default response.
474            return Ok(Messages::default());
475        };
476
477        PaginableRoom::messages(&room, opts).await
478    }
479}
480
481#[cfg(all(not(target_arch = "wasm32"), test))]
482mod tests {
483    use std::sync::Arc;
484
485    use assert_matches2::assert_let;
486    use futures_core::Future;
487    use futures_util::FutureExt as _;
488    use matrix_sdk_base::deserialized_responses::TimelineEvent;
489    use matrix_sdk_test::{async_test, event_factory::EventFactory};
490    use once_cell::sync::Lazy;
491    use ruma::{api::Direction, event_id, room_id, uint, user_id, EventId, RoomId, UInt, UserId};
492    use tokio::{
493        spawn,
494        sync::{Mutex, Notify},
495        task::AbortHandle,
496    };
497
498    use super::{PaginableRoom, PaginatorError, PaginatorState};
499    use crate::{
500        event_cache::paginator::Paginator,
501        room::{EventWithContextResponse, Messages, MessagesOptions},
502        test_utils::assert_event_matches_msg,
503    };
504
505    #[derive(Clone)]
506    struct TestRoom {
507        event_factory: Arc<EventFactory>,
508        wait_for_ready: bool,
509
510        target_event_text: Arc<Mutex<String>>,
511        next_events: Arc<Mutex<Vec<TimelineEvent>>>,
512        prev_events: Arc<Mutex<Vec<TimelineEvent>>>,
513        prev_batch_token: Arc<Mutex<Option<String>>>,
514        next_batch_token: Arc<Mutex<Option<String>>>,
515
516        room_ready: Arc<Notify>,
517    }
518
519    impl TestRoom {
520        fn new(wait_for_ready: bool, room_id: &RoomId, sender: &UserId) -> Self {
521            let event_factory = Arc::new(EventFactory::default().sender(sender).room(room_id));
522
523            Self {
524                event_factory,
525                wait_for_ready,
526
527                room_ready: Default::default(),
528                target_event_text: Default::default(),
529                next_events: Default::default(),
530                prev_events: Default::default(),
531                prev_batch_token: Default::default(),
532                next_batch_token: Default::default(),
533            }
534        }
535
536        /// Unblocks the next request.
537        fn mark_ready(&self) {
538            self.room_ready.notify_one();
539        }
540    }
541
542    static ROOM_ID: Lazy<&RoomId> = Lazy::new(|| room_id!("!dune:herbert.org"));
543    static USER_ID: Lazy<&UserId> = Lazy::new(|| user_id!("@paul:atreid.es"));
544
545    impl PaginableRoom for TestRoom {
546        async fn event_with_context(
547            &self,
548            event_id: &EventId,
549            _lazy_load_members: bool,
550            num_events: UInt,
551        ) -> Result<EventWithContextResponse, PaginatorError> {
552            // Wait for the room to be marked as ready first.
553            if self.wait_for_ready {
554                self.room_ready.notified().await;
555            }
556
557            let event = self
558                .event_factory
559                .text_msg(self.target_event_text.lock().await.clone())
560                .event_id(event_id)
561                .into_event();
562
563            // Properly simulate `num_events`: take either the closest num_events events
564            // before, or use all of the before events and then consume after events.
565            let mut num_events = u64::from(num_events) as usize;
566
567            let prev_events = self.prev_events.lock().await;
568
569            let events_before = if prev_events.is_empty() {
570                Vec::new()
571            } else {
572                let len = prev_events.len();
573                let take_before = num_events.min(len);
574                // Subtract is safe because take_before <= num_events.
575                num_events -= take_before;
576                // Subtract is safe because take_before <= len
577                prev_events[len - take_before..len].to_vec()
578            };
579
580            let events_after = self.next_events.lock().await;
581            let events_after = if events_after.is_empty() {
582                Vec::new()
583            } else {
584                events_after[0..num_events.min(events_after.len())].to_vec()
585            };
586
587            Ok(EventWithContextResponse {
588                event: Some(event),
589                events_before,
590                events_after,
591                prev_batch_token: self.prev_batch_token.lock().await.clone(),
592                next_batch_token: self.next_batch_token.lock().await.clone(),
593                state: Vec::new(),
594            })
595        }
596
597        async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
598            if self.wait_for_ready {
599                self.room_ready.notified().await;
600            }
601
602            let limit = u64::from(opts.limit) as usize;
603
604            let (end, events) = match opts.dir {
605                Direction::Backward => {
606                    let events = self.prev_events.lock().await;
607                    let events = if events.is_empty() {
608                        Vec::new()
609                    } else {
610                        let len = events.len();
611                        let take_before = limit.min(len);
612                        // Subtract is safe because take_before <= len
613                        events[len - take_before..len].to_vec()
614                    };
615                    (self.prev_batch_token.lock().await.clone(), events)
616                }
617
618                Direction::Forward => {
619                    let events = self.next_events.lock().await;
620                    let events = if events.is_empty() {
621                        Vec::new()
622                    } else {
623                        events[0..limit.min(events.len())].to_vec()
624                    };
625                    (self.next_batch_token.lock().await.clone(), events)
626                }
627            };
628
629            Ok(Messages { start: opts.from.unwrap(), end, chunk: events, state: Vec::new() })
630        }
631    }
632
633    async fn assert_invalid_state<T: std::fmt::Debug>(
634        task: impl Future<Output = Result<T, PaginatorError>>,
635        expected: PaginatorState,
636        actual: PaginatorState,
637    ) {
638        assert_let!(
639            Err(PaginatorError::InvalidPreviousState {
640                expected: real_expected,
641                actual: real_actual
642            }) = task.await
643        );
644        assert_eq!(real_expected, expected);
645        assert_eq!(real_actual, actual);
646    }
647
648    #[async_test]
649    async fn test_start_from() {
650        // Prepare test data.
651        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
652
653        let event_id = event_id!("$yoyoyo");
654        let event_factory = &room.event_factory;
655
656        *room.target_event_text.lock().await = "fetch_from".to_owned();
657        *room.prev_events.lock().await = (0..10)
658            .rev()
659            .map(|i| event_factory.text_msg(format!("before-{i}")).into_event())
660            .collect();
661        *room.next_events.lock().await =
662            (0..10).map(|i| event_factory.text_msg(format!("after-{i}")).into_event()).collect();
663
664        // When I call `Paginator::start_from`, it works,
665        let paginator = Arc::new(Paginator::new(room.clone()));
666        let context =
667            paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
668
669        assert!(!context.has_prev);
670        assert!(!context.has_next);
671
672        // And I get the events I expected.
673
674        // 10 events before, the target event, 10 events after.
675        assert_eq!(context.events.len(), 21);
676
677        for i in 0..10 {
678            assert_event_matches_msg(&context.events[i], &format!("before-{i}"));
679        }
680
681        assert_event_matches_msg(&context.events[10], "fetch_from");
682        assert_eq!(context.events[10].raw().deserialize().unwrap().event_id(), event_id);
683
684        for i in 0..10 {
685            assert_event_matches_msg(&context.events[i + 11], &format!("after-{i}"));
686        }
687    }
688
689    #[async_test]
690    async fn test_start_from_with_num_events() {
691        // Prepare test data.
692        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
693
694        let event_id = event_id!("$yoyoyo");
695        let event_factory = &room.event_factory;
696
697        *room.target_event_text.lock().await = "fetch_from".to_owned();
698        *room.prev_events.lock().await =
699            (0..100).rev().map(|i| event_factory.text_msg(format!("ev{i}")).into_event()).collect();
700
701        // When I call `Paginator::start_from`, it works,
702        let paginator = Arc::new(Paginator::new(room.clone()));
703        let context =
704            paginator.start_from(event_id, uint!(10)).await.expect("start_from should work");
705
706        // Then I only get 10 events + the target event, even if there was more than 10
707        // events in the room.
708        assert_eq!(context.events.len(), 11);
709
710        for i in 0..10 {
711            assert_event_matches_msg(&context.events[i], &format!("ev{i}"));
712        }
713        assert_event_matches_msg(&context.events[10], "fetch_from");
714    }
715
716    #[async_test]
717    async fn test_paginate_backward() {
718        // Prepare test data.
719        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
720
721        let event_id = event_id!("$yoyoyo");
722        let event_factory = &room.event_factory;
723
724        *room.target_event_text.lock().await = "initial".to_owned();
725        *room.prev_batch_token.lock().await = Some("prev".to_owned());
726
727        // When I call `Paginator::start_from`, it works,
728        let paginator = Arc::new(Paginator::new(room.clone()));
729
730        assert!(!paginator.hit_timeline_start(), "we must have a prev-batch token");
731        assert!(
732            !paginator.hit_timeline_end(),
733            "we don't know about the status of the next-batch token"
734        );
735
736        let context =
737            paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
738
739        // And I get the events I expected.
740        assert_eq!(context.events.len(), 1);
741        assert_event_matches_msg(&context.events[0], "initial");
742        assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
743
744        // There's a previous batch, but no next batch.
745        assert!(context.has_prev);
746        assert!(!context.has_next);
747
748        assert!(!paginator.hit_timeline_start());
749        assert!(paginator.hit_timeline_end());
750
751        // Preparing data for the next back-pagination.
752        *room.prev_events.lock().await = vec![event_factory.text_msg("previous").into_event()];
753        *room.prev_batch_token.lock().await = Some("prev2".to_owned());
754
755        // When I backpaginate, I get the events I expect.
756        let prev =
757            paginator.paginate_backward(uint!(100)).await.expect("paginate backward should work");
758        assert!(!prev.hit_end_of_timeline);
759        assert!(!paginator.hit_timeline_start());
760        assert_eq!(prev.events.len(), 1);
761        assert_event_matches_msg(&prev.events[0], "previous");
762
763        // And I can backpaginate again, because there's a prev batch token
764        // still.
765        *room.prev_events.lock().await = vec![event_factory.text_msg("oldest").into_event()];
766        *room.prev_batch_token.lock().await = None;
767
768        let prev = paginator
769            .paginate_backward(uint!(100))
770            .await
771            .expect("paginate backward the second time should work");
772        assert!(prev.hit_end_of_timeline);
773        assert!(paginator.hit_timeline_start());
774        assert_eq!(prev.events.len(), 1);
775        assert_event_matches_msg(&prev.events[0], "oldest");
776
777        // I've hit the start of the timeline, but back-paginating again will
778        // return immediately.
779        let prev = paginator
780            .paginate_backward(uint!(100))
781            .await
782            .expect("paginate backward the third time should work");
783        assert!(prev.hit_end_of_timeline);
784        assert!(paginator.hit_timeline_start());
785        assert!(prev.events.is_empty());
786    }
787
788    #[async_test]
789    async fn test_paginate_backward_with_limit() {
790        // Prepare test data.
791        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
792
793        let event_id = event_id!("$yoyoyo");
794        let event_factory = &room.event_factory;
795
796        *room.target_event_text.lock().await = "initial".to_owned();
797        *room.prev_batch_token.lock().await = Some("prev".to_owned());
798
799        // When I call `Paginator::start_from`, it works,
800        let paginator = Arc::new(Paginator::new(room.clone()));
801        let context =
802            paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
803
804        // And I get the events I expected.
805        assert_eq!(context.events.len(), 1);
806        assert_event_matches_msg(&context.events[0], "initial");
807        assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
808
809        // There's a previous batch.
810        assert!(context.has_prev);
811        assert!(!context.has_next);
812
813        // Preparing data for the next back-pagination.
814        *room.prev_events.lock().await = (0..100)
815            .rev()
816            .map(|i| event_factory.text_msg(format!("prev{i}")).into_event())
817            .collect();
818        *room.prev_batch_token.lock().await = None;
819
820        // When I backpaginate and request 100 events, I get only 10 events.
821        let prev =
822            paginator.paginate_backward(uint!(10)).await.expect("paginate backward should work");
823        assert!(prev.hit_end_of_timeline);
824        assert_eq!(prev.events.len(), 10);
825        for i in 0..10 {
826            assert_event_matches_msg(&prev.events[i], &format!("prev{}", 9 - i));
827        }
828    }
829
830    #[async_test]
831    async fn test_paginate_forward() {
832        // Prepare test data.
833        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
834
835        let event_id = event_id!("$yoyoyo");
836        let event_factory = &room.event_factory;
837
838        *room.target_event_text.lock().await = "initial".to_owned();
839        *room.next_batch_token.lock().await = Some("next".to_owned());
840
841        // When I call `Paginator::start_from`, it works,
842        let paginator = Arc::new(Paginator::new(room.clone()));
843        assert!(!paginator.hit_timeline_end(), "we must have a next-batch token");
844        assert!(
845            !paginator.hit_timeline_start(),
846            "we don't know about the status of the prev-batch token"
847        );
848
849        let context =
850            paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
851
852        // And I get the events I expected.
853        assert_eq!(context.events.len(), 1);
854        assert_event_matches_msg(&context.events[0], "initial");
855        assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
856
857        // There's a next batch, but no previous batch (i.e. we've hit the start of the
858        // timeline).
859        assert!(!context.has_prev);
860        assert!(context.has_next);
861
862        assert!(paginator.hit_timeline_start());
863        assert!(!paginator.hit_timeline_end());
864
865        // Preparing data for the next forward-pagination.
866        *room.next_events.lock().await = vec![event_factory.text_msg("next").into_event()];
867        *room.next_batch_token.lock().await = Some("next2".to_owned());
868
869        // When I forward-paginate, I get the events I expect.
870        let next =
871            paginator.paginate_forward(uint!(100)).await.expect("paginate forward should work");
872        assert!(!next.hit_end_of_timeline);
873        assert_eq!(next.events.len(), 1);
874        assert_event_matches_msg(&next.events[0], "next");
875        assert!(!paginator.hit_timeline_end());
876
877        // And I can forward-paginate again, because there's a prev batch token
878        // still.
879        *room.next_events.lock().await = vec![event_factory.text_msg("latest").into_event()];
880        *room.next_batch_token.lock().await = None;
881
882        let next = paginator
883            .paginate_forward(uint!(100))
884            .await
885            .expect("paginate forward the second time should work");
886        assert!(next.hit_end_of_timeline);
887        assert_eq!(next.events.len(), 1);
888        assert_event_matches_msg(&next.events[0], "latest");
889        assert!(paginator.hit_timeline_end());
890
891        // I've hit the start of the timeline, but back-paginating again will
892        // return immediately.
893        let next = paginator
894            .paginate_forward(uint!(100))
895            .await
896            .expect("paginate forward the third time should work");
897        assert!(next.hit_end_of_timeline);
898        assert!(next.events.is_empty());
899        assert!(paginator.hit_timeline_end());
900    }
901
902    #[async_test]
903    async fn test_state() {
904        let room = TestRoom::new(true, *ROOM_ID, *USER_ID);
905
906        *room.prev_batch_token.lock().await = Some("prev".to_owned());
907        *room.next_batch_token.lock().await = Some("next".to_owned());
908
909        let paginator = Arc::new(Paginator::new(room.clone()));
910
911        let event_id = event_id!("$yoyoyo");
912
913        let mut state = paginator.state();
914
915        assert_eq!(state.get(), PaginatorState::Initial);
916        assert!(state.next().now_or_never().is_none());
917
918        // Attempting to run pagination must fail and not change the state.
919        assert_invalid_state(
920            paginator.paginate_backward(uint!(100)),
921            PaginatorState::Idle,
922            PaginatorState::Initial,
923        )
924        .await;
925
926        assert!(state.next().now_or_never().is_none());
927
928        // Running the initial query must work.
929        let p = paginator.clone();
930        let join_handle = spawn(async move { p.start_from(event_id, uint!(100)).await });
931
932        assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
933        assert!(state.next().now_or_never().is_none());
934
935        // The query is pending. Running other operations must fail.
936        assert_invalid_state(
937            paginator.start_from(event_id, uint!(100)),
938            PaginatorState::Initial,
939            PaginatorState::FetchingTargetEvent,
940        )
941        .await;
942
943        assert_invalid_state(
944            paginator.paginate_backward(uint!(100)),
945            PaginatorState::Idle,
946            PaginatorState::FetchingTargetEvent,
947        )
948        .await;
949
950        assert!(state.next().now_or_never().is_none());
951
952        // Mark the dummy room as ready. The query may now terminate.
953        room.mark_ready();
954
955        // After fetching the initial event data, the paginator switches to `Idle`.
956        assert_eq!(state.next().await, Some(PaginatorState::Idle));
957
958        join_handle.await.expect("joined failed").expect("/context failed");
959
960        assert!(state.next().now_or_never().is_none());
961
962        let p = paginator.clone();
963        let join_handle = spawn(async move { p.paginate_backward(uint!(100)).await });
964
965        assert_eq!(state.next().await, Some(PaginatorState::Paginating));
966
967        // The query is pending. Running other operations must fail.
968        assert_invalid_state(
969            paginator.start_from(event_id, uint!(100)),
970            PaginatorState::Initial,
971            PaginatorState::Paginating,
972        )
973        .await;
974
975        assert_invalid_state(
976            paginator.paginate_backward(uint!(100)),
977            PaginatorState::Idle,
978            PaginatorState::Paginating,
979        )
980        .await;
981
982        assert_invalid_state(
983            paginator.paginate_forward(uint!(100)),
984            PaginatorState::Idle,
985            PaginatorState::Paginating,
986        )
987        .await;
988
989        assert!(state.next().now_or_never().is_none());
990
991        room.mark_ready();
992
993        assert_eq!(state.next().await, Some(PaginatorState::Idle));
994
995        join_handle.await.expect("joined failed").expect("/messages failed");
996
997        assert!(state.next().now_or_never().is_none());
998    }
999
1000    mod aborts {
1001        use super::*;
1002        use crate::event_cache::{paginator::PaginationTokens, PaginationToken};
1003
1004        #[derive(Clone, Default)]
1005        struct AbortingRoom {
1006            abort_handle: Arc<Mutex<Option<AbortHandle>>>,
1007            room_ready: Arc<Notify>,
1008        }
1009
1010        impl AbortingRoom {
1011            async fn wait_abort_and_yield(&self) -> ! {
1012                // Wait for the controller to tell us we're ready.
1013                self.room_ready.notified().await;
1014
1015                // Abort the given handle.
1016                let mut guard = self.abort_handle.lock().await;
1017                let handle = guard.take().expect("only call me when i'm initialized");
1018                handle.abort();
1019
1020                // Enter an endless loop of yielding.
1021                loop {
1022                    tokio::task::yield_now().await;
1023                }
1024            }
1025        }
1026
1027        impl PaginableRoom for AbortingRoom {
1028            async fn event_with_context(
1029                &self,
1030                _event_id: &EventId,
1031                _lazy_load_members: bool,
1032                _num_events: UInt,
1033            ) -> Result<EventWithContextResponse, PaginatorError> {
1034                self.wait_abort_and_yield().await
1035            }
1036
1037            async fn messages(&self, _opts: MessagesOptions) -> Result<Messages, PaginatorError> {
1038                self.wait_abort_and_yield().await
1039            }
1040        }
1041
1042        #[async_test]
1043        async fn test_abort_while_starting_from() {
1044            let room = AbortingRoom::default();
1045
1046            let paginator = Arc::new(Paginator::new(room.clone()));
1047
1048            let mut state = paginator.state();
1049
1050            assert_eq!(state.get(), PaginatorState::Initial);
1051            assert!(state.next().now_or_never().is_none());
1052
1053            // When I try to start the initial query…
1054            let p = paginator.clone();
1055            let join_handle = spawn(async move {
1056                let _ = p.start_from(event_id!("$yoyoyo"), uint!(100)).await;
1057            });
1058
1059            *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1060
1061            assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
1062            assert!(state.next().now_or_never().is_none());
1063
1064            room.room_ready.notify_one();
1065
1066            // But it's aborted when awaiting the task.
1067            let join_result = join_handle.await;
1068            assert!(join_result.unwrap_err().is_cancelled());
1069
1070            // Then the state is reset to initial.
1071            assert_eq!(state.next().await, Some(PaginatorState::Initial));
1072            assert!(state.next().now_or_never().is_none());
1073        }
1074
1075        #[async_test]
1076        async fn test_abort_while_paginating() {
1077            let room = AbortingRoom::default();
1078
1079            // Assuming a paginator ready to back- or forward- paginate,
1080            let paginator = Paginator::new(room.clone());
1081            paginator.state.set(PaginatorState::Idle);
1082            *paginator.tokens.lock().unwrap() = PaginationTokens {
1083                previous: PaginationToken::HasMore("prev".to_owned()),
1084                next: PaginationToken::HasMore("next".to_owned()),
1085            };
1086
1087            let paginator = Arc::new(paginator);
1088
1089            let mut state = paginator.state();
1090
1091            assert_eq!(state.get(), PaginatorState::Idle);
1092            assert!(state.next().now_or_never().is_none());
1093
1094            // When I try to back-paginate…
1095            let p = paginator.clone();
1096            let join_handle = spawn(async move {
1097                let _ = p.paginate_backward(uint!(100)).await;
1098            });
1099
1100            *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1101
1102            assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1103            assert!(state.next().now_or_never().is_none());
1104
1105            room.room_ready.notify_one();
1106
1107            // But it's aborted when awaiting the task.
1108            let join_result = join_handle.await;
1109            assert!(join_result.unwrap_err().is_cancelled());
1110
1111            // Then the state is reset to idle.
1112            assert_eq!(state.next().await, Some(PaginatorState::Idle));
1113            assert!(state.next().now_or_never().is_none());
1114
1115            // And ditto for forward pagination.
1116            let p = paginator.clone();
1117            let join_handle = spawn(async move {
1118                let _ = p.paginate_forward(uint!(100)).await;
1119            });
1120
1121            *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1122
1123            assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1124            assert!(state.next().now_or_never().is_none());
1125
1126            room.room_ready.notify_one();
1127
1128            let join_result = join_handle.await;
1129            assert!(join_result.unwrap_err().is_cancelled());
1130
1131            assert_eq!(state.next().await, Some(PaginatorState::Idle));
1132            assert!(state.next().now_or_never().is_none());
1133        }
1134    }
1135}