use async_trait::async_trait; use futures_util::FutureExt; use futures_util::StreamExt; use futures::stream; use lazy_static::lazy_static; use log::error; use mobc_redis::redis; use mobc_redis::redis::AsyncCommands; use serde_json::json; use mobc::Pool; use mobc_redis::RedisConnectionManager; use crate::database::{Storage, Result, StorageError, ErrorKind, MicropubChannel}; use crate::indieauth::User; struct RedisScripts { edit_post: redis::Script } impl From for StorageError { fn from(err: mobc_redis::redis::RedisError) -> Self { Self { msg: format!("{}", err), source: Some(Box::new(err)), kind: ErrorKind::Backend } } } impl From> for StorageError { fn from(err: mobc::Error) -> Self { Self { msg: format!("{}", err), source: Some(Box::new(err)), kind: ErrorKind::Backend } } } lazy_static! { static ref SCRIPTS: RedisScripts = RedisScripts { edit_post: redis::Script::new(include_str!("./edit_post.lua")) }; } #[derive(Clone)] pub struct RedisStorage { // note to future Vika: // mobc::Pool is actually a fancy name for an Arc // around a shared connection pool with a manager // which makes it safe to implement [`Clone`] and // not worry about new pools being suddenly made // // stop worrying and start coding, you dum-dum redis: mobc::Pool, } fn filter_post(mut post: serde_json::Value, user: &'_ Option) -> Option { if post["properties"]["deleted"][0].is_string() { return Some(json!({ "type": post["type"], "properties": { "deleted": post["properties"]["deleted"] } })); } let empty_vec: Vec = vec![]; let author = post["properties"]["author"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string()); let visibility = post["properties"]["visibility"][0].as_str().unwrap_or("public"); let mut audience = author.chain(post["properties"]["audience"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string())); if (visibility == "private" && !audience.any(|i| Some(i) == *user)) || (visibility == "protected" && user.is_none()) { return None } if post["properties"]["location"].is_array() { let location_visibility = post["properties"]["location-visibility"][0].as_str().unwrap_or("private"); let mut author = post["properties"]["author"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string()); if location_visibility == "private" && !author.any(|i| Some(i) == *user) { post["properties"].as_object_mut().unwrap().remove("location"); } } Some(post) } #[async_trait] impl Storage for RedisStorage { async fn delete_post<'a>(&self, url: &'a str) -> Result<()> { let mut conn = self.redis.get().await?; Ok(conn.hdel::<&str, &str, ()>("posts", url).await?) } async fn post_exists(&self, url: &str) -> Result { let mut conn = self.redis.get().await?; Ok(conn.hexists::<&str, &str, bool>(&"posts", url).await?) } async fn get_post(&self, url: &str) -> Result> { let mut conn = self.redis.get().await?; match conn.hget::<&str, &str, Option>(&"posts", url).await? { Some(val) => { let parsed = serde_json::from_str::(&val)?; if let Some(new_url) = parsed["see_other"].as_str() { match conn.hget::<&str, &str, Option>(&"posts", new_url).await? { Some(val) => Ok(Some(serde_json::from_str::(&val)?)), None => Ok(None) } } else { Ok(Some(parsed)) } }, None => Ok(None) } } async fn get_channels(&self, user: &User) -> Result> { let mut conn = self.redis.get().await?; let channels = conn.smembers::>("channels_".to_string() + user.me.as_str()).await?; // TODO: use streams here instead of this weird thing... how did I even write this?! Ok(futures_util::future::join_all(channels.iter() .map(|channel| self.get_post(channel) .map(|result| result.unwrap()) .map(|post: Option| { if let Some(post) = post { Some(MicropubChannel { uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), name: post["properties"]["name"][0].as_str().unwrap().to_string() }) } else { None } }) ).collect::>()).await.into_iter().filter_map(|chan| chan).collect::>()) } async fn put_post<'a>(&self, post: &'a serde_json::Value) -> Result<()> { let mut conn = self.redis.get().await?; let key: &str; match post["properties"]["uid"][0].as_str() { Some(uid) => key = uid, None => return Err(StorageError::new(ErrorKind::BadRequest, "post doesn't have a UID")) } conn.hset::<&str, &str, String, ()>(&"posts", key, post.to_string()).await?; if post["properties"]["url"].is_array() { for url in post["properties"]["url"].as_array().unwrap().iter().map(|i| i.as_str().unwrap().to_string()) { if url != key { conn.hset::<&str, &str, String, ()>(&"posts", &url, json!({"see_other": key}).to_string()).await?; } } } if post["type"].as_array().unwrap().iter().any(|i| i == "h-feed") { // This is a feed. Add it to the channels array if it's not already there. conn.sadd::("channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(), key).await? } Ok(()) } async fn read_feed_with_limit<'a>(&self, url: &'a str, after: &'a Option, limit: usize, user: &'a Option) -> Result> { let mut conn = self.redis.get().await?; let mut feed; match conn.hget::<&str, &str, Option>(&"posts", url).await? { Some(post) => feed = serde_json::from_str::(&post)?, None => return Ok(None) } if feed["see_other"].is_string() { match conn.hget::<&str, &str, Option>(&"posts", feed["see_other"].as_str().unwrap()).await? { Some(post) => feed = serde_json::from_str::(&post)?, None => return Ok(None) } } if let Some(post) = filter_post(feed, user) { feed = post } else { return Err(StorageError::new(ErrorKind::PermissionDenied, "specified user cannot access this post")) } if feed["children"].is_array() { let children = feed["children"].as_array().unwrap(); let posts_iter: Box + Send>; // TODO: refactor this to apply the skip on the &mut iterator if let Some(after) = after { posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string()).skip_while(move |i| i != after).skip(1)); } else { posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string())); } let posts = stream::iter(posts_iter) .map(|url| async move { match self.redis.get().await { Ok(mut conn) => match conn.hget::<&str, &str, Option>("posts", &url).await { Ok(post) => match post { Some(post) => match serde_json::from_str::(&post) { Ok(post) => Some(post), Err(err) => { let err = StorageError::from(err); error!("{}", err); panic!("{}", err) } }, // Happens because of a broken link (result of an improper deletion?) None => None, }, Err(err) => { let err = StorageError::from(err); error!("{}", err); panic!("{}", err) } }, // TODO: Instead of causing a panic, investigate how can you fail the whole stream // Somehow fuse it maybe? Err(err) => { let err = StorageError::from(err); error!("{}", err); panic!("{}", err) } } }) // TODO: determine the optimal value for this buffer // It will probably depend on how often can you encounter a private post on the page // It shouldn't be too large, or we'll start fetching too many posts from the database // It MUST NOT be larger than the typical page size // It MUST NOT be a significant amount of the connection pool size .buffered(std::cmp::min(3, limit)) // Hack to unwrap the Option and sieve out broken links // Broken links return None, and Stream::filter_map skips all Nones. .filter_map(|post: Option| async move { post }) .filter_map(|post| async move { filter_post(post, user) }) .take(limit); // TODO: Instead of catching panics, find a way to make the whole stream fail with Result> match std::panic::AssertUnwindSafe(posts.collect::>()).catch_unwind().await { Ok(posts) => feed["children"] = json!(posts), Err(_) => return Err(StorageError::new(ErrorKind::Other, "Unknown error encountered while assembling feed, see logs for more info")) } } return Ok(Some(feed)); } async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> { let mut conn = self.redis.get().await?; if !conn.hexists::<&str, &str, bool>("posts", url).await.unwrap() { return Err(StorageError::new(ErrorKind::NotFound, "can't edit a non-existent post")) } let post: serde_json::Value = serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await?)?; if let Some(new_url) = post["see_other"].as_str() { url = new_url } Ok(SCRIPTS.edit_post.key("posts").arg(url).arg(update.to_string()).invoke_async::<_, ()>(&mut conn as &mut redis::aio::Connection).await?) } } impl RedisStorage { /// Create a new RedisDatabase that will connect to Redis at `redis_uri` to store data. pub async fn new(redis_uri: String) -> Result { match redis::Client::open(redis_uri) { Ok(client) => Ok(Self { redis: Pool::builder().max_open(20).build(RedisConnectionManager::new(client)) }), Err(e) => Err(e.into()) } } } #[cfg(test)] pub mod tests { use std::process; use std::time::Duration; use mobc_redis::redis; pub struct RedisInstance { // We just need to hold on to it so it won't get dropped and remove the socket _tempdir: tempdir::TempDir, uri: String, child: std::process::Child } impl Drop for RedisInstance { fn drop(&mut self) { self.child.kill().expect("Failed to kill the child!"); } } impl RedisInstance { pub fn uri(&self) -> &str { &self.uri } } pub async fn get_redis_instance() -> RedisInstance { let tempdir = tempdir::TempDir::new("redis").expect("failed to create tempdir"); let socket = tempdir.path().join("redis.sock"); let redis_child = process::Command::new("redis-server") .current_dir(&tempdir) .arg("--port").arg("0") .arg("--unixsocket").arg(&socket) .stdout(process::Stdio::null()) .stderr(process::Stdio::null()) .spawn().expect("Failed to spawn Redis"); println!("redis+unix:///{}", socket.to_str().unwrap()); let uri = format!("redis+unix:///{}", socket.to_str().unwrap()); // There should be a slight delay, we need to wait for Redis to spin up let client = redis::Client::open(uri.clone()).unwrap(); let millisecond = Duration::from_millis(1); let mut retries: usize = 0; const MAX_RETRIES: usize = 60 * 1000/*ms*/; while let Err(err) = client.get_connection() { if err.is_connection_refusal() { async_std::task::sleep(millisecond).await; retries += 1; if retries > MAX_RETRIES { panic!("Timeout waiting for Redis, last error: {}", err); } } else { panic!("Could not connect: {}", err); } } return RedisInstance { uri, child: redis_child, _tempdir: tempdir } } }