Skip to main content

matrix_sdk_common/
ttl.rs

1// Copyright 2025 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Types to implement TTL caches which can be used to persist data for a fixed
16//! duration.
17
18use ruma::time::SystemTime;
19use serde::{Deserialize, Serialize};
20
21/// A value that expires after some time.
22///
23/// This value is (de)serializable so it can be persisted in a store.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct TtlValue<T> {
26    /// The data of the item.
27    #[serde(flatten)]
28    data: T,
29
30    /// Last time we fetched this data from the server, in milliseconds since
31    /// UNIX epoch.
32    ///
33    /// When this field is missing during deserialization, it defaults to `0.0`,
34    /// which means that the data is always expired. This allows to be
35    /// compatible with data that was persisted before deciding to add an
36    /// expiration time.
37    #[serde(default = "default_timestamp")]
38    last_fetch_ts: Option<f64>,
39}
40
41impl<T> TtlValue<T> {
42    /// The number of milliseconds after which the data is considered stale.
43    ///
44    /// This matches 1 day.
45    pub const STALE_THRESHOLD: f64 = (1000 * 60 * 60 * 24) as _;
46
47    /// Construct a new `TtlValue` with the given data.
48    pub fn new(data: T) -> Self {
49        Self { data, last_fetch_ts: Some(now_timestamp_ms()) }
50    }
51
52    /// Construct a new `TtlValue` with the given data that never expires.
53    pub fn without_expiry(data: T) -> Self {
54        Self { data, last_fetch_ts: None }
55    }
56
57    /// Converts from `&TtlValue<T>` to `TtlValue<&T>`.
58    pub fn as_ref(&self) -> TtlValue<&T> {
59        TtlValue { data: &self.data, last_fetch_ts: self.last_fetch_ts }
60    }
61
62    /// Transform the data of this `TtlValue` with the given function.
63    pub fn map<U, F>(self, f: F) -> TtlValue<U>
64    where
65        F: FnOnce(T) -> U,
66    {
67        TtlValue { data: f(self.data), last_fetch_ts: self.last_fetch_ts }
68    }
69
70    /// Whether this value has expired.
71    pub fn has_expired(&self) -> bool {
72        self.last_fetch_ts.is_some_and(|ts| now_timestamp_ms() - ts >= Self::STALE_THRESHOLD)
73    }
74
75    /// Mark this value has expired.
76    pub fn expire(&mut self) {
77        // We assume that the system time is always correct and we are far from the UNIX
78        // epoch so a timestamp of 0 should always be expired.
79        self.last_fetch_ts = Some(0.0)
80    }
81
82    /// Get a reference to the data of this value.
83    pub fn data(&self) -> &T {
84        &self.data
85    }
86
87    /// Get the data of this value.
88    pub fn into_data(self) -> T {
89        self.data
90    }
91}
92
93/// Get the current timestamp as the number of milliseconds since Unix Epoch.
94fn now_timestamp_ms() -> f64 {
95    SystemTime::now()
96        .duration_since(SystemTime::UNIX_EPOCH)
97        .expect("System clock was before 1970.")
98        .as_secs_f64()
99        * 1000.0
100}
101
102/// The default timestamp if it is missing during deserialization.
103///
104/// We expect that a value that was serialized always has an expiry time, so the
105/// default is `Some(0.0)`.
106fn default_timestamp() -> Option<f64> {
107    Some(0.0)
108}
109
110#[cfg(test)]
111mod tests {
112    use serde::{Deserialize, Serialize};
113    use serde_json::json;
114
115    use super::{TtlValue, now_timestamp_ms};
116
117    #[test]
118    fn test_ttl_value_expiry() {
119        // Definitely stale.
120        let ttl_value = TtlValue {
121            data: (),
122            last_fetch_ts: Some(now_timestamp_ms() - TtlValue::<()>::STALE_THRESHOLD - 1.0),
123        };
124        assert!(ttl_value.has_expired());
125
126        // Definitely not stale.
127        let ttl_value = TtlValue::new(());
128        assert!(!ttl_value.has_expired());
129
130        // Cannot be stale.
131        let ttl_value = TtlValue::without_expiry(());
132        assert!(!ttl_value.has_expired());
133    }
134
135    #[test]
136    fn test_ttl_value_serialize_roundtrip() {
137        #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
138        struct Data {
139            foo: String,
140        }
141
142        let data = Data { foo: "bar".to_owned() };
143
144        // With timestamp.
145        let ttl_value = TtlValue { data: data.clone(), last_fetch_ts: Some(1000.0) };
146        let json = json!({
147            "foo": "bar",
148            "last_fetch_ts": 1000.0,
149        });
150        assert_eq!(serde_json::to_value(&ttl_value).unwrap(), json);
151
152        let deserialized = serde_json::from_value::<TtlValue<Data>>(json).unwrap();
153        assert_eq!(deserialized.data, data);
154        assert!(deserialized.last_fetch_ts.unwrap() - ttl_value.last_fetch_ts.unwrap() < 0.0001);
155
156        // Without timestamp the value is always expired in theory.
157        let json = json!({
158            "foo": "bar",
159        });
160        let deserialized = serde_json::from_value::<TtlValue<Data>>(json).unwrap();
161        assert_eq!(deserialized.data, data);
162        assert!(deserialized.last_fetch_ts.unwrap() - 0.0 < 0.0001);
163    }
164}