use chrono::{DateTime, Duration, Utc};
use mas_axum_utils::cookies::CookieJar;
use mas_router::PostAuthAction;
use mas_storage::Clock;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use ulid::Ulid;
static COOKIE_NAME: &str = "upstream-oauth2-sessions";
static SESSION_MAX_TIME: Duration = Duration::microseconds(10 * 60 * 1000 * 1000);
#[derive(Serialize, Deserialize, Debug)]
pub struct Payload {
session: Ulid,
provider: Ulid,
state: String,
link: Option<Ulid>,
post_auth_action: Option<PostAuthAction>,
}
impl Payload {
fn expired(&self, now: DateTime<Utc>) -> bool {
let Ok(ts) = self.session.timestamp_ms().try_into() else {
return true;
};
let Some(when) = DateTime::from_timestamp_millis(ts) else {
return true;
};
now - when > SESSION_MAX_TIME
}
}
#[derive(Serialize, Deserialize, Default, Debug)]
pub struct UpstreamSessions(Vec<Payload>);
#[derive(Debug, Error, PartialEq, Eq)]
#[error("upstream session not found")]
pub struct UpstreamSessionNotFound;
impl UpstreamSessions {
pub fn load(cookie_jar: &CookieJar) -> Self {
match cookie_jar.load(COOKIE_NAME) {
Ok(Some(sessions)) => sessions,
Ok(None) => Self::default(),
Err(e) => {
tracing::warn!("Invalid upstream sessions cookie: {}", e);
Self::default()
}
}
}
pub fn save<C>(self, cookie_jar: CookieJar, clock: &C) -> CookieJar
where
C: Clock,
{
let this = self.expire(clock.now());
cookie_jar.save(COOKIE_NAME, &this, false)
}
fn expire(mut self, now: DateTime<Utc>) -> Self {
self.0.retain(|p| !p.expired(now));
self
}
pub fn add(
mut self,
session: Ulid,
provider: Ulid,
state: String,
post_auth_action: Option<PostAuthAction>,
) -> Self {
self.0.push(Payload {
session,
provider,
state,
link: None,
post_auth_action,
});
self
}
pub fn find_session(
&self,
provider: Ulid,
state: &str,
) -> Result<(Ulid, Option<&PostAuthAction>), UpstreamSessionNotFound> {
self.0
.iter()
.find(|p| p.provider == provider && p.state == state && p.link.is_none())
.map(|p| (p.session, p.post_auth_action.as_ref()))
.ok_or(UpstreamSessionNotFound)
}
pub fn add_link_to_session(
mut self,
session: Ulid,
link: Ulid,
) -> Result<Self, UpstreamSessionNotFound> {
let payload = self
.0
.iter_mut()
.find(|p| p.session == session && p.link.is_none())
.ok_or(UpstreamSessionNotFound)?;
payload.link = Some(link);
Ok(self)
}
pub fn lookup_link(
&self,
link_id: Ulid,
) -> Result<(Ulid, Option<&PostAuthAction>), UpstreamSessionNotFound> {
self.0
.iter()
.filter(|p| p.link == Some(link_id))
.reduce(|a, b| if a.session > b.session { a } else { b })
.map(|p| (p.session, p.post_auth_action.as_ref()))
.ok_or(UpstreamSessionNotFound)
}
pub fn consume_link(mut self, link_id: Ulid) -> Result<Self, UpstreamSessionNotFound> {
let pos = self
.0
.iter()
.position(|p| p.link == Some(link_id))
.ok_or(UpstreamSessionNotFound)?;
self.0.remove(pos);
Ok(self)
}
}
#[cfg(test)]
mod tests {
use chrono::TimeZone;
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use super::*;
#[test]
fn test_session_cookie() {
let now = chrono::Utc
.with_ymd_and_hms(2018, 1, 18, 1, 30, 22)
.unwrap();
let mut rng = ChaChaRng::seed_from_u64(42);
let sessions = UpstreamSessions::default();
let provider_a = Ulid::from_datetime_with_source(now.into(), &mut rng);
let provider_b = Ulid::from_datetime_with_source(now.into(), &mut rng);
let first_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
let first_state = "first-state";
let sessions = sessions.add(first_session, provider_a, first_state.into(), None);
let now = now + Duration::microseconds(5 * 60 * 1000 * 1000);
let second_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
let second_state = "second-state";
let sessions = sessions.add(second_session, provider_b, second_state.into(), None);
let sessions = sessions.expire(now);
assert_eq!(
sessions.find_session(provider_a, first_state).unwrap().0,
first_session,
);
assert_eq!(
sessions.find_session(provider_b, second_state).unwrap().0,
second_session
);
assert!(sessions.find_session(provider_b, first_state).is_err());
assert!(sessions.find_session(provider_a, second_state).is_err());
let now = now + Duration::microseconds(6 * 60 * 1000 * 1000);
let sessions = sessions.expire(now);
assert!(sessions.find_session(provider_a, first_state).is_err());
assert_eq!(
sessions.find_session(provider_b, second_state).unwrap().0,
second_session
);
let second_link = Ulid::from_datetime_with_source(now.into(), &mut rng);
let sessions = sessions
.add_link_to_session(second_session, second_link)
.unwrap();
assert!(sessions.find_session(provider_b, second_state).is_err());
assert_eq!(sessions.lookup_link(second_link).unwrap().0, second_session);
let sessions = sessions.consume_link(second_link).unwrap();
assert!(sessions.consume_link(second_link).is_err());
}
}