Skip to main content

matrix_sdk/utils/
mod.rs

1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Utility types and traits.
16
17#[cfg(feature = "e2e-encryption")]
18use std::sync::{Arc, RwLock};
19
20#[cfg(feature = "e2e-encryption")]
21use futures_core::Stream;
22#[cfg(feature = "e2e-encryption")]
23use futures_util::StreamExt;
24#[cfg(feature = "markdown")]
25use ruma::events::room::message::FormattedBody;
26use ruma::{
27    RoomAliasId,
28    events::{AnyMessageLikeEventContent, AnyStateEventContent},
29    serde::Raw,
30};
31use serde_json::value::{RawValue as RawJsonValue, Value as JsonValue};
32#[cfg(feature = "e2e-encryption")]
33use tokio::sync::broadcast;
34#[cfg(feature = "e2e-encryption")]
35use tokio_stream::wrappers::{BroadcastStream, errors::BroadcastStreamRecvError};
36use url::Url;
37
38#[cfg(feature = "local-server")]
39pub mod local_server;
40
41#[cfg(feature = "local-server")]
42use self::local_server::QueryString;
43#[cfg(doc)]
44use crate::Room;
45
46/// An observable with channel semantics.
47///
48/// Channel semantics means that each update to the shared mutable value will be
49/// sent out to subscribers. That is, intermediate updates to the value will not
50/// be skipped like they would be in an observable without channel semantics.
51#[cfg(feature = "e2e-encryption")]
52#[derive(Clone, Debug)]
53pub(crate) struct ChannelObservable<T: Clone + Send> {
54    value: Arc<RwLock<T>>,
55    channel: broadcast::Sender<T>,
56}
57
58#[cfg(feature = "e2e-encryption")]
59impl<T: Default + Clone + Send + 'static> Default for ChannelObservable<T> {
60    fn default() -> Self {
61        let value = Default::default();
62        Self::new(value)
63    }
64}
65
66#[cfg(feature = "e2e-encryption")]
67impl<T: 'static + Send + Clone> ChannelObservable<T> {
68    /// Create a new [`ChannelObservable`] with the given value for the
69    /// underlying data.
70    pub(crate) fn new(value: T) -> Self {
71        let channel = broadcast::Sender::new(100);
72        Self { value: RwLock::new(value).into(), channel }
73    }
74
75    /// Subscribe to updates to the observable value.
76    ///
77    /// The current value will always be emitted as the first item in the
78    /// stream.
79    pub(crate) fn subscribe(
80        &self,
81    ) -> impl Stream<Item = Result<T, BroadcastStreamRecvError>> + use<T> {
82        let current_value = self.value.read().unwrap().to_owned();
83        let initial_stream = tokio_stream::once(Ok(current_value));
84        let broadcast_stream = BroadcastStream::new(self.channel.subscribe());
85
86        initial_stream.chain(broadcast_stream)
87    }
88
89    /// Set the underlying data to the new value.
90    pub(crate) fn set(&self, new_value: T) -> T {
91        let old_value = {
92            let mut guard = self.value.write().unwrap();
93            std::mem::replace(&mut (*guard), new_value.clone())
94        };
95
96        // We're ignoring the error case where no receivers exist.
97        let _ = self.channel.send(new_value);
98
99        old_value
100    }
101
102    /// Get the current value of the underlying data.
103    pub(crate) fn get(&self) -> T {
104        self.value.read().unwrap().to_owned()
105    }
106}
107
108/// The set of types that can be used with [`Room::send_raw`].
109pub trait IntoRawMessageLikeEventContent {
110    #[doc(hidden)]
111    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent>;
112}
113
114impl IntoRawMessageLikeEventContent for Raw<AnyMessageLikeEventContent> {
115    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
116        self
117    }
118}
119
120impl IntoRawMessageLikeEventContent for &Raw<AnyMessageLikeEventContent> {
121    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
122        self.clone()
123    }
124}
125
126impl IntoRawMessageLikeEventContent for JsonValue {
127    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
128        (&self).into_raw_message_like_event_content()
129    }
130}
131
132impl IntoRawMessageLikeEventContent for &JsonValue {
133    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
134        Raw::new(self).expect("serde_json::Value never fails to serialize").cast_unchecked()
135    }
136}
137
138impl IntoRawMessageLikeEventContent for Box<RawJsonValue> {
139    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
140        Raw::from_json(self)
141    }
142}
143
144impl IntoRawMessageLikeEventContent for &RawJsonValue {
145    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
146        self.to_owned().into_raw_message_like_event_content()
147    }
148}
149
150impl IntoRawMessageLikeEventContent for &Box<RawJsonValue> {
151    fn into_raw_message_like_event_content(self) -> Raw<AnyMessageLikeEventContent> {
152        self.clone().into_raw_message_like_event_content()
153    }
154}
155
156/// The set of types that can be used with [`Room::send_state_event_raw`].
157pub trait IntoRawStateEventContent {
158    #[doc(hidden)]
159    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent>;
160}
161
162impl IntoRawStateEventContent for Raw<AnyStateEventContent> {
163    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
164        self
165    }
166}
167
168impl IntoRawStateEventContent for &Raw<AnyStateEventContent> {
169    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
170        self.clone()
171    }
172}
173
174impl IntoRawStateEventContent for JsonValue {
175    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
176        (&self).into_raw_state_event_content()
177    }
178}
179
180impl IntoRawStateEventContent for &JsonValue {
181    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
182        Raw::new(self).expect("serde_json::Value never fails to serialize").cast_unchecked()
183    }
184}
185
186impl IntoRawStateEventContent for Box<RawJsonValue> {
187    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
188        Raw::from_json(self)
189    }
190}
191
192impl IntoRawStateEventContent for &RawJsonValue {
193    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
194        self.to_owned().into_raw_state_event_content()
195    }
196}
197
198impl IntoRawStateEventContent for &Box<RawJsonValue> {
199    fn into_raw_state_event_content(self) -> Raw<AnyStateEventContent> {
200        self.clone().into_raw_state_event_content()
201    }
202}
203
204const INVALID_ROOM_ALIAS_NAME_CHARS: &str = "#,:{}\\";
205
206/// Verifies the passed `String` matches the expected room alias format:
207///
208/// This means it's lowercase, with no whitespace chars, has a single leading
209/// `#` char and a single `:` separator between the local and domain parts, and
210/// the local part only contains characters that can't be percent encoded.
211pub fn is_room_alias_format_valid(alias: String) -> bool {
212    let alias_parts: Vec<&str> = alias.split(':').collect();
213    if alias_parts.len() != 2 {
214        return false;
215    }
216
217    let local_part = alias_parts[0];
218    let has_valid_format = local_part.chars().skip(1).all(|c| {
219        c.is_ascii()
220            && !c.is_whitespace()
221            && !c.is_control()
222            && !INVALID_ROOM_ALIAS_NAME_CHARS.contains(c)
223    });
224
225    let is_lowercase = alias.to_lowercase() == alias;
226
227    // Checks both the local part and the domain part
228    has_valid_format && is_lowercase && RoomAliasId::parse(alias).is_ok()
229}
230
231/// Given a pair of optional `body` and `formatted_body` parameters,
232/// returns a formatted body.
233///
234/// Return the formatted body if available, or interpret the `body` parameter as
235/// markdown, if provided.
236#[cfg(feature = "markdown")]
237pub fn formatted_body_from(
238    body: Option<&str>,
239    formatted_body: Option<FormattedBody>,
240) -> Option<FormattedBody> {
241    if formatted_body.is_some() { formatted_body } else { body.and_then(FormattedBody::markdown) }
242}
243
244/// A full URL or just the query part of a URL.
245///
246/// This is a convenience type to be able to access the query part of a URL from
247/// different sources.
248#[derive(Debug, Clone, PartialEq, Eq)]
249pub enum UrlOrQuery {
250    /// A full URL.
251    Url(Url),
252
253    /// The query part of a URL.
254    Query(String),
255}
256
257impl UrlOrQuery {
258    /// Get the query part of this [`UrlOrQuery`].
259    ///
260    /// If this is a [`Url`], this extracts the query.
261    pub fn query(&self) -> Option<&str> {
262        match self {
263            Self::Url(url) => url.query(),
264            Self::Query(query) => Some(query),
265        }
266    }
267}
268
269impl From<Url> for UrlOrQuery {
270    fn from(value: Url) -> Self {
271        Self::Url(value)
272    }
273}
274
275#[cfg(feature = "local-server")]
276impl From<QueryString> for UrlOrQuery {
277    fn from(value: QueryString) -> Self {
278        Self::Query(value.0)
279    }
280}
281
282#[cfg(test)]
283mod test {
284    #[cfg(feature = "markdown")]
285    use assert_matches2::{assert_let, assert_matches};
286    #[cfg(feature = "markdown")]
287    use ruma::events::room::message::FormattedBody;
288
289    #[cfg(feature = "markdown")]
290    use crate::utils::formatted_body_from;
291    use crate::utils::is_room_alias_format_valid;
292
293    #[cfg(feature = "e2e-encryption")]
294    #[test]
295    fn test_channel_observable_get_set() {
296        let observable = super::ChannelObservable::new(0);
297
298        assert_eq!(observable.get(), 0);
299        assert_eq!(observable.set(1), 0);
300        assert_eq!(observable.set(10), 1);
301        assert_eq!(observable.get(), 10);
302    }
303
304    #[test]
305    fn test_is_room_alias_format_valid_when_it_has_no_leading_hash_char_is_not_valid() {
306        assert!(!is_room_alias_format_valid("alias:domain.org".to_owned()))
307    }
308
309    #[test]
310    fn test_is_room_alias_format_valid_when_it_has_several_colon_chars_is_not_valid() {
311        assert!(!is_room_alias_format_valid("#alias:something:domain.org".to_owned()))
312    }
313
314    #[test]
315    fn test_is_room_alias_format_valid_when_it_has_no_colon_chars_is_not_valid() {
316        assert!(!is_room_alias_format_valid("#alias.domain.org".to_owned()))
317    }
318
319    #[test]
320    fn test_is_room_alias_format_valid_when_server_part_is_not_valid() {
321        assert!(!is_room_alias_format_valid("#alias:".to_owned()))
322    }
323
324    #[test]
325    fn test_is_room_alias_format_valid_when_name_part_has_whitespace_is_not_valid() {
326        assert!(!is_room_alias_format_valid("#alias with whitespace:domain.org".to_owned()))
327    }
328
329    #[test]
330    fn test_is_room_alias_format_valid_when_name_part_has_control_char_is_not_valid() {
331        assert!(!is_room_alias_format_valid("#alias\u{0009}:domain.org".to_owned()))
332    }
333
334    #[test]
335    fn test_is_room_alias_format_valid_when_name_part_has_invalid_char_is_not_valid() {
336        assert!(!is_room_alias_format_valid("#a#lias,{t\\est}:domain.org".to_owned()))
337    }
338
339    #[test]
340    fn test_is_room_alias_format_valid_when_name_part_is_not_lowercase_is_not_valid() {
341        assert!(!is_room_alias_format_valid("#Alias:domain.org".to_owned()))
342    }
343
344    #[test]
345    fn test_is_room_alias_format_valid_when_server_part_is_not_lowercase_is_not_valid() {
346        assert!(!is_room_alias_format_valid("#alias:Domain.org".to_owned()))
347    }
348
349    #[test]
350    fn test_is_room_alias_format_valid_when_has_valid_format() {
351        assert!(is_room_alias_format_valid("#alias.test:domain.org".to_owned()))
352    }
353
354    #[test]
355    #[cfg(feature = "markdown")]
356    fn test_formatted_body_from_nothing_returns_none() {
357        assert_matches!(formatted_body_from(None, None), None);
358    }
359
360    #[test]
361    #[cfg(feature = "markdown")]
362    fn test_formatted_body_from_only_formatted_body_returns_the_formatted_body() {
363        let formatted_body = FormattedBody::html(r"<h1>Hello!</h1>");
364
365        assert_let!(
366            Some(result_formatted_body) = formatted_body_from(None, Some(formatted_body.clone()))
367        );
368
369        assert_eq!(formatted_body.body, result_formatted_body.body);
370        assert_eq!(result_formatted_body.format, result_formatted_body.format);
371    }
372
373    #[test]
374    #[cfg(feature = "markdown")]
375    fn test_formatted_body_from_markdown_body_returns_a_processed_formatted_body() {
376        let markdown_body = Some(r"# Parsed");
377
378        assert_let!(Some(result_formatted_body) = formatted_body_from(markdown_body, None));
379
380        let expected_formatted_body = FormattedBody::html("<h1>Parsed</h1>\n".to_owned());
381        assert_eq!(expected_formatted_body.body, result_formatted_body.body);
382        assert_eq!(expected_formatted_body.format, result_formatted_body.format);
383    }
384
385    #[test]
386    #[cfg(feature = "markdown")]
387    fn test_formatted_body_from_body_and_formatted_body_returns_the_formatted_body() {
388        let markdown_body = Some(r"# Markdown");
389        let formatted_body = FormattedBody::html(r"<h1>HTML</h1>");
390
391        assert_let!(
392            Some(result_formatted_body) =
393                formatted_body_from(markdown_body, Some(formatted_body.clone()))
394        );
395
396        assert_eq!(formatted_body.body, result_formatted_body.body);
397        assert_eq!(formatted_body.format, result_formatted_body.format);
398    }
399}