about summary refs log tree commit diff
path: root/src/webmentions/queue.rs
diff options
context:
space:
mode:
authorVika <vika@fireburn.ru>2023-07-29 21:59:56 +0300
committerVika <vika@fireburn.ru>2023-07-29 21:59:56 +0300
commit0617663b249f9ca488e5de652108b17d67fbaf45 (patch)
tree11564b6c8fa37bf9203a0a4cc1c4e9cc088cb1a5 /src/webmentions/queue.rs
parent26c2b79f6a6380ae3224e9309b9f3352f5717bd7 (diff)
downloadkittybox-0617663b249f9ca488e5de652108b17d67fbaf45.tar.zst
Moved the entire Kittybox tree into the root
Diffstat (limited to 'src/webmentions/queue.rs')
-rw-r--r--src/webmentions/queue.rs303
1 files changed, 303 insertions, 0 deletions
diff --git a/src/webmentions/queue.rs b/src/webmentions/queue.rs
new file mode 100644
index 0000000..b811e71
--- /dev/null
+++ b/src/webmentions/queue.rs
@@ -0,0 +1,303 @@
+use std::{pin::Pin, str::FromStr};
+
+use futures_util::{Stream, StreamExt};
+use sqlx::{postgres::PgListener, Executor};
+use uuid::Uuid;
+
+use super::Webmention;
+
+static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("./migrations/webmention");
+
+pub use kittybox_util::queue::{JobQueue, JobItem, Job};
+
+pub trait PostgresJobItem: JobItem + sqlx::FromRow<'static, sqlx::postgres::PgRow> {
+    const DATABASE_NAME: &'static str;
+    const NOTIFICATION_CHANNEL: &'static str;
+}
+
+#[derive(sqlx::FromRow)]
+struct PostgresJobRow<T: PostgresJobItem> {
+    id: Uuid,
+    #[sqlx(flatten)]
+    job: T
+}
+
+#[derive(Debug)]
+pub struct PostgresJob<T: PostgresJobItem> {
+    id: Uuid,
+    job: T,
+    // This will normally always be Some, except on drop
+    txn: Option<sqlx::Transaction<'static, sqlx::Postgres>>,
+    runtime_handle: tokio::runtime::Handle,
+}
+
+
+impl<T: PostgresJobItem> Drop for PostgresJob<T> {
+    // This is an emulation of "async drop" — the struct retains a
+    // runtime handle, which it uses to block on a future that does
+    // the actual cleanup.
+    //
+    // Of course, this is not portable between runtimes, but I don't
+    // care about that, since Kittybox is designed to work within the
+    // Tokio ecosystem.
+    fn drop(&mut self) {
+        tracing::error!("Job {:?} failed, incrementing attempts...", &self);
+        if let Some(mut txn) = self.txn.take() {
+            let id = self.id;
+            self.runtime_handle.spawn(async move {
+                tracing::debug!("Constructing query to increment attempts for job {}...", id);
+                // UPDATE "T::DATABASE_NAME" WHERE id = $1 SET attempts = attempts + 1
+                sqlx::query_builder::QueryBuilder::new("UPDATE ")
+                    // This is safe from a SQL injection standpoint, since it is a constant.
+                    .push(T::DATABASE_NAME)
+                    .push(" SET attempts = attempts + 1")
+                    .push(" WHERE id = ")
+                    .push_bind(id)
+                    .build()
+                    .execute(&mut *txn)
+                    .await
+                    .unwrap();
+                sqlx::query_builder::QueryBuilder::new("NOTIFY ")
+                    .push(T::NOTIFICATION_CHANNEL)
+                    .build()
+                    .execute(&mut *txn)
+                    .await
+                    .unwrap();
+                txn.commit().await.unwrap();
+            });
+        }
+    }
+}
+
+#[cfg(test)]
+impl<T: PostgresJobItem> PostgresJob<T> {
+    async fn attempts(&mut self) -> Result<usize, sqlx::Error> {
+        sqlx::query_builder::QueryBuilder::new("SELECT attempts FROM ")
+            .push(T::DATABASE_NAME)
+            .push(" WHERE id = ")
+            .push_bind(self.id)
+            .build_query_as::<(i32,)>()
+            // It's safe to unwrap here, because we "take" the txn only on drop or commit,
+            // where it's passed by value, not by reference.
+            .fetch_one(self.txn.as_deref_mut().unwrap())
+            .await
+            .map(|(i,)| i as usize)
+    }
+}
+
+#[async_trait::async_trait]
+impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJob<Webmention> {
+    fn job(&self) -> &Webmention {
+        &self.job
+    }
+    async fn done(mut self) -> Result<(), <PostgresJobQueue<Webmention> as JobQueue<Webmention>>::Error> {
+        tracing::debug!("Deleting {} from the job queue", self.id);
+        sqlx::query("DELETE FROM kittybox_webmention.incoming_webmention_queue WHERE id = $1")
+            .bind(self.id)
+            .execute(self.txn.as_deref_mut().unwrap())
+            .await?;
+
+        self.txn.take().unwrap().commit().await
+    }
+}
+
+pub struct PostgresJobQueue<T> {
+    db: sqlx::PgPool,
+    _phantom: std::marker::PhantomData<T>
+}
+impl<T> Clone for PostgresJobQueue<T> {
+    fn clone(&self) -> Self {
+        Self {
+            db: self.db.clone(),
+            _phantom: std::marker::PhantomData
+        }
+    }
+}
+
+impl PostgresJobQueue<Webmention> {
+    pub async fn new(uri: &str) -> Result<Self, sqlx::Error> {
+        let mut options = sqlx::postgres::PgConnectOptions::from_str(uri)?
+            .options([("search_path", "kittybox_webmention")]);
+        if let Ok(password_file) = std::env::var("PGPASS_FILE") {
+            let password = tokio::fs::read_to_string(password_file).await.unwrap();
+            options = options.password(&password);
+        } else if let Ok(password) = std::env::var("PGPASS") {
+            options = options.password(&password)
+        }
+        Self::from_pool(
+            sqlx::postgres::PgPoolOptions::new()
+                .max_connections(50)
+                .connect_with(options)
+                .await?
+        ).await
+
+    }
+
+    pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result<Self, sqlx::Error> {
+        db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox_webmention")).await?;
+        MIGRATOR.run(&db).await?;
+        Ok(Self { db, _phantom: std::marker::PhantomData })
+    }
+}
+
+#[async_trait::async_trait]
+impl JobQueue<Webmention> for PostgresJobQueue<Webmention> {
+    type Job = PostgresJob<Webmention>;
+    type Error = sqlx::Error;
+
+    async fn get_one(&self) -> Result<Option<Self::Job>, Self::Error> {
+        let mut txn = self.db.begin().await?;
+
+        match sqlx::query_as::<_, PostgresJobRow<Webmention>>(
+            "SELECT id, source, target FROM kittybox_webmention.incoming_webmention_queue WHERE attempts < 5 FOR UPDATE SKIP LOCKED LIMIT 1"
+        )
+            .fetch_optional(&mut *txn)
+            .await?
+        {
+            Some(job_row) => {
+                return Ok(Some(Self::Job {
+                    id: job_row.id,
+                    job: job_row.job,
+                    txn: Some(txn),
+                    runtime_handle: tokio::runtime::Handle::current(),
+                }))
+            },
+            None => Ok(None)
+        }
+    }
+
+    async fn put(&self, item: &Webmention) -> Result<Uuid, Self::Error> {
+        sqlx::query_scalar::<_, Uuid>("INSERT INTO kittybox_webmention.incoming_webmention_queue (source, target) VALUES ($1, $2) RETURNING id")
+            .bind(item.source.as_str())
+            .bind(item.target.as_str())
+            .fetch_one(&self.db)
+            .await
+    }
+
+    async fn into_stream(self) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>>, Self::Error> {
+        let mut listener = PgListener::connect_with(&self.db).await?;
+        listener.listen("incoming_webmention").await?;
+
+        let stream: Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>> = futures_util::stream::try_unfold((), {
+            let listener = std::sync::Arc::new(tokio::sync::Mutex::new(listener));
+            move |_| {
+                let queue = self.clone();
+                let listener = listener.clone();
+                async move {
+                    loop {
+                        match queue.get_one().await? {
+                            Some(item) => return Ok(Some((item, ()))),
+                            None => {
+                                listener.lock().await.recv().await?;
+                                continue
+                            }
+                        }
+                    }
+                }
+            }
+        }).boxed();
+
+        Ok(stream)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::sync::Arc;
+
+    use super::{Webmention, PostgresJobQueue, Job, JobQueue, MIGRATOR};
+    use futures_util::StreamExt;
+
+    #[sqlx::test(migrator = "MIGRATOR")]
+    #[tracing_test::traced_test]
+    async fn test_webmention_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> {
+        let test_webmention = Webmention {
+            source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(),
+            target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned()
+        };
+
+        let queue = PostgresJobQueue::<Webmention>::from_pool(pool).await?;
+        tracing::debug!("Putting webmention into queue");
+        queue.put(&test_webmention).await?;
+        {
+            let mut job_description = queue.get_one().await?.unwrap();
+            assert_eq!(job_description.job(), &test_webmention);
+            assert_eq!(job_description.attempts().await?, 0);
+        }
+        tracing::debug!("Creating a stream");
+        let mut stream = queue.clone().into_stream().await?;
+
+        {
+            let mut guard = stream.next().await.transpose()?.unwrap();
+            assert_eq!(guard.job(), &test_webmention);
+            assert_eq!(guard.attempts().await?, 1);
+            if let Some(item) = queue.get_one().await? {
+                panic!("Unexpected item {:?} returned from job queue!", item)
+            };
+        }
+
+        {
+            let mut guard = stream.next().await.transpose()?.unwrap();
+            assert_eq!(guard.job(), &test_webmention);
+            assert_eq!(guard.attempts().await?, 2);
+            guard.done().await?;
+        }
+
+        match queue.get_one().await? {
+            Some(item) => panic!("Unexpected item {:?} returned from job queue!", item),
+            None => Ok(())
+        }
+    }
+
+    #[sqlx::test(migrator = "MIGRATOR")]
+    #[tracing_test::traced_test]
+    async fn test_no_hangups_in_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> {
+        let test_webmention = Webmention {
+            source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(),
+            target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned()
+        };
+
+        let queue = PostgresJobQueue::<Webmention>::from_pool(pool.clone()).await?;
+        tracing::debug!("Putting webmention into queue");
+        queue.put(&test_webmention).await?;
+        tracing::debug!("Creating a stream");
+        let mut stream = queue.clone().into_stream().await?;
+
+        // Synchronisation barrier that will be useful later
+        let barrier = Arc::new(tokio::sync::Barrier::new(2));
+        {
+            // Get one job guard from a queue
+            let mut guard = stream.next().await.transpose()?.unwrap();
+            assert_eq!(guard.job(), &test_webmention);
+            assert_eq!(guard.attempts().await?, 0);
+
+            tokio::task::spawn({
+                let barrier = barrier.clone();
+                async move {
+                    // Wait for the signal to drop the guard!
+                    barrier.wait().await;
+
+                    drop(guard)
+                }
+            });
+        }
+        tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()).await.unwrap_err();
+
+        let future = tokio::task::spawn(
+            tokio::time::timeout(
+                std::time::Duration::from_secs(10), async move {
+                    stream.next().await.unwrap().unwrap()
+                }
+            )
+        );
+        // Let the other task drop the guard it is holding
+        barrier.wait().await;
+        let mut guard = future.await
+            .expect("Timeout on fetching item")
+            .expect("Job queue error");
+        assert_eq!(guard.job(), &test_webmention);
+        assert_eq!(guard.attempts().await?, 1);
+
+        Ok(())
+    }
+}