use axum::{extract::State, response::IntoResponse, Json};
use axum_extra::typed_header::TypedHeader;
use chrono::Duration;
use hyper::StatusCode;
use mas_axum_utils::sentry::SentryEventID;
use mas_data_model::{
CompatSession, CompatSsoLoginState, Device, SiteConfig, TokenType, User, UserAgent,
};
use mas_matrix::BoxHomeserverConnection;
use mas_storage::{
compat::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
CompatSsoLoginRepository,
},
user::{UserPasswordRepository, UserRepository},
BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
};
use rand::{CryptoRng, RngCore};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds};
use thiserror::Error;
use zeroize::Zeroizing;
use super::MatrixError;
use crate::{
impl_from_error_for_route, passwords::PasswordManager, rate_limit::PasswordCheckLimitedError,
BoundActivityTracker, Limiter, RequesterFingerprint,
};
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
enum LoginType {
#[serde(rename = "m.login.password")]
Password,
#[serde(rename = "m.login.token")]
Token,
#[serde(rename = "m.login.sso")]
Sso {
#[serde(skip_serializing_if = "Vec::is_empty")]
identity_providers: Vec<SsoIdentityProvider>,
#[serde(rename = "org.matrix.msc3824.delegated_oidc_compatibility")]
delegated_oidc_compatibility: bool,
},
}
#[derive(Debug, Serialize)]
struct SsoIdentityProvider {
id: &'static str,
name: &'static str,
}
#[derive(Debug, Serialize)]
struct LoginTypes {
flows: Vec<LoginType>,
}
#[tracing::instrument(name = "handlers.compat.login.get", skip_all)]
pub(crate) async fn get(State(password_manager): State<PasswordManager>) -> impl IntoResponse {
let flows = if password_manager.is_enabled() {
vec![
LoginType::Password,
LoginType::Sso {
identity_providers: vec![],
delegated_oidc_compatibility: true,
},
LoginType::Token,
]
} else {
vec![
LoginType::Sso {
identity_providers: vec![],
delegated_oidc_compatibility: true,
},
LoginType::Token,
]
};
let res = LoginTypes { flows };
Json(res)
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RequestBody {
#[serde(flatten)]
credentials: Credentials,
#[serde(default)]
refresh_token: bool,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Credentials {
#[serde(rename = "m.login.password")]
Password {
identifier: Identifier,
password: String,
},
#[serde(rename = "m.login.token")]
Token { token: String },
#[serde(other)]
Unsupported,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Identifier {
#[serde(rename = "m.id.user")]
User { user: String },
#[serde(other)]
Unsupported,
}
#[skip_serializing_none]
#[serde_as]
#[derive(Debug, Serialize, Deserialize)]
pub struct ResponseBody {
access_token: String,
device_id: Device,
user_id: String,
refresh_token: Option<String>,
#[serde_as(as = "Option<DurationMilliSeconds<i64>>")]
expires_in_ms: Option<Duration>,
}
#[derive(Debug, Error)]
pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("unsupported login method")]
Unsupported,
#[error("user not found")]
UserNotFound,
#[error("session not found")]
SessionNotFound,
#[error("user has no password")]
NoPassword,
#[error("password verification failed")]
PasswordVerificationFailed(#[source] anyhow::Error),
#[error("request rate limited")]
RateLimited(#[from] PasswordCheckLimitedError),
#[error("login took too long")]
LoginTookTooLong,
#[error("invalid login token")]
InvalidLoginToken,
#[error("failed to provision device")]
ProvisionDeviceFailed(#[source] anyhow::Error),
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let response = match self {
Self::Internal(_) | Self::SessionNotFound | Self::ProvisionDeviceFailed(_) => {
MatrixError {
errcode: "M_UNKNOWN",
error: "Internal server error",
status: StatusCode::INTERNAL_SERVER_ERROR,
}
}
Self::RateLimited(_) => MatrixError {
errcode: "M_LIMIT_EXCEEDED",
error: "Too many login attempts",
status: StatusCode::TOO_MANY_REQUESTS,
},
Self::Unsupported => MatrixError {
errcode: "M_UNRECOGNIZED",
error: "Invalid login type",
status: StatusCode::BAD_REQUEST,
},
Self::UserNotFound | Self::NoPassword | Self::PasswordVerificationFailed(_) => {
MatrixError {
errcode: "M_FORBIDDEN",
error: "Invalid username/password",
status: StatusCode::FORBIDDEN,
}
}
Self::LoginTookTooLong => MatrixError {
errcode: "M_FORBIDDEN",
error: "Login token expired",
status: StatusCode::FORBIDDEN,
},
Self::InvalidLoginToken => MatrixError {
errcode: "M_FORBIDDEN",
error: "Invalid login token",
status: StatusCode::FORBIDDEN,
},
};
(SentryEventID::from(event_id), response).into_response()
}
}
#[tracing::instrument(name = "handlers.compat.login.post", skip_all, err)]
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
State(password_manager): State<PasswordManager>,
mut repo: BoxRepository,
activity_tracker: BoundActivityTracker,
State(homeserver): State<BoxHomeserverConnection>,
State(site_config): State<SiteConfig>,
State(limiter): State<Limiter>,
requester: RequesterFingerprint,
user_agent: Option<TypedHeader<headers::UserAgent>>,
Json(input): Json<RequestBody>,
) -> Result<impl IntoResponse, RouteError> {
let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned()));
let (mut session, user) = match (password_manager.is_enabled(), input.credentials) {
(
true,
Credentials::Password {
identifier: Identifier::User { user },
password,
},
) => {
user_password_login(
&mut rng,
&clock,
&password_manager,
&limiter,
requester,
&mut repo,
&homeserver,
user,
password,
)
.await?
}
(_, Credentials::Token { token }) => token_login(&mut repo, &clock, &token).await?,
_ => {
return Err(RouteError::Unsupported);
}
};
if let Some(user_agent) = user_agent {
session = repo
.compat_session()
.record_user_agent(session, user_agent)
.await?;
}
let user_id = homeserver.mxid(&user.username);
let expires_in = if input.refresh_token {
Some(site_config.compat_token_ttl)
} else {
None
};
let access_token = TokenType::CompatAccessToken.generate(&mut rng);
let access_token = repo
.compat_access_token()
.add(&mut rng, &clock, &session, access_token, expires_in)
.await?;
let refresh_token = if input.refresh_token {
let refresh_token = TokenType::CompatRefreshToken.generate(&mut rng);
let refresh_token = repo
.compat_refresh_token()
.add(&mut rng, &clock, &session, &access_token, refresh_token)
.await?;
Some(refresh_token.token)
} else {
None
};
repo.save().await?;
activity_tracker
.record_compat_session(&clock, &session)
.await;
Ok(Json(ResponseBody {
access_token: access_token.token,
device_id: session.device,
user_id,
refresh_token,
expires_in_ms: expires_in,
}))
}
async fn token_login(
repo: &mut BoxRepository,
clock: &dyn Clock,
token: &str,
) -> Result<(CompatSession, User), RouteError> {
let login = repo
.compat_sso_login()
.find_by_token(token)
.await?
.ok_or(RouteError::InvalidLoginToken)?;
let now = clock.now();
let session_id = match login.state {
CompatSsoLoginState::Pending => {
tracing::error!(
compat_sso_login.id = %login.id,
"Exchanged a token for a login that was not fullfilled yet"
);
return Err(RouteError::InvalidLoginToken);
}
CompatSsoLoginState::Fulfilled {
fulfilled_at,
session_id,
..
} => {
if now > fulfilled_at + Duration::microseconds(30 * 1000 * 1000) {
return Err(RouteError::LoginTookTooLong);
}
session_id
}
CompatSsoLoginState::Exchanged {
exchanged_at,
session_id,
..
} => {
if now > exchanged_at + Duration::microseconds(30 * 1000 * 1000) {
tracing::error!(
compat_sso_login.id = %login.id,
compat_session.id = %session_id,
"Login token exchanged a second time more than 30s after"
);
}
return Err(RouteError::InvalidLoginToken);
}
};
let session = repo
.compat_session()
.lookup(session_id)
.await?
.ok_or(RouteError::SessionNotFound)?;
let user = repo
.user()
.lookup(session.user_id)
.await?
.filter(mas_data_model::User::is_valid)
.ok_or(RouteError::UserNotFound)?;
repo.compat_sso_login().exchange(clock, login).await?;
Ok((session, user))
}
async fn user_password_login(
mut rng: &mut (impl RngCore + CryptoRng + Send),
clock: &impl Clock,
password_manager: &PasswordManager,
limiter: &Limiter,
requester: RequesterFingerprint,
repo: &mut BoxRepository,
homeserver: &BoxHomeserverConnection,
username: String,
password: String,
) -> Result<(CompatSession, User), RouteError> {
let user = repo
.user()
.find_by_username(&username)
.await?
.filter(mas_data_model::User::is_valid)
.ok_or(RouteError::UserNotFound)?;
limiter.check_password(requester, &user)?;
let user_password = repo
.user_password()
.active(&user)
.await?
.ok_or(RouteError::NoPassword)?;
let password = Zeroizing::new(password.into_bytes());
let new_password_hash = password_manager
.verify_and_upgrade(
&mut rng,
user_password.version,
password,
user_password.hashed_password.clone(),
)
.await
.map_err(RouteError::PasswordVerificationFailed)?;
if let Some((version, hashed_password)) = new_password_hash {
repo.user_password()
.add(
&mut rng,
clock,
&user,
version,
hashed_password,
Some(&user_password),
)
.await?;
}
repo.user().acquire_lock_for_sync(&user).await?;
let device = Device::generate(&mut rng);
let mxid = homeserver.mxid(&user.username);
homeserver
.create_device(&mxid, device.as_str())
.await
.map_err(RouteError::ProvisionDeviceFailed)?;
let session = repo
.compat_session()
.add(&mut rng, clock, &user, device, None, false)
.await?;
Ok((session, user))
}
#[cfg(test)]
mod tests {
use hyper::Request;
use mas_matrix::{HomeserverConnection, ProvisionRequest};
use rand::distributions::{Alphanumeric, DistString};
use sqlx::PgPool;
use super::*;
use crate::test_utils::{setup, RequestBuilderExt, ResponseExt, TestState};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_get_login(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let request = Request::get("/_matrix/client/v3/login").empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_eq!(
body,
serde_json::json!({
"flows": [
{
"type": "m.login.password",
},
{
"type": "m.login.sso",
"org.matrix.msc3824.delegated_oidc_compatibility": true,
},
{
"type": "m.login.token",
}
],
})
);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_disabled(pool: PgPool) {
setup();
let state = {
let mut state = TestState::from_pool(pool).await.unwrap();
state.password_manager = PasswordManager::disabled();
state
};
let request = Request::get("/_matrix/client/v3/login").empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_eq!(
body,
serde_json::json!({
"flows": [
{
"type": "m.login.sso",
"org.matrix.msc3824.delegated_oidc_compatibility": true,
},
{
"type": "m.login.token",
}
],
})
);
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
"type": "m.login.password",
"identifier": {
"type": "m.id.user",
"user": "alice",
},
"password": "password",
}));
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let body: serde_json::Value = response.json();
assert_eq!(body["errcode"], "M_UNRECOGNIZED");
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_user_password_login(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.add(&mut state.rng(), &state.clock, "alice".to_owned())
.await
.unwrap();
let mxid = state.homeserver_connection.mxid(&user.username);
state
.homeserver_connection
.provision_user(&ProvisionRequest::new(mxid, &user.sub))
.await
.unwrap();
let (version, hashed_password) = state
.password_manager
.hash(
&mut state.rng(),
Zeroizing::new("password".to_owned().into_bytes()),
)
.await
.unwrap();
repo.user_password()
.add(
&mut state.rng(),
&state.clock,
&user,
version,
hashed_password,
None,
)
.await
.unwrap();
repo.save().await.unwrap();
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
"type": "m.login.password",
"identifier": {
"type": "m.id.user",
"user": "alice",
},
"password": "password",
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: ResponseBody = response.json();
assert!(!body.access_token.is_empty());
assert_eq!(body.device_id.as_str().len(), 10);
assert_eq!(body.user_id, "@alice:example.com");
assert_eq!(body.refresh_token, None);
assert_eq!(body.expires_in_ms, None);
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
"type": "m.login.password",
"identifier": {
"type": "m.id.user",
"user": "alice",
},
"password": "password",
"refresh_token": true,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: ResponseBody = response.json();
assert!(!body.access_token.is_empty());
assert_eq!(body.device_id.as_str().len(), 10);
assert_eq!(body.user_id, "@alice:example.com");
assert!(body.refresh_token.is_some());
assert!(body.expires_in_ms.is_some());
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
"type": "m.login.password",
"identifier": {
"type": "m.id.user",
"user": "alice",
},
"password": "wrongpassword",
}));
let response = state.request(request).await;
response.assert_status(StatusCode::FORBIDDEN);
let body: serde_json::Value = response.json();
assert_eq!(body["errcode"], "M_FORBIDDEN");
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
"type": "m.login.password",
"identifier": {
"type": "m.id.user",
"user": "bob",
},
"password": "wrongpassword",
}));
let old_body = body;
let response = state.request(request).await;
response.assert_status(StatusCode::FORBIDDEN);
let body: serde_json::Value = response.json();
assert_eq!(body, old_body);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_login_rate_limit(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.add(&mut state.rng(), &state.clock, "alice".to_owned())
.await
.unwrap();
let mxid = state.homeserver_connection.mxid(&user.username);
state
.homeserver_connection
.provision_user(&ProvisionRequest::new(mxid, &user.sub))
.await
.unwrap();
repo.save().await.unwrap();
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
"type": "m.login.password",
"identifier": {
"type": "m.id.user",
"user": "alice",
},
"password": "password",
}));
let response = state.request(request.clone()).await;
response.assert_status(StatusCode::FORBIDDEN);
let response = state.request(request.clone()).await;
response.assert_status(StatusCode::FORBIDDEN);
let response = state.request(request.clone()).await;
response.assert_status(StatusCode::FORBIDDEN);
let response = state.request(request.clone()).await;
response.assert_status(StatusCode::TOO_MANY_REQUESTS);
let body: serde_json::Value = response.json();
assert_eq!(body["errcode"], "M_LIMIT_EXCEEDED");
assert_eq!(body["error"], "Too many login attempts");
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_unsupported_login(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
"type": "m.login.unsupported",
}));
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let body: serde_json::Value = response.json();
assert_eq!(body["errcode"], "M_UNRECOGNIZED");
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_login_token_login(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.add(&mut state.rng(), &state.clock, "alice".to_owned())
.await
.unwrap();
repo.save().await.unwrap();
let mxid = state.homeserver_connection.mxid(&user.username);
state
.homeserver_connection
.provision_user(&ProvisionRequest::new(mxid, &user.sub))
.await
.unwrap();
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
"type": "m.login.token",
"token": "someinvalidtoken",
}));
let response = state.request(request).await;
response.assert_status(StatusCode::FORBIDDEN);
let body: serde_json::Value = response.json();
assert_eq!(body["errcode"], "M_FORBIDDEN");
let (device, token) = get_login_token(&state, &user).await;
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
"type": "m.login.token",
"token": token,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: ResponseBody = response.json();
assert!(!body.access_token.is_empty());
assert_eq!(body.device_id, device);
assert_eq!(body.user_id, "@alice:example.com");
assert_eq!(body.refresh_token, None);
assert_eq!(body.expires_in_ms, None);
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
"type": "m.login.token",
"token": token,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::FORBIDDEN);
let body: serde_json::Value = response.json();
assert_eq!(body["errcode"], "M_FORBIDDEN");
let (_device, token) = get_login_token(&state, &user).await;
state
.clock
.advance(Duration::microseconds(60 * 1000 * 1000));
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
"type": "m.login.token",
"token": token,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::FORBIDDEN);
let body: serde_json::Value = response.json();
assert_eq!(body["errcode"], "M_FORBIDDEN");
}
async fn get_login_token(state: &TestState, user: &User) -> (Device, String) {
let mut repo = state.repository().await.unwrap();
let token = Alphanumeric.sample_string(&mut state.rng(), 32);
let device = Device::generate(&mut state.rng());
let login = repo
.compat_sso_login()
.add(
&mut state.rng(),
&state.clock,
token.clone(),
"http://example.com/".parse().unwrap(),
)
.await
.unwrap();
let compat_session = repo
.compat_session()
.add(
&mut state.rng(),
&state.clock,
user,
device.clone(),
None,
false,
)
.await
.unwrap();
repo.compat_sso_login()
.fulfill(&state.clock, login, &compat_session)
.await
.unwrap();
repo.save().await.unwrap();
(device, token)
}
}