1use std::fmt::Debug;
16
17use prost::Message as ProstMessage;
18use serde::{Deserialize, Serialize};
19
20use crate::{
21 Curve25519PublicKey, DecodeError,
22 cipher::{Mac, MessageMac},
23 utilities::{VarInt, base64_decode, base64_encode, extract_mac},
24};
25
26pub(crate) const MAC_TRUNCATED_VERSION: u8 = 3;
27pub(crate) const VERSION: u8 = 4;
28
29#[derive(Clone, PartialEq, Eq)]
36pub struct Message {
37 pub(crate) version: u8,
38 pub(crate) ratchet_key: Curve25519PublicKey,
39 pub(crate) chain_index: u64,
40 pub(crate) ciphertext: Vec<u8>,
41 pub(crate) mac: MessageMac,
42}
43
44impl Message {
45 pub const fn ratchet_key(&self) -> Curve25519PublicKey {
48 self.ratchet_key
49 }
50
51 pub const fn chain_index(&self) -> u64 {
53 self.chain_index
54 }
55
56 pub fn ciphertext(&self) -> &[u8] {
58 &self.ciphertext
59 }
60
61 pub const fn version(&self) -> u8 {
63 self.version
64 }
65
66 pub fn from_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
71 Self::try_from(bytes)
72 }
73
74 pub fn to_bytes(&self) -> Vec<u8> {
94 let mut message = self.encode();
95 message.extend(self.mac.as_bytes());
96
97 message
98 }
99
100 pub fn from_base64(message: &str) -> Result<Self, DecodeError> {
105 Self::try_from(message)
106 }
107
108 pub fn to_base64(&self) -> String {
113 base64_encode(self.to_bytes())
114 }
115
116 #[cfg(feature = "experimental-session-config")]
117 pub(crate) fn new(
118 ratchet_key: Curve25519PublicKey,
119 chain_index: u64,
120 ciphertext: Vec<u8>,
121 ) -> Self {
122 Self {
123 version: VERSION,
124 ratchet_key,
125 chain_index,
126 ciphertext,
127 mac: Mac([0u8; Mac::LENGTH]).into(),
128 }
129 }
130
131 pub(crate) fn new_truncated_mac(
132 ratchet_key: Curve25519PublicKey,
133 chain_index: u64,
134 ciphertext: Vec<u8>,
135 ) -> Self {
136 Self {
137 version: MAC_TRUNCATED_VERSION,
138 ratchet_key,
139 chain_index,
140 ciphertext,
141 mac: [0u8; Mac::TRUNCATED_LEN].into(),
142 }
143 }
144
145 fn encode(&self) -> Vec<u8> {
146 ProtoBufMessage {
147 ratchet_key: self.ratchet_key.to_bytes().to_vec(),
148 chain_index: self.chain_index,
149 ciphertext: self.ciphertext.clone(),
150 }
151 .encode_manual(self.version)
152 }
153
154 pub(crate) fn to_mac_bytes(&self) -> Vec<u8> {
155 self.encode()
156 }
157
158 pub(crate) fn set_mac(&mut self, mac: Mac) {
159 match self.mac {
160 MessageMac::Truncated(_) => self.mac = mac.truncate().into(),
161 MessageMac::Full(_) => self.mac = mac.into(),
162 }
163 }
164}
165
166impl Serialize for Message {
167 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
168 where
169 S: serde::Serializer,
170 {
171 let message = self.to_base64();
172 serializer.serialize_str(&message)
173 }
174}
175
176impl<'de> Deserialize<'de> for Message {
177 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
178 let ciphertext = String::deserialize(d)?;
179 Message::from_base64(&ciphertext).map_err(serde::de::Error::custom)
180 }
181}
182
183impl TryFrom<&str> for Message {
184 type Error = DecodeError;
185
186 fn try_from(value: &str) -> Result<Self, Self::Error> {
187 let decoded = base64_decode(value)?;
188
189 Self::try_from(decoded)
190 }
191}
192
193impl TryFrom<Vec<u8>> for Message {
194 type Error = DecodeError;
195
196 fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
197 Self::try_from(value.as_slice())
198 }
199}
200
201impl TryFrom<&[u8]> for Message {
202 type Error = DecodeError;
203
204 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
205 let version = *value.first().ok_or(DecodeError::MissingVersion)?;
206
207 let mac_length = match version {
208 VERSION => Mac::LENGTH,
209 MAC_TRUNCATED_VERSION => Mac::TRUNCATED_LEN,
210 _ => return Err(DecodeError::InvalidVersion(VERSION, version)),
211 };
212
213 if value.len() < mac_length + 2 {
214 Err(DecodeError::MessageTooShort(value.len()))
215 } else {
216 let inner = ProtoBufMessage::decode(
217 value
218 .get(1..value.len() - mac_length)
219 .ok_or_else(|| DecodeError::MessageTooShort(value.len()))?,
220 )?;
221
222 let mac_slice = &value[value.len() - mac_length..];
223
224 if mac_slice.len() != mac_length {
225 Err(DecodeError::InvalidMacLength(mac_length, mac_slice.len()))
226 } else {
227 let mac = extract_mac(mac_slice, version == MAC_TRUNCATED_VERSION);
228
229 let chain_index = inner.chain_index;
230 let ciphertext = inner.ciphertext;
231 let ratchet_key = Curve25519PublicKey::from_slice(&inner.ratchet_key)?;
232
233 let message = Message { version, ratchet_key, chain_index, ciphertext, mac };
234
235 Ok(message)
236 }
237 }
238 }
239}
240
241impl Debug for Message {
242 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243 let Self { version, ratchet_key, chain_index, ciphertext: _, mac: _ } = self;
244
245 f.debug_struct("Message")
246 .field("version", version)
247 .field("ratchet_key", ratchet_key)
248 .field("chain_index", chain_index)
249 .finish_non_exhaustive()
250 }
251}
252
253#[derive(ProstMessage, PartialEq, Eq)]
254struct ProtoBufMessage {
255 #[prost(bytes, tag = "1")]
256 ratchet_key: Vec<u8>,
257 #[prost(uint64, tag = "2")]
258 chain_index: u64,
259 #[prost(bytes, tag = "4")]
260 ciphertext: Vec<u8>,
261}
262
263impl ProtoBufMessage {
264 const RATCHET_TAG: &'static [u8; 1] = b"\x0A";
265 const INDEX_TAG: &'static [u8; 1] = b"\x10";
266 const CIPHER_TAG: &'static [u8; 1] = b"\x22";
267
268 fn encode_manual(&self, version: u8) -> Vec<u8> {
269 let index = self.chain_index.to_var_int();
270 let ratchet_len = self.ratchet_key.len().to_var_int();
271 let ciphertext_len = self.ciphertext.len().to_var_int();
272
273 [
274 [version].as_ref(),
275 Self::RATCHET_TAG.as_ref(),
276 &ratchet_len,
277 &self.ratchet_key,
278 Self::INDEX_TAG.as_ref(),
279 &index,
280 Self::CIPHER_TAG.as_ref(),
281 &ciphertext_len,
282 &self.ciphertext,
283 ]
284 .concat()
285 }
286}
287
288#[cfg(test)]
289mod test {
290 use assert_matches2::{assert_let, assert_matches};
291
292 use super::Message;
293 use crate::{
294 Curve25519PublicKey, DecodeError,
295 olm::messages::message::{MAC_TRUNCATED_VERSION, VERSION},
296 };
297
298 #[test]
299 fn encode() {
300 let message = b"\x03\n\x20ratchetkeyhereprettyplease123456\x10\x02\"\nciphertext";
301 let message_mac =
302 b"\x03\n\x20ratchetkeyhereprettyplease123456\x10\x02\"\nciphertextMACHEREE";
303
304 let ratchet_key = Curve25519PublicKey::from(*b"ratchetkeyhereprettyplease123456");
305 let ciphertext = b"ciphertext";
306 let chain_index = 2;
307
308 let mut encoded = Message::new_truncated_mac(ratchet_key, chain_index, ciphertext.to_vec());
309 encoded.mac = (*b"MACHEREE").into();
310
311 assert_eq!(encoded.to_mac_bytes(), message.as_ref());
312 assert_eq!(encoded.to_bytes(), message_mac.as_ref());
313 assert_eq!(encoded.ciphertext(), ciphertext.to_vec());
314 assert_eq!(encoded.chain_index(), chain_index);
315 assert_eq!(encoded.version(), MAC_TRUNCATED_VERSION);
316 }
317
318 #[test]
319 fn from_bytes_normal_message() {
320 let bytes = b"\x04\n\x20ratchetkeyhereprettyplease123456\x10\x02\"\nciphertextMAC_01234567890_01234567890_HERE";
321 let result = Message::try_from(bytes.as_slice());
322 assert_let!(Ok(message) = result);
323
324 let Message { version, ratchet_key, chain_index, ciphertext, mac } = message;
325
326 assert_eq!(version, VERSION);
327 assert_eq!(ratchet_key.as_bytes(), b"ratchetkeyhereprettyplease123456");
328 assert_eq!(chain_index, 2);
329 assert_eq!(ciphertext, b"ciphertext");
330 assert_eq!(mac.as_bytes(), b"MAC_01234567890_01234567890_HERE");
331 }
332
333 #[test]
334 fn from_bytes_too_short() {
335 let bytes = vec![MAC_TRUNCATED_VERSION, 0, 0, 0, 0, 0, 0, 0, 0];
336 let result = Message::try_from(bytes);
337 assert_matches!(result, Err(DecodeError::MessageTooShort(9)));
338 }
339
340 #[test]
341 fn from_bytes_decoding_error() {
342 let bytes = vec![MAC_TRUNCATED_VERSION, 0, 0, 0, 0, 0, 0, 0, 0, 0];
343 let result = Message::try_from(bytes);
344 assert_matches!(result, Err(DecodeError::ProtoBufError(_)));
345 }
346
347 #[test]
348 fn from_bytes_invalid_tag() {
349 let bytes = [
350 VERSION, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
351 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
352 ];
353 let result = Message::try_from(bytes.to_vec());
354 assert_matches!(result, Err(DecodeError::ProtoBufError(_)));
355 }
356}