matrix_sdk_crypto/verification/
cache.rs

1// Copyright 2021 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{collections::BTreeMap, sync::Arc};
16
17use as_variant::as_variant;
18use matrix_sdk_common::locks::RwLock as StdRwLock;
19use ruma::{DeviceId, OwnedTransactionId, OwnedUserId, TransactionId, UserId};
20#[cfg(feature = "qrcode")]
21use tracing::debug;
22use tracing::{trace, warn};
23
24use super::{event_enums::OutgoingContent, FlowId, Sas, Verification};
25use crate::types::requests::{
26    OutgoingRequest, OutgoingVerificationRequest, RoomMessageRequest, ToDeviceRequest,
27};
28#[cfg(feature = "qrcode")]
29use crate::QrVerification;
30
31#[derive(Clone, Debug, Default)]
32pub struct VerificationCache {
33    inner: Arc<VerificationCacheInner>,
34}
35
36// See https://github.com/matrix-org/matrix-rust-sdk/pull/3749#issuecomment-2312939823.
37#[cfg(not(feature = "test-send-sync"))]
38unsafe impl Sync for VerificationCache {}
39
40#[cfg(feature = "test-send-sync")]
41#[test]
42// See https://github.com/matrix-org/matrix-rust-sdk/pull/3749#issuecomment-2312939823.
43fn test_send_sync_for_room() {
44    fn assert_send_sync<T: Send + Sync>() {}
45
46    assert_send_sync::<VerificationCache>();
47}
48
49#[derive(Debug, Default)]
50struct VerificationCacheInner {
51    verification: StdRwLock<BTreeMap<OwnedUserId, BTreeMap<String, Verification>>>,
52    outgoing_requests: StdRwLock<BTreeMap<OwnedTransactionId, OutgoingRequest>>,
53    flow_ids_waiting_for_response: StdRwLock<BTreeMap<OwnedTransactionId, (OwnedUserId, FlowId)>>,
54}
55
56#[derive(Debug)]
57pub struct RequestInfo {
58    pub flow_id: FlowId,
59    pub request_id: OwnedTransactionId,
60}
61
62impl VerificationCache {
63    pub fn new() -> Self {
64        Self::default()
65    }
66
67    #[cfg(test)]
68    #[allow(dead_code)]
69    pub fn is_empty(&self) -> bool {
70        self.inner.verification.read().values().all(|m| m.is_empty())
71    }
72
73    /// Add a new `Verification` object to the cache, this will cancel any
74    /// duplicates we have going on, including the newly inserted one, with a
75    /// given user.
76    pub fn insert(&self, verification: impl Into<Verification>) {
77        let verification = verification.into();
78
79        let mut verification_write_guard = self.inner.verification.write();
80        let user_verifications =
81            verification_write_guard.entry(verification.other_user().to_owned()).or_default();
82
83        // Cancel all the old verifications as well as the new one we have for
84        // this user if someone tries to have two verifications going on at
85        // once.
86        for old_verification in user_verifications.values() {
87            if !old_verification.is_cancelled() {
88                warn!(
89                    user_id = ?verification.other_user(),
90                    old_flow_id = old_verification.flow_id(),
91                    new_flow_id = verification.flow_id(),
92                    "Received a new verification whilst another one with \
93                    the same user is ongoing. Cancelling both verifications"
94                );
95
96                if let Some(r) = old_verification.cancel() {
97                    self.add_request(r.into())
98                }
99
100                if let Some(r) = verification.cancel() {
101                    self.add_request(r.into())
102                }
103            }
104        }
105
106        // We still want to add the new verification, in case users want to
107        // inspect the verification object a matching `m.key.verification.start`
108        // produced.
109        user_verifications.insert(verification.flow_id().to_owned(), verification);
110    }
111
112    pub fn insert_sas(&self, sas: Sas) {
113        self.insert(sas);
114    }
115
116    pub fn replace_sas(&self, sas: Sas) {
117        let verification: Verification = sas.into();
118        self.replace(verification);
119    }
120
121    #[cfg(feature = "qrcode")]
122    pub fn insert_qr(&self, qr: QrVerification) {
123        debug!(
124            user_id = ?qr.other_user_id(),
125            flow_id = qr.flow_id().as_str(),
126            "Inserting new QR verification"
127        );
128        self.insert(qr)
129    }
130
131    #[cfg(feature = "qrcode")]
132    pub fn replace_qr(&self, qr: QrVerification) {
133        debug!(
134            user_id = ?qr.other_user_id(),
135            flow_id = qr.flow_id().as_str(),
136            "Replacing existing QR verification"
137        );
138        let verification: Verification = qr.into();
139        self.replace(verification);
140    }
141
142    #[cfg(feature = "qrcode")]
143    pub fn get_qr(&self, sender: &UserId, flow_id: &str) -> Option<QrVerification> {
144        self.get(sender, flow_id).and_then(as_variant!(Verification::QrV1))
145    }
146
147    pub fn replace(&self, verification: Verification) {
148        self.inner
149            .verification
150            .write()
151            .entry(verification.other_user().to_owned())
152            .or_default()
153            .insert(verification.flow_id().to_owned(), verification.clone());
154    }
155
156    pub fn get(&self, sender: &UserId, flow_id: &str) -> Option<Verification> {
157        self.inner.verification.read().get(sender)?.get(flow_id).cloned()
158    }
159
160    pub fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
161        self.inner.outgoing_requests.read().values().cloned().collect()
162    }
163
164    pub fn garbage_collect(&self) -> Vec<OutgoingVerificationRequest> {
165        let verification = &mut self.inner.verification.write();
166
167        for user_verification in verification.values_mut() {
168            user_verification.retain(|_, s| !(s.is_done() || s.is_cancelled()));
169        }
170
171        verification.retain(|_, m| !m.is_empty());
172
173        verification
174            .values()
175            .flat_map(BTreeMap::values)
176            .filter_map(|s| as_variant!(s, Verification::SasV1)?.cancel_if_timed_out())
177            .collect()
178    }
179
180    pub fn get_sas(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
181        self.get(user_id, flow_id).and_then(as_variant!(Verification::SasV1))
182    }
183
184    pub fn add_request(&self, request: OutgoingRequest) {
185        trace!("Adding an outgoing request {:?}", request);
186        self.inner.outgoing_requests.write().insert(request.request_id.clone(), request);
187    }
188
189    pub fn add_verification_request(&self, request: OutgoingVerificationRequest) {
190        let request = OutgoingRequest {
191            request_id: request.request_id().to_owned(),
192            request: Arc::new(request.into()),
193        };
194        self.add_request(request);
195    }
196
197    pub fn queue_up_content(
198        &self,
199        recipient: &UserId,
200        recipient_device: &DeviceId,
201        content: OutgoingContent,
202        request_info: Option<RequestInfo>,
203    ) {
204        let request_id = if let Some(request_info) = request_info {
205            trace!(
206                ?recipient,
207                ?request_info,
208                "Storing the request info, waiting for the request to be marked as sent"
209            );
210
211            self.inner.flow_ids_waiting_for_response.write().insert(
212                request_info.request_id.to_owned(),
213                (recipient.to_owned(), request_info.flow_id),
214            );
215            request_info.request_id
216        } else {
217            TransactionId::new()
218        };
219
220        match content {
221            OutgoingContent::ToDevice(c) => {
222                let request = ToDeviceRequest::with_id(
223                    recipient,
224                    recipient_device.to_owned(),
225                    &c,
226                    request_id,
227                );
228                let request_id = request.txn_id.clone();
229
230                let request = OutgoingRequest {
231                    request_id: request_id.clone(),
232                    request: Arc::new(request.into()),
233                };
234
235                self.inner.outgoing_requests.write().insert(request_id, request);
236            }
237
238            OutgoingContent::Room(r, c) => {
239                let request = OutgoingRequest {
240                    request: Arc::new(
241                        RoomMessageRequest { room_id: r, txn_id: request_id.clone(), content: c }
242                            .into(),
243                    ),
244                    request_id: request_id.clone(),
245                };
246
247                self.inner.outgoing_requests.write().insert(request_id, request);
248            }
249        }
250    }
251
252    pub fn mark_request_as_sent(&self, request_id: &TransactionId) {
253        if let Some(request_id) = self.inner.outgoing_requests.write().remove(request_id) {
254            trace!(?request_id, "Marking a verification HTTP request as sent");
255        }
256
257        if let Some((user_id, flow_id)) =
258            self.inner.flow_ids_waiting_for_response.read().get(request_id)
259        {
260            if let Some(verification) = self.get(user_id, flow_id.as_str()) {
261                match verification {
262                    Verification::SasV1(s) => s.mark_request_as_sent(request_id),
263                    #[cfg(feature = "qrcode")]
264                    Verification::QrV1(_) => (),
265                }
266            }
267        }
268    }
269}