1use std::{
16 collections::HashMap,
17 num::NonZeroUsize,
18 sync::{Arc, RwLock as StdRwLock},
19};
20
21use async_trait::async_trait;
22use matrix_sdk_common::{
23 linked_chunk::{
24 ChunkIdentifier, ChunkIdentifierGenerator, ChunkMetadata, LinkedChunkId, Position,
25 RawChunk, Update, relational::RelationalLinkedChunk,
26 },
27 ring_buffer::RingBuffer,
28 store_locks::memory_store_helper::try_take_leased_lock,
29};
30use ruma::{
31 EventId, MxcUri, OwnedEventId, OwnedMxcUri, RoomId,
32 events::relation::RelationType,
33 time::{Instant, SystemTime},
34};
35use tracing::error;
36
37use super::{
38 EventCacheStore, EventCacheStoreError, Result, compute_filters_string, extract_event_relation,
39 media::{EventCacheStoreMedia, IgnoreMediaRetentionPolicy, MediaRetentionPolicy, MediaService},
40};
41use crate::{
42 event_cache::{Event, Gap},
43 media::{MediaRequestParameters, UniqueKey as _},
44};
45
46#[derive(Debug, Clone)]
50pub struct MemoryStore {
51 inner: Arc<StdRwLock<MemoryStoreInner>>,
52 media_service: MediaService,
53}
54
55#[derive(Debug)]
56struct MemoryStoreInner {
57 media: RingBuffer<MediaContent>,
58 leases: HashMap<String, (String, Instant)>,
59 events: RelationalLinkedChunk<OwnedEventId, Event, Gap>,
60 media_retention_policy: Option<MediaRetentionPolicy>,
61 last_media_cleanup_time: SystemTime,
62}
63
64#[derive(Debug)]
66struct MediaContent {
67 uri: OwnedMxcUri,
69
70 key: String,
72
73 data: Vec<u8>,
75
76 ignore_policy: bool,
78
79 last_access: SystemTime,
81}
82
83const NUMBER_OF_MEDIAS: NonZeroUsize = NonZeroUsize::new(20).unwrap();
84
85impl Default for MemoryStore {
86 fn default() -> Self {
87 let last_media_cleanup_time = SystemTime::now();
89 let media_service = MediaService::new();
90 media_service.restore(None, Some(last_media_cleanup_time));
91
92 Self {
93 inner: Arc::new(StdRwLock::new(MemoryStoreInner {
94 media: RingBuffer::new(NUMBER_OF_MEDIAS),
95 leases: Default::default(),
96 events: RelationalLinkedChunk::new(),
97 media_retention_policy: None,
98 last_media_cleanup_time,
99 })),
100 media_service,
101 }
102 }
103}
104
105impl MemoryStore {
106 pub fn new() -> Self {
108 Self::default()
109 }
110}
111
112#[cfg_attr(target_family = "wasm", async_trait(?Send))]
113#[cfg_attr(not(target_family = "wasm"), async_trait)]
114impl EventCacheStore for MemoryStore {
115 type Error = EventCacheStoreError;
116
117 async fn try_take_leased_lock(
118 &self,
119 lease_duration_ms: u32,
120 key: &str,
121 holder: &str,
122 ) -> Result<bool, Self::Error> {
123 let mut inner = self.inner.write().unwrap();
124
125 Ok(try_take_leased_lock(&mut inner.leases, lease_duration_ms, key, holder))
126 }
127
128 async fn handle_linked_chunk_updates(
129 &self,
130 linked_chunk_id: LinkedChunkId<'_>,
131 updates: Vec<Update<Event, Gap>>,
132 ) -> Result<(), Self::Error> {
133 let mut inner = self.inner.write().unwrap();
134 inner.events.apply_updates(linked_chunk_id, updates);
135
136 Ok(())
137 }
138
139 async fn load_all_chunks(
140 &self,
141 linked_chunk_id: LinkedChunkId<'_>,
142 ) -> Result<Vec<RawChunk<Event, Gap>>, Self::Error> {
143 let inner = self.inner.read().unwrap();
144 inner
145 .events
146 .load_all_chunks(linked_chunk_id)
147 .map_err(|err| EventCacheStoreError::InvalidData { details: err })
148 }
149
150 async fn load_all_chunks_metadata(
151 &self,
152 linked_chunk_id: LinkedChunkId<'_>,
153 ) -> Result<Vec<ChunkMetadata>, Self::Error> {
154 let inner = self.inner.read().unwrap();
155 inner
156 .events
157 .load_all_chunks_metadata(linked_chunk_id)
158 .map_err(|err| EventCacheStoreError::InvalidData { details: err })
159 }
160
161 async fn load_last_chunk(
162 &self,
163 linked_chunk_id: LinkedChunkId<'_>,
164 ) -> Result<(Option<RawChunk<Event, Gap>>, ChunkIdentifierGenerator), Self::Error> {
165 let inner = self.inner.read().unwrap();
166 inner
167 .events
168 .load_last_chunk(linked_chunk_id)
169 .map_err(|err| EventCacheStoreError::InvalidData { details: err })
170 }
171
172 async fn load_previous_chunk(
173 &self,
174 linked_chunk_id: LinkedChunkId<'_>,
175 before_chunk_identifier: ChunkIdentifier,
176 ) -> Result<Option<RawChunk<Event, Gap>>, Self::Error> {
177 let inner = self.inner.read().unwrap();
178 inner
179 .events
180 .load_previous_chunk(linked_chunk_id, before_chunk_identifier)
181 .map_err(|err| EventCacheStoreError::InvalidData { details: err })
182 }
183
184 async fn clear_all_linked_chunks(&self) -> Result<(), Self::Error> {
185 self.inner.write().unwrap().events.clear();
186 Ok(())
187 }
188
189 async fn filter_duplicated_events(
190 &self,
191 linked_chunk_id: LinkedChunkId<'_>,
192 mut events: Vec<OwnedEventId>,
193 ) -> Result<Vec<(OwnedEventId, Position)>, Self::Error> {
194 if events.is_empty() {
195 return Ok(Vec::new());
196 }
197
198 let inner = self.inner.read().unwrap();
199
200 let mut duplicated_events = Vec::new();
201
202 for (event, position) in
203 inner.events.unordered_linked_chunk_items(&linked_chunk_id.to_owned())
204 {
205 if let Some(known_event_id) = event.event_id() {
206 if let Some(index) =
208 events.iter().position(|new_event_id| &known_event_id == new_event_id)
209 {
210 duplicated_events.push((events.remove(index), position));
211 }
212 }
213 }
214
215 Ok(duplicated_events)
216 }
217
218 async fn find_event(
219 &self,
220 room_id: &RoomId,
221 event_id: &EventId,
222 ) -> Result<Option<Event>, Self::Error> {
223 let inner = self.inner.read().unwrap();
224
225 let event = inner
226 .events
227 .items(room_id)
228 .find_map(|(event, _pos)| (event.event_id()? == event_id).then_some(event.clone()));
229
230 Ok(event)
231 }
232
233 async fn find_event_relations(
234 &self,
235 room_id: &RoomId,
236 event_id: &EventId,
237 filters: Option<&[RelationType]>,
238 ) -> Result<Vec<(Event, Option<Position>)>, Self::Error> {
239 let inner = self.inner.read().unwrap();
240
241 let filters = compute_filters_string(filters);
242
243 let related_events = inner
244 .events
245 .items(room_id)
246 .filter_map(|(event, pos)| {
247 let (related_to, rel_type) = extract_event_relation(event.raw())?;
249
250 if related_to != event_id {
252 return None;
253 }
254
255 if let Some(filters) = &filters {
257 filters.contains(&rel_type).then_some((event.clone(), pos))
258 } else {
259 Some((event.clone(), pos))
260 }
261 })
262 .collect();
263
264 Ok(related_events)
265 }
266
267 async fn save_event(&self, room_id: &RoomId, event: Event) -> Result<(), Self::Error> {
268 if event.event_id().is_none() {
269 error!(%room_id, "Trying to save an event with no ID");
270 return Ok(());
271 }
272 self.inner.write().unwrap().events.save_item(room_id.to_owned(), event);
273 Ok(())
274 }
275
276 async fn add_media_content(
277 &self,
278 request: &MediaRequestParameters,
279 data: Vec<u8>,
280 ignore_policy: IgnoreMediaRetentionPolicy,
281 ) -> Result<()> {
282 self.media_service.add_media_content(self, request, data, ignore_policy).await
283 }
284
285 async fn replace_media_key(
286 &self,
287 from: &MediaRequestParameters,
288 to: &MediaRequestParameters,
289 ) -> Result<(), Self::Error> {
290 let expected_key = from.unique_key();
291
292 let mut inner = self.inner.write().unwrap();
293
294 if let Some(media_content) =
295 inner.media.iter_mut().find(|media_content| media_content.key == expected_key)
296 {
297 media_content.uri = to.uri().to_owned();
298 media_content.key = to.unique_key();
299 }
300
301 Ok(())
302 }
303
304 async fn get_media_content(&self, request: &MediaRequestParameters) -> Result<Option<Vec<u8>>> {
305 self.media_service.get_media_content(self, request).await
306 }
307
308 async fn remove_media_content(&self, request: &MediaRequestParameters) -> Result<()> {
309 let expected_key = request.unique_key();
310
311 let mut inner = self.inner.write().unwrap();
312
313 let Some(index) =
314 inner.media.iter().position(|media_content| media_content.key == expected_key)
315 else {
316 return Ok(());
317 };
318
319 inner.media.remove(index);
320
321 Ok(())
322 }
323
324 async fn get_media_content_for_uri(
325 &self,
326 uri: &MxcUri,
327 ) -> Result<Option<Vec<u8>>, Self::Error> {
328 self.media_service.get_media_content_for_uri(self, uri).await
329 }
330
331 async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
332 let mut inner = self.inner.write().unwrap();
333
334 let positions = inner
335 .media
336 .iter()
337 .enumerate()
338 .filter_map(|(position, media_content)| (media_content.uri == uri).then_some(position))
339 .collect::<Vec<_>>();
340
341 for position in positions.into_iter().rev() {
343 inner.media.remove(position);
344 }
345
346 Ok(())
347 }
348
349 async fn set_media_retention_policy(
350 &self,
351 policy: MediaRetentionPolicy,
352 ) -> Result<(), Self::Error> {
353 self.media_service.set_media_retention_policy(self, policy).await
354 }
355
356 fn media_retention_policy(&self) -> MediaRetentionPolicy {
357 self.media_service.media_retention_policy()
358 }
359
360 async fn set_ignore_media_retention_policy(
361 &self,
362 request: &MediaRequestParameters,
363 ignore_policy: IgnoreMediaRetentionPolicy,
364 ) -> Result<(), Self::Error> {
365 self.media_service.set_ignore_media_retention_policy(self, request, ignore_policy).await
366 }
367
368 async fn clean_up_media_cache(&self) -> Result<(), Self::Error> {
369 self.media_service.clean_up_media_cache(self).await
370 }
371}
372
373#[cfg_attr(target_family = "wasm", async_trait(?Send))]
374#[cfg_attr(not(target_family = "wasm"), async_trait)]
375impl EventCacheStoreMedia for MemoryStore {
376 type Error = EventCacheStoreError;
377
378 async fn media_retention_policy_inner(
379 &self,
380 ) -> Result<Option<MediaRetentionPolicy>, Self::Error> {
381 Ok(self.inner.read().unwrap().media_retention_policy)
382 }
383
384 async fn set_media_retention_policy_inner(
385 &self,
386 policy: MediaRetentionPolicy,
387 ) -> Result<(), Self::Error> {
388 self.inner.write().unwrap().media_retention_policy = Some(policy);
389 Ok(())
390 }
391
392 async fn add_media_content_inner(
393 &self,
394 request: &MediaRequestParameters,
395 data: Vec<u8>,
396 last_access: SystemTime,
397 policy: MediaRetentionPolicy,
398 ignore_policy: IgnoreMediaRetentionPolicy,
399 ) -> Result<(), Self::Error> {
400 self.remove_media_content(request).await?;
402
403 let ignore_policy = ignore_policy.is_yes();
404
405 if !ignore_policy && policy.exceeds_max_file_size(data.len() as u64) {
406 return Ok(());
408 }
409
410 let mut inner = self.inner.write().unwrap();
412 inner.media.push(MediaContent {
413 uri: request.uri().to_owned(),
414 key: request.unique_key(),
415 data,
416 ignore_policy,
417 last_access,
418 });
419
420 Ok(())
421 }
422
423 async fn set_ignore_media_retention_policy_inner(
424 &self,
425 request: &MediaRequestParameters,
426 ignore_policy: IgnoreMediaRetentionPolicy,
427 ) -> Result<(), Self::Error> {
428 let mut inner = self.inner.write().unwrap();
429 let expected_key = request.unique_key();
430
431 if let Some(media_content) = inner.media.iter_mut().find(|media| media.key == expected_key)
432 {
433 media_content.ignore_policy = ignore_policy.is_yes();
434 }
435
436 Ok(())
437 }
438
439 async fn get_media_content_inner(
440 &self,
441 request: &MediaRequestParameters,
442 current_time: SystemTime,
443 ) -> Result<Option<Vec<u8>>, Self::Error> {
444 let mut inner = self.inner.write().unwrap();
445 let expected_key = request.unique_key();
446
447 let Some(index) = inner.media.iter().position(|media| media.key == expected_key) else {
450 return Ok(None);
451 };
452 let Some(mut content) = inner.media.remove(index) else {
453 return Ok(None);
454 };
455
456 let data = content.data.clone();
458
459 content.last_access = current_time;
461
462 inner.media.push(content);
464
465 Ok(Some(data))
466 }
467
468 async fn get_media_content_for_uri_inner(
469 &self,
470 expected_uri: &MxcUri,
471 current_time: SystemTime,
472 ) -> Result<Option<Vec<u8>>, Self::Error> {
473 let mut inner = self.inner.write().unwrap();
474
475 let Some(index) = inner.media.iter().position(|media| media.uri == expected_uri) else {
478 return Ok(None);
479 };
480 let Some(mut content) = inner.media.remove(index) else {
481 return Ok(None);
482 };
483
484 let data = content.data.clone();
486
487 content.last_access = current_time;
489
490 inner.media.push(content);
492
493 Ok(Some(data))
494 }
495
496 async fn clean_up_media_cache_inner(
497 &self,
498 policy: MediaRetentionPolicy,
499 current_time: SystemTime,
500 ) -> Result<(), Self::Error> {
501 if !policy.has_limitations() {
502 return Ok(());
504 }
505
506 let mut inner = self.inner.write().unwrap();
507
508 if policy.computed_max_file_size().is_some() {
510 inner.media.retain(|content| {
511 content.ignore_policy || !policy.exceeds_max_file_size(content.data.len() as u64)
512 });
513 }
514
515 if policy.last_access_expiry.is_some() {
517 inner.media.retain(|content| {
518 content.ignore_policy
519 || !policy.has_content_expired(current_time, content.last_access)
520 });
521 }
522
523 if let Some(max_cache_size) = policy.max_cache_size {
525 let (_, items_to_remove) = inner.media.iter().enumerate().rev().fold(
529 (0u64, Vec::with_capacity(NUMBER_OF_MEDIAS.into())),
530 |(mut cache_size, mut items_to_remove), (index, content)| {
531 if content.ignore_policy {
532 return (cache_size, items_to_remove);
534 }
535
536 let remove_item = if items_to_remove.is_empty() {
537 if let Some(sum) = cache_size.checked_add(content.data.len() as u64) {
539 cache_size = sum;
540 cache_size > max_cache_size
542 } else {
543 true
547 }
548 } else {
549 true
551 };
552
553 if remove_item {
554 items_to_remove.push(index);
555 }
556
557 (cache_size, items_to_remove)
558 },
559 );
560
561 for index in items_to_remove {
564 inner.media.remove(index);
565 }
566 }
567
568 inner.last_media_cleanup_time = current_time;
569
570 Ok(())
571 }
572
573 async fn last_media_cleanup_time_inner(&self) -> Result<Option<SystemTime>, Self::Error> {
574 Ok(Some(self.inner.read().unwrap().last_media_cleanup_time))
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use super::{MemoryStore, Result};
581 use crate::event_cache_store_media_integration_tests;
582
583 async fn get_event_cache_store() -> Result<MemoryStore> {
584 Ok(MemoryStore::new())
585 }
586
587 event_cache_store_integration_tests!();
588 event_cache_store_integration_tests_time!();
589 event_cache_store_media_integration_tests!(with_media_size_tests);
590}