use std::{fmt, sync::Arc};
use ruma::{serde::Raw, SecondsSinceUnixEpoch};
use serde::{Deserialize, Serialize};
use serde_json::json;
use tokio::sync::Mutex;
use tracing::{debug, Span};
use vodozemac::{
olm::{DecryptionError, OlmMessage, Session as InnerSession, SessionConfig, SessionPickle},
Curve25519PublicKey,
};
#[cfg(feature = "experimental-algorithms")]
use crate::types::events::room::encrypted::OlmV2Curve25519AesSha2Content;
use crate::{
error::{EventError, OlmResult, SessionUnpickleError},
types::{
events::room::encrypted::{OlmV1Curve25519AesSha2Content, ToDeviceEncryptedEventContent},
DeviceKeys, EventEncryptionAlgorithm,
},
DeviceData,
};
#[derive(Clone)]
pub struct Session {
pub inner: Arc<Mutex<InnerSession>>,
pub session_id: Arc<str>,
pub sender_key: Curve25519PublicKey,
pub our_device_keys: DeviceKeys,
pub created_using_fallback_key: bool,
pub creation_time: SecondsSinceUnixEpoch,
pub last_use_time: SecondsSinceUnixEpoch,
}
#[cfg(not(tarpaulin_include))]
impl fmt::Debug for Session {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Session")
.field("session_id", &self.session_id())
.field("sender_key", &self.sender_key)
.finish()
}
}
impl Session {
pub async fn decrypt(&mut self, message: &OlmMessage) -> Result<String, DecryptionError> {
let mut inner = self.inner.lock().await;
Span::current().record("session_id", inner.session_id());
let plaintext = inner.decrypt(message)?;
debug!(session=?inner, "Decrypted an Olm message");
let plaintext = String::from_utf8_lossy(&plaintext).to_string();
self.last_use_time = SecondsSinceUnixEpoch::now();
Ok(plaintext)
}
pub fn sender_key(&self) -> Curve25519PublicKey {
self.sender_key
}
pub async fn session_config(&self) -> SessionConfig {
self.inner.lock().await.session_config()
}
#[allow(clippy::unused_async)] pub async fn algorithm(&self) -> EventEncryptionAlgorithm {
#[cfg(feature = "experimental-algorithms")]
if self.session_config().await.version() == 2 {
EventEncryptionAlgorithm::OlmV2Curve25519AesSha2
} else {
EventEncryptionAlgorithm::OlmV1Curve25519AesSha2
}
#[cfg(not(feature = "experimental-algorithms"))]
EventEncryptionAlgorithm::OlmV1Curve25519AesSha2
}
pub(crate) async fn encrypt_helper(&mut self, plaintext: &str) -> OlmMessage {
let mut session = self.inner.lock().await;
let message = session.encrypt(plaintext);
self.last_use_time = SecondsSinceUnixEpoch::now();
debug!(?session, "Successfully encrypted an event");
message
}
pub async fn encrypt(
&mut self,
recipient_device: &DeviceData,
event_type: &str,
content: impl Serialize,
message_id: Option<String>,
) -> OlmResult<Raw<ToDeviceEncryptedEventContent>> {
let plaintext = {
let recipient_signing_key =
recipient_device.ed25519_key().ok_or(EventError::MissingSigningKey)?;
let payload = json!({
"sender": &self.our_device_keys.user_id,
"sender_device": &self.our_device_keys.device_id,
"keys": {
"ed25519": self.our_device_keys.ed25519_key().expect("Device doesn't have ed25519 key").to_base64(),
},
"org.matrix.msc4147.device_keys": self.our_device_keys,
"recipient": recipient_device.user_id(),
"recipient_keys": {
"ed25519": recipient_signing_key.to_base64(),
},
"type": event_type,
"content": content,
});
serde_json::to_string(&payload)?
};
let ciphertext = self.encrypt_helper(&plaintext).await;
let content = self.build_encrypted_event(ciphertext, message_id).await?;
let content = Raw::new(&content)?;
Ok(content)
}
pub(crate) async fn build_encrypted_event(
&self,
ciphertext: OlmMessage,
message_id: Option<String>,
) -> OlmResult<ToDeviceEncryptedEventContent> {
let content = match self.algorithm().await {
EventEncryptionAlgorithm::OlmV1Curve25519AesSha2 => OlmV1Curve25519AesSha2Content {
ciphertext,
recipient_key: self.sender_key,
sender_key: self
.our_device_keys
.curve25519_key()
.expect("Device doesn't have curve25519 key"),
message_id,
}
.into(),
#[cfg(feature = "experimental-algorithms")]
EventEncryptionAlgorithm::OlmV2Curve25519AesSha2 => OlmV2Curve25519AesSha2Content {
ciphertext,
sender_key: self
.our_device_keys
.curve25519_key()
.expect("Device doesn't have curve25519 key"),
message_id,
}
.into(),
_ => unreachable!(),
};
Ok(content)
}
pub fn session_id(&self) -> &str {
&self.session_id
}
pub async fn pickle(&self) -> PickledSession {
let pickle = self.inner.lock().await.pickle();
PickledSession {
pickle,
sender_key: self.sender_key,
created_using_fallback_key: self.created_using_fallback_key,
creation_time: self.creation_time,
last_use_time: self.last_use_time,
}
}
pub fn from_pickle(
our_device_keys: DeviceKeys,
pickle: PickledSession,
) -> Result<Self, SessionUnpickleError> {
if our_device_keys.curve25519_key().is_none() {
return Err(SessionUnpickleError::MissingIdentityKey);
}
if our_device_keys.ed25519_key().is_none() {
return Err(SessionUnpickleError::MissingSigningKey);
}
let session: vodozemac::olm::Session = pickle.pickle.into();
let session_id = session.session_id();
Ok(Session {
inner: Arc::new(Mutex::new(session)),
session_id: session_id.into(),
created_using_fallback_key: pickle.created_using_fallback_key,
sender_key: pickle.sender_key,
our_device_keys,
creation_time: pickle.creation_time,
last_use_time: pickle.last_use_time,
})
}
}
impl PartialEq for Session {
fn eq(&self, other: &Self) -> bool {
self.session_id() == other.session_id()
}
}
#[derive(Serialize, Deserialize)]
#[allow(missing_debug_implementations)]
pub struct PickledSession {
pub pickle: SessionPickle,
pub sender_key: Curve25519PublicKey,
#[serde(default)]
pub created_using_fallback_key: bool,
pub creation_time: SecondsSinceUnixEpoch,
pub last_use_time: SecondsSinceUnixEpoch,
}
#[cfg(test)]
mod tests {
use assert_matches2::assert_let;
use matrix_sdk_test::async_test;
use ruma::{device_id, user_id};
use serde_json::{self, Value};
use vodozemac::olm::{OlmMessage, SessionConfig};
use crate::{
identities::DeviceData,
olm::Account,
types::events::{
dummy::DummyEventContent, olm_v1::DecryptedOlmV1Event,
room::encrypted::ToDeviceEncryptedEventContent,
},
};
#[async_test]
async fn test_encryption_and_decryption() {
use ruma::events::dummy::ToDeviceDummyEventContent;
let alice =
Account::with_device_id(user_id!("@alice:localhost"), device_id!("ALICEDEVICE"));
let mut bob = Account::with_device_id(user_id!("@bob:localhost"), device_id!("BOBDEVICE"));
bob.generate_one_time_keys(1);
let one_time_key = *bob.one_time_keys().values().next().unwrap();
let sender_key = bob.identity_keys().curve25519;
let mut alice_session = alice.create_outbound_session_helper(
SessionConfig::default(),
sender_key,
one_time_key,
false,
alice.device_keys(),
);
let alice_device = DeviceData::from_account(&alice);
let message = alice_session
.encrypt(&alice_device, "m.dummy", ToDeviceDummyEventContent::new(), None)
.await
.unwrap()
.deserialize()
.unwrap();
#[cfg(feature = "experimental-algorithms")]
assert_let!(ToDeviceEncryptedEventContent::OlmV2Curve25519AesSha2(content) = message);
#[cfg(not(feature = "experimental-algorithms"))]
assert_let!(ToDeviceEncryptedEventContent::OlmV1Curve25519AesSha2(content) = message);
let prekey = if let OlmMessage::PreKey(m) = content.ciphertext {
m
} else {
panic!("Wrong Olm message type");
};
let bob_session_result = bob
.create_inbound_session(
alice_device.curve25519_key().unwrap(),
bob.device_keys(),
&prekey,
)
.unwrap();
let plaintext: Value = serde_json::from_str(&bob_session_result.plaintext).unwrap();
assert_eq!(
plaintext["org.matrix.msc4147.device_keys"]["user_id"].as_str(),
Some("@alice:localhost")
);
let event: DecryptedOlmV1Event<DummyEventContent> =
serde_json::from_str(&bob_session_result.plaintext).unwrap();
assert_eq!(event.sender_device_keys.unwrap(), alice.device_keys());
}
}