about summary refs log tree commit diff
path: root/kittybox-rs/src/micropub
diff options
context:
space:
mode:
Diffstat (limited to 'kittybox-rs/src/micropub')
-rw-r--r--kittybox-rs/src/micropub/mod.rs27
1 files changed, 21 insertions, 6 deletions
diff --git a/kittybox-rs/src/micropub/mod.rs b/kittybox-rs/src/micropub/mod.rs
index 04bf0a5..02eee6e 100644
--- a/kittybox-rs/src/micropub/mod.rs
+++ b/kittybox-rs/src/micropub/mod.rs
@@ -1,4 +1,5 @@
 use std::collections::HashMap;
+use std::sync::Arc;
 
 use crate::database::{MicropubChannel, Storage, StorageError};
 use crate::indieauth::backend::AuthBackend;
@@ -11,6 +12,8 @@ use axum::TypedHeader;
 use axum::{http::StatusCode, Extension};
 use serde::{Deserialize, Serialize};
 use serde_json::json;
+use tokio::sync::Mutex;
+use tokio::task::JoinSet;
 use tracing::{debug, error, info, warn};
 use kittybox_indieauth::{Scope, TokenData};
 use kittybox_util::{MicropubError, ErrorType};
@@ -228,6 +231,7 @@ pub(crate) async fn _post<D: 'static + Storage>(
     mf2: serde_json::Value,
     db: D,
     http: reqwest::Client,
+    jobset: Arc<Mutex<JoinSet<()>>>,
 ) -> Result<Response, MicropubError> {
     // Here, we have the following guarantees:
     // - The MF2-JSON document is normalized (guaranteed by normalize_mf2)
@@ -313,7 +317,12 @@ pub(crate) async fn _post<D: 'static + Storage>(
     let reply =
         IntoResponse::into_response((StatusCode::ACCEPTED, [("Location", uid.as_str())]));
 
-    tokio::task::spawn(background_processing(db, mf2, http));
+    #[cfg(not(tokio_unstable))]
+    jobset.lock().await.spawn(background_processing(db, mf2, http));
+    #[cfg(tokio_unstable)]
+    jobset.lock().await.build_task()
+        .name(format!("Kittybox background processing for post {}", uid.as_str()).as_str())
+        .spawn(background_processing(db, mf2, http));
 
     Ok(reply)
 }
@@ -497,6 +506,7 @@ async fn dispatch_body(
 pub(crate) async fn post<D: Storage + 'static, A: AuthBackend>(
     Extension(db): Extension<D>,
     Extension(http): Extension<reqwest::Client>,
+    Extension(jobset): Extension<Arc<Mutex<JoinSet<()>>>>,
     TypedHeader(content_type): TypedHeader<ContentType>,
     user: User<A>,
     body: BodyStream,
@@ -508,7 +518,7 @@ pub(crate) async fn post<D: Storage + 'static, A: AuthBackend>(
         },
         Ok(PostBody::MF2(mf2)) => {
             let (uid, mf2) = normalize_mf2(mf2, &user);
-            match _post(&user, uid, mf2, db, http).await {
+            match _post(&user, uid, mf2, db, http, jobset).await {
                 Ok(response) => response,
                 Err(err) => err.into_response(),
             }
@@ -631,7 +641,8 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>(
 pub fn router<S, A>(
     storage: S,
     http: reqwest::Client,
-    auth: A
+    auth: A,
+    jobset: Arc<Mutex<JoinSet<()>>>
 ) -> axum::routing::MethodRouter
 where
     S: Storage + 'static,
@@ -648,6 +659,7 @@ where
         .layer::<_, _, std::convert::Infallible>(axum::Extension(storage))
         .layer::<_, _, std::convert::Infallible>(axum::Extension(http))
         .layer::<_, _, std::convert::Infallible>(axum::Extension(auth))
+        .layer::<_, _, std::convert::Infallible>(axum::Extension(jobset))
 }
 
 #[cfg(test)]
@@ -670,9 +682,12 @@ impl MicropubQuery {
 
 #[cfg(test)]
 mod tests {
+    use std::sync::Arc;
+
     use crate::{database::Storage, micropub::MicropubError};
     use hyper::body::HttpBody;
     use serde_json::json;
+    use tokio::sync::Mutex;
 
     use super::FetchedPostContext;
     use kittybox_indieauth::{Scopes, Scope, TokenData};
@@ -733,7 +748,7 @@ mod tests {
         };
         let (uid, mf2) = super::normalize_mf2(post, &user);
 
-        let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new())
+        let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new())))
             .await
             .unwrap_err();
 
@@ -763,7 +778,7 @@ mod tests {
         };
         let (uid, mf2) = super::normalize_mf2(post, &user);
 
-        let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new())
+        let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new())))
             .await
             .unwrap_err();
 
@@ -791,7 +806,7 @@ mod tests {
         };
         let (uid, mf2) = super::normalize_mf2(post, &user);
 
-        let res = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new())
+        let res = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new())))
             .await
             .unwrap();