1use std::time::Duration;
16
17use http::{
18 header::{CONTENT_TYPE, ETAG, EXPIRES, IF_MATCH, IF_NONE_MATCH, LAST_MODIFIED},
19 HeaderMap, HeaderName, Method, StatusCode,
20};
21use ruma::api::{
22 error::{FromHttpResponseError, HeaderDeserializationError, IntoHttpError, MatrixError},
23 EndpointError,
24};
25use tracing::{debug, instrument, trace};
26use url::Url;
27
28use crate::{http_client::HttpClient, HttpError, RumaApiError};
29
30const TEXT_PLAIN_CONTENT_TYPE: &str = "text/plain";
31#[cfg(test)]
32const POLL_TIMEOUT: Duration = Duration::from_millis(10);
33#[cfg(not(test))]
34const POLL_TIMEOUT: Duration = Duration::from_secs(1);
35
36type Etag = String;
37
38fn get_header(
40 header_map: &HeaderMap,
41 header_name: &HeaderName,
42) -> Result<String, FromHttpResponseError<RumaApiError>> {
43 let header = header_map
44 .get(header_name)
45 .ok_or(HeaderDeserializationError::MissingHeader(ETAG.to_string()))?;
46
47 let header = header.to_str()?.to_owned();
48
49 Ok(header)
50}
51
52pub(super) struct InboundChannelCreationResult {
54 pub channel: RendezvousChannel,
56 #[allow(dead_code)]
61 pub initial_message: Vec<u8>,
62}
63
64struct RendezvousGetResponse {
65 pub status_code: StatusCode,
66 pub etag: String,
67 #[allow(dead_code)]
71 pub expires: String,
72 #[allow(dead_code)]
73 pub last_modified: String,
74 pub content_type: Option<String>,
75 pub body: Vec<u8>,
76}
77
78struct RendezvousMessage {
79 pub status_code: StatusCode,
80 pub body: Vec<u8>,
81 pub content_type: String,
82}
83
84pub(super) struct RendezvousChannel {
85 client: HttpClient,
86 rendezvous_url: Url,
87 etag: Etag,
88}
89
90fn response_to_error(status: StatusCode, body: Vec<u8>) -> HttpError {
91 match http::Response::builder().status(status).body(body).map_err(IntoHttpError::from) {
92 Ok(response) => {
93 let error = FromHttpResponseError::<RumaApiError>::Server(RumaApiError::Other(
94 MatrixError::from_http_response(response),
95 ));
96
97 error.into()
98 }
99 Err(e) => HttpError::IntoHttp(e),
100 }
101}
102
103impl RendezvousChannel {
104 #[cfg(test)]
110 pub(super) async fn create_outbound(
111 client: HttpClient,
112 rendezvous_server: &Url,
113 ) -> Result<Self, HttpError> {
114 use ruma::api::client::rendezvous::create_rendezvous_session;
115
116 let request = create_rendezvous_session::unstable::Request::default();
117 let response = client
118 .send(request, None, rendezvous_server.to_string(), None, &[], Default::default())
119 .await?;
120
121 let rendezvous_url = response.url;
122 let etag = response.etag;
123
124 Ok(Self { client, rendezvous_url, etag })
125 }
126
127 pub(super) async fn create_inbound(
132 client: HttpClient,
133 rendezvous_url: &Url,
134 ) -> Result<InboundChannelCreationResult, HttpError> {
135 let response = Self::receive_message_impl(&client.inner, None, rendezvous_url).await?;
138
139 let etag = response.etag.clone();
140
141 let initial_message = RendezvousMessage {
142 status_code: response.status_code,
143 body: response.body,
144 content_type: response.content_type.unwrap_or_else(|| "text/plain".to_owned()),
145 };
146
147 let channel = Self { client, rendezvous_url: rendezvous_url.clone(), etag };
148
149 Ok(InboundChannelCreationResult { channel, initial_message: initial_message.body })
150 }
151
152 pub(super) fn rendezvous_url(&self) -> &Url {
155 &self.rendezvous_url
156 }
157
158 #[instrument(skip_all)]
163 pub(super) async fn send(&mut self, message: Vec<u8>) -> Result<(), HttpError> {
164 let etag = self.etag.clone();
165
166 let request = self
167 .client
168 .inner
169 .request(Method::PUT, self.rendezvous_url().to_owned())
170 .body(message)
171 .header(IF_MATCH, etag)
172 .header(CONTENT_TYPE, TEXT_PLAIN_CONTENT_TYPE);
173
174 debug!("Sending a request to the rendezvous channel {request:?}");
175
176 let response = request.send().await?;
177 let status = response.status();
178
179 debug!("Response for the rendezvous sending request {response:?}");
180
181 if status.is_success() {
182 let etag = get_header(response.headers(), &ETAG)?;
185 self.etag = etag;
186
187 Ok(())
188 } else {
189 let body = response.bytes().await?;
190 let error = response_to_error(status, body.to_vec());
191
192 return Err(error);
193 }
194 }
195
196 pub(super) async fn receive(&mut self) -> Result<Vec<u8>, HttpError> {
205 loop {
206 let message = self.receive_single_message().await?;
207
208 trace!(
209 status_code = %message.status_code,
210 "Received data from the rendezvous channel"
211 );
212
213 if message.status_code == StatusCode::OK
214 && message.content_type == TEXT_PLAIN_CONTENT_TYPE
215 && !message.body.is_empty()
216 {
217 return Ok(message.body);
218 } else if message.status_code == StatusCode::NOT_MODIFIED {
219 tokio::time::sleep(POLL_TIMEOUT).await;
220 continue;
221 } else {
222 let error = response_to_error(message.status_code, message.body);
223
224 return Err(error);
225 }
226 }
227 }
228
229 #[instrument]
230 async fn receive_message_impl(
231 client: &reqwest::Client,
232 etag: Option<String>,
233 rendezvous_url: &Url,
234 ) -> Result<RendezvousGetResponse, HttpError> {
235 let mut builder = client.request(Method::GET, rendezvous_url.to_owned());
236
237 if let Some(etag) = etag {
238 builder = builder.header(IF_NONE_MATCH, etag);
239 }
240
241 let response = builder.send().await?;
242
243 debug!("Received data from the rendezvous channel {response:?}");
244
245 let status_code = response.status();
246 let headers = response.headers();
247
248 let etag = get_header(headers, &ETAG)?;
249 let expires = get_header(headers, &EXPIRES)?;
250 let last_modified = get_header(headers, &LAST_MODIFIED)?;
251 let content_type = response
252 .headers()
253 .get(CONTENT_TYPE)
254 .map(|c| c.to_str().map_err(FromHttpResponseError::<RumaApiError>::from))
255 .transpose()?
256 .map(ToOwned::to_owned);
257
258 let body = response.bytes().await?.to_vec();
259
260 let response =
261 RendezvousGetResponse { status_code, etag, expires, last_modified, content_type, body };
262
263 Ok(response)
264 }
265
266 async fn receive_single_message(&mut self) -> Result<RendezvousMessage, HttpError> {
267 let etag = Some(self.etag.clone());
268
269 let RendezvousGetResponse { status_code, etag, content_type, body, .. } =
270 Self::receive_message_impl(&self.client.inner, etag, &self.rendezvous_url).await?;
271
272 self.etag = etag;
274
275 let message = RendezvousMessage {
276 status_code,
277 body,
278 content_type: content_type.unwrap_or_else(|| "text/plain".to_owned()),
279 };
280
281 Ok(message)
282 }
283}
284
285#[cfg(test)]
286mod test {
287 use matrix_sdk_test::async_test;
288 use serde_json::json;
289 use similar_asserts::assert_eq;
290 use wiremock::{
291 matchers::{header, method, path},
292 Mock, MockServer, ResponseTemplate,
293 };
294
295 use super::*;
296 use crate::config::RequestConfig;
297
298 async fn mock_rendzvous_create(server: &MockServer, rendezvous_url: &Url) {
299 server
300 .register(
301 Mock::given(method("POST"))
302 .and(path("/_matrix/client/unstable/org.matrix.msc4108/rendezvous"))
303 .respond_with(
304 ResponseTemplate::new(200)
305 .append_header("X-Max-Bytes", "10240")
306 .append_header("ETag", "1")
307 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
308 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
309 .set_body_json(json!({
310 "url": rendezvous_url,
311 })),
312 ),
313 )
314 .await;
315 }
316
317 #[async_test]
318 async fn test_creation() {
319 let server = MockServer::start().await;
320 let url =
321 Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
322 let rendezvous_url =
323 url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
324
325 mock_rendzvous_create(&server, &rendezvous_url).await;
326
327 let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
328
329 let mut alice = RendezvousChannel::create_outbound(client, &url)
330 .await
331 .expect("We should be able to create an outbound rendezvous channel");
332
333 assert_eq!(
334 alice.rendezvous_url(),
335 &rendezvous_url,
336 "Alice should have configured the rendezvous URL correctly."
337 );
338
339 assert_eq!(alice.etag, "1", "Alice should have remembered the ETAG the server gave us.");
340
341 let mut bob = {
342 let _scope = server
343 .register_as_scoped(
344 Mock::given(method("GET")).and(path("/abcdEFG12345")).respond_with(
345 ResponseTemplate::new(200)
346 .append_header("Content-Type", "text/plain")
347 .append_header("ETag", "2")
348 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
349 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
350 ),
351 )
352 .await;
353
354 let client = HttpClient::new(reqwest::Client::new(), RequestConfig::short_retry());
355 let InboundChannelCreationResult { channel: bob, initial_message: _ } =
356 RendezvousChannel::create_inbound(client, &rendezvous_url).await.expect(
357 "We should be able to create a rendezvous channel from a received message",
358 );
359
360 assert_eq!(alice.rendezvous_url(), bob.rendezvous_url());
361
362 bob
363 };
364
365 assert_eq!(bob.etag, "2", "Bob should have remembered the ETAG the server gave us.");
366
367 {
368 let _scope = server
369 .register_as_scoped(
370 Mock::given(method("GET"))
371 .and(path("/abcdEFG12345"))
372 .and(header("if-none-match", "1"))
373 .respond_with(
374 ResponseTemplate::new(304)
375 .append_header("ETag", "1")
376 .append_header("Content-Type", "text/plain")
377 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
378 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
379 ),
380 )
381 .await;
382
383 let response = alice
384 .receive_single_message()
385 .await
386 .expect("We should be able to wait for data on the rendezvous channel.");
387 assert_eq!(response.status_code, StatusCode::NOT_MODIFIED);
388 }
389
390 {
391 let _scope = server
392 .register_as_scoped(
393 Mock::given(method("PUT"))
394 .and(path("/abcdEFG12345"))
395 .and(header("Content-Type", "text/plain"))
396 .respond_with(
397 ResponseTemplate::new(200)
398 .append_header("ETag", "1")
399 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
400 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
401 ),
402 )
403 .await;
404
405 bob.send(b"Hello world".to_vec())
406 .await
407 .expect("We should be able to send data to the rendezouvs server.");
408 }
409
410 {
411 let _scope = server
412 .register_as_scoped(
413 Mock::given(method("GET"))
414 .and(path("/abcdEFG12345"))
415 .and(header("if-none-match", "1"))
416 .respond_with(
417 ResponseTemplate::new(200)
418 .append_header("ETag", "3")
419 .append_header("Content-Type", "text/plain")
420 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
421 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
422 .set_body_string("Hello world"),
423 ),
424 )
425 .await;
426
427 let response = alice
428 .receive_single_message()
429 .await
430 .expect("We should be able to wait and get data on the rendezvous channel.");
431
432 assert_eq!(response.status_code, StatusCode::OK);
433 assert_eq!(response.body, b"Hello world");
434 assert_eq!(response.content_type, TEXT_PLAIN_CONTENT_TYPE);
435 }
436 }
437
438 #[async_test]
439 async fn test_retry_mechanism() {
440 let server = MockServer::start().await;
441 let url =
442 Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
443 let rendezvous_url =
444 url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
445 mock_rendzvous_create(&server, &rendezvous_url).await;
446
447 let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
448
449 let mut alice = RendezvousChannel::create_outbound(client, &url)
450 .await
451 .expect("We should be able to create an outbound rendezvous channel");
452
453 server
454 .register(
455 Mock::given(method("GET"))
456 .and(path("/abcdEFG12345"))
457 .and(header("if-none-match", "1"))
458 .respond_with(
459 ResponseTemplate::new(304)
460 .append_header("ETag", "2")
461 .append_header("Content-Type", "text/plain")
462 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
463 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
464 .set_body_string(""),
465 )
466 .expect(1),
467 )
468 .await;
469
470 server
471 .register(
472 Mock::given(method("GET"))
473 .and(path("/abcdEFG12345"))
474 .and(header("if-none-match", "2"))
475 .respond_with(
476 ResponseTemplate::new(200)
477 .append_header("ETag", "3")
478 .append_header("Content-Type", "text/plain")
479 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
480 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
481 .set_body_string("Hello world"),
482 )
483 .expect(1),
484 )
485 .await;
486
487 let response = alice
488 .receive()
489 .await
490 .expect("We should be able to wait and get data on the rendezvous channel.");
491
492 assert_eq!(response, b"Hello world");
493 }
494
495 #[async_test]
496 async fn test_receive_error() {
497 let server = MockServer::start().await;
498 let url =
499 Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
500 let rendezvous_url =
501 url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
502 mock_rendzvous_create(&server, &rendezvous_url).await;
503
504 let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
505
506 let mut alice = RendezvousChannel::create_outbound(client, &url)
507 .await
508 .expect("We should be able to create an outbound rendezvous channel");
509
510 {
511 let _scope = server
512 .register_as_scoped(
513 Mock::given(method("GET"))
514 .and(path("/abcdEFG12345"))
515 .and(header("if-none-match", "1"))
516 .respond_with(
517 ResponseTemplate::new(404)
518 .append_header("ETag", "1")
519 .append_header("Content-Type", "text/plain")
520 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
521 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
522 .set_body_string(""),
523 )
524 .expect(1),
525 )
526 .await;
527
528 alice.receive().await.expect_err("We should return an error if we receive a 404");
529 }
530
531 {
532 let _scope = server
533 .register_as_scoped(
534 Mock::given(method("GET"))
535 .and(path("/abcdEFG12345"))
536 .and(header("if-none-match", "1"))
537 .respond_with(
538 ResponseTemplate::new(504)
539 .append_header("ETag", "1")
540 .append_header("Content-Type", "text/plain")
541 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
542 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
543 .set_body_json(json!({
544 "errcode": "M_NOT_FOUND",
545 "error": "No resource was found for this request.",
546 })),
547 )
548 .expect(1),
549 )
550 .await;
551
552 alice
553 .receive()
554 .await
555 .expect_err("We should return an error if we receive a gateway timeout");
556 }
557 }
558}