1use std::sync::Arc;
16
17use eyeball::{ObservableWriteGuard, SharedObservable, Subscriber};
18use eyeball_im::{ObservableVector, VectorDiff, VectorSubscriberBatchedStream};
19use futures_util::future::join_all;
20use imbl::Vector;
21use matrix_sdk::{
22 Result, Room,
23 deserialized_responses::TimelineEvent,
24 event_cache::{RoomEventCacheSubscriber, RoomEventCacheUpdate},
25 locks::Mutex,
26 paginators::PaginationToken,
27 room::ListThreadsOptions,
28 task_monitor::BackgroundTaskHandle,
29};
30use matrix_sdk_common::serde_helpers::extract_thread_root;
31use ruma::{MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedUserId};
32use tokio::sync::Mutex as AsyncMutex;
33use tracing::{error, trace, warn};
34
35use crate::timeline::{Profile, TimelineDetails, TimelineItemContent, traits::RoomDataProvider};
36
37#[derive(Clone, Debug)]
47pub struct ThreadListItem {
48 pub root_event: ThreadListItemEvent,
50
51 pub latest_event: Option<ThreadListItemEvent>,
57
58 pub num_replies: u32,
63}
64
65#[derive(Clone, Debug)]
68pub struct ThreadListItemEvent {
69 pub event_id: OwnedEventId,
71
72 pub timestamp: MilliSecondsSinceUnixEpoch,
74
75 pub sender: OwnedUserId,
77
78 pub is_own: bool,
80
81 pub sender_profile: TimelineDetails<Profile>,
83
84 pub content: Option<TimelineItemContent>,
90}
91
92#[cfg_attr(feature = "uniffi", derive(uniffi::Enum))]
94#[derive(Clone, Debug, Eq, PartialEq)]
95pub enum ThreadListPaginationState {
96 Idle {
98 end_reached: bool,
101 },
102 Loading,
104}
105
106#[derive(Debug, thiserror::Error)]
108pub enum ThreadListServiceError {
109 #[error(transparent)]
111 Sdk(#[from] matrix_sdk::Error),
112}
113
114pub struct ThreadListService {
149 room: Room,
151
152 token: AsyncMutex<PaginationToken>,
154
155 pagination_state: SharedObservable<ThreadListPaginationState>,
157
158 items: Arc<Mutex<ObservableVector<ThreadListItem>>>,
160
161 _event_cache_task: BackgroundTaskHandle,
164}
165
166impl ThreadListService {
167 pub fn new(room: Room) -> Self {
173 let items: Arc<Mutex<ObservableVector<ThreadListItem>>> =
174 Arc::new(Mutex::new(ObservableVector::new()));
175
176 if let Err(e) = room.client().event_cache().subscribe() {
179 warn!("ThreadListService: failed to subscribe event cache to sync: {e}");
180 }
181
182 let event_cache_task = room
183 .client()
184 .task_monitor()
185 .spawn_infinite_task("thread_list_service::event_cache_listener", {
186 let room = room.clone();
187 let items = items.clone();
188 async move {
189 let (_event_cache_drop, mut subscriber) = match async {
191 let (room_event_cache, drop_handles) = room.event_cache().await?;
192 let (_, subscriber) = room_event_cache.subscribe().await?;
193 matrix_sdk::event_cache::Result::Ok((drop_handles, subscriber))
194 }
195 .await
196 {
197 Ok(pair) => pair,
198 Err(e) => {
199 error!(
200 "ThreadListService: failed to subscribe to room event cache, \
201 live updates will not work: {e}"
202 );
203 return;
204 }
205 };
206
207 trace!("ThreadListService: event cache listener started");
208
209 Self::event_cache_listener_loop(&room, &mut subscriber, items).await;
210 }
211 })
212 .abort_on_drop();
213
214 Self {
215 room,
216 token: AsyncMutex::new(PaginationToken::None),
217 pagination_state: SharedObservable::new(ThreadListPaginationState::Idle {
218 end_reached: false,
219 }),
220 items,
221 _event_cache_task: event_cache_task,
222 }
223 }
224
225 pub fn pagination_state(&self) -> ThreadListPaginationState {
227 self.pagination_state.get()
228 }
229
230 pub fn subscribe_to_pagination_state_updates(&self) -> Subscriber<ThreadListPaginationState> {
235 self.pagination_state.subscribe()
236 }
237
238 pub fn items(&self) -> Vec<ThreadListItem> {
240 self.items.lock().iter().cloned().collect()
241 }
242
243 pub fn subscribe_to_items_updates(
248 &self,
249 ) -> (Vector<ThreadListItem>, VectorSubscriberBatchedStream<ThreadListItem>) {
250 self.items.lock().subscribe().into_values_and_batched_stream()
251 }
252
253 pub async fn paginate(&self) -> Result<(), ThreadListServiceError> {
261 {
263 let mut pagination_state = self.pagination_state.write();
264
265 match *pagination_state {
266 ThreadListPaginationState::Idle { end_reached: true }
267 | ThreadListPaginationState::Loading => return Ok(()),
268 _ => {}
269 }
270
271 ObservableWriteGuard::set(&mut pagination_state, ThreadListPaginationState::Loading);
272 }
273
274 let mut pagination_token = self.token.lock().await;
275
276 let from = match &*pagination_token {
278 PaginationToken::HasMore(token) => Some(token.clone()),
279 _ => None,
280 };
281
282 let opts = ListThreadsOptions { from, ..Default::default() };
283
284 match self.load_thread_list(opts).await {
285 Ok(thread_list) => {
286 *pagination_token = match &thread_list.prev_batch_token {
288 Some(token) => PaginationToken::HasMore(token.clone()),
289 None => PaginationToken::HitEnd,
290 };
291
292 let end_reached = thread_list.prev_batch_token.is_none();
293
294 self.items.lock().append(thread_list.items.into());
296
297 self.pagination_state.set(ThreadListPaginationState::Idle { end_reached });
298
299 Ok(())
300 }
301 Err(err) => {
302 self.pagination_state.set(ThreadListPaginationState::Idle { end_reached: false });
303 Err(ThreadListServiceError::Sdk(err))
304 }
305 }
306 }
307
308 pub async fn reset(&self) {
315 let mut pagination_token = self.token.lock().await;
316 *pagination_token = PaginationToken::None;
317
318 self.items.lock().clear();
319
320 self.pagination_state.set(ThreadListPaginationState::Idle { end_reached: false });
321 }
322
323 async fn load_thread_list(&self, opts: ListThreadsOptions) -> Result<ThreadList> {
324 let thread_roots = self.room.list_threads(opts).await?;
325
326 let list_items = join_all(
327 thread_roots
328 .chunk
329 .into_iter()
330 .map(|timeline_event| Self::build_thread_list_item(&self.room, timeline_event))
331 .collect::<Vec<_>>(),
332 )
333 .await
334 .into_iter()
335 .flatten()
336 .collect();
337
338 Ok(ThreadList { items: list_items, prev_batch_token: thread_roots.prev_batch_token })
339 }
340
341 async fn build_thread_list_item(
342 room: &Room,
343 timeline_event: TimelineEvent,
344 ) -> Option<ThreadListItem> {
345 let thread_summary = timeline_event.thread_summary.summary().cloned();
347 let bundled_latest_thread_event = timeline_event.bundled_latest_thread_event.clone();
348
349 let root_event = Self::build_event(room, timeline_event).await?;
351
352 let num_replies = thread_summary.as_ref().map(|s| s.num_replies).unwrap_or(0);
354
355 let latest_event = if let Some(ev) = bundled_latest_thread_event.map(|b| *b) {
356 Self::build_event(room, ev).await
357 } else {
358 None
359 };
360
361 Some(ThreadListItem { root_event, latest_event, num_replies })
362 }
363
364 async fn build_event(
366 room: &Room,
367 timeline_event: TimelineEvent,
368 ) -> Option<ThreadListItemEvent> {
369 let event_id = timeline_event.event_id()?;
370 let timestamp = timeline_event.timestamp()?;
371 let sender = timeline_event.sender()?;
372 let is_own = room.own_user_id() == sender;
373 let sender_profile =
374 TimelineDetails::from_initial_value(Profile::load(room, &sender).await);
375 let content = TimelineItemContent::from_event(room, timeline_event).await;
376 Some(ThreadListItemEvent { event_id, timestamp, sender, is_own, sender_profile, content })
377 }
378
379 async fn event_cache_listener_loop(
385 room: &Room,
386 subscriber: &mut RoomEventCacheSubscriber,
387 items: Arc<Mutex<ObservableVector<ThreadListItem>>>,
388 ) {
389 use tokio::sync::broadcast::error::RecvError;
390
391 loop {
392 let update = match subscriber.recv().await {
393 Ok(update) => update,
394 Err(RecvError::Closed) => {
395 error!("ThreadListService: event cache channel closed, stopping listener");
396 break;
397 }
398 Err(RecvError::Lagged(n)) => {
399 warn!("ThreadListService: lagged behind {n} event cache updates");
400 continue;
401 }
402 };
403
404 if let RoomEventCacheUpdate::UpdateTimelineEvents(timeline_diffs) = update {
405 let new_events = Self::collect_events_from_diffs(timeline_diffs.diffs);
406
407 for event in new_events {
408 let Some(thread_root) = extract_thread_root(event.raw()) else { continue };
410
411 let position = {
413 let guard = items.lock();
414 guard.iter().position(|item| item.root_event.event_id == thread_root)
415 };
416
417 if let Some(index) = position {
418 if let Some(latest_event) = Self::build_event(room, event).await {
420 let mut guard = items.lock();
421
422 if index < guard.len()
425 && guard[index].root_event.event_id == thread_root
426 {
427 let mut updated = guard[index].clone();
428 updated.latest_event = Some(latest_event);
429 updated.num_replies = updated.num_replies.saturating_add(1);
430 guard.set(index, updated);
431 }
432 }
433 }
434 }
435 }
436 }
437 }
438
439 fn collect_events_from_diffs(
441 diffs: Vec<VectorDiff<matrix_sdk_base::event_cache::Event>>,
442 ) -> Vec<matrix_sdk_base::event_cache::Event> {
443 let mut events = Vec::new();
444
445 for diff in diffs {
446 match diff {
447 VectorDiff::Append { values } => events.extend(values),
448 VectorDiff::PushBack { value }
449 | VectorDiff::PushFront { value }
450 | VectorDiff::Insert { value, .. }
451 | VectorDiff::Set { value, .. } => events.push(value),
452 VectorDiff::Reset { values } => events.extend(values),
453 VectorDiff::Clear
455 | VectorDiff::PopBack
456 | VectorDiff::PopFront
457 | VectorDiff::Remove { .. }
458 | VectorDiff::Truncate { .. } => {}
459 }
460 }
461
462 events
463 }
464}
465
466#[derive(Clone, Debug)]
469struct ThreadList {
470 pub items: Vec<ThreadListItem>,
472
473 pub prev_batch_token: Option<String>,
475}
476
477#[cfg(test)]
478mod tests {
479 use std::time::Duration;
480
481 use futures_util::pin_mut;
482 use matrix_sdk::test_utils::mocks::MatrixMockServer;
483 use matrix_sdk_test::{async_test, event_factory::EventFactory};
484 use ruma::{event_id, events::AnyTimelineEvent, room_id, serde::Raw, user_id};
485 use serde_json::json;
486 use stream_assert::{assert_next_matches, assert_pending};
487 use wiremock::ResponseTemplate;
488
489 use super::{ThreadListPaginationState, ThreadListService};
490
491 #[async_test]
492 async fn test_initial_state() {
493 let server = MatrixMockServer::new().await;
494 let service = make_service(&server).await;
495
496 assert_eq!(
497 service.pagination_state(),
498 ThreadListPaginationState::Idle { end_reached: false }
499 );
500 assert!(service.items().is_empty());
501 }
502
503 #[async_test]
504 async fn test_pagination() {
505 let server = MatrixMockServer::new().await;
506 let client = server.client_builder().build().await;
507 let room_id = room_id!("!a:b.c");
508 let sender_id = user_id!("@alice:b.c");
509
510 let f = EventFactory::new().room(room_id).sender(sender_id);
511
512 let eid1 = event_id!("$1");
513 let eid2 = event_id!("$2");
514
515 server
516 .mock_room_threads()
517 .ok(
518 vec![f.text_msg("Thread root 1").event_id(eid1).into_raw()],
519 Some("next_page_token".to_owned()),
520 )
521 .mock_once()
522 .mount()
523 .await;
524
525 server
526 .mock_room_threads()
527 .match_from("next_page_token")
528 .ok(vec![f.text_msg("Thread root 2").event_id(eid2).into_raw()], None)
529 .mock_once()
530 .mount()
531 .await;
532
533 let room = server.sync_joined_room(&client, room_id).await;
534 let service = ThreadListService::new(room);
535
536 service.paginate().await.expect("first paginate failed");
537
538 assert_eq!(
539 service.pagination_state(),
540 ThreadListPaginationState::Idle { end_reached: false }
541 );
542 assert_eq!(service.items().len(), 1);
543 assert_eq!(service.items()[0].root_event.event_id, eid1);
544
545 service.paginate().await.expect("second paginate failed");
546
547 assert_eq!(
548 service.pagination_state(),
549 ThreadListPaginationState::Idle { end_reached: true }
550 );
551 assert_eq!(service.items().len(), 2);
552 assert_eq!(service.items()[1].root_event.event_id, eid2);
553 }
554
555 #[async_test]
556 async fn test_pagination_end_reached() {
557 let server = MatrixMockServer::new().await;
558 let client = server.client_builder().build().await;
559 let room_id = room_id!("!a:b.c");
560 let sender_id = user_id!("@alice:b.c");
561 let f = EventFactory::new().room(room_id).sender(sender_id);
562 let eid1 = event_id!("$1");
563
564 server
565 .mock_room_threads()
566 .ok(vec![f.text_msg("Thread root").event_id(eid1).into_raw()], None)
567 .mock_once()
568 .mount()
569 .await;
570
571 let room = server.sync_joined_room(&client, room_id).await;
572 let service = ThreadListService::new(room);
573
574 service.paginate().await.expect("paginate failed");
575 assert_eq!(
576 service.pagination_state(),
577 ThreadListPaginationState::Idle { end_reached: true }
578 );
579 assert_eq!(service.items().len(), 1);
580
581 service.paginate().await.expect("second paginate should be a no-op");
582 assert_eq!(service.items().len(), 1);
583 assert_eq!(
584 service.pagination_state(),
585 ThreadListPaginationState::Idle { end_reached: true }
586 );
587 }
588
589 #[async_test]
594 async fn test_concurrent_pagination_is_not_possible() {
595 let server = MatrixMockServer::new().await;
596 let client = server.client_builder().build().await;
597 let room_id = room_id!("!a:b.c");
598 let sender_id = user_id!("@alice:b.c");
599 let f = EventFactory::new().room(room_id).sender(sender_id);
600 let eid1 = event_id!("$1");
601
602 let chunk: Vec<Raw<AnyTimelineEvent>> =
606 vec![f.text_msg("Thread root").event_id(eid1).into_raw()];
607 server
608 .mock_room_threads()
609 .respond_with(
610 ResponseTemplate::new(200)
611 .set_body_json(json!({ "chunk": chunk, "next_batch": null }))
612 .set_delay(Duration::from_millis(100)),
613 )
614 .expect(1)
615 .mount()
616 .await;
617
618 let room = server.sync_joined_room(&client, room_id).await;
619 let service = ThreadListService::new(room);
620
621 let (first, second) = tokio::join!(service.paginate(), service.paginate());
623
624 first.expect("first paginate should succeed");
625 second.expect("second (concurrent) paginate should succeed as a no-op");
626
627 assert_eq!(service.items().len(), 1);
629 assert_eq!(service.items()[0].root_event.event_id, eid1);
630 assert_eq!(
631 service.pagination_state(),
632 ThreadListPaginationState::Idle { end_reached: true }
633 );
634 }
635
636 #[async_test]
640 async fn test_pagination_error() {
641 let server = MatrixMockServer::new().await;
642 let client = server.client_builder().build().await;
643 let room_id = room_id!("!a:b.c");
644
645 server.mock_room_threads().error500().mock_once().mount().await;
646
647 let room = server.sync_joined_room(&client, room_id).await;
648 let service = ThreadListService::new(room);
649
650 service.paginate().await.expect_err("paginate should fail on a 500 response");
652
653 assert_eq!(
656 service.pagination_state(),
657 ThreadListPaginationState::Idle { end_reached: false }
658 );
659
660 assert!(service.items().is_empty());
662 }
663
664 #[async_test]
665 async fn test_reset() {
666 let server = MatrixMockServer::new().await;
667 let client = server.client_builder().build().await;
668 let room_id = room_id!("!a:b.c");
669 let sender_id = user_id!("@alice:b.c");
670 let f = EventFactory::new().room(room_id).sender(sender_id);
671 let eid1 = event_id!("$1");
672
673 server
674 .mock_room_threads()
675 .ok(vec![f.text_msg("Thread root").event_id(eid1).into_raw()], None)
676 .expect(2)
677 .mount()
678 .await;
679
680 let room = server.sync_joined_room(&client, room_id).await;
681 let service = ThreadListService::new(room);
682
683 service.paginate().await.expect("first paginate failed");
684 assert_eq!(service.items().len(), 1);
685 assert_eq!(
686 service.pagination_state(),
687 ThreadListPaginationState::Idle { end_reached: true }
688 );
689
690 service.reset().await;
691 assert!(service.items().is_empty());
692 assert_eq!(
693 service.pagination_state(),
694 ThreadListPaginationState::Idle { end_reached: false }
695 );
696
697 service.paginate().await.expect("paginate after reset failed");
698 assert_eq!(service.items().len(), 1);
699 }
700
701 #[async_test]
702 async fn test_pagination_state_subscriber() {
703 let server = MatrixMockServer::new().await;
704 let client = server.client_builder().build().await;
705 let room_id = room_id!("!a:b.c");
706 let sender_id = user_id!("@alice:b.c");
707 let f = EventFactory::new().room(room_id).sender(sender_id);
708 let eid1 = event_id!("$1");
709
710 server
711 .mock_room_threads()
712 .ok(
713 vec![f.text_msg("Thread root").event_id(eid1).into_raw()],
714 Some("next_token".to_owned()),
715 )
716 .mock_once()
717 .mount()
718 .await;
719
720 let room = server.sync_joined_room(&client, room_id).await;
721 let service = ThreadListService::new(room);
722
723 let subscriber = service.subscribe_to_pagination_state_updates();
724 pin_mut!(subscriber);
725
726 assert_pending!(subscriber);
727
728 service.paginate().await.expect("paginate failed");
729
730 assert_next_matches!(subscriber, ThreadListPaginationState::Idle { end_reached: false });
731 }
732
733 #[async_test]
734 async fn test_paginated_items_have_num_replies_zero_without_summary() {
735 let server = MatrixMockServer::new().await;
736 let client = server.client_builder().build().await;
737 let room_id = room_id!("!a:b.c");
738 let sender_id = user_id!("@alice:b.c");
739 let f = EventFactory::new().room(room_id).sender(sender_id);
740 let eid1 = event_id!("$1");
741
742 server
744 .mock_room_threads()
745 .ok(vec![f.text_msg("Thread root").event_id(eid1).into_raw()], None)
746 .mock_once()
747 .mount()
748 .await;
749
750 let room = server.sync_joined_room(&client, room_id).await;
751 let service = ThreadListService::new(room);
752
753 service.paginate().await.expect("paginate failed");
754
755 let items = service.items();
756 assert_eq!(items.len(), 1);
757 assert_eq!(items[0].num_replies, 0);
758 assert!(items[0].latest_event.is_none());
759 }
760
761 #[async_test]
762 async fn test_paginated_items_have_num_replies_from_bundled_summary() {
763 let server = MatrixMockServer::new().await;
764 let client = server.client_builder().build().await;
765 let room_id = room_id!("!a:b.c");
766 let sender_id = user_id!("@alice:b.c");
767 let f = EventFactory::new().room(room_id).sender(sender_id);
768 let root_id = event_id!("$root");
769 let reply_id = event_id!("$reply");
770
771 let reply_event =
775 f.text_msg("Reply in thread").event_id(reply_id).into_raw_sync().cast_unchecked();
776
777 let thread_root = f
779 .text_msg("Thread root")
780 .event_id(root_id)
781 .with_bundled_thread_summary(reply_event, 3, false)
782 .into_raw();
783
784 server.mock_room_threads().ok(vec![thread_root], None).mock_once().mount().await;
785
786 let room = server.sync_joined_room(&client, room_id).await;
787 let service = ThreadListService::new(room);
788
789 service.paginate().await.expect("paginate failed");
790
791 let items = service.items();
792 assert_eq!(items.len(), 1);
793 assert_eq!(items[0].root_event.event_id, root_id);
794 assert_eq!(items[0].num_replies, 3);
795
796 let latest = items[0].latest_event.as_ref().expect("should have latest_event");
798 assert_eq!(latest.event_id, reply_id);
799 assert_eq!(latest.sender.as_str(), sender_id.as_str());
800 }
801
802 async fn make_service(server: &MatrixMockServer) -> ThreadListService {
805 let client = server.client_builder().build().await;
806 let room_id = room_id!("!a:b.c");
807 let room = server.sync_joined_room(&client, room_id).await;
808 ThreadListService::new(room)
809 }
810}