1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
// Copyright 2024 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::convert::Infallible;

use aide::OperationIo;
use axum::{
    extract::FromRequestParts,
    response::{IntoResponse, Response},
    Json,
};
use axum_extra::TypedHeader;
use headers::{authorization::Bearer, Authorization};
use hyper::StatusCode;
use mas_data_model::{Session, User};
use mas_storage::{BoxClock, BoxRepository, RepositoryError};
use ulid::Ulid;

use super::response::ErrorResponse;
use crate::BoundActivityTracker;

#[derive(Debug, thiserror::Error)]
pub enum Rejection {
    /// The authorization header is missing
    #[error("Missing authorization header")]
    MissingAuthorizationHeader,

    /// The authorization header is invalid
    #[error("Invalid authorization header")]
    InvalidAuthorizationHeader,

    /// Couldn't load the database repository
    #[error("Couldn't load the database repository")]
    RepositorySetup(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),

    /// A database operation failed
    #[error("Invalid repository operation")]
    Repository(#[from] RepositoryError),

    /// The access token could not be found in the database
    #[error("Unknown access token")]
    UnknownAccessToken,

    /// The access token provided expired
    #[error("Access token expired")]
    TokenExpired,

    /// The session associated with the access token was revoked
    #[error("Access token revoked")]
    SessionRevoked,

    /// The user associated with the session is locked
    #[error("User locked")]
    UserLocked,

    /// Failed to load the session
    #[error("Failed to load session {0}")]
    LoadSession(Ulid),

    /// Failed to load the user
    #[error("Failed to load user {0}")]
    LoadUser(Ulid),

    /// The session does not have the `urn:mas:admin` scope
    #[error("Missing urn:mas:admin scope")]
    MissingScope,
}

impl Rejection {
    fn status_code(&self) -> StatusCode {
        match self {
            Self::InvalidAuthorizationHeader | Self::MissingAuthorizationHeader => {
                StatusCode::BAD_REQUEST
            }
            Self::UnknownAccessToken
            | Self::TokenExpired
            | Self::SessionRevoked
            | Self::UserLocked
            | Self::MissingScope => StatusCode::UNAUTHORIZED,
            _ => StatusCode::INTERNAL_SERVER_ERROR,
        }
    }
}

impl IntoResponse for Rejection {
    fn into_response(self) -> Response {
        let response = ErrorResponse::from_error(&self);
        let status = self.status_code();
        (status, Json(response)).into_response()
    }
}

/// An extractor which authorizes the request
///
/// Because we need to load the database repository and the clock, we keep them
/// in the context to avoid creating two instances for each request.
#[non_exhaustive]
#[derive(OperationIo)]
#[aide(input)]
pub struct CallContext {
    pub repo: BoxRepository,
    pub clock: BoxClock,
    pub user: Option<User>,
    pub session: Session,
}

#[async_trait::async_trait]
impl<S> FromRequestParts<S> for CallContext
where
    S: Send + Sync,
    BoundActivityTracker: FromRequestParts<S, Rejection = Infallible>,
    BoxRepository: FromRequestParts<S>,
    BoxClock: FromRequestParts<S, Rejection = Infallible>,
    <BoxRepository as FromRequestParts<S>>::Rejection:
        Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
    type Rejection = Rejection;

    async fn from_request_parts(
        parts: &mut axum::http::request::Parts,
        state: &S,
    ) -> Result<Self, Self::Rejection> {
        let activity_tracker = BoundActivityTracker::from_request_parts(parts, state).await;
        let activity_tracker = match activity_tracker {
            Ok(t) => t,
            Err(e) => match e {},
        };

        let clock = BoxClock::from_request_parts(parts, state).await;
        let clock = match clock {
            Ok(c) => c,
            Err(e) => match e {},
        };

        // Load the database repository
        let mut repo = BoxRepository::from_request_parts(parts, state)
            .await
            .map_err(Into::into)
            .map_err(Rejection::RepositorySetup)?;

        // Extract the access token from the authorization header
        let token = TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
            .await
            .map_err(|e| {
                // We map to two differentsson of errors depending on whether the header is
                // missing or invalid
                if e.is_missing() {
                    Rejection::MissingAuthorizationHeader
                } else {
                    Rejection::InvalidAuthorizationHeader
                }
            })?;

        let token = token.token();

        // Look for the access token in the database
        let token = repo
            .oauth2_access_token()
            .find_by_token(token)
            .await?
            .ok_or(Rejection::UnknownAccessToken)?;

        // Look for the associated session in the database
        let session = repo
            .oauth2_session()
            .lookup(token.session_id)
            .await?
            .ok_or_else(|| Rejection::LoadSession(token.session_id))?;

        // Record the activity on the session
        activity_tracker
            .record_oauth2_session(&clock, &session)
            .await;

        // Load the user if there is one
        let user = if let Some(user_id) = session.user_id {
            let user = repo
                .user()
                .lookup(user_id)
                .await?
                .ok_or_else(|| Rejection::LoadUser(user_id))?;
            Some(user)
        } else {
            None
        };

        // If there is a user for this session, check that it is not locked
        if let Some(user) = &user {
            if !user.is_valid() {
                return Err(Rejection::UserLocked);
            }
        }

        if !session.is_valid() {
            return Err(Rejection::SessionRevoked);
        }

        if !token.is_valid(clock.now()) {
            return Err(Rejection::TokenExpired);
        }

        // For now, we only check that the session has the admin scope
        // Later we might want to check other route-specific scopes
        if !session.scope.contains("urn:mas:admin") {
            return Err(Rejection::MissingScope);
        }

        Ok(Self {
            repo,
            clock,
            user,
            session,
        })
    }
}