about summary refs log tree commit diff
path: root/kittybox-rs/src/webmentions
diff options
context:
space:
mode:
Diffstat (limited to 'kittybox-rs/src/webmentions')
-rw-r--r--kittybox-rs/src/webmentions/mod.rs4
-rw-r--r--kittybox-rs/src/webmentions/queue.rs171
2 files changed, 128 insertions, 47 deletions
diff --git a/kittybox-rs/src/webmentions/mod.rs b/kittybox-rs/src/webmentions/mod.rs
index cffd064..630a1a6 100644
--- a/kittybox-rs/src/webmentions/mod.rs
+++ b/kittybox-rs/src/webmentions/mod.rs
@@ -11,8 +11,10 @@ pub struct Webmention {
     target: String,
 }
 
-impl queue::JobItem for Webmention {
+impl queue::JobItem for Webmention {}
+impl queue::PostgresJobItem for Webmention {
     const DATABASE_NAME: &'static str = "kittybox.incoming_webmention_queue";
+    const NOTIFICATION_CHANNEL: &'static str = "incoming_webmention";
 }
 
 async fn accept_webmention<Q: JobQueue<Webmention>>(
diff --git a/kittybox-rs/src/webmentions/queue.rs b/kittybox-rs/src/webmentions/queue.rs
index b585f58..0b11a4e 100644
--- a/kittybox-rs/src/webmentions/queue.rs
+++ b/kittybox-rs/src/webmentions/queue.rs
@@ -8,29 +8,22 @@ use super::Webmention;
 
 static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!();
 
-#[async_trait::async_trait]
-pub trait JobQueue<T: JobItem>: Send + Sync + Sized + Clone + 'static {
-    type Job: Job<T, Self>;
-    type Error: std::error::Error + Send + Sync + Sized;
-
-    async fn get_one(&self) -> Result<Option<Self::Job>, Self::Error>;
-    async fn put(&self, item: &T) -> Result<Uuid, Self::Error>;
-
-    async fn into_stream(self) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>>, Self::Error>;
-}
+pub use kittybox_util::queue::{JobQueue, JobItem, Job};
 
-#[async_trait::async_trait]
-pub trait Job<T: JobItem, Q: JobQueue<T>>: Send + Sync + Sized {
-    fn job(&self) -> &T;
-    async fn done(self) -> Result<(), Q::Error>;
+pub trait PostgresJobItem: JobItem + sqlx::FromRow<'static, sqlx::postgres::PgRow> {
+    const DATABASE_NAME: &'static str;
+    const NOTIFICATION_CHANNEL: &'static str;
 }
 
-pub trait JobItem: Send + Sync + Sized + std::fmt::Debug {
-    const DATABASE_NAME: &'static str;
+#[derive(sqlx::FromRow)]
+struct PostgresJobRow<T: PostgresJobItem> {
+    id: Uuid,
+    #[sqlx(flatten)]
+    job: T
 }
 
 #[derive(Debug)]
-pub struct PostgresJobItem<T: JobItem> {
+pub struct PostgresJob<T: PostgresJobItem> {
     id: Uuid,
     job: T,
     // This will normally always be Some, except on drop
@@ -39,7 +32,7 @@ pub struct PostgresJobItem<T: JobItem> {
 }
 
 
-impl<T: JobItem> Drop for PostgresJobItem<T> {
+impl<T: PostgresJobItem> Drop for PostgresJob<T> {
     // This is an emulation of "async drop" — the struct retains a
     // runtime handle, which it uses to block on a future that does
     // the actual cleanup.
@@ -64,15 +57,36 @@ impl<T: JobItem> Drop for PostgresJobItem<T> {
                     .execute(&mut txn)
                     .await
                     .unwrap();
-
+                sqlx::query_builder::QueryBuilder::new("NOTIFY ")
+                    .push(T::NOTIFICATION_CHANNEL)
+                    .build()
+                    .execute(&mut txn)
+                    .await
+                    .unwrap();
                 txn.commit().await.unwrap();
             });
         }
     }
 }
 
+#[cfg(test)]
+impl<T: PostgresJobItem> PostgresJob<T> {
+    async fn attempts(&mut self) -> Result<usize, sqlx::Error> {
+        sqlx::query_builder::QueryBuilder::new("SELECT attempts FROM ")
+            .push(T::DATABASE_NAME)
+            .push(" WHERE id = ")
+            .push_bind(self.id)
+            .build_query_as::<(i32,)>()
+            // It's safe to unwrap here, because we "take" the txn only on drop or commit,
+            // where it's passed by value, not by reference.
+            .fetch_one(self.txn.as_mut().unwrap())
+            .await
+            .map(|(i,)| i as usize)
+    }
+}
+
 #[async_trait::async_trait]
-impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJobItem<Webmention> {
+impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJob<Webmention> {
     fn job(&self) -> &Webmention {
         &self.job
     }
@@ -126,28 +140,28 @@ impl PostgresJobQueue<Webmention> {
 
 #[async_trait::async_trait]
 impl JobQueue<Webmention> for PostgresJobQueue<Webmention> {
-    type Job = PostgresJobItem<Webmention>;
+    type Job = PostgresJob<Webmention>;
     type Error = sqlx::Error;
 
     async fn get_one(&self) -> Result<Option<Self::Job>, Self::Error> {
         let mut txn = self.db.begin().await?;
 
-        match sqlx::query_as::<_, (Uuid, String, String)>(
+        match sqlx::query_as::<_, PostgresJobRow<Webmention>>(
             "SELECT id, source, target FROM kittybox.incoming_webmention_queue WHERE attempts < 5 FOR UPDATE SKIP LOCKED LIMIT 1"
         )
             .fetch_optional(&mut txn)
             .await?
-            .map(|(id, source, target)| (id, Webmention { source, target })) {
-                Some((id, webmention)) => {
-                    return Ok(Some(Self::Job {
-                        id,
-                        job: webmention,
-                        txn: Some(txn),
-                        runtime_handle: tokio::runtime::Handle::current(),
-                    }))
-                },
-                None => Ok(None)
-            }
+        {
+            Some(job_row) => {
+                return Ok(Some(Self::Job {
+                    id: job_row.id,
+                    job: job_row.job,
+                    txn: Some(txn),
+                    runtime_handle: tokio::runtime::Handle::current(),
+                }))
+            },
+            None => Ok(None)
+        }
     }
 
     async fn put(&self, item: &Webmention) -> Result<Uuid, Self::Error> {
@@ -187,9 +201,12 @@ impl JobQueue<Webmention> for PostgresJobQueue<Webmention> {
 
 #[cfg(test)]
 mod tests {
+    use std::sync::Arc;
+
     use super::{Webmention, PostgresJobQueue, Job, JobQueue};
     use futures_util::StreamExt;
     #[sqlx::test]
+    #[tracing_test::traced_test]
     async fn test_webmention_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> {
         let test_webmention = Webmention {
             source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(),
@@ -197,25 +214,87 @@ mod tests {
         };
 
         let queue = PostgresJobQueue::<Webmention>::from_pool(pool).await?;
-        println!("Putting webmention into queue");
+        tracing::debug!("Putting webmention into queue");
         queue.put(&test_webmention).await?;
-        assert_eq!(queue.get_one().await?.as_ref().map(|j| j.job()), Some(&test_webmention));
-        println!("Creating a stream");
+        {
+            let mut job_description = queue.get_one().await?.unwrap();
+            assert_eq!(job_description.job(), &test_webmention);
+            assert_eq!(job_description.attempts().await?, 0);
+        }
+        tracing::debug!("Creating a stream");
         let mut stream = queue.clone().into_stream().await?;
 
-        let future = stream.next();
-        let guard = future.await.transpose()?.unwrap();
-        assert_eq!(guard.job(), &test_webmention);
-        if let Some(item) = queue.get_one().await? {
-            panic!("Unexpected item {:?} returned from job queue!", item)
-        };
-        drop(guard);
-        let guard = stream.next().await.transpose()?.unwrap();
-        assert_eq!(guard.job(), &test_webmention);
-        guard.done().await?;
+        {
+            let mut guard = stream.next().await.transpose()?.unwrap();
+            assert_eq!(guard.job(), &test_webmention);
+            assert_eq!(guard.attempts().await?, 1);
+            if let Some(item) = queue.get_one().await? {
+                panic!("Unexpected item {:?} returned from job queue!", item)
+            };
+        }
+
+        {
+            let mut guard = stream.next().await.transpose()?.unwrap();
+            assert_eq!(guard.job(), &test_webmention);
+            assert_eq!(guard.attempts().await?, 2);
+            guard.done().await?;
+        }
+
         match queue.get_one().await? {
             Some(item) => panic!("Unexpected item {:?} returned from job queue!", item),
             None => Ok(())
         }
     }
+
+    #[sqlx::test]
+    #[tracing_test::traced_test]
+    async fn test_no_hangups_in_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> {
+        let test_webmention = Webmention {
+            source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(),
+            target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned()
+        };
+
+        let queue = PostgresJobQueue::<Webmention>::from_pool(pool.clone()).await?;
+        tracing::debug!("Putting webmention into queue");
+        queue.put(&test_webmention).await?;
+        tracing::debug!("Creating a stream");
+        let mut stream = queue.clone().into_stream().await?;
+
+        // Synchronisation barrier that will be useful later
+        let barrier = Arc::new(tokio::sync::Barrier::new(2));
+        {
+            // Get one job guard from a queue
+            let mut guard = stream.next().await.transpose()?.unwrap();
+            assert_eq!(guard.job(), &test_webmention);
+            assert_eq!(guard.attempts().await?, 0);
+
+            tokio::task::spawn({
+                let barrier = barrier.clone();
+                async move {
+                    // Wait for the signal to drop the guard!
+                    barrier.wait().await;
+
+                    drop(guard)
+                }
+            });
+        }
+        tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()).await.unwrap_err();
+
+        let future = tokio::task::spawn(
+            tokio::time::timeout(
+                std::time::Duration::from_secs(10), async move {
+                    stream.next().await.unwrap().unwrap()
+                }
+            )
+        );
+        // Let the other task drop the guard it is holding
+        barrier.wait().await;
+        let mut guard = future.await
+            .expect("Timeout on fetching item")
+            .expect("Job queue error");
+        assert_eq!(guard.job(), &test_webmention);
+        assert_eq!(guard.attempts().await?, 1);
+
+        Ok(())
+    }
 }