1use std::{future::Future, sync::Mutex};
21
22use eyeball::{SharedObservable, Subscriber};
23use matrix_sdk_base::{deserialized_responses::TimelineEvent, SendOutsideWasm, SyncOutsideWasm};
24use ruma::{api::Direction, EventId, OwnedEventId, UInt};
25
26use super::pagination::PaginationToken;
27use crate::{
28 room::{EventWithContextResponse, Messages, MessagesOptions, WeakRoom},
29 Room,
30};
31
32#[derive(Debug, PartialEq, Copy, Clone)]
34#[cfg_attr(feature = "uniffi", derive(uniffi::Enum))]
35pub enum PaginatorState {
36 Initial,
38
39 FetchingTargetEvent,
41
42 Idle,
45
46 Paginating,
48}
49
50#[derive(Debug, thiserror::Error)]
52pub enum PaginatorError {
53 #[error("target event with id {0} could not be found")]
55 EventNotFound(OwnedEventId),
56
57 #[error("expected paginator state {expected:?}, observed {actual:?}")]
59 InvalidPreviousState {
60 expected: PaginatorState,
62 actual: PaginatorState,
64 },
65
66 #[error("an error happened while paginating: {0}")]
68 SdkError(#[from] Box<crate::Error>),
69}
70
71#[derive(Debug)]
73struct PaginationTokens {
74 previous: PaginationToken,
76 next: PaginationToken,
78}
79
80pub struct Paginator<PR: PaginableRoom> {
85 room: PR,
87
88 state: SharedObservable<PaginatorState>,
90
91 tokens: Mutex<PaginationTokens>,
95}
96
97#[cfg(not(tarpaulin_include))]
98impl<PR: PaginableRoom> std::fmt::Debug for Paginator<PR> {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.debug_struct("Paginator")
102 .field("state", &self.state.get())
103 .field("tokens", &self.tokens)
104 .finish_non_exhaustive()
105 }
106}
107
108#[derive(Debug)]
111pub struct PaginationResult {
112 pub events: Vec<TimelineEvent>,
120
121 pub hit_end_of_timeline: bool,
129}
130
131#[derive(Debug)]
133pub struct StartFromResult {
134 pub events: Vec<TimelineEvent>,
136
137 pub has_prev: bool,
139
140 pub has_next: bool,
142}
143
144struct ResetStateGuard {
146 target: Option<PaginatorState>,
147 state: SharedObservable<PaginatorState>,
148}
149
150impl ResetStateGuard {
151 fn new(state: SharedObservable<PaginatorState>, target: PaginatorState) -> Self {
153 Self { target: Some(target), state }
154 }
155
156 fn disarm(mut self) {
158 self.target = None;
159 }
160}
161
162impl Drop for ResetStateGuard {
163 fn drop(&mut self) {
164 if let Some(target) = self.target.take() {
165 self.state.set_if_not_eq(target);
166 }
167 }
168}
169
170impl<PR: PaginableRoom> Paginator<PR> {
171 pub fn new(room: PR) -> Self {
173 Self {
174 room,
175 state: SharedObservable::new(PaginatorState::Initial),
176 tokens: Mutex::new(PaginationTokens { previous: None.into(), next: None.into() }),
177 }
178 }
179
180 fn check_state(&self, expected: PaginatorState) -> Result<(), PaginatorError> {
182 let actual = self.state.get();
183 if actual != expected {
184 Err(PaginatorError::InvalidPreviousState { expected, actual })
185 } else {
186 Ok(())
187 }
188 }
189
190 pub fn state(&self) -> Subscriber<PaginatorState> {
192 self.state.subscribe()
193 }
194
195 pub(super) fn set_idle_state(
201 &self,
202 next_state: PaginatorState,
203 prev_batch_token: Option<String>,
204 next_batch_token: Option<String>,
205 ) -> Result<(), PaginatorError> {
206 let prev_state = self.state.get();
207
208 match next_state {
209 PaginatorState::Initial | PaginatorState::Idle => {}
210 PaginatorState::FetchingTargetEvent | PaginatorState::Paginating => {
211 panic!("internal error: set_idle_state only accept Initial|Idle next states");
212 }
213 }
214
215 match prev_state {
216 PaginatorState::Initial | PaginatorState::Idle => {}
217 PaginatorState::FetchingTargetEvent | PaginatorState::Paginating => {
218 return Err(PaginatorError::InvalidPreviousState {
220 expected: PaginatorState::Idle,
222 actual: prev_state,
223 });
224 }
225 }
226
227 self.state.set_if_not_eq(next_state);
228
229 {
230 let mut tokens = self.tokens.lock().unwrap();
231 tokens.previous = prev_batch_token.into();
232 tokens.next = next_batch_token.into();
233 }
234
235 Ok(())
236 }
237
238 pub(super) fn prev_batch_token(&self) -> Option<String> {
240 match &self.tokens.lock().unwrap().previous {
241 PaginationToken::HitEnd | PaginationToken::None => None,
242 PaginationToken::HasMore(token) => Some(token.clone()),
243 }
244 }
245
246 pub async fn start_from(
252 &self,
253 event_id: &EventId,
254 num_events: UInt,
255 ) -> Result<StartFromResult, PaginatorError> {
256 self.check_state(PaginatorState::Initial)?;
257
258 if self.state.set_if_not_eq(PaginatorState::FetchingTargetEvent).is_none() {
262 return Err(PaginatorError::InvalidPreviousState {
263 expected: PaginatorState::Initial,
264 actual: PaginatorState::FetchingTargetEvent,
265 });
266 }
267
268 let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Initial);
269
270 let lazy_load_members = true;
272
273 let response =
274 self.room.event_with_context(event_id, lazy_load_members, num_events).await?;
275
276 let has_prev = response.prev_batch_token.is_some();
281 let has_next = response.next_batch_token.is_some();
282
283 {
284 let mut tokens = self.tokens.lock().unwrap();
285 tokens.previous = match response.prev_batch_token {
286 Some(token) => PaginationToken::HasMore(token),
287 None => PaginationToken::HitEnd,
288 };
289 tokens.next = match response.next_batch_token {
290 Some(token) => PaginationToken::HasMore(token),
291 None => PaginationToken::HitEnd,
292 };
293 }
294
295 reset_state_guard.disarm();
297 self.state.set(PaginatorState::Idle);
299
300 let events = response
307 .events_before
308 .into_iter()
309 .rev()
310 .chain(response.event)
311 .chain(response.events_after)
312 .collect();
313
314 Ok(StartFromResult { events, has_prev, has_next })
315 }
316
317 pub async fn paginate_backward(
326 &self,
327 num_events: UInt,
328 ) -> Result<PaginationResult, PaginatorError> {
329 self.paginate(Direction::Backward, num_events).await
330 }
331
332 pub fn hit_timeline_start(&self) -> bool {
337 matches!(self.tokens.lock().unwrap().previous, PaginationToken::HitEnd)
338 }
339
340 pub fn hit_timeline_end(&self) -> bool {
345 matches!(self.tokens.lock().unwrap().next, PaginationToken::HitEnd)
346 }
347
348 pub async fn paginate_forward(
356 &self,
357 num_events: UInt,
358 ) -> Result<PaginationResult, PaginatorError> {
359 self.paginate(Direction::Forward, num_events).await
360 }
361
362 async fn paginate(
366 &self,
367 dir: Direction,
368 num_events: UInt,
369 ) -> Result<PaginationResult, PaginatorError> {
370 self.check_state(PaginatorState::Idle)?;
371
372 let token = {
373 let tokens = self.tokens.lock().unwrap();
374
375 let token = match dir {
376 Direction::Backward => &tokens.previous,
377 Direction::Forward => &tokens.next,
378 };
379
380 match token {
381 PaginationToken::None => None,
382 PaginationToken::HasMore(val) => Some(val.clone()),
383 PaginationToken::HitEnd => {
384 return Ok(PaginationResult { events: Vec::new(), hit_end_of_timeline: true });
385 }
386 }
387 };
388
389 if self.state.set_if_not_eq(PaginatorState::Paginating).is_none() {
393 return Err(PaginatorError::InvalidPreviousState {
394 expected: PaginatorState::Idle,
395 actual: PaginatorState::Paginating,
396 });
397 }
398
399 let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Idle);
400
401 let mut options = MessagesOptions::new(dir).from(token.as_deref());
402 options.limit = num_events;
403
404 let response = self.room.messages(options).await?;
407
408 let hit_end_of_timeline = response.end.is_none();
413
414 {
415 let mut tokens = self.tokens.lock().unwrap();
416
417 let token = match dir {
418 Direction::Backward => &mut tokens.previous,
419 Direction::Forward => &mut tokens.next,
420 };
421
422 *token = match response.end {
423 Some(val) => PaginationToken::HasMore(val),
424 None => PaginationToken::HitEnd,
425 };
426 }
427
428 reset_state_guard.disarm();
432 self.state.set(PaginatorState::Idle);
434
435 Ok(PaginationResult { events: response.chunk, hit_end_of_timeline })
436 }
437}
438
439pub trait PaginableRoom: SendOutsideWasm + SyncOutsideWasm {
444 fn event_with_context(
459 &self,
460 event_id: &EventId,
461 lazy_load_members: bool,
462 num_events: UInt,
463 ) -> impl Future<Output = Result<EventWithContextResponse, PaginatorError>> + SendOutsideWasm;
464
465 fn messages(
467 &self,
468 opts: MessagesOptions,
469 ) -> impl Future<Output = Result<Messages, PaginatorError>> + SendOutsideWasm;
470}
471
472impl PaginableRoom for Room {
473 async fn event_with_context(
474 &self,
475 event_id: &EventId,
476 lazy_load_members: bool,
477 num_events: UInt,
478 ) -> Result<EventWithContextResponse, PaginatorError> {
479 let response =
480 match self.event_with_context(event_id, lazy_load_members, num_events, None).await {
481 Ok(result) => result,
482
483 Err(err) => {
484 if let Some(error) = err.as_client_api_error() {
488 if error.status_code == 404 {
489 return Err(PaginatorError::EventNotFound(event_id.to_owned()));
491 }
492 }
493
494 return Err(PaginatorError::SdkError(Box::new(err)));
496 }
497 };
498
499 Ok(response)
500 }
501
502 async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
503 self.messages(opts).await.map_err(|err| PaginatorError::SdkError(Box::new(err)))
504 }
505}
506
507impl PaginableRoom for WeakRoom {
508 async fn event_with_context(
509 &self,
510 event_id: &EventId,
511 lazy_load_members: bool,
512 num_events: UInt,
513 ) -> Result<EventWithContextResponse, PaginatorError> {
514 let Some(room) = self.get() else {
515 return Ok(EventWithContextResponse::default());
517 };
518
519 PaginableRoom::event_with_context(&room, event_id, lazy_load_members, num_events).await
520 }
521
522 async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
523 let Some(room) = self.get() else {
524 return Ok(Messages::default());
526 };
527
528 PaginableRoom::messages(&room, opts).await
529 }
530}
531
532#[cfg(all(not(target_arch = "wasm32"), test))]
533mod tests {
534 use std::sync::Arc;
535
536 use assert_matches2::assert_let;
537 use futures_core::Future;
538 use futures_util::FutureExt as _;
539 use matrix_sdk_base::deserialized_responses::TimelineEvent;
540 use matrix_sdk_test::{async_test, event_factory::EventFactory};
541 use once_cell::sync::Lazy;
542 use ruma::{api::Direction, event_id, room_id, uint, user_id, EventId, RoomId, UInt, UserId};
543 use tokio::{
544 spawn,
545 sync::{Mutex, Notify},
546 task::AbortHandle,
547 };
548
549 use super::{PaginableRoom, PaginatorError, PaginatorState};
550 use crate::{
551 event_cache::paginator::Paginator,
552 room::{EventWithContextResponse, Messages, MessagesOptions},
553 test_utils::assert_event_matches_msg,
554 };
555
556 #[derive(Clone)]
557 struct TestRoom {
558 event_factory: Arc<EventFactory>,
559 wait_for_ready: bool,
560
561 target_event_text: Arc<Mutex<String>>,
562 next_events: Arc<Mutex<Vec<TimelineEvent>>>,
563 prev_events: Arc<Mutex<Vec<TimelineEvent>>>,
564 prev_batch_token: Arc<Mutex<Option<String>>>,
565 next_batch_token: Arc<Mutex<Option<String>>>,
566
567 room_ready: Arc<Notify>,
568 }
569
570 impl TestRoom {
571 fn new(wait_for_ready: bool, room_id: &RoomId, sender: &UserId) -> Self {
572 let event_factory = Arc::new(EventFactory::default().sender(sender).room(room_id));
573
574 Self {
575 event_factory,
576 wait_for_ready,
577
578 room_ready: Default::default(),
579 target_event_text: Default::default(),
580 next_events: Default::default(),
581 prev_events: Default::default(),
582 prev_batch_token: Default::default(),
583 next_batch_token: Default::default(),
584 }
585 }
586
587 fn mark_ready(&self) {
589 self.room_ready.notify_one();
590 }
591 }
592
593 static ROOM_ID: Lazy<&RoomId> = Lazy::new(|| room_id!("!dune:herbert.org"));
594 static USER_ID: Lazy<&UserId> = Lazy::new(|| user_id!("@paul:atreid.es"));
595
596 impl PaginableRoom for TestRoom {
597 async fn event_with_context(
598 &self,
599 event_id: &EventId,
600 _lazy_load_members: bool,
601 num_events: UInt,
602 ) -> Result<EventWithContextResponse, PaginatorError> {
603 if self.wait_for_ready {
605 self.room_ready.notified().await;
606 }
607
608 let event = self
609 .event_factory
610 .text_msg(self.target_event_text.lock().await.clone())
611 .event_id(event_id)
612 .into_event();
613
614 let mut num_events = u64::from(num_events) as usize;
617
618 let prev_events = self.prev_events.lock().await;
619
620 let events_before = if prev_events.is_empty() {
621 Vec::new()
622 } else {
623 let len = prev_events.len();
624 let take_before = num_events.min(len);
625 num_events -= take_before;
627 prev_events[len - take_before..len].to_vec()
629 };
630
631 let events_after = self.next_events.lock().await;
632 let events_after = if events_after.is_empty() {
633 Vec::new()
634 } else {
635 events_after[0..num_events.min(events_after.len())].to_vec()
636 };
637
638 Ok(EventWithContextResponse {
639 event: Some(event),
640 events_before,
641 events_after,
642 prev_batch_token: self.prev_batch_token.lock().await.clone(),
643 next_batch_token: self.next_batch_token.lock().await.clone(),
644 state: Vec::new(),
645 })
646 }
647
648 async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
649 if self.wait_for_ready {
650 self.room_ready.notified().await;
651 }
652
653 let limit = u64::from(opts.limit) as usize;
654
655 let (end, events) = match opts.dir {
656 Direction::Backward => {
657 let events = self.prev_events.lock().await;
658 let events = if events.is_empty() {
659 Vec::new()
660 } else {
661 let len = events.len();
662 let take_before = limit.min(len);
663 events[len - take_before..len].to_vec()
665 };
666 (self.prev_batch_token.lock().await.clone(), events)
667 }
668
669 Direction::Forward => {
670 let events = self.next_events.lock().await;
671 let events = if events.is_empty() {
672 Vec::new()
673 } else {
674 events[0..limit.min(events.len())].to_vec()
675 };
676 (self.next_batch_token.lock().await.clone(), events)
677 }
678 };
679
680 Ok(Messages { start: opts.from.unwrap(), end, chunk: events, state: Vec::new() })
681 }
682 }
683
684 async fn assert_invalid_state<T: std::fmt::Debug>(
685 task: impl Future<Output = Result<T, PaginatorError>>,
686 expected: PaginatorState,
687 actual: PaginatorState,
688 ) {
689 assert_let!(
690 Err(PaginatorError::InvalidPreviousState {
691 expected: real_expected,
692 actual: real_actual
693 }) = task.await
694 );
695 assert_eq!(real_expected, expected);
696 assert_eq!(real_actual, actual);
697 }
698
699 #[async_test]
700 async fn test_start_from() {
701 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
703
704 let event_id = event_id!("$yoyoyo");
705 let event_factory = &room.event_factory;
706
707 *room.target_event_text.lock().await = "fetch_from".to_owned();
708 *room.prev_events.lock().await = (0..10)
709 .rev()
710 .map(|i| event_factory.text_msg(format!("before-{i}")).into_event())
711 .collect();
712 *room.next_events.lock().await =
713 (0..10).map(|i| event_factory.text_msg(format!("after-{i}")).into_event()).collect();
714
715 let paginator = Arc::new(Paginator::new(room.clone()));
717 let context =
718 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
719
720 assert!(!context.has_prev);
721 assert!(!context.has_next);
722
723 assert_eq!(context.events.len(), 21);
727
728 for i in 0..10 {
729 assert_event_matches_msg(&context.events[i], &format!("before-{i}"));
730 }
731
732 assert_event_matches_msg(&context.events[10], "fetch_from");
733 assert_eq!(context.events[10].raw().deserialize().unwrap().event_id(), event_id);
734
735 for i in 0..10 {
736 assert_event_matches_msg(&context.events[i + 11], &format!("after-{i}"));
737 }
738 }
739
740 #[async_test]
741 async fn test_start_from_with_num_events() {
742 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
744
745 let event_id = event_id!("$yoyoyo");
746 let event_factory = &room.event_factory;
747
748 *room.target_event_text.lock().await = "fetch_from".to_owned();
749 *room.prev_events.lock().await =
750 (0..100).rev().map(|i| event_factory.text_msg(format!("ev{i}")).into_event()).collect();
751
752 let paginator = Arc::new(Paginator::new(room.clone()));
754 let context =
755 paginator.start_from(event_id, uint!(10)).await.expect("start_from should work");
756
757 assert_eq!(context.events.len(), 11);
760
761 for i in 0..10 {
762 assert_event_matches_msg(&context.events[i], &format!("ev{i}"));
763 }
764 assert_event_matches_msg(&context.events[10], "fetch_from");
765 }
766
767 #[async_test]
768 async fn test_paginate_backward() {
769 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
771
772 let event_id = event_id!("$yoyoyo");
773 let event_factory = &room.event_factory;
774
775 *room.target_event_text.lock().await = "initial".to_owned();
776 *room.prev_batch_token.lock().await = Some("prev".to_owned());
777
778 let paginator = Arc::new(Paginator::new(room.clone()));
780
781 assert!(!paginator.hit_timeline_start(), "we must have a prev-batch token");
782 assert!(
783 !paginator.hit_timeline_end(),
784 "we don't know about the status of the next-batch token"
785 );
786
787 let context =
788 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
789
790 assert_eq!(context.events.len(), 1);
792 assert_event_matches_msg(&context.events[0], "initial");
793 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
794
795 assert!(context.has_prev);
797 assert!(!context.has_next);
798
799 assert!(!paginator.hit_timeline_start());
800 assert!(paginator.hit_timeline_end());
801
802 *room.prev_events.lock().await = vec![event_factory.text_msg("previous").into_event()];
804 *room.prev_batch_token.lock().await = Some("prev2".to_owned());
805
806 let prev =
808 paginator.paginate_backward(uint!(100)).await.expect("paginate backward should work");
809 assert!(!prev.hit_end_of_timeline);
810 assert!(!paginator.hit_timeline_start());
811 assert_eq!(prev.events.len(), 1);
812 assert_event_matches_msg(&prev.events[0], "previous");
813
814 *room.prev_events.lock().await = vec![event_factory.text_msg("oldest").into_event()];
817 *room.prev_batch_token.lock().await = None;
818
819 let prev = paginator
820 .paginate_backward(uint!(100))
821 .await
822 .expect("paginate backward the second time should work");
823 assert!(prev.hit_end_of_timeline);
824 assert!(paginator.hit_timeline_start());
825 assert_eq!(prev.events.len(), 1);
826 assert_event_matches_msg(&prev.events[0], "oldest");
827
828 let prev = paginator
831 .paginate_backward(uint!(100))
832 .await
833 .expect("paginate backward the third time should work");
834 assert!(prev.hit_end_of_timeline);
835 assert!(paginator.hit_timeline_start());
836 assert!(prev.events.is_empty());
837 }
838
839 #[async_test]
840 async fn test_paginate_backward_with_limit() {
841 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
843
844 let event_id = event_id!("$yoyoyo");
845 let event_factory = &room.event_factory;
846
847 *room.target_event_text.lock().await = "initial".to_owned();
848 *room.prev_batch_token.lock().await = Some("prev".to_owned());
849
850 let paginator = Arc::new(Paginator::new(room.clone()));
852 let context =
853 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
854
855 assert_eq!(context.events.len(), 1);
857 assert_event_matches_msg(&context.events[0], "initial");
858 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
859
860 assert!(context.has_prev);
862 assert!(!context.has_next);
863
864 *room.prev_events.lock().await = (0..100)
866 .rev()
867 .map(|i| event_factory.text_msg(format!("prev{i}")).into_event())
868 .collect();
869 *room.prev_batch_token.lock().await = None;
870
871 let prev =
873 paginator.paginate_backward(uint!(10)).await.expect("paginate backward should work");
874 assert!(prev.hit_end_of_timeline);
875 assert_eq!(prev.events.len(), 10);
876 for i in 0..10 {
877 assert_event_matches_msg(&prev.events[i], &format!("prev{}", 9 - i));
878 }
879 }
880
881 #[async_test]
882 async fn test_paginate_forward() {
883 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
885
886 let event_id = event_id!("$yoyoyo");
887 let event_factory = &room.event_factory;
888
889 *room.target_event_text.lock().await = "initial".to_owned();
890 *room.next_batch_token.lock().await = Some("next".to_owned());
891
892 let paginator = Arc::new(Paginator::new(room.clone()));
894 assert!(!paginator.hit_timeline_end(), "we must have a next-batch token");
895 assert!(
896 !paginator.hit_timeline_start(),
897 "we don't know about the status of the prev-batch token"
898 );
899
900 let context =
901 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
902
903 assert_eq!(context.events.len(), 1);
905 assert_event_matches_msg(&context.events[0], "initial");
906 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
907
908 assert!(!context.has_prev);
911 assert!(context.has_next);
912
913 assert!(paginator.hit_timeline_start());
914 assert!(!paginator.hit_timeline_end());
915
916 *room.next_events.lock().await = vec![event_factory.text_msg("next").into_event()];
918 *room.next_batch_token.lock().await = Some("next2".to_owned());
919
920 let next =
922 paginator.paginate_forward(uint!(100)).await.expect("paginate forward should work");
923 assert!(!next.hit_end_of_timeline);
924 assert_eq!(next.events.len(), 1);
925 assert_event_matches_msg(&next.events[0], "next");
926 assert!(!paginator.hit_timeline_end());
927
928 *room.next_events.lock().await = vec![event_factory.text_msg("latest").into_event()];
931 *room.next_batch_token.lock().await = None;
932
933 let next = paginator
934 .paginate_forward(uint!(100))
935 .await
936 .expect("paginate forward the second time should work");
937 assert!(next.hit_end_of_timeline);
938 assert_eq!(next.events.len(), 1);
939 assert_event_matches_msg(&next.events[0], "latest");
940 assert!(paginator.hit_timeline_end());
941
942 let next = paginator
945 .paginate_forward(uint!(100))
946 .await
947 .expect("paginate forward the third time should work");
948 assert!(next.hit_end_of_timeline);
949 assert!(next.events.is_empty());
950 assert!(paginator.hit_timeline_end());
951 }
952
953 #[async_test]
954 async fn test_state() {
955 let room = TestRoom::new(true, *ROOM_ID, *USER_ID);
956
957 *room.prev_batch_token.lock().await = Some("prev".to_owned());
958 *room.next_batch_token.lock().await = Some("next".to_owned());
959
960 let paginator = Arc::new(Paginator::new(room.clone()));
961
962 let event_id = event_id!("$yoyoyo");
963
964 let mut state = paginator.state();
965
966 assert_eq!(state.get(), PaginatorState::Initial);
967 assert!(state.next().now_or_never().is_none());
968
969 assert_invalid_state(
971 paginator.paginate_backward(uint!(100)),
972 PaginatorState::Idle,
973 PaginatorState::Initial,
974 )
975 .await;
976
977 assert!(state.next().now_or_never().is_none());
978
979 let p = paginator.clone();
981 let join_handle = spawn(async move { p.start_from(event_id, uint!(100)).await });
982
983 assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
984 assert!(state.next().now_or_never().is_none());
985
986 assert_invalid_state(
988 paginator.start_from(event_id, uint!(100)),
989 PaginatorState::Initial,
990 PaginatorState::FetchingTargetEvent,
991 )
992 .await;
993
994 assert_invalid_state(
995 paginator.paginate_backward(uint!(100)),
996 PaginatorState::Idle,
997 PaginatorState::FetchingTargetEvent,
998 )
999 .await;
1000
1001 assert!(state.next().now_or_never().is_none());
1002
1003 room.mark_ready();
1005
1006 assert_eq!(state.next().await, Some(PaginatorState::Idle));
1008
1009 join_handle.await.expect("joined failed").expect("/context failed");
1010
1011 assert!(state.next().now_or_never().is_none());
1012
1013 let p = paginator.clone();
1014 let join_handle = spawn(async move { p.paginate_backward(uint!(100)).await });
1015
1016 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1017
1018 assert_invalid_state(
1020 paginator.start_from(event_id, uint!(100)),
1021 PaginatorState::Initial,
1022 PaginatorState::Paginating,
1023 )
1024 .await;
1025
1026 assert_invalid_state(
1027 paginator.paginate_backward(uint!(100)),
1028 PaginatorState::Idle,
1029 PaginatorState::Paginating,
1030 )
1031 .await;
1032
1033 assert_invalid_state(
1034 paginator.paginate_forward(uint!(100)),
1035 PaginatorState::Idle,
1036 PaginatorState::Paginating,
1037 )
1038 .await;
1039
1040 assert!(state.next().now_or_never().is_none());
1041
1042 room.mark_ready();
1043
1044 assert_eq!(state.next().await, Some(PaginatorState::Idle));
1045
1046 join_handle.await.expect("joined failed").expect("/messages failed");
1047
1048 assert!(state.next().now_or_never().is_none());
1049 }
1050
1051 mod aborts {
1052 use super::*;
1053
1054 #[derive(Clone, Default)]
1055 struct AbortingRoom {
1056 abort_handle: Arc<Mutex<Option<AbortHandle>>>,
1057 room_ready: Arc<Notify>,
1058 }
1059
1060 impl AbortingRoom {
1061 async fn wait_abort_and_yield(&self) -> ! {
1062 self.room_ready.notified().await;
1064
1065 let mut guard = self.abort_handle.lock().await;
1067 let handle = guard.take().expect("only call me when i'm initialized");
1068 handle.abort();
1069
1070 loop {
1072 tokio::task::yield_now().await;
1073 }
1074 }
1075 }
1076
1077 impl PaginableRoom for AbortingRoom {
1078 async fn event_with_context(
1079 &self,
1080 _event_id: &EventId,
1081 _lazy_load_members: bool,
1082 _num_events: UInt,
1083 ) -> Result<EventWithContextResponse, PaginatorError> {
1084 self.wait_abort_and_yield().await
1085 }
1086
1087 async fn messages(&self, _opts: MessagesOptions) -> Result<Messages, PaginatorError> {
1088 self.wait_abort_and_yield().await
1089 }
1090 }
1091
1092 #[async_test]
1093 async fn test_abort_while_starting_from() {
1094 let room = AbortingRoom::default();
1095
1096 let paginator = Arc::new(Paginator::new(room.clone()));
1097
1098 let mut state = paginator.state();
1099
1100 assert_eq!(state.get(), PaginatorState::Initial);
1101 assert!(state.next().now_or_never().is_none());
1102
1103 let p = paginator.clone();
1105 let join_handle = spawn(async move {
1106 let _ = p.start_from(event_id!("$yoyoyo"), uint!(100)).await;
1107 });
1108
1109 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1110
1111 assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
1112 assert!(state.next().now_or_never().is_none());
1113
1114 room.room_ready.notify_one();
1115
1116 let join_result = join_handle.await;
1118 assert!(join_result.unwrap_err().is_cancelled());
1119
1120 assert_eq!(state.next().await, Some(PaginatorState::Initial));
1122 assert!(state.next().now_or_never().is_none());
1123 }
1124
1125 #[async_test]
1126 async fn test_abort_while_paginating() {
1127 let room = AbortingRoom::default();
1128
1129 let paginator = Paginator::new(room.clone());
1131 paginator
1132 .set_idle_state(
1133 PaginatorState::Idle,
1134 Some("prev".to_owned()),
1135 Some("next".to_owned()),
1136 )
1137 .unwrap();
1138
1139 let paginator = Arc::new(paginator);
1140
1141 let mut state = paginator.state();
1142
1143 assert_eq!(state.get(), PaginatorState::Idle);
1144 assert!(state.next().now_or_never().is_none());
1145
1146 let p = paginator.clone();
1148 let join_handle = spawn(async move {
1149 let _ = p.paginate_backward(uint!(100)).await;
1150 });
1151
1152 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1153
1154 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1155 assert!(state.next().now_or_never().is_none());
1156
1157 room.room_ready.notify_one();
1158
1159 let join_result = join_handle.await;
1161 assert!(join_result.unwrap_err().is_cancelled());
1162
1163 assert_eq!(state.next().await, Some(PaginatorState::Idle));
1165 assert!(state.next().now_or_never().is_none());
1166
1167 let p = paginator.clone();
1169 let join_handle = spawn(async move {
1170 let _ = p.paginate_forward(uint!(100)).await;
1171 });
1172
1173 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1174
1175 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1176 assert!(state.next().now_or_never().is_none());
1177
1178 room.room_ready.notify_one();
1179
1180 let join_result = join_handle.await;
1181 assert!(join_result.unwrap_err().is_cancelled());
1182
1183 assert_eq!(state.next().await, Some(PaginatorState::Idle));
1184 assert!(state.next().now_or_never().is_none());
1185 }
1186 }
1187}