1use std::sync::Arc;
33
34use eyeball_im::VectorDiff;
35use matrix_sdk_base::{
36 deserialized_responses::TimelineEvent,
37 event_cache::{Event, Gap},
38 linked_chunk::OwnedLinkedChunkId,
39};
40use matrix_sdk_common::{linked_chunk::ChunkIdentifier, serde_helpers::extract_thread_root};
41use ruma::{OwnedEventId, UInt, api::Direction};
42use tokio::sync::broadcast::{Receiver, Sender};
43use tracing::{instrument, trace};
44
45#[cfg(feature = "e2e-encryption")]
46use super::super::redecryptor::ResolvedUtd;
47use super::{
48 super::{
49 EventCacheError, EventsOrigin, Result, RoomEventCacheLinkedChunkUpdate,
50 states::{
51 CacheStateLock, ReloadPreprocessing, StateLock, selectors::EventFocusedStateSelector,
52 },
53 },
54 TimelineVectorDiffs,
55 event_linked_chunk::EventLinkedChunk,
56};
57use crate::{
58 Room,
59 paginators::{PaginationResult, Paginator, StartFromResult, thread::PaginableThread},
60 room::{IncludeRelations, MessagesOptions, RelationsOptions, WeakRoom},
61};
62
63#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
66pub enum EventFocusThreadMode {
67 ForceThread,
75
76 Automatic,
81}
82
83#[derive(Debug, Clone)]
85pub(crate) enum EventFocusedPaginationMode {
86 Room { hide_thread_events: bool },
89
90 Thread {
92 thread_root: OwnedEventId,
94 },
95}
96
97pub struct EventFocusedCacheState {
98 room: WeakRoom,
100
101 focused_event_id: OwnedEventId,
103
104 pagination_mode: EventFocusedPaginationMode,
106
107 chunk: EventLinkedChunk,
109
110 initial_num_context_events: u16,
114
115 thread_mode: EventFocusThreadMode,
117
118 pub update_sender: EventFocusedCacheUpdateSender,
120
121 linked_chunk_update_sender: Sender<RoomEventCacheLinkedChunkUpdate>,
123}
124
125impl EventFocusedCacheState {
126 #[instrument(skip(self), fields(room_id = %self.room.room_id(), event_id = %self.focused_event_id))]
137 async fn start_from(
138 &mut self,
139 num_context_events: u16,
140 thread_mode: EventFocusThreadMode,
141 ) -> Result<StartFromResult> {
142 self.initial_num_context_events = num_context_events;
143 self.thread_mode = thread_mode;
144
145 let result = self.reload_impl().await?;
146
147 let _ = self.chunk.updates_as_vector_diffs();
155
156 Ok(result)
157 }
158
159 #[must_use = "Propagate `VectorDiff` updates via `TimelineVectorDiffs`"]
165 pub async fn reload(
166 &mut self,
167 _preprocessing: ReloadPreprocessing,
168 ) -> Result<Vec<VectorDiff<Event>>> {
169 let _ = self.reload_impl().await?;
170
171 Ok(self.chunk.updates_as_vector_diffs())
172 }
173
174 async fn reload_impl(&mut self) -> Result<StartFromResult> {
177 let room = self.room.get().ok_or(EventCacheError::ClientDropped)?;
178 let num_context_events = self.initial_num_context_events;
179 let thread_mode = self.thread_mode;
180
181 trace!(num_context_events, "fetching event with context via /context");
182
183 let paginator = Paginator::new(room);
184
185 let result =
186 paginator.start_from(&self.focused_event_id, UInt::from(num_context_events)).await?;
187
188 let thread_root = match thread_mode {
190 EventFocusThreadMode::ForceThread => {
191 let focused_event = result
193 .events
194 .iter()
195 .find(|event| event.event_id() == Some(&self.focused_event_id));
196
197 let mut thread_root =
199 focused_event.and_then(|event| extract_thread_root(event.raw()));
200
201 if thread_root.is_none() {
204 thread_root = Some(self.focused_event_id.clone());
205 }
206
207 trace!("force thread mode enabled, treating focused event as thread root");
208 thread_root
209 }
210
211 EventFocusThreadMode::Automatic => {
212 trace!(
213 "automatic thread mode enabled, checking if focused event is part of a thread"
214 );
215 result
216 .events
217 .iter()
218 .find(|event| event.event_id() == Some(&self.focused_event_id))
219 .and_then(|event| extract_thread_root(event.raw()))
220 }
221 };
222
223 let tokens = paginator.tokens();
225
226 if let Some(root_id) = thread_root {
227 trace!(thread_root = %root_id, "focused event is part of a thread, setting up thread pagination");
228
229 let includes_root =
233 result.events.iter().any(|event| event.event_id() == Some(&root_id));
234
235 self.pagination_mode =
236 EventFocusedPaginationMode::Thread { thread_root: root_id.clone() };
237
238 let thread_events = result
240 .events
241 .iter()
242 .filter(|event| {
243 extract_thread_root(event.raw()).as_ref() == Some(&root_id)
244 || event.event_id() == Some(&root_id)
245 })
246 .cloned()
247 .collect();
248
249 let backward_token = if includes_root {
251 None
253 } else {
254 tokens.previous.into_token()
255 };
256
257 let forward_token = tokens.next.into_token();
259
260 self.add_initial_events_with_gaps(thread_events, backward_token, forward_token);
261 } else {
262 trace!("focused event is not part of a thread, setting up room pagination");
263
264 let backward_token = tokens.previous.into_token();
265 let forward_token = tokens.next.into_token();
266
267 let hide_thread_events =
268 matches!(thread_mode, EventFocusThreadMode::Automatic) && thread_root.is_none();
269
270 self.pagination_mode = EventFocusedPaginationMode::Room { hide_thread_events };
271
272 let events = if hide_thread_events {
273 result
274 .events
275 .iter()
276 .filter(|event| extract_thread_root(event.raw()).is_none())
277 .cloned()
278 .collect()
279 } else {
280 result.events.clone()
281 };
282
283 self.add_initial_events_with_gaps(events, backward_token, forward_token);
284 }
285
286 self.propagate_changes();
287
288 Ok(result)
289 }
290
291 fn add_initial_events_with_gaps(
293 &mut self,
294 events: Vec<TimelineEvent>,
295 prev_gap_token: Option<String>,
296 next_gap_token: Option<String>,
297 ) {
298 self.chunk.reset();
300
301 self.chunk
304 .push_live_events(prev_gap_token.map(|prev_token| Gap { token: prev_token }), &events);
305
306 if let Some(next_token) = next_gap_token {
308 trace!("inserting forward pagination gap at back");
309 self.chunk.push_gap(Gap { token: next_token });
310 }
311 }
312
313 fn propagate_changes(&mut self) {
315 let updates = self.chunk.store_updates().take();
316 if !updates.is_empty() {
317 let _ = self.linked_chunk_update_sender.send(RoomEventCacheLinkedChunkUpdate {
318 updates,
319 linked_chunk_id: OwnedLinkedChunkId::EventFocused(
320 self.room.room_id().to_owned(),
321 self.focused_event_id.clone(),
322 ),
323 });
324 }
325 }
326
327 fn notify_subscribers(&mut self, origin: EventsOrigin) {
329 let diffs = self.chunk.updates_as_vector_diffs();
330 if !diffs.is_empty() {
331 let _ = self.update_sender.send(TimelineVectorDiffs { diffs, origin });
332 }
333 }
334
335 fn first_chunk_as_gap(&self) -> Option<(ChunkIdentifier, Gap)> {
337 self.chunk.first_chunk_as_gap()
338 }
339
340 fn last_chunk_as_gap(&self) -> Option<(ChunkIdentifier, Gap)> {
342 self.chunk.last_chunk_as_gap()
343 }
344
345 #[instrument(skip(self), fields(room_id = %self.room.room_id()))]
351 async fn paginate_backwards(&mut self, num_events: u16) -> Result<PaginationResult> {
352 let room = self.room.get().ok_or(EventCacheError::ClientDropped)?;
353
354 let Some((gap_id, gap)) = self.first_chunk_as_gap() else {
356 trace!("no front gap found, already at timeline start");
358 return Ok(PaginationResult { events: Vec::new(), hit_end_of_timeline: true });
359 };
360
361 let token = gap.token;
362 trace!(?token, "paginating backwards with token from front gap");
363
364 let (mut events, new_token) = match &self.pagination_mode {
366 EventFocusedPaginationMode::Room { .. } => {
367 Self::fetch_room_backwards(&room, num_events, &token).await?
368 }
369 EventFocusedPaginationMode::Thread { thread_root } => {
370 Self::fetch_thread_backwards(&room, num_events, &token, thread_root.clone()).await?
371 }
372 };
373
374 events.reverse();
377
378 let hit_end = new_token.is_none();
379 let new_gap = new_token.map(|t| Gap { token: t });
380
381 let hide_thread_events = match &self.pagination_mode {
382 EventFocusedPaginationMode::Room { hide_thread_events } => *hide_thread_events,
383 EventFocusedPaginationMode::Thread { .. } => false,
384 };
385
386 let events = if hide_thread_events {
387 events.into_iter().filter(|event| extract_thread_root(event.raw()).is_none()).collect()
388 } else {
389 events
390 };
391
392 self.chunk.push_backwards_pagination_events(Some(gap_id), new_gap, &events);
394
395 self.propagate_changes();
396 self.notify_subscribers(EventsOrigin::Pagination);
397
398 Ok(PaginationResult { events, hit_end_of_timeline: hit_end })
399 }
400
401 async fn fetch_room_backwards(
407 room: &Room,
408 num_events: u16,
409 token: &str,
410 ) -> Result<(Vec<Event>, Option<String>)> {
411 let mut options = MessagesOptions::backward().from(token);
412 options.limit = UInt::from(num_events);
413
414 let messages = room
415 .messages(options)
416 .await
417 .map_err(|err| EventCacheError::PaginationError(Arc::new(err)))?;
418
419 Ok((messages.chunk, messages.end))
420 }
421
422 async fn fetch_thread_backwards(
427 room: &Room,
428 num_events: u16,
429 token: &str,
430 thread_root: OwnedEventId,
431 ) -> Result<(Vec<Event>, Option<String>)> {
432 let options = RelationsOptions {
433 from: Some(token.to_owned()),
434 dir: Direction::Backward,
435 limit: Some(UInt::from(num_events)),
436 include_relations: IncludeRelations::AllRelations,
437 recurse: true,
438 };
439
440 let mut result = room
441 .relations(thread_root.clone(), options)
442 .await
443 .map_err(|err| EventCacheError::PaginationError(Arc::new(err)))?;
444
445 if result.next_batch_token.is_none() {
447 let root_event = room
448 .load_event(&thread_root)
449 .await
450 .map_err(|err| EventCacheError::PaginationError(Arc::new(err)))?;
451 result.chunk.push(root_event);
452 }
453
454 Ok((result.chunk, result.next_batch_token))
455 }
456
457 #[instrument(skip(self), fields(room_id = %self.room.room_id()))]
463 async fn paginate_forwards(&mut self, num_events: u16) -> Result<PaginationResult> {
464 let room = self.room.get().ok_or(EventCacheError::ClientDropped)?;
465
466 let Some((gap_id, gap)) = self.last_chunk_as_gap() else {
468 trace!("no back gap found, already at timeline end");
470 return Ok(PaginationResult { events: Vec::new(), hit_end_of_timeline: true });
471 };
472
473 let token = gap.token;
474 trace!(?token, "paginating forwards with token from back gap");
475
476 let (events, new_token) = match &self.pagination_mode {
478 EventFocusedPaginationMode::Room { .. } => {
479 Self::fetch_room_forwards(&room, num_events, &token).await?
480 }
481 EventFocusedPaginationMode::Thread { thread_root } => {
482 Self::fetch_thread_forwards(&room, num_events, &token, thread_root.clone()).await?
483 }
484 };
485
486 let hit_end = new_token.is_none();
487 let new_gap = new_token.map(|t| Gap { token: t });
488
489 let hide_thread_events = match &self.pagination_mode {
490 EventFocusedPaginationMode::Room { hide_thread_events } => *hide_thread_events,
491 EventFocusedPaginationMode::Thread { .. } => false,
492 };
493
494 let events = if hide_thread_events {
495 events.into_iter().filter(|event| extract_thread_root(event.raw()).is_none()).collect()
496 } else {
497 events
498 };
499
500 self.chunk.push_forwards_pagination_events(Some(gap_id), new_gap, &events);
502
503 self.propagate_changes();
504 self.notify_subscribers(EventsOrigin::Pagination);
505
506 Ok(PaginationResult { events, hit_end_of_timeline: hit_end })
507 }
508
509 async fn fetch_room_forwards(
511 room: &Room,
512 num_events: u16,
513 token: &str,
514 ) -> Result<(Vec<Event>, Option<String>)> {
515 let mut options = MessagesOptions::new(Direction::Forward);
516 options = options.from(Some(token));
517 options.limit = UInt::from(num_events);
518
519 let messages = room
520 .messages(options)
521 .await
522 .map_err(|err| EventCacheError::PaginationError(Arc::new(err)))?;
523
524 Ok((messages.chunk, messages.end))
525 }
526
527 async fn fetch_thread_forwards(
529 room: &Room,
530 num_events: u16,
531 token: &str,
532 thread_root: OwnedEventId,
533 ) -> Result<(Vec<Event>, Option<String>)> {
534 let options = RelationsOptions {
535 from: Some(token.to_owned()),
536 dir: Direction::Forward,
537 limit: Some(UInt::from(num_events)),
538 include_relations: IncludeRelations::AllRelations,
539 recurse: true,
540 };
541
542 let result = room
543 .relations(thread_root, options)
544 .await
545 .map_err(|err| EventCacheError::PaginationError(Arc::new(err)))?;
546
547 Ok((result.chunk, result.next_batch_token))
548 }
549}
550
551#[derive(Clone)]
566pub struct EventFocusedCache {
567 inner: Arc<CacheStateLock<EventFocusedStateSelector>>,
568}
569
570impl EventFocusedCache {
571 pub(super) async fn new(
573 room: WeakRoom,
574 key: EventFocusedCacheKey,
575 state: &StateLock,
576 linked_chunk_update_sender: Sender<RoomEventCacheLinkedChunkUpdate>,
577 ) -> Result<Self> {
578 let cache_state = state
579 .try_insert_once_with(
580 EventFocusedStateSelector::new(room.room_id().to_owned(), key.clone()),
581 |_store_guard| async {
582 Ok(EventFocusedCacheState {
583 room,
584 focused_event_id: key.focused_event_id,
585 pagination_mode: EventFocusedPaginationMode::Room {
586 hide_thread_events: false,
587 },
588 chunk: EventLinkedChunk::new(),
589 initial_num_context_events: 0, thread_mode: EventFocusThreadMode::Automatic, update_sender: Sender::new(32),
592 linked_chunk_update_sender,
593 })
594 },
595 )
596 .await?;
597
598 Ok(Self { inner: Arc::new(cache_state) })
599 }
600
601 pub async fn events(&self) -> Result<Vec<Event>> {
606 let state = self.inner.read().await?;
607
608 Ok(state.chunk.events().map(|(_position, item)| item.clone()).collect())
609 }
610
611 pub async fn subscribe(&self) -> Result<(Vec<Event>, Receiver<TimelineVectorDiffs>)> {
613 let state = self.inner.read().await?;
614 let events = state.chunk.events().map(|(_position, item)| item.clone()).collect();
615 let recv = state.update_sender.subscribe();
616 Ok((events, recv))
617 }
618
619 pub async fn hit_timeline_start(&self) -> Result<bool> {
622 Ok(self.inner.read().await?.first_chunk_as_gap().is_none())
623 }
624
625 pub async fn hit_timeline_end(&self) -> Result<bool> {
628 Ok(self.inner.read().await?.last_chunk_as_gap().is_none())
629 }
630
631 pub(super) async fn start_from(
634 &self,
635 num_context_events: u16,
636 thread_mode: EventFocusThreadMode,
637 ) -> Result<StartFromResult> {
638 self.inner.write().await?.start_from(num_context_events, thread_mode).await
639 }
640
641 pub async fn paginate_backwards(&self, num_events: u16) -> Result<PaginationResult> {
644 self.inner.write().await?.paginate_backwards(num_events).await
645 }
646
647 pub async fn paginate_forwards(&self, num_events: u16) -> Result<PaginationResult> {
650 self.inner.write().await?.paginate_forwards(num_events).await
651 }
652
653 pub async fn thread_root(&self) -> Result<Option<OwnedEventId>> {
655 Ok(match &self.inner.read().await?.pagination_mode {
656 EventFocusedPaginationMode::Thread { thread_root } => Some(thread_root.clone()),
657 _ => None,
658 })
659 }
660
661 #[cfg(feature = "e2e-encryption")]
665 pub async fn replace_utds(&self, events: &[ResolvedUtd]) -> Result<()> {
666 let mut guard = self.inner.write().await?;
667
668 if guard.chunk.replace_utds(events) {
669 guard.propagate_changes();
670 guard.notify_subscribers(EventsOrigin::Cache);
671 }
672
673 Ok(())
674 }
675}
676
677#[cfg(not(tarpaulin_include))]
678impl std::fmt::Debug for EventFocusedCache {
679 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
680 f.debug_struct("EventFocusedCache").finish_non_exhaustive()
681 }
682}
683
684#[derive(Clone, Debug, Hash, PartialEq, Eq)]
686pub struct EventFocusedCacheKey {
687 pub focused_event_id: OwnedEventId,
689 pub thread_mode: EventFocusThreadMode,
691}
692
693pub type EventFocusedCacheUpdateSender = Sender<TimelineVectorDiffs>;