1#![allow(rustdoc::private_intra_doc_links)]
16#![doc = include_str!("README.md")]
17
18use std::{fmt, time::Duration};
19
20use async_channel::{Receiver, Sender};
21use futures_util::StreamExt;
22use matrix_sdk_common::executor::spawn;
23use ruma::api::client::delayed_events::DelayParameters;
24use serde::de::{self, Deserialize, Deserializer, Visitor};
25use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
26use tokio_stream::wrappers::UnboundedReceiverStream;
27use tokio_util::sync::{CancellationToken, DropGuard};
28
29use self::{
30 machine::{
31 Action, IncomingMessage, MatrixDriverRequestData, MatrixDriverResponse, SendEventRequest,
32 WidgetMachine,
33 },
34 matrix::MatrixDriver,
35};
36use crate::{room::Room, Result};
37
38mod capabilities;
39mod filter;
40mod machine;
41mod matrix;
42mod settings;
43
44pub use self::{
45 capabilities::{Capabilities, CapabilitiesProvider},
46 filter::{Filter, MessageLikeEventFilter, StateEventFilter, ToDeviceEventFilter},
47 settings::{
48 ClientProperties, EncryptionSystem, Intent, VirtualElementCallWidgetOptions, WidgetSettings,
49 },
50};
51
52#[derive(Debug)]
55pub struct WidgetDriver {
56 settings: WidgetSettings,
57
58 from_widget_rx: Receiver<String>,
62
63 to_widget_tx: Sender<String>,
68
69 event_forwarding_guard: Option<DropGuard>,
74}
75
76#[derive(Clone, Debug)]
79pub struct WidgetDriverHandle {
80 to_widget_rx: Receiver<String>,
87
88 from_widget_tx: Sender<String>,
95}
96
97impl WidgetDriverHandle {
98 pub async fn recv(&self) -> Option<String> {
104 self.to_widget_rx.recv().await.ok()
105 }
106
107 pub async fn send(&self, message: String) -> bool {
111 self.from_widget_tx.send(message).await.is_ok()
112 }
113}
114
115impl WidgetDriver {
116 pub fn new(settings: WidgetSettings) -> (Self, WidgetDriverHandle) {
119 let (from_widget_tx, from_widget_rx) = async_channel::unbounded();
120 let (to_widget_tx, to_widget_rx) = async_channel::unbounded();
121
122 let driver = Self { settings, from_widget_rx, to_widget_tx, event_forwarding_guard: None };
123 let channels = WidgetDriverHandle { from_widget_tx, to_widget_rx };
124
125 (driver, channels)
126 }
127
128 pub async fn run(
133 mut self,
134 room: Room,
135 capabilities_provider: impl CapabilitiesProvider,
136 ) -> Result<(), ()> {
137 let (incoming_msg_tx, incoming_msg_rx) = unbounded_channel();
144
145 spawn({
152 let incoming_msg_tx = incoming_msg_tx.clone();
153 let from_widget_rx = self.from_widget_rx.clone();
154
155 async move {
156 while let Ok(msg) = from_widget_rx.recv().await {
157 let _ = incoming_msg_tx.send(IncomingMessage::WidgetMessage(msg));
158 }
159 }
160 });
161
162 let (mut widget_machine, initial_actions) = WidgetMachine::new(
166 self.settings.widget_id().to_owned(),
167 room.room_id().to_owned(),
168 self.settings.init_on_content_load(),
169 );
170
171 let matrix_driver = MatrixDriver::new(room.clone());
172
173 let stream = UnboundedReceiverStream::new(incoming_msg_rx)
175 .flat_map(|message| tokio_stream::iter(widget_machine.process(message)));
176
177 let mut combined = tokio_stream::iter(initial_actions).chain(stream);
179
180 while let Some(action) = combined.next().await {
182 self.process_action(&matrix_driver, &incoming_msg_tx, &capabilities_provider, action)
183 .await?;
184 }
185
186 Ok(())
187 }
188
189 async fn process_action(
191 &mut self,
192 matrix_driver: &MatrixDriver,
193 incoming_msg_tx: &UnboundedSender<IncomingMessage>,
194 capabilities_provider: &impl CapabilitiesProvider,
195 action: Action,
196 ) -> Result<(), ()> {
197 match action {
198 Action::SendToWidget(msg) => {
199 self.to_widget_tx.send(msg).await.map_err(|_| ())?;
200 }
201
202 Action::MatrixDriverRequest { request_id, data } => {
203 let response = match data {
204 MatrixDriverRequestData::AcquireCapabilities(cmd) => {
205 let obtained = capabilities_provider
206 .acquire_capabilities(cmd.desired_capabilities)
207 .await;
208 Ok(MatrixDriverResponse::CapabilitiesAcquired(obtained))
209 }
210
211 MatrixDriverRequestData::GetOpenId => {
212 matrix_driver.get_open_id().await.map(MatrixDriverResponse::OpenIdReceived)
213 }
214
215 MatrixDriverRequestData::ReadEvents(cmd) => matrix_driver
216 .read_events(cmd.event_type.into(), cmd.state_key, cmd.limit)
217 .await
218 .map(MatrixDriverResponse::EventsRead),
219
220 MatrixDriverRequestData::ReadState(cmd) => matrix_driver
221 .read_state(cmd.event_type.into(), &cmd.state_key)
222 .await
223 .map(MatrixDriverResponse::StateRead),
224
225 MatrixDriverRequestData::SendEvent(req) => {
226 let SendEventRequest { event_type, state_key, content, delay } = req;
227 let delay_event_parameter = delay.map(|d| DelayParameters::Timeout {
232 timeout: Duration::from_millis(d),
233 });
234 matrix_driver
235 .send(event_type.into(), state_key, content, delay_event_parameter)
236 .await
237 .map(MatrixDriverResponse::EventSent)
238 }
239
240 MatrixDriverRequestData::UpdateDelayedEvent(req) => matrix_driver
241 .update_delayed_event(req.delay_id, req.action)
242 .await
243 .map(MatrixDriverResponse::DelayedEventUpdated),
244
245 MatrixDriverRequestData::SendToDeviceEvent(send_to_device_request) => {
246 matrix_driver
247 .send_to_device(
248 send_to_device_request.event_type.into(),
249 send_to_device_request.messages,
250 )
251 .await
252 .map(MatrixDriverResponse::ToDeviceSent)
253 }
254 };
255
256 incoming_msg_tx
258 .send(IncomingMessage::MatrixDriverResponse { request_id, response })
259 .map_err(|_| ())?;
260 }
261
262 Action::Subscribe => {
263 if self.event_forwarding_guard.is_some() {
265 return Ok(());
266 }
267
268 let (stop_forwarding, guard) = {
269 let token = CancellationToken::new();
270 (token.child_token(), token.drop_guard())
271 };
272
273 self.event_forwarding_guard = Some(guard);
274
275 let mut events = matrix_driver.events();
276 let mut state_updates = matrix_driver.state_updates();
277 let mut to_device_events = matrix_driver.to_device_events();
278 let incoming_msg_tx = incoming_msg_tx.clone();
279
280 spawn(async move {
281 loop {
282 tokio::select! {
283 _ = stop_forwarding.cancelled() => {
284 return;
286 }
287
288 Some(event) = events.recv() => {
289 let _ = incoming_msg_tx.send(IncomingMessage::MatrixEventReceived(event));
291 }
292
293 Ok(state) = state_updates.recv() => {
294 let _ = incoming_msg_tx.send(IncomingMessage::StateUpdateReceived(state));
296 }
297
298 Some(event) = to_device_events.recv() => {
299 let _ = incoming_msg_tx.send(IncomingMessage::ToDeviceReceived(event));
301 }
302 }
303 }
304 });
305 }
306
307 Action::Unsubscribe => {
308 self.event_forwarding_guard = None;
309 }
310 }
311
312 Ok(())
313 }
314}
315
316#[derive(Clone, Debug)]
318pub(crate) enum StateKeySelector {
319 Key(String),
320 Any,
321}
322
323impl<'de> Deserialize<'de> for StateKeySelector {
324 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
325 where
326 D: Deserializer<'de>,
327 {
328 struct StateKeySelectorVisitor;
329
330 impl Visitor<'_> for StateKeySelectorVisitor {
331 type Value = StateKeySelector;
332
333 fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
334 write!(f, "a string or `true`")
335 }
336
337 fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
338 where
339 E: de::Error,
340 {
341 if v {
342 Ok(StateKeySelector::Any)
343 } else {
344 Err(E::invalid_value(de::Unexpected::Bool(v), &self))
345 }
346 }
347
348 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
349 where
350 E: de::Error,
351 {
352 self.visit_string(v.to_owned())
353 }
354
355 fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
356 where
357 E: de::Error,
358 {
359 Ok(StateKeySelector::Key(v))
360 }
361 }
362
363 deserializer.deserialize_any(StateKeySelectorVisitor)
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use assert_matches::assert_matches;
370 use serde_json::json;
371
372 use super::StateKeySelector;
373
374 #[test]
375 fn state_key_selector_from_true() {
376 let state_key = serde_json::from_value(json!(true)).unwrap();
377 assert_matches!(state_key, StateKeySelector::Any);
378 }
379
380 #[test]
381 fn state_key_selector_from_string() {
382 let state_key = serde_json::from_value(json!("test")).unwrap();
383 assert_matches!(state_key, StateKeySelector::Key(k) if k == "test");
384 }
385
386 #[test]
387 fn state_key_selector_from_false() {
388 let result = serde_json::from_value::<StateKeySelector>(json!(false));
389 assert_matches!(result, Err(e) if e.is_data());
390 }
391
392 #[test]
393 fn state_key_selector_from_number() {
394 let result = serde_json::from_value::<StateKeySelector>(json!(5));
395 assert_matches!(result, Err(e) if e.is_data());
396 }
397}