1use std::{
2 borrow::Cow,
3 collections::{BTreeMap, BTreeSet, HashMap},
4 fmt, iter,
5 path::{Path, PathBuf},
6 str::FromStr as _,
7 sync::Arc,
8};
9
10use async_trait::async_trait;
11use deadpool::managed::PoolConfig;
12use matrix_sdk_base::{
13 MinimalRoomMemberEvent, ROOM_VERSION_FALLBACK, ROOM_VERSION_RULES_FALLBACK, RoomInfo,
14 RoomMemberships, RoomState, StateChanges, StateStore, StateStoreDataKey, StateStoreDataValue,
15 deserialized_responses::{DisplayName, RawAnySyncOrStrippedState, SyncOrStrippedState},
16 store::{
17 ChildTransactionId, DependentQueuedRequest, DependentQueuedRequestKind, QueueWedgeError,
18 QueuedRequest, QueuedRequestKind, RoomLoadSettings, SentRequestKey,
19 StoredThreadSubscription, ThreadSubscriptionStatus, migration_helpers::RoomInfoV1,
20 },
21};
22use matrix_sdk_store_encryption::StoreCipher;
23use ruma::{
24 CanonicalJsonObject, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId,
25 OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UInt, UserId,
26 canonical_json::{RedactedBecause, redact},
27 events::{
28 AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnySyncStateEvent,
29 GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType,
30 presence::PresenceEvent,
31 receipt::{Receipt, ReceiptThread, ReceiptType},
32 room::{
33 create::RoomCreateEventContent,
34 member::{StrippedRoomMemberEvent, SyncRoomMemberEvent},
35 },
36 },
37 serde::Raw,
38};
39use rusqlite::{OptionalExtension, Transaction};
40use serde::{Deserialize, Serialize};
41use tokio::{
42 fs,
43 sync::{Mutex, OwnedMutexGuard},
44};
45use tracing::{debug, instrument, warn};
46
47use crate::{
48 OpenStoreError, RuntimeConfig, Secret, SqliteStoreConfig,
49 connection::{self, Connection as SqliteAsyncConn, Pool as SqlitePool, SqliteConnections},
50 error::{Error, Result},
51 utils::{
52 EncryptableStore, Key, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt,
53 SqliteKeyValueStoreConnExt, repeat_vars,
54 },
55};
56
57mod keys {
58 pub const KV_BLOB: &str = "kv_blob";
60 pub const ROOM_INFO: &str = "room_info";
61 pub const STATE_EVENT: &str = "state_event";
62 pub const GLOBAL_ACCOUNT_DATA: &str = "global_account_data";
63 pub const ROOM_ACCOUNT_DATA: &str = "room_account_data";
64 pub const MEMBER: &str = "member";
65 pub const PROFILE: &str = "profile";
66 pub const RECEIPT: &str = "receipt";
67 pub const DISPLAY_NAME: &str = "display_name";
68 pub const SEND_QUEUE: &str = "send_queue_events";
69 pub const DEPENDENTS_SEND_QUEUE: &str = "dependent_send_queue_events";
70 pub const THREAD_SUBSCRIPTIONS: &str = "thread_subscriptions";
71}
72
73pub const DATABASE_NAME: &str = "matrix-sdk-state.sqlite3";
75
76#[derive(Clone)]
78pub struct SqliteStateStore {
79 store_cipher: Option<Arc<StoreCipher>>,
80
81 connections: Arc<Mutex<Option<SqliteConnections>>>,
83
84 db_path: PathBuf,
86
87 pool_config: PoolConfig,
89
90 runtime_config: RuntimeConfig,
92}
93
94#[cfg(not(tarpaulin_include))]
95impl fmt::Debug for SqliteStateStore {
96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 f.debug_struct("SqliteStateStore").finish_non_exhaustive()
98 }
99}
100
101impl SqliteStateStore {
102 pub async fn open(
105 path: impl AsRef<Path>,
106 passphrase: Option<&str>,
107 ) -> Result<Self, OpenStoreError> {
108 Self::open_with_config(&SqliteStoreConfig::new(path).passphrase(passphrase)).await
109 }
110
111 pub async fn open_with_key(
114 path: impl AsRef<Path>,
115 key: Option<&[u8; 32]>,
116 ) -> Result<Self, OpenStoreError> {
117 Self::open_with_config(&SqliteStoreConfig::new(path).key(key)).await
118 }
119
120 pub async fn open_with_config(config: &SqliteStoreConfig) -> Result<Self, OpenStoreError> {
122 fs::create_dir_all(&config.path).await.map_err(OpenStoreError::CreateDir)?;
123
124 let pool = config.build_pool_of_connections(DATABASE_NAME)?;
125 let pool_config = config.pool_config;
126 let runtime_config = config.runtime_config;
127
128 let this =
129 Self::open_with_pool(pool, config.secret.clone(), pool_config, runtime_config).await?;
130 this.read().await?.apply_runtime_config(runtime_config).await?;
131
132 Ok(this)
133 }
134
135 pub(crate) async fn open_with_pool(
138 pool: SqlitePool,
139 secret: Option<Secret>,
140 pool_config: PoolConfig,
141 runtime_config: RuntimeConfig,
142 ) -> Result<Self, OpenStoreError> {
143 let db_path = pool.manager().database_path.clone();
144 let conn = pool.get().await?;
145
146 let mut version = conn.db_version().await?;
147
148 if version == 0 {
149 init(&conn).await?;
150 version = 1;
151 }
152
153 let store_cipher = match secret {
154 Some(s) => Some(Arc::new(conn.get_or_create_store_cipher(s).await?)),
155 None => None,
156 };
157 let this = Self {
158 store_cipher,
159 connections: Arc::new(Mutex::new(Some(SqliteConnections {
160 pool,
161 write_connection: Arc::new(Mutex::new(conn)),
162 }))),
163 db_path,
164 pool_config,
165 runtime_config,
166 };
167 this.run_migrations(version, None).await?;
168
169 this.read().await?.wal_checkpoint().await;
170
171 Ok(this)
172 }
173
174 async fn run_migrations(&self, from: u8, to: Option<u8>) -> Result<()> {
179 if to == Some(1) {
180 return Ok(());
181 }
182
183 let conn = self.write().await?;
184
185 if from < 2 {
186 debug!("Upgrading database to version 2");
187 let this = self.clone();
188 conn.with_transaction(move |txn| {
189 txn.execute_batch(include_str!(
191 "../migrations/state_store/002_a_create_new_room_info.sql"
192 ))?;
193
194 for data in txn
196 .prepare("SELECT data FROM room_info")?
197 .query_map((), |row| row.get::<_, Vec<u8>>(0))?
198 {
199 let data = data?;
200 let room_info: RoomInfoV1 = this.deserialize_json(&data)?;
201
202 let room_id = this.encode_key(keys::ROOM_INFO, room_info.room_id());
203 let state = this
204 .encode_key(keys::ROOM_INFO, serde_json::to_string(&room_info.state())?);
205 txn.prepare_cached(
206 "INSERT OR REPLACE INTO new_room_info (room_id, state, data)
207 VALUES (?, ?, ?)",
208 )?
209 .execute((room_id, state, data))?;
210 }
211
212 txn.execute_batch(include_str!(
214 "../migrations/state_store/002_b_replace_room_info.sql"
215 ))?;
216
217 txn.set_db_version(2)?;
218 Result::<_, Error>::Ok(())
219 })
220 .await?;
221 }
222
223 if to == Some(2) {
224 return Ok(());
225 }
226
227 if from < 3 {
229 debug!("Upgrading database to version 3");
230 let this = self.clone();
231 conn.with_transaction(move |txn| {
232 for data in txn
234 .prepare("SELECT data FROM room_info")?
235 .query_map((), |row| row.get::<_, Vec<u8>>(0))?
236 {
237 let data = data?;
238 let room_info_v1: RoomInfoV1 = this.deserialize_json(&data)?;
239
240 let room_id = this.encode_key(keys::STATE_EVENT, room_info_v1.room_id());
242 let event_type =
243 this.encode_key(keys::STATE_EVENT, StateEventType::RoomCreate.to_string());
244 let create_res = txn
245 .prepare(
246 "SELECT stripped, data FROM state_event
247 WHERE room_id = ? AND event_type = ?",
248 )?
249 .query_row([room_id, event_type], |row| {
250 Ok((row.get::<_, bool>(0)?, row.get::<_, Vec<u8>>(1)?))
251 })
252 .optional()?;
253
254 let create = create_res.and_then(|(stripped, data)| {
255 let create = if stripped {
256 SyncOrStrippedState::<RoomCreateEventContent>::Stripped(
257 this.deserialize_json(&data).ok()?,
258 )
259 } else {
260 SyncOrStrippedState::Sync(this.deserialize_json(&data).ok()?)
261 };
262 Some(create)
263 });
264
265 let migrated_room_info = room_info_v1.migrate(create.as_ref());
266
267 let data = this.serialize_json(&migrated_room_info)?;
268 let room_id = this.encode_key(keys::ROOM_INFO, migrated_room_info.room_id());
269 txn.prepare_cached("UPDATE room_info SET data = ? WHERE room_id = ?")?
270 .execute((data, room_id))?;
271 }
272
273 txn.set_db_version(3)?;
274 Result::<_, Error>::Ok(())
275 })
276 .await?;
277 }
278
279 if to == Some(3) {
280 return Ok(());
281 }
282
283 if from < 4 {
284 debug!("Upgrading database to version 4");
285 conn.with_transaction(move |txn| {
286 txn.execute_batch(include_str!("../migrations/state_store/003_send_queue.sql"))?;
288 txn.set_db_version(4)
289 })
290 .await?;
291 }
292
293 if to == Some(4) {
294 return Ok(());
295 }
296
297 if from < 5 {
298 debug!("Upgrading database to version 5");
299 conn.with_transaction(move |txn| {
300 txn.execute_batch(include_str!(
302 "../migrations/state_store/004_send_queue_with_roomid_value.sql"
303 ))?;
304 txn.set_db_version(4)
305 })
306 .await?;
307 }
308
309 if to == Some(5) {
310 return Ok(());
311 }
312
313 if from < 6 {
314 debug!("Upgrading database to version 6");
315 conn.with_transaction(move |txn| {
316 txn.execute_batch(include_str!(
318 "../migrations/state_store/005_send_queue_dependent_events.sql"
319 ))?;
320 txn.set_db_version(6)
321 })
322 .await?;
323 }
324
325 if to == Some(6) {
326 return Ok(());
327 }
328
329 if from < 7 {
330 debug!("Upgrading database to version 7");
331 conn.with_transaction(move |txn| {
332 txn.execute_batch(include_str!("../migrations/state_store/006_drop_media.sql"))?;
334 txn.set_db_version(7)
335 })
336 .await?;
337 }
338
339 if to == Some(7) {
340 return Ok(());
341 }
342
343 if from < 8 {
344 debug!("Upgrading database to version 8");
345 let error = QueueWedgeError::GenericApiError {
347 msg: "local echo failed to send in a previous session".into(),
348 };
349 let default_err = self.serialize_value(&error)?;
350
351 conn.with_transaction(move |txn| {
352 txn.execute_batch(include_str!("../migrations/state_store/007_a_send_queue_wedge_reason.sql"))?;
354
355 for wedged_entries in txn
358 .prepare("SELECT room_id, transaction_id FROM send_queue_events WHERE wedged = 1")?
359 .query_map((), |row| {
360 Ok(
361 (row.get::<_, Vec<u8>>(0)?,row.get::<_, String>(1)?)
362 )
363 })? {
364
365 let (room_id, transaction_id) = wedged_entries?;
366
367 txn.prepare_cached("UPDATE send_queue_events SET wedge_reason = ? WHERE room_id = ? AND transaction_id = ?")?
368 .execute((default_err.clone(), room_id, transaction_id))?;
369 }
370
371
372 txn.execute_batch(include_str!("../migrations/state_store/007_b_send_queue_clean.sql"))?;
374
375 txn.set_db_version(8)
376 })
377 .await?;
378 }
379
380 if to == Some(8) {
381 return Ok(());
382 }
383
384 if from < 9 {
385 debug!("Upgrading database to version 9");
386 conn.with_transaction(move |txn| {
387 txn.execute_batch(include_str!("../migrations/state_store/008_send_queue.sql"))?;
389 txn.set_db_version(9)
390 })
391 .await?;
392 }
393
394 if to == Some(9) {
395 return Ok(());
396 }
397
398 if from < 10 {
399 debug!("Upgrading database to version 10");
400 conn.with_transaction(move |txn| {
401 txn.execute_batch(include_str!(
403 "../migrations/state_store/009_send_queue_priority.sql"
404 ))?;
405 txn.set_db_version(10)
406 })
407 .await?;
408 }
409
410 if to == Some(10) {
411 return Ok(());
412 }
413
414 if from < 11 {
415 debug!("Upgrading database to version 11");
416 conn.with_transaction(move |txn| {
417 txn.execute_batch(include_str!(
419 "../migrations/state_store/010_send_queue_enqueue_time.sql"
420 ))?;
421 txn.set_db_version(11)
422 })
423 .await?;
424 }
425
426 if to == Some(11) {
427 return Ok(());
428 }
429
430 if from < 12 {
431 debug!("Upgrading database to version 12");
432 conn.vacuum().await?;
436 conn.set_kv("version", vec![12]).await?;
437 }
438
439 if to == Some(12) {
440 return Ok(());
441 }
442
443 if from < 13 {
444 debug!("Upgrading database to version 13");
445 conn.with_transaction(move |txn| {
446 txn.execute_batch(include_str!(
448 "../migrations/state_store/011_thread_subscriptions.sql"
449 ))?;
450 txn.set_db_version(13)
451 })
452 .await?;
453 }
454
455 if to == Some(13) {
456 return Ok(());
457 }
458
459 if from < 14 {
460 debug!("Upgrading database to version 14");
461 conn.with_transaction(move |txn| {
462 txn.execute_batch(include_str!(
464 "../migrations/state_store/012_thread_subscriptions_bumpstamp.sql"
465 ))?;
466 txn.set_db_version(14)
467 })
468 .await?;
469 }
470
471 if to == Some(14) {
472 return Ok(());
473 }
474
475 if from < 15 {
476 debug!("Upgrading database to version 15");
477 conn.with_transaction(move |txn| {
478 txn.execute_batch(include_str!(
480 "../migrations/state_store/013_send_queue_new_parent_key_format.sql"
481 ))?;
482 txn.set_db_version(15)
483 })
484 .await?;
485 }
486
487 if to == Some(15) {
488 return Ok(());
489 }
490
491 Ok(())
492 }
493
494 fn encode_state_store_data_key(&self, key: StateStoreDataKey<'_>) -> Key {
495 let key_s = match key {
496 StateStoreDataKey::SyncToken => Cow::Borrowed(StateStoreDataKey::SYNC_TOKEN),
497 StateStoreDataKey::SupportedVersions => {
498 Cow::Borrowed(StateStoreDataKey::SUPPORTED_VERSIONS)
499 }
500 StateStoreDataKey::WellKnown => Cow::Borrowed(StateStoreDataKey::WELL_KNOWN),
501 StateStoreDataKey::Filter(f) => {
502 Cow::Owned(format!("{}:{f}", StateStoreDataKey::FILTER))
503 }
504 StateStoreDataKey::UserAvatarUrl(u) => {
505 Cow::Owned(format!("{}:{u}", StateStoreDataKey::USER_AVATAR_URL))
506 }
507 StateStoreDataKey::RecentlyVisitedRooms(b) => {
508 Cow::Owned(format!("{}:{b}", StateStoreDataKey::RECENTLY_VISITED_ROOMS))
509 }
510 StateStoreDataKey::UtdHookManagerData => {
511 Cow::Borrowed(StateStoreDataKey::UTD_HOOK_MANAGER_DATA)
512 }
513 StateStoreDataKey::OneTimeKeyAlreadyUploaded => {
514 Cow::Borrowed(StateStoreDataKey::ONE_TIME_KEY_ALREADY_UPLOADED)
515 }
516 StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
517 if let Some(thread_root) = thread_root {
518 Cow::Owned(format!(
519 "{}:{room_id}:{thread_root}",
520 StateStoreDataKey::COMPOSER_DRAFT
521 ))
522 } else {
523 Cow::Owned(format!("{}:{room_id}", StateStoreDataKey::COMPOSER_DRAFT))
524 }
525 }
526 StateStoreDataKey::SeenKnockRequests(room_id) => {
527 Cow::Owned(format!("{}:{room_id}", StateStoreDataKey::SEEN_KNOCK_REQUESTS))
528 }
529 StateStoreDataKey::ThreadSubscriptionsCatchupTokens => {
530 Cow::Borrowed(StateStoreDataKey::THREAD_SUBSCRIPTIONS_CATCHUP_TOKENS)
531 }
532 StateStoreDataKey::HomeserverCapabilities => {
533 Cow::Borrowed(StateStoreDataKey::HOMESERVER_CAPABILITIES)
534 }
535 };
536
537 self.encode_key(keys::KV_BLOB, &*key_s)
538 }
539
540 fn encode_presence_key(&self, user_id: &UserId) -> Key {
541 self.encode_key(keys::KV_BLOB, format!("presence:{user_id}"))
542 }
543
544 fn encode_custom_key(&self, key: &[u8]) -> Key {
545 let mut full_key = b"custom:".to_vec();
546 full_key.extend(key);
547 self.encode_key(keys::KV_BLOB, full_key)
548 }
549
550 #[instrument(skip_all)]
553 async fn read(&self) -> Result<SqliteAsyncConn> {
554 let pool = {
555 let guard = self.connections.lock().await;
556 let conns = guard.as_ref().ok_or(Error::StoreClosed)?;
557 conns.pool.clone()
558 };
559 Ok(pool.get().await?)
560 }
561
562 #[instrument(skip_all)]
565 async fn write(&self) -> Result<OwnedMutexGuard<SqliteAsyncConn>> {
566 let write_conn = {
567 let guard = self.connections.lock().await;
568 let conns = guard.as_ref().ok_or(Error::StoreClosed)?;
569 conns.write_connection.clone()
570 };
571 Ok(write_conn.lock_owned().await)
572 }
573
574 fn remove_maybe_stripped_room_data(
575 &self,
576 txn: &Transaction<'_>,
577 room_id: &RoomId,
578 stripped: bool,
579 ) -> rusqlite::Result<()> {
580 let state_event_room_id = self.encode_key(keys::STATE_EVENT, room_id);
581 txn.remove_room_state_events(&state_event_room_id, Some(stripped))?;
582
583 let member_room_id = self.encode_key(keys::MEMBER, room_id);
584 txn.remove_room_members(&member_room_id, Some(stripped))
585 }
586
587 pub async fn vacuum(&self) -> Result<()> {
588 self.write().await?.vacuum().await
589 }
590
591 pub async fn get_db_size(&self) -> Result<Option<usize>> {
592 let read_conn = self.read().await?;
593 Ok(Some(read_conn.get_db_size().await?))
594 }
595}
596
597impl EncryptableStore for SqliteStateStore {
598 fn get_cypher(&self) -> Option<&StoreCipher> {
599 self.store_cipher.as_deref()
600 }
601}
602
603async fn init(conn: &SqliteAsyncConn) -> Result<()> {
605 conn.execute_batch("PRAGMA journal_mode = wal;").await?;
608 conn.with_transaction(|txn| {
609 txn.execute_batch(include_str!("../migrations/state_store/001_init.sql"))?;
610 txn.set_db_version(1)?;
611
612 Ok(())
613 })
614 .await
615}
616
617trait SqliteConnectionStateStoreExt {
618 fn set_kv_blob(&self, key: &[u8], value: &[u8]) -> rusqlite::Result<()>;
619
620 fn set_global_account_data(&self, event_type: &[u8], data: &[u8]) -> rusqlite::Result<()>;
621
622 fn set_room_account_data(
623 &self,
624 room_id: &[u8],
625 event_type: &[u8],
626 data: &[u8],
627 ) -> rusqlite::Result<()>;
628 fn remove_room_account_data(&self, room_id: &[u8]) -> rusqlite::Result<()>;
629
630 fn set_room_info(&self, room_id: &[u8], state: &[u8], data: &[u8]) -> rusqlite::Result<()>;
631 fn get_room_info(&self, room_id: &[u8]) -> rusqlite::Result<Option<Vec<u8>>>;
632 fn remove_room_info(&self, room_id: &[u8]) -> rusqlite::Result<()>;
633
634 fn set_state_event(
635 &self,
636 room_id: &[u8],
637 event_type: &[u8],
638 state_key: &[u8],
639 stripped: bool,
640 event_id: Option<&[u8]>,
641 data: &[u8],
642 ) -> rusqlite::Result<()>;
643 fn get_state_event_by_id(
644 &self,
645 room_id: &[u8],
646 event_id: &[u8],
647 ) -> rusqlite::Result<Option<Vec<u8>>>;
648 fn remove_room_state_events(
649 &self,
650 room_id: &[u8],
651 stripped: Option<bool>,
652 ) -> rusqlite::Result<()>;
653
654 fn set_member(
655 &self,
656 room_id: &[u8],
657 user_id: &[u8],
658 membership: &[u8],
659 stripped: bool,
660 data: &[u8],
661 ) -> rusqlite::Result<()>;
662 fn remove_room_members(&self, room_id: &[u8], stripped: Option<bool>) -> rusqlite::Result<()>;
663
664 fn set_profile(&self, room_id: &[u8], user_id: &[u8], data: &[u8]) -> rusqlite::Result<()>;
665 fn remove_room_profiles(&self, room_id: &[u8]) -> rusqlite::Result<()>;
666 fn remove_room_profile(&self, room_id: &[u8], user_id: &[u8]) -> rusqlite::Result<()>;
667
668 fn set_receipt(
669 &self,
670 room_id: &[u8],
671 user_id: &[u8],
672 receipt_type: &[u8],
673 thread_id: &[u8],
674 event_id: &[u8],
675 data: &[u8],
676 ) -> rusqlite::Result<()>;
677 fn remove_room_receipts(&self, room_id: &[u8]) -> rusqlite::Result<()>;
678
679 fn set_display_name(&self, room_id: &[u8], name: &[u8], data: &[u8]) -> rusqlite::Result<()>;
680 fn remove_display_name(&self, room_id: &[u8], name: &[u8]) -> rusqlite::Result<()>;
681 fn remove_room_display_names(&self, room_id: &[u8]) -> rusqlite::Result<()>;
682 fn remove_room_send_queue(&self, room_id: &[u8]) -> rusqlite::Result<()>;
683 fn remove_room_dependent_send_queue(&self, room_id: &[u8]) -> rusqlite::Result<()>;
684}
685
686impl SqliteConnectionStateStoreExt for rusqlite::Connection {
687 fn set_kv_blob(&self, key: &[u8], value: &[u8]) -> rusqlite::Result<()> {
688 self.execute("INSERT OR REPLACE INTO kv_blob VALUES (?, ?)", (key, value))?;
689 Ok(())
690 }
691
692 fn set_global_account_data(&self, event_type: &[u8], data: &[u8]) -> rusqlite::Result<()> {
693 self.prepare_cached(
694 "INSERT OR REPLACE INTO global_account_data (event_type, data)
695 VALUES (?, ?)",
696 )?
697 .execute((event_type, data))?;
698 Ok(())
699 }
700
701 fn set_room_account_data(
702 &self,
703 room_id: &[u8],
704 event_type: &[u8],
705 data: &[u8],
706 ) -> rusqlite::Result<()> {
707 self.prepare_cached(
708 "INSERT OR REPLACE INTO room_account_data (room_id, event_type, data)
709 VALUES (?, ?, ?)",
710 )?
711 .execute((room_id, event_type, data))?;
712 Ok(())
713 }
714
715 fn remove_room_account_data(&self, room_id: &[u8]) -> rusqlite::Result<()> {
716 self.prepare(
717 "DELETE FROM room_account_data
718 WHERE room_id = ?",
719 )?
720 .execute((room_id,))?;
721 Ok(())
722 }
723
724 fn set_room_info(&self, room_id: &[u8], state: &[u8], data: &[u8]) -> rusqlite::Result<()> {
725 self.prepare_cached(
726 "INSERT OR REPLACE INTO room_info (room_id, state, data)
727 VALUES (?, ?, ?)",
728 )?
729 .execute((room_id, state, data))?;
730 Ok(())
731 }
732
733 fn get_room_info(&self, room_id: &[u8]) -> rusqlite::Result<Option<Vec<u8>>> {
734 self.query_row("SELECT data FROM room_info WHERE room_id = ?", (room_id,), |row| row.get(0))
735 .optional()
736 }
737
738 fn remove_room_info(&self, room_id: &[u8]) -> rusqlite::Result<()> {
740 self.prepare_cached("DELETE FROM room_info WHERE room_id = ?")?.execute((room_id,))?;
741 Ok(())
742 }
743
744 fn set_state_event(
745 &self,
746 room_id: &[u8],
747 event_type: &[u8],
748 state_key: &[u8],
749 stripped: bool,
750 event_id: Option<&[u8]>,
751 data: &[u8],
752 ) -> rusqlite::Result<()> {
753 self.prepare_cached(
754 "INSERT OR REPLACE
755 INTO state_event (room_id, event_type, state_key, stripped, event_id, data)
756 VALUES (?, ?, ?, ?, ?, ?)",
757 )?
758 .execute((room_id, event_type, state_key, stripped, event_id, data))?;
759 Ok(())
760 }
761
762 fn get_state_event_by_id(
763 &self,
764 room_id: &[u8],
765 event_id: &[u8],
766 ) -> rusqlite::Result<Option<Vec<u8>>> {
767 self.query_row(
768 "SELECT data FROM state_event WHERE room_id = ? AND event_id = ?",
769 (room_id, event_id),
770 |row| row.get(0),
771 )
772 .optional()
773 }
774
775 fn remove_room_state_events(
781 &self,
782 room_id: &[u8],
783 stripped: Option<bool>,
784 ) -> rusqlite::Result<()> {
785 if let Some(stripped) = stripped {
786 self.prepare_cached("DELETE FROM state_event WHERE room_id = ? AND stripped = ?")?
787 .execute((room_id, stripped))?;
788 } else {
789 self.prepare_cached("DELETE FROM state_event WHERE room_id = ?")?
790 .execute((room_id,))?;
791 }
792 Ok(())
793 }
794
795 fn set_member(
796 &self,
797 room_id: &[u8],
798 user_id: &[u8],
799 membership: &[u8],
800 stripped: bool,
801 data: &[u8],
802 ) -> rusqlite::Result<()> {
803 self.prepare_cached(
804 "INSERT OR REPLACE
805 INTO member (room_id, user_id, membership, stripped, data)
806 VALUES (?, ?, ?, ?, ?)",
807 )?
808 .execute((room_id, user_id, membership, stripped, data))?;
809 Ok(())
810 }
811
812 fn remove_room_members(&self, room_id: &[u8], stripped: Option<bool>) -> rusqlite::Result<()> {
817 if let Some(stripped) = stripped {
818 self.prepare_cached("DELETE FROM member WHERE room_id = ? AND stripped = ?")?
819 .execute((room_id, stripped))?;
820 } else {
821 self.prepare_cached("DELETE FROM member WHERE room_id = ?")?.execute((room_id,))?;
822 }
823 Ok(())
824 }
825
826 fn set_profile(&self, room_id: &[u8], user_id: &[u8], data: &[u8]) -> rusqlite::Result<()> {
827 self.prepare_cached(
828 "INSERT OR REPLACE
829 INTO profile (room_id, user_id, data)
830 VALUES (?, ?, ?)",
831 )?
832 .execute((room_id, user_id, data))?;
833 Ok(())
834 }
835
836 fn remove_room_profiles(&self, room_id: &[u8]) -> rusqlite::Result<()> {
837 self.prepare("DELETE FROM profile WHERE room_id = ?")?.execute((room_id,))?;
838 Ok(())
839 }
840
841 fn remove_room_profile(&self, room_id: &[u8], user_id: &[u8]) -> rusqlite::Result<()> {
842 self.prepare("DELETE FROM profile WHERE room_id = ? AND user_id = ?")?
843 .execute((room_id, user_id))?;
844 Ok(())
845 }
846
847 fn set_receipt(
848 &self,
849 room_id: &[u8],
850 user_id: &[u8],
851 receipt_type: &[u8],
852 thread: &[u8],
853 event_id: &[u8],
854 data: &[u8],
855 ) -> rusqlite::Result<()> {
856 self.prepare_cached(
857 "INSERT OR REPLACE
858 INTO receipt (room_id, user_id, receipt_type, thread, event_id, data)
859 VALUES (?, ?, ?, ?, ?, ?)",
860 )?
861 .execute((room_id, user_id, receipt_type, thread, event_id, data))?;
862 Ok(())
863 }
864
865 fn remove_room_receipts(&self, room_id: &[u8]) -> rusqlite::Result<()> {
866 self.prepare("DELETE FROM receipt WHERE room_id = ?")?.execute((room_id,))?;
867 Ok(())
868 }
869
870 fn set_display_name(&self, room_id: &[u8], name: &[u8], data: &[u8]) -> rusqlite::Result<()> {
871 self.prepare_cached(
872 "INSERT OR REPLACE
873 INTO display_name (room_id, name, data)
874 VALUES (?, ?, ?)",
875 )?
876 .execute((room_id, name, data))?;
877 Ok(())
878 }
879
880 fn remove_display_name(&self, room_id: &[u8], name: &[u8]) -> rusqlite::Result<()> {
881 self.prepare("DELETE FROM display_name WHERE room_id = ? AND name = ?")?
882 .execute((room_id, name))?;
883 Ok(())
884 }
885
886 fn remove_room_display_names(&self, room_id: &[u8]) -> rusqlite::Result<()> {
887 self.prepare("DELETE FROM display_name WHERE room_id = ?")?.execute((room_id,))?;
888 Ok(())
889 }
890
891 fn remove_room_send_queue(&self, room_id: &[u8]) -> rusqlite::Result<()> {
892 self.prepare("DELETE FROM send_queue_events WHERE room_id = ?")?.execute((room_id,))?;
893 Ok(())
894 }
895
896 fn remove_room_dependent_send_queue(&self, room_id: &[u8]) -> rusqlite::Result<()> {
897 self.prepare("DELETE FROM dependent_send_queue_events WHERE room_id = ?")?
898 .execute((room_id,))?;
899 Ok(())
900 }
901}
902
903#[async_trait]
904trait SqliteObjectStateStoreExt: SqliteAsyncConnExt {
905 async fn get_kv_blob(&self, key: Key) -> Result<Option<Vec<u8>>> {
906 Ok(self
907 .query_row("SELECT value FROM kv_blob WHERE key = ?", (key,), |row| row.get(0))
908 .await
909 .optional()?)
910 }
911
912 async fn get_kv_blobs(&self, keys: Vec<Key>) -> Result<Vec<Vec<u8>>> {
913 let keys_length = keys.len();
914
915 self.chunk_large_query_over(keys, Some(keys_length), |txn, keys| {
916 let sql_params = repeat_vars(keys.len());
917 let sql = format!("SELECT value FROM kv_blob WHERE key IN ({sql_params})");
918
919 let params = rusqlite::params_from_iter(keys);
920
921 Ok(txn
922 .prepare(&sql)?
923 .query(params)?
924 .mapped(|row| row.get(0))
925 .collect::<Result<_, _>>()?)
926 })
927 .await
928 }
929
930 async fn set_kv_blob(&self, key: Key, value: Vec<u8>) -> Result<()>;
931
932 async fn delete_kv_blob(&self, key: Key) -> Result<()> {
933 self.execute("DELETE FROM kv_blob WHERE key = ?", (key,)).await?;
934 Ok(())
935 }
936
937 async fn get_room_infos(&self, room_id: Option<Key>) -> Result<Vec<Vec<u8>>> {
938 Ok(match room_id {
939 None => {
940 self.prepare("SELECT data FROM room_info", move |mut stmt| {
941 stmt.query_map((), |row| row.get(0))?.collect()
942 })
943 .await?
944 }
945
946 Some(room_id) => {
947 self.prepare("SELECT data FROM room_info WHERE room_id = ?", move |mut stmt| {
948 stmt.query((room_id,))?.mapped(|row| row.get(0)).collect()
949 })
950 .await?
951 }
952 })
953 }
954
955 async fn get_maybe_stripped_state_events_for_keys(
956 &self,
957 room_id: Key,
958 event_type: Key,
959 state_keys: Vec<Key>,
960 ) -> Result<Vec<(bool, Vec<u8>)>> {
961 self.chunk_large_query_over(state_keys, None, move |txn, state_keys: Vec<Key>| {
962 let sql_params = repeat_vars(state_keys.len());
963 let sql = format!(
964 "SELECT stripped, data FROM state_event
965 WHERE room_id = ? AND event_type = ? AND state_key IN ({sql_params})"
966 );
967
968 let params = rusqlite::params_from_iter(
969 [room_id.clone(), event_type.clone()].into_iter().chain(state_keys),
970 );
971
972 Ok(txn
973 .prepare(&sql)?
974 .query(params)?
975 .mapped(|row| Ok((row.get(0)?, row.get(1)?)))
976 .collect::<Result<_, _>>()?)
977 })
978 .await
979 }
980
981 async fn get_maybe_stripped_state_events(
982 &self,
983 room_id: Key,
984 event_type: Key,
985 ) -> Result<Vec<(bool, Vec<u8>)>> {
986 Ok(self
987 .prepare(
988 "SELECT stripped, data FROM state_event
989 WHERE room_id = ? AND event_type = ?",
990 |mut stmt| {
991 stmt.query((room_id, event_type))?
992 .mapped(|row| Ok((row.get(0)?, row.get(1)?)))
993 .collect()
994 },
995 )
996 .await?)
997 }
998
999 async fn get_profiles(
1000 &self,
1001 room_id: Key,
1002 user_ids: Vec<Key>,
1003 ) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
1004 let user_ids_length = user_ids.len();
1005
1006 self.chunk_large_query_over(user_ids, Some(user_ids_length), move |txn, user_ids| {
1007 let sql_params = repeat_vars(user_ids.len());
1008 let sql = format!(
1009 "SELECT user_id, data FROM profile WHERE room_id = ? AND user_id IN ({sql_params})"
1010 );
1011
1012 let params = rusqlite::params_from_iter(iter::once(room_id.clone()).chain(user_ids));
1013
1014 Ok(txn
1015 .prepare(&sql)?
1016 .query(params)?
1017 .mapped(|row| Ok((row.get(0)?, row.get(1)?)))
1018 .collect::<Result<_, _>>()?)
1019 })
1020 .await
1021 }
1022
1023 async fn get_user_ids(&self, room_id: Key, memberships: Vec<Key>) -> Result<Vec<Vec<u8>>> {
1024 let res = if memberships.is_empty() {
1025 self.prepare("SELECT data FROM member WHERE room_id = ?", |mut stmt| {
1026 stmt.query((room_id,))?.mapped(|row| row.get(0)).collect()
1027 })
1028 .await?
1029 } else {
1030 self.chunk_large_query_over(memberships, None, move |txn, memberships| {
1031 let sql_params = repeat_vars(memberships.len());
1032 let sql = format!(
1033 "SELECT data FROM member WHERE room_id = ? AND membership IN ({sql_params})"
1034 );
1035
1036 let params =
1037 rusqlite::params_from_iter(iter::once(room_id.clone()).chain(memberships));
1038
1039 Ok(txn
1040 .prepare(&sql)?
1041 .query(params)?
1042 .mapped(|row| row.get(0))
1043 .collect::<Result<_, _>>()?)
1044 })
1045 .await?
1046 };
1047
1048 Ok(res)
1049 }
1050
1051 async fn get_global_account_data(&self, event_type: Key) -> Result<Option<Vec<u8>>> {
1052 Ok(self
1053 .query_row(
1054 "SELECT data FROM global_account_data WHERE event_type = ?",
1055 (event_type,),
1056 |row| row.get(0),
1057 )
1058 .await
1059 .optional()?)
1060 }
1061
1062 async fn get_room_account_data(
1063 &self,
1064 room_id: Key,
1065 event_type: Key,
1066 ) -> Result<Option<Vec<u8>>> {
1067 Ok(self
1068 .query_row(
1069 "SELECT data FROM room_account_data WHERE room_id = ? AND event_type = ?",
1070 (room_id, event_type),
1071 |row| row.get(0),
1072 )
1073 .await
1074 .optional()?)
1075 }
1076
1077 async fn get_display_names(
1078 &self,
1079 room_id: Key,
1080 names: Vec<Key>,
1081 ) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
1082 let names_length = names.len();
1083
1084 self.chunk_large_query_over(names, Some(names_length), move |txn, names| {
1085 let sql_params = repeat_vars(names.len());
1086 let sql = format!(
1087 "SELECT name, data FROM display_name WHERE room_id = ? AND name IN ({sql_params})"
1088 );
1089
1090 let params = rusqlite::params_from_iter(iter::once(room_id.clone()).chain(names));
1091
1092 Ok(txn
1093 .prepare(&sql)?
1094 .query(params)?
1095 .mapped(|row| Ok((row.get(0)?, row.get(1)?)))
1096 .collect::<Result<_, _>>()?)
1097 })
1098 .await
1099 }
1100
1101 async fn get_user_receipt(
1102 &self,
1103 room_id: Key,
1104 receipt_type: Key,
1105 thread: Key,
1106 user_id: Key,
1107 ) -> Result<Option<Vec<u8>>> {
1108 Ok(self
1109 .query_row(
1110 "SELECT data FROM receipt
1111 WHERE room_id = ? AND receipt_type = ? AND thread = ? and user_id = ?",
1112 (room_id, receipt_type, thread, user_id),
1113 |row| row.get(0),
1114 )
1115 .await
1116 .optional()?)
1117 }
1118
1119 async fn get_event_receipts(
1120 &self,
1121 room_id: Key,
1122 receipt_type: Key,
1123 thread: Key,
1124 event_id: Key,
1125 ) -> Result<Vec<Vec<u8>>> {
1126 Ok(self
1127 .prepare(
1128 "SELECT data FROM receipt
1129 WHERE room_id = ? AND receipt_type = ? AND thread = ? and event_id = ?",
1130 |mut stmt| {
1131 stmt.query((room_id, receipt_type, thread, event_id))?
1132 .mapped(|row| row.get(0))
1133 .collect()
1134 },
1135 )
1136 .await?)
1137 }
1138}
1139
1140#[async_trait]
1141impl SqliteObjectStateStoreExt for SqliteAsyncConn {
1142 async fn set_kv_blob(&self, key: Key, value: Vec<u8>) -> Result<()> {
1143 Ok(self.interact(move |conn| conn.set_kv_blob(&key, &value)).await.unwrap()?)
1144 }
1145}
1146
1147#[async_trait]
1148impl StateStore for SqliteStateStore {
1149 type Error = Error;
1150
1151 async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<Option<StateStoreDataValue>> {
1152 self.read()
1153 .await?
1154 .get_kv_blob(self.encode_state_store_data_key(key))
1155 .await?
1156 .map(|data| {
1157 Ok(match key {
1158 StateStoreDataKey::SyncToken => {
1159 StateStoreDataValue::SyncToken(self.deserialize_value(&data)?)
1160 }
1161 StateStoreDataKey::SupportedVersions => {
1162 StateStoreDataValue::SupportedVersions(self.deserialize_value(&data)?)
1163 }
1164 StateStoreDataKey::WellKnown => {
1165 StateStoreDataValue::WellKnown(self.deserialize_value(&data)?)
1166 }
1167 StateStoreDataKey::Filter(_) => {
1168 StateStoreDataValue::Filter(self.deserialize_value(&data)?)
1169 }
1170 StateStoreDataKey::UserAvatarUrl(_) => {
1171 StateStoreDataValue::UserAvatarUrl(self.deserialize_value(&data)?)
1172 }
1173 StateStoreDataKey::RecentlyVisitedRooms(_) => {
1174 StateStoreDataValue::RecentlyVisitedRooms(self.deserialize_value(&data)?)
1175 }
1176 StateStoreDataKey::UtdHookManagerData => {
1177 StateStoreDataValue::UtdHookManagerData(self.deserialize_value(&data)?)
1178 }
1179 StateStoreDataKey::OneTimeKeyAlreadyUploaded => {
1180 StateStoreDataValue::OneTimeKeyAlreadyUploaded
1181 }
1182 StateStoreDataKey::ComposerDraft(_, _) => {
1183 StateStoreDataValue::ComposerDraft(self.deserialize_value(&data)?)
1184 }
1185 StateStoreDataKey::SeenKnockRequests(_) => {
1186 StateStoreDataValue::SeenKnockRequests(self.deserialize_value(&data)?)
1187 }
1188 StateStoreDataKey::ThreadSubscriptionsCatchupTokens => {
1189 StateStoreDataValue::ThreadSubscriptionsCatchupTokens(
1190 self.deserialize_value(&data)?,
1191 )
1192 }
1193 StateStoreDataKey::HomeserverCapabilities => {
1194 StateStoreDataValue::HomeserverCapabilities(self.deserialize_value(&data)?)
1195 }
1196 })
1197 })
1198 .transpose()
1199 }
1200
1201 async fn set_kv_data(
1202 &self,
1203 key: StateStoreDataKey<'_>,
1204 value: StateStoreDataValue,
1205 ) -> Result<()> {
1206 let serialized_value = match key {
1207 StateStoreDataKey::SyncToken => self.serialize_value(
1208 &value.into_sync_token().expect("Session data not a sync token"),
1209 )?,
1210 StateStoreDataKey::SupportedVersions => self.serialize_value(
1211 &value
1212 .into_supported_versions()
1213 .expect("Session data not containing supported versions"),
1214 )?,
1215 StateStoreDataKey::WellKnown => self.serialize_value(
1216 &value.into_well_known().expect("Session data not containing well-known"),
1217 )?,
1218 StateStoreDataKey::Filter(_) => {
1219 self.serialize_value(&value.into_filter().expect("Session data not a filter"))?
1220 }
1221 StateStoreDataKey::UserAvatarUrl(_) => self.serialize_value(
1222 &value.into_user_avatar_url().expect("Session data not an user avatar url"),
1223 )?,
1224 StateStoreDataKey::RecentlyVisitedRooms(_) => self.serialize_value(
1225 &value.into_recently_visited_rooms().expect("Session data not breadcrumbs"),
1226 )?,
1227 StateStoreDataKey::UtdHookManagerData => self.serialize_value(
1228 &value.into_utd_hook_manager_data().expect("Session data not UtdHookManagerData"),
1229 )?,
1230 StateStoreDataKey::OneTimeKeyAlreadyUploaded => {
1231 self.serialize_value(&true).expect("We should be able to serialize a boolean")
1232 }
1233 StateStoreDataKey::ComposerDraft(_, _) => self.serialize_value(
1234 &value.into_composer_draft().expect("Session data not a composer draft"),
1235 )?,
1236 StateStoreDataKey::SeenKnockRequests(_) => self.serialize_value(
1237 &value
1238 .into_seen_knock_requests()
1239 .expect("Session data is not a set of seen knock request ids"),
1240 )?,
1241 StateStoreDataKey::ThreadSubscriptionsCatchupTokens => self.serialize_value(
1242 &value
1243 .into_thread_subscriptions_catchup_tokens()
1244 .expect("Session data is not a list of thread subscription catchup tokens"),
1245 )?,
1246 StateStoreDataKey::HomeserverCapabilities => self.serialize_value(
1247 &value
1248 .into_homeserver_capabilities()
1249 .expect("Session data is not the homeserver capabilities"),
1250 )?,
1251 };
1252
1253 self.write()
1254 .await?
1255 .set_kv_blob(self.encode_state_store_data_key(key), serialized_value)
1256 .await
1257 }
1258
1259 async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> {
1260 self.write().await?.delete_kv_blob(self.encode_state_store_data_key(key)).await
1261 }
1262
1263 async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
1264 let changes = changes.to_owned();
1265 let this = self.clone();
1266 self.write()
1267 .await?
1268 .with_transaction(move |txn| {
1269 let StateChanges {
1270 sync_token,
1271 account_data,
1272 presence,
1273 profiles,
1274 profiles_to_delete,
1275 state,
1276 room_account_data,
1277 room_infos,
1278 receipts,
1279 redactions,
1280 stripped_state,
1281 ambiguity_maps,
1282 } = changes;
1283
1284 if let Some(sync_token) = sync_token {
1285 let key = this.encode_state_store_data_key(StateStoreDataKey::SyncToken);
1286 let value = this.serialize_value(&sync_token)?;
1287 txn.set_kv_blob(&key, &value)?;
1288 }
1289
1290 for (event_type, event) in account_data {
1291 let event_type =
1292 this.encode_key(keys::GLOBAL_ACCOUNT_DATA, event_type.to_string());
1293 let data = this.serialize_json(&event)?;
1294 txn.set_global_account_data(&event_type, &data)?;
1295 }
1296
1297 for (room_id, events) in room_account_data {
1298 let room_id = this.encode_key(keys::ROOM_ACCOUNT_DATA, room_id);
1299 for (event_type, event) in events {
1300 let event_type =
1301 this.encode_key(keys::ROOM_ACCOUNT_DATA, event_type.to_string());
1302 let data = this.serialize_json(&event)?;
1303 txn.set_room_account_data(&room_id, &event_type, &data)?;
1304 }
1305 }
1306
1307 for (user_id, event) in presence {
1308 let key = this.encode_presence_key(&user_id);
1309 let value = this.serialize_json(&event)?;
1310 txn.set_kv_blob(&key, &value)?;
1311 }
1312
1313 for (room_id, room_info) in room_infos {
1314 let stripped = room_info.state() == RoomState::Invited;
1315 this.remove_maybe_stripped_room_data(txn, &room_id, !stripped)?;
1317
1318 let room_id = this.encode_key(keys::ROOM_INFO, room_id);
1319 let state = this
1320 .encode_key(keys::ROOM_INFO, serde_json::to_string(&room_info.state())?);
1321 let data = this.serialize_json(&room_info)?;
1322 txn.set_room_info(&room_id, &state, &data)?;
1323 }
1324
1325 for (room_id, user_ids) in profiles_to_delete {
1326 let room_id = this.encode_key(keys::PROFILE, room_id);
1327 for user_id in user_ids {
1328 let user_id = this.encode_key(keys::PROFILE, user_id);
1329 txn.remove_room_profile(&room_id, &user_id)?;
1330 }
1331 }
1332
1333 for (room_id, state_event_types) in state {
1334 let profiles = profiles.get(&room_id);
1335 let encoded_room_id = this.encode_key(keys::STATE_EVENT, &room_id);
1336
1337 for (event_type, state_events) in state_event_types {
1338 let encoded_event_type =
1339 this.encode_key(keys::STATE_EVENT, event_type.to_string());
1340
1341 for (state_key, raw_state_event) in state_events {
1342 let encoded_state_key = this.encode_key(keys::STATE_EVENT, &state_key);
1343 let data = this.serialize_json(&raw_state_event)?;
1344
1345 let event_id: Option<String> =
1346 raw_state_event.get_field("event_id").ok().flatten();
1347 let encoded_event_id =
1348 event_id.as_ref().map(|e| this.encode_key(keys::STATE_EVENT, e));
1349
1350 txn.set_state_event(
1351 &encoded_room_id,
1352 &encoded_event_type,
1353 &encoded_state_key,
1354 false,
1355 encoded_event_id.as_deref(),
1356 &data,
1357 )?;
1358
1359 if event_type == StateEventType::RoomMember {
1360 let member_event = match raw_state_event
1361 .deserialize_as_unchecked::<SyncRoomMemberEvent>()
1362 {
1363 Ok(ev) => ev,
1364 Err(e) => {
1365 debug!(event_id, "Failed to deserialize member event: {e}");
1366 continue;
1367 }
1368 };
1369
1370 let encoded_room_id = this.encode_key(keys::MEMBER, &room_id);
1371 let user_id = this.encode_key(keys::MEMBER, &state_key);
1372 let membership = this
1373 .encode_key(keys::MEMBER, member_event.membership().as_str());
1374 let data = this.serialize_value(&state_key)?;
1375
1376 txn.set_member(
1377 &encoded_room_id,
1378 &user_id,
1379 &membership,
1380 false,
1381 &data,
1382 )?;
1383
1384 if let Some(profile) =
1385 profiles.and_then(|p| p.get(member_event.state_key()))
1386 {
1387 let room_id = this.encode_key(keys::PROFILE, &room_id);
1388 let user_id = this.encode_key(keys::PROFILE, &state_key);
1389 let data = this.serialize_json(&profile)?;
1390 txn.set_profile(&room_id, &user_id, &data)?;
1391 }
1392 }
1393 }
1394 }
1395 }
1396
1397 for (room_id, stripped_state_event_types) in stripped_state {
1398 let encoded_room_id = this.encode_key(keys::STATE_EVENT, &room_id);
1399
1400 for (event_type, stripped_state_events) in stripped_state_event_types {
1401 let encoded_event_type =
1402 this.encode_key(keys::STATE_EVENT, event_type.to_string());
1403
1404 for (state_key, raw_stripped_state_event) in stripped_state_events {
1405 let encoded_state_key = this.encode_key(keys::STATE_EVENT, &state_key);
1406 let data = this.serialize_json(&raw_stripped_state_event)?;
1407 txn.set_state_event(
1408 &encoded_room_id,
1409 &encoded_event_type,
1410 &encoded_state_key,
1411 true,
1412 None,
1413 &data,
1414 )?;
1415
1416 if event_type == StateEventType::RoomMember {
1417 let member_event = match raw_stripped_state_event
1418 .deserialize_as_unchecked::<StrippedRoomMemberEvent>(
1419 ) {
1420 Ok(ev) => ev,
1421 Err(e) => {
1422 debug!("Failed to deserialize stripped member event: {e}");
1423 continue;
1424 }
1425 };
1426
1427 let room_id = this.encode_key(keys::MEMBER, &room_id);
1428 let user_id = this.encode_key(keys::MEMBER, &state_key);
1429 let membership = this.encode_key(
1430 keys::MEMBER,
1431 member_event.content.membership.as_str(),
1432 );
1433 let data = this.serialize_value(&state_key)?;
1434
1435 txn.set_member(&room_id, &user_id, &membership, true, &data)?;
1436 }
1437 }
1438 }
1439 }
1440
1441 for (room_id, receipt_event) in receipts {
1442 let room_id = this.encode_key(keys::RECEIPT, room_id);
1443
1444 for (event_id, receipt_types) in receipt_event {
1445 let encoded_event_id = this.encode_key(keys::RECEIPT, &event_id);
1446
1447 for (receipt_type, receipt_users) in receipt_types {
1448 let receipt_type =
1449 this.encode_key(keys::RECEIPT, receipt_type.as_str());
1450
1451 for (user_id, receipt) in receipt_users {
1452 let encoded_user_id = this.encode_key(keys::RECEIPT, &user_id);
1453 let thread = this.encode_key(
1456 keys::RECEIPT,
1457 rmp_serde::to_vec_named(&receipt.thread)?,
1458 );
1459 let data = this.serialize_json(&ReceiptData {
1460 receipt,
1461 event_id: event_id.clone(),
1462 user_id,
1463 })?;
1464
1465 txn.set_receipt(
1466 &room_id,
1467 &encoded_user_id,
1468 &receipt_type,
1469 &thread,
1470 &encoded_event_id,
1471 &data,
1472 )?;
1473 }
1474 }
1475 }
1476 }
1477
1478 for (room_id, redactions) in redactions {
1479 let make_redaction_rules = || {
1480 let encoded_room_id = this.encode_key(keys::ROOM_INFO, &room_id);
1481 txn.get_room_info(&encoded_room_id)
1482 .ok()
1483 .flatten()
1484 .and_then(|v| this.deserialize_json::<RoomInfo>(&v).ok())
1485 .map(|info| info.room_version_rules_or_default())
1486 .unwrap_or_else(|| {
1487 warn!(
1488 ?room_id,
1489 "Unable to get the room version rules, defaulting to rules for room version {ROOM_VERSION_FALLBACK}"
1490 );
1491 ROOM_VERSION_RULES_FALLBACK
1492 }).redaction
1493 };
1494
1495 let encoded_room_id = this.encode_key(keys::STATE_EVENT, &room_id);
1496 let mut redaction_rules = None;
1497
1498 for (event_id, redaction) in redactions {
1499 let event_id = this.encode_key(keys::STATE_EVENT, event_id);
1500
1501 if let Some(Ok(raw_event)) = txn
1502 .get_state_event_by_id(&encoded_room_id, &event_id)?
1503 .map(|value| this.deserialize_json::<Raw<AnySyncStateEvent>>(&value))
1504 {
1505 let event = raw_event.deserialize()?;
1506 let redacted = redact(
1507 raw_event.deserialize_as::<CanonicalJsonObject>()?,
1508 redaction_rules.get_or_insert_with(make_redaction_rules),
1509 Some(RedactedBecause::from_raw_event(&redaction)?),
1510 )
1511 .map_err(Error::Redaction)?;
1512 let data = this.serialize_json(&redacted)?;
1513
1514 let event_type =
1515 this.encode_key(keys::STATE_EVENT, event.event_type().to_string());
1516 let state_key = this.encode_key(keys::STATE_EVENT, event.state_key());
1517
1518 txn.set_state_event(
1519 &encoded_room_id,
1520 &event_type,
1521 &state_key,
1522 false,
1523 Some(&event_id),
1524 &data,
1525 )?;
1526 }
1527 }
1528 }
1529
1530 for (room_id, display_names) in ambiguity_maps {
1531 let room_id = this.encode_key(keys::DISPLAY_NAME, room_id);
1532
1533 for (name, user_ids) in display_names {
1534 let encoded_name = this.encode_key(
1535 keys::DISPLAY_NAME,
1536 name.as_normalized_str().unwrap_or_else(|| name.as_raw_str()),
1537 );
1538 let data = this.serialize_json(&user_ids)?;
1539
1540 if user_ids.is_empty() {
1541 txn.remove_display_name(&room_id, &encoded_name)?;
1542
1543 let raw_name = this.encode_key(keys::DISPLAY_NAME, name.as_raw_str());
1558 txn.remove_display_name(&room_id, &raw_name)?;
1559 } else {
1560 txn.set_display_name(&room_id, &encoded_name, &data)?;
1562 }
1563 }
1564 }
1565
1566 Ok::<_, Error>(())
1567 })
1568 .await?;
1569
1570 Ok(())
1571 }
1572
1573 async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
1574 self.read()
1575 .await?
1576 .get_kv_blob(self.encode_presence_key(user_id))
1577 .await?
1578 .map(|data| self.deserialize_json(&data))
1579 .transpose()
1580 }
1581
1582 async fn get_presence_events(
1583 &self,
1584 user_ids: &[OwnedUserId],
1585 ) -> Result<Vec<Raw<PresenceEvent>>> {
1586 if user_ids.is_empty() {
1587 return Ok(Vec::new());
1588 }
1589
1590 let user_ids = user_ids.iter().map(|u| self.encode_presence_key(u)).collect();
1591 self.read()
1592 .await?
1593 .get_kv_blobs(user_ids)
1594 .await?
1595 .into_iter()
1596 .map(|data| self.deserialize_json(&data))
1597 .collect()
1598 }
1599
1600 async fn get_state_event(
1601 &self,
1602 room_id: &RoomId,
1603 event_type: StateEventType,
1604 state_key: &str,
1605 ) -> Result<Option<RawAnySyncOrStrippedState>> {
1606 Ok(self
1607 .get_state_events_for_keys(room_id, event_type, &[state_key])
1608 .await?
1609 .into_iter()
1610 .next())
1611 }
1612
1613 async fn get_state_events(
1614 &self,
1615 room_id: &RoomId,
1616 event_type: StateEventType,
1617 ) -> Result<Vec<RawAnySyncOrStrippedState>> {
1618 let room_id = self.encode_key(keys::STATE_EVENT, room_id);
1619 let event_type = self.encode_key(keys::STATE_EVENT, event_type.to_string());
1620 self.read()
1621 .await?
1622 .get_maybe_stripped_state_events(room_id, event_type)
1623 .await?
1624 .into_iter()
1625 .map(|(stripped, data)| {
1626 let ev = if stripped {
1627 RawAnySyncOrStrippedState::Stripped(self.deserialize_json(&data)?)
1628 } else {
1629 RawAnySyncOrStrippedState::Sync(self.deserialize_json(&data)?)
1630 };
1631
1632 Ok(ev)
1633 })
1634 .collect()
1635 }
1636
1637 async fn get_state_events_for_keys(
1638 &self,
1639 room_id: &RoomId,
1640 event_type: StateEventType,
1641 state_keys: &[&str],
1642 ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
1643 if state_keys.is_empty() {
1644 return Ok(Vec::new());
1645 }
1646
1647 let room_id = self.encode_key(keys::STATE_EVENT, room_id);
1648 let event_type = self.encode_key(keys::STATE_EVENT, event_type.to_string());
1649 let state_keys = state_keys.iter().map(|k| self.encode_key(keys::STATE_EVENT, k)).collect();
1650 self.read()
1651 .await?
1652 .get_maybe_stripped_state_events_for_keys(room_id, event_type, state_keys)
1653 .await?
1654 .into_iter()
1655 .map(|(stripped, data)| {
1656 let ev = if stripped {
1657 RawAnySyncOrStrippedState::Stripped(self.deserialize_json(&data)?)
1658 } else {
1659 RawAnySyncOrStrippedState::Sync(self.deserialize_json(&data)?)
1660 };
1661
1662 Ok(ev)
1663 })
1664 .collect()
1665 }
1666
1667 async fn get_profile(
1668 &self,
1669 room_id: &RoomId,
1670 user_id: &UserId,
1671 ) -> Result<Option<MinimalRoomMemberEvent>> {
1672 let room_id = self.encode_key(keys::PROFILE, room_id);
1673 let user_ids = vec![self.encode_key(keys::PROFILE, user_id)];
1674
1675 self.read()
1676 .await?
1677 .get_profiles(room_id, user_ids)
1678 .await?
1679 .into_iter()
1680 .next()
1681 .map(|(_, data)| self.deserialize_json(&data))
1682 .transpose()
1683 }
1684
1685 async fn get_profiles<'a>(
1686 &self,
1687 room_id: &RoomId,
1688 user_ids: &'a [OwnedUserId],
1689 ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>> {
1690 if user_ids.is_empty() {
1691 return Ok(BTreeMap::new());
1692 }
1693
1694 let room_id = self.encode_key(keys::PROFILE, room_id);
1695 let mut user_ids_map = user_ids
1696 .iter()
1697 .map(|u| (self.encode_key(keys::PROFILE, u), u.as_ref()))
1698 .collect::<BTreeMap<_, _>>();
1699 let user_ids = user_ids_map.keys().cloned().collect();
1700
1701 self.read()
1702 .await?
1703 .get_profiles(room_id, user_ids)
1704 .await?
1705 .into_iter()
1706 .map(|(user_id, data)| {
1707 Ok((
1708 user_ids_map
1709 .remove(user_id.as_slice())
1710 .expect("returned user IDs were requested"),
1711 self.deserialize_json(&data)?,
1712 ))
1713 })
1714 .collect()
1715 }
1716
1717 async fn get_user_ids(
1718 &self,
1719 room_id: &RoomId,
1720 membership: RoomMemberships,
1721 ) -> Result<Vec<OwnedUserId>> {
1722 let room_id = self.encode_key(keys::MEMBER, room_id);
1723 let memberships = membership
1724 .as_vec()
1725 .into_iter()
1726 .map(|m| self.encode_key(keys::MEMBER, m.as_str()))
1727 .collect();
1728 self.read()
1729 .await?
1730 .get_user_ids(room_id, memberships)
1731 .await?
1732 .iter()
1733 .map(|data| self.deserialize_value(data))
1734 .collect()
1735 }
1736
1737 async fn get_room_infos(&self, room_load_settings: &RoomLoadSettings) -> Result<Vec<RoomInfo>> {
1738 self.read()
1739 .await?
1740 .get_room_infos(match room_load_settings {
1741 RoomLoadSettings::All => None,
1742 RoomLoadSettings::One(room_id) => Some(self.encode_key(keys::ROOM_INFO, room_id)),
1743 })
1744 .await?
1745 .into_iter()
1746 .map(|data| self.deserialize_json(&data))
1747 .collect()
1748 }
1749
1750 async fn get_users_with_display_name(
1751 &self,
1752 room_id: &RoomId,
1753 display_name: &DisplayName,
1754 ) -> Result<BTreeSet<OwnedUserId>> {
1755 let room_id = self.encode_key(keys::DISPLAY_NAME, room_id);
1756 let names = vec![self.encode_key(
1757 keys::DISPLAY_NAME,
1758 display_name.as_normalized_str().unwrap_or_else(|| display_name.as_raw_str()),
1759 )];
1760
1761 Ok(self
1762 .read()
1763 .await?
1764 .get_display_names(room_id, names)
1765 .await?
1766 .into_iter()
1767 .next()
1768 .map(|(_, data)| self.deserialize_json(&data))
1769 .transpose()?
1770 .unwrap_or_default())
1771 }
1772
1773 async fn get_users_with_display_names<'a>(
1774 &self,
1775 room_id: &RoomId,
1776 display_names: &'a [DisplayName],
1777 ) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>> {
1778 let mut result = HashMap::new();
1779
1780 if display_names.is_empty() {
1781 return Ok(result);
1782 }
1783
1784 let room_id = self.encode_key(keys::DISPLAY_NAME, room_id);
1785 let mut names_map = display_names
1786 .iter()
1787 .flat_map(|display_name| {
1788 let raw =
1798 (self.encode_key(keys::DISPLAY_NAME, display_name.as_raw_str()), display_name);
1799 let normalized = display_name.as_normalized_str().map(|normalized| {
1800 (self.encode_key(keys::DISPLAY_NAME, normalized), display_name)
1801 });
1802
1803 iter::once(raw).chain(normalized)
1804 })
1805 .collect::<BTreeMap<_, _>>();
1806 let names = names_map.keys().cloned().collect();
1807
1808 for (name, data) in self.read().await?.get_display_names(room_id, names).await?.into_iter()
1809 {
1810 let display_name =
1811 names_map.remove(name.as_slice()).expect("returned display names were requested");
1812 let user_ids: BTreeSet<_> = self.deserialize_json(&data)?;
1813
1814 result.entry(display_name).or_insert_with(BTreeSet::new).extend(user_ids);
1815 }
1816
1817 Ok(result)
1818 }
1819
1820 async fn get_account_data_event(
1821 &self,
1822 event_type: GlobalAccountDataEventType,
1823 ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
1824 let event_type = self.encode_key(keys::GLOBAL_ACCOUNT_DATA, event_type.to_string());
1825 self.read()
1826 .await?
1827 .get_global_account_data(event_type)
1828 .await?
1829 .map(|value| self.deserialize_json(&value))
1830 .transpose()
1831 }
1832
1833 async fn get_room_account_data_event(
1834 &self,
1835 room_id: &RoomId,
1836 event_type: RoomAccountDataEventType,
1837 ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
1838 let room_id = self.encode_key(keys::ROOM_ACCOUNT_DATA, room_id);
1839 let event_type = self.encode_key(keys::ROOM_ACCOUNT_DATA, event_type.to_string());
1840 self.read()
1841 .await?
1842 .get_room_account_data(room_id, event_type)
1843 .await?
1844 .map(|value| self.deserialize_json(&value))
1845 .transpose()
1846 }
1847
1848 async fn get_user_room_receipt_event(
1849 &self,
1850 room_id: &RoomId,
1851 receipt_type: ReceiptType,
1852 thread: ReceiptThread,
1853 user_id: &UserId,
1854 ) -> Result<Option<(OwnedEventId, Receipt)>> {
1855 let room_id = self.encode_key(keys::RECEIPT, room_id);
1856 let receipt_type = self.encode_key(keys::RECEIPT, receipt_type.to_string());
1857 let thread = self.encode_key(keys::RECEIPT, rmp_serde::to_vec_named(&thread)?);
1860 let user_id = self.encode_key(keys::RECEIPT, user_id);
1861
1862 self.read()
1863 .await?
1864 .get_user_receipt(room_id, receipt_type, thread, user_id)
1865 .await?
1866 .map(|value| {
1867 self.deserialize_json::<ReceiptData>(&value).map(|d| (d.event_id, d.receipt))
1868 })
1869 .transpose()
1870 }
1871
1872 async fn get_event_room_receipt_events(
1873 &self,
1874 room_id: &RoomId,
1875 receipt_type: ReceiptType,
1876 thread: ReceiptThread,
1877 event_id: &EventId,
1878 ) -> Result<Vec<(OwnedUserId, Receipt)>> {
1879 let room_id = self.encode_key(keys::RECEIPT, room_id);
1880 let receipt_type = self.encode_key(keys::RECEIPT, receipt_type.to_string());
1881 let thread = self.encode_key(keys::RECEIPT, rmp_serde::to_vec_named(&thread)?);
1884 let event_id = self.encode_key(keys::RECEIPT, event_id);
1885
1886 self.read()
1887 .await?
1888 .get_event_receipts(room_id, receipt_type, thread, event_id)
1889 .await?
1890 .iter()
1891 .map(|value| {
1892 self.deserialize_json::<ReceiptData>(value).map(|d| (d.user_id, d.receipt))
1893 })
1894 .collect()
1895 }
1896
1897 async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
1898 self.read().await?.get_kv_blob(self.encode_custom_key(key)).await
1899 }
1900
1901 async fn set_custom_value_no_read(&self, key: &[u8], value: Vec<u8>) -> Result<()> {
1902 let conn = self.write().await?;
1903 let key = self.encode_custom_key(key);
1904 conn.set_kv_blob(key, value).await?;
1905 Ok(())
1906 }
1907
1908 async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
1909 let conn = self.write().await?;
1910 let key = self.encode_custom_key(key);
1911 let previous = conn.get_kv_blob(key.clone()).await?;
1912 conn.set_kv_blob(key, value).await?;
1913 Ok(previous)
1914 }
1915
1916 async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
1917 let conn = self.write().await?;
1918 let key = self.encode_custom_key(key);
1919 let previous = conn.get_kv_blob(key.clone()).await?;
1920 if previous.is_some() {
1921 conn.delete_kv_blob(key).await?;
1922 }
1923 Ok(previous)
1924 }
1925
1926 async fn remove_room(&self, room_id: &RoomId) -> Result<()> {
1927 let this = self.clone();
1928 let room_id = room_id.to_owned();
1929
1930 let conn = self.write().await?;
1931
1932 conn.with_transaction(move |txn| -> Result<()> {
1933 let room_info_room_id = this.encode_key(keys::ROOM_INFO, &room_id);
1934 txn.remove_room_info(&room_info_room_id)?;
1935
1936 let state_event_room_id = this.encode_key(keys::STATE_EVENT, &room_id);
1937 txn.remove_room_state_events(&state_event_room_id, None)?;
1938
1939 let member_room_id = this.encode_key(keys::MEMBER, &room_id);
1940 txn.remove_room_members(&member_room_id, None)?;
1941
1942 let profile_room_id = this.encode_key(keys::PROFILE, &room_id);
1943 txn.remove_room_profiles(&profile_room_id)?;
1944
1945 let room_account_data_room_id = this.encode_key(keys::ROOM_ACCOUNT_DATA, &room_id);
1946 txn.remove_room_account_data(&room_account_data_room_id)?;
1947
1948 let receipt_room_id = this.encode_key(keys::RECEIPT, &room_id);
1949 txn.remove_room_receipts(&receipt_room_id)?;
1950
1951 let display_name_room_id = this.encode_key(keys::DISPLAY_NAME, &room_id);
1952 txn.remove_room_display_names(&display_name_room_id)?;
1953
1954 let send_queue_room_id = this.encode_key(keys::SEND_QUEUE, &room_id);
1955 txn.remove_room_send_queue(&send_queue_room_id)?;
1956
1957 let dependent_send_queue_room_id =
1958 this.encode_key(keys::DEPENDENTS_SEND_QUEUE, &room_id);
1959 txn.remove_room_dependent_send_queue(&dependent_send_queue_room_id)?;
1960
1961 let thread_subscriptions_room_id =
1962 this.encode_key(keys::THREAD_SUBSCRIPTIONS, &room_id);
1963 txn.execute(
1964 "DELETE FROM thread_subscriptions WHERE room_id = ?",
1965 (thread_subscriptions_room_id,),
1966 )?;
1967
1968 Ok(())
1969 })
1970 .await?;
1971
1972 conn.vacuum().await
1973 }
1974
1975 async fn save_send_queue_request(
1976 &self,
1977 room_id: &RoomId,
1978 transaction_id: OwnedTransactionId,
1979 created_at: MilliSecondsSinceUnixEpoch,
1980 content: QueuedRequestKind,
1981 priority: usize,
1982 ) -> Result<(), Self::Error> {
1983 let room_id_key = self.encode_key(keys::SEND_QUEUE, room_id);
1984 let room_id_value = self.serialize_value(&room_id.to_owned())?;
1985
1986 let content = self.serialize_json(&content)?;
1987 let created_at_ts: u64 = created_at.0.into();
1993 self.write()
1994 .await?
1995 .with_transaction(move |txn| {
1996 txn.prepare_cached("INSERT INTO send_queue_events (room_id, room_id_val, transaction_id, content, priority, created_at) VALUES (?, ?, ?, ?, ?, ?)")?.execute((room_id_key, room_id_value, transaction_id.to_string(), content, priority, created_at_ts))?;
1997 Ok(())
1998 })
1999 .await
2000 }
2001
2002 async fn update_send_queue_request(
2003 &self,
2004 room_id: &RoomId,
2005 transaction_id: &TransactionId,
2006 content: QueuedRequestKind,
2007 ) -> Result<bool, Self::Error> {
2008 let room_id = self.encode_key(keys::SEND_QUEUE, room_id);
2009
2010 let content = self.serialize_json(&content)?;
2011 let transaction_id = transaction_id.to_string();
2014
2015 let num_updated = self.write()
2016 .await?
2017 .with_transaction(move |txn| {
2018 txn.prepare_cached("UPDATE send_queue_events SET wedge_reason = NULL, content = ? WHERE room_id = ? AND transaction_id = ?")?.execute((content, room_id, transaction_id))
2019 })
2020 .await?;
2021
2022 Ok(num_updated > 0)
2023 }
2024
2025 async fn remove_send_queue_request(
2026 &self,
2027 room_id: &RoomId,
2028 transaction_id: &TransactionId,
2029 ) -> Result<bool, Self::Error> {
2030 let room_id = self.encode_key(keys::SEND_QUEUE, room_id);
2031
2032 let transaction_id = transaction_id.to_string();
2034
2035 let num_deleted = self
2036 .write()
2037 .await?
2038 .with_transaction(move |txn| {
2039 txn.prepare_cached(
2040 "DELETE FROM send_queue_events WHERE room_id = ? AND transaction_id = ?",
2041 )?
2042 .execute((room_id, &transaction_id))
2043 })
2044 .await?;
2045
2046 Ok(num_deleted > 0)
2047 }
2048
2049 async fn load_send_queue_requests(
2050 &self,
2051 room_id: &RoomId,
2052 ) -> Result<Vec<QueuedRequest>, Self::Error> {
2053 let room_id = self.encode_key(keys::SEND_QUEUE, room_id);
2054
2055 let res: Vec<(String, Vec<u8>, Option<Vec<u8>>, usize, Option<u64>)> = self
2059 .read()
2060 .await?
2061 .prepare(
2062 "SELECT transaction_id, content, wedge_reason, priority, created_at FROM send_queue_events WHERE room_id = ? ORDER BY priority DESC, ROWID",
2063 |mut stmt| {
2064 stmt.query((room_id,))?
2065 .mapped(|row| Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?, row.get(4)?)))
2066 .collect()
2067 },
2068 )
2069 .await?;
2070
2071 let mut requests = Vec::with_capacity(res.len());
2072
2073 for entry in res {
2074 let created_at = entry
2075 .4
2076 .and_then(UInt::new)
2077 .map_or_else(MilliSecondsSinceUnixEpoch::now, MilliSecondsSinceUnixEpoch);
2078
2079 requests.push(QueuedRequest {
2080 transaction_id: entry.0.into(),
2081 kind: self.deserialize_json(&entry.1)?,
2082 error: entry.2.map(|v| self.deserialize_value(&v)).transpose()?,
2083 priority: entry.3,
2084 created_at,
2085 });
2086 }
2087
2088 Ok(requests)
2089 }
2090
2091 async fn update_send_queue_request_status(
2092 &self,
2093 room_id: &RoomId,
2094 transaction_id: &TransactionId,
2095 error: Option<QueueWedgeError>,
2096 ) -> Result<(), Self::Error> {
2097 let room_id = self.encode_key(keys::SEND_QUEUE, room_id);
2098
2099 let transaction_id = transaction_id.to_string();
2101
2102 let error_value = error.map(|e| self.serialize_value(&e)).transpose()?;
2104
2105 self.write()
2106 .await?
2107 .with_transaction(move |txn| {
2108 txn.prepare_cached("UPDATE send_queue_events SET wedge_reason = ? WHERE room_id = ? AND transaction_id = ?")?.execute((error_value, room_id, transaction_id))?;
2109 Ok(())
2110 })
2111 .await
2112 }
2113
2114 async fn load_rooms_with_unsent_requests(&self) -> Result<Vec<OwnedRoomId>, Self::Error> {
2115 let res: Vec<Vec<u8>> = self
2120 .read()
2121 .await?
2122 .prepare("SELECT room_id_val FROM send_queue_events", |mut stmt| {
2123 stmt.query(())?.mapped(|row| row.get(0)).collect()
2124 })
2125 .await?;
2126
2127 Ok(res
2130 .into_iter()
2131 .map(|entry| self.deserialize_value(&entry))
2132 .collect::<Result<BTreeSet<OwnedRoomId>, _>>()?
2133 .into_iter()
2134 .collect())
2135 }
2136
2137 async fn save_dependent_queued_request(
2138 &self,
2139 room_id: &RoomId,
2140 parent_txn_id: &TransactionId,
2141 own_txn_id: ChildTransactionId,
2142 created_at: MilliSecondsSinceUnixEpoch,
2143 content: DependentQueuedRequestKind,
2144 ) -> Result<()> {
2145 let room_id = self.encode_key(keys::DEPENDENTS_SEND_QUEUE, room_id);
2146 let content = self.serialize_json(&content)?;
2147
2148 let parent_txn_id = parent_txn_id.to_string();
2150 let own_txn_id = own_txn_id.to_string();
2151
2152 let created_at_ts: u64 = created_at.0.into();
2153 self.write()
2154 .await?
2155 .with_transaction(move |txn| {
2156 txn.prepare_cached(
2157 r#"INSERT INTO dependent_send_queue_events
2158 (room_id, parent_transaction_id, own_transaction_id, content, created_at)
2159 VALUES (?, ?, ?, ?, ?)"#,
2160 )?
2161 .execute((
2162 room_id,
2163 parent_txn_id,
2164 own_txn_id,
2165 content,
2166 created_at_ts,
2167 ))?;
2168 Ok(())
2169 })
2170 .await
2171 }
2172
2173 async fn update_dependent_queued_request(
2174 &self,
2175 room_id: &RoomId,
2176 own_transaction_id: &ChildTransactionId,
2177 new_content: DependentQueuedRequestKind,
2178 ) -> Result<bool> {
2179 let room_id = self.encode_key(keys::DEPENDENTS_SEND_QUEUE, room_id);
2180 let content = self.serialize_json(&new_content)?;
2181
2182 let own_txn_id = own_transaction_id.to_string();
2184
2185 let num_updated = self
2186 .write()
2187 .await?
2188 .with_transaction(move |txn| {
2189 txn.prepare_cached(
2190 r#"UPDATE dependent_send_queue_events
2191 SET content = ?
2192 WHERE own_transaction_id = ?
2193 AND room_id = ?"#,
2194 )?
2195 .execute((content, own_txn_id, room_id))
2196 })
2197 .await?;
2198
2199 if num_updated > 1 {
2200 return Err(Error::InconsistentUpdate);
2201 }
2202
2203 Ok(num_updated == 1)
2204 }
2205
2206 async fn mark_dependent_queued_requests_as_ready(
2207 &self,
2208 room_id: &RoomId,
2209 parent_txn_id: &TransactionId,
2210 parent_key: SentRequestKey,
2211 ) -> Result<usize> {
2212 let room_id = self.encode_key(keys::DEPENDENTS_SEND_QUEUE, room_id);
2213 let parent_key = self.serialize_json(&parent_key)?;
2214
2215 let parent_txn_id = parent_txn_id.to_string();
2217
2218 self.write()
2219 .await?
2220 .with_transaction(move |txn| {
2221 Ok(txn.prepare_cached(
2222 "UPDATE dependent_send_queue_events SET parent_key = ? WHERE parent_transaction_id = ? and room_id = ?",
2223 )?
2224 .execute((parent_key, parent_txn_id, room_id))?)
2225 })
2226 .await
2227 }
2228
2229 async fn remove_dependent_queued_request(
2230 &self,
2231 room_id: &RoomId,
2232 txn_id: &ChildTransactionId,
2233 ) -> Result<bool> {
2234 let room_id = self.encode_key(keys::DEPENDENTS_SEND_QUEUE, room_id);
2235
2236 let txn_id = txn_id.to_string();
2238
2239 let num_deleted = self
2240 .write()
2241 .await?
2242 .with_transaction(move |txn| {
2243 txn.prepare_cached(
2244 "DELETE FROM dependent_send_queue_events WHERE own_transaction_id = ? AND room_id = ?",
2245 )?
2246 .execute((txn_id, room_id))
2247 })
2248 .await?;
2249
2250 Ok(num_deleted > 0)
2251 }
2252
2253 async fn load_dependent_queued_requests(
2254 &self,
2255 room_id: &RoomId,
2256 ) -> Result<Vec<DependentQueuedRequest>> {
2257 let room_id = self.encode_key(keys::DEPENDENTS_SEND_QUEUE, room_id);
2258
2259 let res: Vec<(String, String, Option<Vec<u8>>, Vec<u8>, Option<u64>)> = self
2261 .read()
2262 .await?
2263 .prepare(
2264 "SELECT own_transaction_id, parent_transaction_id, parent_key, content, created_at FROM dependent_send_queue_events WHERE room_id = ? ORDER BY ROWID",
2265 |mut stmt| {
2266 stmt.query((room_id,))?
2267 .mapped(|row| Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?, row.get(4)?)))
2268 .collect()
2269 },
2270 )
2271 .await?;
2272
2273 let mut dependent_events = Vec::with_capacity(res.len());
2274
2275 for entry in res {
2276 let created_at = entry
2277 .4
2278 .and_then(UInt::new)
2279 .map_or_else(MilliSecondsSinceUnixEpoch::now, MilliSecondsSinceUnixEpoch);
2280
2281 dependent_events.push(DependentQueuedRequest {
2282 own_transaction_id: entry.0.into(),
2283 parent_transaction_id: entry.1.into(),
2284 parent_key: entry.2.map(|json| self.deserialize_json(&json)).transpose()?,
2285 kind: self.deserialize_json(&entry.3)?,
2286 created_at,
2287 });
2288 }
2289
2290 Ok(dependent_events)
2291 }
2292
2293 async fn upsert_thread_subscriptions(
2294 &self,
2295 updates: Vec<(&RoomId, &EventId, StoredThreadSubscription)>,
2296 ) -> Result<(), Self::Error> {
2297 let values: Vec<_> = updates
2298 .into_iter()
2299 .map(|(room_id, thread_id, subscription)| {
2300 (
2301 self.encode_key(keys::THREAD_SUBSCRIPTIONS, room_id),
2302 self.encode_key(keys::THREAD_SUBSCRIPTIONS, thread_id),
2303 subscription.status.as_str(),
2304 subscription.bump_stamp,
2305 )
2306 })
2307 .collect();
2308
2309 self.write()
2310 .await?
2311 .with_transaction(move |txn| {
2312 let mut txn = txn.prepare_cached(
2313 "INSERT INTO thread_subscriptions (room_id, event_id, status, bump_stamp)
2314 VALUES (?, ?, ?, ?)
2315 ON CONFLICT (room_id, event_id) DO UPDATE
2316 SET
2317 status =
2318 CASE
2319 WHEN thread_subscriptions.bump_stamp IS NULL THEN EXCLUDED.status
2320 WHEN EXCLUDED.bump_stamp IS NULL THEN EXCLUDED.status
2321 WHEN thread_subscriptions.bump_stamp < EXCLUDED.bump_stamp THEN EXCLUDED.status
2322 ELSE thread_subscriptions.status
2323 END,
2324 bump_stamp =
2325 CASE
2326 WHEN thread_subscriptions.bump_stamp IS NULL THEN EXCLUDED.bump_stamp
2327 WHEN EXCLUDED.bump_stamp IS NULL THEN thread_subscriptions.bump_stamp
2328 WHEN thread_subscriptions.bump_stamp < EXCLUDED.bump_stamp THEN EXCLUDED.bump_stamp
2329 ELSE thread_subscriptions.bump_stamp
2330 END",
2331 )?;
2332
2333 for value in values {
2334 txn.execute(value)?;
2335 }
2336
2337 Result::<_, Error>::Ok(())
2338 })
2339 .await?;
2340
2341 Ok(())
2342 }
2343
2344 async fn load_thread_subscription(
2345 &self,
2346 room_id: &RoomId,
2347 thread_id: &EventId,
2348 ) -> Result<Option<StoredThreadSubscription>, Self::Error> {
2349 let room_id = self.encode_key(keys::THREAD_SUBSCRIPTIONS, room_id);
2350 let thread_id = self.encode_key(keys::THREAD_SUBSCRIPTIONS, thread_id);
2351
2352 Ok(self
2353 .read()
2354 .await?
2355 .query_row(
2356 "SELECT status, bump_stamp FROM thread_subscriptions WHERE room_id = ? AND event_id = ?",
2357 (room_id, thread_id),
2358 |row| Ok((row.get::<_, String>(0)?, row.get::<_, Option<u64>>(1)?))
2359 )
2360 .await
2361 .optional()?
2362 .map(|(status, bump_stamp)| -> Result<_, Self::Error> {
2363 let status = ThreadSubscriptionStatus::from_str(&status).map_err(|_| {
2364 Error::InvalidData { details: format!("Invalid thread status: {status}") }
2365 })?;
2366 Ok(StoredThreadSubscription { status, bump_stamp })
2367 })
2368 .transpose()?)
2369 }
2370
2371 async fn remove_thread_subscription(
2372 &self,
2373 room_id: &RoomId,
2374 thread_id: &EventId,
2375 ) -> Result<(), Self::Error> {
2376 let room_id = self.encode_key(keys::THREAD_SUBSCRIPTIONS, room_id);
2377 let thread_id = self.encode_key(keys::THREAD_SUBSCRIPTIONS, thread_id);
2378
2379 self.write()
2380 .await?
2381 .execute(
2382 "DELETE FROM thread_subscriptions WHERE room_id = ? AND event_id = ?",
2383 (room_id, thread_id),
2384 )
2385 .await?;
2386
2387 Ok(())
2388 }
2389
2390 async fn optimize(&self) -> Result<(), Self::Error> {
2391 Ok(self.vacuum().await?)
2392 }
2393
2394 async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
2395 self.get_db_size().await
2396 }
2397
2398 async fn close(&self) -> Result<(), Self::Error> {
2399 connection::close_connections(&self.connections, "State store").await;
2400 Ok(())
2401 }
2402
2403 async fn reopen(&self) -> Result<(), Self::Error> {
2404 connection::reopen_connections(
2405 &self.connections,
2406 self.db_path.clone(),
2407 self.pool_config,
2408 self.runtime_config,
2409 )
2410 .await?;
2411 Ok(())
2412 }
2413}
2414
2415#[derive(Debug, Clone, Serialize, Deserialize)]
2416struct ReceiptData {
2417 receipt: Receipt,
2418 event_id: OwnedEventId,
2419 user_id: OwnedUserId,
2420}
2421
2422#[cfg(test)]
2423mod tests {
2424 use std::sync::{
2425 LazyLock,
2426 atomic::{AtomicU32, Ordering::SeqCst},
2427 };
2428
2429 use matrix_sdk_base::{StateStore, StoreError, statestore_integration_tests};
2430 use tempfile::{TempDir, tempdir};
2431
2432 use super::SqliteStateStore;
2433
2434 static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
2435 static NUM: AtomicU32 = AtomicU32::new(0);
2436
2437 async fn get_store() -> Result<impl StateStore, StoreError> {
2438 let name = NUM.fetch_add(1, SeqCst).to_string();
2439 let tmpdir_path = TMP_DIR.path().join(name);
2440
2441 tracing::info!("using store @ {}", tmpdir_path.to_str().unwrap());
2442
2443 Ok(SqliteStateStore::open(tmpdir_path.to_str().unwrap(), None).await.unwrap())
2444 }
2445
2446 statestore_integration_tests!();
2447}
2448
2449#[cfg(test)]
2450mod encrypted_tests {
2451 use std::{
2452 path::PathBuf,
2453 sync::{
2454 LazyLock,
2455 atomic::{AtomicU32, Ordering::SeqCst},
2456 },
2457 };
2458
2459 use matrix_sdk_base::{StateStore, StoreError, statestore_integration_tests};
2460 use matrix_sdk_test::async_test;
2461 use tempfile::{TempDir, tempdir};
2462
2463 use super::SqliteStateStore;
2464 use crate::{SqliteStoreConfig, utils::SqliteAsyncConnExt};
2465
2466 static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
2467 static NUM: AtomicU32 = AtomicU32::new(0);
2468
2469 fn new_state_store_workspace() -> PathBuf {
2470 let name = NUM.fetch_add(1, SeqCst).to_string();
2471 TMP_DIR.path().join(name)
2472 }
2473
2474 async fn get_store() -> Result<impl StateStore, StoreError> {
2475 let tmpdir_path = new_state_store_workspace();
2476
2477 tracing::info!("using store @ {}", tmpdir_path.to_str().unwrap());
2478
2479 Ok(SqliteStateStore::open(tmpdir_path.to_str().unwrap(), Some("default_test_password"))
2480 .await
2481 .unwrap())
2482 }
2483
2484 #[async_test]
2485 async fn test_pool_size() {
2486 let tmpdir_path = new_state_store_workspace();
2487 let store_open_config = SqliteStoreConfig::new(tmpdir_path).pool_max_size(42);
2488
2489 let store = SqliteStateStore::open_with_config(&store_open_config).await.unwrap();
2490
2491 let guard = store.connections.lock().await;
2492 assert_eq!(guard.as_ref().unwrap().pool.status().max_size, 42);
2493 }
2494
2495 #[async_test]
2496 async fn test_cache_size() {
2497 let tmpdir_path = new_state_store_workspace();
2498 let store_open_config = SqliteStoreConfig::new(tmpdir_path).cache_size(1500);
2499
2500 let store = SqliteStateStore::open_with_config(&store_open_config).await.unwrap();
2501
2502 let conn = store.read().await.unwrap();
2503 let cache_size =
2504 conn.query_row("PRAGMA cache_size", (), |row| row.get::<_, i32>(0)).await.unwrap();
2505
2506 assert_eq!(cache_size, -(1500 / 1024));
2510 }
2511
2512 #[async_test]
2513 async fn test_journal_size_limit() {
2514 let tmpdir_path = new_state_store_workspace();
2515 let store_open_config = SqliteStoreConfig::new(tmpdir_path).journal_size_limit(1500);
2516
2517 let store = SqliteStateStore::open_with_config(&store_open_config).await.unwrap();
2518
2519 let conn = store.read().await.unwrap();
2520 let journal_size_limit = conn
2521 .query_row("PRAGMA journal_size_limit", (), |row| row.get::<_, u32>(0))
2522 .await
2523 .unwrap();
2524
2525 assert_eq!(journal_size_limit, 1500);
2528 }
2529
2530 statestore_integration_tests!();
2531}
2532
2533#[cfg(test)]
2534mod migration_tests {
2535 use std::{
2536 path::{Path, PathBuf},
2537 sync::{
2538 Arc, LazyLock,
2539 atomic::{AtomicU32, Ordering::SeqCst},
2540 },
2541 };
2542
2543 use as_variant::as_variant;
2544 use matrix_sdk_base::{
2545 RoomState, StateStore,
2546 media::{MediaFormat, MediaRequestParameters},
2547 store::{
2548 ChildTransactionId, DependentQueuedRequestKind, RoomLoadSettings,
2549 SerializableEventContent,
2550 },
2551 sync::UnreadNotificationsCount,
2552 };
2553 use matrix_sdk_test::async_test;
2554 use ruma::{
2555 EventId, MilliSecondsSinceUnixEpoch, OwnedTransactionId, RoomId, TransactionId, UserId,
2556 events::{
2557 StateEventType,
2558 room::{MediaSource, create::RoomCreateEventContent, message::RoomMessageEventContent},
2559 },
2560 room_id, server_name, user_id,
2561 };
2562 use rusqlite::Transaction;
2563 use serde::{Deserialize, Serialize};
2564 use serde_json::json;
2565 use tempfile::{TempDir, tempdir};
2566 use tokio::{fs, sync::Mutex};
2567 use zeroize::Zeroizing;
2568
2569 use super::{DATABASE_NAME, SqliteStateStore, init, keys};
2570 use crate::{
2571 OpenStoreError, Secret, SqliteStoreConfig, connection,
2572 error::{Error, Result},
2573 utils::{EncryptableStore as _, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt},
2574 };
2575
2576 static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
2577 static NUM: AtomicU32 = AtomicU32::new(0);
2578 const SECRET: &str = "secret";
2579
2580 fn new_path() -> PathBuf {
2581 let name = NUM.fetch_add(1, SeqCst).to_string();
2582 TMP_DIR.path().join(name)
2583 }
2584
2585 async fn create_fake_db(path: &Path, version: u8) -> Result<SqliteStateStore> {
2586 let config = SqliteStoreConfig::new(path);
2587
2588 fs::create_dir_all(&config.path).await.map_err(OpenStoreError::CreateDir).unwrap();
2589
2590 let pool = config.build_pool_of_connections(DATABASE_NAME).unwrap();
2591 let db_path = pool.manager().database_path.clone();
2592 let conn = pool.get().await?;
2593
2594 init(&conn).await?;
2595
2596 let store_cipher = Some(Arc::new(
2597 conn.get_or_create_store_cipher(Secret::PassPhrase(Zeroizing::new(SECRET.to_owned())))
2598 .await
2599 .unwrap(),
2600 ));
2601 let this = SqliteStateStore {
2602 store_cipher,
2603 connections: Arc::new(Mutex::new(Some(connection::SqliteConnections {
2604 pool,
2605 write_connection: Arc::new(Mutex::new(conn)),
2606 }))),
2607 db_path,
2608 pool_config: deadpool::managed::PoolConfig::default(),
2609 runtime_config: crate::RuntimeConfig::default(),
2610 };
2611 this.run_migrations(1, Some(version)).await?;
2612
2613 Ok(this)
2614 }
2615
2616 fn room_info_v1_json(
2617 room_id: &RoomId,
2618 state: RoomState,
2619 name: Option<&str>,
2620 creator: Option<&UserId>,
2621 ) -> serde_json::Value {
2622 let name_content = match name {
2624 Some(name) => json!({ "name": name }),
2625 None => json!({ "name": null }),
2626 };
2627 let create_content = match creator {
2629 Some(creator) => RoomCreateEventContent::new_v1(creator.to_owned()),
2630 None => RoomCreateEventContent::new_v11(),
2631 };
2632
2633 json!({
2634 "room_id": room_id,
2635 "room_type": state,
2636 "notification_counts": UnreadNotificationsCount::default(),
2637 "summary": {
2638 "heroes": [],
2639 "joined_member_count": 0,
2640 "invited_member_count": 0,
2641 },
2642 "members_synced": false,
2643 "base_info": {
2644 "dm_targets": [],
2645 "max_power_level": 100,
2646 "name": {
2647 "Original": {
2648 "content": name_content,
2649 },
2650 },
2651 "create": {
2652 "Original": {
2653 "content": create_content,
2654 }
2655 }
2656 },
2657 })
2658 }
2659
2660 #[async_test]
2661 pub async fn test_migrating_v1_to_v2() {
2662 let path = new_path();
2663 {
2665 let db = create_fake_db(&path, 1).await.unwrap();
2666 let conn = db.read().await.unwrap();
2667
2668 let this = db.clone();
2669 conn.with_transaction(move |txn| {
2670 for i in 0..5 {
2671 let room_id = RoomId::parse(format!("!room_{i}:localhost")).unwrap();
2672 let (state, stripped) =
2673 if i < 3 { (RoomState::Joined, false) } else { (RoomState::Invited, true) };
2674 let info = room_info_v1_json(&room_id, state, None, None);
2675
2676 let room_id = this.encode_key(keys::ROOM_INFO, room_id);
2677 let data = this.serialize_json(&info)?;
2678
2679 txn.prepare_cached(
2680 "INSERT INTO room_info (room_id, stripped, data)
2681 VALUES (?, ?, ?)",
2682 )?
2683 .execute((room_id, stripped, data))?;
2684 }
2685
2686 Result::<_, Error>::Ok(())
2687 })
2688 .await
2689 .unwrap();
2690 }
2691
2692 let store = SqliteStateStore::open(path, Some(SECRET)).await.unwrap();
2694
2695 assert_eq!(store.get_room_infos(&RoomLoadSettings::default()).await.unwrap().len(), 5);
2697 }
2698
2699 fn add_room_v2(
2701 this: &SqliteStateStore,
2702 txn: &Transaction<'_>,
2703 room_id: &RoomId,
2704 name: Option<&str>,
2705 create_creator: Option<&UserId>,
2706 create_sender: Option<&UserId>,
2707 ) -> Result<(), Error> {
2708 let room_info_json = room_info_v1_json(room_id, RoomState::Joined, name, create_creator);
2709
2710 let encoded_room_id = this.encode_key(keys::ROOM_INFO, room_id);
2711 let encoded_state =
2712 this.encode_key(keys::ROOM_INFO, serde_json::to_string(&RoomState::Joined)?);
2713 let data = this.serialize_json(&room_info_json)?;
2714
2715 txn.prepare_cached(
2716 "INSERT INTO room_info (room_id, state, data)
2717 VALUES (?, ?, ?)",
2718 )?
2719 .execute((encoded_room_id, encoded_state, data))?;
2720
2721 let Some(create_sender) = create_sender else {
2723 return Ok(());
2724 };
2725
2726 let create_content = match create_creator {
2727 Some(creator) => RoomCreateEventContent::new_v1(creator.to_owned()),
2728 None => RoomCreateEventContent::new_v11(),
2729 };
2730
2731 let event_id = EventId::new_v1(server_name!("dummy.local"));
2732 let create_event = json!({
2733 "content": create_content,
2734 "event_id": event_id,
2735 "sender": create_sender.to_owned(),
2736 "origin_server_ts": MilliSecondsSinceUnixEpoch::now(),
2737 "state_key": "",
2738 "type": "m.room.create",
2739 "unsigned": {},
2740 });
2741
2742 let encoded_room_id = this.encode_key(keys::STATE_EVENT, room_id);
2743 let encoded_event_type =
2744 this.encode_key(keys::STATE_EVENT, StateEventType::RoomCreate.to_string());
2745 let encoded_state_key = this.encode_key(keys::STATE_EVENT, "");
2746 let stripped = false;
2747 let encoded_event_id = this.encode_key(keys::STATE_EVENT, event_id);
2748 let data = this.serialize_json(&create_event)?;
2749
2750 txn.prepare_cached(
2751 "INSERT
2752 INTO state_event (room_id, event_type, state_key, stripped, event_id, data)
2753 VALUES (?, ?, ?, ?, ?, ?)",
2754 )?
2755 .execute((
2756 encoded_room_id,
2757 encoded_event_type,
2758 encoded_state_key,
2759 stripped,
2760 encoded_event_id,
2761 data,
2762 ))?;
2763
2764 Ok(())
2765 }
2766
2767 #[async_test]
2768 pub async fn test_migrating_v2_to_v3() {
2769 let path = new_path();
2770
2771 let room_a_id = room_id!("!room_a:dummy.local");
2773 let room_a_name = "Room A";
2774 let room_a_creator = user_id!("@creator:dummy.local");
2775 let room_a_create_sender = user_id!("@sender:dummy.local");
2778
2779 let room_b_id = room_id!("!room_b:dummy.local");
2781
2782 let room_c_id = room_id!("!room_c:dummy.local");
2784 let room_c_create_sender = user_id!("@creator:dummy.local");
2785
2786 {
2788 let db = create_fake_db(&path, 2).await.unwrap();
2789 let conn = db.read().await.unwrap();
2790
2791 let this = db.clone();
2792 conn.with_transaction(move |txn| {
2793 add_room_v2(
2794 &this,
2795 txn,
2796 room_a_id,
2797 Some(room_a_name),
2798 Some(room_a_creator),
2799 Some(room_a_create_sender),
2800 )?;
2801 add_room_v2(&this, txn, room_b_id, None, None, None)?;
2802 add_room_v2(&this, txn, room_c_id, None, None, Some(room_c_create_sender))?;
2803
2804 Result::<_, Error>::Ok(())
2805 })
2806 .await
2807 .unwrap();
2808 }
2809
2810 let store = SqliteStateStore::open(path, Some(SECRET)).await.unwrap();
2812
2813 let room_infos = store.get_room_infos(&RoomLoadSettings::default()).await.unwrap();
2815 assert_eq!(room_infos.len(), 3);
2816
2817 let room_a = room_infos.iter().find(|r| r.room_id() == room_a_id).unwrap();
2818 assert_eq!(room_a.name(), Some(room_a_name));
2819 assert_eq!(room_a.creators(), Some(vec![room_a_create_sender.to_owned()]));
2820
2821 let room_b = room_infos.iter().find(|r| r.room_id() == room_b_id).unwrap();
2822 assert_eq!(room_b.name(), None);
2823 assert_eq!(room_b.creators(), None);
2824
2825 let room_c = room_infos.iter().find(|r| r.room_id() == room_c_id).unwrap();
2826 assert_eq!(room_c.name(), None);
2827 assert_eq!(room_c.creators(), Some(vec![room_c_create_sender.to_owned()]));
2828 }
2829
2830 #[async_test]
2831 pub async fn test_migrating_v7_to_v9() {
2832 let path = new_path();
2833
2834 let room_id = room_id!("!room_a:dummy.local");
2835 let wedged_event_transaction_id = TransactionId::new();
2836 let local_event_transaction_id = TransactionId::new();
2837
2838 {
2840 let db = create_fake_db(&path, 7).await.unwrap();
2841 let conn = db.read().await.unwrap();
2842
2843 let wedge_tx = wedged_event_transaction_id.clone();
2844 let local_tx = local_event_transaction_id.clone();
2845
2846 conn.with_transaction(move |txn| {
2847 add_dependent_send_queue_event_v7(
2848 &db,
2849 txn,
2850 room_id,
2851 &local_tx,
2852 ChildTransactionId::new(),
2853 DependentQueuedRequestKind::RedactEvent,
2854 )?;
2855 add_send_queue_event_v7(&db, txn, &wedge_tx, room_id, true)?;
2856 add_send_queue_event_v7(&db, txn, &local_tx, room_id, false)?;
2857 Result::<_, Error>::Ok(())
2858 })
2859 .await
2860 .unwrap();
2861 }
2862
2863 let store = SqliteStateStore::open(path, Some(SECRET)).await.unwrap();
2866
2867 let requests = store.load_send_queue_requests(room_id).await.unwrap();
2868 assert!(requests.is_empty());
2869
2870 let dependent_requests = store.load_dependent_queued_requests(room_id).await.unwrap();
2871 assert!(dependent_requests.is_empty());
2872 }
2873
2874 fn add_send_queue_event_v7(
2875 this: &SqliteStateStore,
2876 txn: &Transaction<'_>,
2877 transaction_id: &TransactionId,
2878 room_id: &RoomId,
2879 is_wedged: bool,
2880 ) -> Result<(), Error> {
2881 let content =
2882 SerializableEventContent::new(&RoomMessageEventContent::text_plain("Hello").into())?;
2883
2884 let room_id_key = this.encode_key(keys::SEND_QUEUE, room_id);
2885 let room_id_value = this.serialize_value(&room_id.to_owned())?;
2886
2887 let content = this.serialize_json(&content)?;
2888
2889 txn.prepare_cached("INSERT INTO send_queue_events (room_id, room_id_val, transaction_id, content, wedged) VALUES (?, ?, ?, ?, ?)")?
2890 .execute((room_id_key, room_id_value, transaction_id.to_string(), content, is_wedged))?;
2891
2892 Ok(())
2893 }
2894
2895 fn add_dependent_send_queue_event_v7(
2896 this: &SqliteStateStore,
2897 txn: &Transaction<'_>,
2898 room_id: &RoomId,
2899 parent_txn_id: &TransactionId,
2900 own_txn_id: ChildTransactionId,
2901 content: DependentQueuedRequestKind,
2902 ) -> Result<(), Error> {
2903 let room_id_value = this.serialize_value(&room_id.to_owned())?;
2904
2905 let parent_txn_id = parent_txn_id.to_string();
2906 let own_txn_id = own_txn_id.to_string();
2907 let content = this.serialize_json(&content)?;
2908
2909 txn.prepare_cached(
2910 "INSERT INTO dependent_send_queue_events
2911 (room_id, parent_transaction_id, own_transaction_id, content)
2912 VALUES (?, ?, ?, ?)",
2913 )?
2914 .execute((room_id_value, parent_txn_id, own_txn_id, content))?;
2915
2916 Ok(())
2917 }
2918
2919 #[derive(Clone, Debug, Serialize, Deserialize)]
2920 pub enum LegacyDependentQueuedRequestKind {
2921 UploadFileWithThumbnail {
2922 content_type: String,
2923 cache_key: MediaRequestParameters,
2924 related_to: OwnedTransactionId,
2925 },
2926 }
2927
2928 #[async_test]
2929 pub async fn test_dependent_queued_request_variant_renaming() {
2930 let path = new_path();
2931 let db = create_fake_db(&path, 7).await.unwrap();
2932
2933 let cache_key = MediaRequestParameters {
2934 format: MediaFormat::File,
2935 source: MediaSource::Plain("https://server.local/foobar".into()),
2936 };
2937 let related_to = TransactionId::new();
2938 let request = LegacyDependentQueuedRequestKind::UploadFileWithThumbnail {
2939 content_type: "image/png".to_owned(),
2940 cache_key,
2941 related_to: related_to.clone(),
2942 };
2943
2944 let data = db
2945 .serialize_json(&request)
2946 .expect("should be able to serialize legacy dependent request");
2947 let deserialized: DependentQueuedRequestKind = db.deserialize_json(&data).expect(
2948 "should be able to deserialize dependent request from legacy dependent request",
2949 );
2950
2951 as_variant!(deserialized, DependentQueuedRequestKind::UploadFileOrThumbnail { related_to: de_related_to, .. } => {
2952 assert_eq!(de_related_to, related_to);
2953 });
2954 }
2955}
2956
2957#[cfg(test)]
2958mod close_reopen_tests {
2959 use std::sync::{
2960 LazyLock,
2961 atomic::{AtomicU32, Ordering::SeqCst},
2962 };
2963
2964 use matrix_sdk_base::StateStore;
2965 use matrix_sdk_test::async_test;
2966 use tempfile::{TempDir, tempdir};
2967
2968 use super::SqliteStateStore;
2969
2970 static TMP_DIR: LazyLock<TempDir> = LazyLock::new(|| tempdir().unwrap());
2971 static NUM: AtomicU32 = AtomicU32::new(0);
2972
2973 async fn new_store() -> SqliteStateStore {
2974 let name = NUM.fetch_add(1, SeqCst).to_string();
2975 let tmpdir_path = TMP_DIR.path().join(name);
2976 SqliteStateStore::open(tmpdir_path.to_str().unwrap(), None).await.unwrap()
2977 }
2978
2979 #[async_test]
2980 async fn test_close_completes_without_timeout() {
2981 let store = new_store().await;
2982
2983 let start = std::time::Instant::now();
2985 store.close().await.unwrap();
2986 let elapsed = start.elapsed();
2987
2988 assert!(
2989 elapsed < std::time::Duration::from_secs(2),
2990 "close() took {elapsed:?}, expected < 2s (no timeout)"
2991 );
2992
2993 let guard = store.connections.lock().await;
2995 assert!(guard.is_none(), "connections should be None after close");
2996 }
2997
2998 #[async_test]
2999 async fn test_reopen_restores_connections() {
3000 let store = new_store().await;
3001
3002 store.close().await.unwrap();
3003
3004 {
3006 let guard = store.connections.lock().await;
3007 assert!(guard.is_none());
3008 }
3009
3010 store.reopen().await.unwrap();
3011
3012 {
3014 let guard = store.connections.lock().await;
3015 assert!(guard.is_some(), "connections should be Some after reopen");
3016 }
3017 }
3018
3019 #[async_test]
3020 async fn test_close_is_idempotent() {
3021 let store = new_store().await;
3022
3023 store.close().await.unwrap();
3025 store.close().await.unwrap();
3027
3028 let guard = store.connections.lock().await;
3029 assert!(guard.is_none());
3030 }
3031
3032 #[async_test]
3033 async fn test_reopen_is_idempotent() {
3034 let store = new_store().await;
3035
3036 store.reopen().await.unwrap();
3038
3039 let guard = store.connections.lock().await;
3041 assert!(guard.is_some());
3042 }
3043
3044 #[async_test]
3045 async fn test_read_fails_when_closed() {
3046 let store = new_store().await;
3047 store.close().await.unwrap();
3048
3049 let err = store.get_custom_value(b"some_key").await;
3050 assert!(err.is_err(), "read should fail when closed");
3051
3052 let err_msg = err.unwrap_err().to_string();
3053 assert!(err_msg.contains("closed"), "error should mention 'closed', got: {err_msg}");
3054 }
3055
3056 #[async_test]
3057 async fn test_write_fails_when_closed() {
3058 let store = new_store().await;
3059 store.close().await.unwrap();
3060
3061 let err = store.set_custom_value(b"key", b"value".to_vec()).await;
3062 assert!(err.is_err(), "write should fail when closed");
3063
3064 let err_msg = err.unwrap_err().to_string();
3065 assert!(err_msg.contains("closed"), "error should mention 'closed', got: {err_msg}");
3066 }
3067
3068 #[async_test]
3069 async fn test_data_persists_across_close_reopen() {
3070 let store = new_store().await;
3071
3072 store.set_custom_value(b"test_key", b"test_value".to_vec()).await.unwrap();
3074
3075 let value = store.get_custom_value(b"test_key").await.unwrap();
3077 assert_eq!(value.as_deref(), Some(b"test_value".as_slice()));
3078
3079 store.close().await.unwrap();
3081 store.reopen().await.unwrap();
3082
3083 let value = store.get_custom_value(b"test_key").await.unwrap();
3085 assert_eq!(
3086 value.as_deref(),
3087 Some(b"test_value".as_slice()),
3088 "data should persist across close/reopen"
3089 );
3090 }
3091
3092 #[async_test]
3093 async fn test_multiple_close_reopen_cycles() {
3094 let store = new_store().await;
3095
3096 for i in 0..3 {
3097 let key = format!("key_{i}");
3098 let value = format!("value_{i}");
3099
3100 store.set_custom_value(key.as_bytes(), value.as_bytes().to_vec()).await.unwrap();
3101
3102 store.close().await.unwrap();
3103 store.reopen().await.unwrap();
3104
3105 for j in 0..=i {
3107 let k = format!("key_{j}");
3108 let v = format!("value_{j}");
3109 let retrieved = store.get_custom_value(k.as_bytes()).await.unwrap();
3110 assert_eq!(
3111 retrieved.as_deref(),
3112 Some(v.as_bytes()),
3113 "data for key_{j} should persist after cycle {i}"
3114 );
3115 }
3116 }
3117 }
3118
3119 #[async_test]
3120 async fn test_pool_is_fully_drained_after_close() {
3121 let store = new_store().await;
3122
3123 let _ = store.get_custom_value(b"key1").await;
3125 let _ = store.get_custom_value(b"key2").await;
3126
3127 store.close().await.unwrap();
3128
3129 let guard = store.connections.lock().await;
3132 assert!(guard.is_none(), "all connections should be released after close");
3133 }
3134
3135 #[async_test]
3136 async fn test_operations_work_immediately_after_reopen() {
3137 let store = new_store().await;
3138
3139 store.close().await.unwrap();
3140 store.reopen().await.unwrap();
3141
3142 store.set_custom_value(b"after_reopen", b"works".to_vec()).await.unwrap();
3144
3145 let value = store.get_custom_value(b"after_reopen").await.unwrap();
3147 assert_eq!(value.as_deref(), Some(b"works".as_slice()));
3148 }
3149
3150 #[async_test]
3151 async fn test_close_waits_for_held_read_connection_to_drain() {
3152 let store = new_store().await;
3153
3154 let held_conn = store.read().await.unwrap();
3156
3157 let store_clone = store.clone();
3160 let close_handle = tokio::spawn(async move {
3161 store_clone.close().await.unwrap();
3162 });
3163
3164 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
3166
3167 assert!(!close_handle.is_finished(), "close should be waiting for the held connection");
3169
3170 drop(held_conn);
3172
3173 let timeout = tokio::time::timeout(std::time::Duration::from_secs(3), close_handle).await;
3175 assert!(timeout.is_ok(), "close should complete after the held connection is released");
3176 timeout.unwrap().unwrap();
3177
3178 let guard = store.connections.lock().await;
3180 assert!(guard.is_none(), "connections should be None after close");
3181 }
3182}