1use std::sync::Arc;
16
17use api::{
18 download::{
19 encrypted::DownloadAndScanEncryptedMediaRequest, unencrypted::DownloadAndScanMediaRequest,
20 },
21 public_server_key::PublicServerKeyRequest,
22};
23use matrix_sdk::{
24 BoxFuture, Client, Error, IdParseError,
25 encryption::vodozemac::pk_encryption::Message,
26 locks::Mutex,
27 media::{MediaFetcher, MediaRequestParameters},
28 ruma::events::room::MediaSource,
29};
30use matrix_sdk_crypto::olm::Curve25519PublicKey;
31use ruma::{
32 events::room::EncryptedFile,
33 serde::{Base64, base64::Standard},
34};
35use serde::{Deserialize, Serialize};
36use tracing::trace;
37
38#[cfg(feature = "uniffi")]
39uniffi::setup_scaffolding!();
40
41pub use crate::api::scan::MediaScanResponse;
42use crate::api::{
43 DownloadAndScanMediaResponse,
44 scan::{encrypted::EncryptedMediaScanRequest, unencrypted::MediaScanRequest},
45};
46
47mod api;
48
49#[derive(Debug)]
51pub struct ContentScanner {
52 scanner_url: String,
53 public_server_key: Arc<Mutex<Option<String>>>,
54}
55
56impl ContentScanner {
57 pub fn new(scanner_url: impl Into<String>) -> Self {
59 Self { scanner_url: scanner_url.into(), public_server_key: Arc::new(Mutex::new(None)) }
60 }
61
62 pub(crate) async fn fetch_public_server_key(&self, client: &Client) -> Result<String, Error> {
63 let response = client.send(PublicServerKeyRequest::new(self.scanner_url.clone())).await?;
64 Ok(response.public_key)
65 }
66
67 async fn get_or_fetch_public_server_key(&self, client: &Client) -> Option<Curve25519PublicKey> {
68 let public_server_key =
69 if let Some(public_server_key) = (*self.public_server_key.lock()).clone() {
70 trace!("Using cached public server key");
71 Some(public_server_key)
72 } else {
73 trace!("Using cached public server key");
74 let ret = self.fetch_public_server_key(client).await.ok();
75
76 if let Some(public_server_key) = &ret {
77 trace!("Saved new public server key");
78 let mut guard = self.public_server_key.lock();
79 let _ = guard.insert(public_server_key.clone());
80 }
81
82 ret
83 };
84
85 public_server_key.and_then(|key| Curve25519PublicKey::from_base64(&key).ok())
86 }
87
88 pub(crate) async fn get_media(
89 &self,
90 client: &Client,
91 media_source: &MediaSource,
92 ) -> Result<DownloadAndScanMediaResponse, Error> {
93 match &media_source {
94 MediaSource::Encrypted(encrypted) => {
95 let public_server_key = self.get_or_fetch_public_server_key(client).await;
97
98 Ok(client
99 .send(DownloadAndScanEncryptedMediaRequest::new(
100 self.scanner_url.clone(),
101 public_server_key,
102 *encrypted.clone(),
103 ))
104 .await?)
105 }
106 MediaSource::Plain(mxc) => {
107 let (server_name, media_id) =
108 mxc.parts().map_err(|e| Error::Identifier(IdParseError::InvalidMxcUri(e)))?;
109 Ok(client
110 .send(DownloadAndScanMediaRequest::new(
111 &self.scanner_url,
112 server_name.as_str(),
113 media_id,
114 ))
115 .await?)
116 }
117 }
118 }
119
120 pub async fn scan(
123 &self,
124 client: &Client,
125 media_source: &MediaSource,
126 ) -> Result<MediaScanResponse, Error> {
127 match &media_source {
128 MediaSource::Encrypted(encrypted) => {
129 let public_server_key = self.get_or_fetch_public_server_key(client).await;
131
132 Ok(client
133 .send(EncryptedMediaScanRequest::new(
134 self.scanner_url.clone(),
135 public_server_key,
136 *encrypted.clone(),
137 ))
138 .await?)
139 }
140 MediaSource::Plain(mxc) => {
141 let (server_name, media_id) =
142 mxc.parts().map_err(|e| Error::Identifier(IdParseError::InvalidMxcUri(e)))?;
143 Ok(client
144 .send(MediaScanRequest::new(
145 self.scanner_url.clone(),
146 server_name.to_string(),
147 media_id.to_owned(),
148 ))
149 .await?)
150 }
151 }
152 }
153}
154
155#[derive(Debug, Clone, Serialize)]
156struct EncryptedBody {
157 ciphertext: String,
158 mac: String,
159 ephemeral: String,
160}
161
162impl From<Message> for EncryptedBody {
163 fn from(value: Message) -> Self {
164 Self {
165 ciphertext: Base64::<Standard>::new(value.ciphertext).to_string(),
166 mac: Base64::<Standard>::new(value.mac).to_string(),
167 ephemeral: value.ephemeral_key.to_base64(),
168 }
169 }
170}
171
172#[derive(Debug, Clone, Serialize)]
173pub(crate) struct EncryptedFileRequest {
174 #[serde(skip_serializing_if = "Option::is_none")]
175 pub file: Option<EncryptedFile>,
176 #[serde(skip_serializing_if = "Option::is_none")]
177 pub encrypted_body: Option<EncryptedBody>,
178}
179
180impl EncryptedFileRequest {
181 pub(crate) fn from_file_info(file_info: EncryptedFile) -> Self {
182 Self { file: Some(file_info), encrypted_body: None }
183 }
184
185 pub(crate) fn from_encrypted_body(encrypted_body: EncryptedBody) -> Self {
186 Self { file: None, encrypted_body: Some(encrypted_body) }
187 }
188}
189
190pub struct ContentScannerMediaFetcher {
192 pub content_scanner: ContentScanner,
193}
194
195impl ContentScannerMediaFetcher {
196 pub fn new(scanner_url: impl Into<String>) -> Self {
198 Self { content_scanner: ContentScanner::new(scanner_url.into()) }
199 }
200}
201
202impl MediaFetcher for ContentScannerMediaFetcher {
203 fn fetch_media_content<'a>(
204 &'a self,
205 client: &'a Client,
206 request: &'a MediaRequestParameters,
207 ) -> BoxFuture<'a, matrix_sdk::Result<Vec<u8>, Error>> {
208 Box::pin(async move {
209 Ok(self.content_scanner.get_media(client, &request.source).await?.content)
210 })
211 }
212}
213
214#[derive(Debug, Deserialize)]
216pub struct ContentScannerError {
217 pub info: String,
218 pub reason: ErrorReason,
219}
220
221#[allow(non_camel_case_types)]
223#[cfg_attr(feature = "uniffi", derive(uniffi::Enum))]
224#[derive(Clone, Debug, Deserialize)]
225pub enum ErrorReason {
226 MCS_MALFORMED_JSON,
228 MCS_MEDIA_FAILED_TO_DECRYPT,
230 M_MISSING_TOKEN,
232 M_UNKNOWN_TOKEN,
234 M_NOT_FOUND,
236 MCS_MEDIA_NOT_CLEAN,
238 MCS_MIME_TYPE_FORBIDDEN,
240 MCS_BAD_DECRYPTION,
242 M_UNKNOWN,
244 MCS_MEDIA_REQUEST_FAILED,
246}
247
248#[cfg(test)]
249mod tests {
250 use std::ops::Not;
251
252 use assert_matches2::assert_matches;
253 use matrix_sdk::{HttpError, RumaApiError, test_utils::mocks::MatrixMockServer};
254 use matrix_sdk_test::async_test;
255 use ruma::{
256 api::{
257 MatrixVersion,
258 error::{ErrorBody, FromHttpResponseError},
259 },
260 events::room::{
261 EncryptedFile, EncryptedFileHash, EncryptedFileHashes, EncryptedFileInfo, MediaSource,
262 V2EncryptedFileInfo,
263 },
264 exports::{http::StatusCode, serde_json::json},
265 owned_mxc_uri,
266 serde::Base64,
267 };
268 use serde::Deserialize;
269 use wiremock::{
270 Mock, MockServer, ResponseTemplate,
271 matchers::{header_exists, method, path, path_regex},
272 };
273
274 use crate::{ContentScanner, ContentScannerError, ErrorReason};
275
276 #[async_test]
277 async fn test_fetch_public_key() {
278 let server = MatrixMockServer::new().await;
279 let client =
280 server.client_builder().server_versions(vec![MatrixVersion::V1_11]).build().await;
281
282 let content_scanner_server = MockServer::start().await;
283 Mock::given(method("GET"))
284 .and(path("/_matrix/media_proxy/unstable/public_key"))
285 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
286 "public_key": "1234567890"
287 })))
288 .mount(&content_scanner_server)
289 .await;
290
291 let content_scanner = ContentScanner::new(content_scanner_server.uri());
292 content_scanner.fetch_public_server_key(&client).await.expect("Load public key");
293 }
294
295 #[async_test]
296 async fn test_get_media() {
297 let server = MatrixMockServer::new().await;
298 let client =
299 server.client_builder().server_versions(vec![MatrixVersion::V1_11]).build().await;
300
301 let content_scanner_server = MockServer::start().await;
302 Mock::given(method("GET"))
303 .and(path_regex(r"/_matrix/media_proxy/unstable/download/.+/.+"))
304 .and(header_exists("Authorization"))
305 .respond_with(
306 ResponseTemplate::new(200).set_body_bytes(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]),
307 )
308 .mount(&content_scanner_server)
309 .await;
310
311 let content_scanner = ContentScanner::new(content_scanner_server.uri());
312 let media_source =
313 MediaSource::Plain(owned_mxc_uri!("mxc://matrix.org/RhfpOXOzAwzkuqcmbgMwQUrJ"));
314 content_scanner.get_media(&client, &media_source).await.expect("Get media");
315 }
316
317 #[async_test]
318 async fn test_get_media_unsupported() {
319 let server = MatrixMockServer::new().await;
320 let client =
321 server.client_builder().server_versions(vec![MatrixVersion::V1_11]).build().await;
322
323 let content_scanner_server = MockServer::start().await;
324 Mock::given(method("GET"))
325 .and(path_regex(r"/_matrix/media_proxy/unstable/download/.+/.+"))
326 .and(header_exists("Authorization"))
327 .respond_with(ResponseTemplate::new(403).set_body_json(json!({
328 "reason": "MCS_MIME_TYPE_FORBIDDEN",
329 "info": "File type: application/octet-stream not allowed",
330 })))
331 .mount(&content_scanner_server)
332 .await;
333
334 let content_scanner = ContentScanner::new(content_scanner_server.uri());
335 let media_source =
336 MediaSource::Plain(owned_mxc_uri!("mxc://matrix.org/ckTaStcNnFXLzKApkBmgRDoC"));
337 let err =
338 content_scanner.get_media(&client, &media_source).await.expect_err("Get media error");
339 let client_error = err.as_client_api_error().expect("Get client error");
340 assert_eq!(client_error.status_code, StatusCode::FORBIDDEN);
341 assert_eq!(
342 client_error.to_string(),
343 "[403] {\"info\":\"File type: application/octet-stream not allowed\",\"reason\":\"MCS_MIME_TYPE_FORBIDDEN\"}"
344 );
345 }
346
347 #[async_test]
348 async fn test_get_encrypted_media() {
349 let server = MatrixMockServer::new().await;
350 let client =
351 server.client_builder().server_versions(vec![MatrixVersion::V1_11]).build().await;
352
353 let content_scanner_server = MockServer::start().await;
354 Mock::given(method("POST"))
355 .and(path("/_matrix/media_proxy/unstable/download_encrypted"))
356 .and(header_exists("Authorization"))
357 .respond_with(
358 ResponseTemplate::new(200).set_body_bytes(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]),
359 )
360 .mount(&content_scanner_server)
361 .await;
362
363 let content_scanner = ContentScanner::new(content_scanner_server.uri());
364 let file_info = EncryptedFileInfo::V2(V2EncryptedFileInfo::new(
365 Base64::parse("9lpOscZyMOZRCF3v867nPPo3WPNMZt9JXMsuYiWRszc".as_bytes()).expect("k"),
366 Base64::parse("czvdfKSjfLEAAAAAAAAAAA".as_bytes()).expect("iv"),
367 ));
368 let mut hashes = EncryptedFileHashes::new();
369 hashes.insert(EncryptedFileHash::Sha256(
370 Base64::parse("SBbJ3hINT2LgwXK8ev82enjnhubUy5UuKGDF3SezAhs".as_bytes()).expect("hash"),
371 ));
372 let media_source = MediaSource::Encrypted(Box::new(EncryptedFile::new(
373 owned_mxc_uri!(
374 "mxc://element.io/b50f38aa8ae820c75992370e4e944a045481e3932057062074730676224"
375 ),
376 file_info,
377 hashes,
378 )));
379 content_scanner.get_media(&client, &media_source).await.expect("Get media");
380 }
381
382 #[async_test]
383 async fn test_get_encrypted_media_unsupported() {
384 let server = MatrixMockServer::new().await;
385 let client =
386 server.client_builder().server_versions(vec![MatrixVersion::V1_11]).build().await;
387
388 let content_scanner_server = MockServer::start().await;
389
390 Mock::given(method("POST"))
391 .and(path("/_matrix/media_proxy/unstable/download_encrypted"))
392 .and(header_exists("Authorization"))
393 .respond_with(ResponseTemplate::new(403).set_body_json(json!({
394 "reason": "MCS_MIME_TYPE_FORBIDDEN",
395 "info": "File type: application/octet-stream not allowed",
396 })))
397 .mount(&content_scanner_server)
398 .await;
399
400 let content_scanner = ContentScanner::new(content_scanner_server.uri());
401 let file_info = EncryptedFileInfo::V2(V2EncryptedFileInfo::new(
402 Base64::parse("tdHdCI5mc-g29IYfhYx2wkA5o-bILP9-nXY6Np1uSnM".as_bytes()).expect("k"),
403 Base64::parse("IBFdH65KqhoAAAAAAAAAAA".as_bytes()).expect("iv"),
404 ));
405 let mut hashes = EncryptedFileHashes::new();
406 hashes.insert(EncryptedFileHash::Sha256(
407 Base64::parse("HSkkamvMSvF3Q30HInorh0ccPrxjgu+wp1vyUOmov/8".as_bytes()).expect("hash"),
408 ));
409 let media_source = MediaSource::Encrypted(Box::new(EncryptedFile::new(
410 owned_mxc_uri!("mxc://matrix.org/WlfuejQQdpvWiWVpAGwfIKJL"),
411 file_info,
412 hashes,
413 )));
414 let err = content_scanner
415 .get_media(&client, &media_source)
416 .await
417 .expect_err("Invalid type error");
418 let client_error = err.as_client_api_error().expect("Invalid error");
419 assert_eq!(client_error.status_code, StatusCode::FORBIDDEN);
420 assert_eq!(
421 client_error.to_string(),
422 "[403] {\"info\":\"File type: application/octet-stream not allowed\",\"reason\":\"MCS_MIME_TYPE_FORBIDDEN\"}"
423 );
424 }
425
426 #[async_test]
427 async fn test_scan_media() {
428 let server = MatrixMockServer::new().await;
429 let client =
430 server.client_builder().server_versions(vec![MatrixVersion::V1_11]).build().await;
431
432 let content_scanner_server = MockServer::start().await;
433 Mock::given(method("GET"))
434 .and(path_regex(r"/_matrix/media_proxy/unstable/scan/.+/.+"))
435 .and(header_exists("Authorization"))
436 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
437 "clean": true,
438 "info": "All clear!"
439 })))
440 .mount(&content_scanner_server)
441 .await;
442
443 let content_scanner = ContentScanner::new(content_scanner_server.uri());
444 let media_source =
445 MediaSource::Plain(owned_mxc_uri!("mxc://matrix.org/RhfpOXOzAwzkuqcmbgMwQUrJ"));
446 let response = content_scanner.scan(&client, &media_source).await.expect("Get media");
447 assert!(response.clean);
448 }
449
450 #[async_test]
451 async fn test_scan_encrypted_media() {
452 let server = MatrixMockServer::new().await;
453 let client =
454 server.client_builder().server_versions(vec![MatrixVersion::V1_11]).build().await;
455
456 let content_scanner_server = MockServer::start().await;
457 Mock::given(method("POST"))
458 .and(path("/_matrix/media_proxy/unstable/scan_encrypted"))
459 .and(header_exists("Authorization"))
460 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
461 "clean": true,
462 "info": "All clear!"
463 })))
464 .mount(&content_scanner_server)
465 .await;
466
467 let content_scanner = ContentScanner::new(content_scanner_server.uri());
468 let file_info = EncryptedFileInfo::V2(V2EncryptedFileInfo::new(
469 Base64::parse("9lpOscZyMOZRCF3v867nPPo3WPNMZt9JXMsuYiWRszc".as_bytes()).expect("k"),
470 Base64::parse("czvdfKSjfLEAAAAAAAAAAA".as_bytes()).expect("iv"),
471 ));
472 let mut hashes = EncryptedFileHashes::new();
473 hashes.insert(EncryptedFileHash::Sha256(
474 Base64::parse("SBbJ3hINT2LgwXK8ev82enjnhubUy5UuKGDF3SezAhs".as_bytes()).expect("hash"),
475 ));
476 let media_source = MediaSource::Encrypted(Box::new(EncryptedFile::new(
477 owned_mxc_uri!(
478 "mxc://element.io/b50f38aa8ae820c75992370e4e944a045481e3932057062074730676224"
479 ),
480 file_info,
481 hashes,
482 )));
483 let response = content_scanner.scan(&client, &media_source).await.expect("Get media");
484 assert!(response.clean);
485 }
486
487 #[async_test]
488 async fn test_scan_media_unsupported() {
489 let server = MatrixMockServer::new().await;
490 let client =
491 server.client_builder().server_versions(vec![MatrixVersion::V1_11]).build().await;
492
493 let content_scanner_server = MockServer::start().await;
494 Mock::given(method("GET"))
495 .and(path_regex(r"/_matrix/media_proxy/unstable/scan/.+/.+"))
496 .and(header_exists("Authorization"))
497 .respond_with(
498 ResponseTemplate::new(200).set_body_json(json!({
500 "clean": false,
501 "info": "***VIRUS DETECTED***"
502 })),
503 )
504 .mount(&content_scanner_server)
505 .await;
506
507 let content_scanner = ContentScanner::new(content_scanner_server.uri());
508 let media_source =
509 MediaSource::Plain(owned_mxc_uri!("mxc://matrix.org/RhfpOXOzAwzkuqcmbgMwQUrJ"));
510 let response = content_scanner.scan(&client, &media_source).await.expect("Get media");
511 assert!(response.clean.not());
512 }
513
514 #[test]
515 fn test_error_mapping() {
516 let error = HttpError::Api(Box::new(FromHttpResponseError::Server(
517 RumaApiError::MatrixError(ruma::api::error::Error::new(
518 StatusCode::FORBIDDEN,
519 ErrorBody::Json(json!({
520 "info": "***VIRUS DETECTED***",
521 "reason": "MCS_MEDIA_NOT_CLEAN"
522 })),
523 )),
524 )));
525 let api_error = error.as_client_api_error().expect("error as api error");
526 assert_eq!(
527 api_error.to_string(),
528 "[403] {\"info\":\"***VIRUS DETECTED***\",\"reason\":\"MCS_MEDIA_NOT_CLEAN\"}"
529 );
530 assert_matches!(&api_error.body, ErrorBody::Json(json_body));
531 let content_scanner_error =
532 ContentScannerError::deserialize(json_body).expect("deserialize");
533 assert_eq!(content_scanner_error.info, "***VIRUS DETECTED***");
534 assert_matches!(content_scanner_error.reason, ErrorReason::MCS_MEDIA_NOT_CLEAN);
535 }
536}