matrix_sdk_crypto/file_encryption/
attachments.rs1use std::io::{Error as IoError, Read};
16
17use aes::{
18 Aes256,
19 cipher::{KeyIvInit, StreamCipher},
20};
21use rand::{Rng, rng};
22use ruma::{
23 events::room::{
24 EncryptedFile, EncryptedFileHash, EncryptedFileHashAlgorithm, EncryptedFileHashes,
25 EncryptedFileInfo, V2EncryptedFileInfo,
26 },
27 serde::Base64,
28};
29use serde::{Deserialize, Serialize};
30use sha2::{Digest, Sha256};
31use thiserror::Error;
32
33const IV_SIZE: usize = 16;
34const KEY_SIZE: usize = 32;
35const HASH_SIZE: usize = 32;
36
37type Aes256Ctr = ctr::Ctr128BE<Aes256>;
38
39pub struct AttachmentDecryptor<'a, R: Read> {
42 inner: &'a mut R,
43 expected_hash: [u8; HASH_SIZE],
44 sha: Sha256,
45 aes: Aes256Ctr,
46}
47
48#[cfg(not(tarpaulin_include))]
49impl<'a, R: 'a + Read + std::fmt::Debug> std::fmt::Debug for AttachmentDecryptor<'a, R> {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 f.debug_struct("AttachmentDecryptor")
52 .field("inner", &self.inner)
53 .field("expected_hash", &self.expected_hash)
54 .finish()
55 }
56}
57
58impl<R: Read> Read for AttachmentDecryptor<'_, R> {
59 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
60 let read_bytes = self.inner.read(buf)?;
61
62 if read_bytes == 0 {
63 let hash = self.sha.finalize_reset();
64
65 if hash.as_slice() == self.expected_hash.as_slice() {
66 Ok(0)
67 } else {
68 Err(IoError::other("Hash mismatch while decrypting"))
69 }
70 } else {
71 self.sha.update(&buf[0..read_bytes]);
72 self.aes.apply_keystream(&mut buf[0..read_bytes]);
73
74 Ok(read_bytes)
75 }
76 }
77}
78
79#[derive(Error, Debug)]
81pub enum DecryptorError {
82 #[error(transparent)]
85 Decode(#[from] vodozemac::Base64DecodeError),
86 #[error("The encryption info is missing a hash")]
88 MissingHash,
89 #[error("Unknown version for the encrypted attachment.")]
92 UnknownVersion,
93}
94
95impl<'a, R: Read + 'a> AttachmentDecryptor<'a, R> {
96 pub fn new(
126 input: &'a mut R,
127 info: MediaEncryptionInfo,
128 ) -> Result<AttachmentDecryptor<'a, R>, DecryptorError> {
129 let EncryptedFileInfo::V2(encryption_info) = info.encryption_info else {
130 return Err(DecryptorError::UnknownVersion);
131 };
132
133 let Some(EncryptedFileHash::Sha256(hash)) =
134 info.hashes.get(&EncryptedFileHashAlgorithm::Sha256)
135 else {
136 return Err(DecryptorError::MissingHash);
137 };
138 let hash = hash.clone().into_inner();
139 let key = encryption_info.k.as_inner();
140 let iv = encryption_info.iv.as_inner();
141
142 let sha = Sha256::default();
143
144 let aes = Aes256Ctr::new(key.into(), iv.into());
145
146 Ok(AttachmentDecryptor { inner: input, expected_hash: hash, sha, aes })
147 }
148}
149
150pub struct AttachmentEncryptor<'a, R: Read + ?Sized> {
152 finished: bool,
153 inner: &'a mut R,
154 key: [u8; KEY_SIZE],
155 iv: [u8; IV_SIZE],
156 hashes: EncryptedFileHashes,
157 aes: Aes256Ctr,
158 sha: Sha256,
159}
160
161#[cfg(not(tarpaulin_include))]
162impl<'a, R: 'a + Read + std::fmt::Debug + ?Sized> std::fmt::Debug for AttachmentEncryptor<'a, R> {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.debug_struct("AttachmentEncryptor")
165 .field("inner", &self.inner)
166 .field("finished", &self.finished)
167 .finish()
168 }
169}
170
171impl<'a, R: Read + ?Sized + 'a> Read for AttachmentEncryptor<'a, R> {
172 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
173 let read_bytes = self.inner.read(buf)?;
174
175 if read_bytes == 0 {
176 Ok(0)
177 } else {
178 self.aes.apply_keystream(&mut buf[0..read_bytes]);
179 self.sha.update(&buf[0..read_bytes]);
180
181 Ok(read_bytes)
182 }
183 }
184}
185
186impl<'a, R: Read + ?Sized + 'a> AttachmentEncryptor<'a, R> {
187 pub fn new(reader: &'a mut R) -> Self {
216 let mut key = [0u8; KEY_SIZE];
217 let mut iv = [0u8; IV_SIZE];
218
219 let mut rng = rng();
220
221 rng.fill_bytes(&mut key);
222 rng.fill_bytes(&mut iv[0..8]);
225
226 let key_array = &key.into();
227
228 let aes = Aes256Ctr::new(key_array, &iv.into());
229
230 AttachmentEncryptor {
231 finished: false,
232 inner: reader,
233 iv,
234 key,
235 hashes: EncryptedFileHashes::new(),
236 aes,
237 sha: Sha256::default(),
238 }
239 }
240
241 pub fn finish(mut self) -> MediaEncryptionInfo {
243 let hash = self.sha.finalize();
244 self.hashes.insert(EncryptedFileHash::Sha256(Base64::new(hash.into())));
245
246 MediaEncryptionInfo {
247 encryption_info: V2EncryptedFileInfo::encode(self.key, self.iv).into(),
248 hashes: self.hashes,
249 }
250 }
251}
252
253#[derive(Debug, Serialize, Deserialize)]
256pub struct MediaEncryptionInfo {
257 #[serde(flatten)]
259 pub encryption_info: EncryptedFileInfo,
260 pub hashes: EncryptedFileHashes,
262}
263
264impl From<EncryptedFile> for MediaEncryptionInfo {
265 fn from(file: EncryptedFile) -> Self {
266 Self { encryption_info: file.info, hashes: file.hashes }
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use std::io::{Cursor, Read};
273
274 use serde_json::json;
275
276 use super::{AttachmentDecryptor, AttachmentEncryptor, MediaEncryptionInfo};
277
278 const EXAMPLE_DATA: &[u8] = &[
279 179, 154, 118, 127, 186, 127, 110, 33, 203, 33, 33, 134, 67, 100, 173, 46, 235, 27, 215,
280 172, 36, 26, 75, 47, 33, 160,
281 ];
282
283 fn example_key_json() -> serde_json::Value {
284 json!({
285 "v": "v2",
286 "key": {
287 "kty": "oct",
288 "alg": "A256CTR",
289 "ext": true,
290 "k": "Voq2nkPme_x8no5-Tjq_laDAdxE6iDbxnlQXxwFPgE4",
291 "key_ops": ["decrypt", "encrypt"]
292 },
293 "iv": "i0DovxYdJEcAAAAAAAAAAA",
294 "hashes": {
295 "sha256": "ANdt819a8bZl4jKy3Z+jcqtiNICa2y0AW4BBJ/iQRAU"
296 }
297 })
298 }
299
300 fn example_key() -> MediaEncryptionInfo {
301 serde_json::from_value(example_key_json()).unwrap()
302 }
303
304 #[test]
305 fn media_encryption_info_serde_roundtrip() {
306 let json = example_key_json();
307
308 let info = serde_json::from_value::<MediaEncryptionInfo>(json.clone()).unwrap();
309
310 let serialized_info = serde_json::to_value(&info).unwrap();
311 assert_eq!(serialized_info, json);
312 }
313
314 #[test]
315 fn encrypt_decrypt_cycle() {
316 let data = "Hello world".to_owned();
317 let mut cursor = Cursor::new(data.clone());
318
319 let mut encryptor = AttachmentEncryptor::new(&mut cursor);
320
321 let mut encrypted = Vec::new();
322
323 encryptor.read_to_end(&mut encrypted).unwrap();
324 let key = encryptor.finish();
325 assert_ne!(encrypted.as_slice(), data.as_bytes());
326
327 let mut cursor = Cursor::new(encrypted);
328 let mut decryptor = AttachmentDecryptor::new(&mut cursor, key).unwrap();
329 let mut decrypted_data = Vec::new();
330
331 decryptor.read_to_end(&mut decrypted_data).unwrap();
332
333 let decrypted = String::from_utf8(decrypted_data).unwrap();
334
335 assert_eq!(data, decrypted);
336 }
337
338 #[test]
339 fn real_decrypt() {
340 let mut cursor = Cursor::new(EXAMPLE_DATA.to_vec());
341 let key = example_key();
342
343 let mut decryptor = AttachmentDecryptor::new(&mut cursor, key).unwrap();
344 let mut decrypted_data = Vec::new();
345
346 decryptor.read_to_end(&mut decrypted_data).unwrap();
347 let decrypted = String::from_utf8(decrypted_data).unwrap();
348
349 assert_eq!("It's a secret to everybody", decrypted);
350 }
351
352 #[test]
353 fn decrypt_invalid_hash() {
354 let mut cursor = Cursor::new("fake message");
355 let key = example_key();
356
357 let mut decryptor = AttachmentDecryptor::new(&mut cursor, key).unwrap();
358 let mut decrypted_data = Vec::new();
359
360 decryptor.read_to_end(&mut decrypted_data).unwrap_err();
361 }
362}