1use std::{future::Future, sync::Mutex};
21
22use eyeball::{SharedObservable, Subscriber};
23use matrix_sdk_base::{SendOutsideWasm, SyncOutsideWasm, deserialized_responses::TimelineEvent};
24use ruma::{EventId, UInt, api::Direction};
25
26use crate::{
27 Room,
28 paginators::{PaginationResult, PaginationToken, PaginatorError},
29 room::{EventWithContextResponse, Messages, MessagesOptions},
30};
31
32#[derive(Debug, PartialEq, Copy, Clone)]
34#[cfg_attr(feature = "uniffi", derive(uniffi::Enum))]
35pub enum PaginatorState {
36 Initial,
38
39 FetchingTargetEvent,
41
42 Idle,
45
46 Paginating,
48}
49
50#[derive(Debug, Clone)]
52pub struct PaginationTokens {
53 pub previous: PaginationToken,
55 pub next: PaginationToken,
57}
58
59pub struct Paginator<PR: PaginableRoom> {
64 room: PR,
66
67 state: SharedObservable<PaginatorState>,
69
70 tokens: Mutex<PaginationTokens>,
74}
75
76#[cfg(not(tarpaulin_include))]
77impl<PR: PaginableRoom> std::fmt::Debug for Paginator<PR> {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("Paginator")
81 .field("state", &self.state.get())
82 .field("tokens", &self.tokens)
83 .finish_non_exhaustive()
84 }
85}
86
87#[derive(Debug)]
89pub struct StartFromResult {
90 pub events: Vec<TimelineEvent>,
92
93 pub has_prev: bool,
95
96 pub has_next: bool,
98}
99
100struct ResetStateGuard {
102 target: Option<PaginatorState>,
103 state: SharedObservable<PaginatorState>,
104}
105
106impl ResetStateGuard {
107 fn new(state: SharedObservable<PaginatorState>, target: PaginatorState) -> Self {
109 Self { target: Some(target), state }
110 }
111
112 fn disarm(mut self) {
114 self.target = None;
115 }
116}
117
118impl Drop for ResetStateGuard {
119 fn drop(&mut self) {
120 if let Some(target) = self.target.take() {
121 self.state.set_if_not_eq(target);
122 }
123 }
124}
125
126impl<PR: PaginableRoom> Paginator<PR> {
127 pub fn new(room: PR) -> Self {
129 Self {
130 room,
131 state: SharedObservable::new(PaginatorState::Initial),
132 tokens: Mutex::new(PaginationTokens { previous: None.into(), next: None.into() }),
133 }
134 }
135
136 fn check_state(&self, expected: PaginatorState) -> Result<(), PaginatorError> {
138 let actual = self.state.get();
139 if actual != expected {
140 Err(PaginatorError::InvalidPreviousState { expected, actual })
141 } else {
142 Ok(())
143 }
144 }
145
146 pub fn state(&self) -> Subscriber<PaginatorState> {
148 self.state.subscribe()
149 }
150
151 pub async fn start_from(
157 &self,
158 event_id: &EventId,
159 num_events: UInt,
160 ) -> Result<StartFromResult, PaginatorError> {
161 self.check_state(PaginatorState::Initial)?;
162
163 if self.state.set_if_not_eq(PaginatorState::FetchingTargetEvent).is_none() {
167 return Err(PaginatorError::InvalidPreviousState {
168 expected: PaginatorState::Initial,
169 actual: PaginatorState::FetchingTargetEvent,
170 });
171 }
172
173 let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Initial);
174
175 let lazy_load_members = true;
177
178 let response =
179 self.room.event_with_context(event_id, lazy_load_members, num_events).await?;
180
181 let has_prev = response.prev_batch_token.is_some();
186 let has_next = response.next_batch_token.is_some();
187
188 {
189 let mut tokens = self.tokens.lock().unwrap();
190 tokens.previous = match response.prev_batch_token {
191 Some(token) => PaginationToken::HasMore(token),
192 None => PaginationToken::HitEnd,
193 };
194 tokens.next = match response.next_batch_token {
195 Some(token) => PaginationToken::HasMore(token),
196 None => PaginationToken::HitEnd,
197 };
198 }
199
200 reset_state_guard.disarm();
202 self.state.set(PaginatorState::Idle);
204
205 let events = response
212 .events_before
213 .into_iter()
214 .rev()
215 .chain(response.event)
216 .chain(response.events_after)
217 .collect();
218
219 Ok(StartFromResult { events, has_prev, has_next })
220 }
221
222 pub async fn paginate_backward(
231 &self,
232 num_events: UInt,
233 ) -> Result<PaginationResult, PaginatorError> {
234 self.paginate(Direction::Backward, num_events).await
235 }
236
237 pub fn hit_timeline_start(&self) -> bool {
242 matches!(self.tokens.lock().unwrap().previous, PaginationToken::HitEnd)
243 }
244
245 pub fn hit_timeline_end(&self) -> bool {
250 matches!(self.tokens.lock().unwrap().next, PaginationToken::HitEnd)
251 }
252
253 pub async fn paginate_forward(
261 &self,
262 num_events: UInt,
263 ) -> Result<PaginationResult, PaginatorError> {
264 self.paginate(Direction::Forward, num_events).await
265 }
266
267 async fn paginate(
271 &self,
272 dir: Direction,
273 num_events: UInt,
274 ) -> Result<PaginationResult, PaginatorError> {
275 self.check_state(PaginatorState::Idle)?;
276
277 let token = {
278 let tokens = self.tokens.lock().unwrap();
279
280 let token = match dir {
281 Direction::Backward => &tokens.previous,
282 Direction::Forward => &tokens.next,
283 };
284
285 match token {
286 PaginationToken::None => None,
287 PaginationToken::HasMore(val) => Some(val.clone()),
288 PaginationToken::HitEnd => {
289 return Ok(PaginationResult { events: Vec::new(), hit_end_of_timeline: true });
290 }
291 }
292 };
293
294 if self.state.set_if_not_eq(PaginatorState::Paginating).is_none() {
298 return Err(PaginatorError::InvalidPreviousState {
299 expected: PaginatorState::Idle,
300 actual: PaginatorState::Paginating,
301 });
302 }
303
304 let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Idle);
305
306 let mut options = MessagesOptions::new(dir).from(token.as_deref());
307 options.limit = num_events;
308
309 let response = self.room.messages(options).await?;
312
313 let hit_end_of_timeline = response.end.is_none();
318
319 {
320 let mut tokens = self.tokens.lock().unwrap();
321
322 let token = match dir {
323 Direction::Backward => &mut tokens.previous,
324 Direction::Forward => &mut tokens.next,
325 };
326
327 *token = match response.end {
328 Some(val) => PaginationToken::HasMore(val),
329 None => PaginationToken::HitEnd,
330 };
331 }
332
333 reset_state_guard.disarm();
337 self.state.set(PaginatorState::Idle);
339
340 Ok(PaginationResult { events: response.chunk, hit_end_of_timeline })
341 }
342
343 pub fn tokens(&self) -> PaginationTokens {
345 self.tokens.lock().unwrap().clone()
346 }
347}
348
349pub trait PaginableRoom: SendOutsideWasm + SyncOutsideWasm {
354 fn event_with_context(
369 &self,
370 event_id: &EventId,
371 lazy_load_members: bool,
372 num_events: UInt,
373 ) -> impl Future<Output = Result<EventWithContextResponse, PaginatorError>> + SendOutsideWasm;
374
375 fn messages(
377 &self,
378 opts: MessagesOptions,
379 ) -> impl Future<Output = Result<Messages, PaginatorError>> + SendOutsideWasm;
380}
381
382impl PaginableRoom for Room {
383 async fn event_with_context(
384 &self,
385 event_id: &EventId,
386 lazy_load_members: bool,
387 num_events: UInt,
388 ) -> Result<EventWithContextResponse, PaginatorError> {
389 let response =
390 match self.event_with_context(event_id, lazy_load_members, num_events, None).await {
391 Ok(result) => result,
392
393 Err(err) => {
394 if let Some(error) = err.as_client_api_error()
398 && error.status_code == 404
399 {
400 return Err(PaginatorError::EventNotFound(event_id.to_owned()));
402 }
403
404 return Err(PaginatorError::SdkError(Box::new(err)));
406 }
407 };
408
409 Ok(response)
410 }
411
412 async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
413 self.messages(opts).await.map_err(|err| PaginatorError::SdkError(Box::new(err)))
414 }
415}
416
417#[cfg(all(not(target_family = "wasm"), test))]
418mod tests {
419 use std::sync::Arc;
420
421 use assert_matches2::assert_let;
422 use futures_core::Future;
423 use futures_util::FutureExt as _;
424 use matrix_sdk_base::deserialized_responses::TimelineEvent;
425 use matrix_sdk_test::{async_test, event_factory::EventFactory};
426 use once_cell::sync::Lazy;
427 use ruma::{EventId, RoomId, UInt, UserId, api::Direction, event_id, room_id, uint, user_id};
428 use tokio::{
429 spawn,
430 sync::{Mutex, Notify},
431 task::AbortHandle,
432 };
433
434 use super::{PaginableRoom, PaginatorError, PaginatorState};
435 use crate::{
436 paginators::Paginator,
437 room::{EventWithContextResponse, Messages, MessagesOptions},
438 test_utils::assert_event_matches_msg,
439 };
440
441 #[derive(Clone)]
442 struct TestRoom {
443 event_factory: Arc<EventFactory>,
444 wait_for_ready: bool,
445
446 target_event_text: Arc<Mutex<String>>,
447 next_events: Arc<Mutex<Vec<TimelineEvent>>>,
448 prev_events: Arc<Mutex<Vec<TimelineEvent>>>,
449 prev_batch_token: Arc<Mutex<Option<String>>>,
450 next_batch_token: Arc<Mutex<Option<String>>>,
451
452 room_ready: Arc<Notify>,
453 }
454
455 impl TestRoom {
456 fn new(wait_for_ready: bool, room_id: &RoomId, sender: &UserId) -> Self {
457 let event_factory = Arc::new(EventFactory::default().sender(sender).room(room_id));
458
459 Self {
460 event_factory,
461 wait_for_ready,
462
463 room_ready: Default::default(),
464 target_event_text: Default::default(),
465 next_events: Default::default(),
466 prev_events: Default::default(),
467 prev_batch_token: Default::default(),
468 next_batch_token: Default::default(),
469 }
470 }
471
472 fn mark_ready(&self) {
474 self.room_ready.notify_one();
475 }
476 }
477
478 static ROOM_ID: Lazy<&RoomId> = Lazy::new(|| room_id!("!dune:herbert.org"));
479 static USER_ID: Lazy<&UserId> = Lazy::new(|| user_id!("@paul:atreid.es"));
480
481 impl PaginableRoom for TestRoom {
482 async fn event_with_context(
483 &self,
484 event_id: &EventId,
485 _lazy_load_members: bool,
486 num_events: UInt,
487 ) -> Result<EventWithContextResponse, PaginatorError> {
488 if self.wait_for_ready {
490 self.room_ready.notified().await;
491 }
492
493 let event = self
494 .event_factory
495 .text_msg(self.target_event_text.lock().await.clone())
496 .event_id(event_id)
497 .into_event();
498
499 let mut num_events = u64::from(num_events) as usize;
502
503 let prev_events = self.prev_events.lock().await;
504
505 let events_before = if prev_events.is_empty() {
506 Vec::new()
507 } else {
508 let len = prev_events.len();
509 let take_before = num_events.min(len);
510 num_events -= take_before;
512 prev_events[len - take_before..len].to_vec()
514 };
515
516 let events_after = self.next_events.lock().await;
517 let events_after = if events_after.is_empty() {
518 Vec::new()
519 } else {
520 events_after[0..num_events.min(events_after.len())].to_vec()
521 };
522
523 Ok(EventWithContextResponse {
524 event: Some(event),
525 events_before,
526 events_after,
527 prev_batch_token: self.prev_batch_token.lock().await.clone(),
528 next_batch_token: self.next_batch_token.lock().await.clone(),
529 state: Vec::new(),
530 })
531 }
532
533 async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
534 if self.wait_for_ready {
535 self.room_ready.notified().await;
536 }
537
538 let limit = u64::from(opts.limit) as usize;
539
540 let (end, events) = match opts.dir {
541 Direction::Backward => {
542 let events = self.prev_events.lock().await;
543 let events = if events.is_empty() {
544 Vec::new()
545 } else {
546 let len = events.len();
547 let take_before = limit.min(len);
548 events[len - take_before..len].to_vec()
550 };
551 (self.prev_batch_token.lock().await.clone(), events)
552 }
553
554 Direction::Forward => {
555 let events = self.next_events.lock().await;
556 let events = if events.is_empty() {
557 Vec::new()
558 } else {
559 events[0..limit.min(events.len())].to_vec()
560 };
561 (self.next_batch_token.lock().await.clone(), events)
562 }
563 };
564
565 Ok(Messages { start: opts.from.unwrap(), end, chunk: events, state: Vec::new() })
566 }
567 }
568
569 async fn assert_invalid_state<T: std::fmt::Debug>(
570 task: impl Future<Output = Result<T, PaginatorError>>,
571 expected: PaginatorState,
572 actual: PaginatorState,
573 ) {
574 assert_let!(
575 Err(PaginatorError::InvalidPreviousState {
576 expected: real_expected,
577 actual: real_actual
578 }) = task.await
579 );
580 assert_eq!(real_expected, expected);
581 assert_eq!(real_actual, actual);
582 }
583
584 #[async_test]
585 async fn test_start_from() {
586 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
588
589 let event_id = event_id!("$yoyoyo");
590 let event_factory = &room.event_factory;
591
592 *room.target_event_text.lock().await = "fetch_from".to_owned();
593 *room.prev_events.lock().await = (0..10)
594 .rev()
595 .map(|i| event_factory.text_msg(format!("before-{i}")).into_event())
596 .collect();
597 *room.next_events.lock().await =
598 (0..10).map(|i| event_factory.text_msg(format!("after-{i}")).into_event()).collect();
599
600 let paginator = Arc::new(Paginator::new(room.clone()));
602 let context =
603 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
604
605 assert!(!context.has_prev);
606 assert!(!context.has_next);
607
608 assert_eq!(context.events.len(), 21);
612
613 for i in 0..10 {
614 assert_event_matches_msg(&context.events[i], &format!("before-{i}"));
615 }
616
617 assert_event_matches_msg(&context.events[10], "fetch_from");
618 assert_eq!(context.events[10].raw().deserialize().unwrap().event_id(), event_id);
619
620 for i in 0..10 {
621 assert_event_matches_msg(&context.events[i + 11], &format!("after-{i}"));
622 }
623 }
624
625 #[async_test]
626 async fn test_start_from_with_num_events() {
627 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
629
630 let event_id = event_id!("$yoyoyo");
631 let event_factory = &room.event_factory;
632
633 *room.target_event_text.lock().await = "fetch_from".to_owned();
634 *room.prev_events.lock().await =
635 (0..100).rev().map(|i| event_factory.text_msg(format!("ev{i}")).into_event()).collect();
636
637 let paginator = Arc::new(Paginator::new(room.clone()));
639 let context =
640 paginator.start_from(event_id, uint!(10)).await.expect("start_from should work");
641
642 assert_eq!(context.events.len(), 11);
645
646 for i in 0..10 {
647 assert_event_matches_msg(&context.events[i], &format!("ev{i}"));
648 }
649 assert_event_matches_msg(&context.events[10], "fetch_from");
650 }
651
652 #[async_test]
653 async fn test_paginate_backward() {
654 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
656
657 let event_id = event_id!("$yoyoyo");
658 let event_factory = &room.event_factory;
659
660 *room.target_event_text.lock().await = "initial".to_owned();
661 *room.prev_batch_token.lock().await = Some("prev".to_owned());
662
663 let paginator = Arc::new(Paginator::new(room.clone()));
665
666 assert!(!paginator.hit_timeline_start(), "we must have a prev-batch token");
667 assert!(
668 !paginator.hit_timeline_end(),
669 "we don't know about the status of the next-batch token"
670 );
671
672 let context =
673 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
674
675 assert_eq!(context.events.len(), 1);
677 assert_event_matches_msg(&context.events[0], "initial");
678 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
679
680 assert!(context.has_prev);
682 assert!(!context.has_next);
683
684 assert!(!paginator.hit_timeline_start());
685 assert!(paginator.hit_timeline_end());
686
687 *room.prev_events.lock().await = vec![event_factory.text_msg("previous").into_event()];
689 *room.prev_batch_token.lock().await = Some("prev2".to_owned());
690
691 let prev =
693 paginator.paginate_backward(uint!(100)).await.expect("paginate backward should work");
694 assert!(!prev.hit_end_of_timeline);
695 assert!(!paginator.hit_timeline_start());
696 assert_eq!(prev.events.len(), 1);
697 assert_event_matches_msg(&prev.events[0], "previous");
698
699 *room.prev_events.lock().await = vec![event_factory.text_msg("oldest").into_event()];
702 *room.prev_batch_token.lock().await = None;
703
704 let prev = paginator
705 .paginate_backward(uint!(100))
706 .await
707 .expect("paginate backward the second time should work");
708 assert!(prev.hit_end_of_timeline);
709 assert!(paginator.hit_timeline_start());
710 assert_eq!(prev.events.len(), 1);
711 assert_event_matches_msg(&prev.events[0], "oldest");
712
713 let prev = paginator
716 .paginate_backward(uint!(100))
717 .await
718 .expect("paginate backward the third time should work");
719 assert!(prev.hit_end_of_timeline);
720 assert!(paginator.hit_timeline_start());
721 assert!(prev.events.is_empty());
722 }
723
724 #[async_test]
725 async fn test_paginate_backward_with_limit() {
726 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
728
729 let event_id = event_id!("$yoyoyo");
730 let event_factory = &room.event_factory;
731
732 *room.target_event_text.lock().await = "initial".to_owned();
733 *room.prev_batch_token.lock().await = Some("prev".to_owned());
734
735 let paginator = Arc::new(Paginator::new(room.clone()));
737 let context =
738 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
739
740 assert_eq!(context.events.len(), 1);
742 assert_event_matches_msg(&context.events[0], "initial");
743 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
744
745 assert!(context.has_prev);
747 assert!(!context.has_next);
748
749 *room.prev_events.lock().await = (0..100)
751 .rev()
752 .map(|i| event_factory.text_msg(format!("prev{i}")).into_event())
753 .collect();
754 *room.prev_batch_token.lock().await = None;
755
756 let prev =
758 paginator.paginate_backward(uint!(10)).await.expect("paginate backward should work");
759 assert!(prev.hit_end_of_timeline);
760 assert_eq!(prev.events.len(), 10);
761 for i in 0..10 {
762 assert_event_matches_msg(&prev.events[i], &format!("prev{}", 9 - i));
763 }
764 }
765
766 #[async_test]
767 async fn test_paginate_forward() {
768 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
770
771 let event_id = event_id!("$yoyoyo");
772 let event_factory = &room.event_factory;
773
774 *room.target_event_text.lock().await = "initial".to_owned();
775 *room.next_batch_token.lock().await = Some("next".to_owned());
776
777 let paginator = Arc::new(Paginator::new(room.clone()));
779 assert!(!paginator.hit_timeline_end(), "we must have a next-batch token");
780 assert!(
781 !paginator.hit_timeline_start(),
782 "we don't know about the status of the prev-batch token"
783 );
784
785 let context =
786 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
787
788 assert_eq!(context.events.len(), 1);
790 assert_event_matches_msg(&context.events[0], "initial");
791 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
792
793 assert!(!context.has_prev);
796 assert!(context.has_next);
797
798 assert!(paginator.hit_timeline_start());
799 assert!(!paginator.hit_timeline_end());
800
801 *room.next_events.lock().await = vec![event_factory.text_msg("next").into_event()];
803 *room.next_batch_token.lock().await = Some("next2".to_owned());
804
805 let next =
807 paginator.paginate_forward(uint!(100)).await.expect("paginate forward should work");
808 assert!(!next.hit_end_of_timeline);
809 assert_eq!(next.events.len(), 1);
810 assert_event_matches_msg(&next.events[0], "next");
811 assert!(!paginator.hit_timeline_end());
812
813 *room.next_events.lock().await = vec![event_factory.text_msg("latest").into_event()];
816 *room.next_batch_token.lock().await = None;
817
818 let next = paginator
819 .paginate_forward(uint!(100))
820 .await
821 .expect("paginate forward the second time should work");
822 assert!(next.hit_end_of_timeline);
823 assert_eq!(next.events.len(), 1);
824 assert_event_matches_msg(&next.events[0], "latest");
825 assert!(paginator.hit_timeline_end());
826
827 let next = paginator
830 .paginate_forward(uint!(100))
831 .await
832 .expect("paginate forward the third time should work");
833 assert!(next.hit_end_of_timeline);
834 assert!(next.events.is_empty());
835 assert!(paginator.hit_timeline_end());
836 }
837
838 #[async_test]
839 async fn test_state() {
840 let room = TestRoom::new(true, *ROOM_ID, *USER_ID);
841
842 *room.prev_batch_token.lock().await = Some("prev".to_owned());
843 *room.next_batch_token.lock().await = Some("next".to_owned());
844
845 let paginator = Arc::new(Paginator::new(room.clone()));
846
847 let event_id = event_id!("$yoyoyo");
848
849 let mut state = paginator.state();
850
851 assert_eq!(state.get(), PaginatorState::Initial);
852 assert!(state.next().now_or_never().is_none());
853
854 assert_invalid_state(
856 paginator.paginate_backward(uint!(100)),
857 PaginatorState::Idle,
858 PaginatorState::Initial,
859 )
860 .await;
861
862 assert!(state.next().now_or_never().is_none());
863
864 let p = paginator.clone();
866 let join_handle = spawn(async move { p.start_from(event_id, uint!(100)).await });
867
868 assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
869 assert!(state.next().now_or_never().is_none());
870
871 assert_invalid_state(
873 paginator.start_from(event_id, uint!(100)),
874 PaginatorState::Initial,
875 PaginatorState::FetchingTargetEvent,
876 )
877 .await;
878
879 assert_invalid_state(
880 paginator.paginate_backward(uint!(100)),
881 PaginatorState::Idle,
882 PaginatorState::FetchingTargetEvent,
883 )
884 .await;
885
886 assert!(state.next().now_or_never().is_none());
887
888 room.mark_ready();
890
891 assert_eq!(state.next().await, Some(PaginatorState::Idle));
893
894 join_handle.await.expect("joined failed").expect("/context failed");
895
896 assert!(state.next().now_or_never().is_none());
897
898 let p = paginator.clone();
899 let join_handle = spawn(async move { p.paginate_backward(uint!(100)).await });
900
901 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
902
903 assert_invalid_state(
905 paginator.start_from(event_id, uint!(100)),
906 PaginatorState::Initial,
907 PaginatorState::Paginating,
908 )
909 .await;
910
911 assert_invalid_state(
912 paginator.paginate_backward(uint!(100)),
913 PaginatorState::Idle,
914 PaginatorState::Paginating,
915 )
916 .await;
917
918 assert_invalid_state(
919 paginator.paginate_forward(uint!(100)),
920 PaginatorState::Idle,
921 PaginatorState::Paginating,
922 )
923 .await;
924
925 assert!(state.next().now_or_never().is_none());
926
927 room.mark_ready();
928
929 assert_eq!(state.next().await, Some(PaginatorState::Idle));
930
931 join_handle.await.expect("joined failed").expect("/messages failed");
932
933 assert!(state.next().now_or_never().is_none());
934 }
935
936 mod aborts {
937 use super::*;
938 use crate::paginators::room::{PaginationToken, PaginationTokens};
939
940 #[derive(Clone, Default)]
941 struct AbortingRoom {
942 abort_handle: Arc<Mutex<Option<AbortHandle>>>,
943 room_ready: Arc<Notify>,
944 }
945
946 impl AbortingRoom {
947 async fn wait_abort_and_yield(&self) -> ! {
948 self.room_ready.notified().await;
950
951 let mut guard = self.abort_handle.lock().await;
953 let handle = guard.take().expect("only call me when i'm initialized");
954 handle.abort();
955
956 loop {
958 tokio::task::yield_now().await;
959 }
960 }
961 }
962
963 impl PaginableRoom for AbortingRoom {
964 async fn event_with_context(
965 &self,
966 _event_id: &EventId,
967 _lazy_load_members: bool,
968 _num_events: UInt,
969 ) -> Result<EventWithContextResponse, PaginatorError> {
970 self.wait_abort_and_yield().await
971 }
972
973 async fn messages(&self, _opts: MessagesOptions) -> Result<Messages, PaginatorError> {
974 self.wait_abort_and_yield().await
975 }
976 }
977
978 #[async_test]
979 async fn test_abort_while_starting_from() {
980 let room = AbortingRoom::default();
981
982 let paginator = Arc::new(Paginator::new(room.clone()));
983
984 let mut state = paginator.state();
985
986 assert_eq!(state.get(), PaginatorState::Initial);
987 assert!(state.next().now_or_never().is_none());
988
989 let p = paginator.clone();
991 let join_handle = spawn(async move {
992 let _ = p.start_from(event_id!("$yoyoyo"), uint!(100)).await;
993 });
994
995 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
996
997 assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
998 assert!(state.next().now_or_never().is_none());
999
1000 room.room_ready.notify_one();
1001
1002 let join_result = join_handle.await;
1004 assert!(join_result.unwrap_err().is_cancelled());
1005
1006 assert_eq!(state.next().await, Some(PaginatorState::Initial));
1008 assert!(state.next().now_or_never().is_none());
1009 }
1010
1011 #[async_test]
1012 async fn test_abort_while_paginating() {
1013 let room = AbortingRoom::default();
1014
1015 let paginator = Paginator::new(room.clone());
1017 paginator.state.set(PaginatorState::Idle);
1018 *paginator.tokens.lock().unwrap() = PaginationTokens {
1019 previous: PaginationToken::HasMore("prev".to_owned()),
1020 next: PaginationToken::HasMore("next".to_owned()),
1021 };
1022
1023 let paginator = Arc::new(paginator);
1024
1025 let mut state = paginator.state();
1026
1027 assert_eq!(state.get(), PaginatorState::Idle);
1028 assert!(state.next().now_or_never().is_none());
1029
1030 let p = paginator.clone();
1032 let join_handle = spawn(async move {
1033 let _ = p.paginate_backward(uint!(100)).await;
1034 });
1035
1036 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1037
1038 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1039 assert!(state.next().now_or_never().is_none());
1040
1041 room.room_ready.notify_one();
1042
1043 let join_result = join_handle.await;
1045 assert!(join_result.unwrap_err().is_cancelled());
1046
1047 assert_eq!(state.next().await, Some(PaginatorState::Idle));
1049 assert!(state.next().now_or_never().is_none());
1050
1051 let p = paginator.clone();
1053 let join_handle = spawn(async move {
1054 let _ = p.paginate_forward(uint!(100)).await;
1055 });
1056
1057 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1058
1059 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1060 assert!(state.next().now_or_never().is_none());
1061
1062 room.room_ready.notify_one();
1063
1064 let join_result = join_handle.await;
1065 assert!(join_result.unwrap_err().is_cancelled());
1066
1067 assert_eq!(state.next().await, Some(PaginatorState::Idle));
1068 assert!(state.next().now_or_never().is_none());
1069 }
1070 }
1071}