vodozemac/olm/account/
mod.rs

1// Copyright 2021 Damir Jelić, Denis Kasak
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod fallback_keys;
16mod one_time_keys;
17
18use std::collections::HashMap;
19
20use chacha20poly1305::{
21    ChaCha20Poly1305,
22    aead::{Aead, AeadCore, KeyInit},
23};
24use rand::thread_rng;
25use serde::{Deserialize, Serialize};
26use thiserror::Error;
27use x25519_dalek::ReusableSecret;
28use zeroize::Zeroize;
29
30pub use self::one_time_keys::OneTimeKeyGenerationResult;
31use self::{
32    fallback_keys::FallbackKeys,
33    one_time_keys::{OneTimeKeys, OneTimeKeysPickle},
34};
35use super::{
36    SessionConfig,
37    messages::PreKeyMessage,
38    session::{DecryptionError, Session},
39    session_keys::SessionKeys,
40    shared_secret::{RemoteShared3DHSecret, Shared3DHSecret},
41};
42use crate::{
43    Ed25519Signature, PickleError,
44    types::{
45        Curve25519Keypair, Curve25519KeypairPickle, Curve25519PublicKey, Curve25519SecretKey,
46        Ed25519Keypair, Ed25519KeypairPickle, Ed25519PublicKey, KeyId,
47    },
48    utilities::{pickle, unpickle},
49};
50
51const PUBLIC_MAX_ONE_TIME_KEYS: usize = 50;
52
53/// Error describing failure modes when creating a Olm [`Session`] from an
54/// incoming Olm message.
55#[derive(Error, Debug)]
56pub enum SessionCreationError {
57    /// The pre-key message contained an unknown one-time key. This happens
58    /// either because we never had such a one-time key, or because it has
59    /// already been used up.
60    #[error("The pre-key message contained an unknown one-time key: {0}")]
61    MissingOneTimeKey(Curve25519PublicKey),
62    /// The pre-key message contains a Curve25519 identity key that doesn't
63    /// match to the identity key that was given.
64    #[error(
65        "The given identity key doesn't match the one in the pre-key message: \
66        expected {0}, got {1}"
67    )]
68    MismatchedIdentityKey(Curve25519PublicKey, Curve25519PublicKey),
69    /// The pre-key message that was used to establish the [`Session`] couldn't
70    /// be decrypted. The message needs to be decryptable, otherwise we will
71    /// have created a Session that wasn't used to encrypt the pre-key
72    /// message.
73    #[error("The message that was used to establish the Session couldn't be decrypted")]
74    Decryption(#[from] DecryptionError),
75}
76
77/// Struct holding the two public identity keys of an [`Account`].
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79pub struct IdentityKeys {
80    /// The Ed25519 key, used for signing.
81    pub ed25519: Ed25519PublicKey,
82    /// The Curve25519 key, used for to establish shared secrets.
83    pub curve25519: Curve25519PublicKey,
84}
85
86/// Return type for the creation of inbound [`Session`] objects.
87#[derive(Debug)]
88pub struct InboundCreationResult {
89    /// The [`Session`] that was created from a pre-key message.
90    pub session: Session,
91    /// The plaintext of the pre-key message.
92    pub plaintext: Vec<u8>,
93}
94
95/// Return type for the creation of a dehydrated device.
96#[derive(Debug)]
97pub struct DehydratedDeviceResult {
98    /// The encrypted dehydrated device, as a base64-encoded string.
99    pub ciphertext: String,
100    /// The nonce used for encrypting, as a base64-encoded string.
101    pub nonce: String,
102}
103
104/// An Olm [`Account`] manages all cryptographic keys used on a device.
105pub struct Account {
106    /// A permanent Ed25519 key used for signing. Also known as the fingerprint
107    /// key.
108    signing_key: Ed25519Keypair,
109    /// The permanent Curve25519 key used for triple Diffie-Hellman (3DH). Also
110    /// known as the sender key or the identity key.
111    diffie_hellman_key: Curve25519Keypair,
112    /// The ephemeral (one-time) Curve25519 keys used as part of the triple
113    /// Diffie-Hellman (3DH).
114    one_time_keys: OneTimeKeys,
115    /// The ephemeral Curve25519 keys used in lieu of a one-time key as part of
116    /// the 3DH, in case we run out of those. We keep track of both the current
117    /// and the previous fallback key in any given moment.
118    fallback_keys: FallbackKeys,
119}
120
121impl Account {
122    /// Create a new [`Account`] with new random identity keys.
123    pub fn new() -> Self {
124        Self {
125            signing_key: Ed25519Keypair::new(),
126            diffie_hellman_key: Curve25519Keypair::new(),
127            one_time_keys: OneTimeKeys::new(),
128            fallback_keys: FallbackKeys::new(),
129        }
130    }
131
132    /// Get the [`IdentityKeys`] of this Account
133    pub const fn identity_keys(&self) -> IdentityKeys {
134        IdentityKeys { ed25519: self.ed25519_key(), curve25519: self.curve25519_key() }
135    }
136
137    /// Get a copy of the account's public Ed25519 key
138    pub const fn ed25519_key(&self) -> Ed25519PublicKey {
139        self.signing_key.public_key()
140    }
141
142    /// Get a copy of the account's public Curve25519 key
143    pub const fn curve25519_key(&self) -> Curve25519PublicKey {
144        self.diffie_hellman_key.public_key()
145    }
146
147    /// Sign the given message using our Ed25519 fingerprint key.
148    pub fn sign(&self, message: impl AsRef<[u8]>) -> Ed25519Signature {
149        self.signing_key.sign(message.as_ref())
150    }
151
152    /// Get the maximum number of one-time keys the client should keep on the
153    /// server.
154    ///
155    /// **Note**: this differs from the libolm method of the same name, the
156    /// libolm method returned the maximum amount of one-time keys the
157    /// [`Account`] could hold and only half of those should be uploaded.
158    pub const fn max_number_of_one_time_keys(&self) -> usize {
159        // We tell clients to upload a limited amount of one-time keys, this
160        // amount is smaller than what we can store.
161        //
162        // We do this because a client might receive the count of uploaded keys
163        // from the server before they receive all the pre-key messages that
164        // used some of our one-time keys. This would mean that we would forget
165        // private one-time keys, since we're generating new ones, while we
166        // didn't yet receive the pre-key messages that used those one-time
167        // keys.
168        PUBLIC_MAX_ONE_TIME_KEYS
169    }
170
171    /// Create a [`Session`] with the given identity key and one-time key.
172    pub fn create_outbound_session(
173        &self,
174        session_config: SessionConfig,
175        identity_key: Curve25519PublicKey,
176        one_time_key: Curve25519PublicKey,
177    ) -> Session {
178        let rng = thread_rng();
179
180        let base_key = ReusableSecret::random_from_rng(rng);
181        let public_base_key = Curve25519PublicKey::from(&base_key);
182
183        let shared_secret = Shared3DHSecret::new(
184            self.diffie_hellman_key.secret_key(),
185            &base_key,
186            &identity_key,
187            &one_time_key,
188        );
189
190        let session_keys = SessionKeys {
191            identity_key: self.curve25519_key(),
192            base_key: public_base_key,
193            one_time_key,
194        };
195
196        Session::new(session_config, shared_secret, session_keys)
197    }
198
199    /// Try to find a [`Curve25519SecretKey`] that forms a pair with the given
200    /// [`Curve25519PublicKey`].
201    fn find_one_time_key(&self, public_key: &Curve25519PublicKey) -> Option<&Curve25519SecretKey> {
202        self.one_time_keys
203            .get_secret_key(public_key)
204            .or_else(|| self.fallback_keys.get_secret_key(public_key))
205    }
206
207    /// Remove a one-time key that has previously been published but not yet
208    /// used.
209    ///
210    /// **Note**: This function is only rarely useful and you'll know if you
211    /// need it. Notably, you do *not* need to call it manually when using up
212    /// a key via [`Account::create_inbound_session`] since the key is
213    /// automatically removed in that case.
214    #[cfg(feature = "low-level-api")]
215    pub fn remove_one_time_key(
216        &mut self,
217        public_key: Curve25519PublicKey,
218    ) -> Option<Curve25519SecretKey> {
219        self.remove_one_time_key_helper(public_key)
220    }
221
222    fn remove_one_time_key_helper(
223        &mut self,
224        public_key: Curve25519PublicKey,
225    ) -> Option<Curve25519SecretKey> {
226        self.one_time_keys.remove_secret_key(&public_key)
227    }
228
229    /// Create a [`Session`] from the given [`PreKeyMessage`] message and
230    /// identity key
231    pub fn create_inbound_session(
232        &mut self,
233        their_identity_key: Curve25519PublicKey,
234        pre_key_message: &PreKeyMessage,
235    ) -> Result<InboundCreationResult, SessionCreationError> {
236        if their_identity_key != pre_key_message.identity_key() {
237            Err(SessionCreationError::MismatchedIdentityKey(
238                their_identity_key,
239                pre_key_message.identity_key(),
240            ))
241        } else {
242            // Find the matching private part of the OTK that the message claims
243            // was used to create the session that encrypted it.
244            let public_otk = pre_key_message.one_time_key();
245            let private_otk = self
246                .find_one_time_key(&public_otk)
247                .ok_or(SessionCreationError::MissingOneTimeKey(public_otk))?;
248
249            // Construct a 3DH shared secret from the various curve25519 keys.
250            let shared_secret = RemoteShared3DHSecret::new(
251                self.diffie_hellman_key.secret_key(),
252                private_otk,
253                &pre_key_message.identity_key(),
254                &pre_key_message.base_key(),
255            );
256
257            // These will be used to uniquely identify the Session.
258            let session_keys = SessionKeys {
259                identity_key: pre_key_message.identity_key(),
260                base_key: pre_key_message.base_key(),
261                one_time_key: pre_key_message.one_time_key(),
262            };
263
264            let config = if pre_key_message.message.mac_truncated() {
265                SessionConfig::version_1()
266            } else {
267                SessionConfig::version_2()
268            };
269
270            // Create a Session, AKA a double ratchet, this one will have an
271            // inactive sending chain until we decide to encrypt a message.
272            let mut session = Session::new_remote(
273                config,
274                shared_secret,
275                pre_key_message.message.ratchet_key,
276                session_keys,
277            );
278
279            // Decrypt the message to check if the Session is actually valid.
280            let plaintext = session.decrypt_decoded(&pre_key_message.message)?;
281
282            // We only drop the one-time key now, this is why we can't use a
283            // one-time key type that takes `self`. If we didn't do this,
284            // someone could maliciously pretend to use up our one-time key and
285            // make us drop the private part. Unsuspecting users that actually
286            // try to use such an one-time key won't be able to communicate with
287            // us. This is strictly worse than the one-time key exhaustion
288            // scenario.
289            self.remove_one_time_key_helper(pre_key_message.one_time_key());
290
291            Ok(InboundCreationResult { session, plaintext })
292        }
293    }
294
295    /// Generates the supplied number of one time keys.
296    /// Returns the public parts of the one-time keys that were created and
297    /// discarded.
298    ///
299    /// Our one-time key store inside the [`Account`] has a limited amount of
300    /// places for one-time keys, If we try to generate new ones while the store
301    /// is completely populated, the oldest one-time keys will get discarded
302    /// to make place for new ones.
303    pub fn generate_one_time_keys(&mut self, count: usize) -> OneTimeKeyGenerationResult {
304        self.one_time_keys.generate(count)
305    }
306
307    /// Get the number of one-time keys we have stored locally.
308    ///
309    /// This will be equal or greater to the number of one-time keys we have
310    /// published. Each time a new [`Session`] is created using the
311    /// [`Account::create_inbound_session()`] a one-time key will be used up
312    /// and removed.
313    pub fn stored_one_time_key_count(&self) -> usize {
314        self.one_time_keys.private_keys.len()
315    }
316
317    /// Get the currently unpublished one-time keys.
318    ///
319    /// The one-time keys should be published to a server and marked as
320    /// published using the `mark_keys_as_published()` method.
321    pub fn one_time_keys(&self) -> HashMap<KeyId, Curve25519PublicKey> {
322        self.one_time_keys
323            .unpublished_public_keys
324            .iter()
325            .map(|(key_id, key)| (*key_id, *key))
326            .collect()
327    }
328
329    /// Generate a single new fallback key.
330    ///
331    /// The fallback key will be used by other users to establish a [`Session`]
332    /// if all the one-time keys on the server have been used up.
333    ///
334    /// Returns the public Curve25519 key of the *previous* fallback key, that
335    /// is, the one that will get removed from the [`Account`] when this method
336    /// is called. This return value is mostly useful for logging purposes.
337    pub fn generate_fallback_key(&mut self) -> Option<Curve25519PublicKey> {
338        self.fallback_keys.generate_fallback_key()
339    }
340
341    /// Get the currently unpublished fallback key.
342    ///
343    /// The fallback key should be published just like the one-time keys, after
344    /// it has been successfully published it needs to be marked as published
345    /// using the `mark_keys_as_published()` method as well.
346    pub fn fallback_key(&self) -> HashMap<KeyId, Curve25519PublicKey> {
347        let fallback_key = self.fallback_keys.unpublished_fallback_key();
348
349        if let Some(fallback_key) = fallback_key {
350            HashMap::from([(fallback_key.key_id(), fallback_key.public_key())])
351        } else {
352            HashMap::new()
353        }
354    }
355
356    /// The [`Account`] stores at most two private parts of the fallback key.
357    /// This method lets us forget the previously used fallback key.
358    pub fn forget_fallback_key(&mut self) -> bool {
359        self.fallback_keys.forget_previous_fallback_key().is_some()
360    }
361
362    /// Mark all currently unpublished one-time and fallback keys as published.
363    pub fn mark_keys_as_published(&mut self) {
364        self.one_time_keys.mark_as_published();
365        self.fallback_keys.mark_as_published();
366    }
367
368    /// Convert the account into a struct which implements [`serde::Serialize`]
369    /// and [`serde::Deserialize`].
370    pub fn pickle(&self) -> AccountPickle {
371        AccountPickle {
372            signing_key: self.signing_key.clone().into(),
373            diffie_hellman_key: self.diffie_hellman_key.clone().into(),
374            one_time_keys: self.one_time_keys.clone().into(),
375            fallback_keys: self.fallback_keys.clone(),
376        }
377    }
378
379    /// Restore an [`Account`] from a previously saved [`AccountPickle`].
380    pub fn from_pickle(pickle: AccountPickle) -> Self {
381        pickle.into()
382    }
383
384    /// Create an [`Account`] object by unpickling an account pickle in libolm
385    /// legacy pickle format.
386    ///
387    /// Such pickles are encrypted and need to first be decrypted using
388    /// `pickle_key`.
389    #[cfg(feature = "libolm-compat")]
390    pub fn from_libolm_pickle(
391        pickle: &str,
392        pickle_key: &[u8],
393    ) -> Result<Self, crate::LibolmPickleError> {
394        use self::libolm::Pickle;
395        use crate::utilities::unpickle_libolm;
396
397        const PICKLE_VERSION: u32 = 4;
398        unpickle_libolm::<Pickle, _>(pickle, pickle_key, PICKLE_VERSION)
399    }
400
401    /// Pickle an [`Account`] into a libolm pickle format.
402    ///
403    /// This pickle can be restored using the [`Account::from_libolm_pickle()`]
404    /// method, or can be used in the [`libolm`] C library.
405    ///
406    /// The pickle will be encrypted using the pickle key.
407    ///
408    /// *Note*: This method might be lossy, the vodozemac [`Account`] has the
409    /// ability to hold more one-time keys compared to the [`libolm`]
410    /// variant.
411    ///
412    /// ⚠️  ***Security Warning***: The pickle key will get expanded into both
413    /// an AES key and an IV in a deterministic manner. If the same pickle
414    /// key is reused, this will lead to IV reuse. To prevent this, users
415    /// have to ensure that they always use a globally (probabilistically)
416    /// unique pickle key.
417    ///
418    /// [`libolm`]: https://gitlab.matrix.org/matrix-org/olm/
419    ///
420    /// # Examples
421    /// ```
422    /// use vodozemac::olm::Account;
423    /// use olm_rs::{account::OlmAccount, PicklingMode};
424    /// let account = Account::new();
425    ///
426    /// let export = account
427    ///     .to_libolm_pickle(&[0u8; 32])
428    ///     .expect("We should be able to pickle a freshly created Account");
429    ///
430    /// let unpickled = OlmAccount::unpickle(
431    ///     export,
432    ///     PicklingMode::Encrypted { key: [0u8; 32].to_vec() },
433    /// ).expect("We should be able to unpickle our exported Account");
434    /// ```
435    #[cfg(feature = "libolm-compat")]
436    pub fn to_libolm_pickle(&self, pickle_key: &[u8]) -> Result<String, crate::LibolmPickleError> {
437        use self::libolm::Pickle;
438        use crate::utilities::pickle_libolm;
439        pickle_libolm::<Pickle>(self.into(), pickle_key)
440    }
441
442    #[cfg(all(any(fuzzing, test), feature = "libolm-compat"))]
443    #[doc(hidden)]
444    pub fn from_decrypted_libolm_pickle(pickle: &[u8]) -> Result<Self, crate::LibolmPickleError> {
445        use std::io::Cursor;
446
447        use matrix_pickle::Decode;
448
449        use self::libolm::Pickle;
450
451        let mut cursor = Cursor::new(&pickle);
452        let pickle = Pickle::decode(&mut cursor)?;
453
454        pickle.try_into()
455    }
456
457    /// Create a dehydrated device from the account.
458    ///
459    /// A dehydrated device is a device that is stored encrypted on the server
460    /// that can receive messages when the user has no other active devices.
461    /// Upon login, the user can rehydrate the device (using
462    /// [`Account::from_dehydrated_device`]) and decrypt the messages sent to
463    /// the dehydrated device.
464    ///
465    /// The account must be a newly-created account that does not have any Olm
466    /// sessions, since the dehydrated device format does not store sessions.
467    ///
468    /// Returns the ciphertext and nonce.  `key` is a 256-bit (32-byte) key for
469    /// encrypting the device.
470    ///
471    /// The format used here is defined in
472    /// [MSC3814](https://github.com/matrix-org/matrix-spec-proposals/pull/3814).
473    pub fn to_dehydrated_device(
474        &self,
475        key: &[u8; 32],
476    ) -> Result<DehydratedDeviceResult, crate::DehydratedDeviceError> {
477        use matrix_pickle::Encode;
478
479        use self::dehydrated_device::Pickle;
480        use crate::{DehydratedDeviceError, LibolmPickleError, utilities::base64_encode};
481
482        let pickle: Pickle = self.try_into()?;
483        let mut encoded = pickle
484            .encode_to_vec()
485            .map_err(|e| DehydratedDeviceError::LibolmPickle(LibolmPickleError::Encode(e)))?;
486
487        let cipher = ChaCha20Poly1305::new(key.into());
488        let rng = thread_rng();
489        let nonce = ChaCha20Poly1305::generate_nonce(rng);
490        let ciphertext = cipher.encrypt(&nonce, encoded.as_slice());
491
492        encoded.zeroize();
493
494        let ciphertext = ciphertext?;
495
496        Ok(DehydratedDeviceResult {
497            ciphertext: base64_encode(ciphertext),
498            nonce: base64_encode(nonce),
499        })
500    }
501
502    /// Create an [`Account`] object from a dehydrated device.
503    ///
504    /// `ciphertext` and `nonce` are the ciphertext and nonce returned by
505    /// [`Account::to_dehydrated_device`]. `key` is a 256-bit (32-byte) key for
506    /// decrypting the device, and must be the same key used when
507    /// [`Account::to_dehydrated_device`] was called.
508    pub fn from_dehydrated_device(
509        ciphertext: &str,
510        nonce: &str,
511        key: &[u8; 32],
512    ) -> Result<Self, crate::DehydratedDeviceError> {
513        use self::dehydrated_device::PICKLE_VERSION;
514        use crate::utilities::{base64_decode, get_pickle_version};
515
516        let cipher = ChaCha20Poly1305::new(key.into());
517        let ciphertext = base64_decode(ciphertext)?;
518        let nonce = base64_decode(nonce)?;
519
520        if nonce.len() != 12 {
521            Err(crate::DehydratedDeviceError::InvalidNonce)
522        } else {
523            let mut plaintext = cipher.decrypt(nonce.as_slice().into(), ciphertext.as_slice())?;
524            let version = get_pickle_version(&plaintext)
525                .ok_or(crate::DehydratedDeviceError::MissingVersion)?;
526
527            if version != PICKLE_VERSION {
528                Err(crate::DehydratedDeviceError::Version(PICKLE_VERSION, version))
529            } else {
530                let pickle = Self::from_decrypted_dehydrated_device(&plaintext);
531                plaintext.zeroize();
532                pickle
533            }
534        }
535    }
536
537    // This function is public for fuzzing, but should not be used by anything
538    // else
539    #[doc(hidden)]
540    pub fn from_decrypted_dehydrated_device(
541        pickle: &[u8],
542    ) -> Result<Self, crate::DehydratedDeviceError> {
543        use std::io::Cursor;
544
545        use matrix_pickle::Decode;
546
547        use self::dehydrated_device::Pickle;
548        use crate::{DehydratedDeviceError, LibolmPickleError};
549
550        let mut cursor = Cursor::new(&pickle);
551        let pickle = Pickle::decode(&mut cursor)
552            .map_err(|e| DehydratedDeviceError::LibolmPickle(LibolmPickleError::Decode(e)))?;
553
554        pickle.try_into()
555    }
556}
557
558impl Default for Account {
559    fn default() -> Self {
560        Self::new()
561    }
562}
563
564/// A format suitable for serialization which implements [`serde::Serialize`]
565/// and [`serde::Deserialize`]. Obtainable by calling [`Account::pickle`].
566#[derive(Serialize, Deserialize)]
567pub struct AccountPickle {
568    signing_key: Ed25519KeypairPickle,
569    diffie_hellman_key: Curve25519KeypairPickle,
570    one_time_keys: OneTimeKeysPickle,
571    fallback_keys: FallbackKeys,
572}
573
574/// A format suitable for serialization which implements [`serde::Serialize`]
575/// and [`serde::Deserialize`]. Obtainable by calling [`Account::pickle`].
576impl AccountPickle {
577    /// Serialize and encrypt the pickle using the given key.
578    ///
579    /// This is the inverse of [`AccountPickle::from_encrypted`].
580    pub fn encrypt(self, pickle_key: &[u8; 32]) -> String {
581        pickle(&self, pickle_key)
582    }
583
584    /// Obtain a pickle from a ciphertext by decrypting and deserializing using
585    /// the given key.
586    ///
587    /// This is the inverse of [`AccountPickle::encrypt`].
588    pub fn from_encrypted(ciphertext: &str, pickle_key: &[u8; 32]) -> Result<Self, PickleError> {
589        unpickle(ciphertext, pickle_key)
590    }
591}
592
593impl From<AccountPickle> for Account {
594    fn from(pickle: AccountPickle) -> Self {
595        Self {
596            signing_key: pickle.signing_key.into(),
597            diffie_hellman_key: pickle.diffie_hellman_key.into(),
598            one_time_keys: pickle.one_time_keys.into(),
599            fallback_keys: pickle.fallback_keys,
600        }
601    }
602}
603
604#[cfg(feature = "libolm-compat")]
605mod libolm {
606    use matrix_pickle::{Decode, DecodeError, Encode, EncodeError};
607    use zeroize::{Zeroize, ZeroizeOnDrop};
608
609    use super::{
610        Account,
611        fallback_keys::{FallbackKey, FallbackKeys},
612        one_time_keys::OneTimeKeys,
613    };
614    use crate::{
615        Curve25519PublicKey, Ed25519Keypair, KeyId,
616        types::{Curve25519Keypair, Curve25519SecretKey},
617        utilities::LibolmEd25519Keypair,
618    };
619
620    #[derive(Encode, Decode, Zeroize, ZeroizeOnDrop)]
621    struct OneTimeKey {
622        key_id: u32,
623        published: bool,
624        public_key: [u8; 32],
625        private_key: Box<[u8; 32]>,
626    }
627
628    impl From<&OneTimeKey> for FallbackKey {
629        fn from(key: &OneTimeKey) -> Self {
630            FallbackKey {
631                key_id: KeyId(key.key_id.into()),
632                key: Curve25519SecretKey::from_slice(&key.private_key),
633                published: key.published,
634            }
635        }
636    }
637
638    #[derive(Zeroize, ZeroizeOnDrop)]
639    struct FallbackKeysArray {
640        fallback_key: Option<OneTimeKey>,
641        previous_fallback_key: Option<OneTimeKey>,
642    }
643
644    impl Decode for FallbackKeysArray {
645        fn decode(reader: &mut impl std::io::Read) -> Result<Self, DecodeError> {
646            let count = u8::decode(reader)?;
647
648            let (fallback_key, previous_fallback_key) = if count >= 1 {
649                let fallback_key = OneTimeKey::decode(reader)?;
650
651                let previous_fallback_key =
652                    if count >= 2 { Some(OneTimeKey::decode(reader)?) } else { None };
653
654                (Some(fallback_key), previous_fallback_key)
655            } else {
656                (None, None)
657            };
658
659            Ok(Self { fallback_key, previous_fallback_key })
660        }
661    }
662
663    impl Encode for FallbackKeysArray {
664        fn encode(&self, writer: &mut impl std::io::Write) -> Result<usize, EncodeError> {
665            let ret = match (&self.fallback_key, &self.previous_fallback_key) {
666                (None, None) => 0u8.encode(writer)?,
667                (Some(key), None) | (None, Some(key)) => {
668                    let mut ret = 1u8.encode(writer)?;
669                    ret += key.encode(writer)?;
670
671                    ret
672                }
673                (Some(key), Some(previous_key)) => {
674                    let mut ret = 2u8.encode(writer)?;
675                    ret += key.encode(writer)?;
676                    ret += previous_key.encode(writer)?;
677
678                    ret
679                }
680            };
681
682            Ok(ret)
683        }
684    }
685
686    #[derive(Encode, Decode, Zeroize, ZeroizeOnDrop)]
687    pub(super) struct Pickle {
688        version: u32,
689        ed25519_keypair: LibolmEd25519Keypair,
690        public_curve25519_key: [u8; 32],
691        private_curve25519_key: Box<[u8; 32]>,
692        one_time_keys: Vec<OneTimeKey>,
693        fallback_keys: FallbackKeysArray,
694        next_key_id: u32,
695    }
696
697    impl TryFrom<&FallbackKey> for OneTimeKey {
698        type Error = ();
699
700        fn try_from(key: &FallbackKey) -> Result<Self, ()> {
701            Ok(OneTimeKey {
702                key_id: key.key_id.0.try_into().map_err(|_| ())?,
703                published: key.published(),
704                public_key: key.public_key().to_bytes(),
705                private_key: key.secret_key().to_bytes(),
706            })
707        }
708    }
709
710    impl From<&Account> for Pickle {
711        fn from(account: &Account) -> Self {
712            let one_time_keys: Vec<_> = account
713                .one_time_keys
714                .secret_keys()
715                .iter()
716                .filter_map(|(key_id, secret_key)| {
717                    Some(OneTimeKey {
718                        key_id: key_id.0.try_into().ok()?,
719                        published: account.one_time_keys.is_secret_key_published(key_id),
720                        public_key: Curve25519PublicKey::from(secret_key).to_bytes(),
721                        private_key: secret_key.to_bytes(),
722                    })
723                })
724                .collect();
725
726            let fallback_keys = FallbackKeysArray {
727                fallback_key: account
728                    .fallback_keys
729                    .fallback_key
730                    .as_ref()
731                    .and_then(|f| f.try_into().ok()),
732                previous_fallback_key: account
733                    .fallback_keys
734                    .previous_fallback_key
735                    .as_ref()
736                    .and_then(|f| f.try_into().ok()),
737            };
738
739            let next_key_id = account.one_time_keys.next_key_id.try_into().unwrap_or_default();
740
741            Self {
742                version: 4,
743                ed25519_keypair: LibolmEd25519Keypair {
744                    private_key: account.signing_key.expanded_secret_key(),
745                    public_key: account.signing_key.public_key().as_bytes().to_owned(),
746                },
747                public_curve25519_key: account.diffie_hellman_key.public_key().to_bytes(),
748                private_curve25519_key: account.diffie_hellman_key.secret_key().to_bytes(),
749                one_time_keys,
750                fallback_keys,
751                next_key_id,
752            }
753        }
754    }
755
756    impl TryFrom<Pickle> for Account {
757        type Error = crate::LibolmPickleError;
758
759        fn try_from(pickle: Pickle) -> Result<Self, Self::Error> {
760            let mut one_time_keys = OneTimeKeys::new();
761
762            for key in &pickle.one_time_keys {
763                let secret_key = Curve25519SecretKey::from_slice(&key.private_key);
764                let key_id = KeyId(key.key_id.into());
765                one_time_keys.insert_secret_key(key_id, secret_key, key.published);
766            }
767
768            one_time_keys.next_key_id = pickle.next_key_id.into();
769
770            let fallback_keys = FallbackKeys {
771                key_id: pickle
772                    .fallback_keys
773                    .fallback_key
774                    .as_ref()
775                    .map(|k| k.key_id.wrapping_add(1))
776                    .unwrap_or(0) as u64,
777                fallback_key: pickle.fallback_keys.fallback_key.as_ref().map(|k| k.into()),
778                previous_fallback_key: pickle
779                    .fallback_keys
780                    .previous_fallback_key
781                    .as_ref()
782                    .map(|k| k.into()),
783            };
784
785            Ok(Self {
786                signing_key: Ed25519Keypair::from_expanded_key(
787                    &pickle.ed25519_keypair.private_key,
788                )?,
789                diffie_hellman_key: Curve25519Keypair::from_secret_key(
790                    &pickle.private_curve25519_key,
791                ),
792                one_time_keys,
793                fallback_keys,
794            })
795        }
796    }
797}
798
799mod dehydrated_device {
800    use matrix_pickle::{Decode, DecodeError, Encode, EncodeError};
801    use zeroize::{Zeroize, ZeroizeOnDrop};
802
803    use super::{
804        Account,
805        fallback_keys::{FallbackKey, FallbackKeys},
806        one_time_keys::OneTimeKeys,
807    };
808    use crate::{
809        DehydratedDeviceError, Ed25519Keypair, KeyId,
810        types::{Curve25519Keypair, Curve25519SecretKey},
811    };
812
813    #[derive(Encode, Decode, Zeroize, ZeroizeOnDrop)]
814    pub(crate) struct OneTimeKey {
815        #[secret]
816        pub(crate) private_key: Box<[u8; 32]>,
817    }
818
819    impl From<&OneTimeKey> for FallbackKey {
820        fn from(key: &OneTimeKey) -> Self {
821            FallbackKey {
822                key_id: KeyId(0),
823                key: Curve25519SecretKey::from_slice(&key.private_key),
824                published: true,
825            }
826        }
827    }
828
829    impl TryFrom<&FallbackKey> for OneTimeKey {
830        type Error = ();
831
832        fn try_from(key: &FallbackKey) -> Result<Self, ()> {
833            Ok(OneTimeKey { private_key: key.secret_key().to_bytes() })
834        }
835    }
836
837    #[derive(Zeroize, ZeroizeOnDrop)]
838    pub(crate) struct OptFallbackKey {
839        pub(crate) fallback_key: Option<OneTimeKey>,
840    }
841
842    impl Decode for OptFallbackKey {
843        fn decode(reader: &mut impl std::io::Read) -> Result<Self, DecodeError> {
844            let present = bool::decode(reader)?;
845
846            let fallback_key = if present {
847                let fallback_key = OneTimeKey::decode(reader)?;
848
849                Some(fallback_key)
850            } else {
851                None
852            };
853
854            Ok(Self { fallback_key })
855        }
856    }
857
858    impl Encode for OptFallbackKey {
859        fn encode(&self, writer: &mut impl std::io::Write) -> Result<usize, EncodeError> {
860            let ret = match &self.fallback_key {
861                None => false.encode(writer)?,
862                Some(key) => {
863                    let mut ret = true.encode(writer)?;
864                    ret += key.encode(writer)?;
865
866                    ret
867                }
868            };
869
870            Ok(ret)
871        }
872    }
873
874    #[derive(Encode, Decode, Zeroize, ZeroizeOnDrop)]
875    /// Pickle used for dehydrated devices.
876    ///
877    /// Dehydrated devices are used for receiving encrypted messages when the
878    /// user has no other devices logged in, and are defined in
879    /// [MSC3814](https://github.com/matrix-org/matrix-spec-proposals/pull/3814).
880    pub(super) struct Pickle {
881        version: u32,
882        #[secret]
883        private_curve25519_key: Box<[u8; 32]>,
884        #[secret]
885        private_ed25519_key: Box<[u8; 32]>,
886        one_time_keys: Vec<OneTimeKey>,
887        opt_fallback_key: OptFallbackKey,
888    }
889
890    pub(super) const PICKLE_VERSION: u32 = 1;
891
892    impl TryFrom<&Account> for Pickle {
893        type Error = DehydratedDeviceError;
894
895        fn try_from(account: &Account) -> Result<Self, Self::Error> {
896            let one_time_keys: Vec<_> = account
897                .one_time_keys
898                .secret_keys()
899                .values()
900                .map(|secret_key| OneTimeKey { private_key: secret_key.to_bytes() })
901                .collect();
902
903            let fallback_key =
904                account.fallback_keys.fallback_key.as_ref().and_then(|f| f.try_into().ok());
905
906            Ok(Self {
907                version: PICKLE_VERSION,
908                private_curve25519_key: account.diffie_hellman_key.secret_key().to_bytes(),
909                private_ed25519_key: account
910                    .signing_key
911                    .unexpanded_secret_key()
912                    .ok_or_else(|| DehydratedDeviceError::InvalidAccount)?,
913                one_time_keys,
914                opt_fallback_key: OptFallbackKey { fallback_key },
915            })
916        }
917    }
918
919    impl TryFrom<Pickle> for Account {
920        type Error = DehydratedDeviceError;
921
922        fn try_from(pickle: Pickle) -> Result<Self, Self::Error> {
923            use crate::{DehydratedDeviceError, LibolmPickleError};
924            let mut one_time_keys = OneTimeKeys::new();
925
926            for (num, key) in pickle.one_time_keys.iter().enumerate() {
927                let secret_key = Curve25519SecretKey::from_slice(&key.private_key);
928                let key_id = KeyId(num as u64);
929                one_time_keys.insert_secret_key(key_id, secret_key, true);
930            }
931            one_time_keys.next_key_id = pickle.one_time_keys.len().try_into().unwrap_or_default();
932
933            let fallback_keys = FallbackKeys {
934                key_id: 1,
935                fallback_key: pickle.opt_fallback_key.fallback_key.as_ref().map(|otk| otk.into()),
936                previous_fallback_key: None,
937            };
938
939            Ok(Self {
940                signing_key: Ed25519Keypair::from_unexpanded_key(&pickle.private_ed25519_key)
941                    .map_err(|e| {
942                        DehydratedDeviceError::LibolmPickle(LibolmPickleError::PublicKey(e))
943                    })?,
944                diffie_hellman_key: Curve25519Keypair::from_secret_key(
945                    &pickle.private_curve25519_key,
946                ),
947                one_time_keys,
948                fallback_keys,
949            })
950        }
951    }
952}
953
954#[cfg(test)]
955mod test {
956    use anyhow::{Context, Result, bail};
957    use assert_matches2::assert_matches;
958    use matrix_pickle::{Decode, Encode};
959    use olm_rs::{account::OlmAccount, session::OlmMessage as LibolmOlmMessage};
960
961    #[cfg(feature = "libolm-compat")]
962    use super::libolm::Pickle;
963    use super::{
964        Account, InboundCreationResult, SessionConfig, SessionCreationError, dehydrated_device,
965    };
966    use crate::{
967        Curve25519PublicKey as PublicKey,
968        cipher::Mac,
969        olm::{
970            AccountPickle,
971            account::PUBLIC_MAX_ONE_TIME_KEYS,
972            messages::{OlmMessage, PreKeyMessage},
973        },
974    };
975
976    const PICKLE_KEY: [u8; 32] = [0u8; 32];
977
978    #[test]
979    fn max_number_of_one_time_keys_matches_global_constant() {
980        assert_eq!(Account::new().max_number_of_one_time_keys(), PUBLIC_MAX_ONE_TIME_KEYS);
981    }
982
983    #[test]
984    #[cfg(feature = "low-level-api")]
985    fn generate_and_remove_one_time_key() {
986        let mut alice = Account::new();
987        assert_eq!(alice.stored_one_time_key_count(), 0);
988
989        alice.generate_one_time_keys(1);
990        assert_eq!(alice.stored_one_time_key_count(), 1);
991
992        let public_key = alice
993            .one_time_keys()
994            .values()
995            .next()
996            .copied()
997            .expect("Should have an unpublished one-time key");
998        let secret_key_bytes = alice
999            .find_one_time_key(&public_key)
1000            .expect("Should find secret key for public one-time key")
1001            .to_bytes();
1002        let removed_key_bytes = alice
1003            .remove_one_time_key(public_key)
1004            .expect("Should be able to remove one-time key")
1005            .to_bytes();
1006
1007        assert_eq!(removed_key_bytes, secret_key_bytes);
1008        assert_eq!(alice.stored_one_time_key_count(), 0);
1009    }
1010
1011    #[test]
1012    fn generate_and_forget_fallback_keys() {
1013        let mut alice = Account::default();
1014        assert!(!alice.forget_fallback_key());
1015        alice.generate_fallback_key();
1016        assert!(!alice.forget_fallback_key());
1017        alice.generate_fallback_key();
1018        assert!(alice.forget_fallback_key());
1019    }
1020
1021    #[test]
1022    fn vodozemac_libolm_communication() -> Result<()> {
1023        // vodozemac account
1024        let alice = Account::new();
1025        // libolm account
1026        let bob = OlmAccount::new();
1027
1028        bob.generate_one_time_keys(1);
1029
1030        let one_time_key = bob
1031            .parsed_one_time_keys()
1032            .curve25519()
1033            .values()
1034            .next()
1035            .cloned()
1036            .expect("Didn't find a valid one-time key");
1037
1038        bob.mark_keys_as_published();
1039
1040        let identity_keys = bob.parsed_identity_keys();
1041        let curve25519_key = PublicKey::from_base64(identity_keys.curve25519())?;
1042        let one_time_key = PublicKey::from_base64(&one_time_key)?;
1043        let mut alice_session =
1044            alice.create_outbound_session(SessionConfig::version_1(), curve25519_key, one_time_key);
1045
1046        let message = "It's a secret to everybody";
1047        let olm_message: LibolmOlmMessage = alice_session.encrypt(message).into();
1048
1049        if let LibolmOlmMessage::PreKey(m) = olm_message.clone() {
1050            let libolm_session =
1051                bob.create_inbound_session_from(&alice.curve25519_key().to_base64(), m)?;
1052            assert_eq!(alice_session.session_id(), libolm_session.session_id());
1053
1054            let plaintext = libolm_session.decrypt(olm_message)?;
1055            assert_eq!(message, plaintext);
1056
1057            let second_text = "Here's another secret to everybody";
1058            let olm_message = alice_session.encrypt(second_text).into();
1059
1060            let plaintext = libolm_session.decrypt(olm_message)?;
1061            assert_eq!(second_text, plaintext);
1062
1063            let reply_plain = "Yes, take this, it's dangerous out there";
1064            let reply = libolm_session.encrypt(reply_plain).into();
1065            let plaintext = alice_session.decrypt(&reply)?;
1066
1067            assert_eq!(plaintext, reply_plain.as_bytes());
1068
1069            let another_reply = "Last one";
1070            let reply = libolm_session.encrypt(another_reply).into();
1071            let plaintext = alice_session.decrypt(&reply)?;
1072            assert_eq!(plaintext, another_reply.as_bytes());
1073
1074            let last_text = "Nope, I'll have the last word";
1075            let olm_message = alice_session.encrypt(last_text).into();
1076
1077            let plaintext = libolm_session.decrypt(olm_message)?;
1078            assert_eq!(last_text, plaintext);
1079        } else {
1080            bail!("Received a invalid message type {:?}", olm_message);
1081        }
1082
1083        Ok(())
1084    }
1085
1086    #[test]
1087    fn vodozemac_vodozemac_communication() -> Result<()> {
1088        // Both of these are vodozemac accounts.
1089        let alice = Account::new();
1090        let mut bob = Account::new();
1091
1092        bob.generate_one_time_keys(1);
1093
1094        let mut alice_session = alice.create_outbound_session(
1095            SessionConfig::version_2(),
1096            bob.curve25519_key(),
1097            *bob.one_time_keys()
1098                .iter()
1099                .next()
1100                .context("Failed getting bob's OTK, which should never happen here.")?
1101                .1,
1102        );
1103
1104        assert!(!bob.one_time_keys().is_empty());
1105        bob.mark_keys_as_published();
1106        assert!(bob.one_time_keys().is_empty());
1107
1108        let message = "It's a secret to everybody";
1109        let olm_message = alice_session.encrypt(message);
1110
1111        if let OlmMessage::PreKey(m) = olm_message {
1112            assert_eq!(m.session_keys(), alice_session.session_keys());
1113            assert_eq!(m.session_id(), alice_session.session_id());
1114
1115            let InboundCreationResult { session: mut bob_session, plaintext } =
1116                bob.create_inbound_session(alice.curve25519_key(), &m)?;
1117            assert_eq!(alice_session.session_id(), bob_session.session_id());
1118            assert_eq!(m.session_keys(), bob_session.session_keys());
1119
1120            assert_eq!(message.as_bytes(), plaintext);
1121
1122            let second_text = "Here's another secret to everybody";
1123            let olm_message = alice_session.encrypt(second_text);
1124
1125            let plaintext = bob_session.decrypt(&olm_message)?;
1126            assert_eq!(second_text.as_bytes(), plaintext);
1127
1128            let reply_plain = "Yes, take this, it's dangerous out there";
1129            let reply = bob_session.encrypt(reply_plain);
1130            let plaintext = alice_session.decrypt(&reply)?;
1131
1132            assert_eq!(plaintext, reply_plain.as_bytes());
1133
1134            let another_reply = "Last one";
1135            let reply = bob_session.encrypt(another_reply);
1136            let plaintext = alice_session.decrypt(&reply)?;
1137            assert_eq!(plaintext, another_reply.as_bytes());
1138
1139            let last_text = "Nope, I'll have the last word";
1140            let olm_message = alice_session.encrypt(last_text);
1141
1142            let plaintext = bob_session.decrypt(&olm_message)?;
1143            assert_eq!(last_text.as_bytes(), plaintext);
1144        }
1145
1146        Ok(())
1147    }
1148
1149    #[test]
1150    fn inbound_session_creation() -> Result<()> {
1151        let alice = OlmAccount::new();
1152        let mut bob = Account::new();
1153
1154        bob.generate_one_time_keys(1);
1155
1156        let one_time_key =
1157            bob.one_time_keys().values().next().cloned().expect("Didn't find a valid one-time key");
1158
1159        let alice_session = alice.create_outbound_session(
1160            &bob.curve25519_key().to_base64(),
1161            &one_time_key.to_base64(),
1162        )?;
1163
1164        let text = "It's a secret to everybody";
1165        let message = alice_session.encrypt(text).into();
1166
1167        let identity_key = PublicKey::from_base64(alice.parsed_identity_keys().curve25519())?;
1168
1169        let InboundCreationResult { session, plaintext } = if let OlmMessage::PreKey(m) = &message {
1170            bob.create_inbound_session(identity_key, m)?
1171        } else {
1172            bail!("Got invalid message type from olm_rs {:?}", message);
1173        };
1174
1175        assert_eq!(alice_session.session_id(), session.session_id());
1176        assert!(bob.one_time_keys.private_keys.is_empty());
1177
1178        assert_eq!(text.as_bytes(), plaintext);
1179
1180        Ok(())
1181    }
1182
1183    #[test]
1184    fn inbound_session_creation_using_fallback_keys() -> Result<()> {
1185        let alice = OlmAccount::new();
1186        let mut bob = Account::new();
1187
1188        bob.generate_fallback_key();
1189
1190        let one_time_key =
1191            bob.fallback_key().values().next().cloned().expect("Didn't find a valid fallback key");
1192        assert!(bob.one_time_keys.private_keys.is_empty());
1193
1194        let alice_session = alice.create_outbound_session(
1195            &bob.curve25519_key().to_base64(),
1196            &one_time_key.to_base64(),
1197        )?;
1198
1199        let text = "It's a secret to everybody";
1200
1201        let message = alice_session.encrypt(text).into();
1202        let identity_key = PublicKey::from_base64(alice.parsed_identity_keys().curve25519())?;
1203
1204        if let OlmMessage::PreKey(m) = &message {
1205            let InboundCreationResult { session, plaintext } =
1206                bob.create_inbound_session(identity_key, m)?;
1207
1208            assert_eq!(m.session_keys(), session.session_keys());
1209            assert_eq!(alice_session.session_id(), session.session_id());
1210            assert!(bob.fallback_keys.fallback_key.is_some());
1211
1212            assert_eq!(text.as_bytes(), plaintext);
1213        } else {
1214            bail!("Got invalid message type from olm_rs");
1215        };
1216
1217        Ok(())
1218    }
1219
1220    #[test]
1221    fn account_pickling_roundtrip_is_identity() -> Result<()> {
1222        let mut account = Account::new();
1223
1224        account.generate_one_time_keys(50);
1225
1226        // Generate two fallback keys so the previous fallback key field gets populated.
1227        account.generate_fallback_key();
1228        account.generate_fallback_key();
1229
1230        let pickle = account.pickle().encrypt(&PICKLE_KEY);
1231
1232        let decrypted_pickle = AccountPickle::from_encrypted(&pickle, &PICKLE_KEY)?;
1233        let unpickled_account = Account::from_pickle(decrypted_pickle);
1234        let repickle = unpickled_account.pickle();
1235
1236        assert_eq!(account.identity_keys(), unpickled_account.identity_keys());
1237
1238        let decrypted_pickle = AccountPickle::from_encrypted(&pickle, &PICKLE_KEY)?;
1239        let pickle = serde_json::to_value(decrypted_pickle)?;
1240        let repickle = serde_json::to_value(repickle)?;
1241
1242        assert_eq!(pickle, repickle);
1243
1244        Ok(())
1245    }
1246
1247    #[test]
1248    #[cfg(feature = "libolm-compat")]
1249    fn libolm_unpickling() -> Result<()> {
1250        let olm = OlmAccount::new();
1251        olm.generate_one_time_keys(10);
1252        olm.generate_fallback_key();
1253
1254        let key = b"DEFAULT_PICKLE_KEY";
1255        let pickle = olm.pickle(olm_rs::PicklingMode::Encrypted { key: key.to_vec() });
1256
1257        let unpickled = Account::from_libolm_pickle(&pickle, key)?;
1258
1259        assert_eq!(olm.parsed_identity_keys().ed25519(), unpickled.ed25519_key().to_base64());
1260        assert_eq!(olm.parsed_identity_keys().curve25519(), unpickled.curve25519_key().to_base64());
1261
1262        let mut olm_one_time_keys: Vec<_> =
1263            olm.parsed_one_time_keys().curve25519().values().map(|k| k.to_owned()).collect();
1264        let mut one_time_keys: Vec<_> =
1265            unpickled.one_time_keys().values().map(|k| k.to_base64()).collect();
1266
1267        // We generated 10 one-time keys on the libolm side, we expect the next key id
1268        // to be 11.
1269        assert_eq!(unpickled.one_time_keys.next_key_id, 11);
1270
1271        olm_one_time_keys.sort();
1272        one_time_keys.sort();
1273        assert_eq!(olm_one_time_keys, one_time_keys);
1274
1275        let olm_fallback_key =
1276            olm.parsed_fallback_key().expect("libolm should have a fallback key");
1277        assert_eq!(
1278            olm_fallback_key.curve25519(),
1279            unpickled
1280                .fallback_key()
1281                .values()
1282                .next()
1283                .expect("We should have a fallback key")
1284                .to_base64()
1285        );
1286
1287        Ok(())
1288    }
1289
1290    #[test]
1291    #[cfg(feature = "libolm-compat")]
1292    fn pickle_cycle_with_one_fallback_key() {
1293        let mut alice = Account::new();
1294        alice.generate_fallback_key();
1295
1296        let mut encoded = Vec::<u8>::new();
1297        let pickle = Pickle::from(&alice);
1298        let size = pickle.encode(&mut encoded).expect("Should encode pickle");
1299        assert_eq!(size, encoded.len());
1300
1301        let account =
1302            Account::from_decrypted_libolm_pickle(&encoded).expect("Should unpickle account");
1303
1304        let key_bytes = alice
1305            .fallback_key()
1306            .values()
1307            .next()
1308            .expect("Should have a fallback key before encoding")
1309            .to_bytes();
1310        let decoded_key_bytes = account
1311            .fallback_key()
1312            .values()
1313            .next()
1314            .expect("Should have a fallback key after decoding")
1315            .to_bytes();
1316        assert_eq!(key_bytes, decoded_key_bytes);
1317    }
1318
1319    #[test]
1320    #[cfg(feature = "libolm-compat")]
1321    fn pickle_cycle_with_two_fallback_keys() {
1322        let mut alice = Account::new();
1323        alice.generate_fallback_key();
1324        alice.generate_fallback_key();
1325
1326        let mut encoded = Vec::<u8>::new();
1327        let pickle = Pickle::from(&alice);
1328        let size = pickle.encode(&mut encoded).expect("Should encode pickle");
1329        assert_eq!(size, encoded.len());
1330
1331        let account =
1332            Account::from_decrypted_libolm_pickle(&encoded).expect("Should unpickle account");
1333
1334        let key_bytes = alice
1335            .fallback_key()
1336            .values()
1337            .next()
1338            .expect("Should have a fallback key before encoding")
1339            .to_bytes();
1340        let decoded_key_bytes = account
1341            .fallback_key()
1342            .values()
1343            .next()
1344            .expect("Should have a fallback key after decoding")
1345            .to_bytes();
1346        assert_eq!(key_bytes, decoded_key_bytes);
1347    }
1348
1349    #[test]
1350    #[cfg(feature = "libolm-compat")]
1351    fn signing_with_expanded_key() -> Result<()> {
1352        let olm = OlmAccount::new();
1353        olm.generate_one_time_keys(10);
1354        olm.generate_fallback_key();
1355
1356        let key = b"DEFAULT_PICKLE_KEY";
1357        let pickle = olm.pickle(olm_rs::PicklingMode::Encrypted { key: key.to_vec() });
1358
1359        let account_with_expanded_key = Account::from_libolm_pickle(&pickle, key)?;
1360
1361        // The clone is needed since we're later on using the account.
1362        #[allow(clippy::redundant_clone)]
1363        let signing_key_clone = account_with_expanded_key.signing_key.clone();
1364        signing_key_clone.sign("You met with a terrible fate, haven’t you?".as_bytes());
1365        account_with_expanded_key.sign("You met with a terrible fate, haven’t you?".as_bytes());
1366
1367        Ok(())
1368    }
1369
1370    #[test]
1371    fn invalid_session_creation_does_not_remove_otk() -> Result<()> {
1372        let mut alice = Account::new();
1373        let malory = Account::new();
1374        alice.generate_one_time_keys(1);
1375
1376        let mut session = malory.create_outbound_session(
1377            SessionConfig::default(),
1378            alice.curve25519_key(),
1379            *alice.one_time_keys().values().next().expect("Should have one-time key"),
1380        );
1381
1382        let message = session.encrypt("Test");
1383
1384        if let OlmMessage::PreKey(m) = message {
1385            let mut message = m.to_bytes();
1386            let message_len = message.len();
1387
1388            // We mangle the MAC so decryption fails but creating a Session
1389            // succeeds.
1390            message[message_len - Mac::TRUNCATED_LEN..message_len]
1391                .copy_from_slice(&[0u8; Mac::TRUNCATED_LEN]);
1392
1393            let message = PreKeyMessage::try_from(message)?;
1394
1395            match alice.create_inbound_session(malory.curve25519_key(), &message) {
1396                Err(SessionCreationError::Decryption(_)) => {}
1397                e => bail!("Expected a decryption error, got {:?}", e),
1398            }
1399            assert!(
1400                !alice.one_time_keys.private_keys.is_empty(),
1401                "The one-time key was removed when it shouldn't"
1402            );
1403
1404            Ok(())
1405        } else {
1406            bail!("Invalid message type");
1407        }
1408    }
1409
1410    #[test]
1411    #[cfg(feature = "libolm-compat")]
1412    fn fuzz_corpus_unpickling() {
1413        crate::run_corpus("olm-account-unpickling", |data| {
1414            let _ = Account::from_decrypted_libolm_pickle(data);
1415        });
1416    }
1417
1418    #[test]
1419    #[cfg(feature = "libolm-compat")]
1420    fn libolm_pickle_cycle() -> Result<()> {
1421        let message = "It's a secret to everybody";
1422
1423        let olm = OlmAccount::new();
1424        olm.generate_one_time_keys(10);
1425        olm.generate_fallback_key();
1426
1427        let olm_signature = olm.sign(message);
1428
1429        let key = b"DEFAULT_PICKLE_KEY";
1430        let pickle = olm.pickle(olm_rs::PicklingMode::Encrypted { key: key.to_vec() });
1431
1432        let account = Account::from_libolm_pickle(&pickle, key).unwrap();
1433        let vodozemac_pickle = account.to_libolm_pickle(key).unwrap();
1434        let _ = Account::from_libolm_pickle(&vodozemac_pickle, key).unwrap();
1435
1436        let vodozemac_signature = account.sign(message.as_bytes());
1437        let olm_signature = crate::types::Ed25519Signature::from_base64(&olm_signature)
1438            .expect("We should be able to parse a signature produced by libolm");
1439        account
1440            .identity_keys()
1441            .ed25519
1442            .verify(message.as_bytes(), &olm_signature)
1443            .expect("We should be able to verify the libolm signature with our vodozemac Account");
1444
1445        let unpickled = OlmAccount::unpickle(
1446            vodozemac_pickle,
1447            olm_rs::PicklingMode::Encrypted { key: key.to_vec() },
1448        )
1449        .unwrap();
1450
1451        let utility = olm_rs::utility::OlmUtility::new();
1452        utility
1453            .ed25519_verify(
1454                unpickled.parsed_identity_keys().ed25519(),
1455                message,
1456                vodozemac_signature.to_base64(),
1457            )
1458            .expect("We should be able to verify the signature vodozemac created");
1459        utility
1460            .ed25519_verify(
1461                unpickled.parsed_identity_keys().ed25519(),
1462                message,
1463                olm_signature.to_base64(),
1464            )
1465            .expect("We should be able to verify the original signature from libolm");
1466
1467        assert_eq!(olm.parsed_identity_keys(), unpickled.parsed_identity_keys());
1468
1469        Ok(())
1470    }
1471
1472    #[test]
1473    fn decrypt_with_dehydrated_device() {
1474        let mut alice = Account::new();
1475        let bob = Account::new();
1476        let carol = Account::new();
1477
1478        alice.generate_one_time_keys(alice.max_number_of_one_time_keys());
1479        alice.generate_fallback_key();
1480
1481        let alice_dehydrated_result =
1482            alice.to_dehydrated_device(&PICKLE_KEY).expect("Should be able to dehydrate device");
1483
1484        // encrypt using a one-time key
1485        let mut bob_session = bob.create_outbound_session(
1486            SessionConfig::version_1(),
1487            alice.curve25519_key(),
1488            *alice
1489                .one_time_keys()
1490                .iter()
1491                .next()
1492                .expect("Failed getting alice's OTK, which should never happen here.")
1493                .1,
1494        );
1495
1496        // encrypt using a fallback key
1497        let mut carol_session = carol.create_outbound_session(
1498            SessionConfig::version_1(),
1499            alice.curve25519_key(),
1500            *alice
1501                .fallback_key()
1502                .iter()
1503                .next()
1504                .expect("Failed getting alice's fallback key, which should never happen here.")
1505                .1,
1506        );
1507
1508        let message = "It's a secret to everybody";
1509        let bob_olm_message = bob_session.encrypt(message);
1510        let carol_olm_message = carol_session.encrypt(message);
1511
1512        let mut alice_rehydrated = Account::from_dehydrated_device(
1513            &alice_dehydrated_result.ciphertext,
1514            &alice_dehydrated_result.nonce,
1515            &PICKLE_KEY,
1516        )
1517        .expect("Should be able to rehydrate device");
1518
1519        // make sure we can decrypt both messages
1520        assert_matches!(bob_olm_message, OlmMessage::PreKey(prekey_message));
1521        let InboundCreationResult { session: alice_session, plaintext } = alice_rehydrated
1522            .create_inbound_session(bob.curve25519_key(), &prekey_message)
1523            .expect("Alice should be able to create an inbound session from Bob's pre-key message");
1524        assert_eq!(alice_session.session_id(), bob_session.session_id());
1525        assert_eq!(message.as_bytes(), plaintext);
1526
1527        assert_matches!(carol_olm_message, OlmMessage::PreKey(prekey_message));
1528        let InboundCreationResult { session: alice_session, plaintext } = alice_rehydrated
1529            .create_inbound_session(carol.curve25519_key(), &prekey_message)
1530            .expect(
1531                "Alice should be able to create an inbound session from Carol's pre-key message",
1532            );
1533
1534        assert_eq!(alice_session.session_id(), carol_session.session_id());
1535        assert_eq!(message.as_bytes(), plaintext);
1536    }
1537
1538    #[test]
1539    fn fails_to_rehydrate_with_wrong_key() {
1540        let mut alice = Account::new();
1541
1542        alice.generate_one_time_keys(alice.max_number_of_one_time_keys());
1543        alice.generate_fallback_key();
1544
1545        let alice_dehydrated_result =
1546            alice.to_dehydrated_device(&PICKLE_KEY).expect("Should be able to dehydrate device");
1547
1548        assert!(
1549            Account::from_dehydrated_device(
1550                &alice_dehydrated_result.ciphertext,
1551                &alice_dehydrated_result.nonce,
1552                &[1; 32],
1553            )
1554            .is_err()
1555        );
1556
1557        assert!(
1558            Account::from_dehydrated_device(
1559                &alice_dehydrated_result.ciphertext,
1560                "WrongNonce123456",
1561                &PICKLE_KEY,
1562            )
1563            .is_err()
1564        );
1565    }
1566
1567    #[derive(Encode, Decode)]
1568    struct OptFallbackPickleTest {
1569        fallback1: dehydrated_device::OptFallbackKey,
1570        fallback2: dehydrated_device::OptFallbackKey,
1571    }
1572
1573    #[test]
1574    fn encodes_optional_fallback_key() {
1575        use std::io::Cursor;
1576
1577        let data_to_pickle = OptFallbackPickleTest {
1578            fallback1: dehydrated_device::OptFallbackKey {
1579                fallback_key: Some(dehydrated_device::OneTimeKey {
1580                    private_key: Box::new([1; 32]),
1581                }),
1582            },
1583            fallback2: dehydrated_device::OptFallbackKey { fallback_key: None },
1584        };
1585
1586        let buffer = Vec::<u8>::new();
1587        let mut cursor = Cursor::new(buffer);
1588        let pickle_length = data_to_pickle.encode(&mut cursor).expect("Can pickle data");
1589        let pickle = cursor.into_inner();
1590        assert_eq!(pickle.len(), pickle_length);
1591
1592        let mut cursor = Cursor::new(&pickle);
1593        let unpickled_data = OptFallbackPickleTest::decode(&mut cursor).expect("Can unpickle");
1594
1595        assert!(unpickled_data.fallback1.fallback_key.is_some());
1596        assert!(unpickled_data.fallback2.fallback_key.is_none());
1597    }
1598
1599    #[test]
1600    fn decrypted_dehydration_cycle() {
1601        use dehydrated_device::Pickle;
1602
1603        let alice = Account::new();
1604
1605        let mut encoded = Vec::<u8>::new();
1606        let pickle = Pickle::try_from(&alice)
1607            .expect("We should be able to create a dehydrated device from the account");
1608        let size = pickle.encode(&mut encoded).expect("Should dehydrate");
1609        assert_eq!(size, encoded.len());
1610
1611        let account =
1612            Account::from_decrypted_dehydrated_device(&encoded).expect("Should rehydrate account");
1613
1614        assert_eq!(alice.identity_keys(), account.identity_keys());
1615    }
1616}