1use std::{matches, sync::Arc, time::Duration};
16
17use matrix_sdk_common::locks::Mutex;
18use ruma::{
19 events::{
20 key::verification::{
21 accept::{
22 AcceptMethod, KeyVerificationAcceptEventContent, SasV1Content as AcceptV1Content,
23 SasV1ContentInit as AcceptV1ContentInit, ToDeviceKeyVerificationAcceptEventContent,
24 },
25 cancel::CancelCode,
26 done::{KeyVerificationDoneEventContent, ToDeviceKeyVerificationDoneEventContent},
27 key::{KeyVerificationKeyEventContent, ToDeviceKeyVerificationKeyEventContent},
28 start::{
29 KeyVerificationStartEventContent, SasV1Content, SasV1ContentInit, StartMethod,
30 ToDeviceKeyVerificationStartEventContent,
31 },
32 HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode,
33 ShortAuthenticationString,
34 },
35 relation::Reference,
36 AnyMessageLikeEventContent, AnyToDeviceEventContent,
37 },
38 serde::Base64,
39 time::Instant,
40 DeviceId, OwnedTransactionId, TransactionId, UserId,
41};
42use serde::{Deserialize, Serialize};
43use tracing::info;
44use vodozemac::{
45 sas::{EstablishedSas, Mac, Sas},
46 Curve25519PublicKey,
47};
48
49use super::{
50 helpers::{
51 calculate_commitment, get_decimal, get_emoji, get_emoji_index, get_mac_content,
52 receive_mac_event, SasIds,
53 },
54 OutgoingContent,
55};
56use crate::{
57 identities::{DeviceData, UserIdentityData},
58 olm::StaticAccountData,
59 verification::{
60 cache::RequestInfo,
61 event_enums::{
62 AcceptContent, DoneContent, KeyContent, MacContent, OwnedAcceptContent,
63 OwnedStartContent, StartContent,
64 },
65 Cancelled, Emoji, FlowId,
66 },
67 OwnUserIdentityData,
68};
69
70const KEY_AGREEMENT_PROTOCOLS: &[KeyAgreementProtocol] =
71 &[KeyAgreementProtocol::Curve25519HkdfSha256];
72const HASHES: &[HashAlgorithm] = &[HashAlgorithm::Sha256];
73const STRINGS: &[ShortAuthenticationString] =
74 &[ShortAuthenticationString::Decimal, ShortAuthenticationString::Emoji];
75
76fn the_protocol_definitions(
77 short_auth_strings: Option<Vec<ShortAuthenticationString>>,
78) -> SasV1Content {
79 SasV1ContentInit {
80 short_authentication_string: short_auth_strings.unwrap_or_else(|| STRINGS.to_owned()),
81 key_agreement_protocols: KEY_AGREEMENT_PROTOCOLS.to_vec(),
82 message_authentication_codes: vec![
83 #[allow(deprecated)]
84 MessageAuthenticationCode::HkdfHmacSha256,
85 MessageAuthenticationCode::HkdfHmacSha256V2,
86 MessageAuthenticationCode::from("org.matrix.msc3783.hkdf-hmac-sha256"),
88 ],
89 hashes: HASHES.to_vec(),
90 }
91 .into()
92}
93
94const MAX_AGE: Duration = Duration::from_secs(60 * 5);
96
97const MAX_EVENT_TIMEOUT: Duration = Duration::from_secs(60);
99
100#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
104pub enum SupportedMacMethod {
105 #[serde(rename = "hkdf-hmac-sha256")]
106 HkdfHmacSha256,
107 #[serde(rename = "hkdf-hmac-sha256.v2")]
108 HkdfHmacSha256V2,
109 #[serde(rename = "org.matrix.msc3783.hkdf-hmac-sha256")]
110 Msc3783HkdfHmacSha256V2,
111}
112
113impl AsRef<str> for SupportedMacMethod {
114 fn as_ref(&self) -> &str {
115 match self {
116 SupportedMacMethod::HkdfHmacSha256 => "hkdf-hmac-sha256",
117 SupportedMacMethod::HkdfHmacSha256V2 => "hkdf-hmac-sha256.v2",
118 SupportedMacMethod::Msc3783HkdfHmacSha256V2 => "org.matrix.msc3783.hkdf-hmac-sha256",
119 }
120 }
121}
122
123impl From<SupportedMacMethod> for MessageAuthenticationCode {
124 fn from(m: SupportedMacMethod) -> Self {
125 MessageAuthenticationCode::from(m.as_ref())
126 }
127}
128
129impl TryFrom<&MessageAuthenticationCode> for SupportedMacMethod {
130 type Error = ();
131
132 fn try_from(value: &MessageAuthenticationCode) -> Result<Self, Self::Error> {
133 match value.as_str() {
134 "hkdf-hmac-sha256" => Ok(Self::HkdfHmacSha256),
135 "org.matrix.msc3783.hkdf-hmac-sha256" => Ok(Self::Msc3783HkdfHmacSha256V2),
136 "hkdf-hmac-sha256.v2" => Ok(Self::HkdfHmacSha256V2),
137 _ => Err(()),
138 }
139 }
140}
141
142impl SupportedMacMethod {
143 pub fn verify_mac(
149 &self,
150 sas: &EstablishedSas,
151 input: &str,
152 info: &str,
153 mac: &Base64,
154 ) -> Result<(), CancelCode> {
155 match self {
156 SupportedMacMethod::HkdfHmacSha256 => {
157 let calculated_mac = sas.calculate_mac_invalid_base64(input, info);
158 let calculated_mac = Base64::parse(calculated_mac)
159 .expect("We can always decode a Mac from vodozemac");
160
161 if calculated_mac != *mac {
162 Err(CancelCode::KeyMismatch)
163 } else {
164 Ok(())
165 }
166 }
167 SupportedMacMethod::HkdfHmacSha256V2 | SupportedMacMethod::Msc3783HkdfHmacSha256V2 => {
168 let mac = Mac::from_slice(mac.as_bytes());
169 sas.verify_mac(input, info, &mac).map_err(|_| CancelCode::MismatchedSas)
170 }
171 }
172 }
173
174 pub fn calculate_mac(&self, sas: &EstablishedSas, input: &str, info: &str) -> Base64 {
180 match self {
181 SupportedMacMethod::HkdfHmacSha256 => {
182 Base64::parse(sas.calculate_mac_invalid_base64(input, info))
183 .expect("We can always decode our newly generated Mac")
184 }
185 SupportedMacMethod::HkdfHmacSha256V2 | SupportedMacMethod::Msc3783HkdfHmacSha256V2 => {
186 let mac = sas.calculate_mac(input, info);
187 Base64::new(mac.as_bytes().to_vec())
188 }
189 }
190 }
191}
192
193#[derive(Clone, Debug, PartialEq, Eq)]
196pub struct AcceptedProtocols {
197 pub key_agreement_protocol: KeyAgreementProtocol,
199 pub hash: HashAlgorithm,
201 pub message_auth_code: SupportedMacMethod,
203 pub short_auth_string: Vec<ShortAuthenticationString>,
206}
207
208impl TryFrom<AcceptV1Content> for AcceptedProtocols {
209 type Error = CancelCode;
210
211 fn try_from(content: AcceptV1Content) -> Result<Self, Self::Error> {
212 if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol)
213 || !HASHES.contains(&content.hash)
214 || (!content.short_authentication_string.contains(&ShortAuthenticationString::Emoji)
215 && !content
216 .short_authentication_string
217 .contains(&ShortAuthenticationString::Decimal))
218 {
219 Err(CancelCode::UnknownMethod)
220 } else {
221 let message_auth_code = (&content.message_authentication_code)
222 .try_into()
223 .map_err(|_| CancelCode::UnknownMethod)?;
224
225 Ok(Self {
226 hash: content.hash,
227 key_agreement_protocol: content.key_agreement_protocol,
228 message_auth_code,
229 short_auth_string: content.short_authentication_string,
230 })
231 }
232 }
233}
234
235impl TryFrom<&SasV1Content> for AcceptedProtocols {
236 type Error = CancelCode;
237
238 fn try_from(method_content: &SasV1Content) -> Result<Self, Self::Error> {
239 if !method_content
240 .key_agreement_protocols
241 .contains(&KeyAgreementProtocol::Curve25519HkdfSha256)
242 || !method_content.hashes.contains(&HashAlgorithm::Sha256)
243 || (!method_content
244 .short_authentication_string
245 .contains(&ShortAuthenticationString::Decimal)
246 && !method_content
247 .short_authentication_string
248 .contains(&ShortAuthenticationString::Emoji))
249 {
250 Err(CancelCode::UnknownMethod)
251 } else {
252 let mac_methods: Vec<SupportedMacMethod> = method_content
253 .message_authentication_codes
254 .iter()
255 .filter_map(|m| SupportedMacMethod::try_from(m).ok())
256 .collect();
257
258 let message_auth_code =
259 if mac_methods.contains(&SupportedMacMethod::HkdfHmacSha256V2) {
260 Some(SupportedMacMethod::HkdfHmacSha256V2)
261 } else if mac_methods.contains(&SupportedMacMethod::Msc3783HkdfHmacSha256V2) {
262 Some(SupportedMacMethod::Msc3783HkdfHmacSha256V2)
263 } else {
264 mac_methods.first().copied()
265 }
266 .ok_or(CancelCode::UnknownMethod)?;
267
268 let mut short_auth_string = vec![];
269
270 if method_content
271 .short_authentication_string
272 .contains(&ShortAuthenticationString::Decimal)
273 {
274 short_auth_string.push(ShortAuthenticationString::Decimal)
275 }
276
277 if method_content
278 .short_authentication_string
279 .contains(&ShortAuthenticationString::Emoji)
280 {
281 short_auth_string.push(ShortAuthenticationString::Emoji);
282 }
283
284 Ok(Self {
285 hash: HashAlgorithm::Sha256,
286 key_agreement_protocol: KeyAgreementProtocol::Curve25519HkdfSha256,
287 message_auth_code,
288 short_auth_string,
289 })
290 }
291 }
292}
293
294#[cfg(not(tarpaulin_include))]
295impl Default for AcceptedProtocols {
296 fn default() -> Self {
297 AcceptedProtocols {
298 hash: HashAlgorithm::Sha256,
299 key_agreement_protocol: KeyAgreementProtocol::Curve25519HkdfSha256,
300 message_auth_code: SupportedMacMethod::HkdfHmacSha256V2,
301 short_auth_string: vec![
302 ShortAuthenticationString::Decimal,
303 ShortAuthenticationString::Emoji,
304 ],
305 }
306 }
307}
308
309#[derive(Clone)]
314pub struct SasState<S: Clone> {
315 inner: Arc<Mutex<Option<Sas>>>,
317
318 our_public_key: Curve25519PublicKey,
320
321 ids: SasIds,
323
324 creation_time: Arc<Instant>,
328
329 last_event_time: Arc<Instant>,
331
332 pub verification_flow_id: Arc<FlowId>,
337
338 pub state: Arc<S>,
340
341 pub started_from_request: bool,
343}
344
345impl<S: Clone> SasState<S> {
346 fn handle_key_content(
347 &self,
348 sender: &UserId,
349 content: &KeyContent<'_>,
350 ) -> Result<EstablishedSas, CancelCode> {
351 self.check_event(sender, content.flow_id())?;
352
353 let their_public_key = Curve25519PublicKey::from_slice(content.public_key().as_bytes())
354 .map_err(|_| CancelCode::from("Invalid public key"))?;
355
356 if let Some(sas) = self.inner.lock().take() {
357 sas.diffie_hellman(their_public_key).map_err(|_| "Invalid public key".into())
358 } else {
359 Err(CancelCode::UnexpectedMessage)
360 }
361 }
362}
363
364#[cfg(not(tarpaulin_include))]
365impl<S: Clone + std::fmt::Debug> std::fmt::Debug for SasState<S> {
366 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367 f.debug_struct("SasState")
368 .field("ids", &self.ids)
369 .field("flow_id", &self.verification_flow_id)
370 .field("state", &self.state)
371 .finish()
372 }
373}
374
375#[derive(Clone, Debug)]
377pub struct Created {
378 pub protocol_definitions: SasV1Content,
379}
380
381#[derive(Clone, Debug)]
383pub struct Started {
384 commitment: Base64,
385 pub protocol_definitions: SasV1Content,
386 pub accepted_protocols: AcceptedProtocols,
387}
388
389#[derive(Clone, Debug)]
392pub struct Accepted {
393 pub accepted_protocols: AcceptedProtocols,
394 start_content: Arc<OwnedStartContent>,
395 pub request_id: OwnedTransactionId,
396 commitment: Base64,
397}
398
399#[derive(Clone, Debug)]
402pub struct WeAccepted {
403 we_started: bool,
404 pub accepted_protocols: AcceptedProtocols,
405 commitment: Base64,
406}
407
408#[derive(Clone, Debug)]
413pub struct KeyReceived {
414 sas: Arc<Mutex<EstablishedSas>>,
415 we_started: bool,
416 pub request_id: OwnedTransactionId,
417 pub accepted_protocols: AcceptedProtocols,
418}
419
420#[derive(Clone, Debug)]
421pub struct KeySent {
422 we_started: bool,
423 start_content: Arc<OwnedStartContent>,
424 commitment: Base64,
425 pub accepted_protocols: AcceptedProtocols,
426}
427
428#[derive(Clone, Debug)]
429pub struct KeysExchanged {
430 sas: Arc<Mutex<EstablishedSas>>,
431 we_started: bool,
432 pub accepted_protocols: AcceptedProtocols,
433}
434
435#[derive(Clone, Debug)]
439pub struct Confirmed {
440 sas: Arc<Mutex<EstablishedSas>>,
441 pub accepted_protocols: AcceptedProtocols,
442}
443
444#[derive(Clone, Debug)]
448pub struct MacReceived {
449 sas: Arc<Mutex<EstablishedSas>>,
450 we_started: bool,
451 verified_devices: Arc<[DeviceData]>,
452 verified_master_keys: Arc<[UserIdentityData]>,
453 pub accepted_protocols: AcceptedProtocols,
454}
455
456#[derive(Clone, Debug)]
460pub struct WaitingForDone {
461 sas: Arc<Mutex<EstablishedSas>>,
462 verified_devices: Arc<[DeviceData]>,
463 verified_master_keys: Arc<[UserIdentityData]>,
464 pub accepted_protocols: AcceptedProtocols,
465}
466
467#[derive(Clone, Debug)]
473pub struct Done {
474 sas: Arc<Mutex<EstablishedSas>>,
475 verified_devices: Arc<[DeviceData]>,
476 verified_master_keys: Arc<[UserIdentityData]>,
477 pub accepted_protocols: AcceptedProtocols,
478}
479
480impl<S: Clone> SasState<S> {
481 #[cfg(test)]
483 pub fn user_id(&self) -> &UserId {
484 &self.ids.account.user_id
485 }
486
487 pub fn device_id(&self) -> &DeviceId {
489 &self.ids.account.device_id
490 }
491
492 #[cfg(test)]
493 pub fn other_device(&self) -> DeviceData {
494 self.ids.other_device.clone()
495 }
496
497 pub fn cancel(self, cancelled_by_us: bool, cancel_code: CancelCode) -> SasState<Cancelled> {
498 SasState {
499 inner: self.inner,
500 our_public_key: self.our_public_key,
501 ids: self.ids,
502 creation_time: self.creation_time,
503 last_event_time: self.last_event_time,
504 verification_flow_id: self.verification_flow_id,
505 state: Arc::new(Cancelled::new(cancelled_by_us, cancel_code)),
506 started_from_request: self.started_from_request,
507 }
508 }
509
510 pub fn timed_out(&self) -> bool {
512 self.creation_time.elapsed() > MAX_AGE || self.last_event_time.elapsed() > MAX_EVENT_TIMEOUT
513 }
514
515 #[allow(dead_code)]
517 pub fn is_dm_verification(&self) -> bool {
518 matches!(&*self.verification_flow_id, FlowId::InRoom(_, _))
519 }
520
521 #[cfg(test)]
522 #[allow(dead_code)]
523 pub fn set_creation_time(&mut self, time: Instant) {
524 self.creation_time = Arc::new(time);
525 }
526
527 fn check_event(&self, sender: &UserId, flow_id: &str) -> Result<(), CancelCode> {
528 if *flow_id != *self.verification_flow_id.as_str() {
529 Err(CancelCode::UnknownTransaction)
530 } else if sender != self.ids.other_device.user_id() {
531 Err(CancelCode::UserMismatch)
532 } else if self.timed_out() {
533 Err(CancelCode::Timeout)
534 } else {
535 Ok(())
536 }
537 }
538}
539
540impl SasState<Created> {
541 pub fn new(
551 account: StaticAccountData,
552 other_device: DeviceData,
553 own_identity: Option<OwnUserIdentityData>,
554 other_identity: Option<UserIdentityData>,
555 flow_id: FlowId,
556 started_from_request: bool,
557 short_auth_strings: Option<Vec<ShortAuthenticationString>>,
558 ) -> SasState<Created> {
559 Self::new_helper(
560 flow_id,
561 account,
562 other_device,
563 own_identity,
564 other_identity,
565 started_from_request,
566 short_auth_strings,
567 )
568 }
569
570 fn new_helper(
571 flow_id: FlowId,
572 account: StaticAccountData,
573 other_device: DeviceData,
574 own_identity: Option<OwnUserIdentityData>,
575 other_identity: Option<UserIdentityData>,
576 started_from_request: bool,
577 short_auth_strings: Option<Vec<ShortAuthenticationString>>,
578 ) -> SasState<Created> {
579 let sas = Sas::new();
580 let our_public_key = sas.public_key();
581
582 let protocol_definitions = the_protocol_definitions(short_auth_strings);
583
584 SasState {
585 inner: Arc::new(Mutex::new(Some(sas))),
586 our_public_key,
587 ids: SasIds { account, other_device, other_identity, own_identity },
588 verification_flow_id: flow_id.into(),
589
590 creation_time: Arc::new(Instant::now()),
591 last_event_time: Arc::new(Instant::now()),
592 started_from_request,
593
594 state: Arc::new(Created { protocol_definitions }),
595 }
596 }
597
598 pub fn as_content(&self) -> OwnedStartContent {
599 match self.verification_flow_id.as_ref() {
600 FlowId::ToDevice(s) => {
601 OwnedStartContent::ToDevice(ToDeviceKeyVerificationStartEventContent::new(
602 self.device_id().into(),
603 s.clone(),
604 StartMethod::SasV1(self.state.protocol_definitions.clone()),
605 ))
606 }
607 FlowId::InRoom(r, e) => OwnedStartContent::Room(
608 r.clone(),
609 KeyVerificationStartEventContent::new(
610 self.device_id().into(),
611 StartMethod::SasV1(self.state.protocol_definitions.clone()),
612 Reference::new(e.clone()),
613 ),
614 ),
615 }
616 }
617
618 pub fn into_accepted(
626 self,
627 sender: &UserId,
628 content: &AcceptContent<'_>,
629 ) -> Result<SasState<Accepted>, SasState<Cancelled>> {
630 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
631
632 let AcceptMethod::SasV1(content) = content.method() else {
633 return Err(self.cancel(true, CancelCode::UnknownMethod));
634 };
635
636 let accepted_protocols = AcceptedProtocols::try_from(content.clone())
637 .map_err(|c| self.clone().cancel(true, c))?;
638
639 let start_content = self.as_content().into();
640
641 Ok(SasState {
642 inner: self.inner,
643 our_public_key: self.our_public_key,
644 ids: self.ids,
645 verification_flow_id: self.verification_flow_id,
646 creation_time: self.creation_time,
647 last_event_time: Instant::now().into(),
648 started_from_request: self.started_from_request,
649 state: Arc::new(Accepted {
650 start_content,
651 commitment: content.commitment.clone(),
652 request_id: TransactionId::new(),
653 accepted_protocols,
654 }),
655 })
656 }
657}
658
659impl SasState<Started> {
660 pub fn from_start_event(
674 account: StaticAccountData,
675 other_device: DeviceData,
676 own_identity: Option<OwnUserIdentityData>,
677 other_identity: Option<UserIdentityData>,
678 flow_id: FlowId,
679 content: &StartContent<'_>,
680 started_from_request: bool,
681 ) -> Result<SasState<Started>, SasState<Cancelled>> {
682 let flow_id = Arc::new(flow_id);
683
684 let sas = Sas::new();
685 let our_public_key = sas.public_key();
686
687 let canceled = || SasState {
688 inner: Arc::new(Mutex::new(None)),
689 our_public_key,
690
691 creation_time: Arc::new(Instant::now()),
692 last_event_time: Arc::new(Instant::now()),
693 started_from_request,
694
695 ids: SasIds {
696 account: account.clone(),
697 other_device: other_device.clone(),
698 own_identity: own_identity.clone(),
699 other_identity: other_identity.clone(),
700 },
701
702 verification_flow_id: flow_id.clone(),
703 state: Arc::new(Cancelled::new(true, CancelCode::UnknownMethod)),
704 };
705
706 let state = match content.method() {
707 StartMethod::SasV1(method_content) => {
708 let commitment = calculate_commitment(our_public_key, content);
709
710 info!(
711 public_key = our_public_key.to_base64(),
712 ?commitment,
713 ?content,
714 "Calculated SAS commitment",
715 );
716
717 let Ok(accepted_protocols) = AcceptedProtocols::try_from(method_content) else {
718 return Err(canceled());
719 };
720
721 Started {
722 protocol_definitions: method_content.to_owned(),
723 accepted_protocols,
724 commitment,
725 }
726 }
727 _ => return Err(canceled()),
728 };
729
730 Ok(SasState {
731 inner: Arc::new(Mutex::new(Some(sas))),
732 our_public_key,
733
734 ids: SasIds { account, other_device, other_identity, own_identity },
735
736 creation_time: Arc::new(Instant::now()),
737 last_event_time: Arc::new(Instant::now()),
738 started_from_request,
739
740 verification_flow_id: flow_id,
741
742 state: Arc::new(state),
743 })
744 }
745
746 #[cfg(test)]
747 fn into_we_accepted_with_mac_method(
748 self,
749 methods: Vec<ShortAuthenticationString>,
750 mac_method: Option<SupportedMacMethod>,
751 ) -> SasState<WeAccepted> {
752 let mut accepted_protocols = self.state.accepted_protocols.to_owned();
753
754 if let Some(mac_method) = mac_method {
755 accepted_protocols.message_auth_code = mac_method;
756 }
757
758 self.into_we_accepted_helper(accepted_protocols, methods)
759 }
760
761 fn into_we_accepted_helper(
762 self,
763 mut accepted_protocols: AcceptedProtocols,
764 methods: Vec<ShortAuthenticationString>,
765 ) -> SasState<WeAccepted> {
766 accepted_protocols.short_auth_string = methods;
767
768 if !accepted_protocols.short_auth_string.contains(&ShortAuthenticationString::Decimal) {
770 accepted_protocols.short_auth_string.push(ShortAuthenticationString::Decimal);
771 }
772
773 SasState {
774 inner: self.inner,
775 our_public_key: self.our_public_key,
776 ids: self.ids,
777 verification_flow_id: self.verification_flow_id,
778 creation_time: self.creation_time,
779 last_event_time: self.last_event_time,
780 started_from_request: self.started_from_request,
781 state: Arc::new(WeAccepted {
782 we_started: false,
783 accepted_protocols,
784 commitment: self.state.commitment.clone(),
785 }),
786 }
787 }
788
789 pub fn into_we_accepted(self, methods: Vec<ShortAuthenticationString>) -> SasState<WeAccepted> {
790 let accepted_protocols = self.state.accepted_protocols.to_owned();
791 self.into_we_accepted_helper(accepted_protocols, methods)
792 }
793
794 fn as_content(&self) -> OwnedStartContent {
795 match self.verification_flow_id.as_ref() {
796 FlowId::ToDevice(s) => {
797 OwnedStartContent::ToDevice(ToDeviceKeyVerificationStartEventContent::new(
798 self.device_id().into(),
799 s.clone(),
800 StartMethod::SasV1(self.state.protocol_definitions.to_owned()),
801 ))
802 }
803 FlowId::InRoom(r, e) => OwnedStartContent::Room(
804 r.clone(),
805 KeyVerificationStartEventContent::new(
806 self.device_id().into(),
807 StartMethod::SasV1(self.state.protocol_definitions.to_owned()),
808 Reference::new(e.clone()),
809 ),
810 ),
811 }
812 }
813
814 pub fn into_accepted(
827 self,
828 sender: &UserId,
829 content: &AcceptContent<'_>,
830 ) -> Result<SasState<Accepted>, SasState<Cancelled>> {
831 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
832
833 let AcceptMethod::SasV1(content) = content.method() else {
834 return Err(self.cancel(true, CancelCode::UnknownMethod));
835 };
836
837 let accepted_protocols = AcceptedProtocols::try_from(content.clone())
838 .map_err(|c| self.clone().cancel(true, c))?;
839
840 let start_content = self.as_content().into();
841
842 Ok(SasState {
843 inner: self.inner,
844 our_public_key: self.our_public_key,
845 ids: self.ids,
846 verification_flow_id: self.verification_flow_id,
847 creation_time: self.creation_time,
848 last_event_time: Instant::now().into(),
849 started_from_request: self.started_from_request,
850 state: Arc::new(Accepted {
851 start_content,
852 commitment: content.commitment.clone(),
853 request_id: TransactionId::new(),
854 accepted_protocols,
855 }),
856 })
857 }
858}
859
860impl SasState<WeAccepted> {
861 pub fn as_content(&self) -> OwnedAcceptContent {
869 let method = AcceptMethod::SasV1(
870 AcceptV1ContentInit {
871 commitment: self.state.commitment.clone(),
872 hash: self.state.accepted_protocols.hash.clone(),
873 key_agreement_protocol: self
874 .state
875 .accepted_protocols
876 .key_agreement_protocol
877 .clone(),
878 message_authentication_code: self.state.accepted_protocols.message_auth_code.into(),
879 short_authentication_string: self
880 .state
881 .accepted_protocols
882 .short_auth_string
883 .clone(),
884 }
885 .into(),
886 );
887
888 match self.verification_flow_id.as_ref() {
889 FlowId::ToDevice(s) => {
890 ToDeviceKeyVerificationAcceptEventContent::new(s.clone(), method).into()
891 }
892 FlowId::InRoom(r, e) => (
893 r.clone(),
894 KeyVerificationAcceptEventContent::new(method, Reference::new(e.clone())),
895 )
896 .into(),
897 }
898 }
899
900 pub fn into_key_received(
909 self,
910 sender: &UserId,
911 content: &KeyContent<'_>,
912 ) -> Result<SasState<KeyReceived>, SasState<Cancelled>> {
913 let established =
914 self.handle_key_content(sender, content).map_err(|c| self.clone().cancel(true, c))?;
915
916 Ok(SasState {
917 inner: self.inner,
918 our_public_key: self.our_public_key,
919 ids: self.ids,
920 verification_flow_id: self.verification_flow_id,
921 creation_time: self.creation_time,
922 last_event_time: Instant::now().into(),
923 started_from_request: self.started_from_request,
924 state: Arc::new(KeyReceived {
925 sas: Mutex::new(established).into(),
926 we_started: self.state.we_started,
927 request_id: TransactionId::new(),
928 accepted_protocols: self.state.accepted_protocols.clone(),
929 }),
930 })
931 }
932}
933
934impl SasState<Accepted> {
935 pub fn into_key_received(
944 self,
945 sender: &UserId,
946 content: &KeyContent<'_>,
947 ) -> Result<SasState<KeyReceived>, SasState<Cancelled>> {
948 let established =
949 self.handle_key_content(sender, content).map_err(|c| self.clone().cancel(true, c))?;
950
951 let their_public_key = established.their_public_key();
952
953 let commitment =
954 calculate_commitment(their_public_key, &self.state.start_content.as_start_content());
955
956 if self.state.commitment == commitment {
957 Ok(SasState {
958 inner: self.inner,
959 our_public_key: self.our_public_key,
960 ids: self.ids,
961 verification_flow_id: self.verification_flow_id,
962 creation_time: self.creation_time,
963 last_event_time: Instant::now().into(),
964 started_from_request: self.started_from_request,
965 state: Arc::new(KeyReceived {
966 sas: Mutex::new(established).into(),
967 we_started: true,
968 request_id: self.state.request_id.to_owned(),
969 accepted_protocols: self.state.accepted_protocols.clone(),
970 }),
971 })
972 } else {
973 Err(self.cancel(true, CancelCode::KeyMismatch))
974 }
975 }
976
977 pub fn into_key_sent(self, request_id: &TransactionId) -> Option<SasState<KeySent>> {
978 (self.state.request_id == request_id).then(|| SasState {
979 inner: self.inner,
980 our_public_key: self.our_public_key,
981 ids: self.ids,
982 verification_flow_id: self.verification_flow_id,
983 creation_time: self.creation_time,
984 last_event_time: Instant::now().into(),
985 started_from_request: self.started_from_request,
986 state: Arc::new(KeySent {
987 we_started: true,
988 start_content: self.state.start_content.clone(),
989 commitment: self.state.commitment.clone(),
990 accepted_protocols: self.state.accepted_protocols.clone(),
991 }),
992 })
993 }
994
995 pub fn as_content(&self) -> (OutgoingContent, RequestInfo) {
999 let content = match &*self.verification_flow_id {
1000 FlowId::ToDevice(s) => AnyToDeviceEventContent::KeyVerificationKey(
1001 ToDeviceKeyVerificationKeyEventContent::new(
1002 s.clone(),
1003 Base64::new(self.our_public_key.to_vec()),
1004 ),
1005 )
1006 .into(),
1007 FlowId::InRoom(r, e) => (
1008 r.clone(),
1009 AnyMessageLikeEventContent::KeyVerificationKey(
1010 KeyVerificationKeyEventContent::new(
1011 Base64::new(self.our_public_key.to_vec()),
1012 Reference::new(e.clone()),
1013 ),
1014 ),
1015 )
1016 .into(),
1017 };
1018
1019 (
1020 content,
1021 RequestInfo {
1022 flow_id: (*self.verification_flow_id).to_owned(),
1023 request_id: self.state.request_id.to_owned(),
1024 },
1025 )
1026 }
1027}
1028
1029impl SasState<KeySent> {
1030 pub fn into_keys_exchanged(
1031 self,
1032 sender: &UserId,
1033 content: &KeyContent<'_>,
1034 ) -> Result<SasState<KeysExchanged>, SasState<Cancelled>> {
1035 let established =
1036 self.handle_key_content(sender, content).map_err(|c| self.clone().cancel(true, c))?;
1037
1038 let their_public_key = established.their_public_key();
1039 let commitment =
1040 calculate_commitment(their_public_key, &self.state.start_content.as_start_content());
1041
1042 if self.state.commitment == commitment {
1043 Ok(SasState {
1044 inner: self.inner,
1045 our_public_key: self.our_public_key,
1046 ids: self.ids,
1047 verification_flow_id: self.verification_flow_id,
1048 creation_time: self.creation_time,
1049 last_event_time: Instant::now().into(),
1050 started_from_request: self.started_from_request,
1051 state: Arc::new(KeysExchanged {
1052 sas: Mutex::new(established).into(),
1053 we_started: self.state.we_started,
1054 accepted_protocols: self.state.accepted_protocols.clone(),
1055 }),
1056 })
1057 } else {
1058 Err(self.cancel(true, CancelCode::KeyMismatch))
1059 }
1060 }
1061}
1062
1063impl SasState<KeyReceived> {
1064 pub fn as_content(&self) -> (OutgoingContent, RequestInfo) {
1069 let content = match &*self.verification_flow_id {
1070 FlowId::ToDevice(s) => AnyToDeviceEventContent::KeyVerificationKey(
1071 ToDeviceKeyVerificationKeyEventContent::new(
1072 s.clone(),
1073 Base64::new(self.our_public_key.to_vec()),
1074 ),
1075 )
1076 .into(),
1077 FlowId::InRoom(r, e) => (
1078 r.clone(),
1079 AnyMessageLikeEventContent::KeyVerificationKey(
1080 KeyVerificationKeyEventContent::new(
1081 Base64::new(self.our_public_key.to_vec()),
1082 Reference::new(e.clone()),
1083 ),
1084 ),
1085 )
1086 .into(),
1087 };
1088
1089 (
1090 content,
1091 RequestInfo {
1092 flow_id: (*self.verification_flow_id).to_owned(),
1093 request_id: self.state.request_id.to_owned(),
1094 },
1095 )
1096 }
1097
1098 pub fn into_keys_exchanged(
1099 self,
1100 request_id: &TransactionId,
1101 ) -> Option<SasState<KeysExchanged>> {
1102 (self.state.request_id == request_id).then(|| SasState {
1103 inner: self.inner,
1104 our_public_key: self.our_public_key,
1105 ids: self.ids,
1106 verification_flow_id: self.verification_flow_id,
1107 creation_time: self.creation_time,
1108 last_event_time: Instant::now().into(),
1109 started_from_request: self.started_from_request,
1110 state: KeysExchanged {
1111 sas: self.state.sas.clone(),
1112 we_started: self.state.we_started,
1113 accepted_protocols: self.state.accepted_protocols.clone(),
1114 }
1115 .into(),
1116 })
1117 }
1118}
1119
1120impl SasState<KeysExchanged> {
1121 pub fn get_emoji(&self) -> [Emoji; 7] {
1126 get_emoji(
1127 &self.state.sas.lock(),
1128 &self.ids,
1129 self.verification_flow_id.as_str(),
1130 self.state.we_started,
1131 )
1132 }
1133
1134 pub fn get_emoji_index(&self) -> [u8; 7] {
1139 get_emoji_index(
1140 &self.state.sas.lock(),
1141 &self.ids,
1142 self.verification_flow_id.as_str(),
1143 self.state.we_started,
1144 )
1145 }
1146
1147 pub fn get_decimal(&self) -> (u16, u16, u16) {
1152 get_decimal(
1153 &self.state.sas.lock(),
1154 &self.ids,
1155 self.verification_flow_id.as_str(),
1156 self.state.we_started,
1157 )
1158 }
1159
1160 pub fn into_mac_received(
1168 self,
1169 sender: &UserId,
1170 content: &MacContent<'_>,
1171 ) -> Result<SasState<MacReceived>, SasState<Cancelled>> {
1172 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
1173
1174 let (devices, master_keys) = receive_mac_event(
1175 &self.state.sas.lock(),
1176 &self.ids,
1177 self.verification_flow_id.as_str(),
1178 sender,
1179 self.state.accepted_protocols.message_auth_code,
1180 content,
1181 )
1182 .map_err(|c| self.clone().cancel(true, c))?;
1183
1184 Ok(SasState {
1185 inner: self.inner,
1186 our_public_key: self.our_public_key,
1187 verification_flow_id: self.verification_flow_id,
1188 creation_time: self.creation_time,
1189 last_event_time: Instant::now().into(),
1190 ids: self.ids,
1191 started_from_request: self.started_from_request,
1192 state: Arc::new(MacReceived {
1193 sas: self.state.sas.clone(),
1194 we_started: self.state.we_started,
1195 verified_devices: devices.into(),
1196 verified_master_keys: master_keys.into(),
1197 accepted_protocols: self.state.accepted_protocols.clone(),
1198 }),
1199 })
1200 }
1201
1202 pub fn confirm(self) -> SasState<Confirmed> {
1207 SasState {
1208 inner: self.inner,
1209 our_public_key: self.our_public_key,
1210 started_from_request: self.started_from_request,
1211 verification_flow_id: self.verification_flow_id,
1212 creation_time: self.creation_time,
1213 last_event_time: self.last_event_time,
1214 ids: self.ids,
1215 state: Arc::new(Confirmed {
1216 sas: self.state.sas.clone(),
1217 accepted_protocols: self.state.accepted_protocols.clone(),
1218 }),
1219 }
1220 }
1221}
1222
1223impl SasState<Confirmed> {
1224 pub fn into_done(
1232 self,
1233 sender: &UserId,
1234 content: &MacContent<'_>,
1235 ) -> Result<SasState<Done>, SasState<Cancelled>> {
1236 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
1237
1238 let (devices, master_keys) = receive_mac_event(
1239 &self.state.sas.lock(),
1240 &self.ids,
1241 self.verification_flow_id.as_str(),
1242 sender,
1243 self.state.accepted_protocols.message_auth_code,
1244 content,
1245 )
1246 .map_err(|c| self.clone().cancel(true, c))?;
1247
1248 Ok(SasState {
1249 inner: self.inner,
1250 our_public_key: self.our_public_key,
1251 creation_time: self.creation_time,
1252 last_event_time: Instant::now().into(),
1253 verification_flow_id: self.verification_flow_id,
1254 started_from_request: self.started_from_request,
1255 ids: self.ids,
1256
1257 state: Arc::new(Done {
1258 sas: self.state.sas.clone(),
1259 verified_devices: devices.into(),
1260 verified_master_keys: master_keys.into(),
1261 accepted_protocols: self.state.accepted_protocols.clone(),
1262 }),
1263 })
1264 }
1265
1266 pub fn into_waiting_for_done(
1276 self,
1277 sender: &UserId,
1278 content: &MacContent<'_>,
1279 ) -> Result<SasState<WaitingForDone>, SasState<Cancelled>> {
1280 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
1281
1282 let (devices, master_keys) = receive_mac_event(
1283 &self.state.sas.lock(),
1284 &self.ids,
1285 self.verification_flow_id.as_str(),
1286 sender,
1287 self.state.accepted_protocols.message_auth_code,
1288 content,
1289 )
1290 .map_err(|c| self.clone().cancel(true, c))?;
1291
1292 Ok(SasState {
1293 inner: self.inner,
1294 our_public_key: self.our_public_key,
1295 creation_time: self.creation_time,
1296 last_event_time: Instant::now().into(),
1297 verification_flow_id: self.verification_flow_id,
1298 started_from_request: self.started_from_request,
1299 ids: self.ids,
1300
1301 state: Arc::new(WaitingForDone {
1302 sas: self.state.sas.clone(),
1303 verified_devices: devices.into(),
1304 verified_master_keys: master_keys.into(),
1305 accepted_protocols: self.state.accepted_protocols.clone(),
1306 }),
1307 })
1308 }
1309
1310 pub fn as_content(&self) -> OutgoingContent {
1314 get_mac_content(
1315 &self.state.sas.lock(),
1316 &self.ids,
1317 &self.verification_flow_id,
1318 self.state.accepted_protocols.message_auth_code,
1319 )
1320 }
1321}
1322
1323impl SasState<MacReceived> {
1324 pub fn confirm(self) -> SasState<Done> {
1329 SasState {
1330 inner: self.inner,
1331 our_public_key: self.our_public_key,
1332 verification_flow_id: self.verification_flow_id,
1333 creation_time: self.creation_time,
1334 started_from_request: self.started_from_request,
1335 last_event_time: self.last_event_time,
1336 ids: self.ids,
1337 state: Arc::new(Done {
1338 sas: self.state.sas.clone(),
1339 verified_devices: self.state.verified_devices.clone(),
1340 verified_master_keys: self.state.verified_master_keys.clone(),
1341 accepted_protocols: self.state.accepted_protocols.clone(),
1342 }),
1343 }
1344 }
1345
1346 pub fn confirm_and_wait_for_done(self) -> SasState<WaitingForDone> {
1353 SasState {
1354 inner: self.inner,
1355 our_public_key: self.our_public_key,
1356 verification_flow_id: self.verification_flow_id,
1357 creation_time: self.creation_time,
1358 started_from_request: self.started_from_request,
1359 last_event_time: self.last_event_time,
1360 ids: self.ids,
1361 state: Arc::new(WaitingForDone {
1362 sas: self.state.sas.clone(),
1363 verified_devices: self.state.verified_devices.clone(),
1364 verified_master_keys: self.state.verified_master_keys.clone(),
1365 accepted_protocols: self.state.accepted_protocols.clone(),
1366 }),
1367 }
1368 }
1369
1370 pub fn get_emoji(&self) -> [Emoji; 7] {
1375 get_emoji(
1376 &self.state.sas.lock(),
1377 &self.ids,
1378 self.verification_flow_id.as_str(),
1379 self.state.we_started,
1380 )
1381 }
1382
1383 pub fn get_emoji_index(&self) -> [u8; 7] {
1388 get_emoji_index(
1389 &self.state.sas.lock(),
1390 &self.ids,
1391 self.verification_flow_id.as_str(),
1392 self.state.we_started,
1393 )
1394 }
1395
1396 pub fn get_decimal(&self) -> (u16, u16, u16) {
1401 get_decimal(
1402 &self.state.sas.lock(),
1403 &self.ids,
1404 self.verification_flow_id.as_str(),
1405 self.state.we_started,
1406 )
1407 }
1408}
1409
1410impl SasState<WaitingForDone> {
1411 pub fn as_content(&self) -> OutgoingContent {
1416 get_mac_content(
1417 &self.state.sas.lock(),
1418 &self.ids,
1419 &self.verification_flow_id,
1420 self.state.accepted_protocols.message_auth_code,
1421 )
1422 }
1423
1424 pub fn done_content(&self) -> OutgoingContent {
1425 match self.verification_flow_id.as_ref() {
1426 FlowId::ToDevice(t) => AnyToDeviceEventContent::KeyVerificationDone(
1427 ToDeviceKeyVerificationDoneEventContent::new(t.to_owned()),
1428 )
1429 .into(),
1430 FlowId::InRoom(r, e) => (
1431 r.clone(),
1432 AnyMessageLikeEventContent::KeyVerificationDone(
1433 KeyVerificationDoneEventContent::new(Reference::new(e.clone())),
1434 ),
1435 )
1436 .into(),
1437 }
1438 }
1439
1440 pub fn into_done(
1448 self,
1449 sender: &UserId,
1450 content: &DoneContent<'_>,
1451 ) -> Result<SasState<Done>, SasState<Cancelled>> {
1452 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
1453
1454 Ok(SasState {
1455 inner: self.inner,
1456 our_public_key: self.our_public_key,
1457 creation_time: self.creation_time,
1458 last_event_time: Instant::now().into(),
1459 verification_flow_id: self.verification_flow_id,
1460 started_from_request: self.started_from_request,
1461 ids: self.ids,
1462
1463 state: Arc::new(Done {
1464 sas: self.state.sas.clone(),
1465 verified_devices: self.state.verified_devices.clone(),
1466 verified_master_keys: self.state.verified_master_keys.clone(),
1467 accepted_protocols: self.state.accepted_protocols.clone(),
1468 }),
1469 })
1470 }
1471}
1472
1473impl SasState<Done> {
1474 pub fn as_content(&self) -> OutgoingContent {
1479 get_mac_content(
1480 &self.state.sas.lock(),
1481 &self.ids,
1482 &self.verification_flow_id,
1483 self.state.accepted_protocols.message_auth_code,
1484 )
1485 }
1486
1487 pub fn verified_devices(&self) -> Arc<[DeviceData]> {
1489 self.state.verified_devices.clone()
1490 }
1491
1492 pub fn verified_identities(&self) -> Arc<[UserIdentityData]> {
1494 self.state.verified_master_keys.clone()
1495 }
1496}
1497
1498impl SasState<Cancelled> {
1499 pub fn as_content(&self) -> OutgoingContent {
1500 self.state.as_content(&self.verification_flow_id)
1501 }
1502}
1503
1504#[cfg(test)]
1505mod tests {
1506 use matrix_sdk_test::async_test;
1507 use ruma::{
1508 device_id,
1509 events::key::verification::{
1510 accept::{AcceptMethod, ToDeviceKeyVerificationAcceptEventContent},
1511 start::{
1512 SasV1Content, SasV1ContentInit, StartMethod,
1513 ToDeviceKeyVerificationStartEventContent,
1514 },
1515 HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode,
1516 ShortAuthenticationString,
1517 },
1518 serde::Base64,
1519 user_id, DeviceId, TransactionId, UserId,
1520 };
1521 use serde_json::json;
1522
1523 use super::{Accepted, Created, SasState, Started, SupportedMacMethod, WeAccepted};
1524 use crate::{
1525 verification::{
1526 event_enums::{AcceptContent, KeyContent, MacContent, StartContent},
1527 FlowId,
1528 },
1529 AcceptedProtocols, Account, DeviceData,
1530 };
1531
1532 fn alice_id() -> &'static UserId {
1533 user_id!("@alice:example.org")
1534 }
1535
1536 fn alice_device_id() -> &'static DeviceId {
1537 device_id!("JLAFKJWSCS")
1538 }
1539
1540 fn bob_id() -> &'static UserId {
1541 user_id!("@bob:example.org")
1542 }
1543
1544 fn bob_device_id() -> &'static DeviceId {
1545 device_id!("BOBDEVICE")
1546 }
1547
1548 fn get_sas_pair(
1549 mac_method: Option<SupportedMacMethod>,
1550 ) -> (SasState<Created>, SasState<WeAccepted>) {
1551 let alice = Account::with_device_id(alice_id(), alice_device_id());
1552 let alice_device = DeviceData::from_account(&alice);
1553
1554 let bob = Account::with_device_id(bob_id(), bob_device_id());
1555 let bob_device = DeviceData::from_account(&bob);
1556
1557 let flow_id = TransactionId::new().into();
1558 let alice_sas = SasState::<Created>::new(
1559 alice.static_data().clone(),
1560 bob_device,
1561 None,
1562 None,
1563 flow_id,
1564 false,
1565 None,
1566 );
1567
1568 let start_content = alice_sas.as_content();
1569 let flow_id = start_content.flow_id();
1570
1571 let bob_sas = SasState::<Started>::from_start_event(
1572 bob.static_data().clone(),
1573 alice_device,
1574 None,
1575 None,
1576 flow_id,
1577 &start_content.as_start_content(),
1578 false,
1579 );
1580 let bob_sas = bob_sas
1581 .unwrap()
1582 .into_we_accepted_with_mac_method(vec![ShortAuthenticationString::Emoji], mac_method);
1583
1584 (alice_sas, bob_sas)
1585 }
1586
1587 #[test]
1588 fn start_content_accepting() {
1589 let mut start_content: SasV1Content = SasV1ContentInit {
1590 key_agreement_protocols: vec![
1591 KeyAgreementProtocol::Curve25519HkdfSha256,
1592 KeyAgreementProtocol::Curve25519,
1593 ],
1594 hashes: vec![HashAlgorithm::Sha256],
1595 message_authentication_codes: vec![
1596 #[allow(deprecated)]
1597 MessageAuthenticationCode::HkdfHmacSha256,
1598 MessageAuthenticationCode::from("org.matrix.msc3783.hkdf-hmac-sha256"),
1599 MessageAuthenticationCode::HkdfHmacSha256V2,
1600 ],
1601 short_authentication_string: vec![
1602 ShortAuthenticationString::Emoji,
1603 ShortAuthenticationString::Decimal,
1604 ],
1605 }
1606 .into();
1607
1608 let accepted_protocols = AcceptedProtocols::try_from(&start_content).unwrap();
1609
1610 assert_eq!(accepted_protocols.message_auth_code, SupportedMacMethod::HkdfHmacSha256V2);
1611 assert_eq!(
1612 accepted_protocols.key_agreement_protocol,
1613 KeyAgreementProtocol::Curve25519HkdfSha256
1614 );
1615
1616 start_content.message_authentication_codes = vec![
1617 #[allow(deprecated)]
1618 MessageAuthenticationCode::HkdfHmacSha256,
1619 MessageAuthenticationCode::from("org.matrix.msc3783.hkdf-hmac-sha256"),
1620 ];
1621 let accepted_protocols = AcceptedProtocols::try_from(&start_content).unwrap();
1622 assert_eq!(
1623 accepted_protocols.message_auth_code,
1624 SupportedMacMethod::Msc3783HkdfHmacSha256V2
1625 );
1626
1627 start_content.key_agreement_protocols = vec![KeyAgreementProtocol::Curve25519];
1628 AcceptedProtocols::try_from(&start_content)
1629 .expect_err("We don't support the old Curve25519 key agreement protocol");
1630 }
1631
1632 #[test]
1633 fn test_create_sas() {
1634 let (_, _) = get_sas_pair(None);
1635 }
1636
1637 #[test]
1638 fn test_sas_accept() {
1639 let (alice, bob) = get_sas_pair(None);
1640 let content = bob.as_content();
1641 let content = AcceptContent::from(&content);
1642
1643 alice.into_accepted(bob.user_id(), &content).unwrap();
1644 }
1645
1646 #[test]
1647 fn test_sas_key_share() {
1648 let (alice, bob) = get_sas_pair(None);
1649
1650 let content = bob.as_content();
1651 let content = AcceptContent::from(&content);
1652
1653 let alice: SasState<Accepted> = alice.into_accepted(bob.user_id(), &content).unwrap();
1654 let content = alice.as_content();
1655 let transaction_id = content.1.request_id;
1656 let content = KeyContent::try_from(&content.0).unwrap();
1657 let alice = alice.into_key_sent(&transaction_id).unwrap();
1658
1659 let bob = bob.into_key_received(alice.user_id(), &content).unwrap();
1660
1661 let content = bob.as_content();
1662 let transaction_id = content.1.request_id;
1663 let content = KeyContent::try_from(&content.0).unwrap();
1664
1665 let bob = bob.into_keys_exchanged(&transaction_id).unwrap();
1666
1667 let alice = alice.into_keys_exchanged(bob.user_id(), &content).unwrap();
1668
1669 assert_eq!(alice.get_decimal(), bob.get_decimal());
1670 assert_eq!(alice.get_emoji(), bob.get_emoji());
1671 }
1672
1673 fn full_flow_helper(mac_method: SupportedMacMethod) {
1674 let (alice, bob) = get_sas_pair(Some(mac_method));
1675
1676 let content = bob.as_content();
1677 let content = AcceptContent::from(&content);
1678
1679 assert_eq!(
1680 bob.state.accepted_protocols.message_auth_code, mac_method,
1681 "Bob should be using the specified MAC method."
1682 );
1683
1684 let alice: SasState<Accepted> = alice.into_accepted(bob.user_id(), &content).unwrap();
1685
1686 assert_eq!(
1687 alice.state.accepted_protocols.message_auth_code, mac_method,
1688 "Alice should use the our specified MAC method.",
1689 );
1690
1691 let content = alice.as_content();
1692 let request_id = content.1.request_id;
1693 let content = KeyContent::try_from(&content.0).unwrap();
1694
1695 let alice = alice.into_key_sent(&request_id).unwrap();
1696 let bob = bob.into_key_received(alice.user_id(), &content).unwrap();
1697
1698 let (content, request_info) = bob.as_content();
1699 let request_id = request_info.request_id;
1700 let content = KeyContent::try_from(&content).unwrap();
1701 let bob = bob.into_keys_exchanged(&request_id).unwrap();
1702
1703 let alice = alice.into_keys_exchanged(bob.user_id(), &content).unwrap();
1704
1705 assert_eq!(alice.get_decimal(), bob.get_decimal());
1706 assert_eq!(alice.get_emoji(), bob.get_emoji());
1707
1708 let bob_decimals = bob.get_decimal();
1709
1710 let bob = bob.confirm();
1711
1712 let content = bob.as_content();
1713 let content = MacContent::try_from(&content).unwrap();
1714
1715 let alice = alice.into_mac_received(bob.user_id(), &content).unwrap();
1716 assert!(!alice.get_emoji().is_empty());
1717 assert_eq!(alice.get_decimal(), bob_decimals);
1718 let alice = alice.confirm();
1719
1720 let content = alice.as_content();
1721 let content = MacContent::try_from(&content).unwrap();
1722 let bob = bob.into_done(alice.user_id(), &content).unwrap();
1723
1724 assert!(bob.verified_devices().contains(&bob.other_device()));
1725 assert!(alice.verified_devices().contains(&alice.other_device()));
1726 }
1727
1728 #[test]
1729 fn test_full_flow() {
1730 full_flow_helper(SupportedMacMethod::HkdfHmacSha256);
1731 }
1732
1733 #[test]
1734 fn test_full_flow_hkdf_hmac_sha_v2() {
1735 full_flow_helper(SupportedMacMethod::HkdfHmacSha256V2);
1736 }
1737
1738 #[test]
1739 fn test_full_flow_hkdf_msc3783() {
1740 full_flow_helper(SupportedMacMethod::Msc3783HkdfHmacSha256V2);
1741 }
1742
1743 #[test]
1744 fn test_sas_invalid_commitment() {
1745 let (alice, bob) = get_sas_pair(None);
1746
1747 let mut content = bob.as_content();
1748 let mut method = content.method_mut();
1749
1750 match &mut method {
1751 AcceptMethod::SasV1(c) => {
1752 c.commitment = Base64::empty();
1753 }
1754 _ => panic!("Unknown accept event content"),
1755 }
1756
1757 let content = AcceptContent::from(&content);
1758
1759 let alice: SasState<Accepted> = alice.into_accepted(bob.user_id(), &content).unwrap();
1760
1761 let content = alice.as_content();
1762 let content = KeyContent::try_from(&content.0).unwrap();
1763 let bob = bob.into_key_received(alice.user_id(), &content).unwrap();
1764 let content = bob.as_content();
1765 let content = KeyContent::try_from(&content.0).unwrap();
1766
1767 alice
1768 .into_key_received(bob.user_id(), &content)
1769 .expect_err("Didn't cancel on invalid commitment");
1770 }
1771
1772 #[test]
1773 fn test_sas_invalid_sender() {
1774 let (alice, bob) = get_sas_pair(None);
1775
1776 let content = bob.as_content();
1777 let content = AcceptContent::from(&content);
1778 let sender = user_id!("@malory:example.org");
1779 alice.into_accepted(sender, &content).expect_err("Didn't cancel on a invalid sender");
1780 }
1781
1782 #[test]
1783 fn test_sas_unknown_sas_method() {
1784 let (alice, bob) = get_sas_pair(None);
1785
1786 let mut content = bob.as_content();
1787 let mut method = content.method_mut();
1788
1789 match &mut method {
1790 AcceptMethod::SasV1(ref mut c) => {
1791 c.short_authentication_string = vec![];
1792 }
1793 _ => panic!("Unknown accept event content"),
1794 }
1795
1796 let content = AcceptContent::from(&content);
1797
1798 alice
1799 .into_accepted(bob.user_id(), &content)
1800 .expect_err("Didn't cancel on an invalid SAS method");
1801 }
1802
1803 #[test]
1804 fn test_sas_unknown_method() {
1805 let (alice, bob) = get_sas_pair(None);
1806
1807 let content = json!({
1808 "method": "m.sas.custom",
1809 "method_data": "something",
1810 "transaction_id": "some_id",
1811 });
1812
1813 let content: ToDeviceKeyVerificationAcceptEventContent =
1814 serde_json::from_value(content).unwrap();
1815 let content = AcceptContent::from(&content);
1816
1817 alice
1818 .into_accepted(bob.user_id(), &content)
1819 .expect_err("Didn't cancel on an unknown SAS method");
1820 }
1821
1822 #[async_test]
1823 async fn test_sas_from_start_unknown_method() {
1824 let alice = Account::with_device_id(alice_id(), alice_device_id());
1825 let alice_device = DeviceData::from_account(&alice);
1826
1827 let bob = Account::with_device_id(bob_id(), bob_device_id());
1828 let bob_device = DeviceData::from_account(&bob);
1829
1830 let flow_id = TransactionId::new().into();
1831 let alice_sas = SasState::<Created>::new(
1832 alice.static_data().clone(),
1833 bob_device,
1834 None,
1835 None,
1836 flow_id,
1837 false,
1838 None,
1839 );
1840
1841 let mut start_content = alice_sas.as_content();
1842 let method = start_content.method_mut();
1843
1844 match method {
1845 StartMethod::SasV1(ref mut c) => {
1846 c.message_authentication_codes = vec![];
1847 }
1848 _ => panic!("Unknown SAS start method"),
1849 }
1850
1851 let flow_id = start_content.flow_id();
1852 let content = StartContent::from(&start_content);
1853
1854 SasState::<Started>::from_start_event(
1855 bob.static_data().clone(),
1856 alice_device.clone(),
1857 None,
1858 None,
1859 flow_id,
1860 &content,
1861 false,
1862 )
1863 .expect_err("Didn't cancel on invalid MAC method");
1864
1865 let content = json!({
1866 "method": "m.sas.custom",
1867 "from_device": "DEVICEID",
1868 "method_data": "something",
1869 "transaction_id": "some_id",
1870 });
1871
1872 let content: ToDeviceKeyVerificationStartEventContent =
1873 serde_json::from_value(content).unwrap();
1874 let content = StartContent::from(&content);
1875 let flow_id = content.flow_id().to_owned();
1876
1877 SasState::<Started>::from_start_event(
1878 bob.static_data().clone(),
1879 alice_device,
1880 None,
1881 None,
1882 FlowId::ToDevice(flow_id.into()),
1883 &content,
1884 false,
1885 )
1886 .expect_err("Didn't cancel on unknown sas method");
1887 }
1888}