Skip to main content

matrix_sdk/widget/
capabilities.rs

1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Types and traits related to the capabilities that a widget can request from
16//! a client.
17
18use std::{fmt, future::Future};
19
20use matrix_sdk_common::{SendOutsideWasm, SyncOutsideWasm};
21use serde::{Deserialize, Deserializer, Serialize, Serializer, ser::SerializeSeq};
22use tracing::{debug, warn};
23
24use super::{
25    MessageLikeEventFilter, StateEventFilter,
26    filter::{Filter, FilterInput, ToDeviceEventFilter},
27};
28
29/// Must be implemented by a component that provides functionality of deciding
30/// whether a widget is allowed to use certain capabilities (typically by
31/// providing a prompt to the user).
32pub trait CapabilitiesProvider: SendOutsideWasm + SyncOutsideWasm + 'static {
33    /// Receives a request for given capabilities and returns the actual
34    /// capabilities that the clients grants to a given widget (usually by
35    /// prompting the user).
36    fn acquire_capabilities(
37        &self,
38        capabilities: Capabilities,
39    ) -> impl Future<Output = Capabilities> + SendOutsideWasm;
40}
41
42/// Capabilities that a widget can request from a client.
43#[derive(Clone, Debug, Default)]
44#[cfg_attr(test, derive(PartialEq))]
45pub struct Capabilities {
46    /// Types of the messages that a widget wants to be able to fetch.
47    pub read: Vec<Filter>,
48    /// Types of the messages that a widget wants to be able to send.
49    pub send: Vec<Filter>,
50    /// If this capability is requested by the widget, it can not operate
51    /// separately from the Matrix client.
52    ///
53    /// This means clients should not offer to open the widget in a separate
54    /// browser/tab/webview that is not connected to the postmessage widget-api.
55    pub requires_client: bool,
56    /// This allows the widget to ask the client to update delayed events.
57    pub update_delayed_event: bool,
58    /// This allows the widget to send events with a delay.
59    pub send_delayed_event: bool,
60
61    /// This allows the widget to download files as per MSC4039.
62    pub download_file: bool,
63}
64
65impl Capabilities {
66    /// Checks if a given event is allowed to be forwarded to the widget.
67    ///
68    /// - `event_filter_input` is a minimized event representation that contains
69    ///   only the information needed to check if the widget is allowed to
70    ///   receive the event. (See [`FilterInput`])
71    pub(super) fn allow_reading<'a>(
72        &self,
73        event_filter_input: impl TryInto<FilterInput<'a>>,
74    ) -> bool {
75        match &event_filter_input.try_into() {
76            Err(_) => {
77                warn!("Failed to convert event into filter input for `allow_reading`.");
78                false
79            }
80            Ok(filter_input) => self.read.iter().any(|f| f.matches(filter_input)),
81        }
82    }
83
84    /// Checks if a given event is allowed to be sent by the widget.
85    ///
86    /// - `event_filter_input` is a minimized event representation that contains
87    ///   only the information needed to check if the widget is allowed to send
88    ///   the event to a matrix room. (See [`FilterInput`])
89    pub(super) fn allow_sending<'a>(
90        &self,
91        event_filter_input: impl TryInto<FilterInput<'a>>,
92    ) -> bool {
93        match &event_filter_input.try_into() {
94            Err(_) => {
95                warn!("Failed to convert event into filter input for `allow_sending`.");
96                false
97            }
98            Ok(filter_input) => self.send.iter().any(|f| f.matches(filter_input)),
99        }
100    }
101
102    /// Checks if a filter exists for the given event type, useful for
103    /// optimization. Avoids unnecessary read event requests when no matching
104    /// filter is present.
105    pub(super) fn has_read_filter_for_type(&self, event_type: &str) -> bool {
106        self.read.iter().any(|f| f.filter_event_type() == event_type)
107    }
108}
109
110pub(super) const SEND_EVENT: &str = "org.matrix.msc2762.send.event";
111pub(super) const READ_EVENT: &str = "org.matrix.msc2762.receive.event";
112pub(super) const SEND_STATE: &str = "org.matrix.msc2762.send.state_event";
113pub(super) const READ_STATE: &str = "org.matrix.msc2762.receive.state_event";
114pub(super) const SEND_TODEVICE: &str = "org.matrix.msc3819.send.to_device";
115pub(super) const READ_TODEVICE: &str = "org.matrix.msc3819.receive.to_device";
116pub(super) const REQUIRES_CLIENT: &str = "io.element.requires_client";
117pub(super) const SEND_DELAYED_EVENT: &str = "org.matrix.msc4157.send.delayed_event";
118pub(super) const UPDATE_DELAYED_EVENT: &str = "org.matrix.msc4157.update_delayed_event";
119
120pub(super) const DOWNLOAD_FILE: &str = "org.matrix.msc4039.download_file";
121
122impl Serialize for Capabilities {
123    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
124    where
125        S: Serializer,
126    {
127        struct PrintEventFilter<'a>(&'a Filter);
128        impl fmt::Display for PrintEventFilter<'_> {
129            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130                match self.0 {
131                    Filter::MessageLike(filter) => PrintMessageLikeEventFilter(filter).fmt(f),
132                    Filter::State(filter) => PrintStateEventFilter(filter).fmt(f),
133                    Filter::ToDevice(filter) => {
134                        // As per MSC 3819 https://github.com/matrix-org/matrix-spec-proposals/pull/3819
135                        // ToDevice capabilities is in the form of `m.send.to_device:<event type>`
136                        // or `m.receive.to_device:<event type>`
137                        write!(f, "{}", filter.event_type)
138                    }
139                }
140            }
141        }
142
143        struct PrintMessageLikeEventFilter<'a>(&'a MessageLikeEventFilter);
144        impl fmt::Display for PrintMessageLikeEventFilter<'_> {
145            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146                match self.0 {
147                    MessageLikeEventFilter::WithType(event_type) => {
148                        // TODO: escape `#` as `\#` and `\` as `\\` in event_type
149                        write!(f, "{event_type}")
150                    }
151                    MessageLikeEventFilter::RoomMessageWithMsgtype(msgtype) => {
152                        write!(f, "m.room.message#{msgtype}")
153                    }
154                }
155            }
156        }
157
158        struct PrintStateEventFilter<'a>(&'a StateEventFilter);
159        impl fmt::Display for PrintStateEventFilter<'_> {
160            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161                // TODO: escape `#` as `\#` and `\` as `\\` in event_type
162                match self.0 {
163                    StateEventFilter::WithType(event_type) => write!(f, "{event_type}"),
164                    StateEventFilter::WithTypeAndStateKey(event_type, state_key) => {
165                        write!(f, "{event_type}#{state_key}")
166                    }
167                }
168            }
169        }
170
171        let mut seq = serializer.serialize_seq(None)?;
172
173        if self.requires_client {
174            seq.serialize_element(REQUIRES_CLIENT)?;
175        }
176        if self.update_delayed_event {
177            seq.serialize_element(UPDATE_DELAYED_EVENT)?;
178        }
179        if self.send_delayed_event {
180            seq.serialize_element(SEND_DELAYED_EVENT)?;
181        }
182        if self.download_file {
183            seq.serialize_element(DOWNLOAD_FILE)?;
184        }
185        for filter in &self.read {
186            let name = match filter {
187                Filter::MessageLike(_) => READ_EVENT,
188                Filter::State(_) => READ_STATE,
189                Filter::ToDevice(_) => READ_TODEVICE,
190            };
191            seq.serialize_element(&format!("{name}:{}", PrintEventFilter(filter)))?;
192        }
193        for filter in &self.send {
194            let name = match filter {
195                Filter::MessageLike(_) => SEND_EVENT,
196                Filter::State(_) => SEND_STATE,
197                Filter::ToDevice(_) => SEND_TODEVICE,
198            };
199            seq.serialize_element(&format!("{name}:{}", PrintEventFilter(filter)))?;
200        }
201
202        seq.end()
203    }
204}
205
206impl<'de> Deserialize<'de> for Capabilities {
207    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
208    where
209        D: Deserializer<'de>,
210    {
211        enum Permission {
212            RequiresClient,
213            UpdateDelayedEvent,
214            SendDelayedEvent,
215            DownloadFile,
216            Read(Filter),
217            Send(Filter),
218            Unknown,
219        }
220
221        impl<'de> Deserialize<'de> for Permission {
222            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
223            where
224                D: Deserializer<'de>,
225            {
226                let s = ruma::serde::deserialize_cow_str(deserializer)?;
227                if s == REQUIRES_CLIENT {
228                    return Ok(Self::RequiresClient);
229                }
230                if s == UPDATE_DELAYED_EVENT {
231                    return Ok(Self::UpdateDelayedEvent);
232                }
233                if s == SEND_DELAYED_EVENT {
234                    return Ok(Self::SendDelayedEvent);
235                }
236                if s == DOWNLOAD_FILE {
237                    return Ok(Self::DownloadFile);
238                }
239
240                match s.split_once(':') {
241                    Some((READ_EVENT, filter_s)) => Ok(Permission::Read(Filter::MessageLike(
242                        parse_message_event_filter(filter_s),
243                    ))),
244                    Some((SEND_EVENT, filter_s)) => Ok(Permission::Send(Filter::MessageLike(
245                        parse_message_event_filter(filter_s),
246                    ))),
247                    Some((READ_STATE, filter_s)) => {
248                        Ok(Permission::Read(Filter::State(parse_state_event_filter(filter_s))))
249                    }
250                    Some((SEND_STATE, filter_s)) => {
251                        Ok(Permission::Send(Filter::State(parse_state_event_filter(filter_s))))
252                    }
253                    Some((READ_TODEVICE, filter_s)) => Ok(Permission::Read(Filter::ToDevice(
254                        parse_to_device_event_filter(filter_s),
255                    ))),
256                    Some((SEND_TODEVICE, filter_s)) => Ok(Permission::Send(Filter::ToDevice(
257                        parse_to_device_event_filter(filter_s),
258                    ))),
259                    _ => {
260                        debug!("Unknown capability `{s}`");
261                        Ok(Self::Unknown)
262                    }
263                }
264            }
265        }
266
267        fn parse_message_event_filter(s: &str) -> MessageLikeEventFilter {
268            match s.strip_prefix("m.room.message#") {
269                Some(msgtype) => MessageLikeEventFilter::RoomMessageWithMsgtype(msgtype.to_owned()),
270                // TODO: Replace `\\` by `\` and `\#` by `#`, enforce no unescaped `#`
271                None => MessageLikeEventFilter::WithType(s.into()),
272            }
273        }
274
275        fn parse_state_event_filter(s: &str) -> StateEventFilter {
276            // TODO: Search for un-escaped `#` only, replace `\\` by `\` and `\#` by `#`
277            match s.split_once('#') {
278                Some((event_type, state_key)) => {
279                    StateEventFilter::WithTypeAndStateKey(event_type.into(), state_key.to_owned())
280                }
281                None => StateEventFilter::WithType(s.into()),
282            }
283        }
284
285        fn parse_to_device_event_filter(s: &str) -> ToDeviceEventFilter {
286            ToDeviceEventFilter::new(s.into())
287        }
288
289        let mut capabilities = Capabilities::default();
290        for capability in Vec::<Permission>::deserialize(deserializer)? {
291            match capability {
292                Permission::RequiresClient => capabilities.requires_client = true,
293                Permission::Read(filter) => capabilities.read.push(filter),
294                Permission::Send(filter) => capabilities.send.push(filter),
295                // ignore unknown capabilities
296                Permission::Unknown => {}
297                Permission::UpdateDelayedEvent => capabilities.update_delayed_event = true,
298                Permission::SendDelayedEvent => capabilities.send_delayed_event = true,
299                Permission::DownloadFile => capabilities.download_file = true,
300            }
301        }
302
303        Ok(capabilities)
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use ruma::events::StateEventType;
310
311    use super::*;
312    use crate::widget::filter::ToDeviceEventFilter;
313
314    #[test]
315    fn deserialization_of_no_capabilities() {
316        let capabilities_str = r#"[]"#;
317
318        let parsed = serde_json::from_str::<Capabilities>(capabilities_str).unwrap();
319        let expected = Capabilities::default();
320
321        assert_eq!(parsed, expected);
322    }
323
324    #[test]
325    fn deserialization_of_capabilities() {
326        let capabilities_str = r#"[
327            "m.always_on_screen",
328            "io.element.requires_client",
329            "org.matrix.msc2762.receive.event:org.matrix.rageshake_request",
330            "org.matrix.msc2762.receive.state_event:m.room.member",
331            "org.matrix.msc2762.receive.state_event:org.matrix.msc3401.call.member",
332            "org.matrix.msc3819.receive.to_device:io.element.call.encryption_keys",
333            "org.matrix.msc2762.send.event:org.matrix.rageshake_request",
334            "org.matrix.msc2762.send.state_event:org.matrix.msc3401.call.member#@user:matrix.server",
335            "org.matrix.msc3819.send.to_device:io.element.call.encryption_keys",
336            "org.matrix.msc4157.send.delayed_event",
337            "org.matrix.msc4157.update_delayed_event",
338            "org.matrix.msc4039.download_file"
339        ]"#;
340
341        let parsed = serde_json::from_str::<Capabilities>(capabilities_str).unwrap();
342        let expected = Capabilities {
343            read: vec![
344                Filter::MessageLike(MessageLikeEventFilter::WithType(
345                    "org.matrix.rageshake_request".into(),
346                )),
347                Filter::State(StateEventFilter::WithType(StateEventType::RoomMember)),
348                Filter::State(StateEventFilter::WithType("org.matrix.msc3401.call.member".into())),
349                Filter::ToDevice(ToDeviceEventFilter::new(
350                    "io.element.call.encryption_keys".into(),
351                )),
352            ],
353            send: vec![
354                Filter::MessageLike(MessageLikeEventFilter::WithType(
355                    "org.matrix.rageshake_request".into(),
356                )),
357                Filter::State(StateEventFilter::WithTypeAndStateKey(
358                    "org.matrix.msc3401.call.member".into(),
359                    "@user:matrix.server".into(),
360                )),
361                Filter::ToDevice(ToDeviceEventFilter::new(
362                    "io.element.call.encryption_keys".into(),
363                )),
364            ],
365            requires_client: true,
366            update_delayed_event: true,
367            send_delayed_event: true,
368            download_file: true,
369        };
370
371        assert_eq!(parsed, expected);
372    }
373
374    #[test]
375    fn serialization_and_deserialization_are_symmetrical() {
376        let capabilities = Capabilities {
377            read: vec![
378                Filter::MessageLike(MessageLikeEventFilter::WithType("io.element.custom".into())),
379                Filter::State(StateEventFilter::WithType(StateEventType::RoomMember)),
380                Filter::State(StateEventFilter::WithTypeAndStateKey(
381                    "org.matrix.msc3401.call.member".into(),
382                    "@user:matrix.server".into(),
383                )),
384                Filter::ToDevice(ToDeviceEventFilter::new(
385                    "io.element.call.encryption_keys".into(),
386                )),
387            ],
388            send: vec![
389                Filter::MessageLike(MessageLikeEventFilter::WithType("io.element.custom".into())),
390                Filter::State(StateEventFilter::WithTypeAndStateKey(
391                    "org.matrix.msc3401.call.member".into(),
392                    "@user:matrix.server".into(),
393                )),
394                Filter::ToDevice(ToDeviceEventFilter::new("my.org.other.to_device_event".into())),
395            ],
396            requires_client: true,
397            update_delayed_event: false,
398            send_delayed_event: false,
399            download_file: false,
400        };
401
402        let capabilities_str = serde_json::to_string(&capabilities).unwrap();
403        let parsed = serde_json::from_str::<Capabilities>(&capabilities_str).unwrap();
404        assert_eq!(parsed, capabilities);
405    }
406}