use std::{future::Future, sync::Mutex};
use eyeball::{SharedObservable, Subscriber};
use matrix_sdk_base::{deserialized_responses::TimelineEvent, SendOutsideWasm, SyncOutsideWasm};
use ruma::{api::Direction, EventId, OwnedEventId, UInt};
use super::pagination::PaginationToken;
use crate::{
room::{EventWithContextResponse, Messages, MessagesOptions, WeakRoom},
Room,
};
#[derive(Debug, PartialEq, Copy, Clone)]
#[cfg_attr(feature = "uniffi", derive(uniffi::Enum))]
pub enum PaginatorState {
Initial,
FetchingTargetEvent,
Idle,
Paginating,
}
#[derive(Debug, thiserror::Error)]
pub enum PaginatorError {
#[error("target event with id {0} could not be found")]
EventNotFound(OwnedEventId),
#[error("expected paginator state {expected:?}, observed {actual:?}")]
InvalidPreviousState {
expected: PaginatorState,
actual: PaginatorState,
},
#[error("an error happened while paginating: {0}")]
SdkError(#[from] Box<crate::Error>),
}
#[derive(Debug)]
struct PaginationTokens {
previous: PaginationToken,
next: PaginationToken,
}
pub struct Paginator<PR: PaginableRoom> {
room: PR,
state: SharedObservable<PaginatorState>,
tokens: Mutex<PaginationTokens>,
}
#[cfg(not(tarpaulin_include))]
impl<PR: PaginableRoom> std::fmt::Debug for Paginator<PR> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Paginator")
.field("state", &self.state.get())
.field("tokens", &self.tokens)
.finish_non_exhaustive()
}
}
#[derive(Debug)]
pub struct PaginationResult {
pub events: Vec<TimelineEvent>,
pub hit_end_of_timeline: bool,
}
#[derive(Debug)]
pub struct StartFromResult {
pub events: Vec<TimelineEvent>,
pub has_prev: bool,
pub has_next: bool,
}
struct ResetStateGuard {
target: Option<PaginatorState>,
state: SharedObservable<PaginatorState>,
}
impl ResetStateGuard {
fn new(state: SharedObservable<PaginatorState>, target: PaginatorState) -> Self {
Self { target: Some(target), state }
}
fn disarm(mut self) {
self.target = None;
}
}
impl Drop for ResetStateGuard {
fn drop(&mut self) {
if let Some(target) = self.target.take() {
self.state.set_if_not_eq(target);
}
}
}
impl<PR: PaginableRoom> Paginator<PR> {
pub fn new(room: PR) -> Self {
Self {
room,
state: SharedObservable::new(PaginatorState::Initial),
tokens: Mutex::new(PaginationTokens { previous: None.into(), next: None.into() }),
}
}
fn check_state(&self, expected: PaginatorState) -> Result<(), PaginatorError> {
let actual = self.state.get();
if actual != expected {
Err(PaginatorError::InvalidPreviousState { expected, actual })
} else {
Ok(())
}
}
pub fn state(&self) -> Subscriber<PaginatorState> {
self.state.subscribe()
}
pub(super) fn set_idle_state(
&self,
next_state: PaginatorState,
prev_batch_token: Option<String>,
next_batch_token: Option<String>,
) -> Result<(), PaginatorError> {
let prev_state = self.state.get();
match next_state {
PaginatorState::Initial | PaginatorState::Idle => {}
PaginatorState::FetchingTargetEvent | PaginatorState::Paginating => {
panic!("internal error: set_idle_state only accept Initial|Idle next states");
}
}
match prev_state {
PaginatorState::Initial | PaginatorState::Idle => {}
PaginatorState::FetchingTargetEvent | PaginatorState::Paginating => {
return Err(PaginatorError::InvalidPreviousState {
expected: PaginatorState::Idle,
actual: prev_state,
});
}
}
self.state.set_if_not_eq(next_state);
{
let mut tokens = self.tokens.lock().unwrap();
tokens.previous = prev_batch_token.into();
tokens.next = next_batch_token.into();
}
Ok(())
}
pub(super) fn prev_batch_token(&self) -> Option<String> {
match &self.tokens.lock().unwrap().previous {
PaginationToken::HitEnd | PaginationToken::None => None,
PaginationToken::HasMore(token) => Some(token.clone()),
}
}
pub async fn start_from(
&self,
event_id: &EventId,
num_events: UInt,
) -> Result<StartFromResult, PaginatorError> {
self.check_state(PaginatorState::Initial)?;
if self.state.set_if_not_eq(PaginatorState::FetchingTargetEvent).is_none() {
return Err(PaginatorError::InvalidPreviousState {
expected: PaginatorState::Initial,
actual: PaginatorState::FetchingTargetEvent,
});
}
let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Initial);
let lazy_load_members = true;
let response =
self.room.event_with_context(event_id, lazy_load_members, num_events).await?;
let has_prev = response.prev_batch_token.is_some();
let has_next = response.next_batch_token.is_some();
{
let mut tokens = self.tokens.lock().unwrap();
tokens.previous = match response.prev_batch_token {
Some(token) => PaginationToken::HasMore(token),
None => PaginationToken::HitEnd,
};
tokens.next = match response.next_batch_token {
Some(token) => PaginationToken::HasMore(token),
None => PaginationToken::HitEnd,
};
}
reset_state_guard.disarm();
self.state.set(PaginatorState::Idle);
let events = response
.events_before
.into_iter()
.rev()
.chain(response.event)
.chain(response.events_after)
.collect();
Ok(StartFromResult { events, has_prev, has_next })
}
pub async fn paginate_backward(
&self,
num_events: UInt,
) -> Result<PaginationResult, PaginatorError> {
self.paginate(Direction::Backward, num_events).await
}
pub fn hit_timeline_start(&self) -> bool {
matches!(self.tokens.lock().unwrap().previous, PaginationToken::HitEnd)
}
pub fn hit_timeline_end(&self) -> bool {
matches!(self.tokens.lock().unwrap().next, PaginationToken::HitEnd)
}
pub async fn paginate_forward(
&self,
num_events: UInt,
) -> Result<PaginationResult, PaginatorError> {
self.paginate(Direction::Forward, num_events).await
}
async fn paginate(
&self,
dir: Direction,
num_events: UInt,
) -> Result<PaginationResult, PaginatorError> {
self.check_state(PaginatorState::Idle)?;
let token = {
let tokens = self.tokens.lock().unwrap();
let token = match dir {
Direction::Backward => &tokens.previous,
Direction::Forward => &tokens.next,
};
match token {
PaginationToken::None => None,
PaginationToken::HasMore(val) => Some(val.clone()),
PaginationToken::HitEnd => {
return Ok(PaginationResult { events: Vec::new(), hit_end_of_timeline: true });
}
}
};
if self.state.set_if_not_eq(PaginatorState::Paginating).is_none() {
return Err(PaginatorError::InvalidPreviousState {
expected: PaginatorState::Idle,
actual: PaginatorState::Paginating,
});
}
let reset_state_guard = ResetStateGuard::new(self.state.clone(), PaginatorState::Idle);
let mut options = MessagesOptions::new(dir).from(token.as_deref());
options.limit = num_events;
let response = self.room.messages(options).await?;
let hit_end_of_timeline = response.end.is_none();
{
let mut tokens = self.tokens.lock().unwrap();
let token = match dir {
Direction::Backward => &mut tokens.previous,
Direction::Forward => &mut tokens.next,
};
*token = match response.end {
Some(val) => PaginationToken::HasMore(val),
None => PaginationToken::HitEnd,
};
}
reset_state_guard.disarm();
self.state.set(PaginatorState::Idle);
Ok(PaginationResult { events: response.chunk, hit_end_of_timeline })
}
}
pub trait PaginableRoom: SendOutsideWasm + SyncOutsideWasm {
fn event_with_context(
&self,
event_id: &EventId,
lazy_load_members: bool,
num_events: UInt,
) -> impl Future<Output = Result<EventWithContextResponse, PaginatorError>> + SendOutsideWasm;
fn messages(
&self,
opts: MessagesOptions,
) -> impl Future<Output = Result<Messages, PaginatorError>> + SendOutsideWasm;
}
impl PaginableRoom for Room {
async fn event_with_context(
&self,
event_id: &EventId,
lazy_load_members: bool,
num_events: UInt,
) -> Result<EventWithContextResponse, PaginatorError> {
let response =
match self.event_with_context(event_id, lazy_load_members, num_events, None).await {
Ok(result) => result,
Err(err) => {
if let Some(error) = err.as_client_api_error() {
if error.status_code == 404 {
return Err(PaginatorError::EventNotFound(event_id.to_owned()));
}
}
return Err(PaginatorError::SdkError(Box::new(err)));
}
};
Ok(response)
}
async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
self.messages(opts).await.map_err(|err| PaginatorError::SdkError(Box::new(err)))
}
}
impl PaginableRoom for WeakRoom {
async fn event_with_context(
&self,
event_id: &EventId,
lazy_load_members: bool,
num_events: UInt,
) -> Result<EventWithContextResponse, PaginatorError> {
let Some(room) = self.get() else {
return Ok(EventWithContextResponse::default());
};
PaginableRoom::event_with_context(&room, event_id, lazy_load_members, num_events).await
}
async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
let Some(room) = self.get() else {
return Ok(Messages::default());
};
PaginableRoom::messages(&room, opts).await
}
}
#[cfg(all(not(target_arch = "wasm32"), test))]
mod tests {
use std::sync::Arc;
use assert_matches2::assert_let;
use futures_core::Future;
use futures_util::FutureExt as _;
use matrix_sdk_base::deserialized_responses::TimelineEvent;
use matrix_sdk_test::{async_test, event_factory::EventFactory};
use once_cell::sync::Lazy;
use ruma::{api::Direction, event_id, room_id, uint, user_id, EventId, RoomId, UInt, UserId};
use tokio::{
spawn,
sync::{Mutex, Notify},
task::AbortHandle,
};
use super::{PaginableRoom, PaginatorError, PaginatorState};
use crate::{
event_cache::paginator::Paginator,
room::{EventWithContextResponse, Messages, MessagesOptions},
test_utils::assert_event_matches_msg,
};
#[derive(Clone)]
struct TestRoom {
event_factory: Arc<EventFactory>,
wait_for_ready: bool,
target_event_text: Arc<Mutex<String>>,
next_events: Arc<Mutex<Vec<TimelineEvent>>>,
prev_events: Arc<Mutex<Vec<TimelineEvent>>>,
prev_batch_token: Arc<Mutex<Option<String>>>,
next_batch_token: Arc<Mutex<Option<String>>>,
room_ready: Arc<Notify>,
}
impl TestRoom {
fn new(wait_for_ready: bool, room_id: &RoomId, sender: &UserId) -> Self {
let event_factory = Arc::new(EventFactory::default().sender(sender).room(room_id));
Self {
event_factory,
wait_for_ready,
room_ready: Default::default(),
target_event_text: Default::default(),
next_events: Default::default(),
prev_events: Default::default(),
prev_batch_token: Default::default(),
next_batch_token: Default::default(),
}
}
fn mark_ready(&self) {
self.room_ready.notify_one();
}
}
static ROOM_ID: Lazy<&RoomId> = Lazy::new(|| room_id!("!dune:herbert.org"));
static USER_ID: Lazy<&UserId> = Lazy::new(|| user_id!("@paul:atreid.es"));
impl PaginableRoom for TestRoom {
async fn event_with_context(
&self,
event_id: &EventId,
_lazy_load_members: bool,
num_events: UInt,
) -> Result<EventWithContextResponse, PaginatorError> {
if self.wait_for_ready {
self.room_ready.notified().await;
}
let event = self
.event_factory
.text_msg(self.target_event_text.lock().await.clone())
.event_id(event_id)
.into_timeline();
let mut num_events = u64::from(num_events) as usize;
let prev_events = self.prev_events.lock().await;
let events_before = if prev_events.is_empty() {
Vec::new()
} else {
let len = prev_events.len();
let take_before = num_events.min(len);
num_events -= take_before;
prev_events[len - take_before..len].to_vec()
};
let events_after = self.next_events.lock().await;
let events_after = if events_after.is_empty() {
Vec::new()
} else {
events_after[0..num_events.min(events_after.len())].to_vec()
};
Ok(EventWithContextResponse {
event: Some(event),
events_before,
events_after,
prev_batch_token: self.prev_batch_token.lock().await.clone(),
next_batch_token: self.next_batch_token.lock().await.clone(),
state: Vec::new(),
})
}
async fn messages(&self, opts: MessagesOptions) -> Result<Messages, PaginatorError> {
if self.wait_for_ready {
self.room_ready.notified().await;
}
let limit = u64::from(opts.limit) as usize;
let (end, events) = match opts.dir {
Direction::Backward => {
let events = self.prev_events.lock().await;
let events = if events.is_empty() {
Vec::new()
} else {
let len = events.len();
let take_before = limit.min(len);
events[len - take_before..len].to_vec()
};
(self.prev_batch_token.lock().await.clone(), events)
}
Direction::Forward => {
let events = self.next_events.lock().await;
let events = if events.is_empty() {
Vec::new()
} else {
events[0..limit.min(events.len())].to_vec()
};
(self.next_batch_token.lock().await.clone(), events)
}
};
Ok(Messages { start: opts.from.unwrap(), end, chunk: events, state: Vec::new() })
}
}
async fn assert_invalid_state<T: std::fmt::Debug>(
task: impl Future<Output = Result<T, PaginatorError>>,
expected: PaginatorState,
actual: PaginatorState,
) {
assert_let!(
Err(PaginatorError::InvalidPreviousState {
expected: real_expected,
actual: real_actual
}) = task.await
);
assert_eq!(real_expected, expected);
assert_eq!(real_actual, actual);
}
#[async_test]
async fn test_start_from() {
let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
let event_id = event_id!("$yoyoyo");
let event_factory = &room.event_factory;
*room.target_event_text.lock().await = "fetch_from".to_owned();
*room.prev_events.lock().await = (0..10)
.rev()
.map(|i| {
TimelineEvent::new(
event_factory.text_msg(format!("before-{i}")).into_raw_timeline(),
)
})
.collect();
*room.next_events.lock().await = (0..10)
.map(|i| {
TimelineEvent::new(event_factory.text_msg(format!("after-{i}")).into_raw_timeline())
})
.collect();
let paginator = Arc::new(Paginator::new(room.clone()));
let context =
paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
assert!(!context.has_prev);
assert!(!context.has_next);
assert_eq!(context.events.len(), 21);
for i in 0..10 {
assert_event_matches_msg(&context.events[i], &format!("before-{i}"));
}
assert_event_matches_msg(&context.events[10], "fetch_from");
assert_eq!(context.events[10].raw().deserialize().unwrap().event_id(), event_id);
for i in 0..10 {
assert_event_matches_msg(&context.events[i + 11], &format!("after-{i}"));
}
}
#[async_test]
async fn test_start_from_with_num_events() {
let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
let event_id = event_id!("$yoyoyo");
let event_factory = &room.event_factory;
*room.target_event_text.lock().await = "fetch_from".to_owned();
*room.prev_events.lock().await = (0..100)
.rev()
.map(|i| {
TimelineEvent::new(event_factory.text_msg(format!("ev{i}")).into_raw_timeline())
})
.collect();
let paginator = Arc::new(Paginator::new(room.clone()));
let context =
paginator.start_from(event_id, uint!(10)).await.expect("start_from should work");
assert_eq!(context.events.len(), 11);
for i in 0..10 {
assert_event_matches_msg(&context.events[i], &format!("ev{i}"));
}
assert_event_matches_msg(&context.events[10], "fetch_from");
}
#[async_test]
async fn test_paginate_backward() {
let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
let event_id = event_id!("$yoyoyo");
let event_factory = &room.event_factory;
*room.target_event_text.lock().await = "initial".to_owned();
*room.prev_batch_token.lock().await = Some("prev".to_owned());
let paginator = Arc::new(Paginator::new(room.clone()));
assert!(!paginator.hit_timeline_start(), "we must have a prev-batch token");
assert!(
!paginator.hit_timeline_end(),
"we don't know about the status of the next-batch token"
);
let context =
paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
assert_eq!(context.events.len(), 1);
assert_event_matches_msg(&context.events[0], "initial");
assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
assert!(context.has_prev);
assert!(!context.has_next);
assert!(!paginator.hit_timeline_start());
assert!(paginator.hit_timeline_end());
*room.prev_events.lock().await = vec![event_factory.text_msg("previous").into_timeline()];
*room.prev_batch_token.lock().await = Some("prev2".to_owned());
let prev =
paginator.paginate_backward(uint!(100)).await.expect("paginate backward should work");
assert!(!prev.hit_end_of_timeline);
assert!(!paginator.hit_timeline_start());
assert_eq!(prev.events.len(), 1);
assert_event_matches_msg(&prev.events[0], "previous");
*room.prev_events.lock().await = vec![event_factory.text_msg("oldest").into_timeline()];
*room.prev_batch_token.lock().await = None;
let prev = paginator
.paginate_backward(uint!(100))
.await
.expect("paginate backward the second time should work");
assert!(prev.hit_end_of_timeline);
assert!(paginator.hit_timeline_start());
assert_eq!(prev.events.len(), 1);
assert_event_matches_msg(&prev.events[0], "oldest");
let prev = paginator
.paginate_backward(uint!(100))
.await
.expect("paginate backward the third time should work");
assert!(prev.hit_end_of_timeline);
assert!(paginator.hit_timeline_start());
assert!(prev.events.is_empty());
}
#[async_test]
async fn test_paginate_backward_with_limit() {
let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
let event_id = event_id!("$yoyoyo");
let event_factory = &room.event_factory;
*room.target_event_text.lock().await = "initial".to_owned();
*room.prev_batch_token.lock().await = Some("prev".to_owned());
let paginator = Arc::new(Paginator::new(room.clone()));
let context =
paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
assert_eq!(context.events.len(), 1);
assert_event_matches_msg(&context.events[0], "initial");
assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
assert!(context.has_prev);
assert!(!context.has_next);
*room.prev_events.lock().await = (0..100)
.rev()
.map(|i| {
TimelineEvent::new(event_factory.text_msg(format!("prev{i}")).into_raw_timeline())
})
.collect();
*room.prev_batch_token.lock().await = None;
let prev =
paginator.paginate_backward(uint!(10)).await.expect("paginate backward should work");
assert!(prev.hit_end_of_timeline);
assert_eq!(prev.events.len(), 10);
for i in 0..10 {
assert_event_matches_msg(&prev.events[i], &format!("prev{}", 9 - i));
}
}
#[async_test]
async fn test_paginate_forward() {
let room = TestRoom::new(false, *ROOM_ID, *USER_ID);
let event_id = event_id!("$yoyoyo");
let event_factory = &room.event_factory;
*room.target_event_text.lock().await = "initial".to_owned();
*room.next_batch_token.lock().await = Some("next".to_owned());
let paginator = Arc::new(Paginator::new(room.clone()));
assert!(!paginator.hit_timeline_end(), "we must have a next-batch token");
assert!(
!paginator.hit_timeline_start(),
"we don't know about the status of the prev-batch token"
);
let context =
paginator.start_from(event_id, uint!(100)).await.expect("start_from should work");
assert_eq!(context.events.len(), 1);
assert_event_matches_msg(&context.events[0], "initial");
assert_eq!(context.events[0].raw().deserialize().unwrap().event_id(), event_id);
assert!(!context.has_prev);
assert!(context.has_next);
assert!(paginator.hit_timeline_start());
assert!(!paginator.hit_timeline_end());
*room.next_events.lock().await = vec![event_factory.text_msg("next").into_timeline()];
*room.next_batch_token.lock().await = Some("next2".to_owned());
let next =
paginator.paginate_forward(uint!(100)).await.expect("paginate forward should work");
assert!(!next.hit_end_of_timeline);
assert_eq!(next.events.len(), 1);
assert_event_matches_msg(&next.events[0], "next");
assert!(!paginator.hit_timeline_end());
*room.next_events.lock().await = vec![event_factory.text_msg("latest").into_timeline()];
*room.next_batch_token.lock().await = None;
let next = paginator
.paginate_forward(uint!(100))
.await
.expect("paginate forward the second time should work");
assert!(next.hit_end_of_timeline);
assert_eq!(next.events.len(), 1);
assert_event_matches_msg(&next.events[0], "latest");
assert!(paginator.hit_timeline_end());
let next = paginator
.paginate_forward(uint!(100))
.await
.expect("paginate forward the third time should work");
assert!(next.hit_end_of_timeline);
assert!(next.events.is_empty());
assert!(paginator.hit_timeline_end());
}
#[async_test]
async fn test_state() {
let room = TestRoom::new(true, *ROOM_ID, *USER_ID);
*room.prev_batch_token.lock().await = Some("prev".to_owned());
*room.next_batch_token.lock().await = Some("next".to_owned());
let paginator = Arc::new(Paginator::new(room.clone()));
let event_id = event_id!("$yoyoyo");
let mut state = paginator.state();
assert_eq!(state.get(), PaginatorState::Initial);
assert!(state.next().now_or_never().is_none());
assert_invalid_state(
paginator.paginate_backward(uint!(100)),
PaginatorState::Idle,
PaginatorState::Initial,
)
.await;
assert!(state.next().now_or_never().is_none());
let p = paginator.clone();
let join_handle = spawn(async move { p.start_from(event_id, uint!(100)).await });
assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
assert!(state.next().now_or_never().is_none());
assert_invalid_state(
paginator.start_from(event_id, uint!(100)),
PaginatorState::Initial,
PaginatorState::FetchingTargetEvent,
)
.await;
assert_invalid_state(
paginator.paginate_backward(uint!(100)),
PaginatorState::Idle,
PaginatorState::FetchingTargetEvent,
)
.await;
assert!(state.next().now_or_never().is_none());
room.mark_ready();
assert_eq!(state.next().await, Some(PaginatorState::Idle));
join_handle.await.expect("joined failed").expect("/context failed");
assert!(state.next().now_or_never().is_none());
let p = paginator.clone();
let join_handle = spawn(async move { p.paginate_backward(uint!(100)).await });
assert_eq!(state.next().await, Some(PaginatorState::Paginating));
assert_invalid_state(
paginator.start_from(event_id, uint!(100)),
PaginatorState::Initial,
PaginatorState::Paginating,
)
.await;
assert_invalid_state(
paginator.paginate_backward(uint!(100)),
PaginatorState::Idle,
PaginatorState::Paginating,
)
.await;
assert_invalid_state(
paginator.paginate_forward(uint!(100)),
PaginatorState::Idle,
PaginatorState::Paginating,
)
.await;
assert!(state.next().now_or_never().is_none());
room.mark_ready();
assert_eq!(state.next().await, Some(PaginatorState::Idle));
join_handle.await.expect("joined failed").expect("/messages failed");
assert!(state.next().now_or_never().is_none());
}
mod aborts {
use super::*;
#[derive(Clone, Default)]
struct AbortingRoom {
abort_handle: Arc<Mutex<Option<AbortHandle>>>,
room_ready: Arc<Notify>,
}
impl AbortingRoom {
async fn wait_abort_and_yield(&self) -> ! {
self.room_ready.notified().await;
let mut guard = self.abort_handle.lock().await;
let handle = guard.take().expect("only call me when i'm initialized");
handle.abort();
loop {
tokio::task::yield_now().await;
}
}
}
impl PaginableRoom for AbortingRoom {
async fn event_with_context(
&self,
_event_id: &EventId,
_lazy_load_members: bool,
_num_events: UInt,
) -> Result<EventWithContextResponse, PaginatorError> {
self.wait_abort_and_yield().await
}
async fn messages(&self, _opts: MessagesOptions) -> Result<Messages, PaginatorError> {
self.wait_abort_and_yield().await
}
}
#[async_test]
async fn test_abort_while_starting_from() {
let room = AbortingRoom::default();
let paginator = Arc::new(Paginator::new(room.clone()));
let mut state = paginator.state();
assert_eq!(state.get(), PaginatorState::Initial);
assert!(state.next().now_or_never().is_none());
let p = paginator.clone();
let join_handle = spawn(async move {
let _ = p.start_from(event_id!("$yoyoyo"), uint!(100)).await;
});
*room.abort_handle.lock().await = Some(join_handle.abort_handle());
assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent));
assert!(state.next().now_or_never().is_none());
room.room_ready.notify_one();
let join_result = join_handle.await;
assert!(join_result.unwrap_err().is_cancelled());
assert_eq!(state.next().await, Some(PaginatorState::Initial));
assert!(state.next().now_or_never().is_none());
}
#[async_test]
async fn test_abort_while_paginating() {
let room = AbortingRoom::default();
let paginator = Paginator::new(room.clone());
paginator
.set_idle_state(
PaginatorState::Idle,
Some("prev".to_owned()),
Some("next".to_owned()),
)
.unwrap();
let paginator = Arc::new(paginator);
let mut state = paginator.state();
assert_eq!(state.get(), PaginatorState::Idle);
assert!(state.next().now_or_never().is_none());
let p = paginator.clone();
let join_handle = spawn(async move {
let _ = p.paginate_backward(uint!(100)).await;
});
*room.abort_handle.lock().await = Some(join_handle.abort_handle());
assert_eq!(state.next().await, Some(PaginatorState::Paginating));
assert!(state.next().now_or_never().is_none());
room.room_ready.notify_one();
let join_result = join_handle.await;
assert!(join_result.unwrap_err().is_cancelled());
assert_eq!(state.next().await, Some(PaginatorState::Idle));
assert!(state.next().now_or_never().is_none());
let p = paginator.clone();
let join_handle = spawn(async move {
let _ = p.paginate_forward(uint!(100)).await;
});
*room.abort_handle.lock().await = Some(join_handle.abort_handle());
assert_eq!(state.next().await, Some(PaginatorState::Paginating));
assert!(state.next().now_or_never().is_none());
room.room_ready.notify_one();
let join_result = join_handle.await;
assert!(join_result.unwrap_err().is_cancelled());
assert_eq!(state.next().await, Some(PaginatorState::Idle));
assert!(state.next().now_or_never().is_none());
}
}
}