matrix_sdk/authentication/
mod.rs1use std::{fmt, sync::Arc};
18
19use matrix_sdk_base::{locks::Mutex, SessionMeta};
20use serde::{Deserialize, Serialize};
21use tokio::sync::{broadcast, Mutex as AsyncMutex, OnceCell};
22
23pub mod matrix;
24pub mod oauth;
25
26use self::{
27 matrix::MatrixAuth,
28 oauth::{OAuth, OAuthAuthData, OAuthCtx},
29};
30use crate::{Client, RefreshTokenError, SessionChange};
31
32#[derive(Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
34#[allow(missing_debug_implementations)]
35pub struct SessionTokens {
36 pub access_token: String,
38
39 #[serde(default, skip_serializing_if = "Option::is_none")]
41 pub refresh_token: Option<String>,
42}
43
44#[cfg(not(tarpaulin_include))]
45impl fmt::Debug for SessionTokens {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 f.debug_struct("SessionTokens").finish_non_exhaustive()
48 }
49}
50
51pub(crate) type SessionCallbackError = Box<dyn std::error::Error + Send + Sync>;
52
53#[cfg(not(target_family = "wasm"))]
54pub(crate) type SaveSessionCallback =
55 dyn Fn(Client) -> Result<(), SessionCallbackError> + Send + Sync;
56#[cfg(target_family = "wasm")]
57pub(crate) type SaveSessionCallback = dyn Fn(Client) -> Result<(), SessionCallbackError>;
58
59#[cfg(not(target_family = "wasm"))]
60pub(crate) type ReloadSessionCallback =
61 dyn Fn(Client) -> Result<SessionTokens, SessionCallbackError> + Send + Sync;
62#[cfg(target_family = "wasm")]
63pub(crate) type ReloadSessionCallback =
64 dyn Fn(Client) -> Result<SessionTokens, SessionCallbackError>;
65
66pub(crate) struct AuthCtx {
69 pub(crate) oauth: OAuthCtx,
70
71 pub(crate) handle_refresh_tokens: bool,
74
75 pub(crate) refresh_token_lock: Arc<AsyncMutex<Result<(), RefreshTokenError>>>,
77
78 pub(crate) session_change_sender: broadcast::Sender<SessionChange>,
82
83 pub(crate) auth_data: OnceCell<AuthData>,
85
86 pub(crate) tokens: OnceCell<Mutex<SessionTokens>>,
88
89 pub(crate) reload_session_callback: OnceCell<Box<ReloadSessionCallback>>,
94
95 pub(crate) save_session_callback: OnceCell<Box<SaveSessionCallback>>,
103}
104
105impl AuthCtx {
106 pub(crate) fn session_tokens(&self) -> Option<SessionTokens> {
108 Some(self.tokens.get()?.lock().clone())
109 }
110
111 pub(crate) fn access_token(&self) -> Option<String> {
113 Some(self.tokens.get()?.lock().access_token.clone())
114 }
115
116 pub(crate) fn set_session_tokens(&self, session_tokens: SessionTokens) {
118 if let Some(tokens) = self.tokens.get() {
119 *tokens.lock() = session_tokens;
120 } else {
121 let _ = self.tokens.set(Mutex::new(session_tokens));
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
128#[non_exhaustive]
129pub enum AuthApi {
130 Matrix(MatrixAuth),
132
133 OAuth(OAuth),
135}
136
137#[derive(Debug, Clone)]
139#[non_exhaustive]
140pub enum AuthSession {
141 Matrix(matrix::MatrixSession),
143
144 OAuth(Box<oauth::OAuthSession>),
146}
147
148impl AuthSession {
149 pub fn meta(&self) -> &SessionMeta {
151 match self {
152 AuthSession::Matrix(session) => &session.meta,
153 AuthSession::OAuth(session) => &session.user.meta,
154 }
155 }
156
157 pub fn into_meta(self) -> SessionMeta {
159 match self {
160 AuthSession::Matrix(session) => session.meta,
161 AuthSession::OAuth(session) => session.user.meta,
162 }
163 }
164
165 pub fn access_token(&self) -> &str {
167 match self {
168 AuthSession::Matrix(session) => &session.tokens.access_token,
169 AuthSession::OAuth(session) => &session.user.tokens.access_token,
170 }
171 }
172
173 pub fn get_refresh_token(&self) -> Option<&str> {
175 match self {
176 AuthSession::Matrix(session) => session.tokens.refresh_token.as_deref(),
177 AuthSession::OAuth(session) => session.user.tokens.refresh_token.as_deref(),
178 }
179 }
180}
181
182impl From<matrix::MatrixSession> for AuthSession {
183 fn from(session: matrix::MatrixSession) -> Self {
184 Self::Matrix(session)
185 }
186}
187
188impl From<oauth::OAuthSession> for AuthSession {
189 fn from(session: oauth::OAuthSession) -> Self {
190 Self::OAuth(session.into())
191 }
192}
193
194#[derive(Debug)]
196pub(crate) enum AuthData {
197 Matrix,
199 OAuth(OAuthAuthData),
201}