From 7024cbefb27e1c9649bff57df32b316484de4104 Mon Sep 17 00:00:00 2001 From: Vika Date: Tue, 4 May 2021 19:24:51 +0300 Subject: Refactored the database module and its tests --- src/database/redis/mod.rs | 304 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 304 insertions(+) create mode 100644 src/database/redis/mod.rs (limited to 'src/database/redis/mod.rs') diff --git a/src/database/redis/mod.rs b/src/database/redis/mod.rs new file mode 100644 index 0000000..2377fac --- /dev/null +++ b/src/database/redis/mod.rs @@ -0,0 +1,304 @@ +use async_trait::async_trait; +use futures_util::FutureExt; +use futures_util::StreamExt; +use futures::stream; +use lazy_static::lazy_static; +use log::error; +use redis; +use redis::AsyncCommands; +use serde_json::json; + +use crate::database::{Storage, Result, StorageError, ErrorKind, MicropubChannel}; +use crate::indieauth::User; + +struct RedisScripts { + edit_post: redis::Script +} + +impl From for StorageError { + fn from(err: redis::RedisError) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::Backend + } + } +} + +lazy_static! { + static ref SCRIPTS: RedisScripts = RedisScripts { + edit_post: redis::Script::new(include_str!("./edit_post.lua")) + }; +} + +#[derive(Clone)] +pub struct RedisStorage { + // TODO: use mobc crate to create a connection pool and reuse connections for efficiency + redis: redis::Client, +} + +fn filter_post<'a>(mut post: serde_json::Value, user: &'a Option) -> Option { + if post["properties"]["deleted"][0].is_string() { + return Some(json!({ + "type": post["type"], + "properties": { + "deleted": post["properties"]["deleted"] + } + })); + } + let empty_vec: Vec = vec![]; + let author = post["properties"]["author"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string()); + let visibility = post["properties"]["visibility"][0].as_str().unwrap_or("public"); + let mut audience = author.chain(post["properties"]["audience"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string())); + if visibility == "private" { + if !audience.any(|i| Some(i) == *user) { + return None + } + } + if post["properties"]["location"].is_array() { + let location_visibility = post["properties"]["location-visibility"][0].as_str().unwrap_or("private"); + let mut author = post["properties"]["author"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string()); + if location_visibility == "private" && !author.any(|i| Some(i) == *user) { + post["properties"].as_object_mut().unwrap().remove("location"); + } + } + Some(post) +} + +#[async_trait] +impl Storage for RedisStorage { + async fn delete_post<'a>(&self, url: &'a str) -> Result<()> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => if let Err(err) = conn.hdel::<&str, &str, bool>("posts", url).await { + return Err(err.into()); + }, + Err(err) => return Err(err.into()) + } + Ok(()) + } + + async fn post_exists(&self, url: &str) -> Result { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => match conn.hexists::<&str, &str, bool>(&"posts", url).await { + Ok(val) => Ok(val), + Err(err) => Err(err.into()) + }, + Err(err) => Err(err.into()) + } + } + + async fn get_post(&self, url: &str) -> Result> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => match conn.hget::<&str, &str, Option>(&"posts", url).await { + Ok(val) => match val { + Some(val) => match serde_json::from_str::(&val) { + Ok(parsed) => if let Some(new_url) = parsed["see_other"].as_str() { + match conn.hget::<&str, &str, Option>(&"posts", new_url).await { + Ok(val) => match val { + Some(val) => match serde_json::from_str::(&val) { + Ok(parsed) => Ok(Some(parsed)), + Err(err) => Err(err.into()) + }, + None => Ok(None) + } + Err(err) => { + Err(err.into()) + } + } + } else { + Ok(Some(parsed)) + }, + Err(err) => Err(err.into()) + }, + None => Ok(None) + }, + Err(err) => Err(err.into()) + }, + Err(err) => Err(err.into()) + } + } + + async fn get_channels(&self, user: &User) -> Result> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => match conn.smembers::>("channels_".to_string() + user.me.as_str()).await { + Ok(channels) => { + Ok(futures_util::future::join_all(channels.iter() + .map(|channel| self.get_post(channel) + .map(|result| result.unwrap()) + .map(|post: Option| { + if let Some(post) = post { + Some(MicropubChannel { + uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), + name: post["properties"]["name"][0].as_str().unwrap().to_string() + }) + } else { None } + }) + ).collect::>()).await.into_iter().filter_map(|chan| chan).collect::>()) + }, + Err(err) => Err(err.into()) + }, + Err(err) => Err(err.into()) + } + } + + async fn put_post<'a>(&self, post: &'a serde_json::Value) -> Result<()> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => { + let key: &str; + match post["properties"]["uid"][0].as_str() { + Some(uid) => key = uid, + None => return Err(StorageError::new(ErrorKind::Other, "post doesn't have a UID")) + } + match conn.hset::<&str, &str, String, ()>(&"posts", key, post.to_string()).await { + Err(err) => return Err(err.into()), + _ => {} + } + if post["properties"]["url"].is_array() { + for url in post["properties"]["url"].as_array().unwrap().iter().map(|i| i.as_str().unwrap().to_string()) { + if &url != key { + match conn.hset::<&str, &str, String, ()>(&"posts", &url, json!({"see_other": key}).to_string()).await { + Err(err) => return Err(err.into()), + _ => {} + } + } + } + } + if post["type"].as_array().unwrap().iter().any(|i| i == "h-feed") { + // This is a feed. Add it to the channels array if it's not already there. + match conn.sadd::("channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(), key).await { + Err(err) => return Err(err.into()), + _ => {}, + } + } + Ok(()) + }, + Err(err) => Err(err.into()) + } + } + + async fn read_feed_with_limit<'a>(&self, url: &'a str, after: &'a Option, limit: usize, user: &'a Option) -> Result> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => { + let mut feed; + match conn.hget::<&str, &str, Option>(&"posts", url).await { + Ok(post) => { + match post { + Some(post) => match serde_json::from_str::(&post) { + Ok(post) => feed = post, + Err(err) => return Err(err.into()) + }, + None => return Ok(None) + } + }, + Err(err) => return Err(err.into()) + } + if feed["see_other"].is_string() { + match conn.hget::<&str, &str, Option>(&"posts", feed["see_other"].as_str().unwrap()).await { + Ok(post) => { + match post { + Some(post) => match serde_json::from_str::(&post) { + Ok(post) => feed = post, + Err(err) => return Err(err.into()) + }, + None => return Ok(None) + } + }, + Err(err) => return Err(err.into()) + } + } + if let Some(post) = filter_post(feed, user) { + feed = post + } else { + return Err(StorageError::new(ErrorKind::PermissionDenied, "specified user cannot access this post")) + } + if feed["children"].is_array() { + let children = feed["children"].as_array().unwrap(); + let posts_iter: Box + Send>; + if let Some(after) = after { + posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string()).skip_while(move |i| i != after).skip(1)); + } else { + posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string())); + } + let posts = stream::iter(posts_iter) + .map(|url| async move { + // Is it rational to use a new connection for every post fetched? + match self.redis.get_async_std_connection().await { + Ok(mut conn) => match conn.hget::<&str, &str, Option>("posts", &url).await { + Ok(post) => match post { + Some(post) => match serde_json::from_str::(&post) { + Ok(post) => Some(post), + Err(err) => { + let err = StorageError::from(err); + error!("{}", err); + panic!("{}", err) + } + }, + // Happens because of a broken link (result of an improper deletion?) + None => None, + }, + Err(err) => { + let err = StorageError::from(err); + error!("{}", err); + panic!("{}", err) + } + }, + Err(err) => { + let err = StorageError::from(err); + error!("{}", err); + panic!("{}", err) + } + } + }) + // TODO: determine the optimal value for this buffer + // It will probably depend on how often can you encounter a private post on the page + // It shouldn't be too large, or we'll start fetching too many posts from the database + // It MUST NOT be larger than the typical page size + .buffered(std::cmp::min(3, limit)) + // Hack to unwrap the Option and sieve out broken links + // Broken links return None, and Stream::filter_map skips all Nones. + .filter_map(|post: Option| async move { post }) + .filter_map(|post| async move { + return filter_post(post, user) + }) + .take(limit); + match std::panic::AssertUnwindSafe(posts.collect::>()).catch_unwind().await { + Ok(posts) => feed["children"] = json!(posts), + Err(err) => return Err(StorageError::new(ErrorKind::Other, "Unknown error encountered while assembling feed, see logs for more info")) + } + } + return Ok(Some(feed)); + } + Err(err) => Err(err.into()) + } + } + + async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => { + if !conn.hexists::<&str, &str, bool>("posts", url).await.unwrap() { + return Err(StorageError::new(ErrorKind::NotFound, "can't edit a non-existent post")) + } + let post: serde_json::Value = serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await.unwrap()).unwrap(); + if let Some(new_url) = post["see_other"].as_str() { + url = new_url + } + if let Err(err) = SCRIPTS.edit_post.key("posts").arg(url).arg(update.to_string()).invoke_async::<_, ()>(&mut conn).await { + return Err(err.into()) + } + }, + Err(err) => return Err(err.into()) + } + Ok(()) + } +} + + +impl RedisStorage { + /// Create a new RedisDatabase that will connect to Redis at `redis_uri` to store data. + pub async fn new(redis_uri: String) -> Result { + match redis::Client::open(redis_uri) { + Ok(client) => Ok(Self { redis: client }), + Err(e) => Err(e.into()) + } + } +} \ No newline at end of file -- cgit 1.4.1