1// Copyright 2023 Kévin Commaille
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
1415//! Types and functions related to authentication in Matrix.
1617use std::{fmt, sync::Arc};
1819use matrix_sdk_base::{locks::Mutex, SessionMeta};
20use serde::{Deserialize, Serialize};
21use tokio::sync::{broadcast, Mutex as AsyncMutex, OnceCell};
2223pub mod matrix;
24pub mod oauth;
2526use self::{
27 matrix::MatrixAuth,
28 oauth::{OAuth, OAuthAuthData, OAuthCtx},
29};
30use crate::{Client, RefreshTokenError, SessionChange};
3132/// The tokens for a user session.
33#[derive(Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
34#[allow(missing_debug_implementations)]
35pub struct SessionTokens {
36/// The access token used for this session.
37pub access_token: String,
3839/// The token used for refreshing the access token, if any.
40#[serde(default, skip_serializing_if = "Option::is_none")]
41pub refresh_token: Option<String>,
42}
4344#[cfg(not(tarpaulin_include))]
45impl fmt::Debug for SessionTokens {
46fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 f.debug_struct("SessionTokens").finish_non_exhaustive()
48 }
49}
5051pub(crate) type SessionCallbackError = Box<dyn std::error::Error + Send + Sync>;
52pub(crate) type SaveSessionCallback =
53dyn Fn(Client) -> Result<(), SessionCallbackError> + Send + Sync;
54pub(crate) type ReloadSessionCallback =
55dyn Fn(Client) -> Result<SessionTokens, SessionCallbackError> + Send + Sync;
5657/// All the data relative to authentication, and that must be shared between a
58/// client and all its children.
59pub(crate) struct AuthCtx {
60pub(crate) oauth: OAuthCtx,
6162/// Whether to try to refresh the access token automatically when an
63 /// `M_UNKNOWN_TOKEN` error is encountered.
64pub(crate) handle_refresh_tokens: bool,
6566/// Lock making sure we're only doing one token refresh at a time.
67pub(crate) refresh_token_lock: Arc<AsyncMutex<Result<(), RefreshTokenError>>>,
6869/// Session change publisher. Allows the subscriber to handle changes to the
70 /// session such as logging out when the access token is invalid or
71 /// persisting updates to the access/refresh tokens.
72pub(crate) session_change_sender: broadcast::Sender<SessionChange>,
7374/// Authentication data to keep in memory.
75pub(crate) auth_data: OnceCell<AuthData>,
7677/// The current session tokens.
78pub(crate) tokens: OnceCell<Mutex<SessionTokens>>,
7980/// A callback called whenever we need an absolute source of truth for the
81 /// current session tokens.
82 ///
83 /// This is required only in multiple processes setups.
84pub(crate) reload_session_callback: OnceCell<Box<ReloadSessionCallback>>,
8586/// A callback to save a session back into the app's secure storage.
87 ///
88 /// This is always called, independently of the presence of a cross-process
89 /// lock.
90 ///
91 /// Internal invariant: this must be called only after `set_session_tokens`
92 /// has been called, not before.
93pub(crate) save_session_callback: OnceCell<Box<SaveSessionCallback>>,
94}
9596impl AuthCtx {
97/// The current session tokens.
98pub(crate) fn session_tokens(&self) -> Option<SessionTokens> {
99Some(self.tokens.get()?.lock().clone())
100 }
101102/// The current access token.
103pub(crate) fn access_token(&self) -> Option<String> {
104Some(self.tokens.get()?.lock().access_token.clone())
105 }
106107/// Set the current session tokens.
108pub(crate) fn set_session_tokens(&self, session_tokens: SessionTokens) {
109if let Some(tokens) = self.tokens.get() {
110*tokens.lock() = session_tokens;
111 } else {
112let _ = self.tokens.set(Mutex::new(session_tokens));
113 }
114 }
115}
116117/// An enum over all the possible authentication APIs.
118#[derive(Debug, Clone)]
119#[non_exhaustive]
120pub enum AuthApi {
121/// The native Matrix authentication API.
122Matrix(MatrixAuth),
123124/// The OAuth 2.0 API.
125OAuth(OAuth),
126}
127128/// A user session using one of the available authentication APIs.
129#[derive(Debug, Clone)]
130#[non_exhaustive]
131pub enum AuthSession {
132/// A session using the native Matrix authentication API.
133Matrix(matrix::MatrixSession),
134135/// A session using the OAuth 2.0 API.
136OAuth(Box<oauth::OAuthSession>),
137}
138139impl AuthSession {
140/// Get the matrix user information of this session.
141pub fn meta(&self) -> &SessionMeta {
142match self {
143 AuthSession::Matrix(session) => &session.meta,
144 AuthSession::OAuth(session) => &session.user.meta,
145 }
146 }
147148/// Take the matrix user information of this session.
149pub fn into_meta(self) -> SessionMeta {
150match self {
151 AuthSession::Matrix(session) => session.meta,
152 AuthSession::OAuth(session) => session.user.meta,
153 }
154 }
155156/// Get the access token of this session.
157pub fn access_token(&self) -> &str {
158match self {
159 AuthSession::Matrix(session) => &session.tokens.access_token,
160 AuthSession::OAuth(session) => &session.user.tokens.access_token,
161 }
162 }
163164/// Get the refresh token of this session.
165pub fn get_refresh_token(&self) -> Option<&str> {
166match self {
167 AuthSession::Matrix(session) => session.tokens.refresh_token.as_deref(),
168 AuthSession::OAuth(session) => session.user.tokens.refresh_token.as_deref(),
169 }
170 }
171}
172173impl From<matrix::MatrixSession> for AuthSession {
174fn from(session: matrix::MatrixSession) -> Self {
175Self::Matrix(session)
176 }
177}
178179impl From<oauth::OAuthSession> for AuthSession {
180fn from(session: oauth::OAuthSession) -> Self {
181Self::OAuth(session.into())
182 }
183}
184185/// Data for an authentication API.
186#[derive(Debug)]
187pub(crate) enum AuthData {
188/// Data for the native Matrix authentication API.
189Matrix,
190/// Data for the OAuth 2.0 API.
191OAuth(OAuthAuthData),
192}