use std::borrow::Cow; use kittybox_util::{micropub::Channel as MicropubChannel, MentionType}; use sqlx::{ConnectOptions, Executor, PgPool}; use crate::micropub::{MicropubUpdate, MicropubPropertyDeletion}; use super::settings::Setting; use super::{Storage, Result, StorageError, ErrorKind}; static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!(); impl From for StorageError { fn from(value: sqlx::Error) -> Self { Self::with_source( super::ErrorKind::Backend, Cow::Owned(format!("sqlx error: {}", &value)), Box::new(value) ) } } impl From for StorageError { fn from(value: sqlx::migrate::MigrateError) -> Self { Self::with_source( super::ErrorKind::Backend, Cow::Owned(format!("sqlx migration error: {}", &value)), Box::new(value) ) } } /// Micropub storage that uses a PostgreSQL database. #[derive(Debug, Clone)] pub struct PostgresStorage { db: PgPool } impl PostgresStorage { /// Construct a [`PostgresStorage`] from a [`sqlx::PgPool`], /// running appropriate migrations. pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result { db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox")).await?; MIGRATOR.run(&db).await?; Ok(Self { db }) } } impl Storage for PostgresStorage { /// Construct a new [`PostgresStorage`] from an URI string and run /// migrations on the database. /// /// If `PGPASS_FILE` environment variable is defined, read the /// password from the file at the specified path. If, instead, /// the `PGPASS` environment variable is present, read the /// password from it. async fn new(url: &'_ url::Url) -> Result { tracing::debug!("Postgres URL: {url}"); let mut options = sqlx::postgres::PgConnectOptions::from_url(url)? .options([("search_path", "kittybox")]); if let Ok(password_file) = std::env::var("PGPASS_FILE") { let password = tokio::fs::read_to_string(password_file).await.unwrap(); options = options.password(&password); } else if let Ok(password) = std::env::var("PGPASS") { options = options.password(&password) } Self::from_pool( sqlx::postgres::PgPoolOptions::new() .max_connections(50) .connect_with(options) .await? ).await } #[tracing::instrument(skip(self))] async fn categories(&self, url: &str) -> Result> { sqlx::query_scalar::<_, String>(" SELECT jsonb_array_elements(mf2['properties']['category']) AS category FROM kittybox.mf2_json WHERE jsonb_typeof(mf2['properties']['category']) = 'array' AND uid LIKE ($1 + '%') GROUP BY category ORDER BY count(*) DESC ") .bind(url) .fetch_all(&self.db) .await .map_err(|err| err.into()) } #[tracing::instrument(skip(self))] async fn post_exists(&self, url: &str) -> Result { sqlx::query_as::<_, (bool,)>("SELECT exists(SELECT 1 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1)") .bind(url) .fetch_one(&self.db) .await .map(|v| v.0) .map_err(|err| err.into()) } #[tracing::instrument(skip(self))] async fn get_post(&self, url: &str) -> Result> { sqlx::query_as::<_, (serde_json::Value,)>("SELECT mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1") .bind(url) .fetch_optional(&self.db) .await .map(|v| v.map(|v| v.0)) .map_err(|err| err.into()) } #[tracing::instrument(skip(self))] async fn put_post(&self, post: &'_ serde_json::Value, user: &url::Url) -> Result<()> { tracing::debug!("New post: {}", post); sqlx::query("INSERT INTO kittybox.mf2_json (uid, mf2, owner) VALUES ($1 #>> '{properties,uid,0}', $1, $2)") .bind(post) .bind(user.authority()) .execute(&self.db) .await .map(|_| ()) .map_err(Into::into) } #[tracing::instrument(skip(self))] async fn add_to_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> { tracing::debug!("Inserting {} into {}", post, feed); sqlx::query("INSERT INTO kittybox.children (parent, child) VALUES ($1, $2) ON CONFLICT DO NOTHING") .bind(feed) .bind(post) .execute(&self.db) .await .map(|_| ()) .map_err(Into::into) } #[tracing::instrument(skip(self))] async fn remove_from_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> { sqlx::query("DELETE FROM kittybox.children WHERE parent = $1 AND child = $2") .bind(feed) .bind(post) .execute(&self.db) .await .map_err(Into::into) .map(|_| ()) } #[tracing::instrument(skip(self))] async fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> Result<()> { let mut txn = self.db.begin().await?; let (uid, mut post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE") .bind(target) .fetch_optional(&mut *txn) .await? .ok_or(StorageError::from_static( ErrorKind::NotFound, "The specified post wasn't found in the database." ))?; tracing::debug!("Loaded post for target {} with uid {}", target, uid); let key: &'static str = match mention_type { MentionType::Reply => "comment", MentionType::Like => "like", MentionType::Repost => "repost", MentionType::Bookmark => "bookmark", MentionType::Mention => "mention", }; tracing::debug!("Mention type -> key: {}", key); let mention_uid = mention["properties"]["uid"][0].clone(); if let Some(values) = post["properties"][key].as_array_mut() { for value in values.iter_mut() { if value["properties"]["uid"][0] == mention_uid { *value = mention; break; } } } else { post["properties"][key] = serde_json::Value::Array(vec![mention]); } sqlx::query("UPDATE kittybox.mf2_json SET mf2 = $2 WHERE uid = $1") .bind(uid) .bind(post) .execute(&mut *txn) .await?; txn.commit().await.map_err(Into::into) } #[tracing::instrument(skip(self), fields(f = std::any::type_name::()))] async fn update_with( &self, url: &str, f: F ) -> Result<(serde_json::Value, serde_json::Value)> { tracing::debug!("Updating post {}", url); let mut txn = self.db.begin().await?; let (uid, old_post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE") .bind(url) .fetch_optional(&mut *txn) .await? .ok_or(StorageError::from_static( ErrorKind::NotFound, "The specified post wasn't found in the database." ))?; let new_post = { let mut post = old_post.clone(); #[cfg(not(test))] // Tests use the current-thread runtime. tokio::task::block_in_place(|| f(&mut post)); #[cfg(test)] f(&mut post); post }; sqlx::query("UPDATE kittybox.mf2_json SET mf2 = $2 WHERE uid = $1") .bind(uid) .bind(&new_post) .execute(&mut *txn) .await?; txn.commit().await?; Ok((old_post, new_post)) } #[tracing::instrument(skip(self))] async fn get_channels(&self, user: &url::Url) -> Result> { sqlx::query_as::<_, MicropubChannel>(r#"SELECT mf2 #>> '{properties,name,0}' as name, uid FROM kittybox.mf2_json WHERE '["h-feed"]'::jsonb @> mf2['type'] AND owner = $1"#) .bind(user.authority()) .fetch_all(&self.db) .await .map_err(|err| err.into()) } #[tracing::instrument(skip(self))] async fn read_feed_with_limit( &self, url: &'_ str, _after: Option<&str>, _limit: usize, _user: Option<&url::Url>, ) -> Result> { unimplemented!("read_feed_with_limit is insecure and deprecated"); } #[tracing::instrument(skip(self))] async fn read_feed_with_cursor( &self, url: &'_ str, cursor: Option<&'_ str>, limit: usize, user: Option<&url::Url> ) -> Result)>> { let mut txn = self.db.begin().await?; sqlx::query("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ ONLY") .execute(&mut *txn) .await?; tracing::debug!("Started txn: {:?}", txn); let mut feed = match sqlx::query_scalar::<_, serde_json::Value>(" SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 ") .bind(url) .fetch_optional(&mut *txn) .await? { Some(feed) => feed, None => return Ok(None) }; // Don't query for children if this isn't a feed. // // The second query is very long and will probably be extremely // expensive. It's best to skip it on types where it doesn't make sense // (Kittybox doesn't support rendering children on non-feeds) if !feed["type"].as_array().unwrap().iter().any(|t| *t == serde_json::json!("h-feed")) { return Ok(Some((feed, None))); } feed["children"] = sqlx::query_scalar::<_, serde_json::Value>(" SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json INNER JOIN kittybox.children ON mf2_json.uid = children.child WHERE children.parent = $1 AND ( ( (mf2 #>> '{properties,visibility,0}') = 'public' OR NOT (mf2['properties'] ? 'visibility') ) OR ( $3 IS NOT NULL AND ( (NOT (mf2['properties'] ? 'audience')) OR (mf2['properties']['audience'] ? $3) OR (mf2['properties']['author'] ? $3) ) ) ) AND ($4 IS NULL OR ((mf2_json.mf2 #>> '{properties,published,0}') < $4)) ORDER BY (mf2_json.mf2 #>> '{properties,published,0}') DESC LIMIT $2" ) .bind(url) .bind(limit as i64) .bind(user.map(url::Url::as_str)) .bind(cursor) .fetch_all(&mut *txn) .await .map(serde_json::Value::Array)?; let new_cursor = feed["children"].as_array().unwrap() .last() .map(|v| v["properties"]["published"][0].as_str().unwrap().to_owned()); txn.commit().await?; Ok(Some((feed, new_cursor))) } #[tracing::instrument(skip(self))] async fn delete_post(&self, url: &'_ str) -> Result<()> { todo!() } #[tracing::instrument(skip(self))] async fn get_setting(&'_ self, user: &url::Url) -> Result { match sqlx::query_as::<_, (serde_json::Value,)>("SELECT kittybox.get_setting($1, $2)") .bind(user.authority()) .bind(S::ID) .fetch_one(&self.db) .await { Ok((value,)) => Ok(serde_json::from_value(value)?), Err(err) => Err(err.into()) } } #[tracing::instrument(skip(self))] async fn set_setting(&self, user: &url::Url, value: S::Data) -> Result<()> { sqlx::query("SELECT kittybox.set_setting($1, $2, $3)") .bind(user.authority()) .bind(S::ID) .bind(serde_json::to_value(S::new(value)).unwrap()) .execute(&self.db) .await .map_err(Into::into) .map(|_| ()) } }