mod share_strategy;
use std::{
collections::{BTreeMap, BTreeSet},
fmt::Debug,
sync::{Arc, RwLock as StdRwLock},
};
use futures_util::future::join_all;
use itertools::Itertools;
use matrix_sdk_common::{deserialized_responses::WithheldCode, executor::spawn};
use ruma::{
events::{AnyMessageLikeEventContent, ToDeviceEventType},
serde::Raw,
to_device::DeviceIdOrAllDevices,
OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId,
};
pub(crate) use share_strategy::CollectRecipientsResult;
pub use share_strategy::CollectStrategy;
use tracing::{debug, error, info, instrument, trace};
use crate::{
error::{EventError, MegolmResult, OlmResult},
identities::device::MaybeEncryptedRoomKey,
olm::{
InboundGroupSession, OutboundGroupSession, SenderData, SenderDataFinder, Session,
ShareInfo, ShareState,
},
store::{Changes, CryptoStoreWrapper, Result as StoreResult, Store},
types::{events::room::encrypted::RoomEncryptedEventContent, requests::ToDeviceRequest},
Device, DeviceData, EncryptionSettings, OlmError,
};
#[derive(Clone, Debug)]
pub(crate) struct GroupSessionCache {
store: Store,
sessions: Arc<StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>>,
sessions_being_shared: Arc<StdRwLock<BTreeMap<OwnedTransactionId, OutboundGroupSession>>>,
}
impl GroupSessionCache {
pub(crate) fn new(store: Store) -> Self {
Self { store, sessions: Default::default(), sessions_being_shared: Default::default() }
}
pub(crate) fn insert(&self, session: OutboundGroupSession) {
self.sessions.write().unwrap().insert(session.room_id().to_owned(), session);
}
pub async fn get_or_load(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
if let Some(s) = self.sessions.read().unwrap().get(room_id) {
return Some(s.clone());
}
match self.store.get_outbound_group_session(room_id).await {
Ok(Some(s)) => {
{
let mut sessions_being_shared = self.sessions_being_shared.write().unwrap();
for request_id in s.pending_request_ids() {
sessions_being_shared.insert(request_id, s.clone());
}
}
self.sessions.write().unwrap().insert(room_id.to_owned(), s.clone());
Some(s)
}
Ok(None) => None,
Err(e) => {
error!("Couldn't restore an outbound group session: {e:?}");
None
}
}
}
fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
self.sessions.read().unwrap().get(room_id).cloned()
}
fn has_session_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
self.sessions.read().unwrap().values().any(|s| s.is_withheld_to(device, code))
}
fn remove_from_being_shared(&self, id: &TransactionId) -> Option<OutboundGroupSession> {
self.sessions_being_shared.write().unwrap().remove(id)
}
fn mark_as_being_shared(&self, id: OwnedTransactionId, session: OutboundGroupSession) {
self.sessions_being_shared.write().unwrap().insert(id, session);
}
}
#[derive(Debug, Clone)]
pub(crate) struct GroupSessionManager {
store: Store,
sessions: GroupSessionCache,
}
impl GroupSessionManager {
const MAX_TO_DEVICE_MESSAGES: usize = 250;
pub fn new(store: Store) -> Self {
Self { store: store.clone(), sessions: GroupSessionCache::new(store) }
}
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
if let Some(s) = self.sessions.get(room_id) {
s.invalidate_session();
let mut changes = Changes::default();
changes.outbound_group_sessions.push(s.clone());
self.store.save_changes(changes).await?;
Ok(true)
} else {
Ok(false)
}
}
pub async fn mark_request_as_sent(&self, request_id: &TransactionId) -> StoreResult<()> {
let Some(session) = self.sessions.remove_from_being_shared(request_id) else {
return Ok(());
};
let no_olm = session.mark_request_as_sent(request_id);
let mut changes = Changes::default();
for (user_id, devices) in &no_olm {
for device_id in devices {
let device = self.store.get_device(user_id, device_id).await;
if let Ok(Some(device)) = device {
device.mark_withheld_code_as_sent();
changes.devices.changed.push(device.inner.clone());
} else {
error!(
?request_id,
"Marking to-device no olm as sent but device not found, might \
have been deleted?"
);
}
}
}
changes.outbound_group_sessions.push(session.clone());
self.store.save_changes(changes).await
}
#[cfg(test)]
pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
self.sessions.get(room_id)
}
pub async fn encrypt(
&self,
room_id: &RoomId,
event_type: &str,
content: &Raw<AnyMessageLikeEventContent>,
) -> MegolmResult<Raw<RoomEncryptedEventContent>> {
let session =
self.sessions.get_or_load(room_id).await.expect("Session wasn't created nor shared");
assert!(!session.expired(), "Session expired");
let content = session.encrypt(event_type, content).await;
let mut changes = Changes::default();
changes.outbound_group_sessions.push(session);
self.store.save_changes(changes).await?;
Ok(content)
}
pub async fn create_outbound_group_session(
&self,
room_id: &RoomId,
settings: EncryptionSettings,
own_sender_data: SenderData,
) -> OlmResult<(OutboundGroupSession, InboundGroupSession)> {
let (outbound, inbound) = self
.store
.static_account()
.create_group_session_pair(room_id, settings, own_sender_data)
.await
.map_err(|_| EventError::UnsupportedAlgorithm)?;
self.sessions.insert(outbound.clone());
Ok((outbound, inbound))
}
pub async fn get_or_create_outbound_session(
&self,
room_id: &RoomId,
settings: EncryptionSettings,
own_sender_data: SenderData,
) -> OlmResult<(OutboundGroupSession, Option<InboundGroupSession>)> {
let outbound_session = self.sessions.get_or_load(room_id).await;
if let Some(s) = outbound_session {
if s.expired() || s.invalidated() {
self.create_outbound_group_session(room_id, settings, own_sender_data)
.await
.map(|(o, i)| (o, i.into()))
} else {
Ok((s, None))
}
} else {
self.create_outbound_group_session(room_id, settings, own_sender_data)
.await
.map(|(o, i)| (o, i.into()))
}
}
async fn encrypt_session_for(
store: Arc<CryptoStoreWrapper>,
group_session: OutboundGroupSession,
devices: Vec<DeviceData>,
) -> OlmResult<(
OwnedTransactionId,
ToDeviceRequest,
BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>,
Vec<Session>,
Vec<(DeviceData, WithheldCode)>,
)> {
pub struct DeviceResult {
device: DeviceData,
maybe_encrypted_room_key: MaybeEncryptedRoomKey,
}
let mut messages = BTreeMap::new();
let mut changed_sessions = Vec::new();
let mut share_infos = BTreeMap::new();
let mut withheld_devices = Vec::new();
let encrypt = |store: Arc<CryptoStoreWrapper>,
device: DeviceData,
session: OutboundGroupSession| async move {
let encryption_result = device.maybe_encrypt_room_key(store.as_ref(), session).await?;
Ok::<_, OlmError>(DeviceResult { device, maybe_encrypted_room_key: encryption_result })
};
let tasks: Vec<_> = devices
.iter()
.map(|d| spawn(encrypt(store.clone(), d.clone(), group_session.clone())))
.collect();
let results = join_all(tasks).await;
for result in results {
let result = result.expect("Encryption task panicked")?;
match result.maybe_encrypted_room_key {
MaybeEncryptedRoomKey::Encrypted { used_session, share_info, message } => {
changed_sessions.push(used_session);
let user_id = result.device.user_id().to_owned();
let device_id = result.device.device_id().to_owned();
messages
.entry(user_id.to_owned())
.or_insert_with(BTreeMap::new)
.insert(DeviceIdOrAllDevices::DeviceId(device_id.to_owned()), message);
share_infos
.entry(user_id)
.or_insert_with(BTreeMap::new)
.insert(device_id, share_info);
}
MaybeEncryptedRoomKey::Withheld { code } => {
withheld_devices.push((result.device, code));
}
}
}
let txn_id = TransactionId::new();
let request = ToDeviceRequest {
event_type: ToDeviceEventType::RoomEncrypted,
txn_id: txn_id.to_owned(),
messages,
};
Ok((txn_id, request, share_infos, changed_sessions, withheld_devices))
}
#[instrument(skip_all)]
pub async fn collect_session_recipients(
&self,
users: impl Iterator<Item = &UserId>,
settings: &EncryptionSettings,
outbound: &OutboundGroupSession,
) -> OlmResult<CollectRecipientsResult> {
share_strategy::collect_session_recipients(&self.store, users, settings, outbound).await
}
async fn encrypt_request(
store: Arc<CryptoStoreWrapper>,
chunk: Vec<DeviceData>,
outbound: OutboundGroupSession,
sessions: GroupSessionCache,
) -> OlmResult<(Vec<Session>, Vec<(DeviceData, WithheldCode)>)> {
let (id, request, share_infos, used_sessions, no_olm) =
Self::encrypt_session_for(store, outbound.clone(), chunk).await?;
if !request.messages.is_empty() {
trace!(
recipient_count = request.message_count(),
transaction_id = ?id,
"Created a to-device request carrying a room_key"
);
outbound.add_request(id.clone(), request.into(), share_infos);
sessions.mark_as_being_shared(id, outbound.clone());
}
Ok((used_sessions, no_olm))
}
pub(crate) fn session_cache(&self) -> GroupSessionCache {
self.sessions.clone()
}
async fn maybe_rotate_group_session(
&self,
should_rotate: bool,
room_id: &RoomId,
outbound: OutboundGroupSession,
encryption_settings: EncryptionSettings,
changes: &mut Changes,
own_device: Option<Device>,
) -> OlmResult<OutboundGroupSession> {
Ok(if should_rotate {
let old_session_id = outbound.session_id();
let (outbound, mut inbound) = self
.create_outbound_group_session(room_id, encryption_settings, SenderData::unknown())
.await?;
let own_sender_data = if let Some(device) = own_device {
SenderDataFinder::find_using_device_data(
&self.store,
device.inner.clone(),
&inbound,
)
.await?
} else {
error!("Unable to find our own device!");
SenderData::unknown()
};
inbound.sender_data = own_sender_data;
changes.outbound_group_sessions.push(outbound.clone());
changes.inbound_group_sessions.push(inbound);
debug!(
old_session_id = old_session_id,
session_id = outbound.session_id(),
"A user or device has left the room since we last sent a \
message, or the encryption settings have changed. Rotating the \
room key.",
);
outbound
} else {
outbound
})
}
async fn encrypt_for_devices(
&self,
recipient_devices: Vec<DeviceData>,
group_session: &OutboundGroupSession,
changes: &mut Changes,
) -> OlmResult<Vec<(DeviceData, WithheldCode)>> {
if !recipient_devices.is_empty() {
#[allow(unknown_lints, clippy::unwrap_or_default)] let recipients = recipient_devices.iter().fold(BTreeMap::new(), |mut acc, d| {
acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id());
acc
});
changes.outbound_group_sessions = vec![group_session.clone()];
let message_index = group_session.message_index().await;
info!(
?recipients,
message_index,
room_id = ?group_session.room_id(),
session_id = group_session.session_id(),
"Trying to encrypt a room key",
);
}
let tasks: Vec<_> = recipient_devices
.chunks(Self::MAX_TO_DEVICE_MESSAGES)
.map(|chunk| {
spawn(Self::encrypt_request(
self.store.crypto_store(),
chunk.to_vec(),
group_session.clone(),
self.sessions.clone(),
))
})
.collect();
let mut withheld_devices = Vec::new();
for result in join_all(tasks).await {
let result = result.expect("Encryption task panicked");
let (used_sessions, failed_no_olm) = result?;
changes.sessions.extend(used_sessions);
withheld_devices.extend(failed_no_olm);
}
Ok(withheld_devices)
}
fn is_withheld_to(
&self,
group_session: &OutboundGroupSession,
device: &DeviceData,
code: &WithheldCode,
) -> bool {
if code == &WithheldCode::NoOlm {
device.was_withheld_code_sent() || self.sessions.has_session_withheld_to(device, code)
} else {
group_session.is_withheld_to(device, code)
}
}
fn handle_withheld_devices(
&self,
group_session: &OutboundGroupSession,
withheld_devices: Vec<(DeviceData, WithheldCode)>,
) -> OlmResult<()> {
let to_content = |code| {
let content = group_session.withheld_code(code);
Raw::new(&content).expect("We can always serialize a withheld content info").cast()
};
let chunk_to_request = |chunk| {
let mut messages = BTreeMap::new();
let mut share_infos = BTreeMap::new();
for (device, code) in chunk {
let device: DeviceData = device;
let code: WithheldCode = code;
let user_id = device.user_id().to_owned();
let device_id = device.device_id().to_owned();
let share_info = ShareInfo::new_withheld(code.to_owned());
let content = to_content(code);
messages
.entry(user_id.to_owned())
.or_insert_with(BTreeMap::new)
.insert(DeviceIdOrAllDevices::DeviceId(device_id.to_owned()), content);
share_infos
.entry(user_id)
.or_insert_with(BTreeMap::new)
.insert(device_id, share_info);
}
let txn_id = TransactionId::new();
let request = ToDeviceRequest {
event_type: ToDeviceEventType::from("m.room_key.withheld"),
txn_id,
messages,
};
(request, share_infos)
};
let result: Vec<_> = withheld_devices
.into_iter()
.filter(|(device, code)| !self.is_withheld_to(group_session, device, code))
.chunks(Self::MAX_TO_DEVICE_MESSAGES)
.into_iter()
.map(chunk_to_request)
.collect();
for (request, share_info) in result {
if !request.messages.is_empty() {
let txn_id = request.txn_id.to_owned();
group_session.add_request(txn_id.to_owned(), request.into(), share_info);
self.sessions.mark_as_being_shared(txn_id, group_session.clone());
}
}
Ok(())
}
fn log_room_key_sharing_result(requests: &[Arc<ToDeviceRequest>]) {
for request in requests {
let message_list = Self::to_device_request_to_log_list(request);
info!(
request_id = ?request.txn_id,
?message_list,
"Created batch of to-device messages of type {}",
request.event_type
);
}
}
fn to_device_request_to_log_list(
request: &Arc<ToDeviceRequest>,
) -> Vec<(String, String, String)> {
#[derive(serde::Deserialize)]
struct ContentStub<'a> {
#[serde(borrow, default, rename = "org.matrix.msgid")]
message_id: Option<&'a str>,
}
let mut result: Vec<(String, String, String)> = Vec::new();
for (user_id, device_map) in &request.messages {
for (device, content) in device_map {
let message_id: Option<&str> = content
.deserialize_as::<ContentStub<'_>>()
.expect("We should be able to deserialize the content we generated")
.message_id;
result.push((
message_id.unwrap_or("<undefined>").to_owned(),
user_id.to_string(),
device.to_string(),
));
}
}
result
}
#[instrument(skip(self, users, encryption_settings), fields(session_id))]
pub async fn share_room_key(
&self,
room_id: &RoomId,
users: impl Iterator<Item = &UserId>,
encryption_settings: impl Into<EncryptionSettings>,
) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
trace!("Checking if a room key needs to be shared");
let account = self.store.static_account();
let device = self.store.get_device(account.user_id(), account.device_id()).await?;
let encryption_settings = encryption_settings.into();
let mut changes = Changes::default();
let (outbound, inbound) = self
.get_or_create_outbound_session(
room_id,
encryption_settings.clone(),
SenderData::unknown(),
)
.await?;
tracing::Span::current().record("session_id", outbound.session_id());
if let Some(mut inbound) = inbound {
let own_sender_data = if let Some(device) = &device {
SenderDataFinder::find_using_device_data(
&self.store,
device.inner.clone(),
&inbound,
)
.await?
} else {
error!("Unable to find our own device!");
SenderData::unknown()
};
inbound.sender_data = own_sender_data;
changes.outbound_group_sessions.push(outbound.clone());
changes.inbound_group_sessions.push(inbound);
}
let CollectRecipientsResult { should_rotate, devices, mut withheld_devices } =
self.collect_session_recipients(users, &encryption_settings, &outbound).await?;
let outbound = self
.maybe_rotate_group_session(
should_rotate,
room_id,
outbound,
encryption_settings,
&mut changes,
device,
)
.await?;
let devices: Vec<_> = devices
.into_iter()
.flat_map(|(_, d)| {
d.into_iter().filter(|d| match outbound.is_shared_with(d) {
ShareState::NotShared => true,
ShareState::Shared { message_index: _, olm_wedging_index } => {
olm_wedging_index < d.olm_wedging_index
}
_ => false,
})
})
.collect();
let unable_to_encrypt_devices =
self.encrypt_for_devices(devices, &outbound, &mut changes).await?;
withheld_devices.extend(unable_to_encrypt_devices);
self.handle_withheld_devices(&outbound, withheld_devices)?;
let requests = outbound.pending_requests();
if requests.is_empty() {
if !outbound.shared() {
debug!("The room key doesn't need to be shared with anyone. Marking as shared.");
outbound.mark_as_shared();
changes.outbound_group_sessions.push(outbound.clone());
}
} else {
Self::log_room_key_sharing_result(&requests)
}
if !changes.is_empty() {
let session_count = changes.sessions.len();
self.store.save_changes(changes).await?;
trace!(
session_count = session_count,
"Stored the changed sessions after encrypting an room key"
);
}
Ok(requests)
}
}
#[cfg(test)]
mod tests {
use std::{
collections::{BTreeMap, BTreeSet},
iter,
ops::Deref,
sync::Arc,
};
use assert_matches2::assert_let;
use matrix_sdk_common::deserialized_responses::WithheldCode;
use matrix_sdk_test::{async_test, ruma_response_from_json};
use ruma::{
api::client::{
keys::{claim_keys, get_keys, upload_keys},
to_device::send_event_to_device::v3::Response as ToDeviceResponse,
},
device_id,
events::room::history_visibility::HistoryVisibility,
room_id,
to_device::DeviceIdOrAllDevices,
user_id, DeviceId, OneTimeKeyAlgorithm, TransactionId, UInt, UserId,
};
use serde_json::{json, Value};
use crate::{
identities::DeviceData,
machine::EncryptionSyncChanges,
olm::{Account, SenderData},
session_manager::{group_sessions::CollectRecipientsResult, CollectStrategy},
types::{
events::{
room::encrypted::EncryptedToDeviceEvent,
room_key_withheld::RoomKeyWithheldContent::{self, MegolmV1AesSha2},
},
requests::ToDeviceRequest,
DeviceKeys, EventEncryptionAlgorithm,
},
EncryptionSettings, LocalTrust, OlmMachine,
};
fn alice_id() -> &'static UserId {
user_id!("@alice:example.org")
}
fn alice_device_id() -> &'static DeviceId {
device_id!("JLAFKJWSCS")
}
fn keys_query_response() -> get_keys::v3::Response {
let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_query.json");
let data: Value = serde_json::from_slice(data).unwrap();
ruma_response_from_json(&data)
}
fn bob_keys_query_response() -> get_keys::v3::Response {
let data = json!({
"device_keys": {
"@bob:localhost": {
"BOBDEVICE": {
"user_id": "@bob:localhost",
"device_id": "BOBDEVICE",
"algorithms": [
"m.olm.v1.curve25519-aes-sha2",
"m.megolm.v1.aes-sha2",
"m.megolm.v2.aes-sha2"
],
"keys": {
"curve25519:BOBDEVICE": "QzXDFZj0Pt5xG4r11XGSrqE4mnFOTgRM5pz7n3tzohU",
"ed25519:BOBDEVICE": "T7QMEXcEo/NfiC/8doVHT+2XnMm0pDpRa27bmE8PlPI"
},
"signatures": {
"@bob:localhost": {
"ed25519:BOBDEVICE": "1Ee9J02KoVf4DKhT+LkurpZJEygiznqpgkT4lqvMTLtZyzShsVTnwmoMPttuGcJkLp9lMK1egveNYCEaYP80Cw"
}
}
}
}
}
});
ruma_response_from_json(&data)
}
fn bob_one_time_key() -> claim_keys::v3::Response {
let data = json!({
"failures": {},
"one_time_keys":{
"@bob:localhost":{
"BOBDEVICE":{
"signed_curve25519:AAAAAAAAAAA": {
"key":"bm1olfbksjC5SwKxCLLK4XaINCA0FwR/155J85gIpCk",
"signatures":{
"@bob:localhost":{
"ed25519:BOBDEVICE":"BKyS/+EV76zdZkWgny2D0svZ0ycS3etfyHCrsDgm7MYe166HqQmSoX29HsjGLvE/5F+Sg2zW7RJileUvquPwDA"
}
}
}
}
}
}
});
ruma_response_from_json(&data)
}
fn keys_claim_response() -> claim_keys::v3::Response {
let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_claim.json");
let data: Value = serde_json::from_slice(data).unwrap();
ruma_response_from_json(&data)
}
async fn machine_with_user_test_helper(user_id: &UserId, device_id: &DeviceId) -> OlmMachine {
let keys_query = keys_query_response();
let txn_id = TransactionId::new();
let machine = OlmMachine::new(user_id, device_id).await;
machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
let (txn_id, _keys_claim_request) = machine
.get_missing_sessions(iter::once(user_id!("@example:localhost")))
.await
.unwrap()
.unwrap();
let keys_claim = keys_claim_response();
machine.mark_request_as_sent(&txn_id, &keys_claim).await.unwrap();
machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
let (txn_id, _keys_claim_request) = machine
.get_missing_sessions(iter::once(user_id!("@bob:localhost")))
.await
.unwrap()
.unwrap();
machine.mark_request_as_sent(&txn_id, &bob_one_time_key()).await.unwrap();
machine
}
async fn machine() -> OlmMachine {
machine_with_user_test_helper(alice_id(), alice_device_id()).await
}
async fn machine_with_shared_room_key_test_helper() -> OlmMachine {
let machine = machine().await;
let room_id = room_id!("!test:localhost");
let keys_claim = keys_claim_response();
let users = keys_claim.one_time_keys.keys().map(Deref::deref);
let requests =
machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
let outbound =
machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
assert!(!outbound.pending_requests().is_empty());
assert!(!outbound.shared());
let response = ToDeviceResponse::new();
for request in requests {
machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
}
assert!(outbound.shared());
assert!(outbound.pending_requests().is_empty());
machine
}
#[async_test]
async fn test_sharing() {
let machine = machine().await;
let room_id = room_id!("!test:localhost");
let keys_claim = keys_claim_response();
let users = keys_claim.one_time_keys.keys().map(Deref::deref);
let requests =
machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
let event_count: usize = requests
.iter()
.filter(|r| r.event_type == "m.room.encrypted".into())
.map(|r| r.message_count())
.sum();
assert_eq!(event_count, 148);
let withheld_count: usize = requests
.iter()
.filter(|r| r.event_type == "m.room_key.withheld".into())
.map(|r| r.message_count())
.sum();
assert_eq!(withheld_count, 2);
}
fn count_withheld_from(requests: &[Arc<ToDeviceRequest>], code: WithheldCode) -> usize {
requests
.iter()
.filter(|r| r.event_type == "m.room_key.withheld".into())
.map(|r| {
let mut count = 0;
for message in r.messages.values() {
message.iter().for_each(|(_, content)| {
let withheld: RoomKeyWithheldContent =
content.deserialize_as::<RoomKeyWithheldContent>().unwrap();
if let MegolmV1AesSha2(content) = withheld {
if content.withheld_code() == code {
count += 1;
}
}
})
}
count
})
.sum()
}
#[async_test]
async fn test_no_olm_sent_once() {
let machine = machine().await;
let keys_claim = keys_claim_response();
let users = keys_claim.one_time_keys.keys().map(Deref::deref);
let first_room_id = room_id!("!test:localhost");
let requests = machine
.share_room_key(first_room_id, users.to_owned(), EncryptionSettings::default())
.await
.unwrap();
let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
assert_eq!(withheld_count, 2);
let new_requests = machine
.share_room_key(first_room_id, users, EncryptionSettings::default())
.await
.unwrap();
let withheld_count: usize = count_withheld_from(&new_requests, WithheldCode::NoOlm);
assert_eq!(withheld_count, 2);
let response = ToDeviceResponse::new();
for request in requests {
machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
}
let second_room_id = room_id!("!other:localhost");
let users = keys_claim.one_time_keys.keys().map(Deref::deref);
let requests = machine
.share_room_key(second_room_id, users, EncryptionSettings::default())
.await
.unwrap();
let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
assert_eq!(withheld_count, 0);
}
#[async_test]
async fn test_ratcheted_sharing() {
let machine = machine_with_shared_room_key_test_helper().await;
let room_id = room_id!("!test:localhost");
let late_joiner = user_id!("@bob:localhost");
let keys_claim = keys_claim_response();
let mut users: BTreeSet<_> = keys_claim.one_time_keys.keys().map(Deref::deref).collect();
users.insert(late_joiner);
let requests = machine
.share_room_key(room_id, users.into_iter(), EncryptionSettings::default())
.await
.unwrap();
let event_count: usize = requests
.iter()
.filter(|r| r.event_type == "m.room.encrypted".into())
.map(|r| r.message_count())
.sum();
let outbound =
machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
assert_eq!(event_count, 1);
assert!(!outbound.pending_requests().is_empty());
}
#[async_test]
async fn test_changing_encryption_settings() {
let machine = machine_with_shared_room_key_test_helper().await;
let room_id = room_id!("!test:localhost");
let keys_claim = keys_claim_response();
let users = keys_claim.one_time_keys.keys().map(Deref::deref);
let outbound =
machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
let CollectRecipientsResult { should_rotate, .. } = machine
.inner
.group_session_manager
.collect_session_recipients(users.clone(), &EncryptionSettings::default(), &outbound)
.await
.unwrap();
assert!(!should_rotate);
let settings = EncryptionSettings {
history_visibility: HistoryVisibility::Invited,
..Default::default()
};
let CollectRecipientsResult { should_rotate, .. } = machine
.inner
.group_session_manager
.collect_session_recipients(users.clone(), &settings, &outbound)
.await
.unwrap();
assert!(should_rotate);
let settings = EncryptionSettings {
algorithm: EventEncryptionAlgorithm::from("m.megolm.v2.aes-sha2"),
..Default::default()
};
let CollectRecipientsResult { should_rotate, .. } = machine
.inner
.group_session_manager
.collect_session_recipients(users, &settings, &outbound)
.await
.unwrap();
assert!(should_rotate);
}
#[async_test]
async fn test_key_recipient_collecting() {
let user_id = user_id!("@example:localhost");
let device_id = device_id!("TESTDEVICE");
let room_id = room_id!("!test:localhost");
let machine = machine_with_user_test_helper(user_id, device_id).await;
let (outbound, _) = machine
.inner
.group_session_manager
.get_or_create_outbound_session(
room_id,
EncryptionSettings::default(),
SenderData::unknown(),
)
.await
.expect("We should be able to create a new session");
let history_visibility = HistoryVisibility::Joined;
let settings = EncryptionSettings { history_visibility, ..Default::default() };
let users = [user_id].into_iter();
let CollectRecipientsResult { devices: recipients, .. } = machine
.inner
.group_session_manager
.collect_session_recipients(users, &settings, &outbound)
.await
.expect("We should be able to collect the session recipients");
assert!(!recipients[user_id].is_empty());
assert!(!recipients[user_id]
.iter()
.any(|d| d.user_id() == user_id && d.device_id() == device_id));
let settings = EncryptionSettings {
sharing_strategy: CollectStrategy::DeviceBasedStrategy {
only_allow_trusted_devices: true,
error_on_verified_user_problem: false,
},
..Default::default()
};
let users = [user_id].into_iter();
let CollectRecipientsResult { devices: recipients, .. } = machine
.inner
.group_session_manager
.collect_session_recipients(users, &settings, &outbound)
.await
.expect("We should be able to collect the session recipients");
assert!(recipients[user_id].is_empty());
let device_id = "AFGUOBTZWM".into();
let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
device.set_local_trust(LocalTrust::Verified).await.unwrap();
let users = [user_id].into_iter();
let CollectRecipientsResult { devices: recipients, withheld_devices: withheld, .. } =
machine
.inner
.group_session_manager
.collect_session_recipients(users, &settings, &outbound)
.await
.expect("We should be able to collect the session recipients");
assert!(recipients[user_id]
.iter()
.any(|d| d.user_id() == user_id && d.device_id() == device_id));
let devices = machine.get_user_devices(user_id, None).await.unwrap();
devices
.devices()
.filter(|d| d.device_id() != device_id!("TESTDEVICE"))
.for_each(|d| {
if d.is_blacklisted() {
assert!(withheld.iter().any(|(dev, w)| {
dev.device_id() == d.device_id() && w == &WithheldCode::Blacklisted
}));
} else if !d.is_verified() {
assert!(withheld.iter().any(|(dev, w)| {
dev.device_id() == d.device_id() && w == &WithheldCode::Unverified
}));
}
});
assert_eq!(149, withheld.len());
}
#[async_test]
async fn test_sharing_withheld_only_trusted() {
let machine = machine().await;
let room_id = room_id!("!test:localhost");
let keys_claim = keys_claim_response();
let users = keys_claim.one_time_keys.keys().map(Deref::deref);
let settings = EncryptionSettings {
sharing_strategy: CollectStrategy::DeviceBasedStrategy {
only_allow_trusted_devices: true,
error_on_verified_user_problem: false,
},
..Default::default()
};
let user_id = user_id!("@example:localhost");
let device_id = "MWFXPINOAO".into();
let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
device.set_local_trust(LocalTrust::Verified).await.unwrap();
machine
.get_device(user_id, "MWVTUXDNNM".into(), None)
.await
.unwrap()
.unwrap()
.set_local_trust(LocalTrust::BlackListed)
.await
.unwrap();
let requests = machine.share_room_key(room_id, users, settings).await.unwrap();
let room_key_count =
requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).count();
assert_eq!(1, room_key_count);
let withheld_count =
requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
assert_eq!(1, withheld_count);
let event_count: usize = requests
.iter()
.filter(|r| r.event_type == "m.room_key.withheld".into())
.map(|r| r.message_count())
.sum();
assert_eq!(event_count, 149);
let has_blacklist =
requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).any(|r| {
let device_key = DeviceIdOrAllDevices::from(device_id!("MWVTUXDNNM").to_owned());
let content = &r.messages[user_id][&device_key];
let withheld: RoomKeyWithheldContent =
content.deserialize_as::<RoomKeyWithheldContent>().unwrap();
if let MegolmV1AesSha2(content) = withheld {
content.withheld_code() == WithheldCode::Blacklisted
} else {
false
}
});
assert!(has_blacklist);
}
#[async_test]
async fn test_no_olm_withheld_only_sent_once() {
let keys_query = keys_query_response();
let txn_id = TransactionId::new();
let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
let first_room = room_id!("!test:localhost");
let second_room = room_id!("!test2:localhost");
let bob_id = user_id!("@bob:localhost");
let settings = EncryptionSettings::default();
let users = [bob_id];
let requests = machine
.share_room_key(first_room, users.into_iter(), settings.to_owned())
.await
.unwrap();
let withheld_count =
requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
assert_eq!(withheld_count, 1);
assert_eq!(requests.len(), 1);
let second_requests =
machine.share_room_key(second_room, users.into_iter(), settings).await.unwrap();
let withheld_count =
second_requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
assert_eq!(withheld_count, 0);
assert_eq!(second_requests.len(), 0);
let response = ToDeviceResponse::new();
let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
assert!(!device.was_withheld_code_sent());
for request in requests {
machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
}
let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
assert!(device.was_withheld_code_sent());
}
#[async_test]
async fn test_resend_session_after_unwedging() {
let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
assert_let!(Ok(Some((txn_id, device_keys_request))) = machine.upload_device_keys().await);
let device_keys_response = upload_keys::v3::Response::new(BTreeMap::from([(
OneTimeKeyAlgorithm::SignedCurve25519,
UInt::new(device_keys_request.one_time_keys.len() as u64).unwrap(),
)]));
machine.mark_request_as_sent(&txn_id, &device_keys_response).await.unwrap();
let room_id = room_id!("!test:localhost");
let bob_id = user_id!("@bob:localhost");
let bob_account = Account::new(bob_id);
let keys_query_data = json!({
"device_keys": {
"@bob:localhost": {
bob_account.device_id.clone(): bob_account.device_keys()
}
}
});
let keys_query: get_keys::v3::Response = ruma_response_from_json(&keys_query_data);
let txn_id = TransactionId::new();
machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
let alice_device_keys =
device_keys_request.device_keys.unwrap().deserialize_as::<DeviceKeys>().unwrap();
let mut alice_otks = device_keys_request.one_time_keys.iter();
let alice_device = DeviceData::new(alice_device_keys, LocalTrust::Unset);
{
let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
let mut session = bob_account
.create_outbound_session(
&alice_device,
&BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
bob_account.device_keys(),
)
.unwrap();
let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
let to_device =
EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
let sync_changes = EncryptionSyncChanges {
to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
changed_devices: &Default::default(),
one_time_keys_counts: &Default::default(),
unused_fallback_keys: None,
next_batch_token: None,
};
let (decrypted, _) = machine.receive_sync_changes(sync_changes).await.unwrap();
assert_eq!(1, decrypted.len());
}
{
let requests = machine
.share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
.await
.unwrap();
let event_count: usize = requests
.iter()
.filter(|r| r.event_type == "m.room.encrypted".into())
.map(|r| r.message_count())
.sum();
assert_eq!(event_count, 1);
let response = ToDeviceResponse::new();
for request in requests {
machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
}
}
{
let requests = machine
.share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
.await
.unwrap();
let event_count: usize = requests
.iter()
.filter(|r| r.event_type == "m.room.encrypted".into())
.map(|r| r.message_count())
.sum();
assert_eq!(event_count, 0);
}
{
let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
let mut session = bob_account
.create_outbound_session(
&alice_device,
&BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
bob_account.device_keys(),
)
.unwrap();
let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
let to_device =
EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
let sync_changes = EncryptionSyncChanges {
to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
changed_devices: &Default::default(),
one_time_keys_counts: &Default::default(),
unused_fallback_keys: None,
next_batch_token: None,
};
let (decrypted, _) = machine.receive_sync_changes(sync_changes).await.unwrap();
assert_eq!(1, decrypted.len());
}
{
let requests = machine
.share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
.await
.unwrap();
let event_count: usize = requests
.iter()
.filter(|r| r.event_type == "m.room.encrypted".into())
.map(|r| r.message_count())
.sum();
assert_eq!(event_count, 1);
let response = ToDeviceResponse::new();
for request in requests {
machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
}
}
{
let requests = machine
.share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
.await
.unwrap();
let event_count: usize = requests
.iter()
.filter(|r| r.event_type == "m.room.encrypted".into())
.map(|r| r.message_count())
.sum();
assert_eq!(event_count, 0);
}
}
}