matrix_sdk/widget/
capabilities.rs

1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Types and traits related to the capabilities that a widget can request from
16//! a client.
17
18use std::fmt;
19
20use async_trait::async_trait;
21use ruma::{events::AnyTimelineEvent, serde::Raw};
22use serde::{ser::SerializeSeq, Deserialize, Deserializer, Serialize, Serializer};
23use tracing::{debug, error};
24
25use super::{
26    filter::MatrixEventFilterInput, EventFilter, MessageLikeEventFilter, StateEventFilter,
27};
28
29/// Must be implemented by a component that provides functionality of deciding
30/// whether a widget is allowed to use certain capabilities (typically by
31/// providing a prompt to the user).
32#[async_trait]
33pub trait CapabilitiesProvider: Send + Sync + 'static {
34    /// Receives a request for given capabilities and returns the actual
35    /// capabilities that the clients grants to a given widget (usually by
36    /// prompting the user).
37    async fn acquire_capabilities(&self, capabilities: Capabilities) -> Capabilities;
38}
39
40/// Capabilities that a widget can request from a client.
41#[derive(Clone, Debug, Default)]
42#[cfg_attr(test, derive(PartialEq))]
43pub struct Capabilities {
44    /// Types of the messages that a widget wants to be able to fetch.
45    pub read: Vec<EventFilter>,
46    /// Types of the messages that a widget wants to be able to send.
47    pub send: Vec<EventFilter>,
48    /// If this capability is requested by the widget, it can not operate
49    /// separately from the matrix client.
50    ///
51    /// This means clients should not offer to open the widget in a separate
52    /// browser/tab/webview that is not connected to the postmessage widget-api.
53    pub requires_client: bool,
54    /// This allows the widget to ask the client to update delayed events.
55    pub update_delayed_event: bool,
56    /// This allows the widget to send events with a delay.
57    pub send_delayed_event: bool,
58}
59
60impl Capabilities {
61    /// Tells if a given raw event matches the read filter.
62    pub fn raw_event_matches_read_filter(&self, raw: &Raw<AnyTimelineEvent>) -> bool {
63        let filter_in = match raw.deserialize_as::<MatrixEventFilterInput>() {
64            Ok(filter) => filter,
65            Err(err) => {
66                error!("Failed to deserialize raw event as MatrixEventFilterInput: {err}");
67                return false;
68            }
69        };
70
71        self.read.iter().any(|f| f.matches(&filter_in))
72    }
73}
74
75const SEND_EVENT: &str = "org.matrix.msc2762.send.event";
76const READ_EVENT: &str = "org.matrix.msc2762.receive.event";
77const SEND_STATE: &str = "org.matrix.msc2762.send.state_event";
78const READ_STATE: &str = "org.matrix.msc2762.receive.state_event";
79const REQUIRES_CLIENT: &str = "io.element.requires_client";
80pub(super) const SEND_DELAYED_EVENT: &str = "org.matrix.msc4157.send.delayed_event";
81pub(super) const UPDATE_DELAYED_EVENT: &str = "org.matrix.msc4157.update_delayed_event";
82
83impl Serialize for Capabilities {
84    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
85    where
86        S: Serializer,
87    {
88        struct PrintEventFilter<'a>(&'a EventFilter);
89        impl fmt::Display for PrintEventFilter<'_> {
90            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91                match self.0 {
92                    EventFilter::MessageLike(filter) => PrintMessageLikeEventFilter(filter).fmt(f),
93                    EventFilter::State(filter) => PrintStateEventFilter(filter).fmt(f),
94                }
95            }
96        }
97
98        struct PrintMessageLikeEventFilter<'a>(&'a MessageLikeEventFilter);
99        impl fmt::Display for PrintMessageLikeEventFilter<'_> {
100            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101                match self.0 {
102                    MessageLikeEventFilter::WithType(event_type) => {
103                        // TODO: escape `#` as `\#` and `\` as `\\` in event_type
104                        write!(f, "{event_type}")
105                    }
106                    MessageLikeEventFilter::RoomMessageWithMsgtype(msgtype) => {
107                        write!(f, "m.room.message#{msgtype}")
108                    }
109                }
110            }
111        }
112
113        struct PrintStateEventFilter<'a>(&'a StateEventFilter);
114        impl fmt::Display for PrintStateEventFilter<'_> {
115            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116                // TODO: escape `#` as `\#` and `\` as `\\` in event_type
117                match self.0 {
118                    StateEventFilter::WithType(event_type) => write!(f, "{event_type}"),
119                    StateEventFilter::WithTypeAndStateKey(event_type, state_key) => {
120                        write!(f, "{event_type}#{state_key}")
121                    }
122                }
123            }
124        }
125
126        let mut seq = serializer.serialize_seq(None)?;
127
128        if self.requires_client {
129            seq.serialize_element(REQUIRES_CLIENT)?;
130        }
131        if self.update_delayed_event {
132            seq.serialize_element(UPDATE_DELAYED_EVENT)?;
133        }
134        if self.send_delayed_event {
135            seq.serialize_element(SEND_DELAYED_EVENT)?;
136        }
137        for filter in &self.read {
138            let name = match filter {
139                EventFilter::MessageLike(_) => READ_EVENT,
140                EventFilter::State(_) => READ_STATE,
141            };
142            seq.serialize_element(&format!("{name}:{}", PrintEventFilter(filter)))?;
143        }
144        for filter in &self.send {
145            let name = match filter {
146                EventFilter::MessageLike(_) => SEND_EVENT,
147                EventFilter::State(_) => SEND_STATE,
148            };
149            seq.serialize_element(&format!("{name}:{}", PrintEventFilter(filter)))?;
150        }
151
152        seq.end()
153    }
154}
155
156impl<'de> Deserialize<'de> for Capabilities {
157    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
158    where
159        D: Deserializer<'de>,
160    {
161        enum Permission {
162            RequiresClient,
163            UpdateDelayedEvent,
164            SendDelayedEvent,
165            Read(EventFilter),
166            Send(EventFilter),
167            Unknown,
168        }
169
170        impl<'de> Deserialize<'de> for Permission {
171            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
172            where
173                D: Deserializer<'de>,
174            {
175                let s = ruma::serde::deserialize_cow_str(deserializer)?;
176                if s == REQUIRES_CLIENT {
177                    return Ok(Self::RequiresClient);
178                }
179                if s == UPDATE_DELAYED_EVENT {
180                    return Ok(Self::UpdateDelayedEvent);
181                }
182                if s == SEND_DELAYED_EVENT {
183                    return Ok(Self::SendDelayedEvent);
184                }
185
186                match s.split_once(':') {
187                    Some((READ_EVENT, filter_s)) => Ok(Permission::Read(EventFilter::MessageLike(
188                        parse_message_event_filter(filter_s),
189                    ))),
190                    Some((SEND_EVENT, filter_s)) => Ok(Permission::Send(EventFilter::MessageLike(
191                        parse_message_event_filter(filter_s),
192                    ))),
193                    Some((READ_STATE, filter_s)) => {
194                        Ok(Permission::Read(EventFilter::State(parse_state_event_filter(filter_s))))
195                    }
196                    Some((SEND_STATE, filter_s)) => {
197                        Ok(Permission::Send(EventFilter::State(parse_state_event_filter(filter_s))))
198                    }
199                    _ => {
200                        debug!("Unknown capability `{s}`");
201                        Ok(Self::Unknown)
202                    }
203                }
204            }
205        }
206
207        fn parse_message_event_filter(s: &str) -> MessageLikeEventFilter {
208            match s.strip_prefix("m.room.message#") {
209                Some(msgtype) => MessageLikeEventFilter::RoomMessageWithMsgtype(msgtype.to_owned()),
210                // TODO: Replace `\\` by `\` and `\#` by `#`, enforce no unescaped `#`
211                None => MessageLikeEventFilter::WithType(s.into()),
212            }
213        }
214
215        fn parse_state_event_filter(s: &str) -> StateEventFilter {
216            // TODO: Search for un-escaped `#` only, replace `\\` by `\` and `\#` by `#`
217            match s.split_once('#') {
218                Some((event_type, state_key)) => {
219                    StateEventFilter::WithTypeAndStateKey(event_type.into(), state_key.to_owned())
220                }
221                None => StateEventFilter::WithType(s.into()),
222            }
223        }
224
225        let mut capabilities = Capabilities::default();
226        for capability in Vec::<Permission>::deserialize(deserializer)? {
227            match capability {
228                Permission::RequiresClient => capabilities.requires_client = true,
229                Permission::Read(filter) => capabilities.read.push(filter),
230                Permission::Send(filter) => capabilities.send.push(filter),
231                // ignore unknown capabilities
232                Permission::Unknown => {}
233                Permission::UpdateDelayedEvent => capabilities.update_delayed_event = true,
234                Permission::SendDelayedEvent => capabilities.send_delayed_event = true,
235            }
236        }
237
238        Ok(capabilities)
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use ruma::events::StateEventType;
245
246    use super::*;
247
248    #[test]
249    fn deserialization_of_no_capabilities() {
250        let capabilities_str = r#"[]"#;
251
252        let parsed = serde_json::from_str::<Capabilities>(capabilities_str).unwrap();
253        let expected = Capabilities::default();
254
255        assert_eq!(parsed, expected);
256    }
257
258    #[test]
259    fn deserialization_of_capabilities() {
260        let capabilities_str = r#"[
261            "m.always_on_screen",
262            "io.element.requires_client",
263            "org.matrix.msc2762.receive.event:org.matrix.rageshake_request",
264            "org.matrix.msc2762.receive.state_event:m.room.member",
265            "org.matrix.msc2762.receive.state_event:org.matrix.msc3401.call.member",
266            "org.matrix.msc2762.send.event:org.matrix.rageshake_request",
267            "org.matrix.msc2762.send.state_event:org.matrix.msc3401.call.member#@user:matrix.server",
268            "org.matrix.msc4157.send.delayed_event",
269            "org.matrix.msc4157.update_delayed_event"
270        ]"#;
271
272        let parsed = serde_json::from_str::<Capabilities>(capabilities_str).unwrap();
273        let expected = Capabilities {
274            read: vec![
275                EventFilter::MessageLike(MessageLikeEventFilter::WithType(
276                    "org.matrix.rageshake_request".into(),
277                )),
278                EventFilter::State(StateEventFilter::WithType(StateEventType::RoomMember)),
279                EventFilter::State(StateEventFilter::WithType(
280                    "org.matrix.msc3401.call.member".into(),
281                )),
282            ],
283            send: vec![
284                EventFilter::MessageLike(MessageLikeEventFilter::WithType(
285                    "org.matrix.rageshake_request".into(),
286                )),
287                EventFilter::State(StateEventFilter::WithTypeAndStateKey(
288                    "org.matrix.msc3401.call.member".into(),
289                    "@user:matrix.server".into(),
290                )),
291            ],
292            requires_client: true,
293            update_delayed_event: true,
294            send_delayed_event: true,
295        };
296
297        assert_eq!(parsed, expected);
298    }
299
300    #[test]
301    fn serialization_and_deserialization_are_symmetrical() {
302        let capabilities = Capabilities {
303            read: vec![
304                EventFilter::MessageLike(MessageLikeEventFilter::WithType(
305                    "io.element.custom".into(),
306                )),
307                EventFilter::State(StateEventFilter::WithType(StateEventType::RoomMember)),
308                EventFilter::State(StateEventFilter::WithTypeAndStateKey(
309                    "org.matrix.msc3401.call.member".into(),
310                    "@user:matrix.server".into(),
311                )),
312            ],
313            send: vec![
314                EventFilter::MessageLike(MessageLikeEventFilter::WithType(
315                    "io.element.custom".into(),
316                )),
317                EventFilter::State(StateEventFilter::WithTypeAndStateKey(
318                    "org.matrix.msc3401.call.member".into(),
319                    "@user:matrix.server".into(),
320                )),
321            ],
322            requires_client: true,
323            update_delayed_event: false,
324            send_delayed_event: false,
325        };
326
327        let capabilities_str = serde_json::to_string(&capabilities).unwrap();
328        let parsed = serde_json::from_str::<Capabilities>(&capabilities_str).unwrap();
329        assert_eq!(parsed, capabilities);
330    }
331}