1use std::{fmt, path::Path, sync::Arc};
18
19use async_trait::async_trait;
20use deadpool_sqlite::{Object as SqliteAsyncConn, Pool as SqlitePool, Runtime};
21use matrix_sdk_base::{
22 media::{
23 store::{
24 IgnoreMediaRetentionPolicy, MediaRetentionPolicy, MediaService, MediaStore,
25 MediaStoreInner,
26 },
27 MediaRequestParameters, UniqueKey,
28 },
29 timer,
30};
31use matrix_sdk_store_encryption::StoreCipher;
32use ruma::{time::SystemTime, MilliSecondsSinceUnixEpoch, MxcUri};
33use rusqlite::{params_from_iter, OptionalExtension};
34use tokio::{
35 fs,
36 sync::{Mutex, OwnedMutexGuard},
37};
38use tracing::{debug, instrument, trace};
39
40use crate::{
41 error::{Error, Result},
42 utils::{
43 repeat_vars, time_to_timestamp, EncryptableStore, SqliteAsyncConnExt,
44 SqliteKeyValueStoreAsyncConnExt, SqliteKeyValueStoreConnExt, SqliteTransactionExt,
45 },
46 OpenStoreError, SqliteStoreConfig,
47};
48
49mod keys {
50 pub const MEDIA_RETENTION_POLICY: &str = "media_retention_policy";
52 pub const LAST_MEDIA_CLEANUP_TIME: &str = "last_media_cleanup_time";
53
54 pub const MEDIA: &str = "media";
56}
57
58const DATABASE_NAME: &str = "matrix-sdk-media.sqlite3";
60
61const DATABASE_VERSION: u8 = 1;
67
68#[derive(Clone)]
70pub struct SqliteMediaStore {
71 store_cipher: Option<Arc<StoreCipher>>,
72
73 pool: SqlitePool,
75
76 write_connection: Arc<Mutex<SqliteAsyncConn>>,
81
82 media_service: MediaService,
83}
84
85#[cfg(not(tarpaulin_include))]
86impl fmt::Debug for SqliteMediaStore {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 f.debug_struct("SqliteMediaStore").finish_non_exhaustive()
89 }
90}
91
92impl EncryptableStore for SqliteMediaStore {
93 fn get_cypher(&self) -> Option<&StoreCipher> {
94 self.store_cipher.as_deref()
95 }
96}
97
98impl SqliteMediaStore {
99 pub async fn open(
102 path: impl AsRef<Path>,
103 passphrase: Option<&str>,
104 ) -> Result<Self, OpenStoreError> {
105 Self::open_with_config(SqliteStoreConfig::new(path).passphrase(passphrase)).await
106 }
107
108 #[instrument(skip(config), fields(path = ?config.path))]
110 pub async fn open_with_config(config: SqliteStoreConfig) -> Result<Self, OpenStoreError> {
111 debug!(?config);
112
113 let _timer = timer!("open_with_config");
114
115 let SqliteStoreConfig { path, passphrase, pool_config, runtime_config } = config;
116
117 fs::create_dir_all(&path).await.map_err(OpenStoreError::CreateDir)?;
118
119 let mut config = deadpool_sqlite::Config::new(path.join(DATABASE_NAME));
120 config.pool = Some(pool_config);
121
122 let pool = config.create_pool(Runtime::Tokio1)?;
123
124 let this = Self::open_with_pool(pool, passphrase.as_deref()).await?;
125 this.write().await?.apply_runtime_config(runtime_config).await?;
126
127 Ok(this)
128 }
129
130 async fn open_with_pool(
133 pool: SqlitePool,
134 passphrase: Option<&str>,
135 ) -> Result<Self, OpenStoreError> {
136 let conn = pool.get().await?;
137
138 let version = conn.db_version().await?;
139 run_migrations(&conn, version).await?;
140
141 let store_cipher = match passphrase {
142 Some(p) => Some(Arc::new(conn.get_or_create_store_cipher(p).await?)),
143 None => None,
144 };
145
146 let media_service = MediaService::new();
147 let media_retention_policy = conn.get_serialized_kv(keys::MEDIA_RETENTION_POLICY).await?;
148 let last_media_cleanup_time = conn.get_serialized_kv(keys::LAST_MEDIA_CLEANUP_TIME).await?;
149 media_service.restore(media_retention_policy, last_media_cleanup_time);
150
151 Ok(Self {
152 store_cipher,
153 pool,
154 write_connection: Arc::new(Mutex::new(conn)),
156 media_service,
157 })
158 }
159
160 #[instrument(skip_all)]
162 async fn read(&self) -> Result<SqliteAsyncConn> {
163 trace!("Taking a `read` connection");
164 let _timer = timer!("connection");
165
166 let connection = self.pool.get().await?;
167
168 connection.execute_batch("PRAGMA foreign_keys = ON;").await?;
173
174 Ok(connection)
175 }
176
177 #[instrument(skip_all)]
179 async fn write(&self) -> Result<OwnedMutexGuard<SqliteAsyncConn>> {
180 trace!("Taking a `write` connection");
181 let _timer = timer!("connection");
182
183 let connection = self.write_connection.clone().lock_owned().await;
184
185 connection.execute_batch("PRAGMA foreign_keys = ON;").await?;
190
191 Ok(connection)
192 }
193}
194
195async fn run_migrations(conn: &SqliteAsyncConn, version: u8) -> Result<()> {
197 if version == 0 {
198 debug!("Creating database");
199 } else if version < DATABASE_VERSION {
200 debug!(version, new_version = DATABASE_VERSION, "Upgrading database");
201 } else {
202 return Ok(());
203 }
204
205 conn.execute_batch("PRAGMA foreign_keys = ON;").await?;
207
208 if version < 1 {
209 conn.execute_batch("PRAGMA journal_mode = wal;").await?;
212 conn.with_transaction(|txn| {
213 txn.execute_batch(include_str!("../migrations/media_store/001_init.sql"))?;
214 txn.set_db_version(1)
215 })
216 .await?;
217 }
218
219 Ok(())
220}
221
222#[async_trait]
223impl MediaStore for SqliteMediaStore {
224 type Error = Error;
225
226 #[instrument(skip(self))]
227 async fn try_take_leased_lock(
228 &self,
229 lease_duration_ms: u32,
230 key: &str,
231 holder: &str,
232 ) -> Result<bool> {
233 let _timer = timer!("method");
234
235 let key = key.to_owned();
236 let holder = holder.to_owned();
237
238 let now: u64 = MilliSecondsSinceUnixEpoch::now().get().into();
239 let expiration = now + lease_duration_ms as u64;
240
241 let num_touched = self
242 .write()
243 .await?
244 .with_transaction(move |txn| {
245 txn.execute(
246 "INSERT INTO lease_locks (key, holder, expiration)
247 VALUES (?1, ?2, ?3)
248 ON CONFLICT (key)
249 DO
250 UPDATE SET holder = ?2, expiration = ?3
251 WHERE holder = ?2
252 OR expiration < ?4
253 ",
254 (key, holder, expiration, now),
255 )
256 })
257 .await?;
258
259 Ok(num_touched == 1)
260 }
261
262 async fn add_media_content(
263 &self,
264 request: &MediaRequestParameters,
265 content: Vec<u8>,
266 ignore_policy: IgnoreMediaRetentionPolicy,
267 ) -> Result<()> {
268 let _timer = timer!("method");
269
270 self.media_service.add_media_content(self, request, content, ignore_policy).await
271 }
272
273 #[instrument(skip_all)]
274 async fn replace_media_key(
275 &self,
276 from: &MediaRequestParameters,
277 to: &MediaRequestParameters,
278 ) -> Result<(), Self::Error> {
279 let _timer = timer!("method");
280
281 let prev_uri = self.encode_key(keys::MEDIA, from.source.unique_key());
282 let prev_format = self.encode_key(keys::MEDIA, from.format.unique_key());
283
284 let new_uri = self.encode_key(keys::MEDIA, to.source.unique_key());
285 let new_format = self.encode_key(keys::MEDIA, to.format.unique_key());
286
287 let conn = self.write().await?;
288 conn.execute(
289 r#"UPDATE media SET uri = ?, format = ? WHERE uri = ? AND format = ?"#,
290 (new_uri, new_format, prev_uri, prev_format),
291 )
292 .await?;
293
294 Ok(())
295 }
296
297 #[instrument(skip_all)]
298 async fn get_media_content(&self, request: &MediaRequestParameters) -> Result<Option<Vec<u8>>> {
299 let _timer = timer!("method");
300
301 self.media_service.get_media_content(self, request).await
302 }
303
304 #[instrument(skip_all)]
305 async fn remove_media_content(&self, request: &MediaRequestParameters) -> Result<()> {
306 let _timer = timer!("method");
307
308 let uri = self.encode_key(keys::MEDIA, request.source.unique_key());
309 let format = self.encode_key(keys::MEDIA, request.format.unique_key());
310
311 let conn = self.write().await?;
312 conn.execute("DELETE FROM media WHERE uri = ? AND format = ?", (uri, format)).await?;
313
314 Ok(())
315 }
316
317 #[instrument(skip(self))]
318 async fn get_media_content_for_uri(
319 &self,
320 uri: &MxcUri,
321 ) -> Result<Option<Vec<u8>>, Self::Error> {
322 let _timer = timer!("method");
323
324 self.media_service.get_media_content_for_uri(self, uri).await
325 }
326
327 #[instrument(skip(self))]
328 async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
329 let _timer = timer!("method");
330
331 let uri = self.encode_key(keys::MEDIA, uri);
332
333 let conn = self.write().await?;
334 conn.execute("DELETE FROM media WHERE uri = ?", (uri,)).await?;
335
336 Ok(())
337 }
338
339 #[instrument(skip_all)]
340 async fn set_media_retention_policy(
341 &self,
342 policy: MediaRetentionPolicy,
343 ) -> Result<(), Self::Error> {
344 let _timer = timer!("method");
345
346 self.media_service.set_media_retention_policy(self, policy).await
347 }
348
349 #[instrument(skip_all)]
350 fn media_retention_policy(&self) -> MediaRetentionPolicy {
351 let _timer = timer!("method");
352
353 self.media_service.media_retention_policy()
354 }
355
356 #[instrument(skip_all)]
357 async fn set_ignore_media_retention_policy(
358 &self,
359 request: &MediaRequestParameters,
360 ignore_policy: IgnoreMediaRetentionPolicy,
361 ) -> Result<(), Self::Error> {
362 let _timer = timer!("method");
363
364 self.media_service.set_ignore_media_retention_policy(self, request, ignore_policy).await
365 }
366
367 #[instrument(skip_all)]
368 async fn clean(&self) -> Result<(), Self::Error> {
369 let _timer = timer!("method");
370
371 self.media_service.clean(self).await
372 }
373}
374
375#[cfg_attr(target_family = "wasm", async_trait(?Send))]
376#[cfg_attr(not(target_family = "wasm"), async_trait)]
377impl MediaStoreInner for SqliteMediaStore {
378 type Error = Error;
379
380 async fn media_retention_policy_inner(
381 &self,
382 ) -> Result<Option<MediaRetentionPolicy>, Self::Error> {
383 let conn = self.read().await?;
384 conn.get_serialized_kv(keys::MEDIA_RETENTION_POLICY).await
385 }
386
387 async fn set_media_retention_policy_inner(
388 &self,
389 policy: MediaRetentionPolicy,
390 ) -> Result<(), Self::Error> {
391 let conn = self.write().await?;
392 conn.set_serialized_kv(keys::MEDIA_RETENTION_POLICY, policy).await?;
393 Ok(())
394 }
395
396 async fn add_media_content_inner(
397 &self,
398 request: &MediaRequestParameters,
399 data: Vec<u8>,
400 last_access: SystemTime,
401 policy: MediaRetentionPolicy,
402 ignore_policy: IgnoreMediaRetentionPolicy,
403 ) -> Result<(), Self::Error> {
404 let ignore_policy = ignore_policy.is_yes();
405 let data = self.encode_value(data)?;
406
407 if !ignore_policy && policy.exceeds_max_file_size(data.len() as u64) {
408 return Ok(());
409 }
410
411 let uri = self.encode_key(keys::MEDIA, request.source.unique_key());
412 let format = self.encode_key(keys::MEDIA, request.format.unique_key());
413 let timestamp = time_to_timestamp(last_access);
414
415 let conn = self.write().await?;
416 conn.execute(
417 "INSERT OR REPLACE INTO media (uri, format, data, last_access, ignore_policy) VALUES (?, ?, ?, ?, ?)",
418 (uri, format, data, timestamp, ignore_policy),
419 )
420 .await?;
421
422 Ok(())
423 }
424
425 async fn set_ignore_media_retention_policy_inner(
426 &self,
427 request: &MediaRequestParameters,
428 ignore_policy: IgnoreMediaRetentionPolicy,
429 ) -> Result<(), Self::Error> {
430 let uri = self.encode_key(keys::MEDIA, request.source.unique_key());
431 let format = self.encode_key(keys::MEDIA, request.format.unique_key());
432 let ignore_policy = ignore_policy.is_yes();
433
434 let conn = self.write().await?;
435 conn.execute(
436 r#"UPDATE media SET ignore_policy = ? WHERE uri = ? AND format = ?"#,
437 (ignore_policy, uri, format),
438 )
439 .await?;
440
441 Ok(())
442 }
443
444 async fn get_media_content_inner(
445 &self,
446 request: &MediaRequestParameters,
447 current_time: SystemTime,
448 ) -> Result<Option<Vec<u8>>, Self::Error> {
449 let uri = self.encode_key(keys::MEDIA, request.source.unique_key());
450 let format = self.encode_key(keys::MEDIA, request.format.unique_key());
451 let timestamp = time_to_timestamp(current_time);
452
453 let conn = self.write().await?;
454 let data = conn
455 .with_transaction::<_, rusqlite::Error, _>(move |txn| {
456 txn.execute(
460 "UPDATE media SET last_access = ? WHERE uri = ? AND format = ?",
461 (timestamp, &uri, &format),
462 )?;
463
464 txn.query_row::<Vec<u8>, _, _>(
465 "SELECT data FROM media WHERE uri = ? AND format = ?",
466 (&uri, &format),
467 |row| row.get(0),
468 )
469 .optional()
470 })
471 .await?;
472
473 data.map(|v| self.decode_value(&v).map(Into::into)).transpose()
474 }
475
476 async fn get_media_content_for_uri_inner(
477 &self,
478 uri: &MxcUri,
479 current_time: SystemTime,
480 ) -> Result<Option<Vec<u8>>, Self::Error> {
481 let uri = self.encode_key(keys::MEDIA, uri);
482 let timestamp = time_to_timestamp(current_time);
483
484 let conn = self.write().await?;
485 let data = conn
486 .with_transaction::<_, rusqlite::Error, _>(move |txn| {
487 txn.execute("UPDATE media SET last_access = ? WHERE uri = ?", (timestamp, &uri))?;
491
492 txn.query_row::<Vec<u8>, _, _>(
493 "SELECT data FROM media WHERE uri = ?",
494 (&uri,),
495 |row| row.get(0),
496 )
497 .optional()
498 })
499 .await?;
500
501 data.map(|v| self.decode_value(&v).map(Into::into)).transpose()
502 }
503
504 async fn clean_inner(
505 &self,
506 policy: MediaRetentionPolicy,
507 current_time: SystemTime,
508 ) -> Result<(), Self::Error> {
509 if !policy.has_limitations() {
510 return Ok(());
512 }
513
514 let conn = self.write().await?;
515 let removed = conn
516 .with_transaction::<_, Error, _>(move |txn| {
517 let mut removed = false;
518
519 if let Some(max_file_size) = policy.computed_max_file_size() {
521 let count = txn.execute(
522 "DELETE FROM media WHERE ignore_policy IS FALSE AND length(data) > ?",
523 (max_file_size,),
524 )?;
525
526 if count > 0 {
527 removed = true;
528 }
529 }
530
531 if let Some(last_access_expiry) = policy.last_access_expiry {
533 let current_timestamp = time_to_timestamp(current_time);
534 let expiry_secs = last_access_expiry.as_secs();
535 let count = txn.execute(
536 "DELETE FROM media WHERE ignore_policy IS FALSE AND (? - last_access) >= ?",
537 (current_timestamp, expiry_secs),
538 )?;
539
540 if count > 0 {
541 removed = true;
542 }
543 }
544
545 if let Some(max_cache_size) = policy.max_cache_size {
547 let cache_size = txn
550 .query_row(
551 "SELECT sum(length(data)) FROM media WHERE ignore_policy IS FALSE",
552 (),
553 |row| {
554 row.get::<_, Option<u64>>(0)
556 },
557 )?
558 .unwrap_or_default();
559
560 if cache_size > max_cache_size {
562 let mut cached_stmt = txn.prepare_cached(
564 "SELECT rowid, length(data) FROM media \
565 WHERE ignore_policy IS FALSE ORDER BY last_access DESC",
566 )?;
567 let content_sizes = cached_stmt
568 .query(())?
569 .mapped(|row| Ok((row.get::<_, i64>(0)?, row.get::<_, u64>(1)?)));
570
571 let mut accumulated_items_size = 0u64;
572 let mut limit_reached = false;
573 let mut rows_to_remove = Vec::new();
574
575 for result in content_sizes {
576 let (row_id, size) = match result {
577 Ok(content_size) => content_size,
578 Err(error) => {
579 return Err(error.into());
580 }
581 };
582
583 if limit_reached {
584 rows_to_remove.push(row_id);
585 continue;
586 }
587
588 match accumulated_items_size.checked_add(size) {
589 Some(acc) if acc > max_cache_size => {
590 limit_reached = true;
592 rows_to_remove.push(row_id);
593 }
594 Some(acc) => accumulated_items_size = acc,
595 None => {
596 limit_reached = true;
599 rows_to_remove.push(row_id);
600 }
601 }
602 }
603
604 if !rows_to_remove.is_empty() {
605 removed = true;
606 }
607
608 txn.chunk_large_query_over(rows_to_remove, None, |txn, row_ids| {
609 let sql_params = repeat_vars(row_ids.len());
610 let query = format!("DELETE FROM media WHERE rowid IN ({sql_params})");
611 txn.prepare(&query)?.execute(params_from_iter(row_ids))?;
612 Ok(Vec::<()>::new())
613 })?;
614 }
615 }
616
617 txn.set_serialized_kv(keys::LAST_MEDIA_CLEANUP_TIME, current_time)?;
618
619 Ok(removed)
620 })
621 .await?;
622
623 if removed {
626 conn.vacuum().await?;
627 }
628
629 Ok(())
630 }
631
632 async fn last_media_cleanup_time_inner(&self) -> Result<Option<SystemTime>, Self::Error> {
633 let conn = self.read().await?;
634 conn.get_serialized_kv(keys::LAST_MEDIA_CLEANUP_TIME).await
635 }
636}
637
638#[cfg(test)]
639mod tests {
640 use std::{
641 path::PathBuf,
642 sync::atomic::{AtomicU32, Ordering::SeqCst},
643 time::Duration,
644 };
645
646 use matrix_sdk_base::{
647 media::{
648 store::{IgnoreMediaRetentionPolicy, MediaStore, MediaStoreError},
649 MediaFormat, MediaRequestParameters, MediaThumbnailSettings,
650 },
651 media_store_inner_integration_tests, media_store_integration_tests,
652 media_store_integration_tests_time,
653 };
654 use matrix_sdk_test::async_test;
655 use once_cell::sync::Lazy;
656 use ruma::{events::room::MediaSource, media::Method, mxc_uri, uint};
657 use tempfile::{tempdir, TempDir};
658
659 use super::SqliteMediaStore;
660 use crate::{utils::SqliteAsyncConnExt, SqliteStoreConfig};
661
662 static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
663 static NUM: AtomicU32 = AtomicU32::new(0);
664
665 fn new_media_store_workspace() -> PathBuf {
666 let name = NUM.fetch_add(1, SeqCst).to_string();
667 TMP_DIR.path().join(name)
668 }
669
670 async fn get_media_store() -> Result<SqliteMediaStore, MediaStoreError> {
671 let tmpdir_path = new_media_store_workspace();
672
673 tracing::info!("using media store @ {}", tmpdir_path.to_str().unwrap());
674
675 Ok(SqliteMediaStore::open(tmpdir_path.to_str().unwrap(), None).await.unwrap())
676 }
677
678 media_store_integration_tests!();
679 media_store_integration_tests_time!();
680 media_store_inner_integration_tests!();
681
682 async fn get_media_store_content_sorted_by_last_access(
683 media_store: &SqliteMediaStore,
684 ) -> Vec<Vec<u8>> {
685 let sqlite_db = media_store.read().await.expect("accessing sqlite db failed");
686 sqlite_db
687 .prepare("SELECT data FROM media ORDER BY last_access DESC", |mut stmt| {
688 stmt.query(())?.mapped(|row| row.get(0)).collect()
689 })
690 .await
691 .expect("querying media cache content by last access failed")
692 }
693
694 #[async_test]
695 async fn test_pool_size() {
696 let tmpdir_path = new_media_store_workspace();
697 let store_open_config = SqliteStoreConfig::new(tmpdir_path).pool_max_size(42);
698
699 let store = SqliteMediaStore::open_with_config(store_open_config).await.unwrap();
700
701 assert_eq!(store.pool.status().max_size, 42);
702 }
703
704 #[async_test]
705 async fn test_last_access() {
706 let media_store = get_media_store().await.expect("creating media cache failed");
707 let uri = mxc_uri!("mxc://localhost/media");
708 let file_request = MediaRequestParameters {
709 source: MediaSource::Plain(uri.to_owned()),
710 format: MediaFormat::File,
711 };
712 let thumbnail_request = MediaRequestParameters {
713 source: MediaSource::Plain(uri.to_owned()),
714 format: MediaFormat::Thumbnail(MediaThumbnailSettings::with_method(
715 Method::Crop,
716 uint!(100),
717 uint!(100),
718 )),
719 };
720
721 let content: Vec<u8> = "hello world".into();
722 let thumbnail_content: Vec<u8> = "hello…".into();
723
724 media_store
726 .add_media_content(&file_request, content.clone(), IgnoreMediaRetentionPolicy::No)
727 .await
728 .expect("adding file failed");
729
730 tokio::time::sleep(Duration::from_secs(3)).await;
733
734 media_store
735 .add_media_content(
736 &thumbnail_request,
737 thumbnail_content.clone(),
738 IgnoreMediaRetentionPolicy::No,
739 )
740 .await
741 .expect("adding thumbnail failed");
742
743 let contents = get_media_store_content_sorted_by_last_access(&media_store).await;
745
746 assert_eq!(contents.len(), 2, "media cache contents length is wrong");
747 assert_eq!(contents[0], thumbnail_content, "thumbnail is not last access");
748 assert_eq!(contents[1], content, "file is not second-to-last access");
749
750 tokio::time::sleep(Duration::from_secs(3)).await;
753
754 let _ = media_store
756 .get_media_content(&file_request)
757 .await
758 .expect("getting file failed")
759 .expect("file is missing");
760
761 let contents = get_media_store_content_sorted_by_last_access(&media_store).await;
763
764 assert_eq!(contents.len(), 2, "media cache contents length is wrong");
765 assert_eq!(contents[0], content, "file is not last access");
766 assert_eq!(contents[1], thumbnail_content, "thumbnail is not second-to-last access");
767 }
768}
769
770#[cfg(test)]
771mod encrypted_tests {
772 use std::sync::atomic::{AtomicU32, Ordering::SeqCst};
773
774 use matrix_sdk_base::{
775 media::store::MediaStoreError, media_store_inner_integration_tests,
776 media_store_integration_tests, media_store_integration_tests_time,
777 };
778 use once_cell::sync::Lazy;
779 use tempfile::{tempdir, TempDir};
780
781 use super::SqliteMediaStore;
782
783 static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
784 static NUM: AtomicU32 = AtomicU32::new(0);
785
786 async fn get_media_store() -> Result<SqliteMediaStore, MediaStoreError> {
787 let name = NUM.fetch_add(1, SeqCst).to_string();
788 let tmpdir_path = TMP_DIR.path().join(name);
789
790 tracing::info!("using media store @ {}", tmpdir_path.to_str().unwrap());
791
792 Ok(SqliteMediaStore::open(tmpdir_path.to_str().unwrap(), Some("default_test_password"))
793 .await
794 .unwrap())
795 }
796
797 media_store_integration_tests!();
798 media_store_integration_tests_time!();
799 media_store_inner_integration_tests!();
800}