1use std::collections::HashSet;
71
72use matrix_sdk_base::{RoomStateFilter, deserialized_responses::TimelineEvent};
73use matrix_sdk_search::error::IndexError;
74#[cfg(doc)]
75use matrix_sdk_search::index::RoomIndex;
76use ruma::{OwnedEventId, OwnedRoomId};
77
78use crate::{Client, Room};
79
80impl Room {
81 pub async fn search(
84 &self,
85 query: &str,
86 max_number_of_results: usize,
87 pagination_offset: Option<usize>,
88 ) -> Result<Vec<(f32, OwnedEventId)>, IndexError> {
89 let mut search_index_guard = self.client.search_index().lock().await;
90 search_index_guard.search(query, max_number_of_results, pagination_offset, self.room_id())
91 }
92}
93
94#[derive(thiserror::Error, Debug)]
97pub enum SearchError {
98 #[error(transparent)]
100 IndexError(#[from] IndexError),
101 #[error(transparent)]
103 EventLoadError(#[from] crate::Error),
104}
105
106impl Room {
107 pub fn search_messages(
110 &self,
111 query: String,
112 num_results_per_batch: usize,
113 ) -> RoomSearchIterator {
114 RoomSearchIterator {
115 room: self.clone(),
116 query,
117 offset: None,
118 is_done: false,
119 num_results_per_batch,
120 }
121 }
122}
123
124#[derive(Debug)]
126pub struct RoomSearchIterator {
127 room: Room,
129
130 query: String,
132
133 offset: Option<usize>,
136
137 is_done: bool,
139
140 num_results_per_batch: usize,
143}
144
145impl RoomSearchIterator {
146 pub async fn next(&mut self) -> Result<Option<Vec<OwnedEventId>>, IndexError> {
149 if self.is_done {
150 return Ok(None);
151 }
152
153 let result = self.room.search(&self.query, self.num_results_per_batch, self.offset).await?;
156
157 if result.is_empty() {
158 self.is_done = true;
159 Ok(None)
160 } else {
161 self.offset = Some(self.offset.unwrap_or(0) + result.len());
162 Ok(Some(result.into_iter().map(|(_, id)| id).collect()))
163 }
164 }
165
166 pub async fn next_events(&mut self) -> Result<Option<Vec<TimelineEvent>>, SearchError> {
169 let Some(event_ids) = self.next().await? else {
170 return Ok(None);
171 };
172 let mut results = Vec::new();
173 for event_id in event_ids {
174 results.push(self.room.load_or_fetch_event(&event_id, None).await?);
175 }
176 Ok(Some(results))
177 }
178}
179
180#[derive(Debug)]
181struct GlobalSearchRoomState {
182 room: Room,
184
185 offset: Option<usize>,
188}
189
190impl GlobalSearchRoomState {
191 fn new(room: Room) -> Self {
192 Self { room, offset: None }
193 }
194}
195
196#[derive(Debug)]
199pub struct GlobalSearchBuilder {
200 client: Client,
201
202 query: String,
204
205 num_results_per_batch: usize,
208
209 room_set: Vec<Room>,
211}
212
213impl GlobalSearchBuilder {
214 fn new(client: Client, query: String, num_results_per_batch: usize) -> Self {
216 let room_set = client.rooms_filtered(RoomStateFilter::JOINED);
217 Self { client, query, room_set, num_results_per_batch }
218 }
219
220 pub async fn only_dm_rooms(mut self) -> Result<Self, crate::Error> {
222 let mut to_remove = HashSet::new();
223 for room in &self.room_set {
224 if !room.compute_is_dm().await? {
225 to_remove.insert(room.room_id().to_owned());
226 }
227 }
228 self.room_set.retain(|room| !to_remove.contains(room.room_id()));
229 Ok(self)
230 }
231
232 pub async fn no_dms(mut self) -> Result<Self, crate::Error> {
234 let mut to_remove = HashSet::new();
235 for room in &self.room_set {
236 if room.compute_is_dm().await? {
237 to_remove.insert(room.room_id().to_owned());
238 }
239 }
240 self.room_set.retain(|room| !to_remove.contains(room.room_id()));
241 Ok(self)
242 }
243
244 pub fn build(self) -> GlobalSearchIterator {
246 GlobalSearchIterator {
247 client: self.client,
248 query: self.query,
249 room_state: Vec::from_iter(self.room_set.into_iter().map(GlobalSearchRoomState::new)),
250 current_batch: Vec::new(),
251 num_results_per_batch: self.num_results_per_batch,
252 }
253 }
254}
255
256impl Client {
257 pub fn search_messages(
260 &self,
261 query: String,
262 num_results_per_batch: usize,
263 ) -> GlobalSearchBuilder {
264 GlobalSearchBuilder::new(self.clone(), query, num_results_per_batch)
265 }
266}
267
268#[derive(Debug)]
270pub struct GlobalSearchIterator {
271 client: Client,
272
273 query: String,
275
276 room_state: Vec<GlobalSearchRoomState>,
282
283 current_batch: Vec<(f32, OwnedRoomId, OwnedEventId)>,
286
287 num_results_per_batch: usize,
290}
291
292impl GlobalSearchIterator {
293 pub async fn next(&mut self) -> Result<Option<Vec<(OwnedRoomId, OwnedEventId)>>, SearchError> {
296 if self.room_state.is_empty() {
297 return Ok(None);
298 }
299
300 if self.current_batch.len() >= self.num_results_per_batch {
303 return Ok(Some(
304 self.current_batch
305 .drain(0..self.num_results_per_batch)
306 .map(|(_, room_id, event_id)| (room_id, event_id))
307 .collect(),
308 ));
309 }
310
311 let mut to_remove = HashSet::new();
312
313 for room_state in &mut self.room_state {
316 let room_results = room_state
317 .room
318 .search(&self.query, self.num_results_per_batch, room_state.offset)
319 .await?;
320
321 if room_results.is_empty() {
322 to_remove.insert(room_state.room.room_id().to_owned());
324 } else {
325 room_state.offset = Some(room_state.offset.unwrap_or(0) + room_results.len());
327
328 self.current_batch.extend(room_results.into_iter().map(|(score, event_id)| {
330 (score, room_state.room.room_id().to_owned(), event_id)
331 }));
332
333 if self.current_batch.len() >= self.num_results_per_batch {
334 break;
336 }
337 }
338 }
339
340 for room_id in to_remove {
342 self.room_state.retain(|room_state| room_state.room.room_id() != room_id);
343 }
344
345 if !self.current_batch.is_empty() {
346 self.current_batch.sort_unstable_by(|a, b| b.0.total_cmp(&a.0));
349 let high = self.num_results_per_batch.min(self.current_batch.len());
350 Ok(Some(
351 self.current_batch
352 .drain(0..high)
353 .map(|(_, room_id, event_id)| (room_id, event_id))
354 .collect(),
355 ))
356 } else {
357 debug_assert!(self.room_state.is_empty());
358 Ok(None)
359 }
360 }
361
362 pub async fn next_events(
365 &mut self,
366 ) -> Result<Option<Vec<(OwnedRoomId, TimelineEvent)>>, SearchError> {
367 let Some(event_ids) = self.next().await? else {
368 return Ok(None);
369 };
370 let mut results = Vec::with_capacity(event_ids.len());
371 for (room_id, event_id) in event_ids {
372 let Some(room) = self.client.get_room(&room_id) else {
373 continue;
374 };
375 results.push((room_id, room.load_or_fetch_event(&event_id, None).await?));
376 }
377 Ok(Some(results))
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use std::time::Duration;
384
385 use matrix_sdk_test::{BOB, JoinedRoomBuilder, async_test, event_factory::EventFactory};
386 use ruma::{event_id, room_id, user_id};
387
388 use crate::{sleep::sleep, test_utils::mocks::MatrixMockServer};
389
390 #[async_test]
391 async fn test_room_message_search() {
392 let server = MatrixMockServer::new().await;
393 let client = server.client_builder().build().await;
394
395 let event_cache = client.event_cache();
396 event_cache.subscribe().unwrap();
397
398 let room_id = room_id!("!room_id:localhost");
399 let room = server.sync_joined_room(&client, room_id).await;
400
401 let f = EventFactory::new().room(room_id).sender(user_id!("@user_id:localhost"));
402
403 let event_id = event_id!("$event_id:localhost");
404
405 server
406 .sync_room(
407 &client,
408 JoinedRoomBuilder::new(room_id)
409 .add_timeline_event(f.text_msg("hello world").event_id(event_id)),
410 )
411 .await;
412
413 sleep(Duration::from_millis(200)).await;
415
416 {
418 let mut room_search = room.search_messages("search query".to_owned(), 5);
419
420 let maybe_results = room_search.next().await.unwrap();
422 assert!(maybe_results.is_none());
423
424 let maybe_results = room_search.next().await.unwrap();
427 assert!(maybe_results.is_none());
428 }
429
430 {
432 let mut room_search = room.search_messages("world".to_owned(), 5);
433
434 let maybe_results = room_search.next().await.unwrap();
437 let results = maybe_results.unwrap();
438 assert_eq!(results.len(), 1);
439 assert_eq!(&results[0], event_id,);
440
441 let maybe_results = room_search.next().await.unwrap();
443 assert!(maybe_results.is_none());
444 }
445
446 {
448 let mut room_search = room.search_messages("world".to_owned(), 5);
449
450 let maybe_results = room_search.next_events().await.unwrap();
453 let results = maybe_results.unwrap();
454 assert_eq!(results.len(), 1);
455 assert_eq!(results[0].event_id().unwrap(), event_id,);
456
457 let maybe_results = room_search.next_events().await.unwrap();
459 assert!(maybe_results.is_none());
460 }
461 }
462
463 #[async_test]
464 async fn test_global_message_search() {
465 let server = MatrixMockServer::new().await;
466 let client = server.client_builder().build().await;
467
468 let event_cache = client.event_cache();
469 event_cache.subscribe().unwrap();
470
471 let room_id1 = room_id!("!r1:localhost");
472 let room_id2 = room_id!("!r2:localhost");
473
474 let f = EventFactory::new().sender(user_id!("@user_id:localhost"));
475
476 let result_event_id1 = event_id!("$result1:localhost");
477 let result_event_id2 = event_id!("$result2:localhost");
478
479 server
480 .mock_sync()
481 .ok_and_run(&client, |sync_builder| {
482 sync_builder
483 .add_joined_room(
484 JoinedRoomBuilder::new(room_id1)
485 .add_timeline_event(
486 f.text_msg("hello world").room(room_id1).event_id(result_event_id1),
487 )
488 .add_timeline_event(f.text_msg("hello back").room(room_id1)),
489 )
490 .add_joined_room(JoinedRoomBuilder::new(room_id2).add_timeline_event(
491 f.text_msg("it's a mad world").room(room_id2).event_id(result_event_id2),
492 ));
493 })
494 .await;
495
496 sleep(Duration::from_millis(200)).await;
498
499 {
501 let mut search = client.search_messages("search query".to_owned(), 5).build();
502
503 let maybe_results = search.next().await.unwrap();
505 assert!(maybe_results.is_none());
506
507 let maybe_results = search.next().await.unwrap();
510 assert!(maybe_results.is_none());
511 }
512
513 {
515 let mut search = client.search_messages("world".to_owned(), 5).build();
516
517 let maybe_results = search.next().await.unwrap();
520 let results = maybe_results.unwrap();
521 assert_eq!(results.len(), 2);
522 assert!(results.contains(&(room_id1.to_owned(), result_event_id1.to_owned())));
525 assert!(results.contains(&(room_id2.to_owned(), result_event_id2.to_owned())));
526
527 let maybe_results = search.next().await.unwrap();
529 assert!(maybe_results.is_none());
530 }
531
532 {
534 let mut search = client.search_messages("world".to_owned(), 5).build();
535
536 let maybe_results = search.next_events().await.unwrap();
539 let results = maybe_results.unwrap();
540 assert_eq!(results.len(), 2);
541 assert!(results.iter().any(|(room_id, event)| {
544 room_id == room_id1 && event.event_id() == Some(result_event_id1)
545 }));
546 assert!(results.iter().any(|(room_id, event)| {
547 room_id == room_id2 && event.event_id() == Some(result_event_id2)
548 }));
549
550 let maybe_results = search.next_events().await.unwrap();
552 assert!(maybe_results.is_none());
553 }
554 }
555
556 #[async_test]
557 async fn test_global_message_search_score_ordering() {
558 let server = MatrixMockServer::new().await;
559 let client = server.client_builder().build().await;
560
561 let event_cache = client.event_cache();
562 event_cache.subscribe().unwrap();
563
564 let room_id1 = room_id!("!r1:localhost");
565 let room_id2 = room_id!("!r2:localhost");
566
567 let f = EventFactory::new().sender(user_id!("@user_id:localhost"));
568
569 let r1_rank1 = event_id!("$r1_rank1:localhost"); let r2_rank2 = event_id!("$r2_rank2:localhost"); let r1_rank3 = event_id!("$r1_rank3:localhost"); let r2_rank4 = event_id!("$r2_rank4:localhost"); server
582 .mock_sync()
583 .ok_and_run(&client, |sync_builder| {
584 sync_builder
585 .add_joined_room(
586 JoinedRoomBuilder::new(room_id1)
587 .add_timeline_event(
588 f.text_msg("world world world world filler filler filler filler filler filler")
589 .room(room_id1)
590 .event_id(r1_rank1),
591 )
592 .add_timeline_event(
593 f.text_msg("world world filler filler filler filler filler filler filler filler")
594 .room(room_id1)
595 .event_id(r1_rank3),
596 ),
597 )
598 .add_joined_room(
599 JoinedRoomBuilder::new(room_id2)
600 .add_timeline_event(
601 f.text_msg("world world world filler filler filler filler filler filler filler")
602 .room(room_id2)
603 .event_id(r2_rank2),
604 )
605 .add_timeline_event(
606 f.text_msg("world filler filler filler filler filler filler filler filler filler")
607 .room(room_id2)
608 .event_id(r2_rank4),
609 ),
610 );
611 })
612 .await;
613
614 sleep(Duration::from_millis(200)).await;
615
616 let mut search = client.search_messages("world".to_owned(), 10).build();
617
618 let results = search.next().await.unwrap().unwrap();
619 assert_eq!(results.len(), 4);
620
621 assert_eq!(results[0], (room_id1.to_owned(), r1_rank1.to_owned()));
623 assert_eq!(results[1], (room_id2.to_owned(), r2_rank2.to_owned()));
624 assert_eq!(results[2], (room_id1.to_owned(), r1_rank3.to_owned()));
625 assert_eq!(results[3], (room_id2.to_owned(), r2_rank4.to_owned()));
626 }
627
628 #[async_test]
629 async fn test_global_message_search_dm_or_groups() {
630 let server = MatrixMockServer::new().await;
631 let client = server.client_builder().build().await;
632
633 let event_cache = client.event_cache();
634 event_cache.subscribe().unwrap();
635
636 let room_id1 = room_id!("!r1:localhost");
638 let room_id2 = room_id!("!r2:localhost");
640
641 let f = EventFactory::new().sender(user_id!("@user_id:localhost"));
642
643 let result_event_id1 = event_id!("$result1:localhost");
644 let result_event_id2 = event_id!("$result2:localhost");
645
646 server
647 .mock_sync()
648 .ok_and_run(&client, |sync_builder| {
649 sync_builder
650 .add_joined_room(
651 JoinedRoomBuilder::new(room_id1)
652 .add_timeline_event(
653 f.text_msg("hello world").room(room_id1).event_id(result_event_id1),
654 )
655 .add_timeline_event(f.text_msg("hello back").room(room_id1)),
656 )
657 .add_joined_room(JoinedRoomBuilder::new(room_id2).add_timeline_event(
658 f.text_msg("it's a mad world").room(room_id2).event_id(result_event_id2),
659 ))
660 .add_global_account_data(
662 f.direct().add_user((*BOB).to_owned().into(), room_id1),
663 );
664 })
665 .await;
666
667 sleep(Duration::from_millis(200)).await;
669
670 {
672 let mut search = client
673 .search_messages("world".to_owned(), 5)
674 .only_dm_rooms()
675 .await
676 .unwrap()
677 .build();
678
679 let maybe_results = search.next().await.unwrap();
680 let results = maybe_results.unwrap();
681 assert_eq!(results.len(), 1);
682 assert_eq!(&results[0], &(room_id1.to_owned(), result_event_id1.to_owned()));
683
684 let maybe_results = search.next().await.unwrap();
686 assert!(maybe_results.is_none());
687 }
688
689 {
691 let mut search =
692 client.search_messages("world".to_owned(), 5).no_dms().await.unwrap().build();
693
694 let maybe_results = search.next_events().await.unwrap();
695 let results = maybe_results.unwrap();
696 assert_eq!(results.len(), 1);
697 assert_eq!(results[0].0, room_id2);
698 assert_eq!(results[0].1.event_id().unwrap(), result_event_id2);
699
700 let maybe_results = search.next().await.unwrap();
702 assert!(maybe_results.is_none());
703 }
704 }
705}