matrix_sdk_crypto/olm/group_sessions/
outbound.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    cmp::max,
17    collections::{BTreeMap, BTreeSet},
18    fmt,
19    sync::{
20        atomic::{AtomicBool, AtomicU64, Ordering},
21        Arc,
22    },
23    time::Duration,
24};
25
26use matrix_sdk_common::{deserialized_responses::WithheldCode, locks::RwLock as StdRwLock};
27use ruma::{
28    events::{
29        room::{encryption::RoomEncryptionEventContent, history_visibility::HistoryVisibility},
30        AnyMessageLikeEventContent,
31    },
32    serde::Raw,
33    DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId,
34    SecondsSinceUnixEpoch, TransactionId, UserId,
35};
36use serde::{Deserialize, Serialize};
37use tokio::sync::RwLock;
38use tracing::{debug, error, info};
39use vodozemac::{megolm::SessionConfig, Curve25519PublicKey};
40pub use vodozemac::{
41    megolm::{GroupSession, GroupSessionPickle, MegolmMessage, SessionKey},
42    olm::IdentityKeys,
43    PickleError,
44};
45
46use super::SessionCreationError;
47#[cfg(feature = "experimental-algorithms")]
48use crate::types::events::room::encrypted::MegolmV2AesSha2Content;
49use crate::{
50    session_manager::CollectStrategy,
51    store::caches::SequenceNumber,
52    types::{
53        events::{
54            room::encrypted::{
55                MegolmV1AesSha2Content, RoomEncryptedEventContent, RoomEventEncryptionScheme,
56            },
57            room_key::{MegolmV1AesSha2Content as MegolmV1AesSha2RoomKeyContent, RoomKeyContent},
58            room_key_withheld::RoomKeyWithheldContent,
59        },
60        requests::ToDeviceRequest,
61        EventEncryptionAlgorithm,
62    },
63    DeviceData,
64};
65
66const ONE_HOUR: Duration = Duration::from_secs(60 * 60);
67const ONE_WEEK: Duration = Duration::from_secs(60 * 60 * 24 * 7);
68
69const ROTATION_PERIOD: Duration = ONE_WEEK;
70const ROTATION_MESSAGES: u64 = 100;
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73/// Information about whether a session was shared with a device.
74pub(crate) enum ShareState {
75    /// The session was not shared with the device.
76    NotShared,
77    /// The session was shared with the device with the given device ID, but
78    /// with a different curve25519 key.
79    SharedButChangedSenderKey,
80    /// The session was shared with the device, at the given message index. The
81    /// `olm_wedging_index` is the value of the `olm_wedging_index` from the
82    /// [`DeviceData`] at the time that we last shared the session with the
83    /// device, and indicates whether we need to re-share the session with the
84    /// device.
85    Shared { message_index: u32, olm_wedging_index: SequenceNumber },
86}
87
88/// Settings for an encrypted room.
89///
90/// This determines the algorithm and rotation periods of a group session.
91#[derive(Clone, Debug, Deserialize, Serialize)]
92pub struct EncryptionSettings {
93    /// The encryption algorithm that should be used in the room.
94    pub algorithm: EventEncryptionAlgorithm,
95    /// How long the session should be used before changing it.
96    pub rotation_period: Duration,
97    /// How many messages should be sent before changing the session.
98    pub rotation_period_msgs: u64,
99    /// The history visibility of the room when the session was created.
100    pub history_visibility: HistoryVisibility,
101    /// The strategy used to distribute the room keys to participant.
102    /// Default will send to all devices.
103    #[serde(default)]
104    pub sharing_strategy: CollectStrategy,
105}
106
107impl Default for EncryptionSettings {
108    fn default() -> Self {
109        Self {
110            algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
111            rotation_period: ROTATION_PERIOD,
112            rotation_period_msgs: ROTATION_MESSAGES,
113            history_visibility: HistoryVisibility::Shared,
114            sharing_strategy: CollectStrategy::default(),
115        }
116    }
117}
118
119impl EncryptionSettings {
120    /// Create new encryption settings using an `RoomEncryptionEventContent`,
121    /// a history visibility, and key sharing strategy.
122    pub fn new(
123        content: RoomEncryptionEventContent,
124        history_visibility: HistoryVisibility,
125        sharing_strategy: CollectStrategy,
126    ) -> Self {
127        let rotation_period: Duration =
128            content.rotation_period_ms.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into()));
129        let rotation_period_msgs: u64 =
130            content.rotation_period_msgs.map_or(ROTATION_MESSAGES, Into::into);
131
132        Self {
133            algorithm: EventEncryptionAlgorithm::from(content.algorithm.as_str()),
134            rotation_period,
135            rotation_period_msgs,
136            history_visibility,
137            sharing_strategy,
138        }
139    }
140}
141
142/// Outbound group session.
143///
144/// Outbound group sessions are used to exchange room messages between a group
145/// of participants. Outbound group sessions are used to encrypt the room
146/// messages.
147#[derive(Clone)]
148pub struct OutboundGroupSession {
149    inner: Arc<RwLock<GroupSession>>,
150    device_id: OwnedDeviceId,
151    account_identity_keys: Arc<IdentityKeys>,
152    session_id: Arc<str>,
153    room_id: OwnedRoomId,
154    pub(crate) creation_time: SecondsSinceUnixEpoch,
155    message_count: Arc<AtomicU64>,
156    shared: Arc<AtomicBool>,
157    invalidated: Arc<AtomicBool>,
158    settings: Arc<EncryptionSettings>,
159    pub(crate) shared_with_set:
160        Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>>>,
161    #[allow(clippy::type_complexity)]
162    to_share_with_set:
163        Arc<StdRwLock<BTreeMap<OwnedTransactionId, (Arc<ToDeviceRequest>, ShareInfoSet)>>>,
164}
165
166/// A a map of userid/device it to a `ShareInfo`.
167///
168/// Holds the `ShareInfo` for all the user/device pairs that will receive the
169/// room key.
170pub type ShareInfoSet = BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>;
171
172/// Struct holding info about the share state of a outbound group session.
173#[derive(Clone, Debug, Serialize, Deserialize)]
174pub enum ShareInfo {
175    /// When the key has been shared
176    Shared(SharedWith),
177    /// When the session has been withheld
178    Withheld(WithheldCode),
179}
180
181impl ShareInfo {
182    /// Helper to create a SharedWith info
183    pub fn new_shared(
184        sender_key: Curve25519PublicKey,
185        message_index: u32,
186        olm_wedging_index: SequenceNumber,
187    ) -> Self {
188        ShareInfo::Shared(SharedWith { sender_key, message_index, olm_wedging_index })
189    }
190
191    /// Helper to create a Withheld info
192    pub fn new_withheld(code: WithheldCode) -> Self {
193        ShareInfo::Withheld(code)
194    }
195}
196
197#[derive(Clone, Debug, Serialize, Deserialize)]
198pub struct SharedWith {
199    /// The sender key of the device that was used to encrypt the room key.
200    pub sender_key: Curve25519PublicKey,
201    /// The message index that the device received.
202    pub message_index: u32,
203    /// The Olm wedging index of the device at the time the session was shared.
204    #[serde(default)]
205    pub olm_wedging_index: SequenceNumber,
206}
207
208impl OutboundGroupSession {
209    pub(super) fn session_config(
210        algorithm: &EventEncryptionAlgorithm,
211    ) -> Result<SessionConfig, SessionCreationError> {
212        match algorithm {
213            EventEncryptionAlgorithm::MegolmV1AesSha2 => Ok(SessionConfig::version_1()),
214            #[cfg(feature = "experimental-algorithms")]
215            EventEncryptionAlgorithm::MegolmV2AesSha2 => Ok(SessionConfig::version_2()),
216            _ => Err(SessionCreationError::Algorithm(algorithm.to_owned())),
217        }
218    }
219
220    /// Create a new outbound group session for the given room.
221    ///
222    /// Outbound group sessions are used to encrypt room messages.
223    ///
224    /// # Arguments
225    ///
226    /// * `device_id` - The id of the device that created this session.
227    ///
228    /// * `identity_keys` - The identity keys of the account that created this
229    ///   session.
230    ///
231    /// * `room_id` - The id of the room that the session is used in.
232    ///
233    /// * `settings` - Settings determining the algorithm and rotation period of
234    ///   the outbound group session.
235    pub fn new(
236        device_id: OwnedDeviceId,
237        identity_keys: Arc<IdentityKeys>,
238        room_id: &RoomId,
239        settings: EncryptionSettings,
240    ) -> Result<Self, SessionCreationError> {
241        let config = Self::session_config(&settings.algorithm)?;
242
243        let session = GroupSession::new(config);
244        let session_id = session.session_id();
245
246        Ok(OutboundGroupSession {
247            inner: RwLock::new(session).into(),
248            room_id: room_id.into(),
249            device_id,
250            account_identity_keys: identity_keys,
251            session_id: session_id.into(),
252            creation_time: SecondsSinceUnixEpoch::now(),
253            message_count: Arc::new(AtomicU64::new(0)),
254            shared: Arc::new(AtomicBool::new(false)),
255            invalidated: Arc::new(AtomicBool::new(false)),
256            settings: Arc::new(settings),
257            shared_with_set: Default::default(),
258            to_share_with_set: Default::default(),
259        })
260    }
261
262    /// Add a to-device request that is sending the session key (or room key)
263    /// belonging to this [`OutboundGroupSession`] to other members of the
264    /// group.
265    ///
266    /// The request will get persisted with the session which allows seamless
267    /// session reuse across application restarts.
268    ///
269    /// **Warning** this method is only exposed to be used in integration tests
270    /// of crypto-store implementations. **Do not use this outside of tests**.
271    pub fn add_request(
272        &self,
273        request_id: OwnedTransactionId,
274        request: Arc<ToDeviceRequest>,
275        share_infos: ShareInfoSet,
276    ) {
277        self.to_share_with_set.write().insert(request_id, (request, share_infos));
278    }
279
280    /// Create a new `m.room_key.withheld` event content with the given code for
281    /// this outbound group session.
282    pub fn withheld_code(&self, code: WithheldCode) -> RoomKeyWithheldContent {
283        RoomKeyWithheldContent::new(
284            self.settings().algorithm.to_owned(),
285            code,
286            self.room_id().to_owned(),
287            self.session_id().to_owned(),
288            self.sender_key().to_owned(),
289            (*self.device_id).to_owned(),
290        )
291    }
292
293    /// This should be called if an the user wishes to rotate this session.
294    pub fn invalidate_session(&self) {
295        self.invalidated.store(true, Ordering::Relaxed)
296    }
297
298    /// Get the encryption settings of this outbound session.
299    pub fn settings(&self) -> &EncryptionSettings {
300        &self.settings
301    }
302
303    /// Mark the request with the given request id as sent.
304    ///
305    /// This removes the request from the queue and marks the set of
306    /// users/devices that received the session.
307    pub fn mark_request_as_sent(
308        &self,
309        request_id: &TransactionId,
310    ) -> BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>> {
311        let mut no_olm_devices = BTreeMap::new();
312
313        let removed = self.to_share_with_set.write().remove(request_id);
314        if let Some((to_device, request)) = removed {
315            let recipients: BTreeMap<&UserId, BTreeSet<&DeviceId>> = request
316                .iter()
317                .map(|(u, d)| (u.as_ref(), d.keys().map(|d| d.as_ref()).collect()))
318                .collect();
319
320            info!(
321                ?request_id,
322                ?recipients,
323                ?to_device.event_type,
324                "Marking to-device request carrying a room key or a withheld as sent"
325            );
326
327            for (user_id, info) in request {
328                let no_olms: BTreeSet<OwnedDeviceId> = info
329                    .iter()
330                    .filter(|(_, info)| matches!(info, ShareInfo::Withheld(WithheldCode::NoOlm)))
331                    .map(|(d, _)| d.to_owned())
332                    .collect();
333                no_olm_devices.insert(user_id.to_owned(), no_olms);
334
335                self.shared_with_set.write().entry(user_id).or_default().extend(info);
336            }
337
338            if self.to_share_with_set.read().is_empty() {
339                debug!(
340                    session_id = self.session_id(),
341                    room_id = ?self.room_id,
342                    "All m.room_key and withheld to-device requests were sent out, marking \
343                     session as shared.",
344                );
345
346                self.mark_as_shared();
347            }
348        } else {
349            let request_ids: Vec<String> =
350                self.to_share_with_set.read().keys().map(|k| k.to_string()).collect();
351
352            error!(
353                all_request_ids = ?request_ids,
354                request_id = ?request_id,
355                "Marking to-device request carrying a room key as sent but no \
356                 request found with the given id"
357            );
358        }
359
360        no_olm_devices
361    }
362
363    /// Encrypt the given plaintext using this session.
364    ///
365    /// Returns the encrypted ciphertext.
366    ///
367    /// # Arguments
368    ///
369    /// * `plaintext` - The plaintext that should be encrypted.
370    pub(crate) async fn encrypt_helper(&self, plaintext: String) -> MegolmMessage {
371        let mut session = self.inner.write().await;
372        self.message_count.fetch_add(1, Ordering::SeqCst);
373        session.encrypt(&plaintext)
374    }
375
376    /// Encrypt a room message for the given room.
377    ///
378    /// Beware that a room key needs to be shared before this method
379    /// can be called using the `share_room_key()` method.
380    ///
381    /// # Arguments
382    ///
383    /// * `event_type` - The plaintext type of the event, the outer type of the
384    ///   event will become `m.room.encrypted`.
385    ///
386    /// * `content` - The plaintext content of the message that should be
387    ///   encrypted in raw JSON form.
388    ///
389    /// # Panics
390    ///
391    /// Panics if the content can't be serialized.
392    pub async fn encrypt(
393        &self,
394        event_type: &str,
395        content: &Raw<AnyMessageLikeEventContent>,
396    ) -> Raw<RoomEncryptedEventContent> {
397        #[derive(Serialize)]
398        struct Payload<'a> {
399            #[serde(rename = "type")]
400            event_type: &'a str,
401            content: &'a Raw<AnyMessageLikeEventContent>,
402            room_id: &'a RoomId,
403        }
404
405        let payload = Payload { event_type, content, room_id: &self.room_id };
406        let payload_json =
407            serde_json::to_string(&payload).expect("payload serialization never fails");
408
409        let relates_to = content
410            .get_field::<serde_json::Value>("m.relates_to")
411            .expect("serde_json::Value deserialization with valid JSON input never fails");
412
413        let ciphertext = self.encrypt_helper(payload_json).await;
414        let scheme: RoomEventEncryptionScheme = match self.settings.algorithm {
415            EventEncryptionAlgorithm::MegolmV1AesSha2 => MegolmV1AesSha2Content {
416                ciphertext,
417                sender_key: self.account_identity_keys.curve25519,
418                session_id: self.session_id().to_owned(),
419                device_id: (*self.device_id).to_owned(),
420            }
421            .into(),
422            #[cfg(feature = "experimental-algorithms")]
423            EventEncryptionAlgorithm::MegolmV2AesSha2 => {
424                MegolmV2AesSha2Content { ciphertext, session_id: self.session_id().to_owned() }
425                    .into()
426            }
427            _ => unreachable!(
428                "An outbound group session is always using one of the supported algorithms"
429            ),
430        };
431
432        let content = RoomEncryptedEventContent { scheme, relates_to, other: Default::default() };
433
434        Raw::new(&content).expect("m.room.encrypted event content can always be serialized")
435    }
436
437    fn elapsed(&self) -> bool {
438        let creation_time = Duration::from_secs(self.creation_time.get().into());
439        let now = Duration::from_secs(SecondsSinceUnixEpoch::now().get().into());
440        now.checked_sub(creation_time)
441            .map(|elapsed| elapsed >= self.safe_rotation_period())
442            .unwrap_or(true)
443    }
444
445    /// Returns the rotation_period_ms that was set for this session, clamped
446    /// to be no less than one hour.
447    ///
448    /// This is to prevent a malicious or careless user causing sessions to be
449    /// rotated very frequently.
450    ///
451    /// The feature flag `_disable-minimum-rotation-period-ms` can
452    /// be used to prevent this behaviour (which can be useful for tests).
453    fn safe_rotation_period(&self) -> Duration {
454        if cfg!(feature = "_disable-minimum-rotation-period-ms") {
455            self.settings.rotation_period
456        } else {
457            max(self.settings.rotation_period, ONE_HOUR)
458        }
459    }
460
461    /// Check if the session has expired and if it should be rotated.
462    ///
463    /// A session will expire after some time or if enough messages have been
464    /// encrypted using it.
465    pub fn expired(&self) -> bool {
466        let count = self.message_count.load(Ordering::SeqCst);
467        // We clamp the rotation period for message counts to be between 1 and
468        // 10000. The Megolm session should be usable for at least 1 message,
469        // and at most 10000 messages. Realistically Megolm uses u32 for it's
470        // internal counter and one could use the Megolm session for up to
471        // u32::MAX messages, but we're staying on the safe side of things.
472        let rotation_period_msgs = self.settings.rotation_period_msgs.clamp(1, 10_000);
473
474        count >= rotation_period_msgs || self.elapsed()
475    }
476
477    /// Has the session been invalidated.
478    pub fn invalidated(&self) -> bool {
479        self.invalidated.load(Ordering::Relaxed)
480    }
481
482    /// Mark the session as shared.
483    ///
484    /// Messages shouldn't be encrypted with the session before it has been
485    /// shared.
486    pub fn mark_as_shared(&self) {
487        self.shared.store(true, Ordering::Relaxed);
488    }
489
490    /// Check if the session has been marked as shared.
491    pub fn shared(&self) -> bool {
492        self.shared.load(Ordering::Relaxed)
493    }
494
495    /// Get the session key of this session.
496    ///
497    /// A session key can be used to to create an `InboundGroupSession`.
498    pub async fn session_key(&self) -> SessionKey {
499        let session = self.inner.read().await;
500        session.session_key()
501    }
502
503    /// Gets the Sender Key
504    pub fn sender_key(&self) -> Curve25519PublicKey {
505        self.account_identity_keys.as_ref().curve25519.to_owned()
506    }
507
508    /// Get the room id of the room this session belongs to.
509    pub fn room_id(&self) -> &RoomId {
510        &self.room_id
511    }
512
513    /// Returns the unique identifier for this session.
514    pub fn session_id(&self) -> &str {
515        &self.session_id
516    }
517
518    /// Get the current message index for this session.
519    ///
520    /// Each message is sent with an increasing index. This returns the
521    /// message index that will be used for the next encrypted message.
522    pub async fn message_index(&self) -> u32 {
523        let session = self.inner.read().await;
524        session.message_index()
525    }
526
527    pub(crate) async fn as_content(&self) -> RoomKeyContent {
528        let session_key = self.session_key().await;
529
530        RoomKeyContent::MegolmV1AesSha2(
531            MegolmV1AesSha2RoomKeyContent::new(
532                self.room_id().to_owned(),
533                self.session_id().to_owned(),
534                session_key,
535            )
536            .into(),
537        )
538    }
539
540    /// Has or will the session be shared with the given user/device pair.
541    pub(crate) fn is_shared_with(&self, device: &DeviceData) -> ShareState {
542        // Check if we shared the session.
543        let shared_state = self.shared_with_set.read().get(device.user_id()).and_then(|d| {
544            d.get(device.device_id()).map(|s| match s {
545                ShareInfo::Shared(s) => {
546                    if device.curve25519_key() == Some(s.sender_key) {
547                        ShareState::Shared {
548                            message_index: s.message_index,
549                            olm_wedging_index: s.olm_wedging_index,
550                        }
551                    } else {
552                        ShareState::SharedButChangedSenderKey
553                    }
554                }
555                ShareInfo::Withheld(_) => ShareState::NotShared,
556            })
557        });
558
559        if let Some(state) = shared_state {
560            state
561        } else {
562            // If we haven't shared the session, check if we're going to share
563            // the session.
564
565            // Find the first request that contains the given user id and
566            // device ID.
567            let shared = self.to_share_with_set.read().values().find_map(|(_, share_info)| {
568                let d = share_info.get(device.user_id())?;
569                let info = d.get(device.device_id())?;
570                Some(match info {
571                    ShareInfo::Shared(info) => {
572                        if device.curve25519_key() == Some(info.sender_key) {
573                            ShareState::Shared {
574                                message_index: info.message_index,
575                                olm_wedging_index: info.olm_wedging_index,
576                            }
577                        } else {
578                            ShareState::SharedButChangedSenderKey
579                        }
580                    }
581                    ShareInfo::Withheld(_) => ShareState::NotShared,
582                })
583            });
584
585            shared.unwrap_or(ShareState::NotShared)
586        }
587    }
588
589    pub(crate) fn is_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
590        self.shared_with_set
591            .read()
592            .get(device.user_id())
593            .and_then(|d| {
594                let info = d.get(device.device_id())?;
595                Some(matches!(info, ShareInfo::Withheld(c) if c == code))
596            })
597            .unwrap_or_else(|| {
598                // If we haven't yet withheld, check if we're going to withheld
599                // the session.
600
601                // Find the first request that contains the given user id and
602                // device ID.
603                self.to_share_with_set.read().values().any(|(_, share_info)| {
604                    share_info
605                        .get(device.user_id())
606                        .and_then(|d| d.get(device.device_id()))
607                        .is_some_and(|info| matches!(info, ShareInfo::Withheld(c) if c == code))
608                })
609            })
610    }
611
612    /// Mark the session as shared with the given user/device pair, starting
613    /// from some message index.
614    #[cfg(test)]
615    pub fn mark_shared_with_from_index(
616        &self,
617        user_id: &UserId,
618        device_id: &DeviceId,
619        sender_key: Curve25519PublicKey,
620        index: u32,
621    ) {
622        self.shared_with_set.write().entry(user_id.to_owned()).or_default().insert(
623            device_id.to_owned(),
624            ShareInfo::new_shared(sender_key, index, Default::default()),
625        );
626    }
627
628    /// Mark the session as shared with the given user/device pair, starting
629    /// from the current index.
630    #[cfg(test)]
631    pub async fn mark_shared_with(
632        &self,
633        user_id: &UserId,
634        device_id: &DeviceId,
635        sender_key: Curve25519PublicKey,
636    ) {
637        let share_info =
638            ShareInfo::new_shared(sender_key, self.message_index().await, Default::default());
639        self.shared_with_set
640            .write()
641            .entry(user_id.to_owned())
642            .or_default()
643            .insert(device_id.to_owned(), share_info);
644    }
645
646    /// Get the list of requests that need to be sent out for this session to be
647    /// marked as shared.
648    pub(crate) fn pending_requests(&self) -> Vec<Arc<ToDeviceRequest>> {
649        self.to_share_with_set.read().values().map(|(req, _)| req.clone()).collect()
650    }
651
652    /// Get the list of request ids this session is waiting for to be sent out.
653    pub(crate) fn pending_request_ids(&self) -> Vec<OwnedTransactionId> {
654        self.to_share_with_set.read().keys().cloned().collect()
655    }
656
657    /// Restore a Session from a previously pickled string.
658    ///
659    /// Returns the restored group session or a `OlmGroupSessionError` if there
660    /// was an error.
661    ///
662    /// # Arguments
663    ///
664    /// * `device_id` - The device ID of the device that created this session.
665    ///   Put differently, our own device ID.
666    ///
667    /// * `identity_keys` - The identity keys of the device that created this
668    ///   session, our own identity keys.
669    ///
670    /// * `pickle` - The pickled version of the `OutboundGroupSession`.
671    ///
672    /// * `pickle_mode` - The mode that was used to pickle the session, either
673    ///   an unencrypted mode or an encrypted using passphrase.
674    pub fn from_pickle(
675        device_id: OwnedDeviceId,
676        identity_keys: Arc<IdentityKeys>,
677        pickle: PickledOutboundGroupSession,
678    ) -> Result<Self, PickleError> {
679        let inner: GroupSession = pickle.pickle.into();
680        let session_id = inner.session_id();
681
682        Ok(Self {
683            inner: Arc::new(RwLock::new(inner)),
684            device_id,
685            account_identity_keys: identity_keys,
686            session_id: session_id.into(),
687            room_id: pickle.room_id,
688            creation_time: pickle.creation_time,
689            message_count: AtomicU64::from(pickle.message_count).into(),
690            shared: AtomicBool::from(pickle.shared).into(),
691            invalidated: AtomicBool::from(pickle.invalidated).into(),
692            settings: pickle.settings,
693            shared_with_set: Arc::new(StdRwLock::new(pickle.shared_with_set)),
694            to_share_with_set: Arc::new(StdRwLock::new(pickle.requests)),
695        })
696    }
697
698    /// Store the group session as a base64 encoded string and associated data
699    /// belonging to the session.
700    ///
701    /// # Arguments
702    ///
703    /// * `pickle_mode` - The mode that should be used to pickle the group
704    ///   session, either an unencrypted mode or an encrypted using passphrase.
705    pub async fn pickle(&self) -> PickledOutboundGroupSession {
706        let pickle = self.inner.read().await.pickle();
707
708        PickledOutboundGroupSession {
709            pickle,
710            room_id: self.room_id.clone(),
711            settings: self.settings.clone(),
712            creation_time: self.creation_time,
713            message_count: self.message_count.load(Ordering::SeqCst),
714            shared: self.shared(),
715            invalidated: self.invalidated(),
716            shared_with_set: self.shared_with_set.read().clone(),
717            requests: self.to_share_with_set.read().clone(),
718        }
719    }
720}
721
722#[derive(Clone, Debug, Serialize, Deserialize)]
723pub struct OutboundGroupSessionPickle(String);
724
725impl From<String> for OutboundGroupSessionPickle {
726    fn from(p: String) -> Self {
727        Self(p)
728    }
729}
730
731#[cfg(not(tarpaulin_include))]
732impl fmt::Debug for OutboundGroupSession {
733    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
734        f.debug_struct("OutboundGroupSession")
735            .field("session_id", &self.session_id)
736            .field("room_id", &self.room_id)
737            .field("creation_time", &self.creation_time)
738            .field("message_count", &self.message_count)
739            .finish()
740    }
741}
742
743/// A pickled version of an `InboundGroupSession`.
744///
745/// Holds all the information that needs to be stored in a database to restore
746/// an InboundGroupSession.
747#[derive(Deserialize, Serialize)]
748#[allow(missing_debug_implementations)]
749pub struct PickledOutboundGroupSession {
750    /// The pickle string holding the OutboundGroupSession.
751    pub pickle: GroupSessionPickle,
752    /// The settings this session adheres to.
753    pub settings: Arc<EncryptionSettings>,
754    /// The room id this session is used for.
755    pub room_id: OwnedRoomId,
756    /// The timestamp when this session was created.
757    pub creation_time: SecondsSinceUnixEpoch,
758    /// The number of messages this session has already encrypted.
759    pub message_count: u64,
760    /// Is the session shared.
761    pub shared: bool,
762    /// Has the session been invalidated.
763    pub invalidated: bool,
764    /// The set of users the session has been already shared with.
765    pub shared_with_set: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>,
766    /// Requests that need to be sent out to share the session.
767    pub requests: BTreeMap<OwnedTransactionId, (Arc<ToDeviceRequest>, ShareInfoSet)>,
768}
769
770#[cfg(test)]
771mod tests {
772    use std::time::Duration;
773
774    use ruma::{
775        events::room::{
776            encryption::RoomEncryptionEventContent, history_visibility::HistoryVisibility,
777        },
778        uint, EventEncryptionAlgorithm,
779    };
780
781    use super::{EncryptionSettings, ROTATION_MESSAGES, ROTATION_PERIOD};
782    use crate::CollectStrategy;
783
784    #[test]
785    fn test_encryption_settings_conversion() {
786        let mut content =
787            RoomEncryptionEventContent::new(EventEncryptionAlgorithm::MegolmV1AesSha2);
788        let settings = EncryptionSettings::new(
789            content.clone(),
790            HistoryVisibility::Joined,
791            CollectStrategy::AllDevices,
792        );
793
794        assert_eq!(settings.rotation_period, ROTATION_PERIOD);
795        assert_eq!(settings.rotation_period_msgs, ROTATION_MESSAGES);
796
797        content.rotation_period_ms = Some(uint!(3600));
798        content.rotation_period_msgs = Some(uint!(500));
799
800        let settings = EncryptionSettings::new(
801            content,
802            HistoryVisibility::Shared,
803            CollectStrategy::AllDevices,
804        );
805
806        assert_eq!(settings.rotation_period, Duration::from_millis(3600));
807        assert_eq!(settings.rotation_period_msgs, 500);
808    }
809
810    #[cfg(any(target_os = "linux", target_os = "macos", target_arch = "wasm32"))]
811    mod expiration {
812        use std::{sync::atomic::Ordering, time::Duration};
813
814        use matrix_sdk_test::async_test;
815        use ruma::{
816            device_id, events::room::message::RoomMessageEventContent, room_id, serde::Raw, uint,
817            user_id, SecondsSinceUnixEpoch,
818        };
819
820        use crate::{
821            olm::{OutboundGroupSession, SenderData},
822            Account, EncryptionSettings, MegolmError,
823        };
824
825        const TWO_HOURS: Duration = Duration::from_secs(60 * 60 * 2);
826
827        #[async_test]
828        async fn test_session_is_not_expired_if_no_messages_sent_and_no_time_passed() {
829            // Given a session that expires after one message
830            let session = create_session(EncryptionSettings {
831                rotation_period_msgs: 1,
832                ..Default::default()
833            })
834            .await;
835
836            // When we send no messages at all
837
838            // Then it is not expired
839            assert!(!session.expired());
840        }
841
842        #[async_test]
843        async fn test_session_is_expired_if_we_rotate_every_message_and_one_was_sent(
844        ) -> Result<(), MegolmError> {
845            // Given a session that expires after one message
846            let session = create_session(EncryptionSettings {
847                rotation_period_msgs: 1,
848                ..Default::default()
849            })
850            .await;
851
852            // When we send a message
853            let _ = session
854                .encrypt(
855                    "m.room.message",
856                    &Raw::new(&RoomMessageEventContent::text_plain("Test message"))?.cast(),
857                )
858                .await;
859
860            // Then the session is expired
861            assert!(session.expired());
862
863            Ok(())
864        }
865
866        #[async_test]
867        async fn test_session_with_rotation_period_is_not_expired_after_no_time() {
868            // Given a session with a 2h expiration
869            let session = create_session(EncryptionSettings {
870                rotation_period: TWO_HOURS,
871                ..Default::default()
872            })
873            .await;
874
875            // When we don't allow any time to pass
876
877            // Then it is not expired
878            assert!(!session.expired());
879        }
880
881        #[async_test]
882        async fn test_session_is_expired_after_rotation_period() {
883            // Given a session with a 2h expiration
884            let mut session = create_session(EncryptionSettings {
885                rotation_period: TWO_HOURS,
886                ..Default::default()
887            })
888            .await;
889
890            // When 3 hours have passed
891            let now = SecondsSinceUnixEpoch::now();
892            session.creation_time = SecondsSinceUnixEpoch(now.get() - uint!(10800));
893
894            // Then the session is expired
895            assert!(session.expired());
896        }
897
898        #[async_test]
899        #[cfg(not(feature = "_disable-minimum-rotation-period-ms"))]
900        async fn test_session_does_not_expire_under_one_hour_even_if_we_ask_for_shorter() {
901            // Given a session with a 100ms expiration
902            let mut session = create_session(EncryptionSettings {
903                rotation_period: Duration::from_millis(100),
904                ..Default::default()
905            })
906            .await;
907
908            // When less than an hour has passed
909            let now = SecondsSinceUnixEpoch::now();
910            session.creation_time = SecondsSinceUnixEpoch(now.get() - uint!(1800));
911
912            // Then the session is not expired: we enforce a minimum of 1 hour
913            assert!(!session.expired());
914
915            // But when more than an hour has passed
916            session.creation_time = SecondsSinceUnixEpoch(now.get() - uint!(3601));
917
918            // Then the session is expired
919            assert!(session.expired());
920        }
921
922        #[async_test]
923        #[cfg(feature = "_disable-minimum-rotation-period-ms")]
924        async fn test_with_disable_minrotperiod_feature_sessions_can_expire_quickly() {
925            // Given a session with a 100ms expiration
926            let mut session = create_session(EncryptionSettings {
927                rotation_period: Duration::from_millis(100),
928                ..Default::default()
929            })
930            .await;
931
932            // When less than an hour has passed
933            let now = SecondsSinceUnixEpoch::now();
934            session.creation_time = SecondsSinceUnixEpoch(now.get() - uint!(1800));
935
936            // Then the session is expired: the feature flag has prevented us enforcing a
937            // minimum
938            assert!(session.expired());
939        }
940
941        #[async_test]
942        async fn test_session_with_zero_msgs_rotation_is_not_expired_initially() {
943            // Given a session that is supposed to expire after zero messages
944            let session = create_session(EncryptionSettings {
945                rotation_period_msgs: 0,
946                ..Default::default()
947            })
948            .await;
949
950            // When we send no messages
951
952            // Then the session is not expired: we are protected against this nonsensical
953            // setup
954            assert!(!session.expired());
955        }
956
957        #[async_test]
958        async fn test_session_with_zero_msgs_rotation_expires_after_one_message(
959        ) -> Result<(), MegolmError> {
960            // Given a session that is supposed to expire after zero messages
961            let session = create_session(EncryptionSettings {
962                rotation_period_msgs: 0,
963                ..Default::default()
964            })
965            .await;
966
967            // When we send a message
968            let _ = session
969                .encrypt(
970                    "m.room.message",
971                    &Raw::new(&RoomMessageEventContent::text_plain("Test message"))?.cast(),
972                )
973                .await;
974
975            // Then the session is expired: we treated rotation_period_msgs=0 as if it were
976            // =1
977            assert!(session.expired());
978
979            Ok(())
980        }
981
982        #[async_test]
983        async fn test_session_expires_after_10k_messages_even_if_we_ask_for_more() {
984            // Given we asked to expire after 100K messages
985            let session = create_session(EncryptionSettings {
986                rotation_period_msgs: 100_000,
987                ..Default::default()
988            })
989            .await;
990
991            // Sanity: it does not expire after <10K messages
992            assert!(!session.expired());
993            session.message_count.store(1000, Ordering::SeqCst);
994            assert!(!session.expired());
995            session.message_count.store(9999, Ordering::SeqCst);
996            assert!(!session.expired());
997
998            // When we have sent >= 10K messages
999            session.message_count.store(10_000, Ordering::SeqCst);
1000
1001            // Then it is considered expired: we enforce a maximum of 10K messages before
1002            // rotation.
1003            assert!(session.expired());
1004        }
1005
1006        async fn create_session(settings: EncryptionSettings) -> OutboundGroupSession {
1007            let account =
1008                Account::with_device_id(user_id!("@alice:example.org"), device_id!("DEVICEID"))
1009                    .static_data;
1010            let (session, _) = account
1011                .create_group_session_pair(
1012                    room_id!("!test_room:example.org"),
1013                    settings,
1014                    SenderData::unknown(),
1015                )
1016                .await
1017                .unwrap();
1018            session
1019        }
1020    }
1021}