about summary refs log tree commit diff
path: root/kittybox-rs/src/database/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'kittybox-rs/src/database/mod.rs')
-rw-r--r--kittybox-rs/src/database/mod.rs539
1 files changed, 539 insertions, 0 deletions
diff --git a/kittybox-rs/src/database/mod.rs b/kittybox-rs/src/database/mod.rs
new file mode 100644
index 0000000..6bf5409
--- /dev/null
+++ b/kittybox-rs/src/database/mod.rs
@@ -0,0 +1,539 @@
+#![warn(missing_docs)]
+use async_trait::async_trait;
+use serde::{Deserialize, Serialize};
+
+mod file;
+pub use crate::database::file::FileStorage;
+#[cfg(test)]
+mod memory;
+#[cfg(test)]
+pub use crate::database::memory::MemoryStorage;
+
+pub use kittybox_util::MicropubChannel;
+
+/// Enum representing different errors that might occur during the database query.
+#[derive(Debug, Clone, Copy)]
+pub enum ErrorKind {
+    /// Backend error (e.g. database connection error)
+    Backend,
+    /// Error due to insufficient contextual permissions for the query
+    PermissionDenied,
+    /// Error due to the database being unable to parse JSON returned from the backing storage.
+    /// Usually indicative of someone fiddling with the database manually instead of using proper tools.
+    JsonParsing,
+    ///  - ErrorKind::NotFound - equivalent to a 404 error. Note, some requests return an Option,
+    ///    in which case None is also equivalent to a 404.
+    NotFound,
+    /// The user's query or request to the database was malformed. Used whenever the database processes
+    /// the user's query directly, such as when editing posts inside of the database (e.g. Redis backend)
+    BadRequest,
+    /// the user's query collided with an in-flight request and needs to be retried
+    Conflict,
+    ///  - ErrorKind::Other - when something so weird happens that it becomes undescribable.
+    Other,
+}
+
+/// Enum representing settings that might be stored in the site's database.
+#[derive(Deserialize, Serialize, Debug, Clone, Copy)]
+#[serde(rename_all = "snake_case")]
+pub enum Settings {
+    /// The name of the website -- displayed in the header and the browser titlebar.
+    SiteName,
+}
+
+impl std::string::ToString for Settings {
+    fn to_string(&self) -> String {
+        serde_variant::to_variant_name(self).unwrap().to_string()
+    }
+}
+
+/// Error signalled from the database.
+#[derive(Debug)]
+pub struct StorageError {
+    msg: String,
+    source: Option<Box<dyn std::error::Error + Send + Sync>>,
+    kind: ErrorKind,
+}
+
+impl warp::reject::Reject for StorageError {}
+
+impl std::error::Error for StorageError {
+    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+        self.source
+            .as_ref()
+            .map(|e| e.as_ref() as &dyn std::error::Error)
+    }
+}
+impl From<serde_json::Error> for StorageError {
+    fn from(err: serde_json::Error) -> Self {
+        Self {
+            msg: format!("{}", err),
+            source: Some(Box::new(err)),
+            kind: ErrorKind::JsonParsing,
+        }
+    }
+}
+impl std::fmt::Display for StorageError {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        match match self.kind {
+            ErrorKind::Backend => write!(f, "backend error: "),
+            ErrorKind::JsonParsing => write!(f, "error while parsing JSON: "),
+            ErrorKind::PermissionDenied => write!(f, "permission denied: "),
+            ErrorKind::NotFound => write!(f, "not found: "),
+            ErrorKind::BadRequest => write!(f, "bad request: "),
+            ErrorKind::Conflict => write!(f, "conflict with an in-flight request or existing data: "),
+            ErrorKind::Other => write!(f, "generic storage layer error: "),
+        } {
+            Ok(_) => write!(f, "{}", self.msg),
+            Err(err) => Err(err),
+        }
+    }
+}
+impl serde::Serialize for StorageError {
+    fn serialize<S: serde::Serializer>(
+        &self,
+        serializer: S,
+    ) -> std::result::Result<S::Ok, S::Error> {
+        serializer.serialize_str(&self.to_string())
+    }
+}
+impl StorageError {
+    /// Create a new StorageError of an ErrorKind with a message.
+    fn new(kind: ErrorKind, msg: &str) -> Self {
+        Self {
+            msg: msg.to_string(),
+            source: None,
+            kind,
+        }
+    }
+    /// Create a StorageError using another arbitrary Error as a source.
+    fn with_source(
+        kind: ErrorKind,
+        msg: &str,
+        source: Box<dyn std::error::Error + Send + Sync>,
+    ) -> Self {
+        Self {
+            msg: msg.to_string(),
+            source: Some(source),
+            kind,
+        }
+    }
+    /// Get the kind of an error.
+    pub fn kind(&self) -> ErrorKind {
+        self.kind
+    }
+    /// Get the message as a string slice.
+    pub fn msg(&self) -> &str {
+        &self.msg
+    }
+}
+
+/// A special Result type for the Micropub backing storage.
+pub type Result<T> = std::result::Result<T, StorageError>;
+
+/// Filter the post according to the value of `user`.
+///
+/// Anonymous users cannot view private posts and protected locations;
+/// Logged-in users can only view private posts targeted at them;
+/// Logged-in users can't view private location data
+pub fn filter_post(
+    mut post: serde_json::Value,
+    user: &'_ Option<String>,
+) -> Option<serde_json::Value> {
+    if post["properties"]["deleted"][0].is_string() {
+        return Some(serde_json::json!({
+            "type": post["type"],
+            "properties": {
+                "deleted": post["properties"]["deleted"]
+            }
+        }));
+    }
+    let empty_vec: Vec<serde_json::Value> = vec![];
+    let author = post["properties"]["author"]
+        .as_array()
+        .unwrap_or(&empty_vec)
+        .iter()
+        .map(|i| i.as_str().unwrap().to_string());
+    let visibility = post["properties"]["visibility"][0]
+        .as_str()
+        .unwrap_or("public");
+    let mut audience = author.chain(
+        post["properties"]["audience"]
+            .as_array()
+            .unwrap_or(&empty_vec)
+            .iter()
+            .map(|i| i.as_str().unwrap().to_string()),
+    );
+    if (visibility == "private" && !audience.any(|i| Some(i) == *user))
+        || (visibility == "protected" && user.is_none())
+    {
+        return None;
+    }
+    if post["properties"]["location"].is_array() {
+        let location_visibility = post["properties"]["location-visibility"][0]
+            .as_str()
+            .unwrap_or("private");
+        let mut author = post["properties"]["author"]
+            .as_array()
+            .unwrap_or(&empty_vec)
+            .iter()
+            .map(|i| i.as_str().unwrap().to_string());
+        if (location_visibility == "private" && !author.any(|i| Some(i) == *user))
+            || (location_visibility == "protected" && user.is_none())
+        {
+            post["properties"]
+                .as_object_mut()
+                .unwrap()
+                .remove("location");
+        }
+    }
+    Some(post)
+}
+
+/// A storage backend for the Micropub server.
+///
+/// Implementations should note that all methods listed on this trait MUST be fully atomic
+/// or lock the database so that write conflicts or reading half-written data should not occur.
+#[async_trait]
+pub trait Storage: std::fmt::Debug + Clone + Send + Sync {
+    /// Check if a post exists in the database.
+    async fn post_exists(&self, url: &str) -> Result<bool>;
+
+    /// Load a post from the database in MF2-JSON format, deserialized from JSON.
+    async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>>;
+
+    /// Save a post to the database as an MF2-JSON structure.
+    ///
+    /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined.
+    async fn put_post(&self, post: &'_ serde_json::Value, user: &'_ str) -> Result<()>;
+
+    /// Modify a post using an update object as defined in the Micropub spec.
+    ///
+    /// Note to implementors: the update operation MUST be atomic and
+    /// SHOULD lock the database to prevent two clients overwriting
+    /// each other's changes or simply corrupting something. Rejecting
+    /// is allowed in case of concurrent updates if waiting for a lock
+    /// cannot be done.
+    async fn update_post(&self, url: &'_ str, update: serde_json::Value) -> Result<()>;
+
+    /// Get a list of channels available for the user represented by the URL `user` to write to.
+    async fn get_channels(&self, user: &'_ str) -> Result<Vec<MicropubChannel>>;
+
+    /// Fetch a feed at `url` and return a an h-feed object containing
+    /// `limit` posts after a post by url `after`, filtering the content
+    /// in context of a user specified by `user` (or an anonymous user).
+    ///
+    /// Specifically, private posts that don't include the user in the audience
+    /// will be elided from the feed, and the posts containing location and not
+    /// specifying post["properties"]["location-visibility"][0] == "public"
+    /// will have their location data (but not check-in data) stripped.
+    ///
+    /// This function is used as an optimization so the client, whatever it is,
+    /// doesn't have to fetch posts, then realize some of them are private, and
+    /// fetch more posts.
+    ///
+    /// Note for implementors: if you use streams to fetch posts in parallel
+    /// from the database, preferably make this method use a connection pool
+    /// to reduce overhead of creating a database connection per post for
+    /// parallel fetching.
+    async fn read_feed_with_limit(
+        &self,
+        url: &'_ str,
+        after: &'_ Option<String>,
+        limit: usize,
+        user: &'_ Option<String>,
+    ) -> Result<Option<serde_json::Value>>;
+
+    /// Deletes a post from the database irreversibly. 'nuff said. Must be idempotent.
+    async fn delete_post(&self, url: &'_ str) -> Result<()>;
+
+    /// Gets a setting from the setting store and passes the result.
+    async fn get_setting(&self, setting: Settings, user: &'_ str) -> Result<String>;
+
+    /// Commits a setting to the setting store.
+    async fn set_setting(&self, setting: Settings, user: &'_ str, value: &'_ str) -> Result<()>;
+}
+
+#[cfg(test)]
+mod tests {
+    use super::{MicropubChannel, Storage};
+    use serde_json::json;
+
+    async fn test_basic_operations<Backend: Storage>(backend: Backend) {
+        let post: serde_json::Value = json!({
+            "type": ["h-entry"],
+            "properties": {
+                "content": ["Test content"],
+                "author": ["https://fireburn.ru/"],
+                "uid": ["https://fireburn.ru/posts/hello"],
+                "url": ["https://fireburn.ru/posts/hello", "https://fireburn.ru/posts/test"]
+            }
+        });
+        let key = post["properties"]["uid"][0].as_str().unwrap().to_string();
+        let alt_url = post["properties"]["url"][1].as_str().unwrap().to_string();
+
+        // Reading and writing
+        backend
+            .put_post(&post, "https://fireburn.ru/")
+            .await
+            .unwrap();
+        if let Some(returned_post) = backend.get_post(&key).await.unwrap() {
+            assert!(returned_post.is_object());
+            assert_eq!(
+                returned_post["type"].as_array().unwrap().len(),
+                post["type"].as_array().unwrap().len()
+            );
+            assert_eq!(
+                returned_post["type"].as_array().unwrap(),
+                post["type"].as_array().unwrap()
+            );
+            let props: &serde_json::Map<String, serde_json::Value> =
+                post["properties"].as_object().unwrap();
+            for key in props.keys() {
+                assert_eq!(
+                    returned_post["properties"][key].as_array().unwrap(),
+                    post["properties"][key].as_array().unwrap()
+                )
+            }
+        } else {
+            panic!("For some reason the backend did not return the post.")
+        }
+        // Check the alternative URL - it should return the same post
+        if let Ok(Some(returned_post)) = backend.get_post(&alt_url).await {
+            assert!(returned_post.is_object());
+            assert_eq!(
+                returned_post["type"].as_array().unwrap().len(),
+                post["type"].as_array().unwrap().len()
+            );
+            assert_eq!(
+                returned_post["type"].as_array().unwrap(),
+                post["type"].as_array().unwrap()
+            );
+            let props: &serde_json::Map<String, serde_json::Value> =
+                post["properties"].as_object().unwrap();
+            for key in props.keys() {
+                assert_eq!(
+                    returned_post["properties"][key].as_array().unwrap(),
+                    post["properties"][key].as_array().unwrap()
+                )
+            }
+        } else {
+            panic!("For some reason the backend did not return the post.")
+        }
+    }
+
+    /// Note: this is merely a smoke check and is in no way comprehensive.
+    // TODO updates for feeds must update children using special logic
+    async fn test_update<Backend: Storage>(backend: Backend) {
+        let post: serde_json::Value = json!({
+            "type": ["h-entry"],
+            "properties": {
+                "content": ["Test content"],
+                "author": ["https://fireburn.ru/"],
+                "uid": ["https://fireburn.ru/posts/hello"],
+                "url": ["https://fireburn.ru/posts/hello", "https://fireburn.ru/posts/test"]
+            }
+        });
+        let key = post["properties"]["uid"][0].as_str().unwrap().to_string();
+
+        // Reading and writing
+        backend
+            .put_post(&post, "https://fireburn.ru/")
+            .await
+            .unwrap();
+
+        backend
+            .update_post(
+                &key,
+                json!({
+                    "url": &key,
+                    "add": {
+                        "category": ["testing"],
+                    },
+                    "replace": {
+                        "content": ["Different test content"]
+                    }
+                }),
+            )
+            .await
+            .unwrap();
+
+        match backend.get_post(&key).await {
+            Ok(Some(returned_post)) => {
+                assert!(returned_post.is_object());
+                assert_eq!(
+                    returned_post["type"].as_array().unwrap().len(),
+                    post["type"].as_array().unwrap().len()
+                );
+                assert_eq!(
+                    returned_post["type"].as_array().unwrap(),
+                    post["type"].as_array().unwrap()
+                );
+                assert_eq!(
+                    returned_post["properties"]["content"][0].as_str().unwrap(),
+                    "Different test content"
+                );
+                assert_eq!(
+                    returned_post["properties"]["category"].as_array().unwrap(),
+                    &vec![json!("testing")]
+                );
+            },
+            something_else => {
+                something_else.expect("Shouldn't error").expect("Should have the post");
+            }
+        }
+    }
+
+    async fn test_get_channel_list<Backend: Storage>(backend: Backend) {
+        let feed = json!({
+            "type": ["h-feed"],
+            "properties": {
+                "name": ["Main Page"],
+                "author": ["https://fireburn.ru/"],
+                "uid": ["https://fireburn.ru/feeds/main"]
+            },
+            "children": []
+        });
+        backend
+            .put_post(&feed, "https://fireburn.ru/")
+            .await
+            .unwrap();
+        let chans = backend.get_channels("https://fireburn.ru/").await.unwrap();
+        assert_eq!(chans.len(), 1);
+        assert_eq!(
+            chans[0],
+            MicropubChannel {
+                uid: "https://fireburn.ru/feeds/main".to_string(),
+                name: "Main Page".to_string()
+            }
+        );
+    }
+
+    async fn test_settings<Backend: Storage>(backend: Backend) {
+        backend
+            .set_setting(crate::database::Settings::SiteName, "https://fireburn.ru/", "Vika's Hideout")
+            .await
+            .unwrap();
+        assert_eq!(
+            backend
+                .get_setting(crate::database::Settings::SiteName, "https://fireburn.ru/")
+                .await
+                .unwrap(),
+            "Vika's Hideout"
+        );
+    }
+
+    fn gen_random_post(domain: &str) -> serde_json::Value {
+        use faker_rand::lorem::{Paragraphs, Word};
+
+        let uid = format!(
+            "https://{domain}/posts/{}-{}-{}",
+            rand::random::<Word>(), rand::random::<Word>(), rand::random::<Word>()
+        );
+
+        let post = json!({
+            "type": ["h-entry"],
+            "properties": {
+                "content": [rand::random::<Paragraphs>().to_string()],
+                "uid": [&uid],
+                "url": [&uid]
+            }
+        });
+
+        post
+    }
+
+    async fn test_feed_pagination<Backend: Storage>(backend: Backend) {
+        let posts = std::iter::from_fn(|| Some(gen_random_post("fireburn.ru")))
+            .take(20)
+            .collect::<Vec<serde_json::Value>>();
+
+        let feed = json!({
+            "type": ["h-feed"],
+            "properties": {
+                "name": ["Main Page"],
+                "author": ["https://fireburn.ru/"],
+                "uid": ["https://fireburn.ru/feeds/main"]
+            },
+            "children": posts.iter()
+                .filter_map(|json| json["properties"]["uid"][0].as_str())
+                .collect::<Vec<&str>>()
+        });
+        let key = feed["properties"]["uid"][0].as_str().unwrap();
+
+        backend
+            .put_post(&feed, "https://fireburn.ru/")
+            .await
+            .unwrap();
+        println!("---");
+        for (i, post) in posts.iter().enumerate() {
+            backend.put_post(post, "https://fireburn.ru/").await.unwrap();
+            println!("posts[{}] = {}", i, post["properties"]["uid"][0]);
+        }
+        println!("---");
+        let limit: usize = 10;
+        let result = backend.read_feed_with_limit(key, &None, limit, &None)
+            .await
+            .unwrap()
+            .unwrap();
+        for (i, post) in result["children"].as_array().unwrap().iter().enumerate() {
+            println!("feed[0][{}] = {}", i, post["properties"]["uid"][0]);
+        }
+        println!("---");
+        assert_eq!(result["children"].as_array().unwrap()[0..10], posts[0..10]);
+
+        let result2 = backend.read_feed_with_limit(
+            key,
+            &result["children"]
+                .as_array()
+                .unwrap()
+                .last()
+                .unwrap()
+                ["properties"]["uid"][0]
+                .as_str()
+                .map(|i| i.to_owned()),
+            limit, &None
+        ).await.unwrap().unwrap();
+        for (i, post) in result2["children"].as_array().unwrap().iter().enumerate() {
+            println!("feed[1][{}] = {}", i, post["properties"]["uid"][0]);
+        }
+        println!("---");
+        assert_eq!(result2["children"].as_array().unwrap()[0..10], posts[10..20]);
+
+        // Regression test for #4
+        let nonsense_after = Some("1010101010".to_owned());
+        let result3 = tokio::time::timeout(tokio::time::Duration::from_secs(10), async move {
+            backend.read_feed_with_limit(
+                key, &nonsense_after, limit, &None
+            ).await.unwrap().unwrap()
+        }).await.expect("Operation should not hang: see https://gitlab.com/kittybox/kittybox/-/issues/4");
+        assert!(result3["children"].as_array().unwrap().is_empty());
+    }
+
+    /// Automatically generates a test suite for
+    macro_rules! test_all {
+        ($func_name:ident, $mod_name:ident) => {
+            mod $mod_name {
+                $func_name!(test_basic_operations);
+                $func_name!(test_get_channel_list);
+                $func_name!(test_settings);
+                $func_name!(test_update);
+                $func_name!(test_feed_pagination);
+            }
+        }
+    }
+    macro_rules! file_test {
+        ($func_name:ident) => {
+            #[tokio::test]
+            async fn $func_name () {
+                test_logger::ensure_env_logger_initialized();
+                let tempdir = tempdir::TempDir::new("file").expect("Failed to create tempdir");
+                let backend = super::super::FileStorage::new(tempdir.into_path()).await.unwrap();
+                super::$func_name(backend).await
+            }
+        };
+    }
+
+    test_all!(file_test, file);
+
+}