1use std::fmt;
19
20use async_trait::async_trait;
21use ruma::{events::AnyTimelineEvent, serde::Raw};
22use serde::{ser::SerializeSeq, Deserialize, Deserializer, Serialize, Serializer};
23use tracing::{debug, error};
24
25use super::{
26 filter::MatrixEventFilterInput, EventFilter, MessageLikeEventFilter, StateEventFilter,
27};
28
29#[async_trait]
33pub trait CapabilitiesProvider: Send + Sync + 'static {
34 async fn acquire_capabilities(&self, capabilities: Capabilities) -> Capabilities;
38}
39
40#[derive(Clone, Debug, Default)]
42#[cfg_attr(test, derive(PartialEq))]
43pub struct Capabilities {
44 pub read: Vec<EventFilter>,
46 pub send: Vec<EventFilter>,
48 pub requires_client: bool,
54 pub update_delayed_event: bool,
56 pub send_delayed_event: bool,
58}
59
60impl Capabilities {
61 pub fn raw_event_matches_read_filter(&self, raw: &Raw<AnyTimelineEvent>) -> bool {
63 let filter_in = match raw.deserialize_as::<MatrixEventFilterInput>() {
64 Ok(filter) => filter,
65 Err(err) => {
66 error!("Failed to deserialize raw event as MatrixEventFilterInput: {err}");
67 return false;
68 }
69 };
70
71 self.read.iter().any(|f| f.matches(&filter_in))
72 }
73}
74
75const SEND_EVENT: &str = "org.matrix.msc2762.send.event";
76const READ_EVENT: &str = "org.matrix.msc2762.receive.event";
77const SEND_STATE: &str = "org.matrix.msc2762.send.state_event";
78const READ_STATE: &str = "org.matrix.msc2762.receive.state_event";
79const REQUIRES_CLIENT: &str = "io.element.requires_client";
80pub(super) const SEND_DELAYED_EVENT: &str = "org.matrix.msc4157.send.delayed_event";
81pub(super) const UPDATE_DELAYED_EVENT: &str = "org.matrix.msc4157.update_delayed_event";
82
83impl Serialize for Capabilities {
84 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
85 where
86 S: Serializer,
87 {
88 struct PrintEventFilter<'a>(&'a EventFilter);
89 impl fmt::Display for PrintEventFilter<'_> {
90 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91 match self.0 {
92 EventFilter::MessageLike(filter) => PrintMessageLikeEventFilter(filter).fmt(f),
93 EventFilter::State(filter) => PrintStateEventFilter(filter).fmt(f),
94 }
95 }
96 }
97
98 struct PrintMessageLikeEventFilter<'a>(&'a MessageLikeEventFilter);
99 impl fmt::Display for PrintMessageLikeEventFilter<'_> {
100 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101 match self.0 {
102 MessageLikeEventFilter::WithType(event_type) => {
103 write!(f, "{event_type}")
105 }
106 MessageLikeEventFilter::RoomMessageWithMsgtype(msgtype) => {
107 write!(f, "m.room.message#{msgtype}")
108 }
109 }
110 }
111 }
112
113 struct PrintStateEventFilter<'a>(&'a StateEventFilter);
114 impl fmt::Display for PrintStateEventFilter<'_> {
115 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116 match self.0 {
118 StateEventFilter::WithType(event_type) => write!(f, "{event_type}"),
119 StateEventFilter::WithTypeAndStateKey(event_type, state_key) => {
120 write!(f, "{event_type}#{state_key}")
121 }
122 }
123 }
124 }
125
126 let mut seq = serializer.serialize_seq(None)?;
127
128 if self.requires_client {
129 seq.serialize_element(REQUIRES_CLIENT)?;
130 }
131 if self.update_delayed_event {
132 seq.serialize_element(UPDATE_DELAYED_EVENT)?;
133 }
134 if self.send_delayed_event {
135 seq.serialize_element(SEND_DELAYED_EVENT)?;
136 }
137 for filter in &self.read {
138 let name = match filter {
139 EventFilter::MessageLike(_) => READ_EVENT,
140 EventFilter::State(_) => READ_STATE,
141 };
142 seq.serialize_element(&format!("{name}:{}", PrintEventFilter(filter)))?;
143 }
144 for filter in &self.send {
145 let name = match filter {
146 EventFilter::MessageLike(_) => SEND_EVENT,
147 EventFilter::State(_) => SEND_STATE,
148 };
149 seq.serialize_element(&format!("{name}:{}", PrintEventFilter(filter)))?;
150 }
151
152 seq.end()
153 }
154}
155
156impl<'de> Deserialize<'de> for Capabilities {
157 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
158 where
159 D: Deserializer<'de>,
160 {
161 enum Permission {
162 RequiresClient,
163 UpdateDelayedEvent,
164 SendDelayedEvent,
165 Read(EventFilter),
166 Send(EventFilter),
167 Unknown,
168 }
169
170 impl<'de> Deserialize<'de> for Permission {
171 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
172 where
173 D: Deserializer<'de>,
174 {
175 let s = ruma::serde::deserialize_cow_str(deserializer)?;
176 if s == REQUIRES_CLIENT {
177 return Ok(Self::RequiresClient);
178 }
179 if s == UPDATE_DELAYED_EVENT {
180 return Ok(Self::UpdateDelayedEvent);
181 }
182 if s == SEND_DELAYED_EVENT {
183 return Ok(Self::SendDelayedEvent);
184 }
185
186 match s.split_once(':') {
187 Some((READ_EVENT, filter_s)) => Ok(Permission::Read(EventFilter::MessageLike(
188 parse_message_event_filter(filter_s),
189 ))),
190 Some((SEND_EVENT, filter_s)) => Ok(Permission::Send(EventFilter::MessageLike(
191 parse_message_event_filter(filter_s),
192 ))),
193 Some((READ_STATE, filter_s)) => {
194 Ok(Permission::Read(EventFilter::State(parse_state_event_filter(filter_s))))
195 }
196 Some((SEND_STATE, filter_s)) => {
197 Ok(Permission::Send(EventFilter::State(parse_state_event_filter(filter_s))))
198 }
199 _ => {
200 debug!("Unknown capability `{s}`");
201 Ok(Self::Unknown)
202 }
203 }
204 }
205 }
206
207 fn parse_message_event_filter(s: &str) -> MessageLikeEventFilter {
208 match s.strip_prefix("m.room.message#") {
209 Some(msgtype) => MessageLikeEventFilter::RoomMessageWithMsgtype(msgtype.to_owned()),
210 None => MessageLikeEventFilter::WithType(s.into()),
212 }
213 }
214
215 fn parse_state_event_filter(s: &str) -> StateEventFilter {
216 match s.split_once('#') {
218 Some((event_type, state_key)) => {
219 StateEventFilter::WithTypeAndStateKey(event_type.into(), state_key.to_owned())
220 }
221 None => StateEventFilter::WithType(s.into()),
222 }
223 }
224
225 let mut capabilities = Capabilities::default();
226 for capability in Vec::<Permission>::deserialize(deserializer)? {
227 match capability {
228 Permission::RequiresClient => capabilities.requires_client = true,
229 Permission::Read(filter) => capabilities.read.push(filter),
230 Permission::Send(filter) => capabilities.send.push(filter),
231 Permission::Unknown => {}
233 Permission::UpdateDelayedEvent => capabilities.update_delayed_event = true,
234 Permission::SendDelayedEvent => capabilities.send_delayed_event = true,
235 }
236 }
237
238 Ok(capabilities)
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use ruma::events::StateEventType;
245
246 use super::*;
247
248 #[test]
249 fn deserialization_of_no_capabilities() {
250 let capabilities_str = r#"[]"#;
251
252 let parsed = serde_json::from_str::<Capabilities>(capabilities_str).unwrap();
253 let expected = Capabilities::default();
254
255 assert_eq!(parsed, expected);
256 }
257
258 #[test]
259 fn deserialization_of_capabilities() {
260 let capabilities_str = r#"[
261 "m.always_on_screen",
262 "io.element.requires_client",
263 "org.matrix.msc2762.receive.event:org.matrix.rageshake_request",
264 "org.matrix.msc2762.receive.state_event:m.room.member",
265 "org.matrix.msc2762.receive.state_event:org.matrix.msc3401.call.member",
266 "org.matrix.msc2762.send.event:org.matrix.rageshake_request",
267 "org.matrix.msc2762.send.state_event:org.matrix.msc3401.call.member#@user:matrix.server",
268 "org.matrix.msc4157.send.delayed_event",
269 "org.matrix.msc4157.update_delayed_event"
270 ]"#;
271
272 let parsed = serde_json::from_str::<Capabilities>(capabilities_str).unwrap();
273 let expected = Capabilities {
274 read: vec![
275 EventFilter::MessageLike(MessageLikeEventFilter::WithType(
276 "org.matrix.rageshake_request".into(),
277 )),
278 EventFilter::State(StateEventFilter::WithType(StateEventType::RoomMember)),
279 EventFilter::State(StateEventFilter::WithType(
280 "org.matrix.msc3401.call.member".into(),
281 )),
282 ],
283 send: vec![
284 EventFilter::MessageLike(MessageLikeEventFilter::WithType(
285 "org.matrix.rageshake_request".into(),
286 )),
287 EventFilter::State(StateEventFilter::WithTypeAndStateKey(
288 "org.matrix.msc3401.call.member".into(),
289 "@user:matrix.server".into(),
290 )),
291 ],
292 requires_client: true,
293 update_delayed_event: true,
294 send_delayed_event: true,
295 };
296
297 assert_eq!(parsed, expected);
298 }
299
300 #[test]
301 fn serialization_and_deserialization_are_symmetrical() {
302 let capabilities = Capabilities {
303 read: vec![
304 EventFilter::MessageLike(MessageLikeEventFilter::WithType(
305 "io.element.custom".into(),
306 )),
307 EventFilter::State(StateEventFilter::WithType(StateEventType::RoomMember)),
308 EventFilter::State(StateEventFilter::WithTypeAndStateKey(
309 "org.matrix.msc3401.call.member".into(),
310 "@user:matrix.server".into(),
311 )),
312 ],
313 send: vec![
314 EventFilter::MessageLike(MessageLikeEventFilter::WithType(
315 "io.element.custom".into(),
316 )),
317 EventFilter::State(StateEventFilter::WithTypeAndStateKey(
318 "org.matrix.msc3401.call.member".into(),
319 "@user:matrix.server".into(),
320 )),
321 ],
322 requires_client: true,
323 update_delayed_event: false,
324 send_delayed_event: false,
325 };
326
327 let capabilities_str = serde_json::to_string(&capabilities).unwrap();
328 let parsed = serde_json::from_str::<Capabilities>(&capabilities_str).unwrap();
329 assert_eq!(parsed, capabilities);
330 }
331}