1use std::collections::{BTreeSet, HashMap};
20
21pub use language_tags;
22use language_tags::LanguageTag;
23use matrix_sdk_base::deserialized_responses::PrivOwnedStr;
24use oauth2::{AsyncHttpClient, ClientId, HttpClientError, RequestTokenError};
25use ruma::{
26 api::client::discovery::get_authorization_server_metadata::msc2965::{GrantType, ResponseType},
27 serde::{PartialEqAsRefStr, Raw, StringEnum},
28 SecondsSinceUnixEpoch,
29};
30use serde::{ser::SerializeMap, Deserialize, Serialize};
31use url::Url;
32
33use super::{
34 error::OAuthClientRegistrationError,
35 http_client::{check_http_response_json_content_type, check_http_response_status_code},
36 OAuthHttpClient,
37};
38
39#[tracing::instrument(skip_all, fields(registration_endpoint))]
56pub(super) async fn register_client(
57 http_client: &OAuthHttpClient,
58 registration_endpoint: &Url,
59 client_metadata: &Raw<ClientMetadata>,
60) -> Result<ClientRegistrationResponse, OAuthClientRegistrationError> {
61 tracing::debug!("Registering client...");
62
63 let body =
64 serde_json::to_vec(client_metadata).map_err(OAuthClientRegistrationError::IntoJson)?;
65 let request = http::Request::post(registration_endpoint.as_str())
66 .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.to_string())
67 .body(body)
68 .map_err(|err| RequestTokenError::Request(HttpClientError::Http(err)))?;
69
70 let response = http_client.call(request).await.map_err(RequestTokenError::Request)?;
71
72 check_http_response_status_code(&response)?;
73 check_http_response_json_content_type(&response)?;
74
75 let response = serde_json::from_slice(&response.into_body())
76 .map_err(OAuthClientRegistrationError::FromJson)?;
77
78 Ok(response)
79}
80
81#[derive(Debug, Clone, Deserialize)]
85pub struct ClientRegistrationResponse {
86 pub client_id: ClientId,
88
89 pub client_id_issued_at: Option<SecondsSinceUnixEpoch>,
91}
92
93#[derive(Debug, Clone, Serialize)]
106#[serde(into = "ClientMetadataSerializeHelper")]
107pub struct ClientMetadata {
108 pub application_type: ApplicationType,
110
111 pub grant_types: Vec<OAuthGrantType>,
115
116 pub client_uri: Localized<Url>,
118
119 pub client_name: Option<Localized<String>>,
121
122 pub logo_uri: Option<Localized<Url>>,
124
125 pub policy_uri: Option<Localized<Url>>,
128
129 pub tos_uri: Option<Localized<Url>>,
132}
133
134impl ClientMetadata {
135 pub fn new(
137 application_type: ApplicationType,
138 grant_types: Vec<OAuthGrantType>,
139 client_uri: Localized<Url>,
140 ) -> Self {
141 Self {
142 application_type,
143 grant_types,
144 client_uri,
145 client_name: None,
146 logo_uri: None,
147 policy_uri: None,
148 tos_uri: None,
149 }
150 }
151}
152
153#[derive(Debug, Clone)]
159#[non_exhaustive]
160pub enum OAuthGrantType {
161 AuthorizationCode {
168 redirect_uris: Vec<Url>,
170 },
171
172 DeviceCode,
179}
180
181#[derive(Clone, StringEnum, PartialEqAsRefStr, Eq)]
183#[ruma_enum(rename_all = "lowercase")]
184#[non_exhaustive]
185pub enum ApplicationType {
186 Web,
191
192 Native,
196
197 #[doc(hidden)]
198 _Custom(PrivOwnedStr),
199}
200
201#[derive(Debug, Clone, PartialEq, Eq)]
205pub struct Localized<T> {
206 non_localized: T,
207 localized: HashMap<LanguageTag, T>,
208}
209
210impl<T> Localized<T> {
211 pub fn new(non_localized: T, localized: impl IntoIterator<Item = (LanguageTag, T)>) -> Self {
214 Self { non_localized, localized: localized.into_iter().collect() }
215 }
216
217 pub fn non_localized(&self) -> &T {
219 &self.non_localized
220 }
221
222 pub fn get(&self, language: Option<&LanguageTag>) -> Option<&T> {
224 match language {
225 Some(lang) => self.localized.get(lang),
226 None => Some(&self.non_localized),
227 }
228 }
229}
230
231impl<T> From<(T, HashMap<LanguageTag, T>)> for Localized<T> {
232 fn from(t: (T, HashMap<LanguageTag, T>)) -> Self {
233 Localized { non_localized: t.0, localized: t.1 }
234 }
235}
236
237#[derive(Serialize)]
238struct ClientMetadataSerializeHelper {
239 #[serde(skip_serializing_if = "Vec::is_empty")]
240 redirect_uris: Vec<Url>,
241 token_endpoint_auth_method: &'static str,
242 grant_types: BTreeSet<GrantType>,
243 #[serde(skip_serializing_if = "Vec::is_empty")]
244 response_types: Vec<ResponseType>,
245 application_type: ApplicationType,
246 #[serde(flatten)]
247 localized: ClientMetadataLocalizedFields,
248}
249
250impl From<ClientMetadata> for ClientMetadataSerializeHelper {
251 fn from(value: ClientMetadata) -> Self {
252 let ClientMetadata {
253 application_type,
254 grant_types: oauth_grant_types,
255 client_uri,
256 client_name,
257 logo_uri,
258 policy_uri,
259 tos_uri,
260 } = value;
261
262 let mut redirect_uris = None;
263 let mut response_types = None;
264 let mut grant_types = BTreeSet::new();
265
266 grant_types.insert(GrantType::RefreshToken);
268
269 for oauth_grant_type in oauth_grant_types {
270 match oauth_grant_type {
271 OAuthGrantType::AuthorizationCode { redirect_uris: uris } => {
272 redirect_uris = Some(uris);
273 response_types = Some(vec![ResponseType::Code]);
274 grant_types.insert(GrantType::AuthorizationCode);
275 }
276 OAuthGrantType::DeviceCode => {
277 grant_types.insert(GrantType::DeviceCode);
278 }
279 }
280 }
281
282 ClientMetadataSerializeHelper {
283 redirect_uris: redirect_uris.unwrap_or_default(),
284 token_endpoint_auth_method: "none",
286 grant_types,
287 response_types: response_types.unwrap_or_default(),
288 application_type,
289 localized: ClientMetadataLocalizedFields {
290 client_uri,
291 client_name,
292 logo_uri,
293 policy_uri,
294 tos_uri,
295 },
296 }
297 }
298}
299
300struct ClientMetadataLocalizedFields {
305 client_uri: Localized<Url>,
306 client_name: Option<Localized<String>>,
307 logo_uri: Option<Localized<Url>>,
308 policy_uri: Option<Localized<Url>>,
309 tos_uri: Option<Localized<Url>>,
310}
311
312impl Serialize for ClientMetadataLocalizedFields {
313 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
314 where
315 S: serde::Serializer,
316 {
317 fn serialize_localized_into_map<M: SerializeMap, T: Serialize>(
318 map: &mut M,
319 field_name: &str,
320 value: &Localized<T>,
321 ) -> Result<(), M::Error> {
322 map.serialize_entry(field_name, &value.non_localized)?;
323
324 for (lang, localized) in &value.localized {
325 map.serialize_entry(&format!("{field_name}#{lang}"), localized)?;
326 }
327
328 Ok(())
329 }
330
331 let mut map = serializer.serialize_map(None)?;
332
333 serialize_localized_into_map(&mut map, "client_uri", &self.client_uri)?;
334
335 if let Some(client_name) = &self.client_name {
336 serialize_localized_into_map(&mut map, "client_name", client_name)?;
337 }
338
339 if let Some(logo_uri) = &self.logo_uri {
340 serialize_localized_into_map(&mut map, "logo_uri", logo_uri)?;
341 }
342
343 if let Some(policy_uri) = &self.policy_uri {
344 serialize_localized_into_map(&mut map, "policy_uri", policy_uri)?;
345 }
346
347 if let Some(tos_uri) = &self.tos_uri {
348 serialize_localized_into_map(&mut map, "tos_uri", tos_uri)?;
349 }
350
351 map.end()
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use language_tags::LanguageTag;
358 use serde_json::json;
359 use url::Url;
360
361 use super::{ApplicationType, ClientMetadata, Localized, OAuthGrantType};
362
363 #[test]
364 fn test_serialize_minimal_client_metadata() {
365 let metadata = ClientMetadata::new(
366 ApplicationType::Native,
367 vec![OAuthGrantType::AuthorizationCode {
368 redirect_uris: vec![Url::parse("http://127.0.0.1/").unwrap()],
369 }],
370 Localized::new(
371 Url::parse("https://github.com/matrix-org/matrix-rust-sdk").unwrap(),
372 [],
373 ),
374 );
375
376 assert_eq!(
377 serde_json::to_value(metadata).unwrap(),
378 json!({
379 "application_type": "native",
380 "grant_types": ["authorization_code", "refresh_token"],
381 "response_types": ["code"],
382 "token_endpoint_auth_method": "none",
383 "redirect_uris": ["http://127.0.0.1/"],
384 "client_uri": "https://github.com/matrix-org/matrix-rust-sdk",
385 }),
386 );
387 }
388
389 #[test]
390 fn test_serialize_full_client_metadata() {
391 let lang_fr = LanguageTag::parse("fr").unwrap();
392 let lang_mas = LanguageTag::parse("mas").unwrap();
393
394 let mut metadata = ClientMetadata::new(
395 ApplicationType::Web,
396 vec![
397 OAuthGrantType::AuthorizationCode {
398 redirect_uris: vec![
399 Url::parse("http://127.0.0.1/").unwrap(),
400 Url::parse("http://[::1]/").unwrap(),
401 ],
402 },
403 OAuthGrantType::DeviceCode,
404 ],
405 Localized::new(
406 Url::parse("https://example.org/matrix-client").unwrap(),
407 [
408 (lang_fr.clone(), Url::parse("https://example.org/fr/matrix-client").unwrap()),
409 (
410 lang_mas.clone(),
411 Url::parse("https://example.org/mas/matrix-client").unwrap(),
412 ),
413 ],
414 ),
415 );
416
417 metadata.client_name = Some(Localized::new(
418 "My Matrix client".to_owned(),
419 [(lang_fr.clone(), "Mon client Matrix".to_owned())],
420 ));
421 metadata.logo_uri =
422 Some(Localized::new(Url::parse("https://example.org/logo.svg").unwrap(), []));
423 metadata.policy_uri = Some(Localized::new(
424 Url::parse("https://example.org/policy").unwrap(),
425 [
426 (lang_fr.clone(), Url::parse("https://example.org/fr/policy").unwrap()),
427 (lang_mas.clone(), Url::parse("https://example.org/mas/policy").unwrap()),
428 ],
429 ));
430 metadata.tos_uri = Some(Localized::new(
431 Url::parse("https://example.org/tos").unwrap(),
432 [
433 (lang_fr, Url::parse("https://example.org/fr/tos").unwrap()),
434 (lang_mas, Url::parse("https://example.org/mas/tos").unwrap()),
435 ],
436 ));
437
438 assert_eq!(
439 serde_json::to_value(metadata).unwrap(),
440 json!({
441 "application_type": "web",
442 "grant_types": [
443 "authorization_code",
444 "refresh_token",
445 "urn:ietf:params:oauth:grant-type:device_code",
446 ],
447 "response_types": ["code"],
448 "token_endpoint_auth_method": "none",
449 "redirect_uris": ["http://127.0.0.1/", "http://[::1]/"],
450 "client_uri": "https://example.org/matrix-client",
451 "client_uri#fr": "https://example.org/fr/matrix-client",
452 "client_uri#mas": "https://example.org/mas/matrix-client",
453 "client_name": "My Matrix client",
454 "client_name#fr": "Mon client Matrix",
455 "logo_uri": "https://example.org/logo.svg",
456 "policy_uri": "https://example.org/policy",
457 "policy_uri#fr": "https://example.org/fr/policy",
458 "policy_uri#mas": "https://example.org/mas/policy",
459 "tos_uri": "https://example.org/tos",
460 "tos_uri#fr": "https://example.org/fr/tos",
461 "tos_uri#mas": "https://example.org/mas/tos",
462 }),
463 );
464 }
465}