vodozemac/types/
curve25519.rs1use std::fmt::Display;
16
17use base64::decoded_len_estimate;
18use matrix_pickle::{Decode, DecodeError};
19use rand::thread_rng;
20use serde::{Deserialize, Serialize};
21use x25519_dalek::{EphemeralSecret, PublicKey, ReusableSecret, SharedSecret, StaticSecret};
22use zeroize::Zeroize;
23
24use super::KeyError;
25use crate::utilities::{base64_decode, base64_encode};
26
27#[derive(Clone, Deserialize, Serialize)]
29#[serde(transparent)]
30pub struct Curve25519SecretKey(Box<StaticSecret>);
31
32impl Curve25519SecretKey {
33 pub fn new() -> Self {
35 let rng = thread_rng();
36
37 Self(Box::new(StaticSecret::random_from_rng(rng)))
38 }
39
40 pub fn from_slice(bytes: &[u8; 32]) -> Self {
42 Self(Box::new(StaticSecret::from(*bytes)))
44 }
45
46 pub fn diffie_hellman(&self, their_public_key: &Curve25519PublicKey) -> SharedSecret {
50 self.0.diffie_hellman(&their_public_key.inner)
51 }
52
53 pub fn to_bytes(&self) -> Box<[u8; 32]> {
58 let mut key = Box::new([0u8; 32]);
59 let mut bytes = self.0.to_bytes();
60 key.copy_from_slice(&bytes);
61
62 bytes.zeroize();
63
64 key
65 }
66}
67
68impl Default for Curve25519SecretKey {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74#[derive(Serialize, Deserialize, Clone)]
75#[serde(from = "Curve25519KeypairPickle")]
76#[serde(into = "Curve25519KeypairPickle")]
77pub(crate) struct Curve25519Keypair {
78 pub secret_key: Curve25519SecretKey,
79 pub public_key: Curve25519PublicKey,
80}
81
82impl Curve25519Keypair {
83 pub fn new() -> Self {
84 let secret_key = Curve25519SecretKey::new();
85 let public_key = Curve25519PublicKey::from(&secret_key);
86
87 Self { secret_key, public_key }
88 }
89
90 pub fn from_secret_key(key: &[u8; 32]) -> Self {
91 let secret_key = Curve25519SecretKey::from_slice(key);
92 let public_key = Curve25519PublicKey::from(&secret_key);
93
94 Curve25519Keypair { secret_key, public_key }
95 }
96
97 pub const fn secret_key(&self) -> &Curve25519SecretKey {
98 &self.secret_key
99 }
100
101 pub const fn public_key(&self) -> Curve25519PublicKey {
102 self.public_key
103 }
104}
105
106#[derive(PartialEq, Eq, Hash, Copy, Clone, Serialize, Deserialize)]
108#[serde(transparent)]
109pub struct Curve25519PublicKey {
110 pub(crate) inner: PublicKey,
111}
112
113impl Decode for Curve25519PublicKey {
114 fn decode(reader: &mut impl std::io::Read) -> Result<Self, DecodeError> {
115 let key = <[u8; 32]>::decode(reader)?;
116
117 Ok(Curve25519PublicKey::from(key))
118 }
119}
120
121impl Curve25519PublicKey {
122 pub const LENGTH: usize = 32;
124
125 const BASE64_LENGTH: usize = 43;
126 const PADDED_BASE64_LENGTH: usize = 44;
127
128 #[inline]
130 pub fn to_bytes(&self) -> [u8; Self::LENGTH] {
131 self.inner.to_bytes()
132 }
133
134 #[inline]
136 pub fn as_bytes(&self) -> &[u8; Self::LENGTH] {
137 self.inner.as_bytes()
138 }
139
140 pub fn to_vec(&self) -> Vec<u8> {
142 self.inner.as_bytes().to_vec()
143 }
144
145 pub fn from_bytes(bytes: [u8; 32]) -> Self {
147 Self { inner: PublicKey::from(bytes) }
148 }
149
150 pub fn from_base64(input: &str) -> Result<Curve25519PublicKey, KeyError> {
153 if input.len() != Self::BASE64_LENGTH && input.len() != Self::PADDED_BASE64_LENGTH {
154 Err(KeyError::InvalidKeyLength {
155 key_type: "Curve25519",
156 expected_length: Self::LENGTH,
157 length: decoded_len_estimate(input.len()),
158 })
159 } else {
160 let key = base64_decode(input)?;
161 Self::from_slice(&key)
162 }
163 }
164
165 pub fn from_slice(slice: &[u8]) -> Result<Curve25519PublicKey, KeyError> {
167 let key_len = slice.len();
168
169 if key_len == Self::LENGTH {
170 let mut key = [0u8; Self::LENGTH];
171 key.copy_from_slice(slice);
172
173 Ok(Self::from(key))
174 } else {
175 Err(KeyError::InvalidKeyLength {
176 key_type: "Curve25519",
177 expected_length: Self::LENGTH,
178 length: key_len,
179 })
180 }
181 }
182
183 pub fn to_base64(&self) -> String {
185 base64_encode(self.inner.as_bytes())
186 }
187}
188
189impl Display for Curve25519PublicKey {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 write!(f, "{}", self.to_base64())
192 }
193}
194
195impl std::fmt::Debug for Curve25519PublicKey {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 let s = format!("curve25519:{self}");
198 <str as std::fmt::Debug>::fmt(&s, f)
199 }
200}
201
202impl From<[u8; Self::LENGTH]> for Curve25519PublicKey {
203 fn from(bytes: [u8; Self::LENGTH]) -> Curve25519PublicKey {
204 Curve25519PublicKey { inner: PublicKey::from(bytes) }
205 }
206}
207
208impl<'a> From<&'a Curve25519SecretKey> for Curve25519PublicKey {
209 fn from(secret: &'a Curve25519SecretKey) -> Curve25519PublicKey {
210 Curve25519PublicKey { inner: PublicKey::from(secret.0.as_ref()) }
211 }
212}
213
214impl<'a> From<&'a EphemeralSecret> for Curve25519PublicKey {
215 fn from(secret: &'a EphemeralSecret) -> Curve25519PublicKey {
216 Curve25519PublicKey { inner: PublicKey::from(secret) }
217 }
218}
219
220impl<'a> From<&'a ReusableSecret> for Curve25519PublicKey {
221 fn from(secret: &'a ReusableSecret) -> Curve25519PublicKey {
222 Curve25519PublicKey { inner: PublicKey::from(secret) }
223 }
224}
225
226#[derive(Serialize, Deserialize)]
227#[serde(transparent)]
228pub(crate) struct Curve25519KeypairPickle(Curve25519SecretKey);
229
230impl From<Curve25519KeypairPickle> for Curve25519Keypair {
231 fn from(pickle: Curve25519KeypairPickle) -> Self {
232 let secret_key = pickle.0;
233 let public_key = Curve25519PublicKey::from(&secret_key);
234
235 Self { secret_key, public_key }
236 }
237}
238
239impl From<Curve25519Keypair> for Curve25519KeypairPickle {
240 fn from(key: Curve25519Keypair) -> Self {
241 Curve25519KeypairPickle(key.secret_key)
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use assert_matches2::assert_matches;
248 use insta::assert_debug_snapshot;
249
250 use super::Curve25519PublicKey;
251 use crate::{Curve25519SecretKey, KeyError, utilities::DecodeError};
252
253 #[test]
254 fn decoding_invalid_base64_fails() {
255 let base64_payload = "a";
256 assert_matches!(
257 Curve25519PublicKey::from_base64(base64_payload),
258 Err(KeyError::InvalidKeyLength { .. })
259 );
260
261 let base64_payload = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA ";
262 assert_matches!(
263 Curve25519PublicKey::from_base64(base64_payload),
264 Err(KeyError::Base64Error(DecodeError::InvalidByte(..)))
265 );
266
267 let base64_payload = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAZ";
268 assert_matches!(
269 Curve25519PublicKey::from_base64(base64_payload),
270 Err(KeyError::Base64Error(DecodeError::InvalidLastSymbol(..)))
271 );
272 }
273
274 #[test]
275 fn decoding_incorrect_num_of_bytes_fails() {
276 let base64_payload = "aaaa";
277 assert_matches!(
278 Curve25519PublicKey::from_base64(base64_payload),
279 Err(KeyError::InvalidKeyLength { .. })
280 );
281 }
282
283 #[test]
284 fn decoding_of_correct_num_of_bytes_succeeds() {
285 let base64_payload = "MDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDA";
286 assert_matches!(Curve25519PublicKey::from_base64(base64_payload), Ok(..));
287 }
288
289 #[test]
290 fn byte_decoding_roundtrip_succeeds_for_public_key() {
291 let bytes = *b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
292 let key = Curve25519PublicKey::from_bytes(bytes);
293 assert_eq!(key.to_bytes(), bytes);
294 assert_eq!(key.as_bytes(), &bytes);
295 assert_eq!(key.to_vec(), bytes.to_vec());
296 }
297
298 #[test]
299 fn byte_decoding_roundtrip_succeeds_for_secret_key() {
300 let bytes = *b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
301 let key = Curve25519SecretKey::from_slice(&bytes);
302 assert_eq!(*(key.to_bytes()), bytes);
303 }
304
305 #[test]
306 fn snapshot_public_key_debug() {
307 let key = Curve25519PublicKey::from_bytes([0; 32]);
308 assert_debug_snapshot!(key);
309 }
310}