about summary refs log tree commit diff
path: root/kittybox-rs
diff options
context:
space:
mode:
Diffstat (limited to 'kittybox-rs')
-rw-r--r--kittybox-rs/src/main.rs163
1 files changed, 117 insertions, 46 deletions
diff --git a/kittybox-rs/src/main.rs b/kittybox-rs/src/main.rs
index 7131200..c588ad8 100644
--- a/kittybox-rs/src/main.rs
+++ b/kittybox-rs/src/main.rs
@@ -134,7 +134,12 @@ where A: kittybox::indieauth::backend::AuthBackend
     }
 }
 
-async fn compose_kittybox(backend_uri: &str, blobstore_uri: &str, authstore_uri: &str) -> axum::Router {
+async fn compose_kittybox(
+    backend_uri: &str,
+    blobstore_uri: &str,
+    authstore_uri: &str,
+    cancellation_token: &tokio_util::sync::CancellationToken
+) -> axum::Router {
     let http: reqwest::Client = {
         #[allow(unused_mut)]
         let mut builder = reqwest::Client::builder().user_agent(concat!(
@@ -169,9 +174,7 @@ async fn compose_kittybox(backend_uri: &str, blobstore_uri: &str, authstore_uri:
         )
         .route("/.kittybox/coffee", teapot_route())
         .nest("/.kittybox/micropub/client", kittybox::companion::router())
-        .layer(tower::ServiceBuilder::new()
-            .layer(tower_http::trace::TraceLayer::new_for_http())
-            .into_inner())
+        .layer(tower_http::trace::TraceLayer::new_for_http())
         .layer(tower_http::catch_panic::CatchPanicLayer::new())
 }
 
@@ -230,63 +233,131 @@ async fn main() {
             std::process::exit(1);
         });
 
-    let listen_addr = env::var("SERVE_AT")
-        .ok()
-        .unwrap_or_else(|| "[::]:8080".to_string())
-        .parse::<std::net::SocketAddr>()
-        .unwrap_or_else(|e| {
-            error!("Cannot parse SERVE_AT: {}", e);
             std::process::exit(1);
         });
 
+    let cancellation_token = tokio_util::sync::CancellationToken::new();
+
     let router = compose_kittybox(
         backend_uri.as_str(),
         blobstore_uri.as_str(),
-        authstore_uri.as_str()
+        authstore_uri.as_str(),
+        &cancellation_token
     ).await;
 
-    // A little dance to turn a potential file descriptor into
-    // a guaranteed async network socket
-    let tcp_listener: std::net::TcpListener = {
-        let mut listenfd = listenfd::ListenFd::from_env();
+    let mut servers: Vec<hyper::server::Server<hyper::server::conn::AddrIncoming, _>> = vec![];
 
-        let tcp_listener = if let Ok(Some(listener)) = listenfd.take_tcp_listener(0) {
-            listener
-        } else {
-            std::net::TcpListener::bind(listen_addr).unwrap()
-        };
+    let build_hyper = |tcp: std::net::TcpListener| {
+        tracing::info!("Listening on {}", tcp.local_addr().unwrap());
         // Set the socket to non-blocking so tokio can poll it
         // properly -- this is the async magic!
-        tcp_listener.set_nonblocking(true).unwrap();
+        tcp.set_nonblocking(true).unwrap();
 
-        tcp_listener
+        hyper::server::Server::from_tcp(tcp).unwrap()
+            // Otherwise Chrome keeps connections open for too long
+            .tcp_keepalive(Some(Duration::from_secs(30 * 60)))
+            .serve(router.clone().into_make_service())
     };
-    info!("Listening on {}", tcp_listener.local_addr().unwrap());
-
-    let server = hyper::server::Server::from_tcp(tcp_listener)
-        .unwrap()
-        // Otherwise Chrome keeps connections open for too long
-        .tcp_keepalive(Some(Duration::from_secs(30 * 60)))
-        .serve(router.into_make_service())
-        .with_graceful_shutdown(async move {
-            // Defer to C-c handler whenever we're not on Unix
-            // TODO consider using a diverging future here
-            #[cfg(not(unix))]
-            return tokio::signal::ctrl_c().await.unwrap();
-            #[cfg(unix)]
-            {
-                use tokio::signal::unix::{signal, SignalKind};
-
-                signal(SignalKind::terminate())
-                    .unwrap()
-                    .recv()
-                    .await
+
+    let mut listenfd = listenfd::ListenFd::from_env();
+    for i in 0..(listenfd.len()) {
+        match listenfd.take_tcp_listener(i) {
+            Ok(Some(tcp)) => servers.push(build_hyper(tcp)),
+            Ok(None) => {},
+            Err(err) => {
+                tracing::error!("Error binding to socket in fd {}: {}", i, err);
+            }
+        }
+    }
+    if servers.is_empty() {
+        servers.push(build_hyper({
+            let listen_addr = env::var("SERVE_AT")
+                .ok()
+                .unwrap_or_else(|| "[::]:8080".to_string())
+                .parse::<std::net::SocketAddr>()
+                .unwrap_or_else(|e| {
+                    error!("Cannot parse SERVE_AT: {}", e);
+                    std::process::exit(1);
+                });
+
+            std::net::TcpListener::bind(listen_addr).unwrap()
+        }))
+    }
+
+    // Polling streams mutates them
+    let mut servers_futures = Box::pin(servers.into_iter()
+        .map(
+            #[cfg(not(tokio_unstable))] |server| tokio::task::spawn(
+                server.with_graceful_shutdown(cancellation_token.clone().cancelled_owned())
+            ),
+            #[cfg(tokio_unstable)] |server| {
+                tokio::task::Builder::new()
+                    // We leak the String here. It is acceptable, as the string
+                    // is reasonably small and needs to live forever.
+                    .name({
+                        let name = format!("Kittybox HTTP acceptor: {}", server.local_addr());
+
+                        // Polyfill for unstablized [`String::leak`]
+                        //
+                        // SAFETY: the bytes come from a [`String`], which is valid UTF-8.
+                        unsafe { std::str::from_utf8_unchecked(name.into_bytes().leak()) }
+                    })
+                    .spawn(
+                        server.with_graceful_shutdown(
+                            cancellation_token.clone().cancelled_owned()
+                        )
+                    )
                     .unwrap()
             }
-        });
+        )
+        .collect::<futures_util::stream::FuturesUnordered<tokio::task::JoinHandle<Result<(), hyper::Error>>>>()
+    );
+
+    #[cfg(not(unix))]
+    let shutdown_signal = tokio::signal::ctrl_c();
+    #[cfg(unix)]
+    let shutdown_signal = {
+        use tokio::signal::unix::{signal, SignalKind};
+
+        async move {
+            let mut interrupt = signal(SignalKind::interrupt())
+                .expect("Failed to set up SIGINT handler");
+            let mut terminate = signal(SignalKind::terminate())
+                .expect("Failed to setup SIGTERM handler");
 
-    if let Err(err) = server.await {
-        error!("Error serving requests: {}", err);
-        std::process::exit(1);
+            tokio::select! {
+                _ = terminate.recv() => {},
+                _ = interrupt.recv() => {},
+            }
+        }
+    };
+    use futures_util::stream::StreamExt;
+
+    tokio::select! {
+        // Poll the servers stream for errors.
+        // If any error out, shut down the entire operation
+        //
+        // We do this because there might not be a good way
+        // to recover from some errors without external help
+        Some(Err(e)) = servers_futures.next() => {
+            tracing::error!("Error in HTTP server: {}", e);
+            tracing::error!("Shutting down because of error.");
+            cancellation_token.cancel();
+            let _ = Box::pin(futures_util::future::join_all(
+                servers_futures.iter_mut().collect::<Vec<_>>()
+            )).await;
+
+            std::process::exit(1);
+        }
+        _ = shutdown_signal => {
+            info!("Shutdown requested by signal.");
+            cancellation_token.cancel();
+            let _ = Box::pin(futures_util::future::join_all(
+                servers_futures.iter_mut().collect::<Vec<_>>()
+            )).await;
+
+            info!("Shutdown complete, exiting.");
+            std::process::exit(0);
+        }
     }
 }