diff options
Diffstat (limited to 'src/database/redis')
| -rw-r--r-- | src/database/redis/edit_post.lua | 93 | ||||
| -rw-r--r-- | src/database/redis/mod.rs | 398 | 
2 files changed, 491 insertions, 0 deletions
| diff --git a/src/database/redis/edit_post.lua b/src/database/redis/edit_post.lua new file mode 100644 index 0000000..a398f8d --- /dev/null +++ b/src/database/redis/edit_post.lua @@ -0,0 +1,93 @@ +local posts = KEYS[1] +local update_desc = cjson.decode(ARGV[2]) +local post = cjson.decode(redis.call("HGET", posts, ARGV[1])) + +local delete_keys = {} +local delete_kvs = {} +local add_keys = {} + +if update_desc.replace ~= nil then + for k, v in pairs(update_desc.replace) do + table.insert(delete_keys, k) + add_keys[k] = v + end +end +if update_desc.delete ~= nil then + if update_desc.delete[0] == nil then + -- Table has string keys. Probably! + for k, v in pairs(update_desc.delete) do + delete_kvs[k] = v + end + else + -- Table has numeric keys. Probably! + for i, v in ipairs(update_desc.delete) do + table.insert(delete_keys, v) + end + end +end +if update_desc.add ~= nil then + for k, v in pairs(update_desc.add) do + add_keys[k] = v + end +end + +for i, v in ipairs(delete_keys) do + post["properties"][v] = nil + -- TODO delete URL links +end + +for k, v in pairs(delete_kvs) do + local index = -1 + if k == "children" then + for j, w in ipairs(post[k]) do + if w == v then + index = j + break + end + end + if index > -1 then + table.remove(post[k], index) + end + else + for j, w in ipairs(post["properties"][k]) do + if w == v then + index = j + break + end + end + if index > -1 then + table.remove(post["properties"][k], index) + -- TODO delete URL links + end + end +end + +for k, v in pairs(add_keys) do + if k == "children" then + if post["children"] == nil then + post["children"] = {} + end + for i, w in ipairs(v) do + table.insert(post["children"], 1, w) + end + else + if post["properties"][k] == nil then + post["properties"][k] = {} + end + for i, w in ipairs(v) do + table.insert(post["properties"][k], w) + end + if k == "url" then + redis.call("HSET", posts, v, cjson.encode({ see_other = post["properties"]["uid"][1] })) + elseif k == "channel" then + local feed = cjson.decode(redis.call("HGET", posts, v)) + table.insert(feed["children"], 1, post["properties"]["uid"][1]) + redis.call("HSET", posts, v, cjson.encode(feed)) + end + end +end + +local encoded = cjson.encode(post) +redis.call("SET", "debug", encoded) +redis.call("HSET", posts, post["properties"]["uid"][1], encoded) +return \ No newline at end of file diff --git a/src/database/redis/mod.rs b/src/database/redis/mod.rs new file mode 100644 index 0000000..39ee852 --- /dev/null +++ b/src/database/redis/mod.rs @@ -0,0 +1,398 @@ +use async_trait::async_trait; +use futures::stream; +use futures_util::FutureExt; +use futures_util::StreamExt; +use futures_util::TryStream; +use futures_util::TryStreamExt; +use lazy_static::lazy_static; +use log::error; +use mobc::Pool; +use mobc_redis::redis; +use mobc_redis::redis::AsyncCommands; +use mobc_redis::RedisConnectionManager; +use serde_json::json; +use std::time::Duration; + +use crate::database::{ErrorKind, MicropubChannel, Result, Storage, StorageError, filter_post}; +use crate::indieauth::User; + +struct RedisScripts { + edit_post: redis::Script, +} + +impl From<mobc_redis::redis::RedisError> for StorageError { + fn from(err: mobc_redis::redis::RedisError) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::Backend, + } + } +} +impl From<mobc::Error<mobc_redis::redis::RedisError>> for StorageError { + fn from(err: mobc::Error<mobc_redis::redis::RedisError>) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::Backend, + } + } +} + +lazy_static! { + static ref SCRIPTS: RedisScripts = RedisScripts { + edit_post: redis::Script::new(include_str!("./edit_post.lua")) + }; +} +/*#[cfg(feature(lazy_cell))] +static SCRIPTS_CELL: std::cell::LazyCell<RedisScripts> = std::cell::LazyCell::new(|| { + RedisScripts { + edit_post: redis::Script::new(include_str!("./edit_post.lua")) + } +});*/ + +#[derive(Clone)] +pub struct RedisStorage { + // note to future Vika: + // mobc::Pool is actually a fancy name for an Arc + // around a shared connection pool with a manager + // which makes it safe to implement [`Clone`] and + // not worry about new pools being suddenly made + // + // stop worrying and start coding, you dum-dum + redis: mobc::Pool<RedisConnectionManager>, +} + +#[async_trait] +impl Storage for RedisStorage { + async fn get_setting<'a>(&self, setting: &'a str, user: &'a str) -> Result<String> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + Ok(conn + .hget::<String, &str, String>(format!("settings_{}", user), setting) + .await?) + } + + async fn set_setting<'a>(&self, setting: &'a str, user: &'a str, value: &'a str) -> Result<()> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + Ok(conn + .hset::<String, &str, &str, ()>(format!("settings_{}", user), setting, value) + .await?) + } + + async fn delete_post<'a>(&self, url: &'a str) -> Result<()> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + Ok(conn.hdel::<&str, &str, ()>("posts", url).await?) + } + + async fn post_exists(&self, url: &str) -> Result<bool> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + Ok(conn.hexists::<&str, &str, bool>("posts", url).await?) + } + + async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + match conn + .hget::<&str, &str, Option<String>>("posts", url) + .await? + { + Some(val) => { + let parsed = serde_json::from_str::<serde_json::Value>(&val)?; + if let Some(new_url) = parsed["see_other"].as_str() { + match conn + .hget::<&str, &str, Option<String>>("posts", new_url) + .await? + { + Some(val) => Ok(Some(serde_json::from_str::<serde_json::Value>(&val)?)), + None => Ok(None), + } + } else { + Ok(Some(parsed)) + } + } + None => Ok(None), + } + } + + async fn get_channels(&self, user: &User) -> Result<Vec<MicropubChannel>> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + let channels = conn + .smembers::<String, Vec<String>>("channels_".to_string() + user.me.as_str()) + .await?; + // TODO: use streams here instead of this weird thing... how did I even write this?! + Ok(futures_util::future::join_all( + channels + .iter() + .map(|channel| { + self.get_post(channel).map(|result| result.unwrap()).map( + |post: Option<serde_json::Value>| { + post.map(|post| MicropubChannel { + uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), + name: post["properties"]["name"][0].as_str().unwrap().to_string(), + }) + }, + ) + }) + .collect::<Vec<_>>(), + ) + .await + .into_iter() + .flatten() + .collect::<Vec<_>>()) + } + + async fn put_post<'a>(&self, post: &'a serde_json::Value, user: &'a str) -> Result<()> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + let key: &str; + match post["properties"]["uid"][0].as_str() { + Some(uid) => key = uid, + None => { + return Err(StorageError::new( + ErrorKind::BadRequest, + "post doesn't have a UID", + )) + } + } + conn.hset::<&str, &str, String, ()>("posts", key, post.to_string()) + .await?; + if post["properties"]["url"].is_array() { + for url in post["properties"]["url"] + .as_array() + .unwrap() + .iter() + .map(|i| i.as_str().unwrap().to_string()) + { + if url != key && url.starts_with(user) { + conn.hset::<&str, &str, String, ()>( + "posts", + &url, + json!({ "see_other": key }).to_string(), + ) + .await?; + } + } + } + if post["type"] + .as_array() + .unwrap() + .iter() + .any(|i| i == "h-feed") + { + // This is a feed. Add it to the channels array if it's not already there. + conn.sadd::<String, &str, ()>( + "channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(), + key, + ) + .await? + } + Ok(()) + } + + async fn read_feed_with_limit<'a>( + &self, + url: &'a str, + after: &'a Option<String>, + limit: usize, + user: &'a Option<String>, + ) -> Result<Option<serde_json::Value>> { + let mut conn = self.redis.get().await?; + let mut feed; + match conn + .hget::<&str, &str, Option<String>>("posts", url) + .await + .map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))? + { + Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?, + None => return Ok(None), + } + if feed["see_other"].is_string() { + match conn + .hget::<&str, &str, Option<String>>("posts", feed["see_other"].as_str().unwrap()) + .await? + { + Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?, + None => return Ok(None), + } + } + if let Some(post) = filter_post(feed, user) { + feed = post + } else { + return Err(StorageError::new( + ErrorKind::PermissionDenied, + "specified user cannot access this post", + )); + } + if feed["children"].is_array() { + let children = feed["children"].as_array().unwrap(); + let mut posts_iter = children.iter().map(|i| i.as_str().unwrap().to_string()); + if after.is_some() { + loop { + let i = posts_iter.next(); + if &i == after { + break; + } + } + } + async fn fetch_post_for_feed(url: String) -> Option<serde_json::Value> { + return Some(serde_json::json!({})); + } + let posts = stream::iter(posts_iter) + .map(|url: String| async move { + return Ok(fetch_post_for_feed(url).await); + /*match self.redis.get().await { + Ok(mut conn) => { + match conn.hget::<&str, &str, Option<String>>("posts", &url).await { + Ok(post) => match post { + Some(post) => { + Ok(Some(serde_json::from_str(&post)?)) + } + // Happens because of a broken link (result of an improper deletion?) + None => Ok(None), + }, + Err(err) => Err(StorageError::with_source(ErrorKind::Backend, "Error executing a Redis command", Box::new(err))) + } + } + Err(err) => Err(StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(err))) + }*/ + }) + // TODO: determine the optimal value for this buffer + // It will probably depend on how often can you encounter a private post on the page + // It shouldn't be too large, or we'll start fetching too many posts from the database + // It MUST NOT be larger than the typical page size + // It MUST NOT be a significant amount of the connection pool size + //.buffered(std::cmp::min(3, limit)) + // Hack to unwrap the Option and sieve out broken links + // Broken links return None, and Stream::filter_map skips all Nones. + // I wonder if one can use try_flatten() here somehow akin to iters + .try_filter_map(|post| async move { Ok(post) }) + .try_filter_map(|post| async move { + Ok(filter_post(post, user)) + }) + .take(limit); + match posts.try_collect::<Vec<serde_json::Value>>().await { + Ok(posts) => feed["children"] = json!(posts), + Err(err) => { + let e = StorageError::with_source( + ErrorKind::Other, + "An error was encountered while processing the feed", + Box::new(err) + ); + error!("Error while assembling feed: {}", e); + return Err(e); + } + } + } + return Ok(Some(feed)); + } + + async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + if !conn + .hexists::<&str, &str, bool>("posts", url) + .await + .unwrap() + { + return Err(StorageError::new( + ErrorKind::NotFound, + "can't edit a non-existent post", + )); + } + let post: serde_json::Value = + serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await?)?; + if let Some(new_url) = post["see_other"].as_str() { + url = new_url + } + Ok(SCRIPTS + .edit_post + .key("posts") + .arg(url) + .arg(update.to_string()) + .invoke_async::<_, ()>(&mut conn as &mut redis::aio::Connection) + .await?) + } +} + +impl RedisStorage { + /// Create a new RedisDatabase that will connect to Redis at `redis_uri` to store data. + pub async fn new(redis_uri: String) -> Result<Self> { + match redis::Client::open(redis_uri) { + Ok(client) => Ok(Self { + redis: Pool::builder() + .max_open(20) + .max_idle(5) + .get_timeout(Some(Duration::from_secs(3))) + .max_lifetime(Some(Duration::from_secs(120))) + .build(RedisConnectionManager::new(client)), + }), + Err(e) => Err(e.into()), + } + } + + pub async fn conn(&self) -> Result<mobc::Connection<mobc_redis::RedisConnectionManager>> { + self.redis.get().await.map_err(|e| StorageError::with_source( + ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e) + )) + } +} + +#[cfg(test)] +pub mod tests { + use mobc_redis::redis; + use std::process; + use std::time::Duration; + + pub struct RedisInstance { + // We just need to hold on to it so it won't get dropped and remove the socket + _tempdir: tempdir::TempDir, + uri: String, + child: std::process::Child, + } + impl Drop for RedisInstance { + fn drop(&mut self) { + self.child.kill().expect("Failed to kill the child!"); + } + } + impl RedisInstance { + pub fn uri(&self) -> &str { + &self.uri + } + } + + pub async fn get_redis_instance() -> RedisInstance { + let tempdir = tempdir::TempDir::new("redis").expect("failed to create tempdir"); + let socket = tempdir.path().join("redis.sock"); + let redis_child = process::Command::new("redis-server") + .current_dir(&tempdir) + .arg("--port") + .arg("0") + .arg("--unixsocket") + .arg(&socket) + .stdout(process::Stdio::null()) + .stderr(process::Stdio::null()) + .spawn() + .expect("Failed to spawn Redis"); + println!("redis+unix:///{}", socket.to_str().unwrap()); + let uri = format!("redis+unix:///{}", socket.to_str().unwrap()); + // There should be a slight delay, we need to wait for Redis to spin up + let client = redis::Client::open(uri.clone()).unwrap(); + let millisecond = Duration::from_millis(1); + let mut retries: usize = 0; + const MAX_RETRIES: usize = 60 * 1000/*ms*/; + while let Err(err) = client.get_connection() { + if err.is_connection_refusal() { + async_std::task::sleep(millisecond).await; + retries += 1; + if retries > MAX_RETRIES { + panic!("Timeout waiting for Redis, last error: {}", err); + } + } else { + panic!("Could not connect: {}", err); + } + } + + RedisInstance { + uri, + child: redis_child, + _tempdir: tempdir, + } + } +} | 
