1use std::{fmt, time::Duration};
18
19use async_channel::{Receiver, Sender};
20use ruma::api::client::delayed_events::DelayParameters;
21use serde::de::{self, Deserialize, Deserializer, Visitor};
22use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
23use tokio_util::sync::{CancellationToken, DropGuard};
24
25use self::{
26 machine::{
27 Action, IncomingMessage, MatrixDriverRequestData, MatrixDriverResponse, SendEventRequest,
28 WidgetMachine,
29 },
30 matrix::MatrixDriver,
31};
32use crate::{room::Room, Result};
33
34mod capabilities;
35mod filter;
36mod machine;
37mod matrix;
38mod settings;
39
40pub use self::{
41 capabilities::{Capabilities, CapabilitiesProvider},
42 filter::{EventFilter, MessageLikeEventFilter, StateEventFilter},
43 settings::{
44 ClientProperties, EncryptionSystem, VirtualElementCallWidgetOptions, WidgetSettings,
45 },
46};
47
48#[derive(Debug)]
51pub struct WidgetDriver {
52 settings: WidgetSettings,
53
54 from_widget_rx: Receiver<String>,
58
59 to_widget_tx: Sender<String>,
64
65 event_forwarding_guard: Option<DropGuard>,
70}
71
72#[derive(Clone, Debug)]
75pub struct WidgetDriverHandle {
76 to_widget_rx: Receiver<String>,
83
84 from_widget_tx: Sender<String>,
91}
92
93impl WidgetDriverHandle {
94 pub async fn recv(&self) -> Option<String> {
100 self.to_widget_rx.recv().await.ok()
101 }
102
103 pub async fn send(&self, message: String) -> bool {
107 self.from_widget_tx.send(message).await.is_ok()
108 }
109}
110
111impl WidgetDriver {
112 pub fn new(settings: WidgetSettings) -> (Self, WidgetDriverHandle) {
115 let (from_widget_tx, from_widget_rx) = async_channel::unbounded();
116 let (to_widget_tx, to_widget_rx) = async_channel::unbounded();
117
118 let driver = Self { settings, from_widget_rx, to_widget_tx, event_forwarding_guard: None };
119 let channels = WidgetDriverHandle { from_widget_tx, to_widget_rx };
120
121 (driver, channels)
122 }
123
124 pub async fn run(
129 mut self,
130 room: Room,
131 capabilities_provider: impl CapabilitiesProvider,
132 ) -> Result<(), ()> {
133 let (incoming_msg_tx, mut incoming_msg_rx) = unbounded_channel();
140
141 tokio::spawn({
143 let incoming_msg_tx = incoming_msg_tx.clone();
144 let from_widget_rx = self.from_widget_rx.clone();
145 async move {
146 while let Ok(msg) = from_widget_rx.recv().await {
147 let _ = incoming_msg_tx.send(IncomingMessage::WidgetMessage(msg));
148 }
149 }
150 });
151
152 let (mut widget_machine, initial_actions) = WidgetMachine::new(
154 self.settings.widget_id().to_owned(),
155 room.room_id().to_owned(),
156 self.settings.init_on_content_load(),
157 );
158
159 let matrix_driver = MatrixDriver::new(room.clone());
160
161 for action in initial_actions {
163 self.process_action(&matrix_driver, &incoming_msg_tx, &capabilities_provider, action)
164 .await?;
165 }
166
167 while let Some(msg) = incoming_msg_rx.recv().await {
169 for action in widget_machine.process(msg) {
170 self.process_action(
171 &matrix_driver,
172 &incoming_msg_tx,
173 &capabilities_provider,
174 action,
175 )
176 .await?;
177 }
178 }
179
180 Ok(())
181 }
182
183 async fn process_action(
185 &mut self,
186 matrix_driver: &MatrixDriver,
187 incoming_msg_tx: &UnboundedSender<IncomingMessage>,
188 capabilities_provider: &impl CapabilitiesProvider,
189 action: Action,
190 ) -> Result<(), ()> {
191 match action {
192 Action::SendToWidget(msg) => {
193 self.to_widget_tx.send(msg).await.map_err(|_| ())?;
194 }
195
196 Action::MatrixDriverRequest { request_id, data } => {
197 let response = match data {
198 MatrixDriverRequestData::AcquireCapabilities(cmd) => {
199 let obtained = capabilities_provider
200 .acquire_capabilities(cmd.desired_capabilities)
201 .await;
202 Ok(MatrixDriverResponse::CapabilitiesAcquired(obtained))
203 }
204
205 MatrixDriverRequestData::GetOpenId => {
206 matrix_driver.get_open_id().await.map(MatrixDriverResponse::OpenIdReceived)
207 }
208
209 MatrixDriverRequestData::ReadMessageLikeEvent(cmd) => matrix_driver
210 .read_message_like_events(cmd.event_type.clone(), cmd.limit)
211 .await
212 .map(MatrixDriverResponse::MatrixEventRead),
213
214 MatrixDriverRequestData::ReadStateEvent(cmd) => matrix_driver
215 .read_state_events(cmd.event_type.clone(), &cmd.state_key)
216 .await
217 .map(MatrixDriverResponse::MatrixEventRead),
218
219 MatrixDriverRequestData::SendMatrixEvent(req) => {
220 let SendEventRequest { event_type, state_key, content, delay } = req;
221 let delay_event_parameter = delay.map(|d| DelayParameters::Timeout {
226 timeout: Duration::from_millis(d),
227 });
228 matrix_driver
229 .send(event_type, state_key, content, delay_event_parameter)
230 .await
231 .map(MatrixDriverResponse::MatrixEventSent)
232 }
233
234 MatrixDriverRequestData::UpdateDelayedEvent(req) => matrix_driver
235 .update_delayed_event(req.delay_id, req.action)
236 .await
237 .map(MatrixDriverResponse::MatrixDelayedEventUpdate),
238 };
239
240 incoming_msg_tx
242 .send(IncomingMessage::MatrixDriverResponse { request_id, response })
243 .map_err(|_| ())?;
244 }
245
246 Action::Subscribe => {
247 if self.event_forwarding_guard.is_some() {
249 return Ok(());
250 }
251
252 let (stop_forwarding, guard) = {
253 let token = CancellationToken::new();
254 (token.child_token(), token.drop_guard())
255 };
256
257 self.event_forwarding_guard = Some(guard);
258
259 let mut matrix = matrix_driver.events();
260 let incoming_msg_tx = incoming_msg_tx.clone();
261
262 tokio::spawn(async move {
263 loop {
264 tokio::select! {
265 _ = stop_forwarding.cancelled() => {
266 return;
268 }
269
270 Some(event) = matrix.recv() => {
271 let _ = incoming_msg_tx.send(IncomingMessage::MatrixEventReceived(event));
273 }
274 }
275 }
276 });
277 }
278
279 Action::Unsubscribe => {
280 self.event_forwarding_guard = None;
281 }
282 }
283
284 Ok(())
285 }
286}
287
288#[derive(Clone, Debug)]
290pub(crate) enum StateKeySelector {
291 Key(String),
292 Any,
293}
294
295impl<'de> Deserialize<'de> for StateKeySelector {
296 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
297 where
298 D: Deserializer<'de>,
299 {
300 struct StateKeySelectorVisitor;
301
302 impl Visitor<'_> for StateKeySelectorVisitor {
303 type Value = StateKeySelector;
304
305 fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
306 write!(f, "a string or `true`")
307 }
308
309 fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
310 where
311 E: de::Error,
312 {
313 if v {
314 Ok(StateKeySelector::Any)
315 } else {
316 Err(E::invalid_value(de::Unexpected::Bool(v), &self))
317 }
318 }
319
320 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
321 where
322 E: de::Error,
323 {
324 self.visit_string(v.to_owned())
325 }
326
327 fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
328 where
329 E: de::Error,
330 {
331 Ok(StateKeySelector::Key(v))
332 }
333 }
334
335 deserializer.deserialize_any(StateKeySelectorVisitor)
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use assert_matches::assert_matches;
342 use serde_json::json;
343
344 use super::StateKeySelector;
345
346 #[test]
347 fn state_key_selector_from_true() {
348 let state_key = serde_json::from_value(json!(true)).unwrap();
349 assert_matches!(state_key, StateKeySelector::Any);
350 }
351
352 #[test]
353 fn state_key_selector_from_string() {
354 let state_key = serde_json::from_value(json!("test")).unwrap();
355 assert_matches!(state_key, StateKeySelector::Key(k) if k == "test");
356 }
357
358 #[test]
359 fn state_key_selector_from_false() {
360 let result = serde_json::from_value::<StateKeySelector>(json!(false));
361 assert_matches!(result, Err(e) if e.is_data());
362 }
363
364 #[test]
365 fn state_key_selector_from_number() {
366 let result = serde_json::from_value::<StateKeySelector>(json!(5));
367 assert_matches!(result, Err(e) if e.is_data());
368 }
369}