use std::borrow::Cow;
use futures::{Stream, StreamExt};
use kittybox_util::{micropub::Channel as MicropubChannel, MentionType};
use sqlx::{ConnectOptions, Executor, PgPool};
use crate::micropub::{MicropubUpdate, MicropubPropertyDeletion};
use super::settings::Setting;
use super::{Storage, Result, StorageError, ErrorKind};
static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!();
impl From<sqlx::Error> for StorageError {
fn from(value: sqlx::Error) -> Self {
Self::with_source(
super::ErrorKind::Backend,
Cow::Owned(format!("sqlx error: {}", &value)),
Box::new(value)
)
}
}
impl From<sqlx::migrate::MigrateError> for StorageError {
fn from(value: sqlx::migrate::MigrateError) -> Self {
Self::with_source(
super::ErrorKind::Backend,
Cow::Owned(format!("sqlx migration error: {}", &value)),
Box::new(value)
)
}
}
/// Micropub storage that uses a PostgreSQL database.
#[derive(Debug, Clone)]
pub struct PostgresStorage {
db: PgPool
}
impl PostgresStorage {
/// Construct a [`PostgresStorage`] from a [`sqlx::PgPool`],
/// running appropriate migrations.
pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result<Self> {
db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox")).await?;
MIGRATOR.run(&db).await?;
Ok(Self { db })
}
}
impl Storage for PostgresStorage {
/// Construct a new [`PostgresStorage`] from an URI string and run
/// migrations on the database.
///
/// If `PGPASS_FILE` environment variable is defined, read the
/// password from the file at the specified path. If, instead,
/// the `PGPASS` environment variable is present, read the
/// password from it.
async fn new(url: &'_ url::Url) -> Result<Self> {
tracing::debug!("Postgres URL: {url}");
let mut options = sqlx::postgres::PgConnectOptions::from_url(url)?
.options([("search_path", "kittybox")]);
if let Ok(password_file) = std::env::var("PGPASS_FILE") {
let password = tokio::fs::read_to_string(password_file).await.unwrap();
options = options.password(&password);
} else if let Ok(password) = std::env::var("PGPASS") {
options = options.password(&password)
}
Self::from_pool(
sqlx::postgres::PgPoolOptions::new()
.max_connections(50)
.connect_with(options)
.await?
).await
}
async fn all_posts<'this>(&'this self, user: &url::Url) -> Result<impl Stream<Item = serde_json::Value> + Send + 'this> {
let authority = user.authority().to_owned();
Ok(
sqlx::query_scalar::<_, serde_json::Value>("SELECT mf2 FROM kittybox.mf2_json WHERE owner = $1 ORDER BY (mf2_json.mf2 #>> '{properties,published,0}') DESC")
.bind(authority)
.fetch(&self.db)
.filter_map(|f| std::future::ready(f.ok()))
)
}
#[tracing::instrument(skip(self))]
async fn categories(&self, url: &str) -> Result<Vec<String>> {
sqlx::query_scalar::<_, String>("
SELECT jsonb_array_elements(mf2['properties']['category']) AS category
FROM kittybox.mf2_json
WHERE
jsonb_typeof(mf2['properties']['category']) = 'array'
AND uid LIKE ($1 + '%')
GROUP BY category ORDER BY count(*) DESC
")
.bind(url)
.fetch_all(&self.db)
.await
.map_err(|err| err.into())
}
#[tracing::instrument(skip(self))]
async fn post_exists(&self, url: &str) -> Result<bool> {
sqlx::query_as::<_, (bool,)>("SELECT exists(SELECT 1 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1)")
.bind(url)
.fetch_one(&self.db)
.await
.map(|v| v.0)
.map_err(|err| err.into())
}
#[tracing::instrument(skip(self))]
async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> {
sqlx::query_as::<_, (serde_json::Value,)>("SELECT mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1")
.bind(url)
.fetch_optional(&self.db)
.await
.map(|v| v.map(|v| v.0))
.map_err(|err| err.into())
}
#[tracing::instrument(skip(self))]
async fn put_post(&self, post: &'_ serde_json::Value, user: &url::Url) -> Result<()> {
tracing::debug!("New post: {}", post);
sqlx::query("INSERT INTO kittybox.mf2_json (uid, mf2, owner) VALUES ($1 #>> '{properties,uid,0}', $1, $2)")
.bind(post)
.bind(user.authority())
.execute(&self.db)
.await
.map(|_| ())
.map_err(Into::into)
}
#[tracing::instrument(skip(self))]
async fn add_to_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> {
tracing::debug!("Inserting {} into {}", post, feed);
sqlx::query("INSERT INTO kittybox.children (parent, child) VALUES ($1, $2) ON CONFLICT DO NOTHING")
.bind(feed)
.bind(post)
.execute(&self.db)
.await
.map(|_| ())
.map_err(Into::into)
}
#[tracing::instrument(skip(self))]
async fn remove_from_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> {
sqlx::query("DELETE FROM kittybox.children WHERE parent = $1 AND child = $2")
.bind(feed)
.bind(post)
.execute(&self.db)
.await
.map_err(Into::into)
.map(|_| ())
}
#[tracing::instrument(skip(self))]
async fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> Result<()> {
let mut txn = self.db.begin().await?;
let (uid, mut post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE")
.bind(target)
.fetch_optional(&mut *txn)
.await?
.ok_or(StorageError::from_static(
ErrorKind::NotFound,
"The specified post wasn't found in the database."
))?;
tracing::debug!("Loaded post for target {} with uid {}", target, uid);
let key: &'static str = match mention_type {
MentionType::Reply => "comment",
MentionType::Like => "like",
MentionType::Repost => "repost",
MentionType::Bookmark => "bookmark",
MentionType::Mention => "mention",
};
tracing::debug!("Mention type -> key: {}", key);
let mention_uid = mention["properties"]["uid"][0].clone();
if let Some(values) = post["properties"][key].as_array_mut() {
for value in values.iter_mut() {
if value["properties"]["uid"][0] == mention_uid {
*value = mention;
break;
}
}
} else {
post["properties"][key] = serde_json::Value::Array(vec![mention]);
}
sqlx::query("UPDATE kittybox.mf2_json SET mf2 = $2 WHERE uid = $1")
.bind(uid)
.bind(post)
.execute(&mut *txn)
.await?;
txn.commit().await.map_err(Into::into)
}
#[tracing::instrument(skip(self), fields(f = std::any::type_name::<F>()))]
async fn update_with<F: FnOnce(&mut serde_json::Value) + Send>(
&self, url: &str, f: F
) -> Result<(serde_json::Value, serde_json::Value)> {
tracing::debug!("Updating post {}", url);
let mut txn = self.db.begin().await?;
let (uid, old_post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE")
.bind(url)
.fetch_optional(&mut *txn)
.await?
.ok_or(StorageError::from_static(
ErrorKind::NotFound,
"The specified post wasn't found in the database."
))?;
let new_post = {
let mut post = old_post.clone();
#[cfg(not(test))] // Tests use the current-thread runtime.
tokio::task::block_in_place(|| f(&mut post));
#[cfg(test)]
f(&mut post);
post
};
sqlx::query("UPDATE kittybox.mf2_json SET mf2 = $2 WHERE uid = $1")
.bind(uid)
.bind(&new_post)
.execute(&mut *txn)
.await?;
txn.commit().await?;
Ok((old_post, new_post))
}
#[tracing::instrument(skip(self))]
async fn get_channels(&self, user: &url::Url) -> Result<Vec<MicropubChannel>> {
sqlx::query_as::<_, MicropubChannel>(r#"SELECT mf2 #>> '{properties,name,0}' as name, uid FROM kittybox.mf2_json WHERE '["h-feed"]'::jsonb @> mf2['type'] AND owner = $1"#)
.bind(user.authority())
.fetch_all(&self.db)
.await
.map_err(|err| err.into())
}
#[tracing::instrument(skip(self))]
async fn read_feed_with_limit(
&self,
url: &'_ str,
_after: Option<&str>,
_limit: usize,
_user: Option<&url::Url>,
) -> Result<Option<serde_json::Value>> {
unimplemented!("read_feed_with_limit is insecure and deprecated");
}
#[tracing::instrument(skip(self))]
async fn read_feed_with_cursor(
&self,
url: &'_ str,
cursor: Option<&'_ str>,
limit: usize,
user: Option<&url::Url>
) -> Result<Option<(serde_json::Value, Option<String>)>> {
let mut txn = self.db.begin().await?;
sqlx::query("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ ONLY")
.execute(&mut *txn)
.await?;
tracing::debug!("Started txn: {:?}", txn);
let mut feed = match sqlx::query_scalar::<_, serde_json::Value>("
SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1
")
.bind(url)
.fetch_optional(&mut *txn)
.await?
{
Some(feed) => feed,
None => return Ok(None)
};
// Don't query for children if this isn't a feed.
//
// The second query is very long and will probably be extremely
// expensive. It's best to skip it on types where it doesn't make sense
// (Kittybox doesn't support rendering children on non-feeds)
if !feed["type"].as_array().unwrap().iter().any(|t| *t == serde_json::json!("h-feed")) {
return Ok(Some((feed, None)));
}
feed["children"] = sqlx::query_scalar::<_, serde_json::Value>("
SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json
INNER JOIN kittybox.children
ON mf2_json.uid = children.child
WHERE
children.parent = $1
AND (
(
(mf2 #>> '{properties,visibility,0}') = 'public'
OR
NOT (mf2['properties'] ? 'visibility')
)
OR
(
$3 IS NOT NULL AND (
(NOT (mf2['properties'] ? 'audience'))
OR
(mf2['properties']['audience'] ? $3)
OR
(mf2['properties']['author'] ? $3)
)
)
)
AND ($4 IS NULL OR ((mf2_json.mf2 #>> '{properties,published,0}') < $4))
ORDER BY (mf2_json.mf2 #>> '{properties,published,0}') DESC
LIMIT $2"
)
.bind(url)
.bind(limit as i64)
.bind(user.map(url::Url::as_str))
.bind(cursor)
.fetch_all(&mut *txn)
.await
.map(serde_json::Value::Array)?;
let new_cursor = feed["children"].as_array().unwrap()
.last()
.map(|v| v["properties"]["published"][0].as_str().unwrap().to_owned());
txn.commit().await?;
Ok(Some((feed, new_cursor)))
}
#[tracing::instrument(skip(self))]
async fn delete_post(&self, url: &'_ str) -> Result<()> {
todo!()
}
#[tracing::instrument(skip(self))]
async fn get_setting<S: Setting>(&'_ self, user: &url::Url) -> Result<S> {
match sqlx::query_as::<_, (serde_json::Value,)>("SELECT kittybox.get_setting($1, $2)")
.bind(user.authority())
.bind(S::ID)
.fetch_one(&self.db)
.await
{
Ok((value,)) => Ok(serde_json::from_value(value)?),
Err(err) => Err(err.into())
}
}
#[tracing::instrument(skip(self))]
async fn set_setting<S: Setting>(&self, user: &url::Url, value: S::Data) -> Result<()> {
sqlx::query("SELECT kittybox.set_setting($1, $2, $3)")
.bind(user.authority())
.bind(S::ID)
.bind(serde_json::to_value(S::new(value)).unwrap())
.execute(&self.db)
.await
.map_err(Into::into)
.map(|_| ())
}
}