about summary refs log tree commit diff
path: root/kittybox-rs/src/webmentions/queue.rs
diff options
context:
space:
mode:
Diffstat (limited to 'kittybox-rs/src/webmentions/queue.rs')
-rw-r--r--kittybox-rs/src/webmentions/queue.rs179
1 files changed, 179 insertions, 0 deletions
diff --git a/kittybox-rs/src/webmentions/queue.rs b/kittybox-rs/src/webmentions/queue.rs
new file mode 100644
index 0000000..77ad4ea
--- /dev/null
+++ b/kittybox-rs/src/webmentions/queue.rs
@@ -0,0 +1,179 @@
+use std::{pin::Pin, str::FromStr};
+
+use futures_util::{Stream, StreamExt};
+use sqlx::postgres::PgListener;
+use uuid::Uuid;
+
+use super::Webmention;
+
+static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!();
+
+#[async_trait::async_trait]
+pub trait JobQueue<T: Send + Sync + Sized>: Send + Sync + Sized + Clone + 'static {
+    type Job: Job<T, Self>;
+    type Error: std::error::Error + Send + Sync + Sized;
+
+    async fn get_one(&self) -> Result<Option<Self::Job>, Self::Error>;
+    async fn put(&self, item: &T) -> Result<Uuid, Self::Error>;
+
+    async fn into_stream(self) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>>, Self::Error>;
+}
+
+#[async_trait::async_trait]
+pub trait Job<T: Send + Sync + Sized, Q: JobQueue<T>>: Send + Sync + Sized {
+    fn job(&self) -> &T;
+    async fn done(self) -> Result<(), Q::Error>;
+}
+
+#[derive(Debug)]
+pub struct PostgresJobItem<'c, T> {
+    id: Uuid,
+    job: T,
+    txn: sqlx::Transaction<'c, sqlx::Postgres>
+}
+
+#[async_trait::async_trait]
+impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJobItem<'_, Webmention> {
+    fn job(&self) -> &Webmention {
+        &self.job
+    }
+    async fn done(mut self) -> Result<(), <PostgresJobQueue<Webmention> as JobQueue<Webmention>>::Error> {
+        println!("Deleting {} from the job queue", self.id);
+        sqlx::query("DELETE FROM incoming_webmention_queue WHERE id = $1")
+            .bind(self.id)
+            .execute(&mut self.txn)
+            .await?;
+
+        self.txn.commit().await
+    }
+}
+
+pub struct PostgresJobQueue<T> {
+    db: sqlx::PgPool,
+    _phantom: std::marker::PhantomData<T>
+}
+impl<T> Clone for PostgresJobQueue<T> {
+    fn clone(&self) -> Self {
+        Self {
+            db: self.db.clone(),
+            _phantom: std::marker::PhantomData
+        }
+    }
+}
+
+impl PostgresJobQueue<Webmention> {
+    pub async fn new(uri: &str) -> Result<Self, sqlx::Error> {
+        let mut options = sqlx::postgres::PgConnectOptions::from_str(uri)?;
+        if let Ok(password_file) = std::env::var("PGPASS_FILE") {
+            let password = tokio::fs::read_to_string(password_file).await.unwrap();
+            options = options.password(&password);
+        } else if let Ok(password) = std::env::var("PGPASS") {
+            options = options.password(&password)
+        }
+        Ok(Self::from_pool(
+            sqlx::postgres::PgPoolOptions::new()
+                .max_connections(50)
+                .connect_with(options)
+                .await?
+        ).await?)
+
+    }
+
+    pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result<Self, sqlx::migrate::MigrateError> {
+        MIGRATOR.run(&db).await?;
+        Ok(Self { db, _phantom: std::marker::PhantomData })
+    }
+}
+
+#[async_trait::async_trait]
+impl JobQueue<Webmention> for PostgresJobQueue<Webmention> {
+    type Job = PostgresJobItem<'static, Webmention>;
+    type Error = sqlx::Error;
+
+    async fn get_one(&self) -> Result<Option<Self::Job>, Self::Error> {
+        let mut txn = self.db.begin().await?;
+
+        match sqlx::query_as::<_, (Uuid, String, String)>(
+            "SELECT id, source, target FROM incoming_webmention_queue FOR UPDATE SKIP LOCKED LIMIT 1"
+        )
+            .fetch_optional(&mut txn)
+            .await?
+            .map(|(id, source, target)| (id, Webmention { source, target })) {
+                Some((id, webmention)) => Ok(Some(Self::Job {
+                    id,
+                    job: webmention,
+                    txn
+                })),
+                None => Ok(None)
+            }
+    }
+
+    async fn put(&self, item: &Webmention) -> Result<Uuid, Self::Error> {
+        sqlx::query_scalar::<_, Uuid>("INSERT INTO incoming_webmention_queue (source, target) VALUES ($1, $2) RETURNING id")
+            .bind(item.source.as_str())
+            .bind(item.target.as_str())
+            .fetch_one(&self.db)
+            .await
+    }
+
+    async fn into_stream(self) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>>, Self::Error> {
+        let mut listener = PgListener::connect_with(&self.db).await?;
+        listener.listen("incoming_webmention").await?;
+
+        let stream: Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>> = futures_util::stream::try_unfold((), {
+            let listener = std::sync::Arc::new(tokio::sync::Mutex::new(listener));
+            move |_| {
+                let queue = self.clone();
+                let listener = listener.clone();
+                async move {
+                    loop {
+                        match queue.get_one().await? {
+                            Some(item) => return Ok(Some((item, ()))),
+                            None => {
+                                listener.lock().await.recv().await?;
+                                continue
+                            }
+                        }
+                    }
+                }
+            }
+        }).boxed();
+
+        Ok(stream)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::{Webmention, PostgresJobQueue, Job, JobQueue};
+    use futures_util::StreamExt;
+    #[sqlx::test]
+    async fn test_webmention_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> {
+        let test_webmention = Webmention {
+            source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(),
+            target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned()
+        };
+
+        let queue = PostgresJobQueue::<Webmention>::from_pool(pool).await?;
+        println!("Putting webmention into queue");
+        queue.put(&test_webmention).await?;
+        assert_eq!(queue.get_one().await?.as_ref().map(|j| j.job()), Some(&test_webmention));
+        println!("Creating a stream");
+        let mut stream = queue.clone().into_stream().await?;
+
+        let future = stream.next();
+        let guard = future.await.transpose()?.unwrap();
+        assert_eq!(guard.job(), &test_webmention);
+        if let Some(item) = queue.get_one().await? {
+            panic!("Unexpected item {:?} returned from job queue!", item)
+        };
+        drop(guard);
+        let guard = stream.next().await.transpose()?.unwrap();
+        assert_eq!(guard.job(), &test_webmention);
+        guard.done().await?;
+        match queue.get_one().await? {
+            Some(item) => panic!("Unexpected item {:?} returned from job queue!", item),
+            None => Ok(())
+        }
+    }
+}