#![allow(unused_variables)] use std::borrow::Cow; use std::str::FromStr; use kittybox_util::{MicropubChannel, MentionType}; use sqlx::{PgPool, Executor}; use crate::micropub::{MicropubUpdate, MicropubPropertyDeletion}; use super::settings::Setting; use super::{Storage, Result, StorageError, ErrorKind}; static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!(); impl From for StorageError { fn from(value: sqlx::Error) -> Self { Self::with_source( super::ErrorKind::Backend, Cow::Owned(format!("sqlx error: {}", &value)), Box::new(value) ) } } impl From for StorageError { fn from(value: sqlx::migrate::MigrateError) -> Self { Self::with_source( super::ErrorKind::Backend, Cow::Owned(format!("sqlx migration error: {}", &value)), Box::new(value) ) } } #[derive(Debug, Clone)] pub struct PostgresStorage { db: PgPool } impl PostgresStorage { /// Construct a new [`PostgresStorage`] from an URI string and run /// migrations on the database. /// /// If `PGPASS_FILE` environment variable is defined, read the /// password from the file at the specified path. If, instead, /// the `PGPASS` environment variable is present, read the /// password from it. pub async fn new(uri: &str) -> Result { tracing::debug!("Postgres URL: {uri}"); let mut options = sqlx::postgres::PgConnectOptions::from_str(uri)? .options([("search_path", "kittybox")]); if let Ok(password_file) = std::env::var("PGPASS_FILE") { let password = tokio::fs::read_to_string(password_file).await.unwrap(); options = options.password(&password); } else if let Ok(password) = std::env::var("PGPASS") { options = options.password(&password) } Self::from_pool( sqlx::postgres::PgPoolOptions::new() .max_connections(50) .connect_with(options) .await? ).await } /// Construct a [`PostgresStorage`] from a [`sqlx::PgPool`], /// running appropriate migrations. pub async fn from_pool(db: sqlx::PgPool) -> Result { db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox")).await?; MIGRATOR.run(&db).await?; Ok(Self { db }) } } #[async_trait::async_trait] impl Storage for PostgresStorage { #[tracing::instrument(skip(self))] async fn post_exists(&self, url: &str) -> Result { sqlx::query_as::<_, (bool,)>("SELECT exists(SELECT 1 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1)") .bind(url) .fetch_one(&self.db) .await .map(|v| v.0) .map_err(|err| err.into()) } #[tracing::instrument(skip(self))] async fn get_post(&self, url: &str) -> Result> { sqlx::query_as::<_, (serde_json::Value,)>("SELECT mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1") .bind(url) .fetch_optional(&self.db) .await .map(|v| v.map(|v| v.0)) .map_err(|err| err.into()) } #[tracing::instrument(skip(self))] async fn put_post(&self, post: &'_ serde_json::Value, user: &'_ str) -> Result<()> { tracing::debug!("New post: {}", post); sqlx::query("INSERT INTO kittybox.mf2_json (uid, mf2, owner) VALUES ($1 #>> '{properties,uid,0}', $1, $2)") .bind(post) .bind(user) .execute(&self.db) .await .map(|_| ()) .map_err(Into::into) } #[tracing::instrument(skip(self))] async fn add_to_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> { tracing::debug!("Inserting {} into {}", post, feed); sqlx::query("INSERT INTO kittybox.children (parent, child) VALUES ($1, $2) ON CONFLICT DO NOTHING") .bind(feed) .bind(post) .execute(&self.db) .await .map(|_| ()) .map_err(Into::into) } #[tracing::instrument(skip(self))] async fn remove_from_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> { sqlx::query("DELETE FROM kittybox.children WHERE parent = $1 AND child = $2") .bind(feed) .bind(post) .execute(&self.db) .await .map_err(Into::into) .map(|_| ()) } #[tracing::instrument(skip(self))] async fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> Result<()> { let mut txn = self.db.begin().await?; let (uid, mut post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE") .bind(target) .fetch_optional(&mut *txn) .await? .ok_or(StorageError::from_static( ErrorKind::NotFound, "The specified post wasn't found in the database." ))?; tracing::debug!("Loaded post for target {} with uid {}", target, uid); let key: &'static str = match mention_type { MentionType::Reply => "comment", MentionType::Like => "like", MentionType::Repost => "repost", MentionType::Bookmark => "bookmark", MentionType::Mention => "mention", }; tracing::debug!("Mention type -> key: {}", key); let mention_uid = mention["properties"]["uid"][0].clone(); if let Some(values) = post["properties"][key].as_array_mut() { for value in values.iter_mut() { if value["properties"]["uid"][0] == mention_uid { *value = mention; break; } } } else { post["properties"][key] = serde_json::Value::Array(vec![mention]); } sqlx::query("UPDATE kittybox.mf2_json SET mf2 = $2 WHERE uid = $1") .bind(uid) .bind(post) .execute(&mut *txn) .await?; txn.commit().await.map_err(Into::into) } #[tracing::instrument(skip(self))] async fn update_post(&self, url: &'_ str, update: MicropubUpdate) -> Result<()> { tracing::debug!("Updating post {}", url); let mut txn = self.db.begin().await?; let (uid, mut post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE") .bind(url) .fetch_optional(&mut *txn) .await? .ok_or(StorageError::from_static( ErrorKind::NotFound, "The specified post wasn't found in the database." ))?; if let Some(MicropubPropertyDeletion::Properties(ref delete)) = update.delete { if let Some(props) = post["properties"].as_object_mut() { for key in delete { props.remove(key); } } } else if let Some(MicropubPropertyDeletion::Values(ref delete)) = update.delete { if let Some(props) = post["properties"].as_object_mut() { for (key, values) in delete { if let Some(prop) = props.get_mut(key).and_then(serde_json::Value::as_array_mut) { prop.retain(|v| { values.iter().all(|i| i != v) }) } } } } if let Some(replace) = update.replace { if let Some(props) = post["properties"].as_object_mut() { for (key, value) in replace { props.insert(key, serde_json::Value::Array(value)); } } } if let Some(add) = update.add { if let Some(props) = post["properties"].as_object_mut() { for (key, value) in add { if let Some(prop) = props.get_mut(&key).and_then(serde_json::Value::as_array_mut) { prop.extend_from_slice(value.as_slice()); } else { props.insert(key, serde_json::Value::Array(value)); } } } } sqlx::query("UPDATE kittybox.mf2_json SET mf2 = $2 WHERE uid = $1") .bind(uid) .bind(post) .execute(&mut *txn) .await?; txn.commit().await.map_err(Into::into) } #[tracing::instrument(skip(self))] async fn get_channels(&self, user: &'_ str) -> Result> { /*sqlx::query_as::<_, MicropubChannel>("SELECT name, uid FROM kittybox.channels WHERE owner = $1") .bind(user) .fetch_all(&self.db) .await .map_err(|err| err.into())*/ sqlx::query_as::<_, MicropubChannel>(r#"SELECT mf2 #>> '{properties,name,0}' as name, uid FROM kittybox.mf2_json WHERE '["h-feed"]'::jsonb @> mf2['type'] AND owner = $1"#) .bind(user) .fetch_all(&self.db) .await .map_err(|err| err.into()) } #[tracing::instrument(skip(self))] async fn read_feed_with_limit( &self, url: &'_ str, after: &'_ Option, limit: usize, user: &'_ Option, ) -> Result> { let mut feed = match sqlx::query_as::<_, (serde_json::Value,)>(" SELECT jsonb_set( mf2, '{properties,author,0}', (SELECT mf2 FROM kittybox.mf2_json WHERE uid = mf2 #>> '{properties,author,0}') ) FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 ") .bind(url) .fetch_optional(&self.db) .await? .map(|v| v.0) { Some(feed) => feed, None => return Ok(None) }; let posts: Vec = { let mut posts_iter = feed["children"] .as_array() .cloned() .unwrap_or_default() .into_iter() .map(|s| s.as_str().unwrap().to_string()); if let Some(after) = after { for s in posts_iter.by_ref() { if &s == after { break; } } }; posts_iter.take(limit).collect::>() }; feed["children"] = serde_json::Value::Array( sqlx::query_as::<_, (serde_json::Value,)>(" SELECT jsonb_set( mf2, '{properties,author,0}', (SELECT mf2 FROM kittybox.mf2_json WHERE uid = mf2 #>> '{properties,author,0}') ) FROM kittybox.mf2_json WHERE uid = ANY($1) ORDER BY mf2 #>> '{properties,published,0}' DESC ") .bind(&posts[..]) .fetch_all(&self.db) .await? .into_iter() .map(|v| v.0) .collect::>() ); Ok(Some(feed)) } #[tracing::instrument(skip(self))] async fn read_feed_with_cursor( &self, url: &'_ str, cursor: Option<&'_ str>, limit: usize, user: Option<&'_ str> ) -> Result)>> { let mut txn = self.db.begin().await?; sqlx::query("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ ONLY") .execute(&mut *txn) .await?; tracing::debug!("Started txn: {:?}", txn); let mut feed = match sqlx::query_scalar::<_, serde_json::Value>(" SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 ") .bind(url) .fetch_optional(&mut *txn) .await? { Some(feed) => feed, None => return Ok(None) }; // Don't query for children if this isn't a feed. // // The second query is very long and will probably be extremely // expensive. It's best to skip it on types where it doesn't make sense // (Kittybox doesn't support rendering children on non-feeds) if !feed["type"].as_array().unwrap().iter().any(|t| *t == serde_json::json!("h-feed")) { return Ok(Some((feed, None))); } feed["children"] = sqlx::query_scalar::<_, serde_json::Value>(" SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json INNER JOIN kittybox.children ON mf2_json.uid = children.child WHERE children.parent = $1 AND ( ( (mf2 #>> '{properties,visibility,0}') = 'public' OR NOT (mf2['properties'] ? 'visibility') ) OR ( $3 != null AND ( mf2['properties']['audience'] ? $3 OR mf2['properties']['author'] ? $3 ) ) ) AND ($4 IS NULL OR ((mf2_json.mf2 #>> '{properties,published,0}') < $4)) ORDER BY (mf2_json.mf2 #>> '{properties,published,0}') DESC LIMIT $2" ) .bind(url) .bind(limit as i64) .bind(user) .bind(cursor) .fetch_all(&mut *txn) .await .map(serde_json::Value::Array)?; let new_cursor = feed["children"].as_array().unwrap() .last() .map(|v| v["properties"]["published"][0].as_str().unwrap().to_owned()); txn.commit().await?; Ok(Some((feed, new_cursor))) } #[tracing::instrument(skip(self))] async fn delete_post(&self, url: &'_ str) -> Result<()> { todo!() } #[tracing::instrument(skip(self))] async fn get_setting, 'a>(&'_ self, user: &'_ str) -> Result { match sqlx::query_as::<_, (serde_json::Value,)>("SELECT kittybox.get_setting($1, $2)") .bind(user) .bind(S::ID) .fetch_one(&self.db) .await { Ok((value,)) => Ok(serde_json::from_value(value)?), Err(err) => Err(err.into()) } } #[tracing::instrument(skip(self))] async fn set_setting + 'a, 'a>(&self, user: &'a str, value: S::Data) -> Result<()> { sqlx::query("SELECT kittybox.set_setting($1, $2, $3)") .bind(user) .bind(S::ID) .bind(serde_json::to_value(S::new(value)).unwrap()) .execute(&self.db) .await .map_err(Into::into) .map(|_| ()) } }