use async_trait::async_trait; use futures::stream; use futures_util::FutureExt; use futures_util::StreamExt; use futures_util::TryStream; use futures_util::TryStreamExt; use lazy_static::lazy_static; use log::error; use mobc::Pool; use mobc_redis::redis; use mobc_redis::redis::AsyncCommands; use mobc_redis::RedisConnectionManager; use serde_json::json; use std::time::Duration; use crate::database::{ErrorKind, MicropubChannel, Result, Storage, StorageError, filter_post}; use crate::indieauth::User; struct RedisScripts { edit_post: redis::Script, } impl From for StorageError { fn from(err: mobc_redis::redis::RedisError) -> Self { Self { msg: format!("{}", err), source: Some(Box::new(err)), kind: ErrorKind::Backend, } } } impl From> for StorageError { fn from(err: mobc::Error) -> Self { Self { msg: format!("{}", err), source: Some(Box::new(err)), kind: ErrorKind::Backend, } } } lazy_static! { static ref SCRIPTS: RedisScripts = RedisScripts { edit_post: redis::Script::new(include_str!("./edit_post.lua")) }; } #[derive(Clone)] pub struct RedisStorage { // note to future Vika: // mobc::Pool is actually a fancy name for an Arc // around a shared connection pool with a manager // which makes it safe to implement [`Clone`] and // not worry about new pools being suddenly made // // stop worrying and start coding, you dum-dum redis: mobc::Pool, } #[async_trait] impl Storage for RedisStorage { async fn get_setting<'a>(&self, setting: &'a str, user: &'a str) -> Result { let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; Ok(conn .hget::(format!("settings_{}", user), setting) .await?) } async fn set_setting<'a>(&self, setting: &'a str, user: &'a str, value: &'a str) -> Result<()> { let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; Ok(conn .hset::(format!("settings_{}", user), setting, value) .await?) } async fn delete_post<'a>(&self, url: &'a str) -> Result<()> { let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; Ok(conn.hdel::<&str, &str, ()>("posts", url).await?) } async fn post_exists(&self, url: &str) -> Result { let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; Ok(conn.hexists::<&str, &str, bool>("posts", url).await?) } async fn get_post(&self, url: &str) -> Result> { let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; match conn .hget::<&str, &str, Option>("posts", url) .await? { Some(val) => { let parsed = serde_json::from_str::(&val)?; if let Some(new_url) = parsed["see_other"].as_str() { match conn .hget::<&str, &str, Option>("posts", new_url) .await? { Some(val) => Ok(Some(serde_json::from_str::(&val)?)), None => Ok(None), } } else { Ok(Some(parsed)) } } None => Ok(None), } } async fn get_channels(&self, user: &User) -> Result> { let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; let channels = conn .smembers::>("channels_".to_string() + user.me.as_str()) .await?; // TODO: use streams here instead of this weird thing... how did I even write this?! Ok(futures_util::future::join_all( channels .iter() .map(|channel| { self.get_post(channel).map(|result| result.unwrap()).map( |post: Option| { post.map(|post| MicropubChannel { uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), name: post["properties"]["name"][0].as_str().unwrap().to_string(), }) }, ) }) .collect::>(), ) .await .into_iter() .flatten() .collect::>()) } async fn put_post<'a>(&self, post: &'a serde_json::Value, user: &'a str) -> Result<()> { let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; let key: &str; match post["properties"]["uid"][0].as_str() { Some(uid) => key = uid, None => { return Err(StorageError::new( ErrorKind::BadRequest, "post doesn't have a UID", )) } } conn.hset::<&str, &str, String, ()>("posts", key, post.to_string()) .await?; if post["properties"]["url"].is_array() { for url in post["properties"]["url"] .as_array() .unwrap() .iter() .map(|i| i.as_str().unwrap().to_string()) { if url != key && url.starts_with(user) { conn.hset::<&str, &str, String, ()>( "posts", &url, json!({ "see_other": key }).to_string(), ) .await?; } } } if post["type"] .as_array() .unwrap() .iter() .any(|i| i == "h-feed") { // This is a feed. Add it to the channels array if it's not already there. conn.sadd::( "channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(), key, ) .await? } Ok(()) } async fn read_feed_with_limit<'a>( &self, url: &'a str, after: &'a Option, limit: usize, user: &'a Option, ) -> Result> { let mut conn = self.redis.get().await?; let mut feed; match conn .hget::<&str, &str, Option>("posts", url) .await .map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))? { Some(post) => feed = serde_json::from_str::(&post)?, None => return Ok(None), } if feed["see_other"].is_string() { match conn .hget::<&str, &str, Option>("posts", feed["see_other"].as_str().unwrap()) .await? { Some(post) => feed = serde_json::from_str::(&post)?, None => return Ok(None), } } if let Some(post) = filter_post(feed, user) { feed = post } else { return Err(StorageError::new( ErrorKind::PermissionDenied, "specified user cannot access this post", )); } if feed["children"].is_array() { let children = feed["children"].as_array().unwrap(); let mut posts_iter = children.iter().map(|i| i.as_str().unwrap().to_string()); if after.is_some() { loop { let i = posts_iter.next(); if &i == after { break; } } } async fn fetch_post_for_feed(url: String) -> Option { return Some(serde_json::json!({})); } let posts = stream::iter(posts_iter) .map(|url: String| async move { return Ok(fetch_post_for_feed(url).await); /*match self.redis.get().await { Ok(mut conn) => { match conn.hget::<&str, &str, Option>("posts", &url).await { Ok(post) => match post { Some(post) => { Ok(Some(serde_json::from_str(&post)?)) } // Happens because of a broken link (result of an improper deletion?) None => Ok(None), }, Err(err) => Err(StorageError::with_source(ErrorKind::Backend, "Error executing a Redis command", Box::new(err))) } } Err(err) => Err(StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(err))) }*/ }) // TODO: determine the optimal value for this buffer // It will probably depend on how often can you encounter a private post on the page // It shouldn't be too large, or we'll start fetching too many posts from the database // It MUST NOT be larger than the typical page size // It MUST NOT be a significant amount of the connection pool size //.buffered(std::cmp::min(3, limit)) // Hack to unwrap the Option and sieve out broken links // Broken links return None, and Stream::filter_map skips all Nones. // I wonder if one can use try_flatten() here somehow akin to iters .try_filter_map(|post| async move { Ok(post) }) .try_filter_map(|post| async move { Ok(filter_post(post, user)) }) .take(limit); match posts.try_collect::>().await { Ok(posts) => feed["children"] = json!(posts), Err(err) => { let e = StorageError::with_source( ErrorKind::Other, "An error was encountered while processing the feed", Box::new(err) ); error!("Error while assembling feed: {}", e); return Err(e); } } } return Ok(Some(feed)); } async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> { let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; if !conn .hexists::<&str, &str, bool>("posts", url) .await .unwrap() { return Err(StorageError::new( ErrorKind::NotFound, "can't edit a non-existent post", )); } let post: serde_json::Value = serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await?)?; if let Some(new_url) = post["see_other"].as_str() { url = new_url } Ok(SCRIPTS .edit_post .key("posts") .arg(url) .arg(update.to_string()) .invoke_async::<_, ()>(&mut conn as &mut redis::aio::Connection) .await?) } } impl RedisStorage { /// Create a new RedisDatabase that will connect to Redis at `redis_uri` to store data. pub async fn new(redis_uri: String) -> Result { match redis::Client::open(redis_uri) { Ok(client) => Ok(Self { redis: Pool::builder() .max_open(20) .max_idle(5) .get_timeout(Some(Duration::from_secs(3))) .max_lifetime(Some(Duration::from_secs(120))) .build(RedisConnectionManager::new(client)), }), Err(e) => Err(e.into()), } } pub async fn conn(&self) -> Result> { self.redis.get().await.map_err(|e| StorageError::with_source( ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e) )) } } #[cfg(test)] pub mod tests { use mobc_redis::redis; use std::process; use std::time::Duration; pub struct RedisInstance { // We just need to hold on to it so it won't get dropped and remove the socket _tempdir: tempdir::TempDir, uri: String, child: std::process::Child, } impl Drop for RedisInstance { fn drop(&mut self) { self.child.kill().expect("Failed to kill the child!"); } } impl RedisInstance { pub fn uri(&self) -> &str { &self.uri } } pub async fn get_redis_instance() -> RedisInstance { let tempdir = tempdir::TempDir::new("redis").expect("failed to create tempdir"); let socket = tempdir.path().join("redis.sock"); let redis_child = process::Command::new("redis-server") .current_dir(&tempdir) .arg("--port") .arg("0") .arg("--unixsocket") .arg(&socket) .stdout(process::Stdio::null()) .stderr(process::Stdio::null()) .spawn() .expect("Failed to spawn Redis"); println!("redis+unix:///{}", socket.to_str().unwrap()); let uri = format!("redis+unix:///{}", socket.to_str().unwrap()); // There should be a slight delay, we need to wait for Redis to spin up let client = redis::Client::open(uri.clone()).unwrap(); let millisecond = Duration::from_millis(1); let mut retries: usize = 0; const MAX_RETRIES: usize = 60 * 1000/*ms*/; while let Err(err) = client.get_connection() { if err.is_connection_refusal() { async_std::task::sleep(millisecond).await; retries += 1; if retries > MAX_RETRIES { panic!("Timeout waiting for Redis, last error: {}", err); } } else { panic!("Could not connect: {}", err); } } RedisInstance { uri, child: redis_child, _tempdir: tempdir, } } }