1use std::{fmt, future::Future};
19
20use matrix_sdk_common::{SendOutsideWasm, SyncOutsideWasm};
21use serde::{Deserialize, Deserializer, Serialize, Serializer, ser::SerializeSeq};
22use tracing::{debug, warn};
23
24use super::{
25 MessageLikeEventFilter, StateEventFilter,
26 filter::{Filter, FilterInput, ToDeviceEventFilter},
27};
28
29pub trait CapabilitiesProvider: SendOutsideWasm + SyncOutsideWasm + 'static {
33 fn acquire_capabilities(
37 &self,
38 capabilities: Capabilities,
39 ) -> impl Future<Output = Capabilities> + SendOutsideWasm;
40}
41
42#[derive(Clone, Debug, Default)]
44#[cfg_attr(test, derive(PartialEq))]
45pub struct Capabilities {
46 pub read: Vec<Filter>,
48 pub send: Vec<Filter>,
50 pub requires_client: bool,
56 pub update_delayed_event: bool,
58 pub send_delayed_event: bool,
60
61 pub download_file: bool,
63}
64
65impl Capabilities {
66 pub(super) fn allow_reading<'a>(
72 &self,
73 event_filter_input: impl TryInto<FilterInput<'a>>,
74 ) -> bool {
75 match &event_filter_input.try_into() {
76 Err(_) => {
77 warn!("Failed to convert event into filter input for `allow_reading`.");
78 false
79 }
80 Ok(filter_input) => self.read.iter().any(|f| f.matches(filter_input)),
81 }
82 }
83
84 pub(super) fn allow_sending<'a>(
90 &self,
91 event_filter_input: impl TryInto<FilterInput<'a>>,
92 ) -> bool {
93 match &event_filter_input.try_into() {
94 Err(_) => {
95 warn!("Failed to convert event into filter input for `allow_sending`.");
96 false
97 }
98 Ok(filter_input) => self.send.iter().any(|f| f.matches(filter_input)),
99 }
100 }
101
102 pub(super) fn has_read_filter_for_type(&self, event_type: &str) -> bool {
106 self.read.iter().any(|f| f.filter_event_type() == event_type)
107 }
108}
109
110pub(super) const SEND_EVENT: &str = "org.matrix.msc2762.send.event";
111pub(super) const READ_EVENT: &str = "org.matrix.msc2762.receive.event";
112pub(super) const SEND_STATE: &str = "org.matrix.msc2762.send.state_event";
113pub(super) const READ_STATE: &str = "org.matrix.msc2762.receive.state_event";
114pub(super) const SEND_TODEVICE: &str = "org.matrix.msc3819.send.to_device";
115pub(super) const READ_TODEVICE: &str = "org.matrix.msc3819.receive.to_device";
116pub(super) const REQUIRES_CLIENT: &str = "io.element.requires_client";
117pub(super) const SEND_DELAYED_EVENT: &str = "org.matrix.msc4157.send.delayed_event";
118pub(super) const UPDATE_DELAYED_EVENT: &str = "org.matrix.msc4157.update_delayed_event";
119
120pub(super) const DOWNLOAD_FILE: &str = "org.matrix.msc4039.download_file";
121
122impl Serialize for Capabilities {
123 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
124 where
125 S: Serializer,
126 {
127 struct PrintEventFilter<'a>(&'a Filter);
128 impl fmt::Display for PrintEventFilter<'_> {
129 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130 match self.0 {
131 Filter::MessageLike(filter) => PrintMessageLikeEventFilter(filter).fmt(f),
132 Filter::State(filter) => PrintStateEventFilter(filter).fmt(f),
133 Filter::ToDevice(filter) => {
134 write!(f, "{}", filter.event_type)
138 }
139 }
140 }
141 }
142
143 struct PrintMessageLikeEventFilter<'a>(&'a MessageLikeEventFilter);
144 impl fmt::Display for PrintMessageLikeEventFilter<'_> {
145 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146 match self.0 {
147 MessageLikeEventFilter::WithType(event_type) => {
148 write!(f, "{event_type}")
150 }
151 MessageLikeEventFilter::RoomMessageWithMsgtype(msgtype) => {
152 write!(f, "m.room.message#{msgtype}")
153 }
154 }
155 }
156 }
157
158 struct PrintStateEventFilter<'a>(&'a StateEventFilter);
159 impl fmt::Display for PrintStateEventFilter<'_> {
160 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161 match self.0 {
163 StateEventFilter::WithType(event_type) => write!(f, "{event_type}"),
164 StateEventFilter::WithTypeAndStateKey(event_type, state_key) => {
165 write!(f, "{event_type}#{state_key}")
166 }
167 }
168 }
169 }
170
171 let mut seq = serializer.serialize_seq(None)?;
172
173 if self.requires_client {
174 seq.serialize_element(REQUIRES_CLIENT)?;
175 }
176 if self.update_delayed_event {
177 seq.serialize_element(UPDATE_DELAYED_EVENT)?;
178 }
179 if self.send_delayed_event {
180 seq.serialize_element(SEND_DELAYED_EVENT)?;
181 }
182 if self.download_file {
183 seq.serialize_element(DOWNLOAD_FILE)?;
184 }
185 for filter in &self.read {
186 let name = match filter {
187 Filter::MessageLike(_) => READ_EVENT,
188 Filter::State(_) => READ_STATE,
189 Filter::ToDevice(_) => READ_TODEVICE,
190 };
191 seq.serialize_element(&format!("{name}:{}", PrintEventFilter(filter)))?;
192 }
193 for filter in &self.send {
194 let name = match filter {
195 Filter::MessageLike(_) => SEND_EVENT,
196 Filter::State(_) => SEND_STATE,
197 Filter::ToDevice(_) => SEND_TODEVICE,
198 };
199 seq.serialize_element(&format!("{name}:{}", PrintEventFilter(filter)))?;
200 }
201
202 seq.end()
203 }
204}
205
206impl<'de> Deserialize<'de> for Capabilities {
207 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
208 where
209 D: Deserializer<'de>,
210 {
211 enum Permission {
212 RequiresClient,
213 UpdateDelayedEvent,
214 SendDelayedEvent,
215 DownloadFile,
216 Read(Filter),
217 Send(Filter),
218 Unknown,
219 }
220
221 impl<'de> Deserialize<'de> for Permission {
222 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
223 where
224 D: Deserializer<'de>,
225 {
226 let s = ruma::serde::deserialize_cow_str(deserializer)?;
227 if s == REQUIRES_CLIENT {
228 return Ok(Self::RequiresClient);
229 }
230 if s == UPDATE_DELAYED_EVENT {
231 return Ok(Self::UpdateDelayedEvent);
232 }
233 if s == SEND_DELAYED_EVENT {
234 return Ok(Self::SendDelayedEvent);
235 }
236 if s == DOWNLOAD_FILE {
237 return Ok(Self::DownloadFile);
238 }
239
240 match s.split_once(':') {
241 Some((READ_EVENT, filter_s)) => Ok(Permission::Read(Filter::MessageLike(
242 parse_message_event_filter(filter_s),
243 ))),
244 Some((SEND_EVENT, filter_s)) => Ok(Permission::Send(Filter::MessageLike(
245 parse_message_event_filter(filter_s),
246 ))),
247 Some((READ_STATE, filter_s)) => {
248 Ok(Permission::Read(Filter::State(parse_state_event_filter(filter_s))))
249 }
250 Some((SEND_STATE, filter_s)) => {
251 Ok(Permission::Send(Filter::State(parse_state_event_filter(filter_s))))
252 }
253 Some((READ_TODEVICE, filter_s)) => Ok(Permission::Read(Filter::ToDevice(
254 parse_to_device_event_filter(filter_s),
255 ))),
256 Some((SEND_TODEVICE, filter_s)) => Ok(Permission::Send(Filter::ToDevice(
257 parse_to_device_event_filter(filter_s),
258 ))),
259 _ => {
260 debug!("Unknown capability `{s}`");
261 Ok(Self::Unknown)
262 }
263 }
264 }
265 }
266
267 fn parse_message_event_filter(s: &str) -> MessageLikeEventFilter {
268 match s.strip_prefix("m.room.message#") {
269 Some(msgtype) => MessageLikeEventFilter::RoomMessageWithMsgtype(msgtype.to_owned()),
270 None => MessageLikeEventFilter::WithType(s.into()),
272 }
273 }
274
275 fn parse_state_event_filter(s: &str) -> StateEventFilter {
276 match s.split_once('#') {
278 Some((event_type, state_key)) => {
279 StateEventFilter::WithTypeAndStateKey(event_type.into(), state_key.to_owned())
280 }
281 None => StateEventFilter::WithType(s.into()),
282 }
283 }
284
285 fn parse_to_device_event_filter(s: &str) -> ToDeviceEventFilter {
286 ToDeviceEventFilter::new(s.into())
287 }
288
289 let mut capabilities = Capabilities::default();
290 for capability in Vec::<Permission>::deserialize(deserializer)? {
291 match capability {
292 Permission::RequiresClient => capabilities.requires_client = true,
293 Permission::Read(filter) => capabilities.read.push(filter),
294 Permission::Send(filter) => capabilities.send.push(filter),
295 Permission::Unknown => {}
297 Permission::UpdateDelayedEvent => capabilities.update_delayed_event = true,
298 Permission::SendDelayedEvent => capabilities.send_delayed_event = true,
299 Permission::DownloadFile => capabilities.download_file = true,
300 }
301 }
302
303 Ok(capabilities)
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use ruma::events::StateEventType;
310
311 use super::*;
312 use crate::widget::filter::ToDeviceEventFilter;
313
314 #[test]
315 fn deserialization_of_no_capabilities() {
316 let capabilities_str = r#"[]"#;
317
318 let parsed = serde_json::from_str::<Capabilities>(capabilities_str).unwrap();
319 let expected = Capabilities::default();
320
321 assert_eq!(parsed, expected);
322 }
323
324 #[test]
325 fn deserialization_of_capabilities() {
326 let capabilities_str = r#"[
327 "m.always_on_screen",
328 "io.element.requires_client",
329 "org.matrix.msc2762.receive.event:org.matrix.rageshake_request",
330 "org.matrix.msc2762.receive.state_event:m.room.member",
331 "org.matrix.msc2762.receive.state_event:org.matrix.msc3401.call.member",
332 "org.matrix.msc3819.receive.to_device:io.element.call.encryption_keys",
333 "org.matrix.msc2762.send.event:org.matrix.rageshake_request",
334 "org.matrix.msc2762.send.state_event:org.matrix.msc3401.call.member#@user:matrix.server",
335 "org.matrix.msc3819.send.to_device:io.element.call.encryption_keys",
336 "org.matrix.msc4157.send.delayed_event",
337 "org.matrix.msc4157.update_delayed_event",
338 "org.matrix.msc4039.download_file"
339 ]"#;
340
341 let parsed = serde_json::from_str::<Capabilities>(capabilities_str).unwrap();
342 let expected = Capabilities {
343 read: vec![
344 Filter::MessageLike(MessageLikeEventFilter::WithType(
345 "org.matrix.rageshake_request".into(),
346 )),
347 Filter::State(StateEventFilter::WithType(StateEventType::RoomMember)),
348 Filter::State(StateEventFilter::WithType("org.matrix.msc3401.call.member".into())),
349 Filter::ToDevice(ToDeviceEventFilter::new(
350 "io.element.call.encryption_keys".into(),
351 )),
352 ],
353 send: vec![
354 Filter::MessageLike(MessageLikeEventFilter::WithType(
355 "org.matrix.rageshake_request".into(),
356 )),
357 Filter::State(StateEventFilter::WithTypeAndStateKey(
358 "org.matrix.msc3401.call.member".into(),
359 "@user:matrix.server".into(),
360 )),
361 Filter::ToDevice(ToDeviceEventFilter::new(
362 "io.element.call.encryption_keys".into(),
363 )),
364 ],
365 requires_client: true,
366 update_delayed_event: true,
367 send_delayed_event: true,
368 download_file: true,
369 };
370
371 assert_eq!(parsed, expected);
372 }
373
374 #[test]
375 fn serialization_and_deserialization_are_symmetrical() {
376 let capabilities = Capabilities {
377 read: vec![
378 Filter::MessageLike(MessageLikeEventFilter::WithType("io.element.custom".into())),
379 Filter::State(StateEventFilter::WithType(StateEventType::RoomMember)),
380 Filter::State(StateEventFilter::WithTypeAndStateKey(
381 "org.matrix.msc3401.call.member".into(),
382 "@user:matrix.server".into(),
383 )),
384 Filter::ToDevice(ToDeviceEventFilter::new(
385 "io.element.call.encryption_keys".into(),
386 )),
387 ],
388 send: vec![
389 Filter::MessageLike(MessageLikeEventFilter::WithType("io.element.custom".into())),
390 Filter::State(StateEventFilter::WithTypeAndStateKey(
391 "org.matrix.msc3401.call.member".into(),
392 "@user:matrix.server".into(),
393 )),
394 Filter::ToDevice(ToDeviceEventFilter::new("my.org.other.to_device_event".into())),
395 ],
396 requires_client: true,
397 update_delayed_event: false,
398 send_delayed_event: false,
399 download_file: false,
400 };
401
402 let capabilities_str = serde_json::to_string(&capabilities).unwrap();
403 let parsed = serde_json::from_str::<Capabilities>(&capabilities_str).unwrap();
404 assert_eq!(parsed, capabilities);
405 }
406}