vodozemac/olm/messages/
message.rs

1// Copyright 2021 Damir Jelić
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16
17use prost::Message as ProstMessage;
18use serde::{Deserialize, Serialize};
19
20use crate::{
21    Curve25519PublicKey, DecodeError,
22    cipher::{Mac, MessageMac},
23    utilities::{VarInt, base64_decode, base64_encode, extract_mac},
24};
25
26const MAC_TRUNCATED_VERSION: u8 = 3;
27const VERSION: u8 = 4;
28
29/// An encrypted Olm message.
30///
31/// Contains metadata that is required to find the correct ratchet state of a
32/// [`Session`] necessary to decrypt the message.
33///
34/// [`Session`]: crate::olm::Session
35#[derive(Clone, PartialEq, Eq)]
36pub struct Message {
37    pub(crate) version: u8,
38    pub(crate) ratchet_key: Curve25519PublicKey,
39    pub(crate) chain_index: u64,
40    pub(crate) ciphertext: Vec<u8>,
41    pub(crate) mac: MessageMac,
42}
43
44impl Message {
45    /// The public part of the ratchet key, that was used when the message was
46    /// encrypted.
47    pub const fn ratchet_key(&self) -> Curve25519PublicKey {
48        self.ratchet_key
49    }
50
51    /// The index of the chain that was used when the message was encrypted.
52    pub const fn chain_index(&self) -> u64 {
53        self.chain_index
54    }
55
56    /// The actual ciphertext of the message.
57    pub fn ciphertext(&self) -> &[u8] {
58        &self.ciphertext
59    }
60
61    /// The version of the Olm message.
62    pub const fn version(&self) -> u8 {
63        self.version
64    }
65
66    /// Has the MAC been truncated in this Olm message.
67    pub const fn mac_truncated(&self) -> bool {
68        self.version == MAC_TRUNCATED_VERSION
69    }
70
71    /// Try to decode the given byte slice as a Olm [`Message`].
72    ///
73    /// The expected format of the byte array is described in the
74    /// [`Message::to_bytes()`] method.
75    pub fn from_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
76        Self::try_from(bytes)
77    }
78
79    /// Encode the `Message` as an array of bytes.
80    ///
81    /// Olm `Message`s consist of a one-byte version, followed by a variable
82    /// length payload and a fixed length message authentication code.
83    ///
84    /// ```text
85    /// +--------------+------------------------------------+-----------+
86    /// | Version Byte | Payload Bytes                      | MAC Bytes |
87    /// +--------------+------------------------------------+-----------+
88    /// ```
89    ///
90    /// The payload uses a format based on the Protocol Buffers encoding. It
91    /// consists of the following key-value pairs:
92    ///
93    /// **Name**   |**Tag**|**Type**|               **Meaning**
94    /// :---------:|:-----:|:------:|:-----------------------------------------:
95    /// Ratchet-Key|  0x0A | String |The public part of the ratchet key
96    /// Chain-Index|  0x10 | Integer|The chain index, of the message
97    /// Cipher-Text|  0x22 | String |The cipher-text of the message
98    pub fn to_bytes(&self) -> Vec<u8> {
99        let mut message = self.encode();
100        message.extend(self.mac.as_bytes());
101
102        message
103    }
104
105    /// Try to decode the given string as a Olm [`Message`].
106    ///
107    /// The string needs to be a base64 encoded byte array that follows the
108    /// format described in the [`Message::to_bytes()`] method.
109    pub fn from_base64(message: &str) -> Result<Self, DecodeError> {
110        Self::try_from(message)
111    }
112
113    /// Encode the [`Message`] as a string.
114    ///
115    /// This method first calls [`Message::to_bytes()`] and then encodes the
116    /// resulting byte array as a string using base64 encoding.
117    pub fn to_base64(&self) -> String {
118        base64_encode(self.to_bytes())
119    }
120
121    pub(crate) fn new(
122        ratchet_key: Curve25519PublicKey,
123        chain_index: u64,
124        ciphertext: Vec<u8>,
125    ) -> Self {
126        Self {
127            version: VERSION,
128            ratchet_key,
129            chain_index,
130            ciphertext,
131            mac: Mac([0u8; Mac::LENGTH]).into(),
132        }
133    }
134
135    pub(crate) fn new_truncated_mac(
136        ratchet_key: Curve25519PublicKey,
137        chain_index: u64,
138        ciphertext: Vec<u8>,
139    ) -> Self {
140        Self {
141            version: MAC_TRUNCATED_VERSION,
142            ratchet_key,
143            chain_index,
144            ciphertext,
145            mac: [0u8; Mac::TRUNCATED_LEN].into(),
146        }
147    }
148
149    fn encode(&self) -> Vec<u8> {
150        ProtoBufMessage {
151            ratchet_key: self.ratchet_key.to_bytes().to_vec(),
152            chain_index: self.chain_index,
153            ciphertext: self.ciphertext.clone(),
154        }
155        .encode_manual(self.version)
156    }
157
158    pub(crate) fn to_mac_bytes(&self) -> Vec<u8> {
159        self.encode()
160    }
161
162    pub(crate) fn set_mac(&mut self, mac: Mac) {
163        match self.mac {
164            MessageMac::Truncated(_) => self.mac = mac.truncate().into(),
165            MessageMac::Full(_) => self.mac = mac.into(),
166        }
167    }
168}
169
170impl Serialize for Message {
171    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
172    where
173        S: serde::Serializer,
174    {
175        let message = self.to_base64();
176        serializer.serialize_str(&message)
177    }
178}
179
180impl<'de> Deserialize<'de> for Message {
181    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
182        let ciphertext = String::deserialize(d)?;
183        Message::from_base64(&ciphertext).map_err(serde::de::Error::custom)
184    }
185}
186
187impl TryFrom<&str> for Message {
188    type Error = DecodeError;
189
190    fn try_from(value: &str) -> Result<Self, Self::Error> {
191        let decoded = base64_decode(value)?;
192
193        Self::try_from(decoded)
194    }
195}
196
197impl TryFrom<Vec<u8>> for Message {
198    type Error = DecodeError;
199
200    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
201        Self::try_from(value.as_slice())
202    }
203}
204
205impl TryFrom<&[u8]> for Message {
206    type Error = DecodeError;
207
208    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
209        let version = *value.first().ok_or(DecodeError::MissingVersion)?;
210
211        let mac_length = match version {
212            VERSION => Mac::LENGTH,
213            MAC_TRUNCATED_VERSION => Mac::TRUNCATED_LEN,
214            _ => return Err(DecodeError::InvalidVersion(VERSION, version)),
215        };
216
217        if value.len() < mac_length + 2 {
218            Err(DecodeError::MessageTooShort(value.len()))
219        } else {
220            let inner = ProtoBufMessage::decode(
221                value
222                    .get(1..value.len() - mac_length)
223                    .ok_or_else(|| DecodeError::MessageTooShort(value.len()))?,
224            )?;
225
226            let mac_slice = &value[value.len() - mac_length..];
227
228            if mac_slice.len() != mac_length {
229                Err(DecodeError::InvalidMacLength(mac_length, mac_slice.len()))
230            } else {
231                let mac = extract_mac(mac_slice, version == MAC_TRUNCATED_VERSION);
232
233                let chain_index = inner.chain_index;
234                let ciphertext = inner.ciphertext;
235                let ratchet_key = Curve25519PublicKey::from_slice(&inner.ratchet_key)?;
236
237                let message = Message { version, ratchet_key, chain_index, ciphertext, mac };
238
239                Ok(message)
240            }
241        }
242    }
243}
244
245impl Debug for Message {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        let Self { version, ratchet_key, chain_index, ciphertext: _, mac: _ } = self;
248
249        f.debug_struct("Message")
250            .field("version", version)
251            .field("ratchet_key", ratchet_key)
252            .field("chain_index", chain_index)
253            .finish_non_exhaustive()
254    }
255}
256
257#[derive(ProstMessage, PartialEq, Eq)]
258struct ProtoBufMessage {
259    #[prost(bytes, tag = "1")]
260    ratchet_key: Vec<u8>,
261    #[prost(uint64, tag = "2")]
262    chain_index: u64,
263    #[prost(bytes, tag = "4")]
264    ciphertext: Vec<u8>,
265}
266
267impl ProtoBufMessage {
268    const RATCHET_TAG: &'static [u8; 1] = b"\x0A";
269    const INDEX_TAG: &'static [u8; 1] = b"\x10";
270    const CIPHER_TAG: &'static [u8; 1] = b"\x22";
271
272    fn encode_manual(&self, version: u8) -> Vec<u8> {
273        let index = self.chain_index.to_var_int();
274        let ratchet_len = self.ratchet_key.len().to_var_int();
275        let ciphertext_len = self.ciphertext.len().to_var_int();
276
277        [
278            [version].as_ref(),
279            Self::RATCHET_TAG.as_ref(),
280            &ratchet_len,
281            &self.ratchet_key,
282            Self::INDEX_TAG.as_ref(),
283            &index,
284            Self::CIPHER_TAG.as_ref(),
285            &ciphertext_len,
286            &self.ciphertext,
287        ]
288        .concat()
289    }
290}
291
292#[cfg(test)]
293mod test {
294    use assert_matches2::{assert_let, assert_matches};
295
296    use super::Message;
297    use crate::{
298        Curve25519PublicKey, DecodeError,
299        olm::messages::message::{MAC_TRUNCATED_VERSION, VERSION},
300    };
301
302    #[test]
303    fn encode() {
304        let message = b"\x03\n\x20ratchetkeyhereprettyplease123456\x10\x02\"\nciphertext";
305        let message_mac =
306            b"\x03\n\x20ratchetkeyhereprettyplease123456\x10\x02\"\nciphertextMACHEREE";
307
308        let ratchet_key = Curve25519PublicKey::from(*b"ratchetkeyhereprettyplease123456");
309        let ciphertext = b"ciphertext";
310        let chain_index = 2;
311
312        let mut encoded = Message::new_truncated_mac(ratchet_key, chain_index, ciphertext.to_vec());
313        encoded.mac = (*b"MACHEREE").into();
314
315        assert_eq!(encoded.to_mac_bytes(), message.as_ref());
316        assert_eq!(encoded.to_bytes(), message_mac.as_ref());
317        assert_eq!(encoded.ciphertext(), ciphertext.to_vec());
318        assert_eq!(encoded.chain_index(), chain_index);
319        assert_eq!(encoded.version(), MAC_TRUNCATED_VERSION);
320    }
321
322    #[test]
323    fn from_bytes_normal_message() {
324        let bytes = b"\x04\n\x20ratchetkeyhereprettyplease123456\x10\x02\"\nciphertextMAC_01234567890_01234567890_HERE";
325        let result = Message::try_from(bytes.as_slice());
326        assert_let!(Ok(message) = result);
327
328        let Message { version, ratchet_key, chain_index, ciphertext, mac } = message;
329
330        assert_eq!(version, VERSION);
331        assert_eq!(ratchet_key.as_bytes(), b"ratchetkeyhereprettyplease123456");
332        assert_eq!(chain_index, 2);
333        assert_eq!(ciphertext, b"ciphertext");
334        assert_eq!(mac.as_bytes(), b"MAC_01234567890_01234567890_HERE");
335    }
336
337    #[test]
338    fn from_bytes_too_short() {
339        let bytes = vec![MAC_TRUNCATED_VERSION, 0, 0, 0, 0, 0, 0, 0, 0];
340        let result = Message::try_from(bytes);
341        assert_matches!(result, Err(DecodeError::MessageTooShort(9)));
342    }
343
344    #[test]
345    fn from_bytes_decoding_error() {
346        let bytes = vec![MAC_TRUNCATED_VERSION, 0, 0, 0, 0, 0, 0, 0, 0, 0];
347        let result = Message::try_from(bytes);
348        assert_matches!(result, Err(DecodeError::ProtoBufError(_)));
349    }
350
351    #[test]
352    fn from_bytes_invalid_tag() {
353        let bytes = [
354            VERSION, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
355            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
356        ];
357        let result = Message::try_from(bytes.to_vec());
358        assert_matches!(result, Err(DecodeError::ProtoBufError(_)));
359    }
360}