diff options
Diffstat (limited to 'src/database/postgres/mod.rs')
-rw-r--r-- | src/database/postgres/mod.rs | 416 |
1 files changed, 416 insertions, 0 deletions
diff --git a/src/database/postgres/mod.rs b/src/database/postgres/mod.rs new file mode 100644 index 0000000..9176d12 --- /dev/null +++ b/src/database/postgres/mod.rs @@ -0,0 +1,416 @@ +#![allow(unused_variables)] +use std::borrow::Cow; +use std::str::FromStr; + +use kittybox_util::{MicropubChannel, MentionType}; +use sqlx::{PgPool, Executor}; +use crate::micropub::{MicropubUpdate, MicropubPropertyDeletion}; + +use super::settings::Setting; +use super::{Storage, Result, StorageError, ErrorKind}; + +static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!(); + +impl From<sqlx::Error> for StorageError { + fn from(value: sqlx::Error) -> Self { + Self::with_source( + super::ErrorKind::Backend, + Cow::Owned(format!("sqlx error: {}", &value)), + Box::new(value) + ) + } +} + +impl From<sqlx::migrate::MigrateError> for StorageError { + fn from(value: sqlx::migrate::MigrateError) -> Self { + Self::with_source( + super::ErrorKind::Backend, + Cow::Owned(format!("sqlx migration error: {}", &value)), + Box::new(value) + ) + } +} + +#[derive(Debug, Clone)] +pub struct PostgresStorage { + db: PgPool +} + +impl PostgresStorage { + /// Construct a new [`PostgresStorage`] from an URI string and run + /// migrations on the database. + /// + /// If `PGPASS_FILE` environment variable is defined, read the + /// password from the file at the specified path. If, instead, + /// the `PGPASS` environment variable is present, read the + /// password from it. + pub async fn new(uri: &str) -> Result<Self> { + tracing::debug!("Postgres URL: {uri}"); + let mut options = sqlx::postgres::PgConnectOptions::from_str(uri)? + .options([("search_path", "kittybox")]); + if let Ok(password_file) = std::env::var("PGPASS_FILE") { + let password = tokio::fs::read_to_string(password_file).await.unwrap(); + options = options.password(&password); + } else if let Ok(password) = std::env::var("PGPASS") { + options = options.password(&password) + } + Self::from_pool( + sqlx::postgres::PgPoolOptions::new() + .max_connections(50) + .connect_with(options) + .await? + ).await + + } + + /// Construct a [`PostgresStorage`] from a [`sqlx::PgPool`], + /// running appropriate migrations. + pub async fn from_pool(db: sqlx::PgPool) -> Result<Self> { + db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox")).await?; + MIGRATOR.run(&db).await?; + Ok(Self { db }) + } +} + +#[async_trait::async_trait] +impl Storage for PostgresStorage { + #[tracing::instrument(skip(self))] + async fn post_exists(&self, url: &str) -> Result<bool> { + sqlx::query_as::<_, (bool,)>("SELECT exists(SELECT 1 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1)") + .bind(url) + .fetch_one(&self.db) + .await + .map(|v| v.0) + .map_err(|err| err.into()) + } + + #[tracing::instrument(skip(self))] + async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> { + sqlx::query_as::<_, (serde_json::Value,)>("SELECT mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1") + .bind(url) + .fetch_optional(&self.db) + .await + .map(|v| v.map(|v| v.0)) + .map_err(|err| err.into()) + + } + + #[tracing::instrument(skip(self))] + async fn put_post(&self, post: &'_ serde_json::Value, user: &'_ str) -> Result<()> { + tracing::debug!("New post: {}", post); + sqlx::query("INSERT INTO kittybox.mf2_json (uid, mf2, owner) VALUES ($1 #>> '{properties,uid,0}', $1, $2)") + .bind(post) + .bind(user) + .execute(&self.db) + .await + .map(|_| ()) + .map_err(Into::into) + } + + #[tracing::instrument(skip(self))] + async fn add_to_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> { + tracing::debug!("Inserting {} into {}", post, feed); + sqlx::query("INSERT INTO kittybox.children (parent, child) VALUES ($1, $2) ON CONFLICT DO NOTHING") + .bind(feed) + .bind(post) + .execute(&self.db) + .await + .map(|_| ()) + .map_err(Into::into) + } + + #[tracing::instrument(skip(self))] + async fn remove_from_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> { + sqlx::query("DELETE FROM kittybox.children WHERE parent = $1 AND child = $2") + .bind(feed) + .bind(post) + .execute(&self.db) + .await + .map_err(Into::into) + .map(|_| ()) + } + + #[tracing::instrument(skip(self))] + async fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> Result<()> { + let mut txn = self.db.begin().await?; + + let (uid, mut post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE") + .bind(target) + .fetch_optional(&mut *txn) + .await? + .ok_or(StorageError::from_static( + ErrorKind::NotFound, + "The specified post wasn't found in the database." + ))?; + + tracing::debug!("Loaded post for target {} with uid {}", target, uid); + + let key: &'static str = match mention_type { + MentionType::Reply => "comment", + MentionType::Like => "like", + MentionType::Repost => "repost", + MentionType::Bookmark => "bookmark", + MentionType::Mention => "mention", + }; + + tracing::debug!("Mention type -> key: {}", key); + + let mention_uid = mention["properties"]["uid"][0].clone(); + if let Some(values) = post["properties"][key].as_array_mut() { + for value in values.iter_mut() { + if value["properties"]["uid"][0] == mention_uid { + *value = mention; + break; + } + } + } else { + post["properties"][key] = serde_json::Value::Array(vec![mention]); + } + + sqlx::query("UPDATE kittybox.mf2_json SET mf2 = $2 WHERE uid = $1") + .bind(uid) + .bind(post) + .execute(&mut *txn) + .await?; + + txn.commit().await.map_err(Into::into) + } + #[tracing::instrument(skip(self))] + async fn update_post(&self, url: &'_ str, update: MicropubUpdate) -> Result<()> { + tracing::debug!("Updating post {}", url); + let mut txn = self.db.begin().await?; + let (uid, mut post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE") + .bind(url) + .fetch_optional(&mut *txn) + .await? + .ok_or(StorageError::from_static( + ErrorKind::NotFound, + "The specified post wasn't found in the database." + ))?; + + if let Some(MicropubPropertyDeletion::Properties(ref delete)) = update.delete { + if let Some(props) = post["properties"].as_object_mut() { + for key in delete { + props.remove(key); + } + } + } else if let Some(MicropubPropertyDeletion::Values(ref delete)) = update.delete { + if let Some(props) = post["properties"].as_object_mut() { + for (key, values) in delete { + if let Some(prop) = props.get_mut(key).and_then(serde_json::Value::as_array_mut) { + prop.retain(|v| { values.iter().all(|i| i != v) }) + } + } + } + } + if let Some(replace) = update.replace { + if let Some(props) = post["properties"].as_object_mut() { + for (key, value) in replace { + props.insert(key, serde_json::Value::Array(value)); + } + } + } + if let Some(add) = update.add { + if let Some(props) = post["properties"].as_object_mut() { + for (key, value) in add { + if let Some(prop) = props.get_mut(&key).and_then(serde_json::Value::as_array_mut) { + prop.extend_from_slice(value.as_slice()); + } else { + props.insert(key, serde_json::Value::Array(value)); + } + } + } + } + + sqlx::query("UPDATE kittybox.mf2_json SET mf2 = $2 WHERE uid = $1") + .bind(uid) + .bind(post) + .execute(&mut *txn) + .await?; + + txn.commit().await.map_err(Into::into) + } + + #[tracing::instrument(skip(self))] + async fn get_channels(&self, user: &'_ str) -> Result<Vec<MicropubChannel>> { + /*sqlx::query_as::<_, MicropubChannel>("SELECT name, uid FROM kittybox.channels WHERE owner = $1") + .bind(user) + .fetch_all(&self.db) + .await + .map_err(|err| err.into())*/ + sqlx::query_as::<_, MicropubChannel>(r#"SELECT mf2 #>> '{properties,name,0}' as name, uid FROM kittybox.mf2_json WHERE '["h-feed"]'::jsonb @> mf2['type'] AND owner = $1"#) + .bind(user) + .fetch_all(&self.db) + .await + .map_err(|err| err.into()) + } + + #[tracing::instrument(skip(self))] + async fn read_feed_with_limit( + &self, + url: &'_ str, + after: &'_ Option<String>, + limit: usize, + user: &'_ Option<String>, + ) -> Result<Option<serde_json::Value>> { + let mut feed = match sqlx::query_as::<_, (serde_json::Value,)>(" +SELECT jsonb_set( + mf2, + '{properties,author,0}', + (SELECT mf2 FROM kittybox.mf2_json + WHERE uid = mf2 #>> '{properties,author,0}') +) FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 +") + .bind(url) + .fetch_optional(&self.db) + .await? + .map(|v| v.0) + { + Some(feed) => feed, + None => return Ok(None) + }; + + let posts: Vec<String> = { + let mut posts_iter = feed["children"] + .as_array() + .cloned() + .unwrap_or_default() + .into_iter() + .map(|s| s.as_str().unwrap().to_string()); + if let Some(after) = after { + for s in posts_iter.by_ref() { + if &s == after { + break; + } + } + }; + + posts_iter.take(limit).collect::<Vec<_>>() + }; + feed["children"] = serde_json::Value::Array( + sqlx::query_as::<_, (serde_json::Value,)>(" +SELECT jsonb_set( + mf2, + '{properties,author,0}', + (SELECT mf2 FROM kittybox.mf2_json + WHERE uid = mf2 #>> '{properties,author,0}') +) FROM kittybox.mf2_json +WHERE uid = ANY($1) +ORDER BY mf2 #>> '{properties,published,0}' DESC +") + .bind(&posts[..]) + .fetch_all(&self.db) + .await? + .into_iter() + .map(|v| v.0) + .collect::<Vec<_>>() + ); + + Ok(Some(feed)) + + } + + #[tracing::instrument(skip(self))] + async fn read_feed_with_cursor( + &self, + url: &'_ str, + cursor: Option<&'_ str>, + limit: usize, + user: Option<&'_ str> + ) -> Result<Option<(serde_json::Value, Option<String>)>> { + let mut txn = self.db.begin().await?; + sqlx::query("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ ONLY") + .execute(&mut *txn) + .await?; + tracing::debug!("Started txn: {:?}", txn); + let mut feed = match sqlx::query_scalar::<_, serde_json::Value>(" +SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 +") + .bind(url) + .fetch_optional(&mut *txn) + .await? + { + Some(feed) => feed, + None => return Ok(None) + }; + + // Don't query for children if this isn't a feed. + // + // The second query is very long and will probably be extremely + // expensive. It's best to skip it on types where it doesn't make sense + // (Kittybox doesn't support rendering children on non-feeds) + if !feed["type"].as_array().unwrap().iter().any(|t| *t == serde_json::json!("h-feed")) { + return Ok(Some((feed, None))); + } + + feed["children"] = sqlx::query_scalar::<_, serde_json::Value>(" +SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json +INNER JOIN kittybox.children +ON mf2_json.uid = children.child +WHERE + children.parent = $1 + AND ( + ( + (mf2 #>> '{properties,visibility,0}') = 'public' + OR + NOT (mf2['properties'] ? 'visibility') + ) + OR + ( + $3 != null AND ( + mf2['properties']['audience'] ? $3 + OR mf2['properties']['author'] ? $3 + ) + ) + ) + AND ($4 IS NULL OR ((mf2_json.mf2 #>> '{properties,published,0}') < $4)) +ORDER BY (mf2_json.mf2 #>> '{properties,published,0}') DESC +LIMIT $2" + ) + .bind(url) + .bind(limit as i64) + .bind(user) + .bind(cursor) + .fetch_all(&mut *txn) + .await + .map(serde_json::Value::Array)?; + + let new_cursor = feed["children"].as_array().unwrap() + .last() + .map(|v| v["properties"]["published"][0].as_str().unwrap().to_owned()); + + txn.commit().await?; + + Ok(Some((feed, new_cursor))) + } + + #[tracing::instrument(skip(self))] + async fn delete_post(&self, url: &'_ str) -> Result<()> { + todo!() + } + + #[tracing::instrument(skip(self))] + async fn get_setting<S: Setting<'a>, 'a>(&'_ self, user: &'_ str) -> Result<S> { + match sqlx::query_as::<_, (serde_json::Value,)>("SELECT kittybox.get_setting($1, $2)") + .bind(user) + .bind(S::ID) + .fetch_one(&self.db) + .await + { + Ok((value,)) => Ok(serde_json::from_value(value)?), + Err(err) => Err(err.into()) + } + } + + #[tracing::instrument(skip(self))] + async fn set_setting<S: Setting<'a> + 'a, 'a>(&self, user: &'a str, value: S::Data) -> Result<()> { + sqlx::query("SELECT kittybox.set_setting($1, $2, $3)") + .bind(user) + .bind(S::ID) + .bind(serde_json::to_value(S::new(value)).unwrap()) + .execute(&self.db) + .await + .map_err(Into::into) + .map(|_| ()) + } +} |