Skip to main content

vodozemac/olm/messages/
message.rs

1// Copyright 2021 Damir Jelić
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16
17use prost::Message as ProstMessage;
18use serde::{Deserialize, Serialize};
19
20use crate::{
21    Curve25519PublicKey, DecodeError,
22    cipher::{Mac, MessageMac},
23    utilities::{VarInt, base64_decode, base64_encode, extract_mac},
24};
25
26pub(crate) const MAC_TRUNCATED_VERSION: u8 = 3;
27pub(crate) const VERSION: u8 = 4;
28
29/// An encrypted Olm message.
30///
31/// Contains metadata that is required to find the correct ratchet state of a
32/// [`Session`] necessary to decrypt the message.
33///
34/// [`Session`]: crate::olm::Session
35#[derive(Clone, PartialEq, Eq)]
36pub struct Message {
37    pub(crate) version: u8,
38    pub(crate) ratchet_key: Curve25519PublicKey,
39    pub(crate) chain_index: u64,
40    pub(crate) ciphertext: Vec<u8>,
41    pub(crate) mac: MessageMac,
42}
43
44impl Message {
45    /// The public part of the ratchet key, that was used when the message was
46    /// encrypted.
47    pub const fn ratchet_key(&self) -> Curve25519PublicKey {
48        self.ratchet_key
49    }
50
51    /// The index of the chain that was used when the message was encrypted.
52    pub const fn chain_index(&self) -> u64 {
53        self.chain_index
54    }
55
56    /// The actual ciphertext of the message.
57    pub fn ciphertext(&self) -> &[u8] {
58        &self.ciphertext
59    }
60
61    /// The version of the Olm message.
62    pub const fn version(&self) -> u8 {
63        self.version
64    }
65
66    /// Try to decode the given byte slice as a Olm [`Message`].
67    ///
68    /// The expected format of the byte array is described in the
69    /// [`Message::to_bytes()`] method.
70    pub fn from_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
71        Self::try_from(bytes)
72    }
73
74    /// Encode the `Message` as an array of bytes.
75    ///
76    /// Olm `Message`s consist of a one-byte version, followed by a variable
77    /// length payload and a fixed length message authentication code.
78    ///
79    /// ```text
80    /// +--------------+------------------------------------+-----------+
81    /// | Version Byte | Payload Bytes                      | MAC Bytes |
82    /// +--------------+------------------------------------+-----------+
83    /// ```
84    ///
85    /// The payload uses a format based on the Protocol Buffers encoding. It
86    /// consists of the following key-value pairs:
87    ///
88    /// **Name**   |**Tag**|**Type**|               **Meaning**
89    /// :---------:|:-----:|:------:|:-----------------------------------------:
90    /// Ratchet-Key|  0x0A | String |The public part of the ratchet key
91    /// Chain-Index|  0x10 | Integer|The chain index, of the message
92    /// Cipher-Text|  0x22 | String |The cipher-text of the message
93    pub fn to_bytes(&self) -> Vec<u8> {
94        let mut message = self.encode();
95        message.extend(self.mac.as_bytes());
96
97        message
98    }
99
100    /// Try to decode the given string as a Olm [`Message`].
101    ///
102    /// The string needs to be a base64 encoded byte array that follows the
103    /// format described in the [`Message::to_bytes()`] method.
104    pub fn from_base64(message: &str) -> Result<Self, DecodeError> {
105        Self::try_from(message)
106    }
107
108    /// Encode the [`Message`] as a string.
109    ///
110    /// This method first calls [`Message::to_bytes()`] and then encodes the
111    /// resulting byte array as a string using base64 encoding.
112    pub fn to_base64(&self) -> String {
113        base64_encode(self.to_bytes())
114    }
115
116    #[cfg(feature = "experimental-session-config")]
117    pub(crate) fn new(
118        ratchet_key: Curve25519PublicKey,
119        chain_index: u64,
120        ciphertext: Vec<u8>,
121    ) -> Self {
122        Self {
123            version: VERSION,
124            ratchet_key,
125            chain_index,
126            ciphertext,
127            mac: Mac([0u8; Mac::LENGTH]).into(),
128        }
129    }
130
131    pub(crate) fn new_truncated_mac(
132        ratchet_key: Curve25519PublicKey,
133        chain_index: u64,
134        ciphertext: Vec<u8>,
135    ) -> Self {
136        Self {
137            version: MAC_TRUNCATED_VERSION,
138            ratchet_key,
139            chain_index,
140            ciphertext,
141            mac: [0u8; Mac::TRUNCATED_LEN].into(),
142        }
143    }
144
145    fn encode(&self) -> Vec<u8> {
146        ProtoBufMessage {
147            ratchet_key: self.ratchet_key.to_bytes().to_vec(),
148            chain_index: self.chain_index,
149            ciphertext: self.ciphertext.clone(),
150        }
151        .encode_manual(self.version)
152    }
153
154    pub(crate) fn to_mac_bytes(&self) -> Vec<u8> {
155        self.encode()
156    }
157
158    pub(crate) fn set_mac(&mut self, mac: Mac) {
159        match self.mac {
160            MessageMac::Truncated(_) => self.mac = mac.truncate().into(),
161            MessageMac::Full(_) => self.mac = mac.into(),
162        }
163    }
164}
165
166impl Serialize for Message {
167    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
168    where
169        S: serde::Serializer,
170    {
171        let message = self.to_base64();
172        serializer.serialize_str(&message)
173    }
174}
175
176impl<'de> Deserialize<'de> for Message {
177    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
178        let ciphertext = String::deserialize(d)?;
179        Message::from_base64(&ciphertext).map_err(serde::de::Error::custom)
180    }
181}
182
183impl TryFrom<&str> for Message {
184    type Error = DecodeError;
185
186    fn try_from(value: &str) -> Result<Self, Self::Error> {
187        let decoded = base64_decode(value)?;
188
189        Self::try_from(decoded)
190    }
191}
192
193impl TryFrom<Vec<u8>> for Message {
194    type Error = DecodeError;
195
196    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
197        Self::try_from(value.as_slice())
198    }
199}
200
201impl TryFrom<&[u8]> for Message {
202    type Error = DecodeError;
203
204    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
205        let version = *value.first().ok_or(DecodeError::MissingVersion)?;
206
207        let mac_length = match version {
208            VERSION => Mac::LENGTH,
209            MAC_TRUNCATED_VERSION => Mac::TRUNCATED_LEN,
210            _ => return Err(DecodeError::InvalidVersion(VERSION, version)),
211        };
212
213        if value.len() < mac_length + 2 {
214            Err(DecodeError::MessageTooShort(value.len()))
215        } else {
216            let inner = ProtoBufMessage::decode(
217                value
218                    .get(1..value.len() - mac_length)
219                    .ok_or_else(|| DecodeError::MessageTooShort(value.len()))?,
220            )?;
221
222            let mac_slice = &value[value.len() - mac_length..];
223
224            if mac_slice.len() != mac_length {
225                Err(DecodeError::InvalidMacLength(mac_length, mac_slice.len()))
226            } else {
227                let mac = extract_mac(mac_slice, version == MAC_TRUNCATED_VERSION);
228
229                let chain_index = inner.chain_index;
230                let ciphertext = inner.ciphertext;
231                let ratchet_key = Curve25519PublicKey::from_slice(&inner.ratchet_key)?;
232
233                let message = Message { version, ratchet_key, chain_index, ciphertext, mac };
234
235                Ok(message)
236            }
237        }
238    }
239}
240
241impl Debug for Message {
242    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243        let Self { version, ratchet_key, chain_index, ciphertext: _, mac: _ } = self;
244
245        f.debug_struct("Message")
246            .field("version", version)
247            .field("ratchet_key", ratchet_key)
248            .field("chain_index", chain_index)
249            .finish_non_exhaustive()
250    }
251}
252
253#[derive(ProstMessage, PartialEq, Eq)]
254struct ProtoBufMessage {
255    #[prost(bytes, tag = "1")]
256    ratchet_key: Vec<u8>,
257    #[prost(uint64, tag = "2")]
258    chain_index: u64,
259    #[prost(bytes, tag = "4")]
260    ciphertext: Vec<u8>,
261}
262
263impl ProtoBufMessage {
264    const RATCHET_TAG: &'static [u8; 1] = b"\x0A";
265    const INDEX_TAG: &'static [u8; 1] = b"\x10";
266    const CIPHER_TAG: &'static [u8; 1] = b"\x22";
267
268    fn encode_manual(&self, version: u8) -> Vec<u8> {
269        let index = self.chain_index.to_var_int();
270        let ratchet_len = self.ratchet_key.len().to_var_int();
271        let ciphertext_len = self.ciphertext.len().to_var_int();
272
273        [
274            [version].as_ref(),
275            Self::RATCHET_TAG.as_ref(),
276            &ratchet_len,
277            &self.ratchet_key,
278            Self::INDEX_TAG.as_ref(),
279            &index,
280            Self::CIPHER_TAG.as_ref(),
281            &ciphertext_len,
282            &self.ciphertext,
283        ]
284        .concat()
285    }
286}
287
288#[cfg(test)]
289mod test {
290    use assert_matches2::{assert_let, assert_matches};
291
292    use super::Message;
293    use crate::{
294        Curve25519PublicKey, DecodeError,
295        olm::messages::message::{MAC_TRUNCATED_VERSION, VERSION},
296    };
297
298    #[test]
299    fn encode() {
300        let message = b"\x03\n\x20ratchetkeyhereprettyplease123456\x10\x02\"\nciphertext";
301        let message_mac =
302            b"\x03\n\x20ratchetkeyhereprettyplease123456\x10\x02\"\nciphertextMACHEREE";
303
304        let ratchet_key = Curve25519PublicKey::from(*b"ratchetkeyhereprettyplease123456");
305        let ciphertext = b"ciphertext";
306        let chain_index = 2;
307
308        let mut encoded = Message::new_truncated_mac(ratchet_key, chain_index, ciphertext.to_vec());
309        encoded.mac = (*b"MACHEREE").into();
310
311        assert_eq!(encoded.to_mac_bytes(), message.as_ref());
312        assert_eq!(encoded.to_bytes(), message_mac.as_ref());
313        assert_eq!(encoded.ciphertext(), ciphertext.to_vec());
314        assert_eq!(encoded.chain_index(), chain_index);
315        assert_eq!(encoded.version(), MAC_TRUNCATED_VERSION);
316    }
317
318    #[test]
319    fn from_bytes_normal_message() {
320        let bytes = b"\x04\n\x20ratchetkeyhereprettyplease123456\x10\x02\"\nciphertextMAC_01234567890_01234567890_HERE";
321        let result = Message::try_from(bytes.as_slice());
322        assert_let!(Ok(message) = result);
323
324        let Message { version, ratchet_key, chain_index, ciphertext, mac } = message;
325
326        assert_eq!(version, VERSION);
327        assert_eq!(ratchet_key.as_bytes(), b"ratchetkeyhereprettyplease123456");
328        assert_eq!(chain_index, 2);
329        assert_eq!(ciphertext, b"ciphertext");
330        assert_eq!(mac.as_bytes(), b"MAC_01234567890_01234567890_HERE");
331    }
332
333    #[test]
334    fn from_bytes_too_short() {
335        let bytes = vec![MAC_TRUNCATED_VERSION, 0, 0, 0, 0, 0, 0, 0, 0];
336        let result = Message::try_from(bytes);
337        assert_matches!(result, Err(DecodeError::MessageTooShort(9)));
338    }
339
340    #[test]
341    fn from_bytes_decoding_error() {
342        let bytes = vec![MAC_TRUNCATED_VERSION, 0, 0, 0, 0, 0, 0, 0, 0, 0];
343        let result = Message::try_from(bytes);
344        assert_matches!(result, Err(DecodeError::ProtoBufError(_)));
345    }
346
347    #[test]
348    fn from_bytes_invalid_tag() {
349        let bytes = [
350            VERSION, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
351            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
352        ];
353        let result = Message::try_from(bytes.to_vec());
354        assert_matches!(result, Err(DecodeError::ProtoBufError(_)));
355    }
356}