1mod message;
16mod pre_key;
17
18pub use message::Message;
19pub use pre_key::PreKeyMessage;
20use serde::{Deserialize, Serialize};
21
22use crate::{DecodeError, base64_decode, base64_encode};
23
24#[derive(Debug, Clone, PartialEq, Eq)]
34pub enum OlmMessage {
35 Normal(Message),
38 PreKey(PreKeyMessage),
43}
44
45impl From<Message> for OlmMessage {
46 fn from(m: Message) -> Self {
47 Self::Normal(m)
48 }
49}
50
51impl From<PreKeyMessage> for OlmMessage {
52 fn from(m: PreKeyMessage) -> Self {
53 Self::PreKey(m)
54 }
55}
56
57#[derive(Serialize, Deserialize)]
58struct MessageSerdeHelper {
59 #[serde(rename = "type")]
60 message_type: usize,
61 #[serde(rename = "body")]
62 ciphertext: String,
63}
64
65impl Serialize for OlmMessage {
66 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
67 where
68 S: serde::Serializer,
69 {
70 let (message_type, ciphertext) = self.to_parts();
71 let message = MessageSerdeHelper { message_type, ciphertext: base64_encode(ciphertext) };
72
73 message.serialize(serializer)
74 }
75}
76
77impl<'de> Deserialize<'de> for OlmMessage {
78 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
79 let value = MessageSerdeHelper::deserialize(d)?;
80 let ciphertext_bytes = base64_decode(value.ciphertext).map_err(serde::de::Error::custom)?;
81
82 OlmMessage::from_parts(value.message_type, ciphertext_bytes.as_slice())
83 .map_err(serde::de::Error::custom)
84 }
85}
86
87impl OlmMessage {
88 pub fn from_parts(message_type: usize, ciphertext: &[u8]) -> Result<Self, DecodeError> {
90 match message_type {
91 0 => Ok(Self::PreKey(PreKeyMessage::from_bytes(ciphertext)?)),
92 1 => Ok(Self::Normal(Message::from_bytes(ciphertext)?)),
93 m => Err(DecodeError::MessageType(m)),
94 }
95 }
96
97 pub fn message(&self) -> &[u8] {
99 match self {
100 OlmMessage::Normal(m) => &m.ciphertext,
101 OlmMessage::PreKey(m) => &m.message.ciphertext,
102 }
103 }
104
105 pub const fn message_type(&self) -> MessageType {
107 match self {
108 OlmMessage::Normal(_) => MessageType::Normal,
109 OlmMessage::PreKey(_) => MessageType::PreKey,
110 }
111 }
112
113 pub fn to_parts(&self) -> (usize, Vec<u8>) {
116 let message_type = self.message_type();
117
118 match self {
119 OlmMessage::Normal(m) => (message_type.into(), m.to_bytes()),
120 OlmMessage::PreKey(m) => (message_type.into(), m.to_bytes()),
121 }
122 }
123}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum MessageType {
128 PreKey = 0,
130 Normal = 1,
132}
133
134impl TryFrom<usize> for MessageType {
135 type Error = ();
136
137 fn try_from(value: usize) -> Result<Self, Self::Error> {
138 match value {
139 0 => Ok(MessageType::PreKey),
140 1 => Ok(MessageType::Normal),
141 _ => Err(()),
142 }
143 }
144}
145
146impl From<MessageType> for usize {
147 fn from(value: MessageType) -> usize {
148 value as usize
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use anyhow::Result;
155 use assert_matches2::assert_matches;
156 use olm_rs::session::OlmMessage as LibolmMessage;
157 use serde_json::json;
158
159 use super::*;
160 use crate::run_corpus;
161
162 const PRE_KEY_MESSAGE: &str = "AwoghAEuxPZ+w7M3pgUae4tDNiggUpOsQ/zci457VAti\
163 AEYSIO3xOKRDBWKicIfxjSmYCYZ9DD4RMLjvvclbMlE5\
164 yIEWGiApLrCr853CKlPpW4Bi7S8ykRcejJ0lq7AfYLXK\
165 CjKdHSJPAwoghw3+P+cajhWj9Qzp5g87h+tbpiuh5wEa\
166 eUppqmWqug4QASIgRhZ2cgZcIWQbIa23R7U4y1Mo1R/t\
167 LCaMU+xjzRV5smGsCrJ6AHwktg";
168
169 const PRE_KEY_MESSAGE_CIPHERTEXT: [u8; 32] = [
170 70, 22, 118, 114, 6, 92, 33, 100, 27, 33, 173, 183, 71, 181, 56, 203, 83, 40, 213, 31, 237,
171 44, 38, 140, 83, 236, 99, 205, 21, 121, 178, 97,
172 ];
173
174 const MESSAGE: &str = "AwogI7JhE/UsMZqXKb3xV6kUZWoJc6jTm2+AIgWYmaETIR0QASIQ\
175 +X2zb7kEX/3JvoLspcNBcLWOFXYpV0nS";
176
177 const MESSAGE_CIPHERTEXT: [u8; 16] =
178 [249, 125, 179, 111, 185, 4, 95, 253, 201, 190, 130, 236, 165, 195, 65, 112];
179
180 impl From<OlmMessage> for LibolmMessage {
181 fn from(value: OlmMessage) -> LibolmMessage {
182 match value {
183 OlmMessage::Normal(m) => LibolmMessage::from_type_and_ciphertext(1, m.to_base64())
184 .expect("Can't create a valid libolm message"),
185 OlmMessage::PreKey(m) => LibolmMessage::from_type_and_ciphertext(0, m.to_base64())
186 .expect("Can't create a valid libolm pre-key message"),
187 }
188 }
189 }
190
191 impl From<LibolmMessage> for OlmMessage {
192 fn from(other: LibolmMessage) -> Self {
193 let (message_type, ciphertext) = other.to_tuple();
194 let ciphertext_bytes = base64_decode(ciphertext).expect("Can't decode base64");
195
196 Self::from_parts(message_type.into(), ciphertext_bytes.as_slice())
197 .expect("Can't decode a libolm message")
198 }
199 }
200
201 #[test]
202 fn message_type_from_usize() {
203 assert_eq!(
204 MessageType::try_from(0),
205 Ok(MessageType::PreKey),
206 "0 should denote a pre-key Olm message"
207 );
208 assert_eq!(
209 MessageType::try_from(1),
210 Ok(MessageType::Normal),
211 "1 should denote a normal Olm message"
212 );
213 assert!(
214 MessageType::try_from(2).is_err(),
215 "2 should be recognized as an unknown Olm message type"
216 );
217 }
218
219 #[test]
220 fn from_json() -> Result<()> {
221 let value = json!({
222 "type": 0u8,
223 "body": PRE_KEY_MESSAGE,
224 });
225
226 let message: OlmMessage = serde_json::from_value(value.clone())?;
227 assert_matches!(message.clone(), OlmMessage::PreKey(_));
228
229 let serialized = serde_json::to_value(message)?;
230 assert_eq!(value, serialized, "The serialization cycle isn't a noop");
231
232 let value = json!({
233 "type": 1u8,
234 "body": MESSAGE,
235 });
236
237 let message: OlmMessage = serde_json::from_value(value.clone())?;
238 assert_matches!(message.clone(), OlmMessage::Normal(_));
239
240 let serialized = serde_json::to_value(message)?;
241 assert_eq!(value, serialized, "The serialization cycle isn't a noop");
242
243 Ok(())
244 }
245
246 #[test]
247 fn from_parts() -> Result<()> {
248 let message = OlmMessage::from_parts(0, base64_decode(PRE_KEY_MESSAGE)?.as_slice())?;
249 assert_matches!(message.clone(), OlmMessage::PreKey(_));
250 assert_eq!(
251 message.message_type(),
252 MessageType::PreKey,
253 "Expected message to be recognized as a pre-key Olm message."
254 );
255 assert_eq!(message.message(), PRE_KEY_MESSAGE_CIPHERTEXT);
256 assert_eq!(
257 message.to_parts(),
258 (0, base64_decode(PRE_KEY_MESSAGE)?),
259 "Roundtrip not identity."
260 );
261
262 let message = OlmMessage::from_parts(1, base64_decode(MESSAGE)?.as_slice())?;
263 assert_matches!(message.clone(), OlmMessage::Normal(_));
264 assert_eq!(
265 message.message_type(),
266 MessageType::Normal,
267 "Expected message to be recognized as a normal Olm message."
268 );
269 assert_eq!(message.message(), MESSAGE_CIPHERTEXT);
270 assert_eq!(message.to_parts(), (1, base64_decode(MESSAGE)?), "Roundtrip not identity.");
271
272 OlmMessage::from_parts(3, base64_decode(PRE_KEY_MESSAGE)?.as_slice())
273 .expect_err("Unknown message types can't be parsed");
274
275 Ok(())
276 }
277
278 #[test]
279 fn fuzz_corpus_decoding() {
280 run_corpus("olm-message-decoding", |data| {
281 let _ = PreKeyMessage::from_bytes(data);
282 });
283 }
284}