1use std::{matches, sync::Arc, time::Duration};
16
17use matrix_sdk_common::locks::Mutex;
18use ruma::{
19 DeviceId, OwnedTransactionId, TransactionId, UserId,
20 events::{
21 AnyMessageLikeEventContent, AnyToDeviceEventContent,
22 key::verification::{
23 HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode,
24 ShortAuthenticationString,
25 accept::{
26 AcceptMethod, KeyVerificationAcceptEventContent, SasV1Content as AcceptV1Content,
27 SasV1ContentInit as AcceptV1ContentInit, ToDeviceKeyVerificationAcceptEventContent,
28 },
29 cancel::CancelCode,
30 done::{KeyVerificationDoneEventContent, ToDeviceKeyVerificationDoneEventContent},
31 key::{KeyVerificationKeyEventContent, ToDeviceKeyVerificationKeyEventContent},
32 start::{
33 KeyVerificationStartEventContent, SasV1Content, SasV1ContentInit, StartMethod,
34 ToDeviceKeyVerificationStartEventContent,
35 },
36 },
37 relation::Reference,
38 },
39 serde::Base64,
40 time::Instant,
41};
42use serde::{Deserialize, Serialize};
43use tracing::info;
44use vodozemac::{
45 Curve25519PublicKey,
46 sas::{EstablishedSas, Mac, Sas},
47};
48
49use super::{
50 OutgoingContent,
51 helpers::{
52 SasIds, calculate_commitment, get_decimal, get_emoji, get_emoji_index, get_mac_content,
53 receive_mac_event,
54 },
55};
56use crate::{
57 OwnUserIdentityData,
58 identities::{DeviceData, UserIdentityData},
59 olm::StaticAccountData,
60 verification::{
61 Cancelled, Emoji, FlowId,
62 cache::RequestInfo,
63 event_enums::{
64 AcceptContent, DoneContent, KeyContent, MacContent, OwnedAcceptContent,
65 OwnedStartContent, StartContent,
66 },
67 },
68};
69
70const KEY_AGREEMENT_PROTOCOLS: &[KeyAgreementProtocol] =
71 &[KeyAgreementProtocol::Curve25519HkdfSha256];
72const HASHES: &[HashAlgorithm] = &[HashAlgorithm::Sha256];
73const STRINGS: &[ShortAuthenticationString] =
74 &[ShortAuthenticationString::Decimal, ShortAuthenticationString::Emoji];
75
76fn the_protocol_definitions(
77 short_auth_strings: Option<Vec<ShortAuthenticationString>>,
78) -> SasV1Content {
79 SasV1ContentInit {
80 short_authentication_string: short_auth_strings.unwrap_or_else(|| STRINGS.to_owned()),
81 key_agreement_protocols: KEY_AGREEMENT_PROTOCOLS.to_vec(),
82 message_authentication_codes: vec![
83 #[allow(deprecated)]
84 MessageAuthenticationCode::HkdfHmacSha256,
85 MessageAuthenticationCode::HkdfHmacSha256V2,
86 MessageAuthenticationCode::from("org.matrix.msc3783.hkdf-hmac-sha256"),
88 ],
89 hashes: HASHES.to_vec(),
90 }
91 .into()
92}
93
94const MAX_AGE: Duration = Duration::from_secs(60 * 5);
96
97const MAX_EVENT_TIMEOUT: Duration = Duration::from_secs(60);
99
100#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
104pub enum SupportedMacMethod {
105 #[serde(rename = "hkdf-hmac-sha256")]
106 HkdfHmacSha256,
107 #[serde(rename = "hkdf-hmac-sha256.v2")]
108 HkdfHmacSha256V2,
109 #[serde(rename = "org.matrix.msc3783.hkdf-hmac-sha256")]
110 Msc3783HkdfHmacSha256V2,
111}
112
113impl AsRef<str> for SupportedMacMethod {
114 fn as_ref(&self) -> &str {
115 match self {
116 SupportedMacMethod::HkdfHmacSha256 => "hkdf-hmac-sha256",
117 SupportedMacMethod::HkdfHmacSha256V2 => "hkdf-hmac-sha256.v2",
118 SupportedMacMethod::Msc3783HkdfHmacSha256V2 => "org.matrix.msc3783.hkdf-hmac-sha256",
119 }
120 }
121}
122
123impl From<SupportedMacMethod> for MessageAuthenticationCode {
124 fn from(m: SupportedMacMethod) -> Self {
125 MessageAuthenticationCode::from(m.as_ref())
126 }
127}
128
129impl TryFrom<&MessageAuthenticationCode> for SupportedMacMethod {
130 type Error = ();
131
132 fn try_from(value: &MessageAuthenticationCode) -> Result<Self, Self::Error> {
133 match value.as_str() {
134 "hkdf-hmac-sha256" => Ok(Self::HkdfHmacSha256),
135 "org.matrix.msc3783.hkdf-hmac-sha256" => Ok(Self::Msc3783HkdfHmacSha256V2),
136 "hkdf-hmac-sha256.v2" => Ok(Self::HkdfHmacSha256V2),
137 _ => Err(()),
138 }
139 }
140}
141
142impl SupportedMacMethod {
143 pub fn verify_mac(
149 &self,
150 sas: &EstablishedSas,
151 input: &str,
152 info: &str,
153 mac: &Base64,
154 ) -> Result<(), CancelCode> {
155 match self {
156 SupportedMacMethod::HkdfHmacSha256 => {
157 let calculated_mac = sas.calculate_mac_invalid_base64(input, info);
158 let calculated_mac = Base64::parse(calculated_mac)
159 .expect("We can always decode a Mac from vodozemac");
160
161 if calculated_mac != *mac { Err(CancelCode::KeyMismatch) } else { Ok(()) }
162 }
163 SupportedMacMethod::HkdfHmacSha256V2 | SupportedMacMethod::Msc3783HkdfHmacSha256V2 => {
164 let mac = Mac::from_slice(mac.as_bytes());
165 sas.verify_mac(input, info, &mac).map_err(|_| CancelCode::MismatchedSas)
166 }
167 }
168 }
169
170 pub fn calculate_mac(&self, sas: &EstablishedSas, input: &str, info: &str) -> Base64 {
176 match self {
177 SupportedMacMethod::HkdfHmacSha256 => {
178 Base64::parse(sas.calculate_mac_invalid_base64(input, info))
179 .expect("We can always decode our newly generated Mac")
180 }
181 SupportedMacMethod::HkdfHmacSha256V2 | SupportedMacMethod::Msc3783HkdfHmacSha256V2 => {
182 let mac = sas.calculate_mac(input, info);
183 Base64::new(mac.as_bytes().to_vec())
184 }
185 }
186 }
187}
188
189#[derive(Clone, Debug, PartialEq, Eq)]
192pub struct AcceptedProtocols {
193 pub key_agreement_protocol: KeyAgreementProtocol,
195 pub hash: HashAlgorithm,
197 pub message_auth_code: SupportedMacMethod,
199 pub short_auth_string: Vec<ShortAuthenticationString>,
202}
203
204impl TryFrom<AcceptV1Content> for AcceptedProtocols {
205 type Error = CancelCode;
206
207 fn try_from(content: AcceptV1Content) -> Result<Self, Self::Error> {
208 if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol)
209 || !HASHES.contains(&content.hash)
210 || (!content.short_authentication_string.contains(&ShortAuthenticationString::Emoji)
211 && !content
212 .short_authentication_string
213 .contains(&ShortAuthenticationString::Decimal))
214 {
215 Err(CancelCode::UnknownMethod)
216 } else {
217 let message_auth_code = (&content.message_authentication_code)
218 .try_into()
219 .map_err(|_| CancelCode::UnknownMethod)?;
220
221 Ok(Self {
222 hash: content.hash,
223 key_agreement_protocol: content.key_agreement_protocol,
224 message_auth_code,
225 short_auth_string: content.short_authentication_string,
226 })
227 }
228 }
229}
230
231impl TryFrom<&SasV1Content> for AcceptedProtocols {
232 type Error = CancelCode;
233
234 fn try_from(method_content: &SasV1Content) -> Result<Self, Self::Error> {
235 if !method_content
236 .key_agreement_protocols
237 .contains(&KeyAgreementProtocol::Curve25519HkdfSha256)
238 || !method_content.hashes.contains(&HashAlgorithm::Sha256)
239 || (!method_content
240 .short_authentication_string
241 .contains(&ShortAuthenticationString::Decimal)
242 && !method_content
243 .short_authentication_string
244 .contains(&ShortAuthenticationString::Emoji))
245 {
246 Err(CancelCode::UnknownMethod)
247 } else {
248 let mac_methods: Vec<SupportedMacMethod> = method_content
249 .message_authentication_codes
250 .iter()
251 .filter_map(|m| SupportedMacMethod::try_from(m).ok())
252 .collect();
253
254 let message_auth_code =
255 if mac_methods.contains(&SupportedMacMethod::HkdfHmacSha256V2) {
256 Some(SupportedMacMethod::HkdfHmacSha256V2)
257 } else if mac_methods.contains(&SupportedMacMethod::Msc3783HkdfHmacSha256V2) {
258 Some(SupportedMacMethod::Msc3783HkdfHmacSha256V2)
259 } else {
260 mac_methods.first().copied()
261 }
262 .ok_or(CancelCode::UnknownMethod)?;
263
264 let mut short_auth_string = vec![];
265
266 if method_content
267 .short_authentication_string
268 .contains(&ShortAuthenticationString::Decimal)
269 {
270 short_auth_string.push(ShortAuthenticationString::Decimal)
271 }
272
273 if method_content
274 .short_authentication_string
275 .contains(&ShortAuthenticationString::Emoji)
276 {
277 short_auth_string.push(ShortAuthenticationString::Emoji);
278 }
279
280 Ok(Self {
281 hash: HashAlgorithm::Sha256,
282 key_agreement_protocol: KeyAgreementProtocol::Curve25519HkdfSha256,
283 message_auth_code,
284 short_auth_string,
285 })
286 }
287 }
288}
289
290#[cfg(not(tarpaulin_include))]
291impl Default for AcceptedProtocols {
292 fn default() -> Self {
293 AcceptedProtocols {
294 hash: HashAlgorithm::Sha256,
295 key_agreement_protocol: KeyAgreementProtocol::Curve25519HkdfSha256,
296 message_auth_code: SupportedMacMethod::HkdfHmacSha256V2,
297 short_auth_string: vec![
298 ShortAuthenticationString::Decimal,
299 ShortAuthenticationString::Emoji,
300 ],
301 }
302 }
303}
304
305#[derive(Clone)]
310pub struct SasState<S: Clone> {
311 inner: Arc<Mutex<Option<Sas>>>,
313
314 our_public_key: Curve25519PublicKey,
316
317 ids: Box<SasIds>,
320
321 creation_time: Arc<Instant>,
325
326 last_event_time: Arc<Instant>,
328
329 pub verification_flow_id: Arc<FlowId>,
334
335 pub state: Arc<S>,
337
338 pub started_from_request: bool,
340}
341
342impl<S: Clone> SasState<S> {
343 fn handle_key_content(
344 &self,
345 sender: &UserId,
346 content: &KeyContent<'_>,
347 ) -> Result<EstablishedSas, CancelCode> {
348 self.check_event(sender, content.flow_id())?;
349
350 let their_public_key = Curve25519PublicKey::from_slice(content.public_key().as_bytes())
351 .map_err(|_| CancelCode::from("Invalid public key"))?;
352
353 if let Some(sas) = self.inner.lock().take() {
354 sas.diffie_hellman(their_public_key).map_err(|_| "Invalid public key".into())
355 } else {
356 Err(CancelCode::UnexpectedMessage)
357 }
358 }
359}
360
361#[cfg(not(tarpaulin_include))]
362impl<S: Clone + std::fmt::Debug> std::fmt::Debug for SasState<S> {
363 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 f.debug_struct("SasState")
365 .field("ids", &self.ids)
366 .field("flow_id", &self.verification_flow_id)
367 .field("state", &self.state)
368 .finish()
369 }
370}
371
372#[derive(Clone, Debug)]
374pub struct Created {
375 pub protocol_definitions: SasV1Content,
376}
377
378#[derive(Clone, Debug)]
380pub struct Started {
381 commitment: Base64,
382 pub protocol_definitions: SasV1Content,
383 pub accepted_protocols: AcceptedProtocols,
384}
385
386#[derive(Clone, Debug)]
389pub struct Accepted {
390 pub accepted_protocols: AcceptedProtocols,
391 start_content: Arc<OwnedStartContent>,
392 pub request_id: OwnedTransactionId,
393 commitment: Base64,
394}
395
396#[derive(Clone, Debug)]
399pub struct WeAccepted {
400 we_started: bool,
401 pub accepted_protocols: AcceptedProtocols,
402 commitment: Base64,
403}
404
405#[derive(Clone, Debug)]
410pub struct KeyReceived {
411 sas: Arc<Mutex<EstablishedSas>>,
412 we_started: bool,
413 pub request_id: OwnedTransactionId,
414 pub accepted_protocols: AcceptedProtocols,
415}
416
417#[derive(Clone, Debug)]
418pub struct KeySent {
419 we_started: bool,
420 start_content: Arc<OwnedStartContent>,
421 commitment: Base64,
422 pub accepted_protocols: AcceptedProtocols,
423}
424
425#[derive(Clone, Debug)]
426pub struct KeysExchanged {
427 sas: Arc<Mutex<EstablishedSas>>,
428 we_started: bool,
429 pub accepted_protocols: AcceptedProtocols,
430}
431
432#[derive(Clone, Debug)]
436pub struct Confirmed {
437 sas: Arc<Mutex<EstablishedSas>>,
438 pub accepted_protocols: AcceptedProtocols,
439}
440
441#[derive(Clone, Debug)]
445pub struct MacReceived {
446 sas: Arc<Mutex<EstablishedSas>>,
447 we_started: bool,
448 verified_devices: Arc<[DeviceData]>,
449 verified_master_keys: Arc<[UserIdentityData]>,
450 pub accepted_protocols: AcceptedProtocols,
451}
452
453#[derive(Clone, Debug)]
457pub struct WaitingForDone {
458 sas: Arc<Mutex<EstablishedSas>>,
459 verified_devices: Arc<[DeviceData]>,
460 verified_master_keys: Arc<[UserIdentityData]>,
461 pub accepted_protocols: AcceptedProtocols,
462}
463
464#[derive(Clone, Debug)]
470pub struct Done {
471 sas: Arc<Mutex<EstablishedSas>>,
472 verified_devices: Arc<[DeviceData]>,
473 verified_master_keys: Arc<[UserIdentityData]>,
474 pub accepted_protocols: AcceptedProtocols,
475}
476
477impl<S: Clone> SasState<S> {
478 #[cfg(test)]
480 pub fn user_id(&self) -> &UserId {
481 &self.ids.account.user_id
482 }
483
484 pub fn device_id(&self) -> &DeviceId {
486 &self.ids.account.device_id
487 }
488
489 #[cfg(test)]
490 pub fn other_device(&self) -> DeviceData {
491 self.ids.other_device.clone()
492 }
493
494 pub fn cancel(self, cancelled_by_us: bool, cancel_code: CancelCode) -> SasState<Cancelled> {
495 SasState {
496 inner: self.inner,
497 our_public_key: self.our_public_key,
498 ids: self.ids,
499 creation_time: self.creation_time,
500 last_event_time: self.last_event_time,
501 verification_flow_id: self.verification_flow_id,
502 state: Arc::new(Cancelled::new(cancelled_by_us, cancel_code)),
503 started_from_request: self.started_from_request,
504 }
505 }
506
507 pub fn timed_out(&self) -> bool {
509 self.creation_time.elapsed() > MAX_AGE || self.last_event_time.elapsed() > MAX_EVENT_TIMEOUT
510 }
511
512 #[allow(dead_code)]
514 pub fn is_dm_verification(&self) -> bool {
515 matches!(&*self.verification_flow_id, FlowId::InRoom(_, _))
516 }
517
518 #[cfg(test)]
519 #[allow(dead_code)]
520 pub fn set_creation_time(&mut self, time: Instant) {
521 self.creation_time = Arc::new(time);
522 }
523
524 fn check_event(&self, sender: &UserId, flow_id: &str) -> Result<(), CancelCode> {
525 if *flow_id != *self.verification_flow_id.as_str() {
526 Err(CancelCode::UnknownTransaction)
527 } else if sender != self.ids.other_device.user_id() {
528 Err(CancelCode::UserMismatch)
529 } else if self.timed_out() {
530 Err(CancelCode::Timeout)
531 } else {
532 Ok(())
533 }
534 }
535}
536
537impl SasState<Created> {
538 pub fn new(
548 account: StaticAccountData,
549 other_device: DeviceData,
550 own_identity: Option<OwnUserIdentityData>,
551 other_identity: Option<UserIdentityData>,
552 flow_id: FlowId,
553 started_from_request: bool,
554 short_auth_strings: Option<Vec<ShortAuthenticationString>>,
555 ) -> SasState<Created> {
556 Self::new_helper(
557 flow_id,
558 account,
559 other_device,
560 own_identity,
561 other_identity,
562 started_from_request,
563 short_auth_strings,
564 )
565 }
566
567 fn new_helper(
568 flow_id: FlowId,
569 account: StaticAccountData,
570 other_device: DeviceData,
571 own_identity: Option<OwnUserIdentityData>,
572 other_identity: Option<UserIdentityData>,
573 started_from_request: bool,
574 short_auth_strings: Option<Vec<ShortAuthenticationString>>,
575 ) -> SasState<Created> {
576 let sas = Sas::new();
577 let our_public_key = sas.public_key();
578
579 let protocol_definitions = the_protocol_definitions(short_auth_strings);
580
581 SasState {
582 inner: Arc::new(Mutex::new(Some(sas))),
583 our_public_key,
584 ids: Box::new(SasIds { account, other_device, other_identity, own_identity }),
585 verification_flow_id: flow_id.into(),
586
587 creation_time: Arc::new(Instant::now()),
588 last_event_time: Arc::new(Instant::now()),
589 started_from_request,
590
591 state: Arc::new(Created { protocol_definitions }),
592 }
593 }
594
595 pub fn as_content(&self) -> OwnedStartContent {
596 match self.verification_flow_id.as_ref() {
597 FlowId::ToDevice(s) => {
598 OwnedStartContent::ToDevice(ToDeviceKeyVerificationStartEventContent::new(
599 self.device_id().into(),
600 s.clone(),
601 StartMethod::SasV1(self.state.protocol_definitions.clone()),
602 ))
603 }
604 FlowId::InRoom(r, e) => OwnedStartContent::Room(
605 r.clone(),
606 KeyVerificationStartEventContent::new(
607 self.device_id().into(),
608 StartMethod::SasV1(self.state.protocol_definitions.clone()),
609 Reference::new(e.clone()),
610 ),
611 ),
612 }
613 }
614
615 pub fn into_accepted(
623 self,
624 sender: &UserId,
625 content: &AcceptContent<'_>,
626 ) -> Result<SasState<Accepted>, SasState<Cancelled>> {
627 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
628
629 let AcceptMethod::SasV1(content) = content.method() else {
630 return Err(self.cancel(true, CancelCode::UnknownMethod));
631 };
632
633 let accepted_protocols = AcceptedProtocols::try_from(content.clone())
634 .map_err(|c| self.clone().cancel(true, c))?;
635
636 let start_content = self.as_content().into();
637
638 Ok(SasState {
639 inner: self.inner,
640 our_public_key: self.our_public_key,
641 ids: self.ids,
642 verification_flow_id: self.verification_flow_id,
643 creation_time: self.creation_time,
644 last_event_time: Instant::now().into(),
645 started_from_request: self.started_from_request,
646 state: Arc::new(Accepted {
647 start_content,
648 commitment: content.commitment.clone(),
649 request_id: TransactionId::new(),
650 accepted_protocols,
651 }),
652 })
653 }
654}
655
656impl SasState<Started> {
657 pub fn from_start_event(
671 account: StaticAccountData,
672 other_device: DeviceData,
673 own_identity: Option<OwnUserIdentityData>,
674 other_identity: Option<UserIdentityData>,
675 flow_id: FlowId,
676 content: &StartContent<'_>,
677 started_from_request: bool,
678 ) -> Result<SasState<Started>, SasState<Cancelled>> {
679 let flow_id = Arc::new(flow_id);
680
681 let sas = Sas::new();
682 let our_public_key = sas.public_key();
683
684 let canceled = || SasState {
685 inner: Arc::new(Mutex::new(None)),
686 our_public_key,
687
688 creation_time: Arc::new(Instant::now()),
689 last_event_time: Arc::new(Instant::now()),
690 started_from_request,
691
692 ids: Box::new(SasIds {
693 account: account.clone(),
694 other_device: other_device.clone(),
695 own_identity: own_identity.clone(),
696 other_identity: other_identity.clone(),
697 }),
698
699 verification_flow_id: flow_id.clone(),
700 state: Arc::new(Cancelled::new(true, CancelCode::UnknownMethod)),
701 };
702
703 let state = match content.method() {
704 StartMethod::SasV1(method_content) => {
705 let commitment = calculate_commitment(our_public_key, content);
706
707 info!(
708 public_key = our_public_key.to_base64(),
709 ?commitment,
710 ?content,
711 "Calculated SAS commitment",
712 );
713
714 let Ok(accepted_protocols) = AcceptedProtocols::try_from(method_content) else {
715 return Err(canceled());
716 };
717
718 Started {
719 protocol_definitions: method_content.to_owned(),
720 accepted_protocols,
721 commitment,
722 }
723 }
724 _ => return Err(canceled()),
725 };
726
727 Ok(SasState {
728 inner: Arc::new(Mutex::new(Some(sas))),
729 our_public_key,
730
731 ids: Box::new(SasIds { account, other_device, other_identity, own_identity }),
732
733 creation_time: Arc::new(Instant::now()),
734 last_event_time: Arc::new(Instant::now()),
735 started_from_request,
736
737 verification_flow_id: flow_id,
738
739 state: Arc::new(state),
740 })
741 }
742
743 #[cfg(test)]
744 fn into_we_accepted_with_mac_method(
745 self,
746 methods: Vec<ShortAuthenticationString>,
747 mac_method: Option<SupportedMacMethod>,
748 ) -> SasState<WeAccepted> {
749 let mut accepted_protocols = self.state.accepted_protocols.to_owned();
750
751 if let Some(mac_method) = mac_method {
752 accepted_protocols.message_auth_code = mac_method;
753 }
754
755 self.into_we_accepted_helper(accepted_protocols, methods)
756 }
757
758 fn into_we_accepted_helper(
759 self,
760 mut accepted_protocols: AcceptedProtocols,
761 methods: Vec<ShortAuthenticationString>,
762 ) -> SasState<WeAccepted> {
763 accepted_protocols.short_auth_string = methods;
764
765 if !accepted_protocols.short_auth_string.contains(&ShortAuthenticationString::Decimal) {
767 accepted_protocols.short_auth_string.push(ShortAuthenticationString::Decimal);
768 }
769
770 SasState {
771 inner: self.inner,
772 our_public_key: self.our_public_key,
773 ids: self.ids,
774 verification_flow_id: self.verification_flow_id,
775 creation_time: self.creation_time,
776 last_event_time: self.last_event_time,
777 started_from_request: self.started_from_request,
778 state: Arc::new(WeAccepted {
779 we_started: false,
780 accepted_protocols,
781 commitment: self.state.commitment.clone(),
782 }),
783 }
784 }
785
786 pub fn into_we_accepted(self, methods: Vec<ShortAuthenticationString>) -> SasState<WeAccepted> {
787 let accepted_protocols = self.state.accepted_protocols.to_owned();
788 self.into_we_accepted_helper(accepted_protocols, methods)
789 }
790
791 fn as_content(&self) -> OwnedStartContent {
792 match self.verification_flow_id.as_ref() {
793 FlowId::ToDevice(s) => {
794 OwnedStartContent::ToDevice(ToDeviceKeyVerificationStartEventContent::new(
795 self.device_id().into(),
796 s.clone(),
797 StartMethod::SasV1(self.state.protocol_definitions.to_owned()),
798 ))
799 }
800 FlowId::InRoom(r, e) => OwnedStartContent::Room(
801 r.clone(),
802 KeyVerificationStartEventContent::new(
803 self.device_id().into(),
804 StartMethod::SasV1(self.state.protocol_definitions.to_owned()),
805 Reference::new(e.clone()),
806 ),
807 ),
808 }
809 }
810
811 pub fn into_accepted(
824 self,
825 sender: &UserId,
826 content: &AcceptContent<'_>,
827 ) -> Result<SasState<Accepted>, SasState<Cancelled>> {
828 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
829
830 let AcceptMethod::SasV1(content) = content.method() else {
831 return Err(self.cancel(true, CancelCode::UnknownMethod));
832 };
833
834 let accepted_protocols = AcceptedProtocols::try_from(content.clone())
835 .map_err(|c| self.clone().cancel(true, c))?;
836
837 let start_content = self.as_content().into();
838
839 Ok(SasState {
840 inner: self.inner,
841 our_public_key: self.our_public_key,
842 ids: self.ids,
843 verification_flow_id: self.verification_flow_id,
844 creation_time: self.creation_time,
845 last_event_time: Instant::now().into(),
846 started_from_request: self.started_from_request,
847 state: Arc::new(Accepted {
848 start_content,
849 commitment: content.commitment.clone(),
850 request_id: TransactionId::new(),
851 accepted_protocols,
852 }),
853 })
854 }
855}
856
857impl SasState<WeAccepted> {
858 pub fn as_content(&self) -> OwnedAcceptContent {
866 let method = AcceptMethod::SasV1(
867 AcceptV1ContentInit {
868 commitment: self.state.commitment.clone(),
869 hash: self.state.accepted_protocols.hash.clone(),
870 key_agreement_protocol: self
871 .state
872 .accepted_protocols
873 .key_agreement_protocol
874 .clone(),
875 message_authentication_code: self.state.accepted_protocols.message_auth_code.into(),
876 short_authentication_string: self
877 .state
878 .accepted_protocols
879 .short_auth_string
880 .clone(),
881 }
882 .into(),
883 );
884
885 match self.verification_flow_id.as_ref() {
886 FlowId::ToDevice(s) => {
887 ToDeviceKeyVerificationAcceptEventContent::new(s.clone(), method).into()
888 }
889 FlowId::InRoom(r, e) => (
890 r.clone(),
891 KeyVerificationAcceptEventContent::new(method, Reference::new(e.clone())),
892 )
893 .into(),
894 }
895 }
896
897 pub fn into_key_received(
906 self,
907 sender: &UserId,
908 content: &KeyContent<'_>,
909 ) -> Result<SasState<KeyReceived>, SasState<Cancelled>> {
910 let established =
911 self.handle_key_content(sender, content).map_err(|c| self.clone().cancel(true, c))?;
912
913 Ok(SasState {
914 inner: self.inner,
915 our_public_key: self.our_public_key,
916 ids: self.ids,
917 verification_flow_id: self.verification_flow_id,
918 creation_time: self.creation_time,
919 last_event_time: Instant::now().into(),
920 started_from_request: self.started_from_request,
921 state: Arc::new(KeyReceived {
922 sas: Mutex::new(established).into(),
923 we_started: self.state.we_started,
924 request_id: TransactionId::new(),
925 accepted_protocols: self.state.accepted_protocols.clone(),
926 }),
927 })
928 }
929}
930
931impl SasState<Accepted> {
932 pub fn into_key_received(
941 self,
942 sender: &UserId,
943 content: &KeyContent<'_>,
944 ) -> Result<SasState<KeyReceived>, SasState<Cancelled>> {
945 let established =
946 self.handle_key_content(sender, content).map_err(|c| self.clone().cancel(true, c))?;
947
948 let their_public_key = established.their_public_key();
949
950 let commitment =
951 calculate_commitment(their_public_key, &self.state.start_content.as_start_content());
952
953 if self.state.commitment == commitment {
954 Ok(SasState {
955 inner: self.inner,
956 our_public_key: self.our_public_key,
957 ids: self.ids,
958 verification_flow_id: self.verification_flow_id,
959 creation_time: self.creation_time,
960 last_event_time: Instant::now().into(),
961 started_from_request: self.started_from_request,
962 state: Arc::new(KeyReceived {
963 sas: Mutex::new(established).into(),
964 we_started: true,
965 request_id: self.state.request_id.to_owned(),
966 accepted_protocols: self.state.accepted_protocols.clone(),
967 }),
968 })
969 } else {
970 Err(self.cancel(true, CancelCode::KeyMismatch))
971 }
972 }
973
974 pub fn into_key_sent(self, request_id: &TransactionId) -> Option<SasState<KeySent>> {
975 (self.state.request_id == request_id).then(|| SasState {
976 inner: self.inner,
977 our_public_key: self.our_public_key,
978 ids: self.ids,
979 verification_flow_id: self.verification_flow_id,
980 creation_time: self.creation_time,
981 last_event_time: Instant::now().into(),
982 started_from_request: self.started_from_request,
983 state: Arc::new(KeySent {
984 we_started: true,
985 start_content: self.state.start_content.clone(),
986 commitment: self.state.commitment.clone(),
987 accepted_protocols: self.state.accepted_protocols.clone(),
988 }),
989 })
990 }
991
992 pub fn as_content(&self) -> (OutgoingContent, RequestInfo) {
996 let content = match &*self.verification_flow_id {
997 FlowId::ToDevice(s) => AnyToDeviceEventContent::KeyVerificationKey(
998 ToDeviceKeyVerificationKeyEventContent::new(
999 s.clone(),
1000 Base64::new(self.our_public_key.to_vec()),
1001 ),
1002 )
1003 .into(),
1004 FlowId::InRoom(r, e) => (
1005 r.clone(),
1006 AnyMessageLikeEventContent::KeyVerificationKey(
1007 KeyVerificationKeyEventContent::new(
1008 Base64::new(self.our_public_key.to_vec()),
1009 Reference::new(e.clone()),
1010 ),
1011 ),
1012 )
1013 .into(),
1014 };
1015
1016 (
1017 content,
1018 RequestInfo {
1019 flow_id: (*self.verification_flow_id).to_owned(),
1020 request_id: self.state.request_id.to_owned(),
1021 },
1022 )
1023 }
1024}
1025
1026impl SasState<KeySent> {
1027 pub fn into_keys_exchanged(
1028 self,
1029 sender: &UserId,
1030 content: &KeyContent<'_>,
1031 ) -> Result<SasState<KeysExchanged>, SasState<Cancelled>> {
1032 let established =
1033 self.handle_key_content(sender, content).map_err(|c| self.clone().cancel(true, c))?;
1034
1035 let their_public_key = established.their_public_key();
1036 let commitment =
1037 calculate_commitment(their_public_key, &self.state.start_content.as_start_content());
1038
1039 if self.state.commitment == commitment {
1040 Ok(SasState {
1041 inner: self.inner,
1042 our_public_key: self.our_public_key,
1043 ids: self.ids,
1044 verification_flow_id: self.verification_flow_id,
1045 creation_time: self.creation_time,
1046 last_event_time: Instant::now().into(),
1047 started_from_request: self.started_from_request,
1048 state: Arc::new(KeysExchanged {
1049 sas: Mutex::new(established).into(),
1050 we_started: self.state.we_started,
1051 accepted_protocols: self.state.accepted_protocols.clone(),
1052 }),
1053 })
1054 } else {
1055 Err(self.cancel(true, CancelCode::KeyMismatch))
1056 }
1057 }
1058}
1059
1060impl SasState<KeyReceived> {
1061 pub fn as_content(&self) -> (OutgoingContent, RequestInfo) {
1066 let content = match &*self.verification_flow_id {
1067 FlowId::ToDevice(s) => AnyToDeviceEventContent::KeyVerificationKey(
1068 ToDeviceKeyVerificationKeyEventContent::new(
1069 s.clone(),
1070 Base64::new(self.our_public_key.to_vec()),
1071 ),
1072 )
1073 .into(),
1074 FlowId::InRoom(r, e) => (
1075 r.clone(),
1076 AnyMessageLikeEventContent::KeyVerificationKey(
1077 KeyVerificationKeyEventContent::new(
1078 Base64::new(self.our_public_key.to_vec()),
1079 Reference::new(e.clone()),
1080 ),
1081 ),
1082 )
1083 .into(),
1084 };
1085
1086 (
1087 content,
1088 RequestInfo {
1089 flow_id: (*self.verification_flow_id).to_owned(),
1090 request_id: self.state.request_id.to_owned(),
1091 },
1092 )
1093 }
1094
1095 pub fn into_keys_exchanged(
1096 self,
1097 request_id: &TransactionId,
1098 ) -> Option<SasState<KeysExchanged>> {
1099 (self.state.request_id == request_id).then(|| SasState {
1100 inner: self.inner,
1101 our_public_key: self.our_public_key,
1102 ids: self.ids,
1103 verification_flow_id: self.verification_flow_id,
1104 creation_time: self.creation_time,
1105 last_event_time: Instant::now().into(),
1106 started_from_request: self.started_from_request,
1107 state: KeysExchanged {
1108 sas: self.state.sas.clone(),
1109 we_started: self.state.we_started,
1110 accepted_protocols: self.state.accepted_protocols.clone(),
1111 }
1112 .into(),
1113 })
1114 }
1115}
1116
1117impl SasState<KeysExchanged> {
1118 pub fn get_emoji(&self) -> [Emoji; 7] {
1123 get_emoji(
1124 &self.state.sas.lock(),
1125 &self.ids,
1126 self.verification_flow_id.as_str(),
1127 self.state.we_started,
1128 )
1129 }
1130
1131 pub fn get_emoji_index(&self) -> [u8; 7] {
1136 get_emoji_index(
1137 &self.state.sas.lock(),
1138 &self.ids,
1139 self.verification_flow_id.as_str(),
1140 self.state.we_started,
1141 )
1142 }
1143
1144 pub fn get_decimal(&self) -> (u16, u16, u16) {
1149 get_decimal(
1150 &self.state.sas.lock(),
1151 &self.ids,
1152 self.verification_flow_id.as_str(),
1153 self.state.we_started,
1154 )
1155 }
1156
1157 pub fn into_mac_received(
1165 self,
1166 sender: &UserId,
1167 content: &MacContent<'_>,
1168 ) -> Result<SasState<MacReceived>, SasState<Cancelled>> {
1169 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
1170
1171 let (devices, master_keys) = receive_mac_event(
1172 &self.state.sas.lock(),
1173 &self.ids,
1174 self.verification_flow_id.as_str(),
1175 sender,
1176 self.state.accepted_protocols.message_auth_code,
1177 content,
1178 )
1179 .map_err(|c| self.clone().cancel(true, c))?;
1180
1181 Ok(SasState {
1182 inner: self.inner,
1183 our_public_key: self.our_public_key,
1184 verification_flow_id: self.verification_flow_id,
1185 creation_time: self.creation_time,
1186 last_event_time: Instant::now().into(),
1187 ids: self.ids,
1188 started_from_request: self.started_from_request,
1189 state: Arc::new(MacReceived {
1190 sas: self.state.sas.clone(),
1191 we_started: self.state.we_started,
1192 verified_devices: devices.into(),
1193 verified_master_keys: master_keys.into(),
1194 accepted_protocols: self.state.accepted_protocols.clone(),
1195 }),
1196 })
1197 }
1198
1199 pub fn confirm(self) -> SasState<Confirmed> {
1204 SasState {
1205 inner: self.inner,
1206 our_public_key: self.our_public_key,
1207 started_from_request: self.started_from_request,
1208 verification_flow_id: self.verification_flow_id,
1209 creation_time: self.creation_time,
1210 last_event_time: self.last_event_time,
1211 ids: self.ids,
1212 state: Arc::new(Confirmed {
1213 sas: self.state.sas.clone(),
1214 accepted_protocols: self.state.accepted_protocols.clone(),
1215 }),
1216 }
1217 }
1218}
1219
1220impl SasState<Confirmed> {
1221 pub fn into_done(
1229 self,
1230 sender: &UserId,
1231 content: &MacContent<'_>,
1232 ) -> Result<SasState<Done>, SasState<Cancelled>> {
1233 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
1234
1235 let (devices, master_keys) = receive_mac_event(
1236 &self.state.sas.lock(),
1237 &self.ids,
1238 self.verification_flow_id.as_str(),
1239 sender,
1240 self.state.accepted_protocols.message_auth_code,
1241 content,
1242 )
1243 .map_err(|c| self.clone().cancel(true, c))?;
1244
1245 Ok(SasState {
1246 inner: self.inner,
1247 our_public_key: self.our_public_key,
1248 creation_time: self.creation_time,
1249 last_event_time: Instant::now().into(),
1250 verification_flow_id: self.verification_flow_id,
1251 started_from_request: self.started_from_request,
1252 ids: self.ids,
1253
1254 state: Arc::new(Done {
1255 sas: self.state.sas.clone(),
1256 verified_devices: devices.into(),
1257 verified_master_keys: master_keys.into(),
1258 accepted_protocols: self.state.accepted_protocols.clone(),
1259 }),
1260 })
1261 }
1262
1263 pub fn into_waiting_for_done(
1273 self,
1274 sender: &UserId,
1275 content: &MacContent<'_>,
1276 ) -> Result<SasState<WaitingForDone>, SasState<Cancelled>> {
1277 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
1278
1279 let (devices, master_keys) = receive_mac_event(
1280 &self.state.sas.lock(),
1281 &self.ids,
1282 self.verification_flow_id.as_str(),
1283 sender,
1284 self.state.accepted_protocols.message_auth_code,
1285 content,
1286 )
1287 .map_err(|c| self.clone().cancel(true, c))?;
1288
1289 Ok(SasState {
1290 inner: self.inner,
1291 our_public_key: self.our_public_key,
1292 creation_time: self.creation_time,
1293 last_event_time: Instant::now().into(),
1294 verification_flow_id: self.verification_flow_id,
1295 started_from_request: self.started_from_request,
1296 ids: self.ids,
1297
1298 state: Arc::new(WaitingForDone {
1299 sas: self.state.sas.clone(),
1300 verified_devices: devices.into(),
1301 verified_master_keys: master_keys.into(),
1302 accepted_protocols: self.state.accepted_protocols.clone(),
1303 }),
1304 })
1305 }
1306
1307 pub fn as_content(&self) -> OutgoingContent {
1311 get_mac_content(
1312 &self.state.sas.lock(),
1313 &self.ids,
1314 &self.verification_flow_id,
1315 self.state.accepted_protocols.message_auth_code,
1316 )
1317 }
1318}
1319
1320impl SasState<MacReceived> {
1321 pub fn confirm(self) -> SasState<Done> {
1326 SasState {
1327 inner: self.inner,
1328 our_public_key: self.our_public_key,
1329 verification_flow_id: self.verification_flow_id,
1330 creation_time: self.creation_time,
1331 started_from_request: self.started_from_request,
1332 last_event_time: self.last_event_time,
1333 ids: self.ids,
1334 state: Arc::new(Done {
1335 sas: self.state.sas.clone(),
1336 verified_devices: self.state.verified_devices.clone(),
1337 verified_master_keys: self.state.verified_master_keys.clone(),
1338 accepted_protocols: self.state.accepted_protocols.clone(),
1339 }),
1340 }
1341 }
1342
1343 pub fn confirm_and_wait_for_done(self) -> SasState<WaitingForDone> {
1350 SasState {
1351 inner: self.inner,
1352 our_public_key: self.our_public_key,
1353 verification_flow_id: self.verification_flow_id,
1354 creation_time: self.creation_time,
1355 started_from_request: self.started_from_request,
1356 last_event_time: self.last_event_time,
1357 ids: self.ids,
1358 state: Arc::new(WaitingForDone {
1359 sas: self.state.sas.clone(),
1360 verified_devices: self.state.verified_devices.clone(),
1361 verified_master_keys: self.state.verified_master_keys.clone(),
1362 accepted_protocols: self.state.accepted_protocols.clone(),
1363 }),
1364 }
1365 }
1366
1367 pub fn get_emoji(&self) -> [Emoji; 7] {
1372 get_emoji(
1373 &self.state.sas.lock(),
1374 &self.ids,
1375 self.verification_flow_id.as_str(),
1376 self.state.we_started,
1377 )
1378 }
1379
1380 pub fn get_emoji_index(&self) -> [u8; 7] {
1385 get_emoji_index(
1386 &self.state.sas.lock(),
1387 &self.ids,
1388 self.verification_flow_id.as_str(),
1389 self.state.we_started,
1390 )
1391 }
1392
1393 pub fn get_decimal(&self) -> (u16, u16, u16) {
1398 get_decimal(
1399 &self.state.sas.lock(),
1400 &self.ids,
1401 self.verification_flow_id.as_str(),
1402 self.state.we_started,
1403 )
1404 }
1405}
1406
1407impl SasState<WaitingForDone> {
1408 pub fn as_content(&self) -> OutgoingContent {
1413 get_mac_content(
1414 &self.state.sas.lock(),
1415 &self.ids,
1416 &self.verification_flow_id,
1417 self.state.accepted_protocols.message_auth_code,
1418 )
1419 }
1420
1421 pub fn done_content(&self) -> OutgoingContent {
1422 match self.verification_flow_id.as_ref() {
1423 FlowId::ToDevice(t) => AnyToDeviceEventContent::KeyVerificationDone(
1424 ToDeviceKeyVerificationDoneEventContent::new(t.to_owned()),
1425 )
1426 .into(),
1427 FlowId::InRoom(r, e) => (
1428 r.clone(),
1429 AnyMessageLikeEventContent::KeyVerificationDone(
1430 KeyVerificationDoneEventContent::new(Reference::new(e.clone())),
1431 ),
1432 )
1433 .into(),
1434 }
1435 }
1436
1437 pub fn into_done(
1445 self,
1446 sender: &UserId,
1447 content: &DoneContent<'_>,
1448 ) -> Result<SasState<Done>, SasState<Cancelled>> {
1449 self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
1450
1451 Ok(SasState {
1452 inner: self.inner,
1453 our_public_key: self.our_public_key,
1454 creation_time: self.creation_time,
1455 last_event_time: Instant::now().into(),
1456 verification_flow_id: self.verification_flow_id,
1457 started_from_request: self.started_from_request,
1458 ids: self.ids,
1459
1460 state: Arc::new(Done {
1461 sas: self.state.sas.clone(),
1462 verified_devices: self.state.verified_devices.clone(),
1463 verified_master_keys: self.state.verified_master_keys.clone(),
1464 accepted_protocols: self.state.accepted_protocols.clone(),
1465 }),
1466 })
1467 }
1468}
1469
1470impl SasState<Done> {
1471 pub fn as_content(&self) -> OutgoingContent {
1476 get_mac_content(
1477 &self.state.sas.lock(),
1478 &self.ids,
1479 &self.verification_flow_id,
1480 self.state.accepted_protocols.message_auth_code,
1481 )
1482 }
1483
1484 pub fn verified_devices(&self) -> Arc<[DeviceData]> {
1486 self.state.verified_devices.clone()
1487 }
1488
1489 pub fn verified_identities(&self) -> Arc<[UserIdentityData]> {
1491 self.state.verified_master_keys.clone()
1492 }
1493}
1494
1495impl SasState<Cancelled> {
1496 pub fn as_content(&self) -> OutgoingContent {
1497 self.state.as_content(&self.verification_flow_id)
1498 }
1499}
1500
1501#[cfg(test)]
1502mod tests {
1503 use matrix_sdk_test::async_test;
1504 use ruma::{
1505 DeviceId, TransactionId, UserId, device_id,
1506 events::key::verification::{
1507 HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode,
1508 ShortAuthenticationString,
1509 accept::{AcceptMethod, ToDeviceKeyVerificationAcceptEventContent},
1510 start::{
1511 SasV1Content, SasV1ContentInit, StartMethod,
1512 ToDeviceKeyVerificationStartEventContent,
1513 },
1514 },
1515 serde::Base64,
1516 user_id,
1517 };
1518 use serde_json::json;
1519
1520 use super::{Accepted, Created, SasState, Started, SupportedMacMethod, WeAccepted};
1521 use crate::{
1522 AcceptedProtocols, Account, DeviceData,
1523 verification::{
1524 FlowId,
1525 event_enums::{AcceptContent, KeyContent, MacContent, StartContent},
1526 },
1527 };
1528
1529 fn alice_id() -> &'static UserId {
1530 user_id!("@alice:example.org")
1531 }
1532
1533 fn alice_device_id() -> &'static DeviceId {
1534 device_id!("JLAFKJWSCS")
1535 }
1536
1537 fn bob_id() -> &'static UserId {
1538 user_id!("@bob:example.org")
1539 }
1540
1541 fn bob_device_id() -> &'static DeviceId {
1542 device_id!("BOBDEVICE")
1543 }
1544
1545 fn get_sas_pair(
1546 mac_method: Option<SupportedMacMethod>,
1547 ) -> (SasState<Created>, SasState<WeAccepted>) {
1548 let alice = Account::with_device_id(alice_id(), alice_device_id());
1549 let alice_device = DeviceData::from_account(&alice);
1550
1551 let bob = Account::with_device_id(bob_id(), bob_device_id());
1552 let bob_device = DeviceData::from_account(&bob);
1553
1554 let flow_id = TransactionId::new().into();
1555 let alice_sas = SasState::<Created>::new(
1556 alice.static_data().clone(),
1557 bob_device,
1558 None,
1559 None,
1560 flow_id,
1561 false,
1562 None,
1563 );
1564
1565 let start_content = alice_sas.as_content();
1566 let flow_id = start_content.flow_id();
1567
1568 let bob_sas = SasState::<Started>::from_start_event(
1569 bob.static_data().clone(),
1570 alice_device,
1571 None,
1572 None,
1573 flow_id,
1574 &start_content.as_start_content(),
1575 false,
1576 );
1577 let bob_sas = bob_sas
1578 .unwrap()
1579 .into_we_accepted_with_mac_method(vec![ShortAuthenticationString::Emoji], mac_method);
1580
1581 (alice_sas, bob_sas)
1582 }
1583
1584 #[test]
1585 fn start_content_accepting() {
1586 let mut start_content: SasV1Content = SasV1ContentInit {
1587 key_agreement_protocols: vec![
1588 KeyAgreementProtocol::Curve25519HkdfSha256,
1589 KeyAgreementProtocol::Curve25519,
1590 ],
1591 hashes: vec![HashAlgorithm::Sha256],
1592 message_authentication_codes: vec![
1593 #[allow(deprecated)]
1594 MessageAuthenticationCode::HkdfHmacSha256,
1595 MessageAuthenticationCode::from("org.matrix.msc3783.hkdf-hmac-sha256"),
1596 MessageAuthenticationCode::HkdfHmacSha256V2,
1597 ],
1598 short_authentication_string: vec![
1599 ShortAuthenticationString::Emoji,
1600 ShortAuthenticationString::Decimal,
1601 ],
1602 }
1603 .into();
1604
1605 let accepted_protocols = AcceptedProtocols::try_from(&start_content).unwrap();
1606
1607 assert_eq!(accepted_protocols.message_auth_code, SupportedMacMethod::HkdfHmacSha256V2);
1608 assert_eq!(
1609 accepted_protocols.key_agreement_protocol,
1610 KeyAgreementProtocol::Curve25519HkdfSha256
1611 );
1612
1613 start_content.message_authentication_codes = vec![
1614 #[allow(deprecated)]
1615 MessageAuthenticationCode::HkdfHmacSha256,
1616 MessageAuthenticationCode::from("org.matrix.msc3783.hkdf-hmac-sha256"),
1617 ];
1618 let accepted_protocols = AcceptedProtocols::try_from(&start_content).unwrap();
1619 assert_eq!(
1620 accepted_protocols.message_auth_code,
1621 SupportedMacMethod::Msc3783HkdfHmacSha256V2
1622 );
1623
1624 start_content.key_agreement_protocols = vec![KeyAgreementProtocol::Curve25519];
1625 AcceptedProtocols::try_from(&start_content)
1626 .expect_err("We don't support the old Curve25519 key agreement protocol");
1627 }
1628
1629 #[test]
1630 fn test_create_sas() {
1631 let (_, _) = get_sas_pair(None);
1632 }
1633
1634 #[test]
1635 fn test_sas_accept() {
1636 let (alice, bob) = get_sas_pair(None);
1637 let content = bob.as_content();
1638 let content = AcceptContent::from(&content);
1639
1640 alice.into_accepted(bob.user_id(), &content).unwrap();
1641 }
1642
1643 #[test]
1644 fn test_sas_key_share() {
1645 let (alice, bob) = get_sas_pair(None);
1646
1647 let content = bob.as_content();
1648 let content = AcceptContent::from(&content);
1649
1650 let alice: SasState<Accepted> = alice.into_accepted(bob.user_id(), &content).unwrap();
1651 let content = alice.as_content();
1652 let transaction_id = content.1.request_id;
1653 let content = KeyContent::try_from(&content.0).unwrap();
1654 let alice = alice.into_key_sent(&transaction_id).unwrap();
1655
1656 let bob = bob.into_key_received(alice.user_id(), &content).unwrap();
1657
1658 let content = bob.as_content();
1659 let transaction_id = content.1.request_id;
1660 let content = KeyContent::try_from(&content.0).unwrap();
1661
1662 let bob = bob.into_keys_exchanged(&transaction_id).unwrap();
1663
1664 let alice = alice.into_keys_exchanged(bob.user_id(), &content).unwrap();
1665
1666 assert_eq!(alice.get_decimal(), bob.get_decimal());
1667 assert_eq!(alice.get_emoji(), bob.get_emoji());
1668 }
1669
1670 fn full_flow_helper(mac_method: SupportedMacMethod) {
1671 let (alice, bob) = get_sas_pair(Some(mac_method));
1672
1673 let content = bob.as_content();
1674 let content = AcceptContent::from(&content);
1675
1676 assert_eq!(
1677 bob.state.accepted_protocols.message_auth_code, mac_method,
1678 "Bob should be using the specified MAC method."
1679 );
1680
1681 let alice: SasState<Accepted> = alice.into_accepted(bob.user_id(), &content).unwrap();
1682
1683 assert_eq!(
1684 alice.state.accepted_protocols.message_auth_code, mac_method,
1685 "Alice should use the our specified MAC method.",
1686 );
1687
1688 let content = alice.as_content();
1689 let request_id = content.1.request_id;
1690 let content = KeyContent::try_from(&content.0).unwrap();
1691
1692 let alice = alice.into_key_sent(&request_id).unwrap();
1693 let bob = bob.into_key_received(alice.user_id(), &content).unwrap();
1694
1695 let (content, request_info) = bob.as_content();
1696 let request_id = request_info.request_id;
1697 let content = KeyContent::try_from(&content).unwrap();
1698 let bob = bob.into_keys_exchanged(&request_id).unwrap();
1699
1700 let alice = alice.into_keys_exchanged(bob.user_id(), &content).unwrap();
1701
1702 assert_eq!(alice.get_decimal(), bob.get_decimal());
1703 assert_eq!(alice.get_emoji(), bob.get_emoji());
1704
1705 let bob_decimals = bob.get_decimal();
1706
1707 let bob = bob.confirm();
1708
1709 let content = bob.as_content();
1710 let content = MacContent::try_from(&content).unwrap();
1711
1712 let alice = alice.into_mac_received(bob.user_id(), &content).unwrap();
1713 assert!(!alice.get_emoji().is_empty());
1714 assert_eq!(alice.get_decimal(), bob_decimals);
1715 let alice = alice.confirm();
1716
1717 let content = alice.as_content();
1718 let content = MacContent::try_from(&content).unwrap();
1719 let bob = bob.into_done(alice.user_id(), &content).unwrap();
1720
1721 assert!(bob.verified_devices().contains(&bob.other_device()));
1722 assert!(alice.verified_devices().contains(&alice.other_device()));
1723 }
1724
1725 #[test]
1726 fn test_full_flow() {
1727 full_flow_helper(SupportedMacMethod::HkdfHmacSha256);
1728 }
1729
1730 #[test]
1731 fn test_full_flow_hkdf_hmac_sha_v2() {
1732 full_flow_helper(SupportedMacMethod::HkdfHmacSha256V2);
1733 }
1734
1735 #[test]
1736 fn test_full_flow_hkdf_msc3783() {
1737 full_flow_helper(SupportedMacMethod::Msc3783HkdfHmacSha256V2);
1738 }
1739
1740 #[test]
1741 fn test_sas_invalid_commitment() {
1742 let (alice, bob) = get_sas_pair(None);
1743
1744 let mut content = bob.as_content();
1745 let mut method = content.method_mut();
1746
1747 match &mut method {
1748 AcceptMethod::SasV1(c) => {
1749 c.commitment = Base64::empty();
1750 }
1751 _ => panic!("Unknown accept event content"),
1752 }
1753
1754 let content = AcceptContent::from(&content);
1755
1756 let alice: SasState<Accepted> = alice.into_accepted(bob.user_id(), &content).unwrap();
1757
1758 let content = alice.as_content();
1759 let content = KeyContent::try_from(&content.0).unwrap();
1760 let bob = bob.into_key_received(alice.user_id(), &content).unwrap();
1761 let content = bob.as_content();
1762 let content = KeyContent::try_from(&content.0).unwrap();
1763
1764 alice
1765 .into_key_received(bob.user_id(), &content)
1766 .expect_err("Didn't cancel on invalid commitment");
1767 }
1768
1769 #[test]
1770 fn test_sas_invalid_sender() {
1771 let (alice, bob) = get_sas_pair(None);
1772
1773 let content = bob.as_content();
1774 let content = AcceptContent::from(&content);
1775 let sender = user_id!("@malory:example.org");
1776 alice.into_accepted(sender, &content).expect_err("Didn't cancel on a invalid sender");
1777 }
1778
1779 #[test]
1780 fn test_sas_unknown_sas_method() {
1781 let (alice, bob) = get_sas_pair(None);
1782
1783 let mut content = bob.as_content();
1784 let mut method = content.method_mut();
1785
1786 match &mut method {
1787 AcceptMethod::SasV1(c) => {
1788 c.short_authentication_string = vec![];
1789 }
1790 _ => panic!("Unknown accept event content"),
1791 }
1792
1793 let content = AcceptContent::from(&content);
1794
1795 alice
1796 .into_accepted(bob.user_id(), &content)
1797 .expect_err("Didn't cancel on an invalid SAS method");
1798 }
1799
1800 #[test]
1801 fn test_sas_unknown_method() {
1802 let (alice, bob) = get_sas_pair(None);
1803
1804 let content = json!({
1805 "method": "m.sas.custom",
1806 "method_data": "something",
1807 "transaction_id": "some_id",
1808 });
1809
1810 let content: ToDeviceKeyVerificationAcceptEventContent =
1811 serde_json::from_value(content).unwrap();
1812 let content = AcceptContent::from(&content);
1813
1814 alice
1815 .into_accepted(bob.user_id(), &content)
1816 .expect_err("Didn't cancel on an unknown SAS method");
1817 }
1818
1819 #[async_test]
1820 async fn test_sas_from_start_unknown_method() {
1821 let alice = Account::with_device_id(alice_id(), alice_device_id());
1822 let alice_device = DeviceData::from_account(&alice);
1823
1824 let bob = Account::with_device_id(bob_id(), bob_device_id());
1825 let bob_device = DeviceData::from_account(&bob);
1826
1827 let flow_id = TransactionId::new().into();
1828 let alice_sas = SasState::<Created>::new(
1829 alice.static_data().clone(),
1830 bob_device,
1831 None,
1832 None,
1833 flow_id,
1834 false,
1835 None,
1836 );
1837
1838 let mut start_content = alice_sas.as_content();
1839 let method = start_content.method_mut();
1840
1841 match method {
1842 StartMethod::SasV1(c) => {
1843 c.message_authentication_codes = vec![];
1844 }
1845 _ => panic!("Unknown SAS start method"),
1846 }
1847
1848 let flow_id = start_content.flow_id();
1849 let content = StartContent::from(&start_content);
1850
1851 SasState::<Started>::from_start_event(
1852 bob.static_data().clone(),
1853 alice_device.clone(),
1854 None,
1855 None,
1856 flow_id,
1857 &content,
1858 false,
1859 )
1860 .expect_err("Didn't cancel on invalid MAC method");
1861
1862 let content = json!({
1863 "method": "m.sas.custom",
1864 "from_device": "DEVICEID",
1865 "method_data": "something",
1866 "transaction_id": "some_id",
1867 });
1868
1869 let content: ToDeviceKeyVerificationStartEventContent =
1870 serde_json::from_value(content).unwrap();
1871 let content = StartContent::from(&content);
1872 let flow_id = content.flow_id().to_owned();
1873
1874 SasState::<Started>::from_start_event(
1875 bob.static_data().clone(),
1876 alice_device,
1877 None,
1878 None,
1879 FlowId::ToDevice(flow_id.into()),
1880 &content,
1881 false,
1882 )
1883 .expect_err("Didn't cancel on unknown sas method");
1884 }
1885}