From 94ebc5e653191fcaacfa91dddebf88dca7e7b7fe Mon Sep 17 00:00:00 2001 From: Vika Date: Mon, 17 Jul 2023 01:52:09 +0300 Subject: Put Micropub background processing tasks in a JoinSet MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This allows using tree-structured concurrency to keep background tasks in check and allow them to finish running before shutting down — a necessary prerequisite for shutdown-on-idle. (A background task may take a bit too long to complete, and we may need to wait for it.) --- kittybox-rs/src/micropub/mod.rs | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) (limited to 'kittybox-rs/src/micropub/mod.rs') diff --git a/kittybox-rs/src/micropub/mod.rs b/kittybox-rs/src/micropub/mod.rs index 04bf0a5..02eee6e 100644 --- a/kittybox-rs/src/micropub/mod.rs +++ b/kittybox-rs/src/micropub/mod.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use crate::database::{MicropubChannel, Storage, StorageError}; use crate::indieauth::backend::AuthBackend; @@ -11,6 +12,8 @@ use axum::TypedHeader; use axum::{http::StatusCode, Extension}; use serde::{Deserialize, Serialize}; use serde_json::json; +use tokio::sync::Mutex; +use tokio::task::JoinSet; use tracing::{debug, error, info, warn}; use kittybox_indieauth::{Scope, TokenData}; use kittybox_util::{MicropubError, ErrorType}; @@ -228,6 +231,7 @@ pub(crate) async fn _post( mf2: serde_json::Value, db: D, http: reqwest::Client, + jobset: Arc>>, ) -> Result { // Here, we have the following guarantees: // - The MF2-JSON document is normalized (guaranteed by normalize_mf2) @@ -313,7 +317,12 @@ pub(crate) async fn _post( let reply = IntoResponse::into_response((StatusCode::ACCEPTED, [("Location", uid.as_str())])); - tokio::task::spawn(background_processing(db, mf2, http)); + #[cfg(not(tokio_unstable))] + jobset.lock().await.spawn(background_processing(db, mf2, http)); + #[cfg(tokio_unstable)] + jobset.lock().await.build_task() + .name(format!("Kittybox background processing for post {}", uid.as_str()).as_str()) + .spawn(background_processing(db, mf2, http)); Ok(reply) } @@ -497,6 +506,7 @@ async fn dispatch_body( pub(crate) async fn post( Extension(db): Extension, Extension(http): Extension, + Extension(jobset): Extension>>>, TypedHeader(content_type): TypedHeader, user: User, body: BodyStream, @@ -508,7 +518,7 @@ pub(crate) async fn post( }, Ok(PostBody::MF2(mf2)) => { let (uid, mf2) = normalize_mf2(mf2, &user); - match _post(&user, uid, mf2, db, http).await { + match _post(&user, uid, mf2, db, http, jobset).await { Ok(response) => response, Err(err) => err.into_response(), } @@ -631,7 +641,8 @@ pub(crate) async fn query( pub fn router( storage: S, http: reqwest::Client, - auth: A + auth: A, + jobset: Arc>> ) -> axum::routing::MethodRouter where S: Storage + 'static, @@ -648,6 +659,7 @@ where .layer::<_, _, std::convert::Infallible>(axum::Extension(storage)) .layer::<_, _, std::convert::Infallible>(axum::Extension(http)) .layer::<_, _, std::convert::Infallible>(axum::Extension(auth)) + .layer::<_, _, std::convert::Infallible>(axum::Extension(jobset)) } #[cfg(test)] @@ -670,9 +682,12 @@ impl MicropubQuery { #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::{database::Storage, micropub::MicropubError}; use hyper::body::HttpBody; use serde_json::json; + use tokio::sync::Mutex; use super::FetchedPostContext; use kittybox_indieauth::{Scopes, Scope, TokenData}; @@ -733,7 +748,7 @@ mod tests { }; let (uid, mf2) = super::normalize_mf2(post, &user); - let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new()) + let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) .await .unwrap_err(); @@ -763,7 +778,7 @@ mod tests { }; let (uid, mf2) = super::normalize_mf2(post, &user); - let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new()) + let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) .await .unwrap_err(); @@ -791,7 +806,7 @@ mod tests { }; let (uid, mf2) = super::normalize_mf2(post, &user); - let res = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new()) + let res = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) .await .unwrap(); -- cgit 1.4.1