Skip to main content

matrix_sdk_sqlite/
media_store.rs

1// Copyright 2024 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! An SQLite-based backend for the [`MediaStore`].
16
17use std::{
18    fmt,
19    path::{Path, PathBuf},
20    sync::Arc,
21};
22
23use async_trait::async_trait;
24use deadpool::managed::PoolConfig;
25use matrix_sdk_base::{
26    cross_process_lock::CrossProcessLockGeneration,
27    media::{
28        MediaRequestParameters, UniqueKey,
29        store::{
30            IgnoreMediaRetentionPolicy, MediaRetentionPolicy, MediaService, MediaStore,
31            MediaStoreInner,
32        },
33    },
34    timer,
35};
36use matrix_sdk_store_encryption::StoreCipher;
37use ruma::{MilliSecondsSinceUnixEpoch, MxcUri, time::SystemTime};
38use rusqlite::{OptionalExtension, params_from_iter};
39use tokio::{
40    fs,
41    sync::{Mutex, OwnedMutexGuard},
42};
43use tracing::{debug, instrument};
44
45use crate::{
46    OpenStoreError, RuntimeConfig, Secret, SqliteStoreConfig,
47    connection::{self, Connection as SqliteAsyncConn, Pool as SqlitePool, SqliteConnections},
48    error::{Error, Result},
49    utils::{
50        EncryptableStore, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt,
51        SqliteKeyValueStoreConnExt, SqliteTransactionExt, repeat_vars, time_to_timestamp,
52    },
53};
54
55mod keys {
56    // Entries in Key-value store
57    pub const MEDIA_RETENTION_POLICY: &str = "media_retention_policy";
58    pub const LAST_MEDIA_CLEANUP_TIME: &str = "last_media_cleanup_time";
59
60    // Tables
61    pub const MEDIA: &str = "media";
62}
63
64/// The database name.
65const DATABASE_NAME: &str = "matrix-sdk-media.sqlite3";
66
67/// An SQLite-based media store.
68#[derive(Clone)]
69pub struct SqliteMediaStore {
70    store_cipher: Option<Arc<StoreCipher>>,
71
72    /// `Some` when active, `None` when closed.
73    connections: Arc<Mutex<Option<SqliteConnections>>>,
74
75    /// Retained so we can rebuild the pool on reopen.
76    db_path: PathBuf,
77
78    /// Retained so we can rebuild the pool on reopen.
79    pool_config: PoolConfig,
80
81    /// Retained so we can re-apply runtime config on reopen.
82    runtime_config: RuntimeConfig,
83
84    media_service: MediaService,
85}
86
87#[cfg(not(tarpaulin_include))]
88impl fmt::Debug for SqliteMediaStore {
89    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90        f.debug_struct("SqliteMediaStore").finish_non_exhaustive()
91    }
92}
93
94impl EncryptableStore for SqliteMediaStore {
95    fn get_cypher(&self) -> Option<&StoreCipher> {
96        self.store_cipher.as_deref()
97    }
98}
99
100impl SqliteMediaStore {
101    /// Open the SQLite-based media store at the given path using the
102    /// given passphrase to encrypt private data.
103    pub async fn open(
104        path: impl AsRef<Path>,
105        passphrase: Option<&str>,
106    ) -> Result<Self, OpenStoreError> {
107        Self::open_with_config(&SqliteStoreConfig::new(path).passphrase(passphrase)).await
108    }
109
110    /// Open the SQLite-based media store at the given path using the given
111    /// key to encrypt private data.
112    pub async fn open_with_key(
113        path: impl AsRef<Path>,
114        key: Option<&[u8; 32]>,
115    ) -> Result<Self, OpenStoreError> {
116        Self::open_with_config(&SqliteStoreConfig::new(path).key(key)).await
117    }
118
119    /// Open the SQLite-based media store with the config open config.
120    #[instrument(skip(config), fields(path = ?config.path))]
121    pub async fn open_with_config(config: &SqliteStoreConfig) -> Result<Self, OpenStoreError> {
122        debug!(?config);
123
124        let _timer = timer!("open_with_config");
125
126        fs::create_dir_all(&config.path).await.map_err(OpenStoreError::CreateDir)?;
127
128        let db_path = config.path.join(DATABASE_NAME);
129        let pool_config = config.pool_config();
130        let runtime_config = config.runtime_config();
131
132        let pool = config.build_pool_of_connections(DATABASE_NAME)?;
133
134        let this =
135            Self::open_with_pool(pool, db_path, pool_config, runtime_config, config.secret.clone())
136                .await?;
137
138        // Apply runtime config on the write connection.
139        this.write().await?.apply_runtime_config(runtime_config).await?;
140
141        Ok(this)
142    }
143
144    /// Open an SQLite-based media store using the given SQLite database
145    /// pool. The given passphrase will be used to encrypt private data.
146    async fn open_with_pool(
147        pool: SqlitePool,
148        db_path: PathBuf,
149        pool_config: PoolConfig,
150        runtime_config: RuntimeConfig,
151        secret: Option<Secret>,
152    ) -> Result<Self, OpenStoreError> {
153        let conn = pool.get().await?;
154
155        let version = conn.db_version().await?;
156        run_migrations(&conn, version).await?;
157
158        conn.wal_checkpoint().await;
159
160        let store_cipher = match &secret {
161            Some(s) => Some(Arc::new(conn.get_or_create_store_cipher(s.clone()).await?)),
162            None => None,
163        };
164
165        let media_service = MediaService::new();
166        let media_retention_policy = conn.get_serialized_kv(keys::MEDIA_RETENTION_POLICY).await?;
167        let last_media_cleanup_time = conn.get_serialized_kv(keys::LAST_MEDIA_CLEANUP_TIME).await?;
168        media_service.restore(media_retention_policy, last_media_cleanup_time);
169
170        let connections = SqliteConnections {
171            pool,
172            // Use `conn` as our selected write connection.
173            write_connection: Arc::new(Mutex::new(conn)),
174        };
175
176        Ok(Self {
177            store_cipher,
178            connections: Arc::new(Mutex::new(Some(connections))),
179            db_path,
180            pool_config,
181            runtime_config,
182            media_service,
183        })
184    }
185
186    // Acquire a connection for executing read operations.
187    #[instrument(skip_all)]
188    async fn read(&self) -> Result<SqliteAsyncConn> {
189        let pool = {
190            let guard = self.connections.lock().await;
191            let conns = guard.as_ref().ok_or(Error::StoreClosed)?;
192            conns.pool.clone()
193        };
194
195        let connection = pool.get().await?;
196
197        // Per https://www.sqlite.org/foreignkeys.html#fk_enable, foreign key
198        // support must be enabled on a per-connection basis. Execute it every
199        // time we try to get a connection, since we can't guarantee a previous
200        // connection did enable it before.
201        connection.execute_batch("PRAGMA foreign_keys = ON;").await?;
202
203        Ok(connection)
204    }
205
206    // Acquire a connection for executing write operations.
207    #[instrument(skip_all)]
208    async fn write(&self) -> Result<OwnedMutexGuard<SqliteAsyncConn>> {
209        let write_connection = {
210            let guard = self.connections.lock().await;
211            let conns = guard.as_ref().ok_or(Error::StoreClosed)?;
212            conns.write_connection.clone()
213        };
214
215        let connection = write_connection.lock_owned().await;
216
217        // Per https://www.sqlite.org/foreignkeys.html#fk_enable, foreign key
218        // support must be enabled on a per-connection basis. Execute it every
219        // time we try to get a connection, since we can't guarantee a previous
220        // connection did enable it before.
221        connection.execute_batch("PRAGMA foreign_keys = ON;").await?;
222
223        Ok(connection)
224    }
225
226    pub async fn vacuum(&self) -> Result<()> {
227        let write_connection = {
228            let guard = self.connections.lock().await;
229            let conns = guard.as_ref().ok_or(Error::StoreClosed)?;
230            conns.write_connection.clone()
231        };
232        write_connection.lock().await.vacuum().await
233    }
234
235    async fn get_db_size(&self) -> Result<Option<usize>> {
236        let pool = {
237            let guard = self.connections.lock().await;
238            let conns = guard.as_ref().ok_or(Error::StoreClosed)?;
239            conns.pool.clone()
240        };
241        Ok(Some(pool.get().await?.get_db_size().await?))
242    }
243
244    pub async fn close(&self) -> Result<()> {
245        connection::close_connections(&self.connections, "Media store").await;
246        Ok(())
247    }
248
249    pub async fn reopen(&self) -> Result<()> {
250        connection::reopen_connections(
251            &self.connections,
252            self.db_path.clone(),
253            self.pool_config,
254            self.runtime_config,
255        )
256        .await?;
257        Ok(())
258    }
259
260    /// Returns the pool size status, for testing purposes.
261    #[cfg(test)]
262    async fn pool_max_size(&self) -> Option<usize> {
263        let guard = self.connections.lock().await;
264        guard.as_ref().map(|conns| conns.pool.status().max_size)
265    }
266}
267
268/// Run migrations for the given version of the database.
269async fn run_migrations(conn: &SqliteAsyncConn, version: u8) -> Result<()> {
270    // Always enable foreign keys for the current connection.
271    conn.execute_batch("PRAGMA foreign_keys = ON;").await?;
272
273    if version < 1 {
274        debug!("Creating database");
275        // First turn on WAL mode, this can't be done in the transaction, it fails with
276        // the error message: "cannot change into wal mode from within a transaction".
277        conn.execute_batch("PRAGMA journal_mode = wal;").await?;
278        conn.with_transaction(|txn| {
279            txn.execute_batch(include_str!("../migrations/media_store/001_init.sql"))?;
280            txn.set_db_version(1)
281        })
282        .await?;
283    }
284
285    if version < 2 {
286        debug!("Upgrading database to version 2");
287        conn.with_transaction(|txn| {
288            txn.execute_batch(include_str!(
289                "../migrations/media_store/002_lease_locks_with_generation.sql"
290            ))?;
291            txn.set_db_version(2)
292        })
293        .await?;
294    }
295
296    Ok(())
297}
298
299#[async_trait]
300impl MediaStore for SqliteMediaStore {
301    type Error = Error;
302
303    #[instrument(skip(self))]
304    async fn try_take_leased_lock(
305        &self,
306        lease_duration_ms: u32,
307        key: &str,
308        holder: &str,
309    ) -> Result<Option<CrossProcessLockGeneration>> {
310        let key = key.to_owned();
311        let holder = holder.to_owned();
312
313        let now: u64 = MilliSecondsSinceUnixEpoch::now().get().into();
314        let expiration = now + lease_duration_ms as u64;
315
316        // Learn about the `excluded` keyword in https://sqlite.org/lang_upsert.html.
317        let generation = self
318            .write()
319            .await?
320            .with_transaction(move |txn| {
321                txn.query_row(
322                    "INSERT INTO lease_locks (key, holder, expiration)
323                    VALUES (?1, ?2, ?3)
324                    ON CONFLICT (key)
325                    DO
326                        UPDATE SET
327                            holder = excluded.holder,
328                            expiration = excluded.expiration,
329                            generation =
330                                CASE holder
331                                    WHEN excluded.holder THEN generation
332                                    ELSE generation + 1
333                                END
334                        WHERE
335                            holder = excluded.holder
336                            OR expiration < ?4
337                    RETURNING generation
338                    ",
339                    (key, holder, expiration, now),
340                    |row| row.get(0),
341                )
342                .optional()
343            })
344            .await?;
345
346        Ok(generation)
347    }
348
349    async fn add_media_content(
350        &self,
351        request: &MediaRequestParameters,
352        content: Vec<u8>,
353        ignore_policy: IgnoreMediaRetentionPolicy,
354    ) -> Result<()> {
355        let _timer = timer!("method");
356
357        self.media_service.add_media_content(self, request, content, ignore_policy).await
358    }
359
360    #[instrument(skip_all)]
361    async fn replace_media_key(
362        &self,
363        from: &MediaRequestParameters,
364        to: &MediaRequestParameters,
365    ) -> Result<(), Self::Error> {
366        let _timer = timer!("method");
367
368        let prev_uri = self.encode_key(keys::MEDIA, from.source.unique_key());
369        let prev_format = self.encode_key(keys::MEDIA, from.format.unique_key());
370
371        let new_uri = self.encode_key(keys::MEDIA, to.source.unique_key());
372        let new_format = self.encode_key(keys::MEDIA, to.format.unique_key());
373
374        let conn = self.write().await?;
375        conn.execute(
376            r#"UPDATE media SET uri = ?, format = ? WHERE uri = ? AND format = ?"#,
377            (new_uri, new_format, prev_uri, prev_format),
378        )
379        .await?;
380
381        Ok(())
382    }
383
384    #[instrument(skip_all)]
385    async fn get_media_content(&self, request: &MediaRequestParameters) -> Result<Option<Vec<u8>>> {
386        let _timer = timer!("method");
387
388        self.media_service.get_media_content(self, request).await
389    }
390
391    #[instrument(skip_all)]
392    async fn remove_media_content(&self, request: &MediaRequestParameters) -> Result<()> {
393        let _timer = timer!("method");
394
395        let uri = self.encode_key(keys::MEDIA, request.source.unique_key());
396        let format = self.encode_key(keys::MEDIA, request.format.unique_key());
397
398        let conn = self.write().await?;
399        conn.execute("DELETE FROM media WHERE uri = ? AND format = ?", (uri, format)).await?;
400
401        Ok(())
402    }
403
404    #[instrument(skip(self))]
405    async fn get_media_content_for_uri(
406        &self,
407        uri: &MxcUri,
408    ) -> Result<Option<Vec<u8>>, Self::Error> {
409        let _timer = timer!("method");
410
411        self.media_service.get_media_content_for_uri(self, uri).await
412    }
413
414    #[instrument(skip(self))]
415    async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
416        let _timer = timer!("method");
417
418        let uri = self.encode_key(keys::MEDIA, uri);
419
420        let conn = self.write().await?;
421        conn.execute("DELETE FROM media WHERE uri = ?", (uri,)).await?;
422
423        Ok(())
424    }
425
426    #[instrument(skip_all)]
427    async fn set_media_retention_policy(
428        &self,
429        policy: MediaRetentionPolicy,
430    ) -> Result<(), Self::Error> {
431        let _timer = timer!("method");
432
433        self.media_service.set_media_retention_policy(self, policy).await
434    }
435
436    #[instrument(skip_all)]
437    fn media_retention_policy(&self) -> MediaRetentionPolicy {
438        let _timer = timer!("method");
439
440        self.media_service.media_retention_policy()
441    }
442
443    #[instrument(skip_all)]
444    async fn set_ignore_media_retention_policy(
445        &self,
446        request: &MediaRequestParameters,
447        ignore_policy: IgnoreMediaRetentionPolicy,
448    ) -> Result<(), Self::Error> {
449        let _timer = timer!("method");
450
451        self.media_service.set_ignore_media_retention_policy(self, request, ignore_policy).await
452    }
453
454    #[instrument(skip_all)]
455    async fn clean(&self) -> Result<(), Self::Error> {
456        let _timer = timer!("method");
457
458        self.media_service.clean(self).await
459    }
460
461    async fn close(&self) -> Result<(), Self::Error> {
462        SqliteMediaStore::close(self).await
463    }
464
465    async fn reopen(&self) -> Result<(), Self::Error> {
466        SqliteMediaStore::reopen(self).await
467    }
468
469    async fn optimize(&self) -> Result<(), Self::Error> {
470        Ok(self.vacuum().await?)
471    }
472
473    async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
474        self.get_db_size().await
475    }
476}
477
478#[cfg_attr(target_family = "wasm", async_trait(?Send))]
479#[cfg_attr(not(target_family = "wasm"), async_trait)]
480impl MediaStoreInner for SqliteMediaStore {
481    type Error = Error;
482
483    async fn media_retention_policy_inner(
484        &self,
485    ) -> Result<Option<MediaRetentionPolicy>, Self::Error> {
486        let conn = self.read().await?;
487        conn.get_serialized_kv(keys::MEDIA_RETENTION_POLICY).await
488    }
489
490    async fn set_media_retention_policy_inner(
491        &self,
492        policy: MediaRetentionPolicy,
493    ) -> Result<(), Self::Error> {
494        let conn = self.write().await?;
495        conn.set_serialized_kv(keys::MEDIA_RETENTION_POLICY, policy).await?;
496        Ok(())
497    }
498
499    async fn add_media_content_inner(
500        &self,
501        request: &MediaRequestParameters,
502        data: Vec<u8>,
503        last_access: SystemTime,
504        policy: MediaRetentionPolicy,
505        ignore_policy: IgnoreMediaRetentionPolicy,
506    ) -> Result<(), Self::Error> {
507        let ignore_policy = ignore_policy.is_yes();
508        let data = self.encode_value(data)?;
509
510        if !ignore_policy && policy.exceeds_max_file_size(data.len() as u64) {
511            return Ok(());
512        }
513
514        let uri = self.encode_key(keys::MEDIA, request.source.unique_key());
515        let format = self.encode_key(keys::MEDIA, request.format.unique_key());
516        let timestamp = time_to_timestamp(last_access);
517
518        let conn = self.write().await?;
519        conn.execute(
520            "INSERT OR REPLACE INTO media (uri, format, data, last_access, ignore_policy) VALUES (?, ?, ?, ?, ?)",
521            (uri, format, data, timestamp, ignore_policy),
522        )
523        .await?;
524
525        Ok(())
526    }
527
528    async fn set_ignore_media_retention_policy_inner(
529        &self,
530        request: &MediaRequestParameters,
531        ignore_policy: IgnoreMediaRetentionPolicy,
532    ) -> Result<(), Self::Error> {
533        let uri = self.encode_key(keys::MEDIA, request.source.unique_key());
534        let format = self.encode_key(keys::MEDIA, request.format.unique_key());
535        let ignore_policy = ignore_policy.is_yes();
536
537        let conn = self.write().await?;
538        conn.execute(
539            r#"UPDATE media SET ignore_policy = ? WHERE uri = ? AND format = ?"#,
540            (ignore_policy, uri, format),
541        )
542        .await?;
543
544        Ok(())
545    }
546
547    async fn get_media_content_inner(
548        &self,
549        request: &MediaRequestParameters,
550        current_time: SystemTime,
551    ) -> Result<Option<Vec<u8>>, Self::Error> {
552        let uri = self.encode_key(keys::MEDIA, request.source.unique_key());
553        let format = self.encode_key(keys::MEDIA, request.format.unique_key());
554        let timestamp = time_to_timestamp(current_time);
555
556        let conn = self.write().await?;
557        let data = conn
558            .with_transaction::<_, rusqlite::Error, _>(move |txn| {
559                // Update the last access.
560                // We need to do this first so the transaction is in write mode right away.
561                // See: https://sqlite.org/lang_transaction.html#read_transactions_versus_write_transactions
562                txn.execute(
563                    "UPDATE media SET last_access = ? WHERE uri = ? AND format = ?",
564                    (timestamp, &uri, &format),
565                )?;
566
567                txn.query_row::<Vec<u8>, _, _>(
568                    "SELECT data FROM media WHERE uri = ? AND format = ?",
569                    (&uri, &format),
570                    |row| row.get(0),
571                )
572                .optional()
573            })
574            .await?;
575
576        data.map(|v| self.decode_value(&v).map(Into::into)).transpose()
577    }
578
579    async fn get_media_content_for_uri_inner(
580        &self,
581        uri: &MxcUri,
582        current_time: SystemTime,
583    ) -> Result<Option<Vec<u8>>, Self::Error> {
584        let uri = self.encode_key(keys::MEDIA, uri);
585        let timestamp = time_to_timestamp(current_time);
586
587        let conn = self.write().await?;
588        let data = conn
589            .with_transaction::<_, rusqlite::Error, _>(move |txn| {
590                // Update the last access.
591                // We need to do this first so the transaction is in write mode right away.
592                // See: https://sqlite.org/lang_transaction.html#read_transactions_versus_write_transactions
593                txn.execute("UPDATE media SET last_access = ? WHERE uri = ?", (timestamp, &uri))?;
594
595                txn.query_row::<Vec<u8>, _, _>(
596                    "SELECT data FROM media WHERE uri = ?",
597                    (&uri,),
598                    |row| row.get(0),
599                )
600                .optional()
601            })
602            .await?;
603
604        data.map(|v| self.decode_value(&v).map(Into::into)).transpose()
605    }
606
607    async fn clean_inner(
608        &self,
609        policy: MediaRetentionPolicy,
610        current_time: SystemTime,
611    ) -> Result<(), Self::Error> {
612        if !policy.has_limitations() {
613            // We can safely skip all the checks.
614            return Ok(());
615        }
616
617        let conn = self.write().await?;
618        let removed = conn
619            .with_transaction::<_, Error, _>(move |txn| {
620                let mut removed = false;
621
622                // First, check media content that exceed the max filesize.
623                if let Some(max_file_size) = policy.computed_max_file_size() {
624                    let count = txn.execute(
625                        "DELETE FROM media WHERE ignore_policy IS FALSE AND length(data) > ?",
626                        (max_file_size,),
627                    )?;
628
629                    if count > 0 {
630                        removed = true;
631                    }
632                }
633
634                // Then, clean up expired media content.
635                if let Some(last_access_expiry) = policy.last_access_expiry {
636                    let current_timestamp = time_to_timestamp(current_time);
637                    let expiry_secs = last_access_expiry.as_secs();
638                    let count = txn.execute(
639                        "DELETE FROM media WHERE ignore_policy IS FALSE AND (? - last_access) >= ?",
640                        (current_timestamp, expiry_secs),
641                    )?;
642
643                    if count > 0 {
644                        removed = true;
645                    }
646                }
647
648                // Finally, if the cache size is too big, remove old items until it fits.
649                if let Some(max_cache_size) = policy.max_cache_size {
650                    // i64 is the integer type used by SQLite, use it here to avoid usize overflow
651                    // during the conversion of the result.
652                    let cache_size = txn
653                        .query_row(
654                            "SELECT sum(length(data)) FROM media WHERE ignore_policy IS FALSE",
655                            (),
656                            |row| {
657                                // `sum()` returns `NULL` if there are no rows.
658                                row.get::<_, Option<u64>>(0)
659                            },
660                        )?
661                        .unwrap_or_default();
662
663                    // If the cache size is overflowing or bigger than max cache size, clean up.
664                    if cache_size > max_cache_size {
665                        // Get the sizes of the media contents ordered by last access.
666                        let mut cached_stmt = txn.prepare_cached(
667                            "SELECT rowid, length(data) FROM media \
668                             WHERE ignore_policy IS FALSE ORDER BY last_access DESC",
669                        )?;
670                        let content_sizes = cached_stmt
671                            .query(())?
672                            .mapped(|row| Ok((row.get::<_, i64>(0)?, row.get::<_, u64>(1)?)));
673
674                        let mut accumulated_items_size = 0u64;
675                        let mut limit_reached = false;
676                        let mut rows_to_remove = Vec::new();
677
678                        for result in content_sizes {
679                            let (row_id, size) = match result {
680                                Ok(content_size) => content_size,
681                                Err(error) => {
682                                    return Err(error.into());
683                                }
684                            };
685
686                            if limit_reached {
687                                rows_to_remove.push(row_id);
688                                continue;
689                            }
690
691                            match accumulated_items_size.checked_add(size) {
692                                Some(acc) if acc > max_cache_size => {
693                                    // We can stop accumulating.
694                                    limit_reached = true;
695                                    rows_to_remove.push(row_id);
696                                }
697                                Some(acc) => accumulated_items_size = acc,
698                                None => {
699                                    // The accumulated size is overflowing but the setting cannot be
700                                    // bigger than usize::MAX, we can stop accumulating.
701                                    limit_reached = true;
702                                    rows_to_remove.push(row_id);
703                                }
704                            }
705                        }
706
707                        if !rows_to_remove.is_empty() {
708                            removed = true;
709                        }
710
711                        txn.chunk_large_query_over(rows_to_remove, None, |txn, row_ids| {
712                            let sql_params = repeat_vars(row_ids.len());
713                            let query = format!("DELETE FROM media WHERE rowid IN ({sql_params})");
714                            txn.prepare(&query)?.execute(params_from_iter(row_ids))?;
715                            Ok(Vec::<()>::new())
716                        })?;
717                    }
718                }
719
720                txn.set_serialized_kv(keys::LAST_MEDIA_CLEANUP_TIME, current_time)?;
721
722                Ok(removed)
723            })
724            .await?;
725
726        // If we removed media, defragment the database and free space on the
727        // filesystem.
728        if removed {
729            conn.vacuum().await?;
730        }
731
732        Ok(())
733    }
734
735    async fn last_media_cleanup_time_inner(&self) -> Result<Option<SystemTime>, Self::Error> {
736        let conn = self.read().await?;
737        conn.get_serialized_kv(keys::LAST_MEDIA_CLEANUP_TIME).await
738    }
739}
740
741#[cfg(test)]
742mod tests {
743    use std::{
744        path::PathBuf,
745        sync::{
746            LazyLock,
747            atomic::{AtomicU32, Ordering::SeqCst},
748        },
749        time::Duration,
750    };
751
752    use matrix_sdk_base::{
753        media::{
754            MediaFormat, MediaRequestParameters, MediaThumbnailSettings,
755            store::{IgnoreMediaRetentionPolicy, MediaStore, MediaStoreError},
756        },
757        media_store_inner_integration_tests, media_store_integration_tests,
758        media_store_integration_tests_time,
759    };
760    use matrix_sdk_test::async_test;
761    use ruma::{events::room::MediaSource, media::Method, mxc_uri, uint};
762    use tempfile::{TempDir, tempdir};
763
764    use super::SqliteMediaStore;
765    use crate::{SqliteStoreConfig, utils::SqliteAsyncConnExt};
766
767    static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
768    static NUM: AtomicU32 = AtomicU32::new(0);
769
770    fn new_media_store_workspace() -> PathBuf {
771        let name = NUM.fetch_add(1, SeqCst).to_string();
772        TMP_DIR.path().join(name)
773    }
774
775    async fn get_media_store() -> Result<SqliteMediaStore, MediaStoreError> {
776        let tmpdir_path = new_media_store_workspace();
777
778        tracing::info!("using media store @ {}", tmpdir_path.to_str().unwrap());
779
780        Ok(SqliteMediaStore::open(tmpdir_path.to_str().unwrap(), None).await.unwrap())
781    }
782
783    media_store_integration_tests!();
784    media_store_integration_tests_time!();
785    media_store_inner_integration_tests!();
786
787    async fn get_media_store_content_sorted_by_last_access(
788        media_store: &SqliteMediaStore,
789    ) -> Vec<Vec<u8>> {
790        let sqlite_db = media_store.read().await.expect("accessing sqlite db failed");
791        sqlite_db
792            .prepare("SELECT data FROM media ORDER BY last_access DESC", |mut stmt| {
793                stmt.query(())?.mapped(|row| row.get(0)).collect()
794            })
795            .await
796            .expect("querying media cache content by last access failed")
797    }
798
799    #[async_test]
800    async fn test_pool_size() {
801        let tmpdir_path = new_media_store_workspace();
802        let store_open_config = SqliteStoreConfig::new(tmpdir_path).pool_max_size(42);
803
804        let store = SqliteMediaStore::open_with_config(&store_open_config).await.unwrap();
805
806        assert_eq!(store.pool_max_size().await.unwrap(), 42);
807    }
808
809    #[async_test]
810    async fn test_last_access() {
811        let media_store = get_media_store().await.expect("creating media cache failed");
812        let uri = mxc_uri!("mxc://localhost/media");
813        let file_request = MediaRequestParameters {
814            source: MediaSource::Plain(uri.to_owned()),
815            format: MediaFormat::File,
816        };
817        let thumbnail_request = MediaRequestParameters {
818            source: MediaSource::Plain(uri.to_owned()),
819            format: MediaFormat::Thumbnail(MediaThumbnailSettings::with_method(
820                Method::Crop,
821                uint!(100),
822                uint!(100),
823            )),
824        };
825
826        let content: Vec<u8> = "hello world".into();
827        let thumbnail_content: Vec<u8> = "hello…".into();
828
829        // Add the media.
830        media_store
831            .add_media_content(&file_request, content.clone(), IgnoreMediaRetentionPolicy::No)
832            .await
833            .expect("adding file failed");
834
835        // Since the precision of the timestamp is in seconds, wait so the timestamps
836        // differ.
837        tokio::time::sleep(Duration::from_secs(3)).await;
838
839        media_store
840            .add_media_content(
841                &thumbnail_request,
842                thumbnail_content.clone(),
843                IgnoreMediaRetentionPolicy::No,
844            )
845            .await
846            .expect("adding thumbnail failed");
847
848        // File's last access is older than thumbnail.
849        let contents = get_media_store_content_sorted_by_last_access(&media_store).await;
850
851        assert_eq!(contents.len(), 2, "media cache contents length is wrong");
852        assert_eq!(contents[0], thumbnail_content, "thumbnail is not last access");
853        assert_eq!(contents[1], content, "file is not second-to-last access");
854
855        // Since the precision of the timestamp is in seconds, wait so the timestamps
856        // differ.
857        tokio::time::sleep(Duration::from_secs(3)).await;
858
859        // Access the file so its last access is more recent.
860        let _ = media_store
861            .get_media_content(&file_request)
862            .await
863            .expect("getting file failed")
864            .expect("file is missing");
865
866        // File's last access is more recent than thumbnail.
867        let contents = get_media_store_content_sorted_by_last_access(&media_store).await;
868
869        assert_eq!(contents.len(), 2, "media cache contents length is wrong");
870        assert_eq!(contents[0], content, "file is not last access");
871        assert_eq!(contents[1], thumbnail_content, "thumbnail is not second-to-last access");
872    }
873}
874
875#[cfg(test)]
876mod close_reopen_tests {
877    use std::sync::{
878        LazyLock,
879        atomic::{AtomicU32, Ordering::SeqCst},
880    };
881
882    use matrix_sdk_base::media::{
883        MediaFormat, MediaRequestParameters,
884        store::{IgnoreMediaRetentionPolicy, MediaStore},
885    };
886    use matrix_sdk_test::async_test;
887    use ruma::{events::room::MediaSource, mxc_uri};
888    use tempfile::{TempDir, tempdir};
889
890    use super::SqliteMediaStore;
891
892    static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
893    static NUM: AtomicU32 = AtomicU32::new(0);
894
895    async fn new_store() -> SqliteMediaStore {
896        let name = NUM.fetch_add(1, SeqCst).to_string();
897        let tmpdir_path = TMP_DIR.path().join(name);
898        SqliteMediaStore::open(tmpdir_path, None).await.unwrap()
899    }
900
901    fn test_request() -> MediaRequestParameters {
902        MediaRequestParameters {
903            source: MediaSource::Plain(mxc_uri!("mxc://localhost/test_media").to_owned()),
904            format: MediaFormat::File,
905        }
906    }
907
908    #[async_test]
909    async fn test_close_completes_without_timeout() {
910        let store = new_store().await;
911
912        // Close should complete quickly without hitting any timeout.
913        let start = std::time::Instant::now();
914        store.close().await.unwrap();
915        let elapsed = start.elapsed();
916
917        assert!(
918            elapsed < std::time::Duration::from_secs(2),
919            "close() took {elapsed:?}, expected < 2s (no timeout)"
920        );
921
922        // Connections should be None after close.
923        let guard = store.connections.lock().await;
924        assert!(guard.is_none(), "connections should be None after close");
925    }
926
927    #[async_test]
928    async fn test_reopen_restores_connections() {
929        let store = new_store().await;
930
931        store.close().await.unwrap();
932
933        {
934            let guard = store.connections.lock().await;
935            assert!(guard.is_none());
936        }
937
938        store.reopen().await.unwrap();
939
940        {
941            let guard = store.connections.lock().await;
942            assert!(guard.is_some(), "connections should be Some after reopen");
943        }
944    }
945
946    #[async_test]
947    async fn test_close_is_idempotent() {
948        let store = new_store().await;
949
950        store.close().await.unwrap();
951        // Second close should be a no-op.
952        store.close().await.unwrap();
953
954        let guard = store.connections.lock().await;
955        assert!(guard.is_none());
956    }
957
958    #[async_test]
959    async fn test_reopen_is_idempotent() {
960        let store = new_store().await;
961
962        // Reopen on an active store should be a no-op.
963        store.reopen().await.unwrap();
964
965        let guard = store.connections.lock().await;
966        assert!(guard.is_some());
967    }
968
969    #[async_test]
970    async fn test_read_fails_when_closed() {
971        let store = new_store().await;
972        store.close().await.unwrap();
973
974        let err = store.get_media_content(&test_request()).await;
975        assert!(err.is_err(), "read should fail when closed");
976
977        let err_msg = err.unwrap_err().to_string();
978        assert!(err_msg.contains("closed"), "error should mention 'closed', got: {err_msg}");
979    }
980
981    #[async_test]
982    async fn test_write_fails_when_closed() {
983        let store = new_store().await;
984        store.close().await.unwrap();
985
986        let err = store
987            .add_media_content(&test_request(), b"data".to_vec(), IgnoreMediaRetentionPolicy::No)
988            .await;
989        assert!(err.is_err(), "write should fail when closed");
990
991        let err_msg = err.unwrap_err().to_string();
992        assert!(err_msg.contains("closed"), "error should mention 'closed', got: {err_msg}");
993    }
994
995    #[async_test]
996    async fn test_data_persists_across_close_reopen() {
997        let store = new_store().await;
998
999        // Write some media content.
1000        store
1001            .add_media_content(
1002                &test_request(),
1003                b"hello world".to_vec(),
1004                IgnoreMediaRetentionPolicy::Yes,
1005            )
1006            .await
1007            .unwrap();
1008
1009        // Verify it's there.
1010        let content = store.get_media_content(&test_request()).await.unwrap();
1011        assert_eq!(content.as_deref(), Some(b"hello world".as_slice()));
1012
1013        // Close and reopen.
1014        store.close().await.unwrap();
1015        store.reopen().await.unwrap();
1016
1017        // Content should still be there after reopen.
1018        let content = store.get_media_content(&test_request()).await.unwrap();
1019        assert_eq!(
1020            content.as_deref(),
1021            Some(b"hello world".as_slice()),
1022            "media content should persist across close/reopen"
1023        );
1024    }
1025
1026    #[async_test]
1027    async fn test_multiple_close_reopen_cycles() {
1028        let store = new_store().await;
1029
1030        for _ in 0..5 {
1031            store.close().await.unwrap();
1032            store.reopen().await.unwrap();
1033
1034            // After each cycle, the store should be fully operational.
1035            let result = store.get_media_content(&test_request()).await;
1036            assert!(result.is_ok(), "store should work after close/reopen cycle");
1037        }
1038    }
1039
1040    #[async_test]
1041    async fn test_pool_is_fully_drained_after_close() {
1042        let store = new_store().await;
1043
1044        // Do a few reads to exercise the pool.
1045        let _ = store.get_media_content(&test_request()).await;
1046        let _ = store.get_media_content(&test_request()).await;
1047
1048        store.close().await.unwrap();
1049
1050        // After close, the connections field should be None (pool and write
1051        // connection have been fully torn down).
1052        let guard = store.connections.lock().await;
1053        assert!(guard.is_none(), "all connections should be released after close");
1054    }
1055
1056    #[async_test]
1057    async fn test_operations_work_immediately_after_reopen() {
1058        let store = new_store().await;
1059
1060        store.close().await.unwrap();
1061        store.reopen().await.unwrap();
1062
1063        // Read should work immediately after reopen.
1064        let result = store.get_media_content(&test_request()).await;
1065        assert!(result.is_ok(), "read should succeed immediately after reopen");
1066
1067        // Write should work immediately after reopen.
1068        let result = store
1069            .add_media_content(
1070                &test_request(),
1071                b"after_reopen".to_vec(),
1072                IgnoreMediaRetentionPolicy::No,
1073            )
1074            .await;
1075        assert!(result.is_ok(), "write should succeed immediately after reopen");
1076    }
1077
1078    #[async_test]
1079    async fn test_close_waits_for_held_read_connection_to_drain() {
1080        let store = new_store().await;
1081
1082        // Acquire a read connection and hold it, simulating an in-flight read.
1083        let held_conn = store.read().await.unwrap();
1084
1085        // Spawn close in a background task — it will close the pool and then
1086        // poll-wait for pool.status().size == 0 in the drain loop.
1087        let store_clone = store.clone();
1088        let close_handle = tokio::spawn(async move {
1089            store_clone.close().await.unwrap();
1090        });
1091
1092        // Give close() a moment to close the pool and enter the drain loop.
1093        tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1094
1095        // The close task should still be running because we hold a connection.
1096        assert!(!close_handle.is_finished(), "close should be waiting for the held connection");
1097
1098        // Release the held connection — this lets pool.status().size drop to 0.
1099        drop(held_conn);
1100
1101        // Now close should complete promptly (well within the 5s timeout).
1102        let timeout = tokio::time::timeout(std::time::Duration::from_secs(3), close_handle).await;
1103        assert!(timeout.is_ok(), "close should complete after the held connection is released");
1104        timeout.unwrap().unwrap();
1105
1106        // Verify the store is fully closed.
1107        let guard = store.connections.lock().await;
1108        assert!(guard.is_none(), "connections should be None after close");
1109    }
1110}
1111
1112#[cfg(test)]
1113mod encrypted_tests {
1114    use std::sync::{
1115        LazyLock,
1116        atomic::{AtomicU32, Ordering::SeqCst},
1117    };
1118
1119    use matrix_sdk_base::{
1120        media::store::MediaStoreError, media_store_inner_integration_tests,
1121        media_store_integration_tests, media_store_integration_tests_time,
1122    };
1123    use tempfile::{TempDir, tempdir};
1124
1125    use super::SqliteMediaStore;
1126
1127    static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
1128    static NUM: AtomicU32 = AtomicU32::new(0);
1129
1130    async fn get_media_store() -> Result<SqliteMediaStore, MediaStoreError> {
1131        let name = NUM.fetch_add(1, SeqCst).to_string();
1132        let tmpdir_path = TMP_DIR.path().join(name);
1133
1134        tracing::info!("using media store @ {}", tmpdir_path.to_str().unwrap());
1135
1136        Ok(SqliteMediaStore::open(tmpdir_path.to_str().unwrap(), Some("default_test_password"))
1137            .await
1138            .unwrap())
1139    }
1140
1141    media_store_integration_tests!();
1142    media_store_integration_tests_time!();
1143    media_store_inner_integration_tests!();
1144}