1use ruma::{
18 events::{
19 relation::Thread,
20 room::{
21 encrypted::Relation as EncryptedRelation,
22 message::{
23 AddMentions, ForwardThread, OriginalRoomMessageEvent, Relation, ReplyWithinThread,
24 RoomMessageEventContent, RoomMessageEventContentWithoutRelation,
25 },
26 },
27 AnyMessageLikeEventContent, AnySyncMessageLikeEvent, AnySyncTimelineEvent,
28 SyncMessageLikeEvent,
29 },
30 serde::Raw,
31 EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedUserId, RoomId, UserId,
32};
33use serde::Deserialize;
34use thiserror::Error;
35use tracing::{error, instrument};
36
37use super::{EventSource, Room};
38
39#[derive(Debug, Clone)]
41struct RepliedToInfo {
42 event_id: OwnedEventId,
44 sender: OwnedUserId,
46 timestamp: MilliSecondsSinceUnixEpoch,
48 content: ReplyContent,
50}
51
52#[derive(Debug, Clone)]
54enum ReplyContent {
55 Message(RoomMessageEventContent),
57 Raw(Raw<AnySyncTimelineEvent>),
59}
60
61#[derive(Debug, Error)]
63pub enum ReplyError {
64 #[error("Couldn't fetch the remote event: {0}")]
66 Fetch(Box<crate::Error>),
67 #[error("failed to deserialize event to reply to")]
69 Deserialization,
70 #[error("tried to reply to a state event")]
72 StateEvent,
73}
74
75#[derive(Clone, Copy, Debug, PartialEq, Eq)]
77pub enum EnforceThread {
78 Threaded(ReplyWithinThread),
81
82 MaybeThreaded,
85
86 Unthreaded,
89}
90
91impl Room {
92 #[instrument(skip(self, content), fields(room = %self.room_id()))]
98 pub async fn make_reply_event(
99 &self,
100 content: RoomMessageEventContentWithoutRelation,
101 event_id: &EventId,
102 enforce_thread: EnforceThread,
103 ) -> Result<AnyMessageLikeEventContent, ReplyError> {
104 make_reply_event(
105 self,
106 self.room_id(),
107 self.own_user_id(),
108 content,
109 event_id,
110 enforce_thread,
111 )
112 .await
113 }
114}
115
116async fn make_reply_event<S: EventSource>(
117 source: S,
118 room_id: &RoomId,
119 own_user_id: &UserId,
120 content: RoomMessageEventContentWithoutRelation,
121 event_id: &EventId,
122 enforce_thread: EnforceThread,
123) -> Result<AnyMessageLikeEventContent, ReplyError> {
124 let replied_to_info = replied_to_info_from_event_id(source, event_id).await?;
125
126 let mention_the_sender =
134 if own_user_id == replied_to_info.sender { AddMentions::No } else { AddMentions::Yes };
135
136 let content = match replied_to_info.content {
137 ReplyContent::Message(replied_to_content) => {
138 let event = OriginalRoomMessageEvent {
139 event_id: replied_to_info.event_id,
140 sender: replied_to_info.sender,
141 origin_server_ts: replied_to_info.timestamp,
142 room_id: room_id.to_owned(),
143 content: replied_to_content,
144 unsigned: Default::default(),
145 };
146
147 match enforce_thread {
148 EnforceThread::Threaded(is_reply) => {
149 content.make_for_thread(&event, is_reply, mention_the_sender)
150 }
151 EnforceThread::MaybeThreaded => {
152 content.make_reply_to(&event, ForwardThread::Yes, mention_the_sender)
153 }
154 EnforceThread::Unthreaded => {
155 content.make_reply_to(&event, ForwardThread::No, mention_the_sender)
156 }
157 }
158 }
159
160 ReplyContent::Raw(raw_event) => {
161 match enforce_thread {
162 EnforceThread::Threaded(is_reply) => {
163 #[derive(Deserialize)]
168 struct ContentDeHelper {
169 #[serde(rename = "m.relates_to")]
170 relates_to: Option<EncryptedRelation>,
171 }
172
173 let previous_content =
174 raw_event.get_field::<ContentDeHelper>("content").ok().flatten();
175
176 let mut content = if is_reply == ReplyWithinThread::Yes {
177 content.make_reply_to_raw(
178 &raw_event,
179 replied_to_info.event_id.to_owned(),
180 room_id,
181 ForwardThread::No,
182 mention_the_sender,
183 )
184 } else {
185 content.into()
186 };
187
188 let thread_root = if let Some(EncryptedRelation::Thread(thread)) =
189 previous_content.as_ref().and_then(|c| c.relates_to.as_ref())
190 {
191 thread.event_id.to_owned()
192 } else {
193 replied_to_info.event_id.to_owned()
194 };
195
196 let thread = if is_reply == ReplyWithinThread::Yes {
197 Thread::reply(thread_root, replied_to_info.event_id)
198 } else {
199 Thread::plain(thread_root, replied_to_info.event_id)
200 };
201
202 content.relates_to = Some(Relation::Thread(thread));
203 content
204 }
205
206 EnforceThread::MaybeThreaded => content.make_reply_to_raw(
207 &raw_event,
208 replied_to_info.event_id,
209 room_id,
210 ForwardThread::Yes,
211 mention_the_sender,
212 ),
213
214 EnforceThread::Unthreaded => content.make_reply_to_raw(
215 &raw_event,
216 replied_to_info.event_id,
217 room_id,
218 ForwardThread::No,
219 mention_the_sender,
220 ),
221 }
222 }
223 };
224
225 Ok(content.into())
226}
227
228async fn replied_to_info_from_event_id<S: EventSource>(
229 source: S,
230 event_id: &EventId,
231) -> Result<RepliedToInfo, ReplyError> {
232 let event = source.get_event(event_id).await.map_err(|err| ReplyError::Fetch(Box::new(err)))?;
233
234 let raw_event = event.into_raw();
235 let event = raw_event.deserialize().map_err(|_| ReplyError::Deserialization)?;
236
237 let reply_content = match &event {
238 AnySyncTimelineEvent::MessageLike(event) => {
239 if let AnySyncMessageLikeEvent::RoomMessage(SyncMessageLikeEvent::Original(
240 original_event,
241 )) = event
242 {
243 ReplyContent::Message(original_event.content.clone())
244 } else {
245 ReplyContent::Raw(raw_event)
246 }
247 }
248 AnySyncTimelineEvent::State(_) => return Err(ReplyError::StateEvent),
249 };
250
251 Ok(RepliedToInfo {
252 event_id: event_id.to_owned(),
253 sender: event.sender().to_owned(),
254 timestamp: event.origin_server_ts(),
255 content: reply_content,
256 })
257}
258
259#[cfg(test)]
260mod tests {
261 use std::collections::BTreeMap;
262
263 use assert_matches2::{assert_let, assert_matches};
264 use matrix_sdk_base::deserialized_responses::TimelineEvent;
265 use matrix_sdk_test::{async_test, event_factory::EventFactory};
266 use ruma::{
267 event_id,
268 events::{
269 room::message::{Relation, ReplyWithinThread, RoomMessageEventContentWithoutRelation},
270 AnyMessageLikeEventContent, AnySyncTimelineEvent,
271 },
272 room_id,
273 serde::Raw,
274 user_id, EventId, OwnedEventId,
275 };
276 use serde_json::json;
277
278 use super::{make_reply_event, EnforceThread, EventSource, ReplyError};
279 use crate::{event_cache::EventCacheError, Error};
280
281 #[derive(Default)]
282 struct TestEventCache {
283 events: BTreeMap<OwnedEventId, TimelineEvent>,
284 }
285
286 impl EventSource for TestEventCache {
287 async fn get_event(&self, event_id: &EventId) -> Result<TimelineEvent, Error> {
288 self.events
289 .get(event_id)
290 .cloned()
291 .ok_or(Error::EventCache(EventCacheError::ClientDropped))
292 }
293 }
294
295 #[async_test]
296 async fn test_cannot_reply_to_unknown_event() {
297 let event_id = event_id!("$1");
298 let own_user_id = user_id!("@me:saucisse.bzh");
299
300 let mut cache = TestEventCache::default();
301 let f = EventFactory::new();
302 cache.events.insert(
303 event_id.to_owned(),
304 f.text_msg("hi").event_id(event_id).sender(own_user_id).into(),
305 );
306
307 let room_id = room_id!("!galette:saucisse.bzh");
308 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
309
310 assert_matches!(
311 make_reply_event(
312 cache,
313 room_id,
314 own_user_id,
315 content,
316 event_id!("$2"),
317 EnforceThread::Unthreaded,
318 )
319 .await,
320 Err(ReplyError::Fetch(_))
321 );
322 }
323
324 #[async_test]
325 async fn test_cannot_reply_to_invalid_event() {
326 let event_id = event_id!("$1");
327 let own_user_id = user_id!("@me:saucisse.bzh");
328
329 let mut cache = TestEventCache::default();
330
331 cache.events.insert(
332 event_id.to_owned(),
333 TimelineEvent::new(
334 Raw::<AnySyncTimelineEvent>::from_json_string(
335 json!({
336 "content": {
337 "body": "hi"
338 },
339 "event_id": event_id,
340 "origin_server_ts": 1,
341 "type": "m.room.message",
342 })
344 .to_string(),
345 )
346 .unwrap(),
347 ),
348 );
349
350 let room_id = room_id!("!galette:saucisse.bzh");
351 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
352
353 assert_matches!(
354 make_reply_event(
355 cache,
356 room_id,
357 own_user_id,
358 content,
359 event_id,
360 EnforceThread::Unthreaded,
361 )
362 .await,
363 Err(ReplyError::Deserialization)
364 );
365 }
366
367 #[async_test]
368 async fn test_cannot_reply_to_state_event() {
369 let event_id = event_id!("$1");
370 let own_user_id = user_id!("@me:saucisse.bzh");
371
372 let mut cache = TestEventCache::default();
373 let f = EventFactory::new();
374 cache.events.insert(
375 event_id.to_owned(),
376 f.room_name("lobby").event_id(event_id).sender(own_user_id).into(),
377 );
378
379 let room_id = room_id!("!galette:saucisse.bzh");
380 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
381
382 assert_matches!(
383 make_reply_event(
384 cache,
385 room_id,
386 own_user_id,
387 content,
388 event_id,
389 EnforceThread::Unthreaded,
390 )
391 .await,
392 Err(ReplyError::StateEvent)
393 );
394 }
395
396 #[async_test]
397 async fn test_reply_unthreaded() {
398 let event_id = event_id!("$1");
399 let own_user_id = user_id!("@me:saucisse.bzh");
400
401 let mut cache = TestEventCache::default();
402 let f = EventFactory::new();
403 cache.events.insert(
404 event_id.to_owned(),
405 f.text_msg("hi").event_id(event_id).sender(own_user_id).into(),
406 );
407
408 let room_id = room_id!("!galette:saucisse.bzh");
409 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
410
411 let reply_event = make_reply_event(
412 cache,
413 room_id,
414 own_user_id,
415 content,
416 event_id,
417 EnforceThread::Unthreaded,
418 )
419 .await
420 .unwrap();
421
422 assert_let!(AnyMessageLikeEventContent::RoomMessage(msg) = &reply_event);
423 assert_let!(Some(Relation::Reply { in_reply_to }) = &msg.relates_to);
424
425 assert_eq!(in_reply_to.event_id, event_id);
426 }
427
428 #[async_test]
429 async fn test_start_thread() {
430 let event_id = event_id!("$1");
431 let own_user_id = user_id!("@me:saucisse.bzh");
432
433 let mut cache = TestEventCache::default();
434 let f = EventFactory::new();
435 cache.events.insert(
436 event_id.to_owned(),
437 f.text_msg("hi").event_id(event_id).sender(own_user_id).into(),
438 );
439
440 let room_id = room_id!("!galette:saucisse.bzh");
441 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
442
443 let reply_event = make_reply_event(
444 cache,
445 room_id,
446 own_user_id,
447 content,
448 event_id,
449 EnforceThread::Threaded(ReplyWithinThread::No),
450 )
451 .await
452 .unwrap();
453
454 assert_let!(AnyMessageLikeEventContent::RoomMessage(msg) = &reply_event);
455 assert_let!(Some(Relation::Thread(thread)) = &msg.relates_to);
456
457 assert_eq!(thread.event_id, event_id);
458 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
459 assert!(thread.is_falling_back);
460 }
461
462 #[async_test]
463 async fn test_reply_on_thread() {
464 let thread_root = event_id!("$1");
465 let event_id = event_id!("$2");
466 let own_user_id = user_id!("@me:saucisse.bzh");
467
468 let mut cache = TestEventCache::default();
469 let f = EventFactory::new();
470 cache.events.insert(
471 thread_root.to_owned(),
472 f.text_msg("hi").event_id(thread_root).sender(own_user_id).into(),
473 );
474 cache.events.insert(
475 event_id.to_owned(),
476 f.text_msg("ho")
477 .in_thread(thread_root, thread_root)
478 .event_id(event_id)
479 .sender(own_user_id)
480 .into(),
481 );
482
483 let room_id = room_id!("!galette:saucisse.bzh");
484 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
485
486 let reply_event = make_reply_event(
487 cache,
488 room_id,
489 own_user_id,
490 content,
491 event_id,
492 EnforceThread::Threaded(ReplyWithinThread::No),
493 )
494 .await
495 .unwrap();
496
497 assert_let!(AnyMessageLikeEventContent::RoomMessage(msg) = &reply_event);
498 assert_let!(Some(Relation::Thread(thread)) = &msg.relates_to);
499
500 assert_eq!(thread.event_id, thread_root);
501 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
502 assert!(thread.is_falling_back);
503 }
504
505 #[async_test]
506 async fn test_reply_on_thread_as_reply() {
507 let thread_root = event_id!("$1");
508 let event_id = event_id!("$2");
509 let own_user_id = user_id!("@me:saucisse.bzh");
510
511 let mut cache = TestEventCache::default();
512 let f = EventFactory::new();
513 cache.events.insert(
514 thread_root.to_owned(),
515 f.text_msg("hi").event_id(thread_root).sender(own_user_id).into(),
516 );
517 cache.events.insert(
518 event_id.to_owned(),
519 f.text_msg("ho")
520 .in_thread(thread_root, thread_root)
521 .event_id(event_id)
522 .sender(own_user_id)
523 .into(),
524 );
525
526 let room_id = room_id!("!galette:saucisse.bzh");
527 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
528
529 let reply_event = make_reply_event(
530 cache,
531 room_id,
532 own_user_id,
533 content,
534 event_id,
535 EnforceThread::Threaded(ReplyWithinThread::Yes),
536 )
537 .await
538 .unwrap();
539
540 assert_let!(AnyMessageLikeEventContent::RoomMessage(msg) = &reply_event);
541 assert_let!(Some(Relation::Thread(thread)) = &msg.relates_to);
542
543 assert_eq!(thread.event_id, thread_root);
544 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
545 assert!(!thread.is_falling_back);
546 }
547
548 #[async_test]
549 async fn test_reply_forwarding_thread() {
550 let thread_root = event_id!("$1");
551 let event_id = event_id!("$2");
552 let own_user_id = user_id!("@me:saucisse.bzh");
553
554 let mut cache = TestEventCache::default();
555 let f = EventFactory::new();
556 cache.events.insert(
557 thread_root.to_owned(),
558 f.text_msg("hi").event_id(thread_root).sender(own_user_id).into(),
559 );
560 cache.events.insert(
561 event_id.to_owned(),
562 f.text_msg("ho")
563 .in_thread(thread_root, thread_root)
564 .event_id(event_id)
565 .sender(own_user_id)
566 .into(),
567 );
568
569 let room_id = room_id!("!galette:saucisse.bzh");
570 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
571
572 let reply_event = make_reply_event(
573 cache,
574 room_id,
575 own_user_id,
576 content,
577 event_id,
578 EnforceThread::MaybeThreaded,
579 )
580 .await
581 .unwrap();
582
583 assert_let!(AnyMessageLikeEventContent::RoomMessage(msg) = &reply_event);
584 assert_let!(Some(Relation::Thread(thread)) = &msg.relates_to);
585
586 assert_eq!(thread.event_id, thread_root);
587 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
588 assert!(thread.is_falling_back);
589 }
590}