1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::io::{Cursor, Read, Seek, SeekFrom};

use byteorder::{BigEndian, ReadBytesExt};
use rand::{thread_rng, RngCore};
use serde_json::Error as SerdeError;
use thiserror::Error;
use vodozemac::{base64_decode, base64_encode};
use zeroize::Zeroize;

use crate::{
    ciphers::{AesHmacSha2Key, IV_SIZE, MAC_SIZE, SALT_SIZE},
    olm::ExportedRoomKey,
};

const VERSION: u8 = 1;

const HEADER: &str = "-----BEGIN MEGOLM SESSION DATA-----";
const FOOTER: &str = "-----END MEGOLM SESSION DATA-----";

/// Error representing a failure during key export or import.
#[derive(Error, Debug)]
pub enum KeyExportError {
    /// The key export doesn't contain valid headers.
    #[error("Invalid or missing key export headers.")]
    InvalidHeaders,
    /// The key export has been encrypted with an unsupported version.
    #[error("The key export has been encrypted with an unsupported version.")]
    UnsupportedVersion,
    /// The MAC of the encrypted payload is invalid.
    #[error("The MAC of the encrypted payload is invalid.")]
    InvalidMac,
    /// The decrypted key export isn't valid UTF-8.
    #[error(transparent)]
    InvalidUtf8(#[from] std::string::FromUtf8Error),
    /// The decrypted key export doesn't contain valid JSON.
    #[error(transparent)]
    Json(#[from] SerdeError),
    /// The key export string isn't valid base64.
    #[error(transparent)]
    Decode(#[from] vodozemac::Base64DecodeError),
    /// The key export doesn't all the required fields.
    #[error(transparent)]
    Io(#[from] std::io::Error),
}

/// Try to decrypt a reader into a list of exported room keys.
///
/// # Arguments
///
/// * `passphrase` - The passphrase that was used to encrypt the exported keys.
///
/// # Examples
///
/// ```no_run
/// # use std::io::Cursor;
/// # use matrix_sdk_crypto::{OlmMachine, decrypt_room_key_export};
/// # use ruma::{device_id, user_id};
/// # let alice = user_id!("@alice:example.org");
/// # async {
/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")).await;
/// # let export = Cursor::new("".to_owned());
/// let exported_keys = decrypt_room_key_export(export, "1234").unwrap();
/// machine.store().import_room_keys(exported_keys, None, |_, _| {}).await.unwrap();
/// # };
/// ```
pub fn decrypt_room_key_export(
    mut input: impl Read,
    passphrase: &str,
) -> Result<Vec<ExportedRoomKey>, KeyExportError> {
    let mut x: String = String::new();

    input.read_to_string(&mut x)?;

    if !(x.trim_start().starts_with(HEADER) && x.trim_end().ends_with(FOOTER)) {
        return Err(KeyExportError::InvalidHeaders);
    }

    let payload: String =
        x.lines().filter(|l| !(l.starts_with(HEADER) || l.starts_with(FOOTER))).collect();

    let mut decrypted = decrypt_helper(&payload, passphrase)?;

    let ret = serde_json::from_str(&decrypted);

    decrypted.zeroize();

    Ok(ret?)
}

/// Encrypt the list of exported room keys using the given passphrase.
///
/// # Arguments
///
/// * `keys` - A list of sessions that should be encrypted.
///
/// * `passphrase` - The passphrase that will be used to encrypt the exported
///   room keys.
///
/// * `rounds` - The number of rounds that should be used for the key derivation
///   when the passphrase gets turned into an AES key. More rounds are
///   increasingly computationally intensive and as such help against
///   brute-force attacks. Should be at least `10_000`, while values in the
///   `100_000` ranges should be preferred.
///
/// # Panics
///
/// This method will panic if it can't get enough randomness from the OS to
/// encrypt the exported keys securely.
///
/// # Examples
///
/// ```no_run
/// # use matrix_sdk_crypto::{OlmMachine, encrypt_room_key_export};
/// # use ruma::{device_id, user_id, room_id};
/// # let alice = user_id!("@alice:example.org");
/// # async {
/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")).await;
/// let room_id = room_id!("!test:localhost");
/// let exported_keys = machine.store().export_room_keys(|s| s.room_id() == room_id).await.unwrap();
/// let encrypted_export = encrypt_room_key_export(&exported_keys, "1234", 1);
/// # };
/// ```
pub fn encrypt_room_key_export(
    keys: &[ExportedRoomKey],
    passphrase: &str,
    rounds: u32,
) -> Result<String, SerdeError> {
    let mut plaintext = serde_json::to_string(keys)?.into_bytes();
    let ciphertext = encrypt_helper(&plaintext, passphrase, rounds);

    plaintext.zeroize();

    Ok([HEADER.to_owned(), ciphertext, FOOTER.to_owned()].join("\n"))
}

fn encrypt_helper(plaintext: &[u8], passphrase: &str, rounds: u32) -> String {
    let mut salt = [0u8; SALT_SIZE];
    let mut rng = thread_rng();

    rng.fill_bytes(&mut salt);

    let key = AesHmacSha2Key::from_passphrase(passphrase, rounds, &salt);
    let (ciphertext, initialization_vector) = key.encrypt(plaintext.to_owned());

    let mut payload = [
        VERSION.to_be_bytes().as_slice(),
        &salt,
        &initialization_vector,
        rounds.to_be_bytes().as_slice(),
        &ciphertext,
    ]
    .concat();

    let mac = key.create_mac_tag(&payload);
    payload.extend(mac.as_bytes());

    base64_encode(payload)
}

fn decrypt_helper(ciphertext: &str, passphrase: &str) -> Result<String, KeyExportError> {
    let decoded = base64_decode(ciphertext)?;

    let mut decoded = Cursor::new(decoded);

    let mut salt = [0u8; SALT_SIZE];
    let mut iv = [0u8; IV_SIZE];
    let mut mac = [0u8; MAC_SIZE];

    let version = decoded.read_u8()?;
    decoded.read_exact(&mut salt)?;
    decoded.read_exact(&mut iv)?;

    let rounds = decoded.read_u32::<BigEndian>()?;
    let ciphertext_start = decoded.position() as usize;

    decoded.seek(SeekFrom::End(-32))?;
    let ciphertext_end = decoded.position() as usize;

    decoded.read_exact(&mut mac)?;

    let mut decoded = decoded.into_inner();

    if version != VERSION {
        return Err(KeyExportError::UnsupportedVersion);
    }

    let key = AesHmacSha2Key::from_passphrase(passphrase, rounds, &salt);
    key.verify_mac(&decoded[0..ciphertext_end], &mac).map_err(|_| KeyExportError::InvalidMac)?;

    let ciphertext = &mut decoded[ciphertext_start..ciphertext_end];
    let plaintext = key.decrypt(ciphertext.to_owned(), &iv);
    let ret = String::from_utf8(plaintext);

    Ok(ret?)
}

#[cfg(all(test, not(target_arch = "wasm32")))]
mod proptests {
    use proptest::prelude::*;

    use super::{decrypt_helper, encrypt_helper};

    proptest! {
        #[test]
        fn proptest_encrypt_cycle(plaintext in prop::string::string_regex(".*").unwrap()) {
            let plaintext_bytes = plaintext.clone().into_bytes();

            let ciphertext = encrypt_helper(&plaintext_bytes, "test", 1);
            let decrypted = decrypt_helper(&ciphertext, "test").unwrap();

            prop_assert!(plaintext == decrypted);
        }
    }
}

#[cfg(test)]
mod tests {
    use std::{
        collections::{BTreeMap, BTreeSet},
        io::Cursor,
    };

    use indoc::indoc;
    use matrix_sdk_test::async_test;
    use ruma::{room_id, user_id};

    use super::{
        base64_decode, decrypt_helper, decrypt_room_key_export, encrypt_helper,
        encrypt_room_key_export,
    };
    use crate::{
        error::OlmResult, machine::test_helpers::get_prepared_machine_test_helper,
        RoomKeyImportResult,
    };

    const PASSPHRASE: &str = "1234";

    const TEST_EXPORT: &str = indoc! {"
        -----BEGIN MEGOLM SESSION DATA-----
        Af7mGhlzQ+eGvHu93u0YXd3D/+vYMs3E7gQqOhuCtkvGAAAAASH7pEdWvFyAP1JUisAcpEo
        Xke2Q7Kr9hVl/SCc6jXBNeJCZcrUbUV4D/tRQIl3E9L4fOk928YI1J+3z96qiH0uE7hpsCI
        CkHKwjPU+0XTzFdIk1X8H7sZ+MD/2Sg/q3y8rtUjz7uEj4GUTnb+9SCOTVmJsRfqgUpM1CU
        bDLytHf1JkohY4tWEgpsCc67xdzgodjr12qYrfg/zNm3LGpxlrffJknw4rk5QFTj4kMbqbD
        ZZgDTni+HxRTDGge2J620lMOiznvXX+H09Rwruqx5aJvvaaKd86jWRpiO2oSFqHn4u5ONl9
        41uzm62Sj0eIm6ZbA9NQs87jQw4LxsejhZVL+NdjIg80zVSBTWhTdo0DTnbFSNP4ReOiz0U
        XosOF8A5T8Vdx2nvA0GXltfcHKVKQYh/LJAkNQ7P9UYL4ae/5TtQZkhB1KxCLTRWqADCl53
        uBMGpG53EMgY6G6K2DEIOkcv7sdXQF5WpemiSWZqJRWj+cjfs9BpCTbkp/rszWFl2TniWpR
        RqIbT2jORlN4rTvdtF0F4z1pqP4qWyR3sLNTkXm9CFRzWADNG0RDZKxbCoo6RPvtaCTfaHo
        SwfvzBS6CjfAG+FOugpV48o7+XetaUUPZ6/tZSPhCdeV8eP9q5r0QwWeXFogzoNzWt4HYx9
        MdXxzD+f0mtg5gzehrrEEARwI2bCvPpHxlt/Na9oW/GBpkjwR1LSKgg4CtpRyWngPjdEKpZ
        GYW19pdjg0qdXNk/eqZsQTsNWVo6A
        -----END MEGOLM SESSION DATA-----
    "};

    fn export_without_headers() -> String {
        TEST_EXPORT.lines().filter(|l| !l.starts_with("-----")).collect()
    }

    #[test]
    fn test_decode() {
        let export = export_without_headers();
        base64_decode(export).unwrap();
    }

    #[test]
    fn test_encrypt_decrypt() {
        let data = "It's a secret to everybody";
        let bytes = data.to_owned().into_bytes();

        let encrypted = encrypt_helper(&bytes, PASSPHRASE, 10);
        let decrypted = decrypt_helper(&encrypted, PASSPHRASE).unwrap();

        assert_eq!(data, decrypted);
    }

    #[async_test]
    async fn test_session_encrypt() {
        let user_id = user_id!("@alice:localhost");
        let (machine, _) = get_prepared_machine_test_helper(user_id, false).await;
        let room_id = room_id!("!test:localhost");

        machine.create_outbound_group_session_with_defaults_test_helper(room_id).await.unwrap();
        let export = machine.store().export_room_keys(|s| s.room_id() == room_id).await.unwrap();

        assert!(!export.is_empty());

        let encrypted = encrypt_room_key_export(&export, "1234", 1).unwrap();
        let decrypted = decrypt_room_key_export(Cursor::new(encrypted), "1234").unwrap();

        for (exported, decrypted) in export.iter().zip(decrypted.iter()) {
            assert_eq!(exported.session_key.to_base64(), decrypted.session_key.to_base64());
        }

        assert_eq!(
            machine.store().import_exported_room_keys(decrypted, |_, _| {}).await.unwrap(),
            RoomKeyImportResult::new(0, 1, BTreeMap::new())
        );
    }

    #[async_test]
    async fn test_importing_better_session() -> OlmResult<()> {
        let user_id = user_id!("@alice:localhost");

        let (machine, _) = get_prepared_machine_test_helper(user_id, false).await;
        let room_id = room_id!("!test:localhost");
        let session = machine.create_inbound_session_test_helper(room_id).await?;

        let export = vec![session.export_at_index(10).await];

        let keys = RoomKeyImportResult::new(
            1,
            1,
            BTreeMap::from([(
                session.room_id().to_owned(),
                BTreeMap::from([(
                    session.sender_key().to_base64(),
                    BTreeSet::from([session.session_id().to_owned()]),
                )]),
            )]),
        );

        assert_eq!(machine.store().import_exported_room_keys(export, |_, _| {}).await?, keys);

        let export = vec![session.export_at_index(10).await];
        assert_eq!(
            machine.store().import_exported_room_keys(export, |_, _| {}).await?,
            RoomKeyImportResult::new(0, 1, BTreeMap::new())
        );

        let better_export = vec![session.export().await];

        assert_eq!(
            machine.store().import_exported_room_keys(better_export, |_, _| {}).await?,
            keys
        );

        let another_session = machine.create_inbound_session_test_helper(room_id).await?;
        let export = vec![another_session.export_at_index(10).await];

        let keys = RoomKeyImportResult::new(
            1,
            1,
            BTreeMap::from([(
                another_session.room_id().to_owned(),
                BTreeMap::from([(
                    another_session.sender_key().to_base64(),
                    BTreeSet::from([another_session.session_id().to_owned()]),
                )]),
            )]),
        );

        assert_eq!(machine.store().import_exported_room_keys(export, |_, _| {}).await?, keys);

        Ok(())
    }

    #[test]
    fn test_real_decrypt() {
        let reader = Cursor::new(TEST_EXPORT);
        let imported =
            decrypt_room_key_export(reader, PASSPHRASE).expect("Can't decrypt key export");
        assert!(!imported.is_empty())
    }
}