From 94ebc5e653191fcaacfa91dddebf88dca7e7b7fe Mon Sep 17 00:00:00 2001 From: Vika Date: Mon, 17 Jul 2023 01:52:09 +0300 Subject: Put Micropub background processing tasks in a JoinSet MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This allows using tree-structured concurrency to keep background tasks in check and allow them to finish running before shutting down — a necessary prerequisite for shutdown-on-idle. (A background task may take a bit too long to complete, and we may need to wait for it.) --- kittybox-rs/src/frontend/onboarding.rs | 18 +++++++++--- kittybox-rs/src/main.rs | 54 ++++++++++++++++++++++------------ kittybox-rs/src/micropub/mod.rs | 27 +++++++++++++---- 3 files changed, 71 insertions(+), 28 deletions(-) (limited to 'kittybox-rs/src') diff --git a/kittybox-rs/src/frontend/onboarding.rs b/kittybox-rs/src/frontend/onboarding.rs index d5cde02..e44e866 100644 --- a/kittybox-rs/src/frontend/onboarding.rs +++ b/kittybox-rs/src/frontend/onboarding.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::database::{settings, Storage}; use axum::{ extract::{Extension, Host}, @@ -7,6 +9,7 @@ use axum::{ }; use kittybox_frontend_renderer::{ErrorPage, OnboardingPage, Template}; use serde::Deserialize; +use tokio::{task::JoinSet, sync::Mutex}; use tracing::{debug, error}; use super::FrontendError; @@ -51,6 +54,7 @@ async fn onboard( user_uid: url::Url, data: OnboardingData, http: reqwest::Client, + jobset: Arc>>, ) -> Result<(), FrontendError> { // Create a user to pass to the backend // At this point the site belongs to nobody, so it is safe to do @@ -115,7 +119,7 @@ async fn onboard( } let (uid, post) = crate::micropub::normalize_mf2(data.first_post, &user); tracing::debug!("Posting first post {}...", uid); - crate::micropub::_post(&user, uid, post, db, http) + crate::micropub::_post(&user, uid, post, db, http, jobset) .await .map_err(|e| FrontendError { msg: "Error while posting the first post".to_string(), @@ -130,6 +134,7 @@ pub async fn post( Extension(db): Extension, Host(host): Host, Extension(http): Extension, + Extension(jobset): Extension>>>, Json(data): Json, ) -> axum::response::Response { let user_uid = format!("https://{}/", host.as_str()); @@ -137,7 +142,7 @@ pub async fn post( if db.post_exists(&user_uid).await.unwrap() { IntoResponse::into_response((StatusCode::FOUND, [("Location", "/")])) } else { - match onboard(db, user_uid.parse().unwrap(), data, http).await { + match onboard(db, user_uid.parse().unwrap(), data, http, jobset).await { Ok(()) => IntoResponse::into_response((StatusCode::FOUND, [("Location", "/")])), Err(err) => { error!("Onboarding error: {}", err); @@ -163,9 +168,14 @@ pub async fn post( } } -pub fn router(database: S, http: reqwest::Client) -> axum::routing::MethodRouter { +pub fn router( + database: S, + http: reqwest::Client, + jobset: Arc>>, +) -> axum::routing::MethodRouter { axum::routing::get(get) .post(post::) .layer::<_, _, std::convert::Infallible>(axum::Extension(database)) - .layer(axum::Extension(http)) + .layer::<_, _, std::convert::Infallible>(axum::Extension(http)) + .layer(axum::Extension(jobset)) } diff --git a/kittybox-rs/src/main.rs b/kittybox-rs/src/main.rs index cc11f3a..d96a8fb 100644 --- a/kittybox-rs/src/main.rs +++ b/kittybox-rs/src/main.rs @@ -1,6 +1,6 @@ use kittybox::database::FileStorage; -use std::{env, time::Duration}; -use tracing::{debug, error, info}; +use std::{env, time::Duration, sync::Arc}; +use tracing::error; fn init_media(auth_backend: A, blobstore_uri: &str) -> axum::Router { match blobstore_uri.split_once(':').unwrap().0 { @@ -20,7 +20,8 @@ async fn compose_kittybox_with_auth( http: reqwest::Client, auth_backend: A, backend_uri: &str, - blobstore_uri: &str + blobstore_uri: &str, + jobset: &Arc>> ) -> axum::Router where A: kittybox::indieauth::backend::AuthBackend { @@ -58,10 +59,11 @@ where A: kittybox::indieauth::backend::AuthBackend let micropub = kittybox::micropub::router( database.clone(), http.clone(), - auth_backend.clone() + auth_backend.clone(), + Arc::clone(jobset) ); let onboarding = kittybox::frontend::onboarding::router( - database.clone(), http.clone() + database.clone(), http.clone(), Arc::clone(jobset) ); axum::Router::new() @@ -111,13 +113,14 @@ where A: kittybox::indieauth::backend::AuthBackend let micropub = kittybox::micropub::router( database.clone(), http.clone(), - auth_backend.clone() + auth_backend.clone(), + Arc::clone(jobset) ); let onboarding = kittybox::frontend::onboarding::router( - database.clone(), http.clone() + database.clone(), http.clone(), Arc::clone(jobset) ); - axum::Router::new() + let router = axum::Router::new() .route("/", homepage) .fallback(fallback) .route("/.kittybox/micropub", micropub) @@ -128,7 +131,9 @@ where A: kittybox::indieauth::backend::AuthBackend "/.kittybox/health", axum::routing::get(health_check::) .layer(axum::Extension(database)) - ) + ); + + router }, other => unimplemented!("Unsupported backend: {other}") } @@ -138,6 +143,7 @@ async fn compose_kittybox( backend_uri: &str, blobstore_uri: &str, authstore_uri: &str, + jobset: &Arc>>, cancellation_token: &tokio_util::sync::CancellationToken ) -> axum::Router { let http: reqwest::Client = { @@ -162,7 +168,7 @@ async fn compose_kittybox( kittybox::indieauth::backend::fs::FileBackend::new(folder) }; - compose_kittybox_with_auth(http, auth_backend, backend_uri, blobstore_uri).await + compose_kittybox_with_auth(http, auth_backend, backend_uri, blobstore_uri, jobset).await } other => unimplemented!("Unsupported backend: {other}") }; @@ -214,7 +220,7 @@ async fn main() { .init(); //let _ = tracing_log::LogTracer::init(); - info!("Starting the kittybox server..."); + tracing::info!("Starting the kittybox server..."); let backend_uri: String = env::var("BACKEND_URI") .unwrap_or_else(|_| { @@ -237,11 +243,13 @@ async fn main() { }); let cancellation_token = tokio_util::sync::CancellationToken::new(); + let jobset = Arc::new(tokio::sync::Mutex::new(tokio::task::JoinSet::new())); let router = compose_kittybox( backend_uri.as_str(), blobstore_uri.as_str(), authstore_uri.as_str(), + &jobset, &cancellation_token ).await; @@ -315,7 +323,9 @@ async fn main() { std::net::TcpListener::bind(listen_addr).unwrap() })) } - + // Drop the remaining copy of the router + // to get rid of an extra reference to `jobset` + drop(router); // Polling streams mutates them let mut servers_futures = Box::pin(servers.into_iter() .map( @@ -356,7 +366,7 @@ async fn main() { }; use futures_util::stream::StreamExt; - tokio::select! { + let exitcode: i32 = tokio::select! { // Poll the servers stream for errors. // If any error out, shut down the entire operation // @@ -370,17 +380,25 @@ async fn main() { servers_futures.iter_mut().collect::>() )).await; - std::process::exit(1); + 1 } _ = shutdown_signal => { - info!("Shutdown requested by signal."); + tracing::info!("Shutdown requested by signal."); cancellation_token.cancel(); let _ = Box::pin(futures_util::future::join_all( servers_futures.iter_mut().collect::>() )).await; - info!("Shutdown complete, exiting."); - std::process::exit(0); + 0 } - } + }; + + tracing::info!("Waiting for unfinished background tasks..."); + let mut jobset: tokio::task::JoinSet<()> = Arc::try_unwrap(jobset) + .expect("Dangling jobset references present") + .into_inner(); + while (jobset.join_next().await).is_some() {} + tracing::info!("Shutdown complete, exiting."); + std::process::exit(exitcode); + } diff --git a/kittybox-rs/src/micropub/mod.rs b/kittybox-rs/src/micropub/mod.rs index 04bf0a5..02eee6e 100644 --- a/kittybox-rs/src/micropub/mod.rs +++ b/kittybox-rs/src/micropub/mod.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use crate::database::{MicropubChannel, Storage, StorageError}; use crate::indieauth::backend::AuthBackend; @@ -11,6 +12,8 @@ use axum::TypedHeader; use axum::{http::StatusCode, Extension}; use serde::{Deserialize, Serialize}; use serde_json::json; +use tokio::sync::Mutex; +use tokio::task::JoinSet; use tracing::{debug, error, info, warn}; use kittybox_indieauth::{Scope, TokenData}; use kittybox_util::{MicropubError, ErrorType}; @@ -228,6 +231,7 @@ pub(crate) async fn _post( mf2: serde_json::Value, db: D, http: reqwest::Client, + jobset: Arc>>, ) -> Result { // Here, we have the following guarantees: // - The MF2-JSON document is normalized (guaranteed by normalize_mf2) @@ -313,7 +317,12 @@ pub(crate) async fn _post( let reply = IntoResponse::into_response((StatusCode::ACCEPTED, [("Location", uid.as_str())])); - tokio::task::spawn(background_processing(db, mf2, http)); + #[cfg(not(tokio_unstable))] + jobset.lock().await.spawn(background_processing(db, mf2, http)); + #[cfg(tokio_unstable)] + jobset.lock().await.build_task() + .name(format!("Kittybox background processing for post {}", uid.as_str()).as_str()) + .spawn(background_processing(db, mf2, http)); Ok(reply) } @@ -497,6 +506,7 @@ async fn dispatch_body( pub(crate) async fn post( Extension(db): Extension, Extension(http): Extension, + Extension(jobset): Extension>>>, TypedHeader(content_type): TypedHeader, user: User, body: BodyStream, @@ -508,7 +518,7 @@ pub(crate) async fn post( }, Ok(PostBody::MF2(mf2)) => { let (uid, mf2) = normalize_mf2(mf2, &user); - match _post(&user, uid, mf2, db, http).await { + match _post(&user, uid, mf2, db, http, jobset).await { Ok(response) => response, Err(err) => err.into_response(), } @@ -631,7 +641,8 @@ pub(crate) async fn query( pub fn router( storage: S, http: reqwest::Client, - auth: A + auth: A, + jobset: Arc>> ) -> axum::routing::MethodRouter where S: Storage + 'static, @@ -648,6 +659,7 @@ where .layer::<_, _, std::convert::Infallible>(axum::Extension(storage)) .layer::<_, _, std::convert::Infallible>(axum::Extension(http)) .layer::<_, _, std::convert::Infallible>(axum::Extension(auth)) + .layer::<_, _, std::convert::Infallible>(axum::Extension(jobset)) } #[cfg(test)] @@ -670,9 +682,12 @@ impl MicropubQuery { #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::{database::Storage, micropub::MicropubError}; use hyper::body::HttpBody; use serde_json::json; + use tokio::sync::Mutex; use super::FetchedPostContext; use kittybox_indieauth::{Scopes, Scope, TokenData}; @@ -733,7 +748,7 @@ mod tests { }; let (uid, mf2) = super::normalize_mf2(post, &user); - let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new()) + let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) .await .unwrap_err(); @@ -763,7 +778,7 @@ mod tests { }; let (uid, mf2) = super::normalize_mf2(post, &user); - let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new()) + let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) .await .unwrap_err(); @@ -791,7 +806,7 @@ mod tests { }; let (uid, mf2) = super::normalize_mf2(post, &user); - let res = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new()) + let res = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) .await .unwrap(); -- cgit 1.4.1