use async_trait::async_trait; use futures_util::FutureExt; use futures_util::StreamExt; use futures::stream; use lazy_static::lazy_static; use log::error; use redis::AsyncCommands; use serde_json::json; use crate::database::{Storage, Result, StorageError, ErrorKind, MicropubChannel}; use crate::indieauth::User; struct RedisScripts { edit_post: redis::Script } impl From for StorageError { fn from(err: redis::RedisError) -> Self { Self { msg: format!("{}", err), source: Some(Box::new(err)), kind: ErrorKind::Backend } } } lazy_static! { static ref SCRIPTS: RedisScripts = RedisScripts { edit_post: redis::Script::new(include_str!("./edit_post.lua")) }; } #[derive(Clone)] pub struct RedisStorage { // TODO: use mobc crate to create a connection pool and reuse connections for efficiency redis: redis::Client, } fn filter_post(mut post: serde_json::Value, user: &'_ Option) -> Option { if post["properties"]["deleted"][0].is_string() { return Some(json!({ "type": post["type"], "properties": { "deleted": post["properties"]["deleted"] } })); } let empty_vec: Vec = vec![]; let author = post["properties"]["author"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string()); let visibility = post["properties"]["visibility"][0].as_str().unwrap_or("public"); let mut audience = author.chain(post["properties"]["audience"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string())); if (visibility == "private" && !audience.any(|i| Some(i) == *user)) || (visibility == "protected" && user.is_none()) { return None } if post["properties"]["location"].is_array() { let location_visibility = post["properties"]["location-visibility"][0].as_str().unwrap_or("private"); let mut author = post["properties"]["author"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string()); if location_visibility == "private" && !author.any(|i| Some(i) == *user) { post["properties"].as_object_mut().unwrap().remove("location"); } } Some(post) } #[async_trait] impl Storage for RedisStorage { async fn delete_post<'a>(&self, url: &'a str) -> Result<()> { match self.redis.get_async_std_connection().await { Ok(mut conn) => if let Err(err) = conn.hdel::<&str, &str, bool>("posts", url).await { return Err(err.into()); }, Err(err) => return Err(err.into()) } Ok(()) } async fn post_exists(&self, url: &str) -> Result { match self.redis.get_async_std_connection().await { Ok(mut conn) => match conn.hexists::<&str, &str, bool>(&"posts", url).await { Ok(val) => Ok(val), Err(err) => Err(err.into()) }, Err(err) => Err(err.into()) } } async fn get_post(&self, url: &str) -> Result> { match self.redis.get_async_std_connection().await { Ok(mut conn) => match conn.hget::<&str, &str, Option>(&"posts", url).await { Ok(val) => match val { Some(val) => match serde_json::from_str::(&val) { Ok(parsed) => if let Some(new_url) = parsed["see_other"].as_str() { match conn.hget::<&str, &str, Option>(&"posts", new_url).await { Ok(val) => match val { Some(val) => match serde_json::from_str::(&val) { Ok(parsed) => Ok(Some(parsed)), Err(err) => Err(err.into()) }, None => Ok(None) } Err(err) => { Err(err.into()) } } } else { Ok(Some(parsed)) }, Err(err) => Err(err.into()) }, None => Ok(None) }, Err(err) => Err(err.into()) }, Err(err) => Err(err.into()) } } async fn get_channels(&self, user: &User) -> Result> { match self.redis.get_async_std_connection().await { Ok(mut conn) => match conn.smembers::>("channels_".to_string() + user.me.as_str()).await { Ok(channels) => { Ok(futures_util::future::join_all(channels.iter() .map(|channel| self.get_post(channel) .map(|result| result.unwrap()) .map(|post: Option| { if let Some(post) = post { Some(MicropubChannel { uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), name: post["properties"]["name"][0].as_str().unwrap().to_string() }) } else { None } }) ).collect::>()).await.into_iter().filter_map(|chan| chan).collect::>()) }, Err(err) => Err(err.into()) }, Err(err) => Err(err.into()) } } async fn put_post<'a>(&self, post: &'a serde_json::Value) -> Result<()> { match self.redis.get_async_std_connection().await { Ok(mut conn) => { let key: &str; match post["properties"]["uid"][0].as_str() { Some(uid) => key = uid, None => return Err(StorageError::new(ErrorKind::BadRequest, "post doesn't have a UID")) } if let Err(err) = conn.hset::<&str, &str, String, ()>(&"posts", key, post.to_string()).await { return Err(err.into()) } if post["properties"]["url"].is_array() { for url in post["properties"]["url"].as_array().unwrap().iter().map(|i| i.as_str().unwrap().to_string()) { if url != key { if let Err(err) = conn.hset::<&str, &str, String, ()>(&"posts", &url, json!({"see_other": key}).to_string()).await { return Err(err.into()) } } } } if post["type"].as_array().unwrap().iter().any(|i| i == "h-feed") { // This is a feed. Add it to the channels array if it's not already there. if let Err(err) = conn.sadd::("channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(), key).await { return Err(err.into()) } } Ok(()) }, Err(err) => Err(err.into()) } } async fn read_feed_with_limit<'a>(&self, url: &'a str, after: &'a Option, limit: usize, user: &'a Option) -> Result> { match self.redis.get_async_std_connection().await { Ok(mut conn) => { let mut feed; match conn.hget::<&str, &str, Option>(&"posts", url).await { Ok(post) => { match post { Some(post) => match serde_json::from_str::(&post) { Ok(post) => feed = post, Err(err) => return Err(err.into()) }, None => return Ok(None) } }, Err(err) => return Err(err.into()) } if feed["see_other"].is_string() { match conn.hget::<&str, &str, Option>(&"posts", feed["see_other"].as_str().unwrap()).await { Ok(post) => { match post { Some(post) => match serde_json::from_str::(&post) { Ok(post) => feed = post, Err(err) => return Err(err.into()) }, None => return Ok(None) } }, Err(err) => return Err(err.into()) } } if let Some(post) = filter_post(feed, user) { feed = post } else { return Err(StorageError::new(ErrorKind::PermissionDenied, "specified user cannot access this post")) } if feed["children"].is_array() { let children = feed["children"].as_array().unwrap(); let posts_iter: Box + Send>; if let Some(after) = after { posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string()).skip_while(move |i| i != after).skip(1)); } else { posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string())); } let posts = stream::iter(posts_iter) .map(|url| async move { // Is it rational to use a new connection for every post fetched? match self.redis.get_async_std_connection().await { Ok(mut conn) => match conn.hget::<&str, &str, Option>("posts", &url).await { Ok(post) => match post { Some(post) => match serde_json::from_str::(&post) { Ok(post) => Some(post), Err(err) => { let err = StorageError::from(err); error!("{}", err); panic!("{}", err) } }, // Happens because of a broken link (result of an improper deletion?) None => None, }, Err(err) => { let err = StorageError::from(err); error!("{}", err); panic!("{}", err) } }, Err(err) => { let err = StorageError::from(err); error!("{}", err); panic!("{}", err) } } }) // TODO: determine the optimal value for this buffer // It will probably depend on how often can you encounter a private post on the page // It shouldn't be too large, or we'll start fetching too many posts from the database // It MUST NOT be larger than the typical page size .buffered(std::cmp::min(3, limit)) // Hack to unwrap the Option and sieve out broken links // Broken links return None, and Stream::filter_map skips all Nones. .filter_map(|post: Option| async move { post }) .filter_map(|post| async move { filter_post(post, user) }) .take(limit); match std::panic::AssertUnwindSafe(posts.collect::>()).catch_unwind().await { Ok(posts) => feed["children"] = json!(posts), Err(_) => return Err(StorageError::new(ErrorKind::Other, "Unknown error encountered while assembling feed, see logs for more info")) } } return Ok(Some(feed)); } Err(err) => Err(err.into()) } } async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> { match self.redis.get_async_std_connection().await { Ok(mut conn) => { if !conn.hexists::<&str, &str, bool>("posts", url).await.unwrap() { return Err(StorageError::new(ErrorKind::NotFound, "can't edit a non-existent post")) } let post: serde_json::Value = serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await.unwrap()).unwrap(); if let Some(new_url) = post["see_other"].as_str() { url = new_url } if let Err(err) = SCRIPTS.edit_post.key("posts").arg(url).arg(update.to_string()).invoke_async::<_, ()>(&mut conn).await { return Err(err.into()) } }, Err(err) => return Err(err.into()) } Ok(()) } } impl RedisStorage { /// Create a new RedisDatabase that will connect to Redis at `redis_uri` to store data. pub async fn new(redis_uri: String) -> Result { match redis::Client::open(redis_uri) { Ok(client) => Ok(Self { redis: client }), Err(e) => Err(e.into()) } } }