matrix_sdk/utils/
mod.rs

1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Utility types and traits.
16
17#[cfg(feature = "e2e-encryption")]
18use std::sync::{Arc, RwLock};
19
20#[cfg(feature = "e2e-encryption")]
21use futures_core::Stream;
22#[cfg(feature = "e2e-encryption")]
23use futures_util::StreamExt;
24#[cfg(feature = "markdown")]
25use ruma::events::room::message::FormattedBody;
26use ruma::{
27    RoomAliasId,
28    events::{AnyMessageLikeEventContent, AnyStateEventContent},
29    serde::Raw,
30};
31use serde_json::value::{RawValue as RawJsonValue, Value as JsonValue};
32#[cfg(feature = "e2e-encryption")]
33use tokio::sync::broadcast;
34#[cfg(feature = "e2e-encryption")]
35use tokio_stream::wrappers::{BroadcastStream, errors::BroadcastStreamRecvError};
36
37#[cfg(feature = "local-server")]
38pub mod local_server;
39
40#[cfg(doc)]
41use crate::Room;
42
43/// An observable with channel semantics.
44///
45/// Channel semantics means that each update to the shared mutable value will be
46/// sent out to subscribers. That is, intermediate updates to the value will not
47/// be skipped like they would be in an observable without channel semantics.
48#[cfg(feature = "e2e-encryption")]
49#[derive(Clone, Debug)]
50pub(crate) struct ChannelObservable<T: Clone + Send> {
51    value: Arc<RwLock<T>>,
52    channel: broadcast::Sender<T>,
53}
54
55#[cfg(feature = "e2e-encryption")]
56impl<T: Default + Clone + Send + 'static> Default for ChannelObservable<T> {
57    fn default() -> Self {
58        let value = Default::default();
59        Self::new(value)
60    }
61}
62
63#[cfg(feature = "e2e-encryption")]
64impl<T: 'static + Send + Clone> ChannelObservable<T> {
65    /// Create a new [`ChannelObservable`] with the given value for the
66    /// underlying data.
67    pub(crate) fn new(value: T) -> Self {
68        let channel = broadcast::Sender::new(100);
69        Self { value: RwLock::new(value).into(), channel }
70    }
71
72    /// Subscribe to updates to the observable value.
73    ///
74    /// The current value will always be emitted as the first item in the
75    /// stream.
76    pub(crate) fn subscribe(
77        &self,
78    ) -> impl Stream<Item = Result<T, BroadcastStreamRecvError>> + use<T> {
79        let current_value = self.value.read().unwrap().to_owned();
80        let initial_stream = tokio_stream::once(Ok(current_value));
81        let broadcast_stream = BroadcastStream::new(self.channel.subscribe());
82
83        initial_stream.chain(broadcast_stream)
84    }
85
86    /// Set the underlying data to the new value.
87    pub(crate) fn set(&self, new_value: T) -> T {
88        let old_value = {
89            let mut guard = self.value.write().unwrap();
90            std::mem::replace(&mut (*guard), new_value.clone())
91        };
92
93        // We're ignoring the error case where no receivers exist.
94        let _ = self.channel.send(new_value);
95
96        old_value
97    }
98
99    /// Get the current value of the underlying data.
100    pub(crate) fn get(&self) -> T {
101        self.value.read().unwrap().to_owned()
102    }
103}
104
105/// The set of types that can be used with [`Room::send_raw`].
106pub trait IntoRawMessageLikeEventContent {
107    #[doc(hidden)]
108    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent>;
109}
110
111impl IntoRawMessageLikeEventContent for Raw<AnyMessageLikeEventContent> {
112    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
113        self
114    }
115}
116
117impl IntoRawMessageLikeEventContent for &Raw<AnyMessageLikeEventContent> {
118    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
119        self.clone()
120    }
121}
122
123impl IntoRawMessageLikeEventContent for JsonValue {
124    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
125        (&self).into_raw_message_like_event_content()
126    }
127}
128
129impl IntoRawMessageLikeEventContent for &JsonValue {
130    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
131        Raw::new(self).expect("serde_json::Value never fails to serialize").cast_unchecked()
132    }
133}
134
135impl IntoRawMessageLikeEventContent for Box<RawJsonValue> {
136    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
137        Raw::from_json(self)
138    }
139}
140
141impl IntoRawMessageLikeEventContent for &RawJsonValue {
142    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
143        self.to_owned().into_raw_message_like_event_content()
144    }
145}
146
147impl IntoRawMessageLikeEventContent for &Box<RawJsonValue> {
148    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
149        self.clone().into_raw_message_like_event_content()
150    }
151}
152
153/// The set of types that can be used with [`Room::send_state_event_raw`].
154pub trait IntoRawStateEventContent {
155    #[doc(hidden)]
156    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent>;
157}
158
159impl IntoRawStateEventContent for Raw<AnyStateEventContent> {
160    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
161        self
162    }
163}
164
165impl IntoRawStateEventContent for &Raw<AnyStateEventContent> {
166    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
167        self.clone()
168    }
169}
170
171impl IntoRawStateEventContent for JsonValue {
172    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
173        (&self).into_raw_state_event_content()
174    }
175}
176
177impl IntoRawStateEventContent for &JsonValue {
178    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
179        Raw::new(self).expect("serde_json::Value never fails to serialize").cast_unchecked()
180    }
181}
182
183impl IntoRawStateEventContent for Box<RawJsonValue> {
184    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
185        Raw::from_json(self)
186    }
187}
188
189impl IntoRawStateEventContent for &RawJsonValue {
190    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
191        self.to_owned().into_raw_state_event_content()
192    }
193}
194
195impl IntoRawStateEventContent for &Box<RawJsonValue> {
196    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
197        self.clone().into_raw_state_event_content()
198    }
199}
200
201const INVALID_ROOM_ALIAS_NAME_CHARS: &str = "#,:{}\\";
202
203/// Verifies the passed `String` matches the expected room alias format:
204///
205/// This means it's lowercase, with no whitespace chars, has a single leading
206/// `#` char and a single `:` separator between the local and domain parts, and
207/// the local part only contains characters that can't be percent encoded.
208pub fn is_room_alias_format_valid(alias: String) -> bool {
209    let alias_parts: Vec<&str> = alias.split(':').collect();
210    if alias_parts.len() != 2 {
211        return false;
212    }
213
214    let local_part = alias_parts[0];
215    let has_valid_format = local_part.chars().skip(1).all(|c| {
216        c.is_ascii()
217            && !c.is_whitespace()
218            && !c.is_control()
219            && !INVALID_ROOM_ALIAS_NAME_CHARS.contains(c)
220    });
221
222    let is_lowercase = alias.to_lowercase() == alias;
223
224    // Checks both the local part and the domain part
225    has_valid_format && is_lowercase && RoomAliasId::parse(alias).is_ok()
226}
227
228/// Given a pair of optional `body` and `formatted_body` parameters,
229/// returns a formatted body.
230///
231/// Return the formatted body if available, or interpret the `body` parameter as
232/// markdown, if provided.
233#[cfg(feature = "markdown")]
234pub fn formatted_body_from(
235    body: Option<&str>,
236    formatted_body: Option<FormattedBody>,
237) -> Option<FormattedBody> {
238    if formatted_body.is_some() { formatted_body } else { body.and_then(FormattedBody::markdown) }
239}
240
241#[cfg(test)]
242mod test {
243    #[cfg(feature = "markdown")]
244    use assert_matches2::{assert_let, assert_matches};
245    #[cfg(feature = "markdown")]
246    use ruma::events::room::message::FormattedBody;
247
248    #[cfg(feature = "markdown")]
249    use crate::utils::formatted_body_from;
250    use crate::utils::is_room_alias_format_valid;
251
252    #[cfg(feature = "e2e-encryption")]
253    #[test]
254    fn test_channel_observable_get_set() {
255        let observable = super::ChannelObservable::new(0);
256
257        assert_eq!(observable.get(), 0);
258        assert_eq!(observable.set(1), 0);
259        assert_eq!(observable.set(10), 1);
260        assert_eq!(observable.get(), 10);
261    }
262
263    #[test]
264    fn test_is_room_alias_format_valid_when_it_has_no_leading_hash_char_is_not_valid() {
265        assert!(!is_room_alias_format_valid("alias:domain.org".to_owned()))
266    }
267
268    #[test]
269    fn test_is_room_alias_format_valid_when_it_has_several_colon_chars_is_not_valid() {
270        assert!(!is_room_alias_format_valid("#alias:something:domain.org".to_owned()))
271    }
272
273    #[test]
274    fn test_is_room_alias_format_valid_when_it_has_no_colon_chars_is_not_valid() {
275        assert!(!is_room_alias_format_valid("#alias.domain.org".to_owned()))
276    }
277
278    #[test]
279    fn test_is_room_alias_format_valid_when_server_part_is_not_valid() {
280        assert!(!is_room_alias_format_valid("#alias:".to_owned()))
281    }
282
283    #[test]
284    fn test_is_room_alias_format_valid_when_name_part_has_whitespace_is_not_valid() {
285        assert!(!is_room_alias_format_valid("#alias with whitespace:domain.org".to_owned()))
286    }
287
288    #[test]
289    fn test_is_room_alias_format_valid_when_name_part_has_control_char_is_not_valid() {
290        assert!(!is_room_alias_format_valid("#alias\u{0009}:domain.org".to_owned()))
291    }
292
293    #[test]
294    fn test_is_room_alias_format_valid_when_name_part_has_invalid_char_is_not_valid() {
295        assert!(!is_room_alias_format_valid("#a#lias,{t\\est}:domain.org".to_owned()))
296    }
297
298    #[test]
299    fn test_is_room_alias_format_valid_when_name_part_is_not_lowercase_is_not_valid() {
300        assert!(!is_room_alias_format_valid("#Alias:domain.org".to_owned()))
301    }
302
303    #[test]
304    fn test_is_room_alias_format_valid_when_server_part_is_not_lowercase_is_not_valid() {
305        assert!(!is_room_alias_format_valid("#alias:Domain.org".to_owned()))
306    }
307
308    #[test]
309    fn test_is_room_alias_format_valid_when_has_valid_format() {
310        assert!(is_room_alias_format_valid("#alias.test:domain.org".to_owned()))
311    }
312
313    #[test]
314    #[cfg(feature = "markdown")]
315    fn test_formatted_body_from_nothing_returns_none() {
316        assert_matches!(formatted_body_from(None, None), None);
317    }
318
319    #[test]
320    #[cfg(feature = "markdown")]
321    fn test_formatted_body_from_only_formatted_body_returns_the_formatted_body() {
322        let formatted_body = FormattedBody::html(r"<h1>Hello!</h1>");
323
324        assert_let!(
325            Some(result_formatted_body) = formatted_body_from(None, Some(formatted_body.clone()))
326        );
327
328        assert_eq!(formatted_body.body, result_formatted_body.body);
329        assert_eq!(result_formatted_body.format, result_formatted_body.format);
330    }
331
332    #[test]
333    #[cfg(feature = "markdown")]
334    fn test_formatted_body_from_markdown_body_returns_a_processed_formatted_body() {
335        let markdown_body = Some(r"# Parsed");
336
337        assert_let!(Some(result_formatted_body) = formatted_body_from(markdown_body, None));
338
339        let expected_formatted_body = FormattedBody::html("<h1>Parsed</h1>\n".to_owned());
340        assert_eq!(expected_formatted_body.body, result_formatted_body.body);
341        assert_eq!(expected_formatted_body.format, result_formatted_body.format);
342    }
343
344    #[test]
345    #[cfg(feature = "markdown")]
346    fn test_formatted_body_from_body_and_formatted_body_returns_the_formatted_body() {
347        let markdown_body = Some(r"# Markdown");
348        let formatted_body = FormattedBody::html(r"<h1>HTML</h1>");
349
350        assert_let!(
351            Some(result_formatted_body) =
352                formatted_body_from(markdown_body, Some(formatted_body.clone()))
353        );
354
355        assert_eq!(formatted_body.body, result_formatted_body.body);
356        assert_eq!(formatted_body.format, result_formatted_body.format);
357    }
358}