1use std::collections::BTreeMap;
19
20use matrix_sdk_base::deserialized_responses::RawAnySyncOrStrippedState;
21use ruma::{
22 api::client::{
23 account::request_openid_token::v3::{Request as OpenIdRequest, Response as OpenIdResponse},
24 delayed_events::{self, update_delayed_event::unstable::UpdateAction},
25 filter::RoomEventFilter,
26 },
27 assign,
28 events::{
29 AnyMessageLikeEventContent, AnyStateEventContent, AnySyncMessageLikeEvent,
30 AnySyncStateEvent, AnySyncTimelineEvent, AnyTimelineEvent, MessageLikeEventType,
31 StateEventType, TimelineEventType,
32 },
33 serde::{from_raw_json_value, Raw},
34 EventId, RoomId, TransactionId,
35};
36use serde_json::{value::RawValue as RawJsonValue, Value};
37use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver};
38use tracing::error;
39
40use super::{machine::SendEventResponse, StateKeySelector};
41use crate::{event_handler::EventHandlerDropGuard, room::MessagesOptions, Error, Result, Room};
42
43pub(crate) struct MatrixDriver {
46 room: Room,
47}
48
49impl MatrixDriver {
50 pub(crate) fn new(room: Room) -> Self {
52 Self { room }
53 }
54
55 pub(crate) async fn get_open_id(&self) -> Result<OpenIdResponse> {
57 let user_id = self.room.own_user_id().to_owned();
58 self.room.client.send(OpenIdRequest::new(user_id)).await.map_err(Error::Http)
59 }
60
61 pub(crate) async fn read_message_like_events(
63 &self,
64 event_type: MessageLikeEventType,
65 limit: u32,
66 ) -> Result<Vec<Raw<AnyTimelineEvent>>> {
67 let options = assign!(MessagesOptions::backward(), {
68 limit: limit.into(),
69 filter: assign!(RoomEventFilter::default(), {
70 types: Some(vec![event_type.to_string()])
71 }),
72 });
73
74 let messages = self.room.messages(options).await?;
75 Ok(messages.chunk.into_iter().map(|ev| ev.into_raw().cast()).collect())
76 }
77
78 pub(crate) async fn read_state_events(
79 &self,
80 event_type: StateEventType,
81 state_key: &StateKeySelector,
82 ) -> Result<Vec<Raw<AnyTimelineEvent>>> {
83 let room_id = self.room.room_id();
84 let convert = |sync_or_stripped_state| match sync_or_stripped_state {
85 RawAnySyncOrStrippedState::Sync(ev) => Some(attach_room_id(ev.cast_ref(), room_id)),
86 RawAnySyncOrStrippedState::Stripped(_) => {
87 error!("MatrixDriver can't operate in invited rooms");
88 None
89 }
90 };
91
92 let events = match state_key {
93 StateKeySelector::Key(state_key) => self
94 .room
95 .get_state_event(event_type, state_key)
96 .await?
97 .and_then(convert)
98 .into_iter()
99 .collect(),
100 StateKeySelector::Any => {
101 let events = self.room.get_state_events(event_type).await?;
102 events.into_iter().filter_map(convert).collect()
103 }
104 };
105
106 Ok(events)
107 }
108
109 pub(crate) async fn send(
115 &self,
116 event_type: TimelineEventType,
117 state_key: Option<String>,
118 content: Box<RawJsonValue>,
119 delayed_event_parameters: Option<delayed_events::DelayParameters>,
120 ) -> Result<SendEventResponse> {
121 let type_str = event_type.to_string();
122
123 if let Some(redacts) = from_raw_json_value::<Value, serde_json::Error>(&content)
124 .ok()
125 .and_then(|b| b["redacts"].as_str().and_then(|s| EventId::parse(s).ok()))
126 {
127 return Ok(SendEventResponse::from_event_id(
128 self.room.redact(&redacts, None, None).await?.event_id,
129 ));
130 }
131
132 Ok(match (state_key, delayed_event_parameters) {
133 (None, None) => SendEventResponse::from_event_id(
134 self.room.send_raw(&type_str, content).await?.event_id,
135 ),
136
137 (Some(key), None) => SendEventResponse::from_event_id(
138 self.room.send_state_event_raw(&type_str, &key, content).await?.event_id,
139 ),
140
141 (None, Some(delayed_event_parameters)) => {
142 let r = delayed_events::delayed_message_event::unstable::Request::new_raw(
143 self.room.room_id().to_owned(),
144 TransactionId::new(),
145 MessageLikeEventType::from(type_str),
146 delayed_event_parameters,
147 Raw::<AnyMessageLikeEventContent>::from_json(content),
148 );
149 self.room.client.send(r).await.map(|r| r.into())?
150 }
151
152 (Some(key), Some(delayed_event_parameters)) => {
153 let r = delayed_events::delayed_state_event::unstable::Request::new_raw(
154 self.room.room_id().to_owned(),
155 key,
156 StateEventType::from(type_str),
157 delayed_event_parameters,
158 Raw::<AnyStateEventContent>::from_json(content),
159 );
160 self.room.client.send(r).await.map(|r| r.into())?
161 }
162 })
163 }
164
165 pub(crate) async fn update_delayed_event(
170 &self,
171 delay_id: String,
172 action: UpdateAction,
173 ) -> Result<delayed_events::update_delayed_event::unstable::Response> {
174 let r = delayed_events::update_delayed_event::unstable::Request::new(delay_id, action);
175 self.room.client.send(r).await.map_err(Error::Http)
176 }
177
178 pub(crate) fn events(&self) -> EventReceiver {
181 let (tx, rx) = unbounded_channel();
182 let room_id = self.room.room_id().to_owned();
183
184 let _tx = tx.clone();
186 let _room_id = room_id.clone();
187 let handle_msg_like =
188 self.room.add_event_handler(move |raw: Raw<AnySyncMessageLikeEvent>| {
189 let _ = _tx.send(attach_room_id(raw.cast_ref(), &_room_id));
190 async {}
191 });
192 let drop_guard_msg_like = self.room.client().event_handler_drop_guard(handle_msg_like);
193
194 let handle_state = self.room.add_event_handler(move |raw: Raw<AnySyncStateEvent>| {
196 let _ = tx.send(attach_room_id(raw.cast_ref(), &room_id));
197 async {}
198 });
199 let drop_guard_state = self.room.client().event_handler_drop_guard(handle_state);
200
201 EventReceiver { rx, _drop_guards: [drop_guard_msg_like, drop_guard_state] }
208 }
209}
210
211pub(crate) struct EventReceiver {
214 rx: UnboundedReceiver<Raw<AnyTimelineEvent>>,
215 _drop_guards: [EventHandlerDropGuard; 2],
216}
217
218impl EventReceiver {
219 pub(crate) async fn recv(&mut self) -> Option<Raw<AnyTimelineEvent>> {
220 self.rx.recv().await
221 }
222}
223
224fn attach_room_id(raw_ev: &Raw<AnySyncTimelineEvent>, room_id: &RoomId) -> Raw<AnyTimelineEvent> {
225 let mut ev_obj = raw_ev.deserialize_as::<BTreeMap<String, Box<RawJsonValue>>>().unwrap();
226 ev_obj.insert("room_id".to_owned(), serde_json::value::to_raw_value(room_id).unwrap());
227 Raw::new(&ev_obj).unwrap().cast()
228}