use std::fmt;
use async_trait::async_trait;
use ruma::{events::AnyTimelineEvent, serde::Raw};
use serde::{ser::SerializeSeq, Deserialize, Deserializer, Serialize, Serializer};
use tracing::{debug, error};
use super::{
filter::MatrixEventFilterInput, EventFilter, MessageLikeEventFilter, StateEventFilter,
};
#[async_trait]
pub trait CapabilitiesProvider: Send + Sync + 'static {
async fn acquire_capabilities(&self, capabilities: Capabilities) -> Capabilities;
}
#[derive(Clone, Debug, Default)]
#[cfg_attr(test, derive(PartialEq))]
pub struct Capabilities {
pub read: Vec<EventFilter>,
pub send: Vec<EventFilter>,
pub requires_client: bool,
pub update_delayed_event: bool,
pub send_delayed_event: bool,
}
impl Capabilities {
pub fn raw_event_matches_read_filter(&self, raw: &Raw<AnyTimelineEvent>) -> bool {
let filter_in = match raw.deserialize_as::<MatrixEventFilterInput>() {
Ok(filter) => filter,
Err(err) => {
error!("Failed to deserialize raw event as MatrixEventFilterInput: {err}");
return false;
}
};
self.read.iter().any(|f| f.matches(&filter_in))
}
}
const SEND_EVENT: &str = "org.matrix.msc2762.send.event";
const READ_EVENT: &str = "org.matrix.msc2762.receive.event";
const SEND_STATE: &str = "org.matrix.msc2762.send.state_event";
const READ_STATE: &str = "org.matrix.msc2762.receive.state_event";
const REQUIRES_CLIENT: &str = "io.element.requires_client";
pub(super) const SEND_DELAYED_EVENT: &str = "org.matrix.msc4157.send.delayed_event";
pub(super) const UPDATE_DELAYED_EVENT: &str = "org.matrix.msc4157.update_delayed_event";
impl Serialize for Capabilities {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
struct PrintEventFilter<'a>(&'a EventFilter);
impl fmt::Display for PrintEventFilter<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
EventFilter::MessageLike(filter) => PrintMessageLikeEventFilter(filter).fmt(f),
EventFilter::State(filter) => PrintStateEventFilter(filter).fmt(f),
}
}
}
struct PrintMessageLikeEventFilter<'a>(&'a MessageLikeEventFilter);
impl fmt::Display for PrintMessageLikeEventFilter<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
MessageLikeEventFilter::WithType(event_type) => {
write!(f, "{event_type}")
}
MessageLikeEventFilter::RoomMessageWithMsgtype(msgtype) => {
write!(f, "m.room.message#{msgtype}")
}
}
}
}
struct PrintStateEventFilter<'a>(&'a StateEventFilter);
impl fmt::Display for PrintStateEventFilter<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
StateEventFilter::WithType(event_type) => write!(f, "{event_type}"),
StateEventFilter::WithTypeAndStateKey(event_type, state_key) => {
write!(f, "{event_type}#{state_key}")
}
}
}
}
let mut seq = serializer.serialize_seq(None)?;
if self.requires_client {
seq.serialize_element(REQUIRES_CLIENT)?;
}
if self.update_delayed_event {
seq.serialize_element(UPDATE_DELAYED_EVENT)?;
}
if self.send_delayed_event {
seq.serialize_element(SEND_DELAYED_EVENT)?;
}
for filter in &self.read {
let name = match filter {
EventFilter::MessageLike(_) => READ_EVENT,
EventFilter::State(_) => READ_STATE,
};
seq.serialize_element(&format!("{name}:{}", PrintEventFilter(filter)))?;
}
for filter in &self.send {
let name = match filter {
EventFilter::MessageLike(_) => SEND_EVENT,
EventFilter::State(_) => SEND_STATE,
};
seq.serialize_element(&format!("{name}:{}", PrintEventFilter(filter)))?;
}
seq.end()
}
}
impl<'de> Deserialize<'de> for Capabilities {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
enum Permission {
RequiresClient,
UpdateDelayedEvent,
SendDelayedEvent,
Read(EventFilter),
Send(EventFilter),
Unknown,
}
impl<'de> Deserialize<'de> for Permission {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = ruma::serde::deserialize_cow_str(deserializer)?;
if s == REQUIRES_CLIENT {
return Ok(Self::RequiresClient);
}
if s == UPDATE_DELAYED_EVENT {
return Ok(Self::UpdateDelayedEvent);
}
if s == SEND_DELAYED_EVENT {
return Ok(Self::SendDelayedEvent);
}
match s.split_once(':') {
Some((READ_EVENT, filter_s)) => Ok(Permission::Read(EventFilter::MessageLike(
parse_message_event_filter(filter_s),
))),
Some((SEND_EVENT, filter_s)) => Ok(Permission::Send(EventFilter::MessageLike(
parse_message_event_filter(filter_s),
))),
Some((READ_STATE, filter_s)) => {
Ok(Permission::Read(EventFilter::State(parse_state_event_filter(filter_s))))
}
Some((SEND_STATE, filter_s)) => {
Ok(Permission::Send(EventFilter::State(parse_state_event_filter(filter_s))))
}
_ => {
debug!("Unknown capability `{s}`");
Ok(Self::Unknown)
}
}
}
}
fn parse_message_event_filter(s: &str) -> MessageLikeEventFilter {
match s.strip_prefix("m.room.message#") {
Some(msgtype) => MessageLikeEventFilter::RoomMessageWithMsgtype(msgtype.to_owned()),
None => MessageLikeEventFilter::WithType(s.into()),
}
}
fn parse_state_event_filter(s: &str) -> StateEventFilter {
match s.split_once('#') {
Some((event_type, state_key)) => {
StateEventFilter::WithTypeAndStateKey(event_type.into(), state_key.to_owned())
}
None => StateEventFilter::WithType(s.into()),
}
}
let mut capabilities = Capabilities::default();
for capability in Vec::<Permission>::deserialize(deserializer)? {
match capability {
Permission::RequiresClient => capabilities.requires_client = true,
Permission::Read(filter) => capabilities.read.push(filter),
Permission::Send(filter) => capabilities.send.push(filter),
Permission::Unknown => {}
Permission::UpdateDelayedEvent => capabilities.update_delayed_event = true,
Permission::SendDelayedEvent => capabilities.send_delayed_event = true,
}
}
Ok(capabilities)
}
}
#[cfg(test)]
mod tests {
use ruma::events::StateEventType;
use super::*;
#[test]
fn deserialization_of_no_capabilities() {
let capabilities_str = r#"[]"#;
let parsed = serde_json::from_str::<Capabilities>(capabilities_str).unwrap();
let expected = Capabilities::default();
assert_eq!(parsed, expected);
}
#[test]
fn deserialization_of_capabilities() {
let capabilities_str = r#"[
"m.always_on_screen",
"io.element.requires_client",
"org.matrix.msc2762.receive.event:org.matrix.rageshake_request",
"org.matrix.msc2762.receive.state_event:m.room.member",
"org.matrix.msc2762.receive.state_event:org.matrix.msc3401.call.member",
"org.matrix.msc2762.send.event:org.matrix.rageshake_request",
"org.matrix.msc2762.send.state_event:org.matrix.msc3401.call.member#@user:matrix.server",
"org.matrix.msc4157.send.delayed_event",
"org.matrix.msc4157.update_delayed_event"
]"#;
let parsed = serde_json::from_str::<Capabilities>(capabilities_str).unwrap();
let expected = Capabilities {
read: vec![
EventFilter::MessageLike(MessageLikeEventFilter::WithType(
"org.matrix.rageshake_request".into(),
)),
EventFilter::State(StateEventFilter::WithType(StateEventType::RoomMember)),
EventFilter::State(StateEventFilter::WithType(
"org.matrix.msc3401.call.member".into(),
)),
],
send: vec![
EventFilter::MessageLike(MessageLikeEventFilter::WithType(
"org.matrix.rageshake_request".into(),
)),
EventFilter::State(StateEventFilter::WithTypeAndStateKey(
"org.matrix.msc3401.call.member".into(),
"@user:matrix.server".into(),
)),
],
requires_client: true,
update_delayed_event: true,
send_delayed_event: true,
};
assert_eq!(parsed, expected);
}
#[test]
fn serialization_and_deserialization_are_symmetrical() {
let capabilities = Capabilities {
read: vec![
EventFilter::MessageLike(MessageLikeEventFilter::WithType(
"io.element.custom".into(),
)),
EventFilter::State(StateEventFilter::WithType(StateEventType::RoomMember)),
EventFilter::State(StateEventFilter::WithTypeAndStateKey(
"org.matrix.msc3401.call.member".into(),
"@user:matrix.server".into(),
)),
],
send: vec![
EventFilter::MessageLike(MessageLikeEventFilter::WithType(
"io.element.custom".into(),
)),
EventFilter::State(StateEventFilter::WithTypeAndStateKey(
"org.matrix.msc3401.call.member".into(),
"@user:matrix.server".into(),
)),
],
requires_client: true,
update_delayed_event: false,
send_delayed_event: false,
};
let capabilities_str = serde_json::to_string(&capabilities).unwrap();
let parsed = serde_json::from_str::<Capabilities>(&capabilities_str).unwrap();
assert_eq!(parsed, capabilities);
}
}