1use std::{fmt, sync::Arc};
16
17use ruma::{serde::Raw, SecondsSinceUnixEpoch};
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use tokio::sync::Mutex;
21use tracing::{debug, Span};
22use vodozemac::{
23 olm::{DecryptionError, OlmMessage, Session as InnerSession, SessionConfig, SessionPickle},
24 Curve25519PublicKey,
25};
26
27#[cfg(feature = "experimental-algorithms")]
28use crate::types::events::room::encrypted::OlmV2Curve25519AesSha2Content;
29use crate::{
30 error::{EventError, OlmResult, SessionUnpickleError},
31 types::{
32 events::{
33 olm_v1::DecryptedOlmV1Event,
34 room::encrypted::{OlmV1Curve25519AesSha2Content, ToDeviceEncryptedEventContent},
35 EventType,
36 },
37 DeviceKeys, EventEncryptionAlgorithm,
38 },
39 DeviceData,
40};
41
42#[derive(Clone)]
45pub struct Session {
46 pub inner: Arc<Mutex<InnerSession>>,
48 pub session_id: Arc<str>,
50 pub sender_key: Curve25519PublicKey,
52 pub our_device_keys: DeviceKeys,
54 pub created_using_fallback_key: bool,
56 pub creation_time: SecondsSinceUnixEpoch,
58 pub last_use_time: SecondsSinceUnixEpoch,
60}
61
62#[cfg(not(tarpaulin_include))]
63impl fmt::Debug for Session {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 f.debug_struct("Session")
66 .field("session_id", &self.session_id())
67 .field("sender_key", &self.sender_key)
68 .finish()
69 }
70}
71
72impl Session {
73 pub async fn decrypt(&mut self, message: &OlmMessage) -> Result<String, DecryptionError> {
82 let mut inner = self.inner.lock().await;
83 Span::current().record("session_id", inner.session_id());
84
85 let plaintext = inner.decrypt(message)?;
86 debug!(session=?inner, "Decrypted an Olm message");
87
88 let plaintext = String::from_utf8_lossy(&plaintext).to_string();
89
90 self.last_use_time = SecondsSinceUnixEpoch::now();
91
92 Ok(plaintext)
93 }
94
95 pub fn sender_key(&self) -> Curve25519PublicKey {
97 self.sender_key
98 }
99
100 pub async fn session_config(&self) -> SessionConfig {
102 self.inner.lock().await.session_config()
103 }
104
105 #[allow(clippy::unused_async)] pub async fn algorithm(&self) -> EventEncryptionAlgorithm {
108 #[cfg(feature = "experimental-algorithms")]
109 if self.session_config().await.version() == 2 {
110 EventEncryptionAlgorithm::OlmV2Curve25519AesSha2
111 } else {
112 EventEncryptionAlgorithm::OlmV1Curve25519AesSha2
113 }
114
115 #[cfg(not(feature = "experimental-algorithms"))]
116 EventEncryptionAlgorithm::OlmV1Curve25519AesSha2
117 }
118
119 pub(crate) async fn encrypt_helper(&mut self, plaintext: &str) -> OlmMessage {
127 let mut session = self.inner.lock().await;
128 let message = session.encrypt(plaintext);
129 self.last_use_time = SecondsSinceUnixEpoch::now();
130 debug!(?session, "Successfully encrypted an event");
131 message
132 }
133
134 pub async fn encrypt(
147 &mut self,
148 recipient_device: &DeviceData,
149 event_type: &str,
150 content: impl Serialize,
151 message_id: Option<String>,
152 ) -> OlmResult<Raw<ToDeviceEncryptedEventContent>> {
153 #[derive(Debug)]
154 struct Content<'a> {
155 event_type: &'a str,
156 content: Raw<Value>,
157 }
158
159 impl EventType for Content<'_> {
160 const EVENT_TYPE: &'static str = "";
170
171 fn event_type(&self) -> &str {
172 self.event_type
173 }
174 }
175
176 impl Serialize for Content<'_> {
177 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
178 where
179 S: serde::Serializer,
180 {
181 self.content.serialize(serializer)
182 }
183 }
184
185 let plaintext = {
186 let content = serde_json::to_value(content)?;
187 let content = Content { event_type, content: Raw::new(&content)? };
188
189 let recipient_signing_key =
190 recipient_device.ed25519_key().ok_or(EventError::MissingSigningKey)?;
191
192 let content = DecryptedOlmV1Event {
193 sender: self.our_device_keys.user_id.clone(),
194 recipient: recipient_device.user_id().into(),
195 keys: crate::types::events::olm_v1::OlmV1Keys {
196 ed25519: self
197 .our_device_keys
198 .ed25519_key()
199 .expect("Our own device should have an Ed25519 public key"),
200 },
201 recipient_keys: crate::types::events::olm_v1::OlmV1Keys {
202 ed25519: recipient_signing_key,
203 },
204 sender_device_keys: Some(self.our_device_keys.clone()),
205 content,
206 };
207
208 serde_json::to_string(&content)?
209 };
210
211 let ciphertext = self.encrypt_helper(&plaintext).await;
212
213 let content = self.build_encrypted_event(ciphertext, message_id).await?;
214 let content = Raw::new(&content)?;
215 Ok(content)
216 }
217
218 pub(crate) async fn build_encrypted_event(
227 &self,
228 ciphertext: OlmMessage,
229 message_id: Option<String>,
230 ) -> OlmResult<ToDeviceEncryptedEventContent> {
231 let content = match self.algorithm().await {
232 EventEncryptionAlgorithm::OlmV1Curve25519AesSha2 => OlmV1Curve25519AesSha2Content {
233 ciphertext,
234 recipient_key: self.sender_key,
235 sender_key: self
236 .our_device_keys
237 .curve25519_key()
238 .expect("Device doesn't have curve25519 key"),
239 message_id,
240 }
241 .into(),
242 #[cfg(feature = "experimental-algorithms")]
243 EventEncryptionAlgorithm::OlmV2Curve25519AesSha2 => OlmV2Curve25519AesSha2Content {
244 ciphertext,
245 sender_key: self
246 .our_device_keys
247 .curve25519_key()
248 .expect("Device doesn't have curve25519 key"),
249 message_id,
250 }
251 .into(),
252 _ => unreachable!(),
253 };
254
255 Ok(content)
256 }
257
258 pub fn session_id(&self) -> &str {
260 &self.session_id
261 }
262
263 pub async fn pickle(&self) -> PickledSession {
270 let pickle = self.inner.lock().await.pickle();
271
272 PickledSession {
273 pickle,
274 sender_key: self.sender_key,
275 created_using_fallback_key: self.created_using_fallback_key,
276 creation_time: self.creation_time,
277 last_use_time: self.last_use_time,
278 }
279 }
280
281 pub fn from_pickle(
292 our_device_keys: DeviceKeys,
293 pickle: PickledSession,
294 ) -> Result<Self, SessionUnpickleError> {
295 if our_device_keys.curve25519_key().is_none() {
296 return Err(SessionUnpickleError::MissingIdentityKey);
297 }
298 if our_device_keys.ed25519_key().is_none() {
299 return Err(SessionUnpickleError::MissingSigningKey);
300 }
301
302 let session: vodozemac::olm::Session = pickle.pickle.into();
303 let session_id = session.session_id();
304
305 Ok(Session {
306 inner: Arc::new(Mutex::new(session)),
307 session_id: session_id.into(),
308 created_using_fallback_key: pickle.created_using_fallback_key,
309 sender_key: pickle.sender_key,
310 our_device_keys,
311 creation_time: pickle.creation_time,
312 last_use_time: pickle.last_use_time,
313 })
314 }
315}
316
317impl PartialEq for Session {
318 fn eq(&self, other: &Self) -> bool {
319 self.session_id() == other.session_id()
320 }
321}
322
323#[derive(Serialize, Deserialize)]
328#[allow(missing_debug_implementations)]
329pub struct PickledSession {
330 pub pickle: SessionPickle,
332 pub sender_key: Curve25519PublicKey,
334 #[serde(default)]
336 pub created_using_fallback_key: bool,
337 pub creation_time: SecondsSinceUnixEpoch,
339 pub last_use_time: SecondsSinceUnixEpoch,
341}
342
343#[cfg(test)]
344mod tests {
345 use assert_matches2::assert_let;
346 use matrix_sdk_test::async_test;
347 use ruma::{device_id, user_id};
348 use serde_json::{self, Value};
349 use vodozemac::olm::{OlmMessage, SessionConfig};
350
351 use crate::{
352 identities::DeviceData,
353 olm::Account,
354 types::events::{
355 dummy::DummyEventContent, olm_v1::DecryptedOlmV1Event,
356 room::encrypted::ToDeviceEncryptedEventContent,
357 },
358 };
359
360 #[async_test]
361 async fn test_encryption_and_decryption() {
362 use ruma::events::dummy::ToDeviceDummyEventContent;
363
364 let alice =
366 Account::with_device_id(user_id!("@alice:localhost"), device_id!("ALICEDEVICE"));
367 let mut bob = Account::with_device_id(user_id!("@bob:localhost"), device_id!("BOBDEVICE"));
368
369 bob.generate_one_time_keys(1);
371 let one_time_key = *bob.one_time_keys().values().next().unwrap();
372 let sender_key = bob.identity_keys().curve25519;
373 let mut alice_session = alice.create_outbound_session_helper(
374 SessionConfig::default(),
375 sender_key,
376 one_time_key,
377 false,
378 alice.device_keys(),
379 );
380
381 let alice_device = DeviceData::from_account(&alice);
382
383 let message = alice_session
385 .encrypt(&alice_device, "m.dummy", ToDeviceDummyEventContent::new(), None)
386 .await
387 .unwrap()
388 .deserialize()
389 .unwrap();
390
391 #[cfg(feature = "experimental-algorithms")]
392 assert_let!(ToDeviceEncryptedEventContent::OlmV2Curve25519AesSha2(content) = message);
393 #[cfg(not(feature = "experimental-algorithms"))]
394 assert_let!(ToDeviceEncryptedEventContent::OlmV1Curve25519AesSha2(content) = message);
395
396 let prekey = if let OlmMessage::PreKey(m) = content.ciphertext {
397 m
398 } else {
399 panic!("Wrong Olm message type");
400 };
401
402 let bob_session_result = bob
404 .create_inbound_session(
405 alice_device.curve25519_key().unwrap(),
406 bob.device_keys(),
407 &prekey,
408 )
409 .unwrap();
410
411 let plaintext: Value = serde_json::from_str(&bob_session_result.plaintext).unwrap();
414 assert_eq!(
415 plaintext["org.matrix.msc4147.device_keys"]["user_id"].as_str(),
416 Some("@alice:localhost")
417 );
418
419 let event: DecryptedOlmV1Event<DummyEventContent> =
422 serde_json::from_str(&bob_session_result.plaintext).unwrap();
423 assert_eq!(event.sender_device_keys.unwrap(), alice.device_keys());
424 }
425}