about summary refs log blame commit diff
path: root/src/database/postgres/mod.rs
blob: 1a1b98d61c64a65e2b9deca8a585b9bd5b6c950c (plain) (tree)
1
2
3
4
5
                     
 
                                 
                                                                       
                                             

























                                                                    
                                                     




                            
                                                                
                                                                     




                                                                               
                                  





                                                                      

                                                                          
                                                    










                                                                                   
 
     








                                                                                                                                                                        
                                      













                                                                      



















                                                                                                                                      
                                                                                          

                                                                                                                   
                                   



























                                                                                                           
                                      










                                                                                                                                                                                   
                                                                              
                                                    
                                            



                                                

                                                        


















                                                                              
 


                                                                              
                                                 
                                                                                                                                                                                   
                      
                                      




                                                                  
                                            
                                                                      
                                                         
                         

                

                                                                           
                            
                               
                    
 

                                

                                      
                                                                                   
                                                                                                                                                                                   
                                   







                                      

                                 
                                            
                                                                          






                                      
                               

                                                                                 
                                           

                                                                         
                                                                                                          
                      
                                      




                                   







                                                                                                












                                                                       




                                                      







                                                                            
                                             
                         
                                 
















                                                                                   
                                                                              
                                                                                              
                                   








                                                               
                                                                                            
                                                              
                                   






                                                               
use std::borrow::Cow;

use futures::{Stream, StreamExt};
use kittybox_util::{micropub::Channel as MicropubChannel, MentionType};
use sqlx::{ConnectOptions, Executor, PgPool};
use crate::micropub::{MicropubUpdate, MicropubPropertyDeletion};

use super::settings::Setting;
use super::{Storage, Result, StorageError, ErrorKind};

static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!();

impl From<sqlx::Error> for StorageError {
    fn from(value: sqlx::Error) -> Self {
        Self::with_source(
            super::ErrorKind::Backend,
            Cow::Owned(format!("sqlx error: {}", &value)),
            Box::new(value)
        )
    }
}

impl From<sqlx::migrate::MigrateError> for StorageError {
    fn from(value: sqlx::migrate::MigrateError) -> Self {
        Self::with_source(
            super::ErrorKind::Backend,
            Cow::Owned(format!("sqlx migration error: {}", &value)),
            Box::new(value)
        )
    }
}

/// Micropub storage that uses a PostgreSQL database.
#[derive(Debug, Clone)]
pub struct PostgresStorage {
    db: PgPool
}

impl PostgresStorage {
    /// Construct a [`PostgresStorage`] from a [`sqlx::PgPool`],
    /// running appropriate migrations.
    pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result<Self> {
        db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox")).await?;
        MIGRATOR.run(&db).await?;
        Ok(Self { db })
    }
}

impl Storage for PostgresStorage {
    /// Construct a new [`PostgresStorage`] from an URI string and run
    /// migrations on the database.
    ///
    /// If `PGPASS_FILE` environment variable is defined, read the
    /// password from the file at the specified path. If, instead,
    /// the `PGPASS` environment variable is present, read the
    /// password from it.
    async fn new(url: &'_ url::Url) -> Result<Self> {
        tracing::debug!("Postgres URL: {url}");
        let mut options = sqlx::postgres::PgConnectOptions::from_url(url)?
            .options([("search_path", "kittybox")]);
        if let Ok(password_file) = std::env::var("PGPASS_FILE") {
            let password = tokio::fs::read_to_string(password_file).await.unwrap();
            options = options.password(&password);
        } else if let Ok(password) = std::env::var("PGPASS") {
            options = options.password(&password)
        }
        Self::from_pool(
            sqlx::postgres::PgPoolOptions::new()
                .max_connections(50)
                .connect_with(options)
                .await?
        ).await

    }

    async fn all_posts<'this>(&'this self, user: &url::Url) -> Result<impl Stream<Item = serde_json::Value> + Send + 'this> {
        let authority = user.authority().to_owned();
        Ok(
            sqlx::query_scalar::<_, serde_json::Value>("SELECT mf2 FROM kittybox.mf2_json WHERE owner = $1 ORDER BY (mf2_json.mf2 #>> '{properties,published,0}') DESC")
                .bind(authority)
                .fetch(&self.db)
                .filter_map(|f| std::future::ready(f.ok()))
        )
    }

    #[tracing::instrument(skip(self))]
    async fn categories(&self, url: &str) -> Result<Vec<String>> {
        sqlx::query_scalar::<_, String>("
SELECT jsonb_array_elements(mf2['properties']['category']) AS category
FROM kittybox.mf2_json
WHERE
    jsonb_typeof(mf2['properties']['category']) = 'array'
    AND uid LIKE ($1 + '%')
    GROUP BY category ORDER BY count(*) DESC
")
            .bind(url)
            .fetch_all(&self.db)
            .await
            .map_err(|err| err.into())
    }
    #[tracing::instrument(skip(self))]
    async fn post_exists(&self, url: &str) -> Result<bool> {
        sqlx::query_as::<_, (bool,)>("SELECT exists(SELECT 1 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1)")
            .bind(url)
            .fetch_one(&self.db)
            .await
            .map(|v| v.0)
            .map_err(|err| err.into())
    }

    #[tracing::instrument(skip(self))]
    async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> {
        sqlx::query_as::<_, (serde_json::Value,)>("SELECT mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1")
            .bind(url)
            .fetch_optional(&self.db)
            .await
            .map(|v| v.map(|v| v.0))
            .map_err(|err| err.into())

    }

    #[tracing::instrument(skip(self))]
    async fn put_post(&self, post: &'_ serde_json::Value, user: &url::Url) -> Result<()> {
        tracing::debug!("New post: {}", post);
        sqlx::query("INSERT INTO kittybox.mf2_json (uid, mf2, owner) VALUES ($1 #>> '{properties,uid,0}', $1, $2)")
            .bind(post)
            .bind(user.authority())
            .execute(&self.db)
            .await
            .map(|_| ())
            .map_err(Into::into)
    }

    #[tracing::instrument(skip(self))]
    async fn add_to_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> {
        tracing::debug!("Inserting {} into {}", post, feed);
        sqlx::query("INSERT INTO kittybox.children (parent, child) VALUES ($1, $2) ON CONFLICT DO NOTHING")
            .bind(feed)
            .bind(post)
            .execute(&self.db)
            .await
            .map(|_| ())
            .map_err(Into::into)
    }

    #[tracing::instrument(skip(self))]
    async fn remove_from_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> {
        sqlx::query("DELETE FROM kittybox.children WHERE parent = $1 AND child = $2")
            .bind(feed)
            .bind(post)
            .execute(&self.db)
            .await
            .map_err(Into::into)
            .map(|_| ())
    }

    #[tracing::instrument(skip(self))]
    async fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> Result<()> {
        let mut txn = self.db.begin().await?;

        let (uid, mut post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE")
            .bind(target)
            .fetch_optional(&mut *txn)
            .await?
            .ok_or(StorageError::from_static(
                ErrorKind::NotFound,
                "The specified post wasn't found in the database."
            ))?;

        tracing::debug!("Loaded post for target {} with uid {}", target, uid);

        let key: &'static str = match mention_type {
            MentionType::Reply => "comment",
            MentionType::Like => "like",
            MentionType::Repost => "repost",
            MentionType::Bookmark => "bookmark",
            MentionType::Mention => "mention",
        };

        tracing::debug!("Mention type -> key: {}", key);

        let mention_uid = mention["properties"]["uid"][0].clone();
        if let Some(values) = post["properties"][key].as_array_mut() {
            for value in values.iter_mut() {
                if value["properties"]["uid"][0] == mention_uid {
                    *value = mention;
                    break;
                }
            }
        } else {
            post["properties"][key] = serde_json::Value::Array(vec![mention]);
        }

        sqlx::query("UPDATE kittybox.mf2_json SET mf2 = $2 WHERE uid = $1")
            .bind(uid)
            .bind(post)
            .execute(&mut *txn)
            .await?;

        txn.commit().await.map_err(Into::into)
    }

    #[tracing::instrument(skip(self), fields(f = std::any::type_name::<F>()))]
    async fn update_with<F: FnOnce(&mut serde_json::Value) + Send>(
        &self, url: &str, f: F
    ) -> Result<(serde_json::Value, serde_json::Value)> {
        tracing::debug!("Updating post {}", url);
        let mut txn = self.db.begin().await?;
        let (uid, old_post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE")
            .bind(url)
            .fetch_optional(&mut *txn)
            .await?
            .ok_or(StorageError::from_static(
                ErrorKind::NotFound,
                "The specified post wasn't found in the database."
            ))?;

        let new_post = {
            let mut post = old_post.clone();
            #[cfg(not(test))] // Tests use the current-thread runtime.
            tokio::task::block_in_place(|| f(&mut post));
            #[cfg(test)]
            f(&mut post);

            post
        };

        sqlx::query("UPDATE kittybox.mf2_json SET mf2 = $2 WHERE uid = $1")
            .bind(uid)
            .bind(&new_post)
            .execute(&mut *txn)
            .await?;

        txn.commit().await?;

        Ok((old_post, new_post))
    }

    #[tracing::instrument(skip(self))]
    async fn get_channels(&self, user: &url::Url) -> Result<Vec<MicropubChannel>> {
        sqlx::query_as::<_, MicropubChannel>(r#"SELECT mf2 #>> '{properties,name,0}' as name, uid FROM kittybox.mf2_json WHERE '["h-feed"]'::jsonb @> mf2['type'] AND owner = $1"#)
            .bind(user.authority())
            .fetch_all(&self.db)
            .await
            .map_err(|err| err.into())
    }

    #[tracing::instrument(skip(self))]
    async fn read_feed_with_limit(
        &self,
        url: &'_ str,
        _after: Option<&str>,
        _limit: usize,
        _user: Option<&url::Url>,
    ) -> Result<Option<serde_json::Value>> {
        unimplemented!("read_feed_with_limit is insecure and deprecated");
    }

    #[tracing::instrument(skip(self))]
    async fn read_feed_with_cursor(
        &self,
        url: &'_ str,
        cursor: Option<&'_ str>,
        limit: usize,
        user: Option<&url::Url>
    ) -> Result<Option<(serde_json::Value, Option<String>)>> {
        let mut txn = self.db.begin().await?;
        sqlx::query("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ ONLY")
			.execute(&mut *txn)
			.await?;
        tracing::debug!("Started txn: {:?}", txn);
        let mut feed = match sqlx::query_scalar::<_, serde_json::Value>("
SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1
")
            .bind(url)
            .fetch_optional(&mut *txn)
            .await?
        {
            Some(feed) => feed,
            None => return Ok(None)
        };

        // Don't query for children if this isn't a feed.
        //
        // The second query is very long and will probably be extremely
        // expensive. It's best to skip it on types where it doesn't make sense
        // (Kittybox doesn't support rendering children on non-feeds)
        if !feed["type"].as_array().unwrap().iter().any(|t| *t == serde_json::json!("h-feed")) {
            return Ok(Some((feed, None)));
        }

        feed["children"] = sqlx::query_scalar::<_, serde_json::Value>("
SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json
INNER JOIN kittybox.children
ON mf2_json.uid = children.child
WHERE
    children.parent = $1
    AND (
        (
          (mf2 #>> '{properties,visibility,0}') = 'public'
          OR
          NOT (mf2['properties'] ? 'visibility')
        )
        OR
        (
            $3 IS NOT NULL AND (
                (NOT (mf2['properties'] ? 'audience'))
                OR
                (mf2['properties']['audience'] ? $3)
                OR
                (mf2['properties']['author'] ? $3)
            )
        )
    )
    AND ($4 IS NULL OR ((mf2_json.mf2 #>> '{properties,published,0}') < $4))
ORDER BY (mf2_json.mf2 #>> '{properties,published,0}') DESC
LIMIT $2"
        )
            .bind(url)
            .bind(limit as i64)
            .bind(user.map(url::Url::as_str))
            .bind(cursor)
            .fetch_all(&mut *txn)
            .await
            .map(serde_json::Value::Array)?;

        let new_cursor = feed["children"].as_array().unwrap()
            .last()
            .map(|v| v["properties"]["published"][0].as_str().unwrap().to_owned());

        txn.commit().await?;

        Ok(Some((feed, new_cursor)))
    }

    #[tracing::instrument(skip(self))]
    async fn delete_post(&self, url: &'_ str) -> Result<()> {
        todo!()
    }

    #[tracing::instrument(skip(self))]
    async fn get_setting<S: Setting>(&'_ self, user: &url::Url) -> Result<S> {
        match sqlx::query_as::<_, (serde_json::Value,)>("SELECT kittybox.get_setting($1, $2)")
            .bind(user.authority())
            .bind(S::ID)
            .fetch_one(&self.db)
            .await
        {
            Ok((value,)) => Ok(serde_json::from_value(value)?),
            Err(err) => Err(err.into())
        }
    }

    #[tracing::instrument(skip(self))]
    async fn set_setting<S: Setting>(&self, user: &url::Url, value: S::Data) -> Result<()> {
        sqlx::query("SELECT kittybox.set_setting($1, $2, $3)")
            .bind(user.authority())
            .bind(S::ID)
            .bind(serde_json::to_value(S::new(value)).unwrap())
            .execute(&self.db)
            .await
            .map_err(Into::into)
            .map(|_| ())
    }
}