Skip to main content

matrix_sdk/paginators/
room.rs

1// Copyright 2024 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! The paginator is a stateful helper object that handles reaching an event,
16//! either from a cache or network, and surrounding events ("context"). Then, it
17//! makes it possible to paginate forward or backward, from that event, until
18//! one end of the timeline (front or back) is reached.
19
20use std::{future::Future, sync::Mutex};
21
22use eyeball::{SharedObservable, Subscriber};
23use matrix_sdk_base::{SendOutsideWasm, SyncOutsideWasm, deserialized_responses::TimelineEvent};
24use ruma::{EventId, UInt, api::Direction};
25
26use crate::{
27    Room,
28    paginators::{PaginationResult, PaginationToken, PaginatorError},
29    room::{EventWithContextResponse, Messages, MessagesOptions},
30};
31
32/// Current state of a [`Paginator`].
33#[derive(Debug, PartialEq, Copy, Clone)]
34#[cfg_attr(feature = "uniffi", derive(uniffi::Enum))]
35pub enum PaginatorState {
36    /// The initial state of the paginator.
37    Initial,
38
39    /// The paginator is fetching the target initial event.
40    FetchingTargetEvent,
41
42    /// The target initial event could be found, zero or more paginations have
43    /// happened since then, and the paginator is at rest now.
44    Idle,
45
46    /// The paginator is… paginating one direction or another.
47    Paginating,
48}
49
50/// Paginations tokens used for backward and forward pagination.
51#[derive(Debug, Clone)]
52pub struct PaginationTokens {
53    /// Pagination token used for backward pagination.
54    pub previous: PaginationToken,
55    /// Pagination token used for forward pagination.
56    pub next: PaginationToken,
57}
58
59/// A stateful object to reach to an event, and then paginate backward and
60/// forward from it.
61///
62/// See also the module-level documentation.
63pub struct Paginator<PR: PaginableRoom> {
64    /// The room in which we're going to run the pagination.
65    room: PR,
66
67    /// Current state of the paginator.
68    state: SharedObservable<PaginatorState>,
69
70    /// Pagination tokens used for subsequent requests.
71    ///
72    /// This mutex is always short-lived, so it's sync.
73    tokens: Mutex<PaginationTokens>,
74}
75
76#[cfg(not(tarpaulin_include))]
77impl<PR: PaginableRoom> std::fmt::Debug for Paginator<PR> {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        // Don't include the room in the debug output.
80        f.debug_struct("Paginator")
81            .field("state", &self.state.get())
82            .field("tokens", &self.tokens)
83            .finish_non_exhaustive()
84    }
85}
86
87/// The result of an initial [`Paginator::start_from`] query.
88#[derive(Debug)]
89pub struct StartFromResult {
90    /// All the events returned during this pagination, in topological ordering.
91    pub events: Vec<TimelineEvent>,
92
93    /// Whether the /context query returned a previous batch token.
94    pub has_prev: bool,
95
96    /// Whether the /context query returned a next batch token.
97    pub has_next: bool,
98}
99
100/// Reset the state to a given target on drop.
101struct ResetStateGuard {
102    target: Option<PaginatorState>,
103    state: SharedObservable<PaginatorState>,
104}
105
106impl ResetStateGuard {
107    /// Create a new reset state guard.
108    fn new(state: SharedObservable<PaginatorState>, target: PaginatorState) -> Self {
109        Self { target: Some(target), state }
110    }
111
112    /// Render the guard effectless, and consume it.
113    fn disarm(mut self) {
114        self.target = None;
115    }
116}
117
118impl Drop for ResetStateGuard {
119    fn drop(&mut self) {
120        if let Some(target) = self.target.take() {
121            self.state.set_if_not_eq(target);
122        }
123    }
124}
125
126impl<PR: PaginableRoom> Paginator<PR> {
127    /// Create a new [`Paginator`], given a room implementation.
128    pub fn new(room: PR) -> Self {
129        Self {
130            room,
131            state: SharedObservable::new(PaginatorState::Initial),
132            tokens: Mutex::new(PaginationTokens { previous: None.into(), next: None.into() }),
133        }
134    }
135
136    /// Check if the current state of the paginator matches the expected one.
137    fn check_state(&self, expected: PaginatorState) -> Result<(), PaginatorError> {
138        let actual = self.state.get();
139        if actual != expected {
140            Err(PaginatorError::InvalidPreviousState { expected, actual })
141        } else {
142            Ok(())
143        }
144    }
145
146    /// Returns a subscriber to the internal [`PaginatorState`] machine.
147    pub fn state(&self) -> Subscriber<PaginatorState> {
148        self.state.subscribe()
149    }
150
151    /// Starts the pagination from the initial event, requesting `num_events`
152    /// additional context events.
153    ///
154    /// Only works for fresh [`Paginator`] objects, which are in the
155    /// [`PaginatorState::Initial`] state.
156    pub async fn start_from(
157        &self,
158        event_id: &EventId,
159        num_events: UInt,
160    ) -> Result<StartFromResult, PaginatorError> {
161        self.check_state(PaginatorState::Initial)?;
162
163        // Note: it's possible two callers have checked the state and both figured it's
164        // initial. This check below makes sure there's at most one which can set the
165        // state to FetchingTargetEvent, preventing a race condition.
166        if self.state.set_if_not_eq(PaginatorState::FetchingTargetEvent).is_none() {
167            return Err(PaginatorError::InvalidPreviousState {
168                expected: PaginatorState::Initial,
169                actual: PaginatorState::FetchingTargetEvent,
170            });
171        }
172
173        let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Initial);
174
175        // TODO: do we want to lazy load members?
176        let lazy_load_members = true;
177
178        let response =
179            self.room.event_with_context(event_id, lazy_load_members, num_events).await?;
180
181        // NOTE: it's super important to not have any `await` after this point, since we
182        // don't want the task to be interrupted anymore, or the internal state
183        // may become incorrect.
184
185        let has_prev = response.prev_batch_token.is_some();
186        let has_next = response.next_batch_token.is_some();
187
188        {
189            let mut tokens = self.tokens.lock().unwrap();
190            tokens.previous = match response.prev_batch_token {
191                Some(token) => PaginationToken::HasMore(token),
192                None => PaginationToken::HitEnd,
193            };
194            tokens.next = match response.next_batch_token {
195                Some(token) => PaginationToken::HasMore(token),
196                None => PaginationToken::HitEnd,
197            };
198        }
199
200        // Forget the reset state guard, so its Drop method is not called.
201        reset_state_guard.disarm();
202        // And set the final state.
203        self.state.set(PaginatorState::Idle);
204
205        // Consolidate the events into a linear timeline, topologically ordered.
206        // - the events before are returned in the reverse topological order: invert
207        //   them.
208        // - insert the target event, if set.
209        // - the events after are returned in the correct topological order.
210
211        let events = response
212            .events_before
213            .into_iter()
214            .rev()
215            .chain(response.event)
216            .chain(response.events_after)
217            .collect();
218
219        Ok(StartFromResult { events, has_prev, has_next })
220    }
221
222    /// Runs a backward pagination (requesting `num_events` to the server), from
223    /// the current state of the object.
224    ///
225    /// Will return immediately if we have already hit the start of the
226    /// timeline.
227    ///
228    /// May return an error if it's already paginating, or if the call to
229    /// /messages failed.
230    pub async fn paginate_backward(
231        &self,
232        num_events: UInt,
233    ) -> Result<PaginationResult, PaginatorError> {
234        self.paginate(Direction::Backward, num_events).await
235    }
236
237    /// Returns whether we've hit the start of the timeline.
238    ///
239    /// This is true if, and only if, we didn't have a previous-batch token and
240    /// running backwards pagination would be useless.
241    pub fn hit_timeline_start(&self) -> bool {
242        matches!(self.tokens.lock().unwrap().previous, PaginationToken::HitEnd)
243    }
244
245    /// Returns whether we've hit the end of the timeline.
246    ///
247    /// This is true if, and only if, we didn't have a next-batch token and
248    /// running forwards pagination would be useless.
249    pub fn hit_timeline_end(&self) -> bool {
250        matches!(self.tokens.lock().unwrap().next, PaginationToken::HitEnd)
251    }
252
253    /// Runs a forward pagination (requesting `num_events` to the server), from
254    /// the current state of the object.
255    ///
256    /// Will return immediately if we have already hit the end of the timeline.
257    ///
258    /// May return an error if it's already paginating, or if the call to
259    /// /messages failed.
260    pub async fn paginate_forward(
261        &self,
262        num_events: UInt,
263    ) -> Result<PaginationResult, PaginatorError> {
264        self.paginate(Direction::Forward, num_events).await
265    }
266
267    /// Paginate in the given direction, requesting `num_events` events to the
268    /// server, using the `token_lock` to read from and write the pagination
269    /// token.
270    async fn paginate(
271        &self,
272        dir: Direction,
273        num_events: UInt,
274    ) -> Result<PaginationResult, PaginatorError> {
275        self.check_state(PaginatorState::Idle)?;
276
277        let token = {
278            let tokens = self.tokens.lock().unwrap();
279
280            let token = match dir {
281                Direction::Backward => &tokens.previous,
282                Direction::Forward => &tokens.next,
283            };
284
285            match token {
286                PaginationToken::None => None,
287                PaginationToken::HasMore(val) => Some(val.clone()),
288                PaginationToken::HitEnd => {
289                    return Ok(PaginationResult { events: Vec::new(), hit_end_of_timeline: true });
290                }
291            }
292        };
293
294        // Note: it's possible two callers have checked the state and both figured it's
295        // idle. This check below makes sure there's at most one which can set the
296        // state to paginating, preventing a race condition.
297        if self.state.set_if_not_eq(PaginatorState::Paginating).is_none() {
298            return Err(PaginatorError::InvalidPreviousState {
299                expected: PaginatorState::Idle,
300                actual: PaginatorState::Paginating,
301            });
302        }
303
304        let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Idle);
305
306        let mut options = MessagesOptions::new(dir).from(token.as_deref());
307        options.limit = num_events;
308
309        // In case of error, the state is reset to idle automatically thanks to
310        // reset_state_guard.
311        let response = self.room.messages(options).await?;
312
313        // NOTE: it's super important to not have any `await` after this point, since we
314        // don't want the task to be interrupted anymore, or the internal state
315        // may be incorrect.
316
317        let hit_end_of_timeline = response.end.is_none();
318
319        {
320            let mut tokens = self.tokens.lock().unwrap();
321
322            let token = match dir {
323                Direction::Backward => &mut tokens.previous,
324                Direction::Forward => &mut tokens.next,
325            };
326
327            *token = match response.end {
328                Some(val) => PaginationToken::HasMore(val),
329                None => PaginationToken::HitEnd,
330            };
331        }
332
333        // TODO: what to do with state events?
334
335        // Forget the reset state guard, so its Drop method is not called.
336        reset_state_guard.disarm();
337        // And set the final state.
338        self.state.set(PaginatorState::Idle);
339
340        Ok(PaginationResult { events: response.chunk, hit_end_of_timeline })
341    }
342
343    /// Returns the current pagination tokens.
344    pub fn tokens(&self) -> PaginationTokens {
345        self.tokens.lock().unwrap().clone()
346    }
347}
348
349/// A room that can be paginated.
350///
351/// Not [`crate::Room`] because we may want to paginate rooms we don't belong
352/// to.
353pub trait PaginableRoom: SendOutsideWasm + SyncOutsideWasm {
354    /// Runs a /context query for the given room.
355    ///
356    /// ## Parameters
357    ///
358    /// - `event_id` is the identifier of the target event.
359    /// - `lazy_load_members` controls whether room membership events are lazily
360    ///   loaded as context state events.
361    /// - `num_events` is the number of events (including the fetched event) to
362    ///   return as context.
363    ///
364    /// ## Returns
365    ///
366    /// Must return [`PaginatorError::EventNotFound`] whenever the target event
367    /// could not be found, instead of causing an http `Err` result.
368    fn event_with_context(
369        &self,
370        event_id: &EventId,
371        lazy_load_members: bool,
372        num_events: UInt,
373    ) -> impl Future<Output = Result<EventWithContextResponse, PaginatorError>> + SendOutsideWasm;
374
375    /// Runs a /messages query for the given room.
376    fn messages(
377        &self,
378        opts: MessagesOptions,
379    ) -> impl Future<Output = Result<Messages, PaginatorError>> + SendOutsideWasm;
380}
381
382impl PaginableRoom for Room {
383    async fn event_with_context(
384        &self,
385        event_id: &EventId,
386        lazy_load_members: bool,
387        num_events: UInt,
388    ) -> Result<EventWithContextResponse, PaginatorError> {
389        let response =
390            match self.event_with_context(event_id, lazy_load_members, num_events, None).await {
391                Ok(result) => result,
392
393                Err(err) => {
394                    // If the error was a 404, then the event wasn't found on the server;
395                    // special case this to make it easy to react to
396                    // such an error.
397                    if let Some(error) = err.as_client_api_error()
398                        && error.status_code == 404
399                    {
400                        // Event not found
401                        return Err(PaginatorError::EventNotFound(event_id.to_owned()));
402                    }
403
404                    // Otherwise, just return a wrapped error.
405                    return Err(PaginatorError::SdkError(Box::new(err)));
406                }
407            };
408
409        Ok(response)
410    }
411
412    async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
413        self.messages(opts).await.map_err(|err| PaginatorError::SdkError(Box::new(err)))
414    }
415}
416
417#[cfg(all(not(target_family = "wasm"), test))]
418mod tests {
419    use std::sync::{Arc, LazyLock};
420
421    use assert_matches2::assert_let;
422    use futures_core::Future;
423    use futures_util::FutureExt as _;
424    use matrix_sdk_base::deserialized_responses::TimelineEvent;
425    use matrix_sdk_test::{async_test, event_factory::EventFactory};
426    use ruma::{EventId, RoomId, UInt, UserId, api::Direction, event_id, room_id, uint, user_id};
427    use tokio::{
428        spawn,
429        sync::{Mutex, Notify},
430        task::AbortHandle,
431    };
432
433    use super::{PaginableRoom, PaginatorError, PaginatorState};
434    use crate::{
435        paginators::Paginator,
436        room::{EventWithContextResponse, Messages, MessagesOptions},
437        test_utils::assert_event_matches_msg,
438    };
439
440    #[derive(Clone)]
441    struct TestRoom {
442        event_factory: Arc<EventFactory>,
443        wait_for_ready: bool,
444
445        target_event_text: Arc<Mutex<String>>,
446        next_events: Arc<Mutex<Vec<TimelineEvent>>>,
447        prev_events: Arc<Mutex<Vec<TimelineEvent>>>,
448        prev_batch_token: Arc<Mutex<Option<String>>>,
449        next_batch_token: Arc<Mutex<Option<String>>>,
450
451        room_ready: Arc<Notify>,
452    }
453
454    impl TestRoom {
455        fn new(wait_for_ready: bool, room_id: &RoomId, sender: &UserId) -> Self {
456            let event_factory = Arc::new(EventFactory::default().sender(sender).room(room_id));
457
458            Self {
459                event_factory,
460                wait_for_ready,
461
462                room_ready: Default::default(),
463                target_event_text: Default::default(),
464                next_events: Default::default(),
465                prev_events: Default::default(),
466                prev_batch_token: Default::default(),
467                next_batch_token: Default::default(),
468            }
469        }
470
471        /// Unblocks the next request.
472        fn mark_ready(&self) {
473            self.room_ready.notify_one();
474        }
475    }
476
477    static ROOM_ID: LazyLock<&RoomId> = LazyLock::new(|| room_id!("!dune:herbert.org"));
478    static USER_ID: LazyLock<&UserId> = LazyLock::new(|| user_id!("@paul:atreid.es"));
479
480    impl PaginableRoom for TestRoom {
481        async fn event_with_context(
482            &self,
483            event_id: &EventId,
484            _lazy_load_members: bool,
485            num_events: UInt,
486        ) -> Result<EventWithContextResponse, PaginatorError> {
487            // Wait for the room to be marked as ready first.
488            if self.wait_for_ready {
489                self.room_ready.notified().await;
490            }
491
492            let event = self
493                .event_factory
494                .text_msg(self.target_event_text.lock().await.clone())
495                .event_id(event_id)
496                .into_event();
497
498            // Properly simulate `num_events`: take either the closest num_events events
499            // before, or use all of the before events and then consume after events.
500            let mut num_events = u64::from(num_events) as usize;
501
502            let prev_events = self.prev_events.lock().await;
503
504            let events_before = if prev_events.is_empty() {
505                Vec::new()
506            } else {
507                let len = prev_events.len();
508                let take_before = num_events.min(len);
509                // Subtract is safe because take_before <= num_events.
510                num_events -= take_before;
511                // Subtract is safe because take_before <= len
512                prev_events[len - take_before..len].to_vec()
513            };
514
515            let events_after = self.next_events.lock().await;
516            let events_after = if events_after.is_empty() {
517                Vec::new()
518            } else {
519                events_after[0..num_events.min(events_after.len())].to_vec()
520            };
521
522            Ok(EventWithContextResponse {
523                event: Some(event),
524                events_before,
525                events_after,
526                prev_batch_token: self.prev_batch_token.lock().await.clone(),
527                next_batch_token: self.next_batch_token.lock().await.clone(),
528                state: Vec::new(),
529            })
530        }
531
532        async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
533            if self.wait_for_ready {
534                self.room_ready.notified().await;
535            }
536
537            let limit = u64::from(opts.limit) as usize;
538
539            let (end, events) = match opts.dir {
540                Direction::Backward => {
541                    let events = self.prev_events.lock().await;
542                    let events = if events.is_empty() {
543                        Vec::new()
544                    } else {
545                        let len = events.len();
546                        let take_before = limit.min(len);
547                        // Subtract is safe because take_before <= len
548                        events[len - take_before..len].to_vec()
549                    };
550                    (self.prev_batch_token.lock().await.clone(), events)
551                }
552
553                Direction::Forward => {
554                    let events = self.next_events.lock().await;
555                    let events = if events.is_empty() {
556                        Vec::new()
557                    } else {
558                        events[0..limit.min(events.len())].to_vec()
559                    };
560                    (self.next_batch_token.lock().await.clone(), events)
561                }
562            };
563
564            Ok(Messages { start: opts.from.unwrap(), end, chunk: events, state: Vec::new() })
565        }
566    }
567
568    async fn assert_invalid_state<T: std::fmt::Debug>(
569        task: impl Future<Output = Result<T, PaginatorError>>,
570        expected: PaginatorState,
571        actual: PaginatorState,
572    ) {
573        assert_let!(
574            Err(PaginatorError::InvalidPreviousState {
575                expected: real_expected,
576                actual: real_actual
577            }) = task.await
578        );
579        assert_eq!(real_expected, expected);
580        assert_eq!(real_actual, actual);
581    }
582
583    #[async_test]
584    async fn test_start_from() {
585        // Prepare test data.
586        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
587
588        let event_id = event_id!("$yoyoyo");
589        let event_factory = &room.event_factory;
590
591        *room.target_event_text.lock().await = "fetch_from".to_owned();
592        *room.prev_events.lock().await = (0..10)
593            .rev()
594            .map(|i| event_factory.text_msg(format!("before-{i}")).into_event())
595            .collect();
596        *room.next_events.lock().await =
597            (0..10).map(|i| event_factory.text_msg(format!("after-{i}")).into_event()).collect();
598
599        // When I call `Paginator::start_from`, it works,
600        let paginator = Arc::new(Paginator::new(room.clone()));
601        let context =
602            paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
603
604        assert!(!context.has_prev);
605        assert!(!context.has_next);
606
607        // And I get the events I expected.
608
609        // 10 events before, the target event, 10 events after.
610        assert_eq!(context.events.len(), 21);
611
612        for i in 0..10 {
613            assert_event_matches_msg(&context.events[i], &format!("before-{i}"));
614        }
615
616        assert_event_matches_msg(&context.events[10], "fetch_from");
617        assert_eq!(context.events[10].raw().deserialize().unwrap().event_id(), event_id);
618
619        for i in 0..10 {
620            assert_event_matches_msg(&context.events[i + 11], &format!("after-{i}"));
621        }
622    }
623
624    #[async_test]
625    async fn test_start_from_with_num_events() {
626        // Prepare test data.
627        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
628
629        let event_id = event_id!("$yoyoyo");
630        let event_factory = &room.event_factory;
631
632        *room.target_event_text.lock().await = "fetch_from".to_owned();
633        *room.prev_events.lock().await =
634            (0..100).rev().map(|i| event_factory.text_msg(format!("ev{i}")).into_event()).collect();
635
636        // When I call `Paginator::start_from`, it works,
637        let paginator = Arc::new(Paginator::new(room.clone()));
638        let context =
639            paginator.start_from(event_id, uint!(10)).await.expect("start_from should work");
640
641        // Then I only get 10 events + the target event, even if there was more than 10
642        // events in the room.
643        assert_eq!(context.events.len(), 11);
644
645        for i in 0..10 {
646            assert_event_matches_msg(&context.events[i], &format!("ev{i}"));
647        }
648        assert_event_matches_msg(&context.events[10], "fetch_from");
649    }
650
651    #[async_test]
652    async fn test_paginate_backward() {
653        // Prepare test data.
654        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
655
656        let event_id = event_id!("$yoyoyo");
657        let event_factory = &room.event_factory;
658
659        *room.target_event_text.lock().await = "initial".to_owned();
660        *room.prev_batch_token.lock().await = Some("prev".to_owned());
661
662        // When I call `Paginator::start_from`, it works,
663        let paginator = Arc::new(Paginator::new(room.clone()));
664
665        assert!(!paginator.hit_timeline_start(), "we must have a prev-batch token");
666        assert!(
667            !paginator.hit_timeline_end(),
668            "we don't know about the status of the next-batch token"
669        );
670
671        let context =
672            paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
673
674        // And I get the events I expected.
675        assert_eq!(context.events.len(), 1);
676        assert_event_matches_msg(&context.events[0], "initial");
677        assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
678
679        // There's a previous batch, but no next batch.
680        assert!(context.has_prev);
681        assert!(!context.has_next);
682
683        assert!(!paginator.hit_timeline_start());
684        assert!(paginator.hit_timeline_end());
685
686        // Preparing data for the next back-pagination.
687        *room.prev_events.lock().await = vec![event_factory.text_msg("previous").into_event()];
688        *room.prev_batch_token.lock().await = Some("prev2".to_owned());
689
690        // When I backpaginate, I get the events I expect.
691        let prev =
692            paginator.paginate_backward(uint!(100)).await.expect("paginate backward should work");
693        assert!(!prev.hit_end_of_timeline);
694        assert!(!paginator.hit_timeline_start());
695        assert_eq!(prev.events.len(), 1);
696        assert_event_matches_msg(&prev.events[0], "previous");
697
698        // And I can backpaginate again, because there's a prev batch token
699        // still.
700        *room.prev_events.lock().await = vec![event_factory.text_msg("oldest").into_event()];
701        *room.prev_batch_token.lock().await = None;
702
703        let prev = paginator
704            .paginate_backward(uint!(100))
705            .await
706            .expect("paginate backward the second time should work");
707        assert!(prev.hit_end_of_timeline);
708        assert!(paginator.hit_timeline_start());
709        assert_eq!(prev.events.len(), 1);
710        assert_event_matches_msg(&prev.events[0], "oldest");
711
712        // I've hit the start of the timeline, but back-paginating again will
713        // return immediately.
714        let prev = paginator
715            .paginate_backward(uint!(100))
716            .await
717            .expect("paginate backward the third time should work");
718        assert!(prev.hit_end_of_timeline);
719        assert!(paginator.hit_timeline_start());
720        assert!(prev.events.is_empty());
721    }
722
723    #[async_test]
724    async fn test_paginate_backward_with_limit() {
725        // Prepare test data.
726        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
727
728        let event_id = event_id!("$yoyoyo");
729        let event_factory = &room.event_factory;
730
731        *room.target_event_text.lock().await = "initial".to_owned();
732        *room.prev_batch_token.lock().await = Some("prev".to_owned());
733
734        // When I call `Paginator::start_from`, it works,
735        let paginator = Arc::new(Paginator::new(room.clone()));
736        let context =
737            paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
738
739        // And I get the events I expected.
740        assert_eq!(context.events.len(), 1);
741        assert_event_matches_msg(&context.events[0], "initial");
742        assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
743
744        // There's a previous batch.
745        assert!(context.has_prev);
746        assert!(!context.has_next);
747
748        // Preparing data for the next back-pagination.
749        *room.prev_events.lock().await = (0..100)
750            .rev()
751            .map(|i| event_factory.text_msg(format!("prev{i}")).into_event())
752            .collect();
753        *room.prev_batch_token.lock().await = None;
754
755        // When I backpaginate and request 100 events, I get only 10 events.
756        let prev =
757            paginator.paginate_backward(uint!(10)).await.expect("paginate backward should work");
758        assert!(prev.hit_end_of_timeline);
759        assert_eq!(prev.events.len(), 10);
760        for i in 0..10 {
761            assert_event_matches_msg(&prev.events[i], &format!("prev{}", 9 - i));
762        }
763    }
764
765    #[async_test]
766    async fn test_paginate_forward() {
767        // Prepare test data.
768        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
769
770        let event_id = event_id!("$yoyoyo");
771        let event_factory = &room.event_factory;
772
773        *room.target_event_text.lock().await = "initial".to_owned();
774        *room.next_batch_token.lock().await = Some("next".to_owned());
775
776        // When I call `Paginator::start_from`, it works,
777        let paginator = Arc::new(Paginator::new(room.clone()));
778        assert!(!paginator.hit_timeline_end(), "we must have a next-batch token");
779        assert!(
780            !paginator.hit_timeline_start(),
781            "we don't know about the status of the prev-batch token"
782        );
783
784        let context =
785            paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
786
787        // And I get the events I expected.
788        assert_eq!(context.events.len(), 1);
789        assert_event_matches_msg(&context.events[0], "initial");
790        assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
791
792        // There's a next batch, but no previous batch (i.e. we've hit the start of the
793        // timeline).
794        assert!(!context.has_prev);
795        assert!(context.has_next);
796
797        assert!(paginator.hit_timeline_start());
798        assert!(!paginator.hit_timeline_end());
799
800        // Preparing data for the next forward-pagination.
801        *room.next_events.lock().await = vec![event_factory.text_msg("next").into_event()];
802        *room.next_batch_token.lock().await = Some("next2".to_owned());
803
804        // When I forward-paginate, I get the events I expect.
805        let next =
806            paginator.paginate_forward(uint!(100)).await.expect("paginate forward should work");
807        assert!(!next.hit_end_of_timeline);
808        assert_eq!(next.events.len(), 1);
809        assert_event_matches_msg(&next.events[0], "next");
810        assert!(!paginator.hit_timeline_end());
811
812        // And I can forward-paginate again, because there's a prev batch token
813        // still.
814        *room.next_events.lock().await = vec![event_factory.text_msg("latest").into_event()];
815        *room.next_batch_token.lock().await = None;
816
817        let next = paginator
818            .paginate_forward(uint!(100))
819            .await
820            .expect("paginate forward the second time should work");
821        assert!(next.hit_end_of_timeline);
822        assert_eq!(next.events.len(), 1);
823        assert_event_matches_msg(&next.events[0], "latest");
824        assert!(paginator.hit_timeline_end());
825
826        // I've hit the start of the timeline, but back-paginating again will
827        // return immediately.
828        let next = paginator
829            .paginate_forward(uint!(100))
830            .await
831            .expect("paginate forward the third time should work");
832        assert!(next.hit_end_of_timeline);
833        assert!(next.events.is_empty());
834        assert!(paginator.hit_timeline_end());
835    }
836
837    #[async_test]
838    async fn test_state() {
839        let room = TestRoom::new(true, *ROOM_ID, *USER_ID);
840
841        *room.prev_batch_token.lock().await = Some("prev".to_owned());
842        *room.next_batch_token.lock().await = Some("next".to_owned());
843
844        let paginator = Arc::new(Paginator::new(room.clone()));
845
846        let event_id = event_id!("$yoyoyo");
847
848        let mut state = paginator.state();
849
850        assert_eq!(state.get(), PaginatorState::Initial);
851        assert!(state.next().now_or_never().is_none());
852
853        // Attempting to run pagination must fail and not change the state.
854        assert_invalid_state(
855            paginator.paginate_backward(uint!(100)),
856            PaginatorState::Idle,
857            PaginatorState::Initial,
858        )
859        .await;
860
861        assert!(state.next().now_or_never().is_none());
862
863        // Running the initial query must work.
864        let p = paginator.clone();
865        let join_handle = spawn(async move { p.start_from(event_id, uint!(100)).await });
866
867        assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
868        assert!(state.next().now_or_never().is_none());
869
870        // The query is pending. Running other operations must fail.
871        assert_invalid_state(
872            paginator.start_from(event_id, uint!(100)),
873            PaginatorState::Initial,
874            PaginatorState::FetchingTargetEvent,
875        )
876        .await;
877
878        assert_invalid_state(
879            paginator.paginate_backward(uint!(100)),
880            PaginatorState::Idle,
881            PaginatorState::FetchingTargetEvent,
882        )
883        .await;
884
885        assert!(state.next().now_or_never().is_none());
886
887        // Mark the dummy room as ready. The query may now terminate.
888        room.mark_ready();
889
890        // After fetching the initial event data, the paginator switches to `Idle`.
891        assert_eq!(state.next().await, Some(PaginatorState::Idle));
892
893        join_handle.await.expect("joined failed").expect("/context failed");
894
895        assert!(state.next().now_or_never().is_none());
896
897        let p = paginator.clone();
898        let join_handle = spawn(async move { p.paginate_backward(uint!(100)).await });
899
900        assert_eq!(state.next().await, Some(PaginatorState::Paginating));
901
902        // The query is pending. Running other operations must fail.
903        assert_invalid_state(
904            paginator.start_from(event_id, uint!(100)),
905            PaginatorState::Initial,
906            PaginatorState::Paginating,
907        )
908        .await;
909
910        assert_invalid_state(
911            paginator.paginate_backward(uint!(100)),
912            PaginatorState::Idle,
913            PaginatorState::Paginating,
914        )
915        .await;
916
917        assert_invalid_state(
918            paginator.paginate_forward(uint!(100)),
919            PaginatorState::Idle,
920            PaginatorState::Paginating,
921        )
922        .await;
923
924        assert!(state.next().now_or_never().is_none());
925
926        room.mark_ready();
927
928        assert_eq!(state.next().await, Some(PaginatorState::Idle));
929
930        join_handle.await.expect("joined failed").expect("/messages failed");
931
932        assert!(state.next().now_or_never().is_none());
933    }
934
935    mod aborts {
936        use super::*;
937        use crate::paginators::room::{PaginationToken, PaginationTokens};
938
939        #[derive(Clone, Default)]
940        struct AbortingRoom {
941            abort_handle: Arc<Mutex<Option<AbortHandle>>>,
942            room_ready: Arc<Notify>,
943        }
944
945        impl AbortingRoom {
946            async fn wait_abort_and_yield(&self) -> ! {
947                // Wait for the controller to tell us we're ready.
948                self.room_ready.notified().await;
949
950                // Abort the given handle.
951                let mut guard = self.abort_handle.lock().await;
952                let handle = guard.take().expect("only call me when i'm initialized");
953                handle.abort();
954
955                // Enter an endless loop of yielding.
956                loop {
957                    tokio::task::yield_now().await;
958                }
959            }
960        }
961
962        impl PaginableRoom for AbortingRoom {
963            async fn event_with_context(
964                &self,
965                _event_id: &EventId,
966                _lazy_load_members: bool,
967                _num_events: UInt,
968            ) -> Result<EventWithContextResponse, PaginatorError> {
969                self.wait_abort_and_yield().await
970            }
971
972            async fn messages(&self, _opts: MessagesOptions) -> Result<Messages, PaginatorError> {
973                self.wait_abort_and_yield().await
974            }
975        }
976
977        #[async_test]
978        async fn test_abort_while_starting_from() {
979            let room = AbortingRoom::default();
980
981            let paginator = Arc::new(Paginator::new(room.clone()));
982
983            let mut state = paginator.state();
984
985            assert_eq!(state.get(), PaginatorState::Initial);
986            assert!(state.next().now_or_never().is_none());
987
988            // When I try to start the initial query…
989            let p = paginator.clone();
990            let join_handle = spawn(async move {
991                let _ = p.start_from(event_id!("$yoyoyo"), uint!(100)).await;
992            });
993
994            *room.abort_handle.lock().await = Some(join_handle.abort_handle());
995
996            assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
997            assert!(state.next().now_or_never().is_none());
998
999            room.room_ready.notify_one();
1000
1001            // But it's aborted when awaiting the task.
1002            let join_result = join_handle.await;
1003            assert!(join_result.unwrap_err().is_cancelled());
1004
1005            // Then the state is reset to initial.
1006            assert_eq!(state.next().await, Some(PaginatorState::Initial));
1007            assert!(state.next().now_or_never().is_none());
1008        }
1009
1010        #[async_test]
1011        async fn test_abort_while_paginating() {
1012            let room = AbortingRoom::default();
1013
1014            // Assuming a paginator ready to back- or forward- paginate,
1015            let paginator = Paginator::new(room.clone());
1016            paginator.state.set(PaginatorState::Idle);
1017            *paginator.tokens.lock().unwrap() = PaginationTokens {
1018                previous: PaginationToken::HasMore("prev".to_owned()),
1019                next: PaginationToken::HasMore("next".to_owned()),
1020            };
1021
1022            let paginator = Arc::new(paginator);
1023
1024            let mut state = paginator.state();
1025
1026            assert_eq!(state.get(), PaginatorState::Idle);
1027            assert!(state.next().now_or_never().is_none());
1028
1029            // When I try to back-paginate…
1030            let p = paginator.clone();
1031            let join_handle = spawn(async move {
1032                let _ = p.paginate_backward(uint!(100)).await;
1033            });
1034
1035            *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1036
1037            assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1038            assert!(state.next().now_or_never().is_none());
1039
1040            room.room_ready.notify_one();
1041
1042            // But it's aborted when awaiting the task.
1043            let join_result = join_handle.await;
1044            assert!(join_result.unwrap_err().is_cancelled());
1045
1046            // Then the state is reset to idle.
1047            assert_eq!(state.next().await, Some(PaginatorState::Idle));
1048            assert!(state.next().now_or_never().is_none());
1049
1050            // And ditto for forward pagination.
1051            let p = paginator.clone();
1052            let join_handle = spawn(async move {
1053                let _ = p.paginate_forward(uint!(100)).await;
1054            });
1055
1056            *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1057
1058            assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1059            assert!(state.next().now_or_never().is_none());
1060
1061            room.room_ready.notify_one();
1062
1063            let join_result = join_handle.await;
1064            assert!(join_result.unwrap_err().is_cancelled());
1065
1066            assert_eq!(state.next().await, Some(PaginatorState::Idle));
1067            assert!(state.next().now_or_never().is_none());
1068        }
1069    }
1070}