use matrix_sdk_base::crypto::types::SecretsBundle;
use openidconnect::{
core::CoreDeviceAuthorizationResponse, EndUserVerificationUrl, VerificationUriComplete,
};
use ruma::serde::StringEnum;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use url::Url;
use vodozemac::Curve25519PublicKey;
#[cfg(doc)]
use crate::authentication::qrcode::QRCodeLoginError::SecureChannel;
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum QrAuthMessage {
#[serde(rename = "m.login.protocols")]
LoginProtocols {
protocols: Vec<LoginProtocolType>,
homeserver: Url,
},
#[serde(rename = "m.login.protocol")]
LoginProtocol {
device_authorization_grant: AuthorizationGrant,
protocol: LoginProtocolType,
#[serde(
deserialize_with = "deserialize_curve_key",
serialize_with = "serialize_curve_key"
)]
device_id: Curve25519PublicKey,
},
#[serde(rename = "m.login.protocol_accepted")]
LoginProtocolAccepted,
#[serde(rename = "m.login.success")]
LoginSuccess,
#[serde(rename = "m.login.declined")]
LoginDeclined,
#[serde(rename = "m.login.failure")]
LoginFailure {
reason: LoginFailureReason,
homeserver: Option<Url>,
},
#[serde(rename = "m.login.secrets")]
LoginSecrets(SecretsBundle),
}
impl QrAuthMessage {
pub fn authorization_grant_login_protocol(
device_authorization_grant: AuthorizationGrant,
device_id: Curve25519PublicKey,
) -> QrAuthMessage {
QrAuthMessage::LoginProtocol {
device_id,
device_authorization_grant,
protocol: LoginProtocolType::DeviceAuthorizationGrant,
}
}
}
impl From<&CoreDeviceAuthorizationResponse> for AuthorizationGrant {
fn from(value: &CoreDeviceAuthorizationResponse) -> Self {
Self {
verification_uri: value.verification_uri().clone(),
verification_uri_complete: value.verification_uri_complete().cloned(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthorizationGrant {
pub verification_uri: EndUserVerificationUrl,
pub verification_uri_complete: Option<VerificationUriComplete>,
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, StringEnum)]
#[ruma_enum(rename_all = "snake_case")]
pub enum LoginFailureReason {
AuthorizationExpired,
DeviceAlreadyExists,
DeviceNotFound,
UnexpectedMessageReceived,
UnsupportedProtocol,
UserCancelled,
#[doc(hidden)]
_Custom(PrivOwnedStr),
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, StringEnum)]
#[ruma_enum(rename_all = "snake_case")]
pub enum LoginProtocolType {
DeviceAuthorizationGrant,
#[doc(hidden)]
_Custom(PrivOwnedStr),
}
pub(crate) fn deserialize_curve_key<'de, D>(de: D) -> Result<Curve25519PublicKey, D::Error>
where
D: Deserializer<'de>,
{
let key: String = Deserialize::deserialize(de)?;
Curve25519PublicKey::from_base64(&key).map_err(serde::de::Error::custom)
}
pub(crate) fn serialize_curve_key<S>(key: &Curve25519PublicKey, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
s.serialize_str(&key.to_base64())
}
#[doc(hidden)]
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct PrivOwnedStr(Box<str>);
#[cfg(test)]
mod test {
use assert_matches2::assert_let;
use matrix_sdk_base::crypto::types::BackupSecrets;
use serde_json::json;
use similar_asserts::assert_eq;
use super::*;
#[test]
fn test_protocols_serialization() {
let json = json!({
"type": "m.login.protocols",
"protocols": ["device_authorization_grant"],
"homeserver": "https://matrix-client.matrix.org/"
});
let message: QrAuthMessage = serde_json::from_value(json.clone()).unwrap();
assert_let!(QrAuthMessage::LoginProtocols { protocols, .. } = &message);
assert!(protocols.contains(&LoginProtocolType::DeviceAuthorizationGrant));
let serialized = serde_json::to_value(&message).unwrap();
assert_eq!(json, serialized);
}
#[test]
fn test_protocol_serialization() {
let json = json!({
"type": "m.login.protocol",
"protocol": "device_authorization_grant",
"device_authorization_grant": {
"verification_uri_complete": "https://id.matrix.org/device/abcde",
"verification_uri": "https://id.matrix.org/device/abcde?code=ABCDE"
},
"device_id": "wjLpTLRqbqBzLs63aYaEv2Boi6cFEbbM/sSRQ2oAKk4"
});
let curve_key =
Curve25519PublicKey::from_base64("wjLpTLRqbqBzLs63aYaEv2Boi6cFEbbM/sSRQ2oAKk4")
.unwrap();
let message: QrAuthMessage = serde_json::from_value(json.clone()).unwrap();
assert_let!(QrAuthMessage::LoginProtocol { protocol, device_id, .. } = &message);
assert_eq!(protocol, &LoginProtocolType::DeviceAuthorizationGrant);
assert_eq!(device_id, &curve_key);
let serialized = serde_json::to_value(&message).unwrap();
assert_eq!(json, serialized);
}
#[test]
fn test_protocol_accepted_serialization() {
let json = json!({
"type": "m.login.protocol_accepted",
});
let message: QrAuthMessage = serde_json::from_value(json.clone()).unwrap();
assert_let!(QrAuthMessage::LoginProtocolAccepted = &message);
let serialized = serde_json::to_value(&message).unwrap();
assert_eq!(json, serialized);
}
#[test]
fn test_login_success() {
let json = json!({
"type": "m.login.success",
});
let message: QrAuthMessage = serde_json::from_value(json.clone()).unwrap();
assert_let!(QrAuthMessage::LoginSuccess = &message);
let serialized = serde_json::to_value(&message).unwrap();
assert_eq!(json, serialized);
}
#[test]
fn test_login_declined() {
let json = json!({
"type": "m.login.declined",
});
let message: QrAuthMessage = serde_json::from_value(json.clone()).unwrap();
assert_let!(QrAuthMessage::LoginDeclined = &message);
let serialized = serde_json::to_value(&message).unwrap();
assert_eq!(json, serialized);
}
#[test]
fn test_login_failure() {
let json = json!({
"type": "m.login.failure",
"reason": "unsupported_protocol",
"homeserver": "https://matrix-client.matrix.org/"
});
let message: QrAuthMessage = serde_json::from_value(json.clone()).unwrap();
assert_let!(QrAuthMessage::LoginFailure { reason, .. } = &message);
assert_eq!(reason, &LoginFailureReason::UnsupportedProtocol);
let serialized = serde_json::to_value(&message).unwrap();
assert_eq!(json, serialized);
}
#[test]
fn test_login_secrets() {
let json = json!({
"type": "m.login.secrets",
"cross_signing": {
"master_key": "rTtSv67XGS6k/rg6/yTG/m573cyFTPFRqluFhQY+hSw",
"self_signing_key": "4jbPt7jh5D2iyM4U+3IDa+WthgJB87IQN1ATdkau+xk",
"user_signing_key": "YkFKtkjcsTxF6UAzIIG/l6Nog/G2RigCRfWj3cjNWeM",
},
"backup": {
"algorithm": "m.megolm_backup.v1.curve25519-aes-sha2",
"backup_version": "2",
"key": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
},
});
let message: QrAuthMessage = serde_json::from_value(json.clone()).unwrap();
assert_let!(
QrAuthMessage::LoginSecrets(SecretsBundle { cross_signing, backup }) = &message
);
assert_eq!(cross_signing.master_key, "rTtSv67XGS6k/rg6/yTG/m573cyFTPFRqluFhQY+hSw");
assert_eq!(cross_signing.self_signing_key, "4jbPt7jh5D2iyM4U+3IDa+WthgJB87IQN1ATdkau+xk");
assert_eq!(cross_signing.user_signing_key, "YkFKtkjcsTxF6UAzIIG/l6Nog/G2RigCRfWj3cjNWeM");
assert_let!(Some(BackupSecrets::MegolmBackupV1Curve25519AesSha2(backup)) = backup);
assert_eq!(backup.backup_version, "2");
assert_eq!(&backup.key.to_base64(), "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA");
let serialized = serde_json::to_value(&message).unwrap();
assert_eq!(json, serialized);
}
}