#[deny(missing_docs)] mod watchdog; pub use watchdog::Watchdog; pub struct WatchdogLayer { pet: watchdog::Pet, } /// A Tower layer to send a signal if there hasn't been new requests /// in a while. /// /// It resets a timer at the beginning of every single incoming /// request. Wait on the watchdog to begin. If no new requests haven't /// arrived in a while, the corresponding paired [Watchdog]'s /// [wait][Watchdog::wait] future will resolve. This is a signal to /// gracefully shutdown a server. impl WatchdogLayer { pub fn new(timeout: std::time::Duration) -> (watchdog::Watchdog, WatchdogLayer) { let (watchdog, pet) = watchdog::watchdog(timeout); (watchdog, WatchdogLayer { pet }) } } impl tower_layer::Layer for WatchdogLayer { type Service = WatchdogService; fn layer(&self, inner: S) -> Self::Service { Self::Service { pet: self.pet.clone(), inner } } } pub struct WatchdogService { pet: watchdog::Pet, inner: S } impl + Clone + 'static, Request: std::fmt::Debug + 'static> tower_service::Service for WatchdogService { type Response = S::Response; type Error = S::Error; type Future = std::pin::Pin>> + Send>>, std::pin::Pin>, Box>) -> std::pin::Pin>>>>>; fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll> { self.inner.poll_ready(cx) } fn call(&mut self, request: Request) -> Self::Future { use futures::FutureExt; // We need to get the service that we just polled. For this, // we clone the service, leave in the clone and take the // original. // // Don't ask me why this is needed. let mut inner = self.inner.clone(); std::mem::swap(&mut self.inner, &mut inner); let pet = self.pet.clone(); Box::pin(pet.pet_owned().boxed().then(Box::new(move |_| Box::pin(inner.call(request))))) } } #[cfg(test)] mod tests { use futures::FutureExt; #[tokio::test(start_paused = true)] async fn test_watchdog_layer() { use std::time::Duration; let (watchdog, layer) = super::WatchdogLayer::new(Duration::from_secs(1)); let (mut mock, mut handle) = tower_test::mock::spawn_layer::<(), (), _>(layer); handle.allow((100..1000).count() as u64 + 1); // We don't actually care what the service itself does. let responder = tokio::task::spawn(async move { while let Some(((), res)) = handle.next_request().await { res.send_response(()) } }); let mut watchdog_future = Box::pin(watchdog.wait().fuse()); for i in 100..=1_000 { if i != 1000 { assert!(mock.poll_ready().is_ready()); let request = Box::pin(tokio::time::sleep(std::time::Duration::from_millis(i)).then(|()| mock.call(()))); tokio::select! { _ = &mut watchdog_future => panic!("Watchdog called earlier than response!"), _ = request => {}, }; } else { assert!(mock.poll_ready().is_ready()); // We use `+ 1` here, because the watchdog behavior is // subject to a data race if a request arrives in the // same tick. let request = Box::pin(tokio::time::sleep(std::time::Duration::from_millis(i + 1)).then(|()| mock.call(()))); tokio::select! { _ = &mut watchdog_future => { }, _ = request => panic!("Watchdog didn't fire!") }; } } responder.abort(); } }