matrix_sdk_crypto/file_encryption/
attachments.rs1use std::{
16 collections::BTreeMap,
17 io::{Error as IoError, ErrorKind, Read},
18};
19
20use aes::{
21 cipher::{generic_array::GenericArray, KeyIvInit, StreamCipher},
22 Aes256,
23};
24use rand::{thread_rng, RngCore};
25use ruma::{
26 events::room::{EncryptedFile, JsonWebKey, JsonWebKeyInit},
27 serde::Base64,
28};
29use serde::{Deserialize, Serialize};
30use sha2::{Digest, Sha256};
31use thiserror::Error;
32use zeroize::Zeroize;
33
34const IV_SIZE: usize = 16;
35const KEY_SIZE: usize = 32;
36const VERSION: &str = "v2";
37
38type Aes256Ctr = ctr::Ctr128BE<Aes256>;
39
40pub struct AttachmentDecryptor<'a, R: Read> {
43 inner: &'a mut R,
44 expected_hash: Vec<u8>,
45 sha: Sha256,
46 aes: Aes256Ctr,
47}
48
49#[cfg(not(tarpaulin_include))]
50impl<'a, R: 'a + Read + std::fmt::Debug> std::fmt::Debug for AttachmentDecryptor<'a, R> {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.debug_struct("AttachmentDecryptor")
53 .field("inner", &self.inner)
54 .field("expected_hash", &self.expected_hash)
55 .finish()
56 }
57}
58
59impl<R: Read> Read for AttachmentDecryptor<'_, R> {
60 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
61 let read_bytes = self.inner.read(buf)?;
62
63 if read_bytes == 0 {
64 let hash = self.sha.finalize_reset();
65
66 if hash.as_slice() == self.expected_hash.as_slice() {
67 Ok(0)
68 } else {
69 Err(IoError::new(ErrorKind::Other, "Hash mismatch while decrypting"))
70 }
71 } else {
72 self.sha.update(&buf[0..read_bytes]);
73 self.aes.apply_keystream(&mut buf[0..read_bytes]);
74
75 Ok(read_bytes)
76 }
77 }
78}
79
80#[derive(Error, Debug)]
82pub enum DecryptorError {
83 #[error(transparent)]
86 Decode(#[from] vodozemac::Base64DecodeError),
87 #[error("The encryption info is missing a hash")]
89 MissingHash,
90 #[error("The supplied key or IV has an invalid length.")]
92 KeyNonceLength,
93 #[error("Unknown version for the encrypted attachment.")]
96 UnknownVersion,
97}
98
99impl<'a, R: Read + 'a> AttachmentDecryptor<'a, R> {
100 pub fn new(
130 input: &'a mut R,
131 info: MediaEncryptionInfo,
132 ) -> Result<AttachmentDecryptor<'a, R>, DecryptorError> {
133 if info.version != VERSION {
134 return Err(DecryptorError::UnknownVersion);
135 }
136
137 let hash =
138 info.hashes.get("sha256").ok_or(DecryptorError::MissingHash)?.as_bytes().to_owned();
139 let mut key = info.key.k.into_inner();
140 let iv = info.iv.into_inner();
141
142 if key.len() != KEY_SIZE {
143 return Err(DecryptorError::KeyNonceLength);
144 }
145
146 let key_array = GenericArray::from_slice(&key);
147 let iv = GenericArray::from_exact_iter(iv).ok_or(DecryptorError::KeyNonceLength)?;
148
149 let sha = Sha256::default();
150
151 let aes = Aes256Ctr::new(key_array, &iv);
152 key.zeroize();
153
154 Ok(AttachmentDecryptor { inner: input, expected_hash: hash, sha, aes })
155 }
156}
157
158pub struct AttachmentEncryptor<'a, R: Read + ?Sized> {
160 finished: bool,
161 inner: &'a mut R,
162 web_key: JsonWebKey,
163 iv: Base64,
164 hashes: BTreeMap<String, Base64>,
165 aes: Aes256Ctr,
166 sha: Sha256,
167}
168
169#[cfg(not(tarpaulin_include))]
170impl<'a, R: 'a + Read + std::fmt::Debug + ?Sized> std::fmt::Debug for AttachmentEncryptor<'a, R> {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 f.debug_struct("AttachmentEncryptor")
173 .field("inner", &self.inner)
174 .field("finished", &self.finished)
175 .finish()
176 }
177}
178
179impl<'a, R: Read + ?Sized + 'a> Read for AttachmentEncryptor<'a, R> {
180 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
181 let read_bytes = self.inner.read(buf)?;
182
183 if read_bytes == 0 {
184 let hash = self.sha.finalize_reset();
185 self.hashes
186 .entry("sha256".to_owned())
187 .or_insert_with(|| Base64::new(hash.as_slice().to_owned()));
188 Ok(0)
189 } else {
190 self.aes.apply_keystream(&mut buf[0..read_bytes]);
191 self.sha.update(&buf[0..read_bytes]);
192
193 Ok(read_bytes)
194 }
195 }
196}
197
198impl<'a, R: Read + ?Sized + 'a> AttachmentEncryptor<'a, R> {
199 pub fn new(reader: &'a mut R) -> Self {
228 let mut key = [0u8; KEY_SIZE];
229 let mut iv = [0u8; IV_SIZE];
230
231 let mut rng = thread_rng();
232
233 rng.fill_bytes(&mut key);
234 rng.fill_bytes(&mut iv[0..8]);
237
238 let web_key = JsonWebKey::from(JsonWebKeyInit {
239 kty: "oct".to_owned(),
240 key_ops: vec!["encrypt".to_owned(), "decrypt".to_owned()],
241 alg: "A256CTR".to_owned(),
242 #[allow(clippy::unnecessary_to_owned)]
243 k: Base64::new(key.to_vec()),
244 ext: true,
245 });
246 #[allow(clippy::unnecessary_to_owned)]
247 let encoded_iv = Base64::new(iv.to_vec());
248
249 let key_array = &key.into();
250
251 let aes = Aes256Ctr::new(key_array, &iv.into());
252 key.zeroize();
253
254 AttachmentEncryptor {
255 finished: false,
256 inner: reader,
257 iv: encoded_iv,
258 web_key,
259 hashes: BTreeMap::new(),
260 aes,
261 sha: Sha256::default(),
262 }
263 }
264
265 pub fn finish(mut self) -> MediaEncryptionInfo {
267 let hash = self.sha.finalize();
268 self.hashes
269 .entry("sha256".to_owned())
270 .or_insert_with(|| Base64::new(hash.as_slice().to_owned()));
271
272 MediaEncryptionInfo {
273 version: VERSION.to_owned(),
274 hashes: self.hashes,
275 iv: self.iv,
276 key: self.web_key,
277 }
278 }
279}
280
281#[derive(Debug, Serialize, Deserialize)]
284pub struct MediaEncryptionInfo {
285 #[serde(rename = "v")]
287 pub version: String,
288 pub key: JsonWebKey,
290 pub iv: Base64,
292 pub hashes: BTreeMap<String, Base64>,
294}
295
296impl From<EncryptedFile> for MediaEncryptionInfo {
297 fn from(file: EncryptedFile) -> Self {
298 Self { version: file.v, key: file.key, iv: file.iv, hashes: file.hashes }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use std::io::{Cursor, Read};
305
306 use serde_json::json;
307
308 use super::{AttachmentDecryptor, AttachmentEncryptor, MediaEncryptionInfo};
309
310 const EXAMPLE_DATA: &[u8] = &[
311 179, 154, 118, 127, 186, 127, 110, 33, 203, 33, 33, 134, 67, 100, 173, 46, 235, 27, 215,
312 172, 36, 26, 75, 47, 33, 160,
313 ];
314
315 fn example_key() -> MediaEncryptionInfo {
316 let info = json!({
317 "v": "v2",
318 "key": {
319 "kty": "oct",
320 "alg": "A256CTR",
321 "ext": true,
322 "k": "Voq2nkPme_x8no5-Tjq_laDAdxE6iDbxnlQXxwFPgE4",
323 "key_ops": ["encrypt", "decrypt"]
324 },
325 "iv": "i0DovxYdJEcAAAAAAAAAAA",
326 "hashes": {
327 "sha256": "ANdt819a8bZl4jKy3Z+jcqtiNICa2y0AW4BBJ/iQRAU"
328 }
329 });
330
331 serde_json::from_value(info).unwrap()
332 }
333
334 #[test]
335 fn encrypt_decrypt_cycle() {
336 let data = "Hello world".to_owned();
337 let mut cursor = Cursor::new(data.clone());
338
339 let mut encryptor = AttachmentEncryptor::new(&mut cursor);
340
341 let mut encrypted = Vec::new();
342
343 encryptor.read_to_end(&mut encrypted).unwrap();
344 let key = encryptor.finish();
345 assert_ne!(encrypted.as_slice(), data.as_bytes());
346
347 let mut cursor = Cursor::new(encrypted);
348 let mut decryptor = AttachmentDecryptor::new(&mut cursor, key).unwrap();
349 let mut decrypted_data = Vec::new();
350
351 decryptor.read_to_end(&mut decrypted_data).unwrap();
352
353 let decrypted = String::from_utf8(decrypted_data).unwrap();
354
355 assert_eq!(data, decrypted);
356 }
357
358 #[test]
359 fn real_decrypt() {
360 let mut cursor = Cursor::new(EXAMPLE_DATA.to_vec());
361 let key = example_key();
362
363 let mut decryptor = AttachmentDecryptor::new(&mut cursor, key).unwrap();
364 let mut decrypted_data = Vec::new();
365
366 decryptor.read_to_end(&mut decrypted_data).unwrap();
367 let decrypted = String::from_utf8(decrypted_data).unwrap();
368
369 assert_eq!("It's a secret to everybody", decrypted);
370 }
371
372 #[test]
373 fn decrypt_invalid_hash() {
374 let mut cursor = Cursor::new("fake message");
375 let key = example_key();
376
377 let mut decryptor = AttachmentDecryptor::new(&mut cursor, key).unwrap();
378 let mut decrypted_data = Vec::new();
379
380 decryptor.read_to_end(&mut decrypted_data).unwrap_err();
381 }
382}