about summary refs log tree commit diff
path: root/kittybox-rs
diff options
context:
space:
mode:
Diffstat (limited to 'kittybox-rs')
-rw-r--r--kittybox-rs/src/webmentions/mod.rs67
1 files changed, 54 insertions, 13 deletions
diff --git a/kittybox-rs/src/webmentions/mod.rs b/kittybox-rs/src/webmentions/mod.rs
index d798c50..cffd064 100644
--- a/kittybox-rs/src/webmentions/mod.rs
+++ b/kittybox-rs/src/webmentions/mod.rs
@@ -36,25 +36,66 @@ async fn accept_webmention<Q: JobQueue<Webmention>>(
     }
 }
 
-pub fn router<Q: JobQueue<Webmention>>(queue: Q) -> axum::Router {
+pub fn router<Q: JobQueue<Webmention>>(queue: Q, cancellation_token: tokio_util::sync::CancellationToken) -> (axum::Router, SupervisedTask) {
     // Automatically spawn a background task to handle webmentions
-    tokio::task::spawn(supervised_webmentions_task(queue.clone()));
+    let bgtask_handle = supervised_webmentions_task(queue.clone(), cancellation_token);
 
-    axum::Router::new()
+    let router = axum::Router::new()
         .route("/.kittybox/webmention",
             axum::routing::post(accept_webmention::<Q>)
         )
-        .layer(Extension(queue))
+        .layer(Extension(queue));
+
+    (router, bgtask_handle)
+}
+
+#[derive(thiserror::Error, Debug)]
+pub enum SupervisorError {
+    #[error("the task was explicitly cancelled")]
+    Cancelled
 }
 
-async fn supervisor<E: std::error::Error + std::fmt::Debug + Send + 'static, A: futures_util::Future<Output = Result<std::convert::Infallible, E>> + Send + 'static, F: FnMut() -> A>(mut f: F) -> std::convert::Infallible {
-    loop {
-        match tokio::task::spawn(f()).await {
-            Err(e) => tracing::error!("background task exited unexpectedly: {}", e),
-            Ok(Err(e)) => tracing::error!("background task returned error: {}", e),
-            Ok(Ok(_)) => unreachable!("task's Ok is Infallible")
+pub type SupervisedTask = tokio::task::JoinHandle<Result<std::convert::Infallible, SupervisorError>>;
+
+pub fn supervisor<E, A, F>(mut f: F, cancellation_token: tokio_util::sync::CancellationToken) -> SupervisedTask
+where
+    E: std::error::Error + std::fmt::Debug + Send + 'static,
+    A: std::future::Future<Output = Result<std::convert::Infallible, E>> + Send + 'static,
+    F: FnMut() -> A + Send + 'static
+{
+
+    let supervisor_future = async move {
+        loop {
+            // Don't spawn the task if we are already cancelled, but
+            // have somehow missed it (probably because the task
+            // crashed and we immediately received a cancellation
+            // request after noticing the crashed task)
+            if cancellation_token.is_cancelled() {
+                return Err(SupervisorError::Cancelled)
+            }
+            let task = tokio::task::spawn(f());
+            tokio::select! {
+                _ = cancellation_token.cancelled() => {
+                    tracing::info!("Shutdown of background task {:?} requested.", std::any::type_name::<A>());
+                    return Err(SupervisorError::Cancelled)
+                }
+                task_result = task => match task_result {
+                    Err(e) => tracing::error!("background task {:?} exited unexpectedly: {}", std::any::type_name::<A>(), e),
+                    Ok(Err(e)) => tracing::error!("background task {:?} returned error: {}", std::any::type_name::<A>(), e),
+                    Ok(Ok(_)) => unreachable!("task's Ok is Infallible")
+                }
+            }
+            tracing::debug!("Sleeping for a little while to back-off...");
+            tokio::time::sleep(std::time::Duration::from_secs(5)).await;
         }
-    }
+    };
+    #[cfg(not(tokio_unstable))]
+    return tokio::task::spawn(supervisor_future);
+    #[cfg(tokio_unstable)]
+    return tokio::task::Builder::new()
+        .name(format!("supervisor for background task {}", std::any::type_name::<A>()).as_str())
+        .spawn(supervisor_future)
+        .unwrap();
 }
 
 async fn process_webmentions_from_queue<Q: JobQueue<Webmention>>(queue: Q) -> Result<std::convert::Infallible, Q::Error> {
@@ -74,6 +115,6 @@ async fn process_webmentions_from_queue<Q: JobQueue<Webmention>>(queue: Q) -> Re
     unreachable!()
 }
 
-async fn supervised_webmentions_task<Q: JobQueue<Webmention>>(queue: Q) {
-    supervisor::<Q::Error, _, _>(|| process_webmentions_from_queue(queue.clone())).await;
+fn supervised_webmentions_task<Q: JobQueue<Webmention>>(queue: Q, cancellation_token: tokio_util::sync::CancellationToken) -> SupervisedTask {
+    supervisor::<Q::Error, _, _>(move || process_webmentions_from_queue(queue.clone()), cancellation_token)
 }