1use std::{
21 future::Future,
22 sync::{Arc, Mutex},
23};
24
25use eyeball::{SharedObservable, Subscriber};
26use matrix_sdk_base::{SendOutsideWasm, SyncOutsideWasm, deserialized_responses::TimelineEvent};
27use ruma::{EventId, UInt, api::Direction};
28
29use crate::{
30 Room,
31 paginators::{PaginationResult, PaginationToken, PaginatorError},
32 room::{EventWithContextResponse, Messages, MessagesOptions},
33};
34
35#[derive(Debug, PartialEq, Copy, Clone)]
37#[cfg_attr(feature = "uniffi", derive(uniffi::Enum))]
38pub enum PaginatorState {
39 Initial,
41
42 FetchingTargetEvent,
44
45 Idle,
48
49 Paginating,
51}
52
53#[derive(Debug, Clone)]
55pub struct PaginationTokens {
56 pub previous: PaginationToken,
58 pub next: PaginationToken,
60}
61
62pub struct Paginator<PR: PaginableRoom> {
67 room: PR,
69
70 state: SharedObservable<PaginatorState>,
72
73 tokens: Mutex<PaginationTokens>,
77}
78
79#[cfg(not(tarpaulin_include))]
80impl<PR: PaginableRoom> std::fmt::Debug for Paginator<PR> {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 f.debug_struct("Paginator")
84 .field("state", &self.state.get())
85 .field("tokens", &self.tokens)
86 .finish_non_exhaustive()
87 }
88}
89
90#[derive(Debug)]
92pub struct StartFromResult {
93 pub events: Vec<TimelineEvent>,
95
96 pub has_prev: bool,
98
99 pub has_next: bool,
101}
102
103struct ResetStateGuard {
105 target: Option<PaginatorState>,
106 state: SharedObservable<PaginatorState>,
107}
108
109impl ResetStateGuard {
110 fn new(state: SharedObservable<PaginatorState>, target: PaginatorState) -> Self {
112 Self { target: Some(target), state }
113 }
114
115 fn disarm(mut self) {
117 self.target = None;
118 }
119}
120
121impl Drop for ResetStateGuard {
122 fn drop(&mut self) {
123 if let Some(target) = self.target.take() {
124 self.state.set_if_not_eq(target);
125 }
126 }
127}
128
129impl<PR: PaginableRoom> Paginator<PR> {
130 pub fn new(room: PR) -> Self {
132 Self {
133 room,
134 state: SharedObservable::new(PaginatorState::Initial),
135 tokens: Mutex::new(PaginationTokens { previous: None.into(), next: None.into() }),
136 }
137 }
138
139 fn check_state(&self, expected: PaginatorState) -> Result<(), PaginatorError> {
141 let actual = self.state.get();
142 if actual != expected {
143 Err(PaginatorError::InvalidPreviousState { expected, actual })
144 } else {
145 Ok(())
146 }
147 }
148
149 pub fn state(&self) -> Subscriber<PaginatorState> {
151 self.state.subscribe()
152 }
153
154 pub async fn start_from(
160 &self,
161 event_id: &EventId,
162 num_events: UInt,
163 ) -> Result<StartFromResult, PaginatorError> {
164 self.check_state(PaginatorState::Initial)?;
165
166 if self.state.set_if_not_eq(PaginatorState::FetchingTargetEvent).is_none() {
170 return Err(PaginatorError::InvalidPreviousState {
171 expected: PaginatorState::Initial,
172 actual: PaginatorState::FetchingTargetEvent,
173 });
174 }
175
176 let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Initial);
177
178 let lazy_load_members = true;
180
181 let response =
182 self.room.event_with_context(event_id, lazy_load_members, num_events).await?;
183
184 let has_prev = response.prev_batch_token.is_some();
189 let has_next = response.next_batch_token.is_some();
190
191 {
192 let mut tokens = self.tokens.lock().unwrap();
193 tokens.previous = match response.prev_batch_token {
194 Some(token) => PaginationToken::HasMore(token),
195 None => PaginationToken::HitEnd,
196 };
197 tokens.next = match response.next_batch_token {
198 Some(token) => PaginationToken::HasMore(token),
199 None => PaginationToken::HitEnd,
200 };
201 }
202
203 reset_state_guard.disarm();
205 self.state.set(PaginatorState::Idle);
207
208 let events = response
215 .events_before
216 .into_iter()
217 .rev()
218 .chain(response.event)
219 .chain(response.events_after)
220 .collect();
221
222 Ok(StartFromResult { events, has_prev, has_next })
223 }
224
225 pub async fn paginate_backward(
234 &self,
235 num_events: UInt,
236 ) -> Result<PaginationResult, PaginatorError> {
237 self.paginate(Direction::Backward, num_events).await
238 }
239
240 pub fn hit_timeline_start(&self) -> bool {
245 matches!(self.tokens.lock().unwrap().previous, PaginationToken::HitEnd)
246 }
247
248 pub fn hit_timeline_end(&self) -> bool {
253 matches!(self.tokens.lock().unwrap().next, PaginationToken::HitEnd)
254 }
255
256 pub async fn paginate_forward(
264 &self,
265 num_events: UInt,
266 ) -> Result<PaginationResult, PaginatorError> {
267 self.paginate(Direction::Forward, num_events).await
268 }
269
270 async fn paginate(
274 &self,
275 dir: Direction,
276 num_events: UInt,
277 ) -> Result<PaginationResult, PaginatorError> {
278 self.check_state(PaginatorState::Idle)?;
279
280 let token = {
281 let tokens = self.tokens.lock().unwrap();
282
283 let token = match dir {
284 Direction::Backward => &tokens.previous,
285 Direction::Forward => &tokens.next,
286 };
287
288 match token {
289 PaginationToken::None => None,
290 PaginationToken::HasMore(val) => Some(val.clone()),
291 PaginationToken::HitEnd => {
292 return Ok(PaginationResult { events: Vec::new(), hit_end_of_timeline: true });
293 }
294 }
295 };
296
297 if self.state.set_if_not_eq(PaginatorState::Paginating).is_none() {
301 return Err(PaginatorError::InvalidPreviousState {
302 expected: PaginatorState::Idle,
303 actual: PaginatorState::Paginating,
304 });
305 }
306
307 let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Idle);
308
309 let mut options = MessagesOptions::new(dir).from(token.as_deref());
310 options.limit = num_events;
311
312 let response = self.room.messages(options).await?;
315
316 let hit_end_of_timeline = response.end.is_none();
321
322 {
323 let mut tokens = self.tokens.lock().unwrap();
324
325 let token = match dir {
326 Direction::Backward => &mut tokens.previous,
327 Direction::Forward => &mut tokens.next,
328 };
329
330 *token = match response.end {
331 Some(val) => PaginationToken::HasMore(val),
332 None => PaginationToken::HitEnd,
333 };
334 }
335
336 reset_state_guard.disarm();
340 self.state.set(PaginatorState::Idle);
342
343 Ok(PaginationResult { events: response.chunk, hit_end_of_timeline })
344 }
345
346 pub fn tokens(&self) -> PaginationTokens {
348 self.tokens.lock().unwrap().clone()
349 }
350}
351
352pub trait PaginableRoom: SendOutsideWasm + SyncOutsideWasm {
357 fn event_with_context(
372 &self,
373 event_id: &EventId,
374 lazy_load_members: bool,
375 num_events: UInt,
376 ) -> impl Future<Output = Result<EventWithContextResponse, PaginatorError>> + SendOutsideWasm;
377
378 fn messages(
380 &self,
381 opts: MessagesOptions,
382 ) -> impl Future<Output = Result<Messages, PaginatorError>> + SendOutsideWasm;
383}
384
385impl PaginableRoom for Room {
386 async fn event_with_context(
387 &self,
388 event_id: &EventId,
389 lazy_load_members: bool,
390 num_events: UInt,
391 ) -> Result<EventWithContextResponse, PaginatorError> {
392 let response =
393 match self.event_with_context(event_id, lazy_load_members, num_events, None).await {
394 Ok(result) => result,
395
396 Err(err) => {
397 if let Some(error) = err.as_client_api_error()
401 && error.status_code == 404
402 {
403 return Err(PaginatorError::EventNotFound(event_id.to_owned()));
405 }
406
407 return Err(PaginatorError::SdkError(Arc::new(err)));
409 }
410 };
411
412 Ok(response)
413 }
414
415 async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
416 self.messages(opts).await.map_err(|err| PaginatorError::SdkError(Arc::new(err)))
417 }
418}
419
420#[cfg(all(not(target_family = "wasm"), test))]
421mod tests {
422 use std::sync::{Arc, LazyLock};
423
424 use assert_matches2::assert_let;
425 use futures_core::Future;
426 use futures_util::FutureExt as _;
427 use matrix_sdk_base::deserialized_responses::TimelineEvent;
428 use matrix_sdk_test::{async_test, event_factory::EventFactory};
429 use ruma::{EventId, RoomId, UInt, UserId, api::Direction, event_id, room_id, uint, user_id};
430 use tokio::{
431 spawn,
432 sync::{Mutex, Notify},
433 task::AbortHandle,
434 };
435
436 use super::{PaginableRoom, PaginatorError, PaginatorState};
437 use crate::{
438 paginators::Paginator,
439 room::{EventWithContextResponse, Messages, MessagesOptions},
440 test_utils::assert_event_matches_msg,
441 };
442
443 #[derive(Clone)]
444 struct TestRoom {
445 event_factory: Arc<EventFactory>,
446 wait_for_ready: bool,
447
448 target_event_text: Arc<Mutex<String>>,
449 next_events: Arc<Mutex<Vec<TimelineEvent>>>,
450 prev_events: Arc<Mutex<Vec<TimelineEvent>>>,
451 prev_batch_token: Arc<Mutex<Option<String>>>,
452 next_batch_token: Arc<Mutex<Option<String>>>,
453
454 room_ready: Arc<Notify>,
455 }
456
457 impl TestRoom {
458 fn new(wait_for_ready: bool, room_id: &RoomId, sender: &UserId) -> Self {
459 let event_factory = Arc::new(EventFactory::default().sender(sender).room(room_id));
460
461 Self {
462 event_factory,
463 wait_for_ready,
464
465 room_ready: Default::default(),
466 target_event_text: Default::default(),
467 next_events: Default::default(),
468 prev_events: Default::default(),
469 prev_batch_token: Default::default(),
470 next_batch_token: Default::default(),
471 }
472 }
473
474 fn mark_ready(&self) {
476 self.room_ready.notify_one();
477 }
478 }
479
480 static ROOM_ID: LazyLock<&RoomId> = LazyLock::new(|| room_id!("!dune:herbert.org"));
481 static USER_ID: LazyLock<&UserId> = LazyLock::new(|| user_id!("@paul:atreid.es"));
482
483 impl PaginableRoom for TestRoom {
484 async fn event_with_context(
485 &self,
486 event_id: &EventId,
487 _lazy_load_members: bool,
488 num_events: UInt,
489 ) -> Result<EventWithContextResponse, PaginatorError> {
490 if self.wait_for_ready {
492 self.room_ready.notified().await;
493 }
494
495 let event = self
496 .event_factory
497 .text_msg(self.target_event_text.lock().await.clone())
498 .event_id(event_id)
499 .into_event();
500
501 let mut num_events = u64::from(num_events) as usize;
504
505 let prev_events = self.prev_events.lock().await;
506
507 let events_before = if prev_events.is_empty() {
508 Vec::new()
509 } else {
510 let len = prev_events.len();
511 let take_before = num_events.min(len);
512 num_events -= take_before;
514 prev_events[len - take_before..len].to_vec()
516 };
517
518 let events_after = self.next_events.lock().await;
519 let events_after = if events_after.is_empty() {
520 Vec::new()
521 } else {
522 events_after[0..num_events.min(events_after.len())].to_vec()
523 };
524
525 Ok(EventWithContextResponse {
526 event: Some(event),
527 events_before,
528 events_after,
529 prev_batch_token: self.prev_batch_token.lock().await.clone(),
530 next_batch_token: self.next_batch_token.lock().await.clone(),
531 state: Vec::new(),
532 })
533 }
534
535 async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
536 if self.wait_for_ready {
537 self.room_ready.notified().await;
538 }
539
540 let limit = u64::from(opts.limit) as usize;
541
542 let (end, events) = match opts.dir {
543 Direction::Backward => {
544 let events = self.prev_events.lock().await;
545 let events = if events.is_empty() {
546 Vec::new()
547 } else {
548 let len = events.len();
549 let take_before = limit.min(len);
550 events[len - take_before..len].to_vec()
552 };
553 (self.prev_batch_token.lock().await.clone(), events)
554 }
555
556 Direction::Forward => {
557 let events = self.next_events.lock().await;
558 let events = if events.is_empty() {
559 Vec::new()
560 } else {
561 events[0..limit.min(events.len())].to_vec()
562 };
563 (self.next_batch_token.lock().await.clone(), events)
564 }
565 };
566
567 Ok(Messages { start: opts.from.unwrap(), end, chunk: events, state: Vec::new() })
568 }
569 }
570
571 async fn assert_invalid_state<T: std::fmt::Debug>(
572 task: impl Future<Output = Result<T, PaginatorError>>,
573 expected: PaginatorState,
574 actual: PaginatorState,
575 ) {
576 assert_let!(
577 Err(PaginatorError::InvalidPreviousState {
578 expected: real_expected,
579 actual: real_actual
580 }) = task.await
581 );
582 assert_eq!(real_expected, expected);
583 assert_eq!(real_actual, actual);
584 }
585
586 #[async_test]
587 async fn test_start_from() {
588 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
590
591 let event_id = event_id!("$yoyoyo");
592 let event_factory = &room.event_factory;
593
594 *room.target_event_text.lock().await = "fetch_from".to_owned();
595 *room.prev_events.lock().await = (0..10)
596 .rev()
597 .map(|i| event_factory.text_msg(format!("before-{i}")).into_event())
598 .collect();
599 *room.next_events.lock().await =
600 (0..10).map(|i| event_factory.text_msg(format!("after-{i}")).into_event()).collect();
601
602 let paginator = Arc::new(Paginator::new(room.clone()));
604 let context =
605 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
606
607 assert!(!context.has_prev);
608 assert!(!context.has_next);
609
610 assert_eq!(context.events.len(), 21);
614
615 for i in 0..10 {
616 assert_event_matches_msg(&context.events[i], &format!("before-{i}"));
617 }
618
619 assert_event_matches_msg(&context.events[10], "fetch_from");
620 assert_eq!(context.events[10].raw().deserialize().unwrap().event_id(), event_id);
621
622 for i in 0..10 {
623 assert_event_matches_msg(&context.events[i + 11], &format!("after-{i}"));
624 }
625 }
626
627 #[async_test]
628 async fn test_start_from_with_num_events() {
629 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
631
632 let event_id = event_id!("$yoyoyo");
633 let event_factory = &room.event_factory;
634
635 *room.target_event_text.lock().await = "fetch_from".to_owned();
636 *room.prev_events.lock().await =
637 (0..100).rev().map(|i| event_factory.text_msg(format!("ev{i}")).into_event()).collect();
638
639 let paginator = Arc::new(Paginator::new(room.clone()));
641 let context =
642 paginator.start_from(event_id, uint!(10)).await.expect("start_from should work");
643
644 assert_eq!(context.events.len(), 11);
647
648 for i in 0..10 {
649 assert_event_matches_msg(&context.events[i], &format!("ev{i}"));
650 }
651 assert_event_matches_msg(&context.events[10], "fetch_from");
652 }
653
654 #[async_test]
655 async fn test_paginate_backward() {
656 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
658
659 let event_id = event_id!("$yoyoyo");
660 let event_factory = &room.event_factory;
661
662 *room.target_event_text.lock().await = "initial".to_owned();
663 *room.prev_batch_token.lock().await = Some("prev".to_owned());
664
665 let paginator = Arc::new(Paginator::new(room.clone()));
667
668 assert!(!paginator.hit_timeline_start(), "we must have a prev-batch token");
669 assert!(
670 !paginator.hit_timeline_end(),
671 "we don't know about the status of the next-batch token"
672 );
673
674 let context =
675 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
676
677 assert_eq!(context.events.len(), 1);
679 assert_event_matches_msg(&context.events[0], "initial");
680 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
681
682 assert!(context.has_prev);
684 assert!(!context.has_next);
685
686 assert!(!paginator.hit_timeline_start());
687 assert!(paginator.hit_timeline_end());
688
689 *room.prev_events.lock().await = vec![event_factory.text_msg("previous").into_event()];
691 *room.prev_batch_token.lock().await = Some("prev2".to_owned());
692
693 let prev =
695 paginator.paginate_backward(uint!(100)).await.expect("paginate backward should work");
696 assert!(!prev.hit_end_of_timeline);
697 assert!(!paginator.hit_timeline_start());
698 assert_eq!(prev.events.len(), 1);
699 assert_event_matches_msg(&prev.events[0], "previous");
700
701 *room.prev_events.lock().await = vec![event_factory.text_msg("oldest").into_event()];
704 *room.prev_batch_token.lock().await = None;
705
706 let prev = paginator
707 .paginate_backward(uint!(100))
708 .await
709 .expect("paginate backward the second time should work");
710 assert!(prev.hit_end_of_timeline);
711 assert!(paginator.hit_timeline_start());
712 assert_eq!(prev.events.len(), 1);
713 assert_event_matches_msg(&prev.events[0], "oldest");
714
715 let prev = paginator
718 .paginate_backward(uint!(100))
719 .await
720 .expect("paginate backward the third time should work");
721 assert!(prev.hit_end_of_timeline);
722 assert!(paginator.hit_timeline_start());
723 assert!(prev.events.is_empty());
724 }
725
726 #[async_test]
727 async fn test_paginate_backward_with_limit() {
728 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
730
731 let event_id = event_id!("$yoyoyo");
732 let event_factory = &room.event_factory;
733
734 *room.target_event_text.lock().await = "initial".to_owned();
735 *room.prev_batch_token.lock().await = Some("prev".to_owned());
736
737 let paginator = Arc::new(Paginator::new(room.clone()));
739 let context =
740 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
741
742 assert_eq!(context.events.len(), 1);
744 assert_event_matches_msg(&context.events[0], "initial");
745 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
746
747 assert!(context.has_prev);
749 assert!(!context.has_next);
750
751 *room.prev_events.lock().await = (0..100)
753 .rev()
754 .map(|i| event_factory.text_msg(format!("prev{i}")).into_event())
755 .collect();
756 *room.prev_batch_token.lock().await = None;
757
758 let prev =
760 paginator.paginate_backward(uint!(10)).await.expect("paginate backward should work");
761 assert!(prev.hit_end_of_timeline);
762 assert_eq!(prev.events.len(), 10);
763 for i in 0..10 {
764 assert_event_matches_msg(&prev.events[i], &format!("prev{}", 9 - i));
765 }
766 }
767
768 #[async_test]
769 async fn test_paginate_forward() {
770 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
772
773 let event_id = event_id!("$yoyoyo");
774 let event_factory = &room.event_factory;
775
776 *room.target_event_text.lock().await = "initial".to_owned();
777 *room.next_batch_token.lock().await = Some("next".to_owned());
778
779 let paginator = Arc::new(Paginator::new(room.clone()));
781 assert!(!paginator.hit_timeline_end(), "we must have a next-batch token");
782 assert!(
783 !paginator.hit_timeline_start(),
784 "we don't know about the status of the prev-batch token"
785 );
786
787 let context =
788 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
789
790 assert_eq!(context.events.len(), 1);
792 assert_event_matches_msg(&context.events[0], "initial");
793 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
794
795 assert!(!context.has_prev);
798 assert!(context.has_next);
799
800 assert!(paginator.hit_timeline_start());
801 assert!(!paginator.hit_timeline_end());
802
803 *room.next_events.lock().await = vec![event_factory.text_msg("next").into_event()];
805 *room.next_batch_token.lock().await = Some("next2".to_owned());
806
807 let next =
809 paginator.paginate_forward(uint!(100)).await.expect("paginate forward should work");
810 assert!(!next.hit_end_of_timeline);
811 assert_eq!(next.events.len(), 1);
812 assert_event_matches_msg(&next.events[0], "next");
813 assert!(!paginator.hit_timeline_end());
814
815 *room.next_events.lock().await = vec![event_factory.text_msg("latest").into_event()];
818 *room.next_batch_token.lock().await = None;
819
820 let next = paginator
821 .paginate_forward(uint!(100))
822 .await
823 .expect("paginate forward the second time should work");
824 assert!(next.hit_end_of_timeline);
825 assert_eq!(next.events.len(), 1);
826 assert_event_matches_msg(&next.events[0], "latest");
827 assert!(paginator.hit_timeline_end());
828
829 let next = paginator
832 .paginate_forward(uint!(100))
833 .await
834 .expect("paginate forward the third time should work");
835 assert!(next.hit_end_of_timeline);
836 assert!(next.events.is_empty());
837 assert!(paginator.hit_timeline_end());
838 }
839
840 #[async_test]
841 async fn test_state() {
842 let room = TestRoom::new(true, *ROOM_ID, *USER_ID);
843
844 *room.prev_batch_token.lock().await = Some("prev".to_owned());
845 *room.next_batch_token.lock().await = Some("next".to_owned());
846
847 let paginator = Arc::new(Paginator::new(room.clone()));
848
849 let event_id = event_id!("$yoyoyo");
850
851 let mut state = paginator.state();
852
853 assert_eq!(state.get(), PaginatorState::Initial);
854 assert!(state.next().now_or_never().is_none());
855
856 assert_invalid_state(
858 paginator.paginate_backward(uint!(100)),
859 PaginatorState::Idle,
860 PaginatorState::Initial,
861 )
862 .await;
863
864 assert!(state.next().now_or_never().is_none());
865
866 let p = paginator.clone();
868 let join_handle = spawn(async move { p.start_from(event_id, uint!(100)).await });
869
870 assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
871 assert!(state.next().now_or_never().is_none());
872
873 assert_invalid_state(
875 paginator.start_from(event_id, uint!(100)),
876 PaginatorState::Initial,
877 PaginatorState::FetchingTargetEvent,
878 )
879 .await;
880
881 assert_invalid_state(
882 paginator.paginate_backward(uint!(100)),
883 PaginatorState::Idle,
884 PaginatorState::FetchingTargetEvent,
885 )
886 .await;
887
888 assert!(state.next().now_or_never().is_none());
889
890 room.mark_ready();
892
893 assert_eq!(state.next().await, Some(PaginatorState::Idle));
895
896 join_handle.await.expect("joined failed").expect("/context failed");
897
898 assert!(state.next().now_or_never().is_none());
899
900 let p = paginator.clone();
901 let join_handle = spawn(async move { p.paginate_backward(uint!(100)).await });
902
903 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
904
905 assert_invalid_state(
907 paginator.start_from(event_id, uint!(100)),
908 PaginatorState::Initial,
909 PaginatorState::Paginating,
910 )
911 .await;
912
913 assert_invalid_state(
914 paginator.paginate_backward(uint!(100)),
915 PaginatorState::Idle,
916 PaginatorState::Paginating,
917 )
918 .await;
919
920 assert_invalid_state(
921 paginator.paginate_forward(uint!(100)),
922 PaginatorState::Idle,
923 PaginatorState::Paginating,
924 )
925 .await;
926
927 assert!(state.next().now_or_never().is_none());
928
929 room.mark_ready();
930
931 assert_eq!(state.next().await, Some(PaginatorState::Idle));
932
933 join_handle.await.expect("joined failed").expect("/messages failed");
934
935 assert!(state.next().now_or_never().is_none());
936 }
937
938 mod aborts {
939 use super::*;
940 use crate::paginators::room::{PaginationToken, PaginationTokens};
941
942 #[derive(Clone, Default)]
943 struct AbortingRoom {
944 abort_handle: Arc<Mutex<Option<AbortHandle>>>,
945 room_ready: Arc<Notify>,
946 }
947
948 impl AbortingRoom {
949 async fn wait_abort_and_yield(&self) -> ! {
950 self.room_ready.notified().await;
952
953 let mut guard = self.abort_handle.lock().await;
955 let handle = guard.take().expect("only call me when i'm initialized");
956 handle.abort();
957
958 loop {
960 tokio::task::yield_now().await;
961 }
962 }
963 }
964
965 impl PaginableRoom for AbortingRoom {
966 async fn event_with_context(
967 &self,
968 _event_id: &EventId,
969 _lazy_load_members: bool,
970 _num_events: UInt,
971 ) -> Result<EventWithContextResponse, PaginatorError> {
972 self.wait_abort_and_yield().await
973 }
974
975 async fn messages(&self, _opts: MessagesOptions) -> Result<Messages, PaginatorError> {
976 self.wait_abort_and_yield().await
977 }
978 }
979
980 #[async_test]
981 async fn test_abort_while_starting_from() {
982 let room = AbortingRoom::default();
983
984 let paginator = Arc::new(Paginator::new(room.clone()));
985
986 let mut state = paginator.state();
987
988 assert_eq!(state.get(), PaginatorState::Initial);
989 assert!(state.next().now_or_never().is_none());
990
991 let p = paginator.clone();
993 let join_handle = spawn(async move {
994 let _ = p.start_from(event_id!("$yoyoyo"), uint!(100)).await;
995 });
996
997 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
998
999 assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
1000 assert!(state.next().now_or_never().is_none());
1001
1002 room.room_ready.notify_one();
1003
1004 let join_result = join_handle.await;
1006 assert!(join_result.unwrap_err().is_cancelled());
1007
1008 assert_eq!(state.next().await, Some(PaginatorState::Initial));
1010 assert!(state.next().now_or_never().is_none());
1011 }
1012
1013 #[async_test]
1014 async fn test_abort_while_paginating() {
1015 let room = AbortingRoom::default();
1016
1017 let paginator = Paginator::new(room.clone());
1019 paginator.state.set(PaginatorState::Idle);
1020 *paginator.tokens.lock().unwrap() = PaginationTokens {
1021 previous: PaginationToken::HasMore("prev".to_owned()),
1022 next: PaginationToken::HasMore("next".to_owned()),
1023 };
1024
1025 let paginator = Arc::new(paginator);
1026
1027 let mut state = paginator.state();
1028
1029 assert_eq!(state.get(), PaginatorState::Idle);
1030 assert!(state.next().now_or_never().is_none());
1031
1032 let p = paginator.clone();
1034 let join_handle = spawn(async move {
1035 let _ = p.paginate_backward(uint!(100)).await;
1036 });
1037
1038 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1039
1040 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1041 assert!(state.next().now_or_never().is_none());
1042
1043 room.room_ready.notify_one();
1044
1045 let join_result = join_handle.await;
1047 assert!(join_result.unwrap_err().is_cancelled());
1048
1049 assert_eq!(state.next().await, Some(PaginatorState::Idle));
1051 assert!(state.next().now_or_never().is_none());
1052
1053 let p = paginator.clone();
1055 let join_handle = spawn(async move {
1056 let _ = p.paginate_forward(uint!(100)).await;
1057 });
1058
1059 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1060
1061 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1062 assert!(state.next().now_or_never().is_none());
1063
1064 room.room_ready.notify_one();
1065
1066 let join_result = join_handle.await;
1067 assert!(join_result.unwrap_err().is_cancelled());
1068
1069 assert_eq!(state.next().await, Some(PaginatorState::Idle));
1070 assert!(state.next().now_or_never().is_none());
1071 }
1072 }
1073}