1use std::{
18 fmt,
19 path::{Path, PathBuf},
20 sync::Arc,
21};
22
23use async_trait::async_trait;
24use deadpool::managed::PoolConfig;
25use matrix_sdk_base::{
26 cross_process_lock::CrossProcessLockGeneration,
27 media::{
28 MediaRequestParameters, UniqueKey,
29 store::{
30 IgnoreMediaRetentionPolicy, MediaRetentionPolicy, MediaService, MediaStore,
31 MediaStoreInner,
32 },
33 },
34 timer,
35};
36use matrix_sdk_store_encryption::StoreCipher;
37use ruma::{MilliSecondsSinceUnixEpoch, MxcUri, time::SystemTime};
38use rusqlite::{OptionalExtension, params_from_iter};
39use tokio::{
40 fs,
41 sync::{Mutex, OwnedMutexGuard},
42};
43use tracing::{debug, instrument};
44
45use crate::{
46 OpenStoreError, RuntimeConfig, Secret, SqliteStoreConfig,
47 connection::{self, Connection as SqliteAsyncConn, Pool as SqlitePool, SqliteConnections},
48 error::{Error, Result},
49 utils::{
50 EncryptableStore, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt,
51 SqliteKeyValueStoreConnExt, SqliteTransactionExt, repeat_vars, time_to_timestamp,
52 },
53};
54
55mod keys {
56 pub const MEDIA_RETENTION_POLICY: &str = "media_retention_policy";
58 pub const LAST_MEDIA_CLEANUP_TIME: &str = "last_media_cleanup_time";
59
60 pub const MEDIA: &str = "media";
62}
63
64const DATABASE_NAME: &str = "matrix-sdk-media.sqlite3";
66
67#[derive(Clone)]
69pub struct SqliteMediaStore {
70 store_cipher: Option<Arc<StoreCipher>>,
71
72 connections: Arc<Mutex<Option<SqliteConnections>>>,
74
75 db_path: PathBuf,
77
78 pool_config: PoolConfig,
80
81 runtime_config: RuntimeConfig,
83
84 media_service: MediaService,
85}
86
87#[cfg(not(tarpaulin_include))]
88impl fmt::Debug for SqliteMediaStore {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 f.debug_struct("SqliteMediaStore").finish_non_exhaustive()
91 }
92}
93
94impl EncryptableStore for SqliteMediaStore {
95 fn get_cypher(&self) -> Option<&StoreCipher> {
96 self.store_cipher.as_deref()
97 }
98}
99
100impl SqliteMediaStore {
101 pub async fn open(
104 path: impl AsRef<Path>,
105 passphrase: Option<&str>,
106 ) -> Result<Self, OpenStoreError> {
107 Self::open_with_config(&SqliteStoreConfig::new(path).passphrase(passphrase)).await
108 }
109
110 pub async fn open_with_key(
113 path: impl AsRef<Path>,
114 key: Option<&[u8; 32]>,
115 ) -> Result<Self, OpenStoreError> {
116 Self::open_with_config(&SqliteStoreConfig::new(path).key(key)).await
117 }
118
119 #[instrument(skip(config), fields(path = ?config.path))]
121 pub async fn open_with_config(config: &SqliteStoreConfig) -> Result<Self, OpenStoreError> {
122 debug!(?config);
123
124 let _timer = timer!("open_with_config");
125
126 fs::create_dir_all(&config.path).await.map_err(OpenStoreError::CreateDir)?;
127
128 let db_path = config.path.join(DATABASE_NAME);
129 let pool_config = config.pool_config();
130 let runtime_config = config.runtime_config();
131
132 let pool = config.build_pool_of_connections(DATABASE_NAME)?;
133
134 let this =
135 Self::open_with_pool(pool, db_path, pool_config, runtime_config, config.secret.clone())
136 .await?;
137
138 this.write().await?.apply_runtime_config(runtime_config).await?;
140
141 Ok(this)
142 }
143
144 async fn open_with_pool(
147 pool: SqlitePool,
148 db_path: PathBuf,
149 pool_config: PoolConfig,
150 runtime_config: RuntimeConfig,
151 secret: Option<Secret>,
152 ) -> Result<Self, OpenStoreError> {
153 let conn = pool.get().await?;
154
155 let version = conn.db_version().await?;
156 run_migrations(&conn, version).await?;
157
158 conn.wal_checkpoint().await;
159
160 let store_cipher = match &secret {
161 Some(s) => Some(Arc::new(conn.get_or_create_store_cipher(s.clone()).await?)),
162 None => None,
163 };
164
165 let media_service = MediaService::new();
166 let media_retention_policy = conn.get_serialized_kv(keys::MEDIA_RETENTION_POLICY).await?;
167 let last_media_cleanup_time = conn.get_serialized_kv(keys::LAST_MEDIA_CLEANUP_TIME).await?;
168 media_service.restore(media_retention_policy, last_media_cleanup_time);
169
170 let connections = SqliteConnections {
171 pool,
172 write_connection: Arc::new(Mutex::new(conn)),
174 };
175
176 Ok(Self {
177 store_cipher,
178 connections: Arc::new(Mutex::new(Some(connections))),
179 db_path,
180 pool_config,
181 runtime_config,
182 media_service,
183 })
184 }
185
186 #[instrument(skip_all)]
188 async fn read(&self) -> Result<SqliteAsyncConn> {
189 let pool = {
190 let guard = self.connections.lock().await;
191 let conns = guard.as_ref().ok_or(Error::StoreClosed)?;
192 conns.pool.clone()
193 };
194
195 let connection = pool.get().await?;
196
197 connection.execute_batch("PRAGMA foreign_keys = ON;").await?;
202
203 Ok(connection)
204 }
205
206 #[instrument(skip_all)]
208 async fn write(&self) -> Result<OwnedMutexGuard<SqliteAsyncConn>> {
209 let write_connection = {
210 let guard = self.connections.lock().await;
211 let conns = guard.as_ref().ok_or(Error::StoreClosed)?;
212 conns.write_connection.clone()
213 };
214
215 let connection = write_connection.lock_owned().await;
216
217 connection.execute_batch("PRAGMA foreign_keys = ON;").await?;
222
223 Ok(connection)
224 }
225
226 pub async fn vacuum(&self) -> Result<()> {
227 let write_connection = {
228 let guard = self.connections.lock().await;
229 let conns = guard.as_ref().ok_or(Error::StoreClosed)?;
230 conns.write_connection.clone()
231 };
232 write_connection.lock().await.vacuum().await
233 }
234
235 async fn get_db_size(&self) -> Result<Option<usize>> {
236 let pool = {
237 let guard = self.connections.lock().await;
238 let conns = guard.as_ref().ok_or(Error::StoreClosed)?;
239 conns.pool.clone()
240 };
241 Ok(Some(pool.get().await?.get_db_size().await?))
242 }
243
244 pub async fn close(&self) -> Result<()> {
245 connection::close_connections(&self.connections, "Media store").await;
246 Ok(())
247 }
248
249 pub async fn reopen(&self) -> Result<()> {
250 connection::reopen_connections(
251 &self.connections,
252 self.db_path.clone(),
253 self.pool_config,
254 self.runtime_config,
255 )
256 .await?;
257 Ok(())
258 }
259
260 #[cfg(test)]
262 async fn pool_max_size(&self) -> Option<usize> {
263 let guard = self.connections.lock().await;
264 guard.as_ref().map(|conns| conns.pool.status().max_size)
265 }
266}
267
268async fn run_migrations(conn: &SqliteAsyncConn, version: u8) -> Result<()> {
270 conn.execute_batch("PRAGMA foreign_keys = ON;").await?;
272
273 if version < 1 {
274 debug!("Creating database");
275 conn.execute_batch("PRAGMA journal_mode = wal;").await?;
278 conn.with_transaction(|txn| {
279 txn.execute_batch(include_str!("../migrations/media_store/001_init.sql"))?;
280 txn.set_db_version(1)
281 })
282 .await?;
283 }
284
285 if version < 2 {
286 debug!("Upgrading database to version 2");
287 conn.with_transaction(|txn| {
288 txn.execute_batch(include_str!(
289 "../migrations/media_store/002_lease_locks_with_generation.sql"
290 ))?;
291 txn.set_db_version(2)
292 })
293 .await?;
294 }
295
296 Ok(())
297}
298
299#[async_trait]
300impl MediaStore for SqliteMediaStore {
301 type Error = Error;
302
303 #[instrument(skip(self))]
304 async fn try_take_leased_lock(
305 &self,
306 lease_duration_ms: u32,
307 key: &str,
308 holder: &str,
309 ) -> Result<Option<CrossProcessLockGeneration>> {
310 let key = key.to_owned();
311 let holder = holder.to_owned();
312
313 let now: u64 = MilliSecondsSinceUnixEpoch::now().get().into();
314 let expiration = now + lease_duration_ms as u64;
315
316 let generation = self
318 .write()
319 .await?
320 .with_transaction(move |txn| {
321 txn.query_row(
322 "INSERT INTO lease_locks (key, holder, expiration)
323 VALUES (?1, ?2, ?3)
324 ON CONFLICT (key)
325 DO
326 UPDATE SET
327 holder = excluded.holder,
328 expiration = excluded.expiration,
329 generation =
330 CASE holder
331 WHEN excluded.holder THEN generation
332 ELSE generation + 1
333 END
334 WHERE
335 holder = excluded.holder
336 OR expiration < ?4
337 RETURNING generation
338 ",
339 (key, holder, expiration, now),
340 |row| row.get(0),
341 )
342 .optional()
343 })
344 .await?;
345
346 Ok(generation)
347 }
348
349 async fn add_media_content(
350 &self,
351 request: &MediaRequestParameters,
352 content: Vec<u8>,
353 ignore_policy: IgnoreMediaRetentionPolicy,
354 ) -> Result<()> {
355 let _timer = timer!("method");
356
357 self.media_service.add_media_content(self, request, content, ignore_policy).await
358 }
359
360 #[instrument(skip_all)]
361 async fn replace_media_key(
362 &self,
363 from: &MediaRequestParameters,
364 to: &MediaRequestParameters,
365 ) -> Result<(), Self::Error> {
366 let _timer = timer!("method");
367
368 let prev_uri = self.encode_key(keys::MEDIA, from.source.unique_key());
369 let prev_format = self.encode_key(keys::MEDIA, from.format.unique_key());
370
371 let new_uri = self.encode_key(keys::MEDIA, to.source.unique_key());
372 let new_format = self.encode_key(keys::MEDIA, to.format.unique_key());
373
374 let conn = self.write().await?;
375 conn.execute(
376 r#"UPDATE media SET uri = ?, format = ? WHERE uri = ? AND format = ?"#,
377 (new_uri, new_format, prev_uri, prev_format),
378 )
379 .await?;
380
381 Ok(())
382 }
383
384 #[instrument(skip_all)]
385 async fn get_media_content(&self, request: &MediaRequestParameters) -> Result<Option<Vec<u8>>> {
386 let _timer = timer!("method");
387
388 self.media_service.get_media_content(self, request).await
389 }
390
391 #[instrument(skip_all)]
392 async fn remove_media_content(&self, request: &MediaRequestParameters) -> Result<()> {
393 let _timer = timer!("method");
394
395 let uri = self.encode_key(keys::MEDIA, request.source.unique_key());
396 let format = self.encode_key(keys::MEDIA, request.format.unique_key());
397
398 let conn = self.write().await?;
399 conn.execute("DELETE FROM media WHERE uri = ? AND format = ?", (uri, format)).await?;
400
401 Ok(())
402 }
403
404 #[instrument(skip(self))]
405 async fn get_media_content_for_uri(
406 &self,
407 uri: &MxcUri,
408 ) -> Result<Option<Vec<u8>>, Self::Error> {
409 let _timer = timer!("method");
410
411 self.media_service.get_media_content_for_uri(self, uri).await
412 }
413
414 #[instrument(skip(self))]
415 async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
416 let _timer = timer!("method");
417
418 let uri = self.encode_key(keys::MEDIA, uri);
419
420 let conn = self.write().await?;
421 conn.execute("DELETE FROM media WHERE uri = ?", (uri,)).await?;
422
423 Ok(())
424 }
425
426 #[instrument(skip_all)]
427 async fn set_media_retention_policy(
428 &self,
429 policy: MediaRetentionPolicy,
430 ) -> Result<(), Self::Error> {
431 let _timer = timer!("method");
432
433 self.media_service.set_media_retention_policy(self, policy).await
434 }
435
436 #[instrument(skip_all)]
437 fn media_retention_policy(&self) -> MediaRetentionPolicy {
438 let _timer = timer!("method");
439
440 self.media_service.media_retention_policy()
441 }
442
443 #[instrument(skip_all)]
444 async fn set_ignore_media_retention_policy(
445 &self,
446 request: &MediaRequestParameters,
447 ignore_policy: IgnoreMediaRetentionPolicy,
448 ) -> Result<(), Self::Error> {
449 let _timer = timer!("method");
450
451 self.media_service.set_ignore_media_retention_policy(self, request, ignore_policy).await
452 }
453
454 #[instrument(skip_all)]
455 async fn clean(&self) -> Result<(), Self::Error> {
456 let _timer = timer!("method");
457
458 self.media_service.clean(self).await
459 }
460
461 async fn close(&self) -> Result<(), Self::Error> {
462 SqliteMediaStore::close(self).await
463 }
464
465 async fn reopen(&self) -> Result<(), Self::Error> {
466 SqliteMediaStore::reopen(self).await
467 }
468
469 async fn optimize(&self) -> Result<(), Self::Error> {
470 Ok(self.vacuum().await?)
471 }
472
473 async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
474 self.get_db_size().await
475 }
476}
477
478#[cfg_attr(target_family = "wasm", async_trait(?Send))]
479#[cfg_attr(not(target_family = "wasm"), async_trait)]
480impl MediaStoreInner for SqliteMediaStore {
481 type Error = Error;
482
483 async fn media_retention_policy_inner(
484 &self,
485 ) -> Result<Option<MediaRetentionPolicy>, Self::Error> {
486 let conn = self.read().await?;
487 conn.get_serialized_kv(keys::MEDIA_RETENTION_POLICY).await
488 }
489
490 async fn set_media_retention_policy_inner(
491 &self,
492 policy: MediaRetentionPolicy,
493 ) -> Result<(), Self::Error> {
494 let conn = self.write().await?;
495 conn.set_serialized_kv(keys::MEDIA_RETENTION_POLICY, policy).await?;
496 Ok(())
497 }
498
499 async fn add_media_content_inner(
500 &self,
501 request: &MediaRequestParameters,
502 data: Vec<u8>,
503 last_access: SystemTime,
504 policy: MediaRetentionPolicy,
505 ignore_policy: IgnoreMediaRetentionPolicy,
506 ) -> Result<(), Self::Error> {
507 let ignore_policy = ignore_policy.is_yes();
508 let data = self.encode_value(data)?;
509
510 if !ignore_policy && policy.exceeds_max_file_size(data.len() as u64) {
511 return Ok(());
512 }
513
514 let uri = self.encode_key(keys::MEDIA, request.source.unique_key());
515 let format = self.encode_key(keys::MEDIA, request.format.unique_key());
516 let timestamp = time_to_timestamp(last_access);
517
518 let conn = self.write().await?;
519 conn.execute(
520 "INSERT OR REPLACE INTO media (uri, format, data, last_access, ignore_policy) VALUES (?, ?, ?, ?, ?)",
521 (uri, format, data, timestamp, ignore_policy),
522 )
523 .await?;
524
525 Ok(())
526 }
527
528 async fn set_ignore_media_retention_policy_inner(
529 &self,
530 request: &MediaRequestParameters,
531 ignore_policy: IgnoreMediaRetentionPolicy,
532 ) -> Result<(), Self::Error> {
533 let uri = self.encode_key(keys::MEDIA, request.source.unique_key());
534 let format = self.encode_key(keys::MEDIA, request.format.unique_key());
535 let ignore_policy = ignore_policy.is_yes();
536
537 let conn = self.write().await?;
538 conn.execute(
539 r#"UPDATE media SET ignore_policy = ? WHERE uri = ? AND format = ?"#,
540 (ignore_policy, uri, format),
541 )
542 .await?;
543
544 Ok(())
545 }
546
547 async fn get_media_content_inner(
548 &self,
549 request: &MediaRequestParameters,
550 current_time: SystemTime,
551 ) -> Result<Option<Vec<u8>>, Self::Error> {
552 let uri = self.encode_key(keys::MEDIA, request.source.unique_key());
553 let format = self.encode_key(keys::MEDIA, request.format.unique_key());
554 let timestamp = time_to_timestamp(current_time);
555
556 let conn = self.write().await?;
557 let data = conn
558 .with_transaction::<_, rusqlite::Error, _>(move |txn| {
559 txn.execute(
563 "UPDATE media SET last_access = ? WHERE uri = ? AND format = ?",
564 (timestamp, &uri, &format),
565 )?;
566
567 txn.query_row::<Vec<u8>, _, _>(
568 "SELECT data FROM media WHERE uri = ? AND format = ?",
569 (&uri, &format),
570 |row| row.get(0),
571 )
572 .optional()
573 })
574 .await?;
575
576 data.map(|v| self.decode_value(&v).map(Into::into)).transpose()
577 }
578
579 async fn get_media_content_for_uri_inner(
580 &self,
581 uri: &MxcUri,
582 current_time: SystemTime,
583 ) -> Result<Option<Vec<u8>>, Self::Error> {
584 let uri = self.encode_key(keys::MEDIA, uri);
585 let timestamp = time_to_timestamp(current_time);
586
587 let conn = self.write().await?;
588 let data = conn
589 .with_transaction::<_, rusqlite::Error, _>(move |txn| {
590 txn.execute("UPDATE media SET last_access = ? WHERE uri = ?", (timestamp, &uri))?;
594
595 txn.query_row::<Vec<u8>, _, _>(
596 "SELECT data FROM media WHERE uri = ?",
597 (&uri,),
598 |row| row.get(0),
599 )
600 .optional()
601 })
602 .await?;
603
604 data.map(|v| self.decode_value(&v).map(Into::into)).transpose()
605 }
606
607 async fn clean_inner(
608 &self,
609 policy: MediaRetentionPolicy,
610 current_time: SystemTime,
611 ) -> Result<(), Self::Error> {
612 if !policy.has_limitations() {
613 return Ok(());
615 }
616
617 let conn = self.write().await?;
618 let removed = conn
619 .with_transaction::<_, Error, _>(move |txn| {
620 let mut removed = false;
621
622 if let Some(max_file_size) = policy.computed_max_file_size() {
624 let count = txn.execute(
625 "DELETE FROM media WHERE ignore_policy IS FALSE AND length(data) > ?",
626 (max_file_size,),
627 )?;
628
629 if count > 0 {
630 removed = true;
631 }
632 }
633
634 if let Some(last_access_expiry) = policy.last_access_expiry {
636 let current_timestamp = time_to_timestamp(current_time);
637 let expiry_secs = last_access_expiry.as_secs();
638 let count = txn.execute(
639 "DELETE FROM media WHERE ignore_policy IS FALSE AND (? - last_access) >= ?",
640 (current_timestamp, expiry_secs),
641 )?;
642
643 if count > 0 {
644 removed = true;
645 }
646 }
647
648 if let Some(max_cache_size) = policy.max_cache_size {
650 let cache_size = txn
653 .query_row(
654 "SELECT sum(length(data)) FROM media WHERE ignore_policy IS FALSE",
655 (),
656 |row| {
657 row.get::<_, Option<u64>>(0)
659 },
660 )?
661 .unwrap_or_default();
662
663 if cache_size > max_cache_size {
665 let mut cached_stmt = txn.prepare_cached(
667 "SELECT rowid, length(data) FROM media \
668 WHERE ignore_policy IS FALSE ORDER BY last_access DESC",
669 )?;
670 let content_sizes = cached_stmt
671 .query(())?
672 .mapped(|row| Ok((row.get::<_, i64>(0)?, row.get::<_, u64>(1)?)));
673
674 let mut accumulated_items_size = 0u64;
675 let mut limit_reached = false;
676 let mut rows_to_remove = Vec::new();
677
678 for result in content_sizes {
679 let (row_id, size) = match result {
680 Ok(content_size) => content_size,
681 Err(error) => {
682 return Err(error.into());
683 }
684 };
685
686 if limit_reached {
687 rows_to_remove.push(row_id);
688 continue;
689 }
690
691 match accumulated_items_size.checked_add(size) {
692 Some(acc) if acc > max_cache_size => {
693 limit_reached = true;
695 rows_to_remove.push(row_id);
696 }
697 Some(acc) => accumulated_items_size = acc,
698 None => {
699 limit_reached = true;
702 rows_to_remove.push(row_id);
703 }
704 }
705 }
706
707 if !rows_to_remove.is_empty() {
708 removed = true;
709 }
710
711 txn.chunk_large_query_over(rows_to_remove, None, |txn, row_ids| {
712 let sql_params = repeat_vars(row_ids.len());
713 let query = format!("DELETE FROM media WHERE rowid IN ({sql_params})");
714 txn.prepare(&query)?.execute(params_from_iter(row_ids))?;
715 Ok(Vec::<()>::new())
716 })?;
717 }
718 }
719
720 txn.set_serialized_kv(keys::LAST_MEDIA_CLEANUP_TIME, current_time)?;
721
722 Ok(removed)
723 })
724 .await?;
725
726 if removed {
729 conn.vacuum().await?;
730 }
731
732 Ok(())
733 }
734
735 async fn last_media_cleanup_time_inner(&self) -> Result<Option<SystemTime>, Self::Error> {
736 let conn = self.read().await?;
737 conn.get_serialized_kv(keys::LAST_MEDIA_CLEANUP_TIME).await
738 }
739}
740
741#[cfg(test)]
742mod tests {
743 use std::{
744 path::PathBuf,
745 sync::{
746 LazyLock,
747 atomic::{AtomicU32, Ordering::SeqCst},
748 },
749 time::Duration,
750 };
751
752 use matrix_sdk_base::{
753 media::{
754 MediaFormat, MediaRequestParameters, MediaThumbnailSettings,
755 store::{IgnoreMediaRetentionPolicy, MediaStore, MediaStoreError},
756 },
757 media_store_inner_integration_tests, media_store_integration_tests,
758 media_store_integration_tests_time,
759 };
760 use matrix_sdk_test::async_test;
761 use ruma::{events::room::MediaSource, media::Method, mxc_uri, uint};
762 use tempfile::{TempDir, tempdir};
763
764 use super::SqliteMediaStore;
765 use crate::{SqliteStoreConfig, utils::SqliteAsyncConnExt};
766
767 static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
768 static NUM: AtomicU32 = AtomicU32::new(0);
769
770 fn new_media_store_workspace() -> PathBuf {
771 let name = NUM.fetch_add(1, SeqCst).to_string();
772 TMP_DIR.path().join(name)
773 }
774
775 async fn get_media_store() -> Result<SqliteMediaStore, MediaStoreError> {
776 let tmpdir_path = new_media_store_workspace();
777
778 tracing::info!("using media store @ {}", tmpdir_path.to_str().unwrap());
779
780 Ok(SqliteMediaStore::open(tmpdir_path.to_str().unwrap(), None).await.unwrap())
781 }
782
783 media_store_integration_tests!();
784 media_store_integration_tests_time!();
785 media_store_inner_integration_tests!();
786
787 async fn get_media_store_content_sorted_by_last_access(
788 media_store: &SqliteMediaStore,
789 ) -> Vec<Vec<u8>> {
790 let sqlite_db = media_store.read().await.expect("accessing sqlite db failed");
791 sqlite_db
792 .prepare("SELECT data FROM media ORDER BY last_access DESC", |mut stmt| {
793 stmt.query(())?.mapped(|row| row.get(0)).collect()
794 })
795 .await
796 .expect("querying media cache content by last access failed")
797 }
798
799 #[async_test]
800 async fn test_pool_size() {
801 let tmpdir_path = new_media_store_workspace();
802 let store_open_config = SqliteStoreConfig::new(tmpdir_path).pool_max_size(42);
803
804 let store = SqliteMediaStore::open_with_config(&store_open_config).await.unwrap();
805
806 assert_eq!(store.pool_max_size().await.unwrap(), 42);
807 }
808
809 #[async_test]
810 async fn test_last_access() {
811 let media_store = get_media_store().await.expect("creating media cache failed");
812 let uri = mxc_uri!("mxc://localhost/media");
813 let file_request = MediaRequestParameters {
814 source: MediaSource::Plain(uri.to_owned()),
815 format: MediaFormat::File,
816 };
817 let thumbnail_request = MediaRequestParameters {
818 source: MediaSource::Plain(uri.to_owned()),
819 format: MediaFormat::Thumbnail(MediaThumbnailSettings::with_method(
820 Method::Crop,
821 uint!(100),
822 uint!(100),
823 )),
824 };
825
826 let content: Vec<u8> = "hello world".into();
827 let thumbnail_content: Vec<u8> = "hello…".into();
828
829 media_store
831 .add_media_content(&file_request, content.clone(), IgnoreMediaRetentionPolicy::No)
832 .await
833 .expect("adding file failed");
834
835 tokio::time::sleep(Duration::from_secs(3)).await;
838
839 media_store
840 .add_media_content(
841 &thumbnail_request,
842 thumbnail_content.clone(),
843 IgnoreMediaRetentionPolicy::No,
844 )
845 .await
846 .expect("adding thumbnail failed");
847
848 let contents = get_media_store_content_sorted_by_last_access(&media_store).await;
850
851 assert_eq!(contents.len(), 2, "media cache contents length is wrong");
852 assert_eq!(contents[0], thumbnail_content, "thumbnail is not last access");
853 assert_eq!(contents[1], content, "file is not second-to-last access");
854
855 tokio::time::sleep(Duration::from_secs(3)).await;
858
859 let _ = media_store
861 .get_media_content(&file_request)
862 .await
863 .expect("getting file failed")
864 .expect("file is missing");
865
866 let contents = get_media_store_content_sorted_by_last_access(&media_store).await;
868
869 assert_eq!(contents.len(), 2, "media cache contents length is wrong");
870 assert_eq!(contents[0], content, "file is not last access");
871 assert_eq!(contents[1], thumbnail_content, "thumbnail is not second-to-last access");
872 }
873}
874
875#[cfg(test)]
876mod close_reopen_tests {
877 use std::sync::{
878 LazyLock,
879 atomic::{AtomicU32, Ordering::SeqCst},
880 };
881
882 use matrix_sdk_base::media::{
883 MediaFormat, MediaRequestParameters,
884 store::{IgnoreMediaRetentionPolicy, MediaStore},
885 };
886 use matrix_sdk_test::async_test;
887 use ruma::{events::room::MediaSource, mxc_uri};
888 use tempfile::{TempDir, tempdir};
889
890 use super::SqliteMediaStore;
891
892 static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
893 static NUM: AtomicU32 = AtomicU32::new(0);
894
895 async fn new_store() -> SqliteMediaStore {
896 let name = NUM.fetch_add(1, SeqCst).to_string();
897 let tmpdir_path = TMP_DIR.path().join(name);
898 SqliteMediaStore::open(tmpdir_path, None).await.unwrap()
899 }
900
901 fn test_request() -> MediaRequestParameters {
902 MediaRequestParameters {
903 source: MediaSource::Plain(mxc_uri!("mxc://localhost/test_media").to_owned()),
904 format: MediaFormat::File,
905 }
906 }
907
908 #[async_test]
909 async fn test_close_completes_without_timeout() {
910 let store = new_store().await;
911
912 let start = std::time::Instant::now();
914 store.close().await.unwrap();
915 let elapsed = start.elapsed();
916
917 assert!(
918 elapsed < std::time::Duration::from_secs(2),
919 "close() took {elapsed:?}, expected < 2s (no timeout)"
920 );
921
922 let guard = store.connections.lock().await;
924 assert!(guard.is_none(), "connections should be None after close");
925 }
926
927 #[async_test]
928 async fn test_reopen_restores_connections() {
929 let store = new_store().await;
930
931 store.close().await.unwrap();
932
933 {
934 let guard = store.connections.lock().await;
935 assert!(guard.is_none());
936 }
937
938 store.reopen().await.unwrap();
939
940 {
941 let guard = store.connections.lock().await;
942 assert!(guard.is_some(), "connections should be Some after reopen");
943 }
944 }
945
946 #[async_test]
947 async fn test_close_is_idempotent() {
948 let store = new_store().await;
949
950 store.close().await.unwrap();
951 store.close().await.unwrap();
953
954 let guard = store.connections.lock().await;
955 assert!(guard.is_none());
956 }
957
958 #[async_test]
959 async fn test_reopen_is_idempotent() {
960 let store = new_store().await;
961
962 store.reopen().await.unwrap();
964
965 let guard = store.connections.lock().await;
966 assert!(guard.is_some());
967 }
968
969 #[async_test]
970 async fn test_read_fails_when_closed() {
971 let store = new_store().await;
972 store.close().await.unwrap();
973
974 let err = store.get_media_content(&test_request()).await;
975 assert!(err.is_err(), "read should fail when closed");
976
977 let err_msg = err.unwrap_err().to_string();
978 assert!(err_msg.contains("closed"), "error should mention 'closed', got: {err_msg}");
979 }
980
981 #[async_test]
982 async fn test_write_fails_when_closed() {
983 let store = new_store().await;
984 store.close().await.unwrap();
985
986 let err = store
987 .add_media_content(&test_request(), b"data".to_vec(), IgnoreMediaRetentionPolicy::No)
988 .await;
989 assert!(err.is_err(), "write should fail when closed");
990
991 let err_msg = err.unwrap_err().to_string();
992 assert!(err_msg.contains("closed"), "error should mention 'closed', got: {err_msg}");
993 }
994
995 #[async_test]
996 async fn test_data_persists_across_close_reopen() {
997 let store = new_store().await;
998
999 store
1001 .add_media_content(
1002 &test_request(),
1003 b"hello world".to_vec(),
1004 IgnoreMediaRetentionPolicy::Yes,
1005 )
1006 .await
1007 .unwrap();
1008
1009 let content = store.get_media_content(&test_request()).await.unwrap();
1011 assert_eq!(content.as_deref(), Some(b"hello world".as_slice()));
1012
1013 store.close().await.unwrap();
1015 store.reopen().await.unwrap();
1016
1017 let content = store.get_media_content(&test_request()).await.unwrap();
1019 assert_eq!(
1020 content.as_deref(),
1021 Some(b"hello world".as_slice()),
1022 "media content should persist across close/reopen"
1023 );
1024 }
1025
1026 #[async_test]
1027 async fn test_multiple_close_reopen_cycles() {
1028 let store = new_store().await;
1029
1030 for _ in 0..5 {
1031 store.close().await.unwrap();
1032 store.reopen().await.unwrap();
1033
1034 let result = store.get_media_content(&test_request()).await;
1036 assert!(result.is_ok(), "store should work after close/reopen cycle");
1037 }
1038 }
1039
1040 #[async_test]
1041 async fn test_pool_is_fully_drained_after_close() {
1042 let store = new_store().await;
1043
1044 let _ = store.get_media_content(&test_request()).await;
1046 let _ = store.get_media_content(&test_request()).await;
1047
1048 store.close().await.unwrap();
1049
1050 let guard = store.connections.lock().await;
1053 assert!(guard.is_none(), "all connections should be released after close");
1054 }
1055
1056 #[async_test]
1057 async fn test_operations_work_immediately_after_reopen() {
1058 let store = new_store().await;
1059
1060 store.close().await.unwrap();
1061 store.reopen().await.unwrap();
1062
1063 let result = store.get_media_content(&test_request()).await;
1065 assert!(result.is_ok(), "read should succeed immediately after reopen");
1066
1067 let result = store
1069 .add_media_content(
1070 &test_request(),
1071 b"after_reopen".to_vec(),
1072 IgnoreMediaRetentionPolicy::No,
1073 )
1074 .await;
1075 assert!(result.is_ok(), "write should succeed immediately after reopen");
1076 }
1077
1078 #[async_test]
1079 async fn test_close_waits_for_held_read_connection_to_drain() {
1080 let store = new_store().await;
1081
1082 let held_conn = store.read().await.unwrap();
1084
1085 let store_clone = store.clone();
1088 let close_handle = tokio::spawn(async move {
1089 store_clone.close().await.unwrap();
1090 });
1091
1092 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1094
1095 assert!(!close_handle.is_finished(), "close should be waiting for the held connection");
1097
1098 drop(held_conn);
1100
1101 let timeout = tokio::time::timeout(std::time::Duration::from_secs(3), close_handle).await;
1103 assert!(timeout.is_ok(), "close should complete after the held connection is released");
1104 timeout.unwrap().unwrap();
1105
1106 let guard = store.connections.lock().await;
1108 assert!(guard.is_none(), "connections should be None after close");
1109 }
1110}
1111
1112#[cfg(test)]
1113mod encrypted_tests {
1114 use std::sync::{
1115 LazyLock,
1116 atomic::{AtomicU32, Ordering::SeqCst},
1117 };
1118
1119 use matrix_sdk_base::{
1120 media::store::MediaStoreError, media_store_inner_integration_tests,
1121 media_store_integration_tests, media_store_integration_tests_time,
1122 };
1123 use tempfile::{TempDir, tempdir};
1124
1125 use super::SqliteMediaStore;
1126
1127 static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
1128 static NUM: AtomicU32 = AtomicU32::new(0);
1129
1130 async fn get_media_store() -> Result<SqliteMediaStore, MediaStoreError> {
1131 let name = NUM.fetch_add(1, SeqCst).to_string();
1132 let tmpdir_path = TMP_DIR.path().join(name);
1133
1134 tracing::info!("using media store @ {}", tmpdir_path.to_str().unwrap());
1135
1136 Ok(SqliteMediaStore::open(tmpdir_path.to_str().unwrap(), Some("default_test_password"))
1137 .await
1138 .unwrap())
1139 }
1140
1141 media_store_integration_tests!();
1142 media_store_integration_tests_time!();
1143 media_store_inner_integration_tests!();
1144}