1use as_variant::as_variant;
18use ruma::{
19 OwnedEventId, UserId,
20 events::{
21 AnySyncTimelineEvent,
22 room::{
23 encrypted::Relation as EncryptedRelation,
24 message::{
25 AddMentions, ForwardThread, ReplyMetadata, ReplyWithinThread,
26 RoomMessageEventContent, RoomMessageEventContentWithoutRelation,
27 },
28 },
29 },
30};
31use thiserror::Error;
32use tracing::instrument;
33
34use super::{EventSource, Room};
35
36#[derive(Debug)]
38pub struct Reply {
39 pub event_id: OwnedEventId,
41 pub enforce_thread: EnforceThread,
43 pub add_mentions: AddMentions,
46}
47
48#[derive(Debug, Error)]
50pub enum ReplyError {
51 #[error("Couldn't fetch the remote event: {0}")]
53 Fetch(Box<crate::Error>),
54 #[error("failed to deserialize event to reply to")]
56 Deserialization,
57 #[error("tried to reply to a state event")]
59 StateEvent,
60}
61
62#[derive(Clone, Copy, Debug, PartialEq, Eq)]
66pub enum EnforceThread {
67 Threaded(ReplyWithinThread),
70
71 MaybeThreaded,
74
75 Unthreaded,
78}
79
80impl Room {
81 #[instrument(skip(self, content), fields(room = %self.room_id()))]
93 pub async fn make_reply_event(
94 &self,
95 content: RoomMessageEventContentWithoutRelation,
96 reply: Reply,
97 ) -> Result<RoomMessageEventContent, ReplyError> {
98 make_reply_event(self, self.own_user_id(), content, reply).await
99 }
100}
101
102async fn make_reply_event<S: EventSource>(
103 source: S,
104 own_user_id: &UserId,
105 content: RoomMessageEventContentWithoutRelation,
106 reply: Reply,
107) -> Result<RoomMessageEventContent, ReplyError> {
108 let event =
109 source.get_event(&reply.event_id).await.map_err(|err| ReplyError::Fetch(Box::new(err)))?;
110
111 let raw_event = event.into_raw();
112 let event = raw_event.deserialize().map_err(|_| ReplyError::Deserialization)?;
113
114 let relation = as_variant!(&event, AnySyncTimelineEvent::MessageLike)
115 .ok_or(ReplyError::StateEvent)?
116 .original_content()
117 .and_then(|content| content.relation());
118 let thread =
119 relation.as_ref().and_then(|relation| as_variant!(relation, EncryptedRelation::Thread));
120
121 let reply_metadata = ReplyMetadata::new(event.event_id(), event.sender(), thread);
122
123 let mention_the_sender =
131 if own_user_id == event.sender() { AddMentions::No } else { reply.add_mentions };
132
133 let content = match reply.enforce_thread {
134 EnforceThread::Threaded(is_reply) => {
135 content.make_for_thread(reply_metadata, is_reply, mention_the_sender)
136 }
137 EnforceThread::MaybeThreaded => {
138 content.make_reply_to(reply_metadata, ForwardThread::Yes, mention_the_sender)
139 }
140 EnforceThread::Unthreaded => {
141 content.make_reply_to(reply_metadata, ForwardThread::No, mention_the_sender)
142 }
143 };
144
145 Ok(content)
146}
147
148#[cfg(test)]
149mod tests {
150 use std::collections::BTreeMap;
151
152 use assert_matches2::{assert_let, assert_matches};
153 use matrix_sdk_base::deserialized_responses::TimelineEvent;
154 use matrix_sdk_test::{async_test, event_factory::EventFactory};
155 use ruma::{
156 EventId, OwnedEventId, event_id,
157 events::{
158 AnySyncTimelineEvent,
159 room::message::{
160 AddMentions, Relation, ReplyWithinThread, RoomMessageEventContentWithoutRelation,
161 },
162 },
163 serde::Raw,
164 user_id,
165 };
166 use serde_json::json;
167
168 use super::{EnforceThread, EventSource, Reply, ReplyError, make_reply_event};
169 use crate::{Error, event_cache::EventCacheError};
170
171 #[derive(Default)]
172 struct TestEventCache {
173 events: BTreeMap<OwnedEventId, TimelineEvent>,
174 }
175
176 impl EventSource for TestEventCache {
177 async fn get_event(&self, event_id: &EventId) -> Result<TimelineEvent, Error> {
178 self.events
179 .get(event_id)
180 .cloned()
181 .ok_or(Error::EventCache(Box::new(EventCacheError::ClientDropped)))
182 }
183 }
184
185 #[async_test]
186 async fn test_cannot_reply_to_unknown_event() {
187 let event_id = event_id!("$1");
188 let own_user_id = user_id!("@me:saucisse.bzh");
189
190 let mut cache = TestEventCache::default();
191 let f = EventFactory::new();
192 cache.events.insert(
193 event_id.to_owned(),
194 f.text_msg("hi").event_id(event_id).sender(own_user_id).into(),
195 );
196
197 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
198
199 assert_matches!(
200 make_reply_event(
201 cache,
202 own_user_id,
203 content,
204 Reply {
205 event_id: event_id!("$2").into(),
206 enforce_thread: EnforceThread::Unthreaded,
207 add_mentions: AddMentions::Yes,
208 },
209 )
210 .await,
211 Err(ReplyError::Fetch(_))
212 );
213 }
214
215 #[async_test]
216 async fn test_cannot_reply_to_invalid_event() {
217 let event_id = event_id!("$1");
218 let own_user_id = user_id!("@me:saucisse.bzh");
219
220 let mut cache = TestEventCache::default();
221
222 cache.events.insert(
223 event_id.to_owned(),
224 TimelineEvent::from_plaintext(
225 Raw::<AnySyncTimelineEvent>::from_json_string(
226 json!({
227 "content": {
228 "body": "hi"
229 },
230 "event_id": event_id,
231 "origin_server_ts": 1,
232 "type": "m.room.message",
233 })
235 .to_string(),
236 )
237 .unwrap(),
238 ),
239 );
240
241 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
242
243 assert_matches!(
244 make_reply_event(
245 cache,
246 own_user_id,
247 content,
248 Reply {
249 event_id: event_id.into(),
250 enforce_thread: EnforceThread::Unthreaded,
251 add_mentions: AddMentions::Yes,
252 },
253 )
254 .await,
255 Err(ReplyError::Deserialization)
256 );
257 }
258
259 #[async_test]
260 async fn test_cannot_reply_to_state_event() {
261 let event_id = event_id!("$1");
262 let own_user_id = user_id!("@me:saucisse.bzh");
263
264 let mut cache = TestEventCache::default();
265 let f = EventFactory::new();
266 cache.events.insert(
267 event_id.to_owned(),
268 f.room_name("lobby").event_id(event_id).sender(own_user_id).into(),
269 );
270
271 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
272
273 assert_matches!(
274 make_reply_event(
275 cache,
276 own_user_id,
277 content,
278 Reply {
279 event_id: event_id.into(),
280 enforce_thread: EnforceThread::Unthreaded,
281 add_mentions: AddMentions::Yes,
282 },
283 )
284 .await,
285 Err(ReplyError::StateEvent)
286 );
287 }
288
289 #[async_test]
290 async fn test_reply_unthreaded() {
291 let event_id = event_id!("$1");
292 let own_user_id = user_id!("@me:saucisse.bzh");
293
294 let mut cache = TestEventCache::default();
295 let f = EventFactory::new();
296 cache.events.insert(
297 event_id.to_owned(),
298 f.text_msg("hi").event_id(event_id).sender(own_user_id).into(),
299 );
300
301 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
302
303 let reply_event = make_reply_event(
304 cache,
305 own_user_id,
306 content,
307 Reply {
308 event_id: event_id.into(),
309 enforce_thread: EnforceThread::Unthreaded,
310 add_mentions: AddMentions::Yes,
311 },
312 )
313 .await
314 .unwrap();
315
316 assert_let!(Some(Relation::Reply(reply)) = &reply_event.relates_to);
317
318 assert_eq!(reply.in_reply_to.event_id, event_id);
319 }
320
321 #[async_test]
322 async fn test_start_thread() {
323 let event_id = event_id!("$1");
324 let own_user_id = user_id!("@me:saucisse.bzh");
325
326 let mut cache = TestEventCache::default();
327 let f = EventFactory::new();
328 cache.events.insert(
329 event_id.to_owned(),
330 f.text_msg("hi").event_id(event_id).sender(own_user_id).into(),
331 );
332
333 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
334
335 let reply_event = make_reply_event(
336 cache,
337 own_user_id,
338 content,
339 Reply {
340 event_id: event_id.into(),
341 enforce_thread: EnforceThread::Threaded(ReplyWithinThread::No),
342 add_mentions: AddMentions::Yes,
343 },
344 )
345 .await
346 .unwrap();
347
348 assert_let!(Some(Relation::Thread(thread)) = &reply_event.relates_to);
349
350 assert_eq!(thread.event_id, event_id);
351 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
352 assert!(thread.is_falling_back);
353 }
354
355 #[async_test]
356 async fn test_reply_on_thread() {
357 let thread_root = event_id!("$1");
358 let event_id = event_id!("$2");
359 let own_user_id = user_id!("@me:saucisse.bzh");
360
361 let mut cache = TestEventCache::default();
362 let f = EventFactory::new();
363 cache.events.insert(
364 thread_root.to_owned(),
365 f.text_msg("hi").event_id(thread_root).sender(own_user_id).into(),
366 );
367 cache.events.insert(
368 event_id.to_owned(),
369 f.text_msg("ho")
370 .in_thread(thread_root, thread_root)
371 .event_id(event_id)
372 .sender(own_user_id)
373 .into(),
374 );
375
376 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
377
378 let reply_event = make_reply_event(
379 cache,
380 own_user_id,
381 content,
382 Reply {
383 event_id: event_id.into(),
384 enforce_thread: EnforceThread::Threaded(ReplyWithinThread::No),
385 add_mentions: AddMentions::Yes,
386 },
387 )
388 .await
389 .unwrap();
390
391 assert_let!(Some(Relation::Thread(thread)) = &reply_event.relates_to);
392
393 assert_eq!(thread.event_id, thread_root);
394 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
395 assert!(thread.is_falling_back);
396 }
397
398 #[async_test]
399 async fn test_reply_on_thread_as_reply() {
400 let thread_root = event_id!("$1");
401 let event_id = event_id!("$2");
402 let own_user_id = user_id!("@me:saucisse.bzh");
403
404 let mut cache = TestEventCache::default();
405 let f = EventFactory::new();
406 cache.events.insert(
407 thread_root.to_owned(),
408 f.text_msg("hi").event_id(thread_root).sender(own_user_id).into(),
409 );
410 cache.events.insert(
411 event_id.to_owned(),
412 f.text_msg("ho")
413 .in_thread(thread_root, thread_root)
414 .event_id(event_id)
415 .sender(own_user_id)
416 .into(),
417 );
418
419 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
420
421 let reply_event = make_reply_event(
422 cache,
423 own_user_id,
424 content,
425 Reply {
426 event_id: event_id.into(),
427 enforce_thread: EnforceThread::Threaded(ReplyWithinThread::Yes),
428 add_mentions: AddMentions::Yes,
429 },
430 )
431 .await
432 .unwrap();
433
434 assert_let!(Some(Relation::Thread(thread)) = &reply_event.relates_to);
435
436 assert_eq!(thread.event_id, thread_root);
437 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
438 assert!(!thread.is_falling_back);
439 }
440
441 #[async_test]
442 async fn test_reply_forwarding_thread() {
443 let thread_root = event_id!("$1");
444 let event_id = event_id!("$2");
445 let own_user_id = user_id!("@me:saucisse.bzh");
446
447 let mut cache = TestEventCache::default();
448 let f = EventFactory::new();
449 cache.events.insert(
450 thread_root.to_owned(),
451 f.text_msg("hi").event_id(thread_root).sender(own_user_id).into(),
452 );
453 cache.events.insert(
454 event_id.to_owned(),
455 f.text_msg("ho")
456 .in_thread(thread_root, thread_root)
457 .event_id(event_id)
458 .sender(own_user_id)
459 .into(),
460 );
461
462 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
463
464 let reply_event = make_reply_event(
465 cache,
466 own_user_id,
467 content,
468 Reply {
469 event_id: event_id.into(),
470 enforce_thread: EnforceThread::MaybeThreaded,
471 add_mentions: AddMentions::Yes,
472 },
473 )
474 .await
475 .unwrap();
476
477 assert_let!(Some(Relation::Thread(thread)) = &reply_event.relates_to);
478
479 assert_eq!(thread.event_id, thread_root);
480 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
481 assert!(thread.is_falling_back);
482 }
483
484 #[async_test]
485 async fn test_reply_forwarding_thread_for_poll_start() {
486 let thread_root = event_id!("$thread_root");
487 let event_id = event_id!("$thread_reply");
488 let own_user_id = user_id!("@me:saucisse.bzh");
489
490 let mut cache = TestEventCache::default();
491 let f = EventFactory::new();
492
493 cache.events.insert(
494 event_id.to_owned(),
495 f.poll_start(
496 "would you rather… A) eat a pineapple pizza, B) drink pickle juice",
497 "would you rather…",
498 vec!["eat a pineapple pizza", "drink pickle juice"],
499 )
500 .in_thread(thread_root, thread_root)
501 .event_id(event_id)
502 .sender(own_user_id)
503 .into(),
504 );
505
506 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
507
508 let reply_event = make_reply_event(
509 cache,
510 own_user_id,
511 content,
512 Reply {
513 event_id: event_id.into(),
514 enforce_thread: EnforceThread::Threaded(ReplyWithinThread::No),
515 add_mentions: AddMentions::Yes,
516 },
517 )
518 .await
519 .unwrap();
520
521 assert_let!(Some(Relation::Thread(thread)) = &reply_event.relates_to);
522
523 assert_eq!(thread.event_id, thread_root);
524 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
525 assert!(thread.is_falling_back);
526 }
527
528 #[async_test]
529 async fn test_reply_without_add_mentions() {
530 let event_id = event_id!("$1");
531 let other_user_id = user_id!("@you:saucisse.bzh");
532 let own_user_id = user_id!("@me:saucisse.bzh");
533
534 let mut cache = TestEventCache::default();
535 let f = EventFactory::new();
536 cache.events.insert(
537 event_id.to_owned(),
538 f.text_msg("hi").event_id(event_id).sender(other_user_id).into(),
539 );
540
541 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
542
543 let reply_event = make_reply_event(
544 cache,
545 own_user_id,
546 content,
547 Reply {
548 event_id: event_id.into(),
549 enforce_thread: EnforceThread::Unthreaded,
550 add_mentions: AddMentions::No,
551 },
552 )
553 .await
554 .unwrap();
555
556 assert!(reply_event.mentions.is_none());
557 }
558
559 #[async_test]
560 async fn test_reply_with_add_mentions() {
561 let event_id = event_id!("$1");
562 let other_user_id = user_id!("@you:saucisse.bzh");
563 let own_user_id = user_id!("@me:saucisse.bzh");
564
565 let mut cache = TestEventCache::default();
566 let f = EventFactory::new();
567 cache.events.insert(
568 event_id.to_owned(),
569 f.text_msg("hi").event_id(event_id).sender(other_user_id).into(),
570 );
571
572 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
573
574 let reply_event = make_reply_event(
575 cache,
576 own_user_id,
577 content,
578 Reply {
579 event_id: event_id.into(),
580 enforce_thread: EnforceThread::Unthreaded,
581 add_mentions: AddMentions::Yes,
582 },
583 )
584 .await
585 .unwrap();
586
587 assert!(reply_event.mentions.is_some());
588 assert!(reply_event.mentions.unwrap().user_ids.contains(user_id!("@you:saucisse.bzh")));
589 }
590}