1use std::collections::{BTreeMap, BTreeSet};
19
20use as_variant::as_variant;
21use matrix_sdk_base::{
22 crypto::CollectStrategy,
23 deserialized_responses::{EncryptionInfo, RawAnySyncOrStrippedState},
24 sync::State,
25};
26use ruma::{
27 EventId, OwnedDeviceId, OwnedUserId, RoomId, TransactionId,
28 api::client::{
29 account::request_openid_token::v3::{Request as OpenIdRequest, Response as OpenIdResponse},
30 delayed_events::{self, update_delayed_event::unstable::UpdateAction},
31 filter::RoomEventFilter,
32 to_device::send_event_to_device::v3::Request as RumaToDeviceRequest,
33 },
34 assign,
35 events::{
36 AnyMessageLikeEventContent, AnyStateEvent, AnyStateEventContent, AnySyncStateEvent,
37 AnySyncTimelineEvent, AnyTimelineEvent, AnyToDeviceEvent, AnyToDeviceEventContent,
38 MessageLikeEventType, StateEventType, TimelineEventType, ToDeviceEventType,
39 },
40 serde::{Raw, from_raw_json_value},
41 to_device::DeviceIdOrAllDevices,
42};
43use serde::{Deserialize, Serialize};
44use serde_json::{Value, value::RawValue as RawJsonValue};
45use tokio::sync::{
46 broadcast::{Receiver, error::RecvError},
47 mpsc::{UnboundedReceiver, unbounded_channel},
48};
49use tracing::{error, trace, warn};
50
51use super::{StateKeySelector, machine::SendEventResponse};
52use crate::{
53 Client, Error, Result, Room, event_handler::EventHandlerDropGuard, room::MessagesOptions,
54 sync::RoomUpdate, widget::machine::SendToDeviceEventResponse,
55};
56
57pub(crate) struct MatrixDriver {
60 room: Room,
61}
62
63impl MatrixDriver {
64 pub(crate) fn new(room: Room) -> Self {
66 Self { room }
67 }
68
69 pub(crate) async fn get_open_id(&self) -> Result<OpenIdResponse> {
71 let user_id = self.room.own_user_id().to_owned();
72 self.room
73 .client
74 .send(OpenIdRequest::new(user_id))
75 .await
76 .map_err(|error| Error::Http(Box::new(error)))
77 }
78
79 pub(crate) async fn read_events(
82 &self,
83 event_type: TimelineEventType,
84 state_key: Option<StateKeySelector>,
85 limit: u32,
86 ) -> Result<Vec<Raw<AnyTimelineEvent>>> {
87 let options = assign!(MessagesOptions::backward(), {
88 limit: limit.into(),
89 filter: assign!(RoomEventFilter::default(), {
90 types: Some(vec![event_type.to_string()])
91 }),
92 });
93
94 let messages = self.room.messages(options).await?;
95
96 Ok(messages
97 .chunk
98 .into_iter()
99 .map(|ev| ev.into_raw().cast_unchecked())
100 .filter(|ev| match &state_key {
101 Some(state_key) => {
102 ev.get_field::<String>("state_key").is_ok_and(|key| match state_key {
103 StateKeySelector::Key(state_key) => {
104 key.is_some_and(|key| &key == state_key)
105 }
106 StateKeySelector::Any => key.is_some(),
107 })
108 }
109 None => true,
110 })
111 .collect())
112 }
113
114 pub(crate) async fn read_state(
117 &self,
118 event_type: StateEventType,
119 state_key: &StateKeySelector,
120 ) -> Result<Vec<Raw<AnyStateEvent>>> {
121 let room_id = self.room.room_id();
122 let convert = |sync_or_stripped_state| match sync_or_stripped_state {
123 RawAnySyncOrStrippedState::Sync(ev) => Some(attach_room_id_state(&ev, room_id)),
124 RawAnySyncOrStrippedState::Stripped(_) => {
125 error!("MatrixDriver can't operate in invited rooms");
126 None
127 }
128 };
129
130 let events = match state_key {
131 StateKeySelector::Key(state_key) => self
132 .room
133 .get_state_event(event_type, state_key)
134 .await?
135 .and_then(convert)
136 .into_iter()
137 .collect(),
138 StateKeySelector::Any => {
139 let events = self.room.get_state_events(event_type).await?;
140 events.into_iter().filter_map(convert).collect()
141 }
142 };
143
144 Ok(events)
145 }
146
147 pub(crate) async fn send(
153 &self,
154 event_type: TimelineEventType,
155 state_key: Option<String>,
156 content: Box<RawJsonValue>,
157 delayed_event_parameters: Option<delayed_events::DelayParameters>,
158 ) -> Result<SendEventResponse> {
159 let type_str = event_type.to_string();
160
161 if let Some(redacts) = from_raw_json_value::<Value, serde_json::Error>(&content)
162 .ok()
163 .and_then(|b| b["redacts"].as_str().and_then(|s| EventId::parse(s).ok()))
164 {
165 return Ok(SendEventResponse::from_event_id(
166 self.room.redact(&redacts, None, None).await?.event_id,
167 ));
168 }
169
170 Ok(match (state_key, delayed_event_parameters) {
171 (None, None) => SendEventResponse::from_event_id(
172 self.room.send_raw(&type_str, content).await?.response.event_id,
173 ),
174
175 (Some(key), None) => SendEventResponse::from_event_id(
176 self.room.send_state_event_raw(&type_str, &key, content).await?.event_id,
177 ),
178
179 (None, Some(delayed_event_parameters)) => {
180 let r = delayed_events::delayed_message_event::unstable::Request::new_raw(
181 self.room.room_id().to_owned(),
182 TransactionId::new(),
183 MessageLikeEventType::from(type_str),
184 delayed_event_parameters,
185 Raw::<AnyMessageLikeEventContent>::from_json(content),
186 );
187 self.room.client.send(r).await.map(|r| r.into())?
188 }
189
190 (Some(key), Some(delayed_event_parameters)) => {
191 let r = delayed_events::delayed_state_event::unstable::Request::new_raw(
192 self.room.room_id().to_owned(),
193 key,
194 StateEventType::from(type_str),
195 delayed_event_parameters,
196 Raw::<AnyStateEventContent>::from_json(content),
197 );
198 self.room.client.send(r).await.map(|r| r.into())?
199 }
200 })
201 }
202
203 pub(crate) async fn update_delayed_event(
208 &self,
209 delay_id: String,
210 action: UpdateAction,
211 ) -> Result<delayed_events::update_delayed_event::unstable::Response> {
212 let r = delayed_events::update_delayed_event::unstable::Request::new(delay_id, action);
213 self.room.client.send(r).await.map_err(|error| Error::Http(Box::new(error)))
214 }
215
216 pub(crate) fn events(&self) -> EventReceiver<Raw<AnyTimelineEvent>> {
219 let (tx, rx) = unbounded_channel();
220 let room_id = self.room.room_id().to_owned();
221
222 let handle = self.room.add_event_handler(move |raw: Raw<AnySyncTimelineEvent>| {
223 let _ = tx.send(attach_room_id(&raw, &room_id));
224 async {}
225 });
226 let drop_guard = self.room.client().event_handler_drop_guard(handle);
227
228 EventReceiver { rx, _drop_guard: drop_guard }
232 }
233
234 pub(crate) fn state_updates(&self) -> StateUpdateReceiver {
236 StateUpdateReceiver { room_updates: self.room.subscribe_to_updates() }
237 }
238
239 pub(crate) fn to_device_events(&self) -> EventReceiver<Raw<AnyToDeviceEvent>> {
242 let (tx, rx) = unbounded_channel();
243
244 let room_id = self.room.room_id().to_owned();
245 let to_device_handle = self.room.client().add_event_handler(
246
247 async move |raw: Raw<AnyToDeviceEvent>, encryption_info: Option<EncryptionInfo>, client: Client| {
248
249 if Self::should_filter_message_to_widget(&raw) {
252 return;
253 }
254
255 let Some(room) = client.get_room(&room_id) else {
258 warn!("Room {room_id} not found in client.");
259 return;
260 };
261
262 let room_encrypted = room.latest_encryption_state().await
263 .map(|s| s.is_encrypted())
264 .unwrap_or(true);
266 if room_encrypted {
267 if encryption_info.is_none() {
269 warn!(
270 ?room_id,
271 "Received to-device event in clear for a widget in an e2e room, dropping."
272 );
273 return;
274 }
275
276 #[derive(Deserialize, Serialize)]
282 struct CleanEventHelper<'a> {
283 #[serde(rename = "type")]
284 event_type: String,
285 #[serde(borrow)]
286 content: &'a RawJsonValue,
287 sender: String,
288 }
289
290 let _ = serde_json::from_str::<CleanEventHelper<'_>>(raw.json().get())
291 .and_then(|clean_event_helper| {
292 serde_json::value::to_raw_value(&clean_event_helper)
293 })
294 .map_err(|err| warn!(?room_id, "Unable to process to-device message for widget: {err}"))
295 .map(|box_value | {
296 tx.send(Raw::from_json(box_value))
297 });
298
299 } else {
300 let _ = tx.send(raw);
303 }
304 },
305 );
306
307 let drop_guard = self.room.client().event_handler_drop_guard(to_device_handle);
308 EventReceiver { rx, _drop_guard: drop_guard }
309 }
310
311 fn should_filter_message_to_widget(raw_message: &Raw<AnyToDeviceEvent>) -> bool {
312 let Ok(Some(event_type)) = raw_message.get_field::<String>("type") else {
313 trace!("Invalid to-device message (no type) filtered out by widget driver.");
314 return true;
315 };
316
317 let filtered = Self::is_internal_type(event_type.as_str());
321
322 if filtered {
323 trace!("To-device message of type <{event_type}> filtered out by widget driver.",);
324 }
325 filtered
326 }
327
328 fn is_internal_type(event_type: &str) -> bool {
329 matches!(
330 event_type,
331 "m.dummy"
332 | "m.room_key"
333 | "m.room_key_request"
334 | "m.forwarded_room_key"
335 | "m.key.verification.request"
336 | "m.key.verification.ready"
337 | "m.key.verification.start"
338 | "m.key.verification.cancel"
339 | "m.key.verification.accept"
340 | "m.key.verification.key"
341 | "m.key.verification.mac"
342 | "m.key.verification.done"
343 | "m.secret.request"
344 | "m.secret.send"
345 | "m.room.encrypted"
347 )
348 }
349
350 pub(crate) async fn send_to_device(
354 &self,
355 event_type: ToDeviceEventType,
356 messages: BTreeMap<
357 OwnedUserId,
358 BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>>,
359 >,
360 ) -> Result<SendToDeviceEventResponse> {
361 if Self::is_internal_type(&event_type.to_string()) {
364 warn!("Widget tried to send internal to-device message <{}>, ignoring", event_type);
365 return Ok(Default::default());
367 }
368
369 let client = self.room.client();
370
371 let mut failures: BTreeMap<OwnedUserId, Vec<OwnedDeviceId>> = BTreeMap::new();
372
373 let room_encrypted = self
374 .room
375 .latest_encryption_state()
376 .await
377 .map(|s| s.is_encrypted())
378 .unwrap_or(true);
380
381 if room_encrypted {
382 trace!("Sending to-device message in encrypted room <{}>", self.room.room_id());
383
384 let mut content_to_recipients_map: BTreeMap<
389 &str,
390 BTreeMap<OwnedUserId, Vec<DeviceIdOrAllDevices>>,
391 > = BTreeMap::new();
392
393 for (user_id, device_map) in messages.iter() {
394 for (device_id, content) in device_map.iter() {
395 content_to_recipients_map
396 .entry(content.json().get())
397 .or_default()
398 .entry(user_id.clone())
399 .or_default()
400 .push(device_id.to_owned());
401 }
402 }
403
404 for (content, user_to_list_of_device_id_or_all) in content_to_recipients_map {
406 self.encrypt_and_send_content_to_devices_helper(
407 &event_type,
408 content,
409 user_to_list_of_device_id_or_all,
410 &mut failures,
411 )
412 .await?
413 }
414
415 let failures = failures
416 .into_iter()
417 .map(|(u, list_of_devices)| {
418 (u.into(), list_of_devices.into_iter().map(|d| d.into()).collect())
419 })
420 .collect();
421
422 let response = SendToDeviceEventResponse { failures };
423 Ok(response)
424 } else {
425 let request = RumaToDeviceRequest::new_raw(event_type, TransactionId::new(), messages);
427 client.send(request).await?;
428 Ok(Default::default())
429 }
430 }
431
432 async fn encrypt_and_send_content_to_devices_helper(
434 &self,
435 event_type: &ToDeviceEventType,
436 content: &str,
437 user_to_list_of_device_id_or_all: BTreeMap<OwnedUserId, Vec<DeviceIdOrAllDevices>>,
438 failures: &mut BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>,
439 ) -> Result<()> {
440 let client = self.room.client();
441 let mut recipient_devices = Vec::<_>::new();
442
443 for (user_id, recipient_device_ids) in user_to_list_of_device_id_or_all {
444 let user_devices = client.encryption().get_user_devices(&user_id).await?;
445
446 let user_devices = if recipient_device_ids.contains(&DeviceIdOrAllDevices::AllDevices) {
447 let devices: Vec<_> = user_devices.devices().collect();
450 if devices.is_empty() {
452 warn!(
453 "Recipient list contains `AllDevices` but no devices found for user {user_id}."
454 )
455 }
456 if recipient_device_ids.len() > 1 {
459 warn!(
460 "The recipient_device_ids list for {user_id} contains both `AllDevices` and explicit `DeviceId` entries. Only consider `AllDevices`",
461 );
462 }
463 devices
464 } else {
465 let filtered_devices = user_devices
468 .devices()
469 .map(|device| (device.device_id().to_owned(), device))
470 .filter(|(device_id, _)| {
471 recipient_device_ids
472 .contains(&DeviceIdOrAllDevices::DeviceId(device_id.clone()))
473 });
474
475 let (found_device_ids, devices): (BTreeSet<_>, Vec<_>) = filtered_devices.unzip();
476
477 let list_of_devices: BTreeSet<_> = recipient_device_ids
478 .into_iter()
479 .filter_map(|d| as_variant!(d, DeviceIdOrAllDevices::DeviceId))
480 .collect();
481
482 let missing_devices: Vec<_> =
485 list_of_devices.difference(&found_device_ids).map(|d| d.to_owned()).collect();
486 if !missing_devices.is_empty() {
487 failures.insert(user_id, missing_devices);
488 }
489 devices
490 };
491
492 recipient_devices.extend(user_devices);
493 }
494
495 if !recipient_devices.is_empty() {
496 let encrypt_and_send_failures = client
498 .encryption()
499 .encrypt_and_send_raw_to_device(
500 recipient_devices.iter().collect(),
501 &event_type.to_string(),
502 Raw::from_json_string(content.to_owned())?,
503 CollectStrategy::AllDevices,
504 )
505 .await?;
506
507 for (user_id, device_id) in encrypt_and_send_failures {
508 failures.entry(user_id).or_default().push(device_id)
509 }
510 }
511
512 Ok(())
513 }
514}
515
516pub(crate) struct EventReceiver<E> {
519 rx: UnboundedReceiver<E>,
520 _drop_guard: EventHandlerDropGuard,
521}
522
523impl<T> EventReceiver<T> {
524 pub(crate) async fn recv(&mut self) -> Option<T> {
525 self.rx.recv().await
526 }
527}
528
529pub(crate) struct StateUpdateReceiver {
532 room_updates: Receiver<RoomUpdate>,
533}
534
535impl StateUpdateReceiver {
536 pub(crate) async fn recv(&mut self) -> Result<Vec<Raw<AnyStateEvent>>, RecvError> {
537 loop {
538 match self.room_updates.recv().await? {
539 RoomUpdate::Joined { room, updates } => {
540 let state_events = match updates.state {
541 State::Before(events) => events,
542 State::After(events) => events,
543 };
544
545 if !state_events.is_empty() {
546 return Ok(state_events
547 .into_iter()
548 .map(|ev| attach_room_id_state(&ev, room.room_id()))
549 .collect());
550 }
551 }
552 _ => {
553 error!("MatrixDriver can only operate in joined rooms");
554 return Err(RecvError::Closed);
555 }
556 }
557 }
558 }
559}
560
561fn attach_room_id(raw_ev: &Raw<AnySyncTimelineEvent>, room_id: &RoomId) -> Raw<AnyTimelineEvent> {
562 let mut ev_obj =
563 raw_ev.deserialize_as_unchecked::<BTreeMap<String, Box<RawJsonValue>>>().unwrap();
564 ev_obj.insert("room_id".to_owned(), serde_json::value::to_raw_value(room_id).unwrap());
565 Raw::new(&ev_obj).unwrap().cast_unchecked()
566}
567
568fn attach_room_id_state(raw_ev: &Raw<AnySyncStateEvent>, room_id: &RoomId) -> Raw<AnyStateEvent> {
569 attach_room_id(raw_ev.cast_ref(), room_id).cast_unchecked()
570}
571
572#[cfg(test)]
573mod tests {
574 use insta;
575 use ruma::{events::AnyTimelineEvent, room_id, serde::Raw};
576 use serde_json::{Value, json};
577
578 use super::attach_room_id;
579
580 #[test]
581 fn test_add_room_id_to_raw() {
582 let raw = Raw::new(&json!({
583 "type": "m.room.message",
584 "event_id": "$1676512345:example.org",
585 "sender": "@user:example.org",
586 "origin_server_ts": 1676512345,
587 "content": {
588 "msgtype": "m.text",
589 "body": "Hello world"
590 }
591 }))
592 .unwrap()
593 .cast_unchecked();
594 let room_id = room_id!("!my_id:example.org");
595 let new = attach_room_id(&raw, room_id);
596
597 insta::with_settings!({prepend_module_to_snapshot => false}, {
598 insta::assert_json_snapshot!(new.deserialize_as::<Value>().unwrap())
599 });
600
601 let attached: AnyTimelineEvent = new.deserialize().unwrap();
602 assert_eq!(attached.room_id(), room_id);
603 }
604
605 #[test]
606 fn test_add_room_id_to_raw_override() {
607 let raw = Raw::new(&json!({
610 "type": "m.room.message",
611 "event_id": "$1676512345:example.org",
612 "room_id": "!override_me:example.org",
613 "sender": "@user:example.org",
614 "origin_server_ts": 1676512345,
615 "content": {
616 "msgtype": "m.text",
617 "body": "Hello world"
618 }
619 }))
620 .unwrap()
621 .cast_unchecked();
622 let room_id = room_id!("!my_id:example.org");
623 let new = attach_room_id(&raw, room_id);
624
625 insta::with_settings!({prepend_module_to_snapshot => false}, {
626 insta::assert_json_snapshot!(new.deserialize_as::<Value>().unwrap())
627 });
628
629 let attached: AnyTimelineEvent = new.deserialize().unwrap();
630 assert_eq!(attached.room_id(), room_id);
631 }
632}