use async_trait::async_trait; use futures::stream; use futures_util::FutureExt; use futures_util::StreamExt; use lazy_static::lazy_static; use log::error; use mobc::Pool; use mobc_redis::redis; use mobc_redis::redis::AsyncCommands; use mobc_redis::RedisConnectionManager; use serde_json::json; use crate::database::{ErrorKind, MicropubChannel, Result, Storage, StorageError}; use crate::indieauth::User; struct RedisScripts { edit_post: redis::Script, } impl From for StorageError { fn from(err: mobc_redis::redis::RedisError) -> Self { Self { msg: format!("{}", err), source: Some(Box::new(err)), kind: ErrorKind::Backend, } } } impl From> for StorageError { fn from(err: mobc::Error) -> Self { Self { msg: format!("{}", err), source: Some(Box::new(err)), kind: ErrorKind::Backend, } } } lazy_static! { static ref SCRIPTS: RedisScripts = RedisScripts { edit_post: redis::Script::new(include_str!("./edit_post.lua")) }; } #[derive(Clone)] pub struct RedisStorage { // note to future Vika: // mobc::Pool is actually a fancy name for an Arc // around a shared connection pool with a manager // which makes it safe to implement [`Clone`] and // not worry about new pools being suddenly made // // stop worrying and start coding, you dum-dum redis: mobc::Pool, } fn filter_post(mut post: serde_json::Value, user: &'_ Option) -> Option { if post["properties"]["deleted"][0].is_string() { return Some(json!({ "type": post["type"], "properties": { "deleted": post["properties"]["deleted"] } })); } let empty_vec: Vec = vec![]; let author = post["properties"]["author"] .as_array() .unwrap_or(&empty_vec) .iter() .map(|i| i.as_str().unwrap().to_string()); let visibility = post["properties"]["visibility"][0] .as_str() .unwrap_or("public"); let mut audience = author.chain( post["properties"]["audience"] .as_array() .unwrap_or(&empty_vec) .iter() .map(|i| i.as_str().unwrap().to_string()), ); if (visibility == "private" && !audience.any(|i| Some(i) == *user)) || (visibility == "protected" && user.is_none()) { return None; } if post["properties"]["location"].is_array() { let location_visibility = post["properties"]["location-visibility"][0] .as_str() .unwrap_or("private"); let mut author = post["properties"]["author"] .as_array() .unwrap_or(&empty_vec) .iter() .map(|i| i.as_str().unwrap().to_string()); if location_visibility == "private" && !author.any(|i| Some(i) == *user) { post["properties"] .as_object_mut() .unwrap() .remove("location"); } } Some(post) } #[async_trait] impl Storage for RedisStorage { async fn get_setting<'a>(&self, setting: &'a str, user: &'a str) -> Result { let mut conn = self.redis.get().await?; Ok(conn .hget::(format!("settings_{}", user), setting) .await?) } async fn set_setting<'a>(&self, setting: &'a str, user: &'a str, value: &'a str) -> Result<()> { let mut conn = self.redis.get().await?; Ok(conn .hset::(format!("settings_{}", user), setting, value) .await?) } async fn delete_post<'a>(&self, url: &'a str) -> Result<()> { let mut conn = self.redis.get().await?; Ok(conn.hdel::<&str, &str, ()>("posts", url).await?) } async fn post_exists(&self, url: &str) -> Result { let mut conn = self.redis.get().await?; Ok(conn.hexists::<&str, &str, bool>(&"posts", url).await?) } async fn get_post(&self, url: &str) -> Result> { let mut conn = self.redis.get().await?; match conn .hget::<&str, &str, Option>(&"posts", url) .await? { Some(val) => { let parsed = serde_json::from_str::(&val)?; if let Some(new_url) = parsed["see_other"].as_str() { match conn .hget::<&str, &str, Option>(&"posts", new_url) .await? { Some(val) => Ok(Some(serde_json::from_str::(&val)?)), None => Ok(None), } } else { Ok(Some(parsed)) } } None => Ok(None), } } async fn get_channels(&self, user: &User) -> Result> { let mut conn = self.redis.get().await?; let channels = conn .smembers::>("channels_".to_string() + user.me.as_str()) .await?; // TODO: use streams here instead of this weird thing... how did I even write this?! Ok(futures_util::future::join_all( channels .iter() .map(|channel| { self.get_post(channel).map(|result| result.unwrap()).map( |post: Option| { post.map(|post| MicropubChannel { uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), name: post["properties"]["name"][0] .as_str() .unwrap() .to_string(), }) }, ) }) .collect::>(), ) .await .into_iter() .flatten() .collect::>()) } async fn put_post<'a>(&self, post: &'a serde_json::Value) -> Result<()> { let mut conn = self.redis.get().await?; let key: &str; match post["properties"]["uid"][0].as_str() { Some(uid) => key = uid, None => { return Err(StorageError::new( ErrorKind::BadRequest, "post doesn't have a UID", )) } } conn.hset::<&str, &str, String, ()>(&"posts", key, post.to_string()) .await?; if post["properties"]["url"].is_array() { for url in post["properties"]["url"] .as_array() .unwrap() .iter() .map(|i| i.as_str().unwrap().to_string()) { if url != key { conn.hset::<&str, &str, String, ()>( &"posts", &url, json!({ "see_other": key }).to_string(), ) .await?; } } } if post["type"] .as_array() .unwrap() .iter() .any(|i| i == "h-feed") { // This is a feed. Add it to the channels array if it's not already there. conn.sadd::( "channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(), key, ) .await? } Ok(()) } async fn read_feed_with_limit<'a>( &self, url: &'a str, after: &'a Option, limit: usize, user: &'a Option, ) -> Result> { let mut conn = self.redis.get().await?; let mut feed; match conn .hget::<&str, &str, Option>(&"posts", url) .await? { Some(post) => feed = serde_json::from_str::(&post)?, None => return Ok(None), } if feed["see_other"].is_string() { match conn .hget::<&str, &str, Option>(&"posts", feed["see_other"].as_str().unwrap()) .await? { Some(post) => feed = serde_json::from_str::(&post)?, None => return Ok(None), } } if let Some(post) = filter_post(feed, user) { feed = post } else { return Err(StorageError::new( ErrorKind::PermissionDenied, "specified user cannot access this post", )); } if feed["children"].is_array() { let children = feed["children"].as_array().unwrap(); let posts_iter: Box + Send>; // TODO: refactor this to apply the skip on the &mut iterator if let Some(after) = after { posts_iter = Box::new( children .iter() .map(|i| i.as_str().unwrap().to_string()) .skip_while(move |i| i != after) .skip(1), ); } else { posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string())); } let posts = stream::iter(posts_iter) .map(|url| async move { match self.redis.get().await { Ok(mut conn) => { match conn.hget::<&str, &str, Option>("posts", &url).await { Ok(post) => match post { Some(post) => { match serde_json::from_str::(&post) { Ok(post) => Some(post), Err(err) => { let err = StorageError::from(err); error!("{}", err); panic!("{}", err) } } } // Happens because of a broken link (result of an improper deletion?) None => None, }, Err(err) => { let err = StorageError::from(err); error!("{}", err); panic!("{}", err) } } } // TODO: Instead of causing a panic, investigate how can you fail the whole stream // Somehow fuse it maybe? Err(err) => { let err = StorageError::from(err); error!("{}", err); panic!("{}", err) } } }) // TODO: determine the optimal value for this buffer // It will probably depend on how often can you encounter a private post on the page // It shouldn't be too large, or we'll start fetching too many posts from the database // It MUST NOT be larger than the typical page size // It MUST NOT be a significant amount of the connection pool size .buffered(std::cmp::min(3, limit)) // Hack to unwrap the Option and sieve out broken links // Broken links return None, and Stream::filter_map skips all Nones. .filter_map(|post: Option| async move { post }) .filter_map(|post| async move { filter_post(post, user) }) .take(limit); // TODO: Instead of catching panics, find a way to make the whole stream fail with Result> match std::panic::AssertUnwindSafe(posts.collect::>()) .catch_unwind() .await { Ok(posts) => feed["children"] = json!(posts), Err(_) => { return Err(StorageError::new( ErrorKind::Other, "Unknown error encountered while assembling feed, see logs for more info", )) } } } return Ok(Some(feed)); } async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> { let mut conn = self.redis.get().await?; if !conn .hexists::<&str, &str, bool>("posts", url) .await .unwrap() { return Err(StorageError::new( ErrorKind::NotFound, "can't edit a non-existent post", )); } let post: serde_json::Value = serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await?)?; if let Some(new_url) = post["see_other"].as_str() { url = new_url } Ok(SCRIPTS .edit_post .key("posts") .arg(url) .arg(update.to_string()) .invoke_async::<_, ()>(&mut conn as &mut redis::aio::Connection) .await?) } } impl RedisStorage { /// Create a new RedisDatabase that will connect to Redis at `redis_uri` to store data. pub async fn new(redis_uri: String) -> Result { match redis::Client::open(redis_uri) { Ok(client) => Ok(Self { redis: Pool::builder() .max_open(20) .build(RedisConnectionManager::new(client)), }), Err(e) => Err(e.into()), } } } #[cfg(test)] pub mod tests { use mobc_redis::redis; use std::process; use std::time::Duration; pub struct RedisInstance { // We just need to hold on to it so it won't get dropped and remove the socket _tempdir: tempdir::TempDir, uri: String, child: std::process::Child, } impl Drop for RedisInstance { fn drop(&mut self) { self.child.kill().expect("Failed to kill the child!"); } } impl RedisInstance { pub fn uri(&self) -> &str { &self.uri } } pub async fn get_redis_instance() -> RedisInstance { let tempdir = tempdir::TempDir::new("redis").expect("failed to create tempdir"); let socket = tempdir.path().join("redis.sock"); let redis_child = process::Command::new("redis-server") .current_dir(&tempdir) .arg("--port") .arg("0") .arg("--unixsocket") .arg(&socket) .stdout(process::Stdio::null()) .stderr(process::Stdio::null()) .spawn() .expect("Failed to spawn Redis"); println!("redis+unix:///{}", socket.to_str().unwrap()); let uri = format!("redis+unix:///{}", socket.to_str().unwrap()); // There should be a slight delay, we need to wait for Redis to spin up let client = redis::Client::open(uri.clone()).unwrap(); let millisecond = Duration::from_millis(1); let mut retries: usize = 0; const MAX_RETRIES: usize = 60 * 1000/*ms*/; while let Err(err) = client.get_connection() { if err.is_connection_refusal() { async_std::task::sleep(millisecond).await; retries += 1; if retries > MAX_RETRIES { panic!("Timeout waiting for Redis, last error: {}", err); } } else { panic!("Could not connect: {}", err); } } return RedisInstance { uri, child: redis_child, _tempdir: tempdir, }; } }