use std::{pin::Pin, str::FromStr};
use futures_util::{Stream, StreamExt};
use sqlx::postgres::PgListener;
use uuid::Uuid;
use super::Webmention;
static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!();
pub use kittybox_util::queue::{JobQueue, JobItem, Job};
pub trait PostgresJobItem: JobItem + sqlx::FromRow<'static, sqlx::postgres::PgRow> {
const DATABASE_NAME: &'static str;
const NOTIFICATION_CHANNEL: &'static str;
}
#[derive(sqlx::FromRow)]
struct PostgresJobRow<T: PostgresJobItem> {
id: Uuid,
#[sqlx(flatten)]
job: T
}
#[derive(Debug)]
pub struct PostgresJob<T: PostgresJobItem> {
id: Uuid,
job: T,
// This will normally always be Some, except on drop
txn: Option<sqlx::Transaction<'static, sqlx::Postgres>>,
runtime_handle: tokio::runtime::Handle,
}
impl<T: PostgresJobItem> Drop for PostgresJob<T> {
// This is an emulation of "async drop" — the struct retains a
// runtime handle, which it uses to block on a future that does
// the actual cleanup.
//
// Of course, this is not portable between runtimes, but I don't
// care about that, since Kittybox is designed to work within the
// Tokio ecosystem.
fn drop(&mut self) {
tracing::error!("Job {:?} failed, incrementing attempts...", &self);
if let Some(mut txn) = self.txn.take() {
let id = self.id;
self.runtime_handle.spawn(async move {
tracing::debug!("Constructing query to increment attempts for job {}...", id);
// UPDATE "T::DATABASE_NAME" WHERE id = $1 SET attempts = attempts + 1
sqlx::query_builder::QueryBuilder::new("UPDATE ")
// This is safe from a SQL injection standpoint, since it is a constant.
.push(T::DATABASE_NAME)
.push(" SET attempts = attempts + 1")
.push(" WHERE id = ")
.push_bind(id)
.build()
.execute(&mut txn)
.await
.unwrap();
sqlx::query_builder::QueryBuilder::new("NOTIFY ")
.push(T::NOTIFICATION_CHANNEL)
.build()
.execute(&mut txn)
.await
.unwrap();
txn.commit().await.unwrap();
});
}
}
}
#[cfg(test)]
impl<T: PostgresJobItem> PostgresJob<T> {
async fn attempts(&mut self) -> Result<usize, sqlx::Error> {
sqlx::query_builder::QueryBuilder::new("SELECT attempts FROM ")
.push(T::DATABASE_NAME)
.push(" WHERE id = ")
.push_bind(self.id)
.build_query_as::<(i32,)>()
// It's safe to unwrap here, because we "take" the txn only on drop or commit,
// where it's passed by value, not by reference.
.fetch_one(self.txn.as_mut().unwrap())
.await
.map(|(i,)| i as usize)
}
}
#[async_trait::async_trait]
impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJob<Webmention> {
fn job(&self) -> &Webmention {
&self.job
}
async fn done(mut self) -> Result<(), <PostgresJobQueue<Webmention> as JobQueue<Webmention>>::Error> {
tracing::debug!("Deleting {} from the job queue", self.id);
sqlx::query("DELETE FROM kittybox.incoming_webmention_queue WHERE id = $1")
.bind(self.id)
.execute(self.txn.as_mut().unwrap())
.await?;
self.txn.take().unwrap().commit().await
}
}
pub struct PostgresJobQueue<T> {
db: sqlx::PgPool,
_phantom: std::marker::PhantomData<T>
}
impl<T> Clone for PostgresJobQueue<T> {
fn clone(&self) -> Self {
Self {
db: self.db.clone(),
_phantom: std::marker::PhantomData
}
}
}
impl PostgresJobQueue<Webmention> {
pub async fn new(uri: &str) -> Result<Self, sqlx::Error> {
let mut options = sqlx::postgres::PgConnectOptions::from_str(uri)?;
if let Ok(password_file) = std::env::var("PGPASS_FILE") {
let password = tokio::fs::read_to_string(password_file).await.unwrap();
options = options.password(&password);
} else if let Ok(password) = std::env::var("PGPASS") {
options = options.password(&password)
}
Ok(Self::from_pool(
sqlx::postgres::PgPoolOptions::new()
.max_connections(50)
.connect_with(options)
.await?
).await?)
}
pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result<Self, sqlx::migrate::MigrateError> {
MIGRATOR.run(&db).await?;
Ok(Self { db, _phantom: std::marker::PhantomData })
}
}
#[async_trait::async_trait]
impl JobQueue<Webmention> for PostgresJobQueue<Webmention> {
type Job = PostgresJob<Webmention>;
type Error = sqlx::Error;
async fn get_one(&self) -> Result<Option<Self::Job>, Self::Error> {
let mut txn = self.db.begin().await?;
match sqlx::query_as::<_, PostgresJobRow<Webmention>>(
"SELECT id, source, target FROM kittybox.incoming_webmention_queue WHERE attempts < 5 FOR UPDATE SKIP LOCKED LIMIT 1"
)
.fetch_optional(&mut txn)
.await?
{
Some(job_row) => {
return Ok(Some(Self::Job {
id: job_row.id,
job: job_row.job,
txn: Some(txn),
runtime_handle: tokio::runtime::Handle::current(),
}))
},
None => Ok(None)
}
}
async fn put(&self, item: &Webmention) -> Result<Uuid, Self::Error> {
sqlx::query_scalar::<_, Uuid>("INSERT INTO kittybox.incoming_webmention_queue (source, target) VALUES ($1, $2) RETURNING id")
.bind(item.source.as_str())
.bind(item.target.as_str())
.fetch_one(&self.db)
.await
}
async fn into_stream(self) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>>, Self::Error> {
let mut listener = PgListener::connect_with(&self.db).await?;
listener.listen("incoming_webmention").await?;
let stream: Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>> = futures_util::stream::try_unfold((), {
let listener = std::sync::Arc::new(tokio::sync::Mutex::new(listener));
move |_| {
let queue = self.clone();
let listener = listener.clone();
async move {
loop {
match queue.get_one().await? {
Some(item) => return Ok(Some((item, ()))),
None => {
listener.lock().await.recv().await?;
continue
}
}
}
}
}
}).boxed();
Ok(stream)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::{Webmention, PostgresJobQueue, Job, JobQueue};
use futures_util::StreamExt;
#[sqlx::test]
#[tracing_test::traced_test]
async fn test_webmention_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> {
let test_webmention = Webmention {
source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(),
target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned()
};
let queue = PostgresJobQueue::<Webmention>::from_pool(pool).await?;
tracing::debug!("Putting webmention into queue");
queue.put(&test_webmention).await?;
{
let mut job_description = queue.get_one().await?.unwrap();
assert_eq!(job_description.job(), &test_webmention);
assert_eq!(job_description.attempts().await?, 0);
}
tracing::debug!("Creating a stream");
let mut stream = queue.clone().into_stream().await?;
{
let mut guard = stream.next().await.transpose()?.unwrap();
assert_eq!(guard.job(), &test_webmention);
assert_eq!(guard.attempts().await?, 1);
if let Some(item) = queue.get_one().await? {
panic!("Unexpected item {:?} returned from job queue!", item)
};
}
{
let mut guard = stream.next().await.transpose()?.unwrap();
assert_eq!(guard.job(), &test_webmention);
assert_eq!(guard.attempts().await?, 2);
guard.done().await?;
}
match queue.get_one().await? {
Some(item) => panic!("Unexpected item {:?} returned from job queue!", item),
None => Ok(())
}
}
#[sqlx::test]
#[tracing_test::traced_test]
async fn test_no_hangups_in_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> {
let test_webmention = Webmention {
source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(),
target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned()
};
let queue = PostgresJobQueue::<Webmention>::from_pool(pool.clone()).await?;
tracing::debug!("Putting webmention into queue");
queue.put(&test_webmention).await?;
tracing::debug!("Creating a stream");
let mut stream = queue.clone().into_stream().await?;
// Synchronisation barrier that will be useful later
let barrier = Arc::new(tokio::sync::Barrier::new(2));
{
// Get one job guard from a queue
let mut guard = stream.next().await.transpose()?.unwrap();
assert_eq!(guard.job(), &test_webmention);
assert_eq!(guard.attempts().await?, 0);
tokio::task::spawn({
let barrier = barrier.clone();
async move {
// Wait for the signal to drop the guard!
barrier.wait().await;
drop(guard)
}
});
}
tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()).await.unwrap_err();
let future = tokio::task::spawn(
tokio::time::timeout(
std::time::Duration::from_secs(10), async move {
stream.next().await.unwrap().unwrap()
}
)
);
// Let the other task drop the guard it is holding
barrier.wait().await;
let mut guard = future.await
.expect("Timeout on fetching item")
.expect("Job queue error");
assert_eq!(guard.job(), &test_webmention);
assert_eq!(guard.attempts().await?, 1);
Ok(())
}
}