matrix_sdk_crypto/verification/
cache.rs1use std::{collections::BTreeMap, sync::Arc};
16
17use as_variant::as_variant;
18use matrix_sdk_common::locks::RwLock as StdRwLock;
19use ruma::{DeviceId, OwnedTransactionId, OwnedUserId, TransactionId, UserId};
20#[cfg(feature = "qrcode")]
21use tracing::debug;
22use tracing::{trace, warn};
23
24use super::{event_enums::OutgoingContent, FlowId, Sas, Verification};
25use crate::types::requests::{
26 OutgoingRequest, OutgoingVerificationRequest, RoomMessageRequest, ToDeviceRequest,
27};
28#[cfg(feature = "qrcode")]
29use crate::QrVerification;
30
31#[derive(Clone, Debug, Default)]
32pub struct VerificationCache {
33 inner: Arc<VerificationCacheInner>,
34}
35
36#[cfg(not(feature = "test-send-sync"))]
38unsafe impl Sync for VerificationCache {}
39
40#[cfg(feature = "test-send-sync")]
41#[test]
42fn test_send_sync_for_room() {
44 fn assert_send_sync<T: Send + Sync>() {}
45
46 assert_send_sync::<VerificationCache>();
47}
48
49#[derive(Debug, Default)]
50struct VerificationCacheInner {
51 verification: StdRwLock<BTreeMap<OwnedUserId, BTreeMap<String, Verification>>>,
52 outgoing_requests: StdRwLock<BTreeMap<OwnedTransactionId, OutgoingRequest>>,
53 flow_ids_waiting_for_response: StdRwLock<BTreeMap<OwnedTransactionId, (OwnedUserId, FlowId)>>,
54}
55
56#[derive(Debug)]
57pub struct RequestInfo {
58 pub flow_id: FlowId,
59 pub request_id: OwnedTransactionId,
60}
61
62impl VerificationCache {
63 pub fn new() -> Self {
64 Self::default()
65 }
66
67 #[cfg(test)]
68 #[allow(dead_code)]
69 pub fn is_empty(&self) -> bool {
70 self.inner.verification.read().values().all(|m| m.is_empty())
71 }
72
73 pub fn insert(&self, verification: impl Into<Verification>) {
77 let verification = verification.into();
78
79 let mut verification_write_guard = self.inner.verification.write();
80 let user_verifications =
81 verification_write_guard.entry(verification.other_user().to_owned()).or_default();
82
83 for old_verification in user_verifications.values() {
87 if !old_verification.is_cancelled() {
88 warn!(
89 user_id = ?verification.other_user(),
90 old_flow_id = old_verification.flow_id(),
91 new_flow_id = verification.flow_id(),
92 "Received a new verification whilst another one with \
93 the same user is ongoing. Cancelling both verifications"
94 );
95
96 if let Some(r) = old_verification.cancel() {
97 self.add_request(r.into())
98 }
99
100 if let Some(r) = verification.cancel() {
101 self.add_request(r.into())
102 }
103 }
104 }
105
106 user_verifications.insert(verification.flow_id().to_owned(), verification);
110 }
111
112 pub fn insert_sas(&self, sas: Sas) {
113 self.insert(sas);
114 }
115
116 pub fn replace_sas(&self, sas: Sas) {
117 let verification: Verification = sas.into();
118 self.replace(verification);
119 }
120
121 #[cfg(feature = "qrcode")]
122 pub fn insert_qr(&self, qr: QrVerification) {
123 debug!(
124 user_id = ?qr.other_user_id(),
125 flow_id = qr.flow_id().as_str(),
126 "Inserting new QR verification"
127 );
128 self.insert(qr)
129 }
130
131 #[cfg(feature = "qrcode")]
132 pub fn replace_qr(&self, qr: QrVerification) {
133 debug!(
134 user_id = ?qr.other_user_id(),
135 flow_id = qr.flow_id().as_str(),
136 "Replacing existing QR verification"
137 );
138 let verification: Verification = qr.into();
139 self.replace(verification);
140 }
141
142 #[cfg(feature = "qrcode")]
143 pub fn get_qr(&self, sender: &UserId, flow_id: &str) -> Option<QrVerification> {
144 self.get(sender, flow_id).and_then(as_variant!(Verification::QrV1))
145 }
146
147 pub fn replace(&self, verification: Verification) {
148 self.inner
149 .verification
150 .write()
151 .entry(verification.other_user().to_owned())
152 .or_default()
153 .insert(verification.flow_id().to_owned(), verification.clone());
154 }
155
156 pub fn get(&self, sender: &UserId, flow_id: &str) -> Option<Verification> {
157 self.inner.verification.read().get(sender)?.get(flow_id).cloned()
158 }
159
160 pub fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
161 self.inner.outgoing_requests.read().values().cloned().collect()
162 }
163
164 pub fn garbage_collect(&self) -> Vec<OutgoingVerificationRequest> {
165 let verification = &mut self.inner.verification.write();
166
167 for user_verification in verification.values_mut() {
168 user_verification.retain(|_, s| !(s.is_done() || s.is_cancelled()));
169 }
170
171 verification.retain(|_, m| !m.is_empty());
172
173 verification
174 .values()
175 .flat_map(BTreeMap::values)
176 .filter_map(|s| as_variant!(s, Verification::SasV1)?.cancel_if_timed_out())
177 .collect()
178 }
179
180 pub fn get_sas(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
181 self.get(user_id, flow_id).and_then(as_variant!(Verification::SasV1))
182 }
183
184 pub fn add_request(&self, request: OutgoingRequest) {
185 trace!("Adding an outgoing request {:?}", request);
186 self.inner.outgoing_requests.write().insert(request.request_id.clone(), request);
187 }
188
189 pub fn add_verification_request(&self, request: OutgoingVerificationRequest) {
190 let request = OutgoingRequest {
191 request_id: request.request_id().to_owned(),
192 request: Arc::new(request.into()),
193 };
194 self.add_request(request);
195 }
196
197 pub fn queue_up_content(
198 &self,
199 recipient: &UserId,
200 recipient_device: &DeviceId,
201 content: OutgoingContent,
202 request_info: Option<RequestInfo>,
203 ) {
204 let request_id = if let Some(request_info) = request_info {
205 trace!(
206 ?recipient,
207 ?request_info,
208 "Storing the request info, waiting for the request to be marked as sent"
209 );
210
211 self.inner.flow_ids_waiting_for_response.write().insert(
212 request_info.request_id.to_owned(),
213 (recipient.to_owned(), request_info.flow_id),
214 );
215 request_info.request_id
216 } else {
217 TransactionId::new()
218 };
219
220 match content {
221 OutgoingContent::ToDevice(c) => {
222 let request = ToDeviceRequest::with_id(
223 recipient,
224 recipient_device.to_owned(),
225 &c,
226 request_id,
227 );
228 let request_id = request.txn_id.clone();
229
230 let request = OutgoingRequest {
231 request_id: request_id.clone(),
232 request: Arc::new(request.into()),
233 };
234
235 self.inner.outgoing_requests.write().insert(request_id, request);
236 }
237
238 OutgoingContent::Room(r, c) => {
239 let request = OutgoingRequest {
240 request: Arc::new(
241 RoomMessageRequest { room_id: r, txn_id: request_id.clone(), content: c }
242 .into(),
243 ),
244 request_id: request_id.clone(),
245 };
246
247 self.inner.outgoing_requests.write().insert(request_id, request);
248 }
249 }
250 }
251
252 pub fn mark_request_as_sent(&self, request_id: &TransactionId) {
253 if let Some(request_id) = self.inner.outgoing_requests.write().remove(request_id) {
254 trace!(?request_id, "Marking a verification HTTP request as sent");
255 }
256
257 if let Some((user_id, flow_id)) =
258 self.inner.flow_ids_waiting_for_response.read().get(request_id)
259 {
260 if let Some(verification) = self.get(user_id, flow_id.as_str()) {
261 match verification {
262 Verification::SasV1(s) => s.mark_request_as_sent(request_id),
263 #[cfg(feature = "qrcode")]
264 Verification::QrV1(_) => (),
265 }
266 }
267 }
268 }
269}