about summary refs log tree commit diff
path: root/kittybox-rs/src
diff options
context:
space:
mode:
authorVika <vika@fireburn.ru>2023-07-17 01:52:09 +0300
committerVika <vika@fireburn.ru>2023-07-17 01:53:42 +0300
commit94ebc5e653191fcaacfa91dddebf88dca7e7b7fe (patch)
tree7cc58973f3a809c14acda21001349114f13d7e40 /kittybox-rs/src
parentb13b4fc66dd069e6d5263a8f6a9cc9a6da798e27 (diff)
downloadkittybox-94ebc5e653191fcaacfa91dddebf88dca7e7b7fe.tar.zst
Put Micropub background processing tasks in a JoinSet
This allows using tree-structured concurrency to keep background tasks
in check and allow them to finish running before shutting down — a
necessary prerequisite for shutdown-on-idle. (A background task may
take a bit too long to complete, and we may need to wait for it.)
Diffstat (limited to 'kittybox-rs/src')
-rw-r--r--kittybox-rs/src/frontend/onboarding.rs18
-rw-r--r--kittybox-rs/src/main.rs54
-rw-r--r--kittybox-rs/src/micropub/mod.rs27
3 files changed, 71 insertions, 28 deletions
diff --git a/kittybox-rs/src/frontend/onboarding.rs b/kittybox-rs/src/frontend/onboarding.rs
index d5cde02..e44e866 100644
--- a/kittybox-rs/src/frontend/onboarding.rs
+++ b/kittybox-rs/src/frontend/onboarding.rs
@@ -1,3 +1,5 @@
+use std::sync::Arc;
+
 use crate::database::{settings, Storage};
 use axum::{
     extract::{Extension, Host},
@@ -7,6 +9,7 @@ use axum::{
 };
 use kittybox_frontend_renderer::{ErrorPage, OnboardingPage, Template};
 use serde::Deserialize;
+use tokio::{task::JoinSet, sync::Mutex};
 use tracing::{debug, error};
 
 use super::FrontendError;
@@ -51,6 +54,7 @@ async fn onboard<D: Storage + 'static>(
     user_uid: url::Url,
     data: OnboardingData,
     http: reqwest::Client,
+    jobset: Arc<Mutex<JoinSet<()>>>,
 ) -> Result<(), FrontendError> {
     // Create a user to pass to the backend
     // At this point the site belongs to nobody, so it is safe to do
@@ -115,7 +119,7 @@ async fn onboard<D: Storage + 'static>(
     }
     let (uid, post) = crate::micropub::normalize_mf2(data.first_post, &user);
     tracing::debug!("Posting first post {}...", uid);
-    crate::micropub::_post(&user, uid, post, db, http)
+    crate::micropub::_post(&user, uid, post, db, http, jobset)
         .await
         .map_err(|e| FrontendError {
             msg: "Error while posting the first post".to_string(),
@@ -130,6 +134,7 @@ pub async fn post<D: Storage + 'static>(
     Extension(db): Extension<D>,
     Host(host): Host,
     Extension(http): Extension<reqwest::Client>,
+    Extension(jobset): Extension<Arc<Mutex<JoinSet<()>>>>,
     Json(data): Json<OnboardingData>,
 ) -> axum::response::Response {
     let user_uid = format!("https://{}/", host.as_str());
@@ -137,7 +142,7 @@ pub async fn post<D: Storage + 'static>(
     if db.post_exists(&user_uid).await.unwrap() {
         IntoResponse::into_response((StatusCode::FOUND, [("Location", "/")]))
     } else {
-        match onboard(db, user_uid.parse().unwrap(), data, http).await {
+        match onboard(db, user_uid.parse().unwrap(), data, http, jobset).await {
             Ok(()) => IntoResponse::into_response((StatusCode::FOUND, [("Location", "/")])),
             Err(err) => {
                 error!("Onboarding error: {}", err);
@@ -163,9 +168,14 @@ pub async fn post<D: Storage + 'static>(
     }
 }
 
-pub fn router<S: Storage + 'static>(database: S, http: reqwest::Client) -> axum::routing::MethodRouter {
+pub fn router<S: Storage + 'static>(
+    database: S,
+    http: reqwest::Client,
+    jobset: Arc<Mutex<JoinSet<()>>>,
+) -> axum::routing::MethodRouter {
     axum::routing::get(get)
         .post(post::<S>)
         .layer::<_, _, std::convert::Infallible>(axum::Extension(database))
-        .layer(axum::Extension(http))
+        .layer::<_, _, std::convert::Infallible>(axum::Extension(http))
+        .layer(axum::Extension(jobset))
 }
diff --git a/kittybox-rs/src/main.rs b/kittybox-rs/src/main.rs
index cc11f3a..d96a8fb 100644
--- a/kittybox-rs/src/main.rs
+++ b/kittybox-rs/src/main.rs
@@ -1,6 +1,6 @@
 use kittybox::database::FileStorage;
-use std::{env, time::Duration};
-use tracing::{debug, error, info};
+use std::{env, time::Duration, sync::Arc};
+use tracing::error;
 
 fn init_media<A: kittybox::indieauth::backend::AuthBackend>(auth_backend: A, blobstore_uri: &str) -> axum::Router {
     match blobstore_uri.split_once(':').unwrap().0 {
@@ -20,7 +20,8 @@ async fn compose_kittybox_with_auth<A>(
     http: reqwest::Client,
     auth_backend: A,
     backend_uri: &str,
-    blobstore_uri: &str
+    blobstore_uri: &str,
+    jobset: &Arc<tokio::sync::Mutex<tokio::task::JoinSet<()>>>
 ) -> axum::Router
 where A: kittybox::indieauth::backend::AuthBackend
 {
@@ -58,10 +59,11 @@ where A: kittybox::indieauth::backend::AuthBackend
             let micropub = kittybox::micropub::router(
                 database.clone(),
                 http.clone(),
-                auth_backend.clone()
+                auth_backend.clone(),
+                Arc::clone(jobset)
             );
             let onboarding = kittybox::frontend::onboarding::router(
-                database.clone(), http.clone()
+                database.clone(), http.clone(), Arc::clone(jobset)
             );
 
             axum::Router::new()
@@ -111,13 +113,14 @@ where A: kittybox::indieauth::backend::AuthBackend
             let micropub = kittybox::micropub::router(
                 database.clone(),
                 http.clone(),
-                auth_backend.clone()
+                auth_backend.clone(),
+                Arc::clone(jobset)
             );
             let onboarding = kittybox::frontend::onboarding::router(
-                database.clone(), http.clone()
+                database.clone(), http.clone(), Arc::clone(jobset)
             );
 
-            axum::Router::new()
+            let router = axum::Router::new()
                 .route("/", homepage)
                 .fallback(fallback)
                 .route("/.kittybox/micropub", micropub)
@@ -128,7 +131,9 @@ where A: kittybox::indieauth::backend::AuthBackend
                     "/.kittybox/health",
                     axum::routing::get(health_check::<kittybox::database::PostgresStorage>)
                         .layer(axum::Extension(database))
-                )
+                );
+
+            router
         },
         other => unimplemented!("Unsupported backend: {other}")
     }
@@ -138,6 +143,7 @@ async fn compose_kittybox(
     backend_uri: &str,
     blobstore_uri: &str,
     authstore_uri: &str,
+    jobset: &Arc<tokio::sync::Mutex<tokio::task::JoinSet<()>>>,
     cancellation_token: &tokio_util::sync::CancellationToken
 ) -> axum::Router {
     let http: reqwest::Client = {
@@ -162,7 +168,7 @@ async fn compose_kittybox(
                 kittybox::indieauth::backend::fs::FileBackend::new(folder)
             };
 
-            compose_kittybox_with_auth(http, auth_backend, backend_uri, blobstore_uri).await
+            compose_kittybox_with_auth(http, auth_backend, backend_uri, blobstore_uri, jobset).await
         }
         other => unimplemented!("Unsupported backend: {other}")
     };
@@ -214,7 +220,7 @@ async fn main() {
         .init();
     //let _ = tracing_log::LogTracer::init();
 
-    info!("Starting the kittybox server...");
+    tracing::info!("Starting the kittybox server...");
 
     let backend_uri: String = env::var("BACKEND_URI")
         .unwrap_or_else(|_| {
@@ -237,11 +243,13 @@ async fn main() {
         });
 
     let cancellation_token = tokio_util::sync::CancellationToken::new();
+    let jobset = Arc::new(tokio::sync::Mutex::new(tokio::task::JoinSet::new()));
 
     let router = compose_kittybox(
         backend_uri.as_str(),
         blobstore_uri.as_str(),
         authstore_uri.as_str(),
+        &jobset,
         &cancellation_token
     ).await;
 
@@ -315,7 +323,9 @@ async fn main() {
             std::net::TcpListener::bind(listen_addr).unwrap()
         }))
     }
-
+    // Drop the remaining copy of the router
+    // to get rid of an extra reference to `jobset`
+    drop(router);
     // Polling streams mutates them
     let mut servers_futures = Box::pin(servers.into_iter()
         .map(
@@ -356,7 +366,7 @@ async fn main() {
     };
     use futures_util::stream::StreamExt;
 
-    tokio::select! {
+    let exitcode: i32 = tokio::select! {
         // Poll the servers stream for errors.
         // If any error out, shut down the entire operation
         //
@@ -370,17 +380,25 @@ async fn main() {
                 servers_futures.iter_mut().collect::<Vec<_>>()
             )).await;
 
-            std::process::exit(1);
+            1
         }
         _ = shutdown_signal => {
-            info!("Shutdown requested by signal.");
+            tracing::info!("Shutdown requested by signal.");
             cancellation_token.cancel();
             let _ = Box::pin(futures_util::future::join_all(
                 servers_futures.iter_mut().collect::<Vec<_>>()
             )).await;
 
-            info!("Shutdown complete, exiting.");
-            std::process::exit(0);
+            0
         }
-    }
+    };
+
+    tracing::info!("Waiting for unfinished background tasks...");
+    let mut jobset: tokio::task::JoinSet<()> = Arc::try_unwrap(jobset)
+        .expect("Dangling jobset references present")
+        .into_inner();
+    while (jobset.join_next().await).is_some() {}
+    tracing::info!("Shutdown complete, exiting.");
+    std::process::exit(exitcode);
+
 }
diff --git a/kittybox-rs/src/micropub/mod.rs b/kittybox-rs/src/micropub/mod.rs
index 04bf0a5..02eee6e 100644
--- a/kittybox-rs/src/micropub/mod.rs
+++ b/kittybox-rs/src/micropub/mod.rs
@@ -1,4 +1,5 @@
 use std::collections::HashMap;
+use std::sync::Arc;
 
 use crate::database::{MicropubChannel, Storage, StorageError};
 use crate::indieauth::backend::AuthBackend;
@@ -11,6 +12,8 @@ use axum::TypedHeader;
 use axum::{http::StatusCode, Extension};
 use serde::{Deserialize, Serialize};
 use serde_json::json;
+use tokio::sync::Mutex;
+use tokio::task::JoinSet;
 use tracing::{debug, error, info, warn};
 use kittybox_indieauth::{Scope, TokenData};
 use kittybox_util::{MicropubError, ErrorType};
@@ -228,6 +231,7 @@ pub(crate) async fn _post<D: 'static + Storage>(
     mf2: serde_json::Value,
     db: D,
     http: reqwest::Client,
+    jobset: Arc<Mutex<JoinSet<()>>>,
 ) -> Result<Response, MicropubError> {
     // Here, we have the following guarantees:
     // - The MF2-JSON document is normalized (guaranteed by normalize_mf2)
@@ -313,7 +317,12 @@ pub(crate) async fn _post<D: 'static + Storage>(
     let reply =
         IntoResponse::into_response((StatusCode::ACCEPTED, [("Location", uid.as_str())]));
 
-    tokio::task::spawn(background_processing(db, mf2, http));
+    #[cfg(not(tokio_unstable))]
+    jobset.lock().await.spawn(background_processing(db, mf2, http));
+    #[cfg(tokio_unstable)]
+    jobset.lock().await.build_task()
+        .name(format!("Kittybox background processing for post {}", uid.as_str()).as_str())
+        .spawn(background_processing(db, mf2, http));
 
     Ok(reply)
 }
@@ -497,6 +506,7 @@ async fn dispatch_body(
 pub(crate) async fn post<D: Storage + 'static, A: AuthBackend>(
     Extension(db): Extension<D>,
     Extension(http): Extension<reqwest::Client>,
+    Extension(jobset): Extension<Arc<Mutex<JoinSet<()>>>>,
     TypedHeader(content_type): TypedHeader<ContentType>,
     user: User<A>,
     body: BodyStream,
@@ -508,7 +518,7 @@ pub(crate) async fn post<D: Storage + 'static, A: AuthBackend>(
         },
         Ok(PostBody::MF2(mf2)) => {
             let (uid, mf2) = normalize_mf2(mf2, &user);
-            match _post(&user, uid, mf2, db, http).await {
+            match _post(&user, uid, mf2, db, http, jobset).await {
                 Ok(response) => response,
                 Err(err) => err.into_response(),
             }
@@ -631,7 +641,8 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>(
 pub fn router<S, A>(
     storage: S,
     http: reqwest::Client,
-    auth: A
+    auth: A,
+    jobset: Arc<Mutex<JoinSet<()>>>
 ) -> axum::routing::MethodRouter
 where
     S: Storage + 'static,
@@ -648,6 +659,7 @@ where
         .layer::<_, _, std::convert::Infallible>(axum::Extension(storage))
         .layer::<_, _, std::convert::Infallible>(axum::Extension(http))
         .layer::<_, _, std::convert::Infallible>(axum::Extension(auth))
+        .layer::<_, _, std::convert::Infallible>(axum::Extension(jobset))
 }
 
 #[cfg(test)]
@@ -670,9 +682,12 @@ impl MicropubQuery {
 
 #[cfg(test)]
 mod tests {
+    use std::sync::Arc;
+
     use crate::{database::Storage, micropub::MicropubError};
     use hyper::body::HttpBody;
     use serde_json::json;
+    use tokio::sync::Mutex;
 
     use super::FetchedPostContext;
     use kittybox_indieauth::{Scopes, Scope, TokenData};
@@ -733,7 +748,7 @@ mod tests {
         };
         let (uid, mf2) = super::normalize_mf2(post, &user);
 
-        let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new())
+        let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new())))
             .await
             .unwrap_err();
 
@@ -763,7 +778,7 @@ mod tests {
         };
         let (uid, mf2) = super::normalize_mf2(post, &user);
 
-        let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new())
+        let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new())))
             .await
             .unwrap_err();
 
@@ -791,7 +806,7 @@ mod tests {
         };
         let (uid, mf2) = super::normalize_mf2(post, &user);
 
-        let res = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new())
+        let res = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new())))
             .await
             .unwrap();