about summary refs log tree commit diff
path: root/tower-watchdog/src/lib.rs
blob: 9a5c6093cb08e87af4fe12f4f3c2591077056a53 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#[deny(missing_docs)]
mod watchdog;
pub use watchdog::Watchdog;

pub struct WatchdogLayer {
    pet: watchdog::Pet,
}

/// A Tower layer to send a signal if there hasn't been new requests
/// in a while.
///
/// It resets a timer at the beginning of every single incoming
/// request. Wait on the watchdog to begin. If no new requests haven't
/// arrived in a while, the corresponding paired [Watchdog]'s
/// [wait][Watchdog::wait] future will resolve. This is a signal to
/// gracefully shutdown a server.
impl WatchdogLayer {
    pub fn new(timeout: std::time::Duration) -> (watchdog::Watchdog, WatchdogLayer) {
        let (watchdog, pet) = watchdog::watchdog(timeout);
        (watchdog, WatchdogLayer { pet })
    }
}

impl<S> tower_layer::Layer<S> for WatchdogLayer {
    type Service = WatchdogService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        Self::Service {
            pet: self.pet.clone(),
            inner
        }
    }
}

pub struct WatchdogService<S> {
    pet: watchdog::Pet,
    inner: S
}

impl<S: tower_service::Service<Request> + Clone + 'static, Request: std::fmt::Debug + 'static> tower_service::Service<Request> for WatchdogService<S> {
    type Response = S::Response;
    type Error = S::Error;
    type Future = std::pin::Pin<Box<futures::future::Then<std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), tokio::sync::mpsc::error::SendError<()>>> + Send>>, std::pin::Pin<Box<S::Future>>, Box<dyn FnOnce(Result<(), tokio::sync::mpsc::error::SendError<()>>) -> std::pin::Pin<Box<S::Future>>>>>>;

    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request) -> Self::Future {
        use futures::FutureExt;
        // We need to get the service that we just polled. For this,
        // we clone the service, leave in the clone and take the
        // original.
        //
        // Don't ask me why this is needed.
        let mut inner = self.inner.clone();
        std::mem::swap(&mut self.inner, &mut inner);

        let pet = self.pet.clone();
        Box::pin(pet.pet_owned().boxed().then(Box::new(move |_| Box::pin(inner.call(request)))))
    }
}

#[cfg(test)]
mod tests {
    use futures::FutureExt;

    #[tokio::test(start_paused = true)]
    async fn test_watchdog_layer() {
        use std::time::Duration;

        let (watchdog, layer) = super::WatchdogLayer::new(Duration::from_secs(1));
        let (mut mock, mut handle) = tower_test::mock::spawn_layer::<(), (), _>(layer);
        handle.allow((100..1000).count() as u64 + 1);
        // We don't actually care what the service itself does.
        let responder = tokio::task::spawn(async move {
            while let Some(((), res)) = handle.next_request().await {
                res.send_response(())
            }
        });

        let mut watchdog_future = Box::pin(watchdog.wait().fuse());

        for i in 100..=1_000 {
            if i != 1000 {
                assert!(mock.poll_ready().is_ready());
                let request = Box::pin(tokio::time::sleep(std::time::Duration::from_millis(i)).then(|()| mock.call(())));
                tokio::select! {
                    _ = &mut watchdog_future => panic!("Watchdog called earlier than response!"),
                    _ = request => {},
                };
            } else {
                assert!(mock.poll_ready().is_ready());
                // We use `+ 1` here, because the watchdog behavior is
                // subject to a data race if a request arrives in the
                // same tick.
                let request = Box::pin(tokio::time::sleep(std::time::Duration::from_millis(i + 1)).then(|()| mock.call(())));
                tokio::select! {
                    _ = &mut watchdog_future => {
                    },
                    _ = request => panic!("Watchdog didn't fire!")
                };
            }
        }

        responder.abort();
    }
}