diff options
Diffstat (limited to 'src/database/redis/mod.rs')
-rw-r--r-- | src/database/redis/mod.rs | 293 |
1 files changed, 210 insertions, 83 deletions
diff --git a/src/database/redis/mod.rs b/src/database/redis/mod.rs index 352cece..5a6b70d 100644 --- a/src/database/redis/mod.rs +++ b/src/database/redis/mod.rs @@ -1,20 +1,20 @@ use async_trait::async_trait; +use futures::stream; use futures_util::FutureExt; use futures_util::StreamExt; -use futures::stream; use lazy_static::lazy_static; use log::error; +use mobc::Pool; use mobc_redis::redis; use mobc_redis::redis::AsyncCommands; -use serde_json::json; -use mobc::Pool; use mobc_redis::RedisConnectionManager; +use serde_json::json; -use crate::database::{Storage, Result, StorageError, ErrorKind, MicropubChannel}; +use crate::database::{ErrorKind, MicropubChannel, Result, Storage, StorageError}; use crate::indieauth::User; struct RedisScripts { - edit_post: redis::Script + edit_post: redis::Script, } impl From<mobc_redis::redis::RedisError> for StorageError { @@ -22,7 +22,7 @@ impl From<mobc_redis::redis::RedisError> for StorageError { Self { msg: format!("{}", err), source: Some(Box::new(err)), - kind: ErrorKind::Backend + kind: ErrorKind::Backend, } } } @@ -31,7 +31,7 @@ impl From<mobc::Error<mobc_redis::redis::RedisError>> for StorageError { Self { msg: format!("{}", err), source: Some(Box::new(err)), - kind: ErrorKind::Backend + kind: ErrorKind::Backend, } } } @@ -64,17 +64,40 @@ fn filter_post(mut post: serde_json::Value, user: &'_ Option<String>) -> Option< })); } let empty_vec: Vec<serde_json::Value> = vec![]; - let author = post["properties"]["author"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string()); - let visibility = post["properties"]["visibility"][0].as_str().unwrap_or("public"); - let mut audience = author.chain(post["properties"]["audience"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string())); - if (visibility == "private" && !audience.any(|i| Some(i) == *user)) || (visibility == "protected" && user.is_none()) { - return None + let author = post["properties"]["author"] + .as_array() + .unwrap_or(&empty_vec) + .iter() + .map(|i| i.as_str().unwrap().to_string()); + let visibility = post["properties"]["visibility"][0] + .as_str() + .unwrap_or("public"); + let mut audience = author.chain( + post["properties"]["audience"] + .as_array() + .unwrap_or(&empty_vec) + .iter() + .map(|i| i.as_str().unwrap().to_string()), + ); + if (visibility == "private" && !audience.any(|i| Some(i) == *user)) + || (visibility == "protected" && user.is_none()) + { + return None; } if post["properties"]["location"].is_array() { - let location_visibility = post["properties"]["location-visibility"][0].as_str().unwrap_or("private"); - let mut author = post["properties"]["author"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string()); + let location_visibility = post["properties"]["location-visibility"][0] + .as_str() + .unwrap_or("private"); + let mut author = post["properties"]["author"] + .as_array() + .unwrap_or(&empty_vec) + .iter() + .map(|i| i.as_str().unwrap().to_string()); if location_visibility == "private" && !author.any(|i| Some(i) == *user) { - post["properties"].as_object_mut().unwrap().remove("location"); + post["properties"] + .as_object_mut() + .unwrap() + .remove("location"); } } Some(post) @@ -84,12 +107,16 @@ fn filter_post(mut post: serde_json::Value, user: &'_ Option<String>) -> Option< impl Storage for RedisStorage { async fn get_setting<'a>(&self, setting: &'a str, user: &'a str) -> Result<String> { let mut conn = self.redis.get().await?; - Ok(conn.hget::<String, &str, String>(format!("settings_{}", user), setting).await?) + Ok(conn + .hget::<String, &str, String>(format!("settings_{}", user), setting) + .await?) } async fn set_setting<'a>(&self, setting: &'a str, user: &'a str, value: &'a str) -> Result<()> { let mut conn = self.redis.get().await?; - Ok(conn.hset::<String, &str, &str, ()>(format!("settings_{}", user), setting, value).await?) + Ok(conn + .hset::<String, &str, &str, ()>(format!("settings_{}", user), setting, value) + .await?) } async fn delete_post<'a>(&self, url: &'a str) -> Result<()> { @@ -101,41 +128,63 @@ impl Storage for RedisStorage { let mut conn = self.redis.get().await?; Ok(conn.hexists::<&str, &str, bool>(&"posts", url).await?) } - + async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> { let mut conn = self.redis.get().await?; - match conn.hget::<&str, &str, Option<String>>(&"posts", url).await? { + match conn + .hget::<&str, &str, Option<String>>(&"posts", url) + .await? + { Some(val) => { let parsed = serde_json::from_str::<serde_json::Value>(&val)?; if let Some(new_url) = parsed["see_other"].as_str() { - match conn.hget::<&str, &str, Option<String>>(&"posts", new_url).await? { + match conn + .hget::<&str, &str, Option<String>>(&"posts", new_url) + .await? + { Some(val) => Ok(Some(serde_json::from_str::<serde_json::Value>(&val)?)), - None => Ok(None) + None => Ok(None), } } else { Ok(Some(parsed)) } - }, - None => Ok(None) + } + None => Ok(None), } } async fn get_channels(&self, user: &User) -> Result<Vec<MicropubChannel>> { let mut conn = self.redis.get().await?; - let channels = conn.smembers::<String, Vec<String>>("channels_".to_string() + user.me.as_str()).await?; + let channels = conn + .smembers::<String, Vec<String>>("channels_".to_string() + user.me.as_str()) + .await?; // TODO: use streams here instead of this weird thing... how did I even write this?! - Ok(futures_util::future::join_all(channels.iter() - .map(|channel| self.get_post(channel) - .map(|result| result.unwrap()) - .map(|post: Option<serde_json::Value>| { - if let Some(post) = post { - Some(MicropubChannel { - uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), - name: post["properties"]["name"][0].as_str().unwrap().to_string() - }) - } else { None } + Ok(futures_util::future::join_all( + channels + .iter() + .map(|channel| { + self.get_post(channel).map(|result| result.unwrap()).map( + |post: Option<serde_json::Value>| { + if let Some(post) = post { + Some(MicropubChannel { + uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), + name: post["properties"]["name"][0] + .as_str() + .unwrap() + .to_string(), + }) + } else { + None + } + }, + ) }) - ).collect::<Vec<_>>()).await.into_iter().filter_map(|chan| chan).collect::<Vec<_>>()) + .collect::<Vec<_>>(), + ) + .await + .into_iter() + .filter_map(|chan| chan) + .collect::<Vec<_>>()) } async fn put_post<'a>(&self, post: &'a serde_json::Value) -> Result<()> { @@ -143,72 +192,122 @@ impl Storage for RedisStorage { let key: &str; match post["properties"]["uid"][0].as_str() { Some(uid) => key = uid, - None => return Err(StorageError::new(ErrorKind::BadRequest, "post doesn't have a UID")) + None => { + return Err(StorageError::new( + ErrorKind::BadRequest, + "post doesn't have a UID", + )) + } } - conn.hset::<&str, &str, String, ()>(&"posts", key, post.to_string()).await?; + conn.hset::<&str, &str, String, ()>(&"posts", key, post.to_string()) + .await?; if post["properties"]["url"].is_array() { - for url in post["properties"]["url"].as_array().unwrap().iter().map(|i| i.as_str().unwrap().to_string()) { + for url in post["properties"]["url"] + .as_array() + .unwrap() + .iter() + .map(|i| i.as_str().unwrap().to_string()) + { if url != key { - conn.hset::<&str, &str, String, ()>(&"posts", &url, json!({"see_other": key}).to_string()).await?; + conn.hset::<&str, &str, String, ()>( + &"posts", + &url, + json!({ "see_other": key }).to_string(), + ) + .await?; } } } - if post["type"].as_array().unwrap().iter().any(|i| i == "h-feed") { + if post["type"] + .as_array() + .unwrap() + .iter() + .any(|i| i == "h-feed") + { // This is a feed. Add it to the channels array if it's not already there. - conn.sadd::<String, &str, ()>("channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(), key).await? + conn.sadd::<String, &str, ()>( + "channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(), + key, + ) + .await? } Ok(()) } - async fn read_feed_with_limit<'a>(&self, url: &'a str, after: &'a Option<String>, limit: usize, user: &'a Option<String>) -> Result<Option<serde_json::Value>> { + async fn read_feed_with_limit<'a>( + &self, + url: &'a str, + after: &'a Option<String>, + limit: usize, + user: &'a Option<String>, + ) -> Result<Option<serde_json::Value>> { let mut conn = self.redis.get().await?; let mut feed; - match conn.hget::<&str, &str, Option<String>>(&"posts", url).await? { + match conn + .hget::<&str, &str, Option<String>>(&"posts", url) + .await? + { Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?, - None => return Ok(None) + None => return Ok(None), } if feed["see_other"].is_string() { - match conn.hget::<&str, &str, Option<String>>(&"posts", feed["see_other"].as_str().unwrap()).await? { + match conn + .hget::<&str, &str, Option<String>>(&"posts", feed["see_other"].as_str().unwrap()) + .await? + { Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?, - None => return Ok(None) + None => return Ok(None), } } if let Some(post) = filter_post(feed, user) { feed = post } else { - return Err(StorageError::new(ErrorKind::PermissionDenied, "specified user cannot access this post")) + return Err(StorageError::new( + ErrorKind::PermissionDenied, + "specified user cannot access this post", + )); } if feed["children"].is_array() { let children = feed["children"].as_array().unwrap(); let posts_iter: Box<dyn std::iter::Iterator<Item = String> + Send>; // TODO: refactor this to apply the skip on the &mut iterator if let Some(after) = after { - posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string()).skip_while(move |i| i != after).skip(1)); + posts_iter = Box::new( + children + .iter() + .map(|i| i.as_str().unwrap().to_string()) + .skip_while(move |i| i != after) + .skip(1), + ); } else { posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string())); } let posts = stream::iter(posts_iter) .map(|url| async move { match self.redis.get().await { - Ok(mut conn) => match conn.hget::<&str, &str, Option<String>>("posts", &url).await { - Ok(post) => match post { - Some(post) => match serde_json::from_str::<serde_json::Value>(&post) { - Ok(post) => Some(post), - Err(err) => { - let err = StorageError::from(err); - error!("{}", err); - panic!("{}", err) + Ok(mut conn) => { + match conn.hget::<&str, &str, Option<String>>("posts", &url).await { + Ok(post) => match post { + Some(post) => { + match serde_json::from_str::<serde_json::Value>(&post) { + Ok(post) => Some(post), + Err(err) => { + let err = StorageError::from(err); + error!("{}", err); + panic!("{}", err) + } + } } + // Happens because of a broken link (result of an improper deletion?) + None => None, }, - // Happens because of a broken link (result of an improper deletion?) - None => None, - }, - Err(err) => { - let err = StorageError::from(err); - error!("{}", err); - panic!("{}", err) + Err(err) => { + let err = StorageError::from(err); + error!("{}", err); + panic!("{}", err) + } } - }, + } // TODO: Instead of causing a panic, investigate how can you fail the whole stream // Somehow fuse it maybe? Err(err) => { @@ -227,14 +326,20 @@ impl Storage for RedisStorage { // Hack to unwrap the Option and sieve out broken links // Broken links return None, and Stream::filter_map skips all Nones. .filter_map(|post: Option<serde_json::Value>| async move { post }) - .filter_map(|post| async move { - filter_post(post, user) - }) + .filter_map(|post| async move { filter_post(post, user) }) .take(limit); // TODO: Instead of catching panics, find a way to make the whole stream fail with Result<Vec<serde_json::Value>> - match std::panic::AssertUnwindSafe(posts.collect::<Vec<serde_json::Value>>()).catch_unwind().await { + match std::panic::AssertUnwindSafe(posts.collect::<Vec<serde_json::Value>>()) + .catch_unwind() + .await + { Ok(posts) => feed["children"] = json!(posts), - Err(_) => return Err(StorageError::new(ErrorKind::Other, "Unknown error encountered while assembling feed, see logs for more info")) + Err(_) => { + return Err(StorageError::new( + ErrorKind::Other, + "Unknown error encountered while assembling feed, see logs for more info", + )) + } } } return Ok(Some(feed)); @@ -242,39 +347,56 @@ impl Storage for RedisStorage { async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> { let mut conn = self.redis.get().await?; - if !conn.hexists::<&str, &str, bool>("posts", url).await.unwrap() { - return Err(StorageError::new(ErrorKind::NotFound, "can't edit a non-existent post")) + if !conn + .hexists::<&str, &str, bool>("posts", url) + .await + .unwrap() + { + return Err(StorageError::new( + ErrorKind::NotFound, + "can't edit a non-existent post", + )); } - let post: serde_json::Value = serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await?)?; + let post: serde_json::Value = + serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await?)?; if let Some(new_url) = post["see_other"].as_str() { url = new_url } - Ok(SCRIPTS.edit_post.key("posts").arg(url).arg(update.to_string()).invoke_async::<_, ()>(&mut conn as &mut redis::aio::Connection).await?) + Ok(SCRIPTS + .edit_post + .key("posts") + .arg(url) + .arg(update.to_string()) + .invoke_async::<_, ()>(&mut conn as &mut redis::aio::Connection) + .await?) } } - impl RedisStorage { /// Create a new RedisDatabase that will connect to Redis at `redis_uri` to store data. pub async fn new(redis_uri: String) -> Result<Self> { match redis::Client::open(redis_uri) { - Ok(client) => Ok(Self { redis: Pool::builder().max_open(20).build(RedisConnectionManager::new(client)) }), - Err(e) => Err(e.into()) + Ok(client) => Ok(Self { + redis: Pool::builder() + .max_open(20) + .build(RedisConnectionManager::new(client)), + }), + Err(e) => Err(e.into()), } } } #[cfg(test)] pub mod tests { + use mobc_redis::redis; use std::process; use std::time::Duration; - use mobc_redis::redis; pub struct RedisInstance { // We just need to hold on to it so it won't get dropped and remove the socket _tempdir: tempdir::TempDir, uri: String, - child: std::process::Child + child: std::process::Child, } impl Drop for RedisInstance { fn drop(&mut self) { @@ -292,11 +414,14 @@ pub mod tests { let socket = tempdir.path().join("redis.sock"); let redis_child = process::Command::new("redis-server") .current_dir(&tempdir) - .arg("--port").arg("0") - .arg("--unixsocket").arg(&socket) + .arg("--port") + .arg("0") + .arg("--unixsocket") + .arg(&socket) .stdout(process::Stdio::null()) .stderr(process::Stdio::null()) - .spawn().expect("Failed to spawn Redis"); + .spawn() + .expect("Failed to spawn Redis"); println!("redis+unix:///{}", socket.to_str().unwrap()); let uri = format!("redis+unix:///{}", socket.to_str().unwrap()); // There should be a slight delay, we need to wait for Redis to spin up @@ -317,7 +442,9 @@ pub mod tests { } return RedisInstance { - uri, child: redis_child, _tempdir: tempdir - } + uri, + child: redis_child, + _tempdir: tempdir, + }; } } |