1use std::collections::HashSet;
71
72use matrix_sdk_base::{RoomStateFilter, deserialized_responses::TimelineEvent};
73use matrix_sdk_search::error::IndexError;
74#[cfg(doc)]
75use matrix_sdk_search::index::RoomIndex;
76use ruma::{OwnedEventId, OwnedRoomId};
77
78use crate::{Client, Room};
79
80impl Room {
81 pub async fn search(
84 &self,
85 query: &str,
86 max_number_of_results: usize,
87 pagination_offset: Option<usize>,
88 ) -> Result<Vec<OwnedEventId>, IndexError> {
89 let mut search_index_guard = self.client.search_index().lock().await;
90 search_index_guard.search(query, max_number_of_results, pagination_offset, self.room_id())
91 }
92}
93
94#[derive(thiserror::Error, Debug)]
97pub enum SearchError {
98 #[error(transparent)]
100 IndexError(#[from] IndexError),
101 #[error(transparent)]
103 EventLoadError(#[from] crate::Error),
104}
105
106impl Room {
107 pub fn search_messages(
110 &self,
111 query: String,
112 num_results_per_batch: usize,
113 ) -> RoomSearchIterator {
114 RoomSearchIterator {
115 room: self.clone(),
116 query,
117 offset: None,
118 is_done: false,
119 num_results_per_batch,
120 }
121 }
122}
123
124#[derive(Debug)]
126pub struct RoomSearchIterator {
127 room: Room,
129
130 query: String,
132
133 offset: Option<usize>,
136
137 is_done: bool,
139
140 num_results_per_batch: usize,
143}
144
145impl RoomSearchIterator {
146 pub async fn next(&mut self) -> Result<Option<Vec<OwnedEventId>>, IndexError> {
149 if self.is_done {
150 return Ok(None);
151 }
152
153 let result = self.room.search(&self.query, self.num_results_per_batch, self.offset).await?;
156
157 if result.is_empty() {
158 self.is_done = true;
159 Ok(None)
160 } else {
161 self.offset = Some(self.offset.unwrap_or(0) + result.len());
162 Ok(Some(result))
163 }
164 }
165
166 pub async fn next_events(&mut self) -> Result<Option<Vec<TimelineEvent>>, SearchError> {
169 let Some(event_ids) = self.next().await? else {
170 return Ok(None);
171 };
172 let mut results = Vec::new();
173 for event_id in event_ids {
174 results.push(self.room.load_or_fetch_event(&event_id, None).await?);
175 }
176 Ok(Some(results))
177 }
178}
179
180#[derive(Debug)]
181struct GlobalSearchRoomState {
182 room: Room,
184
185 offset: Option<usize>,
188}
189
190impl GlobalSearchRoomState {
191 fn new(room: Room) -> Self {
192 Self { room, offset: None }
193 }
194}
195
196#[derive(Debug)]
199pub struct GlobalSearchBuilder {
200 client: Client,
201
202 query: String,
204
205 num_results_per_batch: usize,
208
209 room_set: Vec<Room>,
211}
212
213impl GlobalSearchBuilder {
214 fn new(client: Client, query: String, num_results_per_batch: usize) -> Self {
216 let room_set = client.rooms_filtered(RoomStateFilter::JOINED);
217 Self { client, query, room_set, num_results_per_batch }
218 }
219
220 pub async fn only_dm_rooms(mut self) -> Result<Self, crate::Error> {
222 let mut to_remove = HashSet::new();
223 for room in &self.room_set {
224 if !room.is_direct().await? {
225 to_remove.insert(room.room_id().to_owned());
226 }
227 }
228 self.room_set.retain(|room| !to_remove.contains(room.room_id()));
229 Ok(self)
230 }
231
232 pub async fn no_dms(mut self) -> Result<Self, crate::Error> {
234 let mut to_remove = HashSet::new();
235 for room in &self.room_set {
236 if room.is_direct().await? {
237 to_remove.insert(room.room_id().to_owned());
238 }
239 }
240 self.room_set.retain(|room| !to_remove.contains(room.room_id()));
241 Ok(self)
242 }
243
244 pub fn build(self) -> GlobalSearchIterator {
246 GlobalSearchIterator {
247 client: self.client,
248 query: self.query,
249 room_state: Vec::from_iter(self.room_set.into_iter().map(GlobalSearchRoomState::new)),
250 current_batch: Vec::new(),
251 num_results_per_batch: self.num_results_per_batch,
252 }
253 }
254}
255
256impl Client {
257 pub fn search_messages(
260 &self,
261 query: String,
262 num_results_per_batch: usize,
263 ) -> GlobalSearchBuilder {
264 GlobalSearchBuilder::new(self.clone(), query, num_results_per_batch)
265 }
266}
267
268#[derive(Debug)]
270pub struct GlobalSearchIterator {
271 client: Client,
272
273 query: String,
275
276 room_state: Vec<GlobalSearchRoomState>,
282
283 current_batch: Vec<(OwnedRoomId, OwnedEventId)>,
287
288 num_results_per_batch: usize,
291}
292
293impl GlobalSearchIterator {
294 pub async fn next(&mut self) -> Result<Option<Vec<(OwnedRoomId, OwnedEventId)>>, SearchError> {
297 if self.room_state.is_empty() {
298 return Ok(None);
299 }
300
301 if self.current_batch.len() >= self.num_results_per_batch {
304 return Ok(Some(self.current_batch.drain(0..self.num_results_per_batch).collect()));
305 }
306
307 let mut to_remove = HashSet::new();
308
309 for room_state in &mut self.room_state {
312 let room_results = room_state
313 .room
314 .search(&self.query, self.num_results_per_batch, room_state.offset)
315 .await?;
316
317 if room_results.is_empty() {
318 to_remove.insert(room_state.room.room_id().to_owned());
320 } else {
321 room_state.offset = Some(room_state.offset.unwrap_or(0) + room_results.len());
323
324 self.current_batch.extend(
326 room_results
327 .into_iter()
328 .map(|event_id| (room_state.room.room_id().to_owned(), event_id)),
329 );
330
331 if self.current_batch.len() >= self.num_results_per_batch {
332 break;
334 }
335 }
336 }
337
338 for room_id in to_remove {
340 self.room_state.retain(|room_state| room_state.room.room_id() != room_id);
341 }
342
343 if !self.current_batch.is_empty() {
344 let high = self.num_results_per_batch.min(self.current_batch.len());
345 Ok(Some(self.current_batch.drain(0..high).collect()))
346 } else {
347 debug_assert!(self.room_state.is_empty());
348 Ok(None)
349 }
350 }
351
352 pub async fn next_events(
355 &mut self,
356 ) -> Result<Option<Vec<(OwnedRoomId, TimelineEvent)>>, SearchError> {
357 let Some(event_ids) = self.next().await? else {
358 return Ok(None);
359 };
360 let mut results = Vec::with_capacity(event_ids.len());
361 for (room_id, event_id) in event_ids {
362 let Some(room) = self.client.get_room(&room_id) else {
363 continue;
364 };
365 results.push((room_id, room.load_or_fetch_event(&event_id, None).await?));
366 }
367 Ok(Some(results))
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use std::time::Duration;
374
375 use matrix_sdk_test::{BOB, JoinedRoomBuilder, async_test, event_factory::EventFactory};
376 use ruma::{event_id, room_id, user_id};
377
378 use crate::{sleep::sleep, test_utils::mocks::MatrixMockServer};
379
380 #[async_test]
381 async fn test_room_message_search() {
382 let server = MatrixMockServer::new().await;
383 let client = server.client_builder().build().await;
384
385 let event_cache = client.event_cache();
386 event_cache.subscribe().unwrap();
387
388 let room_id = room_id!("!room_id:localhost");
389 let room = server.sync_joined_room(&client, room_id).await;
390
391 let f = EventFactory::new().room(room_id).sender(user_id!("@user_id:localhost"));
392
393 let event_id = event_id!("$event_id:localhost");
394
395 server
396 .sync_room(
397 &client,
398 JoinedRoomBuilder::new(room_id)
399 .add_timeline_event(f.text_msg("hello world").event_id(event_id)),
400 )
401 .await;
402
403 sleep(Duration::from_millis(200)).await;
405
406 {
408 let mut room_search = room.search_messages("search query".to_owned(), 5);
409
410 let maybe_results = room_search.next().await.unwrap();
412 assert!(maybe_results.is_none());
413
414 let maybe_results = room_search.next().await.unwrap();
417 assert!(maybe_results.is_none());
418 }
419
420 {
422 let mut room_search = room.search_messages("world".to_owned(), 5);
423
424 let maybe_results = room_search.next().await.unwrap();
427 let results = maybe_results.unwrap();
428 assert_eq!(results.len(), 1);
429 assert_eq!(&results[0], event_id,);
430
431 let maybe_results = room_search.next().await.unwrap();
433 assert!(maybe_results.is_none());
434 }
435
436 {
438 let mut room_search = room.search_messages("world".to_owned(), 5);
439
440 let maybe_results = room_search.next_events().await.unwrap();
443 let results = maybe_results.unwrap();
444 assert_eq!(results.len(), 1);
445 assert_eq!(results[0].event_id().as_deref().unwrap(), event_id,);
446
447 let maybe_results = room_search.next_events().await.unwrap();
449 assert!(maybe_results.is_none());
450 }
451 }
452
453 #[async_test]
454 async fn test_global_message_search() {
455 let server = MatrixMockServer::new().await;
456 let client = server.client_builder().build().await;
457
458 let event_cache = client.event_cache();
459 event_cache.subscribe().unwrap();
460
461 let room_id1 = room_id!("!r1:localhost");
462 let room_id2 = room_id!("!r2:localhost");
463
464 let f = EventFactory::new().sender(user_id!("@user_id:localhost"));
465
466 let result_event_id1 = event_id!("$result1:localhost");
467 let result_event_id2 = event_id!("$result2:localhost");
468
469 server
470 .mock_sync()
471 .ok_and_run(&client, |sync_builder| {
472 sync_builder
473 .add_joined_room(
474 JoinedRoomBuilder::new(room_id1)
475 .add_timeline_event(
476 f.text_msg("hello world").room(room_id1).event_id(result_event_id1),
477 )
478 .add_timeline_event(f.text_msg("hello back").room(room_id1)),
479 )
480 .add_joined_room(JoinedRoomBuilder::new(room_id2).add_timeline_event(
481 f.text_msg("it's a mad world").room(room_id2).event_id(result_event_id2),
482 ));
483 })
484 .await;
485
486 sleep(Duration::from_millis(200)).await;
488
489 {
491 let mut search = client.search_messages("search query".to_owned(), 5).build();
492
493 let maybe_results = search.next().await.unwrap();
495 assert!(maybe_results.is_none());
496
497 let maybe_results = search.next().await.unwrap();
500 assert!(maybe_results.is_none());
501 }
502
503 {
505 let mut search = client.search_messages("world".to_owned(), 5).build();
506
507 let maybe_results = search.next().await.unwrap();
510 let results = maybe_results.unwrap();
511 assert_eq!(results.len(), 2);
512 assert!(results.contains(&(room_id1.to_owned(), result_event_id1.to_owned())));
515 assert!(results.contains(&(room_id2.to_owned(), result_event_id2.to_owned())));
516
517 let maybe_results = search.next().await.unwrap();
519 assert!(maybe_results.is_none());
520 }
521
522 {
524 let mut search = client.search_messages("world".to_owned(), 5).build();
525
526 let maybe_results = search.next_events().await.unwrap();
529 let results = maybe_results.unwrap();
530 assert_eq!(results.len(), 2);
531 assert!(results.iter().any(|(room_id, event)| {
534 room_id == room_id1 && event.event_id().as_deref() == Some(result_event_id1)
535 }));
536 assert!(results.iter().any(|(room_id, event)| {
537 room_id == room_id2 && event.event_id().as_deref() == Some(result_event_id2)
538 }));
539
540 let maybe_results = search.next_events().await.unwrap();
542 assert!(maybe_results.is_none());
543 }
544 }
545
546 #[async_test]
547 async fn test_global_message_search_dm_or_groups() {
548 let server = MatrixMockServer::new().await;
549 let client = server.client_builder().build().await;
550
551 let event_cache = client.event_cache();
552 event_cache.subscribe().unwrap();
553
554 let room_id1 = room_id!("!r1:localhost");
556 let room_id2 = room_id!("!r2:localhost");
558
559 let f = EventFactory::new().sender(user_id!("@user_id:localhost"));
560
561 let result_event_id1 = event_id!("$result1:localhost");
562 let result_event_id2 = event_id!("$result2:localhost");
563
564 server
565 .mock_sync()
566 .ok_and_run(&client, |sync_builder| {
567 sync_builder
568 .add_joined_room(
569 JoinedRoomBuilder::new(room_id1)
570 .add_timeline_event(
571 f.text_msg("hello world").room(room_id1).event_id(result_event_id1),
572 )
573 .add_timeline_event(f.text_msg("hello back").room(room_id1)),
574 )
575 .add_joined_room(JoinedRoomBuilder::new(room_id2).add_timeline_event(
576 f.text_msg("it's a mad world").room(room_id2).event_id(result_event_id2),
577 ))
578 .add_global_account_data(
580 f.direct().add_user((*BOB).to_owned().into(), room_id1),
581 );
582 })
583 .await;
584
585 sleep(Duration::from_millis(200)).await;
587
588 {
590 let mut search = client
591 .search_messages("world".to_owned(), 5)
592 .only_dm_rooms()
593 .await
594 .unwrap()
595 .build();
596
597 let maybe_results = search.next().await.unwrap();
598 let results = maybe_results.unwrap();
599 assert_eq!(results.len(), 1);
600 assert_eq!(&results[0], &(room_id1.to_owned(), result_event_id1.to_owned()));
601
602 let maybe_results = search.next().await.unwrap();
604 assert!(maybe_results.is_none());
605 }
606
607 {
609 let mut search =
610 client.search_messages("world".to_owned(), 5).no_dms().await.unwrap().build();
611
612 let maybe_results = search.next_events().await.unwrap();
613 let results = maybe_results.unwrap();
614 assert_eq!(results.len(), 1);
615 assert_eq!(results[0].0, room_id2);
616 assert_eq!(results[0].1.event_id().as_deref().unwrap(), result_event_id2);
617
618 let maybe_results = search.next().await.unwrap();
620 assert!(maybe_results.is_none());
621 }
622 }
623}