1use std::{future::Future, sync::Mutex};
21
22use eyeball::{SharedObservable, Subscriber};
23use matrix_sdk_base::{SendOutsideWasm, SyncOutsideWasm, deserialized_responses::TimelineEvent};
24use ruma::{EventId, UInt, api::Direction};
25
26use crate::{
27 Room,
28 paginators::{PaginationResult, PaginationToken, PaginatorError},
29 room::{EventWithContextResponse, Messages, MessagesOptions},
30};
31
32#[derive(Debug, PartialEq, Copy, Clone)]
34#[cfg_attr(feature = "uniffi", derive(uniffi::Enum))]
35pub enum PaginatorState {
36 Initial,
38
39 FetchingTargetEvent,
41
42 Idle,
45
46 Paginating,
48}
49
50#[derive(Debug, Clone)]
52pub struct PaginationTokens {
53 pub previous: PaginationToken,
55 pub next: PaginationToken,
57}
58
59pub struct Paginator<PR: PaginableRoom> {
64 room: PR,
66
67 state: SharedObservable<PaginatorState>,
69
70 tokens: Mutex<PaginationTokens>,
74}
75
76#[cfg(not(tarpaulin_include))]
77impl<PR: PaginableRoom> std::fmt::Debug for Paginator<PR> {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("Paginator")
81 .field("state", &self.state.get())
82 .field("tokens", &self.tokens)
83 .finish_non_exhaustive()
84 }
85}
86
87#[derive(Debug)]
89pub struct StartFromResult {
90 pub events: Vec<TimelineEvent>,
92
93 pub has_prev: bool,
95
96 pub has_next: bool,
98}
99
100struct ResetStateGuard {
102 target: Option<PaginatorState>,
103 state: SharedObservable<PaginatorState>,
104}
105
106impl ResetStateGuard {
107 fn new(state: SharedObservable<PaginatorState>, target: PaginatorState) -> Self {
109 Self { target: Some(target), state }
110 }
111
112 fn disarm(mut self) {
114 self.target = None;
115 }
116}
117
118impl Drop for ResetStateGuard {
119 fn drop(&mut self) {
120 if let Some(target) = self.target.take() {
121 self.state.set_if_not_eq(target);
122 }
123 }
124}
125
126impl<PR: PaginableRoom> Paginator<PR> {
127 pub fn new(room: PR) -> Self {
129 Self {
130 room,
131 state: SharedObservable::new(PaginatorState::Initial),
132 tokens: Mutex::new(PaginationTokens { previous: None.into(), next: None.into() }),
133 }
134 }
135
136 fn check_state(&self, expected: PaginatorState) -> Result<(), PaginatorError> {
138 let actual = self.state.get();
139 if actual != expected {
140 Err(PaginatorError::InvalidPreviousState { expected, actual })
141 } else {
142 Ok(())
143 }
144 }
145
146 pub fn state(&self) -> Subscriber<PaginatorState> {
148 self.state.subscribe()
149 }
150
151 pub async fn start_from(
157 &self,
158 event_id: &EventId,
159 num_events: UInt,
160 ) -> Result<StartFromResult, PaginatorError> {
161 self.check_state(PaginatorState::Initial)?;
162
163 if self.state.set_if_not_eq(PaginatorState::FetchingTargetEvent).is_none() {
167 return Err(PaginatorError::InvalidPreviousState {
168 expected: PaginatorState::Initial,
169 actual: PaginatorState::FetchingTargetEvent,
170 });
171 }
172
173 let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Initial);
174
175 let lazy_load_members = true;
177
178 let response =
179 self.room.event_with_context(event_id, lazy_load_members, num_events).await?;
180
181 let has_prev = response.prev_batch_token.is_some();
186 let has_next = response.next_batch_token.is_some();
187
188 {
189 let mut tokens = self.tokens.lock().unwrap();
190 tokens.previous = match response.prev_batch_token {
191 Some(token) => PaginationToken::HasMore(token),
192 None => PaginationToken::HitEnd,
193 };
194 tokens.next = match response.next_batch_token {
195 Some(token) => PaginationToken::HasMore(token),
196 None => PaginationToken::HitEnd,
197 };
198 }
199
200 reset_state_guard.disarm();
202 self.state.set(PaginatorState::Idle);
204
205 let events = response
212 .events_before
213 .into_iter()
214 .rev()
215 .chain(response.event)
216 .chain(response.events_after)
217 .collect();
218
219 Ok(StartFromResult { events, has_prev, has_next })
220 }
221
222 pub async fn paginate_backward(
231 &self,
232 num_events: UInt,
233 ) -> Result<PaginationResult, PaginatorError> {
234 self.paginate(Direction::Backward, num_events).await
235 }
236
237 pub fn hit_timeline_start(&self) -> bool {
242 matches!(self.tokens.lock().unwrap().previous, PaginationToken::HitEnd)
243 }
244
245 pub fn hit_timeline_end(&self) -> bool {
250 matches!(self.tokens.lock().unwrap().next, PaginationToken::HitEnd)
251 }
252
253 pub async fn paginate_forward(
261 &self,
262 num_events: UInt,
263 ) -> Result<PaginationResult, PaginatorError> {
264 self.paginate(Direction::Forward, num_events).await
265 }
266
267 async fn paginate(
271 &self,
272 dir: Direction,
273 num_events: UInt,
274 ) -> Result<PaginationResult, PaginatorError> {
275 self.check_state(PaginatorState::Idle)?;
276
277 let token = {
278 let tokens = self.tokens.lock().unwrap();
279
280 let token = match dir {
281 Direction::Backward => &tokens.previous,
282 Direction::Forward => &tokens.next,
283 };
284
285 match token {
286 PaginationToken::None => None,
287 PaginationToken::HasMore(val) => Some(val.clone()),
288 PaginationToken::HitEnd => {
289 return Ok(PaginationResult { events: Vec::new(), hit_end_of_timeline: true });
290 }
291 }
292 };
293
294 if self.state.set_if_not_eq(PaginatorState::Paginating).is_none() {
298 return Err(PaginatorError::InvalidPreviousState {
299 expected: PaginatorState::Idle,
300 actual: PaginatorState::Paginating,
301 });
302 }
303
304 let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Idle);
305
306 let mut options = MessagesOptions::new(dir).from(token.as_deref());
307 options.limit = num_events;
308
309 let response = self.room.messages(options).await?;
312
313 let hit_end_of_timeline = response.end.is_none();
318
319 {
320 let mut tokens = self.tokens.lock().unwrap();
321
322 let token = match dir {
323 Direction::Backward => &mut tokens.previous,
324 Direction::Forward => &mut tokens.next,
325 };
326
327 *token = match response.end {
328 Some(val) => PaginationToken::HasMore(val),
329 None => PaginationToken::HitEnd,
330 };
331 }
332
333 reset_state_guard.disarm();
337 self.state.set(PaginatorState::Idle);
339
340 Ok(PaginationResult { events: response.chunk, hit_end_of_timeline })
341 }
342
343 pub fn tokens(&self) -> PaginationTokens {
345 self.tokens.lock().unwrap().clone()
346 }
347}
348
349pub trait PaginableRoom: SendOutsideWasm + SyncOutsideWasm {
354 fn event_with_context(
369 &self,
370 event_id: &EventId,
371 lazy_load_members: bool,
372 num_events: UInt,
373 ) -> impl Future<Output = Result<EventWithContextResponse, PaginatorError>> + SendOutsideWasm;
374
375 fn messages(
377 &self,
378 opts: MessagesOptions,
379 ) -> impl Future<Output = Result<Messages, PaginatorError>> + SendOutsideWasm;
380}
381
382impl PaginableRoom for Room {
383 async fn event_with_context(
384 &self,
385 event_id: &EventId,
386 lazy_load_members: bool,
387 num_events: UInt,
388 ) -> Result<EventWithContextResponse, PaginatorError> {
389 let response =
390 match self.event_with_context(event_id, lazy_load_members, num_events, None).await {
391 Ok(result) => result,
392
393 Err(err) => {
394 if let Some(error) = err.as_client_api_error()
398 && error.status_code == 404
399 {
400 return Err(PaginatorError::EventNotFound(event_id.to_owned()));
402 }
403
404 return Err(PaginatorError::SdkError(Box::new(err)));
406 }
407 };
408
409 Ok(response)
410 }
411
412 async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
413 self.messages(opts).await.map_err(|err| PaginatorError::SdkError(Box::new(err)))
414 }
415}
416
417#[cfg(all(not(target_family = "wasm"), test))]
418mod tests {
419 use std::sync::{Arc, LazyLock};
420
421 use assert_matches2::assert_let;
422 use futures_core::Future;
423 use futures_util::FutureExt as _;
424 use matrix_sdk_base::deserialized_responses::TimelineEvent;
425 use matrix_sdk_test::{async_test, event_factory::EventFactory};
426 use ruma::{EventId, RoomId, UInt, UserId, api::Direction, event_id, room_id, uint, user_id};
427 use tokio::{
428 spawn,
429 sync::{Mutex, Notify},
430 task::AbortHandle,
431 };
432
433 use super::{PaginableRoom, PaginatorError, PaginatorState};
434 use crate::{
435 paginators::Paginator,
436 room::{EventWithContextResponse, Messages, MessagesOptions},
437 test_utils::assert_event_matches_msg,
438 };
439
440 #[derive(Clone)]
441 struct TestRoom {
442 event_factory: Arc<EventFactory>,
443 wait_for_ready: bool,
444
445 target_event_text: Arc<Mutex<String>>,
446 next_events: Arc<Mutex<Vec<TimelineEvent>>>,
447 prev_events: Arc<Mutex<Vec<TimelineEvent>>>,
448 prev_batch_token: Arc<Mutex<Option<String>>>,
449 next_batch_token: Arc<Mutex<Option<String>>>,
450
451 room_ready: Arc<Notify>,
452 }
453
454 impl TestRoom {
455 fn new(wait_for_ready: bool, room_id: &RoomId, sender: &UserId) -> Self {
456 let event_factory = Arc::new(EventFactory::default().sender(sender).room(room_id));
457
458 Self {
459 event_factory,
460 wait_for_ready,
461
462 room_ready: Default::default(),
463 target_event_text: Default::default(),
464 next_events: Default::default(),
465 prev_events: Default::default(),
466 prev_batch_token: Default::default(),
467 next_batch_token: Default::default(),
468 }
469 }
470
471 fn mark_ready(&self) {
473 self.room_ready.notify_one();
474 }
475 }
476
477 static ROOM_ID: LazyLock<&RoomId> = LazyLock::new(|| room_id!("!dune:herbert.org"));
478 static USER_ID: LazyLock<&UserId> = LazyLock::new(|| user_id!("@paul:atreid.es"));
479
480 impl PaginableRoom for TestRoom {
481 async fn event_with_context(
482 &self,
483 event_id: &EventId,
484 _lazy_load_members: bool,
485 num_events: UInt,
486 ) -> Result<EventWithContextResponse, PaginatorError> {
487 if self.wait_for_ready {
489 self.room_ready.notified().await;
490 }
491
492 let event = self
493 .event_factory
494 .text_msg(self.target_event_text.lock().await.clone())
495 .event_id(event_id)
496 .into_event();
497
498 let mut num_events = u64::from(num_events) as usize;
501
502 let prev_events = self.prev_events.lock().await;
503
504 let events_before = if prev_events.is_empty() {
505 Vec::new()
506 } else {
507 let len = prev_events.len();
508 let take_before = num_events.min(len);
509 num_events -= take_before;
511 prev_events[len - take_before..len].to_vec()
513 };
514
515 let events_after = self.next_events.lock().await;
516 let events_after = if events_after.is_empty() {
517 Vec::new()
518 } else {
519 events_after[0..num_events.min(events_after.len())].to_vec()
520 };
521
522 Ok(EventWithContextResponse {
523 event: Some(event),
524 events_before,
525 events_after,
526 prev_batch_token: self.prev_batch_token.lock().await.clone(),
527 next_batch_token: self.next_batch_token.lock().await.clone(),
528 state: Vec::new(),
529 })
530 }
531
532 async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
533 if self.wait_for_ready {
534 self.room_ready.notified().await;
535 }
536
537 let limit = u64::from(opts.limit) as usize;
538
539 let (end, events) = match opts.dir {
540 Direction::Backward => {
541 let events = self.prev_events.lock().await;
542 let events = if events.is_empty() {
543 Vec::new()
544 } else {
545 let len = events.len();
546 let take_before = limit.min(len);
547 events[len - take_before..len].to_vec()
549 };
550 (self.prev_batch_token.lock().await.clone(), events)
551 }
552
553 Direction::Forward => {
554 let events = self.next_events.lock().await;
555 let events = if events.is_empty() {
556 Vec::new()
557 } else {
558 events[0..limit.min(events.len())].to_vec()
559 };
560 (self.next_batch_token.lock().await.clone(), events)
561 }
562 };
563
564 Ok(Messages { start: opts.from.unwrap(), end, chunk: events, state: Vec::new() })
565 }
566 }
567
568 async fn assert_invalid_state<T: std::fmt::Debug>(
569 task: impl Future<Output = Result<T, PaginatorError>>,
570 expected: PaginatorState,
571 actual: PaginatorState,
572 ) {
573 assert_let!(
574 Err(PaginatorError::InvalidPreviousState {
575 expected: real_expected,
576 actual: real_actual
577 }) = task.await
578 );
579 assert_eq!(real_expected, expected);
580 assert_eq!(real_actual, actual);
581 }
582
583 #[async_test]
584 async fn test_start_from() {
585 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
587
588 let event_id = event_id!("$yoyoyo");
589 let event_factory = &room.event_factory;
590
591 *room.target_event_text.lock().await = "fetch_from".to_owned();
592 *room.prev_events.lock().await = (0..10)
593 .rev()
594 .map(|i| event_factory.text_msg(format!("before-{i}")).into_event())
595 .collect();
596 *room.next_events.lock().await =
597 (0..10).map(|i| event_factory.text_msg(format!("after-{i}")).into_event()).collect();
598
599 let paginator = Arc::new(Paginator::new(room.clone()));
601 let context =
602 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
603
604 assert!(!context.has_prev);
605 assert!(!context.has_next);
606
607 assert_eq!(context.events.len(), 21);
611
612 for i in 0..10 {
613 assert_event_matches_msg(&context.events[i], &format!("before-{i}"));
614 }
615
616 assert_event_matches_msg(&context.events[10], "fetch_from");
617 assert_eq!(context.events[10].raw().deserialize().unwrap().event_id(), event_id);
618
619 for i in 0..10 {
620 assert_event_matches_msg(&context.events[i + 11], &format!("after-{i}"));
621 }
622 }
623
624 #[async_test]
625 async fn test_start_from_with_num_events() {
626 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
628
629 let event_id = event_id!("$yoyoyo");
630 let event_factory = &room.event_factory;
631
632 *room.target_event_text.lock().await = "fetch_from".to_owned();
633 *room.prev_events.lock().await =
634 (0..100).rev().map(|i| event_factory.text_msg(format!("ev{i}")).into_event()).collect();
635
636 let paginator = Arc::new(Paginator::new(room.clone()));
638 let context =
639 paginator.start_from(event_id, uint!(10)).await.expect("start_from should work");
640
641 assert_eq!(context.events.len(), 11);
644
645 for i in 0..10 {
646 assert_event_matches_msg(&context.events[i], &format!("ev{i}"));
647 }
648 assert_event_matches_msg(&context.events[10], "fetch_from");
649 }
650
651 #[async_test]
652 async fn test_paginate_backward() {
653 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
655
656 let event_id = event_id!("$yoyoyo");
657 let event_factory = &room.event_factory;
658
659 *room.target_event_text.lock().await = "initial".to_owned();
660 *room.prev_batch_token.lock().await = Some("prev".to_owned());
661
662 let paginator = Arc::new(Paginator::new(room.clone()));
664
665 assert!(!paginator.hit_timeline_start(), "we must have a prev-batch token");
666 assert!(
667 !paginator.hit_timeline_end(),
668 "we don't know about the status of the next-batch token"
669 );
670
671 let context =
672 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
673
674 assert_eq!(context.events.len(), 1);
676 assert_event_matches_msg(&context.events[0], "initial");
677 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
678
679 assert!(context.has_prev);
681 assert!(!context.has_next);
682
683 assert!(!paginator.hit_timeline_start());
684 assert!(paginator.hit_timeline_end());
685
686 *room.prev_events.lock().await = vec![event_factory.text_msg("previous").into_event()];
688 *room.prev_batch_token.lock().await = Some("prev2".to_owned());
689
690 let prev =
692 paginator.paginate_backward(uint!(100)).await.expect("paginate backward should work");
693 assert!(!prev.hit_end_of_timeline);
694 assert!(!paginator.hit_timeline_start());
695 assert_eq!(prev.events.len(), 1);
696 assert_event_matches_msg(&prev.events[0], "previous");
697
698 *room.prev_events.lock().await = vec![event_factory.text_msg("oldest").into_event()];
701 *room.prev_batch_token.lock().await = None;
702
703 let prev = paginator
704 .paginate_backward(uint!(100))
705 .await
706 .expect("paginate backward the second time should work");
707 assert!(prev.hit_end_of_timeline);
708 assert!(paginator.hit_timeline_start());
709 assert_eq!(prev.events.len(), 1);
710 assert_event_matches_msg(&prev.events[0], "oldest");
711
712 let prev = paginator
715 .paginate_backward(uint!(100))
716 .await
717 .expect("paginate backward the third time should work");
718 assert!(prev.hit_end_of_timeline);
719 assert!(paginator.hit_timeline_start());
720 assert!(prev.events.is_empty());
721 }
722
723 #[async_test]
724 async fn test_paginate_backward_with_limit() {
725 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
727
728 let event_id = event_id!("$yoyoyo");
729 let event_factory = &room.event_factory;
730
731 *room.target_event_text.lock().await = "initial".to_owned();
732 *room.prev_batch_token.lock().await = Some("prev".to_owned());
733
734 let paginator = Arc::new(Paginator::new(room.clone()));
736 let context =
737 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
738
739 assert_eq!(context.events.len(), 1);
741 assert_event_matches_msg(&context.events[0], "initial");
742 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
743
744 assert!(context.has_prev);
746 assert!(!context.has_next);
747
748 *room.prev_events.lock().await = (0..100)
750 .rev()
751 .map(|i| event_factory.text_msg(format!("prev{i}")).into_event())
752 .collect();
753 *room.prev_batch_token.lock().await = None;
754
755 let prev =
757 paginator.paginate_backward(uint!(10)).await.expect("paginate backward should work");
758 assert!(prev.hit_end_of_timeline);
759 assert_eq!(prev.events.len(), 10);
760 for i in 0..10 {
761 assert_event_matches_msg(&prev.events[i], &format!("prev{}", 9 - i));
762 }
763 }
764
765 #[async_test]
766 async fn test_paginate_forward() {
767 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
769
770 let event_id = event_id!("$yoyoyo");
771 let event_factory = &room.event_factory;
772
773 *room.target_event_text.lock().await = "initial".to_owned();
774 *room.next_batch_token.lock().await = Some("next".to_owned());
775
776 let paginator = Arc::new(Paginator::new(room.clone()));
778 assert!(!paginator.hit_timeline_end(), "we must have a next-batch token");
779 assert!(
780 !paginator.hit_timeline_start(),
781 "we don't know about the status of the prev-batch token"
782 );
783
784 let context =
785 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
786
787 assert_eq!(context.events.len(), 1);
789 assert_event_matches_msg(&context.events[0], "initial");
790 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
791
792 assert!(!context.has_prev);
795 assert!(context.has_next);
796
797 assert!(paginator.hit_timeline_start());
798 assert!(!paginator.hit_timeline_end());
799
800 *room.next_events.lock().await = vec![event_factory.text_msg("next").into_event()];
802 *room.next_batch_token.lock().await = Some("next2".to_owned());
803
804 let next =
806 paginator.paginate_forward(uint!(100)).await.expect("paginate forward should work");
807 assert!(!next.hit_end_of_timeline);
808 assert_eq!(next.events.len(), 1);
809 assert_event_matches_msg(&next.events[0], "next");
810 assert!(!paginator.hit_timeline_end());
811
812 *room.next_events.lock().await = vec![event_factory.text_msg("latest").into_event()];
815 *room.next_batch_token.lock().await = None;
816
817 let next = paginator
818 .paginate_forward(uint!(100))
819 .await
820 .expect("paginate forward the second time should work");
821 assert!(next.hit_end_of_timeline);
822 assert_eq!(next.events.len(), 1);
823 assert_event_matches_msg(&next.events[0], "latest");
824 assert!(paginator.hit_timeline_end());
825
826 let next = paginator
829 .paginate_forward(uint!(100))
830 .await
831 .expect("paginate forward the third time should work");
832 assert!(next.hit_end_of_timeline);
833 assert!(next.events.is_empty());
834 assert!(paginator.hit_timeline_end());
835 }
836
837 #[async_test]
838 async fn test_state() {
839 let room = TestRoom::new(true, *ROOM_ID, *USER_ID);
840
841 *room.prev_batch_token.lock().await = Some("prev".to_owned());
842 *room.next_batch_token.lock().await = Some("next".to_owned());
843
844 let paginator = Arc::new(Paginator::new(room.clone()));
845
846 let event_id = event_id!("$yoyoyo");
847
848 let mut state = paginator.state();
849
850 assert_eq!(state.get(), PaginatorState::Initial);
851 assert!(state.next().now_or_never().is_none());
852
853 assert_invalid_state(
855 paginator.paginate_backward(uint!(100)),
856 PaginatorState::Idle,
857 PaginatorState::Initial,
858 )
859 .await;
860
861 assert!(state.next().now_or_never().is_none());
862
863 let p = paginator.clone();
865 let join_handle = spawn(async move { p.start_from(event_id, uint!(100)).await });
866
867 assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
868 assert!(state.next().now_or_never().is_none());
869
870 assert_invalid_state(
872 paginator.start_from(event_id, uint!(100)),
873 PaginatorState::Initial,
874 PaginatorState::FetchingTargetEvent,
875 )
876 .await;
877
878 assert_invalid_state(
879 paginator.paginate_backward(uint!(100)),
880 PaginatorState::Idle,
881 PaginatorState::FetchingTargetEvent,
882 )
883 .await;
884
885 assert!(state.next().now_or_never().is_none());
886
887 room.mark_ready();
889
890 assert_eq!(state.next().await, Some(PaginatorState::Idle));
892
893 join_handle.await.expect("joined failed").expect("/context failed");
894
895 assert!(state.next().now_or_never().is_none());
896
897 let p = paginator.clone();
898 let join_handle = spawn(async move { p.paginate_backward(uint!(100)).await });
899
900 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
901
902 assert_invalid_state(
904 paginator.start_from(event_id, uint!(100)),
905 PaginatorState::Initial,
906 PaginatorState::Paginating,
907 )
908 .await;
909
910 assert_invalid_state(
911 paginator.paginate_backward(uint!(100)),
912 PaginatorState::Idle,
913 PaginatorState::Paginating,
914 )
915 .await;
916
917 assert_invalid_state(
918 paginator.paginate_forward(uint!(100)),
919 PaginatorState::Idle,
920 PaginatorState::Paginating,
921 )
922 .await;
923
924 assert!(state.next().now_or_never().is_none());
925
926 room.mark_ready();
927
928 assert_eq!(state.next().await, Some(PaginatorState::Idle));
929
930 join_handle.await.expect("joined failed").expect("/messages failed");
931
932 assert!(state.next().now_or_never().is_none());
933 }
934
935 mod aborts {
936 use super::*;
937 use crate::paginators::room::{PaginationToken, PaginationTokens};
938
939 #[derive(Clone, Default)]
940 struct AbortingRoom {
941 abort_handle: Arc<Mutex<Option<AbortHandle>>>,
942 room_ready: Arc<Notify>,
943 }
944
945 impl AbortingRoom {
946 async fn wait_abort_and_yield(&self) -> ! {
947 self.room_ready.notified().await;
949
950 let mut guard = self.abort_handle.lock().await;
952 let handle = guard.take().expect("only call me when i'm initialized");
953 handle.abort();
954
955 loop {
957 tokio::task::yield_now().await;
958 }
959 }
960 }
961
962 impl PaginableRoom for AbortingRoom {
963 async fn event_with_context(
964 &self,
965 _event_id: &EventId,
966 _lazy_load_members: bool,
967 _num_events: UInt,
968 ) -> Result<EventWithContextResponse, PaginatorError> {
969 self.wait_abort_and_yield().await
970 }
971
972 async fn messages(&self, _opts: MessagesOptions) -> Result<Messages, PaginatorError> {
973 self.wait_abort_and_yield().await
974 }
975 }
976
977 #[async_test]
978 async fn test_abort_while_starting_from() {
979 let room = AbortingRoom::default();
980
981 let paginator = Arc::new(Paginator::new(room.clone()));
982
983 let mut state = paginator.state();
984
985 assert_eq!(state.get(), PaginatorState::Initial);
986 assert!(state.next().now_or_never().is_none());
987
988 let p = paginator.clone();
990 let join_handle = spawn(async move {
991 let _ = p.start_from(event_id!("$yoyoyo"), uint!(100)).await;
992 });
993
994 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
995
996 assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
997 assert!(state.next().now_or_never().is_none());
998
999 room.room_ready.notify_one();
1000
1001 let join_result = join_handle.await;
1003 assert!(join_result.unwrap_err().is_cancelled());
1004
1005 assert_eq!(state.next().await, Some(PaginatorState::Initial));
1007 assert!(state.next().now_or_never().is_none());
1008 }
1009
1010 #[async_test]
1011 async fn test_abort_while_paginating() {
1012 let room = AbortingRoom::default();
1013
1014 let paginator = Paginator::new(room.clone());
1016 paginator.state.set(PaginatorState::Idle);
1017 *paginator.tokens.lock().unwrap() = PaginationTokens {
1018 previous: PaginationToken::HasMore("prev".to_owned()),
1019 next: PaginationToken::HasMore("next".to_owned()),
1020 };
1021
1022 let paginator = Arc::new(paginator);
1023
1024 let mut state = paginator.state();
1025
1026 assert_eq!(state.get(), PaginatorState::Idle);
1027 assert!(state.next().now_or_never().is_none());
1028
1029 let p = paginator.clone();
1031 let join_handle = spawn(async move {
1032 let _ = p.paginate_backward(uint!(100)).await;
1033 });
1034
1035 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1036
1037 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1038 assert!(state.next().now_or_never().is_none());
1039
1040 room.room_ready.notify_one();
1041
1042 let join_result = join_handle.await;
1044 assert!(join_result.unwrap_err().is_cancelled());
1045
1046 assert_eq!(state.next().await, Some(PaginatorState::Idle));
1048 assert!(state.next().now_or_never().is_none());
1049
1050 let p = paginator.clone();
1052 let join_handle = spawn(async move {
1053 let _ = p.paginate_forward(uint!(100)).await;
1054 });
1055
1056 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1057
1058 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1059 assert!(state.next().now_or_never().is_none());
1060
1061 room.room_ready.notify_one();
1062
1063 let join_result = join_handle.await;
1064 assert!(join_result.unwrap_err().is_cancelled());
1065
1066 assert_eq!(state.next().await, Some(PaginatorState::Idle));
1067 assert!(state.next().now_or_never().is_none());
1068 }
1069 }
1070}