1use std::{fmt, future::Future};
19
20use matrix_sdk_common::{SendOutsideWasm, SyncOutsideWasm};
21use serde::{ser::SerializeSeq, Deserialize, Deserializer, Serialize, Serializer};
22use tracing::{debug, warn};
23
24use super::{
25 filter::{Filter, FilterInput, ToDeviceEventFilter},
26 MessageLikeEventFilter, StateEventFilter,
27};
28
29pub trait CapabilitiesProvider: SendOutsideWasm + SyncOutsideWasm + 'static {
33 fn acquire_capabilities(
37 &self,
38 capabilities: Capabilities,
39 ) -> impl Future<Output = Capabilities> + SendOutsideWasm;
40}
41
42#[derive(Clone, Debug, Default)]
44#[cfg_attr(test, derive(PartialEq))]
45pub struct Capabilities {
46 pub read: Vec<Filter>,
48 pub send: Vec<Filter>,
50 pub requires_client: bool,
56 pub update_delayed_event: bool,
58 pub send_delayed_event: bool,
60}
61
62impl Capabilities {
63 pub(super) fn allow_reading<'a>(
69 &self,
70 event_filter_input: impl TryInto<FilterInput<'a>>,
71 ) -> bool {
72 match &event_filter_input.try_into() {
73 Err(_) => {
74 warn!("Failed to convert event into filter input for `allow_reading`.");
75 false
76 }
77 Ok(filter_input) => self.read.iter().any(|f| f.matches(filter_input)),
78 }
79 }
80
81 pub(super) fn allow_sending<'a>(
87 &self,
88 event_filter_input: impl TryInto<FilterInput<'a>>,
89 ) -> bool {
90 match &event_filter_input.try_into() {
91 Err(_) => {
92 warn!("Failed to convert event into filter input for `allow_sending`.");
93 false
94 }
95 Ok(filter_input) => self.send.iter().any(|f| f.matches(filter_input)),
96 }
97 }
98
99 pub(super) fn has_read_filter_for_type(&self, event_type: &str) -> bool {
103 self.read.iter().any(|f| f.filter_event_type() == event_type)
104 }
105}
106
107pub(super) const SEND_EVENT: &str = "org.matrix.msc2762.send.event";
108pub(super) const READ_EVENT: &str = "org.matrix.msc2762.receive.event";
109pub(super) const SEND_STATE: &str = "org.matrix.msc2762.send.state_event";
110pub(super) const READ_STATE: &str = "org.matrix.msc2762.receive.state_event";
111pub(super) const SEND_TODEVICE: &str = "org.matrix.msc3819.send.to_device";
112pub(super) const READ_TODEVICE: &str = "org.matrix.msc3819.receive.to_device";
113pub(super) const REQUIRES_CLIENT: &str = "io.element.requires_client";
114pub(super) const SEND_DELAYED_EVENT: &str = "org.matrix.msc4157.send.delayed_event";
115pub(super) const UPDATE_DELAYED_EVENT: &str = "org.matrix.msc4157.update_delayed_event";
116
117impl Serialize for Capabilities {
118 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
119 where
120 S: Serializer,
121 {
122 struct PrintEventFilter<'a>(&'a Filter);
123 impl fmt::Display for PrintEventFilter<'_> {
124 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125 match self.0 {
126 Filter::MessageLike(filter) => PrintMessageLikeEventFilter(filter).fmt(f),
127 Filter::State(filter) => PrintStateEventFilter(filter).fmt(f),
128 Filter::ToDevice(filter) => {
129 write!(f, "{}", filter.event_type)
133 }
134 }
135 }
136 }
137
138 struct PrintMessageLikeEventFilter<'a>(&'a MessageLikeEventFilter);
139 impl fmt::Display for PrintMessageLikeEventFilter<'_> {
140 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141 match self.0 {
142 MessageLikeEventFilter::WithType(event_type) => {
143 write!(f, "{event_type}")
145 }
146 MessageLikeEventFilter::RoomMessageWithMsgtype(msgtype) => {
147 write!(f, "m.room.message#{msgtype}")
148 }
149 }
150 }
151 }
152
153 struct PrintStateEventFilter<'a>(&'a StateEventFilter);
154 impl fmt::Display for PrintStateEventFilter<'_> {
155 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156 match self.0 {
158 StateEventFilter::WithType(event_type) => write!(f, "{event_type}"),
159 StateEventFilter::WithTypeAndStateKey(event_type, state_key) => {
160 write!(f, "{event_type}#{state_key}")
161 }
162 }
163 }
164 }
165
166 let mut seq = serializer.serialize_seq(None)?;
167
168 if self.requires_client {
169 seq.serialize_element(REQUIRES_CLIENT)?;
170 }
171 if self.update_delayed_event {
172 seq.serialize_element(UPDATE_DELAYED_EVENT)?;
173 }
174 if self.send_delayed_event {
175 seq.serialize_element(SEND_DELAYED_EVENT)?;
176 }
177 for filter in &self.read {
178 let name = match filter {
179 Filter::MessageLike(_) => READ_EVENT,
180 Filter::State(_) => READ_STATE,
181 Filter::ToDevice(_) => READ_TODEVICE,
182 };
183 seq.serialize_element(&format!("{name}:{}", PrintEventFilter(filter)))?;
184 }
185 for filter in &self.send {
186 let name = match filter {
187 Filter::MessageLike(_) => SEND_EVENT,
188 Filter::State(_) => SEND_STATE,
189 Filter::ToDevice(_) => SEND_TODEVICE,
190 };
191 seq.serialize_element(&format!("{name}:{}", PrintEventFilter(filter)))?;
192 }
193
194 seq.end()
195 }
196}
197
198impl<'de> Deserialize<'de> for Capabilities {
199 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
200 where
201 D: Deserializer<'de>,
202 {
203 enum Permission {
204 RequiresClient,
205 UpdateDelayedEvent,
206 SendDelayedEvent,
207 Read(Filter),
208 Send(Filter),
209 Unknown,
210 }
211
212 impl<'de> Deserialize<'de> for Permission {
213 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
214 where
215 D: Deserializer<'de>,
216 {
217 let s = ruma::serde::deserialize_cow_str(deserializer)?;
218 if s == REQUIRES_CLIENT {
219 return Ok(Self::RequiresClient);
220 }
221 if s == UPDATE_DELAYED_EVENT {
222 return Ok(Self::UpdateDelayedEvent);
223 }
224 if s == SEND_DELAYED_EVENT {
225 return Ok(Self::SendDelayedEvent);
226 }
227
228 match s.split_once(':') {
229 Some((READ_EVENT, filter_s)) => Ok(Permission::Read(Filter::MessageLike(
230 parse_message_event_filter(filter_s),
231 ))),
232 Some((SEND_EVENT, filter_s)) => Ok(Permission::Send(Filter::MessageLike(
233 parse_message_event_filter(filter_s),
234 ))),
235 Some((READ_STATE, filter_s)) => {
236 Ok(Permission::Read(Filter::State(parse_state_event_filter(filter_s))))
237 }
238 Some((SEND_STATE, filter_s)) => {
239 Ok(Permission::Send(Filter::State(parse_state_event_filter(filter_s))))
240 }
241 Some((READ_TODEVICE, filter_s)) => Ok(Permission::Read(Filter::ToDevice(
242 parse_to_device_event_filter(filter_s),
243 ))),
244 Some((SEND_TODEVICE, filter_s)) => Ok(Permission::Send(Filter::ToDevice(
245 parse_to_device_event_filter(filter_s),
246 ))),
247 _ => {
248 debug!("Unknown capability `{s}`");
249 Ok(Self::Unknown)
250 }
251 }
252 }
253 }
254
255 fn parse_message_event_filter(s: &str) -> MessageLikeEventFilter {
256 match s.strip_prefix("m.room.message#") {
257 Some(msgtype) => MessageLikeEventFilter::RoomMessageWithMsgtype(msgtype.to_owned()),
258 None => MessageLikeEventFilter::WithType(s.into()),
260 }
261 }
262
263 fn parse_state_event_filter(s: &str) -> StateEventFilter {
264 match s.split_once('#') {
266 Some((event_type, state_key)) => {
267 StateEventFilter::WithTypeAndStateKey(event_type.into(), state_key.to_owned())
268 }
269 None => StateEventFilter::WithType(s.into()),
270 }
271 }
272
273 fn parse_to_device_event_filter(s: &str) -> ToDeviceEventFilter {
274 ToDeviceEventFilter::new(s.into())
275 }
276
277 let mut capabilities = Capabilities::default();
278 for capability in Vec::<Permission>::deserialize(deserializer)? {
279 match capability {
280 Permission::RequiresClient => capabilities.requires_client = true,
281 Permission::Read(filter) => capabilities.read.push(filter),
282 Permission::Send(filter) => capabilities.send.push(filter),
283 Permission::Unknown => {}
285 Permission::UpdateDelayedEvent => capabilities.update_delayed_event = true,
286 Permission::SendDelayedEvent => capabilities.send_delayed_event = true,
287 }
288 }
289
290 Ok(capabilities)
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use ruma::events::StateEventType;
297
298 use super::*;
299 use crate::widget::filter::ToDeviceEventFilter;
300
301 #[test]
302 fn deserialization_of_no_capabilities() {
303 let capabilities_str = r#"[]"#;
304
305 let parsed = serde_json::from_str::<Capabilities>(capabilities_str).unwrap();
306 let expected = Capabilities::default();
307
308 assert_eq!(parsed, expected);
309 }
310
311 #[test]
312 fn deserialization_of_capabilities() {
313 let capabilities_str = r#"[
314 "m.always_on_screen",
315 "io.element.requires_client",
316 "org.matrix.msc2762.receive.event:org.matrix.rageshake_request",
317 "org.matrix.msc2762.receive.state_event:m.room.member",
318 "org.matrix.msc2762.receive.state_event:org.matrix.msc3401.call.member",
319 "org.matrix.msc3819.receive.to_device:io.element.call.encryption_keys",
320 "org.matrix.msc2762.send.event:org.matrix.rageshake_request",
321 "org.matrix.msc2762.send.state_event:org.matrix.msc3401.call.member#@user:matrix.server",
322 "org.matrix.msc3819.send.to_device:io.element.call.encryption_keys",
323 "org.matrix.msc4157.send.delayed_event",
324 "org.matrix.msc4157.update_delayed_event"
325 ]"#;
326
327 let parsed = serde_json::from_str::<Capabilities>(capabilities_str).unwrap();
328 let expected = Capabilities {
329 read: vec![
330 Filter::MessageLike(MessageLikeEventFilter::WithType(
331 "org.matrix.rageshake_request".into(),
332 )),
333 Filter::State(StateEventFilter::WithType(StateEventType::RoomMember)),
334 Filter::State(StateEventFilter::WithType("org.matrix.msc3401.call.member".into())),
335 Filter::ToDevice(ToDeviceEventFilter::new(
336 "io.element.call.encryption_keys".into(),
337 )),
338 ],
339 send: vec![
340 Filter::MessageLike(MessageLikeEventFilter::WithType(
341 "org.matrix.rageshake_request".into(),
342 )),
343 Filter::State(StateEventFilter::WithTypeAndStateKey(
344 "org.matrix.msc3401.call.member".into(),
345 "@user:matrix.server".into(),
346 )),
347 Filter::ToDevice(ToDeviceEventFilter::new(
348 "io.element.call.encryption_keys".into(),
349 )),
350 ],
351 requires_client: true,
352 update_delayed_event: true,
353 send_delayed_event: true,
354 };
355
356 assert_eq!(parsed, expected);
357 }
358
359 #[test]
360 fn serialization_and_deserialization_are_symmetrical() {
361 let capabilities = Capabilities {
362 read: vec![
363 Filter::MessageLike(MessageLikeEventFilter::WithType("io.element.custom".into())),
364 Filter::State(StateEventFilter::WithType(StateEventType::RoomMember)),
365 Filter::State(StateEventFilter::WithTypeAndStateKey(
366 "org.matrix.msc3401.call.member".into(),
367 "@user:matrix.server".into(),
368 )),
369 Filter::ToDevice(ToDeviceEventFilter::new(
370 "io.element.call.encryption_keys".into(),
371 )),
372 ],
373 send: vec![
374 Filter::MessageLike(MessageLikeEventFilter::WithType("io.element.custom".into())),
375 Filter::State(StateEventFilter::WithTypeAndStateKey(
376 "org.matrix.msc3401.call.member".into(),
377 "@user:matrix.server".into(),
378 )),
379 Filter::ToDevice(ToDeviceEventFilter::new("my.org.other.to_device_event".into())),
380 ],
381 requires_client: true,
382 update_delayed_event: false,
383 send_delayed_event: false,
384 };
385
386 let capabilities_str = serde_json::to_string(&capabilities).unwrap();
387 let parsed = serde_json::from_str::<Capabilities>(&capabilities_str).unwrap();
388 assert_eq!(parsed, capabilities);
389 }
390}