Skip to main content

matrix_sdk_crypto/file_encryption/
attachments.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::io::{Error as IoError, Read};
16
17use aes::{
18    Aes256,
19    cipher::{KeyIvInit, StreamCipher},
20};
21use rand::{Rng, rng};
22use ruma::{
23    events::room::{
24        EncryptedFile, EncryptedFileHash, EncryptedFileHashAlgorithm, EncryptedFileHashes,
25        EncryptedFileInfo, V2EncryptedFileInfo,
26    },
27    serde::Base64,
28};
29use serde::{Deserialize, Serialize};
30use sha2::{Digest, Sha256};
31use thiserror::Error;
32
33const IV_SIZE: usize = 16;
34const KEY_SIZE: usize = 32;
35const HASH_SIZE: usize = 32;
36
37type Aes256Ctr = ctr::Ctr128BE<Aes256>;
38
39/// A wrapper that transparently encrypts anything that implements `Read` as an
40/// Matrix attachment.
41pub struct AttachmentDecryptor<'a, R: Read> {
42    inner: &'a mut R,
43    expected_hash: [u8; HASH_SIZE],
44    sha: Sha256,
45    aes: Aes256Ctr,
46}
47
48#[cfg(not(tarpaulin_include))]
49impl<'a, R: 'a + Read + std::fmt::Debug> std::fmt::Debug for AttachmentDecryptor<'a, R> {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        f.debug_struct("AttachmentDecryptor")
52            .field("inner", &self.inner)
53            .field("expected_hash", &self.expected_hash)
54            .finish()
55    }
56}
57
58impl<R: Read> Read for AttachmentDecryptor<'_, R> {
59    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
60        let read_bytes = self.inner.read(buf)?;
61
62        if read_bytes == 0 {
63            let hash = self.sha.finalize_reset();
64
65            if hash.as_slice() == self.expected_hash.as_slice() {
66                Ok(0)
67            } else {
68                Err(IoError::other("Hash mismatch while decrypting"))
69            }
70        } else {
71            self.sha.update(&buf[0..read_bytes]);
72            self.aes.apply_keystream(&mut buf[0..read_bytes]);
73
74            Ok(read_bytes)
75        }
76    }
77}
78
79/// Error type for attachment decryption.
80#[derive(Error, Debug)]
81pub enum DecryptorError {
82    /// Some data in the encrypted attachment coldn't be decoded, this may be a
83    /// hash, the secret key, or the initialization vector.
84    #[error(transparent)]
85    Decode(#[from] vodozemac::Base64DecodeError),
86    /// A hash is missing from the encryption info.
87    #[error("The encryption info is missing a hash")]
88    MissingHash,
89    /// The supplied data was encrypted with an unknown version of the
90    /// attachment encryption spec.
91    #[error("Unknown version for the encrypted attachment.")]
92    UnknownVersion,
93}
94
95impl<'a, R: Read + 'a> AttachmentDecryptor<'a, R> {
96    /// Wrap the given reader decrypting all the data we read from it.
97    ///
98    /// # Arguments
99    ///
100    /// * `reader` - The `Reader` that should be wrapped and decrypted.
101    ///
102    /// * `info` - The encryption info that is necessary to decrypt data from
103    ///   the reader.
104    ///
105    /// # Examples
106    /// ```
107    /// # use std::io::{Cursor, Read};
108    /// # use matrix_sdk_crypto::{AttachmentEncryptor, AttachmentDecryptor};
109    /// let data = "Hello world".to_owned();
110    /// let mut cursor = Cursor::new(data.clone());
111    ///
112    /// let mut encryptor = AttachmentEncryptor::new(&mut cursor);
113    ///
114    /// let mut encrypted = Vec::new();
115    /// encryptor.read_to_end(&mut encrypted).unwrap();
116    /// let info = encryptor.finish();
117    ///
118    /// let mut cursor = Cursor::new(encrypted);
119    /// let mut decryptor = AttachmentDecryptor::new(&mut cursor, info).unwrap();
120    /// let mut decrypted_data = Vec::new();
121    /// decryptor.read_to_end(&mut decrypted_data).unwrap();
122    ///
123    /// let decrypted = String::from_utf8(decrypted_data).unwrap();
124    /// ```
125    pub fn new(
126        input: &'a mut R,
127        info: MediaEncryptionInfo,
128    ) -> Result<AttachmentDecryptor<'a, R>, DecryptorError> {
129        let EncryptedFileInfo::V2(encryption_info) = info.encryption_info else {
130            return Err(DecryptorError::UnknownVersion);
131        };
132
133        let Some(EncryptedFileHash::Sha256(hash)) =
134            info.hashes.get(&EncryptedFileHashAlgorithm::Sha256)
135        else {
136            return Err(DecryptorError::MissingHash);
137        };
138        let hash = hash.clone().into_inner();
139        let key = encryption_info.k.as_inner();
140        let iv = encryption_info.iv.as_inner();
141
142        let sha = Sha256::default();
143
144        let aes = Aes256Ctr::new(key.into(), iv.into());
145
146        Ok(AttachmentDecryptor { inner: input, expected_hash: hash, sha, aes })
147    }
148}
149
150/// A wrapper that transparently encrypts anything that implements `Read`.
151pub struct AttachmentEncryptor<'a, R: Read + ?Sized> {
152    finished: bool,
153    inner: &'a mut R,
154    key: [u8; KEY_SIZE],
155    iv: [u8; IV_SIZE],
156    hashes: EncryptedFileHashes,
157    aes: Aes256Ctr,
158    sha: Sha256,
159}
160
161#[cfg(not(tarpaulin_include))]
162impl<'a, R: 'a + Read + std::fmt::Debug + ?Sized> std::fmt::Debug for AttachmentEncryptor<'a, R> {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        f.debug_struct("AttachmentEncryptor")
165            .field("inner", &self.inner)
166            .field("finished", &self.finished)
167            .finish()
168    }
169}
170
171impl<'a, R: Read + ?Sized + 'a> Read for AttachmentEncryptor<'a, R> {
172    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
173        let read_bytes = self.inner.read(buf)?;
174
175        if read_bytes == 0 {
176            Ok(0)
177        } else {
178            self.aes.apply_keystream(&mut buf[0..read_bytes]);
179            self.sha.update(&buf[0..read_bytes]);
180
181            Ok(read_bytes)
182        }
183    }
184}
185
186impl<'a, R: Read + ?Sized + 'a> AttachmentEncryptor<'a, R> {
187    /// Wrap the given reader encrypting all the data we read from it.
188    ///
189    /// After all the reads are done, and all the data is encrypted that we wish
190    /// to encrypt a call to [`finish()`](#method.finish) is necessary to get
191    /// the decryption key for the data.
192    ///
193    /// # Arguments
194    ///
195    /// * `reader` - The `Reader` that should be wrapped and encrypted.
196    ///
197    /// # Panics
198    ///
199    /// Panics if we can't generate enough random data to create a fresh
200    /// encryption key.
201    ///
202    /// # Examples
203    /// ```
204    /// # use std::io::{Cursor, Read};
205    /// # use matrix_sdk_crypto::AttachmentEncryptor;
206    /// let data = "Hello world".to_owned();
207    /// let mut cursor = Cursor::new(data.clone());
208    ///
209    /// let mut encryptor = AttachmentEncryptor::new(&mut cursor);
210    ///
211    /// let mut encrypted = Vec::new();
212    /// encryptor.read_to_end(&mut encrypted).unwrap();
213    /// let key = encryptor.finish();
214    /// ```
215    pub fn new(reader: &'a mut R) -> Self {
216        let mut key = [0u8; KEY_SIZE];
217        let mut iv = [0u8; IV_SIZE];
218
219        let mut rng = rng();
220
221        rng.fill_bytes(&mut key);
222        // Only populate the first 8 bytes with randomness, the rest is 0
223        // initialized for the counter.
224        rng.fill_bytes(&mut iv[0..8]);
225
226        let key_array = &key.into();
227
228        let aes = Aes256Ctr::new(key_array, &iv.into());
229
230        AttachmentEncryptor {
231            finished: false,
232            inner: reader,
233            iv,
234            key,
235            hashes: EncryptedFileHashes::new(),
236            aes,
237            sha: Sha256::default(),
238        }
239    }
240
241    /// Consume the encryptor and get the encryption key.
242    pub fn finish(mut self) -> MediaEncryptionInfo {
243        let hash = self.sha.finalize();
244        self.hashes.insert(EncryptedFileHash::Sha256(Base64::new(hash.into())));
245
246        MediaEncryptionInfo {
247            encryption_info: V2EncryptedFileInfo::encode(self.key, self.iv).into(),
248            hashes: self.hashes,
249        }
250    }
251}
252
253/// Struct holding all the information that is needed to decrypt an encrypted
254/// file.
255#[derive(Debug, Serialize, Deserialize)]
256pub struct MediaEncryptionInfo {
257    /// The information about the file's encryption.
258    #[serde(flatten)]
259    pub encryption_info: EncryptedFileInfo,
260    /// The hashes that can be used to check the validity of the file.
261    pub hashes: EncryptedFileHashes,
262}
263
264impl From<EncryptedFile> for MediaEncryptionInfo {
265    fn from(file: EncryptedFile) -> Self {
266        Self { encryption_info: file.info, hashes: file.hashes }
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use std::io::{Cursor, Read};
273
274    use serde_json::json;
275
276    use super::{AttachmentDecryptor, AttachmentEncryptor, MediaEncryptionInfo};
277
278    const EXAMPLE_DATA: &[u8] = &[
279        179, 154, 118, 127, 186, 127, 110, 33, 203, 33, 33, 134, 67, 100, 173, 46, 235, 27, 215,
280        172, 36, 26, 75, 47, 33, 160,
281    ];
282
283    fn example_key_json() -> serde_json::Value {
284        json!({
285            "v": "v2",
286            "key": {
287                "kty": "oct",
288                "alg": "A256CTR",
289                "ext": true,
290                "k": "Voq2nkPme_x8no5-Tjq_laDAdxE6iDbxnlQXxwFPgE4",
291                "key_ops": ["decrypt", "encrypt"]
292            },
293            "iv": "i0DovxYdJEcAAAAAAAAAAA",
294            "hashes": {
295                "sha256": "ANdt819a8bZl4jKy3Z+jcqtiNICa2y0AW4BBJ/iQRAU"
296            }
297        })
298    }
299
300    fn example_key() -> MediaEncryptionInfo {
301        serde_json::from_value(example_key_json()).unwrap()
302    }
303
304    #[test]
305    fn media_encryption_info_serde_roundtrip() {
306        let json = example_key_json();
307
308        let info = serde_json::from_value::<MediaEncryptionInfo>(json.clone()).unwrap();
309
310        let serialized_info = serde_json::to_value(&info).unwrap();
311        assert_eq!(serialized_info, json);
312    }
313
314    #[test]
315    fn encrypt_decrypt_cycle() {
316        let data = "Hello world".to_owned();
317        let mut cursor = Cursor::new(data.clone());
318
319        let mut encryptor = AttachmentEncryptor::new(&mut cursor);
320
321        let mut encrypted = Vec::new();
322
323        encryptor.read_to_end(&mut encrypted).unwrap();
324        let key = encryptor.finish();
325        assert_ne!(encrypted.as_slice(), data.as_bytes());
326
327        let mut cursor = Cursor::new(encrypted);
328        let mut decryptor = AttachmentDecryptor::new(&mut cursor, key).unwrap();
329        let mut decrypted_data = Vec::new();
330
331        decryptor.read_to_end(&mut decrypted_data).unwrap();
332
333        let decrypted = String::from_utf8(decrypted_data).unwrap();
334
335        assert_eq!(data, decrypted);
336    }
337
338    #[test]
339    fn real_decrypt() {
340        let mut cursor = Cursor::new(EXAMPLE_DATA.to_vec());
341        let key = example_key();
342
343        let mut decryptor = AttachmentDecryptor::new(&mut cursor, key).unwrap();
344        let mut decrypted_data = Vec::new();
345
346        decryptor.read_to_end(&mut decrypted_data).unwrap();
347        let decrypted = String::from_utf8(decrypted_data).unwrap();
348
349        assert_eq!("It's a secret to everybody", decrypted);
350    }
351
352    #[test]
353    fn decrypt_invalid_hash() {
354        let mut cursor = Cursor::new("fake message");
355        let key = example_key();
356
357        let mut decryptor = AttachmentDecryptor::new(&mut cursor, key).unwrap();
358        let mut decrypted_data = Vec::new();
359
360        decryptor.read_to_end(&mut decrypted_data).unwrap_err();
361    }
362}