about summary refs log tree commit diff
path: root/kittybox-rs/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'kittybox-rs/src/main.rs')
-rw-r--r--kittybox-rs/src/main.rs54
1 files changed, 36 insertions, 18 deletions
diff --git a/kittybox-rs/src/main.rs b/kittybox-rs/src/main.rs
index cc11f3a..d96a8fb 100644
--- a/kittybox-rs/src/main.rs
+++ b/kittybox-rs/src/main.rs
@@ -1,6 +1,6 @@
 use kittybox::database::FileStorage;
-use std::{env, time::Duration};
-use tracing::{debug, error, info};
+use std::{env, time::Duration, sync::Arc};
+use tracing::error;
 
 fn init_media<A: kittybox::indieauth::backend::AuthBackend>(auth_backend: A, blobstore_uri: &str) -> axum::Router {
     match blobstore_uri.split_once(':').unwrap().0 {
@@ -20,7 +20,8 @@ async fn compose_kittybox_with_auth<A>(
     http: reqwest::Client,
     auth_backend: A,
     backend_uri: &str,
-    blobstore_uri: &str
+    blobstore_uri: &str,
+    jobset: &Arc<tokio::sync::Mutex<tokio::task::JoinSet<()>>>
 ) -> axum::Router
 where A: kittybox::indieauth::backend::AuthBackend
 {
@@ -58,10 +59,11 @@ where A: kittybox::indieauth::backend::AuthBackend
             let micropub = kittybox::micropub::router(
                 database.clone(),
                 http.clone(),
-                auth_backend.clone()
+                auth_backend.clone(),
+                Arc::clone(jobset)
             );
             let onboarding = kittybox::frontend::onboarding::router(
-                database.clone(), http.clone()
+                database.clone(), http.clone(), Arc::clone(jobset)
             );
 
             axum::Router::new()
@@ -111,13 +113,14 @@ where A: kittybox::indieauth::backend::AuthBackend
             let micropub = kittybox::micropub::router(
                 database.clone(),
                 http.clone(),
-                auth_backend.clone()
+                auth_backend.clone(),
+                Arc::clone(jobset)
             );
             let onboarding = kittybox::frontend::onboarding::router(
-                database.clone(), http.clone()
+                database.clone(), http.clone(), Arc::clone(jobset)
             );
 
-            axum::Router::new()
+            let router = axum::Router::new()
                 .route("/", homepage)
                 .fallback(fallback)
                 .route("/.kittybox/micropub", micropub)
@@ -128,7 +131,9 @@ where A: kittybox::indieauth::backend::AuthBackend
                     "/.kittybox/health",
                     axum::routing::get(health_check::<kittybox::database::PostgresStorage>)
                         .layer(axum::Extension(database))
-                )
+                );
+
+            router
         },
         other => unimplemented!("Unsupported backend: {other}")
     }
@@ -138,6 +143,7 @@ async fn compose_kittybox(
     backend_uri: &str,
     blobstore_uri: &str,
     authstore_uri: &str,
+    jobset: &Arc<tokio::sync::Mutex<tokio::task::JoinSet<()>>>,
     cancellation_token: &tokio_util::sync::CancellationToken
 ) -> axum::Router {
     let http: reqwest::Client = {
@@ -162,7 +168,7 @@ async fn compose_kittybox(
                 kittybox::indieauth::backend::fs::FileBackend::new(folder)
             };
 
-            compose_kittybox_with_auth(http, auth_backend, backend_uri, blobstore_uri).await
+            compose_kittybox_with_auth(http, auth_backend, backend_uri, blobstore_uri, jobset).await
         }
         other => unimplemented!("Unsupported backend: {other}")
     };
@@ -214,7 +220,7 @@ async fn main() {
         .init();
     //let _ = tracing_log::LogTracer::init();
 
-    info!("Starting the kittybox server...");
+    tracing::info!("Starting the kittybox server...");
 
     let backend_uri: String = env::var("BACKEND_URI")
         .unwrap_or_else(|_| {
@@ -237,11 +243,13 @@ async fn main() {
         });
 
     let cancellation_token = tokio_util::sync::CancellationToken::new();
+    let jobset = Arc::new(tokio::sync::Mutex::new(tokio::task::JoinSet::new()));
 
     let router = compose_kittybox(
         backend_uri.as_str(),
         blobstore_uri.as_str(),
         authstore_uri.as_str(),
+        &jobset,
         &cancellation_token
     ).await;
 
@@ -315,7 +323,9 @@ async fn main() {
             std::net::TcpListener::bind(listen_addr).unwrap()
         }))
     }
-
+    // Drop the remaining copy of the router
+    // to get rid of an extra reference to `jobset`
+    drop(router);
     // Polling streams mutates them
     let mut servers_futures = Box::pin(servers.into_iter()
         .map(
@@ -356,7 +366,7 @@ async fn main() {
     };
     use futures_util::stream::StreamExt;
 
-    tokio::select! {
+    let exitcode: i32 = tokio::select! {
         // Poll the servers stream for errors.
         // If any error out, shut down the entire operation
         //
@@ -370,17 +380,25 @@ async fn main() {
                 servers_futures.iter_mut().collect::<Vec<_>>()
             )).await;
 
-            std::process::exit(1);
+            1
         }
         _ = shutdown_signal => {
-            info!("Shutdown requested by signal.");
+            tracing::info!("Shutdown requested by signal.");
             cancellation_token.cancel();
             let _ = Box::pin(futures_util::future::join_all(
                 servers_futures.iter_mut().collect::<Vec<_>>()
             )).await;
 
-            info!("Shutdown complete, exiting.");
-            std::process::exit(0);
+            0
         }
-    }
+    };
+
+    tracing::info!("Waiting for unfinished background tasks...");
+    let mut jobset: tokio::task::JoinSet<()> = Arc::try_unwrap(jobset)
+        .expect("Dangling jobset references present")
+        .into_inner();
+    while (jobset.join_next().await).is_some() {}
+    tracing::info!("Shutdown complete, exiting.");
+    std::process::exit(exitcode);
+
 }