about summary refs log blame commit diff
path: root/src/database/redis/mod.rs
blob: c331e47e2a0d1ecdd338226a4119d1aea63e0815 (plain) (tree)
1
2
3
4
5
6
7
8
9
                             
                    
                            
                             
               
                                     
                                       
                     
 
                                                                                 
                           
                     
                             
 



                                                           
                                     



                                                                        

                                        
                                     

         







                                                                      





                                                     
                                              
 
                                                                                                    







                                                        

















                                                                       
                                                  






                                                                              
                                                                                  


                                    





                               
                                                                                        

                                                                                


                                                                                                    

                                                                                         
     
                                                                 
                                               
                                                            

                                                            
                                               
                                                                  
     
 
                                                                              
                                               


                                                              

                                                                              


                                                                              
                                                                                                
                                         


                                    
                             


                                                                               
                                               

                                                                                        
                                                                                            




                                                                             
                                                                                                
                                                                                                  
                              
                          
                  


                                     
                  
                             
     
                                                                                            
                                               

                                                     




                                              
         
                                                                            
                                                 




                                                         
                                                        




                                                                
                 
             




                                   
                                                                                      



                                                                                            
         
              
     





                                            
                                               
                     


                                                              
                                                                                   
                                    
                                          


                                                                                                  
                                                                                       
                                        



                                                     


                                                         




                                                                               





                                                                 



                                                                                                
                                                  










                                                                                                
                                     
                                                                                                         
                                  



                                                                      
                             
                         





                                                                                                          
                     



                                                                                                      
                                                                                  


                                                                                    
                                                                          
                                                                                                                             


                                                                                         
                                                             




                                                                                                  
             
         
                              

                                                                                                
                                               







                                                      
         
                                                                                         

                                                           





                                                                            

     


                                                                                           




                                                                
         


               
                          
                     
                            
 


                                                                                      
                                   











                                                                  


                                                                                        


                                
                                           
                                             

















                                                                               
                              


                               
     
 
use async_trait::async_trait;
use futures::stream;
use futures_util::FutureExt;
use futures_util::StreamExt;
use lazy_static::lazy_static;
use log::error;
use mobc::Pool;
use mobc_redis::redis;
use mobc_redis::redis::AsyncCommands;
use mobc_redis::RedisConnectionManager;
use serde_json::json;

use crate::database::{ErrorKind, MicropubChannel, Result, Storage, StorageError};
use crate::indieauth::User;

struct RedisScripts {
    edit_post: redis::Script,
}

impl From<mobc_redis::redis::RedisError> for StorageError {
    fn from(err: mobc_redis::redis::RedisError) -> Self {
        Self {
            msg: format!("{}", err),
            source: Some(Box::new(err)),
            kind: ErrorKind::Backend,
        }
    }
}
impl From<mobc::Error<mobc_redis::redis::RedisError>> for StorageError {
    fn from(err: mobc::Error<mobc_redis::redis::RedisError>) -> Self {
        Self {
            msg: format!("{}", err),
            source: Some(Box::new(err)),
            kind: ErrorKind::Backend,
        }
    }
}

lazy_static! {
    static ref SCRIPTS: RedisScripts = RedisScripts {
        edit_post: redis::Script::new(include_str!("./edit_post.lua"))
    };
}

#[derive(Clone)]
pub struct RedisStorage {
    // note to future Vika:
    // mobc::Pool is actually a fancy name for an Arc
    // around a shared connection pool with a manager
    // which makes it safe to implement [`Clone`] and
    // not worry about new pools being suddenly made
    //
    // stop worrying and start coding, you dum-dum
    redis: mobc::Pool<RedisConnectionManager>,
}

fn filter_post(mut post: serde_json::Value, user: &'_ Option<String>) -> Option<serde_json::Value> {
    if post["properties"]["deleted"][0].is_string() {
        return Some(json!({
            "type": post["type"],
            "properties": {
                "deleted": post["properties"]["deleted"]
            }
        }));
    }
    let empty_vec: Vec<serde_json::Value> = vec![];
    let author = post["properties"]["author"]
        .as_array()
        .unwrap_or(&empty_vec)
        .iter()
        .map(|i| i.as_str().unwrap().to_string());
    let visibility = post["properties"]["visibility"][0]
        .as_str()
        .unwrap_or("public");
    let mut audience = author.chain(
        post["properties"]["audience"]
            .as_array()
            .unwrap_or(&empty_vec)
            .iter()
            .map(|i| i.as_str().unwrap().to_string()),
    );
    if (visibility == "private" && !audience.any(|i| Some(i) == *user))
        || (visibility == "protected" && user.is_none())
    {
        return None;
    }
    if post["properties"]["location"].is_array() {
        let location_visibility = post["properties"]["location-visibility"][0]
            .as_str()
            .unwrap_or("private");
        let mut author = post["properties"]["author"]
            .as_array()
            .unwrap_or(&empty_vec)
            .iter()
            .map(|i| i.as_str().unwrap().to_string());
        if location_visibility == "private" && !author.any(|i| Some(i) == *user) {
            post["properties"]
                .as_object_mut()
                .unwrap()
                .remove("location");
        }
    }
    Some(post)
}

#[async_trait]
impl Storage for RedisStorage {
    async fn get_setting<'a>(&self, setting: &'a str, user: &'a str) -> Result<String> {
        let mut conn = self.redis.get().await?;
        Ok(conn
            .hget::<String, &str, String>(format!("settings_{}", user), setting)
            .await?)
    }

    async fn set_setting<'a>(&self, setting: &'a str, user: &'a str, value: &'a str) -> Result<()> {
        let mut conn = self.redis.get().await?;
        Ok(conn
            .hset::<String, &str, &str, ()>(format!("settings_{}", user), setting, value)
            .await?)
    }

    async fn delete_post<'a>(&self, url: &'a str) -> Result<()> {
        let mut conn = self.redis.get().await?;
        Ok(conn.hdel::<&str, &str, ()>("posts", url).await?)
    }

    async fn post_exists(&self, url: &str) -> Result<bool> {
        let mut conn = self.redis.get().await?;
        Ok(conn.hexists::<&str, &str, bool>(&"posts", url).await?)
    }

    async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> {
        let mut conn = self.redis.get().await?;
        match conn
            .hget::<&str, &str, Option<String>>(&"posts", url)
            .await?
        {
            Some(val) => {
                let parsed = serde_json::from_str::<serde_json::Value>(&val)?;
                if let Some(new_url) = parsed["see_other"].as_str() {
                    match conn
                        .hget::<&str, &str, Option<String>>(&"posts", new_url)
                        .await?
                    {
                        Some(val) => Ok(Some(serde_json::from_str::<serde_json::Value>(&val)?)),
                        None => Ok(None),
                    }
                } else {
                    Ok(Some(parsed))
                }
            }
            None => Ok(None),
        }
    }

    async fn get_channels(&self, user: &User) -> Result<Vec<MicropubChannel>> {
        let mut conn = self.redis.get().await?;
        let channels = conn
            .smembers::<String, Vec<String>>("channels_".to_string() + user.me.as_str())
            .await?;
        // TODO: use streams here instead of this weird thing... how did I even write this?!
        Ok(futures_util::future::join_all(
            channels
                .iter()
                .map(|channel| {
                    self.get_post(channel).map(|result| result.unwrap()).map(
                        |post: Option<serde_json::Value>| {
                            post.map(|post| MicropubChannel {
                                uid: post["properties"]["uid"][0].as_str().unwrap().to_string(),
                                name: post["properties"]["name"][0].as_str().unwrap().to_string(),
                            })
                        },
                    )
                })
                .collect::<Vec<_>>(),
        )
        .await
        .into_iter()
        .flatten()
        .collect::<Vec<_>>())
    }

    async fn put_post<'a>(&self, post: &'a serde_json::Value, user: &'a str) -> Result<()> {
        let mut conn = self.redis.get().await?;
        let key: &str;
        match post["properties"]["uid"][0].as_str() {
            Some(uid) => key = uid,
            None => {
                return Err(StorageError::new(
                    ErrorKind::BadRequest,
                    "post doesn't have a UID",
                ))
            }
        }
        conn.hset::<&str, &str, String, ()>(&"posts", key, post.to_string())
            .await?;
        if post["properties"]["url"].is_array() {
            for url in post["properties"]["url"]
                .as_array()
                .unwrap()
                .iter()
                .map(|i| i.as_str().unwrap().to_string())
            {
                if url != key && url.starts_with(user) {
                    conn.hset::<&str, &str, String, ()>(
                        &"posts",
                        &url,
                        json!({ "see_other": key }).to_string(),
                    )
                    .await?;
                }
            }
        }
        if post["type"]
            .as_array()
            .unwrap()
            .iter()
            .any(|i| i == "h-feed")
        {
            // This is a feed. Add it to the channels array if it's not already there.
            conn.sadd::<String, &str, ()>(
                "channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(),
                key,
            )
            .await?
        }
        Ok(())
    }

    async fn read_feed_with_limit<'a>(
        &self,
        url: &'a str,
        after: &'a Option<String>,
        limit: usize,
        user: &'a Option<String>,
    ) -> Result<Option<serde_json::Value>> {
        let mut conn = self.redis.get().await?;
        let mut feed;
        match conn
            .hget::<&str, &str, Option<String>>(&"posts", url)
            .await?
        {
            Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?,
            None => return Ok(None),
        }
        if feed["see_other"].is_string() {
            match conn
                .hget::<&str, &str, Option<String>>(&"posts", feed["see_other"].as_str().unwrap())
                .await?
            {
                Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?,
                None => return Ok(None),
            }
        }
        if let Some(post) = filter_post(feed, user) {
            feed = post
        } else {
            return Err(StorageError::new(
                ErrorKind::PermissionDenied,
                "specified user cannot access this post",
            ));
        }
        if feed["children"].is_array() {
            let children = feed["children"].as_array().unwrap();
            let posts_iter: Box<dyn std::iter::Iterator<Item = String> + Send>;
            // TODO: refactor this to apply the skip on the &mut iterator
            if let Some(after) = after {
                posts_iter = Box::new(
                    children
                        .iter()
                        .map(|i| i.as_str().unwrap().to_string())
                        .skip_while(move |i| i != after)
                        .skip(1),
                );
            } else {
                posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string()));
            }
            let posts = stream::iter(posts_iter)
                .map(|url| async move {
                    match self.redis.get().await {
                        Ok(mut conn) => {
                            match conn.hget::<&str, &str, Option<String>>("posts", &url).await {
                                Ok(post) => match post {
                                    Some(post) => {
                                        match serde_json::from_str::<serde_json::Value>(&post) {
                                            Ok(post) => Some(post),
                                            Err(err) => {
                                                let err = StorageError::from(err);
                                                error!("{}", err);
                                                panic!("{}", err)
                                            }
                                        }
                                    }
                                    // Happens because of a broken link (result of an improper deletion?)
                                    None => None,
                                },
                                Err(err) => {
                                    let err = StorageError::from(err);
                                    error!("{}", err);
                                    panic!("{}", err)
                                }
                            }
                        }
                        // TODO: Instead of causing a panic, investigate how can you fail the whole stream
                        // Somehow fuse it maybe?
                        Err(err) => {
                            let err = StorageError::from(err);
                            error!("{}", err);
                            panic!("{}", err)
                        }
                    }
                })
                // TODO: determine the optimal value for this buffer
                // It will probably depend on how often can you encounter a private post on the page
                // It shouldn't be too large, or we'll start fetching too many posts from the database
                // It MUST NOT be larger than the typical page size
                // It MUST NOT be a significant amount of the connection pool size
                .buffered(std::cmp::min(3, limit))
                // Hack to unwrap the Option and sieve out broken links
                // Broken links return None, and Stream::filter_map skips all Nones.
                .filter_map(|post: Option<serde_json::Value>| async move { post })
                .filter_map(|post| async move { filter_post(post, user) })
                .take(limit);
            // TODO: Instead of catching panics, find a way to make the whole stream fail with Result<Vec<serde_json::Value>>
            match std::panic::AssertUnwindSafe(posts.collect::<Vec<serde_json::Value>>())
                .catch_unwind()
                .await
            {
                Ok(posts) => feed["children"] = json!(posts),
                Err(_) => {
                    return Err(StorageError::new(
                        ErrorKind::Other,
                        "Unknown error encountered while assembling feed, see logs for more info",
                    ))
                }
            }
        }
        return Ok(Some(feed));
    }

    async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> {
        let mut conn = self.redis.get().await?;
        if !conn
            .hexists::<&str, &str, bool>("posts", url)
            .await
            .unwrap()
        {
            return Err(StorageError::new(
                ErrorKind::NotFound,
                "can't edit a non-existent post",
            ));
        }
        let post: serde_json::Value =
            serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await?)?;
        if let Some(new_url) = post["see_other"].as_str() {
            url = new_url
        }
        Ok(SCRIPTS
            .edit_post
            .key("posts")
            .arg(url)
            .arg(update.to_string())
            .invoke_async::<_, ()>(&mut conn as &mut redis::aio::Connection)
            .await?)
    }
}

impl RedisStorage {
    /// Create a new RedisDatabase that will connect to Redis at `redis_uri` to store data.
    pub async fn new(redis_uri: String) -> Result<Self> {
        match redis::Client::open(redis_uri) {
            Ok(client) => Ok(Self {
                redis: Pool::builder()
                    .max_open(20)
                    .build(RedisConnectionManager::new(client)),
            }),
            Err(e) => Err(e.into()),
        }
    }
}

#[cfg(test)]
pub mod tests {
    use mobc_redis::redis;
    use std::process;
    use std::time::Duration;

    pub struct RedisInstance {
        // We just need to hold on to it so it won't get dropped and remove the socket
        _tempdir: tempdir::TempDir,
        uri: String,
        child: std::process::Child,
    }
    impl Drop for RedisInstance {
        fn drop(&mut self) {
            self.child.kill().expect("Failed to kill the child!");
        }
    }
    impl RedisInstance {
        pub fn uri(&self) -> &str {
            &self.uri
        }
    }

    pub async fn get_redis_instance() -> RedisInstance {
        let tempdir = tempdir::TempDir::new("redis").expect("failed to create tempdir");
        let socket = tempdir.path().join("redis.sock");
        let redis_child = process::Command::new("redis-server")
            .current_dir(&tempdir)
            .arg("--port")
            .arg("0")
            .arg("--unixsocket")
            .arg(&socket)
            .stdout(process::Stdio::null())
            .stderr(process::Stdio::null())
            .spawn()
            .expect("Failed to spawn Redis");
        println!("redis+unix:///{}", socket.to_str().unwrap());
        let uri = format!("redis+unix:///{}", socket.to_str().unwrap());
        // There should be a slight delay, we need to wait for Redis to spin up
        let client = redis::Client::open(uri.clone()).unwrap();
        let millisecond = Duration::from_millis(1);
        let mut retries: usize = 0;
        const MAX_RETRIES: usize = 60 * 1000/*ms*/;
        while let Err(err) = client.get_connection() {
            if err.is_connection_refusal() {
                async_std::task::sleep(millisecond).await;
                retries += 1;
                if retries > MAX_RETRIES {
                    panic!("Timeout waiting for Redis, last error: {}", err);
                }
            } else {
                panic!("Could not connect: {}", err);
            }
        }

        return RedisInstance {
            uri,
            child: redis_child,
            _tempdir: tempdir,
        };
    }
}