1#![allow(rustdoc::private_intra_doc_links)]
16#![doc = include_str!("README.md")]
17
18use std::{fmt, time::Duration};
19
20use async_channel::{Receiver, Sender};
21use futures_util::StreamExt;
22use matrix_sdk_common::executor::spawn;
23use ruma::api::client::delayed_events::DelayParameters;
24use serde::de::{self, Deserialize, Deserializer, Visitor};
25use tokio::sync::mpsc::{UnboundedSender, unbounded_channel};
26use tokio_stream::wrappers::UnboundedReceiverStream;
27use tokio_util::sync::{CancellationToken, DropGuard};
28
29use self::{
30 machine::{
31 Action, IncomingMessage, MatrixDriverRequestData, MatrixDriverResponse, SendEventRequest,
32 WidgetMachine,
33 },
34 matrix::MatrixDriver,
35};
36use crate::{Result, room::Room, widget::machine::DownloadFileResponse};
37
38mod capabilities;
39mod filter;
40mod machine;
41mod matrix;
42mod settings;
43
44pub use self::{
45 capabilities::{Capabilities, CapabilitiesProvider},
46 filter::{Filter, MessageLikeEventFilter, StateEventFilter, ToDeviceEventFilter},
47 settings::{
48 ClientProperties, EncryptionSystem, Intent, VirtualElementCallWidgetConfig,
49 VirtualElementCallWidgetProperties, WidgetSettings,
50 },
51};
52
53#[derive(Debug)]
56pub struct WidgetDriver {
57 settings: WidgetSettings,
58
59 from_widget_rx: Receiver<String>,
63
64 to_widget_tx: Sender<String>,
69
70 event_forwarding_guard: Option<DropGuard>,
75}
76
77#[derive(Clone, Debug)]
80pub struct WidgetDriverHandle {
81 to_widget_rx: Receiver<String>,
88
89 from_widget_tx: Sender<String>,
96}
97
98impl WidgetDriverHandle {
99 pub async fn recv(&self) -> Option<String> {
105 self.to_widget_rx.recv().await.ok()
106 }
107
108 pub async fn send(&self, message: String) -> bool {
112 self.from_widget_tx.send(message).await.is_ok()
113 }
114}
115
116impl WidgetDriver {
117 pub fn new(settings: WidgetSettings) -> (Self, WidgetDriverHandle) {
120 let (from_widget_tx, from_widget_rx) = async_channel::unbounded();
121 let (to_widget_tx, to_widget_rx) = async_channel::unbounded();
122
123 let driver = Self { settings, from_widget_rx, to_widget_tx, event_forwarding_guard: None };
124 let channels = WidgetDriverHandle { from_widget_tx, to_widget_rx };
125
126 (driver, channels)
127 }
128
129 pub async fn run(
134 mut self,
135 room: Room,
136 capabilities_provider: impl CapabilitiesProvider,
137 ) -> Result<(), ()> {
138 let (incoming_msg_tx, incoming_msg_rx) = unbounded_channel();
145
146 spawn({
153 let incoming_msg_tx = incoming_msg_tx.clone();
154 let from_widget_rx = self.from_widget_rx.clone();
155
156 async move {
157 while let Ok(msg) = from_widget_rx.recv().await {
158 let _ = incoming_msg_tx.send(IncomingMessage::WidgetMessage(msg));
159 }
160 }
161 });
162
163 let (mut widget_machine, initial_actions) = WidgetMachine::new(
167 self.settings.widget_id().to_owned(),
168 room.room_id().to_owned(),
169 self.settings.init_on_content_load(),
170 );
171
172 let matrix_driver = MatrixDriver::new(room.clone());
173
174 let stream = UnboundedReceiverStream::new(incoming_msg_rx)
176 .flat_map(|message| tokio_stream::iter(widget_machine.process(message)));
177
178 let mut combined = tokio_stream::iter(initial_actions).chain(stream);
180
181 while let Some(action) = combined.next().await {
183 self.process_action(&matrix_driver, &incoming_msg_tx, &capabilities_provider, action)
184 .await?;
185 }
186
187 Ok(())
188 }
189
190 async fn process_action(
192 &mut self,
193 matrix_driver: &MatrixDriver,
194 incoming_msg_tx: &UnboundedSender<IncomingMessage>,
195 capabilities_provider: &impl CapabilitiesProvider,
196 action: Action,
197 ) -> Result<(), ()> {
198 match action {
199 Action::SendToWidget(msg) => {
200 self.to_widget_tx.send(msg).await.map_err(|_| ())?;
201 }
202
203 Action::MatrixDriverRequest { request_id, data } => {
204 let response = match data {
205 MatrixDriverRequestData::AcquireCapabilities(cmd) => {
206 let obtained = capabilities_provider
207 .acquire_capabilities(cmd.desired_capabilities)
208 .await;
209 Ok(MatrixDriverResponse::CapabilitiesAcquired(obtained))
210 }
211
212 MatrixDriverRequestData::GetOpenId => {
213 matrix_driver.get_open_id().await.map(MatrixDriverResponse::OpenIdReceived)
214 }
215
216 MatrixDriverRequestData::ReadEvents(cmd) => matrix_driver
217 .read_events(cmd.event_type.into(), cmd.state_key, cmd.limit)
218 .await
219 .map(MatrixDriverResponse::EventsRead),
220
221 MatrixDriverRequestData::ReadState(cmd) => matrix_driver
222 .read_state(cmd.event_type.into(), &cmd.state_key)
223 .await
224 .map(MatrixDriverResponse::StateRead),
225
226 MatrixDriverRequestData::SendEvent(req) => {
227 let SendEventRequest { event_type, state_key, content, delay } = req;
228 let delay_event_parameter = delay.map(|d| DelayParameters::Timeout {
233 timeout: Duration::from_millis(d),
234 });
235 matrix_driver
236 .send(event_type.into(), state_key, content, delay_event_parameter)
237 .await
238 .map(MatrixDriverResponse::EventSent)
239 }
240
241 MatrixDriverRequestData::UpdateDelayedEvent(req) => matrix_driver
242 .update_delayed_event(req.delay_id, req.action)
243 .await
244 .map(MatrixDriverResponse::DelayedEventUpdated),
245
246 MatrixDriverRequestData::SendToDeviceEvent(send_to_device_request) => {
247 matrix_driver
248 .send_to_device(
249 send_to_device_request.event_type.into(),
250 send_to_device_request.messages,
251 )
252 .await
253 .map(MatrixDriverResponse::ToDeviceSent)
254 }
255 MatrixDriverRequestData::DownloadFile(req) => matrix_driver
256 .download_attachment(req.content_uri)
257 .await
258 .map(|file_data_base64| {
259 MatrixDriverResponse::FileDownloaded(DownloadFileResponse {
260 file_data_base64,
261 })
262 }),
263 };
264
265 incoming_msg_tx
267 .send(IncomingMessage::MatrixDriverResponse { request_id, response })
268 .map_err(|_| ())?;
269 }
270
271 Action::Subscribe => {
272 if self.event_forwarding_guard.is_some() {
274 return Ok(());
275 }
276
277 let (stop_forwarding, guard) = {
278 let token = CancellationToken::new();
279 (token.child_token(), token.drop_guard())
280 };
281
282 self.event_forwarding_guard = Some(guard);
283
284 let mut events = matrix_driver.events();
285 let mut state_updates = matrix_driver.state_updates();
286 let mut to_device_events = matrix_driver.to_device_events();
287 let incoming_msg_tx = incoming_msg_tx.clone();
288
289 spawn(async move {
290 loop {
291 tokio::select! {
292 _ = stop_forwarding.cancelled() => {
293 return;
295 }
296
297 Some(event) = events.recv() => {
298 let _ = incoming_msg_tx.send(IncomingMessage::MatrixEventReceived(event));
300 }
301
302 Ok(state) = state_updates.recv() => {
303 let _ = incoming_msg_tx.send(IncomingMessage::StateUpdateReceived(state));
305 }
306
307 Some(event) = to_device_events.recv() => {
308 let _ = incoming_msg_tx.send(IncomingMessage::ToDeviceReceived(event));
310 }
311 }
312 }
313 });
314 }
315
316 Action::Unsubscribe => {
317 self.event_forwarding_guard = None;
318 }
319 }
320
321 Ok(())
322 }
323}
324
325#[derive(Clone, Debug)]
327pub(crate) enum StateKeySelector {
328 Key(String),
329 Any,
330}
331
332impl<'de> Deserialize<'de> for StateKeySelector {
333 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
334 where
335 D: Deserializer<'de>,
336 {
337 struct StateKeySelectorVisitor;
338
339 impl Visitor<'_> for StateKeySelectorVisitor {
340 type Value = StateKeySelector;
341
342 fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
343 write!(f, "a string or `true`")
344 }
345
346 fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
347 where
348 E: de::Error,
349 {
350 if v {
351 Ok(StateKeySelector::Any)
352 } else {
353 Err(E::invalid_value(de::Unexpected::Bool(v), &self))
354 }
355 }
356
357 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
358 where
359 E: de::Error,
360 {
361 self.visit_string(v.to_owned())
362 }
363
364 fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
365 where
366 E: de::Error,
367 {
368 Ok(StateKeySelector::Key(v))
369 }
370 }
371
372 deserializer.deserialize_any(StateKeySelectorVisitor)
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use assert_matches::assert_matches;
379 use serde_json::json;
380
381 use super::StateKeySelector;
382
383 #[test]
384 fn state_key_selector_from_true() {
385 let state_key = serde_json::from_value(json!(true)).unwrap();
386 assert_matches!(state_key, StateKeySelector::Any);
387 }
388
389 #[test]
390 fn state_key_selector_from_string() {
391 let state_key = serde_json::from_value(json!("test")).unwrap();
392 assert_matches!(state_key, StateKeySelector::Key(k) if k == "test");
393 }
394
395 #[test]
396 fn state_key_selector_from_false() {
397 let result = serde_json::from_value::<StateKeySelector>(json!(false));
398 assert_matches!(result, Err(e) if e.is_data());
399 }
400
401 #[test]
402 fn state_key_selector_from_number() {
403 let result = serde_json::from_value::<StateKeySelector>(json!(5));
404 assert_matches!(result, Err(e) if e.is_data());
405 }
406}