diff options
author | Vika <vika@fireburn.ru> | 2023-07-29 21:59:56 +0300 |
---|---|---|
committer | Vika <vika@fireburn.ru> | 2023-07-29 21:59:56 +0300 |
commit | 0617663b249f9ca488e5de652108b17d67fbaf45 (patch) | |
tree | 11564b6c8fa37bf9203a0a4cc1c4e9cc088cb1a5 /src/webmentions/queue.rs | |
parent | 26c2b79f6a6380ae3224e9309b9f3352f5717bd7 (diff) | |
download | kittybox-0617663b249f9ca488e5de652108b17d67fbaf45.tar.zst |
Moved the entire Kittybox tree into the root
Diffstat (limited to 'src/webmentions/queue.rs')
-rw-r--r-- | src/webmentions/queue.rs | 303 |
1 files changed, 303 insertions, 0 deletions
diff --git a/src/webmentions/queue.rs b/src/webmentions/queue.rs new file mode 100644 index 0000000..b811e71 --- /dev/null +++ b/src/webmentions/queue.rs @@ -0,0 +1,303 @@ +use std::{pin::Pin, str::FromStr}; + +use futures_util::{Stream, StreamExt}; +use sqlx::{postgres::PgListener, Executor}; +use uuid::Uuid; + +use super::Webmention; + +static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("./migrations/webmention"); + +pub use kittybox_util::queue::{JobQueue, JobItem, Job}; + +pub trait PostgresJobItem: JobItem + sqlx::FromRow<'static, sqlx::postgres::PgRow> { + const DATABASE_NAME: &'static str; + const NOTIFICATION_CHANNEL: &'static str; +} + +#[derive(sqlx::FromRow)] +struct PostgresJobRow<T: PostgresJobItem> { + id: Uuid, + #[sqlx(flatten)] + job: T +} + +#[derive(Debug)] +pub struct PostgresJob<T: PostgresJobItem> { + id: Uuid, + job: T, + // This will normally always be Some, except on drop + txn: Option<sqlx::Transaction<'static, sqlx::Postgres>>, + runtime_handle: tokio::runtime::Handle, +} + + +impl<T: PostgresJobItem> Drop for PostgresJob<T> { + // This is an emulation of "async drop" — the struct retains a + // runtime handle, which it uses to block on a future that does + // the actual cleanup. + // + // Of course, this is not portable between runtimes, but I don't + // care about that, since Kittybox is designed to work within the + // Tokio ecosystem. + fn drop(&mut self) { + tracing::error!("Job {:?} failed, incrementing attempts...", &self); + if let Some(mut txn) = self.txn.take() { + let id = self.id; + self.runtime_handle.spawn(async move { + tracing::debug!("Constructing query to increment attempts for job {}...", id); + // UPDATE "T::DATABASE_NAME" WHERE id = $1 SET attempts = attempts + 1 + sqlx::query_builder::QueryBuilder::new("UPDATE ") + // This is safe from a SQL injection standpoint, since it is a constant. + .push(T::DATABASE_NAME) + .push(" SET attempts = attempts + 1") + .push(" WHERE id = ") + .push_bind(id) + .build() + .execute(&mut *txn) + .await + .unwrap(); + sqlx::query_builder::QueryBuilder::new("NOTIFY ") + .push(T::NOTIFICATION_CHANNEL) + .build() + .execute(&mut *txn) + .await + .unwrap(); + txn.commit().await.unwrap(); + }); + } + } +} + +#[cfg(test)] +impl<T: PostgresJobItem> PostgresJob<T> { + async fn attempts(&mut self) -> Result<usize, sqlx::Error> { + sqlx::query_builder::QueryBuilder::new("SELECT attempts FROM ") + .push(T::DATABASE_NAME) + .push(" WHERE id = ") + .push_bind(self.id) + .build_query_as::<(i32,)>() + // It's safe to unwrap here, because we "take" the txn only on drop or commit, + // where it's passed by value, not by reference. + .fetch_one(self.txn.as_deref_mut().unwrap()) + .await + .map(|(i,)| i as usize) + } +} + +#[async_trait::async_trait] +impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJob<Webmention> { + fn job(&self) -> &Webmention { + &self.job + } + async fn done(mut self) -> Result<(), <PostgresJobQueue<Webmention> as JobQueue<Webmention>>::Error> { + tracing::debug!("Deleting {} from the job queue", self.id); + sqlx::query("DELETE FROM kittybox_webmention.incoming_webmention_queue WHERE id = $1") + .bind(self.id) + .execute(self.txn.as_deref_mut().unwrap()) + .await?; + + self.txn.take().unwrap().commit().await + } +} + +pub struct PostgresJobQueue<T> { + db: sqlx::PgPool, + _phantom: std::marker::PhantomData<T> +} +impl<T> Clone for PostgresJobQueue<T> { + fn clone(&self) -> Self { + Self { + db: self.db.clone(), + _phantom: std::marker::PhantomData + } + } +} + +impl PostgresJobQueue<Webmention> { + pub async fn new(uri: &str) -> Result<Self, sqlx::Error> { + let mut options = sqlx::postgres::PgConnectOptions::from_str(uri)? + .options([("search_path", "kittybox_webmention")]); + if let Ok(password_file) = std::env::var("PGPASS_FILE") { + let password = tokio::fs::read_to_string(password_file).await.unwrap(); + options = options.password(&password); + } else if let Ok(password) = std::env::var("PGPASS") { + options = options.password(&password) + } + Self::from_pool( + sqlx::postgres::PgPoolOptions::new() + .max_connections(50) + .connect_with(options) + .await? + ).await + + } + + pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result<Self, sqlx::Error> { + db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox_webmention")).await?; + MIGRATOR.run(&db).await?; + Ok(Self { db, _phantom: std::marker::PhantomData }) + } +} + +#[async_trait::async_trait] +impl JobQueue<Webmention> for PostgresJobQueue<Webmention> { + type Job = PostgresJob<Webmention>; + type Error = sqlx::Error; + + async fn get_one(&self) -> Result<Option<Self::Job>, Self::Error> { + let mut txn = self.db.begin().await?; + + match sqlx::query_as::<_, PostgresJobRow<Webmention>>( + "SELECT id, source, target FROM kittybox_webmention.incoming_webmention_queue WHERE attempts < 5 FOR UPDATE SKIP LOCKED LIMIT 1" + ) + .fetch_optional(&mut *txn) + .await? + { + Some(job_row) => { + return Ok(Some(Self::Job { + id: job_row.id, + job: job_row.job, + txn: Some(txn), + runtime_handle: tokio::runtime::Handle::current(), + })) + }, + None => Ok(None) + } + } + + async fn put(&self, item: &Webmention) -> Result<Uuid, Self::Error> { + sqlx::query_scalar::<_, Uuid>("INSERT INTO kittybox_webmention.incoming_webmention_queue (source, target) VALUES ($1, $2) RETURNING id") + .bind(item.source.as_str()) + .bind(item.target.as_str()) + .fetch_one(&self.db) + .await + } + + async fn into_stream(self) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>>, Self::Error> { + let mut listener = PgListener::connect_with(&self.db).await?; + listener.listen("incoming_webmention").await?; + + let stream: Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>> = futures_util::stream::try_unfold((), { + let listener = std::sync::Arc::new(tokio::sync::Mutex::new(listener)); + move |_| { + let queue = self.clone(); + let listener = listener.clone(); + async move { + loop { + match queue.get_one().await? { + Some(item) => return Ok(Some((item, ()))), + None => { + listener.lock().await.recv().await?; + continue + } + } + } + } + } + }).boxed(); + + Ok(stream) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::{Webmention, PostgresJobQueue, Job, JobQueue, MIGRATOR}; + use futures_util::StreamExt; + + #[sqlx::test(migrator = "MIGRATOR")] + #[tracing_test::traced_test] + async fn test_webmention_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> { + let test_webmention = Webmention { + source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(), + target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned() + }; + + let queue = PostgresJobQueue::<Webmention>::from_pool(pool).await?; + tracing::debug!("Putting webmention into queue"); + queue.put(&test_webmention).await?; + { + let mut job_description = queue.get_one().await?.unwrap(); + assert_eq!(job_description.job(), &test_webmention); + assert_eq!(job_description.attempts().await?, 0); + } + tracing::debug!("Creating a stream"); + let mut stream = queue.clone().into_stream().await?; + + { + let mut guard = stream.next().await.transpose()?.unwrap(); + assert_eq!(guard.job(), &test_webmention); + assert_eq!(guard.attempts().await?, 1); + if let Some(item) = queue.get_one().await? { + panic!("Unexpected item {:?} returned from job queue!", item) + }; + } + + { + let mut guard = stream.next().await.transpose()?.unwrap(); + assert_eq!(guard.job(), &test_webmention); + assert_eq!(guard.attempts().await?, 2); + guard.done().await?; + } + + match queue.get_one().await? { + Some(item) => panic!("Unexpected item {:?} returned from job queue!", item), + None => Ok(()) + } + } + + #[sqlx::test(migrator = "MIGRATOR")] + #[tracing_test::traced_test] + async fn test_no_hangups_in_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> { + let test_webmention = Webmention { + source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(), + target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned() + }; + + let queue = PostgresJobQueue::<Webmention>::from_pool(pool.clone()).await?; + tracing::debug!("Putting webmention into queue"); + queue.put(&test_webmention).await?; + tracing::debug!("Creating a stream"); + let mut stream = queue.clone().into_stream().await?; + + // Synchronisation barrier that will be useful later + let barrier = Arc::new(tokio::sync::Barrier::new(2)); + { + // Get one job guard from a queue + let mut guard = stream.next().await.transpose()?.unwrap(); + assert_eq!(guard.job(), &test_webmention); + assert_eq!(guard.attempts().await?, 0); + + tokio::task::spawn({ + let barrier = barrier.clone(); + async move { + // Wait for the signal to drop the guard! + barrier.wait().await; + + drop(guard) + } + }); + } + tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()).await.unwrap_err(); + + let future = tokio::task::spawn( + tokio::time::timeout( + std::time::Duration::from_secs(10), async move { + stream.next().await.unwrap().unwrap() + } + ) + ); + // Let the other task drop the guard it is holding + barrier.wait().await; + let mut guard = future.await + .expect("Timeout on fetching item") + .expect("Job queue error"); + assert_eq!(guard.job(), &test_webmention); + assert_eq!(guard.attempts().await?, 1); + + Ok(()) + } +} |