about summary refs log tree commit diff
path: root/tower-watchdog/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tower-watchdog/src/lib.rs')
-rw-r--r--tower-watchdog/src/lib.rs108
1 files changed, 108 insertions, 0 deletions
diff --git a/tower-watchdog/src/lib.rs b/tower-watchdog/src/lib.rs
new file mode 100644
index 0000000..9a5c609
--- /dev/null
+++ b/tower-watchdog/src/lib.rs
@@ -0,0 +1,108 @@
+#[deny(missing_docs)]
+mod watchdog;
+pub use watchdog::Watchdog;
+
+pub struct WatchdogLayer {
+    pet: watchdog::Pet,
+}
+
+/// A Tower layer to send a signal if there hasn't been new requests
+/// in a while.
+///
+/// It resets a timer at the beginning of every single incoming
+/// request. Wait on the watchdog to begin. If no new requests haven't
+/// arrived in a while, the corresponding paired [Watchdog]'s
+/// [wait][Watchdog::wait] future will resolve. This is a signal to
+/// gracefully shutdown a server.
+impl WatchdogLayer {
+    pub fn new(timeout: std::time::Duration) -> (watchdog::Watchdog, WatchdogLayer) {
+        let (watchdog, pet) = watchdog::watchdog(timeout);
+        (watchdog, WatchdogLayer { pet })
+    }
+}
+
+impl<S> tower_layer::Layer<S> for WatchdogLayer {
+    type Service = WatchdogService<S>;
+
+    fn layer(&self, inner: S) -> Self::Service {
+        Self::Service {
+            pet: self.pet.clone(),
+            inner
+        }
+    }
+}
+
+pub struct WatchdogService<S> {
+    pet: watchdog::Pet,
+    inner: S
+}
+
+impl<S: tower_service::Service<Request> + Clone + 'static, Request: std::fmt::Debug + 'static> tower_service::Service<Request> for WatchdogService<S> {
+    type Response = S::Response;
+    type Error = S::Error;
+    type Future = std::pin::Pin<Box<futures::future::Then<std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), tokio::sync::mpsc::error::SendError<()>>> + Send>>, std::pin::Pin<Box<S::Future>>, Box<dyn FnOnce(Result<(), tokio::sync::mpsc::error::SendError<()>>) -> std::pin::Pin<Box<S::Future>>>>>>;
+
+    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
+        self.inner.poll_ready(cx)
+    }
+
+    fn call(&mut self, request: Request) -> Self::Future {
+        use futures::FutureExt;
+        // We need to get the service that we just polled. For this,
+        // we clone the service, leave in the clone and take the
+        // original.
+        //
+        // Don't ask me why this is needed.
+        let mut inner = self.inner.clone();
+        std::mem::swap(&mut self.inner, &mut inner);
+
+        let pet = self.pet.clone();
+        Box::pin(pet.pet_owned().boxed().then(Box::new(move |_| Box::pin(inner.call(request)))))
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use futures::FutureExt;
+
+    #[tokio::test(start_paused = true)]
+    async fn test_watchdog_layer() {
+        use std::time::Duration;
+
+        let (watchdog, layer) = super::WatchdogLayer::new(Duration::from_secs(1));
+        let (mut mock, mut handle) = tower_test::mock::spawn_layer::<(), (), _>(layer);
+        handle.allow((100..1000).count() as u64 + 1);
+        // We don't actually care what the service itself does.
+        let responder = tokio::task::spawn(async move {
+            while let Some(((), res)) = handle.next_request().await {
+                res.send_response(())
+            }
+        });
+
+        let mut watchdog_future = Box::pin(watchdog.wait().fuse());
+
+        for i in 100..=1_000 {
+            if i != 1000 {
+                assert!(mock.poll_ready().is_ready());
+                let request = Box::pin(tokio::time::sleep(std::time::Duration::from_millis(i)).then(|()| mock.call(())));
+                tokio::select! {
+                    _ = &mut watchdog_future => panic!("Watchdog called earlier than response!"),
+                    _ = request => {},
+                };
+            } else {
+                assert!(mock.poll_ready().is_ready());
+                // We use `+ 1` here, because the watchdog behavior is
+                // subject to a data race if a request arrives in the
+                // same tick.
+                let request = Box::pin(tokio::time::sleep(std::time::Duration::from_millis(i + 1)).then(|()| mock.call(())));
+                tokio::select! {
+                    _ = &mut watchdog_future => {
+                    },
+                    _ = request => panic!("Watchdog didn't fire!")
+                };
+            }
+        }
+
+        responder.abort();
+    }
+}