matrix_sdk_crypto/olm/group_sessions/
outbound.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    cmp::max,
17    collections::{BTreeMap, BTreeSet},
18    fmt,
19    ops::Bound,
20    sync::{
21        Arc, RwLockReadGuard,
22        atomic::{AtomicBool, AtomicU64, Ordering},
23    },
24    time::Duration,
25};
26
27use matrix_sdk_common::{deserialized_responses::WithheldCode, locks::RwLock as StdRwLock};
28#[cfg(feature = "experimental-encrypted-state-events")]
29use ruma::events::AnyStateEventContent;
30use ruma::{
31    DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId,
32    SecondsSinceUnixEpoch, TransactionId, UserId,
33    events::{
34        AnyMessageLikeEventContent,
35        room::{encryption::RoomEncryptionEventContent, history_visibility::HistoryVisibility},
36    },
37    serde::Raw,
38};
39use serde::{Deserialize, Serialize};
40use tokio::sync::RwLock;
41use tracing::{debug, error, info};
42use vodozemac::{Curve25519PublicKey, megolm::SessionConfig};
43pub use vodozemac::{
44    PickleError,
45    megolm::{GroupSession, GroupSessionPickle, MegolmMessage, SessionKey},
46    olm::IdentityKeys,
47};
48
49use super::SessionCreationError;
50#[cfg(feature = "experimental-algorithms")]
51use crate::types::events::room::encrypted::MegolmV2AesSha2Content;
52use crate::{
53    DeviceData,
54    olm::account::shared_history_from_history_visibility,
55    session_manager::CollectStrategy,
56    store::caches::SequenceNumber,
57    types::{
58        EventEncryptionAlgorithm,
59        events::{
60            room::encrypted::{
61                MegolmV1AesSha2Content, RoomEncryptedEventContent, RoomEventEncryptionScheme,
62            },
63            room_key::{MegolmV1AesSha2Content as MegolmV1AesSha2RoomKeyContent, RoomKeyContent},
64            room_key_withheld::RoomKeyWithheldContent,
65        },
66        requests::ToDeviceRequest,
67    },
68};
69
70const ONE_HOUR: Duration = Duration::from_secs(60 * 60);
71const ONE_WEEK: Duration = Duration::from_secs(60 * 60 * 24 * 7);
72
73const ROTATION_PERIOD: Duration = ONE_WEEK;
74const ROTATION_MESSAGES: u64 = 100;
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
77/// Information about whether a session was shared with a device.
78pub(crate) enum ShareState {
79    /// The session was not shared with the device.
80    NotShared,
81    /// The session was shared with the device with the given device ID, but
82    /// with a different curve25519 key.
83    SharedButChangedSenderKey,
84    /// The session was shared with the device, at the given message index. The
85    /// `olm_wedging_index` is the value of the `olm_wedging_index` from the
86    /// [`DeviceData`] at the time that we last shared the session with the
87    /// device, and indicates whether we need to re-share the session with the
88    /// device.
89    Shared { message_index: u32, olm_wedging_index: SequenceNumber },
90}
91
92/// Settings for an encrypted room.
93///
94/// This determines the algorithm and rotation periods of a group session.
95#[derive(Clone, Debug, Deserialize, Serialize)]
96pub struct EncryptionSettings {
97    /// The encryption algorithm that should be used in the room.
98    pub algorithm: EventEncryptionAlgorithm,
99    /// Whether state event encryption is enabled.
100    #[cfg(feature = "experimental-encrypted-state-events")]
101    #[serde(default)]
102    pub encrypt_state_events: bool,
103    /// How long the session should be used before changing it.
104    pub rotation_period: Duration,
105    /// How many messages should be sent before changing the session.
106    pub rotation_period_msgs: u64,
107    /// The history visibility of the room when the session was created.
108    pub history_visibility: HistoryVisibility,
109    /// The strategy used to distribute the room keys to participant.
110    /// Default will send to all devices.
111    #[serde(default)]
112    pub sharing_strategy: CollectStrategy,
113}
114
115impl Default for EncryptionSettings {
116    fn default() -> Self {
117        Self {
118            algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
119            #[cfg(feature = "experimental-encrypted-state-events")]
120            encrypt_state_events: false,
121            rotation_period: ROTATION_PERIOD,
122            rotation_period_msgs: ROTATION_MESSAGES,
123            history_visibility: HistoryVisibility::Shared,
124            sharing_strategy: CollectStrategy::default(),
125        }
126    }
127}
128
129impl EncryptionSettings {
130    /// Create new encryption settings using an `RoomEncryptionEventContent`,
131    /// a history visibility, and key sharing strategy.
132    pub fn new(
133        content: RoomEncryptionEventContent,
134        history_visibility: HistoryVisibility,
135        sharing_strategy: CollectStrategy,
136    ) -> Self {
137        let rotation_period: Duration =
138            content.rotation_period_ms.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into()));
139        let rotation_period_msgs: u64 =
140            content.rotation_period_msgs.map_or(ROTATION_MESSAGES, Into::into);
141
142        Self {
143            algorithm: EventEncryptionAlgorithm::from(content.algorithm.as_str()),
144            #[cfg(feature = "experimental-encrypted-state-events")]
145            encrypt_state_events: false,
146            rotation_period,
147            rotation_period_msgs,
148            history_visibility,
149            sharing_strategy,
150        }
151    }
152}
153
154/// The result of encrypting a message with an outbound group session.
155///
156/// Contains the encrypted content, the algorithm used, and the session ID.
157#[derive(Debug)]
158pub struct OutboundGroupSessionEncryptionResult {
159    /// The encrypted content of the message.
160    pub content: Raw<RoomEncryptedEventContent>,
161    /// The algorithm used to encrypt the message.
162    pub algorithm: EventEncryptionAlgorithm,
163    /// The session ID used to encrypt the message.
164    pub session_id: Arc<str>,
165}
166
167/// Outbound group session.
168///
169/// Outbound group sessions are used to exchange room messages between a group
170/// of participants. Outbound group sessions are used to encrypt the room
171/// messages.
172#[derive(Clone)]
173pub struct OutboundGroupSession {
174    inner: Arc<RwLock<GroupSession>>,
175    device_id: OwnedDeviceId,
176    account_identity_keys: Arc<IdentityKeys>,
177    session_id: Arc<str>,
178    room_id: OwnedRoomId,
179    pub(crate) creation_time: SecondsSinceUnixEpoch,
180    message_count: Arc<AtomicU64>,
181    shared: Arc<AtomicBool>,
182    invalidated: Arc<AtomicBool>,
183    settings: Arc<EncryptionSettings>,
184    shared_with_set: Arc<StdRwLock<ShareInfoSet>>,
185    to_share_with_set: Arc<StdRwLock<ToShareMap>>,
186}
187
188/// A a map of userid/device it to a `ShareInfo`.
189///
190/// Holds the `ShareInfo` for all the user/device pairs that will receive the
191/// room key.
192pub type ShareInfoSet = BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>;
193
194type ToShareMap = BTreeMap<OwnedTransactionId, (Arc<ToDeviceRequest>, ShareInfoSet)>;
195
196/// Struct holding info about the share state of a outbound group session.
197#[derive(Clone, Debug, Serialize, Deserialize)]
198pub enum ShareInfo {
199    /// When the key has been shared
200    Shared(SharedWith),
201    /// When the session has been withheld
202    Withheld(WithheldCode),
203}
204
205impl ShareInfo {
206    /// Helper to create a SharedWith info
207    pub fn new_shared(
208        sender_key: Curve25519PublicKey,
209        message_index: u32,
210        olm_wedging_index: SequenceNumber,
211    ) -> Self {
212        ShareInfo::Shared(SharedWith { sender_key, message_index, olm_wedging_index })
213    }
214
215    /// Helper to create a Withheld info
216    pub fn new_withheld(code: WithheldCode) -> Self {
217        ShareInfo::Withheld(code)
218    }
219}
220
221#[derive(Clone, Debug, Serialize, Deserialize)]
222pub struct SharedWith {
223    /// The sender key of the device that was used to encrypt the room key.
224    pub sender_key: Curve25519PublicKey,
225    /// The message index that the device received.
226    pub message_index: u32,
227    /// The Olm wedging index of the device at the time the session was shared.
228    #[serde(default)]
229    pub olm_wedging_index: SequenceNumber,
230}
231
232/// A read-only view into the device sharing state of an
233/// [`OutboundGroupSession`].
234pub(crate) struct SharingView<'a> {
235    shared_with_set: RwLockReadGuard<'a, ShareInfoSet>,
236    to_share_with_set: RwLockReadGuard<'a, ToShareMap>,
237}
238
239impl SharingView<'_> {
240    /// Has the session been shared with the given user/device pair (or if not,
241    /// is there such a request pending).
242    pub(crate) fn get_share_state(&self, device: &DeviceData) -> ShareState {
243        self.iter_shares(Some(device.user_id()), Some(device.device_id()))
244            .map(|(_, _, info)| match info {
245                ShareInfo::Shared(info) => {
246                    if device.curve25519_key() == Some(info.sender_key) {
247                        ShareState::Shared {
248                            message_index: info.message_index,
249                            olm_wedging_index: info.olm_wedging_index,
250                        }
251                    } else {
252                        ShareState::SharedButChangedSenderKey
253                    }
254                }
255                ShareInfo::Withheld(_) => ShareState::NotShared,
256            })
257            // Return the most "definitive" ShareState found (in case there
258            // are multiple entries for the same device).
259            .max()
260            .unwrap_or(ShareState::NotShared)
261    }
262
263    /// Has the session been withheld for the given user/device pair (or if not,
264    /// is there such a request pending).
265    pub(crate) fn is_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
266        self.iter_shares(Some(device.user_id()), Some(device.device_id()))
267            .any(|(_, _, info)| matches!(info, ShareInfo::Withheld(c) if c == code))
268    }
269
270    /// Enumerate all sent or pending sharing requests for the given device (or
271    /// for all devices if not specified).  This can yield the same device
272    /// multiple times.
273    pub(crate) fn iter_shares<'b, 'c>(
274        &self,
275        user_id: Option<&'b UserId>,
276        device_id: Option<&'c DeviceId>,
277    ) -> impl Iterator<Item = (&UserId, &DeviceId, &ShareInfo)> + use<'_, 'b, 'c> {
278        fn iter_share_info_set<'a, 'b, 'c>(
279            set: &'a ShareInfoSet,
280            user_ids: (Bound<&'b UserId>, Bound<&'b UserId>),
281            device_ids: (Bound<&'c DeviceId>, Bound<&'c DeviceId>),
282        ) -> impl Iterator<Item = (&'a UserId, &'a DeviceId, &'a ShareInfo)> + use<'a, 'b, 'c>
283        {
284            set.range::<UserId, _>(user_ids).flat_map(move |(uid, d)| {
285                d.range::<DeviceId, _>(device_ids)
286                    .map(|(id, info)| (uid.as_ref(), id.as_ref(), info))
287            })
288        }
289
290        let user_ids = user_id
291            .map(|u| (Bound::Included(u), Bound::Included(u)))
292            .unwrap_or((Bound::Unbounded, Bound::Unbounded));
293        let device_ids = device_id
294            .map(|d| (Bound::Included(d), Bound::Included(d)))
295            .unwrap_or((Bound::Unbounded, Bound::Unbounded));
296
297        let already_shared = iter_share_info_set(&self.shared_with_set, user_ids, device_ids);
298        let pending = self
299            .to_share_with_set
300            .values()
301            .flat_map(move |(_, set)| iter_share_info_set(set, user_ids, device_ids));
302        already_shared.chain(pending)
303    }
304
305    /// Enumerate all users that have received the session, or have pending
306    /// requests to receive it.  This can yield the same user multiple times,
307    /// so you may want to `collect()` the result into a `BTreeSet`.
308    pub(crate) fn shared_with_users(&self) -> impl Iterator<Item = &UserId> {
309        self.iter_shares(None, None).filter_map(|(u, _, info)| match info {
310            ShareInfo::Shared(_) => Some(u),
311            ShareInfo::Withheld(_) => None,
312        })
313    }
314}
315
316impl OutboundGroupSession {
317    pub(super) fn session_config(
318        algorithm: &EventEncryptionAlgorithm,
319    ) -> Result<SessionConfig, SessionCreationError> {
320        match algorithm {
321            EventEncryptionAlgorithm::MegolmV1AesSha2 => Ok(SessionConfig::version_1()),
322            #[cfg(feature = "experimental-algorithms")]
323            EventEncryptionAlgorithm::MegolmV2AesSha2 => Ok(SessionConfig::version_2()),
324            _ => Err(SessionCreationError::Algorithm(algorithm.to_owned())),
325        }
326    }
327
328    /// Create a new outbound group session for the given room.
329    ///
330    /// Outbound group sessions are used to encrypt room messages.
331    ///
332    /// # Arguments
333    ///
334    /// * `device_id` - The id of the device that created this session.
335    ///
336    /// * `identity_keys` - The identity keys of the account that created this
337    ///   session.
338    ///
339    /// * `room_id` - The id of the room that the session is used in.
340    ///
341    /// * `settings` - Settings determining the algorithm and rotation period of
342    ///   the outbound group session.
343    pub fn new(
344        device_id: OwnedDeviceId,
345        identity_keys: Arc<IdentityKeys>,
346        room_id: &RoomId,
347        settings: EncryptionSettings,
348    ) -> Result<Self, SessionCreationError> {
349        let config = Self::session_config(&settings.algorithm)?;
350
351        let session = GroupSession::new(config);
352        let session_id = session.session_id();
353
354        Ok(OutboundGroupSession {
355            inner: RwLock::new(session).into(),
356            room_id: room_id.into(),
357            device_id,
358            account_identity_keys: identity_keys,
359            session_id: session_id.into(),
360            creation_time: SecondsSinceUnixEpoch::now(),
361            message_count: Arc::new(AtomicU64::new(0)),
362            shared: Arc::new(AtomicBool::new(false)),
363            invalidated: Arc::new(AtomicBool::new(false)),
364            settings: Arc::new(settings),
365            shared_with_set: Default::default(),
366            to_share_with_set: Default::default(),
367        })
368    }
369
370    /// Add a to-device request that is sending the session key (or room key)
371    /// belonging to this [`OutboundGroupSession`] to other members of the
372    /// group.
373    ///
374    /// The request will get persisted with the session which allows seamless
375    /// session reuse across application restarts.
376    ///
377    /// **Warning** this method is only exposed to be used in integration tests
378    /// of crypto-store implementations. **Do not use this outside of tests**.
379    pub fn add_request(
380        &self,
381        request_id: OwnedTransactionId,
382        request: Arc<ToDeviceRequest>,
383        share_infos: ShareInfoSet,
384    ) {
385        self.to_share_with_set.write().insert(request_id, (request, share_infos));
386    }
387
388    /// Create a new `m.room_key.withheld` event content with the given code for
389    /// this outbound group session.
390    pub fn withheld_code(&self, code: WithheldCode) -> RoomKeyWithheldContent {
391        RoomKeyWithheldContent::new(
392            self.settings().algorithm.to_owned(),
393            code,
394            self.room_id().to_owned(),
395            self.session_id().to_owned(),
396            self.sender_key().to_owned(),
397            (*self.device_id).to_owned(),
398        )
399    }
400
401    /// This should be called if an the user wishes to rotate this session.
402    pub fn invalidate_session(&self) {
403        self.invalidated.store(true, Ordering::Relaxed)
404    }
405
406    /// Get the encryption settings of this outbound session.
407    pub fn settings(&self) -> &EncryptionSettings {
408        &self.settings
409    }
410
411    /// Mark the request with the given request id as sent.
412    ///
413    /// This removes the request from the queue and marks the set of
414    /// users/devices that received the session.
415    pub fn mark_request_as_sent(
416        &self,
417        request_id: &TransactionId,
418    ) -> BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>> {
419        let mut no_olm_devices = BTreeMap::new();
420
421        let removed = self.to_share_with_set.write().remove(request_id);
422        if let Some((to_device, request)) = removed {
423            let recipients: BTreeMap<&UserId, BTreeSet<&DeviceId>> = request
424                .iter()
425                .map(|(u, d)| (u.as_ref(), d.keys().map(|d| d.as_ref()).collect()))
426                .collect();
427
428            info!(
429                ?request_id,
430                ?recipients,
431                ?to_device.event_type,
432                "Marking to-device request carrying a room key or a withheld as sent"
433            );
434
435            for (user_id, info) in request {
436                let no_olms: BTreeSet<OwnedDeviceId> = info
437                    .iter()
438                    .filter(|(_, info)| matches!(info, ShareInfo::Withheld(WithheldCode::NoOlm)))
439                    .map(|(d, _)| d.to_owned())
440                    .collect();
441                no_olm_devices.insert(user_id.to_owned(), no_olms);
442
443                self.shared_with_set.write().entry(user_id).or_default().extend(info);
444            }
445
446            if self.to_share_with_set.read().is_empty() {
447                debug!(
448                    session_id = self.session_id(),
449                    room_id = ?self.room_id,
450                    "All m.room_key and withheld to-device requests were sent out, marking \
451                     session as shared.",
452                );
453
454                self.mark_as_shared();
455            }
456        } else {
457            let request_ids: Vec<String> =
458                self.to_share_with_set.read().keys().map(|k| k.to_string()).collect();
459
460            error!(
461                all_request_ids = ?request_ids,
462                ?request_id,
463                "Marking to-device request carrying a room key as sent but no \
464                 request found with the given id"
465            );
466        }
467
468        no_olm_devices
469    }
470
471    /// Encrypt the given plaintext using this session.
472    ///
473    /// Returns the encrypted ciphertext.
474    ///
475    /// # Arguments
476    ///
477    /// * `plaintext` - The plaintext that should be encrypted.
478    pub(crate) async fn encrypt_helper(&self, plaintext: String) -> MegolmMessage {
479        let mut session = self.inner.write().await;
480        self.message_count.fetch_add(1, Ordering::SeqCst);
481        session.encrypt(&plaintext)
482    }
483
484    /// Encrypt an arbitrary event for the given room.
485    ///
486    /// Beware that a room key needs to be shared before this method
487    /// can be called using the `share_room_key()` method.
488    ///
489    /// # Arguments
490    ///
491    /// * `payload` - The plaintext content of the event that should be
492    ///   serialized to JSON and encrypted.
493    ///
494    /// # Panics
495    ///
496    /// Panics if the content can't be serialized.
497    async fn encrypt_inner<T: Serialize>(
498        &self,
499        payload: &T,
500        relates_to: Option<serde_json::Value>,
501    ) -> OutboundGroupSessionEncryptionResult {
502        let ciphertext = self
503            .encrypt_helper(
504                serde_json::to_string(payload).expect("payload serialization never fails"),
505            )
506            .await;
507        let scheme: RoomEventEncryptionScheme = match self.settings.algorithm {
508            EventEncryptionAlgorithm::MegolmV1AesSha2 => MegolmV1AesSha2Content {
509                ciphertext,
510                sender_key: Some(self.account_identity_keys.curve25519),
511                session_id: self.session_id().to_owned(),
512                device_id: Some(self.device_id.clone()),
513            }
514            .into(),
515            #[cfg(feature = "experimental-algorithms")]
516            EventEncryptionAlgorithm::MegolmV2AesSha2 => {
517                MegolmV2AesSha2Content { ciphertext, session_id: self.session_id().to_owned() }
518                    .into()
519            }
520            _ => unreachable!(
521                "An outbound group session is always using one of the supported algorithms"
522            ),
523        };
524        let content = RoomEncryptedEventContent { scheme, relates_to, other: Default::default() };
525
526        OutboundGroupSessionEncryptionResult {
527            content: Raw::new(&content)
528                .expect("m.room.encrypted event content can always be serialized"),
529            algorithm: self.settings.algorithm.to_owned(),
530            session_id: self.session_id.clone(),
531        }
532    }
533
534    /// Encrypt a room message for the given room.
535    ///
536    /// Beware that a room key needs to be shared before this method
537    /// can be called using the `share_room_key()` method.
538    ///
539    /// # Arguments
540    ///
541    /// * `event_type` - The plaintext type of the event, the outer type of the
542    ///   event will become `m.room.encrypted`.
543    ///
544    /// * `content` - The plaintext content of the message that should be
545    ///   encrypted in raw JSON form.
546    ///
547    /// # Panics
548    ///
549    /// Panics if the content can't be serialized.
550    pub async fn encrypt(
551        &self,
552        event_type: &str,
553        content: &Raw<AnyMessageLikeEventContent>,
554    ) -> OutboundGroupSessionEncryptionResult {
555        #[derive(Serialize)]
556        struct Payload<'a> {
557            #[serde(rename = "type")]
558            event_type: &'a str,
559            content: &'a Raw<AnyMessageLikeEventContent>,
560            room_id: &'a RoomId,
561        }
562
563        let payload = Payload { event_type, content, room_id: &self.room_id };
564
565        let relates_to = content
566            .get_field::<serde_json::Value>("m.relates_to")
567            .expect("serde_json::Value deserialization with valid JSON input never fails");
568
569        self.encrypt_inner(&payload, relates_to).await
570    }
571
572    /// Encrypt a room state event for the given room.
573    ///
574    /// Beware that a room key needs to be shared before this method
575    /// can be called using the `share_room_key()` method.
576    ///
577    /// # Arguments
578    ///
579    /// * `event_type` - The plaintext type of the event, the outer type of the
580    ///   event will become `m.room.encrypted`.
581    ///
582    /// * `state_key` - The plaintext state key of the event, the outer state
583    ///   key will be derived from this and the event type.
584    ///
585    /// * `content` - The plaintext content of the message that should be
586    ///   encrypted in raw JSON form.
587    ///
588    /// # Panics
589    ///
590    /// Panics if the content can't be serialized.
591    #[cfg(feature = "experimental-encrypted-state-events")]
592    pub async fn encrypt_state(
593        &self,
594        event_type: &str,
595        state_key: &str,
596        content: &Raw<AnyStateEventContent>,
597    ) -> Raw<RoomEncryptedEventContent> {
598        #[derive(Serialize)]
599        struct Payload<'a> {
600            #[serde(rename = "type")]
601            event_type: &'a str,
602            state_key: &'a str,
603            content: &'a Raw<AnyStateEventContent>,
604            room_id: &'a RoomId,
605        }
606
607        let payload = Payload { event_type, state_key, content, room_id: &self.room_id };
608        self.encrypt_inner(&payload, None).await.content
609    }
610
611    fn elapsed(&self) -> bool {
612        let creation_time = Duration::from_secs(self.creation_time.get().into());
613        let now = Duration::from_secs(SecondsSinceUnixEpoch::now().get().into());
614        now.checked_sub(creation_time)
615            .map(|elapsed| elapsed >= self.safe_rotation_period())
616            .unwrap_or(true)
617    }
618
619    /// Returns the rotation_period_ms that was set for this session, clamped
620    /// to be no less than one hour.
621    ///
622    /// This is to prevent a malicious or careless user causing sessions to be
623    /// rotated very frequently.
624    ///
625    /// The feature flag `_disable-minimum-rotation-period-ms` can
626    /// be used to prevent this behaviour (which can be useful for tests).
627    fn safe_rotation_period(&self) -> Duration {
628        if cfg!(feature = "_disable-minimum-rotation-period-ms") {
629            self.settings.rotation_period
630        } else {
631            max(self.settings.rotation_period, ONE_HOUR)
632        }
633    }
634
635    /// Check if the session has expired and if it should be rotated.
636    ///
637    /// A session will expire after some time or if enough messages have been
638    /// encrypted using it.
639    pub fn expired(&self) -> bool {
640        let count = self.message_count.load(Ordering::SeqCst);
641        // We clamp the rotation period for message counts to be between 1 and
642        // 10000. The Megolm session should be usable for at least 1 message,
643        // and at most 10000 messages. Realistically Megolm uses u32 for it's
644        // internal counter and one could use the Megolm session for up to
645        // u32::MAX messages, but we're staying on the safe side of things.
646        let rotation_period_msgs = self.settings.rotation_period_msgs.clamp(1, 10_000);
647
648        count >= rotation_period_msgs || self.elapsed()
649    }
650
651    /// Has the session been invalidated.
652    pub fn invalidated(&self) -> bool {
653        self.invalidated.load(Ordering::Relaxed)
654    }
655
656    /// Mark the session as shared.
657    ///
658    /// Messages shouldn't be encrypted with the session before it has been
659    /// shared.
660    pub fn mark_as_shared(&self) {
661        self.shared.store(true, Ordering::Relaxed);
662    }
663
664    /// Check if the session has been marked as shared.
665    pub fn shared(&self) -> bool {
666        self.shared.load(Ordering::Relaxed)
667    }
668
669    /// Get the session key of this session.
670    ///
671    /// A session key can be used to to create an `InboundGroupSession`.
672    pub async fn session_key(&self) -> SessionKey {
673        let session = self.inner.read().await;
674        session.session_key()
675    }
676
677    /// Gets the Sender Key
678    pub fn sender_key(&self) -> Curve25519PublicKey {
679        self.account_identity_keys.as_ref().curve25519.to_owned()
680    }
681
682    /// Get the room id of the room this session belongs to.
683    pub fn room_id(&self) -> &RoomId {
684        &self.room_id
685    }
686
687    /// Returns the unique identifier for this session.
688    pub fn session_id(&self) -> &str {
689        &self.session_id
690    }
691
692    /// Get the current message index for this session.
693    ///
694    /// Each message is sent with an increasing index. This returns the
695    /// message index that will be used for the next encrypted message.
696    pub async fn message_index(&self) -> u32 {
697        let session = self.inner.read().await;
698        session.message_index()
699    }
700
701    pub(crate) async fn as_content(&self) -> RoomKeyContent {
702        let session_key = self.session_key().await;
703        let shared_history =
704            shared_history_from_history_visibility(&self.settings.history_visibility);
705
706        RoomKeyContent::MegolmV1AesSha2(
707            MegolmV1AesSha2RoomKeyContent::new(
708                self.room_id().to_owned(),
709                self.session_id().to_owned(),
710                session_key,
711                shared_history,
712            )
713            .into(),
714        )
715    }
716
717    /// Create a read-only view into the device sharing state of this session.
718    /// This view includes pending requests, so it is not guaranteed that the
719    /// represented state has been fully propagated yet.
720    pub(crate) fn sharing_view(&self) -> SharingView<'_> {
721        SharingView {
722            shared_with_set: self.shared_with_set.read(),
723            to_share_with_set: self.to_share_with_set.read(),
724        }
725    }
726
727    /// Mark the session as shared with the given user/device pair, starting
728    /// from some message index.
729    #[cfg(test)]
730    pub fn mark_shared_with_from_index(
731        &self,
732        user_id: &UserId,
733        device_id: &DeviceId,
734        sender_key: Curve25519PublicKey,
735        index: u32,
736    ) {
737        self.shared_with_set.write().entry(user_id.to_owned()).or_default().insert(
738            device_id.to_owned(),
739            ShareInfo::new_shared(sender_key, index, Default::default()),
740        );
741    }
742
743    /// Mark the session as shared with the given user/device pair, starting
744    /// from the current index.
745    #[cfg(test)]
746    pub async fn mark_shared_with(
747        &self,
748        user_id: &UserId,
749        device_id: &DeviceId,
750        sender_key: Curve25519PublicKey,
751    ) {
752        let share_info =
753            ShareInfo::new_shared(sender_key, self.message_index().await, Default::default());
754        self.shared_with_set
755            .write()
756            .entry(user_id.to_owned())
757            .or_default()
758            .insert(device_id.to_owned(), share_info);
759    }
760
761    /// Get the list of requests that need to be sent out for this session to be
762    /// marked as shared.
763    pub(crate) fn pending_requests(&self) -> Vec<Arc<ToDeviceRequest>> {
764        self.to_share_with_set.read().values().map(|(req, _)| req.clone()).collect()
765    }
766
767    /// Get the list of request ids this session is waiting for to be sent out.
768    pub(crate) fn pending_request_ids(&self) -> Vec<OwnedTransactionId> {
769        self.to_share_with_set.read().keys().cloned().collect()
770    }
771
772    /// Restore a Session from a previously pickled string.
773    ///
774    /// Returns the restored group session or a `OlmGroupSessionError` if there
775    /// was an error.
776    ///
777    /// # Arguments
778    ///
779    /// * `device_id` - The device ID of the device that created this session.
780    ///   Put differently, our own device ID.
781    ///
782    /// * `identity_keys` - The identity keys of the device that created this
783    ///   session, our own identity keys.
784    ///
785    /// * `pickle` - The pickled version of the `OutboundGroupSession`.
786    ///
787    /// * `pickle_mode` - The mode that was used to pickle the session, either
788    ///   an unencrypted mode or an encrypted using passphrase.
789    pub fn from_pickle(
790        device_id: OwnedDeviceId,
791        identity_keys: Arc<IdentityKeys>,
792        pickle: PickledOutboundGroupSession,
793    ) -> Result<Self, PickleError> {
794        let inner: GroupSession = pickle.pickle.into();
795        let session_id = inner.session_id();
796
797        Ok(Self {
798            inner: Arc::new(RwLock::new(inner)),
799            device_id,
800            account_identity_keys: identity_keys,
801            session_id: session_id.into(),
802            room_id: pickle.room_id,
803            creation_time: pickle.creation_time,
804            message_count: AtomicU64::from(pickle.message_count).into(),
805            shared: AtomicBool::from(pickle.shared).into(),
806            invalidated: AtomicBool::from(pickle.invalidated).into(),
807            settings: pickle.settings,
808            shared_with_set: Arc::new(StdRwLock::new(pickle.shared_with_set)),
809            to_share_with_set: Arc::new(StdRwLock::new(pickle.requests)),
810        })
811    }
812
813    /// Store the group session as a base64 encoded string and associated data
814    /// belonging to the session.
815    ///
816    /// # Arguments
817    ///
818    /// * `pickle_mode` - The mode that should be used to pickle the group
819    ///   session, either an unencrypted mode or an encrypted using passphrase.
820    pub async fn pickle(&self) -> PickledOutboundGroupSession {
821        let pickle = self.inner.read().await.pickle();
822
823        PickledOutboundGroupSession {
824            pickle,
825            room_id: self.room_id.clone(),
826            settings: self.settings.clone(),
827            creation_time: self.creation_time,
828            message_count: self.message_count.load(Ordering::SeqCst),
829            shared: self.shared(),
830            invalidated: self.invalidated(),
831            shared_with_set: self.shared_with_set.read().clone(),
832            requests: self.to_share_with_set.read().clone(),
833        }
834    }
835}
836
837#[cfg(not(tarpaulin_include))]
838impl fmt::Debug for OutboundGroupSession {
839    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
840        f.debug_struct("OutboundGroupSession")
841            .field("session_id", &self.session_id)
842            .field("room_id", &self.room_id)
843            .field("creation_time", &self.creation_time)
844            .field("message_count", &self.message_count)
845            .finish()
846    }
847}
848
849/// A pickled version of an `InboundGroupSession`.
850///
851/// Holds all the information that needs to be stored in a database to restore
852/// an InboundGroupSession.
853#[derive(Deserialize, Serialize)]
854#[allow(missing_debug_implementations)]
855pub struct PickledOutboundGroupSession {
856    /// The pickle string holding the OutboundGroupSession.
857    pub pickle: GroupSessionPickle,
858    /// The settings this session adheres to.
859    pub settings: Arc<EncryptionSettings>,
860    /// The room id this session is used for.
861    pub room_id: OwnedRoomId,
862    /// The timestamp when this session was created.
863    pub creation_time: SecondsSinceUnixEpoch,
864    /// The number of messages this session has already encrypted.
865    pub message_count: u64,
866    /// Is the session shared.
867    pub shared: bool,
868    /// Has the session been invalidated.
869    pub invalidated: bool,
870    /// The set of users the session has been already shared with.
871    pub shared_with_set: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>,
872    /// Requests that need to be sent out to share the session.
873    pub requests: BTreeMap<OwnedTransactionId, (Arc<ToDeviceRequest>, ShareInfoSet)>,
874}
875
876#[cfg(test)]
877mod tests {
878    use std::time::Duration;
879
880    use ruma::{
881        EventEncryptionAlgorithm,
882        events::room::{
883            encryption::RoomEncryptionEventContent, history_visibility::HistoryVisibility,
884        },
885        uint,
886    };
887
888    use super::{EncryptionSettings, ROTATION_MESSAGES, ROTATION_PERIOD, ShareState};
889    use crate::CollectStrategy;
890
891    #[test]
892    fn test_encryption_settings_conversion() {
893        let mut content =
894            RoomEncryptionEventContent::new(EventEncryptionAlgorithm::MegolmV1AesSha2);
895        let settings = EncryptionSettings::new(
896            content.clone(),
897            HistoryVisibility::Joined,
898            CollectStrategy::AllDevices,
899        );
900
901        assert_eq!(settings.rotation_period, ROTATION_PERIOD);
902        assert_eq!(settings.rotation_period_msgs, ROTATION_MESSAGES);
903
904        content.rotation_period_ms = Some(uint!(3600));
905        content.rotation_period_msgs = Some(uint!(500));
906
907        let settings = EncryptionSettings::new(
908            content,
909            HistoryVisibility::Shared,
910            CollectStrategy::AllDevices,
911        );
912
913        assert_eq!(settings.rotation_period, Duration::from_millis(3600));
914        assert_eq!(settings.rotation_period_msgs, 500);
915    }
916
917    /// Ensure that the `ShareState` PartialOrd instance orders according to
918    /// specificity of the value.
919    #[test]
920    fn test_share_state_ordering() {
921        let values = [
922            ShareState::NotShared,
923            ShareState::SharedButChangedSenderKey,
924            ShareState::Shared { message_index: 1, olm_wedging_index: Default::default() },
925        ];
926        // Make sure our test case of possible variants is exhaustive
927        match values[0] {
928            ShareState::NotShared
929            | ShareState::SharedButChangedSenderKey
930            | ShareState::Shared { .. } => {}
931        }
932        assert!(values.is_sorted());
933    }
934
935    #[cfg(any(target_os = "linux", target_os = "macos", target_family = "wasm"))]
936    mod expiration {
937        use std::{sync::atomic::Ordering, time::Duration};
938
939        use matrix_sdk_test::async_test;
940        use ruma::{
941            SecondsSinceUnixEpoch, device_id, events::room::message::RoomMessageEventContent,
942            room_id, serde::Raw, uint, user_id,
943        };
944
945        use crate::{
946            Account, EncryptionSettings, MegolmError,
947            olm::{OutboundGroupSession, SenderData},
948        };
949
950        const TWO_HOURS: Duration = Duration::from_secs(60 * 60 * 2);
951
952        #[async_test]
953        async fn test_session_is_not_expired_if_no_messages_sent_and_no_time_passed() {
954            // Given a session that expires after one message
955            let session = create_session(EncryptionSettings {
956                rotation_period_msgs: 1,
957                ..Default::default()
958            })
959            .await;
960
961            // When we send no messages at all
962
963            // Then it is not expired
964            assert!(!session.expired());
965        }
966
967        #[async_test]
968        async fn test_session_is_expired_if_we_rotate_every_message_and_one_was_sent()
969        -> Result<(), MegolmError> {
970            // Given a session that expires after one message
971            let session = create_session(EncryptionSettings {
972                rotation_period_msgs: 1,
973                ..Default::default()
974            })
975            .await;
976
977            // When we send a message
978            let _ = session
979                .encrypt(
980                    "m.room.message",
981                    &Raw::new(&RoomMessageEventContent::text_plain("Test message"))?.cast(),
982                )
983                .await;
984
985            // Then the session is expired
986            assert!(session.expired());
987
988            Ok(())
989        }
990
991        #[async_test]
992        async fn test_session_with_rotation_period_is_not_expired_after_no_time() {
993            // Given a session with a 2h expiration
994            let session = create_session(EncryptionSettings {
995                rotation_period: TWO_HOURS,
996                ..Default::default()
997            })
998            .await;
999
1000            // When we don't allow any time to pass
1001
1002            // Then it is not expired
1003            assert!(!session.expired());
1004        }
1005
1006        #[async_test]
1007        async fn test_session_is_expired_after_rotation_period() {
1008            // Given a session with a 2h expiration
1009            let mut session = create_session(EncryptionSettings {
1010                rotation_period: TWO_HOURS,
1011                ..Default::default()
1012            })
1013            .await;
1014
1015            // When 3 hours have passed
1016            let now = SecondsSinceUnixEpoch::now();
1017            session.creation_time = SecondsSinceUnixEpoch(now.get() - uint!(10800));
1018
1019            // Then the session is expired
1020            assert!(session.expired());
1021        }
1022
1023        #[async_test]
1024        #[cfg(not(feature = "_disable-minimum-rotation-period-ms"))]
1025        async fn test_session_does_not_expire_under_one_hour_even_if_we_ask_for_shorter() {
1026            // Given a session with a 100ms expiration
1027            let mut session = create_session(EncryptionSettings {
1028                rotation_period: Duration::from_millis(100),
1029                ..Default::default()
1030            })
1031            .await;
1032
1033            // When less than an hour has passed
1034            let now = SecondsSinceUnixEpoch::now();
1035            session.creation_time = SecondsSinceUnixEpoch(now.get() - uint!(1800));
1036
1037            // Then the session is not expired: we enforce a minimum of 1 hour
1038            assert!(!session.expired());
1039
1040            // But when more than an hour has passed
1041            session.creation_time = SecondsSinceUnixEpoch(now.get() - uint!(3601));
1042
1043            // Then the session is expired
1044            assert!(session.expired());
1045        }
1046
1047        #[async_test]
1048        #[cfg(feature = "_disable-minimum-rotation-period-ms")]
1049        async fn test_with_disable_minrotperiod_feature_sessions_can_expire_quickly() {
1050            // Given a session with a 100ms expiration
1051            let mut session = create_session(EncryptionSettings {
1052                rotation_period: Duration::from_millis(100),
1053                ..Default::default()
1054            })
1055            .await;
1056
1057            // When less than an hour has passed
1058            let now = SecondsSinceUnixEpoch::now();
1059            session.creation_time = SecondsSinceUnixEpoch(now.get() - uint!(1800));
1060
1061            // Then the session is expired: the feature flag has prevented us enforcing a
1062            // minimum
1063            assert!(session.expired());
1064        }
1065
1066        #[async_test]
1067        async fn test_session_with_zero_msgs_rotation_is_not_expired_initially() {
1068            // Given a session that is supposed to expire after zero messages
1069            let session = create_session(EncryptionSettings {
1070                rotation_period_msgs: 0,
1071                ..Default::default()
1072            })
1073            .await;
1074
1075            // When we send no messages
1076
1077            // Then the session is not expired: we are protected against this nonsensical
1078            // setup
1079            assert!(!session.expired());
1080        }
1081
1082        #[async_test]
1083        async fn test_session_with_zero_msgs_rotation_expires_after_one_message()
1084        -> Result<(), MegolmError> {
1085            // Given a session that is supposed to expire after zero messages
1086            let session = create_session(EncryptionSettings {
1087                rotation_period_msgs: 0,
1088                ..Default::default()
1089            })
1090            .await;
1091
1092            // When we send a message
1093            let _ = session
1094                .encrypt(
1095                    "m.room.message",
1096                    &Raw::new(&RoomMessageEventContent::text_plain("Test message"))?.cast(),
1097                )
1098                .await;
1099
1100            // Then the session is expired: we treated rotation_period_msgs=0 as if it were
1101            // =1
1102            assert!(session.expired());
1103
1104            Ok(())
1105        }
1106
1107        #[async_test]
1108        async fn test_session_expires_after_10k_messages_even_if_we_ask_for_more() {
1109            // Given we asked to expire after 100K messages
1110            let session = create_session(EncryptionSettings {
1111                rotation_period_msgs: 100_000,
1112                ..Default::default()
1113            })
1114            .await;
1115
1116            // Sanity: it does not expire after <10K messages
1117            assert!(!session.expired());
1118            session.message_count.store(1000, Ordering::SeqCst);
1119            assert!(!session.expired());
1120            session.message_count.store(9999, Ordering::SeqCst);
1121            assert!(!session.expired());
1122
1123            // When we have sent >= 10K messages
1124            session.message_count.store(10_000, Ordering::SeqCst);
1125
1126            // Then it is considered expired: we enforce a maximum of 10K messages before
1127            // rotation.
1128            assert!(session.expired());
1129        }
1130
1131        async fn create_session(settings: EncryptionSettings) -> OutboundGroupSession {
1132            let account =
1133                Account::with_device_id(user_id!("@alice:example.org"), device_id!("DEVICEID"))
1134                    .static_data;
1135            let (session, _) = account
1136                .create_group_session_pair(
1137                    room_id!("!test_room:example.org"),
1138                    settings,
1139                    SenderData::unknown(),
1140                )
1141                .await
1142                .unwrap();
1143            session
1144        }
1145    }
1146}