matrix_sdk_ffi/
session_verification.rs

1use std::sync::{Arc, RwLock};
2
3use async_compat::get_runtime_handle;
4use futures_util::StreamExt;
5use matrix_sdk::{
6    encryption::{
7        identities::UserIdentity,
8        verification::{SasState, SasVerification, VerificationRequest, VerificationRequestState},
9        Encryption,
10    },
11    ruma::events::key::verification::VerificationMethod,
12    Account,
13};
14use ruma::UserId;
15use tracing::{error, warn};
16
17use crate::{client::UserProfile, error::ClientError, utils::Timestamp};
18
19#[derive(uniffi::Object)]
20pub struct SessionVerificationEmoji {
21    symbol: String,
22    description: String,
23}
24
25#[matrix_sdk_ffi_macros::export]
26impl SessionVerificationEmoji {
27    pub fn symbol(&self) -> String {
28        self.symbol.clone()
29    }
30
31    pub fn description(&self) -> String {
32        self.description.clone()
33    }
34}
35
36#[derive(uniffi::Enum)]
37pub enum SessionVerificationData {
38    Emojis { emojis: Vec<Arc<SessionVerificationEmoji>>, indices: Vec<u8> },
39    Decimals { values: Vec<u16> },
40}
41
42/// Details about the incoming verification request
43#[derive(uniffi::Record)]
44pub struct SessionVerificationRequestDetails {
45    sender_profile: UserProfile,
46    flow_id: String,
47    device_id: String,
48    device_display_name: Option<String>,
49    /// First time this device was seen in milliseconds since epoch.
50    first_seen_timestamp: Timestamp,
51}
52
53#[matrix_sdk_ffi_macros::export(callback_interface)]
54pub trait SessionVerificationControllerDelegate: Sync + Send {
55    fn did_receive_verification_request(&self, details: SessionVerificationRequestDetails);
56    fn did_accept_verification_request(&self);
57    fn did_start_sas_verification(&self);
58    fn did_receive_verification_data(&self, data: SessionVerificationData);
59    fn did_fail(&self);
60    fn did_cancel(&self);
61    fn did_finish(&self);
62}
63
64pub type Delegate = Arc<RwLock<Option<Box<dyn SessionVerificationControllerDelegate>>>>;
65
66#[derive(Clone, uniffi::Object)]
67pub struct SessionVerificationController {
68    encryption: Encryption,
69    user_identity: UserIdentity,
70    account: Account,
71    delegate: Delegate,
72    verification_request: Arc<RwLock<Option<VerificationRequest>>>,
73    sas_verification: Arc<RwLock<Option<SasVerification>>>,
74}
75
76#[matrix_sdk_ffi_macros::export]
77impl SessionVerificationController {
78    pub fn set_delegate(&self, delegate: Option<Box<dyn SessionVerificationControllerDelegate>>) {
79        *self.delegate.write().unwrap() = delegate;
80    }
81
82    /// Set this particular request as the currently active one and register for
83    /// events pertaining it.
84    /// * `sender_id` - The user requesting verification.
85    /// * `flow_id` - - The ID that uniquely identifies the verification flow.
86    pub async fn acknowledge_verification_request(
87        &self,
88        sender_id: String,
89        flow_id: String,
90    ) -> Result<(), ClientError> {
91        let sender_id = UserId::parse(sender_id.clone())?;
92
93        let verification_request = self
94            .encryption
95            .get_verification_request(&sender_id, flow_id)
96            .await
97            .ok_or(ClientError::new("Unknown session verification request"))?;
98
99        self.set_ongoing_verification_request(verification_request)
100    }
101
102    /// Accept the previously acknowledged verification request
103    pub async fn accept_verification_request(&self) -> Result<(), ClientError> {
104        let verification_request = self.verification_request.read().unwrap().clone();
105
106        if let Some(verification_request) = verification_request {
107            let methods = vec![VerificationMethod::SasV1];
108            verification_request.accept_with_methods(methods).await?;
109        }
110
111        Ok(())
112    }
113
114    /// Request verification for the current device
115    pub async fn request_device_verification(&self) -> Result<(), ClientError> {
116        let methods = vec![VerificationMethod::SasV1];
117        let verification_request = self
118            .user_identity
119            .request_verification_with_methods(methods)
120            .await
121            .map_err(anyhow::Error::from)?;
122
123        self.set_ongoing_verification_request(verification_request)
124    }
125
126    /// Request verification for the given user
127    pub async fn request_user_verification(&self, user_id: String) -> Result<(), ClientError> {
128        let user_id = UserId::parse(user_id)?;
129
130        let user_identity = self
131            .encryption
132            .get_user_identity(&user_id)
133            .await?
134            .ok_or(ClientError::new("Unknown user identity"))?;
135
136        if user_identity.is_verified() {
137            return Err(ClientError::new("User is already verified"));
138        }
139
140        let methods = vec![VerificationMethod::SasV1];
141
142        let verification_request = user_identity
143            .request_verification_with_methods(methods)
144            .await
145            .map_err(anyhow::Error::from)?;
146
147        self.set_ongoing_verification_request(verification_request)
148    }
149
150    /// Transition the current verification request into a SAS verification
151    /// flow.
152    pub async fn start_sas_verification(&self) -> Result<(), ClientError> {
153        let verification_request = self.verification_request.read().unwrap().clone();
154
155        let Some(verification_request) = verification_request else {
156            return Err(ClientError::new("Verification request missing."));
157        };
158
159        match verification_request.start_sas().await {
160            Ok(Some(verification)) => {
161                *self.sas_verification.write().unwrap() = Some(verification.clone());
162
163                if let Some(delegate) = &*self.delegate.read().unwrap() {
164                    delegate.did_start_sas_verification()
165                }
166
167                let delegate = self.delegate.clone();
168                get_runtime_handle()
169                    .spawn(Self::listen_to_sas_verification_changes(verification, delegate));
170            }
171            _ => {
172                if let Some(delegate) = &*self.delegate.read().unwrap() {
173                    delegate.did_fail()
174                }
175            }
176        }
177
178        Ok(())
179    }
180
181    /// Confirm that the short auth strings match on both sides.
182    pub async fn approve_verification(&self) -> Result<(), ClientError> {
183        let sas_verification = self.sas_verification.read().unwrap().clone();
184
185        let Some(sas_verification) = sas_verification else {
186            return Err(ClientError::new("SAS verification missing"));
187        };
188
189        Ok(sas_verification.confirm().await?)
190    }
191
192    /// Reject the short auth string
193    pub async fn decline_verification(&self) -> Result<(), ClientError> {
194        let sas_verification = self.sas_verification.read().unwrap().clone();
195
196        let Some(sas_verification) = sas_verification else {
197            return Err(ClientError::new("SAS verification missing"));
198        };
199
200        Ok(sas_verification.mismatch().await?)
201    }
202
203    /// Cancel the current verification request
204    pub async fn cancel_verification(&self) -> Result<(), ClientError> {
205        let verification_request = self.verification_request.read().unwrap().clone();
206
207        let Some(verification_request) = verification_request else {
208            return Err(ClientError::new("Verification request missing."));
209        };
210
211        Ok(verification_request.cancel().await?)
212    }
213}
214
215impl SessionVerificationController {
216    pub(crate) fn new(
217        encryption: Encryption,
218        user_identity: UserIdentity,
219        account: Account,
220    ) -> Self {
221        SessionVerificationController {
222            encryption,
223            user_identity,
224            account,
225            delegate: Arc::new(RwLock::new(None)),
226            verification_request: Arc::new(RwLock::new(None)),
227            sas_verification: Arc::new(RwLock::new(None)),
228        }
229    }
230
231    /// Ask the controller to process an incoming request based on the sender
232    /// and flow identifier. It will fetch the request, verify that it's in the
233    /// correct state and then and notify the delegate.
234    pub(crate) async fn process_incoming_verification_request(
235        &self,
236        sender: &UserId,
237        flow_id: impl AsRef<str>,
238    ) {
239        if sender != self.user_identity.user_id() {
240            if let Some(status) = self.encryption.cross_signing_status().await {
241                if !status.is_complete() {
242                    warn!("Cannot verify other users until our own device's cross-signing status is complete: {:?}", status);
243                    return;
244                }
245            }
246        }
247
248        let Some(request) = self.encryption.get_verification_request(sender, flow_id).await else {
249            error!("Failed retrieving verification request");
250            return;
251        };
252
253        let VerificationRequestState::Requested { other_device_data, .. } = request.state() else {
254            error!("Received verification request event but the request is in the wrong state.");
255            return;
256        };
257
258        let Ok(sender_profile) = self.account.fetch_user_profile_of(sender).await else {
259            error!("Failed fetching user profile for verification request");
260            return;
261        };
262
263        if let Some(delegate) = &*self.delegate.read().unwrap() {
264            delegate.did_receive_verification_request(SessionVerificationRequestDetails {
265                sender_profile: UserProfile {
266                    user_id: request.other_user_id().to_string(),
267                    display_name: sender_profile.displayname,
268                    avatar_url: sender_profile.avatar_url.as_ref().map(|url| url.to_string()),
269                },
270                flow_id: request.flow_id().into(),
271                device_id: other_device_data.device_id().into(),
272                device_display_name: other_device_data.display_name().map(str::to_string),
273                first_seen_timestamp: other_device_data.first_time_seen_ts().into(),
274            });
275        }
276    }
277
278    fn set_ongoing_verification_request(
279        &self,
280        verification_request: VerificationRequest,
281    ) -> Result<(), ClientError> {
282        if let Some(ongoing_verification_request) =
283            self.verification_request.read().unwrap().clone()
284        {
285            if !ongoing_verification_request.is_done()
286                && !ongoing_verification_request.is_cancelled()
287            {
288                return Err(ClientError::new("There is another verification flow ongoing."));
289            }
290        }
291
292        *self.verification_request.write().unwrap() = Some(verification_request.clone());
293
294        get_runtime_handle().spawn(Self::listen_to_verification_request_changes(
295            verification_request,
296            self.sas_verification.clone(),
297            self.delegate.clone(),
298        ));
299
300        Ok(())
301    }
302
303    async fn listen_to_verification_request_changes(
304        verification_request: VerificationRequest,
305        sas_verification: Arc<RwLock<Option<SasVerification>>>,
306        delegate: Delegate,
307    ) {
308        let mut stream = verification_request.changes();
309
310        while let Some(state) = stream.next().await {
311            match state {
312                VerificationRequestState::Transitioned { verification } => {
313                    let Some(verification) = verification.sas() else {
314                        error!("Invalid, non-sas verification flow. Returning.");
315                        return;
316                    };
317
318                    *sas_verification.write().unwrap() = Some(verification.clone());
319
320                    if verification.accept().await.is_ok() {
321                        if let Some(delegate) = &*delegate.read().unwrap() {
322                            delegate.did_start_sas_verification()
323                        }
324
325                        let delegate = delegate.clone();
326                        get_runtime_handle().spawn(Self::listen_to_sas_verification_changes(
327                            verification,
328                            delegate,
329                        ));
330                    } else if let Some(delegate) = &*delegate.read().unwrap() {
331                        delegate.did_fail()
332                    }
333                }
334                VerificationRequestState::Ready { .. } => {
335                    if let Some(delegate) = &*delegate.read().unwrap() {
336                        delegate.did_accept_verification_request()
337                    }
338                }
339                VerificationRequestState::Cancelled(..) => {
340                    if let Some(delegate) = &*delegate.read().unwrap() {
341                        delegate.did_cancel();
342                    }
343                }
344                _ => {}
345            }
346        }
347    }
348
349    async fn listen_to_sas_verification_changes(sas: SasVerification, delegate: Delegate) {
350        let mut stream = sas.changes();
351
352        while let Some(state) = stream.next().await {
353            match state {
354                SasState::KeysExchanged { emojis, decimals } => {
355                    if let Some(delegate) = &*delegate.read().unwrap() {
356                        if let Some(emojis) = emojis {
357                            delegate.did_receive_verification_data(
358                                SessionVerificationData::Emojis {
359                                    emojis: emojis
360                                        .emojis
361                                        .into_iter()
362                                        .map(|emoji| {
363                                            Arc::new(SessionVerificationEmoji {
364                                                symbol: emoji.symbol.to_owned(),
365                                                description: emoji.description.to_owned(),
366                                            })
367                                        })
368                                        .collect(),
369                                    indices: emojis.indices.to_vec(),
370                                },
371                            );
372                        } else {
373                            delegate.did_receive_verification_data(
374                                SessionVerificationData::Decimals {
375                                    values: vec![decimals.0, decimals.1, decimals.2],
376                                },
377                            )
378                        }
379                    }
380                }
381                SasState::Done { .. } => {
382                    if let Some(delegate) = &*delegate.read().unwrap() {
383                        delegate.did_finish()
384                    }
385                    break;
386                }
387                SasState::Cancelled(_cancel_info) => {
388                    // TODO: The cancel_info is usable, we should tell the user why we were
389                    // cancelled.
390                    if let Some(delegate) = &*delegate.read().unwrap() {
391                        delegate.did_cancel()
392                    }
393                    break;
394                }
395                SasState::Created { .. }
396                | SasState::Started { .. }
397                | SasState::Accepted { .. }
398                | SasState::Confirmed => (),
399            }
400        }
401    }
402}