diff options
Diffstat (limited to 'kittybox-rs/src/micropub/mod.rs')
-rw-r--r-- | kittybox-rs/src/micropub/mod.rs | 27 |
1 files changed, 21 insertions, 6 deletions
diff --git a/kittybox-rs/src/micropub/mod.rs b/kittybox-rs/src/micropub/mod.rs index 04bf0a5..02eee6e 100644 --- a/kittybox-rs/src/micropub/mod.rs +++ b/kittybox-rs/src/micropub/mod.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use crate::database::{MicropubChannel, Storage, StorageError}; use crate::indieauth::backend::AuthBackend; @@ -11,6 +12,8 @@ use axum::TypedHeader; use axum::{http::StatusCode, Extension}; use serde::{Deserialize, Serialize}; use serde_json::json; +use tokio::sync::Mutex; +use tokio::task::JoinSet; use tracing::{debug, error, info, warn}; use kittybox_indieauth::{Scope, TokenData}; use kittybox_util::{MicropubError, ErrorType}; @@ -228,6 +231,7 @@ pub(crate) async fn _post<D: 'static + Storage>( mf2: serde_json::Value, db: D, http: reqwest::Client, + jobset: Arc<Mutex<JoinSet<()>>>, ) -> Result<Response, MicropubError> { // Here, we have the following guarantees: // - The MF2-JSON document is normalized (guaranteed by normalize_mf2) @@ -313,7 +317,12 @@ pub(crate) async fn _post<D: 'static + Storage>( let reply = IntoResponse::into_response((StatusCode::ACCEPTED, [("Location", uid.as_str())])); - tokio::task::spawn(background_processing(db, mf2, http)); + #[cfg(not(tokio_unstable))] + jobset.lock().await.spawn(background_processing(db, mf2, http)); + #[cfg(tokio_unstable)] + jobset.lock().await.build_task() + .name(format!("Kittybox background processing for post {}", uid.as_str()).as_str()) + .spawn(background_processing(db, mf2, http)); Ok(reply) } @@ -497,6 +506,7 @@ async fn dispatch_body( pub(crate) async fn post<D: Storage + 'static, A: AuthBackend>( Extension(db): Extension<D>, Extension(http): Extension<reqwest::Client>, + Extension(jobset): Extension<Arc<Mutex<JoinSet<()>>>>, TypedHeader(content_type): TypedHeader<ContentType>, user: User<A>, body: BodyStream, @@ -508,7 +518,7 @@ pub(crate) async fn post<D: Storage + 'static, A: AuthBackend>( }, Ok(PostBody::MF2(mf2)) => { let (uid, mf2) = normalize_mf2(mf2, &user); - match _post(&user, uid, mf2, db, http).await { + match _post(&user, uid, mf2, db, http, jobset).await { Ok(response) => response, Err(err) => err.into_response(), } @@ -631,7 +641,8 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( pub fn router<S, A>( storage: S, http: reqwest::Client, - auth: A + auth: A, + jobset: Arc<Mutex<JoinSet<()>>> ) -> axum::routing::MethodRouter where S: Storage + 'static, @@ -648,6 +659,7 @@ where .layer::<_, _, std::convert::Infallible>(axum::Extension(storage)) .layer::<_, _, std::convert::Infallible>(axum::Extension(http)) .layer::<_, _, std::convert::Infallible>(axum::Extension(auth)) + .layer::<_, _, std::convert::Infallible>(axum::Extension(jobset)) } #[cfg(test)] @@ -670,9 +682,12 @@ impl MicropubQuery { #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::{database::Storage, micropub::MicropubError}; use hyper::body::HttpBody; use serde_json::json; + use tokio::sync::Mutex; use super::FetchedPostContext; use kittybox_indieauth::{Scopes, Scope, TokenData}; @@ -733,7 +748,7 @@ mod tests { }; let (uid, mf2) = super::normalize_mf2(post, &user); - let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new()) + let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) .await .unwrap_err(); @@ -763,7 +778,7 @@ mod tests { }; let (uid, mf2) = super::normalize_mf2(post, &user); - let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new()) + let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) .await .unwrap_err(); @@ -791,7 +806,7 @@ mod tests { }; let (uid, mf2) = super::normalize_mf2(post, &user); - let res = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new()) + let res = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) .await .unwrap(); |