matrix_sdk_base/room/
knock.rs

1// Copyright 2025 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16
17use eyeball::{AsyncLock, ObservableWriteGuard};
18use ruma::{
19    events::{
20        room::member::{MembershipState, RoomMemberEventContent},
21        StateEventType, SyncStateEvent,
22    },
23    OwnedEventId, OwnedUserId,
24};
25use tracing::warn;
26
27use super::Room;
28use crate::{
29    deserialized_responses::{MemberEvent, RawMemberEvent, SyncOrStrippedState},
30    store::{Result as StoreResult, StateStoreExt},
31    StateStoreDataKey, StateStoreDataValue, StoreError,
32};
33
34impl Room {
35    /// Mark a list of requests to join the room as seen, given their state
36    /// event ids.
37    pub async fn mark_knock_requests_as_seen(&self, user_ids: &[OwnedUserId]) -> StoreResult<()> {
38        let raw_user_ids: Vec<&str> = user_ids.iter().map(|id| id.as_str()).collect();
39        let member_raw_events = self
40            .store
41            .get_state_events_for_keys(self.room_id(), StateEventType::RoomMember, &raw_user_ids)
42            .await?;
43        let mut event_to_user_ids = Vec::with_capacity(member_raw_events.len());
44
45        // Map the list of events ids to their user ids, if they are event ids for knock
46        // membership events. Log an error and continue otherwise.
47        for raw_event in member_raw_events {
48            let event = raw_event.cast::<RoomMemberEventContent>().deserialize()?;
49            match event {
50                SyncOrStrippedState::Sync(SyncStateEvent::Original(event)) => {
51                    if event.content.membership == MembershipState::Knock {
52                        event_to_user_ids.push((event.event_id, event.state_key))
53                    } else {
54                        warn!(
55                            "Could not mark knock event as seen: event {} for user {} \
56                             is not in Knock membership state.",
57                            event.event_id, event.state_key
58                        );
59                    }
60                }
61                _ => warn!(
62                    "Could not mark knock event as seen: event for user {} is not valid.",
63                    event.state_key()
64                ),
65            }
66        }
67
68        let current_seen_events_guard = self.get_write_guarded_current_knock_request_ids().await?;
69        let mut current_seen_events = current_seen_events_guard.clone().unwrap_or_default();
70
71        current_seen_events.extend(event_to_user_ids);
72
73        self.update_seen_knock_request_ids(current_seen_events_guard, current_seen_events).await?;
74
75        Ok(())
76    }
77
78    /// Removes the seen knock request ids that are no longer valid given the
79    /// current room members.
80    pub async fn remove_outdated_seen_knock_requests_ids(&self) -> StoreResult<()> {
81        let current_seen_events_guard = self.get_write_guarded_current_knock_request_ids().await?;
82        let mut current_seen_events = current_seen_events_guard.clone().unwrap_or_default();
83
84        // Get and deserialize the member events for the seen knock requests
85        let keys: Vec<OwnedUserId> = current_seen_events.values().map(|id| id.to_owned()).collect();
86        let raw_member_events: Vec<RawMemberEvent> =
87            self.store.get_state_events_for_keys_static(self.room_id(), &keys).await?;
88        let member_events = raw_member_events
89            .into_iter()
90            .map(|raw| raw.deserialize())
91            .collect::<Result<Vec<MemberEvent>, _>>()?;
92
93        let mut ids_to_remove = Vec::new();
94
95        for (event_id, user_id) in current_seen_events.iter() {
96            // Check the seen knock request ids against the current room member events for
97            // the room members associated to them
98            let matching_member = member_events.iter().find(|event| event.user_id() == user_id);
99
100            if let Some(member) = matching_member {
101                let member_event_id = member.event_id();
102                // If the member event is not a knock or it's different knock, it's outdated
103                if *member.membership() != MembershipState::Knock
104                    || member_event_id.is_some_and(|id| id != event_id)
105                {
106                    ids_to_remove.push(event_id.to_owned());
107                }
108            } else {
109                ids_to_remove.push(event_id.to_owned());
110            }
111        }
112
113        // If there are no ids to remove, do nothing
114        if ids_to_remove.is_empty() {
115            return Ok(());
116        }
117
118        for event_id in ids_to_remove {
119            current_seen_events.remove(&event_id);
120        }
121
122        self.update_seen_knock_request_ids(current_seen_events_guard, current_seen_events).await?;
123
124        Ok(())
125    }
126
127    /// Get the list of seen knock request event ids in this room.
128    pub async fn get_seen_knock_request_ids(
129        &self,
130    ) -> Result<BTreeMap<OwnedEventId, OwnedUserId>, StoreError> {
131        Ok(self.get_write_guarded_current_knock_request_ids().await?.clone().unwrap_or_default())
132    }
133
134    async fn get_write_guarded_current_knock_request_ids(
135        &self,
136    ) -> StoreResult<ObservableWriteGuard<'_, Option<BTreeMap<OwnedEventId, OwnedUserId>>, AsyncLock>>
137    {
138        let mut guard = self.seen_knock_request_ids_map.write().await;
139        // If there are no loaded request ids yet
140        if guard.is_none() {
141            // Load the values from the store and update the shared observable contents
142            let updated_seen_ids = self
143                .store
144                .get_kv_data(StateStoreDataKey::SeenKnockRequests(self.room_id()))
145                .await?
146                .and_then(|v| v.into_seen_knock_requests())
147                .unwrap_or_default();
148
149            ObservableWriteGuard::set(&mut guard, Some(updated_seen_ids));
150        }
151        Ok(guard)
152    }
153
154    async fn update_seen_knock_request_ids(
155        &self,
156        mut guard: ObservableWriteGuard<'_, Option<BTreeMap<OwnedEventId, OwnedUserId>>, AsyncLock>,
157        new_value: BTreeMap<OwnedEventId, OwnedUserId>,
158    ) -> StoreResult<()> {
159        // Save the new values to the shared observable
160        ObservableWriteGuard::set(&mut guard, Some(new_value.clone()));
161
162        // Save them into the store too
163        self.store
164            .set_kv_data(
165                StateStoreDataKey::SeenKnockRequests(self.room_id()),
166                StateStoreDataValue::SeenKnockRequests(new_value),
167            )
168            .await?;
169
170        Ok(())
171    }
172}