Skip to main content

matrix_sdk/widget/
mod.rs

1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![allow(rustdoc::private_intra_doc_links)]
16#![doc = include_str!("README.md")]
17
18use std::{fmt, time::Duration};
19
20use async_channel::{Receiver, Sender};
21use futures_util::StreamExt;
22use matrix_sdk_common::executor::spawn;
23use ruma::api::client::delayed_events::DelayParameters;
24use serde::de::{self, Deserialize, Deserializer, Visitor};
25use tokio::sync::mpsc::{UnboundedSender, unbounded_channel};
26use tokio_stream::wrappers::UnboundedReceiverStream;
27use tokio_util::sync::{CancellationToken, DropGuard};
28
29use self::{
30    machine::{
31        Action, IncomingMessage, MatrixDriverRequestData, MatrixDriverResponse, SendEventRequest,
32        WidgetMachine,
33    },
34    matrix::MatrixDriver,
35};
36use crate::{Result, room::Room, widget::machine::DownloadFileResponse};
37
38mod capabilities;
39mod filter;
40mod machine;
41mod matrix;
42mod settings;
43
44pub use self::{
45    capabilities::{Capabilities, CapabilitiesProvider},
46    filter::{Filter, MessageLikeEventFilter, StateEventFilter, ToDeviceEventFilter},
47    settings::{
48        ClientProperties, EncryptionSystem, Intent, VirtualElementCallWidgetConfig,
49        VirtualElementCallWidgetProperties, WidgetSettings,
50    },
51};
52
53/// An object that handles all interactions of a widget living inside a webview
54/// or iframe with the Matrix world.
55#[derive(Debug)]
56pub struct WidgetDriver {
57    settings: WidgetSettings,
58
59    /// Raw incoming messages from the widget (normally formatted as JSON).
60    ///
61    /// These can be both requests and responses.
62    from_widget_rx: Receiver<String>,
63
64    /// Raw outgoing messages from the client (SDK) to the widget (normally
65    /// formatted as JSON).
66    ///
67    /// These can be both requests and responses.
68    to_widget_tx: Sender<String>,
69
70    /// Drop guard for an event handler forwarding all events from the Matrix
71    /// room to the widget.
72    ///
73    /// Only set if a subscription happened ([`Action::Subscribe`]).
74    event_forwarding_guard: Option<DropGuard>,
75}
76
77/// A handle that encapsulates the communication between a widget driver and the
78/// corresponding widget (inside a webview or iframe).
79#[derive(Clone, Debug)]
80pub struct WidgetDriverHandle {
81    /// Raw incoming messages from the widget driver to the widget (normally
82    /// formatted as JSON).
83    ///
84    /// These can be both requests and responses. Users of this API should not
85    /// care what's what though because they are only supposed to forward
86    /// messages between the webview / iframe, and the SDK's widget driver.
87    to_widget_rx: Receiver<String>,
88
89    /// Raw outgoing messages from the widget to the widget driver (normally
90    /// formatted as JSON).
91    ///
92    /// These can be both requests and responses. Users of this API should not
93    /// care what's what though because they are only supposed to forward
94    /// messages between the webview / iframe, and the SDK's widget driver.
95    from_widget_tx: Sender<String>,
96}
97
98impl WidgetDriverHandle {
99    /// Receive a message from the widget driver.
100    ///
101    /// The message must be passed on to the widget.
102    ///
103    /// Returns `None` if the widget driver is no longer running.
104    pub async fn recv(&self) -> Option<String> {
105        self.to_widget_rx.recv().await.ok()
106    }
107
108    /// Send a message from the widget to the widget driver.
109    ///
110    /// Returns `false` if the widget driver is no longer running.
111    pub async fn send(&self, message: String) -> bool {
112        self.from_widget_tx.send(message).await.is_ok()
113    }
114}
115
116impl WidgetDriver {
117    /// Creates a new `WidgetDriver` and a corresponding set of channels to let
118    /// the widget (inside a webview or iframe) communicate with it.
119    pub fn new(settings: WidgetSettings) -> (Self, WidgetDriverHandle) {
120        let (from_widget_tx, from_widget_rx) = async_channel::unbounded();
121        let (to_widget_tx, to_widget_rx) = async_channel::unbounded();
122
123        let driver = Self { settings, from_widget_rx, to_widget_tx, event_forwarding_guard: None };
124        let channels = WidgetDriverHandle { from_widget_tx, to_widget_rx };
125
126        (driver, channels)
127    }
128
129    /// Run client widget API state machine in a given joined `room` forever.
130    ///
131    /// The function returns once the widget is disconnected or any terminal
132    /// error occurs.
133    pub async fn run(
134        mut self,
135        room: Room,
136        capabilities_provider: impl CapabilitiesProvider,
137    ) -> Result<(), ()> {
138        // Create a channel so that we can conveniently send all messages to it.
139        //
140        // It will receive:
141        // - all incoming messages from the widget
142        // - all responses from the Matrix driver
143        // - all events from the Matrix driver, if subscribed
144        let (incoming_msg_tx, incoming_msg_rx) = unbounded_channel();
145
146        // Forward all of the incoming messages from the widget.
147        // TODO: This spawns a detached task, it would be nice to have an owner for this
148        // task. One way to achieve this if `WidgetDriver::run()` returns a handle that
149        // we can drop which will clean up the task and the channels. It's not too bad,
150        // since canelling `run()` will drop the sender this task listens which finishes
151        // the task.
152        spawn({
153            let incoming_msg_tx = incoming_msg_tx.clone();
154            let from_widget_rx = self.from_widget_rx.clone();
155
156            async move {
157                while let Ok(msg) = from_widget_rx.recv().await {
158                    let _ = incoming_msg_tx.send(IncomingMessage::WidgetMessage(msg));
159                }
160            }
161        });
162
163        // Create the widget API machine. The widget machine will process messages it
164        // receives from the widget and convert it into actions the `MatrixDriver` will
165        // then execute on.
166        let (mut widget_machine, initial_actions) = WidgetMachine::new(
167            self.settings.widget_id().to_owned(),
168            room.room_id().to_owned(),
169            self.settings.init_on_content_load(),
170        );
171
172        let matrix_driver = MatrixDriver::new(room.clone());
173
174        // Convert the incoming message receiver into a stream of actions.
175        let stream = UnboundedReceiverStream::new(incoming_msg_rx)
176            .flat_map(|message| tokio_stream::iter(widget_machine.process(message)));
177
178        // Let's combine our set of initial actions with the stream of received actions.
179        let mut combined = tokio_stream::iter(initial_actions).chain(stream);
180
181        // Let's now process all actions we receive forever.
182        while let Some(action) = combined.next().await {
183            self.process_action(&matrix_driver, &incoming_msg_tx, &capabilities_provider, action)
184                .await?;
185        }
186
187        Ok(())
188    }
189
190    /// Process a single [`Action`].
191    async fn process_action(
192        &mut self,
193        matrix_driver: &MatrixDriver,
194        incoming_msg_tx: &UnboundedSender<IncomingMessage>,
195        capabilities_provider: &impl CapabilitiesProvider,
196        action: Action,
197    ) -> Result<(), ()> {
198        match action {
199            Action::SendToWidget(msg) => {
200                self.to_widget_tx.send(msg).await.map_err(|_| ())?;
201            }
202
203            Action::MatrixDriverRequest { request_id, data } => {
204                let response = match data {
205                    MatrixDriverRequestData::AcquireCapabilities(cmd) => {
206                        let obtained = capabilities_provider
207                            .acquire_capabilities(cmd.desired_capabilities)
208                            .await;
209                        Ok(MatrixDriverResponse::CapabilitiesAcquired(obtained))
210                    }
211
212                    MatrixDriverRequestData::GetOpenId => {
213                        matrix_driver.get_open_id().await.map(MatrixDriverResponse::OpenIdReceived)
214                    }
215
216                    MatrixDriverRequestData::ReadEvents(cmd) => matrix_driver
217                        .read_events(cmd.event_type.into(), cmd.state_key, cmd.limit)
218                        .await
219                        .map(MatrixDriverResponse::EventsRead),
220
221                    MatrixDriverRequestData::ReadState(cmd) => matrix_driver
222                        .read_state(cmd.event_type.into(), &cmd.state_key)
223                        .await
224                        .map(MatrixDriverResponse::StateRead),
225
226                    MatrixDriverRequestData::SendEvent(req) => {
227                        let SendEventRequest { event_type, state_key, content, delay } = req;
228                        // The widget api action does not use the unstable prefix:
229                        // `org.matrix.msc4140.delay` so we
230                        // cannot use the `DelayParameters` here and need to convert
231                        // manually.
232                        let delay_event_parameter = delay.map(|d| DelayParameters::Timeout {
233                            timeout: Duration::from_millis(d),
234                        });
235                        matrix_driver
236                            .send(event_type.into(), state_key, content, delay_event_parameter)
237                            .await
238                            .map(MatrixDriverResponse::EventSent)
239                    }
240
241                    MatrixDriverRequestData::UpdateDelayedEvent(req) => matrix_driver
242                        .update_delayed_event(req.delay_id, req.action)
243                        .await
244                        .map(MatrixDriverResponse::DelayedEventUpdated),
245
246                    MatrixDriverRequestData::SendToDeviceEvent(send_to_device_request) => {
247                        matrix_driver
248                            .send_to_device(
249                                send_to_device_request.event_type.into(),
250                                send_to_device_request.messages,
251                            )
252                            .await
253                            .map(MatrixDriverResponse::ToDeviceSent)
254                    }
255                    MatrixDriverRequestData::DownloadFile(req) => matrix_driver
256                        .download_attachment(req.content_uri)
257                        .await
258                        .map(|file_data_base64| {
259                            MatrixDriverResponse::FileDownloaded(DownloadFileResponse {
260                                file_data_base64,
261                            })
262                        }),
263                };
264
265                // Forward the Matrix driver response to the incoming message stream.
266                incoming_msg_tx
267                    .send(IncomingMessage::MatrixDriverResponse { request_id, response })
268                    .map_err(|_| ())?;
269            }
270
271            Action::Subscribe => {
272                // Only subscribe if we are not already subscribed.
273                if self.event_forwarding_guard.is_some() {
274                    return Ok(());
275                }
276
277                let (stop_forwarding, guard) = {
278                    let token = CancellationToken::new();
279                    (token.child_token(), token.drop_guard())
280                };
281
282                self.event_forwarding_guard = Some(guard);
283
284                let mut events = matrix_driver.events();
285                let mut state_updates = matrix_driver.state_updates();
286                let mut to_device_events = matrix_driver.to_device_events();
287                let incoming_msg_tx = incoming_msg_tx.clone();
288
289                spawn(async move {
290                    loop {
291                        tokio::select! {
292                            _ = stop_forwarding.cancelled() => {
293                                // Upon cancellation, stop this task.
294                                return;
295                            }
296
297                            Some(event) = events.recv() => {
298                                // Forward all events to the incoming messages stream.
299                                let _ = incoming_msg_tx.send(IncomingMessage::MatrixEventReceived(event));
300                            }
301
302                            Ok(state) = state_updates.recv() => {
303                                // Forward all state updates to the incoming messages stream.
304                                let _ = incoming_msg_tx.send(IncomingMessage::StateUpdateReceived(state));
305                            }
306
307                            Some(event) = to_device_events.recv() => {
308                                // Forward all events to the incoming messages stream.
309                                let _ = incoming_msg_tx.send(IncomingMessage::ToDeviceReceived(event));
310                            }
311                        }
312                    }
313                });
314            }
315
316            Action::Unsubscribe => {
317                self.event_forwarding_guard = None;
318            }
319        }
320
321        Ok(())
322    }
323}
324
325// TODO: Decide which module this type should live in
326#[derive(Clone, Debug)]
327pub(crate) enum StateKeySelector {
328    Key(String),
329    Any,
330}
331
332impl<'de> Deserialize<'de> for StateKeySelector {
333    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
334    where
335        D: Deserializer<'de>,
336    {
337        struct StateKeySelectorVisitor;
338
339        impl Visitor<'_> for StateKeySelectorVisitor {
340            type Value = StateKeySelector;
341
342            fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
343                write!(f, "a string or `true`")
344            }
345
346            fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
347            where
348                E: de::Error,
349            {
350                if v {
351                    Ok(StateKeySelector::Any)
352                } else {
353                    Err(E::invalid_value(de::Unexpected::Bool(v), &self))
354                }
355            }
356
357            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
358            where
359                E: de::Error,
360            {
361                self.visit_string(v.to_owned())
362            }
363
364            fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
365            where
366                E: de::Error,
367            {
368                Ok(StateKeySelector::Key(v))
369            }
370        }
371
372        deserializer.deserialize_any(StateKeySelectorVisitor)
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use assert_matches::assert_matches;
379    use serde_json::json;
380
381    use super::StateKeySelector;
382
383    #[test]
384    fn state_key_selector_from_true() {
385        let state_key = serde_json::from_value(json!(true)).unwrap();
386        assert_matches!(state_key, StateKeySelector::Any);
387    }
388
389    #[test]
390    fn state_key_selector_from_string() {
391        let state_key = serde_json::from_value(json!("test")).unwrap();
392        assert_matches!(state_key, StateKeySelector::Key(k) if k == "test");
393    }
394
395    #[test]
396    fn state_key_selector_from_false() {
397        let result = serde_json::from_value::<StateKeySelector>(json!(false));
398        assert_matches!(result, Err(e) if e.is_data());
399    }
400
401    #[test]
402    fn state_key_selector_from_number() {
403        let result = serde_json::from_value::<StateKeySelector>(json!(5));
404        assert_matches!(result, Err(e) if e.is_data());
405    }
406}