use std::{
collections::{BTreeMap, HashMap, HashSet},
sync::{Arc, RwLock as StdRwLock},
time::Duration,
};
use itertools::Itertools;
use matrix_sdk_common::deserialized_responses::{
AlgorithmInfo, DeviceLinkProblem, EncryptionInfo, TimelineEvent, VerificationLevel,
VerificationState,
};
use ruma::{
api::client::{
dehydrated_device::DehydratedDeviceData,
keys::{
claim_keys::v3::Request as KeysClaimRequest,
get_keys::v3::Response as KeysQueryResponse,
upload_keys::v3::{Request as UploadKeysRequest, Response as UploadKeysResponse},
upload_signatures::v3::Request as UploadSignaturesRequest,
},
sync::sync_events::DeviceLists,
},
assign,
events::{
secret::request::SecretName, AnyMessageLikeEvent, AnyMessageLikeEventContent,
AnyToDeviceEvent, MessageLikeEventContent,
},
serde::Raw,
DeviceId, DeviceKeyAlgorithm, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId,
OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UInt, UserId,
};
use serde_json::value::to_raw_value;
use tokio::sync::Mutex;
use tracing::{
debug, error,
field::{debug, display},
info, instrument, warn, Span,
};
use vodozemac::{
megolm::{DecryptionError, SessionOrdering},
Curve25519PublicKey, Ed25519Signature,
};
use crate::{
backups::{BackupMachine, MegolmV1BackupKey},
dehydrated_devices::{DehydratedDevices, DehydrationError},
error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult, SetRoomSettingsError},
gossiping::GossipMachine,
identities::{user::UserIdentities, Device, IdentityManager, UserDevices},
olm::{
Account, CrossSigningStatus, EncryptionSettings, ExportedRoomKey, IdentityKeys,
InboundGroupSession, OlmDecryptionInfo, PrivateCrossSigningIdentity, SessionType,
StaticAccountData,
},
requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest},
session_manager::{GroupSessionManager, SessionManager},
store::{
Changes, CryptoStoreWrapper, DeviceChanges, IdentityChanges, IntoCryptoStore, MemoryStore,
PendingChanges, Result as StoreResult, RoomKeyInfo, RoomSettings, SecretImportError, Store,
StoreCache, StoreTransaction,
},
types::{
events::{
olm_v1::{AnyDecryptedOlmEvent, DecryptedRoomKeyEvent},
room::encrypted::{
EncryptedEvent, EncryptedToDeviceEvent, RoomEncryptedEventContent,
RoomEventEncryptionScheme, SupportedEventEncryptionSchemes,
},
room_key::{MegolmV1AesSha2Content, RoomKeyContent},
room_key_withheld::{
MegolmV1AesSha2WithheldContent, RoomKeyWithheldContent, RoomKeyWithheldEvent,
},
ToDeviceEvents,
},
EventEncryptionAlgorithm, Signatures,
},
utilities::timestamp_to_iso8601,
verification::{Verification, VerificationMachine, VerificationRequest},
CrossSigningKeyExport, CryptoStoreError, KeysQueryRequest, LocalTrust, ReadOnlyDevice,
RoomKeyImportResult, SignatureError, ToDeviceRequest,
};
#[derive(Clone)]
pub struct OlmMachine {
pub(crate) inner: Arc<OlmMachineInner>,
}
pub struct OlmMachineInner {
user_id: OwnedUserId,
device_id: OwnedDeviceId,
user_identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
store: Store,
session_manager: SessionManager,
pub(crate) group_session_manager: GroupSessionManager,
verification_machine: VerificationMachine,
pub(crate) key_request_machine: GossipMachine,
identity_manager: IdentityManager,
backup_machine: BackupMachine,
}
#[cfg(not(tarpaulin_include))]
impl std::fmt::Debug for OlmMachine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OlmMachine")
.field("user_id", &self.user_id())
.field("device_id", &self.device_id())
.finish()
}
}
impl OlmMachine {
const CURRENT_GENERATION_STORE_KEY: &'static str = "generation-counter";
pub async fn new(user_id: &UserId, device_id: &DeviceId) -> Self {
OlmMachine::with_store(user_id, device_id, MemoryStore::new())
.await
.expect("Reading and writing to the memory store always succeeds")
}
pub(crate) async fn rehydrate(
&self,
pickle_key: &[u8; 32],
device_id: &DeviceId,
device_data: Raw<DehydratedDeviceData>,
) -> Result<OlmMachine, DehydrationError> {
let account =
Account::rehydrate(pickle_key, self.user_id(), device_id, device_data).await?;
let static_account = account.static_data().clone();
let store = Arc::new(CryptoStoreWrapper::new(self.user_id(), MemoryStore::new()));
store.save_pending_changes(PendingChanges { account: Some(account) }).await?;
Ok(Self::new_helper(
device_id,
store,
static_account,
self.store().private_identity(),
None,
))
}
fn new_helper(
device_id: &DeviceId,
store: Arc<CryptoStoreWrapper>,
account: StaticAccountData,
user_identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
maybe_backup_key: Option<MegolmV1BackupKey>,
) -> Self {
let verification_machine =
VerificationMachine::new(account.clone(), user_identity.clone(), store.clone());
let store = Store::new(account, user_identity.clone(), store, verification_machine.clone());
let group_session_manager = GroupSessionManager::new(store.clone());
let identity_manager = IdentityManager::new(store.clone());
let users_for_key_claim = Arc::new(StdRwLock::new(BTreeMap::new()));
let key_request_machine = GossipMachine::new(
store.clone(),
identity_manager.clone(),
group_session_manager.session_cache(),
users_for_key_claim.clone(),
);
let session_manager =
SessionManager::new(users_for_key_claim, key_request_machine.clone(), store.clone());
let backup_machine = BackupMachine::new(store.clone(), maybe_backup_key);
let inner = Arc::new(OlmMachineInner {
user_id: store.user_id().to_owned(),
device_id: device_id.to_owned(),
user_identity,
store,
session_manager,
group_session_manager,
verification_machine,
key_request_machine,
identity_manager,
backup_machine,
});
Self { inner }
}
#[instrument(skip(store), fields(ed25519_key, curve25519_key))]
pub async fn with_store(
user_id: &UserId,
device_id: &DeviceId,
store: impl IntoCryptoStore,
) -> StoreResult<Self> {
let store = store.into_crypto_store();
let static_account = match store.load_account().await? {
Some(account) => {
if user_id != account.user_id() || device_id != account.device_id() {
return Err(CryptoStoreError::MismatchedAccount {
expected: (account.user_id().to_owned(), account.device_id().to_owned()),
got: (user_id.to_owned(), device_id.to_owned()),
});
}
Span::current()
.record("ed25519_key", display(account.identity_keys().ed25519))
.record("curve25519_key", display(account.identity_keys().curve25519));
debug!("Restored an Olm account");
account.static_data().clone()
}
None => {
let account = Account::with_device_id(user_id, device_id);
let static_account = account.static_data().clone();
Span::current()
.record("ed25519_key", display(account.identity_keys().ed25519))
.record("curve25519_key", display(account.identity_keys().curve25519));
let device = ReadOnlyDevice::from_account(&account);
device.set_trust_state(LocalTrust::Verified);
let changes = Changes {
devices: DeviceChanges { new: vec![device], ..Default::default() },
..Default::default()
};
store.save_changes(changes).await?;
store.save_pending_changes(PendingChanges { account: Some(account) }).await?;
debug!("Created a new Olm account");
static_account
}
};
let identity = match store.load_identity().await? {
Some(i) => {
let master_key = i
.master_public_key()
.await
.and_then(|m| m.get_first_key().map(|m| m.to_owned()));
debug!(?master_key, "Restored the cross signing identity");
i
}
None => {
debug!("Creating an empty cross signing identity stub");
PrivateCrossSigningIdentity::empty(user_id)
}
};
let saved_keys = store.load_backup_keys().await?;
let maybe_backup_key = saved_keys.decryption_key.and_then(|k| {
if let Some(version) = saved_keys.backup_version {
let megolm_v1_backup_key = k.megolm_v1_public_key();
megolm_v1_backup_key.set_version(version);
Some(megolm_v1_backup_key)
} else {
None
}
});
let identity = Arc::new(Mutex::new(identity));
let store = Arc::new(CryptoStoreWrapper::new(user_id, store));
Ok(OlmMachine::new_helper(device_id, store, static_account, identity, maybe_backup_key))
}
pub fn store(&self) -> &Store {
&self.inner.store
}
pub fn user_id(&self) -> &UserId {
&self.inner.user_id
}
pub fn device_id(&self) -> &DeviceId {
&self.inner.device_id
}
pub fn device_creation_time(&self) -> MilliSecondsSinceUnixEpoch {
self.inner.store.static_account().creation_local_time()
}
pub fn identity_keys(&self) -> IdentityKeys {
let account = self.inner.store.static_account();
account.identity_keys()
}
pub async fn display_name(&self) -> StoreResult<Option<String>> {
self.store().device_display_name().await
}
pub async fn tracked_users(&self) -> StoreResult<HashSet<OwnedUserId>> {
let cache = self.store().cache().await?;
Ok(self.inner.identity_manager.key_query_manager.synced(&cache).await?.tracked_users())
}
#[cfg(feature = "automatic-room-key-forwarding")]
pub fn set_room_key_requests_enabled(&self, enable: bool) {
self.inner.key_request_machine.set_room_key_requests_enabled(enable)
}
pub fn are_room_key_requests_enabled(&self) -> bool {
self.inner.key_request_machine.are_room_key_requests_enabled()
}
#[cfg(feature = "automatic-room-key-forwarding")]
pub fn set_room_key_forwarding_enabled(&self, enable: bool) {
self.inner.key_request_machine.set_room_key_forwarding_enabled(enable)
}
pub fn is_room_key_forwarding_enabled(&self) -> bool {
self.inner.key_request_machine.is_room_key_forwarding_enabled()
}
pub async fn outgoing_requests(&self) -> StoreResult<Vec<OutgoingRequest>> {
let mut requests = Vec::new();
{
let store_cache = self.inner.store.cache().await?;
let account = store_cache.account().await?;
if let Some(r) = self.keys_for_upload(&account).await.map(|r| OutgoingRequest {
request_id: TransactionId::new(),
request: Arc::new(r.into()),
}) {
requests.push(r);
}
}
for request in self
.inner
.identity_manager
.users_for_key_query()
.await?
.into_iter()
.map(|(request_id, r)| OutgoingRequest { request_id, request: Arc::new(r.into()) })
{
requests.push(request);
}
requests.append(&mut self.inner.verification_machine.outgoing_messages());
requests.append(&mut self.inner.key_request_machine.outgoing_to_device_requests().await?);
Ok(requests)
}
pub fn query_keys_for_users<'a>(
&self,
users: impl IntoIterator<Item = &'a UserId>,
) -> (OwnedTransactionId, KeysQueryRequest) {
self.inner.identity_manager.build_key_query_for_users(users)
}
pub async fn mark_request_as_sent<'a>(
&self,
request_id: &TransactionId,
response: impl Into<IncomingResponse<'a>>,
) -> OlmResult<()> {
match response.into() {
IncomingResponse::KeysUpload(response) => {
Box::pin(self.receive_keys_upload_response(response)).await?;
}
IncomingResponse::KeysQuery(response) => {
Box::pin(self.receive_keys_query_response(request_id, response)).await?;
}
IncomingResponse::KeysClaim(response) => {
Box::pin(
self.inner.session_manager.receive_keys_claim_response(request_id, response),
)
.await?;
}
IncomingResponse::ToDevice(_) => {
Box::pin(self.mark_to_device_request_as_sent(request_id)).await?;
}
IncomingResponse::SigningKeysUpload(_) => {
Box::pin(self.receive_cross_signing_upload_response()).await?;
}
IncomingResponse::SignatureUpload(_) => {
self.inner.verification_machine.mark_request_as_sent(request_id);
}
IncomingResponse::RoomMessage(_) => {
self.inner.verification_machine.mark_request_as_sent(request_id);
}
IncomingResponse::KeysBackup(_) => {
Box::pin(self.inner.backup_machine.mark_request_as_sent(request_id)).await?;
}
};
Ok(())
}
async fn receive_cross_signing_upload_response(&self) -> StoreResult<()> {
let identity = self.inner.user_identity.lock().await;
identity.mark_as_shared();
let changes = Changes { private_identity: Some(identity.clone()), ..Default::default() };
self.store().save_changes(changes).await
}
pub async fn bootstrap_cross_signing(
&self,
reset: bool,
) -> StoreResult<CrossSigningBootstrapRequests> {
let mut identity = self.inner.user_identity.lock().await;
let (upload_signing_keys_req, upload_signatures_req) = if reset || identity.is_empty().await
{
info!("Creating new cross signing identity");
let (new_identity, upload_signing_keys_req, upload_signatures_req) = {
let cache = self.inner.store.cache().await?;
let account = cache.account().await?;
account.bootstrap_cross_signing().await
};
*identity = new_identity;
let public = identity.to_public_identity().await.expect(
"Couldn't create a public version of the identity from a new private identity",
);
self.store()
.save_changes(Changes {
identities: IdentityChanges { new: vec![public.into()], ..Default::default() },
private_identity: Some(identity.clone()),
..Default::default()
})
.await?;
(upload_signing_keys_req, upload_signatures_req)
} else {
info!("Trying to upload the existing cross signing identity");
let upload_signing_keys_req = identity.as_upload_request().await;
let upload_signatures_req = identity
.sign_account(self.inner.store.static_account())
.await
.expect("Can't sign device keys");
(upload_signing_keys_req, upload_signatures_req)
};
let upload_keys_req = {
let cache = self.store().cache().await?;
let account = cache.account().await?;
if account.shared() {
None
} else {
self.keys_for_upload(&account).await.map(OutgoingRequest::from)
}
};
Ok(CrossSigningBootstrapRequests {
upload_signing_keys_req,
upload_keys_req,
upload_signatures_req,
})
}
async fn receive_keys_upload_response(&self, response: &UploadKeysResponse) -> OlmResult<()> {
self.inner
.store
.with_transaction(|mut tr| async {
let account = tr.account().await?;
account.receive_keys_upload_response(response)?;
Ok((tr, ()))
})
.await
}
#[instrument(skip_all)]
pub async fn get_missing_sessions(
&self,
users: impl Iterator<Item = &UserId>,
) -> StoreResult<Option<(OwnedTransactionId, KeysClaimRequest)>> {
self.inner.session_manager.get_missing_sessions(users).await
}
async fn receive_keys_query_response(
&self,
request_id: &TransactionId,
response: &KeysQueryResponse,
) -> OlmResult<(DeviceChanges, IdentityChanges)> {
self.inner.identity_manager.receive_keys_query_response(request_id, response).await
}
async fn keys_for_upload(&self, account: &Account) -> Option<UploadKeysRequest> {
let (device_keys, one_time_keys, fallback_keys) = account.keys_for_upload();
if device_keys.is_none() && one_time_keys.is_empty() && fallback_keys.is_empty() {
None
} else {
let device_keys = device_keys.map(|d| d.to_raw());
Some(assign!(UploadKeysRequest::new(), {
device_keys, one_time_keys, fallback_keys
}))
}
}
async fn decrypt_to_device_event(
&self,
transaction: &mut StoreTransaction,
event: &EncryptedToDeviceEvent,
changes: &mut Changes,
) -> OlmResult<OlmDecryptionInfo> {
let mut decrypted =
transaction.account().await?.decrypt_to_device_event(&self.inner.store, event).await?;
self.handle_decrypted_to_device_event(transaction.cache(), &mut decrypted, changes).await?;
Ok(decrypted)
}
#[instrument(
skip_all,
fields(room_id = ?content.room_id, session_id)
)]
async fn handle_key(
&self,
sender_key: Curve25519PublicKey,
event: &DecryptedRoomKeyEvent,
content: &MegolmV1AesSha2Content,
) -> OlmResult<Option<InboundGroupSession>> {
let session = InboundGroupSession::new(
sender_key,
event.keys.ed25519,
&content.room_id,
&content.session_key,
event.content.algorithm(),
None,
);
match session {
Ok(session) => {
Span::current().record("session_id", session.session_id());
if self.store().compare_group_session(&session).await? == SessionOrdering::Better {
info!("Received a new megolm room key");
Ok(Some(session))
} else {
warn!(
"Received a megolm room key that we already have a better version of, \
discarding",
);
Ok(None)
}
}
Err(e) => {
Span::current().record("session_id", &content.session_id);
warn!("Received a room key event which contained an invalid session key: {e}");
Ok(None)
}
}
}
#[instrument(skip_all, fields(algorithm = ?event.content.algorithm()))]
async fn add_room_key(
&self,
sender_key: Curve25519PublicKey,
event: &DecryptedRoomKeyEvent,
) -> OlmResult<Option<InboundGroupSession>> {
match &event.content {
RoomKeyContent::MegolmV1AesSha2(content) => {
self.handle_key(sender_key, event, content).await
}
#[cfg(feature = "experimental-algorithms")]
RoomKeyContent::MegolmV2AesSha2(content) => {
self.handle_key(sender_key, event, content).await
}
RoomKeyContent::Unknown(_) => {
warn!("Received a room key with an unsupported algorithm");
Ok(None)
}
}
}
async fn add_withheld_info(&self, changes: &mut Changes, event: &RoomKeyWithheldEvent) {
if let RoomKeyWithheldContent::MegolmV1AesSha2(
MegolmV1AesSha2WithheldContent::BlackListed(c)
| MegolmV1AesSha2WithheldContent::Unverified(c),
) = &event.content
{
changes
.withheld_session_info
.entry(c.room_id.to_owned())
.or_default()
.insert(c.session_id.to_owned(), event.to_owned());
}
}
#[cfg(test)]
pub(crate) async fn create_outbound_group_session_with_defaults_test_helper(
&self,
room_id: &RoomId,
) -> OlmResult<()> {
let (_, session) = self
.inner
.group_session_manager
.create_outbound_group_session(room_id, EncryptionSettings::default())
.await?;
self.store().save_inbound_group_sessions(&[session]).await?;
Ok(())
}
#[cfg(test)]
#[allow(dead_code)]
pub(crate) async fn create_inbound_session_test_helper(
&self,
room_id: &RoomId,
) -> OlmResult<InboundGroupSession> {
let (_, session) = self
.inner
.group_session_manager
.create_outbound_group_session(room_id, EncryptionSettings::default())
.await?;
Ok(session)
}
pub async fn encrypt_room_event(
&self,
room_id: &RoomId,
content: impl MessageLikeEventContent,
) -> MegolmResult<Raw<RoomEncryptedEventContent>> {
let event_type = content.event_type().to_string();
let content = Raw::new(&content)?.cast();
self.encrypt_room_event_raw(room_id, &event_type, &content).await
}
pub async fn encrypt_room_event_raw(
&self,
room_id: &RoomId,
event_type: &str,
content: &Raw<AnyMessageLikeEventContent>,
) -> MegolmResult<Raw<RoomEncryptedEventContent>> {
self.inner.group_session_manager.encrypt(room_id, event_type, content).await
}
pub async fn discard_room_key(&self, room_id: &RoomId) -> StoreResult<bool> {
self.inner.group_session_manager.invalidate_group_session(room_id).await
}
pub async fn share_room_key(
&self,
room_id: &RoomId,
users: impl Iterator<Item = &UserId>,
encryption_settings: impl Into<EncryptionSettings>,
) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
self.inner.group_session_manager.share_room_key(room_id, users, encryption_settings).await
}
#[deprecated(note = "Use OlmMachine::receive_verification_event instead", since = "0.7.0")]
pub async fn receive_unencrypted_verification_event(
&self,
event: &AnyMessageLikeEvent,
) -> StoreResult<()> {
self.inner.verification_machine.receive_any_event(event).await
}
pub async fn receive_verification_event(&self, event: &AnyMessageLikeEvent) -> StoreResult<()> {
self.inner.verification_machine.receive_any_event(event).await
}
#[instrument(
skip_all,
fields(
sender_key = ?decrypted.result.sender_key,
event_type = decrypted.result.event.event_type(),
),
)]
async fn handle_decrypted_to_device_event(
&self,
cache: &StoreCache,
decrypted: &mut OlmDecryptionInfo,
changes: &mut Changes,
) -> OlmResult<()> {
debug!("Received a decrypted to-device event");
match &*decrypted.result.event {
AnyDecryptedOlmEvent::RoomKey(e) => {
let session = self.add_room_key(decrypted.result.sender_key, e).await?;
decrypted.inbound_group_session = session;
}
AnyDecryptedOlmEvent::ForwardedRoomKey(e) => {
let session = self
.inner
.key_request_machine
.receive_forwarded_room_key(decrypted.result.sender_key, e)
.await?;
decrypted.inbound_group_session = session;
}
AnyDecryptedOlmEvent::SecretSend(e) => {
let name = self
.inner
.key_request_machine
.receive_secret_event(cache, decrypted.result.sender_key, e, changes)
.await?;
if let Ok(ToDeviceEvents::SecretSend(mut e)) =
decrypted.result.raw_event.deserialize_as()
{
e.content.secret_name = name;
decrypted.result.raw_event = Raw::from_json(to_raw_value(&e)?);
}
}
AnyDecryptedOlmEvent::Dummy(_) => {
debug!("Received an `m.dummy` event");
}
AnyDecryptedOlmEvent::Custom(_) => {
warn!("Received an unexpected encrypted to-device event");
}
}
Ok(())
}
async fn handle_verification_event(&self, event: &ToDeviceEvents) {
if let Err(e) = self.inner.verification_machine.receive_any_event(event).await {
error!("Error handling a verification event: {e:?}");
}
}
async fn mark_to_device_request_as_sent(&self, request_id: &TransactionId) -> StoreResult<()> {
self.inner.verification_machine.mark_request_as_sent(request_id);
self.inner.key_request_machine.mark_outgoing_request_as_sent(request_id).await?;
self.inner.group_session_manager.mark_request_as_sent(request_id).await?;
self.inner.session_manager.mark_outgoing_request_as_sent(request_id);
Ok(())
}
pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
self.inner.verification_machine.get_verification(user_id, flow_id)
}
pub fn get_verification_request(
&self,
user_id: &UserId,
flow_id: impl AsRef<str>,
) -> Option<VerificationRequest> {
self.inner.verification_machine.get_request(user_id, flow_id)
}
pub fn get_verification_requests(&self, user_id: &UserId) -> Vec<VerificationRequest> {
self.inner.verification_machine.get_requests(user_id)
}
async fn handle_to_device_event(&self, changes: &mut Changes, event: &ToDeviceEvents) {
use crate::types::events::ToDeviceEvents::*;
match event {
RoomKeyRequest(e) => self.inner.key_request_machine.receive_incoming_key_request(e),
SecretRequest(e) => self.inner.key_request_machine.receive_incoming_secret_request(e),
RoomKeyWithheld(e) => self.add_withheld_info(changes, e).await,
KeyVerificationAccept(..)
| KeyVerificationCancel(..)
| KeyVerificationKey(..)
| KeyVerificationMac(..)
| KeyVerificationRequest(..)
| KeyVerificationReady(..)
| KeyVerificationDone(..)
| KeyVerificationStart(..) => {
self.handle_verification_event(event).await;
}
Dummy(_) | RoomKey(_) | ForwardedRoomKey(_) | RoomEncrypted(_) => {}
_ => {}
}
}
fn record_message_id(event: &Raw<AnyToDeviceEvent>) {
use serde::Deserialize;
#[derive(Deserialize)]
struct ContentStub<'a> {
#[serde(borrow, rename = "org.matrix.msgid")]
message_id: Option<&'a str>,
}
#[derive(Deserialize)]
struct ToDeviceStub<'a> {
sender: &'a str,
#[serde(rename = "type")]
event_type: &'a str,
#[serde(borrow)]
content: ContentStub<'a>,
}
if let Ok(event) = event.deserialize_as::<ToDeviceStub<'_>>() {
Span::current().record("sender", event.sender);
Span::current().record("event_type", event.event_type);
Span::current().record("message_id", event.content.message_id);
}
}
#[instrument(skip_all, fields(sender, event_type, message_id))]
async fn receive_to_device_event(
&self,
transaction: &mut StoreTransaction,
changes: &mut Changes,
mut raw_event: Raw<AnyToDeviceEvent>,
) -> Raw<AnyToDeviceEvent> {
Self::record_message_id(&raw_event);
let event: ToDeviceEvents = match raw_event.deserialize_as() {
Ok(e) => e,
Err(e) => {
warn!("Received an invalid to-device event: {e}");
return raw_event;
}
};
debug!("Received a to-device event");
match event {
ToDeviceEvents::RoomEncrypted(e) => {
let decrypted = match self.decrypt_to_device_event(transaction, &e, changes).await {
Ok(e) => e,
Err(err) => {
if let OlmError::SessionWedged(sender, curve_key) = err {
if let Err(e) = self
.inner
.session_manager
.mark_device_as_wedged(&sender, curve_key)
.await
{
error!(
error = ?e,
"Couldn't mark device from to be unwedged",
);
}
}
return raw_event;
}
};
match decrypted.session {
SessionType::New(s) | SessionType::Existing(s) => {
changes.sessions.push(s);
}
}
changes.message_hashes.push(decrypted.message_hash);
if let Some(group_session) = decrypted.inbound_group_session {
changes.inbound_group_sessions.push(group_session);
}
match decrypted.result.raw_event.deserialize_as() {
Ok(event) => {
self.handle_to_device_event(changes, &event).await;
raw_event = event
.serialize_zeroized()
.expect("Zeroizing and reserializing our events should always work")
.cast();
}
Err(e) => {
warn!("Received an invalid encrypted to-device event: {e}");
raw_event = decrypted.result.raw_event;
}
}
}
e => self.handle_to_device_event(changes, &e).await,
}
raw_event
}
#[instrument(skip_all)]
pub async fn receive_sync_changes(
&self,
sync_changes: EncryptionSyncChanges<'_>,
) -> OlmResult<(Vec<Raw<AnyToDeviceEvent>>, Vec<RoomKeyInfo>)> {
let mut store_transaction = self.inner.store.transaction().await;
let (events, changes) =
self.preprocess_sync_changes(&mut store_transaction, sync_changes).await?;
let room_key_updates: Vec<_> =
changes.inbound_group_sessions.iter().map(RoomKeyInfo::from).collect();
self.store().save_changes(changes).await?;
store_transaction.commit().await?;
Ok((events, room_key_updates))
}
pub(crate) async fn preprocess_sync_changes(
&self,
transaction: &mut StoreTransaction,
sync_changes: EncryptionSyncChanges<'_>,
) -> OlmResult<(Vec<Raw<AnyToDeviceEvent>>, Changes)> {
let mut events = self.inner.verification_machine.garbage_collect();
let mut changes = Default::default();
{
let account = transaction.account().await?;
account.update_key_counts(
sync_changes.one_time_keys_counts,
sync_changes.unused_fallback_keys,
)
}
if let Err(e) = self
.inner
.identity_manager
.receive_device_changes(
transaction.cache(),
sync_changes.changed_devices.changed.iter().map(|u| u.as_ref()),
)
.await
{
error!(error = ?e, "Error marking a tracked user as changed");
}
for raw_event in sync_changes.to_device_events {
let raw_event =
Box::pin(self.receive_to_device_event(transaction, &mut changes, raw_event)).await;
events.push(raw_event);
}
let changed_sessions = self
.inner
.key_request_machine
.collect_incoming_key_requests(transaction.cache())
.await?;
changes.sessions.extend(changed_sessions);
changes.next_batch_token = sync_changes.next_batch_token;
Ok((events, changes))
}
pub async fn request_room_key(
&self,
event: &Raw<EncryptedEvent>,
room_id: &RoomId,
) -> MegolmResult<(Option<OutgoingRequest>, OutgoingRequest)> {
let event = event.deserialize()?;
self.inner.key_request_machine.request_key(room_id, &event).await
}
async fn get_verification_state(
&self,
session: &InboundGroupSession,
sender: &UserId,
) -> MegolmResult<(VerificationState, Option<OwnedDeviceId>)> {
let claimed_device = self
.get_user_devices(sender, None)
.await?
.devices()
.find(|d| d.curve25519_key() == Some(session.sender_key()));
Ok(match claimed_device {
None => {
let link_problem = if session.has_been_imported() {
DeviceLinkProblem::InsecureSource
} else {
DeviceLinkProblem::MissingDevice
};
(VerificationState::Unverified(VerificationLevel::None(link_problem)), None)
}
Some(device) => {
let device_id = device.device_id().to_owned();
if !(device.is_owner_of_session(session)?) {
(
VerificationState::Unverified(VerificationLevel::None(
DeviceLinkProblem::InsecureSource,
)),
Some(device_id),
)
} else {
if device.is_cross_signed_by_owner() {
if device.is_device_owner_verified() {
(VerificationState::Verified, Some(device_id))
} else {
(
VerificationState::Unverified(
VerificationLevel::UnverifiedIdentity,
),
Some(device_id),
)
}
} else {
(
VerificationState::Unverified(VerificationLevel::UnsignedDevice),
Some(device_id),
)
}
}
}
})
}
pub async fn query_missing_secrets_from_other_sessions(&self) -> StoreResult<bool> {
let identity = self.inner.user_identity.lock().await;
let mut secrets = identity.get_missing_secrets().await;
if self.store().load_backup_keys().await?.decryption_key.is_none() {
secrets.push(SecretName::RecoveryKey);
}
if secrets.is_empty() {
debug!("No missing requests to query");
return Ok(false);
}
let secret_requests = GossipMachine::request_missing_secrets(self.user_id(), secrets);
let unsent_request = self.store().get_unsent_secret_requests().await?;
let not_yet_requested = secret_requests
.into_iter()
.filter(|request| !unsent_request.iter().any(|unsent| unsent.info == request.info))
.collect_vec();
if not_yet_requested.is_empty() {
debug!("The missing secrets have already been requested");
Ok(false)
} else {
debug!("Requesting missing secrets");
let changes = Changes { key_requests: not_yet_requested, ..Default::default() };
self.store().save_changes(changes).await?;
Ok(true)
}
}
async fn get_encryption_info(
&self,
session: &InboundGroupSession,
sender: &UserId,
) -> MegolmResult<EncryptionInfo> {
let (verification_state, device_id) = self.get_verification_state(session, sender).await?;
let sender = sender.to_owned();
Ok(EncryptionInfo {
sender,
sender_device: device_id,
algorithm_info: AlgorithmInfo::MegolmV1AesSha2 {
curve25519_key: session.sender_key().to_base64(),
sender_claimed_keys: session
.signing_keys()
.iter()
.map(|(k, v)| (k.to_owned(), v.to_base64()))
.collect(),
},
verification_state,
})
}
async fn get_megolm_encryption_info(
&self,
room_id: &RoomId,
event: &EncryptedEvent,
content: &SupportedEventEncryptionSchemes<'_>,
) -> MegolmResult<EncryptionInfo> {
let session =
self.get_inbound_group_session_or_error(room_id, content.session_id()).await?;
self.get_encryption_info(&session, &event.sender).await
}
async fn decrypt_megolm_events(
&self,
room_id: &RoomId,
event: &EncryptedEvent,
content: &SupportedEventEncryptionSchemes<'_>,
) -> MegolmResult<TimelineEvent> {
let session =
self.get_inbound_group_session_or_error(room_id, content.session_id()).await?;
Span::current().record("sender_key", debug(session.sender_key()));
let result = session.decrypt(event).await;
match result {
Ok((decrypted_event, _)) => {
let encryption_info = self.get_encryption_info(&session, &event.sender).await?;
Ok(TimelineEvent {
encryption_info: Some(encryption_info),
event: decrypted_event,
push_actions: None,
})
}
Err(error) => Err(
if let MegolmError::Decryption(DecryptionError::UnknownMessageIndex(_, _)) = error {
let withheld_code = self
.inner
.store
.get_withheld_info(room_id, content.session_id())
.await?
.map(|e| e.content.withheld_code());
if withheld_code.is_some() {
MegolmError::MissingRoomKey(withheld_code)
} else {
error
}
} else {
error
},
),
}
}
async fn get_inbound_group_session_or_error(
&self,
room_id: &RoomId,
session_id: &str,
) -> MegolmResult<InboundGroupSession> {
match self.store().get_inbound_group_session(room_id, session_id).await? {
Some(session) => Ok(session),
None => {
let withheld_code = self
.inner
.store
.get_withheld_info(room_id, session_id)
.await?
.map(|e| e.content.withheld_code());
Err(MegolmError::MissingRoomKey(withheld_code))
}
}
}
#[instrument(skip_all, fields(?room_id, event_id, origin_server_ts, sender, algorithm, session_id, sender_key))]
pub async fn decrypt_room_event(
&self,
event: &Raw<EncryptedEvent>,
room_id: &RoomId,
) -> MegolmResult<TimelineEvent> {
let event = event.deserialize()?;
Span::current()
.record("sender", debug(&event.sender))
.record("event_id", debug(&event.event_id))
.record(
"origin_server_ts",
timestamp_to_iso8601(event.origin_server_ts)
.unwrap_or_else(|| "<out of range>".to_owned()),
)
.record("algorithm", debug(event.content.algorithm()));
let content: SupportedEventEncryptionSchemes<'_> = match &event.content.scheme {
RoomEventEncryptionScheme::MegolmV1AesSha2(c) => {
Span::current().record("sender_key", debug(c.sender_key));
c.into()
}
#[cfg(feature = "experimental-algorithms")]
RoomEventEncryptionScheme::MegolmV2AesSha2(c) => c.into(),
RoomEventEncryptionScheme::Unknown(_) => {
warn!("Received an encrypted room event with an unsupported algorithm");
return Err(EventError::UnsupportedAlgorithm.into());
}
};
Span::current().record("session_id", content.session_id());
let result = self.decrypt_megolm_events(room_id, &event, &content).await;
if let Err(e) = &result {
#[cfg(feature = "automatic-room-key-forwarding")]
match e {
MegolmError::MissingRoomKey(_)
| MegolmError::Decryption(DecryptionError::UnknownMessageIndex(_, _)) => {
self.inner
.key_request_machine
.create_outgoing_key_request(room_id, &event)
.await?;
}
_ => {}
}
warn!("Failed to decrypt a room event: {e}");
}
result
}
pub async fn is_room_key_available(
&self,
room_id: &RoomId,
session_id: &str,
) -> Result<bool, CryptoStoreError> {
Ok(self.store().get_inbound_group_session(room_id, session_id).await?.is_some())
}
pub async fn get_room_event_encryption_info(
&self,
event: &Raw<EncryptedEvent>,
room_id: &RoomId,
) -> MegolmResult<EncryptionInfo> {
let event = event.deserialize()?;
let content: SupportedEventEncryptionSchemes<'_> = match &event.content.scheme {
RoomEventEncryptionScheme::MegolmV1AesSha2(c) => c.into(),
#[cfg(feature = "experimental-algorithms")]
RoomEventEncryptionScheme::MegolmV2AesSha2(c) => c.into(),
RoomEventEncryptionScheme::Unknown(_) => {
return Err(EventError::UnsupportedAlgorithm.into());
}
};
self.get_megolm_encryption_info(room_id, &event, &content).await
}
pub async fn update_tracked_users(
&self,
users: impl IntoIterator<Item = &UserId>,
) -> StoreResult<()> {
self.inner.identity_manager.update_tracked_users(users).await
}
async fn wait_if_user_pending(
&self,
user_id: &UserId,
timeout: Option<Duration>,
) -> StoreResult<()> {
if let Some(timeout) = timeout {
let cache = self.store().cache().await?;
self.inner
.identity_manager
.key_query_manager
.wait_if_user_key_query_pending(cache, timeout, user_id)
.await?;
}
Ok(())
}
#[instrument(skip(self))]
pub async fn get_device(
&self,
user_id: &UserId,
device_id: &DeviceId,
timeout: Option<Duration>,
) -> StoreResult<Option<Device>> {
self.wait_if_user_pending(user_id, timeout).await?;
self.store().get_device(user_id, device_id).await
}
#[instrument(skip(self))]
pub async fn get_identity(
&self,
user_id: &UserId,
timeout: Option<Duration>,
) -> StoreResult<Option<UserIdentities>> {
self.wait_if_user_pending(user_id, timeout).await?;
self.store().get_identity(user_id).await
}
#[instrument(skip(self))]
pub async fn get_user_devices(
&self,
user_id: &UserId,
timeout: Option<Duration>,
) -> StoreResult<UserDevices> {
self.wait_if_user_pending(user_id, timeout).await?;
self.store().get_user_devices(user_id).await
}
#[deprecated(
since = "0.7.0",
note = "Use the OlmMachine::store::import_exported_room_keys method instead"
)]
pub async fn import_room_keys(
&self,
exported_keys: Vec<ExportedRoomKey>,
from_backup: bool,
progress_listener: impl Fn(usize, usize),
) -> StoreResult<RoomKeyImportResult> {
self.store().import_room_keys(exported_keys, from_backup, progress_listener).await
}
pub async fn cross_signing_status(&self) -> CrossSigningStatus {
self.inner.user_identity.lock().await.status().await
}
pub async fn export_cross_signing_keys(&self) -> StoreResult<Option<CrossSigningKeyExport>> {
let master_key = self.store().export_secret(&SecretName::CrossSigningMasterKey).await?;
let self_signing_key =
self.store().export_secret(&SecretName::CrossSigningSelfSigningKey).await?;
let user_signing_key =
self.store().export_secret(&SecretName::CrossSigningUserSigningKey).await?;
Ok(if master_key.is_none() && self_signing_key.is_none() && user_signing_key.is_none() {
None
} else {
Some(CrossSigningKeyExport { master_key, self_signing_key, user_signing_key })
})
}
pub async fn import_cross_signing_keys(
&self,
export: CrossSigningKeyExport,
) -> Result<CrossSigningStatus, SecretImportError> {
self.store().import_cross_signing_keys(export).await
}
async fn sign_with_master_key(
&self,
message: &str,
) -> Result<(OwnedDeviceKeyId, Ed25519Signature), SignatureError> {
let identity = &*self.inner.user_identity.lock().await;
let key_id = identity.master_key_id().await.ok_or(SignatureError::MissingSigningKey)?;
let signature = identity.sign(message).await?;
Ok((key_id, signature))
}
pub async fn sign(&self, message: &str) -> Result<Signatures, CryptoStoreError> {
let mut signatures = Signatures::new();
{
let cache = self.inner.store.cache().await?;
let account = cache.account().await?;
let key_id = account.signing_key_id();
let signature = account.sign(message);
signatures.add_signature(self.user_id().to_owned(), key_id, signature);
}
match self.sign_with_master_key(message).await {
Ok((key_id, signature)) => {
signatures.add_signature(self.user_id().to_owned(), key_id, signature);
}
Err(e) => {
warn!(error = ?e, "Couldn't sign the message using the cross signing master key")
}
}
Ok(signatures)
}
pub fn backup_machine(&self) -> &BackupMachine {
&self.inner.backup_machine
}
pub async fn initialize_crypto_store_generation(
&self,
generation: &Mutex<Option<u64>>,
) -> StoreResult<()> {
let mut gen_guard = generation.lock().await;
let prev_generation =
self.inner.store.get_custom_value(Self::CURRENT_GENERATION_STORE_KEY).await?;
let gen = match prev_generation {
Some(val) => {
u64::from_le_bytes(val.try_into().map_err(|_| {
CryptoStoreError::InvalidLockGeneration("invalid format".to_owned())
})?)
.wrapping_add(1)
}
None => 0,
};
tracing::debug!("Initialising crypto store generation at {}", gen);
self.inner
.store
.set_custom_value(Self::CURRENT_GENERATION_STORE_KEY, gen.to_le_bytes().to_vec())
.await?;
*gen_guard = Some(gen);
Ok(())
}
pub async fn maintain_crypto_store_generation<'a>(
&'a self,
generation: &Mutex<Option<u64>>,
) -> StoreResult<(bool, u64)> {
let mut gen_guard = generation.lock().await;
let actual_gen = self
.inner
.store
.get_custom_value(Self::CURRENT_GENERATION_STORE_KEY)
.await?
.ok_or_else(|| {
CryptoStoreError::InvalidLockGeneration("counter missing in store".to_owned())
})?;
let actual_gen =
u64::from_le_bytes(actual_gen.try_into().map_err(|_| {
CryptoStoreError::InvalidLockGeneration("invalid format".to_owned())
})?);
let new_gen = match gen_guard.as_ref() {
Some(expected_gen) => {
if actual_gen == *expected_gen {
return Ok((false, actual_gen));
}
actual_gen.max(*expected_gen).wrapping_add(1)
}
None => {
actual_gen.wrapping_add(1)
}
};
tracing::debug!(
"Crypto store generation mismatch: previously known was {:?}, actual is {:?}, next is {}",
*gen_guard,
actual_gen,
new_gen
);
*gen_guard = Some(new_gen);
self.inner
.store
.set_custom_value(Self::CURRENT_GENERATION_STORE_KEY, new_gen.to_le_bytes().to_vec())
.await?;
Ok((true, new_gen))
}
pub fn dehydrated_devices(&self) -> DehydratedDevices {
DehydratedDevices { inner: self.to_owned() }
}
pub async fn room_settings(&self, room_id: &RoomId) -> StoreResult<Option<RoomSettings>> {
self.inner.store.get_room_settings(room_id).await
}
pub async fn set_room_settings(
&self,
room_id: &RoomId,
new_settings: &RoomSettings,
) -> Result<(), SetRoomSettingsError> {
let store = &self.inner.store;
let _store_transaction = store.transaction().await;
let old_settings = store.get_room_settings(room_id).await?;
if let Some(old_settings) = old_settings {
if old_settings != *new_settings {
return Err(SetRoomSettingsError::EncryptionDowngrade);
} else {
return Ok(());
}
}
match new_settings.algorithm {
EventEncryptionAlgorithm::MegolmV1AesSha2 => (),
#[cfg(feature = "experimental-algorithms")]
EventEncryptionAlgorithm::MegolmV2AesSha2 => (),
_ => {
warn!(
?room_id,
"Rejecting invalid encryption algorithm {}", new_settings.algorithm
);
return Err(SetRoomSettingsError::InvalidSettings);
}
}
store
.save_changes(Changes {
room_settings: HashMap::from([(room_id.to_owned(), new_settings.clone())]),
..Default::default()
})
.await?;
Ok(())
}
#[cfg(any(feature = "testing", test))]
pub fn same_as(&self, other: &OlmMachine) -> bool {
Arc::ptr_eq(&self.inner, &other.inner)
}
#[cfg(any(feature = "testing", test))]
pub async fn uploaded_key_count(&self) -> Result<u64, CryptoStoreError> {
let cache = self.inner.store.cache().await?;
let account = cache.account().await?;
Ok(account.uploaded_key_count())
}
}
#[derive(Debug)]
pub struct CrossSigningBootstrapRequests {
pub upload_keys_req: Option<OutgoingRequest>,
pub upload_signing_keys_req: UploadSigningKeysRequest,
pub upload_signatures_req: UploadSignaturesRequest,
}
#[derive(Debug)]
pub struct EncryptionSyncChanges<'a> {
pub to_device_events: Vec<Raw<AnyToDeviceEvent>>,
pub changed_devices: &'a DeviceLists,
pub one_time_keys_counts: &'a BTreeMap<DeviceKeyAlgorithm, UInt>,
pub unused_fallback_keys: Option<&'a [DeviceKeyAlgorithm]>,
pub next_batch_token: Option<String>,
}
#[cfg(any(feature = "testing", test))]
#[allow(dead_code)]
pub(crate) mod testing {
use http::Response;
pub fn response_from_file(json: &serde_json::Value) -> Response<Vec<u8>> {
Response::builder().status(200).body(json.to_string().as_bytes().to_vec()).unwrap()
}
}
#[cfg(test)]
pub(crate) mod tests {
use std::{
collections::BTreeMap,
iter,
sync::Arc,
time::{Duration, SystemTime},
};
use assert_matches::assert_matches;
use assert_matches2::assert_let;
use futures_util::{FutureExt, StreamExt};
use itertools::Itertools;
use matrix_sdk_common::deserialized_responses::{
DeviceLinkProblem, ShieldState, VerificationLevel, VerificationState,
};
use matrix_sdk_test::{async_test, message_like_event_content, test_json};
use ruma::{
api::{
client::{
keys::{
claim_keys, get_keys, get_keys::v3::Response as KeyQueryResponse, upload_keys,
},
sync::sync_events::DeviceLists,
to_device::send_event_to_device::v3::Response as ToDeviceResponse,
},
IncomingResponse,
},
device_id,
encryption::OneTimeKey,
events::{
dummy::ToDeviceDummyEventContent,
key::verification::VerificationMethod,
room::message::{MessageType, RoomMessageEventContent},
AnyMessageLikeEvent, AnyMessageLikeEventContent, AnyTimelineEvent, AnyToDeviceEvent,
MessageLikeEvent, OriginalMessageLikeEvent,
},
room_id,
serde::Raw,
to_device::DeviceIdOrAllDevices,
uint, user_id, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch,
OwnedDeviceKeyId, SecondsSinceUnixEpoch, TransactionId, UserId,
};
use serde_json::{json, value::to_raw_value};
use vodozemac::{
megolm::{GroupSession, SessionConfig},
Curve25519PublicKey, Ed25519PublicKey,
};
use super::{testing::response_from_file, CrossSigningBootstrapRequests};
use crate::{
error::{EventError, SetRoomSettingsError},
machine::{EncryptionSyncChanges, OlmMachine},
olm::{
BackedUpRoomKey, ExportedRoomKey, InboundGroupSession, OutboundGroupSession, VerifyJson,
},
store::{BackupDecryptionKey, Changes, CryptoStore, MemoryStore, RoomSettings},
types::{
events::{
room::encrypted::{EncryptedToDeviceEvent, ToDeviceEncryptedEventContent},
room_key_withheld::{RoomKeyWithheldContent, WithheldCode},
ToDeviceEvent,
},
CrossSigningKey, DeviceKeys, EventEncryptionAlgorithm, SignedKey, SigningKeys,
},
utilities::json_convert,
verification::tests::{bob_id, outgoing_request_to_event, request_to_event},
Account, EncryptionSettings, LocalTrust, MegolmError, OlmError, OutgoingRequests,
ReadOnlyDevice, ToDeviceRequest, UserIdentities,
};
type OneTimeKeys = BTreeMap<OwnedDeviceKeyId, Raw<OneTimeKey>>;
fn alice_id() -> &'static UserId {
user_id!("@alice:example.org")
}
fn alice_device_id() -> &'static DeviceId {
device_id!("JLAFKJWSCS")
}
fn bob_device_id() -> &'static DeviceId {
device_id!("NTHHPZDPRN")
}
fn user_id() -> &'static UserId {
user_id!("@bob:example.com")
}
fn keys_upload_response() -> upload_keys::v3::Response {
let data = response_from_file(&test_json::KEYS_UPLOAD);
upload_keys::v3::Response::try_from_http_response(data)
.expect("Can't parse the `/keys/upload` response")
}
fn keys_query_response() -> get_keys::v3::Response {
let data = response_from_file(&test_json::KEYS_QUERY);
get_keys::v3::Response::try_from_http_response(data)
.expect("Can't parse the `/keys/upload` response")
}
pub fn to_device_requests_to_content(
requests: Vec<Arc<ToDeviceRequest>>,
) -> ToDeviceEncryptedEventContent {
let to_device_request = &requests[0];
to_device_request
.messages
.values()
.next()
.unwrap()
.values()
.next()
.unwrap()
.deserialize_as()
.unwrap()
}
pub(crate) async fn get_prepared_machine_test_helper(
user_id: &UserId,
use_fallback_key: bool,
) -> (OlmMachine, OneTimeKeys) {
let machine = OlmMachine::new(user_id, bob_device_id()).await;
let request = machine
.store()
.with_transaction(|mut tr| async {
let account = tr.account().await.unwrap();
account.generate_fallback_key_if_needed();
account.update_uploaded_key_count(0);
account.generate_one_time_keys_if_needed();
let request = machine
.keys_for_upload(account)
.await
.expect("Can't prepare initial key upload");
Ok((tr, request))
})
.await
.unwrap();
let response = keys_upload_response();
machine.receive_keys_upload_response(&response).await.unwrap();
let keys = if use_fallback_key { request.fallback_keys } else { request.one_time_keys };
(machine, keys)
}
async fn get_machine_after_query_test_helper() -> (OlmMachine, OneTimeKeys) {
let (machine, otk) = get_prepared_machine_test_helper(user_id(), false).await;
let response = keys_query_response();
let req_id = TransactionId::new();
machine.receive_keys_query_response(&req_id, &response).await.unwrap();
(machine, otk)
}
pub async fn get_machine_pair(
alice: &UserId,
bob: &UserId,
use_fallback_key: bool,
) -> (OlmMachine, OlmMachine, OneTimeKeys) {
let (bob, otk) = get_prepared_machine_test_helper(bob, use_fallback_key).await;
let alice_device = alice_device_id();
let alice = OlmMachine::new(alice, alice_device).await;
let alice_device = ReadOnlyDevice::from_machine_test_helper(&alice).await.unwrap();
let bob_device = ReadOnlyDevice::from_machine_test_helper(&bob).await.unwrap();
alice.store().save_devices(&[bob_device]).await.unwrap();
bob.store().save_devices(&[alice_device]).await.unwrap();
(alice, bob, otk)
}
async fn get_machine_pair_with_session(
alice: &UserId,
bob: &UserId,
use_fallback_key: bool,
) -> (OlmMachine, OlmMachine) {
let (alice, bob, mut one_time_keys) = get_machine_pair(alice, bob, use_fallback_key).await;
let (device_key_id, one_time_key) = one_time_keys.pop_first().unwrap();
let one_time_keys = BTreeMap::from([(
bob.user_id().to_owned(),
BTreeMap::from([(
bob.device_id().to_owned(),
BTreeMap::from([(device_key_id, one_time_key)]),
)]),
)]);
let response = claim_keys::v3::Response::new(one_time_keys);
alice.inner.session_manager.create_sessions(&response).await.unwrap();
(alice, bob)
}
pub(crate) async fn get_machine_pair_with_setup_sessions_test_helper(
alice: &UserId,
bob: &UserId,
use_fallback_key: bool,
) -> (OlmMachine, OlmMachine) {
let (alice, bob) = get_machine_pair_with_session(alice, bob, use_fallback_key).await;
let bob_device =
alice.get_device(bob.user_id(), bob.device_id(), None).await.unwrap().unwrap();
let (session, content) =
bob_device.encrypt("m.dummy", ToDeviceDummyEventContent::new()).await.unwrap();
alice.store().save_sessions(&[session]).await.unwrap();
let event =
ToDeviceEvent::new(alice.user_id().to_owned(), content.deserialize_as().unwrap());
let decrypted = bob
.store()
.with_transaction(|mut tr| async {
let res =
bob.decrypt_to_device_event(&mut tr, &event, &mut Changes::default()).await?;
Ok((tr, res))
})
.await
.unwrap();
bob.store().save_sessions(&[decrypted.session.session()]).await.unwrap();
(alice, bob)
}
#[async_test]
async fn test_create_olm_machine() {
let test_start_ts = MilliSecondsSinceUnixEpoch::now();
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
let device_creation_time = machine.device_creation_time();
assert!(device_creation_time <= MilliSecondsSinceUnixEpoch::now());
assert!(device_creation_time >= test_start_ts);
let cache = machine.store().cache().await.unwrap();
let account = cache.account().await.unwrap();
assert!(!account.shared());
let own_device = machine
.get_device(machine.user_id(), machine.device_id(), None)
.await
.unwrap()
.expect("We should always have our own device in the store");
assert!(own_device.is_locally_trusted(), "Our own device should always be locally trusted");
}
#[async_test]
async fn test_generate_one_time_keys() {
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
machine
.store()
.with_transaction(|mut tr| async {
let account = tr.account().await.unwrap();
assert!(account.generate_one_time_keys_if_needed().is_some());
Ok((tr, ()))
})
.await
.unwrap();
let mut response = keys_upload_response();
machine.receive_keys_upload_response(&response).await.unwrap();
machine
.store()
.with_transaction(|mut tr| async {
let account = tr.account().await.unwrap();
assert!(account.generate_one_time_keys_if_needed().is_some());
Ok((tr, ()))
})
.await
.unwrap();
response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
machine.receive_keys_upload_response(&response).await.unwrap();
machine
.store()
.with_transaction(|mut tr| async {
let account = tr.account().await.unwrap();
assert!(account.generate_one_time_keys_if_needed().is_none());
Ok((tr, ()))
})
.await
.unwrap();
}
#[async_test]
async fn test_device_key_signing() {
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
let (device_keys, identity_keys) = {
let cache = machine.store().cache().await.unwrap();
let account = cache.account().await.unwrap();
let device_keys = account.device_keys();
let identity_keys = account.identity_keys();
(device_keys, identity_keys)
};
let ed25519_key = identity_keys.ed25519;
let ret = ed25519_key.verify_json(
machine.user_id(),
&DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, machine.device_id()),
&device_keys,
);
ret.unwrap();
}
#[async_test]
async fn test_session_invalidation() {
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
let room_id = room_id!("!test:example.org");
machine.create_outbound_group_session_with_defaults_test_helper(room_id).await.unwrap();
assert!(machine.inner.group_session_manager.get_outbound_group_session(room_id).is_some());
machine.discard_room_key(room_id).await.unwrap();
assert!(machine
.inner
.group_session_manager
.get_outbound_group_session(room_id)
.unwrap()
.invalidated());
}
#[test]
fn test_invalid_signature() {
let account = Account::with_device_id(user_id(), alice_device_id());
let device_keys = account.device_keys();
let key = Ed25519PublicKey::from_slice(&[0u8; 32]).unwrap();
let ret = key.verify_json(
account.user_id(),
&DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, account.device_id()),
&device_keys,
);
ret.unwrap_err();
}
#[test]
fn test_one_time_key_signing() {
let mut account = Account::with_device_id(user_id(), alice_device_id());
account.update_uploaded_key_count(49);
account.generate_one_time_keys_if_needed();
let mut one_time_keys = account.signed_one_time_keys();
let ed25519_key = account.identity_keys().ed25519;
let one_time_key: SignedKey = one_time_keys
.values_mut()
.next()
.expect("One time keys should be generated")
.deserialize_as()
.unwrap();
ed25519_key
.verify_json(
account.user_id(),
&DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, account.device_id()),
&one_time_key,
)
.expect("One-time key has been signed successfully");
}
#[async_test]
async fn test_keys_for_upload() {
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
let key_counts = BTreeMap::from([(DeviceKeyAlgorithm::SignedCurve25519, 49u8.into())]);
machine
.receive_sync_changes(EncryptionSyncChanges {
to_device_events: Vec::new(),
changed_devices: &Default::default(),
one_time_keys_counts: &key_counts,
unused_fallback_keys: None,
next_batch_token: None,
})
.await
.expect("We should be able to update our one-time key counts");
let (ed25519_key, mut request) = {
let cache = machine.store().cache().await.unwrap();
let account = cache.account().await.unwrap();
let ed25519_key = account.identity_keys().ed25519;
let request =
machine.keys_for_upload(&account).await.expect("Can't prepare initial key upload");
(ed25519_key, request)
};
let one_time_key: SignedKey = request
.one_time_keys
.values_mut()
.next()
.expect("One time keys should be generated")
.deserialize_as()
.unwrap();
let ret = ed25519_key.verify_json(
machine.user_id(),
&DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, machine.device_id()),
&one_time_key,
);
ret.unwrap();
let device_keys: DeviceKeys = request.device_keys.unwrap().deserialize_as().unwrap();
let ret = ed25519_key.verify_json(
machine.user_id(),
&DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, machine.device_id()),
&device_keys,
);
ret.unwrap();
let response = {
let cache = machine.store().cache().await.unwrap();
let account = cache.account().await.unwrap();
let mut response = keys_upload_response();
response.one_time_key_counts.insert(
DeviceKeyAlgorithm::SignedCurve25519,
account.max_one_time_keys().try_into().unwrap(),
);
response
};
machine.receive_keys_upload_response(&response).await.unwrap();
{
let cache = machine.store().cache().await.unwrap();
let account = cache.account().await.unwrap();
let ret = machine.keys_for_upload(&account).await;
assert!(ret.is_none());
}
}
#[async_test]
async fn test_keys_query() {
let (machine, _) = get_prepared_machine_test_helper(user_id(), false).await;
let response = keys_query_response();
let alice_id = user_id!("@alice:example.org");
let alice_device_id: &DeviceId = device_id!("JLAFKJWSCS");
let alice_devices = machine.store().get_user_devices(alice_id).await.unwrap();
assert!(alice_devices.devices().peekable().peek().is_none());
let req_id = TransactionId::new();
machine.receive_keys_query_response(&req_id, &response).await.unwrap();
let device = machine.store().get_device(alice_id, alice_device_id).await.unwrap().unwrap();
assert_eq!(device.user_id(), alice_id);
assert_eq!(device.device_id(), alice_device_id);
}
#[async_test]
async fn test_query_keys_for_users() {
let (machine, _) = get_prepared_machine_test_helper(user_id(), false).await;
let alice_id = user_id!("@alice:example.org");
let (_, request) = machine.query_keys_for_users(vec![alice_id]);
assert!(request.device_keys.contains_key(alice_id));
}
#[async_test]
async fn test_missing_sessions_calculation() {
let (machine, _) = get_machine_after_query_test_helper().await;
let alice = alice_id();
let alice_device = alice_device_id();
let (_, missing_sessions) =
machine.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
assert!(missing_sessions.one_time_keys.contains_key(alice));
let user_sessions = missing_sessions.one_time_keys.get(alice).unwrap();
assert!(user_sessions.contains_key(alice_device));
}
pub async fn create_session(
machine: &OlmMachine,
user_id: &UserId,
device_id: &DeviceId,
key_id: OwnedDeviceKeyId,
one_time_key: Raw<OneTimeKey>,
) {
let one_time_keys = BTreeMap::from([(
user_id.to_owned(),
BTreeMap::from([(device_id.to_owned(), BTreeMap::from([(key_id, one_time_key)]))]),
)]);
let response = claim_keys::v3::Response::new(one_time_keys);
machine.inner.session_manager.create_sessions(&response).await.unwrap();
}
#[async_test]
async fn test_session_creation() {
let (alice_machine, bob_machine, mut one_time_keys) =
get_machine_pair(alice_id(), user_id(), false).await;
let (key_id, one_time_key) = one_time_keys.pop_first().unwrap();
create_session(
&alice_machine,
bob_machine.user_id(),
bob_machine.device_id(),
key_id,
one_time_key,
)
.await;
let session = alice_machine
.store()
.get_sessions(
&bob_machine.store().static_account().identity_keys().curve25519.to_base64(),
)
.await
.unwrap()
.unwrap();
assert!(!session.lock().await.is_empty())
}
#[async_test]
async fn test_getting_most_recent_session() {
let (alice_machine, bob_machine, mut one_time_keys) =
get_machine_pair(alice_id(), user_id(), false).await;
let (key_id, one_time_key) = one_time_keys.pop_first().unwrap();
let device = alice_machine
.get_device(bob_machine.user_id(), bob_machine.device_id(), None)
.await
.unwrap()
.unwrap();
assert!(device.get_most_recent_session().await.unwrap().is_none());
create_session(
&alice_machine,
bob_machine.user_id(),
bob_machine.device_id(),
key_id,
one_time_key.to_owned(),
)
.await;
for _ in 0..10 {
let (key_id, one_time_key) = one_time_keys.pop_first().unwrap();
create_session(
&alice_machine,
bob_machine.user_id(),
bob_machine.device_id(),
key_id,
one_time_key.to_owned(),
)
.await;
}
let session_id = {
let sessions = alice_machine
.store()
.get_sessions(&bob_machine.identity_keys().curve25519.to_base64())
.await
.unwrap()
.unwrap();
let mut use_time = SystemTime::now();
let mut sessions = sessions.lock().await;
let mut session_id = None;
let (_, sessions_slice) = sessions.as_mut_slice().split_last_mut().unwrap();
for session in sessions_slice.iter_mut().skip(1) {
session.creation_time = SecondsSinceUnixEpoch::from_system_time(use_time).unwrap();
use_time += Duration::from_secs(10);
session_id = Some(session.session_id().to_owned());
}
session_id.unwrap()
};
let newest_session = device.get_most_recent_session().await.unwrap().unwrap();
assert_eq!(
newest_session.session_id(),
session_id,
"The session we found is the one that was most recently created"
);
}
async fn olm_encryption_test_helper(use_fallback_key: bool) {
let (alice, bob) =
get_machine_pair_with_session(alice_id(), user_id(), use_fallback_key).await;
let bob_device =
alice.get_device(bob.user_id(), bob.device_id(), None).await.unwrap().unwrap();
let (_, content) = bob_device
.encrypt("m.dummy", ToDeviceDummyEventContent::new())
.await
.expect("We should be able to encrypt a dummy event.");
let event = ToDeviceEvent::new(
alice.user_id().to_owned(),
content
.deserialize_as()
.expect("We should be able to deserialize the encrypted content"),
);
let decrypted = bob
.store()
.with_transaction(|mut tr| async {
let res =
bob.decrypt_to_device_event(&mut tr, &event, &mut Changes::default()).await?;
Ok((tr, res))
})
.await
.expect("We should be able to decrypt the event.")
.result
.raw_event
.deserialize()
.expect("We should be able to deserialize the decrypted event.");
assert_let!(AnyToDeviceEvent::Dummy(decrypted) = decrypted);
assert_eq!(&decrypted.sender, alice.user_id());
bob.store()
.with_transaction(|mut tr| async {
let res =
bob.decrypt_to_device_event(&mut tr, &event, &mut Changes::default()).await?;
Ok((tr, res))
})
.await
.expect_err(
"Decrypting a replayed event should not succeed, even if it's a pre-key message",
);
}
#[async_test]
async fn test_olm_encryption() {
olm_encryption_test_helper(false).await;
}
#[async_test]
async fn test_olm_encryption_with_fallback_key() {
olm_encryption_test_helper(true).await;
}
#[async_test]
async fn test_room_key_sharing() {
let (alice, bob) = get_machine_pair_with_session(alice_id(), user_id(), false).await;
let room_id = room_id!("!test:example.org");
let to_device_requests = alice
.share_room_key(room_id, iter::once(bob.user_id()), EncryptionSettings::default())
.await
.unwrap();
let event = ToDeviceEvent::new(
alice.user_id().to_owned(),
to_device_requests_to_content(to_device_requests),
);
let event = json_convert(&event).unwrap();
let alice_session =
alice.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
let (decrypted, room_key_updates) = bob
.receive_sync_changes(EncryptionSyncChanges {
to_device_events: vec![event],
changed_devices: &Default::default(),
one_time_keys_counts: &Default::default(),
unused_fallback_keys: None,
next_batch_token: None,
})
.await
.unwrap();
let event = decrypted[0].deserialize().unwrap();
if let AnyToDeviceEvent::RoomKey(event) = event {
assert_eq!(&event.sender, alice.user_id());
assert!(event.content.session_key.is_empty());
} else {
panic!("expected RoomKeyEvent found {event:?}");
}
let session =
bob.store().get_inbound_group_session(room_id, alice_session.session_id()).await;
assert!(session.unwrap().is_some());
assert_eq!(room_key_updates.len(), 1);
assert_eq!(room_key_updates[0].room_id, room_id);
assert_eq!(room_key_updates[0].session_id, alice_session.session_id());
}
#[async_test]
async fn test_request_missing_secrets() {
let (alice, _) = get_machine_pair_with_session(alice_id(), bob_id(), false).await;
let should_query_secrets = alice.query_missing_secrets_from_other_sessions().await.unwrap();
assert!(should_query_secrets);
let outgoing_to_device = alice
.outgoing_requests()
.await
.unwrap()
.into_iter()
.filter(|outgoing| match outgoing.request.as_ref() {
OutgoingRequests::ToDeviceRequest(request) => {
request.event_type.to_string() == "m.secret.request"
}
_ => false,
})
.collect_vec();
assert_eq!(outgoing_to_device.len(), 4);
let should_query_secrets_now =
alice.query_missing_secrets_from_other_sessions().await.unwrap();
assert!(!should_query_secrets_now);
}
#[async_test]
async fn test_request_missing_secrets_cross_signed() {
let (alice, bob) = get_machine_pair_with_session(alice_id(), bob_id(), false).await;
setup_cross_signing_for_machine_test_helper(&alice, &bob).await;
let should_query_secrets = alice.query_missing_secrets_from_other_sessions().await.unwrap();
assert!(should_query_secrets);
let outgoing_to_device = alice
.outgoing_requests()
.await
.unwrap()
.into_iter()
.filter(|outgoing| match outgoing.request.as_ref() {
OutgoingRequests::ToDeviceRequest(request) => {
request.event_type.to_string() == "m.secret.request"
}
_ => false,
})
.collect_vec();
assert_eq!(outgoing_to_device.len(), 1);
let should_query_secrets_now =
alice.query_missing_secrets_from_other_sessions().await.unwrap();
assert!(!should_query_secrets_now);
}
#[async_test]
async fn test_megolm_encryption() {
let (alice, bob) =
get_machine_pair_with_setup_sessions_test_helper(alice_id(), user_id(), false).await;
let room_id = room_id!("!test:example.org");
let to_device_requests = alice
.share_room_key(room_id, iter::once(bob.user_id()), EncryptionSettings::default())
.await
.unwrap();
let event = ToDeviceEvent::new(
alice.user_id().to_owned(),
to_device_requests_to_content(to_device_requests),
);
let mut room_keys_received_stream = Box::pin(bob.store().room_keys_received_stream());
let group_session = bob
.store()
.with_transaction(|mut tr| async {
let res =
bob.decrypt_to_device_event(&mut tr, &event, &mut Changes::default()).await?;
Ok((tr, res))
})
.await
.unwrap()
.inbound_group_session
.unwrap();
bob.store().save_inbound_group_sessions(&[group_session.clone()]).await.unwrap();
let room_keys = room_keys_received_stream
.next()
.now_or_never()
.flatten()
.expect("We should have received an update of room key infos");
assert_eq!(room_keys.len(), 1);
assert_eq!(room_keys[0].session_id, group_session.session_id());
let plaintext = "It is a secret to everybody";
let content = RoomMessageEventContent::text_plain(plaintext);
let encrypted_content = alice
.encrypt_room_event(room_id, AnyMessageLikeEventContent::RoomMessage(content.clone()))
.await
.unwrap();
let event = json!({
"event_id": "$xxxxx:example.org",
"origin_server_ts": MilliSecondsSinceUnixEpoch::now(),
"sender": alice.user_id(),
"type": "m.room.encrypted",
"content": encrypted_content,
});
let event = json_convert(&event).unwrap();
let decrypted_event =
bob.decrypt_room_event(&event, room_id).await.unwrap().event.deserialize().unwrap();
if let AnyTimelineEvent::MessageLike(AnyMessageLikeEvent::RoomMessage(
MessageLikeEvent::Original(OriginalMessageLikeEvent { sender, content, .. }),
)) = decrypted_event
{
assert_eq!(&sender, alice.user_id());
if let MessageType::Text(c) = &content.msgtype {
assert_eq!(&c.body, plaintext);
} else {
panic!("Decrypted event has a mismatched content");
}
} else {
panic!("Decrypted room event has the wrong type");
}
if let Some(igs) = room_keys_received_stream.next().now_or_never() {
panic!("Session stream unexpectedly returned update: {igs:?}");
}
}
#[async_test]
async fn test_withheld_unverified() {
let (alice, bob) =
get_machine_pair_with_setup_sessions_test_helper(alice_id(), user_id(), false).await;
let room_id = room_id!("!test:example.org");
let encryption_settings = EncryptionSettings::default();
let encryption_settings =
EncryptionSettings { only_allow_trusted_devices: true, ..encryption_settings };
let to_device_requests = alice
.share_room_key(room_id, iter::once(bob.user_id()), encryption_settings)
.await
.expect("Share room key should be ok");
let wh_content = to_device_requests[0]
.messages
.values()
.next()
.unwrap()
.values()
.next()
.unwrap()
.deserialize_as::<RoomKeyWithheldContent>()
.expect("Deserialize should work");
let event = ToDeviceEvent::new(alice.user_id().to_owned(), wh_content);
let event = json_convert(&event).unwrap();
bob.receive_sync_changes(EncryptionSyncChanges {
to_device_events: vec![event],
changed_devices: &Default::default(),
one_time_keys_counts: &Default::default(),
unused_fallback_keys: None,
next_batch_token: None,
})
.await
.unwrap();
let plaintext = "You shouldn't be able to decrypt that message";
let content = RoomMessageEventContent::text_plain(plaintext);
let content = alice
.encrypt_room_event(room_id, AnyMessageLikeEventContent::RoomMessage(content.clone()))
.await
.unwrap();
let room_event = json!({
"event_id": "$xxxxx:example.org",
"origin_server_ts": MilliSecondsSinceUnixEpoch::now(),
"sender": alice.user_id(),
"type": "m.room.encrypted",
"content": content,
});
let room_event = json_convert(&room_event).unwrap();
let decrypt_result = bob.decrypt_room_event(&room_event, room_id).await;
assert_matches!(decrypt_result, Err(MegolmError::MissingRoomKey(Some(_))));
let err = decrypt_result.err().unwrap();
assert_matches!(err, MegolmError::MissingRoomKey(Some(WithheldCode::Unverified)));
}
#[async_test]
async fn test_decryption_verification_state() {
macro_rules! assert_shield {
($foo: ident, $strict: ident, $lax: ident) => {
let lax = $foo.verification_state.to_shield_state_lax();
let strict = $foo.verification_state.to_shield_state_strict();
assert_matches!(lax, ShieldState::$lax { .. });
assert_matches!(strict, ShieldState::$strict { .. });
};
}
let (alice, bob) =
get_machine_pair_with_setup_sessions_test_helper(alice_id(), user_id(), false).await;
let room_id = room_id!("!test:example.org");
let to_device_requests = alice
.share_room_key(room_id, iter::once(bob.user_id()), EncryptionSettings::default())
.await
.unwrap();
let event = ToDeviceEvent::new(
alice.user_id().to_owned(),
to_device_requests_to_content(to_device_requests),
);
let group_session = bob
.store()
.with_transaction(|mut tr| async {
let res =
bob.decrypt_to_device_event(&mut tr, &event, &mut Changes::default()).await?;
Ok((tr, res))
})
.await
.unwrap()
.inbound_group_session;
let export = group_session.as_ref().unwrap().clone().export().await;
bob.store().save_inbound_group_sessions(&[group_session.unwrap()]).await.unwrap();
let plaintext = "It is a secret to everybody";
let content = RoomMessageEventContent::text_plain(plaintext);
let encrypted_content = alice
.encrypt_room_event(room_id, AnyMessageLikeEventContent::RoomMessage(content.clone()))
.await
.unwrap();
let event = json!({
"event_id": "$xxxxx:example.org",
"origin_server_ts": MilliSecondsSinceUnixEpoch::now(),
"sender": alice.user_id(),
"type": "m.room.encrypted",
"content": encrypted_content,
});
let event = json_convert(&event).unwrap();
let encryption_info =
bob.decrypt_room_event(&event, room_id).await.unwrap().encryption_info.unwrap();
assert_eq!(
VerificationState::Unverified(VerificationLevel::UnsignedDevice),
encryption_info.verification_state
);
assert_shield!(encryption_info, Red, Red);
let encryption_info = bob.get_room_event_encryption_info(&event, room_id).await.unwrap();
assert_eq!(
VerificationState::Unverified(VerificationLevel::UnsignedDevice),
encryption_info.verification_state
);
assert_shield!(encryption_info, Red, Red);
bob.get_device(alice.user_id(), alice_device_id(), None)
.await
.unwrap()
.unwrap()
.set_trust_state(LocalTrust::Verified);
let encryption_info = bob.get_room_event_encryption_info(&event, room_id).await.unwrap();
assert_eq!(
VerificationState::Unverified(VerificationLevel::UnsignedDevice),
encryption_info.verification_state
);
assert_shield!(encryption_info, Red, Red);
setup_cross_signing_for_machine_test_helper(&alice, &bob).await;
let bob_id_from_alice = alice.get_identity(bob.user_id(), None).await.unwrap();
assert_matches!(bob_id_from_alice, Some(UserIdentities::Other(_)));
let alice_id_from_bob = bob.get_identity(alice.user_id(), None).await.unwrap();
assert_matches!(alice_id_from_bob, Some(UserIdentities::Other(_)));
let encryption_info = bob.get_room_event_encryption_info(&event, room_id).await.unwrap();
assert_eq!(
VerificationState::Unverified(VerificationLevel::UnsignedDevice),
encryption_info.verification_state
);
assert_shield!(encryption_info, Red, Red);
sign_alice_device_for_machine_test_helper(&alice, &bob).await;
let encryption_info = bob.get_room_event_encryption_info(&event, room_id).await.unwrap();
assert_eq!(
VerificationState::Unverified(VerificationLevel::UnverifiedIdentity),
encryption_info.verification_state
);
assert_shield!(encryption_info, Red, None);
mark_alice_identity_as_verified_test_helper(&alice, &bob).await;
let encryption_info = bob.get_room_event_encryption_info(&event, room_id).await.unwrap();
assert_eq!(VerificationState::Verified, encryption_info.verification_state);
assert_shield!(encryption_info, None, None);
let imported = InboundGroupSession::from_export(&export).unwrap();
bob.store().save_inbound_group_sessions(&[imported]).await.unwrap();
let encryption_info = bob.get_room_event_encryption_info(&event, room_id).await.unwrap();
assert_eq!(
VerificationState::Unverified(VerificationLevel::None(
DeviceLinkProblem::InsecureSource
)),
encryption_info.verification_state
);
assert_shield!(encryption_info, Red, Grey);
}
#[async_test]
async fn test_decrypt_unencrypted_event() {
let (bob, _) = get_prepared_machine_test_helper(user_id(), false).await;
let room_id = room_id!("!test:example.org");
let event = json!({
"event_id": "$xxxxx:example.org",
"origin_server_ts": MilliSecondsSinceUnixEpoch::now(),
"sender": user_id(),
"type": "m.room.encrypted",
"content": RoomMessageEventContent::text_plain("plain"),
});
let event = json_convert(&event).unwrap();
assert_matches!(
bob.decrypt_room_event(&event, room_id).await,
Err(MegolmError::JsonError(..))
);
assert_matches!(
bob.get_room_event_encryption_info(&event, room_id).await,
Err(MegolmError::JsonError(..))
);
}
async fn setup_cross_signing_for_machine_test_helper(alice: &OlmMachine, bob: &OlmMachine) {
let CrossSigningBootstrapRequests { upload_signing_keys_req: alice_upload_signing, .. } =
alice.bootstrap_cross_signing(false).await.expect("Expect Alice x-signing key request");
let CrossSigningBootstrapRequests { upload_signing_keys_req: bob_upload_signing, .. } =
bob.bootstrap_cross_signing(false).await.expect("Expect Bob x-signing key request");
let bob_device_keys = bob
.get_device(bob.user_id(), bob.device_id(), None)
.await
.unwrap()
.unwrap()
.as_device_keys()
.to_owned();
let alice_device_keys = alice
.get_device(alice.user_id(), alice.device_id(), None)
.await
.unwrap()
.unwrap()
.as_device_keys()
.to_owned();
let json = json!({
"device_keys": {
bob.user_id() : { bob.device_id() : bob_device_keys},
alice.user_id() : { alice.device_id(): alice_device_keys }
},
"failures": {},
"master_keys": {
bob.user_id() : bob_upload_signing.master_key.unwrap(),
alice.user_id() : alice_upload_signing.master_key.unwrap()
},
"user_signing_keys": {
bob.user_id() : bob_upload_signing.user_signing_key.unwrap(),
alice.user_id() : alice_upload_signing.user_signing_key.unwrap()
},
"self_signing_keys": {
bob.user_id() : bob_upload_signing.self_signing_key.unwrap(),
alice.user_id() : alice_upload_signing.self_signing_key.unwrap()
},
}
);
let kq_response = KeyQueryResponse::try_from_http_response(response_from_file(&json))
.expect("Can't parse the `/keys/upload` response");
alice.receive_keys_query_response(&TransactionId::new(), &kq_response).await.unwrap();
bob.receive_keys_query_response(&TransactionId::new(), &kq_response).await.unwrap();
}
async fn sign_alice_device_for_machine_test_helper(alice: &OlmMachine, bob: &OlmMachine) {
let CrossSigningBootstrapRequests {
upload_signing_keys_req: upload_signing,
upload_signatures_req: upload_signature,
..
} = alice.bootstrap_cross_signing(false).await.expect("Expect Alice x-signing key request");
let mut device_keys = alice
.get_device(alice.user_id(), alice.device_id(), None)
.await
.unwrap()
.unwrap()
.as_device_keys()
.to_owned();
let raw_extracted = upload_signature
.signed_keys
.get(alice.user_id())
.unwrap()
.iter()
.next()
.unwrap()
.1
.get();
let new_signature: DeviceKeys = serde_json::from_str(raw_extracted).unwrap();
let self_sign_key_id = upload_signing
.self_signing_key
.as_ref()
.unwrap()
.get_first_key_and_id()
.unwrap()
.0
.to_owned();
device_keys.signatures.add_signature(
alice.user_id().to_owned(),
self_sign_key_id.to_owned(),
new_signature.signatures.get_signature(alice.user_id(), &self_sign_key_id).unwrap(),
);
let updated_keys_with_x_signing = json!({ device_keys.device_id.to_string(): device_keys });
let json = json!({
"device_keys": {
alice.user_id() : updated_keys_with_x_signing
},
"failures": {},
"master_keys": {
alice.user_id() : upload_signing.master_key.unwrap(),
},
"user_signing_keys": {
alice.user_id() : upload_signing.user_signing_key.unwrap(),
},
"self_signing_keys": {
alice.user_id() : upload_signing.self_signing_key.unwrap(),
},
}
);
let kq_response = KeyQueryResponse::try_from_http_response(response_from_file(&json))
.expect("Can't parse the `/keys/upload` response");
alice.receive_keys_query_response(&TransactionId::new(), &kq_response).await.unwrap();
bob.receive_keys_query_response(&TransactionId::new(), &kq_response).await.unwrap();
}
async fn mark_alice_identity_as_verified_test_helper(alice: &OlmMachine, bob: &OlmMachine) {
let alice_device =
bob.get_device(alice.user_id(), alice.device_id(), None).await.unwrap().unwrap();
let alice_identity =
bob.get_identity(alice.user_id(), None).await.unwrap().unwrap().other().unwrap();
let upload_request = alice_identity.verify().await.unwrap();
let raw_extracted =
upload_request.signed_keys.get(alice.user_id()).unwrap().iter().next().unwrap().1.get();
let new_signature: CrossSigningKey = serde_json::from_str(raw_extracted).unwrap();
let user_key_id = bob
.bootstrap_cross_signing(false)
.await
.expect("Expect Alice x-signing key request")
.upload_signing_keys_req
.user_signing_key
.unwrap()
.get_first_key_and_id()
.unwrap()
.0
.to_owned();
let mut alice_updated_msk =
alice_device.device_owner_identity.as_ref().unwrap().master_key().as_ref().to_owned();
alice_updated_msk.signatures.add_signature(
bob.user_id().to_owned(),
user_key_id.to_owned(),
new_signature.signatures.get_signature(bob.user_id(), &user_key_id).unwrap(),
);
let alice_x_keys = alice
.bootstrap_cross_signing(false)
.await
.expect("Expect Alice x-signing key request")
.upload_signing_keys_req;
let json = json!({
"device_keys": {
alice.user_id() : { alice.device_id(): alice_device.as_device_keys().to_owned() }
},
"failures": {},
"master_keys": {
alice.user_id() : alice_updated_msk,
},
"user_signing_keys": {
alice.user_id() : alice_x_keys.user_signing_key.unwrap(),
},
"self_signing_keys": {
alice.user_id() : alice_x_keys.self_signing_key.unwrap(),
},
}
);
let kq_response = KeyQueryResponse::try_from_http_response(response_from_file(&json))
.expect("Can't parse the `/keys/upload` response");
alice.receive_keys_query_response(&TransactionId::new(), &kq_response).await.unwrap();
bob.receive_keys_query_response(&TransactionId::new(), &kq_response).await.unwrap();
assert!(bob
.get_identity(alice.user_id(), None)
.await
.unwrap()
.unwrap()
.other()
.unwrap()
.is_verified());
}
#[async_test]
async fn test_verification_states_multiple_device() {
let (bob, _) = get_prepared_machine_test_helper(user_id(), false).await;
let other_user_id = user_id!("@web2:localhost:8482");
let data = response_from_file(&test_json::KEYS_QUERY_TWO_DEVICES_ONE_SIGNED);
let response = get_keys::v3::Response::try_from_http_response(data)
.expect("Can't parse the `/keys/upload` response");
let (device_change, identity_change) =
bob.receive_keys_query_response(&TransactionId::new(), &response).await.unwrap();
assert_eq!(device_change.new.len(), 2);
assert_eq!(identity_change.new.len(), 1);
let devices = bob.store().get_user_devices(other_user_id).await.unwrap();
assert_eq!(devices.devices().count(), 2);
let fake_room_id = room_id!("!roomid:example.com");
let id_keys = bob.identity_keys();
let fake_device_id = bob.device_id().into();
let olm = OutboundGroupSession::new(
fake_device_id,
Arc::new(id_keys),
fake_room_id,
EncryptionSettings::default(),
)
.unwrap()
.session_key()
.await;
let web_unverified_inbound_session = InboundGroupSession::new(
Curve25519PublicKey::from_base64("LTpv2DGMhggPAXO02+7f68CNEp6A40F0Yl8B094Y8gc")
.unwrap(),
Ed25519PublicKey::from_base64("loz5i40dP+azDtWvsD0L/xpnCjNkmrcvtXVXzCHX8Vw").unwrap(),
fake_room_id,
&olm,
EventEncryptionAlgorithm::MegolmV1AesSha2,
None,
)
.unwrap();
let (state, _) = bob
.get_verification_state(&web_unverified_inbound_session, other_user_id)
.await
.unwrap();
assert_eq!(VerificationState::Unverified(VerificationLevel::UnsignedDevice), state);
let web_signed_inbound_session = InboundGroupSession::new(
Curve25519PublicKey::from_base64("XJixbpnfIk+RqcK5T6moqVY9d9Q1veR8WjjSlNiQNT0")
.unwrap(),
Ed25519PublicKey::from_base64("48f3WQAMGwYLBg5M5qUhqnEVA8yeibjZpPsShoWMFT8").unwrap(),
fake_room_id,
&olm,
EventEncryptionAlgorithm::MegolmV1AesSha2,
None,
)
.unwrap();
let (state, _) =
bob.get_verification_state(&web_signed_inbound_session, other_user_id).await.unwrap();
assert_eq!(VerificationState::Unverified(VerificationLevel::UnverifiedIdentity), state);
}
#[async_test]
#[cfg(feature = "automatic-room-key-forwarding")]
async fn test_query_ratcheted_key() {
let (alice, bob) =
get_machine_pair_with_setup_sessions_test_helper(alice_id(), user_id(), false).await;
let room_id = room_id!("!test:example.org");
let bob_id = user_id();
let bob_other_device = device_id!("OTHERBOB");
let bob_other_machine = OlmMachine::new(bob_id, bob_other_device).await;
let bob_other_device =
ReadOnlyDevice::from_machine_test_helper(&bob_other_machine).await.unwrap();
bob.store().save_devices(&[bob_other_device]).await.unwrap();
bob.get_device(bob_id, device_id!("OTHERBOB"), None)
.await
.unwrap()
.expect("should exist")
.set_trust_state(LocalTrust::Verified);
alice.create_outbound_group_session_with_defaults_test_helper(room_id).await.unwrap();
let plaintext = "It is a secret to everybody";
let content = RoomMessageEventContent::text_plain(plaintext);
let content = alice
.encrypt_room_event(room_id, AnyMessageLikeEventContent::RoomMessage(content.clone()))
.await
.unwrap();
let room_event = json!({
"event_id": "$xxxxx:example.org",
"origin_server_ts": MilliSecondsSinceUnixEpoch::now(),
"sender": alice.user_id(),
"type": "m.room.encrypted",
"content": content,
});
let to_device_requests = alice
.share_room_key(room_id, iter::once(bob.user_id()), EncryptionSettings::default())
.await
.unwrap();
let event = ToDeviceEvent::new(
alice.user_id().to_owned(),
to_device_requests_to_content(to_device_requests),
);
let group_session = bob
.store()
.with_transaction(|mut tr| async {
let res =
bob.decrypt_to_device_event(&mut tr, &event, &mut Changes::default()).await?;
Ok((tr, res))
})
.await
.unwrap()
.inbound_group_session;
bob.store().save_inbound_group_sessions(&[group_session.unwrap()]).await.unwrap();
let room_event = json_convert(&room_event).unwrap();
let decrypt_error = bob.decrypt_room_event(&room_event, room_id).await.unwrap_err();
if let MegolmError::Decryption(vodo_error) = decrypt_error {
if let vodozemac::megolm::DecryptionError::UnknownMessageIndex(_, _) = vodo_error {
let outgoing_to_devices =
bob.inner.key_request_machine.outgoing_to_device_requests().await.unwrap();
assert_eq!(1, outgoing_to_devices.len());
} else {
panic!("Should be UnknownMessageIndex error ")
}
} else {
panic!("Should have been unable to decrypt")
}
}
#[async_test]
async fn test_interactive_verification() {
let (alice, bob) =
get_machine_pair_with_setup_sessions_test_helper(alice_id(), user_id(), false).await;
let bob_device =
alice.get_device(bob.user_id(), bob.device_id(), None).await.unwrap().unwrap();
assert!(!bob_device.is_verified());
let (alice_sas, request) = bob_device.start_verification().await.unwrap();
let event = request_to_event(alice.user_id(), &request.into());
bob.handle_verification_event(&event).await;
let bob_sas = bob
.get_verification(alice.user_id(), alice_sas.flow_id().as_str())
.unwrap()
.sas_v1()
.unwrap();
assert!(alice_sas.emoji().is_none());
assert!(bob_sas.emoji().is_none());
let event = bob_sas.accept().map(|r| request_to_event(bob.user_id(), &r)).unwrap();
alice.handle_verification_event(&event).await;
let (event, request_id) = alice
.inner
.verification_machine
.outgoing_messages()
.first()
.map(|r| (outgoing_request_to_event(alice.user_id(), r), r.request_id.to_owned()))
.unwrap();
alice.mark_request_as_sent(&request_id, &ToDeviceResponse::new()).await.unwrap();
bob.handle_verification_event(&event).await;
let (event, request_id) = bob
.inner
.verification_machine
.outgoing_messages()
.first()
.map(|r| (outgoing_request_to_event(bob.user_id(), r), r.request_id.to_owned()))
.unwrap();
alice.handle_verification_event(&event).await;
bob.mark_request_as_sent(&request_id, &ToDeviceResponse::new()).await.unwrap();
assert!(alice_sas.emoji().is_some());
assert!(bob_sas.emoji().is_some());
assert_eq!(alice_sas.emoji(), bob_sas.emoji());
assert_eq!(alice_sas.decimals(), bob_sas.decimals());
let contents = bob_sas.confirm().await.unwrap().0;
assert!(contents.len() == 1);
let event = request_to_event(bob.user_id(), &contents[0]);
alice.handle_verification_event(&event).await;
assert!(!alice_sas.is_done());
assert!(!bob_sas.is_done());
let contents = alice_sas.confirm().await.unwrap().0;
assert!(contents.len() == 1);
let event = request_to_event(alice.user_id(), &contents[0]);
assert!(alice_sas.is_done());
assert!(bob_device.is_verified());
let alice_device =
bob.get_device(alice.user_id(), alice.device_id(), None).await.unwrap().unwrap();
assert!(!alice_device.is_verified());
bob.handle_verification_event(&event).await;
assert!(bob_sas.is_done());
assert!(alice_device.is_verified());
}
#[async_test]
async fn test_interactive_verification_started_from_request() {
let (alice, bob) =
get_machine_pair_with_setup_sessions_test_helper(alice_id(), user_id(), false).await;
let bob_device =
alice.get_device(bob.user_id(), bob.device_id(), None).await.unwrap().unwrap();
assert!(!bob_device.is_verified());
let (alice_ver_req, request) =
bob_device.request_verification_with_methods(vec![VerificationMethod::SasV1]).await;
let event = request_to_event(alice.user_id(), &request);
bob.handle_verification_event(&event).await;
let flow_id = alice_ver_req.flow_id().as_str();
let verification_request = bob.get_verification_request(alice.user_id(), flow_id).unwrap();
let accept_request =
verification_request.accept_with_methods(vec![VerificationMethod::SasV1]).unwrap();
let (_, start_request_from_bob) = verification_request.start_sas().await.unwrap().unwrap();
let event = request_to_event(bob.user_id(), &accept_request);
alice.handle_verification_event(&event).await;
let verification_request = alice.get_verification_request(bob.user_id(), flow_id).unwrap();
let (alice_sas, start_request_from_alice) =
verification_request.start_sas().await.unwrap().unwrap();
let event = request_to_event(bob.user_id(), &start_request_from_bob);
alice.handle_verification_event(&event).await;
assert!(alice.user_id() < bob.user_id());
let event = request_to_event(alice.user_id(), &start_request_from_alice);
bob.handle_verification_event(&event).await;
let bob_sas = bob
.get_verification(alice.user_id(), alice_sas.flow_id().as_str())
.unwrap()
.sas_v1()
.unwrap();
assert!(alice_sas.emoji().is_none());
assert!(bob_sas.emoji().is_none());
let event = bob_sas.accept().map(|r| request_to_event(bob.user_id(), &r)).unwrap();
alice.handle_verification_event(&event).await;
let msgs = alice.inner.verification_machine.outgoing_messages();
assert!(msgs.len() == 1);
let msg = &msgs[0];
let event = outgoing_request_to_event(alice.user_id(), msg);
alice.inner.verification_machine.mark_request_as_sent(&msg.request_id);
bob.handle_verification_event(&event).await;
let msgs = bob.inner.verification_machine.outgoing_messages();
assert!(msgs.len() == 1);
let msg = &msgs[0];
let event = outgoing_request_to_event(bob.user_id(), msg);
bob.inner.verification_machine.mark_request_as_sent(&msg.request_id);
alice.handle_verification_event(&event).await;
assert!(alice_sas.emoji().is_some());
assert!(bob_sas.emoji().is_some());
assert_eq!(alice_sas.emoji(), bob_sas.emoji());
assert_eq!(alice_sas.decimals(), bob_sas.decimals());
let contents = bob_sas.confirm().await.unwrap().0;
assert!(contents.len() == 1);
let event = request_to_event(bob.user_id(), &contents[0]);
alice.handle_verification_event(&event).await;
assert!(!alice_sas.is_done());
assert!(!bob_sas.is_done());
let contents = alice_sas.confirm().await.unwrap().0;
assert!(contents.len() == 2);
let event_mac = request_to_event(alice.user_id(), &contents[0]);
let event_done = request_to_event(alice.user_id(), &contents[1]);
bob.handle_verification_event(&event_mac).await;
let msgs = bob.inner.verification_machine.outgoing_messages();
eprintln!("{msgs:?}");
assert!(msgs.len() == 1);
let event = msgs.first().map(|r| outgoing_request_to_event(bob.user_id(), r)).unwrap();
let alice_device =
bob.get_device(alice.user_id(), alice.device_id(), None).await.unwrap().unwrap();
assert!(!bob_sas.is_done());
assert!(!alice_device.is_verified());
bob.handle_verification_event(&event_done).await;
assert!(bob_sas.is_done());
assert!(alice_device.is_verified());
assert!(!alice_sas.is_done());
assert!(!bob_device.is_verified());
eprintln!("{event:?}");
alice.handle_verification_event(&event).await;
assert!(alice_sas.is_done());
assert!(bob_device.is_verified());
}
#[async_test]
async fn test_room_key_over_megolm() {
let (alice, bob) =
get_machine_pair_with_setup_sessions_test_helper(alice_id(), user_id(), false).await;
let room_id = room_id!("!test:example.org");
let to_device_requests = alice
.share_room_key(room_id, iter::once(bob.user_id()), EncryptionSettings::default())
.await
.unwrap();
let event = ToDeviceEvent {
sender: alice.user_id().to_owned(),
content: to_device_requests_to_content(to_device_requests),
other: Default::default(),
};
let event = json_convert(&event).unwrap();
let changed_devices = DeviceLists::new();
let key_counts: BTreeMap<_, _> = Default::default();
let _ = bob
.receive_sync_changes(EncryptionSyncChanges {
to_device_events: vec![event],
changed_devices: &changed_devices,
one_time_keys_counts: &key_counts,
unused_fallback_keys: None,
next_batch_token: None,
})
.await
.unwrap();
let group_session = GroupSession::new(SessionConfig::version_1());
let session_key = group_session.session_key();
let session_id = group_session.session_id();
let content = message_like_event_content!({
"algorithm": "m.megolm.v1.aes-sha2",
"room_id": room_id,
"session_id": session_id,
"session_key": session_key.to_base64(),
});
let encrypted_content =
alice.encrypt_room_event_raw(room_id, "m.room_key", &content).await.unwrap();
let event = json!({
"sender": alice.user_id(),
"content": encrypted_content,
"type": "m.room.encrypted",
});
let event: EncryptedToDeviceEvent = serde_json::from_value(event).unwrap();
let decrypt_result = bob
.store()
.with_transaction(|mut tr| async {
let res =
bob.decrypt_to_device_event(&mut tr, &event, &mut Changes::default()).await?;
Ok((tr, res))
})
.await;
assert_matches!(
decrypt_result,
Err(OlmError::EventError(EventError::UnsupportedAlgorithm))
);
let event: Raw<AnyToDeviceEvent> = json_convert(&event).unwrap();
bob.receive_sync_changes(EncryptionSyncChanges {
to_device_events: vec![event],
changed_devices: &changed_devices,
one_time_keys_counts: &key_counts,
unused_fallback_keys: None,
next_batch_token: None,
})
.await
.unwrap();
let session = bob.store().get_inbound_group_session(room_id, &session_id).await;
assert!(session.unwrap().is_none());
}
#[async_test]
async fn test_room_key_with_fake_identity_keys() {
let room_id = room_id!("!test:localhost");
let (alice, _) =
get_machine_pair_with_setup_sessions_test_helper(alice_id(), user_id(), false).await;
let device = ReadOnlyDevice::from_machine_test_helper(&alice).await.unwrap();
alice.store().save_devices(&[device]).await.unwrap();
let (outbound, mut inbound) = alice
.store()
.static_account()
.create_group_session_pair(room_id, Default::default())
.await
.unwrap();
let fake_key = Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE")
.unwrap()
.into();
let signing_keys = SigningKeys::from([(DeviceKeyAlgorithm::Ed25519, fake_key)]);
inbound.creator_info.signing_keys = signing_keys.into();
let content = message_like_event_content!({});
let content = outbound.encrypt("m.dummy", &content).await;
alice.store().save_inbound_group_sessions(&[inbound]).await.unwrap();
let event = json!({
"sender": alice.user_id(),
"event_id": "$xxxxx:example.org",
"origin_server_ts": MilliSecondsSinceUnixEpoch::now(),
"type": "m.room.encrypted",
"content": content,
});
let event = json_convert(&event).unwrap();
assert_matches!(
alice.decrypt_room_event(&event, room_id).await,
Err(MegolmError::MismatchedIdentityKeys { .. })
);
}
#[async_test]
async fn importing_private_cross_signing_keys_verifies_the_public_identity() {
async fn create_additional_machine(machine: &OlmMachine) -> OlmMachine {
let second_machine =
OlmMachine::new(machine.user_id(), "ADDITIONAL_MACHINE".into()).await;
let identity = machine
.get_identity(machine.user_id(), None)
.await
.unwrap()
.expect("We should know about our own user identity if we bootstrapped it")
.own()
.unwrap();
let mut changes = Changes::default();
identity.mark_as_unverified();
changes.identities.new.push(crate::ReadOnlyUserIdentities::Own(identity.inner));
second_machine.store().save_changes(changes).await.unwrap();
second_machine
}
let (alice, bob) =
get_machine_pair_with_setup_sessions_test_helper(alice_id(), user_id(), false).await;
setup_cross_signing_for_machine_test_helper(&alice, &bob).await;
let second_alice = create_additional_machine(&alice).await;
let export = alice
.export_cross_signing_keys()
.await
.unwrap()
.expect("We should be able to export our cross-signing keys");
let identity = second_alice
.get_identity(second_alice.user_id(), None)
.await
.unwrap()
.expect("We should know about our own user identity")
.own()
.unwrap();
assert!(!identity.is_verified(), "Initially our identity should not be verified");
second_alice
.import_cross_signing_keys(export)
.await
.expect("We should be able to import our cross-signing keys");
let identity = second_alice
.get_identity(second_alice.user_id(), None)
.await
.unwrap()
.expect("We should know about our own user identity")
.own()
.unwrap();
assert!(
identity.is_verified(),
"Our identity should be verified after we imported the private cross-signing keys"
);
let second_bob = create_additional_machine(&bob).await;
let export = second_alice
.export_cross_signing_keys()
.await
.unwrap()
.expect("The machine should now be able to export cross-signing keys as well");
second_bob.import_cross_signing_keys(export).await.expect_err(
"Importing cross-signing keys that don't match our public identity should fail",
);
let identity = second_bob
.get_identity(second_bob.user_id(), None)
.await
.unwrap()
.expect("We should know about our own user identity")
.own()
.unwrap();
assert!(
!identity.is_verified(),
"Our identity should not be verified when there's a mismatch in the cross-signing keys"
);
}
#[async_test]
async fn test_wait_on_key_query_doesnt_block_store() {
let machine = OlmMachine::new(bob_id(), bob_device_id()).await;
machine.update_tracked_users([alice_id()]).await.unwrap();
let machine_cloned = machine.clone();
let wait = tokio::spawn(async move {
let machine = machine_cloned;
let user_devices =
machine.get_user_devices(alice_id(), Some(Duration::from_secs(10))).await.unwrap();
assert!(user_devices.devices().next().is_some());
});
tokio::task::yield_now().await;
let requests = machine.bootstrap_cross_signing(false).await.unwrap();
let req = requests.upload_keys_req.expect("upload keys request should be there");
let response = keys_upload_response();
let mark_request_as_sent = machine.mark_request_as_sent(&req.request_id, &response);
tokio::time::timeout(Duration::from_secs(5), mark_request_as_sent)
.await
.expect("no timeout")
.expect("the underlying request has been marked as sent");
let response = keys_query_response();
let key_queries = machine.inner.identity_manager.users_for_key_query().await.unwrap();
for (id, _) in key_queries {
machine.mark_request_as_sent(&id, &response).await.unwrap();
}
wait.await.unwrap();
}
#[async_test]
async fn room_settings_returns_none_for_unknown_room() {
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
let settings = machine.room_settings(room_id!("!test2:localhost")).await.unwrap();
assert!(settings.is_none());
}
#[async_test]
async fn stores_and_returns_room_settings() {
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
let room_id = room_id!("!test:localhost");
let settings = RoomSettings {
algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
only_allow_trusted_devices: true,
session_rotation_period: Some(Duration::from_secs(10)),
session_rotation_period_messages: Some(1234),
};
machine.set_room_settings(room_id, &settings).await.unwrap();
assert_eq!(machine.room_settings(room_id).await.unwrap(), Some(settings));
}
#[async_test]
async fn set_room_settings_rejects_invalid_algorithms() {
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
let room_id = room_id!("!test:localhost");
let err = machine
.set_room_settings(
room_id,
&RoomSettings {
algorithm: EventEncryptionAlgorithm::OlmV1Curve25519AesSha2,
..Default::default()
},
)
.await
.unwrap_err();
assert_matches!(err, SetRoomSettingsError::InvalidSettings)
}
#[async_test]
async fn set_room_settings_rejects_changes() {
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
let room_id = room_id!("!test:localhost");
machine
.set_room_settings(
room_id,
&RoomSettings { session_rotation_period_messages: Some(100), ..Default::default() },
)
.await
.unwrap();
let err = machine
.set_room_settings(
room_id,
&RoomSettings {
session_rotation_period_messages: Some(1000),
..Default::default()
},
)
.await
.unwrap_err();
assert_matches!(err, SetRoomSettingsError::EncryptionDowngrade);
}
#[async_test]
async fn set_room_settings_accepts_noop_changes() {
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
let room_id = room_id!("!test:localhost");
machine
.set_room_settings(
room_id,
&RoomSettings { session_rotation_period_messages: Some(100), ..Default::default() },
)
.await
.unwrap();
machine
.set_room_settings(
room_id,
&RoomSettings { session_rotation_period_messages: Some(100), ..Default::default() },
)
.await
.unwrap();
}
#[async_test]
async fn test_send_encrypted_to_device() {
let (alice, bob) = get_machine_pair_with_session(alice_id(), user_id(), false).await;
let custom_event_type = "m.new_device";
let custom_content = json!({
"device_id": "XYZABCDE",
"rooms": ["!726s6s6q:example.com"]
});
let device = alice.get_device(bob.user_id(), bob.device_id(), None).await.unwrap().unwrap();
let raw_encrypted = device
.encrypt_event_raw(custom_event_type, &custom_content)
.await
.expect("Should have encryted the content");
let request = ToDeviceRequest::new(
bob.user_id(),
DeviceIdOrAllDevices::DeviceId(bob_device_id().to_owned()),
"m.room.encrypted",
raw_encrypted.cast(),
);
assert_eq!("m.room.encrypted", request.event_type.to_string());
let messages = &request.messages;
assert_eq!(1, messages.len());
assert!(messages.get(bob.user_id()).is_some());
let target_devices = messages.get(bob.user_id()).unwrap();
assert_eq!(1, target_devices.len());
assert!(target_devices
.get(&DeviceIdOrAllDevices::DeviceId(bob_device_id().to_owned()))
.is_some());
let event = ToDeviceEvent::new(
alice.user_id().to_owned(),
to_device_requests_to_content(vec![request.clone().into()]),
);
let event = json_convert(&event).unwrap();
let sync_changes = EncryptionSyncChanges {
to_device_events: vec![event],
changed_devices: &Default::default(),
one_time_keys_counts: &Default::default(),
unused_fallback_keys: None,
next_batch_token: None,
};
let (decrypted, _) = bob.receive_sync_changes(sync_changes).await.unwrap();
assert_eq!(1, decrypted.len());
let decrypted_event = decrypted[0].deserialize().unwrap();
assert_eq!(decrypted_event.event_type().to_string(), custom_event_type.to_owned());
let decrypted_value = to_raw_value(&decrypted[0]).unwrap();
let decrypted_value = serde_json::to_value(decrypted_value).unwrap();
assert_eq!(
decrypted_value.get("content").unwrap().get("device_id").unwrap().as_str().unwrap(),
custom_content.get("device_id").unwrap().as_str().unwrap(),
);
assert_eq!(
decrypted_value.get("content").unwrap().get("rooms").unwrap().as_array().unwrap(),
custom_content.get("rooms").unwrap().as_array().unwrap(),
);
}
#[async_test]
async fn test_send_encrypted_to_device_no_session() {
let (alice, bob, _) = get_machine_pair(alice_id(), user_id(), false).await;
let custom_event_type = "m.new_device";
let custom_content = json!({
"device_id": "XYZABCDE",
"rooms": ["!726s6s6q:example.com"]
});
let encryption_result = alice
.get_device(bob.user_id(), bob_device_id(), None)
.await
.unwrap()
.unwrap()
.encrypt_event_raw(custom_event_type, &custom_content)
.await;
assert_matches!(encryption_result, Err(OlmError::MissingSession));
}
#[async_test]
async fn test_fix_incorrect_usage_of_backup_key_causing_decryption_errors() {
let store = MemoryStore::new();
let backup_decryption_key = BackupDecryptionKey::new().unwrap();
store
.save_changes(Changes {
backup_decryption_key: Some(backup_decryption_key.clone()),
backup_version: Some("1".to_owned()),
..Default::default()
})
.await
.unwrap();
let data = json!({
"algorithm": "m.megolm.v1.aes-sha2",
"room_id": "!room:id",
"sender_key": "FOvlmz18LLI3k/llCpqRoKT90+gFF8YhuL+v1YBXHlw",
"session_id": "/2K+V777vipCxPZ0gpY9qcpz1DYaXwuMRIu0UEP0Wa0",
"session_key": "AQAAAAAclzWVMeWBKH+B/WMowa3rb4ma3jEl6n5W4GCs9ue65CruzD3ihX+85pZ9hsV9Bf6fvhjp76WNRajoJYX0UIt7aosjmu0i+H+07hEQ0zqTKpVoSH0ykJ6stAMhdr6Q4uW5crBmdTTBIsqmoWsNJZKKoE2+ldYrZ1lrFeaJbjBIY/9ivle++74qQsT2dIKWPanKc9Q2Gl8LjESLtFBD9Fmt",
"sender_claimed_keys": {
"ed25519": "F4P7f1Z0RjbiZMgHk1xBCG3KC4/Ng9PmxLJ4hQ13sHA"
},
"forwarding_curve25519_key_chain": ["DBPC2zr6c9qimo9YRFK3RVr0Two/I6ODb9mbsToZN3Q", "bBc/qzZFOOKshMMT+i4gjS/gWPDoKfGmETs9yfw9430"]
});
let backed_up_room_key: BackedUpRoomKey = serde_json::from_value(data).unwrap();
let alice = OlmMachine::with_store(user_id(), alice_device_id(), store).await.unwrap();
let exported_key = ExportedRoomKey::from_backed_up_room_key(
room_id!("!room:id").to_owned(),
"/2K+V777vipCxPZ0gpY9qcpz1DYaXwuMRIu0UEP0Wa0".into(),
backed_up_room_key,
);
alice.store().import_exported_room_keys(vec![exported_key], |_, _| {}).await.unwrap();
let (_, request) = alice.backup_machine().backup().await.unwrap().unwrap();
let key_backup_data = request.rooms[&room_id!("!room:id").to_owned()]
.sessions
.get("/2K+V777vipCxPZ0gpY9qcpz1DYaXwuMRIu0UEP0Wa0")
.unwrap()
.deserialize()
.unwrap();
let ephemeral = key_backup_data.session_data.ephemeral.encode();
let ciphertext = key_backup_data.session_data.ciphertext.encode();
let mac = key_backup_data.session_data.mac.encode();
backup_decryption_key
.decrypt_v1(&ephemeral, &mac, &ciphertext)
.expect("The backed up key should be decrypted successfully");
}
}