1use std::{future::Future, sync::Mutex};
21
22use eyeball::{SharedObservable, Subscriber};
23use matrix_sdk_base::{deserialized_responses::TimelineEvent, SendOutsideWasm, SyncOutsideWasm};
24use ruma::{api::Direction, EventId, OwnedEventId, UInt};
25
26use super::pagination::PaginationToken;
27use crate::{
28 room::{EventWithContextResponse, Messages, MessagesOptions, WeakRoom},
29 Room,
30};
31
32#[derive(Debug, PartialEq, Copy, Clone)]
34#[cfg_attr(feature = "uniffi", derive(uniffi::Enum))]
35pub enum PaginatorState {
36 Initial,
38
39 FetchingTargetEvent,
41
42 Idle,
45
46 Paginating,
48}
49
50#[derive(Debug, thiserror::Error)]
52pub enum PaginatorError {
53 #[error("target event with id {0} could not be found")]
55 EventNotFound(OwnedEventId),
56
57 #[error("expected paginator state {expected:?}, observed {actual:?}")]
59 InvalidPreviousState {
60 expected: PaginatorState,
62 actual: PaginatorState,
64 },
65
66 #[error("an error happened while paginating: {0}")]
68 SdkError(#[from] Box<crate::Error>),
69}
70
71#[derive(Debug)]
73struct PaginationTokens {
74 previous: PaginationToken,
76 next: PaginationToken,
78}
79
80pub struct Paginator<PR: PaginableRoom> {
85 room: PR,
87
88 state: SharedObservable<PaginatorState>,
90
91 tokens: Mutex<PaginationTokens>,
95}
96
97#[cfg(not(tarpaulin_include))]
98impl<PR: PaginableRoom> std::fmt::Debug for Paginator<PR> {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.debug_struct("Paginator")
102 .field("state", &self.state.get())
103 .field("tokens", &self.tokens)
104 .finish_non_exhaustive()
105 }
106}
107
108#[derive(Debug)]
111pub struct PaginationResult {
112 pub events: Vec<TimelineEvent>,
120
121 pub hit_end_of_timeline: bool,
129}
130
131#[derive(Debug)]
133pub struct StartFromResult {
134 pub events: Vec<TimelineEvent>,
136
137 pub has_prev: bool,
139
140 pub has_next: bool,
142}
143
144struct ResetStateGuard {
146 target: Option<PaginatorState>,
147 state: SharedObservable<PaginatorState>,
148}
149
150impl ResetStateGuard {
151 fn new(state: SharedObservable<PaginatorState>, target: PaginatorState) -> Self {
153 Self { target: Some(target), state }
154 }
155
156 fn disarm(mut self) {
158 self.target = None;
159 }
160}
161
162impl Drop for ResetStateGuard {
163 fn drop(&mut self) {
164 if let Some(target) = self.target.take() {
165 self.state.set_if_not_eq(target);
166 }
167 }
168}
169
170impl<PR: PaginableRoom> Paginator<PR> {
171 pub fn new(room: PR) -> Self {
173 Self {
174 room,
175 state: SharedObservable::new(PaginatorState::Initial),
176 tokens: Mutex::new(PaginationTokens { previous: None.into(), next: None.into() }),
177 }
178 }
179
180 fn check_state(&self, expected: PaginatorState) -> Result<(), PaginatorError> {
182 let actual = self.state.get();
183 if actual != expected {
184 Err(PaginatorError::InvalidPreviousState { expected, actual })
185 } else {
186 Ok(())
187 }
188 }
189
190 pub fn state(&self) -> Subscriber<PaginatorState> {
192 self.state.subscribe()
193 }
194
195 pub async fn start_from(
201 &self,
202 event_id: &EventId,
203 num_events: UInt,
204 ) -> Result<StartFromResult, PaginatorError> {
205 self.check_state(PaginatorState::Initial)?;
206
207 if self.state.set_if_not_eq(PaginatorState::FetchingTargetEvent).is_none() {
211 return Err(PaginatorError::InvalidPreviousState {
212 expected: PaginatorState::Initial,
213 actual: PaginatorState::FetchingTargetEvent,
214 });
215 }
216
217 let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Initial);
218
219 let lazy_load_members = true;
221
222 let response =
223 self.room.event_with_context(event_id, lazy_load_members, num_events).await?;
224
225 let has_prev = response.prev_batch_token.is_some();
230 let has_next = response.next_batch_token.is_some();
231
232 {
233 let mut tokens = self.tokens.lock().unwrap();
234 tokens.previous = match response.prev_batch_token {
235 Some(token) => PaginationToken::HasMore(token),
236 None => PaginationToken::HitEnd,
237 };
238 tokens.next = match response.next_batch_token {
239 Some(token) => PaginationToken::HasMore(token),
240 None => PaginationToken::HitEnd,
241 };
242 }
243
244 reset_state_guard.disarm();
246 self.state.set(PaginatorState::Idle);
248
249 let events = response
256 .events_before
257 .into_iter()
258 .rev()
259 .chain(response.event)
260 .chain(response.events_after)
261 .collect();
262
263 Ok(StartFromResult { events, has_prev, has_next })
264 }
265
266 pub async fn paginate_backward(
275 &self,
276 num_events: UInt,
277 ) -> Result<PaginationResult, PaginatorError> {
278 self.paginate(Direction::Backward, num_events).await
279 }
280
281 pub fn hit_timeline_start(&self) -> bool {
286 matches!(self.tokens.lock().unwrap().previous, PaginationToken::HitEnd)
287 }
288
289 pub fn hit_timeline_end(&self) -> bool {
294 matches!(self.tokens.lock().unwrap().next, PaginationToken::HitEnd)
295 }
296
297 pub async fn paginate_forward(
305 &self,
306 num_events: UInt,
307 ) -> Result<PaginationResult, PaginatorError> {
308 self.paginate(Direction::Forward, num_events).await
309 }
310
311 async fn paginate(
315 &self,
316 dir: Direction,
317 num_events: UInt,
318 ) -> Result<PaginationResult, PaginatorError> {
319 self.check_state(PaginatorState::Idle)?;
320
321 let token = {
322 let tokens = self.tokens.lock().unwrap();
323
324 let token = match dir {
325 Direction::Backward => &tokens.previous,
326 Direction::Forward => &tokens.next,
327 };
328
329 match token {
330 PaginationToken::None => None,
331 PaginationToken::HasMore(val) => Some(val.clone()),
332 PaginationToken::HitEnd => {
333 return Ok(PaginationResult { events: Vec::new(), hit_end_of_timeline: true });
334 }
335 }
336 };
337
338 if self.state.set_if_not_eq(PaginatorState::Paginating).is_none() {
342 return Err(PaginatorError::InvalidPreviousState {
343 expected: PaginatorState::Idle,
344 actual: PaginatorState::Paginating,
345 });
346 }
347
348 let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Idle);
349
350 let mut options = MessagesOptions::new(dir).from(token.as_deref());
351 options.limit = num_events;
352
353 let response = self.room.messages(options).await?;
356
357 let hit_end_of_timeline = response.end.is_none();
362
363 {
364 let mut tokens = self.tokens.lock().unwrap();
365
366 let token = match dir {
367 Direction::Backward => &mut tokens.previous,
368 Direction::Forward => &mut tokens.next,
369 };
370
371 *token = match response.end {
372 Some(val) => PaginationToken::HasMore(val),
373 None => PaginationToken::HitEnd,
374 };
375 }
376
377 reset_state_guard.disarm();
381 self.state.set(PaginatorState::Idle);
383
384 Ok(PaginationResult { events: response.chunk, hit_end_of_timeline })
385 }
386}
387
388pub trait PaginableRoom: SendOutsideWasm + SyncOutsideWasm {
393 fn event_with_context(
408 &self,
409 event_id: &EventId,
410 lazy_load_members: bool,
411 num_events: UInt,
412 ) -> impl Future<Output = Result<EventWithContextResponse, PaginatorError>> + SendOutsideWasm;
413
414 fn messages(
416 &self,
417 opts: MessagesOptions,
418 ) -> impl Future<Output = Result<Messages, PaginatorError>> + SendOutsideWasm;
419}
420
421impl PaginableRoom for Room {
422 async fn event_with_context(
423 &self,
424 event_id: &EventId,
425 lazy_load_members: bool,
426 num_events: UInt,
427 ) -> Result<EventWithContextResponse, PaginatorError> {
428 let response =
429 match self.event_with_context(event_id, lazy_load_members, num_events, None).await {
430 Ok(result) => result,
431
432 Err(err) => {
433 if let Some(error) = err.as_client_api_error() {
437 if error.status_code == 404 {
438 return Err(PaginatorError::EventNotFound(event_id.to_owned()));
440 }
441 }
442
443 return Err(PaginatorError::SdkError(Box::new(err)));
445 }
446 };
447
448 Ok(response)
449 }
450
451 async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
452 self.messages(opts).await.map_err(|err| PaginatorError::SdkError(Box::new(err)))
453 }
454}
455
456impl PaginableRoom for WeakRoom {
457 async fn event_with_context(
458 &self,
459 event_id: &EventId,
460 lazy_load_members: bool,
461 num_events: UInt,
462 ) -> Result<EventWithContextResponse, PaginatorError> {
463 let Some(room) = self.get() else {
464 return Ok(EventWithContextResponse::default());
466 };
467
468 PaginableRoom::event_with_context(&room, event_id, lazy_load_members, num_events).await
469 }
470
471 async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
472 let Some(room) = self.get() else {
473 return Ok(Messages::default());
475 };
476
477 PaginableRoom::messages(&room, opts).await
478 }
479}
480
481#[cfg(all(not(target_arch = "wasm32"), test))]
482mod tests {
483 use std::sync::Arc;
484
485 use assert_matches2::assert_let;
486 use futures_core::Future;
487 use futures_util::FutureExt as _;
488 use matrix_sdk_base::deserialized_responses::TimelineEvent;
489 use matrix_sdk_test::{async_test, event_factory::EventFactory};
490 use once_cell::sync::Lazy;
491 use ruma::{api::Direction, event_id, room_id, uint, user_id, EventId, RoomId, UInt, UserId};
492 use tokio::{
493 spawn,
494 sync::{Mutex, Notify},
495 task::AbortHandle,
496 };
497
498 use super::{PaginableRoom, PaginatorError, PaginatorState};
499 use crate::{
500 event_cache::paginator::Paginator,
501 room::{EventWithContextResponse, Messages, MessagesOptions},
502 test_utils::assert_event_matches_msg,
503 };
504
505 #[derive(Clone)]
506 struct TestRoom {
507 event_factory: Arc<EventFactory>,
508 wait_for_ready: bool,
509
510 target_event_text: Arc<Mutex<String>>,
511 next_events: Arc<Mutex<Vec<TimelineEvent>>>,
512 prev_events: Arc<Mutex<Vec<TimelineEvent>>>,
513 prev_batch_token: Arc<Mutex<Option<String>>>,
514 next_batch_token: Arc<Mutex<Option<String>>>,
515
516 room_ready: Arc<Notify>,
517 }
518
519 impl TestRoom {
520 fn new(wait_for_ready: bool, room_id: &RoomId, sender: &UserId) -> Self {
521 let event_factory = Arc::new(EventFactory::default().sender(sender).room(room_id));
522
523 Self {
524 event_factory,
525 wait_for_ready,
526
527 room_ready: Default::default(),
528 target_event_text: Default::default(),
529 next_events: Default::default(),
530 prev_events: Default::default(),
531 prev_batch_token: Default::default(),
532 next_batch_token: Default::default(),
533 }
534 }
535
536 fn mark_ready(&self) {
538 self.room_ready.notify_one();
539 }
540 }
541
542 static ROOM_ID: Lazy<&RoomId> = Lazy::new(|| room_id!("!dune:herbert.org"));
543 static USER_ID: Lazy<&UserId> = Lazy::new(|| user_id!("@paul:atreid.es"));
544
545 impl PaginableRoom for TestRoom {
546 async fn event_with_context(
547 &self,
548 event_id: &EventId,
549 _lazy_load_members: bool,
550 num_events: UInt,
551 ) -> Result<EventWithContextResponse, PaginatorError> {
552 if self.wait_for_ready {
554 self.room_ready.notified().await;
555 }
556
557 let event = self
558 .event_factory
559 .text_msg(self.target_event_text.lock().await.clone())
560 .event_id(event_id)
561 .into_event();
562
563 let mut num_events = u64::from(num_events) as usize;
566
567 let prev_events = self.prev_events.lock().await;
568
569 let events_before = if prev_events.is_empty() {
570 Vec::new()
571 } else {
572 let len = prev_events.len();
573 let take_before = num_events.min(len);
574 num_events -= take_before;
576 prev_events[len - take_before..len].to_vec()
578 };
579
580 let events_after = self.next_events.lock().await;
581 let events_after = if events_after.is_empty() {
582 Vec::new()
583 } else {
584 events_after[0..num_events.min(events_after.len())].to_vec()
585 };
586
587 Ok(EventWithContextResponse {
588 event: Some(event),
589 events_before,
590 events_after,
591 prev_batch_token: self.prev_batch_token.lock().await.clone(),
592 next_batch_token: self.next_batch_token.lock().await.clone(),
593 state: Vec::new(),
594 })
595 }
596
597 async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
598 if self.wait_for_ready {
599 self.room_ready.notified().await;
600 }
601
602 let limit = u64::from(opts.limit) as usize;
603
604 let (end, events) = match opts.dir {
605 Direction::Backward => {
606 let events = self.prev_events.lock().await;
607 let events = if events.is_empty() {
608 Vec::new()
609 } else {
610 let len = events.len();
611 let take_before = limit.min(len);
612 events[len - take_before..len].to_vec()
614 };
615 (self.prev_batch_token.lock().await.clone(), events)
616 }
617
618 Direction::Forward => {
619 let events = self.next_events.lock().await;
620 let events = if events.is_empty() {
621 Vec::new()
622 } else {
623 events[0..limit.min(events.len())].to_vec()
624 };
625 (self.next_batch_token.lock().await.clone(), events)
626 }
627 };
628
629 Ok(Messages { start: opts.from.unwrap(), end, chunk: events, state: Vec::new() })
630 }
631 }
632
633 async fn assert_invalid_state<T: std::fmt::Debug>(
634 task: impl Future<Output = Result<T, PaginatorError>>,
635 expected: PaginatorState,
636 actual: PaginatorState,
637 ) {
638 assert_let!(
639 Err(PaginatorError::InvalidPreviousState {
640 expected: real_expected,
641 actual: real_actual
642 }) = task.await
643 );
644 assert_eq!(real_expected, expected);
645 assert_eq!(real_actual, actual);
646 }
647
648 #[async_test]
649 async fn test_start_from() {
650 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
652
653 let event_id = event_id!("$yoyoyo");
654 let event_factory = &room.event_factory;
655
656 *room.target_event_text.lock().await = "fetch_from".to_owned();
657 *room.prev_events.lock().await = (0..10)
658 .rev()
659 .map(|i| event_factory.text_msg(format!("before-{i}")).into_event())
660 .collect();
661 *room.next_events.lock().await =
662 (0..10).map(|i| event_factory.text_msg(format!("after-{i}")).into_event()).collect();
663
664 let paginator = Arc::new(Paginator::new(room.clone()));
666 let context =
667 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
668
669 assert!(!context.has_prev);
670 assert!(!context.has_next);
671
672 assert_eq!(context.events.len(), 21);
676
677 for i in 0..10 {
678 assert_event_matches_msg(&context.events[i], &format!("before-{i}"));
679 }
680
681 assert_event_matches_msg(&context.events[10], "fetch_from");
682 assert_eq!(context.events[10].raw().deserialize().unwrap().event_id(), event_id);
683
684 for i in 0..10 {
685 assert_event_matches_msg(&context.events[i + 11], &format!("after-{i}"));
686 }
687 }
688
689 #[async_test]
690 async fn test_start_from_with_num_events() {
691 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
693
694 let event_id = event_id!("$yoyoyo");
695 let event_factory = &room.event_factory;
696
697 *room.target_event_text.lock().await = "fetch_from".to_owned();
698 *room.prev_events.lock().await =
699 (0..100).rev().map(|i| event_factory.text_msg(format!("ev{i}")).into_event()).collect();
700
701 let paginator = Arc::new(Paginator::new(room.clone()));
703 let context =
704 paginator.start_from(event_id, uint!(10)).await.expect("start_from should work");
705
706 assert_eq!(context.events.len(), 11);
709
710 for i in 0..10 {
711 assert_event_matches_msg(&context.events[i], &format!("ev{i}"));
712 }
713 assert_event_matches_msg(&context.events[10], "fetch_from");
714 }
715
716 #[async_test]
717 async fn test_paginate_backward() {
718 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
720
721 let event_id = event_id!("$yoyoyo");
722 let event_factory = &room.event_factory;
723
724 *room.target_event_text.lock().await = "initial".to_owned();
725 *room.prev_batch_token.lock().await = Some("prev".to_owned());
726
727 let paginator = Arc::new(Paginator::new(room.clone()));
729
730 assert!(!paginator.hit_timeline_start(), "we must have a prev-batch token");
731 assert!(
732 !paginator.hit_timeline_end(),
733 "we don't know about the status of the next-batch token"
734 );
735
736 let context =
737 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
738
739 assert_eq!(context.events.len(), 1);
741 assert_event_matches_msg(&context.events[0], "initial");
742 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
743
744 assert!(context.has_prev);
746 assert!(!context.has_next);
747
748 assert!(!paginator.hit_timeline_start());
749 assert!(paginator.hit_timeline_end());
750
751 *room.prev_events.lock().await = vec![event_factory.text_msg("previous").into_event()];
753 *room.prev_batch_token.lock().await = Some("prev2".to_owned());
754
755 let prev =
757 paginator.paginate_backward(uint!(100)).await.expect("paginate backward should work");
758 assert!(!prev.hit_end_of_timeline);
759 assert!(!paginator.hit_timeline_start());
760 assert_eq!(prev.events.len(), 1);
761 assert_event_matches_msg(&prev.events[0], "previous");
762
763 *room.prev_events.lock().await = vec![event_factory.text_msg("oldest").into_event()];
766 *room.prev_batch_token.lock().await = None;
767
768 let prev = paginator
769 .paginate_backward(uint!(100))
770 .await
771 .expect("paginate backward the second time should work");
772 assert!(prev.hit_end_of_timeline);
773 assert!(paginator.hit_timeline_start());
774 assert_eq!(prev.events.len(), 1);
775 assert_event_matches_msg(&prev.events[0], "oldest");
776
777 let prev = paginator
780 .paginate_backward(uint!(100))
781 .await
782 .expect("paginate backward the third time should work");
783 assert!(prev.hit_end_of_timeline);
784 assert!(paginator.hit_timeline_start());
785 assert!(prev.events.is_empty());
786 }
787
788 #[async_test]
789 async fn test_paginate_backward_with_limit() {
790 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
792
793 let event_id = event_id!("$yoyoyo");
794 let event_factory = &room.event_factory;
795
796 *room.target_event_text.lock().await = "initial".to_owned();
797 *room.prev_batch_token.lock().await = Some("prev".to_owned());
798
799 let paginator = Arc::new(Paginator::new(room.clone()));
801 let context =
802 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
803
804 assert_eq!(context.events.len(), 1);
806 assert_event_matches_msg(&context.events[0], "initial");
807 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
808
809 assert!(context.has_prev);
811 assert!(!context.has_next);
812
813 *room.prev_events.lock().await = (0..100)
815 .rev()
816 .map(|i| event_factory.text_msg(format!("prev{i}")).into_event())
817 .collect();
818 *room.prev_batch_token.lock().await = None;
819
820 let prev =
822 paginator.paginate_backward(uint!(10)).await.expect("paginate backward should work");
823 assert!(prev.hit_end_of_timeline);
824 assert_eq!(prev.events.len(), 10);
825 for i in 0..10 {
826 assert_event_matches_msg(&prev.events[i], &format!("prev{}", 9 - i));
827 }
828 }
829
830 #[async_test]
831 async fn test_paginate_forward() {
832 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
834
835 let event_id = event_id!("$yoyoyo");
836 let event_factory = &room.event_factory;
837
838 *room.target_event_text.lock().await = "initial".to_owned();
839 *room.next_batch_token.lock().await = Some("next".to_owned());
840
841 let paginator = Arc::new(Paginator::new(room.clone()));
843 assert!(!paginator.hit_timeline_end(), "we must have a next-batch token");
844 assert!(
845 !paginator.hit_timeline_start(),
846 "we don't know about the status of the prev-batch token"
847 );
848
849 let context =
850 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
851
852 assert_eq!(context.events.len(), 1);
854 assert_event_matches_msg(&context.events[0], "initial");
855 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
856
857 assert!(!context.has_prev);
860 assert!(context.has_next);
861
862 assert!(paginator.hit_timeline_start());
863 assert!(!paginator.hit_timeline_end());
864
865 *room.next_events.lock().await = vec![event_factory.text_msg("next").into_event()];
867 *room.next_batch_token.lock().await = Some("next2".to_owned());
868
869 let next =
871 paginator.paginate_forward(uint!(100)).await.expect("paginate forward should work");
872 assert!(!next.hit_end_of_timeline);
873 assert_eq!(next.events.len(), 1);
874 assert_event_matches_msg(&next.events[0], "next");
875 assert!(!paginator.hit_timeline_end());
876
877 *room.next_events.lock().await = vec![event_factory.text_msg("latest").into_event()];
880 *room.next_batch_token.lock().await = None;
881
882 let next = paginator
883 .paginate_forward(uint!(100))
884 .await
885 .expect("paginate forward the second time should work");
886 assert!(next.hit_end_of_timeline);
887 assert_eq!(next.events.len(), 1);
888 assert_event_matches_msg(&next.events[0], "latest");
889 assert!(paginator.hit_timeline_end());
890
891 let next = paginator
894 .paginate_forward(uint!(100))
895 .await
896 .expect("paginate forward the third time should work");
897 assert!(next.hit_end_of_timeline);
898 assert!(next.events.is_empty());
899 assert!(paginator.hit_timeline_end());
900 }
901
902 #[async_test]
903 async fn test_state() {
904 let room = TestRoom::new(true, *ROOM_ID, *USER_ID);
905
906 *room.prev_batch_token.lock().await = Some("prev".to_owned());
907 *room.next_batch_token.lock().await = Some("next".to_owned());
908
909 let paginator = Arc::new(Paginator::new(room.clone()));
910
911 let event_id = event_id!("$yoyoyo");
912
913 let mut state = paginator.state();
914
915 assert_eq!(state.get(), PaginatorState::Initial);
916 assert!(state.next().now_or_never().is_none());
917
918 assert_invalid_state(
920 paginator.paginate_backward(uint!(100)),
921 PaginatorState::Idle,
922 PaginatorState::Initial,
923 )
924 .await;
925
926 assert!(state.next().now_or_never().is_none());
927
928 let p = paginator.clone();
930 let join_handle = spawn(async move { p.start_from(event_id, uint!(100)).await });
931
932 assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
933 assert!(state.next().now_or_never().is_none());
934
935 assert_invalid_state(
937 paginator.start_from(event_id, uint!(100)),
938 PaginatorState::Initial,
939 PaginatorState::FetchingTargetEvent,
940 )
941 .await;
942
943 assert_invalid_state(
944 paginator.paginate_backward(uint!(100)),
945 PaginatorState::Idle,
946 PaginatorState::FetchingTargetEvent,
947 )
948 .await;
949
950 assert!(state.next().now_or_never().is_none());
951
952 room.mark_ready();
954
955 assert_eq!(state.next().await, Some(PaginatorState::Idle));
957
958 join_handle.await.expect("joined failed").expect("/context failed");
959
960 assert!(state.next().now_or_never().is_none());
961
962 let p = paginator.clone();
963 let join_handle = spawn(async move { p.paginate_backward(uint!(100)).await });
964
965 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
966
967 assert_invalid_state(
969 paginator.start_from(event_id, uint!(100)),
970 PaginatorState::Initial,
971 PaginatorState::Paginating,
972 )
973 .await;
974
975 assert_invalid_state(
976 paginator.paginate_backward(uint!(100)),
977 PaginatorState::Idle,
978 PaginatorState::Paginating,
979 )
980 .await;
981
982 assert_invalid_state(
983 paginator.paginate_forward(uint!(100)),
984 PaginatorState::Idle,
985 PaginatorState::Paginating,
986 )
987 .await;
988
989 assert!(state.next().now_or_never().is_none());
990
991 room.mark_ready();
992
993 assert_eq!(state.next().await, Some(PaginatorState::Idle));
994
995 join_handle.await.expect("joined failed").expect("/messages failed");
996
997 assert!(state.next().now_or_never().is_none());
998 }
999
1000 mod aborts {
1001 use super::*;
1002 use crate::event_cache::{paginator::PaginationTokens, PaginationToken};
1003
1004 #[derive(Clone, Default)]
1005 struct AbortingRoom {
1006 abort_handle: Arc<Mutex<Option<AbortHandle>>>,
1007 room_ready: Arc<Notify>,
1008 }
1009
1010 impl AbortingRoom {
1011 async fn wait_abort_and_yield(&self) -> ! {
1012 self.room_ready.notified().await;
1014
1015 let mut guard = self.abort_handle.lock().await;
1017 let handle = guard.take().expect("only call me when i'm initialized");
1018 handle.abort();
1019
1020 loop {
1022 tokio::task::yield_now().await;
1023 }
1024 }
1025 }
1026
1027 impl PaginableRoom for AbortingRoom {
1028 async fn event_with_context(
1029 &self,
1030 _event_id: &EventId,
1031 _lazy_load_members: bool,
1032 _num_events: UInt,
1033 ) -> Result<EventWithContextResponse, PaginatorError> {
1034 self.wait_abort_and_yield().await
1035 }
1036
1037 async fn messages(&self, _opts: MessagesOptions) -> Result<Messages, PaginatorError> {
1038 self.wait_abort_and_yield().await
1039 }
1040 }
1041
1042 #[async_test]
1043 async fn test_abort_while_starting_from() {
1044 let room = AbortingRoom::default();
1045
1046 let paginator = Arc::new(Paginator::new(room.clone()));
1047
1048 let mut state = paginator.state();
1049
1050 assert_eq!(state.get(), PaginatorState::Initial);
1051 assert!(state.next().now_or_never().is_none());
1052
1053 let p = paginator.clone();
1055 let join_handle = spawn(async move {
1056 let _ = p.start_from(event_id!("$yoyoyo"), uint!(100)).await;
1057 });
1058
1059 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1060
1061 assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
1062 assert!(state.next().now_or_never().is_none());
1063
1064 room.room_ready.notify_one();
1065
1066 let join_result = join_handle.await;
1068 assert!(join_result.unwrap_err().is_cancelled());
1069
1070 assert_eq!(state.next().await, Some(PaginatorState::Initial));
1072 assert!(state.next().now_or_never().is_none());
1073 }
1074
1075 #[async_test]
1076 async fn test_abort_while_paginating() {
1077 let room = AbortingRoom::default();
1078
1079 let paginator = Paginator::new(room.clone());
1081 paginator.state.set(PaginatorState::Idle);
1082 *paginator.tokens.lock().unwrap() = PaginationTokens {
1083 previous: PaginationToken::HasMore("prev".to_owned()),
1084 next: PaginationToken::HasMore("next".to_owned()),
1085 };
1086
1087 let paginator = Arc::new(paginator);
1088
1089 let mut state = paginator.state();
1090
1091 assert_eq!(state.get(), PaginatorState::Idle);
1092 assert!(state.next().now_or_never().is_none());
1093
1094 let p = paginator.clone();
1096 let join_handle = spawn(async move {
1097 let _ = p.paginate_backward(uint!(100)).await;
1098 });
1099
1100 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1101
1102 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1103 assert!(state.next().now_or_never().is_none());
1104
1105 room.room_ready.notify_one();
1106
1107 let join_result = join_handle.await;
1109 assert!(join_result.unwrap_err().is_cancelled());
1110
1111 assert_eq!(state.next().await, Some(PaginatorState::Idle));
1113 assert!(state.next().now_or_never().is_none());
1114
1115 let p = paginator.clone();
1117 let join_handle = spawn(async move {
1118 let _ = p.paginate_forward(uint!(100)).await;
1119 });
1120
1121 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1122
1123 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1124 assert!(state.next().now_or_never().is_none());
1125
1126 room.room_ready.notify_one();
1127
1128 let join_result = join_handle.await;
1129 assert!(join_result.unwrap_err().is_cancelled());
1130
1131 assert_eq!(state.next().await, Some(PaginatorState::Idle));
1132 assert!(state.next().now_or_never().is_none());
1133 }
1134 }
1135}