matrix_sdk/authentication/qrcode/
rendezvous_channel.rs

1// Copyright 2024 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::time::Duration;
16
17use http::{
18    header::{CONTENT_TYPE, ETAG, EXPIRES, IF_MATCH, IF_NONE_MATCH, LAST_MODIFIED},
19    HeaderMap, HeaderName, Method, StatusCode,
20};
21use ruma::api::{
22    error::{FromHttpResponseError, HeaderDeserializationError, IntoHttpError, MatrixError},
23    EndpointError,
24};
25use tracing::{debug, instrument, trace};
26use url::Url;
27
28use crate::{http_client::HttpClient, HttpError, RumaApiError};
29
30const TEXT_PLAIN_CONTENT_TYPE: &str = "text/plain";
31#[cfg(test)]
32const POLL_TIMEOUT: Duration = Duration::from_millis(10);
33#[cfg(not(test))]
34const POLL_TIMEOUT: Duration = Duration::from_secs(1);
35
36type Etag = String;
37
38/// Get a header from a [`HeaderMap`] and parse it as a UTF-8 string.
39fn get_header(
40    header_map: &HeaderMap,
41    header_name: &HeaderName,
42) -> Result<String, FromHttpResponseError<RumaApiError>> {
43    let header = header_map
44        .get(header_name)
45        .ok_or(HeaderDeserializationError::MissingHeader(ETAG.to_string()))?;
46
47    let header = header.to_str()?.to_owned();
48
49    Ok(header)
50}
51
52/// The result of the [`RendezvousChannel::create_inbound()`] method.
53pub(super) struct InboundChannelCreationResult {
54    /// The connected [`RendezvousChannel`].
55    pub channel: RendezvousChannel,
56    /// The initial message we received when we connected to the
57    /// [`RendezvousChannel`].
58    ///
59    /// This is currently unused, but left in for completeness sake.
60    #[allow(dead_code)]
61    pub initial_message: Vec<u8>,
62}
63
64struct RendezvousGetResponse {
65    pub status_code: StatusCode,
66    pub etag: String,
67    // TODO: This is currently unused, but will be required once we implement the reciprocation of
68    // a login. Left here so we don't forget about it. We should put this into the
69    // [`RendezvousChannel`] struct, once we parse it into a [`SystemTime`].
70    #[allow(dead_code)]
71    pub expires: String,
72    #[allow(dead_code)]
73    pub last_modified: String,
74    pub content_type: Option<String>,
75    pub body: Vec<u8>,
76}
77
78struct RendezvousMessage {
79    pub status_code: StatusCode,
80    pub body: Vec<u8>,
81    pub content_type: String,
82}
83
84pub(super) struct RendezvousChannel {
85    client: HttpClient,
86    rendezvous_url: Url,
87    etag: Etag,
88}
89
90fn response_to_error(status: StatusCode, body: Vec<u8>) -> HttpError {
91    match http::Response::builder().status(status).body(body).map_err(IntoHttpError::from) {
92        Ok(response) => {
93            let error = FromHttpResponseError::<RumaApiError>::Server(RumaApiError::Other(
94                MatrixError::from_http_response(response),
95            ));
96
97            error.into()
98        }
99        Err(e) => HttpError::IntoHttp(e),
100    }
101}
102
103impl RendezvousChannel {
104    /// Create a new outbound [`RendezvousChannel`].
105    ///
106    /// By outbound we mean that we're going to tell the Matrix server to create
107    /// a new rendezvous session. We're going to send an initial empty message
108    /// through the channel.
109    #[cfg(test)]
110    pub(super) async fn create_outbound(
111        client: HttpClient,
112        rendezvous_server: &Url,
113    ) -> Result<Self, HttpError> {
114        use ruma::api::client::rendezvous::create_rendezvous_session;
115
116        let request = create_rendezvous_session::unstable::Request::default();
117        let response = client
118            .send(request, None, rendezvous_server.to_string(), None, &[], Default::default())
119            .await?;
120
121        let rendezvous_url = response.url;
122        let etag = response.etag;
123
124        Ok(Self { client, rendezvous_url, etag })
125    }
126
127    /// Create a new inbound [`RendezvousChannel`].
128    ///
129    /// By inbound we mean that we're going to attempt to read an initial
130    /// message from the rendezvous session on the given [`rendezvous_url`].
131    pub(super) async fn create_inbound(
132        client: HttpClient,
133        rendezvous_url: &Url,
134    ) -> Result<InboundChannelCreationResult, HttpError> {
135        // Receive the initial message, which should be empty. But we need the ETAG to
136        // fully establish the rendezvous channel.
137        let response = Self::receive_message_impl(&client.inner, None, rendezvous_url).await?;
138
139        let etag = response.etag.clone();
140
141        let initial_message = RendezvousMessage {
142            status_code: response.status_code,
143            body: response.body,
144            content_type: response.content_type.unwrap_or_else(|| "text/plain".to_owned()),
145        };
146
147        let channel = Self { client, rendezvous_url: rendezvous_url.clone(), etag };
148
149        Ok(InboundChannelCreationResult { channel, initial_message: initial_message.body })
150    }
151
152    /// Get the URL of the rendezvous session we're using to exchange messages
153    /// through the channel.
154    pub(super) fn rendezvous_url(&self) -> &Url {
155        &self.rendezvous_url
156    }
157
158    /// Send the given `message` through the [`RendezvousChannel`] to the other
159    /// device.
160    ///
161    /// The message must be of the `text/plain` content type.
162    #[instrument(skip_all)]
163    pub(super) async fn send(&mut self, message: Vec<u8>) -> Result<(), HttpError> {
164        let etag = self.etag.clone();
165
166        let request = self
167            .client
168            .inner
169            .request(Method::PUT, self.rendezvous_url().to_owned())
170            .body(message)
171            .header(IF_MATCH, etag)
172            .header(CONTENT_TYPE, TEXT_PLAIN_CONTENT_TYPE);
173
174        debug!("Sending a request to the rendezvous channel {request:?}");
175
176        let response = request.send().await?;
177        let status = response.status();
178
179        debug!("Response for the rendezvous sending request {response:?}");
180
181        if status.is_success() {
182            // We successfully send out a message, get the ETAG and update our internal copy
183            // of the ETAG.
184            let etag = get_header(response.headers(), &ETAG)?;
185            self.etag = etag;
186
187            Ok(())
188        } else {
189            let body = response.bytes().await?;
190            let error = response_to_error(status, body.to_vec());
191
192            return Err(error);
193        }
194    }
195
196    /// Attempt to receive a message from the [`RendezvousChannel`] from the
197    /// other device.
198    ///
199    /// The content should be of the `text/plain` content type but the parsing
200    /// and verification of this fact is left up to the caller.
201    ///
202    /// This method will wait in a loop for the channel to give us a new
203    /// message.
204    pub(super) async fn receive(&mut self) -> Result<Vec<u8>, HttpError> {
205        loop {
206            let message = self.receive_single_message().await?;
207
208            trace!(
209                status_code = %message.status_code,
210                "Received data from the rendezvous channel"
211            );
212
213            if message.status_code == StatusCode::OK
214                && message.content_type == TEXT_PLAIN_CONTENT_TYPE
215                && !message.body.is_empty()
216            {
217                return Ok(message.body);
218            } else if message.status_code == StatusCode::NOT_MODIFIED {
219                tokio::time::sleep(POLL_TIMEOUT).await;
220                continue;
221            } else {
222                let error = response_to_error(message.status_code, message.body);
223
224                return Err(error);
225            }
226        }
227    }
228
229    #[instrument]
230    async fn receive_message_impl(
231        client: &reqwest::Client,
232        etag: Option<String>,
233        rendezvous_url: &Url,
234    ) -> Result<RendezvousGetResponse, HttpError> {
235        let mut builder = client.request(Method::GET, rendezvous_url.to_owned());
236
237        if let Some(etag) = etag {
238            builder = builder.header(IF_NONE_MATCH, etag);
239        }
240
241        let response = builder.send().await?;
242
243        debug!("Received data from the rendezvous channel {response:?}");
244
245        let status_code = response.status();
246        let headers = response.headers();
247
248        let etag = get_header(headers, &ETAG)?;
249        let expires = get_header(headers, &EXPIRES)?;
250        let last_modified = get_header(headers, &LAST_MODIFIED)?;
251        let content_type = response
252            .headers()
253            .get(CONTENT_TYPE)
254            .map(|c| c.to_str().map_err(FromHttpResponseError::<RumaApiError>::from))
255            .transpose()?
256            .map(ToOwned::to_owned);
257
258        let body = response.bytes().await?.to_vec();
259
260        let response =
261            RendezvousGetResponse { status_code, etag, expires, last_modified, content_type, body };
262
263        Ok(response)
264    }
265
266    async fn receive_single_message(&mut self) -> Result<RendezvousMessage, HttpError> {
267        let etag = Some(self.etag.clone());
268
269        let RendezvousGetResponse { status_code, etag, content_type, body, .. } =
270            Self::receive_message_impl(&self.client.inner, etag, &self.rendezvous_url).await?;
271
272        // We received a response with an ETAG, put it into the copy of our etag.
273        self.etag = etag;
274
275        let message = RendezvousMessage {
276            status_code,
277            body,
278            content_type: content_type.unwrap_or_else(|| "text/plain".to_owned()),
279        };
280
281        Ok(message)
282    }
283}
284
285#[cfg(test)]
286mod test {
287    use matrix_sdk_test::async_test;
288    use serde_json::json;
289    use similar_asserts::assert_eq;
290    use wiremock::{
291        matchers::{header, method, path},
292        Mock, MockServer, ResponseTemplate,
293    };
294
295    use super::*;
296    use crate::config::RequestConfig;
297
298    async fn mock_rendzvous_create(server: &MockServer, rendezvous_url: &Url) {
299        server
300            .register(
301                Mock::given(method("POST"))
302                    .and(path("/_matrix/client/unstable/org.matrix.msc4108/rendezvous"))
303                    .respond_with(
304                        ResponseTemplate::new(200)
305                            .append_header("X-Max-Bytes", "10240")
306                            .append_header("ETag", "1")
307                            .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
308                            .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
309                            .set_body_json(json!({
310                                "url": rendezvous_url,
311                            })),
312                    ),
313            )
314            .await;
315    }
316
317    #[async_test]
318    async fn test_creation() {
319        let server = MockServer::start().await;
320        let url =
321            Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
322        let rendezvous_url =
323            url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
324
325        mock_rendzvous_create(&server, &rendezvous_url).await;
326
327        let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
328
329        let mut alice = RendezvousChannel::create_outbound(client, &url)
330            .await
331            .expect("We should be able to create an outbound rendezvous channel");
332
333        assert_eq!(
334            alice.rendezvous_url(),
335            &rendezvous_url,
336            "Alice should have configured the rendezvous URL correctly."
337        );
338
339        assert_eq!(alice.etag, "1", "Alice should have remembered the ETAG the server gave us.");
340
341        let mut bob = {
342            let _scope = server
343                .register_as_scoped(
344                    Mock::given(method("GET")).and(path("/abcdEFG12345")).respond_with(
345                        ResponseTemplate::new(200)
346                            .append_header("Content-Type", "text/plain")
347                            .append_header("ETag", "2")
348                            .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
349                            .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
350                    ),
351                )
352                .await;
353
354            let client = HttpClient::new(reqwest::Client::new(), RequestConfig::short_retry());
355            let InboundChannelCreationResult { channel: bob, initial_message: _ } =
356                RendezvousChannel::create_inbound(client, &rendezvous_url).await.expect(
357                    "We should be able to create a rendezvous channel from a received message",
358                );
359
360            assert_eq!(alice.rendezvous_url(), bob.rendezvous_url());
361
362            bob
363        };
364
365        assert_eq!(bob.etag, "2", "Bob should have remembered the ETAG the server gave us.");
366
367        {
368            let _scope = server
369                .register_as_scoped(
370                    Mock::given(method("GET"))
371                        .and(path("/abcdEFG12345"))
372                        .and(header("if-none-match", "1"))
373                        .respond_with(
374                            ResponseTemplate::new(304)
375                                .append_header("ETag", "1")
376                                .append_header("Content-Type", "text/plain")
377                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
378                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
379                        ),
380                )
381                .await;
382
383            let response = alice
384                .receive_single_message()
385                .await
386                .expect("We should be able to wait for data on the rendezvous channel.");
387            assert_eq!(response.status_code, StatusCode::NOT_MODIFIED);
388        }
389
390        {
391            let _scope = server
392                .register_as_scoped(
393                    Mock::given(method("PUT"))
394                        .and(path("/abcdEFG12345"))
395                        .and(header("Content-Type", "text/plain"))
396                        .respond_with(
397                            ResponseTemplate::new(200)
398                                .append_header("ETag", "1")
399                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
400                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
401                        ),
402                )
403                .await;
404
405            bob.send(b"Hello world".to_vec())
406                .await
407                .expect("We should be able to send data to the rendezouvs server.");
408        }
409
410        {
411            let _scope = server
412                .register_as_scoped(
413                    Mock::given(method("GET"))
414                        .and(path("/abcdEFG12345"))
415                        .and(header("if-none-match", "1"))
416                        .respond_with(
417                            ResponseTemplate::new(200)
418                                .append_header("ETag", "3")
419                                .append_header("Content-Type", "text/plain")
420                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
421                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
422                                .set_body_string("Hello world"),
423                        ),
424                )
425                .await;
426
427            let response = alice
428                .receive_single_message()
429                .await
430                .expect("We should be able to wait and get data on the rendezvous channel.");
431
432            assert_eq!(response.status_code, StatusCode::OK);
433            assert_eq!(response.body, b"Hello world");
434            assert_eq!(response.content_type, TEXT_PLAIN_CONTENT_TYPE);
435        }
436    }
437
438    #[async_test]
439    async fn test_retry_mechanism() {
440        let server = MockServer::start().await;
441        let url =
442            Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
443        let rendezvous_url =
444            url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
445        mock_rendzvous_create(&server, &rendezvous_url).await;
446
447        let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
448
449        let mut alice = RendezvousChannel::create_outbound(client, &url)
450            .await
451            .expect("We should be able to create an outbound rendezvous channel");
452
453        server
454            .register(
455                Mock::given(method("GET"))
456                    .and(path("/abcdEFG12345"))
457                    .and(header("if-none-match", "1"))
458                    .respond_with(
459                        ResponseTemplate::new(304)
460                            .append_header("ETag", "2")
461                            .append_header("Content-Type", "text/plain")
462                            .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
463                            .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
464                            .set_body_string(""),
465                    )
466                    .expect(1),
467            )
468            .await;
469
470        server
471            .register(
472                Mock::given(method("GET"))
473                    .and(path("/abcdEFG12345"))
474                    .and(header("if-none-match", "2"))
475                    .respond_with(
476                        ResponseTemplate::new(200)
477                            .append_header("ETag", "3")
478                            .append_header("Content-Type", "text/plain")
479                            .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
480                            .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
481                            .set_body_string("Hello world"),
482                    )
483                    .expect(1),
484            )
485            .await;
486
487        let response = alice
488            .receive()
489            .await
490            .expect("We should be able to wait and get data on the rendezvous channel.");
491
492        assert_eq!(response, b"Hello world");
493    }
494
495    #[async_test]
496    async fn test_receive_error() {
497        let server = MockServer::start().await;
498        let url =
499            Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
500        let rendezvous_url =
501            url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
502        mock_rendzvous_create(&server, &rendezvous_url).await;
503
504        let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
505
506        let mut alice = RendezvousChannel::create_outbound(client, &url)
507            .await
508            .expect("We should be able to create an outbound rendezvous channel");
509
510        {
511            let _scope = server
512                .register_as_scoped(
513                    Mock::given(method("GET"))
514                        .and(path("/abcdEFG12345"))
515                        .and(header("if-none-match", "1"))
516                        .respond_with(
517                            ResponseTemplate::new(404)
518                                .append_header("ETag", "1")
519                                .append_header("Content-Type", "text/plain")
520                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
521                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
522                                .set_body_string(""),
523                        )
524                        .expect(1),
525                )
526                .await;
527
528            alice.receive().await.expect_err("We should return an error if we receive a 404");
529        }
530
531        {
532            let _scope = server
533                .register_as_scoped(
534                    Mock::given(method("GET"))
535                        .and(path("/abcdEFG12345"))
536                        .and(header("if-none-match", "1"))
537                        .respond_with(
538                            ResponseTemplate::new(504)
539                                .append_header("ETag", "1")
540                                .append_header("Content-Type", "text/plain")
541                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
542                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
543                                .set_body_json(json!({
544                                  "errcode": "M_NOT_FOUND",
545                                  "error": "No resource was found for this request.",
546                                })),
547                        )
548                        .expect(1),
549                )
550                .await;
551
552            alice
553                .receive()
554                .await
555                .expect_err("We should return an error if we receive a gateway timeout");
556        }
557    }
558}