use std::{
future::Future,
pin::Pin,
sync::{atomic::AtomicBool, Arc},
task::{Context, Poll},
time::Duration,
};
use event_listener::{Event, EventListener};
use futures_util::{stream::SelectAll, Stream, StreamExt};
use hyper::{Request, Response};
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto::Connection,
service::TowerToHyperService,
};
use pin_project_lite::pin_project;
use thiserror::Error;
use tokio_rustls::rustls::ServerConfig;
use tower::Service;
use tower_http::add_extension::AddExtension;
use tracing::Instrument;
use crate::{
maybe_tls::{MaybeTlsAcceptor, MaybeTlsStream, TlsStreamInfo},
proxy_protocol::{MaybeProxyAcceptor, ProxyAcceptError},
rewind::Rewind,
unix_or_tcp::{SocketAddr, UnixOrTcpConnection, UnixOrTcpListener},
ConnectionInfo,
};
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(5);
pub struct Server<S> {
tls: Option<Arc<ServerConfig>>,
proxy: bool,
listener: UnixOrTcpListener,
service: S,
}
impl<S> Server<S> {
pub fn try_new<L>(listener: L, service: S) -> Result<Self, L::Error>
where
L: TryInto<UnixOrTcpListener>,
{
Ok(Self {
tls: None,
proxy: false,
listener: listener.try_into()?,
service,
})
}
#[must_use]
pub fn new(listener: impl Into<UnixOrTcpListener>, service: S) -> Self {
Self {
tls: None,
proxy: false,
listener: listener.into(),
service,
}
}
#[must_use]
pub const fn with_proxy(mut self) -> Self {
self.proxy = true;
self
}
#[must_use]
pub fn with_tls(mut self, config: Arc<ServerConfig>) -> Self {
self.tls = Some(config);
self
}
pub async fn run<B, SD>(self, shutdown: SD)
where
S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: std::error::Error + Send + Sync + 'static,
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: std::error::Error + Send + Sync + 'static,
SD: Stream + Unpin,
SD::Item: std::fmt::Display,
{
run_servers(std::iter::once(self), shutdown).await;
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
enum AcceptError {
#[error("failed to accept connection from the underlying socket")]
Socket {
#[source]
source: std::io::Error,
},
#[error("failed to complete the TLS handshake")]
TlsHandshake {
#[source]
source: std::io::Error,
},
#[error("failed to complete the proxy protocol handshake")]
ProxyHandshake {
#[source]
source: ProxyAcceptError,
},
#[error("connection handshake timed out")]
HandshakeTimeout {
#[source]
source: tokio::time::error::Elapsed,
},
}
impl AcceptError {
fn socket(source: std::io::Error) -> Self {
Self::Socket { source }
}
fn tls_handshake(source: std::io::Error) -> Self {
Self::TlsHandshake { source }
}
fn proxy_handshake(source: ProxyAcceptError) -> Self {
Self::ProxyHandshake { source }
}
fn handshake_timeout(source: tokio::time::error::Elapsed) -> Self {
Self::HandshakeTimeout { source }
}
}
#[allow(clippy::type_complexity)]
#[tracing::instrument(
name = "accept",
skip_all,
fields(
network.protocol.name = "http",
network.peer.address,
network.peer.port,
),
err,
)]
async fn accept<S, B>(
maybe_proxy_acceptor: &MaybeProxyAcceptor,
maybe_tls_acceptor: &MaybeTlsAcceptor,
peer_addr: SocketAddr,
stream: UnixOrTcpConnection,
service: S,
) -> Result<
Connection<
'static,
TokioIo<MaybeTlsStream<Rewind<UnixOrTcpConnection>>>,
TowerToHyperService<AddExtension<S, ConnectionInfo>>,
TokioExecutor,
>,
AcceptError,
>
where
S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Send + Clone + 'static,
S::Error: std::error::Error + Send + Sync + 'static,
S::Future: Send + 'static,
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: std::error::Error + Send + Sync + 'static,
{
let span = tracing::Span::current();
match peer_addr {
SocketAddr::Net(addr) => {
span.record("network.peer.address", tracing::field::display(addr.ip()));
span.record("network.peer.port", addr.port());
}
SocketAddr::Unix(ref addr) => {
span.record("network.peer.address", tracing::field::debug(addr));
}
}
tokio::time::timeout(HANDSHAKE_TIMEOUT, async move {
let (proxy, stream) = maybe_proxy_acceptor
.accept(stream)
.await
.map_err(AcceptError::proxy_handshake)?;
let stream = maybe_tls_acceptor
.accept(stream)
.await
.map_err(AcceptError::tls_handshake)?;
let tls = stream.tls_info();
let is_h2 = tls.as_ref().map_or(false, TlsStreamInfo::is_alpn_h2);
let info = ConnectionInfo {
tls,
proxy,
net_peer_addr: peer_addr.into_net(),
};
let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
if is_h2 {
builder = builder.http2_only();
}
builder.http1().keep_alive(true);
let service = TowerToHyperService::new(AddExtension::new(service, info));
let conn = builder
.serve_connection(TokioIo::new(stream), service)
.into_owned();
Ok(conn)
})
.instrument(span)
.await
.map_err(AcceptError::handshake_timeout)?
}
pin_project! {
struct AbortableConnection<C> {
#[pin]
connection: C,
#[pin]
shutdown_listener: EventListener,
shutdown_in_progress: Arc<AtomicBool>,
did_start_shutdown: bool,
}
}
impl<C> AbortableConnection<C> {
fn new(connection: C, shutdown_in_progress: &Arc<AtomicBool>, event: &Arc<Event>) -> Self {
Self {
connection,
shutdown_listener: event.listen(),
shutdown_in_progress: Arc::clone(shutdown_in_progress),
did_start_shutdown: false,
}
}
}
impl<T, S, B> Future
for AbortableConnection<Connection<'static, T, TowerToHyperService<S>, TokioExecutor>>
where
Connection<'static, T, TowerToHyperService<S>, TokioExecutor>: Future,
S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Send + Clone + 'static,
S::Future: Send + 'static,
S::Error: std::error::Error + Send + Sync,
T: hyper::rt::Read + hyper::rt::Write + Unpin,
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: std::error::Error + Send + Sync + 'static,
{
type Output = <Connection<'static, T, TowerToHyperService<S>, TokioExecutor> as Future>::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
let _ = this.shutdown_listener.poll(cx);
if !*this.did_start_shutdown
&& this
.shutdown_in_progress
.load(std::sync::atomic::Ordering::Relaxed)
{
*this.did_start_shutdown = true;
this.connection.as_mut().graceful_shutdown();
}
this.connection.poll(cx)
}
}
#[allow(clippy::too_many_lines)]
pub async fn run_servers<S, B, SD>(listeners: impl IntoIterator<Item = Server<S>>, mut shutdown: SD)
where
S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: std::error::Error + Send + Sync + 'static,
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: std::error::Error + Send + Sync + 'static,
SD: Stream + Unpin,
SD::Item: std::fmt::Display,
{
let mut accept_stream: SelectAll<_> = listeners
.into_iter()
.map(|server| {
let maybe_proxy_acceptor = MaybeProxyAcceptor::new(server.proxy);
let maybe_tls_acceptor = MaybeTlsAcceptor::new(server.tls);
futures_util::stream::poll_fn(move |cx| {
let res =
std::task::ready!(server.listener.poll_accept(cx)).map(|(addr, stream)| {
(
maybe_proxy_acceptor,
maybe_tls_acceptor.clone(),
server.service.clone(),
addr,
stream,
)
});
Poll::Ready(Some(res))
})
})
.collect();
let mut accept_tasks = tokio::task::JoinSet::new();
let mut connection_tasks = tokio::task::JoinSet::new();
let shutdown_in_progress = Arc::new(AtomicBool::new(false));
let shutdown_event = Arc::new(Event::new());
loop {
tokio::select! {
biased;
res = shutdown.next() => {
let why = res.map_or_else(|| String::from("???"), |why| format!("{why}"));
tracing::info!("Received shutdown signal ({why})");
break;
},
res = accept_tasks.join_next(), if !accept_tasks.is_empty() => {
match res {
Some(Ok(Ok(connection))) => {
tracing::trace!("Accepted connection");
let conn = AbortableConnection::new(connection, &shutdown_in_progress, &shutdown_event);
connection_tasks.spawn(conn);
},
Some(Ok(Err(_e))) => { },
Some(Err(e)) => tracing::error!("Join error: {e}"),
None => tracing::error!("Join set was polled even though it was empty"),
}
},
res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
match res {
Some(Ok(Ok(()))) => tracing::trace!("Connection finished"),
Some(Ok(Err(e))) => tracing::error!("Error while serving connection: {e}"),
Some(Err(e)) => tracing::error!("Join error: {e}"),
None => tracing::error!("Join set was polled even though it was empty"),
}
},
res = accept_stream.next(), if !accept_stream.is_empty() => {
let Some(res) = res else { unreachable!() };
accept_tasks.spawn(async move {
let (maybe_proxy_acceptor, maybe_tls_acceptor, service, peer_addr, stream) = res
.map_err(AcceptError::socket)?;
accept(&maybe_proxy_acceptor, &maybe_tls_acceptor, peer_addr, stream, service).await
});
},
};
}
shutdown_in_progress.store(true, std::sync::atomic::Ordering::Relaxed);
shutdown_event.notify(usize::MAX);
if !accept_tasks.is_empty() || !connection_tasks.is_empty() {
tracing::info!(
"There are {active} active connections ({pending} pending), performing a graceful shutdown. Send the shutdown signal again to force.",
active = connection_tasks.len(),
pending = accept_tasks.len(),
);
while !accept_tasks.is_empty() || !connection_tasks.is_empty() {
tokio::select! {
biased;
res = accept_tasks.join_next(), if !accept_tasks.is_empty() => {
match res {
Some(Ok(Ok(connection))) => {
tracing::trace!("Accepted connection");
let conn = AbortableConnection::new(connection, &shutdown_in_progress, &shutdown_event);
connection_tasks.spawn(conn);
}
Some(Ok(Err(_e))) => { },
Some(Err(e)) => tracing::error!("Join error: {e}"),
None => tracing::error!("Join set was polled even though it was empty"),
}
},
res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
match res {
Some(Ok(Ok(()))) => tracing::trace!("Connection finished"),
Some(Ok(Err(e))) => tracing::error!("Error while serving connection: {e}"),
Some(Err(e)) => tracing::error!("Join error: {e}"),
None => tracing::error!("Join set was polled even though it was empty"),
}
},
res = shutdown.next() => {
let why = res.map_or_else(|| String::from("???"), |why| format!("{why}"));
tracing::warn!(
"Received shutdown signal again ({why}), forcing shutdown ({active} active connections, {pending} pending connections)",
active = connection_tasks.len(),
pending = accept_tasks.len(),
);
break;
},
}
}
}
accept_tasks.shutdown().await;
connection_tasks.shutdown().await;
tracing::info!("Shutdown complete");
}