1use std::fmt::Debug;
16
17use prost::Message as ProstMessage;
18use serde::{Deserialize, Serialize};
19
20use crate::{
21 Curve25519PublicKey, DecodeError,
22 cipher::{Mac, MessageMac},
23 utilities::{VarInt, base64_decode, base64_encode, extract_mac},
24};
25
26const MAC_TRUNCATED_VERSION: u8 = 3;
27const VERSION: u8 = 4;
28
29#[derive(Clone, PartialEq, Eq)]
36pub struct Message {
37 pub(crate) version: u8,
38 pub(crate) ratchet_key: Curve25519PublicKey,
39 pub(crate) chain_index: u64,
40 pub(crate) ciphertext: Vec<u8>,
41 pub(crate) mac: MessageMac,
42}
43
44impl Message {
45 pub const fn ratchet_key(&self) -> Curve25519PublicKey {
48 self.ratchet_key
49 }
50
51 pub const fn chain_index(&self) -> u64 {
53 self.chain_index
54 }
55
56 pub fn ciphertext(&self) -> &[u8] {
58 &self.ciphertext
59 }
60
61 pub const fn version(&self) -> u8 {
63 self.version
64 }
65
66 pub const fn mac_truncated(&self) -> bool {
68 self.version == MAC_TRUNCATED_VERSION
69 }
70
71 pub fn from_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
76 Self::try_from(bytes)
77 }
78
79 pub fn to_bytes(&self) -> Vec<u8> {
99 let mut message = self.encode();
100 message.extend(self.mac.as_bytes());
101
102 message
103 }
104
105 pub fn from_base64(message: &str) -> Result<Self, DecodeError> {
110 Self::try_from(message)
111 }
112
113 pub fn to_base64(&self) -> String {
118 base64_encode(self.to_bytes())
119 }
120
121 pub(crate) fn new(
122 ratchet_key: Curve25519PublicKey,
123 chain_index: u64,
124 ciphertext: Vec<u8>,
125 ) -> Self {
126 Self {
127 version: VERSION,
128 ratchet_key,
129 chain_index,
130 ciphertext,
131 mac: Mac([0u8; Mac::LENGTH]).into(),
132 }
133 }
134
135 pub(crate) fn new_truncated_mac(
136 ratchet_key: Curve25519PublicKey,
137 chain_index: u64,
138 ciphertext: Vec<u8>,
139 ) -> Self {
140 Self {
141 version: MAC_TRUNCATED_VERSION,
142 ratchet_key,
143 chain_index,
144 ciphertext,
145 mac: [0u8; Mac::TRUNCATED_LEN].into(),
146 }
147 }
148
149 fn encode(&self) -> Vec<u8> {
150 ProtoBufMessage {
151 ratchet_key: self.ratchet_key.to_bytes().to_vec(),
152 chain_index: self.chain_index,
153 ciphertext: self.ciphertext.clone(),
154 }
155 .encode_manual(self.version)
156 }
157
158 pub(crate) fn to_mac_bytes(&self) -> Vec<u8> {
159 self.encode()
160 }
161
162 pub(crate) fn set_mac(&mut self, mac: Mac) {
163 match self.mac {
164 MessageMac::Truncated(_) => self.mac = mac.truncate().into(),
165 MessageMac::Full(_) => self.mac = mac.into(),
166 }
167 }
168}
169
170impl Serialize for Message {
171 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
172 where
173 S: serde::Serializer,
174 {
175 let message = self.to_base64();
176 serializer.serialize_str(&message)
177 }
178}
179
180impl<'de> Deserialize<'de> for Message {
181 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
182 let ciphertext = String::deserialize(d)?;
183 Message::from_base64(&ciphertext).map_err(serde::de::Error::custom)
184 }
185}
186
187impl TryFrom<&str> for Message {
188 type Error = DecodeError;
189
190 fn try_from(value: &str) -> Result<Self, Self::Error> {
191 let decoded = base64_decode(value)?;
192
193 Self::try_from(decoded)
194 }
195}
196
197impl TryFrom<Vec<u8>> for Message {
198 type Error = DecodeError;
199
200 fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
201 Self::try_from(value.as_slice())
202 }
203}
204
205impl TryFrom<&[u8]> for Message {
206 type Error = DecodeError;
207
208 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
209 let version = *value.first().ok_or(DecodeError::MissingVersion)?;
210
211 let mac_length = match version {
212 VERSION => Mac::LENGTH,
213 MAC_TRUNCATED_VERSION => Mac::TRUNCATED_LEN,
214 _ => return Err(DecodeError::InvalidVersion(VERSION, version)),
215 };
216
217 if value.len() < mac_length + 2 {
218 Err(DecodeError::MessageTooShort(value.len()))
219 } else {
220 let inner = ProtoBufMessage::decode(
221 value
222 .get(1..value.len() - mac_length)
223 .ok_or_else(|| DecodeError::MessageTooShort(value.len()))?,
224 )?;
225
226 let mac_slice = &value[value.len() - mac_length..];
227
228 if mac_slice.len() != mac_length {
229 Err(DecodeError::InvalidMacLength(mac_length, mac_slice.len()))
230 } else {
231 let mac = extract_mac(mac_slice, version == MAC_TRUNCATED_VERSION);
232
233 let chain_index = inner.chain_index;
234 let ciphertext = inner.ciphertext;
235 let ratchet_key = Curve25519PublicKey::from_slice(&inner.ratchet_key)?;
236
237 let message = Message { version, ratchet_key, chain_index, ciphertext, mac };
238
239 Ok(message)
240 }
241 }
242 }
243}
244
245impl Debug for Message {
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 let Self { version, ratchet_key, chain_index, ciphertext: _, mac: _ } = self;
248
249 f.debug_struct("Message")
250 .field("version", version)
251 .field("ratchet_key", ratchet_key)
252 .field("chain_index", chain_index)
253 .finish_non_exhaustive()
254 }
255}
256
257#[derive(ProstMessage, PartialEq, Eq)]
258struct ProtoBufMessage {
259 #[prost(bytes, tag = "1")]
260 ratchet_key: Vec<u8>,
261 #[prost(uint64, tag = "2")]
262 chain_index: u64,
263 #[prost(bytes, tag = "4")]
264 ciphertext: Vec<u8>,
265}
266
267impl ProtoBufMessage {
268 const RATCHET_TAG: &'static [u8; 1] = b"\x0A";
269 const INDEX_TAG: &'static [u8; 1] = b"\x10";
270 const CIPHER_TAG: &'static [u8; 1] = b"\x22";
271
272 fn encode_manual(&self, version: u8) -> Vec<u8> {
273 let index = self.chain_index.to_var_int();
274 let ratchet_len = self.ratchet_key.len().to_var_int();
275 let ciphertext_len = self.ciphertext.len().to_var_int();
276
277 [
278 [version].as_ref(),
279 Self::RATCHET_TAG.as_ref(),
280 &ratchet_len,
281 &self.ratchet_key,
282 Self::INDEX_TAG.as_ref(),
283 &index,
284 Self::CIPHER_TAG.as_ref(),
285 &ciphertext_len,
286 &self.ciphertext,
287 ]
288 .concat()
289 }
290}
291
292#[cfg(test)]
293mod test {
294 use assert_matches2::{assert_let, assert_matches};
295
296 use super::Message;
297 use crate::{
298 Curve25519PublicKey, DecodeError,
299 olm::messages::message::{MAC_TRUNCATED_VERSION, VERSION},
300 };
301
302 #[test]
303 fn encode() {
304 let message = b"\x03\n\x20ratchetkeyhereprettyplease123456\x10\x02\"\nciphertext";
305 let message_mac =
306 b"\x03\n\x20ratchetkeyhereprettyplease123456\x10\x02\"\nciphertextMACHEREE";
307
308 let ratchet_key = Curve25519PublicKey::from(*b"ratchetkeyhereprettyplease123456");
309 let ciphertext = b"ciphertext";
310 let chain_index = 2;
311
312 let mut encoded = Message::new_truncated_mac(ratchet_key, chain_index, ciphertext.to_vec());
313 encoded.mac = (*b"MACHEREE").into();
314
315 assert_eq!(encoded.to_mac_bytes(), message.as_ref());
316 assert_eq!(encoded.to_bytes(), message_mac.as_ref());
317 assert_eq!(encoded.ciphertext(), ciphertext.to_vec());
318 assert_eq!(encoded.chain_index(), chain_index);
319 assert_eq!(encoded.version(), MAC_TRUNCATED_VERSION);
320 }
321
322 #[test]
323 fn from_bytes_normal_message() {
324 let bytes = b"\x04\n\x20ratchetkeyhereprettyplease123456\x10\x02\"\nciphertextMAC_01234567890_01234567890_HERE";
325 let result = Message::try_from(bytes.as_slice());
326 assert_let!(Ok(message) = result);
327
328 let Message { version, ratchet_key, chain_index, ciphertext, mac } = message;
329
330 assert_eq!(version, VERSION);
331 assert_eq!(ratchet_key.as_bytes(), b"ratchetkeyhereprettyplease123456");
332 assert_eq!(chain_index, 2);
333 assert_eq!(ciphertext, b"ciphertext");
334 assert_eq!(mac.as_bytes(), b"MAC_01234567890_01234567890_HERE");
335 }
336
337 #[test]
338 fn from_bytes_too_short() {
339 let bytes = vec![MAC_TRUNCATED_VERSION, 0, 0, 0, 0, 0, 0, 0, 0];
340 let result = Message::try_from(bytes);
341 assert_matches!(result, Err(DecodeError::MessageTooShort(9)));
342 }
343
344 #[test]
345 fn from_bytes_decoding_error() {
346 let bytes = vec![MAC_TRUNCATED_VERSION, 0, 0, 0, 0, 0, 0, 0, 0, 0];
347 let result = Message::try_from(bytes);
348 assert_matches!(result, Err(DecodeError::ProtoBufError(_)));
349 }
350
351 #[test]
352 fn from_bytes_invalid_tag() {
353 let bytes = [
354 VERSION, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
355 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
356 ];
357 let result = Message::try_from(bytes.to_vec());
358 assert_matches!(result, Err(DecodeError::ProtoBufError(_)));
359 }
360}