about summary refs log tree commit diff
path: root/src/database/redis/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/redis/mod.rs')
-rw-r--r--src/database/redis/mod.rs398
1 files changed, 398 insertions, 0 deletions
diff --git a/src/database/redis/mod.rs b/src/database/redis/mod.rs
new file mode 100644
index 0000000..39ee852
--- /dev/null
+++ b/src/database/redis/mod.rs
@@ -0,0 +1,398 @@
+use async_trait::async_trait;
+use futures::stream;
+use futures_util::FutureExt;
+use futures_util::StreamExt;
+use futures_util::TryStream;
+use futures_util::TryStreamExt;
+use lazy_static::lazy_static;
+use log::error;
+use mobc::Pool;
+use mobc_redis::redis;
+use mobc_redis::redis::AsyncCommands;
+use mobc_redis::RedisConnectionManager;
+use serde_json::json;
+use std::time::Duration;
+
+use crate::database::{ErrorKind, MicropubChannel, Result, Storage, StorageError, filter_post};
+use crate::indieauth::User;
+
+struct RedisScripts {
+    edit_post: redis::Script,
+}
+
+impl From<mobc_redis::redis::RedisError> for StorageError {
+    fn from(err: mobc_redis::redis::RedisError) -> Self {
+        Self {
+            msg: format!("{}", err),
+            source: Some(Box::new(err)),
+            kind: ErrorKind::Backend,
+        }
+    }
+}
+impl From<mobc::Error<mobc_redis::redis::RedisError>> for StorageError {
+    fn from(err: mobc::Error<mobc_redis::redis::RedisError>) -> Self {
+        Self {
+            msg: format!("{}", err),
+            source: Some(Box::new(err)),
+            kind: ErrorKind::Backend,
+        }
+    }
+}
+
+lazy_static! {
+    static ref SCRIPTS: RedisScripts = RedisScripts {
+        edit_post: redis::Script::new(include_str!("./edit_post.lua"))
+    };
+}
+/*#[cfg(feature(lazy_cell))]
+static SCRIPTS_CELL: std::cell::LazyCell<RedisScripts> = std::cell::LazyCell::new(|| {
+    RedisScripts {
+        edit_post: redis::Script::new(include_str!("./edit_post.lua"))
+    }
+});*/
+
+#[derive(Clone)]
+pub struct RedisStorage {
+    // note to future Vika:
+    // mobc::Pool is actually a fancy name for an Arc
+    // around a shared connection pool with a manager
+    // which makes it safe to implement [`Clone`] and
+    // not worry about new pools being suddenly made
+    //
+    // stop worrying and start coding, you dum-dum
+    redis: mobc::Pool<RedisConnectionManager>,
+}
+
+#[async_trait]
+impl Storage for RedisStorage {
+    async fn get_setting<'a>(&self, setting: &'a str, user: &'a str) -> Result<String> {
+        let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?;
+        Ok(conn
+            .hget::<String, &str, String>(format!("settings_{}", user), setting)
+            .await?)
+    }
+
+    async fn set_setting<'a>(&self, setting: &'a str, user: &'a str, value: &'a str) -> Result<()> {
+        let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?;
+        Ok(conn
+            .hset::<String, &str, &str, ()>(format!("settings_{}", user), setting, value)
+            .await?)
+    }
+
+    async fn delete_post<'a>(&self, url: &'a str) -> Result<()> {
+        let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?;
+        Ok(conn.hdel::<&str, &str, ()>("posts", url).await?)
+    }
+
+    async fn post_exists(&self, url: &str) -> Result<bool> {
+        let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?;
+        Ok(conn.hexists::<&str, &str, bool>("posts", url).await?)
+    }
+
+    async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> {
+        let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?;
+        match conn
+            .hget::<&str, &str, Option<String>>("posts", url)
+            .await?
+        {
+            Some(val) => {
+                let parsed = serde_json::from_str::<serde_json::Value>(&val)?;
+                if let Some(new_url) = parsed["see_other"].as_str() {
+                    match conn
+                        .hget::<&str, &str, Option<String>>("posts", new_url)
+                        .await?
+                    {
+                        Some(val) => Ok(Some(serde_json::from_str::<serde_json::Value>(&val)?)),
+                        None => Ok(None),
+                    }
+                } else {
+                    Ok(Some(parsed))
+                }
+            }
+            None => Ok(None),
+        }
+    }
+
+    async fn get_channels(&self, user: &User) -> Result<Vec<MicropubChannel>> {
+        let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?;
+        let channels = conn
+            .smembers::<String, Vec<String>>("channels_".to_string() + user.me.as_str())
+            .await?;
+        // TODO: use streams here instead of this weird thing... how did I even write this?!
+        Ok(futures_util::future::join_all(
+            channels
+                .iter()
+                .map(|channel| {
+                    self.get_post(channel).map(|result| result.unwrap()).map(
+                        |post: Option<serde_json::Value>| {
+                            post.map(|post| MicropubChannel {
+                                uid: post["properties"]["uid"][0].as_str().unwrap().to_string(),
+                                name: post["properties"]["name"][0].as_str().unwrap().to_string(),
+                            })
+                        },
+                    )
+                })
+                .collect::<Vec<_>>(),
+        )
+        .await
+        .into_iter()
+        .flatten()
+        .collect::<Vec<_>>())
+    }
+
+    async fn put_post<'a>(&self, post: &'a serde_json::Value, user: &'a str) -> Result<()> {
+        let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?;
+        let key: &str;
+        match post["properties"]["uid"][0].as_str() {
+            Some(uid) => key = uid,
+            None => {
+                return Err(StorageError::new(
+                    ErrorKind::BadRequest,
+                    "post doesn't have a UID",
+                ))
+            }
+        }
+        conn.hset::<&str, &str, String, ()>("posts", key, post.to_string())
+            .await?;
+        if post["properties"]["url"].is_array() {
+            for url in post["properties"]["url"]
+                .as_array()
+                .unwrap()
+                .iter()
+                .map(|i| i.as_str().unwrap().to_string())
+            {
+                if url != key && url.starts_with(user) {
+                    conn.hset::<&str, &str, String, ()>(
+                        "posts",
+                        &url,
+                        json!({ "see_other": key }).to_string(),
+                    )
+                    .await?;
+                }
+            }
+        }
+        if post["type"]
+            .as_array()
+            .unwrap()
+            .iter()
+            .any(|i| i == "h-feed")
+        {
+            // This is a feed. Add it to the channels array if it's not already there.
+            conn.sadd::<String, &str, ()>(
+                "channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(),
+                key,
+            )
+            .await?
+        }
+        Ok(())
+    }
+
+    async fn read_feed_with_limit<'a>(
+        &self,
+        url: &'a str,
+        after: &'a Option<String>,
+        limit: usize,
+        user: &'a Option<String>,
+    ) -> Result<Option<serde_json::Value>> {
+        let mut conn = self.redis.get().await?;
+        let mut feed;
+        match conn
+            .hget::<&str, &str, Option<String>>("posts", url)
+            .await
+            .map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?
+        {
+            Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?,
+            None => return Ok(None),
+        }
+        if feed["see_other"].is_string() {
+            match conn
+                .hget::<&str, &str, Option<String>>("posts", feed["see_other"].as_str().unwrap())
+                .await?
+            {
+                Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?,
+                None => return Ok(None),
+            }
+        }
+        if let Some(post) = filter_post(feed, user) {
+            feed = post
+        } else {
+            return Err(StorageError::new(
+                ErrorKind::PermissionDenied,
+                "specified user cannot access this post",
+            ));
+        }
+        if feed["children"].is_array() {
+            let children = feed["children"].as_array().unwrap();
+            let mut posts_iter = children.iter().map(|i| i.as_str().unwrap().to_string());
+            if after.is_some() {
+                loop {
+                    let i = posts_iter.next();
+                    if &i == after {
+                        break;
+                    }
+                }
+            }
+            async fn fetch_post_for_feed(url: String) -> Option<serde_json::Value> {
+                return Some(serde_json::json!({}));
+            }
+            let posts = stream::iter(posts_iter)
+                .map(|url: String| async move {
+                    return Ok(fetch_post_for_feed(url).await);
+                    /*match self.redis.get().await {
+                        Ok(mut conn) => {
+                            match conn.hget::<&str, &str, Option<String>>("posts", &url).await {
+                                Ok(post) => match post {
+                                    Some(post) => {
+                                        Ok(Some(serde_json::from_str(&post)?))
+                                    }
+                                    // Happens because of a broken link (result of an improper deletion?)
+                                    None => Ok(None),
+                                },
+                                Err(err) => Err(StorageError::with_source(ErrorKind::Backend, "Error executing a Redis command", Box::new(err)))
+                            }
+                        }
+                        Err(err) => Err(StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(err)))
+                    }*/
+                })
+                // TODO: determine the optimal value for this buffer
+                // It will probably depend on how often can you encounter a private post on the page
+                // It shouldn't be too large, or we'll start fetching too many posts from the database
+                // It MUST NOT be larger than the typical page size
+                // It MUST NOT be a significant amount of the connection pool size
+                //.buffered(std::cmp::min(3, limit))
+                // Hack to unwrap the Option and sieve out broken links
+                // Broken links return None, and Stream::filter_map skips all Nones.
+                // I wonder if one can use try_flatten() here somehow akin to iters
+                .try_filter_map(|post| async move { Ok(post) })
+                .try_filter_map(|post| async move {
+                    Ok(filter_post(post, user))
+                })
+                .take(limit);
+            match posts.try_collect::<Vec<serde_json::Value>>().await {
+                Ok(posts) => feed["children"] = json!(posts),
+                Err(err) => {
+                    let e = StorageError::with_source(
+                        ErrorKind::Other,
+                        "An error was encountered while processing the feed",
+                        Box::new(err)
+                    );
+                    error!("Error while assembling feed: {}", e);
+                    return Err(e);
+                }
+            }
+        }
+        return Ok(Some(feed));
+    }
+
+    async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> {
+        let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?;
+        if !conn
+            .hexists::<&str, &str, bool>("posts", url)
+            .await
+            .unwrap()
+        {
+            return Err(StorageError::new(
+                ErrorKind::NotFound,
+                "can't edit a non-existent post",
+            ));
+        }
+        let post: serde_json::Value =
+            serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await?)?;
+        if let Some(new_url) = post["see_other"].as_str() {
+            url = new_url
+        }
+        Ok(SCRIPTS
+            .edit_post
+            .key("posts")
+            .arg(url)
+            .arg(update.to_string())
+            .invoke_async::<_, ()>(&mut conn as &mut redis::aio::Connection)
+            .await?)
+    }
+}
+
+impl RedisStorage {
+    /// Create a new RedisDatabase that will connect to Redis at `redis_uri` to store data.
+    pub async fn new(redis_uri: String) -> Result<Self> {
+        match redis::Client::open(redis_uri) {
+            Ok(client) => Ok(Self {
+                redis: Pool::builder()
+                    .max_open(20)
+                    .max_idle(5)
+                    .get_timeout(Some(Duration::from_secs(3)))
+                    .max_lifetime(Some(Duration::from_secs(120)))
+                    .build(RedisConnectionManager::new(client)),
+            }),
+            Err(e) => Err(e.into()),
+        }
+    }
+
+    pub async fn conn(&self) -> Result<mobc::Connection<mobc_redis::RedisConnectionManager>> {
+        self.redis.get().await.map_err(|e| StorageError::with_source(
+            ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)
+        ))
+    }
+}
+
+#[cfg(test)]
+pub mod tests {
+    use mobc_redis::redis;
+    use std::process;
+    use std::time::Duration;
+
+    pub struct RedisInstance {
+        // We just need to hold on to it so it won't get dropped and remove the socket
+        _tempdir: tempdir::TempDir,
+        uri: String,
+        child: std::process::Child,
+    }
+    impl Drop for RedisInstance {
+        fn drop(&mut self) {
+            self.child.kill().expect("Failed to kill the child!");
+        }
+    }
+    impl RedisInstance {
+        pub fn uri(&self) -> &str {
+            &self.uri
+        }
+    }
+
+    pub async fn get_redis_instance() -> RedisInstance {
+        let tempdir = tempdir::TempDir::new("redis").expect("failed to create tempdir");
+        let socket = tempdir.path().join("redis.sock");
+        let redis_child = process::Command::new("redis-server")
+            .current_dir(&tempdir)
+            .arg("--port")
+            .arg("0")
+            .arg("--unixsocket")
+            .arg(&socket)
+            .stdout(process::Stdio::null())
+            .stderr(process::Stdio::null())
+            .spawn()
+            .expect("Failed to spawn Redis");
+        println!("redis+unix:///{}", socket.to_str().unwrap());
+        let uri = format!("redis+unix:///{}", socket.to_str().unwrap());
+        // There should be a slight delay, we need to wait for Redis to spin up
+        let client = redis::Client::open(uri.clone()).unwrap();
+        let millisecond = Duration::from_millis(1);
+        let mut retries: usize = 0;
+        const MAX_RETRIES: usize = 60 * 1000/*ms*/;
+        while let Err(err) = client.get_connection() {
+            if err.is_connection_refusal() {
+                async_std::task::sleep(millisecond).await;
+                retries += 1;
+                if retries > MAX_RETRIES {
+                    panic!("Timeout waiting for Redis, last error: {}", err);
+                }
+            } else {
+                panic!("Could not connect: {}", err);
+            }
+        }
+
+        RedisInstance {
+            uri,
+            child: redis_child,
+            _tempdir: tempdir,
+        }
+    }
+}