matrix_sdk/
utils.rs

1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Utility types and traits.
16
17#[cfg(feature = "e2e-encryption")]
18use std::sync::{Arc, RwLock};
19
20#[cfg(feature = "e2e-encryption")]
21use futures_core::Stream;
22#[cfg(feature = "e2e-encryption")]
23use futures_util::StreamExt;
24#[cfg(feature = "markdown")]
25use ruma::events::room::message::FormattedBody;
26use ruma::{
27    events::{AnyMessageLikeEventContent, AnyStateEventContent},
28    serde::Raw,
29    RoomAliasId,
30};
31use serde_json::value::{RawValue as RawJsonValue, Value as JsonValue};
32#[cfg(feature = "e2e-encryption")]
33use tokio::sync::broadcast;
34#[cfg(feature = "e2e-encryption")]
35use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
36
37#[cfg(doc)]
38use crate::Room;
39
40/// An observable with channel semantics.
41///
42/// Channel semantics means that each update to the shared mutable value will be
43/// sent out to subscribers. That is, intermediate updates to the value will not
44/// be skipped like they would be in an observable without channel semantics.
45#[cfg(feature = "e2e-encryption")]
46#[derive(Clone, Debug)]
47pub(crate) struct ChannelObservable<T: Clone + Send> {
48    value: Arc<RwLock<T>>,
49    channel: broadcast::Sender<T>,
50}
51
52#[cfg(feature = "e2e-encryption")]
53impl<T: Default + Clone + Send + 'static> Default for ChannelObservable<T> {
54    fn default() -> Self {
55        let value = Default::default();
56        Self::new(value)
57    }
58}
59
60#[cfg(feature = "e2e-encryption")]
61impl<T: 'static + Send + Clone> ChannelObservable<T> {
62    /// Create a new [`ChannelObservable`] with the given value for the
63    /// underlying data.
64    pub(crate) fn new(value: T) -> Self {
65        let channel = broadcast::Sender::new(100);
66        Self { value: RwLock::new(value).into(), channel }
67    }
68
69    /// Subscribe to updates to the observable value.
70    ///
71    /// The current value will always be emitted as the first item in the
72    /// stream.
73    pub(crate) fn subscribe(&self) -> impl Stream<Item = Result<T, BroadcastStreamRecvError>> {
74        let current_value = self.value.read().unwrap().to_owned();
75        let initial_stream = tokio_stream::once(Ok(current_value));
76        let broadcast_stream = BroadcastStream::new(self.channel.subscribe());
77
78        initial_stream.chain(broadcast_stream)
79    }
80
81    /// Set the underlying data to the new value.
82    pub(crate) fn set(&self, new_value: T) -> T {
83        let old_value = {
84            let mut guard = self.value.write().unwrap();
85            std::mem::replace(&mut (*guard), new_value.clone())
86        };
87
88        // We're ignoring the error case where no receivers exist.
89        let _ = self.channel.send(new_value);
90
91        old_value
92    }
93
94    /// Get the current value of the underlying data.
95    pub(crate) fn get(&self) -> T {
96        self.value.read().unwrap().to_owned()
97    }
98}
99
100/// The set of types that can be used with [`Room::send_raw`].
101pub trait IntoRawMessageLikeEventContent {
102    #[doc(hidden)]
103    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent>;
104}
105
106impl IntoRawMessageLikeEventContent for Raw<AnyMessageLikeEventContent> {
107    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
108        self
109    }
110}
111
112impl IntoRawMessageLikeEventContent for &Raw<AnyMessageLikeEventContent> {
113    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
114        self.clone()
115    }
116}
117
118impl IntoRawMessageLikeEventContent for JsonValue {
119    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
120        (&self).into_raw_message_like_event_content()
121    }
122}
123
124impl IntoRawMessageLikeEventContent for &JsonValue {
125    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
126        Raw::new(self).expect("serde_json::Value never fails to serialize").cast()
127    }
128}
129
130impl IntoRawMessageLikeEventContent for Box<RawJsonValue> {
131    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
132        Raw::from_json(self)
133    }
134}
135
136impl IntoRawMessageLikeEventContent for &RawJsonValue {
137    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
138        self.to_owned().into_raw_message_like_event_content()
139    }
140}
141
142impl IntoRawMessageLikeEventContent for &Box<RawJsonValue> {
143    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
144        self.clone().into_raw_message_like_event_content()
145    }
146}
147
148/// The set of types that can be used with [`Room::send_state_event_raw`].
149pub trait IntoRawStateEventContent {
150    #[doc(hidden)]
151    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent>;
152}
153
154impl IntoRawStateEventContent for Raw<AnyStateEventContent> {
155    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
156        self
157    }
158}
159
160impl IntoRawStateEventContent for &Raw<AnyStateEventContent> {
161    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
162        self.clone()
163    }
164}
165
166impl IntoRawStateEventContent for JsonValue {
167    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
168        (&self).into_raw_state_event_content()
169    }
170}
171
172impl IntoRawStateEventContent for &JsonValue {
173    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
174        Raw::new(self).expect("serde_json::Value never fails to serialize").cast()
175    }
176}
177
178impl IntoRawStateEventContent for Box<RawJsonValue> {
179    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
180        Raw::from_json(self)
181    }
182}
183
184impl IntoRawStateEventContent for &RawJsonValue {
185    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
186        self.to_owned().into_raw_state_event_content()
187    }
188}
189
190impl IntoRawStateEventContent for &Box<RawJsonValue> {
191    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
192        self.clone().into_raw_state_event_content()
193    }
194}
195
196const INVALID_ROOM_ALIAS_NAME_CHARS: &str = "#,:{}\\";
197
198/// Verifies the passed `String` matches the expected room alias format:
199///
200/// This means it's lowercase, with no whitespace chars, has a single leading
201/// `#` char and a single `:` separator between the local and domain parts, and
202/// the local part only contains characters that can't be percent encoded.
203pub fn is_room_alias_format_valid(alias: String) -> bool {
204    let alias_parts: Vec<&str> = alias.split(':').collect();
205    if alias_parts.len() != 2 {
206        return false;
207    }
208
209    let local_part = alias_parts[0];
210    let has_valid_format = local_part.chars().skip(1).all(|c| {
211        c.is_ascii()
212            && !c.is_whitespace()
213            && !c.is_control()
214            && !INVALID_ROOM_ALIAS_NAME_CHARS.contains(c)
215    });
216
217    let is_lowercase = alias.to_lowercase() == alias;
218
219    // Checks both the local part and the domain part
220    has_valid_format && is_lowercase && RoomAliasId::parse(alias).is_ok()
221}
222
223/// Given a pair of optional `body` and `formatted_body` parameters,
224/// returns a formatted body.
225///
226/// Return the formatted body if available, or interpret the `body` parameter as
227/// markdown, if provided.
228#[cfg(feature = "markdown")]
229pub fn formatted_body_from(
230    body: Option<&str>,
231    formatted_body: Option<FormattedBody>,
232) -> Option<FormattedBody> {
233    if formatted_body.is_some() {
234        formatted_body
235    } else {
236        body.and_then(FormattedBody::markdown)
237    }
238}
239
240#[cfg(test)]
241mod test {
242    #[cfg(feature = "markdown")]
243    use assert_matches2::{assert_let, assert_matches};
244    #[cfg(feature = "markdown")]
245    use ruma::events::room::message::FormattedBody;
246
247    #[cfg(feature = "markdown")]
248    use crate::utils::formatted_body_from;
249    use crate::utils::is_room_alias_format_valid;
250
251    #[cfg(feature = "e2e-encryption")]
252    #[test]
253    fn test_channel_observable_get_set() {
254        let observable = super::ChannelObservable::new(0);
255
256        assert_eq!(observable.get(), 0);
257        assert_eq!(observable.set(1), 0);
258        assert_eq!(observable.set(10), 1);
259        assert_eq!(observable.get(), 10);
260    }
261
262    #[test]
263    fn test_is_room_alias_format_valid_when_it_has_no_leading_hash_char_is_not_valid() {
264        assert!(!is_room_alias_format_valid("alias:domain.org".to_owned()))
265    }
266
267    #[test]
268    fn test_is_room_alias_format_valid_when_it_has_several_colon_chars_is_not_valid() {
269        assert!(!is_room_alias_format_valid("#alias:something:domain.org".to_owned()))
270    }
271
272    #[test]
273    fn test_is_room_alias_format_valid_when_it_has_no_colon_chars_is_not_valid() {
274        assert!(!is_room_alias_format_valid("#alias.domain.org".to_owned()))
275    }
276
277    #[test]
278    fn test_is_room_alias_format_valid_when_server_part_is_not_valid() {
279        assert!(!is_room_alias_format_valid("#alias:".to_owned()))
280    }
281
282    #[test]
283    fn test_is_room_alias_format_valid_when_name_part_has_whitespace_is_not_valid() {
284        assert!(!is_room_alias_format_valid("#alias with whitespace:domain.org".to_owned()))
285    }
286
287    #[test]
288    fn test_is_room_alias_format_valid_when_name_part_has_control_char_is_not_valid() {
289        assert!(!is_room_alias_format_valid("#alias\u{0009}:domain.org".to_owned()))
290    }
291
292    #[test]
293    fn test_is_room_alias_format_valid_when_name_part_has_invalid_char_is_not_valid() {
294        assert!(!is_room_alias_format_valid("#a#lias,{t\\est}:domain.org".to_owned()))
295    }
296
297    #[test]
298    fn test_is_room_alias_format_valid_when_name_part_is_not_lowercase_is_not_valid() {
299        assert!(!is_room_alias_format_valid("#Alias:domain.org".to_owned()))
300    }
301
302    #[test]
303    fn test_is_room_alias_format_valid_when_server_part_is_not_lowercase_is_not_valid() {
304        assert!(!is_room_alias_format_valid("#alias:Domain.org".to_owned()))
305    }
306
307    #[test]
308    fn test_is_room_alias_format_valid_when_has_valid_format() {
309        assert!(is_room_alias_format_valid("#alias.test:domain.org".to_owned()))
310    }
311
312    #[test]
313    #[cfg(feature = "markdown")]
314    fn test_formatted_body_from_nothing_returns_none() {
315        assert_matches!(formatted_body_from(None, None), None);
316    }
317
318    #[test]
319    #[cfg(feature = "markdown")]
320    fn test_formatted_body_from_only_formatted_body_returns_the_formatted_body() {
321        let formatted_body = FormattedBody::html(r"<h1>Hello!</h1>");
322
323        assert_let!(
324            Some(result_formatted_body) = formatted_body_from(None, Some(formatted_body.clone()))
325        );
326
327        assert_eq!(formatted_body.body, result_formatted_body.body);
328        assert_eq!(result_formatted_body.format, result_formatted_body.format);
329    }
330
331    #[test]
332    #[cfg(feature = "markdown")]
333    fn test_formatted_body_from_markdown_body_returns_a_processed_formatted_body() {
334        let markdown_body = Some(r"# Parsed");
335
336        assert_let!(Some(result_formatted_body) = formatted_body_from(markdown_body, None));
337
338        let expected_formatted_body = FormattedBody::html("<h1>Parsed</h1>\n".to_owned());
339        assert_eq!(expected_formatted_body.body, result_formatted_body.body);
340        assert_eq!(expected_formatted_body.format, result_formatted_body.format);
341    }
342
343    #[test]
344    #[cfg(feature = "markdown")]
345    fn test_formatted_body_from_body_and_formatted_body_returns_the_formatted_body() {
346        let markdown_body = Some(r"# Markdown");
347        let formatted_body = FormattedBody::html(r"<h1>HTML</h1>");
348
349        assert_let!(
350            Some(result_formatted_body) =
351                formatted_body_from(markdown_body, Some(formatted_body.clone()))
352        );
353
354        assert_eq!(formatted_body.body, result_formatted_body.body);
355        assert_eq!(formatted_body.format, result_formatted_body.format);
356    }
357}