1use std::{future::Future, sync::Mutex};
21
22use eyeball::{SharedObservable, Subscriber};
23use matrix_sdk_base::{deserialized_responses::TimelineEvent, SendOutsideWasm, SyncOutsideWasm};
24use ruma::{api::Direction, EventId, UInt};
25
26use crate::{
27 paginators::{PaginationResult, PaginationToken, PaginatorError},
28 room::{EventWithContextResponse, Messages, MessagesOptions},
29 Room,
30};
31
32#[derive(Debug, PartialEq, Copy, Clone)]
34#[cfg_attr(feature = "uniffi", derive(uniffi::Enum))]
35pub enum PaginatorState {
36 Initial,
38
39 FetchingTargetEvent,
41
42 Idle,
45
46 Paginating,
48}
49
50#[derive(Debug)]
52struct PaginationTokens {
53 previous: PaginationToken,
55 next: PaginationToken,
57}
58
59pub struct Paginator<PR: PaginableRoom> {
64 room: PR,
66
67 state: SharedObservable<PaginatorState>,
69
70 tokens: Mutex<PaginationTokens>,
74}
75
76#[cfg(not(tarpaulin_include))]
77impl<PR: PaginableRoom> std::fmt::Debug for Paginator<PR> {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("Paginator")
81 .field("state", &self.state.get())
82 .field("tokens", &self.tokens)
83 .finish_non_exhaustive()
84 }
85}
86
87#[derive(Debug)]
89pub struct StartFromResult {
90 pub events: Vec<TimelineEvent>,
92
93 pub has_prev: bool,
95
96 pub has_next: bool,
98}
99
100struct ResetStateGuard {
102 target: Option<PaginatorState>,
103 state: SharedObservable<PaginatorState>,
104}
105
106impl ResetStateGuard {
107 fn new(state: SharedObservable<PaginatorState>, target: PaginatorState) -> Self {
109 Self { target: Some(target), state }
110 }
111
112 fn disarm(mut self) {
114 self.target = None;
115 }
116}
117
118impl Drop for ResetStateGuard {
119 fn drop(&mut self) {
120 if let Some(target) = self.target.take() {
121 self.state.set_if_not_eq(target);
122 }
123 }
124}
125
126impl<PR: PaginableRoom> Paginator<PR> {
127 pub fn new(room: PR) -> Self {
129 Self {
130 room,
131 state: SharedObservable::new(PaginatorState::Initial),
132 tokens: Mutex::new(PaginationTokens { previous: None.into(), next: None.into() }),
133 }
134 }
135
136 fn check_state(&self, expected: PaginatorState) -> Result<(), PaginatorError> {
138 let actual = self.state.get();
139 if actual != expected {
140 Err(PaginatorError::InvalidPreviousState { expected, actual })
141 } else {
142 Ok(())
143 }
144 }
145
146 pub fn state(&self) -> Subscriber<PaginatorState> {
148 self.state.subscribe()
149 }
150
151 pub async fn start_from(
157 &self,
158 event_id: &EventId,
159 num_events: UInt,
160 ) -> Result<StartFromResult, PaginatorError> {
161 self.check_state(PaginatorState::Initial)?;
162
163 if self.state.set_if_not_eq(PaginatorState::FetchingTargetEvent).is_none() {
167 return Err(PaginatorError::InvalidPreviousState {
168 expected: PaginatorState::Initial,
169 actual: PaginatorState::FetchingTargetEvent,
170 });
171 }
172
173 let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Initial);
174
175 let lazy_load_members = true;
177
178 let response =
179 self.room.event_with_context(event_id, lazy_load_members, num_events).await?;
180
181 let has_prev = response.prev_batch_token.is_some();
186 let has_next = response.next_batch_token.is_some();
187
188 {
189 let mut tokens = self.tokens.lock().unwrap();
190 tokens.previous = match response.prev_batch_token {
191 Some(token) => PaginationToken::HasMore(token),
192 None => PaginationToken::HitEnd,
193 };
194 tokens.next = match response.next_batch_token {
195 Some(token) => PaginationToken::HasMore(token),
196 None => PaginationToken::HitEnd,
197 };
198 }
199
200 reset_state_guard.disarm();
202 self.state.set(PaginatorState::Idle);
204
205 let events = response
212 .events_before
213 .into_iter()
214 .rev()
215 .chain(response.event)
216 .chain(response.events_after)
217 .collect();
218
219 Ok(StartFromResult { events, has_prev, has_next })
220 }
221
222 pub async fn paginate_backward(
231 &self,
232 num_events: UInt,
233 ) -> Result<PaginationResult, PaginatorError> {
234 self.paginate(Direction::Backward, num_events).await
235 }
236
237 pub fn hit_timeline_start(&self) -> bool {
242 matches!(self.tokens.lock().unwrap().previous, PaginationToken::HitEnd)
243 }
244
245 pub fn hit_timeline_end(&self) -> bool {
250 matches!(self.tokens.lock().unwrap().next, PaginationToken::HitEnd)
251 }
252
253 pub async fn paginate_forward(
261 &self,
262 num_events: UInt,
263 ) -> Result<PaginationResult, PaginatorError> {
264 self.paginate(Direction::Forward, num_events).await
265 }
266
267 async fn paginate(
271 &self,
272 dir: Direction,
273 num_events: UInt,
274 ) -> Result<PaginationResult, PaginatorError> {
275 self.check_state(PaginatorState::Idle)?;
276
277 let token = {
278 let tokens = self.tokens.lock().unwrap();
279
280 let token = match dir {
281 Direction::Backward => &tokens.previous,
282 Direction::Forward => &tokens.next,
283 };
284
285 match token {
286 PaginationToken::None => None,
287 PaginationToken::HasMore(val) => Some(val.clone()),
288 PaginationToken::HitEnd => {
289 return Ok(PaginationResult { events: Vec::new(), hit_end_of_timeline: true });
290 }
291 }
292 };
293
294 if self.state.set_if_not_eq(PaginatorState::Paginating).is_none() {
298 return Err(PaginatorError::InvalidPreviousState {
299 expected: PaginatorState::Idle,
300 actual: PaginatorState::Paginating,
301 });
302 }
303
304 let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Idle);
305
306 let mut options = MessagesOptions::new(dir).from(token.as_deref());
307 options.limit = num_events;
308
309 let response = self.room.messages(options).await?;
312
313 let hit_end_of_timeline = response.end.is_none();
318
319 {
320 let mut tokens = self.tokens.lock().unwrap();
321
322 let token = match dir {
323 Direction::Backward => &mut tokens.previous,
324 Direction::Forward => &mut tokens.next,
325 };
326
327 *token = match response.end {
328 Some(val) => PaginationToken::HasMore(val),
329 None => PaginationToken::HitEnd,
330 };
331 }
332
333 reset_state_guard.disarm();
337 self.state.set(PaginatorState::Idle);
339
340 Ok(PaginationResult { events: response.chunk, hit_end_of_timeline })
341 }
342}
343
344pub trait PaginableRoom: SendOutsideWasm + SyncOutsideWasm {
349 fn event_with_context(
364 &self,
365 event_id: &EventId,
366 lazy_load_members: bool,
367 num_events: UInt,
368 ) -> impl Future<Output = Result<EventWithContextResponse, PaginatorError>> + SendOutsideWasm;
369
370 fn messages(
372 &self,
373 opts: MessagesOptions,
374 ) -> impl Future<Output = Result<Messages, PaginatorError>> + SendOutsideWasm;
375}
376
377impl PaginableRoom for Room {
378 async fn event_with_context(
379 &self,
380 event_id: &EventId,
381 lazy_load_members: bool,
382 num_events: UInt,
383 ) -> Result<EventWithContextResponse, PaginatorError> {
384 let response =
385 match self.event_with_context(event_id, lazy_load_members, num_events, None).await {
386 Ok(result) => result,
387
388 Err(err) => {
389 if let Some(error) = err.as_client_api_error() {
393 if error.status_code == 404 {
394 return Err(PaginatorError::EventNotFound(event_id.to_owned()));
396 }
397 }
398
399 return Err(PaginatorError::SdkError(Box::new(err)));
401 }
402 };
403
404 Ok(response)
405 }
406
407 async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
408 self.messages(opts).await.map_err(|err| PaginatorError::SdkError(Box::new(err)))
409 }
410}
411
412#[cfg(all(not(target_family = "wasm"), test))]
413mod tests {
414 use std::sync::Arc;
415
416 use assert_matches2::assert_let;
417 use futures_core::Future;
418 use futures_util::FutureExt as _;
419 use matrix_sdk_base::deserialized_responses::TimelineEvent;
420 use matrix_sdk_test::{async_test, event_factory::EventFactory};
421 use once_cell::sync::Lazy;
422 use ruma::{api::Direction, event_id, room_id, uint, user_id, EventId, RoomId, UInt, UserId};
423 use tokio::{
424 spawn,
425 sync::{Mutex, Notify},
426 task::AbortHandle,
427 };
428
429 use super::{PaginableRoom, PaginatorError, PaginatorState};
430 use crate::{
431 paginators::Paginator,
432 room::{EventWithContextResponse, Messages, MessagesOptions},
433 test_utils::assert_event_matches_msg,
434 };
435
436 #[derive(Clone)]
437 struct TestRoom {
438 event_factory: Arc<EventFactory>,
439 wait_for_ready: bool,
440
441 target_event_text: Arc<Mutex<String>>,
442 next_events: Arc<Mutex<Vec<TimelineEvent>>>,
443 prev_events: Arc<Mutex<Vec<TimelineEvent>>>,
444 prev_batch_token: Arc<Mutex<Option<String>>>,
445 next_batch_token: Arc<Mutex<Option<String>>>,
446
447 room_ready: Arc<Notify>,
448 }
449
450 impl TestRoom {
451 fn new(wait_for_ready: bool, room_id: &RoomId, sender: &UserId) -> Self {
452 let event_factory = Arc::new(EventFactory::default().sender(sender).room(room_id));
453
454 Self {
455 event_factory,
456 wait_for_ready,
457
458 room_ready: Default::default(),
459 target_event_text: Default::default(),
460 next_events: Default::default(),
461 prev_events: Default::default(),
462 prev_batch_token: Default::default(),
463 next_batch_token: Default::default(),
464 }
465 }
466
467 fn mark_ready(&self) {
469 self.room_ready.notify_one();
470 }
471 }
472
473 static ROOM_ID: Lazy<&RoomId> = Lazy::new(|| room_id!("!dune:herbert.org"));
474 static USER_ID: Lazy<&UserId> = Lazy::new(|| user_id!("@paul:atreid.es"));
475
476 impl PaginableRoom for TestRoom {
477 async fn event_with_context(
478 &self,
479 event_id: &EventId,
480 _lazy_load_members: bool,
481 num_events: UInt,
482 ) -> Result<EventWithContextResponse, PaginatorError> {
483 if self.wait_for_ready {
485 self.room_ready.notified().await;
486 }
487
488 let event = self
489 .event_factory
490 .text_msg(self.target_event_text.lock().await.clone())
491 .event_id(event_id)
492 .into_event();
493
494 let mut num_events = u64::from(num_events) as usize;
497
498 let prev_events = self.prev_events.lock().await;
499
500 let events_before = if prev_events.is_empty() {
501 Vec::new()
502 } else {
503 let len = prev_events.len();
504 let take_before = num_events.min(len);
505 num_events -= take_before;
507 prev_events[len - take_before..len].to_vec()
509 };
510
511 let events_after = self.next_events.lock().await;
512 let events_after = if events_after.is_empty() {
513 Vec::new()
514 } else {
515 events_after[0..num_events.min(events_after.len())].to_vec()
516 };
517
518 Ok(EventWithContextResponse {
519 event: Some(event),
520 events_before,
521 events_after,
522 prev_batch_token: self.prev_batch_token.lock().await.clone(),
523 next_batch_token: self.next_batch_token.lock().await.clone(),
524 state: Vec::new(),
525 })
526 }
527
528 async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
529 if self.wait_for_ready {
530 self.room_ready.notified().await;
531 }
532
533 let limit = u64::from(opts.limit) as usize;
534
535 let (end, events) = match opts.dir {
536 Direction::Backward => {
537 let events = self.prev_events.lock().await;
538 let events = if events.is_empty() {
539 Vec::new()
540 } else {
541 let len = events.len();
542 let take_before = limit.min(len);
543 events[len - take_before..len].to_vec()
545 };
546 (self.prev_batch_token.lock().await.clone(), events)
547 }
548
549 Direction::Forward => {
550 let events = self.next_events.lock().await;
551 let events = if events.is_empty() {
552 Vec::new()
553 } else {
554 events[0..limit.min(events.len())].to_vec()
555 };
556 (self.next_batch_token.lock().await.clone(), events)
557 }
558 };
559
560 Ok(Messages { start: opts.from.unwrap(), end, chunk: events, state: Vec::new() })
561 }
562 }
563
564 async fn assert_invalid_state<T: std::fmt::Debug>(
565 task: impl Future<Output = Result<T, PaginatorError>>,
566 expected: PaginatorState,
567 actual: PaginatorState,
568 ) {
569 assert_let!(
570 Err(PaginatorError::InvalidPreviousState {
571 expected: real_expected,
572 actual: real_actual
573 }) = task.await
574 );
575 assert_eq!(real_expected, expected);
576 assert_eq!(real_actual, actual);
577 }
578
579 #[async_test]
580 async fn test_start_from() {
581 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
583
584 let event_id = event_id!("$yoyoyo");
585 let event_factory = &room.event_factory;
586
587 *room.target_event_text.lock().await = "fetch_from".to_owned();
588 *room.prev_events.lock().await = (0..10)
589 .rev()
590 .map(|i| event_factory.text_msg(format!("before-{i}")).into_event())
591 .collect();
592 *room.next_events.lock().await =
593 (0..10).map(|i| event_factory.text_msg(format!("after-{i}")).into_event()).collect();
594
595 let paginator = Arc::new(Paginator::new(room.clone()));
597 let context =
598 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
599
600 assert!(!context.has_prev);
601 assert!(!context.has_next);
602
603 assert_eq!(context.events.len(), 21);
607
608 for i in 0..10 {
609 assert_event_matches_msg(&context.events[i], &format!("before-{i}"));
610 }
611
612 assert_event_matches_msg(&context.events[10], "fetch_from");
613 assert_eq!(context.events[10].raw().deserialize().unwrap().event_id(), event_id);
614
615 for i in 0..10 {
616 assert_event_matches_msg(&context.events[i + 11], &format!("after-{i}"));
617 }
618 }
619
620 #[async_test]
621 async fn test_start_from_with_num_events() {
622 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
624
625 let event_id = event_id!("$yoyoyo");
626 let event_factory = &room.event_factory;
627
628 *room.target_event_text.lock().await = "fetch_from".to_owned();
629 *room.prev_events.lock().await =
630 (0..100).rev().map(|i| event_factory.text_msg(format!("ev{i}")).into_event()).collect();
631
632 let paginator = Arc::new(Paginator::new(room.clone()));
634 let context =
635 paginator.start_from(event_id, uint!(10)).await.expect("start_from should work");
636
637 assert_eq!(context.events.len(), 11);
640
641 for i in 0..10 {
642 assert_event_matches_msg(&context.events[i], &format!("ev{i}"));
643 }
644 assert_event_matches_msg(&context.events[10], "fetch_from");
645 }
646
647 #[async_test]
648 async fn test_paginate_backward() {
649 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
651
652 let event_id = event_id!("$yoyoyo");
653 let event_factory = &room.event_factory;
654
655 *room.target_event_text.lock().await = "initial".to_owned();
656 *room.prev_batch_token.lock().await = Some("prev".to_owned());
657
658 let paginator = Arc::new(Paginator::new(room.clone()));
660
661 assert!(!paginator.hit_timeline_start(), "we must have a prev-batch token");
662 assert!(
663 !paginator.hit_timeline_end(),
664 "we don't know about the status of the next-batch token"
665 );
666
667 let context =
668 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
669
670 assert_eq!(context.events.len(), 1);
672 assert_event_matches_msg(&context.events[0], "initial");
673 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
674
675 assert!(context.has_prev);
677 assert!(!context.has_next);
678
679 assert!(!paginator.hit_timeline_start());
680 assert!(paginator.hit_timeline_end());
681
682 *room.prev_events.lock().await = vec![event_factory.text_msg("previous").into_event()];
684 *room.prev_batch_token.lock().await = Some("prev2".to_owned());
685
686 let prev =
688 paginator.paginate_backward(uint!(100)).await.expect("paginate backward should work");
689 assert!(!prev.hit_end_of_timeline);
690 assert!(!paginator.hit_timeline_start());
691 assert_eq!(prev.events.len(), 1);
692 assert_event_matches_msg(&prev.events[0], "previous");
693
694 *room.prev_events.lock().await = vec![event_factory.text_msg("oldest").into_event()];
697 *room.prev_batch_token.lock().await = None;
698
699 let prev = paginator
700 .paginate_backward(uint!(100))
701 .await
702 .expect("paginate backward the second time should work");
703 assert!(prev.hit_end_of_timeline);
704 assert!(paginator.hit_timeline_start());
705 assert_eq!(prev.events.len(), 1);
706 assert_event_matches_msg(&prev.events[0], "oldest");
707
708 let prev = paginator
711 .paginate_backward(uint!(100))
712 .await
713 .expect("paginate backward the third time should work");
714 assert!(prev.hit_end_of_timeline);
715 assert!(paginator.hit_timeline_start());
716 assert!(prev.events.is_empty());
717 }
718
719 #[async_test]
720 async fn test_paginate_backward_with_limit() {
721 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
723
724 let event_id = event_id!("$yoyoyo");
725 let event_factory = &room.event_factory;
726
727 *room.target_event_text.lock().await = "initial".to_owned();
728 *room.prev_batch_token.lock().await = Some("prev".to_owned());
729
730 let paginator = Arc::new(Paginator::new(room.clone()));
732 let context =
733 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
734
735 assert_eq!(context.events.len(), 1);
737 assert_event_matches_msg(&context.events[0], "initial");
738 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
739
740 assert!(context.has_prev);
742 assert!(!context.has_next);
743
744 *room.prev_events.lock().await = (0..100)
746 .rev()
747 .map(|i| event_factory.text_msg(format!("prev{i}")).into_event())
748 .collect();
749 *room.prev_batch_token.lock().await = None;
750
751 let prev =
753 paginator.paginate_backward(uint!(10)).await.expect("paginate backward should work");
754 assert!(prev.hit_end_of_timeline);
755 assert_eq!(prev.events.len(), 10);
756 for i in 0..10 {
757 assert_event_matches_msg(&prev.events[i], &format!("prev{}", 9 - i));
758 }
759 }
760
761 #[async_test]
762 async fn test_paginate_forward() {
763 let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
765
766 let event_id = event_id!("$yoyoyo");
767 let event_factory = &room.event_factory;
768
769 *room.target_event_text.lock().await = "initial".to_owned();
770 *room.next_batch_token.lock().await = Some("next".to_owned());
771
772 let paginator = Arc::new(Paginator::new(room.clone()));
774 assert!(!paginator.hit_timeline_end(), "we must have a next-batch token");
775 assert!(
776 !paginator.hit_timeline_start(),
777 "we don't know about the status of the prev-batch token"
778 );
779
780 let context =
781 paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
782
783 assert_eq!(context.events.len(), 1);
785 assert_event_matches_msg(&context.events[0], "initial");
786 assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
787
788 assert!(!context.has_prev);
791 assert!(context.has_next);
792
793 assert!(paginator.hit_timeline_start());
794 assert!(!paginator.hit_timeline_end());
795
796 *room.next_events.lock().await = vec![event_factory.text_msg("next").into_event()];
798 *room.next_batch_token.lock().await = Some("next2".to_owned());
799
800 let next =
802 paginator.paginate_forward(uint!(100)).await.expect("paginate forward should work");
803 assert!(!next.hit_end_of_timeline);
804 assert_eq!(next.events.len(), 1);
805 assert_event_matches_msg(&next.events[0], "next");
806 assert!(!paginator.hit_timeline_end());
807
808 *room.next_events.lock().await = vec![event_factory.text_msg("latest").into_event()];
811 *room.next_batch_token.lock().await = None;
812
813 let next = paginator
814 .paginate_forward(uint!(100))
815 .await
816 .expect("paginate forward the second time should work");
817 assert!(next.hit_end_of_timeline);
818 assert_eq!(next.events.len(), 1);
819 assert_event_matches_msg(&next.events[0], "latest");
820 assert!(paginator.hit_timeline_end());
821
822 let next = paginator
825 .paginate_forward(uint!(100))
826 .await
827 .expect("paginate forward the third time should work");
828 assert!(next.hit_end_of_timeline);
829 assert!(next.events.is_empty());
830 assert!(paginator.hit_timeline_end());
831 }
832
833 #[async_test]
834 async fn test_state() {
835 let room = TestRoom::new(true, *ROOM_ID, *USER_ID);
836
837 *room.prev_batch_token.lock().await = Some("prev".to_owned());
838 *room.next_batch_token.lock().await = Some("next".to_owned());
839
840 let paginator = Arc::new(Paginator::new(room.clone()));
841
842 let event_id = event_id!("$yoyoyo");
843
844 let mut state = paginator.state();
845
846 assert_eq!(state.get(), PaginatorState::Initial);
847 assert!(state.next().now_or_never().is_none());
848
849 assert_invalid_state(
851 paginator.paginate_backward(uint!(100)),
852 PaginatorState::Idle,
853 PaginatorState::Initial,
854 )
855 .await;
856
857 assert!(state.next().now_or_never().is_none());
858
859 let p = paginator.clone();
861 let join_handle = spawn(async move { p.start_from(event_id, uint!(100)).await });
862
863 assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
864 assert!(state.next().now_or_never().is_none());
865
866 assert_invalid_state(
868 paginator.start_from(event_id, uint!(100)),
869 PaginatorState::Initial,
870 PaginatorState::FetchingTargetEvent,
871 )
872 .await;
873
874 assert_invalid_state(
875 paginator.paginate_backward(uint!(100)),
876 PaginatorState::Idle,
877 PaginatorState::FetchingTargetEvent,
878 )
879 .await;
880
881 assert!(state.next().now_or_never().is_none());
882
883 room.mark_ready();
885
886 assert_eq!(state.next().await, Some(PaginatorState::Idle));
888
889 join_handle.await.expect("joined failed").expect("/context failed");
890
891 assert!(state.next().now_or_never().is_none());
892
893 let p = paginator.clone();
894 let join_handle = spawn(async move { p.paginate_backward(uint!(100)).await });
895
896 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
897
898 assert_invalid_state(
900 paginator.start_from(event_id, uint!(100)),
901 PaginatorState::Initial,
902 PaginatorState::Paginating,
903 )
904 .await;
905
906 assert_invalid_state(
907 paginator.paginate_backward(uint!(100)),
908 PaginatorState::Idle,
909 PaginatorState::Paginating,
910 )
911 .await;
912
913 assert_invalid_state(
914 paginator.paginate_forward(uint!(100)),
915 PaginatorState::Idle,
916 PaginatorState::Paginating,
917 )
918 .await;
919
920 assert!(state.next().now_or_never().is_none());
921
922 room.mark_ready();
923
924 assert_eq!(state.next().await, Some(PaginatorState::Idle));
925
926 join_handle.await.expect("joined failed").expect("/messages failed");
927
928 assert!(state.next().now_or_never().is_none());
929 }
930
931 mod aborts {
932 use super::*;
933 use crate::paginators::room::{PaginationToken, PaginationTokens};
934
935 #[derive(Clone, Default)]
936 struct AbortingRoom {
937 abort_handle: Arc<Mutex<Option<AbortHandle>>>,
938 room_ready: Arc<Notify>,
939 }
940
941 impl AbortingRoom {
942 async fn wait_abort_and_yield(&self) -> ! {
943 self.room_ready.notified().await;
945
946 let mut guard = self.abort_handle.lock().await;
948 let handle = guard.take().expect("only call me when i'm initialized");
949 handle.abort();
950
951 loop {
953 tokio::task::yield_now().await;
954 }
955 }
956 }
957
958 impl PaginableRoom for AbortingRoom {
959 async fn event_with_context(
960 &self,
961 _event_id: &EventId,
962 _lazy_load_members: bool,
963 _num_events: UInt,
964 ) -> Result<EventWithContextResponse, PaginatorError> {
965 self.wait_abort_and_yield().await
966 }
967
968 async fn messages(&self, _opts: MessagesOptions) -> Result<Messages, PaginatorError> {
969 self.wait_abort_and_yield().await
970 }
971 }
972
973 #[async_test]
974 async fn test_abort_while_starting_from() {
975 let room = AbortingRoom::default();
976
977 let paginator = Arc::new(Paginator::new(room.clone()));
978
979 let mut state = paginator.state();
980
981 assert_eq!(state.get(), PaginatorState::Initial);
982 assert!(state.next().now_or_never().is_none());
983
984 let p = paginator.clone();
986 let join_handle = spawn(async move {
987 let _ = p.start_from(event_id!("$yoyoyo"), uint!(100)).await;
988 });
989
990 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
991
992 assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
993 assert!(state.next().now_or_never().is_none());
994
995 room.room_ready.notify_one();
996
997 let join_result = join_handle.await;
999 assert!(join_result.unwrap_err().is_cancelled());
1000
1001 assert_eq!(state.next().await, Some(PaginatorState::Initial));
1003 assert!(state.next().now_or_never().is_none());
1004 }
1005
1006 #[async_test]
1007 async fn test_abort_while_paginating() {
1008 let room = AbortingRoom::default();
1009
1010 let paginator = Paginator::new(room.clone());
1012 paginator.state.set(PaginatorState::Idle);
1013 *paginator.tokens.lock().unwrap() = PaginationTokens {
1014 previous: PaginationToken::HasMore("prev".to_owned()),
1015 next: PaginationToken::HasMore("next".to_owned()),
1016 };
1017
1018 let paginator = Arc::new(paginator);
1019
1020 let mut state = paginator.state();
1021
1022 assert_eq!(state.get(), PaginatorState::Idle);
1023 assert!(state.next().now_or_never().is_none());
1024
1025 let p = paginator.clone();
1027 let join_handle = spawn(async move {
1028 let _ = p.paginate_backward(uint!(100)).await;
1029 });
1030
1031 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1032
1033 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1034 assert!(state.next().now_or_never().is_none());
1035
1036 room.room_ready.notify_one();
1037
1038 let join_result = join_handle.await;
1040 assert!(join_result.unwrap_err().is_cancelled());
1041
1042 assert_eq!(state.next().await, Some(PaginatorState::Idle));
1044 assert!(state.next().now_or_never().is_none());
1045
1046 let p = paginator.clone();
1048 let join_handle = spawn(async move {
1049 let _ = p.paginate_forward(uint!(100)).await;
1050 });
1051
1052 *room.abort_handle.lock().await = Some(join_handle.abort_handle());
1053
1054 assert_eq!(state.next().await, Some(PaginatorState::Paginating));
1055 assert!(state.next().now_or_never().is_none());
1056
1057 room.room_ready.notify_one();
1058
1059 let join_result = join_handle.await;
1060 assert!(join_result.unwrap_err().is_cancelled());
1061
1062 assert_eq!(state.next().await, Some(PaginatorState::Idle));
1063 assert!(state.next().now_or_never().is_none());
1064 }
1065 }
1066}