1use as_variant::as_variant;
18use ruma::{
19 events::{
20 room::{
21 encrypted::Relation as EncryptedRelation,
22 message::{
23 AddMentions, ForwardThread, ReplyMetadata, ReplyWithinThread,
24 RoomMessageEventContent, RoomMessageEventContentWithoutRelation,
25 },
26 },
27 AnySyncTimelineEvent,
28 },
29 OwnedEventId, UserId,
30};
31use thiserror::Error;
32use tracing::{error, instrument};
33
34use super::{EventSource, Room};
35
36#[derive(Debug)]
38pub struct Reply {
39 pub event_id: OwnedEventId,
41 pub enforce_thread: EnforceThread,
43}
44
45#[derive(Debug, Error)]
47pub enum ReplyError {
48 #[error("Couldn't fetch the remote event: {0}")]
50 Fetch(Box<crate::Error>),
51 #[error("failed to deserialize event to reply to")]
53 Deserialization,
54 #[error("tried to reply to a state event")]
56 StateEvent,
57}
58
59#[derive(Clone, Copy, Debug, PartialEq, Eq)]
63pub enum EnforceThread {
64 Threaded(ReplyWithinThread),
67
68 MaybeThreaded,
71
72 Unthreaded,
75}
76
77impl Room {
78 #[instrument(skip(self, content), fields(room = %self.room_id()))]
90 pub async fn make_reply_event(
91 &self,
92 content: RoomMessageEventContentWithoutRelation,
93 reply: Reply,
94 ) -> Result<RoomMessageEventContent, ReplyError> {
95 make_reply_event(self, self.own_user_id(), content, reply).await
96 }
97}
98
99async fn make_reply_event<S: EventSource>(
100 source: S,
101 own_user_id: &UserId,
102 content: RoomMessageEventContentWithoutRelation,
103 reply: Reply,
104) -> Result<RoomMessageEventContent, ReplyError> {
105 let event =
106 source.get_event(&reply.event_id).await.map_err(|err| ReplyError::Fetch(Box::new(err)))?;
107
108 let raw_event = event.into_raw();
109 let event = raw_event.deserialize().map_err(|_| ReplyError::Deserialization)?;
110
111 let relation = as_variant!(&event, AnySyncTimelineEvent::MessageLike)
112 .ok_or(ReplyError::StateEvent)?
113 .original_content()
114 .and_then(|content| content.relation());
115 let thread =
116 relation.as_ref().and_then(|relation| as_variant!(relation, EncryptedRelation::Thread));
117
118 let reply_metadata = ReplyMetadata::new(event.event_id(), event.sender(), thread);
119
120 let mention_the_sender =
128 if own_user_id == event.sender() { AddMentions::No } else { AddMentions::Yes };
129
130 let content = match reply.enforce_thread {
131 EnforceThread::Threaded(is_reply) => {
132 content.make_for_thread(reply_metadata, is_reply, mention_the_sender)
133 }
134 EnforceThread::MaybeThreaded => {
135 content.make_reply_to(reply_metadata, ForwardThread::Yes, mention_the_sender)
136 }
137 EnforceThread::Unthreaded => {
138 content.make_reply_to(reply_metadata, ForwardThread::No, mention_the_sender)
139 }
140 };
141
142 Ok(content)
143}
144
145#[cfg(test)]
146mod tests {
147 use std::collections::BTreeMap;
148
149 use assert_matches2::{assert_let, assert_matches};
150 use matrix_sdk_base::deserialized_responses::TimelineEvent;
151 use matrix_sdk_test::{async_test, event_factory::EventFactory};
152 use ruma::{
153 event_id,
154 events::{
155 room::message::{Relation, ReplyWithinThread, RoomMessageEventContentWithoutRelation},
156 AnySyncTimelineEvent,
157 },
158 serde::Raw,
159 user_id, EventId, OwnedEventId,
160 };
161 use serde_json::json;
162
163 use super::{make_reply_event, EnforceThread, EventSource, Reply, ReplyError};
164 use crate::{event_cache::EventCacheError, Error};
165
166 #[derive(Default)]
167 struct TestEventCache {
168 events: BTreeMap<OwnedEventId, TimelineEvent>,
169 }
170
171 impl EventSource for TestEventCache {
172 async fn get_event(&self, event_id: &EventId) -> Result<TimelineEvent, Error> {
173 self.events
174 .get(event_id)
175 .cloned()
176 .ok_or(Error::EventCache(Box::new(EventCacheError::ClientDropped)))
177 }
178 }
179
180 #[async_test]
181 async fn test_cannot_reply_to_unknown_event() {
182 let event_id = event_id!("$1");
183 let own_user_id = user_id!("@me:saucisse.bzh");
184
185 let mut cache = TestEventCache::default();
186 let f = EventFactory::new();
187 cache.events.insert(
188 event_id.to_owned(),
189 f.text_msg("hi").event_id(event_id).sender(own_user_id).into(),
190 );
191
192 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
193
194 assert_matches!(
195 make_reply_event(
196 cache,
197 own_user_id,
198 content,
199 Reply {
200 event_id: event_id!("$2").into(),
201 enforce_thread: EnforceThread::Unthreaded
202 },
203 )
204 .await,
205 Err(ReplyError::Fetch(_))
206 );
207 }
208
209 #[async_test]
210 async fn test_cannot_reply_to_invalid_event() {
211 let event_id = event_id!("$1");
212 let own_user_id = user_id!("@me:saucisse.bzh");
213
214 let mut cache = TestEventCache::default();
215
216 cache.events.insert(
217 event_id.to_owned(),
218 TimelineEvent::from_plaintext(
219 Raw::<AnySyncTimelineEvent>::from_json_string(
220 json!({
221 "content": {
222 "body": "hi"
223 },
224 "event_id": event_id,
225 "origin_server_ts": 1,
226 "type": "m.room.message",
227 })
229 .to_string(),
230 )
231 .unwrap(),
232 ),
233 );
234
235 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
236
237 assert_matches!(
238 make_reply_event(
239 cache,
240 own_user_id,
241 content,
242 Reply { event_id: event_id.into(), enforce_thread: EnforceThread::Unthreaded },
243 )
244 .await,
245 Err(ReplyError::Deserialization)
246 );
247 }
248
249 #[async_test]
250 async fn test_cannot_reply_to_state_event() {
251 let event_id = event_id!("$1");
252 let own_user_id = user_id!("@me:saucisse.bzh");
253
254 let mut cache = TestEventCache::default();
255 let f = EventFactory::new();
256 cache.events.insert(
257 event_id.to_owned(),
258 f.room_name("lobby").event_id(event_id).sender(own_user_id).into(),
259 );
260
261 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
262
263 assert_matches!(
264 make_reply_event(
265 cache,
266 own_user_id,
267 content,
268 Reply { event_id: event_id.into(), enforce_thread: EnforceThread::Unthreaded },
269 )
270 .await,
271 Err(ReplyError::StateEvent)
272 );
273 }
274
275 #[async_test]
276 async fn test_reply_unthreaded() {
277 let event_id = event_id!("$1");
278 let own_user_id = user_id!("@me:saucisse.bzh");
279
280 let mut cache = TestEventCache::default();
281 let f = EventFactory::new();
282 cache.events.insert(
283 event_id.to_owned(),
284 f.text_msg("hi").event_id(event_id).sender(own_user_id).into(),
285 );
286
287 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
288
289 let reply_event = make_reply_event(
290 cache,
291 own_user_id,
292 content,
293 Reply { event_id: event_id.into(), enforce_thread: EnforceThread::Unthreaded },
294 )
295 .await
296 .unwrap();
297
298 assert_let!(Some(Relation::Reply { in_reply_to }) = &reply_event.relates_to);
299
300 assert_eq!(in_reply_to.event_id, event_id);
301 }
302
303 #[async_test]
304 async fn test_start_thread() {
305 let event_id = event_id!("$1");
306 let own_user_id = user_id!("@me:saucisse.bzh");
307
308 let mut cache = TestEventCache::default();
309 let f = EventFactory::new();
310 cache.events.insert(
311 event_id.to_owned(),
312 f.text_msg("hi").event_id(event_id).sender(own_user_id).into(),
313 );
314
315 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
316
317 let reply_event = make_reply_event(
318 cache,
319 own_user_id,
320 content,
321 Reply {
322 event_id: event_id.into(),
323 enforce_thread: EnforceThread::Threaded(ReplyWithinThread::No),
324 },
325 )
326 .await
327 .unwrap();
328
329 assert_let!(Some(Relation::Thread(thread)) = &reply_event.relates_to);
330
331 assert_eq!(thread.event_id, event_id);
332 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
333 assert!(thread.is_falling_back);
334 }
335
336 #[async_test]
337 async fn test_reply_on_thread() {
338 let thread_root = event_id!("$1");
339 let event_id = event_id!("$2");
340 let own_user_id = user_id!("@me:saucisse.bzh");
341
342 let mut cache = TestEventCache::default();
343 let f = EventFactory::new();
344 cache.events.insert(
345 thread_root.to_owned(),
346 f.text_msg("hi").event_id(thread_root).sender(own_user_id).into(),
347 );
348 cache.events.insert(
349 event_id.to_owned(),
350 f.text_msg("ho")
351 .in_thread(thread_root, thread_root)
352 .event_id(event_id)
353 .sender(own_user_id)
354 .into(),
355 );
356
357 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
358
359 let reply_event = make_reply_event(
360 cache,
361 own_user_id,
362 content,
363 Reply {
364 event_id: event_id.into(),
365 enforce_thread: EnforceThread::Threaded(ReplyWithinThread::No),
366 },
367 )
368 .await
369 .unwrap();
370
371 assert_let!(Some(Relation::Thread(thread)) = &reply_event.relates_to);
372
373 assert_eq!(thread.event_id, thread_root);
374 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
375 assert!(thread.is_falling_back);
376 }
377
378 #[async_test]
379 async fn test_reply_on_thread_as_reply() {
380 let thread_root = event_id!("$1");
381 let event_id = event_id!("$2");
382 let own_user_id = user_id!("@me:saucisse.bzh");
383
384 let mut cache = TestEventCache::default();
385 let f = EventFactory::new();
386 cache.events.insert(
387 thread_root.to_owned(),
388 f.text_msg("hi").event_id(thread_root).sender(own_user_id).into(),
389 );
390 cache.events.insert(
391 event_id.to_owned(),
392 f.text_msg("ho")
393 .in_thread(thread_root, thread_root)
394 .event_id(event_id)
395 .sender(own_user_id)
396 .into(),
397 );
398
399 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
400
401 let reply_event = make_reply_event(
402 cache,
403 own_user_id,
404 content,
405 Reply {
406 event_id: event_id.into(),
407 enforce_thread: EnforceThread::Threaded(ReplyWithinThread::Yes),
408 },
409 )
410 .await
411 .unwrap();
412
413 assert_let!(Some(Relation::Thread(thread)) = &reply_event.relates_to);
414
415 assert_eq!(thread.event_id, thread_root);
416 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
417 assert!(!thread.is_falling_back);
418 }
419
420 #[async_test]
421 async fn test_reply_forwarding_thread() {
422 let thread_root = event_id!("$1");
423 let event_id = event_id!("$2");
424 let own_user_id = user_id!("@me:saucisse.bzh");
425
426 let mut cache = TestEventCache::default();
427 let f = EventFactory::new();
428 cache.events.insert(
429 thread_root.to_owned(),
430 f.text_msg("hi").event_id(thread_root).sender(own_user_id).into(),
431 );
432 cache.events.insert(
433 event_id.to_owned(),
434 f.text_msg("ho")
435 .in_thread(thread_root, thread_root)
436 .event_id(event_id)
437 .sender(own_user_id)
438 .into(),
439 );
440
441 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
442
443 let reply_event = make_reply_event(
444 cache,
445 own_user_id,
446 content,
447 Reply { event_id: event_id.into(), enforce_thread: EnforceThread::MaybeThreaded },
448 )
449 .await
450 .unwrap();
451
452 assert_let!(Some(Relation::Thread(thread)) = &reply_event.relates_to);
453
454 assert_eq!(thread.event_id, thread_root);
455 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
456 assert!(thread.is_falling_back);
457 }
458
459 #[async_test]
460 async fn test_reply_forwarding_thread_for_poll_start() {
461 let thread_root = event_id!("$thread_root");
462 let event_id = event_id!("$thread_reply");
463 let own_user_id = user_id!("@me:saucisse.bzh");
464
465 let mut cache = TestEventCache::default();
466 let f = EventFactory::new();
467
468 cache.events.insert(
469 event_id.to_owned(),
470 f.poll_start(
471 "would you rather… A) eat a pineapple pizza, B) drink pickle juice",
472 "would you rather…",
473 vec!["eat a pineapple pizza", "drink pickle juice"],
474 )
475 .in_thread(thread_root, thread_root)
476 .event_id(event_id)
477 .sender(own_user_id)
478 .into(),
479 );
480
481 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
482
483 let reply_event = make_reply_event(
484 cache,
485 own_user_id,
486 content,
487 Reply {
488 event_id: event_id.into(),
489 enforce_thread: EnforceThread::Threaded(ReplyWithinThread::No),
490 },
491 )
492 .await
493 .unwrap();
494
495 assert_let!(Some(Relation::Thread(thread)) = &reply_event.relates_to);
496
497 assert_eq!(thread.event_id, thread_root);
498 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
499 assert!(thread.is_falling_back);
500 }
501}