From 4ec0638847afc63b4b5166317a8bde1c27915503 Mon Sep 17 00:00:00 2001 From: Vika Date: Sat, 15 Jul 2023 22:28:40 +0300 Subject: Allow listening on several TCP sockets I would also love to be able to listen on Unix stream sockets, but that would require some additional support that can thankfully be just introduced later. (It also requires a second loop over the file descriptor array) --- kittybox-rs/src/main.rs | 163 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 117 insertions(+), 46 deletions(-) diff --git a/kittybox-rs/src/main.rs b/kittybox-rs/src/main.rs index 7131200..c588ad8 100644 --- a/kittybox-rs/src/main.rs +++ b/kittybox-rs/src/main.rs @@ -134,7 +134,12 @@ where A: kittybox::indieauth::backend::AuthBackend } } -async fn compose_kittybox(backend_uri: &str, blobstore_uri: &str, authstore_uri: &str) -> axum::Router { +async fn compose_kittybox( + backend_uri: &str, + blobstore_uri: &str, + authstore_uri: &str, + cancellation_token: &tokio_util::sync::CancellationToken +) -> axum::Router { let http: reqwest::Client = { #[allow(unused_mut)] let mut builder = reqwest::Client::builder().user_agent(concat!( @@ -169,9 +174,7 @@ async fn compose_kittybox(backend_uri: &str, blobstore_uri: &str, authstore_uri: ) .route("/.kittybox/coffee", teapot_route()) .nest("/.kittybox/micropub/client", kittybox::companion::router()) - .layer(tower::ServiceBuilder::new() - .layer(tower_http::trace::TraceLayer::new_for_http()) - .into_inner()) + .layer(tower_http::trace::TraceLayer::new_for_http()) .layer(tower_http::catch_panic::CatchPanicLayer::new()) } @@ -230,63 +233,131 @@ async fn main() { std::process::exit(1); }); - let listen_addr = env::var("SERVE_AT") - .ok() - .unwrap_or_else(|| "[::]:8080".to_string()) - .parse::() - .unwrap_or_else(|e| { - error!("Cannot parse SERVE_AT: {}", e); std::process::exit(1); }); + let cancellation_token = tokio_util::sync::CancellationToken::new(); + let router = compose_kittybox( backend_uri.as_str(), blobstore_uri.as_str(), - authstore_uri.as_str() + authstore_uri.as_str(), + &cancellation_token ).await; - // A little dance to turn a potential file descriptor into - // a guaranteed async network socket - let tcp_listener: std::net::TcpListener = { - let mut listenfd = listenfd::ListenFd::from_env(); + let mut servers: Vec> = vec![]; - let tcp_listener = if let Ok(Some(listener)) = listenfd.take_tcp_listener(0) { - listener - } else { - std::net::TcpListener::bind(listen_addr).unwrap() - }; + let build_hyper = |tcp: std::net::TcpListener| { + tracing::info!("Listening on {}", tcp.local_addr().unwrap()); // Set the socket to non-blocking so tokio can poll it // properly -- this is the async magic! - tcp_listener.set_nonblocking(true).unwrap(); + tcp.set_nonblocking(true).unwrap(); - tcp_listener + hyper::server::Server::from_tcp(tcp).unwrap() + // Otherwise Chrome keeps connections open for too long + .tcp_keepalive(Some(Duration::from_secs(30 * 60))) + .serve(router.clone().into_make_service()) }; - info!("Listening on {}", tcp_listener.local_addr().unwrap()); - - let server = hyper::server::Server::from_tcp(tcp_listener) - .unwrap() - // Otherwise Chrome keeps connections open for too long - .tcp_keepalive(Some(Duration::from_secs(30 * 60))) - .serve(router.into_make_service()) - .with_graceful_shutdown(async move { - // Defer to C-c handler whenever we're not on Unix - // TODO consider using a diverging future here - #[cfg(not(unix))] - return tokio::signal::ctrl_c().await.unwrap(); - #[cfg(unix)] - { - use tokio::signal::unix::{signal, SignalKind}; - - signal(SignalKind::terminate()) - .unwrap() - .recv() - .await + + let mut listenfd = listenfd::ListenFd::from_env(); + for i in 0..(listenfd.len()) { + match listenfd.take_tcp_listener(i) { + Ok(Some(tcp)) => servers.push(build_hyper(tcp)), + Ok(None) => {}, + Err(err) => { + tracing::error!("Error binding to socket in fd {}: {}", i, err); + } + } + } + if servers.is_empty() { + servers.push(build_hyper({ + let listen_addr = env::var("SERVE_AT") + .ok() + .unwrap_or_else(|| "[::]:8080".to_string()) + .parse::() + .unwrap_or_else(|e| { + error!("Cannot parse SERVE_AT: {}", e); + std::process::exit(1); + }); + + std::net::TcpListener::bind(listen_addr).unwrap() + })) + } + + // Polling streams mutates them + let mut servers_futures = Box::pin(servers.into_iter() + .map( + #[cfg(not(tokio_unstable))] |server| tokio::task::spawn( + server.with_graceful_shutdown(cancellation_token.clone().cancelled_owned()) + ), + #[cfg(tokio_unstable)] |server| { + tokio::task::Builder::new() + // We leak the String here. It is acceptable, as the string + // is reasonably small and needs to live forever. + .name({ + let name = format!("Kittybox HTTP acceptor: {}", server.local_addr()); + + // Polyfill for unstablized [`String::leak`] + // + // SAFETY: the bytes come from a [`String`], which is valid UTF-8. + unsafe { std::str::from_utf8_unchecked(name.into_bytes().leak()) } + }) + .spawn( + server.with_graceful_shutdown( + cancellation_token.clone().cancelled_owned() + ) + ) .unwrap() } - }); + ) + .collect::>>>() + ); + + #[cfg(not(unix))] + let shutdown_signal = tokio::signal::ctrl_c(); + #[cfg(unix)] + let shutdown_signal = { + use tokio::signal::unix::{signal, SignalKind}; + + async move { + let mut interrupt = signal(SignalKind::interrupt()) + .expect("Failed to set up SIGINT handler"); + let mut terminate = signal(SignalKind::terminate()) + .expect("Failed to setup SIGTERM handler"); - if let Err(err) = server.await { - error!("Error serving requests: {}", err); - std::process::exit(1); + tokio::select! { + _ = terminate.recv() => {}, + _ = interrupt.recv() => {}, + } + } + }; + use futures_util::stream::StreamExt; + + tokio::select! { + // Poll the servers stream for errors. + // If any error out, shut down the entire operation + // + // We do this because there might not be a good way + // to recover from some errors without external help + Some(Err(e)) = servers_futures.next() => { + tracing::error!("Error in HTTP server: {}", e); + tracing::error!("Shutting down because of error."); + cancellation_token.cancel(); + let _ = Box::pin(futures_util::future::join_all( + servers_futures.iter_mut().collect::>() + )).await; + + std::process::exit(1); + } + _ = shutdown_signal => { + info!("Shutdown requested by signal."); + cancellation_token.cancel(); + let _ = Box::pin(futures_util::future::join_all( + servers_futures.iter_mut().collect::>() + )).await; + + info!("Shutdown complete, exiting."); + std::process::exit(0); + } } } -- cgit 1.4.1