Skip to main content

vodozemac/types/
curve25519.rs

1// Copyright 2021 Denis Kasak, Damir Jelić
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use base64::decoded_len_estimate;
18use matrix_pickle::{Decode, DecodeError};
19use rand::thread_rng;
20use serde::{Deserialize, Serialize};
21use x25519_dalek::{EphemeralSecret, PublicKey, ReusableSecret, SharedSecret, StaticSecret};
22use zeroize::Zeroize;
23
24use super::KeyError;
25use crate::utilities::{base64_decode, base64_encode};
26
27/// Struct representing a Curve25519 secret key.
28#[derive(Clone, Deserialize, Serialize)]
29#[serde(transparent)]
30pub struct Curve25519SecretKey(Box<StaticSecret>);
31
32impl Curve25519SecretKey {
33    /// Generate a new, random, Curve25519SecretKey.
34    pub fn new() -> Self {
35        let rng = thread_rng();
36
37        Self(Box::new(StaticSecret::random_from_rng(rng)))
38    }
39
40    /// Create a `Curve25519SecretKey` from the given slice of bytes.
41    pub fn from_slice(bytes: &[u8; 32]) -> Self {
42        // XXX: Passing in secret array as value.
43        Self(Box::new(StaticSecret::from(*bytes)))
44    }
45
46    /// Perform a Diffie-Hellman key exchange between the given
47    /// `Curve25519PublicKey` and this `Curve25519SecretKey` and return a shared
48    /// secret.
49    pub fn diffie_hellman(&self, their_public_key: &Curve25519PublicKey) -> SharedSecret {
50        self.0.diffie_hellman(&their_public_key.inner)
51    }
52
53    /// Convert the `Curve25519SecretKey` to a byte array.
54    ///
55    /// **Note**: This creates a copy of the key which won't be zeroized, the
56    /// caller of the method needs to make sure to zeroize the returned array.
57    pub fn to_bytes(&self) -> Box<[u8; 32]> {
58        let mut key = Box::new([0u8; 32]);
59        let mut bytes = self.0.to_bytes();
60        key.copy_from_slice(&bytes);
61
62        bytes.zeroize();
63
64        key
65    }
66}
67
68impl Default for Curve25519SecretKey {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74#[derive(Serialize, Deserialize, Clone)]
75#[serde(from = "Curve25519KeypairPickle")]
76#[serde(into = "Curve25519KeypairPickle")]
77pub(crate) struct Curve25519Keypair {
78    pub secret_key: Curve25519SecretKey,
79    pub public_key: Curve25519PublicKey,
80}
81
82impl Curve25519Keypair {
83    pub fn new() -> Self {
84        let secret_key = Curve25519SecretKey::new();
85        let public_key = Curve25519PublicKey::from(&secret_key);
86
87        Self { secret_key, public_key }
88    }
89
90    pub fn from_secret_key(key: &[u8; 32]) -> Self {
91        let secret_key = Curve25519SecretKey::from_slice(key);
92        let public_key = Curve25519PublicKey::from(&secret_key);
93
94        Curve25519Keypair { secret_key, public_key }
95    }
96
97    pub const fn secret_key(&self) -> &Curve25519SecretKey {
98        &self.secret_key
99    }
100
101    pub const fn public_key(&self) -> Curve25519PublicKey {
102        self.public_key
103    }
104}
105
106/// Struct representing a Curve25519 public key.
107#[derive(PartialEq, Eq, Hash, Copy, Clone, Serialize, Deserialize)]
108#[serde(transparent)]
109pub struct Curve25519PublicKey {
110    pub(crate) inner: PublicKey,
111}
112
113impl Decode for Curve25519PublicKey {
114    fn decode(reader: &mut impl std::io::Read) -> Result<Self, DecodeError> {
115        let key = <[u8; 32]>::decode(reader)?;
116
117        Ok(Curve25519PublicKey::from(key))
118    }
119}
120
121impl Curve25519PublicKey {
122    /// The number of bytes a Curve25519 public key has.
123    pub const LENGTH: usize = 32;
124
125    const BASE64_LENGTH: usize = 43;
126    const PADDED_BASE64_LENGTH: usize = 44;
127
128    /// Convert this public key to a byte array.
129    #[inline]
130    pub fn to_bytes(&self) -> [u8; Self::LENGTH] {
131        self.inner.to_bytes()
132    }
133
134    /// View this public key as a byte array.
135    #[inline]
136    pub fn as_bytes(&self) -> &[u8; Self::LENGTH] {
137        self.inner.as_bytes()
138    }
139
140    /// Convert the public key to a vector of bytes.
141    pub fn to_vec(&self) -> Vec<u8> {
142        self.inner.as_bytes().to_vec()
143    }
144
145    /// Create a `Curve25519PublicKey` from a byte array.
146    pub fn from_bytes(bytes: [u8; 32]) -> Self {
147        Self { inner: PublicKey::from(bytes) }
148    }
149
150    /// Instantiate a Curve25519 public key from an unpadded base64
151    /// representation.
152    pub fn from_base64(input: &str) -> Result<Curve25519PublicKey, KeyError> {
153        if input.len() != Self::BASE64_LENGTH && input.len() != Self::PADDED_BASE64_LENGTH {
154            Err(KeyError::InvalidKeyLength {
155                key_type: "Curve25519",
156                expected_length: Self::LENGTH,
157                length: decoded_len_estimate(input.len()),
158            })
159        } else {
160            let key = base64_decode(input)?;
161            Self::from_slice(&key)
162        }
163    }
164
165    /// Try to create a `Curve25519PublicKey` from a slice of bytes.
166    pub fn from_slice(slice: &[u8]) -> Result<Curve25519PublicKey, KeyError> {
167        let key_len = slice.len();
168
169        if key_len == Self::LENGTH {
170            let mut key = [0u8; Self::LENGTH];
171            key.copy_from_slice(slice);
172
173            Ok(Self::from(key))
174        } else {
175            Err(KeyError::InvalidKeyLength {
176                key_type: "Curve25519",
177                expected_length: Self::LENGTH,
178                length: key_len,
179            })
180        }
181    }
182
183    /// Serialize a Curve25519 public key to an unpadded base64 representation.
184    pub fn to_base64(&self) -> String {
185        base64_encode(self.inner.as_bytes())
186    }
187}
188
189impl Display for Curve25519PublicKey {
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        write!(f, "{}", self.to_base64())
192    }
193}
194
195impl std::fmt::Debug for Curve25519PublicKey {
196    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        let s = format!("curve25519:{self}");
198        <str as std::fmt::Debug>::fmt(&s, f)
199    }
200}
201
202impl From<[u8; Self::LENGTH]> for Curve25519PublicKey {
203    fn from(bytes: [u8; Self::LENGTH]) -> Curve25519PublicKey {
204        Curve25519PublicKey { inner: PublicKey::from(bytes) }
205    }
206}
207
208impl<'a> From<&'a Curve25519SecretKey> for Curve25519PublicKey {
209    fn from(secret: &'a Curve25519SecretKey) -> Curve25519PublicKey {
210        Curve25519PublicKey { inner: PublicKey::from(secret.0.as_ref()) }
211    }
212}
213
214impl<'a> From<&'a EphemeralSecret> for Curve25519PublicKey {
215    fn from(secret: &'a EphemeralSecret) -> Curve25519PublicKey {
216        Curve25519PublicKey { inner: PublicKey::from(secret) }
217    }
218}
219
220impl<'a> From<&'a ReusableSecret> for Curve25519PublicKey {
221    fn from(secret: &'a ReusableSecret) -> Curve25519PublicKey {
222        Curve25519PublicKey { inner: PublicKey::from(secret) }
223    }
224}
225
226#[derive(Serialize, Deserialize)]
227#[serde(transparent)]
228pub(crate) struct Curve25519KeypairPickle(Curve25519SecretKey);
229
230impl From<Curve25519KeypairPickle> for Curve25519Keypair {
231    fn from(pickle: Curve25519KeypairPickle) -> Self {
232        let secret_key = pickle.0;
233        let public_key = Curve25519PublicKey::from(&secret_key);
234
235        Self { secret_key, public_key }
236    }
237}
238
239impl From<Curve25519Keypair> for Curve25519KeypairPickle {
240    fn from(key: Curve25519Keypair) -> Self {
241        Curve25519KeypairPickle(key.secret_key)
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use assert_matches2::assert_matches;
248    use insta::assert_debug_snapshot;
249
250    use super::Curve25519PublicKey;
251    use crate::{Curve25519SecretKey, KeyError, utilities::DecodeError};
252
253    #[test]
254    fn decoding_invalid_base64_fails() {
255        let base64_payload = "a";
256        assert_matches!(
257            Curve25519PublicKey::from_base64(base64_payload),
258            Err(KeyError::InvalidKeyLength { .. })
259        );
260
261        let base64_payload = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA ";
262        assert_matches!(
263            Curve25519PublicKey::from_base64(base64_payload),
264            Err(KeyError::Base64Error(DecodeError::InvalidByte(..)))
265        );
266
267        let base64_payload = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAZ";
268        assert_matches!(
269            Curve25519PublicKey::from_base64(base64_payload),
270            Err(KeyError::Base64Error(DecodeError::InvalidLastSymbol(..)))
271        );
272    }
273
274    #[test]
275    fn decoding_incorrect_num_of_bytes_fails() {
276        let base64_payload = "aaaa";
277        assert_matches!(
278            Curve25519PublicKey::from_base64(base64_payload),
279            Err(KeyError::InvalidKeyLength { .. })
280        );
281    }
282
283    #[test]
284    fn decoding_of_correct_num_of_bytes_succeeds() {
285        let base64_payload = "MDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDA";
286        assert_matches!(Curve25519PublicKey::from_base64(base64_payload), Ok(..));
287    }
288
289    #[test]
290    fn byte_decoding_roundtrip_succeeds_for_public_key() {
291        let bytes = *b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
292        let key = Curve25519PublicKey::from_bytes(bytes);
293        assert_eq!(key.to_bytes(), bytes);
294        assert_eq!(key.as_bytes(), &bytes);
295        assert_eq!(key.to_vec(), bytes.to_vec());
296    }
297
298    #[test]
299    fn byte_decoding_roundtrip_succeeds_for_secret_key() {
300        let bytes = *b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
301        let key = Curve25519SecretKey::from_slice(&bytes);
302        assert_eq!(*(key.to_bytes()), bytes);
303    }
304
305    #[test]
306    fn snapshot_public_key_debug() {
307        let key = Curve25519PublicKey::from_bytes([0; 32]);
308        assert_debug_snapshot!(key);
309    }
310}