Skip to main content

vodozemac/olm/messages/
mod.rs

1// Copyright 2021 Damir Jelić
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod message;
16mod pre_key;
17
18pub use message::Message;
19pub use pre_key::PreKeyMessage;
20use serde::{Deserialize, Serialize};
21
22use crate::{DecodeError, base64_decode, base64_encode};
23
24/// Enum over the different Olm message types.
25///
26/// Olm uses two types of messages. The underlying transport protocol must
27/// provide a means for recipients to distinguish between them.
28///
29/// [`OlmMessage`] provides [`Serialize`] and [`Deserialize`] implementations
30/// that are compatible with [Matrix].
31///
32/// [Matrix]: https://spec.matrix.org/latest/client-server-api/#molmv1curve25519-aes-sha2
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub enum OlmMessage {
35    /// A normal message, contains only the ciphertext and metadata to decrypt
36    /// it.
37    Normal(Message),
38    /// A pre-key message, contains metadata to establish a [`Session`] as well
39    /// as a [`Message`].
40    ///
41    /// [`Session`]: crate::olm::Session
42    PreKey(PreKeyMessage),
43}
44
45impl From<Message> for OlmMessage {
46    fn from(m: Message) -> Self {
47        Self::Normal(m)
48    }
49}
50
51impl From<PreKeyMessage> for OlmMessage {
52    fn from(m: PreKeyMessage) -> Self {
53        Self::PreKey(m)
54    }
55}
56
57#[derive(Serialize, Deserialize)]
58struct MessageSerdeHelper {
59    #[serde(rename = "type")]
60    message_type: usize,
61    #[serde(rename = "body")]
62    ciphertext: String,
63}
64
65impl Serialize for OlmMessage {
66    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
67    where
68        S: serde::Serializer,
69    {
70        let (message_type, ciphertext) = self.to_parts();
71        let message = MessageSerdeHelper { message_type, ciphertext: base64_encode(ciphertext) };
72
73        message.serialize(serializer)
74    }
75}
76
77impl<'de> Deserialize<'de> for OlmMessage {
78    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
79        let value = MessageSerdeHelper::deserialize(d)?;
80        let ciphertext_bytes = base64_decode(value.ciphertext).map_err(serde::de::Error::custom)?;
81
82        OlmMessage::from_parts(value.message_type, ciphertext_bytes.as_slice())
83            .map_err(serde::de::Error::custom)
84    }
85}
86
87impl OlmMessage {
88    /// Create an [`OlmMessage`] from a message type and a ciphertext.
89    pub fn from_parts(message_type: usize, ciphertext: &[u8]) -> Result<Self, DecodeError> {
90        match message_type {
91            0 => Ok(Self::PreKey(PreKeyMessage::from_bytes(ciphertext)?)),
92            1 => Ok(Self::Normal(Message::from_bytes(ciphertext)?)),
93            m => Err(DecodeError::MessageType(m)),
94        }
95    }
96
97    /// Get the message's ciphertext as a byte array.
98    pub fn message(&self) -> &[u8] {
99        match self {
100            OlmMessage::Normal(m) => &m.ciphertext,
101            OlmMessage::PreKey(m) => &m.message.ciphertext,
102        }
103    }
104
105    /// Get the type of the message.
106    pub const fn message_type(&self) -> MessageType {
107        match self {
108            OlmMessage::Normal(_) => MessageType::Normal,
109            OlmMessage::PreKey(_) => MessageType::PreKey,
110        }
111    }
112
113    /// Convert the [`OlmMessage`] into a message type, and ciphertext bytes
114    /// tuple.
115    pub fn to_parts(&self) -> (usize, Vec<u8>) {
116        let message_type = self.message_type();
117
118        match self {
119            OlmMessage::Normal(m) => (message_type.into(), m.to_bytes()),
120            OlmMessage::PreKey(m) => (message_type.into(), m.to_bytes()),
121        }
122    }
123}
124
125/// An enum over the two supported message types.
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum MessageType {
128    /// The pre-key message type.
129    PreKey = 0,
130    /// The normal message type.
131    Normal = 1,
132}
133
134impl TryFrom<usize> for MessageType {
135    type Error = ();
136
137    fn try_from(value: usize) -> Result<Self, Self::Error> {
138        match value {
139            0 => Ok(MessageType::PreKey),
140            1 => Ok(MessageType::Normal),
141            _ => Err(()),
142        }
143    }
144}
145
146impl From<MessageType> for usize {
147    fn from(value: MessageType) -> usize {
148        value as usize
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use anyhow::Result;
155    use assert_matches2::assert_matches;
156    use olm_rs::session::OlmMessage as LibolmMessage;
157    use serde_json::json;
158
159    use super::*;
160    use crate::run_corpus;
161
162    const PRE_KEY_MESSAGE: &str = "AwoghAEuxPZ+w7M3pgUae4tDNiggUpOsQ/zci457VAti\
163                                   AEYSIO3xOKRDBWKicIfxjSmYCYZ9DD4RMLjvvclbMlE5\
164                                   yIEWGiApLrCr853CKlPpW4Bi7S8ykRcejJ0lq7AfYLXK\
165                                   CjKdHSJPAwoghw3+P+cajhWj9Qzp5g87h+tbpiuh5wEa\
166                                   eUppqmWqug4QASIgRhZ2cgZcIWQbIa23R7U4y1Mo1R/t\
167                                   LCaMU+xjzRV5smGsCrJ6AHwktg";
168
169    const PRE_KEY_MESSAGE_CIPHERTEXT: [u8; 32] = [
170        70, 22, 118, 114, 6, 92, 33, 100, 27, 33, 173, 183, 71, 181, 56, 203, 83, 40, 213, 31, 237,
171        44, 38, 140, 83, 236, 99, 205, 21, 121, 178, 97,
172    ];
173
174    const MESSAGE: &str = "AwogI7JhE/UsMZqXKb3xV6kUZWoJc6jTm2+AIgWYmaETIR0QASIQ\
175                           +X2zb7kEX/3JvoLspcNBcLWOFXYpV0nS";
176
177    const MESSAGE_CIPHERTEXT: [u8; 16] =
178        [249, 125, 179, 111, 185, 4, 95, 253, 201, 190, 130, 236, 165, 195, 65, 112];
179
180    impl From<OlmMessage> for LibolmMessage {
181        fn from(value: OlmMessage) -> LibolmMessage {
182            match value {
183                OlmMessage::Normal(m) => LibolmMessage::from_type_and_ciphertext(1, m.to_base64())
184                    .expect("Can't create a valid libolm message"),
185                OlmMessage::PreKey(m) => LibolmMessage::from_type_and_ciphertext(0, m.to_base64())
186                    .expect("Can't create a valid libolm pre-key message"),
187            }
188        }
189    }
190
191    impl From<LibolmMessage> for OlmMessage {
192        fn from(other: LibolmMessage) -> Self {
193            let (message_type, ciphertext) = other.to_tuple();
194            let ciphertext_bytes = base64_decode(ciphertext).expect("Can't decode base64");
195
196            Self::from_parts(message_type.into(), ciphertext_bytes.as_slice())
197                .expect("Can't decode a libolm message")
198        }
199    }
200
201    #[test]
202    fn message_type_from_usize() {
203        assert_eq!(
204            MessageType::try_from(0),
205            Ok(MessageType::PreKey),
206            "0 should denote a pre-key Olm message"
207        );
208        assert_eq!(
209            MessageType::try_from(1),
210            Ok(MessageType::Normal),
211            "1 should denote a normal Olm message"
212        );
213        assert!(
214            MessageType::try_from(2).is_err(),
215            "2 should be recognized as an unknown Olm message type"
216        );
217    }
218
219    #[test]
220    fn from_json() -> Result<()> {
221        let value = json!({
222            "type": 0u8,
223            "body": PRE_KEY_MESSAGE,
224        });
225
226        let message: OlmMessage = serde_json::from_value(value.clone())?;
227        assert_matches!(message.clone(), OlmMessage::PreKey(_));
228
229        let serialized = serde_json::to_value(message)?;
230        assert_eq!(value, serialized, "The serialization cycle isn't a noop");
231
232        let value = json!({
233            "type": 1u8,
234            "body": MESSAGE,
235        });
236
237        let message: OlmMessage = serde_json::from_value(value.clone())?;
238        assert_matches!(message.clone(), OlmMessage::Normal(_));
239
240        let serialized = serde_json::to_value(message)?;
241        assert_eq!(value, serialized, "The serialization cycle isn't a noop");
242
243        Ok(())
244    }
245
246    #[test]
247    fn from_parts() -> Result<()> {
248        let message = OlmMessage::from_parts(0, base64_decode(PRE_KEY_MESSAGE)?.as_slice())?;
249        assert_matches!(message.clone(), OlmMessage::PreKey(_));
250        assert_eq!(
251            message.message_type(),
252            MessageType::PreKey,
253            "Expected message to be recognized as a pre-key Olm message."
254        );
255        assert_eq!(message.message(), PRE_KEY_MESSAGE_CIPHERTEXT);
256        assert_eq!(
257            message.to_parts(),
258            (0, base64_decode(PRE_KEY_MESSAGE)?),
259            "Roundtrip not identity."
260        );
261
262        let message = OlmMessage::from_parts(1, base64_decode(MESSAGE)?.as_slice())?;
263        assert_matches!(message.clone(), OlmMessage::Normal(_));
264        assert_eq!(
265            message.message_type(),
266            MessageType::Normal,
267            "Expected message to be recognized as a normal Olm message."
268        );
269        assert_eq!(message.message(), MESSAGE_CIPHERTEXT);
270        assert_eq!(message.to_parts(), (1, base64_decode(MESSAGE)?), "Roundtrip not identity.");
271
272        OlmMessage::from_parts(3, base64_decode(PRE_KEY_MESSAGE)?.as_slice())
273            .expect_err("Unknown message types can't be parsed");
274
275        Ok(())
276    }
277
278    #[test]
279    fn fuzz_corpus_decoding() {
280        run_corpus("olm-message-decoding", |data| {
281            let _ = PreKeyMessage::from_bytes(data);
282        });
283    }
284}