about summary refs log tree commit diff
path: root/src/database/redis
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/redis')
-rw-r--r--src/database/redis/mod.rs293
1 files changed, 210 insertions, 83 deletions
diff --git a/src/database/redis/mod.rs b/src/database/redis/mod.rs
index 352cece..5a6b70d 100644
--- a/src/database/redis/mod.rs
+++ b/src/database/redis/mod.rs
@@ -1,20 +1,20 @@
 use async_trait::async_trait;
+use futures::stream;
 use futures_util::FutureExt;
 use futures_util::StreamExt;
-use futures::stream;
 use lazy_static::lazy_static;
 use log::error;
+use mobc::Pool;
 use mobc_redis::redis;
 use mobc_redis::redis::AsyncCommands;
-use serde_json::json;
-use mobc::Pool;
 use mobc_redis::RedisConnectionManager;
+use serde_json::json;
 
-use crate::database::{Storage, Result, StorageError, ErrorKind, MicropubChannel};
+use crate::database::{ErrorKind, MicropubChannel, Result, Storage, StorageError};
 use crate::indieauth::User;
 
 struct RedisScripts {
-    edit_post: redis::Script
+    edit_post: redis::Script,
 }
 
 impl From<mobc_redis::redis::RedisError> for StorageError {
@@ -22,7 +22,7 @@ impl From<mobc_redis::redis::RedisError> for StorageError {
         Self {
             msg: format!("{}", err),
             source: Some(Box::new(err)),
-            kind: ErrorKind::Backend
+            kind: ErrorKind::Backend,
         }
     }
 }
@@ -31,7 +31,7 @@ impl From<mobc::Error<mobc_redis::redis::RedisError>> for StorageError {
         Self {
             msg: format!("{}", err),
             source: Some(Box::new(err)),
-            kind: ErrorKind::Backend
+            kind: ErrorKind::Backend,
         }
     }
 }
@@ -64,17 +64,40 @@ fn filter_post(mut post: serde_json::Value, user: &'_ Option<String>) -> Option<
         }));
     }
     let empty_vec: Vec<serde_json::Value> = vec![];
-    let author = post["properties"]["author"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string());
-    let visibility = post["properties"]["visibility"][0].as_str().unwrap_or("public");
-    let mut audience = author.chain(post["properties"]["audience"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string()));
-    if (visibility == "private" && !audience.any(|i| Some(i) == *user)) || (visibility == "protected" && user.is_none()) {
-        return None
+    let author = post["properties"]["author"]
+        .as_array()
+        .unwrap_or(&empty_vec)
+        .iter()
+        .map(|i| i.as_str().unwrap().to_string());
+    let visibility = post["properties"]["visibility"][0]
+        .as_str()
+        .unwrap_or("public");
+    let mut audience = author.chain(
+        post["properties"]["audience"]
+            .as_array()
+            .unwrap_or(&empty_vec)
+            .iter()
+            .map(|i| i.as_str().unwrap().to_string()),
+    );
+    if (visibility == "private" && !audience.any(|i| Some(i) == *user))
+        || (visibility == "protected" && user.is_none())
+    {
+        return None;
     }
     if post["properties"]["location"].is_array() {
-        let location_visibility = post["properties"]["location-visibility"][0].as_str().unwrap_or("private");
-        let mut author = post["properties"]["author"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string());
+        let location_visibility = post["properties"]["location-visibility"][0]
+            .as_str()
+            .unwrap_or("private");
+        let mut author = post["properties"]["author"]
+            .as_array()
+            .unwrap_or(&empty_vec)
+            .iter()
+            .map(|i| i.as_str().unwrap().to_string());
         if location_visibility == "private" && !author.any(|i| Some(i) == *user) {
-            post["properties"].as_object_mut().unwrap().remove("location");
+            post["properties"]
+                .as_object_mut()
+                .unwrap()
+                .remove("location");
         }
     }
     Some(post)
@@ -84,12 +107,16 @@ fn filter_post(mut post: serde_json::Value, user: &'_ Option<String>) -> Option<
 impl Storage for RedisStorage {
     async fn get_setting<'a>(&self, setting: &'a str, user: &'a str) -> Result<String> {
         let mut conn = self.redis.get().await?;
-        Ok(conn.hget::<String, &str, String>(format!("settings_{}", user), setting).await?)
+        Ok(conn
+            .hget::<String, &str, String>(format!("settings_{}", user), setting)
+            .await?)
     }
 
     async fn set_setting<'a>(&self, setting: &'a str, user: &'a str, value: &'a str) -> Result<()> {
         let mut conn = self.redis.get().await?;
-        Ok(conn.hset::<String, &str, &str, ()>(format!("settings_{}", user), setting, value).await?)
+        Ok(conn
+            .hset::<String, &str, &str, ()>(format!("settings_{}", user), setting, value)
+            .await?)
     }
 
     async fn delete_post<'a>(&self, url: &'a str) -> Result<()> {
@@ -101,41 +128,63 @@ impl Storage for RedisStorage {
         let mut conn = self.redis.get().await?;
         Ok(conn.hexists::<&str, &str, bool>(&"posts", url).await?)
     }
-    
+
     async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> {
         let mut conn = self.redis.get().await?;
-        match conn.hget::<&str, &str, Option<String>>(&"posts", url).await? {
+        match conn
+            .hget::<&str, &str, Option<String>>(&"posts", url)
+            .await?
+        {
             Some(val) => {
                 let parsed = serde_json::from_str::<serde_json::Value>(&val)?;
                 if let Some(new_url) = parsed["see_other"].as_str() {
-                    match conn.hget::<&str, &str, Option<String>>(&"posts", new_url).await? {
+                    match conn
+                        .hget::<&str, &str, Option<String>>(&"posts", new_url)
+                        .await?
+                    {
                         Some(val) => Ok(Some(serde_json::from_str::<serde_json::Value>(&val)?)),
-                        None => Ok(None)
+                        None => Ok(None),
                     }
                 } else {
                     Ok(Some(parsed))
                 }
-            },
-            None => Ok(None)
+            }
+            None => Ok(None),
         }
     }
 
     async fn get_channels(&self, user: &User) -> Result<Vec<MicropubChannel>> {
         let mut conn = self.redis.get().await?;
-        let channels = conn.smembers::<String, Vec<String>>("channels_".to_string() + user.me.as_str()).await?;
+        let channels = conn
+            .smembers::<String, Vec<String>>("channels_".to_string() + user.me.as_str())
+            .await?;
         // TODO: use streams here instead of this weird thing... how did I even write this?!
-        Ok(futures_util::future::join_all(channels.iter()
-            .map(|channel| self.get_post(channel)
-                .map(|result| result.unwrap())
-                .map(|post: Option<serde_json::Value>| {
-                    if let Some(post) = post {
-                        Some(MicropubChannel {
-                            uid: post["properties"]["uid"][0].as_str().unwrap().to_string(),
-                            name: post["properties"]["name"][0].as_str().unwrap().to_string()
-                        })
-                    } else { None }
+        Ok(futures_util::future::join_all(
+            channels
+                .iter()
+                .map(|channel| {
+                    self.get_post(channel).map(|result| result.unwrap()).map(
+                        |post: Option<serde_json::Value>| {
+                            if let Some(post) = post {
+                                Some(MicropubChannel {
+                                    uid: post["properties"]["uid"][0].as_str().unwrap().to_string(),
+                                    name: post["properties"]["name"][0]
+                                        .as_str()
+                                        .unwrap()
+                                        .to_string(),
+                                })
+                            } else {
+                                None
+                            }
+                        },
+                    )
                 })
-            ).collect::<Vec<_>>()).await.into_iter().filter_map(|chan| chan).collect::<Vec<_>>())
+                .collect::<Vec<_>>(),
+        )
+        .await
+        .into_iter()
+        .filter_map(|chan| chan)
+        .collect::<Vec<_>>())
     }
 
     async fn put_post<'a>(&self, post: &'a serde_json::Value) -> Result<()> {
@@ -143,72 +192,122 @@ impl Storage for RedisStorage {
         let key: &str;
         match post["properties"]["uid"][0].as_str() {
             Some(uid) => key = uid,
-            None => return Err(StorageError::new(ErrorKind::BadRequest, "post doesn't have a UID"))
+            None => {
+                return Err(StorageError::new(
+                    ErrorKind::BadRequest,
+                    "post doesn't have a UID",
+                ))
+            }
         }
-        conn.hset::<&str, &str, String, ()>(&"posts", key, post.to_string()).await?;
+        conn.hset::<&str, &str, String, ()>(&"posts", key, post.to_string())
+            .await?;
         if post["properties"]["url"].is_array() {
-            for url in post["properties"]["url"].as_array().unwrap().iter().map(|i| i.as_str().unwrap().to_string()) {
+            for url in post["properties"]["url"]
+                .as_array()
+                .unwrap()
+                .iter()
+                .map(|i| i.as_str().unwrap().to_string())
+            {
                 if url != key {
-                    conn.hset::<&str, &str, String, ()>(&"posts", &url, json!({"see_other": key}).to_string()).await?;
+                    conn.hset::<&str, &str, String, ()>(
+                        &"posts",
+                        &url,
+                        json!({ "see_other": key }).to_string(),
+                    )
+                    .await?;
                 }
             }
         }
-        if post["type"].as_array().unwrap().iter().any(|i| i == "h-feed") {
+        if post["type"]
+            .as_array()
+            .unwrap()
+            .iter()
+            .any(|i| i == "h-feed")
+        {
             // This is a feed. Add it to the channels array if it's not already there.
-            conn.sadd::<String, &str, ()>("channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(), key).await?
+            conn.sadd::<String, &str, ()>(
+                "channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(),
+                key,
+            )
+            .await?
         }
         Ok(())
     }
 
-    async fn read_feed_with_limit<'a>(&self, url: &'a str, after: &'a Option<String>, limit: usize, user: &'a Option<String>) -> Result<Option<serde_json::Value>> {
+    async fn read_feed_with_limit<'a>(
+        &self,
+        url: &'a str,
+        after: &'a Option<String>,
+        limit: usize,
+        user: &'a Option<String>,
+    ) -> Result<Option<serde_json::Value>> {
         let mut conn = self.redis.get().await?;
         let mut feed;
-        match conn.hget::<&str, &str, Option<String>>(&"posts", url).await? {
+        match conn
+            .hget::<&str, &str, Option<String>>(&"posts", url)
+            .await?
+        {
             Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?,
-            None => return Ok(None)
+            None => return Ok(None),
         }
         if feed["see_other"].is_string() {
-            match conn.hget::<&str, &str, Option<String>>(&"posts", feed["see_other"].as_str().unwrap()).await? {
+            match conn
+                .hget::<&str, &str, Option<String>>(&"posts", feed["see_other"].as_str().unwrap())
+                .await?
+            {
                 Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?,
-                None => return Ok(None)
+                None => return Ok(None),
             }
         }
         if let Some(post) = filter_post(feed, user) {
             feed = post
         } else {
-            return Err(StorageError::new(ErrorKind::PermissionDenied, "specified user cannot access this post"))
+            return Err(StorageError::new(
+                ErrorKind::PermissionDenied,
+                "specified user cannot access this post",
+            ));
         }
         if feed["children"].is_array() {
             let children = feed["children"].as_array().unwrap();
             let posts_iter: Box<dyn std::iter::Iterator<Item = String> + Send>;
             // TODO: refactor this to apply the skip on the &mut iterator
             if let Some(after) = after {
-                posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string()).skip_while(move |i| i != after).skip(1));
+                posts_iter = Box::new(
+                    children
+                        .iter()
+                        .map(|i| i.as_str().unwrap().to_string())
+                        .skip_while(move |i| i != after)
+                        .skip(1),
+                );
             } else {
                 posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string()));
             }
             let posts = stream::iter(posts_iter)
                 .map(|url| async move {
                     match self.redis.get().await {
-                        Ok(mut conn) => match conn.hget::<&str, &str, Option<String>>("posts", &url).await {
-                            Ok(post) => match post {
-                                Some(post) => match serde_json::from_str::<serde_json::Value>(&post) {
-                                    Ok(post) => Some(post),
-                                    Err(err) => {
-                                        let err = StorageError::from(err);
-                                        error!("{}", err);
-                                        panic!("{}", err)
+                        Ok(mut conn) => {
+                            match conn.hget::<&str, &str, Option<String>>("posts", &url).await {
+                                Ok(post) => match post {
+                                    Some(post) => {
+                                        match serde_json::from_str::<serde_json::Value>(&post) {
+                                            Ok(post) => Some(post),
+                                            Err(err) => {
+                                                let err = StorageError::from(err);
+                                                error!("{}", err);
+                                                panic!("{}", err)
+                                            }
+                                        }
                                     }
+                                    // Happens because of a broken link (result of an improper deletion?)
+                                    None => None,
                                 },
-                                // Happens because of a broken link (result of an improper deletion?)
-                                None => None,
-                            },
-                            Err(err) => {
-                                let err = StorageError::from(err);
-                                error!("{}", err);
-                                panic!("{}", err)
+                                Err(err) => {
+                                    let err = StorageError::from(err);
+                                    error!("{}", err);
+                                    panic!("{}", err)
+                                }
                             }
-                        },
+                        }
                         // TODO: Instead of causing a panic, investigate how can you fail the whole stream
                         // Somehow fuse it maybe?
                         Err(err) => {
@@ -227,14 +326,20 @@ impl Storage for RedisStorage {
                 // Hack to unwrap the Option and sieve out broken links
                 // Broken links return None, and Stream::filter_map skips all Nones.
                 .filter_map(|post: Option<serde_json::Value>| async move { post })
-                .filter_map(|post| async move {
-                    filter_post(post, user)
-                })
+                .filter_map(|post| async move { filter_post(post, user) })
                 .take(limit);
             // TODO: Instead of catching panics, find a way to make the whole stream fail with Result<Vec<serde_json::Value>>
-            match std::panic::AssertUnwindSafe(posts.collect::<Vec<serde_json::Value>>()).catch_unwind().await {
+            match std::panic::AssertUnwindSafe(posts.collect::<Vec<serde_json::Value>>())
+                .catch_unwind()
+                .await
+            {
                 Ok(posts) => feed["children"] = json!(posts),
-                Err(_) => return Err(StorageError::new(ErrorKind::Other, "Unknown error encountered while assembling feed, see logs for more info"))
+                Err(_) => {
+                    return Err(StorageError::new(
+                        ErrorKind::Other,
+                        "Unknown error encountered while assembling feed, see logs for more info",
+                    ))
+                }
             }
         }
         return Ok(Some(feed));
@@ -242,39 +347,56 @@ impl Storage for RedisStorage {
 
     async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> {
         let mut conn = self.redis.get().await?;
-        if !conn.hexists::<&str, &str, bool>("posts", url).await.unwrap() {
-            return Err(StorageError::new(ErrorKind::NotFound, "can't edit a non-existent post"))
+        if !conn
+            .hexists::<&str, &str, bool>("posts", url)
+            .await
+            .unwrap()
+        {
+            return Err(StorageError::new(
+                ErrorKind::NotFound,
+                "can't edit a non-existent post",
+            ));
         }
-        let post: serde_json::Value = serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await?)?;
+        let post: serde_json::Value =
+            serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await?)?;
         if let Some(new_url) = post["see_other"].as_str() {
             url = new_url
         }
-        Ok(SCRIPTS.edit_post.key("posts").arg(url).arg(update.to_string()).invoke_async::<_, ()>(&mut conn as &mut redis::aio::Connection).await?)
+        Ok(SCRIPTS
+            .edit_post
+            .key("posts")
+            .arg(url)
+            .arg(update.to_string())
+            .invoke_async::<_, ()>(&mut conn as &mut redis::aio::Connection)
+            .await?)
     }
 }
 
-
 impl RedisStorage {
     /// Create a new RedisDatabase that will connect to Redis at `redis_uri` to store data.
     pub async fn new(redis_uri: String) -> Result<Self> {
         match redis::Client::open(redis_uri) {
-            Ok(client) => Ok(Self { redis: Pool::builder().max_open(20).build(RedisConnectionManager::new(client)) }),
-            Err(e) => Err(e.into())
+            Ok(client) => Ok(Self {
+                redis: Pool::builder()
+                    .max_open(20)
+                    .build(RedisConnectionManager::new(client)),
+            }),
+            Err(e) => Err(e.into()),
         }
     }
 }
 
 #[cfg(test)]
 pub mod tests {
+    use mobc_redis::redis;
     use std::process;
     use std::time::Duration;
-    use mobc_redis::redis;
 
     pub struct RedisInstance {
         // We just need to hold on to it so it won't get dropped and remove the socket
         _tempdir: tempdir::TempDir,
         uri: String,
-        child: std::process::Child
+        child: std::process::Child,
     }
     impl Drop for RedisInstance {
         fn drop(&mut self) {
@@ -292,11 +414,14 @@ pub mod tests {
         let socket = tempdir.path().join("redis.sock");
         let redis_child = process::Command::new("redis-server")
             .current_dir(&tempdir)
-            .arg("--port").arg("0")
-            .arg("--unixsocket").arg(&socket)
+            .arg("--port")
+            .arg("0")
+            .arg("--unixsocket")
+            .arg(&socket)
             .stdout(process::Stdio::null())
             .stderr(process::Stdio::null())
-            .spawn().expect("Failed to spawn Redis");
+            .spawn()
+            .expect("Failed to spawn Redis");
         println!("redis+unix:///{}", socket.to_str().unwrap());
         let uri = format!("redis+unix:///{}", socket.to_str().unwrap());
         // There should be a slight delay, we need to wait for Redis to spin up
@@ -317,7 +442,9 @@ pub mod tests {
         }
 
         return RedisInstance {
-            uri, child: redis_child, _tempdir: tempdir
-        }
+            uri,
+            child: redis_child,
+            _tempdir: tempdir,
+        };
     }
 }