use std::{pin::Pin, str::FromStr}; use futures_util::{Stream, StreamExt}; use sqlx::postgres::PgListener; use uuid::Uuid; use super::Webmention; static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!(); pub use kittybox_util::queue::{JobQueue, JobItem, Job}; pub trait PostgresJobItem: JobItem + sqlx::FromRow<'static, sqlx::postgres::PgRow> { const DATABASE_NAME: &'static str; const NOTIFICATION_CHANNEL: &'static str; } #[derive(sqlx::FromRow)] struct PostgresJobRow<T: PostgresJobItem> { id: Uuid, #[sqlx(flatten)] job: T } #[derive(Debug)] pub struct PostgresJob<T: PostgresJobItem> { id: Uuid, job: T, // This will normally always be Some, except on drop txn: Option<sqlx::Transaction<'static, sqlx::Postgres>>, runtime_handle: tokio::runtime::Handle, } impl<T: PostgresJobItem> Drop for PostgresJob<T> { // This is an emulation of "async drop" — the struct retains a // runtime handle, which it uses to block on a future that does // the actual cleanup. // // Of course, this is not portable between runtimes, but I don't // care about that, since Kittybox is designed to work within the // Tokio ecosystem. fn drop(&mut self) { tracing::error!("Job {:?} failed, incrementing attempts...", &self); if let Some(mut txn) = self.txn.take() { let id = self.id; self.runtime_handle.spawn(async move { tracing::debug!("Constructing query to increment attempts for job {}...", id); // UPDATE "T::DATABASE_NAME" WHERE id = $1 SET attempts = attempts + 1 sqlx::query_builder::QueryBuilder::new("UPDATE ") // This is safe from a SQL injection standpoint, since it is a constant. .push(T::DATABASE_NAME) .push(" SET attempts = attempts + 1") .push(" WHERE id = ") .push_bind(id) .build() .execute(&mut *txn) .await .unwrap(); sqlx::query_builder::QueryBuilder::new("NOTIFY ") .push(T::NOTIFICATION_CHANNEL) .build() .execute(&mut *txn) .await .unwrap(); txn.commit().await.unwrap(); }); } } } #[cfg(test)] impl<T: PostgresJobItem> PostgresJob<T> { async fn attempts(&mut self) -> Result<usize, sqlx::Error> { sqlx::query_builder::QueryBuilder::new("SELECT attempts FROM ") .push(T::DATABASE_NAME) .push(" WHERE id = ") .push_bind(self.id) .build_query_as::<(i32,)>() // It's safe to unwrap here, because we "take" the txn only on drop or commit, // where it's passed by value, not by reference. .fetch_one(self.txn.as_deref_mut().unwrap()) .await .map(|(i,)| i as usize) } } #[async_trait::async_trait] impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJob<Webmention> { fn job(&self) -> &Webmention { &self.job } async fn done(mut self) -> Result<(), <PostgresJobQueue<Webmention> as JobQueue<Webmention>>::Error> { tracing::debug!("Deleting {} from the job queue", self.id); sqlx::query("DELETE FROM kittybox.incoming_webmention_queue WHERE id = $1") .bind(self.id) .execute(self.txn.as_deref_mut().unwrap()) .await?; self.txn.take().unwrap().commit().await } } pub struct PostgresJobQueue<T> { db: sqlx::PgPool, _phantom: std::marker::PhantomData<T> } impl<T> Clone for PostgresJobQueue<T> { fn clone(&self) -> Self { Self { db: self.db.clone(), _phantom: std::marker::PhantomData } } } impl PostgresJobQueue<Webmention> { pub async fn new(uri: &str) -> Result<Self, sqlx::Error> { let mut options = sqlx::postgres::PgConnectOptions::from_str(uri)?; if let Ok(password_file) = std::env::var("PGPASS_FILE") { let password = tokio::fs::read_to_string(password_file).await.unwrap(); options = options.password(&password); } else if let Ok(password) = std::env::var("PGPASS") { options = options.password(&password) } Ok(Self::from_pool( sqlx::postgres::PgPoolOptions::new() .max_connections(50) .connect_with(options) .await? ).await?) } pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result<Self, sqlx::migrate::MigrateError> { MIGRATOR.run(&db).await?; Ok(Self { db, _phantom: std::marker::PhantomData }) } } #[async_trait::async_trait] impl JobQueue<Webmention> for PostgresJobQueue<Webmention> { type Job = PostgresJob<Webmention>; type Error = sqlx::Error; async fn get_one(&self) -> Result<Option<Self::Job>, Self::Error> { let mut txn = self.db.begin().await?; match sqlx::query_as::<_, PostgresJobRow<Webmention>>( "SELECT id, source, target FROM kittybox.incoming_webmention_queue WHERE attempts < 5 FOR UPDATE SKIP LOCKED LIMIT 1" ) .fetch_optional(&mut *txn) .await? { Some(job_row) => { return Ok(Some(Self::Job { id: job_row.id, job: job_row.job, txn: Some(txn), runtime_handle: tokio::runtime::Handle::current(), })) }, None => Ok(None) } } async fn put(&self, item: &Webmention) -> Result<Uuid, Self::Error> { sqlx::query_scalar::<_, Uuid>("INSERT INTO kittybox.incoming_webmention_queue (source, target) VALUES ($1, $2) RETURNING id") .bind(item.source.as_str()) .bind(item.target.as_str()) .fetch_one(&self.db) .await } async fn into_stream(self) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>>, Self::Error> { let mut listener = PgListener::connect_with(&self.db).await?; listener.listen("incoming_webmention").await?; let stream: Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>> = futures_util::stream::try_unfold((), { let listener = std::sync::Arc::new(tokio::sync::Mutex::new(listener)); move |_| { let queue = self.clone(); let listener = listener.clone(); async move { loop { match queue.get_one().await? { Some(item) => return Ok(Some((item, ()))), None => { listener.lock().await.recv().await?; continue } } } } } }).boxed(); Ok(stream) } } #[cfg(test)] mod tests { use std::sync::Arc; use super::{Webmention, PostgresJobQueue, Job, JobQueue}; use futures_util::StreamExt; #[sqlx::test] #[tracing_test::traced_test] async fn test_webmention_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> { let test_webmention = Webmention { source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(), target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned() }; let queue = PostgresJobQueue::<Webmention>::from_pool(pool).await?; tracing::debug!("Putting webmention into queue"); queue.put(&test_webmention).await?; { let mut job_description = queue.get_one().await?.unwrap(); assert_eq!(job_description.job(), &test_webmention); assert_eq!(job_description.attempts().await?, 0); } tracing::debug!("Creating a stream"); let mut stream = queue.clone().into_stream().await?; { let mut guard = stream.next().await.transpose()?.unwrap(); assert_eq!(guard.job(), &test_webmention); assert_eq!(guard.attempts().await?, 1); if let Some(item) = queue.get_one().await? { panic!("Unexpected item {:?} returned from job queue!", item) }; } { let mut guard = stream.next().await.transpose()?.unwrap(); assert_eq!(guard.job(), &test_webmention); assert_eq!(guard.attempts().await?, 2); guard.done().await?; } match queue.get_one().await? { Some(item) => panic!("Unexpected item {:?} returned from job queue!", item), None => Ok(()) } } #[sqlx::test] #[tracing_test::traced_test] async fn test_no_hangups_in_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> { let test_webmention = Webmention { source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(), target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned() }; let queue = PostgresJobQueue::<Webmention>::from_pool(pool.clone()).await?; tracing::debug!("Putting webmention into queue"); queue.put(&test_webmention).await?; tracing::debug!("Creating a stream"); let mut stream = queue.clone().into_stream().await?; // Synchronisation barrier that will be useful later let barrier = Arc::new(tokio::sync::Barrier::new(2)); { // Get one job guard from a queue let mut guard = stream.next().await.transpose()?.unwrap(); assert_eq!(guard.job(), &test_webmention); assert_eq!(guard.attempts().await?, 0); tokio::task::spawn({ let barrier = barrier.clone(); async move { // Wait for the signal to drop the guard! barrier.wait().await; drop(guard) } }); } tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()).await.unwrap_err(); let future = tokio::task::spawn( tokio::time::timeout( std::time::Duration::from_secs(10), async move { stream.next().await.unwrap().unwrap() } ) ); // Let the other task drop the guard it is holding barrier.wait().await; let mut guard = future.await .expect("Timeout on fetching item") .expect("Job queue error"); assert_eq!(guard.job(), &test_webmention); assert_eq!(guard.attempts().await?, 1); Ok(()) } }