1use std::{fmt, sync::Arc};
16
17use ruma::{serde::Raw, SecondsSinceUnixEpoch};
18use serde::{Deserialize, Serialize};
19use serde_json::json;
20use tokio::sync::Mutex;
21use tracing::{debug, Span};
22use vodozemac::{
23 olm::{DecryptionError, OlmMessage, Session as InnerSession, SessionConfig, SessionPickle},
24 Curve25519PublicKey,
25};
26
27#[cfg(feature = "experimental-algorithms")]
28use crate::types::events::room::encrypted::OlmV2Curve25519AesSha2Content;
29use crate::{
30 error::{EventError, OlmResult, SessionUnpickleError},
31 types::{
32 events::room::encrypted::{OlmV1Curve25519AesSha2Content, ToDeviceEncryptedEventContent},
33 DeviceKeys, EventEncryptionAlgorithm,
34 },
35 DeviceData,
36};
37
38#[derive(Clone)]
41pub struct Session {
42 pub inner: Arc<Mutex<InnerSession>>,
44 pub session_id: Arc<str>,
46 pub sender_key: Curve25519PublicKey,
48 pub our_device_keys: DeviceKeys,
50 pub created_using_fallback_key: bool,
52 pub creation_time: SecondsSinceUnixEpoch,
54 pub last_use_time: SecondsSinceUnixEpoch,
56}
57
58#[cfg(not(tarpaulin_include))]
59impl fmt::Debug for Session {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 f.debug_struct("Session")
62 .field("session_id", &self.session_id())
63 .field("sender_key", &self.sender_key)
64 .finish()
65 }
66}
67
68impl Session {
69 pub async fn decrypt(&mut self, message: &OlmMessage) -> Result<String, DecryptionError> {
78 let mut inner = self.inner.lock().await;
79 Span::current().record("session_id", inner.session_id());
80
81 let plaintext = inner.decrypt(message)?;
82 debug!(session=?inner, "Decrypted an Olm message");
83
84 let plaintext = String::from_utf8_lossy(&plaintext).to_string();
85
86 self.last_use_time = SecondsSinceUnixEpoch::now();
87
88 Ok(plaintext)
89 }
90
91 pub fn sender_key(&self) -> Curve25519PublicKey {
93 self.sender_key
94 }
95
96 pub async fn session_config(&self) -> SessionConfig {
98 self.inner.lock().await.session_config()
99 }
100
101 #[allow(clippy::unused_async)] pub async fn algorithm(&self) -> EventEncryptionAlgorithm {
104 #[cfg(feature = "experimental-algorithms")]
105 if self.session_config().await.version() == 2 {
106 EventEncryptionAlgorithm::OlmV2Curve25519AesSha2
107 } else {
108 EventEncryptionAlgorithm::OlmV1Curve25519AesSha2
109 }
110
111 #[cfg(not(feature = "experimental-algorithms"))]
112 EventEncryptionAlgorithm::OlmV1Curve25519AesSha2
113 }
114
115 pub(crate) async fn encrypt_helper(&mut self, plaintext: &str) -> OlmMessage {
123 let mut session = self.inner.lock().await;
124 let message = session.encrypt(plaintext);
125 self.last_use_time = SecondsSinceUnixEpoch::now();
126 debug!(?session, "Successfully encrypted an event");
127 message
128 }
129
130 pub async fn encrypt(
143 &mut self,
144 recipient_device: &DeviceData,
145 event_type: &str,
146 content: impl Serialize,
147 message_id: Option<String>,
148 ) -> OlmResult<Raw<ToDeviceEncryptedEventContent>> {
149 let plaintext = {
150 let recipient_signing_key =
151 recipient_device.ed25519_key().ok_or(EventError::MissingSigningKey)?;
152
153 let payload = json!({
154 "sender": &self.our_device_keys.user_id,
155 "sender_device": &self.our_device_keys.device_id,
156 "keys": {
157 "ed25519": self.our_device_keys.ed25519_key().expect("Device doesn't have ed25519 key").to_base64(),
158 },
159 "org.matrix.msc4147.device_keys": self.our_device_keys,
160 "recipient": recipient_device.user_id(),
161 "recipient_keys": {
162 "ed25519": recipient_signing_key.to_base64(),
163 },
164 "type": event_type,
165 "content": content,
166 });
167
168 serde_json::to_string(&payload)?
169 };
170
171 let ciphertext = self.encrypt_helper(&plaintext).await;
172
173 let content = self.build_encrypted_event(ciphertext, message_id).await?;
174 let content = Raw::new(&content)?;
175 Ok(content)
176 }
177
178 pub(crate) async fn build_encrypted_event(
187 &self,
188 ciphertext: OlmMessage,
189 message_id: Option<String>,
190 ) -> OlmResult<ToDeviceEncryptedEventContent> {
191 let content = match self.algorithm().await {
192 EventEncryptionAlgorithm::OlmV1Curve25519AesSha2 => OlmV1Curve25519AesSha2Content {
193 ciphertext,
194 recipient_key: self.sender_key,
195 sender_key: self
196 .our_device_keys
197 .curve25519_key()
198 .expect("Device doesn't have curve25519 key"),
199 message_id,
200 }
201 .into(),
202 #[cfg(feature = "experimental-algorithms")]
203 EventEncryptionAlgorithm::OlmV2Curve25519AesSha2 => OlmV2Curve25519AesSha2Content {
204 ciphertext,
205 sender_key: self
206 .our_device_keys
207 .curve25519_key()
208 .expect("Device doesn't have curve25519 key"),
209 message_id,
210 }
211 .into(),
212 _ => unreachable!(),
213 };
214
215 Ok(content)
216 }
217
218 pub fn session_id(&self) -> &str {
220 &self.session_id
221 }
222
223 pub async fn pickle(&self) -> PickledSession {
230 let pickle = self.inner.lock().await.pickle();
231
232 PickledSession {
233 pickle,
234 sender_key: self.sender_key,
235 created_using_fallback_key: self.created_using_fallback_key,
236 creation_time: self.creation_time,
237 last_use_time: self.last_use_time,
238 }
239 }
240
241 pub fn from_pickle(
252 our_device_keys: DeviceKeys,
253 pickle: PickledSession,
254 ) -> Result<Self, SessionUnpickleError> {
255 if our_device_keys.curve25519_key().is_none() {
256 return Err(SessionUnpickleError::MissingIdentityKey);
257 }
258 if our_device_keys.ed25519_key().is_none() {
259 return Err(SessionUnpickleError::MissingSigningKey);
260 }
261
262 let session: vodozemac::olm::Session = pickle.pickle.into();
263 let session_id = session.session_id();
264
265 Ok(Session {
266 inner: Arc::new(Mutex::new(session)),
267 session_id: session_id.into(),
268 created_using_fallback_key: pickle.created_using_fallback_key,
269 sender_key: pickle.sender_key,
270 our_device_keys,
271 creation_time: pickle.creation_time,
272 last_use_time: pickle.last_use_time,
273 })
274 }
275}
276
277impl PartialEq for Session {
278 fn eq(&self, other: &Self) -> bool {
279 self.session_id() == other.session_id()
280 }
281}
282
283#[derive(Serialize, Deserialize)]
288#[allow(missing_debug_implementations)]
289pub struct PickledSession {
290 pub pickle: SessionPickle,
292 pub sender_key: Curve25519PublicKey,
294 #[serde(default)]
296 pub created_using_fallback_key: bool,
297 pub creation_time: SecondsSinceUnixEpoch,
299 pub last_use_time: SecondsSinceUnixEpoch,
301}
302
303#[cfg(test)]
304mod tests {
305 use assert_matches2::assert_let;
306 use matrix_sdk_test::async_test;
307 use ruma::{device_id, user_id};
308 use serde_json::{self, Value};
309 use vodozemac::olm::{OlmMessage, SessionConfig};
310
311 use crate::{
312 identities::DeviceData,
313 olm::Account,
314 types::events::{
315 dummy::DummyEventContent, olm_v1::DecryptedOlmV1Event,
316 room::encrypted::ToDeviceEncryptedEventContent,
317 },
318 };
319
320 #[async_test]
321 async fn test_encryption_and_decryption() {
322 use ruma::events::dummy::ToDeviceDummyEventContent;
323
324 let alice =
326 Account::with_device_id(user_id!("@alice:localhost"), device_id!("ALICEDEVICE"));
327 let mut bob = Account::with_device_id(user_id!("@bob:localhost"), device_id!("BOBDEVICE"));
328
329 bob.generate_one_time_keys(1);
331 let one_time_key = *bob.one_time_keys().values().next().unwrap();
332 let sender_key = bob.identity_keys().curve25519;
333 let mut alice_session = alice.create_outbound_session_helper(
334 SessionConfig::default(),
335 sender_key,
336 one_time_key,
337 false,
338 alice.device_keys(),
339 );
340
341 let alice_device = DeviceData::from_account(&alice);
342
343 let message = alice_session
345 .encrypt(&alice_device, "m.dummy", ToDeviceDummyEventContent::new(), None)
346 .await
347 .unwrap()
348 .deserialize()
349 .unwrap();
350
351 #[cfg(feature = "experimental-algorithms")]
352 assert_let!(ToDeviceEncryptedEventContent::OlmV2Curve25519AesSha2(content) = message);
353 #[cfg(not(feature = "experimental-algorithms"))]
354 assert_let!(ToDeviceEncryptedEventContent::OlmV1Curve25519AesSha2(content) = message);
355
356 let prekey = if let OlmMessage::PreKey(m) = content.ciphertext {
357 m
358 } else {
359 panic!("Wrong Olm message type");
360 };
361
362 let bob_session_result = bob
364 .create_inbound_session(
365 alice_device.curve25519_key().unwrap(),
366 bob.device_keys(),
367 &prekey,
368 )
369 .unwrap();
370
371 let plaintext: Value = serde_json::from_str(&bob_session_result.plaintext).unwrap();
374 assert_eq!(
375 plaintext["org.matrix.msc4147.device_keys"]["user_id"].as_str(),
376 Some("@alice:localhost")
377 );
378
379 let event: DecryptedOlmV1Event<DummyEventContent> =
382 serde_json::from_str(&bob_session_result.plaintext).unwrap();
383 assert_eq!(event.sender_device_keys.unwrap(), alice.device_keys());
384 }
385}