1use std::{
16 cmp::max,
17 collections::{BTreeMap, BTreeSet},
18 fmt,
19 ops::Bound,
20 sync::{
21 Arc, RwLockReadGuard,
22 atomic::{AtomicBool, AtomicU64, Ordering},
23 },
24 time::Duration,
25};
26
27use matrix_sdk_common::{deserialized_responses::WithheldCode, locks::RwLock as StdRwLock};
28#[cfg(feature = "experimental-encrypted-state-events")]
29use ruma::events::AnyStateEventContent;
30use ruma::{
31 DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId,
32 SecondsSinceUnixEpoch, TransactionId, UserId,
33 events::{
34 AnyMessageLikeEventContent,
35 room::{
36 encryption::{PossiblyRedactedRoomEncryptionEventContent, RoomEncryptionEventContent},
37 history_visibility::HistoryVisibility,
38 },
39 },
40 serde::Raw,
41};
42use serde::{Deserialize, Serialize};
43use tokio::sync::RwLock;
44use tracing::{debug, error, info};
45use vodozemac::{Curve25519PublicKey, megolm::SessionConfig};
46pub use vodozemac::{
47 PickleError,
48 megolm::{GroupSession, GroupSessionPickle, MegolmMessage, SessionKey},
49 olm::IdentityKeys,
50};
51
52use super::SessionCreationError;
53#[cfg(feature = "experimental-algorithms")]
54use crate::types::events::room::encrypted::MegolmV2AesSha2Content;
55use crate::{
56 DeviceData,
57 olm::account::shared_history_from_history_visibility,
58 session_manager::CollectStrategy,
59 store::caches::SequenceNumber,
60 types::{
61 EventEncryptionAlgorithm,
62 events::{
63 room::encrypted::{
64 MegolmV1AesSha2Content, RoomEncryptedEventContent, RoomEventEncryptionScheme,
65 },
66 room_key::{MegolmV1AesSha2Content as MegolmV1AesSha2RoomKeyContent, RoomKeyContent},
67 room_key_withheld::RoomKeyWithheldContent,
68 },
69 requests::ToDeviceRequest,
70 },
71};
72
73const ONE_HOUR: Duration = Duration::from_secs(60 * 60);
74const ONE_WEEK: Duration = Duration::from_secs(60 * 60 * 24 * 7);
75
76const ROTATION_PERIOD: Duration = ONE_WEEK;
77const ROTATION_MESSAGES: u64 = 100;
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
80pub(crate) enum ShareState {
82 NotShared,
84 SharedButChangedSenderKey,
87 Shared { message_index: u32, olm_wedging_index: SequenceNumber },
93}
94
95#[derive(Clone, Debug, Deserialize, Serialize)]
99pub struct EncryptionSettings {
100 pub algorithm: EventEncryptionAlgorithm,
102 #[cfg(feature = "experimental-encrypted-state-events")]
104 #[serde(default)]
105 pub encrypt_state_events: bool,
106 pub rotation_period: Duration,
108 pub rotation_period_msgs: u64,
110 pub history_visibility: HistoryVisibility,
112 #[serde(default)]
115 pub sharing_strategy: CollectStrategy,
116}
117
118impl Default for EncryptionSettings {
119 fn default() -> Self {
120 Self {
121 algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
122 #[cfg(feature = "experimental-encrypted-state-events")]
123 encrypt_state_events: false,
124 rotation_period: ROTATION_PERIOD,
125 rotation_period_msgs: ROTATION_MESSAGES,
126 history_visibility: HistoryVisibility::Shared,
127 sharing_strategy: CollectStrategy::default(),
128 }
129 }
130}
131
132impl EncryptionSettings {
133 pub fn new(
136 content: RoomEncryptionEventContent,
137 history_visibility: HistoryVisibility,
138 sharing_strategy: CollectStrategy,
139 ) -> Self {
140 let rotation_period: Duration =
141 content.rotation_period_ms.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into()));
142 let rotation_period_msgs: u64 =
143 content.rotation_period_msgs.map_or(ROTATION_MESSAGES, Into::into);
144
145 Self {
146 algorithm: EventEncryptionAlgorithm::from(content.algorithm.as_str()),
147 #[cfg(feature = "experimental-encrypted-state-events")]
148 encrypt_state_events: false,
149 rotation_period,
150 rotation_period_msgs,
151 history_visibility,
152 sharing_strategy,
153 }
154 }
155
156 pub fn from_possibly_redacted(
162 content: PossiblyRedactedRoomEncryptionEventContent,
163 history_visibility: HistoryVisibility,
164 sharing_strategy: CollectStrategy,
165 ) -> Option<Self> {
166 let rotation_period: Duration =
167 content.rotation_period_ms.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into()));
168 let rotation_period_msgs: u64 =
169 content.rotation_period_msgs.map_or(ROTATION_MESSAGES, Into::into);
170
171 Some(Self {
172 algorithm: EventEncryptionAlgorithm::from(content.algorithm?.as_str()),
173 #[cfg(feature = "experimental-encrypted-state-events")]
174 encrypt_state_events: false,
175 rotation_period,
176 rotation_period_msgs,
177 history_visibility,
178 sharing_strategy,
179 })
180 }
181}
182
183#[derive(Debug)]
187pub struct OutboundGroupSessionEncryptionResult {
188 pub content: Raw<RoomEncryptedEventContent>,
190 pub algorithm: EventEncryptionAlgorithm,
192 pub session_id: Arc<str>,
194}
195
196#[derive(Clone)]
202pub struct OutboundGroupSession {
203 inner: Arc<RwLock<GroupSession>>,
204 device_id: OwnedDeviceId,
205 account_identity_keys: Arc<IdentityKeys>,
206 session_id: Arc<str>,
207 room_id: OwnedRoomId,
208 pub(crate) creation_time: SecondsSinceUnixEpoch,
209 message_count: Arc<AtomicU64>,
210 shared: Arc<AtomicBool>,
211 invalidated: Arc<AtomicBool>,
212 settings: Arc<EncryptionSettings>,
213 shared_with_set: Arc<StdRwLock<ShareInfoSet>>,
214 to_share_with_set: Arc<StdRwLock<ToShareMap>>,
215}
216
217pub type ShareInfoSet = BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>;
222
223type ToShareMap = BTreeMap<OwnedTransactionId, (Arc<ToDeviceRequest>, ShareInfoSet)>;
224
225#[derive(Clone, Debug, Serialize, Deserialize)]
227pub enum ShareInfo {
228 Shared(SharedWith),
230 Withheld(WithheldCode),
232}
233
234impl ShareInfo {
235 pub fn new_shared(
237 sender_key: Curve25519PublicKey,
238 message_index: u32,
239 olm_wedging_index: SequenceNumber,
240 ) -> Self {
241 ShareInfo::Shared(SharedWith { sender_key, message_index, olm_wedging_index })
242 }
243
244 pub fn new_withheld(code: WithheldCode) -> Self {
246 ShareInfo::Withheld(code)
247 }
248}
249
250#[derive(Clone, Debug, Serialize, Deserialize)]
251pub struct SharedWith {
252 pub sender_key: Curve25519PublicKey,
254 pub message_index: u32,
256 #[serde(default)]
258 pub olm_wedging_index: SequenceNumber,
259}
260
261pub(crate) struct SharingView<'a> {
264 shared_with_set: RwLockReadGuard<'a, ShareInfoSet>,
265 to_share_with_set: RwLockReadGuard<'a, ToShareMap>,
266}
267
268impl SharingView<'_> {
269 pub(crate) fn get_share_state(&self, device: &DeviceData) -> ShareState {
272 self.iter_shares(Some(device.user_id()), Some(device.device_id()))
273 .map(|(_, _, info)| match info {
274 ShareInfo::Shared(info) => {
275 if device.curve25519_key() == Some(info.sender_key) {
276 ShareState::Shared {
277 message_index: info.message_index,
278 olm_wedging_index: info.olm_wedging_index,
279 }
280 } else {
281 ShareState::SharedButChangedSenderKey
282 }
283 }
284 ShareInfo::Withheld(_) => ShareState::NotShared,
285 })
286 .max()
289 .unwrap_or(ShareState::NotShared)
290 }
291
292 pub(crate) fn is_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
295 self.iter_shares(Some(device.user_id()), Some(device.device_id()))
296 .any(|(_, _, info)| matches!(info, ShareInfo::Withheld(c) if c == code))
297 }
298
299 pub(crate) fn iter_shares<'b, 'c>(
303 &self,
304 user_id: Option<&'b UserId>,
305 device_id: Option<&'c DeviceId>,
306 ) -> impl Iterator<Item = (&UserId, &DeviceId, &ShareInfo)> + use<'_, 'b, 'c> {
307 fn iter_share_info_set<'a, 'b, 'c>(
308 set: &'a ShareInfoSet,
309 user_ids: (Bound<&'b UserId>, Bound<&'b UserId>),
310 device_ids: (Bound<&'c DeviceId>, Bound<&'c DeviceId>),
311 ) -> impl Iterator<Item = (&'a UserId, &'a DeviceId, &'a ShareInfo)> + use<'a, 'b, 'c>
312 {
313 set.range::<UserId, _>(user_ids).flat_map(move |(uid, d)| {
314 d.range::<DeviceId, _>(device_ids)
315 .map(|(id, info)| (uid.as_ref(), id.as_ref(), info))
316 })
317 }
318
319 let user_ids = user_id
320 .map(|u| (Bound::Included(u), Bound::Included(u)))
321 .unwrap_or((Bound::Unbounded, Bound::Unbounded));
322 let device_ids = device_id
323 .map(|d| (Bound::Included(d), Bound::Included(d)))
324 .unwrap_or((Bound::Unbounded, Bound::Unbounded));
325
326 let already_shared = iter_share_info_set(&self.shared_with_set, user_ids, device_ids);
327 let pending = self
328 .to_share_with_set
329 .values()
330 .flat_map(move |(_, set)| iter_share_info_set(set, user_ids, device_ids));
331 already_shared.chain(pending)
332 }
333
334 pub(crate) fn shared_with_users(&self) -> impl Iterator<Item = &UserId> {
338 self.iter_shares(None, None).filter_map(|(u, _, info)| match info {
339 ShareInfo::Shared(_) => Some(u),
340 ShareInfo::Withheld(_) => None,
341 })
342 }
343}
344
345impl OutboundGroupSession {
346 pub(super) fn session_config(
347 algorithm: &EventEncryptionAlgorithm,
348 ) -> Result<SessionConfig, SessionCreationError> {
349 match algorithm {
350 EventEncryptionAlgorithm::MegolmV1AesSha2 => Ok(SessionConfig::version_1()),
351 #[cfg(feature = "experimental-algorithms")]
352 EventEncryptionAlgorithm::MegolmV2AesSha2 => Ok(SessionConfig::version_2()),
353 _ => Err(SessionCreationError::Algorithm(algorithm.to_owned())),
354 }
355 }
356
357 pub fn new(
373 device_id: OwnedDeviceId,
374 identity_keys: Arc<IdentityKeys>,
375 room_id: &RoomId,
376 settings: EncryptionSettings,
377 ) -> Result<Self, SessionCreationError> {
378 let config = Self::session_config(&settings.algorithm)?;
379
380 let session = GroupSession::new(config);
381 let session_id = session.session_id();
382
383 Ok(OutboundGroupSession {
384 inner: RwLock::new(session).into(),
385 room_id: room_id.into(),
386 device_id,
387 account_identity_keys: identity_keys,
388 session_id: session_id.into(),
389 creation_time: SecondsSinceUnixEpoch::now(),
390 message_count: Arc::new(AtomicU64::new(0)),
391 shared: Arc::new(AtomicBool::new(false)),
392 invalidated: Arc::new(AtomicBool::new(false)),
393 settings: Arc::new(settings),
394 shared_with_set: Default::default(),
395 to_share_with_set: Default::default(),
396 })
397 }
398
399 pub fn add_request(
409 &self,
410 request_id: OwnedTransactionId,
411 request: Arc<ToDeviceRequest>,
412 share_infos: ShareInfoSet,
413 ) {
414 self.to_share_with_set.write().insert(request_id, (request, share_infos));
415 }
416
417 pub fn withheld_code(&self, code: WithheldCode) -> RoomKeyWithheldContent {
420 RoomKeyWithheldContent::new(
421 self.settings().algorithm.to_owned(),
422 code,
423 self.room_id().to_owned(),
424 self.session_id().to_owned(),
425 self.sender_key().to_owned(),
426 self.device_id.clone(),
427 )
428 }
429
430 pub fn invalidate_session(&self) {
432 self.invalidated.store(true, Ordering::Relaxed)
433 }
434
435 pub fn settings(&self) -> &EncryptionSettings {
437 &self.settings
438 }
439
440 pub fn mark_request_as_sent(
445 &self,
446 request_id: &TransactionId,
447 ) -> BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>> {
448 let mut no_olm_devices = BTreeMap::new();
449
450 let removed = self.to_share_with_set.write().remove(request_id);
451 if let Some((to_device, request)) = removed {
452 let recipients: BTreeMap<&UserId, BTreeSet<&DeviceId>> = request
453 .iter()
454 .map(|(u, d)| (u.as_ref(), d.keys().map(|d| d.as_ref()).collect()))
455 .collect();
456
457 info!(
458 ?request_id,
459 ?recipients,
460 ?to_device.event_type,
461 "Marking to-device request carrying a room key or a withheld as sent"
462 );
463
464 for (user_id, info) in request {
465 let no_olms: BTreeSet<OwnedDeviceId> = info
466 .iter()
467 .filter(|(_, info)| matches!(info, ShareInfo::Withheld(WithheldCode::NoOlm)))
468 .map(|(d, _)| d.to_owned())
469 .collect();
470 no_olm_devices.insert(user_id.to_owned(), no_olms);
471
472 self.shared_with_set.write().entry(user_id).or_default().extend(info);
473 }
474
475 if self.to_share_with_set.read().is_empty() {
476 debug!(
477 session_id = self.session_id(),
478 room_id = ?self.room_id,
479 "All m.room_key and withheld to-device requests were sent out, marking \
480 session as shared.",
481 );
482
483 self.mark_as_shared();
484 }
485 } else {
486 let request_ids: Vec<String> =
487 self.to_share_with_set.read().keys().map(|k| k.to_string()).collect();
488
489 error!(
490 all_request_ids = ?request_ids,
491 ?request_id,
492 "Marking to-device request carrying a room key as sent but no \
493 request found with the given id"
494 );
495 }
496
497 no_olm_devices
498 }
499
500 pub(crate) async fn encrypt_helper(&self, plaintext: String) -> MegolmMessage {
508 let mut session = self.inner.write().await;
509 self.message_count.fetch_add(1, Ordering::SeqCst);
510 session.encrypt(&plaintext)
511 }
512
513 async fn encrypt_inner<T: Serialize>(
527 &self,
528 payload: &T,
529 relates_to: Option<serde_json::Value>,
530 ) -> OutboundGroupSessionEncryptionResult {
531 let ciphertext = self
532 .encrypt_helper(
533 serde_json::to_string(payload).expect("payload serialization never fails"),
534 )
535 .await;
536 let scheme: RoomEventEncryptionScheme = match self.settings.algorithm {
537 EventEncryptionAlgorithm::MegolmV1AesSha2 => MegolmV1AesSha2Content {
538 ciphertext,
539 sender_key: Some(self.account_identity_keys.curve25519),
540 session_id: self.session_id().to_owned(),
541 device_id: Some(self.device_id.clone()),
542 }
543 .into(),
544 #[cfg(feature = "experimental-algorithms")]
545 EventEncryptionAlgorithm::MegolmV2AesSha2 => {
546 MegolmV2AesSha2Content { ciphertext, session_id: self.session_id().to_owned() }
547 .into()
548 }
549 _ => unreachable!(
550 "An outbound group session is always using one of the supported algorithms"
551 ),
552 };
553 let content = RoomEncryptedEventContent { scheme, relates_to, other: Default::default() };
554
555 OutboundGroupSessionEncryptionResult {
556 content: Raw::new(&content)
557 .expect("m.room.encrypted event content can always be serialized"),
558 algorithm: self.settings.algorithm.to_owned(),
559 session_id: self.session_id.clone(),
560 }
561 }
562
563 pub async fn encrypt(
580 &self,
581 event_type: &str,
582 content: &Raw<AnyMessageLikeEventContent>,
583 ) -> OutboundGroupSessionEncryptionResult {
584 #[derive(Serialize)]
585 struct Payload<'a> {
586 #[serde(rename = "type")]
587 event_type: &'a str,
588 content: &'a Raw<AnyMessageLikeEventContent>,
589 room_id: &'a RoomId,
590 }
591
592 let payload = Payload { event_type, content, room_id: &self.room_id };
593
594 let relates_to = content
595 .get_field::<serde_json::Value>("m.relates_to")
596 .expect("serde_json::Value deserialization with valid JSON input never fails");
597
598 self.encrypt_inner(&payload, relates_to).await
599 }
600
601 #[cfg(feature = "experimental-encrypted-state-events")]
621 pub async fn encrypt_state(
622 &self,
623 event_type: &str,
624 state_key: &str,
625 content: &Raw<AnyStateEventContent>,
626 ) -> Raw<RoomEncryptedEventContent> {
627 #[derive(Serialize)]
628 struct Payload<'a> {
629 #[serde(rename = "type")]
630 event_type: &'a str,
631 state_key: &'a str,
632 content: &'a Raw<AnyStateEventContent>,
633 room_id: &'a RoomId,
634 }
635
636 let payload = Payload { event_type, state_key, content, room_id: &self.room_id };
637 self.encrypt_inner(&payload, None).await.content
638 }
639
640 fn elapsed(&self) -> bool {
641 let creation_time = Duration::from_secs(self.creation_time.get().into());
642 let now = Duration::from_secs(SecondsSinceUnixEpoch::now().get().into());
643 now.checked_sub(creation_time)
644 .map(|elapsed| elapsed >= self.safe_rotation_period())
645 .unwrap_or(true)
646 }
647
648 fn safe_rotation_period(&self) -> Duration {
657 if cfg!(feature = "_disable-minimum-rotation-period-ms") {
658 self.settings.rotation_period
659 } else {
660 max(self.settings.rotation_period, ONE_HOUR)
661 }
662 }
663
664 pub fn expired(&self) -> bool {
669 let count = self.message_count.load(Ordering::SeqCst);
670 let rotation_period_msgs = self.settings.rotation_period_msgs.clamp(1, 10_000);
676
677 count >= rotation_period_msgs || self.elapsed()
678 }
679
680 pub fn invalidated(&self) -> bool {
682 self.invalidated.load(Ordering::Relaxed)
683 }
684
685 pub fn mark_as_shared(&self) {
690 self.shared.store(true, Ordering::Relaxed);
691 }
692
693 pub fn shared(&self) -> bool {
695 self.shared.load(Ordering::Relaxed)
696 }
697
698 pub async fn session_key(&self) -> SessionKey {
702 let session = self.inner.read().await;
703 session.session_key()
704 }
705
706 pub fn sender_key(&self) -> Curve25519PublicKey {
708 self.account_identity_keys.as_ref().curve25519.to_owned()
709 }
710
711 pub fn room_id(&self) -> &RoomId {
713 &self.room_id
714 }
715
716 pub fn session_id(&self) -> &str {
718 &self.session_id
719 }
720
721 pub async fn message_index(&self) -> u32 {
726 let session = self.inner.read().await;
727 session.message_index()
728 }
729
730 pub(crate) async fn as_content(&self) -> RoomKeyContent {
731 let session_key = self.session_key().await;
732 let shared_history =
733 shared_history_from_history_visibility(&self.settings.history_visibility);
734
735 RoomKeyContent::MegolmV1AesSha2(
736 MegolmV1AesSha2RoomKeyContent::new(
737 self.room_id().to_owned(),
738 self.session_id().to_owned(),
739 session_key,
740 shared_history,
741 )
742 .into(),
743 )
744 }
745
746 pub(crate) fn sharing_view(&self) -> SharingView<'_> {
750 SharingView {
751 shared_with_set: self.shared_with_set.read(),
752 to_share_with_set: self.to_share_with_set.read(),
753 }
754 }
755
756 #[cfg(test)]
759 pub fn mark_shared_with_from_index(
760 &self,
761 user_id: &UserId,
762 device_id: &DeviceId,
763 sender_key: Curve25519PublicKey,
764 index: u32,
765 ) {
766 self.shared_with_set.write().entry(user_id.to_owned()).or_default().insert(
767 device_id.to_owned(),
768 ShareInfo::new_shared(sender_key, index, Default::default()),
769 );
770 }
771
772 #[cfg(test)]
775 pub async fn mark_shared_with(
776 &self,
777 user_id: &UserId,
778 device_id: &DeviceId,
779 sender_key: Curve25519PublicKey,
780 ) {
781 let share_info =
782 ShareInfo::new_shared(sender_key, self.message_index().await, Default::default());
783 self.shared_with_set
784 .write()
785 .entry(user_id.to_owned())
786 .or_default()
787 .insert(device_id.to_owned(), share_info);
788 }
789
790 pub(crate) fn pending_requests(&self) -> Vec<Arc<ToDeviceRequest>> {
793 self.to_share_with_set.read().values().map(|(req, _)| req.clone()).collect()
794 }
795
796 pub(crate) fn pending_request_ids(&self) -> Vec<OwnedTransactionId> {
798 self.to_share_with_set.read().keys().cloned().collect()
799 }
800
801 pub fn from_pickle(
819 device_id: OwnedDeviceId,
820 identity_keys: Arc<IdentityKeys>,
821 pickle: PickledOutboundGroupSession,
822 ) -> Result<Self, PickleError> {
823 let inner: GroupSession = pickle.pickle.into();
824 let session_id = inner.session_id();
825
826 Ok(Self {
827 inner: Arc::new(RwLock::new(inner)),
828 device_id,
829 account_identity_keys: identity_keys,
830 session_id: session_id.into(),
831 room_id: pickle.room_id,
832 creation_time: pickle.creation_time,
833 message_count: AtomicU64::from(pickle.message_count).into(),
834 shared: AtomicBool::from(pickle.shared).into(),
835 invalidated: AtomicBool::from(pickle.invalidated).into(),
836 settings: pickle.settings,
837 shared_with_set: Arc::new(StdRwLock::new(pickle.shared_with_set)),
838 to_share_with_set: Arc::new(StdRwLock::new(pickle.requests)),
839 })
840 }
841
842 pub async fn pickle(&self) -> PickledOutboundGroupSession {
850 let pickle = self.inner.read().await.pickle();
851
852 PickledOutboundGroupSession {
853 pickle,
854 room_id: self.room_id.clone(),
855 settings: self.settings.clone(),
856 creation_time: self.creation_time,
857 message_count: self.message_count.load(Ordering::SeqCst),
858 shared: self.shared(),
859 invalidated: self.invalidated(),
860 shared_with_set: self.shared_with_set.read().clone(),
861 requests: self.to_share_with_set.read().clone(),
862 }
863 }
864}
865
866#[cfg(not(tarpaulin_include))]
867impl fmt::Debug for OutboundGroupSession {
868 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
869 f.debug_struct("OutboundGroupSession")
870 .field("session_id", &self.session_id)
871 .field("room_id", &self.room_id)
872 .field("creation_time", &self.creation_time)
873 .field("message_count", &self.message_count)
874 .finish()
875 }
876}
877
878#[derive(Deserialize, Serialize)]
883#[allow(missing_debug_implementations)]
884pub struct PickledOutboundGroupSession {
885 pub pickle: GroupSessionPickle,
887 pub settings: Arc<EncryptionSettings>,
889 pub room_id: OwnedRoomId,
891 pub creation_time: SecondsSinceUnixEpoch,
893 pub message_count: u64,
895 pub shared: bool,
897 pub invalidated: bool,
899 pub shared_with_set: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>,
901 pub requests: BTreeMap<OwnedTransactionId, (Arc<ToDeviceRequest>, ShareInfoSet)>,
903}
904
905#[cfg(test)]
906mod tests {
907 use std::time::Duration;
908
909 use ruma::{
910 EventEncryptionAlgorithm,
911 events::room::{
912 encryption::RoomEncryptionEventContent, history_visibility::HistoryVisibility,
913 },
914 uint,
915 };
916
917 use super::{EncryptionSettings, ROTATION_MESSAGES, ROTATION_PERIOD, ShareState};
918 use crate::CollectStrategy;
919
920 #[test]
921 fn test_encryption_settings_conversion() {
922 let mut content =
923 RoomEncryptionEventContent::new(EventEncryptionAlgorithm::MegolmV1AesSha2);
924 let settings = EncryptionSettings::new(
925 content.clone(),
926 HistoryVisibility::Joined,
927 CollectStrategy::AllDevices,
928 );
929
930 assert_eq!(settings.rotation_period, ROTATION_PERIOD);
931 assert_eq!(settings.rotation_period_msgs, ROTATION_MESSAGES);
932
933 content.rotation_period_ms = Some(uint!(3600));
934 content.rotation_period_msgs = Some(uint!(500));
935
936 let settings = EncryptionSettings::new(
937 content,
938 HistoryVisibility::Shared,
939 CollectStrategy::AllDevices,
940 );
941
942 assert_eq!(settings.rotation_period, Duration::from_millis(3600));
943 assert_eq!(settings.rotation_period_msgs, 500);
944 }
945
946 #[test]
949 fn test_share_state_ordering() {
950 let values = [
951 ShareState::NotShared,
952 ShareState::SharedButChangedSenderKey,
953 ShareState::Shared { message_index: 1, olm_wedging_index: Default::default() },
954 ];
955 match values[0] {
957 ShareState::NotShared
958 | ShareState::SharedButChangedSenderKey
959 | ShareState::Shared { .. } => {}
960 }
961 assert!(values.is_sorted());
962 }
963
964 #[cfg(any(target_os = "linux", target_os = "macos", target_family = "wasm"))]
965 mod expiration {
966 use std::{sync::atomic::Ordering, time::Duration};
967
968 use matrix_sdk_test::async_test;
969 use ruma::{
970 SecondsSinceUnixEpoch, device_id, events::room::message::RoomMessageEventContent,
971 room_id, serde::Raw, uint, user_id,
972 };
973
974 use crate::{
975 Account, EncryptionSettings, MegolmError,
976 olm::{OutboundGroupSession, SenderData},
977 };
978
979 const TWO_HOURS: Duration = Duration::from_secs(60 * 60 * 2);
980
981 #[async_test]
982 async fn test_session_is_not_expired_if_no_messages_sent_and_no_time_passed() {
983 let session = create_session(EncryptionSettings {
985 rotation_period_msgs: 1,
986 ..Default::default()
987 })
988 .await;
989
990 assert!(!session.expired());
994 }
995
996 #[async_test]
997 async fn test_session_is_expired_if_we_rotate_every_message_and_one_was_sent()
998 -> Result<(), MegolmError> {
999 let session = create_session(EncryptionSettings {
1001 rotation_period_msgs: 1,
1002 ..Default::default()
1003 })
1004 .await;
1005
1006 let _ = session
1008 .encrypt(
1009 "m.room.message",
1010 &Raw::new(&RoomMessageEventContent::text_plain("Test message"))?.cast(),
1011 )
1012 .await;
1013
1014 assert!(session.expired());
1016
1017 Ok(())
1018 }
1019
1020 #[async_test]
1021 async fn test_session_with_rotation_period_is_not_expired_after_no_time() {
1022 let session = create_session(EncryptionSettings {
1024 rotation_period: TWO_HOURS,
1025 ..Default::default()
1026 })
1027 .await;
1028
1029 assert!(!session.expired());
1033 }
1034
1035 #[async_test]
1036 async fn test_session_is_expired_after_rotation_period() {
1037 let mut session = create_session(EncryptionSettings {
1039 rotation_period: TWO_HOURS,
1040 ..Default::default()
1041 })
1042 .await;
1043
1044 let now = SecondsSinceUnixEpoch::now();
1046 session.creation_time = SecondsSinceUnixEpoch(now.get() - uint!(10800));
1047
1048 assert!(session.expired());
1050 }
1051
1052 #[async_test]
1053 #[cfg(not(feature = "_disable-minimum-rotation-period-ms"))]
1054 async fn test_session_does_not_expire_under_one_hour_even_if_we_ask_for_shorter() {
1055 let mut session = create_session(EncryptionSettings {
1057 rotation_period: Duration::from_millis(100),
1058 ..Default::default()
1059 })
1060 .await;
1061
1062 let now = SecondsSinceUnixEpoch::now();
1064 session.creation_time = SecondsSinceUnixEpoch(now.get() - uint!(1800));
1065
1066 assert!(!session.expired());
1068
1069 session.creation_time = SecondsSinceUnixEpoch(now.get() - uint!(3601));
1071
1072 assert!(session.expired());
1074 }
1075
1076 #[async_test]
1077 #[cfg(feature = "_disable-minimum-rotation-period-ms")]
1078 async fn test_with_disable_minrotperiod_feature_sessions_can_expire_quickly() {
1079 let mut session = create_session(EncryptionSettings {
1081 rotation_period: Duration::from_millis(100),
1082 ..Default::default()
1083 })
1084 .await;
1085
1086 let now = SecondsSinceUnixEpoch::now();
1088 session.creation_time = SecondsSinceUnixEpoch(now.get() - uint!(1800));
1089
1090 assert!(session.expired());
1093 }
1094
1095 #[async_test]
1096 async fn test_session_with_zero_msgs_rotation_is_not_expired_initially() {
1097 let session = create_session(EncryptionSettings {
1099 rotation_period_msgs: 0,
1100 ..Default::default()
1101 })
1102 .await;
1103
1104 assert!(!session.expired());
1109 }
1110
1111 #[async_test]
1112 async fn test_session_with_zero_msgs_rotation_expires_after_one_message()
1113 -> Result<(), MegolmError> {
1114 let session = create_session(EncryptionSettings {
1116 rotation_period_msgs: 0,
1117 ..Default::default()
1118 })
1119 .await;
1120
1121 let _ = session
1123 .encrypt(
1124 "m.room.message",
1125 &Raw::new(&RoomMessageEventContent::text_plain("Test message"))?.cast(),
1126 )
1127 .await;
1128
1129 assert!(session.expired());
1132
1133 Ok(())
1134 }
1135
1136 #[async_test]
1137 async fn test_session_expires_after_10k_messages_even_if_we_ask_for_more() {
1138 let session = create_session(EncryptionSettings {
1140 rotation_period_msgs: 100_000,
1141 ..Default::default()
1142 })
1143 .await;
1144
1145 assert!(!session.expired());
1147 session.message_count.store(1000, Ordering::SeqCst);
1148 assert!(!session.expired());
1149 session.message_count.store(9999, Ordering::SeqCst);
1150 assert!(!session.expired());
1151
1152 session.message_count.store(10_000, Ordering::SeqCst);
1154
1155 assert!(session.expired());
1158 }
1159
1160 async fn create_session(settings: EncryptionSettings) -> OutboundGroupSession {
1161 let account =
1162 Account::with_device_id(user_id!("@alice:example.org"), device_id!("DEVICEID"))
1163 .static_data;
1164 let (session, _) = account
1165 .create_group_session_pair(
1166 room_id!("!test_room:example.org"),
1167 settings,
1168 SenderData::unknown(),
1169 )
1170 .await
1171 .unwrap();
1172 session
1173 }
1174 }
1175}