use std::{
borrow::Borrow,
collections::{
btree_map::{IntoIter, Iter},
BTreeMap,
},
};
use as_variant::as_variant;
use matrix_sdk_common::deserialized_responses::PrivOwnedStr;
use ruma::{
serde::StringEnum, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceKeyId, OwnedUserId, UserId,
};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use vodozemac::{Curve25519PublicKey, Ed25519PublicKey, Ed25519Signature, KeyError};
use zeroize::{Zeroize, ZeroizeOnDrop};
mod backup;
mod cross_signing;
mod device_keys;
pub mod events;
mod one_time_keys;
pub mod qr_login;
pub mod requests;
pub use self::{backup::*, cross_signing::*, device_keys::*, one_time_keys::*};
use crate::store::BackupDecryptionKey;
macro_rules! from_base64 {
($foo:ident, $name:ident) => {
pub(crate) fn $name<'de, D>(deserializer: D) -> Result<$foo, D::Error>
where
D: Deserializer<'de>,
{
let mut string = String::deserialize(deserializer)?;
let result = $foo::from_base64(&string);
string.zeroize();
result.map_err(serde::de::Error::custom)
}
};
}
macro_rules! to_base64 {
($foo:ident, $name:ident) => {
pub(crate) fn $name<S>(v: &$foo, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut string = v.to_base64();
let ret = string.serialize(serializer);
string.zeroize();
ret
}
};
}
#[derive(Debug, Deserialize, Clone, Serialize, ZeroizeOnDrop)]
pub struct SecretsBundle {
pub cross_signing: CrossSigningSecrets,
pub backup: Option<BackupSecrets>,
}
#[derive(Deserialize, Clone, Serialize, ZeroizeOnDrop)]
pub struct CrossSigningSecrets {
pub master_key: String,
pub user_signing_key: String,
pub self_signing_key: String,
}
impl std::fmt::Debug for CrossSigningSecrets {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CrossSigningSecrets")
.field("master_key", &"...")
.field("user_signing_key", &"...")
.field("self_signing_key", &"...")
.finish()
}
}
#[derive(Debug, Deserialize, Clone, Serialize, ZeroizeOnDrop)]
pub struct MegolmBackupV1Curve25519AesSha2Secrets {
#[serde(serialize_with = "backup_key_to_base64", deserialize_with = "backup_key_from_base64")]
pub key: BackupDecryptionKey,
pub backup_version: String,
}
from_base64!(BackupDecryptionKey, backup_key_from_base64);
to_base64!(BackupDecryptionKey, backup_key_to_base64);
#[derive(Debug, Clone, ZeroizeOnDrop, Serialize, Deserialize)]
#[serde(tag = "algorithm")]
pub enum BackupSecrets {
#[serde(rename = "m.megolm_backup.v1.curve25519-aes-sha2")]
MegolmBackupV1Curve25519AesSha2(MegolmBackupV1Curve25519AesSha2Secrets),
}
impl BackupSecrets {
pub fn algorithm(&self) -> &str {
match &self {
BackupSecrets::MegolmBackupV1Curve25519AesSha2(_) => {
"m.megolm_backup.v1.curve25519-aes-sha2"
}
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Signature {
Ed25519(Ed25519Signature),
Other(String),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct InvalidSignature {
pub source: String,
}
impl Signature {
pub fn ed25519(&self) -> Option<Ed25519Signature> {
as_variant!(self, Self::Ed25519).copied()
}
pub fn to_base64(&self) -> String {
match self {
Signature::Ed25519(s) => s.to_base64(),
Signature::Other(s) => s.to_owned(),
}
}
}
impl From<Ed25519Signature> for Signature {
fn from(signature: Ed25519Signature) -> Self {
Self::Ed25519(signature)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Signatures(
BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceKeyId, Result<Signature, InvalidSignature>>>,
);
impl Signatures {
pub fn new() -> Self {
Signatures(Default::default())
}
pub fn add_signature(
&mut self,
signer: OwnedUserId,
key_id: OwnedDeviceKeyId,
signature: Ed25519Signature,
) -> Option<Result<Signature, InvalidSignature>> {
self.0.entry(signer).or_default().insert(key_id, Ok(signature.into()))
}
pub fn get_signature(&self, signer: &UserId, key_id: &DeviceKeyId) -> Option<Ed25519Signature> {
self.get(signer)?.get(key_id)?.as_ref().ok()?.ed25519()
}
pub fn get(
&self,
signer: &UserId,
) -> Option<&BTreeMap<OwnedDeviceKeyId, Result<Signature, InvalidSignature>>> {
self.0.get(signer)
}
pub fn clear(&mut self) {
self.0.clear()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn signature_count(&self) -> usize {
self.0.values().map(|u| u.len()).sum()
}
}
impl Default for Signatures {
fn default() -> Self {
Self::new()
}
}
impl IntoIterator for Signatures {
type Item = (OwnedUserId, BTreeMap<OwnedDeviceKeyId, Result<Signature, InvalidSignature>>);
type IntoIter =
IntoIter<OwnedUserId, BTreeMap<OwnedDeviceKeyId, Result<Signature, InvalidSignature>>>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl<'de> Deserialize<'de> for Signatures {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let map: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceKeyId, String>> =
Deserialize::deserialize(deserializer)?;
let map = map
.into_iter()
.map(|(user, signatures)| {
let signatures = signatures
.into_iter()
.map(|(key_id, s)| {
let algorithm = key_id.algorithm();
let signature = match algorithm {
DeviceKeyAlgorithm::Ed25519 => Ed25519Signature::from_base64(&s)
.map(|s| s.into())
.map_err(|_| InvalidSignature { source: s }),
_ => Ok(Signature::Other(s)),
};
Ok((key_id, signature))
})
.collect::<Result<BTreeMap<_, _>, _>>()?;
Ok((user, signatures))
})
.collect::<Result<_, _>>()?;
Ok(Signatures(map))
}
}
impl Serialize for Signatures {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let signatures: BTreeMap<&OwnedUserId, BTreeMap<&OwnedDeviceKeyId, String>> = self
.0
.iter()
.map(|(u, m)| {
(
u,
m.iter()
.map(|(d, s)| {
(
d,
match s {
Ok(s) => s.to_base64(),
Err(i) => i.source.to_owned(),
},
)
})
.collect(),
)
})
.collect();
Serialize::serialize(&signatures, serializer)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SigningKeys<T: Ord>(BTreeMap<T, SigningKey>);
impl<T: Ord> SigningKeys<T> {
pub fn new() -> Self {
Self(BTreeMap::new())
}
pub fn insert(&mut self, key_id: T, key: SigningKey) -> Option<SigningKey> {
self.0.insert(key_id, key)
}
pub fn get<Q>(&self, key_id: &Q) -> Option<&SigningKey>
where
T: Borrow<Q>,
Q: Ord + ?Sized,
{
self.0.get(key_id)
}
pub fn iter(&self) -> Iter<'_, T, SigningKey> {
self.0.iter()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
impl<T: Ord> Default for SigningKeys<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Ord> IntoIterator for SigningKeys<T> {
type Item = (T, SigningKey);
type IntoIter = IntoIter<T, SigningKey>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl<K: Ord> FromIterator<(K, SigningKey)> for SigningKeys<K> {
fn from_iter<T: IntoIterator<Item = (K, SigningKey)>>(iter: T) -> Self {
let map = BTreeMap::from_iter(iter);
Self(map)
}
}
impl<K: Ord, const N: usize> From<[(K, SigningKey); N]> for SigningKeys<K> {
fn from(v: [(K, SigningKey); N]) -> Self {
let map = BTreeMap::from(v);
Self(map)
}
}
trait Algorithm {
fn algorithm(&self) -> DeviceKeyAlgorithm;
}
impl Algorithm for OwnedDeviceKeyId {
fn algorithm(&self) -> DeviceKeyAlgorithm {
DeviceKeyId::algorithm(self)
}
}
impl Algorithm for DeviceKeyAlgorithm {
fn algorithm(&self) -> DeviceKeyAlgorithm {
self.to_owned()
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, StringEnum)]
#[non_exhaustive]
pub enum EventEncryptionAlgorithm {
#[ruma_enum(rename = "m.olm.v1.curve25519-aes-sha2")]
OlmV1Curve25519AesSha2,
#[cfg(feature = "experimental-algorithms")]
#[ruma_enum(rename = "m.olm.v2.curve25519-aes-sha2")]
OlmV2Curve25519AesSha2,
#[ruma_enum(rename = "m.megolm.v1.aes-sha2")]
MegolmV1AesSha2,
#[cfg(feature = "experimental-algorithms")]
#[ruma_enum(rename = "m.megolm.v2.aes-sha2")]
MegolmV2AesSha2,
#[doc(hidden)]
_Custom(PrivOwnedStr),
}
impl<T: Ord + Serialize> Serialize for SigningKeys<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let keys: BTreeMap<&T, String> =
self.0.iter().map(|(key_id, key)| (key_id, key.to_base64())).collect();
keys.serialize(serializer)
}
}
impl<'de, T: Algorithm + Ord + Deserialize<'de>> Deserialize<'de> for SigningKeys<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let map: BTreeMap<T, String> = Deserialize::deserialize(deserializer)?;
let map: Result<_, _> = map
.into_iter()
.map(|(key_id, key)| {
let key = SigningKey::from_parts(&key_id.algorithm(), key)
.map_err(serde::de::Error::custom)?;
Ok((key_id, key))
})
.collect();
Ok(SigningKeys(map?))
}
}
from_base64!(Curve25519PublicKey, deserialize_curve_key);
to_base64!(Curve25519PublicKey, serialize_curve_key);
from_base64!(Ed25519PublicKey, deserialize_ed25519_key);
to_base64!(Ed25519PublicKey, serialize_ed25519_key);
pub(crate) fn deserialize_curve_key_vec<'de, D>(de: D) -> Result<Vec<Curve25519PublicKey>, D::Error>
where
D: Deserializer<'de>,
{
let keys: Vec<String> = Deserialize::deserialize(de)?;
let keys: Result<Vec<Curve25519PublicKey>, KeyError> =
keys.iter().map(|k| Curve25519PublicKey::from_base64(k)).collect();
keys.map_err(serde::de::Error::custom)
}
pub(crate) fn serialize_curve_key_vec<S>(
keys: &[Curve25519PublicKey],
s: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let keys: Vec<String> = keys.iter().map(|k| k.to_base64()).collect();
keys.serialize(s)
}
#[cfg(test)]
mod test {
use insta::{assert_debug_snapshot, assert_json_snapshot, with_settings};
use ruma::{device_id, user_id};
use serde_json::json;
use similar_asserts::assert_eq;
use super::*;
#[test]
fn serialize_secrets_bundle() {
let json = json!({
"cross_signing": {
"master_key": "rTtSv67XGS6k/rg6/yTG/m573cyFTPFRqluFhQY+hSw",
"self_signing_key": "4jbPt7jh5D2iyM4U+3IDa+WthgJB87IQN1ATdkau+xk",
"user_signing_key": "YkFKtkjcsTxF6UAzIIG/l6Nog/G2RigCRfWj3cjNWeM",
},
"backup": {
"algorithm": "m.megolm_backup.v1.curve25519-aes-sha2",
"backup_version": "2",
"key": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
},
});
let deserialized: SecretsBundle = serde_json::from_value(json.clone())
.expect("We should be able to deserialize the secrets bundle");
let serialized = serde_json::to_value(&deserialized)
.expect("We should be able to serialize a secrets bundle");
assert_eq!(json, serialized, "A serialization cycle should yield the same result");
}
#[test]
fn snapshot_backup_decryption_key() {
let decryption_key = BackupDecryptionKey { inner: Box::new([1u8; 32]) };
assert_json_snapshot!(decryption_key);
assert_debug_snapshot!(decryption_key);
}
#[test]
fn snapshot_signatures() {
let signatures = Signatures(BTreeMap::from([
(
user_id!("@alice:localhost").to_owned(),
BTreeMap::from([
(
DeviceKeyId::from_parts(
DeviceKeyAlgorithm::Ed25519,
device_id!("ABCDEFGH"),
),
Ok(Signature::from(Ed25519Signature::from_slice(&[0u8; 64]).unwrap())),
),
(
DeviceKeyId::from_parts(
DeviceKeyAlgorithm::Curve25519,
device_id!("IJKLMNOP"),
),
Ok(Signature::from(Ed25519Signature::from_slice(&[1u8; 64]).unwrap())),
),
]),
),
(
user_id!("@bob:localhost").to_owned(),
BTreeMap::from([(
DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, device_id!("ABCDEFGH")),
Err(InvalidSignature { source: "SOME+B64+SOME+B64+SOME+B64+==".to_owned() }),
)]),
),
]));
with_settings!({sort_maps =>true}, {
assert_json_snapshot!(signatures)
});
}
#[test]
fn snapshot_secret_bundle() {
let secret_bundle = SecretsBundle {
cross_signing: CrossSigningSecrets {
master_key: "MSKMSKMSKMSKMSKMSKMSKMSKMSKMSKMSKMSK".to_owned(),
user_signing_key: "USKUSKUSKUSKUSKUSKUSKUSKUSKUSKUSKUSK".to_owned(),
self_signing_key: "SSKSSKSSKSSKSSKSSKSSKSSKSSKSSKSSK".to_owned(),
},
backup: Some(BackupSecrets::MegolmBackupV1Curve25519AesSha2(
MegolmBackupV1Curve25519AesSha2Secrets {
key: BackupDecryptionKey::from_bytes(&[0u8; 32]),
backup_version: "v1.1".to_owned(),
},
)),
};
assert_json_snapshot!(secret_bundle);
let secret_bundle = SecretsBundle {
cross_signing: CrossSigningSecrets {
master_key: "MSKMSKMSKMSKMSKMSKMSKMSKMSKMSKMSKMSK".to_owned(),
user_signing_key: "USKUSKUSKUSKUSKUSKUSKUSKUSKUSKUSKUSK".to_owned(),
self_signing_key: "SSKSSKSSKSSKSSKSSKSSKSSKSSKSSKSSK".to_owned(),
},
backup: None,
};
assert_json_snapshot!(secret_bundle);
}
}