about summary refs log tree commit diff
path: root/src/database/redis/mod.rs
diff options
context:
space:
mode:
authorVika <vika@fireburn.ru>2021-05-05 14:57:11 +0300
committerVika <vika@fireburn.ru>2021-05-05 14:57:11 +0300
commit488285e460dba2f5fb0901c77226d582d31e2be4 (patch)
tree64a55cd79284e121a372074abe781c2c99f7c12b /src/database/redis/mod.rs
parent445873540edd1c4b21dc1c5039a489666cac1f30 (diff)
downloadkittybox-488285e460dba2f5fb0901c77226d582d31e2be4.tar.zst
Refactored error handling in RedisStorage using the ? operator
Diffstat (limited to 'src/database/redis/mod.rs')
-rw-r--r--src/database/redis/mod.rs308
1 files changed, 124 insertions, 184 deletions
diff --git a/src/database/redis/mod.rs b/src/database/redis/mod.rs
index 5ab93af..b709125 100644
--- a/src/database/redis/mod.rs
+++ b/src/database/redis/mod.rs
@@ -65,224 +65,164 @@ fn filter_post(mut post: serde_json::Value, user: &'_ Option<String>) -> Option<
 #[async_trait]
 impl Storage for RedisStorage {
     async fn delete_post<'a>(&self, url: &'a str) -> Result<()> {
-        match self.redis.get_async_std_connection().await {
-            Ok(mut conn) => if let Err(err) = conn.hdel::<&str, &str, bool>("posts", url).await {
-                return Err(err.into());
-            },
-            Err(err) => return Err(err.into())
-        }
-        Ok(())
+        let mut conn = self.redis.get_async_std_connection().await?;
+        Ok(conn.hdel::<&str, &str, ()>("posts", url).await?)
     }
 
     async fn post_exists(&self, url: &str) -> Result<bool> {
-        match self.redis.get_async_std_connection().await {
-            Ok(mut conn) => match conn.hexists::<&str, &str, bool>(&"posts", url).await {
-                Ok(val) => Ok(val),
-                Err(err) => Err(err.into())
-            },
-            Err(err) => Err(err.into())
-        }
+        let mut conn = self.redis.get_async_std_connection().await?;
+        Ok(conn.hexists::<&str, &str, bool>(&"posts", url).await?)
     }
     
     async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> {
-        match self.redis.get_async_std_connection().await {
-            Ok(mut conn) => match conn.hget::<&str, &str, Option<String>>(&"posts", url).await {
-                Ok(val) => match val {
-                    Some(val) => match serde_json::from_str::<serde_json::Value>(&val) {
-                        Ok(parsed) => if let Some(new_url) = parsed["see_other"].as_str() {
-                            match conn.hget::<&str, &str, Option<String>>(&"posts", new_url).await {
-                                Ok(val) => match val {
-                                    Some(val) => match serde_json::from_str::<serde_json::Value>(&val) {
-                                        Ok(parsed) => Ok(Some(parsed)),
-                                        Err(err) => Err(err.into())
-                                    },
-                                    None => Ok(None)
-                                }
-                                Err(err) => {
-                                    Err(err.into())
-                                }
-                            }
-                        } else {
-                            Ok(Some(parsed))
-                        },
-                        Err(err) => Err(err.into())
-                    },
-                    None => Ok(None)
-                },
-                Err(err) => Err(err.into())
+        let mut conn = self.redis.get_async_std_connection().await?;
+        match conn.hget::<&str, &str, Option<String>>(&"posts", url).await? {
+            Some(val) => {
+                let parsed = serde_json::from_str::<serde_json::Value>(&val)?;
+                if let Some(new_url) = parsed["see_other"].as_str() {
+                    match conn.hget::<&str, &str, Option<String>>(&"posts", new_url).await? {
+                        Some(val) => Ok(Some(serde_json::from_str::<serde_json::Value>(&val)?)),
+                        None => Ok(None)
+                    }
+                } else {
+                    Ok(Some(parsed))
+                }
             },
-            Err(err) => Err(err.into())
+            None => Ok(None)
         }
     }
 
     async fn get_channels(&self, user: &User) -> Result<Vec<MicropubChannel>> {
-        match self.redis.get_async_std_connection().await {
-            Ok(mut conn) => match conn.smembers::<String, Vec<String>>("channels_".to_string() + user.me.as_str()).await {
-                Ok(channels) => {
-                    Ok(futures_util::future::join_all(channels.iter()
-                        .map(|channel| self.get_post(channel)
-                            .map(|result| result.unwrap())
-                            .map(|post: Option<serde_json::Value>| {
-                                if let Some(post) = post {
-                                    Some(MicropubChannel {
-                                        uid: post["properties"]["uid"][0].as_str().unwrap().to_string(),
-                                        name: post["properties"]["name"][0].as_str().unwrap().to_string()
-                                    })
-                                } else { None }
-                            })
-                        ).collect::<Vec<_>>()).await.into_iter().filter_map(|chan| chan).collect::<Vec<_>>())
-                },
-                Err(err) => Err(err.into())
-            },
-            Err(err) => Err(err.into())
-        }
+        let mut conn = self.redis.get_async_std_connection().await?;
+        let channels = conn.smembers::<String, Vec<String>>("channels_".to_string() + user.me.as_str()).await?;
+        // TODO: use streams here instead of this weird thing... how did I even write this?!
+        Ok(futures_util::future::join_all(channels.iter()
+            .map(|channel| self.get_post(channel)
+                .map(|result| result.unwrap())
+                .map(|post: Option<serde_json::Value>| {
+                    if let Some(post) = post {
+                        Some(MicropubChannel {
+                            uid: post["properties"]["uid"][0].as_str().unwrap().to_string(),
+                            name: post["properties"]["name"][0].as_str().unwrap().to_string()
+                        })
+                    } else { None }
+                })
+            ).collect::<Vec<_>>()).await.into_iter().filter_map(|chan| chan).collect::<Vec<_>>())
     }
 
     async fn put_post<'a>(&self, post: &'a serde_json::Value) -> Result<()> {
-        match self.redis.get_async_std_connection().await {
-            Ok(mut conn) => {
-                let key: &str;
-                match post["properties"]["uid"][0].as_str() {
-                    Some(uid) => key = uid,
-                    None => return Err(StorageError::new(ErrorKind::BadRequest, "post doesn't have a UID"))
-                }        
-                if let Err(err) = conn.hset::<&str, &str, String, ()>(&"posts", key, post.to_string()).await {
-                    return Err(err.into())
-                }
-                if post["properties"]["url"].is_array() {
-                    for url in post["properties"]["url"].as_array().unwrap().iter().map(|i| i.as_str().unwrap().to_string()) {
-                        if url != key {
-                            if let Err(err) = conn.hset::<&str, &str, String, ()>(&"posts", &url, json!({"see_other": key}).to_string()).await {
-                                return Err(err.into())
-                            }
-                        }
-                    }
-                }
-                if post["type"].as_array().unwrap().iter().any(|i| i == "h-feed") {
-                    // This is a feed. Add it to the channels array if it's not already there.
-                    if let Err(err) = conn.sadd::<String, &str, ()>("channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(), key).await {
-                        return Err(err.into())
-                    }
+        let mut conn = self.redis.get_async_std_connection().await?;
+        let key: &str;
+        match post["properties"]["uid"][0].as_str() {
+            Some(uid) => key = uid,
+            None => return Err(StorageError::new(ErrorKind::BadRequest, "post doesn't have a UID"))
+        }
+        conn.hset::<&str, &str, String, ()>(&"posts", key, post.to_string()).await?;
+        if post["properties"]["url"].is_array() {
+            for url in post["properties"]["url"].as_array().unwrap().iter().map(|i| i.as_str().unwrap().to_string()) {
+                if url != key {
+                    conn.hset::<&str, &str, String, ()>(&"posts", &url, json!({"see_other": key}).to_string()).await?;
                 }
-                Ok(())
-            },
-            Err(err) => Err(err.into())
+            }
+        }
+        if post["type"].as_array().unwrap().iter().any(|i| i == "h-feed") {
+            // This is a feed. Add it to the channels array if it's not already there.
+            conn.sadd::<String, &str, ()>("channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(), key).await?
         }
+        Ok(())
     }
 
     async fn read_feed_with_limit<'a>(&self, url: &'a str, after: &'a Option<String>, limit: usize, user: &'a Option<String>) -> Result<Option<serde_json::Value>> {
-        match self.redis.get_async_std_connection().await {
-            Ok(mut conn) => {
-                let mut feed;
-                match conn.hget::<&str, &str, Option<String>>(&"posts", url).await {
-                    Ok(post) => {
-                        match post {
-                            Some(post) => match serde_json::from_str::<serde_json::Value>(&post) {
-                                Ok(post) => feed = post,
-                                Err(err) => return Err(err.into())
-                            },
-                            None => return Ok(None)
-                        }
-                    },
-                    Err(err) => return Err(err.into())
-                }
-                if feed["see_other"].is_string() {
-                    match conn.hget::<&str, &str, Option<String>>(&"posts", feed["see_other"].as_str().unwrap()).await {
-                        Ok(post) => {
-                            match post {
+        let mut conn = self.redis.get_async_std_connection().await?;
+        let mut feed;
+        match conn.hget::<&str, &str, Option<String>>(&"posts", url).await? {
+            Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?,
+            None => return Ok(None)
+        }
+        if feed["see_other"].is_string() {
+            match conn.hget::<&str, &str, Option<String>>(&"posts", feed["see_other"].as_str().unwrap()).await? {
+                Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?,
+                None => return Ok(None)
+            }
+        }
+        if let Some(post) = filter_post(feed, user) {
+            feed = post
+        } else {
+            return Err(StorageError::new(ErrorKind::PermissionDenied, "specified user cannot access this post"))
+        }
+        if feed["children"].is_array() {
+            let children = feed["children"].as_array().unwrap();
+            let posts_iter: Box<dyn std::iter::Iterator<Item = String> + Send>;
+            // TODO: refactor this to apply the skip on the &mut iterator
+            if let Some(after) = after {
+                posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string()).skip_while(move |i| i != after).skip(1));
+            } else {
+                posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string()));
+            }
+            let posts = stream::iter(posts_iter)
+                .map(|url| async move {
+                    // Is it rational to use a new connection for every post fetched?
+                    // TODO: Use a connection pool here
+                    match self.redis.get_async_std_connection().await {
+                        Ok(mut conn) => match conn.hget::<&str, &str, Option<String>>("posts", &url).await {
+                            Ok(post) => match post {
                                 Some(post) => match serde_json::from_str::<serde_json::Value>(&post) {
-                                    Ok(post) => feed = post,
-                                    Err(err) => return Err(err.into())
-                                },
-                                None => return Ok(None)
-                            }
-                        },
-                        Err(err) => return Err(err.into())
-                    }
-                }
-                if let Some(post) = filter_post(feed, user) {
-                    feed = post
-                } else {
-                    return Err(StorageError::new(ErrorKind::PermissionDenied, "specified user cannot access this post"))
-                }
-                if feed["children"].is_array() {
-                    let children = feed["children"].as_array().unwrap();
-                    let posts_iter: Box<dyn std::iter::Iterator<Item = String> + Send>;
-                    if let Some(after) = after {
-                        posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string()).skip_while(move |i| i != after).skip(1));
-                    } else {
-                        posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string()));
-                    }
-                    let posts = stream::iter(posts_iter)
-                        .map(|url| async move {
-                            // Is it rational to use a new connection for every post fetched?
-                            match self.redis.get_async_std_connection().await {
-                                Ok(mut conn) => match conn.hget::<&str, &str, Option<String>>("posts", &url).await {
-                                    Ok(post) => match post {
-                                        Some(post) => match serde_json::from_str::<serde_json::Value>(&post) {
-                                            Ok(post) => Some(post),
-                                            Err(err) => {
-                                                let err = StorageError::from(err);
-                                                error!("{}", err);
-                                                panic!("{}", err)
-                                            }
-                                        },
-                                        // Happens because of a broken link (result of an improper deletion?)
-                                        None => None,
-                                    },
+                                    Ok(post) => Some(post),
                                     Err(err) => {
                                         let err = StorageError::from(err);
                                         error!("{}", err);
                                         panic!("{}", err)
                                     }
                                 },
-                                Err(err) => {
-                                    let err = StorageError::from(err);
-                                    error!("{}", err);
-                                    panic!("{}", err)
-                                }
+                                // Happens because of a broken link (result of an improper deletion?)
+                                None => None,
+                            },
+                            Err(err) => {
+                                let err = StorageError::from(err);
+                                error!("{}", err);
+                                panic!("{}", err)
                             }
-                        })
-                        // TODO: determine the optimal value for this buffer
-                        // It will probably depend on how often can you encounter a private post on the page
-                        // It shouldn't be too large, or we'll start fetching too many posts from the database
-                        // It MUST NOT be larger than the typical page size
-                        .buffered(std::cmp::min(3, limit))
-                        // Hack to unwrap the Option and sieve out broken links
-                        // Broken links return None, and Stream::filter_map skips all Nones.
-                        .filter_map(|post: Option<serde_json::Value>| async move { post })
-                        .filter_map(|post| async move {
-                            filter_post(post, user)
-                        })
-                        .take(limit);
-                    match std::panic::AssertUnwindSafe(posts.collect::<Vec<serde_json::Value>>()).catch_unwind().await {
-                        Ok(posts) => feed["children"] = json!(posts),
-                        Err(_) => return Err(StorageError::new(ErrorKind::Other, "Unknown error encountered while assembling feed, see logs for more info"))
+                        },
+                        // TODO: Instead of causing a panic, investigate how can you fail the whole stream
+                        // Somehow fuse it maybe?
+                        Err(err) => {
+                            let err = StorageError::from(err);
+                            error!("{}", err);
+                            panic!("{}", err)
+                        }
                     }
-                }
-                return Ok(Some(feed));
+                })
+                // TODO: determine the optimal value for this buffer
+                // It will probably depend on how often can you encounter a private post on the page
+                // It shouldn't be too large, or we'll start fetching too many posts from the database
+                // It MUST NOT be larger than the typical page size
+                .buffered(std::cmp::min(3, limit))
+                // Hack to unwrap the Option and sieve out broken links
+                // Broken links return None, and Stream::filter_map skips all Nones.
+                .filter_map(|post: Option<serde_json::Value>| async move { post })
+                .filter_map(|post| async move {
+                    filter_post(post, user)
+                })
+                .take(limit);
+            // TODO: Instead of catching panics, find a way to make the whole stream fail with Result<Vec<serde_json::Value>>
+            match std::panic::AssertUnwindSafe(posts.collect::<Vec<serde_json::Value>>()).catch_unwind().await {
+                Ok(posts) => feed["children"] = json!(posts),
+                Err(_) => return Err(StorageError::new(ErrorKind::Other, "Unknown error encountered while assembling feed, see logs for more info"))
             }
-            Err(err) => Err(err.into())
         }
+        return Ok(Some(feed));
     }
 
     async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> {
-        match self.redis.get_async_std_connection().await {
-            Ok(mut conn) => {
-                if !conn.hexists::<&str, &str, bool>("posts", url).await.unwrap() {
-                    return Err(StorageError::new(ErrorKind::NotFound, "can't edit a non-existent post"))
-                }
-                let post: serde_json::Value = serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await.unwrap()).unwrap();
-                if let Some(new_url) = post["see_other"].as_str() {
-                    url = new_url
-                }
-                if let Err(err) = SCRIPTS.edit_post.key("posts").arg(url).arg(update.to_string()).invoke_async::<_, ()>(&mut conn).await {
-                    return Err(err.into())
-                }
-            },
-            Err(err) => return Err(err.into())
+        let mut conn = self.redis.get_async_std_connection().await?;
+        if !conn.hexists::<&str, &str, bool>("posts", url).await.unwrap() {
+            return Err(StorageError::new(ErrorKind::NotFound, "can't edit a non-existent post"))
         }
-        Ok(())
+        let post: serde_json::Value = serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await?)?;
+        if let Some(new_url) = post["see_other"].as_str() {
+            url = new_url
+        }
+        Ok(SCRIPTS.edit_post.key("posts").arg(url).arg(update.to_string()).invoke_async::<_, ()>(&mut conn).await?)
     }
 }