matrix_sdk/paginators/
room.rs

1// Copyright 2024 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! The paginator is a stateful helper object that handles reaching an event,
16//! either from a cache or network, and surrounding events ("context"). Then, it
17//! makes it possible to paginate forward or backward, from that event, until
18//! one end of the timeline (front or back) is reached.
19
20use std::{future::Future, sync::Mutex};
21
22use eyeball::{SharedObservable, Subscriber};
23use matrix_sdk_base::{SendOutsideWasm, SyncOutsideWasm, deserialized_responses::TimelineEvent};
24use ruma::{EventId, UInt, api::Direction};
25
26use crate::{
27    Room,
28    paginators::{PaginationResult, PaginationToken, PaginatorError},
29    room::{EventWithContextResponse, Messages, MessagesOptions},
30};
31
32/// Current state of a [`Paginator`].
33#[derive(Debug, PartialEq, Copy, Clone)]
34#[cfg_attr(feature = "uniffi", derive(uniffi::Enum))]
35pub enum PaginatorState {
36    /// The initial state of the paginator.
37    Initial,
38
39    /// The paginator is fetching the target initial event.
40    FetchingTargetEvent,
41
42    /// The target initial event could be found, zero or more paginations have
43    /// happened since then, and the paginator is at rest now.
44    Idle,
45
46    /// The paginator is… paginating one direction or another.
47    Paginating,
48}
49
50/// Paginations tokens used for backward and forward pagination.
51#[derive(Debug, Clone)]
52pub struct PaginationTokens {
53    /// Pagination token used for backward pagination.
54    pub previous: PaginationToken,
55    /// Pagination token used for forward pagination.
56    pub next: PaginationToken,
57}
58
59/// A stateful object to reach to an event, and then paginate backward and
60/// forward from it.
61///
62/// See also the module-level documentation.
63pub struct Paginator<PR: PaginableRoom> {
64    /// The room in which we're going to run the pagination.
65    room: PR,
66
67    /// Current state of the paginator.
68    state: SharedObservable<PaginatorState>,
69
70    /// Pagination tokens used for subsequent requests.
71    ///
72    /// This mutex is always short-lived, so it's sync.
73    tokens: Mutex<PaginationTokens>,
74}
75
76#[cfg(not(tarpaulin_include))]
77impl<PR: PaginableRoom> std::fmt::Debug for Paginator<PR> {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        // Don't include the room in the debug output.
80        f.debug_struct("Paginator")
81            .field("state", &self.state.get())
82            .field("tokens", &self.tokens)
83            .finish_non_exhaustive()
84    }
85}
86
87/// The result of an initial [`Paginator::start_from`] query.
88#[derive(Debug)]
89pub struct StartFromResult {
90    /// All the events returned during this pagination, in topological ordering.
91    pub events: Vec<TimelineEvent>,
92
93    /// Whether the /context query returned a previous batch token.
94    pub has_prev: bool,
95
96    /// Whether the /context query returned a next batch token.
97    pub has_next: bool,
98}
99
100/// Reset the state to a given target on drop.
101struct ResetStateGuard {
102    target: Option<PaginatorState>,
103    state: SharedObservable<PaginatorState>,
104}
105
106impl ResetStateGuard {
107    /// Create a new reset state guard.
108    fn new(state: SharedObservable<PaginatorState>, target: PaginatorState) -> Self {
109        Self { target: Some(target), state }
110    }
111
112    /// Render the guard effectless, and consume it.
113    fn disarm(mut self) {
114        self.target = None;
115    }
116}
117
118impl Drop for ResetStateGuard {
119    fn drop(&mut self) {
120        if let Some(target) = self.target.take() {
121            self.state.set_if_not_eq(target);
122        }
123    }
124}
125
126impl<PR: PaginableRoom> Paginator<PR> {
127    /// Create a new [`Paginator`], given a room implementation.
128    pub fn new(room: PR) -> Self {
129        Self {
130            room,
131            state: SharedObservable::new(PaginatorState::Initial),
132            tokens: Mutex::new(PaginationTokens { previous: None.into(), next: None.into() }),
133        }
134    }
135
136    /// Check if the current state of the paginator matches the expected one.
137    fn check_state(&self, expected: PaginatorState) -> Result<(), PaginatorError> {
138        let actual = self.state.get();
139        if actual != expected {
140            Err(PaginatorError::InvalidPreviousState { expected, actual })
141        } else {
142            Ok(())
143        }
144    }
145
146    /// Returns a subscriber to the internal [`PaginatorState`] machine.
147    pub fn state(&self) -> Subscriber<PaginatorState> {
148        self.state.subscribe()
149    }
150
151    /// Starts the pagination from the initial event, requesting `num_events`
152    /// additional context events.
153    ///
154    /// Only works for fresh [`Paginator`] objects, which are in the
155    /// [`PaginatorState::Initial`] state.
156    pub async fn start_from(
157        &self,
158        event_id: &EventId,
159        num_events: UInt,
160    ) -> Result<StartFromResult, PaginatorError> {
161        self.check_state(PaginatorState::Initial)?;
162
163        // Note: it's possible two callers have checked the state and both figured it's
164        // initial. This check below makes sure there's at most one which can set the
165        // state to FetchingTargetEvent, preventing a race condition.
166        if self.state.set_if_not_eq(PaginatorState::FetchingTargetEvent).is_none() {
167            return Err(PaginatorError::InvalidPreviousState {
168                expected: PaginatorState::Initial,
169                actual: PaginatorState::FetchingTargetEvent,
170            });
171        }
172
173        let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Initial);
174
175        // TODO: do we want to lazy load members?
176        let lazy_load_members = true;
177
178        let response =
179            self.room.event_with_context(event_id, lazy_load_members, num_events).await?;
180
181        // NOTE: it's super important to not have any `await` after this point, since we
182        // don't want the task to be interrupted anymore, or the internal state
183        // may become incorrect.
184
185        let has_prev = response.prev_batch_token.is_some();
186        let has_next = response.next_batch_token.is_some();
187
188        {
189            let mut tokens = self.tokens.lock().unwrap();
190            tokens.previous = match response.prev_batch_token {
191                Some(token) => PaginationToken::HasMore(token),
192                None => PaginationToken::HitEnd,
193            };
194            tokens.next = match response.next_batch_token {
195                Some(token) => PaginationToken::HasMore(token),
196                None => PaginationToken::HitEnd,
197            };
198        }
199
200        // Forget the reset state guard, so its Drop method is not called.
201        reset_state_guard.disarm();
202        // And set the final state.
203        self.state.set(PaginatorState::Idle);
204
205        // Consolidate the events into a linear timeline, topologically ordered.
206        // - the events before are returned in the reverse topological order: invert
207        //   them.
208        // - insert the target event, if set.
209        // - the events after are returned in the correct topological order.
210
211        let events = response
212            .events_before
213            .into_iter()
214            .rev()
215            .chain(response.event)
216            .chain(response.events_after)
217            .collect();
218
219        Ok(StartFromResult { events, has_prev, has_next })
220    }
221
222    /// Runs a backward pagination (requesting `num_events` to the server), from
223    /// the current state of the object.
224    ///
225    /// Will return immediately if we have already hit the start of the
226    /// timeline.
227    ///
228    /// May return an error if it's already paginating, or if the call to
229    /// /messages failed.
230    pub async fn paginate_backward(
231        &self,
232        num_events: UInt,
233    ) -> Result<PaginationResult, PaginatorError> {
234        self.paginate(Direction::Backward, num_events).await
235    }
236
237    /// Returns whether we've hit the start of the timeline.
238    ///
239    /// This is true if, and only if, we didn't have a previous-batch token and
240    /// running backwards pagination would be useless.
241    pub fn hit_timeline_start(&self) -> bool {
242        matches!(self.tokens.lock().unwrap().previous, PaginationToken::HitEnd)
243    }
244
245    /// Returns whether we've hit the end of the timeline.
246    ///
247    /// This is true if, and only if, we didn't have a next-batch token and
248    /// running forwards pagination would be useless.
249    pub fn hit_timeline_end(&self) -> bool {
250        matches!(self.tokens.lock().unwrap().next, PaginationToken::HitEnd)
251    }
252
253    /// Runs a forward pagination (requesting `num_events` to the server), from
254    /// the current state of the object.
255    ///
256    /// Will return immediately if we have already hit the end of the timeline.
257    ///
258    /// May return an error if it's already paginating, or if the call to
259    /// /messages failed.
260    pub async fn paginate_forward(
261        &self,
262        num_events: UInt,
263    ) -> Result<PaginationResult, PaginatorError> {
264        self.paginate(Direction::Forward, num_events).await
265    }
266
267    /// Paginate in the given direction, requesting `num_events` events to the
268    /// server, using the `token_lock` to read from and write the pagination
269    /// token.
270    async fn paginate(
271        &self,
272        dir: Direction,
273        num_events: UInt,
274    ) -> Result<PaginationResult, PaginatorError> {
275        self.check_state(PaginatorState::Idle)?;
276
277        let token = {
278            let tokens = self.tokens.lock().unwrap();
279
280            let token = match dir {
281                Direction::Backward => &tokens.previous,
282                Direction::Forward => &tokens.next,
283            };
284
285            match token {
286                PaginationToken::None => None,
287                PaginationToken::HasMore(val) => Some(val.clone()),
288                PaginationToken::HitEnd => {
289                    return Ok(PaginationResult { events: Vec::new(), hit_end_of_timeline: true });
290                }
291            }
292        };
293
294        // Note: it's possible two callers have checked the state and both figured it's
295        // idle. This check below makes sure there's at most one which can set the
296        // state to paginating, preventing a race condition.
297        if self.state.set_if_not_eq(PaginatorState::Paginating).is_none() {
298            return Err(PaginatorError::InvalidPreviousState {
299                expected: PaginatorState::Idle,
300                actual: PaginatorState::Paginating,
301            });
302        }
303
304        let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Idle);
305
306        let mut options = MessagesOptions::new(dir).from(token.as_deref());
307        options.limit = num_events;
308
309        // In case of error, the state is reset to idle automatically thanks to
310        // reset_state_guard.
311        let response = self.room.messages(options).await?;
312
313        // NOTE: it's super important to not have any `await` after this point, since we
314        // don't want the task to be interrupted anymore, or the internal state
315        // may be incorrect.
316
317        let hit_end_of_timeline = response.end.is_none();
318
319        {
320            let mut tokens = self.tokens.lock().unwrap();
321
322            let token = match dir {
323                Direction::Backward => &mut tokens.previous,
324                Direction::Forward => &mut tokens.next,
325            };
326
327            *token = match response.end {
328                Some(val) => PaginationToken::HasMore(val),
329                None => PaginationToken::HitEnd,
330            };
331        }
332
333        // TODO: what to do with state events?
334
335        // Forget the reset state guard, so its Drop method is not called.
336        reset_state_guard.disarm();
337        // And set the final state.
338        self.state.set(PaginatorState::Idle);
339
340        Ok(PaginationResult { events: response.chunk, hit_end_of_timeline })
341    }
342
343    /// Returns the current pagination tokens.
344    pub fn tokens(&self) -> PaginationTokens {
345        self.tokens.lock().unwrap().clone()
346    }
347}
348
349/// A room that can be paginated.
350///
351/// Not [`crate::Room`] because we may want to paginate rooms we don't belong
352/// to.
353pub trait PaginableRoom: SendOutsideWasm + SyncOutsideWasm {
354    /// Runs a /context query for the given room.
355    ///
356    /// ## Parameters
357    ///
358    /// - `event_id` is the identifier of the target event.
359    /// - `lazy_load_members` controls whether room membership events are lazily
360    ///   loaded as context state events.
361    /// - `num_events` is the number of events (including the fetched event) to
362    ///   return as context.
363    ///
364    /// ## Returns
365    ///
366    /// Must return [`PaginatorError::EventNotFound`] whenever the target event
367    /// could not be found, instead of causing an http `Err` result.
368    fn event_with_context(
369        &self,
370        event_id: &EventId,
371        lazy_load_members: bool,
372        num_events: UInt,
373    ) -> impl Future<Output = Result<EventWithContextResponse, PaginatorError>> + SendOutsideWasm;
374
375    /// Runs a /messages query for the given room.
376    fn messages(
377        &self,
378        opts: MessagesOptions,
379    ) -> impl Future<Output = Result<Messages, PaginatorError>> + SendOutsideWasm;
380}
381
382impl PaginableRoom for Room {
383    async fn event_with_context(
384        &self,
385        event_id: &EventId,
386        lazy_load_members: bool,
387        num_events: UInt,
388    ) -> Result<EventWithContextResponse, PaginatorError> {
389        let response =
390            match self.event_with_context(event_id, lazy_load_members, num_events, None).await {
391                Ok(result) => result,
392
393                Err(err) => {
394                    // If the error was a 404, then the event wasn't found on the server;
395                    // special case this to make it easy to react to
396                    // such an error.
397                    if let Some(error) = err.as_client_api_error()
398                        && error.status_code == 404
399                    {
400                        // Event not found
401                        return Err(PaginatorError::EventNotFound(event_id.to_owned()));
402                    }
403
404                    // Otherwise, just return a wrapped error.
405                    return Err(PaginatorError::SdkError(Box::new(err)));
406                }
407            };
408
409        Ok(response)
410    }
411
412    async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
413        self.messages(opts).await.map_err(|err| PaginatorError::SdkError(Box::new(err)))
414    }
415}
416
417#[cfg(all(not(target_family = "wasm"), test))]
418mod tests {
419    use std::sync::Arc;
420
421    use assert_matches2::assert_let;
422    use futures_core::Future;
423    use futures_util::FutureExt as _;
424    use matrix_sdk_base::deserialized_responses::TimelineEvent;
425    use matrix_sdk_test::{async_test, event_factory::EventFactory};
426    use once_cell::sync::Lazy;
427    use ruma::{EventId, RoomId, UInt, UserId, api::Direction, event_id, room_id, uint, user_id};
428    use tokio::{
429        spawn,
430        sync::{Mutex, Notify},
431        task::AbortHandle,
432    };
433
434    use super::{PaginableRoom, PaginatorError, PaginatorState};
435    use crate::{
436        paginators::Paginator,
437        room::{EventWithContextResponse, Messages, MessagesOptions},
438        test_utils::assert_event_matches_msg,
439    };
440
441    #[derive(Clone)]
442    struct TestRoom {
443        event_factory: Arc<EventFactory>,
444        wait_for_ready: bool,
445
446        target_event_text: Arc<Mutex<String>>,
447        next_events: Arc<Mutex<Vec<TimelineEvent>>>,
448        prev_events: Arc<Mutex<Vec<TimelineEvent>>>,
449        prev_batch_token: Arc<Mutex<Option<String>>>,
450        next_batch_token: Arc<Mutex<Option<String>>>,
451
452        room_ready: Arc<Notify>,
453    }
454
455    impl TestRoom {
456        fn new(wait_for_ready: bool, room_id: &RoomId, sender: &UserId) -> Self {
457            let event_factory = Arc::new(EventFactory::default().sender(sender).room(room_id));
458
459            Self {
460                event_factory,
461                wait_for_ready,
462
463                room_ready: Default::default(),
464                target_event_text: Default::default(),
465                next_events: Default::default(),
466                prev_events: Default::default(),
467                prev_batch_token: Default::default(),
468                next_batch_token: Default::default(),
469            }
470        }
471
472        /// Unblocks the next request.
473        fn mark_ready(&self) {
474            self.room_ready.notify_one();
475        }
476    }
477
478    static ROOM_ID: Lazy<&RoomId> = Lazy::new(|| room_id!("!dune:herbert.org"));
479    static USER_ID: Lazy<&UserId> = Lazy::new(|| user_id!("@paul:atreid.es"));
480
481    impl PaginableRoom for TestRoom {
482        async fn event_with_context(
483            &self,
484            event_id: &EventId,
485            _lazy_load_members: bool,
486            num_events: UInt,
487        ) -> Result<EventWithContextResponse, PaginatorError> {
488            // Wait for the room to be marked as ready first.
489            if self.wait_for_ready {
490                self.room_ready.notified().await;
491            }
492
493            let event = self
494                .event_factory
495                .text_msg(self.target_event_text.lock().await.clone())
496                .event_id(event_id)
497                .into_event();
498
499            // Properly simulate `num_events`: take either the closest num_events events
500            // before, or use all of the before events and then consume after events.
501            let mut num_events = u64::from(num_events) as usize;
502
503            let prev_events = self.prev_events.lock().await;
504
505            let events_before = if prev_events.is_empty() {
506                Vec::new()
507            } else {
508                let len = prev_events.len();
509                let take_before = num_events.min(len);
510                // Subtract is safe because take_before <= num_events.
511                num_events -= take_before;
512                // Subtract is safe because take_before <= len
513                prev_events[len - take_before..len].to_vec()
514            };
515
516            let events_after = self.next_events.lock().await;
517            let events_after = if events_after.is_empty() {
518                Vec::new()
519            } else {
520                events_after[0..num_events.min(events_after.len())].to_vec()
521            };
522
523            Ok(EventWithContextResponse {
524                event: Some(event),
525                events_before,
526                events_after,
527                prev_batch_token: self.prev_batch_token.lock().await.clone(),
528                next_batch_token: self.next_batch_token.lock().await.clone(),
529                state: Vec::new(),
530            })
531        }
532
533        async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
534            if self.wait_for_ready {
535                self.room_ready.notified().await;
536            }
537
538            let limit = u64::from(opts.limit) as usize;
539
540            let (end, events) = match opts.dir {
541                Direction::Backward => {
542                    let events = self.prev_events.lock().await;
543                    let events = if events.is_empty() {
544                        Vec::new()
545                    } else {
546                        let len = events.len();
547                        let take_before = limit.min(len);
548                        // Subtract is safe because take_before <= len
549                        events[len - take_before..len].to_vec()
550                    };
551                    (self.prev_batch_token.lock().await.clone(), events)
552                }
553
554                Direction::Forward => {
555                    let events = self.next_events.lock().await;
556                    let events = if events.is_empty() {
557                        Vec::new()
558                    } else {
559                        events[0..limit.min(events.len())].to_vec()
560                    };
561                    (self.next_batch_token.lock().await.clone(), events)
562                }
563            };
564
565            Ok(Messages { start: opts.from.unwrap(), end, chunk: events, state: Vec::new() })
566        }
567    }
568
569    async fn assert_invalid_state<T: std::fmt::Debug>(
570        task: impl Future<Output = Result<T, PaginatorError>>,
571        expected: PaginatorState,
572        actual: PaginatorState,
573    ) {
574        assert_let!(
575            Err(PaginatorError::InvalidPreviousState {
576                expected: real_expected,
577                actual: real_actual
578            }) = task.await
579        );
580        assert_eq!(real_expected, expected);
581        assert_eq!(real_actual, actual);
582    }
583
584    #[async_test]
585    async fn test_start_from() {
586        // Prepare test data.
587        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
588
589        let event_id = event_id!("$yoyoyo");
590        let event_factory = &room.event_factory;
591
592        *room.target_event_text.lock().await = "fetch_from".to_owned();
593        *room.prev_events.lock().await = (0..10)
594            .rev()
595            .map(|i| event_factory.text_msg(format!("before-{i}")).into_event())
596            .collect();
597        *room.next_events.lock().await =
598            (0..10).map(|i| event_factory.text_msg(format!("after-{i}")).into_event()).collect();
599
600        // When I call `Paginator::start_from`, it works,
601        let paginator = Arc::new(Paginator::new(room.clone()));
602        let context =
603            paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
604
605        assert!(!context.has_prev);
606        assert!(!context.has_next);
607
608        // And I get the events I expected.
609
610        // 10 events before, the target event, 10 events after.
611        assert_eq!(context.events.len(), 21);
612
613        for i in 0..10 {
614            assert_event_matches_msg(&context.events[i], &format!("before-{i}"));
615        }
616
617        assert_event_matches_msg(&context.events[10], "fetch_from");
618        assert_eq!(context.events[10].raw().deserialize().unwrap().event_id(), event_id);
619
620        for i in 0..10 {
621            assert_event_matches_msg(&context.events[i + 11], &format!("after-{i}"));
622        }
623    }
624
625    #[async_test]
626    async fn test_start_from_with_num_events() {
627        // Prepare test data.
628        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
629
630        let event_id = event_id!("$yoyoyo");
631        let event_factory = &room.event_factory;
632
633        *room.target_event_text.lock().await = "fetch_from".to_owned();
634        *room.prev_events.lock().await =
635            (0..100).rev().map(|i| event_factory.text_msg(format!("ev{i}")).into_event()).collect();
636
637        // When I call `Paginator::start_from`, it works,
638        let paginator = Arc::new(Paginator::new(room.clone()));
639        let context =
640            paginator.start_from(event_id, uint!(10)).await.expect("start_from should work");
641
642        // Then I only get 10 events + the target event, even if there was more than 10
643        // events in the room.
644        assert_eq!(context.events.len(), 11);
645
646        for i in 0..10 {
647            assert_event_matches_msg(&context.events[i], &format!("ev{i}"));
648        }
649        assert_event_matches_msg(&context.events[10], "fetch_from");
650    }
651
652    #[async_test]
653    async fn test_paginate_backward() {
654        // Prepare test data.
655        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
656
657        let event_id = event_id!("$yoyoyo");
658        let event_factory = &room.event_factory;
659
660        *room.target_event_text.lock().await = "initial".to_owned();
661        *room.prev_batch_token.lock().await = Some("prev".to_owned());
662
663        // When I call `Paginator::start_from`, it works,
664        let paginator = Arc::new(Paginator::new(room.clone()));
665
666        assert!(!paginator.hit_timeline_start(), "we must have a prev-batch token");
667        assert!(
668            !paginator.hit_timeline_end(),
669            "we don't know about the status of the next-batch token"
670        );
671
672        let context =
673            paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
674
675        // And I get the events I expected.
676        assert_eq!(context.events.len(), 1);
677        assert_event_matches_msg(&context.events[0], "initial");
678        assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
679
680        // There's a previous batch, but no next batch.
681        assert!(context.has_prev);
682        assert!(!context.has_next);
683
684        assert!(!paginator.hit_timeline_start());
685        assert!(paginator.hit_timeline_end());
686
687        // Preparing data for the next back-pagination.
688        *room.prev_events.lock().await = vec![event_factory.text_msg("previous").into_event()];
689        *room.prev_batch_token.lock().await = Some("prev2".to_owned());
690
691        // When I backpaginate, I get the events I expect.
692        let prev =
693            paginator.paginate_backward(uint!(100)).await.expect("paginate backward should work");
694        assert!(!prev.hit_end_of_timeline);
695        assert!(!paginator.hit_timeline_start());
696        assert_eq!(prev.events.len(), 1);
697        assert_event_matches_msg(&prev.events[0], "previous");
698
699        // And I can backpaginate again, because there's a prev batch token
700        // still.
701        *room.prev_events.lock().await = vec![event_factory.text_msg("oldest").into_event()];
702        *room.prev_batch_token.lock().await = None;
703
704        let prev = paginator
705            .paginate_backward(uint!(100))
706            .await
707            .expect("paginate backward the second time should work");
708        assert!(prev.hit_end_of_timeline);
709        assert!(paginator.hit_timeline_start());
710        assert_eq!(prev.events.len(), 1);
711        assert_event_matches_msg(&prev.events[0], "oldest");
712
713        // I've hit the start of the timeline, but back-paginating again will
714        // return immediately.
715        let prev = paginator
716            .paginate_backward(uint!(100))
717            .await
718            .expect("paginate backward the third time should work");
719        assert!(prev.hit_end_of_timeline);
720        assert!(paginator.hit_timeline_start());
721        assert!(prev.events.is_empty());
722    }
723
724    #[async_test]
725    async fn test_paginate_backward_with_limit() {
726        // Prepare test data.
727        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
728
729        let event_id = event_id!("$yoyoyo");
730        let event_factory = &room.event_factory;
731
732        *room.target_event_text.lock().await = "initial".to_owned();
733        *room.prev_batch_token.lock().await = Some("prev".to_owned());
734
735        // When I call `Paginator::start_from`, it works,
736        let paginator = Arc::new(Paginator::new(room.clone()));
737        let context =
738            paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
739
740        // And I get the events I expected.
741        assert_eq!(context.events.len(), 1);
742        assert_event_matches_msg(&context.events[0], "initial");
743        assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
744
745        // There's a previous batch.
746        assert!(context.has_prev);
747        assert!(!context.has_next);
748
749        // Preparing data for the next back-pagination.
750        *room.prev_events.lock().await = (0..100)
751            .rev()
752            .map(|i| event_factory.text_msg(format!("prev{i}")).into_event())
753            .collect();
754        *room.prev_batch_token.lock().await = None;
755
756        // When I backpaginate and request 100 events, I get only 10 events.
757        let prev =
758            paginator.paginate_backward(uint!(10)).await.expect("paginate backward should work");
759        assert!(prev.hit_end_of_timeline);
760        assert_eq!(prev.events.len(), 10);
761        for i in 0..10 {
762            assert_event_matches_msg(&prev.events[i], &format!("prev{}", 9 - i));
763        }
764    }
765
766    #[async_test]
767    async fn test_paginate_forward() {
768        // Prepare test data.
769        let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
770
771        let event_id = event_id!("$yoyoyo");
772        let event_factory = &room.event_factory;
773
774        *room.target_event_text.lock().await = "initial".to_owned();
775        *room.next_batch_token.lock().await = Some("next".to_owned());
776
777        // When I call `Paginator::start_from`, it works,
778        let paginator = Arc::new(Paginator::new(room.clone()));
779        assert!(!paginator.hit_timeline_end(), "we must have a next-batch token");
780        assert!(
781            !paginator.hit_timeline_start(),
782            "we don't know about the status of the prev-batch token"
783        );
784
785        let context =
786            paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
787
788        // And I get the events I expected.
789        assert_eq!(context.events.len(), 1);
790        assert_event_matches_msg(&context.events[0], "initial");
791        assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
792
793        // There's a next batch, but no previous batch (i.e. we've hit the start of the
794        // timeline).
795        assert!(!context.has_prev);
796        assert!(context.has_next);
797
798        assert!(paginator.hit_timeline_start());
799        assert!(!paginator.hit_timeline_end());
800
801        // Preparing data for the next forward-pagination.
802        *room.next_events.lock().await = vec![event_factory.text_msg("next").into_event()];
803        *room.next_batch_token.lock().await = Some("next2".to_owned());
804
805        // When I forward-paginate, I get the events I expect.
806        let next =
807            paginator.paginate_forward(uint!(100)).await.expect("paginate forward should work");
808        assert!(!next.hit_end_of_timeline);
809        assert_eq!(next.events.len(), 1);
810        assert_event_matches_msg(&next.events[0], "next");
811        assert!(!paginator.hit_timeline_end());
812
813        // And I can forward-paginate again, because there's a prev batch token
814        // still.
815        *room.next_events.lock().await = vec![event_factory.text_msg("latest").into_event()];
816        *room.next_batch_token.lock().await = None;
817
818        let next = paginator
819            .paginate_forward(uint!(100))
820            .await
821            .expect("paginate forward the second time should work");
822        assert!(next.hit_end_of_timeline);
823        assert_eq!(next.events.len(), 1);
824        assert_event_matches_msg(&next.events[0], "latest");
825        assert!(paginator.hit_timeline_end());
826
827        // I've hit the start of the timeline, but back-paginating again will
828        // return immediately.
829        let next = paginator
830            .paginate_forward(uint!(100))
831            .await
832            .expect("paginate forward the third time should work");
833        assert!(next.hit_end_of_timeline);
834        assert!(next.events.is_empty());
835        assert!(paginator.hit_timeline_end());
836    }
837
838    #[async_test]
839    async fn test_state() {
840        let room = TestRoom::new(true, *ROOM_ID, *USER_ID);
841
842        *room.prev_batch_token.lock().await = Some("prev".to_owned());
843        *room.next_batch_token.lock().await = Some("next".to_owned());
844
845        let paginator = Arc::new(Paginator::new(room.clone()));
846
847        let event_id = event_id!("$yoyoyo");
848
849        let mut state = paginator.state();
850
851        assert_eq!(state.get(), PaginatorState::Initial);
852        assert!(state.next().now_or_never().is_none());
853
854        // Attempting to run pagination must fail and not change the state.
855        assert_invalid_state(
856            paginator.paginate_backward(uint!(100)),
857            PaginatorState::Idle,
858            PaginatorState::Initial,
859        )
860        .await;
861
862        assert!(state.next().now_or_never().is_none());
863
864        // Running the initial query must work.
865        let p = paginator.clone();
866        let join_handle = spawn(async move { p.start_from(event_id, uint!(100)).await });
867
868        assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
869        assert!(state.next().now_or_never().is_none());
870
871        // The query is pending. Running other operations must fail.
872        assert_invalid_state(
873            paginator.start_from(event_id, uint!(100)),
874            PaginatorState::Initial,
875            PaginatorState::FetchingTargetEvent,
876        )
877        .await;
878
879        assert_invalid_state(
880            paginator.paginate_backward(uint!(100)),
881            PaginatorState::Idle,
882            PaginatorState::FetchingTargetEvent,
883        )
884        .await;
885
886        assert!(state.next().now_or_never().is_none());
887
888        // Mark the dummy room as ready. The query may now terminate.
889        room.mark_ready();
890
891        // After fetching the initial event data, the paginator switches to `Idle`.
892        assert_eq!(state.next().await, Some(PaginatorState::Idle));
893
894        join_handle.await.expect("joined failed").expect("/context failed");
895
896        assert!(state.next().now_or_never().is_none());
897
898        let p = paginator.clone();
899        let join_handle = spawn(async move { p.paginate_backward(uint!(100)).await });
900
901        assert_eq!(state.next().await, Some(PaginatorState::Paginating));
902
903        // The query is pending. Running other operations must fail.
904        assert_invalid_state(
905            paginator.start_from(event_id, uint!(100)),
906            PaginatorState::Initial,
907            PaginatorState::Paginating,
908        )
909        .await;
910
911        assert_invalid_state(
912            paginator.paginate_backward(uint!(100)),
913            PaginatorState::Idle,
914            PaginatorState::Paginating,
915        )
916        .await;
917
918        assert_invalid_state(
919            paginator.paginate_forward(uint!(100)),
920            PaginatorState::Idle,
921            PaginatorState::Paginating,
922        )
923        .await;
924
925        assert!(state.next().now_or_never().is_none());
926
927        room.mark_ready();
928
929        assert_eq!(state.next().await, Some(PaginatorState::Idle));
930
931        join_handle.await.expect("joined failed").expect("/messages failed");
932
933        assert!(state.next().now_or_never().is_none());
934    }
935
936    mod aborts {
937        use super::*;
938        use crate::paginators::room::{PaginationToken, PaginationTokens};
939
940        #[derive(Clone, Default)]
941        struct AbortingRoom {
942            abort_handle: Arc<Mutex<Option<AbortHandle>>>,
943            room_ready: Arc<Notify>,
944        }
945
946        impl AbortingRoom {
947            async fn wait_abort_and_yield(&self) -> ! {
948                // Wait for the controller to tell us we're ready.
949                self.room_ready.notified().await;
950
951                // Abort the given handle.
952                let mut guard = self.abort_handle.lock().await;
953                let handle = guard.take().expect("only call me when i'm initialized");
954                handle.abort();
955
956                // Enter an endless loop of yielding.
957                loop {
958                    tokio::task::yield_now().await;
959                }
960            }
961        }
962
963        impl PaginableRoom for AbortingRoom {
964            async fn event_with_context(
965                &self,
966                _event_id: &EventId,
967                _lazy_load_members: bool,
968                _num_events: UInt,
969            ) -> Result<EventWithContextResponse, PaginatorError> {
970                self.wait_abort_and_yield().await
971            }
972
973            async fn messages(&self, _opts: MessagesOptions) -> Result<Messages, PaginatorError> {
974                self.wait_abort_and_yield().await
975            }
976        }
977
978        #[async_test]
979        async fn test_abort_while_starting_from() {
980            let room = AbortingRoom::default();
981
982            let paginator = Arc::new(Paginator::new(room.clone()));
983
984            let mut state = paginator.state();
985
986            assert_eq!(state.get(), PaginatorState::Initial);
987            assert!(state.next().now_or_never().is_none());
988
989            // When I try to start the initial query…
990            let p = paginator.clone();
991            let join_handle = spawn(async move {
992                let _ = p.start_from(event_id!("$yoyoyo"), uint!(100)).await;
993            });
994
995            *room.abort_handle.lock().await = Some(join_handle.abort_handle());
996
997            assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
998            assert!(state.next().now_or_never().is_none());
999
1000            room.room_ready.notify_one();
1001
1002            // But it's aborted when awaiting the task.
1003            let join_result = join_handle.await;
1004            assert!(join_result.unwrap_err().is_cancelled());
1005
1006            // Then the state is reset to initial.
1007            assert_eq!(state.next().await, Some(PaginatorState::Initial));
1008            assert!(state.next().now_or_never().is_none());
1009        }
1010
1011        #[async_test]
1012        async fn test_abort_while_paginating() {
1013            let room = AbortingRoom::default();
1014
1015            // Assuming a paginator ready to back- or forward- paginate,
1016            let paginator = Paginator::new(room.clone());
1017            paginator.state.set(PaginatorState::Idle);
1018            *paginator.tokens.lock().unwrap() = PaginationTokens {
1019                previous: PaginationToken::HasMore("prev".to_owned()),
1020                next: PaginationToken::HasMore("next".to_owned()),
1021            };
1022
1023            let paginator = Arc::new(paginator);
1024
1025            let mut state = paginator.state();
1026
1027            assert_eq!(state.get(), PaginatorState::Idle);
1028            assert!(state.next().now_or_never().is_none());
1029
1030            // When I try to back-paginate…
1031            let p = paginator.clone();
1032            let join_handle = spawn(async move {
1033                let _ = p.paginate_backward(uint!(100)).await;
1034            });
1035
1036            *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1037
1038            assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1039            assert!(state.next().now_or_never().is_none());
1040
1041            room.room_ready.notify_one();
1042
1043            // But it's aborted when awaiting the task.
1044            let join_result = join_handle.await;
1045            assert!(join_result.unwrap_err().is_cancelled());
1046
1047            // Then the state is reset to idle.
1048            assert_eq!(state.next().await, Some(PaginatorState::Idle));
1049            assert!(state.next().now_or_never().is_none());
1050
1051            // And ditto for forward pagination.
1052            let p = paginator.clone();
1053            let join_handle = spawn(async move {
1054                let _ = p.paginate_forward(uint!(100)).await;
1055            });
1056
1057            *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1058
1059            assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1060            assert!(state.next().now_or_never().is_none());
1061
1062            room.room_ready.notify_one();
1063
1064            let join_result = join_handle.await;
1065            assert!(join_result.unwrap_err().is_cancelled());
1066
1067            assert_eq!(state.next().await, Some(PaginatorState::Idle));
1068            assert!(state.next().now_or_never().is_none());
1069        }
1070    }
1071}