From 7024cbefb27e1c9649bff57df32b316484de4104 Mon Sep 17 00:00:00 2001 From: Vika Date: Tue, 4 May 2021 19:24:51 +0300 Subject: Refactored the database module and its tests --- Cargo.lock | 123 ++++++-- Cargo.toml | 2 +- flake.nix | 3 +- src/database.rs | 625 --------------------------------------- src/database/memory.rs | 107 +++++++ src/database/mod.rs | 263 ++++++++++++++++ src/database/redis/edit_post.lua | 93 ++++++ src/database/redis/mod.rs | 304 +++++++++++++++++++ src/edit_post.lua | 93 ------ src/lib.rs | 100 +++---- src/micropub/post.rs | 1 - 11 files changed, 907 insertions(+), 807 deletions(-) delete mode 100644 src/database.rs create mode 100644 src/database/memory.rs create mode 100644 src/database/mod.rs create mode 100644 src/database/redis/edit_post.lua create mode 100644 src/database/redis/mod.rs delete mode 100644 src/edit_post.lua diff --git a/Cargo.lock b/Cargo.lock index 207bf6f..bd3dbb4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -56,9 +56,9 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "0.7.15" +version = "0.7.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7404febffaa47dac81aa44dba71523c9d069b1bdc50a77db41195149e17f68e5" +checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f" dependencies = [ "memchr", ] @@ -573,9 +573,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7e9d99fa91428effe99c5c6d4634cdeba32b8cf784fc428a2a687f61a952c49" +checksum = "4feb231f0d4d6af81aed15928e58ecf5816aa62a2393e2c82f46973e92a9a278" dependencies = [ "autocfg", "cfg-if 1.0.0", @@ -834,6 +834,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fuchsia-cprng" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06f77d526c1a601b7c4cdd98f54b5eaabffc14d5f2f0296febdc7f357c6d3ba" + [[package]] name = "futf" version = "0.1.4" @@ -1102,9 +1108,9 @@ dependencies = [ [[package]] name = "http-types" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "686f600cccfb9d96c45550bac47b592bc88191a0dd965e9d55848880c2c5a45f" +checksum = "ad077d89137cd3debdce53c66714dc536525ef43fe075d41ddc0a8ac11f85957" dependencies = [ "anyhow", "async-channel", @@ -1222,6 +1228,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "surf", + "tempdir", "tide", "tide-testing", "url", @@ -1256,9 +1263,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.93" +version = "0.2.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9385f66bf6105b241aa65a61cb923ef20efc665cb9f9bb50ac2f0c4b7f378d41" +checksum = "18794a8ad5b29321f790b55d93dfba91e125cb1a9edbd4f8e3150acc771c1a5e" [[package]] name = "libnghttp2-sys" @@ -1272,9 +1279,9 @@ dependencies = [ [[package]] name = "libz-sys" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "602113192b08db8f38796c4e85c39e960c145965140e918018bcde1952429655" +checksum = "de5435b8549c16d423ed0c03dbaafe57cf6c3344744f1242520d59c9d8ecec66" dependencies = [ "cc", "libc", @@ -1284,9 +1291,9 @@ dependencies = [ [[package]] name = "lock_api" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a3c91c24eae6777794bb1997ad98bbb87daf92890acab859f7eaa4320333176" +checksum = "0382880606dff6d15c9476c416d18690b72742aa7b605bb6dd6ec9030fbf07eb" dependencies = [ "scopeguard", ] @@ -1340,9 +1347,9 @@ checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08" [[package]] name = "memchr" -version = "2.3.4" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ee1c47aaa256ecabcaea351eae4a9b01ef39ed810004e298d2511ed284b1525" +checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc" [[package]] name = "mime" @@ -1501,9 +1508,9 @@ checksum = "77af24da69f9d9341038eba93a073b1fdaaa1b788221b00a69bce9e762cb32de" [[package]] name = "openssl-sys" -version = "0.9.61" +version = "0.9.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "313752393519e876837e09e1fa183ddef0be7735868dced3196f4472d536277f" +checksum = "fa52160d45fa2e7608d504b7c3a3355afed615e6d8b627a74458634ba21b69bd" dependencies = [ "autocfg", "cc", @@ -1730,6 +1737,19 @@ dependencies = [ "scheduled-thread-pool", ] +[[package]] +name = "rand" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "552840b97013b1a26992c11eac34bdd778e464601a4c2054b5f0bff7c6761293" +dependencies = [ + "fuchsia-cprng", + "libc", + "rand_core 0.3.1", + "rdrand", + "winapi", +] + [[package]] name = "rand" version = "0.7.3" @@ -1776,6 +1796,21 @@ dependencies = [ "rand_core 0.6.2", ] +[[package]] +name = "rand_core" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a6fdeb83b075e8266dcc8762c22776f6877a63111121f5f8c7411e5be7eed4b" +dependencies = [ + "rand_core 0.4.2", +] + +[[package]] +name = "rand_core" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c33a3c44ca05fa6f1807d8e6743f3824e8509beca625669633be0acbdf509dc" + [[package]] name = "rand_core" version = "0.5.1" @@ -1821,6 +1856,15 @@ dependencies = [ "rand_core 0.5.1", ] +[[package]] +name = "rdrand" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "678054eb77286b51581ba43620cc911abf02758c91f93f479767aed0f90458b2" +dependencies = [ + "rand_core 0.3.1", +] + [[package]] name = "redis" version = "0.19.0" @@ -1865,18 +1909,18 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.2.6" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8270314b5ccceb518e7e578952f0b72b88222d02e8f77f5ecf7abbb673539041" +checksum = "742739e41cd49414de871ea5e549afb7e2a3ac77b589bcbebe8c82fab37147fc" dependencies = [ "bitflags", ] [[package]] name = "regex" -version = "1.4.6" +version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a26af418b574bd56588335b3a3659a65725d4e636eb1016c2f9e3b38c7cc759" +checksum = "ce5f1ceb7f74abbce32601642fcf8e8508a8a8991e0621c7d750295b9095702b" dependencies = [ "aho-corasick", "memchr", @@ -1885,9 +1929,18 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.23" +version = "0.6.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" + +[[package]] +name = "remove_dir_all" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24d5f089152e60f62d28b835fbff2cd2e8dc0baf1ac13343bef92ab7eed84548" +checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7" +dependencies = [ + "winapi", +] [[package]] name = "route-recognizer" @@ -2255,15 +2308,25 @@ checksum = "45f6ee7c7b87caf59549e9fe45d6a69c75c8019e79e212a835c5da0e92f0ba08" [[package]] name = "syn" -version = "1.0.70" +version = "1.0.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9505f307c872bab8eb46f77ae357c8eba1fdacead58ee5a850116b1d7f82883" +checksum = "a1e8cdbefb79a9a5a65e0db8b47b723ee907b7c7f8496c76a1770b5c310bab82" dependencies = [ "proc-macro2", "quote", "unicode-xid", ] +[[package]] +name = "tempdir" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15f2b5fb00ccdf689e0149d1b1b3c03fead81c2b37735d812fa8bddbbf41b6d8" +dependencies = [ + "rand 0.4.6", + "remove_dir_all", +] + [[package]] name = "tendril" version = "0.4.2" @@ -2440,9 +2503,9 @@ dependencies = [ [[package]] name = "tracing" -version = "0.1.25" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01ebdc2bb4498ab1ab5f5b73c5803825e60199229ccba0698170e3be0e7f959f" +checksum = "09adeb8c97449311ccd28a427f96fb563e7fd31aabf994189879d9da2394b89d" dependencies = [ "cfg-if 1.0.0", "log", @@ -2464,9 +2527,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.17" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f50de3927f93d202783f4513cda820ab47ef17f624b03c096e86ef00c67e6b5f" +checksum = "a9ff14f98b1a4b289c6248a023c1c2fa1491062964e9fed67ab29c4e4da4a052" dependencies = [ "lazy_static", ] @@ -2516,9 +2579,9 @@ dependencies = [ [[package]] name = "unicode-xid" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7fe0bb3479651439c9112f72b6c505038574c9fbb575ed1bf3b797fa39dd564" +checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" [[package]] name = "universal-hash" diff --git a/Cargo.toml b/Cargo.toml index 66d3526..e38cffd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ edition = "2018" [dev-dependencies] tide-testing = "0.1.3" # tide testing helper mockito = "0.30.0" # HTTP mocking for Rust. -serde_urlencoded = "0.7.0" +tempdir = "0.3.7" # A library for managing a temporary directory and deleting all contents when it's dropped. [dependencies] # Redis driver for Rust. diff --git a/flake.nix b/flake.nix index 9eb296e..007218a 100644 --- a/flake.nix +++ b/flake.nix @@ -28,7 +28,7 @@ src = ./.; #cargoSha256 = nixpkgs.lib.fakeSha256; - cargoHash = "sha256-DnPwAuzXWLJdyDf0dvfXUic0oQPVCpsnCyFdLSlsSs0="; + cargoHash = "sha256-c862J7qXkrCAA6+gFn/bX5NXrgu/1HtCH1DqOpQePgk="; buildInputs = [ openssl ]; nativeBuildInputs = [ pkg-config ]; @@ -60,6 +60,7 @@ pkg-config lld rust-bin.default rust-bin.rls + redis ]; }; }); diff --git a/src/database.rs b/src/database.rs deleted file mode 100644 index 3a9ac04..0000000 --- a/src/database.rs +++ /dev/null @@ -1,625 +0,0 @@ -#![allow(unused_variables)] -use async_trait::async_trait; -use lazy_static::lazy_static; -#[cfg(test)] -use std::collections::HashMap; -#[cfg(test)] -use std::sync::Arc; -#[cfg(test)] -use async_std::sync::RwLock; -use log::error; -use futures::stream; -use futures_util::FutureExt; -use futures_util::StreamExt; -use serde::{Serialize,Deserialize}; -use serde_json::json; -use redis; -use redis::AsyncCommands; - -use crate::indieauth::User; - -#[derive(Serialize, Deserialize, PartialEq, Debug)] -pub struct MicropubChannel { - pub uid: String, - pub name: String -} - -#[derive(Debug, Clone, Copy)] -pub enum ErrorKind { - Backend, - PermissionDenied, - JSONParsing, - NotFound, - Other -} - -// TODO get rid of your own errors and use a crate? -#[derive(Debug)] -pub struct StorageError { - msg: String, - source: Option>, - kind: ErrorKind -} -unsafe impl Send for StorageError {} -unsafe impl Sync for StorageError {} -impl std::error::Error for StorageError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.source.as_ref().map(|e| e.as_ref()) - } -} -impl From for StorageError { - fn from(err: redis::RedisError) -> Self { - Self { - msg: format!("{}", err), - source: Some(Box::new(err)), - kind: ErrorKind::Backend - } - } -} -impl From for StorageError { - fn from(err: serde_json::Error) -> Self { - Self { - msg: format!("{}", err), - source: Some(Box::new(err)), - kind: ErrorKind::JSONParsing - } - } -} -impl std::fmt::Display for StorageError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match match self.kind { - ErrorKind::Backend => write!(f, "backend error: "), - ErrorKind::JSONParsing => write!(f, "error while parsing JSON: "), - ErrorKind::PermissionDenied => write!(f, "permission denied: "), - ErrorKind::NotFound => write!(f, "not found: "), - ErrorKind::Other => write!(f, "generic storage layer error: ") - } { - Ok(_) => write!(f, "{}", self.msg), - Err(err) => Err(err) - } - } -} -impl serde::Serialize for StorageError { - fn serialize(&self, serializer: S) -> std::result::Result { - serializer.serialize_str(&self.to_string()) - } -} -impl StorageError { - fn new(kind: ErrorKind, msg: &str) -> Self { - return StorageError { - msg: msg.to_string(), - source: None, - kind - } - } - pub fn kind(&self) -> ErrorKind { - self.kind - } -} - -pub type Result = std::result::Result; - -/// A storage backend for the Micropub server. -/// -/// Implementations should note that all methods listed on this trait MUST be fully atomic -/// or lock the database so that write conflicts or reading half-written data should not occur. -#[async_trait] -pub trait Storage: Clone + Send + Sync { - /// Check if a post exists in the database. - async fn post_exists(&self, url: &str) -> Result; - - /// Load a post from the database in MF2-JSON format, deserialized from JSON. - async fn get_post(&self, url: &str) -> Result>; - - /// Save a post to the database as an MF2-JSON structure. - /// - /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined. - async fn put_post<'a>(&self, post: &'a serde_json::Value) -> Result<()>; - - /*/// Save a post and add it to the relevant feeds listed in `post["properties"]["channel"]`. - /// - /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined - /// and `post["properties"]["channel"]` defined, even if it's empty. - async fn put_and_index_post<'a>(&mut self, post: &'a serde_json::Value) -> Result<()>;*/ - - /// Modify a post using an update object as defined in the Micropub spec. - /// - /// Note to implementors: the update operation MUST be atomic OR MUST lock the database - /// to prevent two clients overwriting each other's changes. - /// - /// You can assume concurrent updates will never contradict each other, since that will be dumb. - /// The last update always wins. - async fn update_post<'a>(&self, url: &'a str, update: serde_json::Value) -> Result<()>; - - /// Get a list of channels available for the user represented by the `user` object to write. - async fn get_channels(&self, user: &User) -> Result>; - - /// Fetch a feed at `url` and return a an h-feed object containing - /// `limit` posts after a post by url `after`, filtering the content - /// in context of a user specified by `user` (or an anonymous user). - /// - /// Specifically, private posts that don't include the user in the audience - /// will be elided from the feed, and the posts containing location and not - /// specifying post["properties"]["location-visibility"][0] == "public" - /// will have their location data (but not check-in data) stripped. - /// - /// This function is used as an optimization so the client, whatever it is, - /// doesn't have to fetch posts, then realize some of them are private, and - /// fetch more posts. - /// - /// Note for implementors: if you use streams to fetch posts in parallel - /// from the database, preferably make this method use a connection pool - /// to reduce overhead of creating a database connection per post for - /// parallel fetching. - async fn read_feed_with_limit<'a>(&self, url: &'a str, after: &'a Option, limit: usize, user: &'a Option) -> Result>; - - /// Deletes a post from the database irreversibly. 'nuff said. Must be idempotent. - async fn delete_post<'a>(&self, url: &'a str) -> Result<()>; -} - -#[cfg(test)] -#[derive(Clone)] -pub struct MemoryStorage { - pub mapping: Arc>>, - pub channels: Arc>>> -} - -#[cfg(test)] -#[async_trait] -impl Storage for MemoryStorage { - async fn read_feed_with_limit<'a>(&self, url: &'a str, after: &'a Option, limit: usize, user: &'a Option) -> Result> { - todo!() - } - - async fn update_post<'a>(&self, url: &'a str, update: serde_json::Value) -> Result<()> { - todo!() - } - - async fn delete_post<'a>(&self, url: &'a str) -> Result<()> { - self.mapping.write().await.remove(url); - Ok(()) - } - - async fn post_exists(&self, url: &str) -> Result { - return Ok(self.mapping.read().await.contains_key(url)) - } - - async fn get_post(&self, url: &str) ->Result> { - let mapping = self.mapping.read().await; - match mapping.get(url) { - Some(val) => { - if let Some(new_url) = val["see_other"].as_str() { - match mapping.get(new_url) { - Some(val) => Ok(Some(val.clone())), - None => { - drop(mapping); - self.mapping.write().await.remove(url); - Ok(None) - } - } - } else { - Ok(Some(val.clone())) - } - }, - _ => Ok(None) - } - } - - async fn get_channels(&self, user: &User) -> Result> { - match self.channels.read().await.get(&user.me.to_string()) { - Some(channels) => Ok(futures_util::future::join_all(channels.iter() - .map(|channel| self.get_post(channel) - .map(|result| result.unwrap()) - .map(|post: Option| { - if let Some(post) = post { - Some(MicropubChannel { - uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), - name: post["properties"]["name"][0].as_str().unwrap().to_string() - }) - } else { None } - }) - ).collect::>()).await.into_iter().filter_map(|chan| chan).collect::>()), - None => Ok(vec![]) - } - - } - - async fn put_post<'a>(&self, post: &'a serde_json::Value) -> Result<()> { - let mapping = &mut self.mapping.write().await; - let key: &str; - match post["properties"]["uid"][0].as_str() { - Some(uid) => key = uid, - None => return Err(StorageError::new(ErrorKind::Other, "post doesn't have a UID")) - } - mapping.insert(key.to_string(), post.clone()); - if post["properties"]["url"].is_array() { - for url in post["properties"]["url"].as_array().unwrap().iter().map(|i| i.as_str().unwrap().to_string()) { - if &url != key { - mapping.insert(url, json!({"see_other": key})); - } - } - } - if post["type"].as_array().unwrap().iter().any(|i| i == "h-feed") { - // This is a feed. Add it to the channels array if it's not already there. - println!("{:#}", post); - self.channels.write().await.entry(post["properties"]["author"][0].as_str().unwrap().to_string()).or_insert(vec![]).push(key.to_string()) - } - Ok(()) - } -} -#[cfg(test)] -impl MemoryStorage { - pub fn new() -> Self { - Self { - mapping: Arc::new(RwLock::new(HashMap::new())), - channels: Arc::new(RwLock::new(HashMap::new())) - } - } -} - -struct RedisScripts { - edit_post: redis::Script -} - -lazy_static! { - static ref SCRIPTS: RedisScripts = RedisScripts { - edit_post: redis::Script::new(include_str!("./edit_post.lua")) - }; -} - -#[derive(Clone)] -pub struct RedisStorage { - // TODO: use mobc crate to create a connection pool and reuse connections for efficiency - redis: redis::Client, -} - -fn filter_post<'a>(mut post: serde_json::Value, user: &'a Option) -> Option { - if post["properties"]["deleted"][0].is_string() { - return Some(json!({ - "type": post["type"], - "properties": { - "deleted": post["properties"]["deleted"] - } - })); - } - let empty_vec: Vec = vec![]; - let author = post["properties"]["author"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string()); - let visibility = post["properties"]["visibility"][0].as_str().unwrap_or("public"); - let mut audience = author.chain(post["properties"]["audience"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string())); - if visibility == "private" { - if !audience.any(|i| Some(i) == *user) { - return None - } - } - if post["properties"]["location"].is_array() { - let location_visibility = post["properties"]["location-visibility"][0].as_str().unwrap_or("private"); - let mut author = post["properties"]["author"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string()); - if location_visibility == "private" && !author.any(|i| Some(i) == *user) { - post["properties"].as_object_mut().unwrap().remove("location"); - } - } - Some(post) -} - -#[async_trait] -impl Storage for RedisStorage { - async fn delete_post<'a>(&self, url: &'a str) -> Result<()> { - match self.redis.get_async_std_connection().await { - Ok(mut conn) => if let Err(err) = conn.hdel::<&str, &str, bool>("posts", url).await { - return Err(err.into()); - }, - Err(err) => return Err(err.into()) - } - Ok(()) - } - - async fn post_exists(&self, url: &str) -> Result { - match self.redis.get_async_std_connection().await { - Ok(mut conn) => match conn.hexists::<&str, &str, bool>(&"posts", url).await { - Ok(val) => Ok(val), - Err(err) => Err(err.into()) - }, - Err(err) => Err(err.into()) - } - } - - async fn get_post(&self, url: &str) -> Result> { - match self.redis.get_async_std_connection().await { - Ok(mut conn) => match conn.hget::<&str, &str, Option>(&"posts", url).await { - Ok(val) => match val { - Some(val) => match serde_json::from_str::(&val) { - Ok(parsed) => if let Some(new_url) = parsed["see_other"].as_str() { - match conn.hget::<&str, &str, Option>(&"posts", new_url).await { - Ok(val) => match val { - Some(val) => match serde_json::from_str::(&val) { - Ok(parsed) => Ok(Some(parsed)), - Err(err) => Err(err.into()) - }, - None => Ok(None) - } - Err(err) => { - Ok(None) - } - } - } else { - Ok(Some(parsed)) - }, - Err(err) => Err(err.into()) - }, - None => Ok(None) - }, - Err(err) => Err(err.into()) - }, - Err(err) => Err(err.into()) - } - } - - async fn get_channels(&self, user: &User) -> Result> { - match self.redis.get_async_std_connection().await { - Ok(mut conn) => match conn.smembers::>("channels_".to_string() + user.me.as_str()).await { - Ok(channels) => { - Ok(futures_util::future::join_all(channels.iter() - .map(|channel| self.get_post(channel) - .map(|result| result.unwrap()) - .map(|post: Option| { - if let Some(post) = post { - Some(MicropubChannel { - uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), - name: post["properties"]["name"][0].as_str().unwrap().to_string() - }) - } else { None } - }) - ).collect::>()).await.into_iter().filter_map(|chan| chan).collect::>()) - }, - Err(err) => Err(err.into()) - }, - Err(err) => Err(err.into()) - } - } - - async fn put_post<'a>(&self, post: &'a serde_json::Value) -> Result<()> { - match self.redis.get_async_std_connection().await { - Ok(mut conn) => { - let key: &str; - match post["properties"]["uid"][0].as_str() { - Some(uid) => key = uid, - None => return Err(StorageError::new(ErrorKind::Other, "post doesn't have a UID")) - } - match conn.hset::<&str, &str, String, ()>(&"posts", key, post.to_string()).await { - Err(err) => return Err(err.into()), - _ => {} - } - if post["properties"]["url"].is_array() { - for url in post["properties"]["url"].as_array().unwrap().iter().map(|i| i.as_str().unwrap().to_string()) { - if &url != key { - match conn.hset::<&str, &str, String, ()>(&"posts", &url, json!({"see_other": key}).to_string()).await { - Err(err) => return Err(err.into()), - _ => {} - } - } - } - } - if post["type"].as_array().unwrap().iter().any(|i| i == "h-feed") { - // This is a feed. Add it to the channels array if it's not already there. - match conn.sadd::("channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(), key).await { - Err(err) => return Err(err.into()), - _ => {}, - } - } - Ok(()) - }, - Err(err) => Err(err.into()) - } - } - - async fn read_feed_with_limit<'a>(&self, url: &'a str, after: &'a Option, limit: usize, user: &'a Option) -> Result> { - match self.redis.get_async_std_connection().await { - Ok(mut conn) => { - let mut feed; - match conn.hget::<&str, &str, Option>(&"posts", url).await { - Ok(post) => { - match post { - Some(post) => match serde_json::from_str::(&post) { - Ok(post) => feed = post, - Err(err) => return Err(err.into()) - }, - None => return Ok(None) - } - }, - Err(err) => return Err(err.into()) - } - if feed["see_other"].is_string() { - match conn.hget::<&str, &str, Option>(&"posts", feed["see_other"].as_str().unwrap()).await { - Ok(post) => { - match post { - Some(post) => match serde_json::from_str::(&post) { - Ok(post) => feed = post, - Err(err) => return Err(err.into()) - }, - None => return Ok(None) - } - }, - Err(err) => return Err(err.into()) - } - } - if let Some(post) = filter_post(feed, user) { - feed = post - } else { - return Err(StorageError::new(ErrorKind::PermissionDenied, "specified user cannot access this post")) - } - if feed["children"].is_array() { - let children = feed["children"].as_array().unwrap(); - let posts_iter: Box + Send>; - if let Some(after) = after { - posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string()).skip_while(move |i| i != after).skip(1)); - } else { - posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string())); - } - let posts = stream::iter(posts_iter) - .map(|url| async move { - // Is it rational to use a new connection for every post fetched? - match self.redis.get_async_std_connection().await { - Ok(mut conn) => match conn.hget::<&str, &str, Option>("posts", &url).await { - Ok(post) => match post { - Some(post) => match serde_json::from_str::(&post) { - Ok(post) => Some(post), - Err(err) => { - let err = StorageError::from(err); - error!("{}", err); - panic!("{}", err) - } - }, - // Happens because of a broken link (result of an improper deletion?) - None => None, - }, - Err(err) => { - let err = StorageError::from(err); - error!("{}", err); - panic!("{}", err) - } - }, - Err(err) => { - let err = StorageError::from(err); - error!("{}", err); - panic!("{}", err) - } - } - }) - // TODO: determine the optimal value for this buffer - // It will probably depend on how often can you encounter a private post on the page - // It shouldn't be too large, or we'll start fetching too many posts from the database - // It MUST NOT be larger than the typical page size - .buffered(std::cmp::min(3, limit)) - // Hack to unwrap the Option and sieve out broken links - // Broken links return None, and Stream::filter_map skips all Nones. - .filter_map(|post: Option| async move { post }) - .filter_map(|post| async move { - return filter_post(post, user) - }) - .take(limit); - match std::panic::AssertUnwindSafe(posts.collect::>()).catch_unwind().await { - Ok(posts) => feed["children"] = json!(posts), - Err(err) => return Err(StorageError::new(ErrorKind::Other, "Unknown error encountered while assembling feed, see logs for more info")) - } - } - return Ok(Some(feed)); - } - Err(err) => Err(err.into()) - } - } - - async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> { - match self.redis.get_async_std_connection().await { - Ok(mut conn) => { - if !conn.hexists::<&str, &str, bool>("posts", url).await.unwrap() { - return Err(StorageError::new(ErrorKind::NotFound, "can't edit a non-existent post")) - } - let post: serde_json::Value = serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await.unwrap()).unwrap(); - if let Some(new_url) = post["see_other"].as_str() { - url = new_url - } - if let Err(err) = SCRIPTS.edit_post.key("posts").arg(url).arg(update.to_string()).invoke_async::<_, ()>(&mut conn).await { - return Err(err.into()) - } - }, - Err(err) => return Err(err.into()) - } - Ok(()) - } -} - - -impl RedisStorage { - /// Create a new RedisDatabase that will connect to Redis at `redis_uri` to store data. - pub async fn new(redis_uri: String) -> Result { - match redis::Client::open(redis_uri) { - Ok(client) => Ok(Self { redis: client }), - Err(e) => Err(e.into()) - } - } -} - -#[cfg(test)] -mod tests { - use super::{Storage, MicropubChannel}; - use serde_json::json; - - async fn test_backend_basic_operations(backend: Backend) { - let post: serde_json::Value = json!({ - "type": ["h-entry"], - "properties": { - "content": ["Test content"], - "author": ["https://fireburn.ru/"], - "uid": ["https://fireburn.ru/posts/hello"], - "url": ["https://fireburn.ru/posts/hello", "https://fireburn.ru/posts/test"] - } - }); - let key = post["properties"]["uid"][0].as_str().unwrap().to_string(); - let alt_url = post["properties"]["url"][1].as_str().unwrap().to_string(); - - // Reading and writing - backend.put_post(&post).await.unwrap(); - if let Ok(Some(returned_post)) = backend.get_post(&key).await { - assert!(returned_post.is_object()); - assert_eq!(returned_post["type"].as_array().unwrap().len(), post["type"].as_array().unwrap().len()); - assert_eq!(returned_post["type"].as_array().unwrap(), post["type"].as_array().unwrap()); - let props: &serde_json::Map = post["properties"].as_object().unwrap(); - for key in props.keys() { - assert_eq!(returned_post["properties"][key].as_array().unwrap(), post["properties"][key].as_array().unwrap()) - } - } else { panic!("For some reason the backend did not return the post.") } - // Check the alternative URL - it should return the same post - if let Ok(Some(returned_post)) = backend.get_post(&alt_url).await { - assert!(returned_post.is_object()); - assert_eq!(returned_post["type"].as_array().unwrap().len(), post["type"].as_array().unwrap().len()); - assert_eq!(returned_post["type"].as_array().unwrap(), post["type"].as_array().unwrap()); - let props: &serde_json::Map = post["properties"].as_object().unwrap(); - for key in props.keys() { - assert_eq!(returned_post["properties"][key].as_array().unwrap(), post["properties"][key].as_array().unwrap()) - } - } else { panic!("For some reason the backend did not return the post.") } - } - - async fn test_backend_channel_support(backend: Backend) { - let feed = json!({ - "type": ["h-feed"], - "properties": { - "name": ["Main Page"], - "author": ["https://fireburn.ru/"], - "uid": ["https://fireburn.ru/feeds/main"] - }, - "children": [] - }); - let key = feed["properties"]["uid"][0].as_str().unwrap().to_string(); - backend.put_post(&feed).await.unwrap(); - let chans = backend.get_channels(&crate::indieauth::User::new("https://fireburn.ru/", "https://quill.p3k.io/", "create update media")).await.unwrap(); - assert_eq!(chans.len(), 1); - assert_eq!(chans[0], MicropubChannel { uid: "https://fireburn.ru/feeds/main".to_string(), name: "Main Page".to_string() }); - } - - #[async_std::test] - async fn test_memory_storage_basic_operations() { - let backend = super::MemoryStorage::new(); - test_backend_basic_operations(backend).await - } - #[async_std::test] - async fn test_memory_storage_channel_support() { - let backend = super::MemoryStorage::new(); - test_backend_channel_support(backend).await - } - - #[async_std::test] - #[ignore] - async fn test_redis_storage_basic_operations() { - let backend = super::RedisStorage::new(std::env::var("REDIS_URI").unwrap_or("redis://127.0.0.1:6379/0".to_string())).await.unwrap(); - redis::cmd("FLUSHDB").query_async::<_, ()>(&mut backend.redis.get_async_std_connection().await.unwrap()).await.unwrap(); - test_backend_basic_operations(backend).await - } - #[async_std::test] - #[ignore] - async fn test_redis_storage_channel_support() { - let backend = super::RedisStorage::new(std::env::var("REDIS_URI").unwrap_or("redis://127.0.0.1:6379/1".to_string())).await.unwrap(); - redis::cmd("FLUSHDB").query_async::<_, ()>(&mut backend.redis.get_async_std_connection().await.unwrap()).await.unwrap(); - test_backend_channel_support(backend).await - } -} diff --git a/src/database/memory.rs b/src/database/memory.rs new file mode 100644 index 0000000..a4cf5a9 --- /dev/null +++ b/src/database/memory.rs @@ -0,0 +1,107 @@ +use async_trait::async_trait; +use std::collections::HashMap; +use std::sync::Arc; +use async_std::sync::RwLock; +use futures_util::FutureExt; +use serde_json::json; + +use crate::database::{Storage, Result, StorageError, ErrorKind, MicropubChannel}; +use crate::indieauth::User; + +#[derive(Clone)] +pub struct MemoryStorage { + pub mapping: Arc>>, + pub channels: Arc>>> +} + +#[async_trait] +impl Storage for MemoryStorage { + async fn read_feed_with_limit<'a>(&self, url: &'a str, after: &'a Option, limit: usize, user: &'a Option) -> Result> { + todo!() + } + + async fn update_post<'a>(&self, url: &'a str, update: serde_json::Value) -> Result<()> { + todo!() + } + + async fn delete_post<'a>(&self, url: &'a str) -> Result<()> { + self.mapping.write().await.remove(url); + Ok(()) + } + + async fn post_exists(&self, url: &str) -> Result { + return Ok(self.mapping.read().await.contains_key(url)) + } + + async fn get_post(&self, url: &str) ->Result> { + let mapping = self.mapping.read().await; + match mapping.get(url) { + Some(val) => { + if let Some(new_url) = val["see_other"].as_str() { + match mapping.get(new_url) { + Some(val) => Ok(Some(val.clone())), + None => { + drop(mapping); + self.mapping.write().await.remove(url); + Ok(None) + } + } + } else { + Ok(Some(val.clone())) + } + }, + _ => Ok(None) + } + } + + async fn get_channels(&self, user: &User) -> Result> { + match self.channels.read().await.get(&user.me.to_string()) { + Some(channels) => Ok(futures_util::future::join_all(channels.iter() + .map(|channel| self.get_post(channel) + .map(|result| result.unwrap()) + .map(|post: Option| { + if let Some(post) = post { + Some(MicropubChannel { + uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), + name: post["properties"]["name"][0].as_str().unwrap().to_string() + }) + } else { None } + }) + ).collect::>()).await.into_iter().filter_map(|chan| chan).collect::>()), + None => Ok(vec![]) + } + + } + + async fn put_post<'a>(&self, post: &'a serde_json::Value) -> Result<()> { + let mapping = &mut self.mapping.write().await; + let key: &str; + match post["properties"]["uid"][0].as_str() { + Some(uid) => key = uid, + None => return Err(StorageError::new(ErrorKind::Other, "post doesn't have a UID")) + } + mapping.insert(key.to_string(), post.clone()); + if post["properties"]["url"].is_array() { + for url in post["properties"]["url"].as_array().unwrap().iter().map(|i| i.as_str().unwrap().to_string()) { + if &url != key { + mapping.insert(url, json!({"see_other": key})); + } + } + } + if post["type"].as_array().unwrap().iter().any(|i| i == "h-feed") { + // This is a feed. Add it to the channels array if it's not already there. + println!("{:#}", post); + self.channels.write().await.entry(post["properties"]["author"][0].as_str().unwrap().to_string()).or_insert(vec![]).push(key.to_string()) + } + Ok(()) + } +} + +impl MemoryStorage { + pub fn new() -> Self { + Self { + mapping: Arc::new(RwLock::new(HashMap::new())), + channels: Arc::new(RwLock::new(HashMap::new())) + } + } +} diff --git a/src/database/mod.rs b/src/database/mod.rs new file mode 100644 index 0000000..6abe72c --- /dev/null +++ b/src/database/mod.rs @@ -0,0 +1,263 @@ +#![warn(missing_docs)] +use async_trait::async_trait; +use serde::{Serialize,Deserialize}; + +#[cfg(test)] +mod memory; +#[cfg(test)] +pub(crate) use crate::database::memory::MemoryStorage; + +use crate::indieauth::User; + +mod redis; +pub use crate::database::redis::RedisStorage; + +#[derive(Serialize, Deserialize, PartialEq, Debug)] +pub struct MicropubChannel { + pub uid: String, + pub name: String +} + +#[derive(Debug, Clone, Copy)] +pub enum ErrorKind { + Backend, + PermissionDenied, + JSONParsing, + NotFound, + Other +} + +#[derive(Debug)] +pub struct StorageError { + pub msg: String, + source: Option>, + pub kind: ErrorKind +} +unsafe impl Send for StorageError {} +unsafe impl Sync for StorageError {} +impl std::error::Error for StorageError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.source.as_ref().map(|e| e.as_ref()) + } +} +impl From for StorageError { + fn from(err: serde_json::Error) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::JSONParsing + } + } +} +impl std::fmt::Display for StorageError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match match self.kind { + ErrorKind::Backend => write!(f, "backend error: "), + ErrorKind::JSONParsing => write!(f, "error while parsing JSON: "), + ErrorKind::PermissionDenied => write!(f, "permission denied: "), + ErrorKind::NotFound => write!(f, "not found: "), + ErrorKind::Other => write!(f, "generic storage layer error: ") + } { + Ok(_) => write!(f, "{}", self.msg), + Err(err) => Err(err) + } + } +} +impl serde::Serialize for StorageError { + fn serialize(&self, serializer: S) -> std::result::Result { + serializer.serialize_str(&self.to_string()) + } +} +impl StorageError { + /// Create a new StorageError of an ErrorKind with a message. + fn new(kind: ErrorKind, msg: &str) -> Self { + return StorageError { + msg: msg.to_string(), + source: None, + kind + } + } + /// Get the kind of an error. + pub fn kind(&self) -> ErrorKind { + self.kind + } +} + + +/// A special Result type for the Micropub backing storage. +pub type Result = std::result::Result; + +/// A storage backend for the Micropub server. +/// +/// Implementations should note that all methods listed on this trait MUST be fully atomic +/// or lock the database so that write conflicts or reading half-written data should not occur. +#[async_trait] +pub trait Storage: Clone + Send + Sync { + /// Check if a post exists in the database. + async fn post_exists(&self, url: &str) -> Result; + + /// Load a post from the database in MF2-JSON format, deserialized from JSON. + async fn get_post(&self, url: &str) -> Result>; + + /// Save a post to the database as an MF2-JSON structure. + /// + /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined. + async fn put_post<'a>(&self, post: &'a serde_json::Value) -> Result<()>; + + /*/// Save a post and add it to the relevant feeds listed in `post["properties"]["channel"]`. + /// + /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined + /// and `post["properties"]["channel"]` defined, even if it's empty. + async fn put_and_index_post<'a>(&mut self, post: &'a serde_json::Value) -> Result<()>;*/ + + /// Modify a post using an update object as defined in the Micropub spec. + /// + /// Note to implementors: the update operation MUST be atomic OR MUST lock the database + /// to prevent two clients overwriting each other's changes. + /// + /// You can assume concurrent updates will never contradict each other, since that will be dumb. + /// The last update always wins. + async fn update_post<'a>(&self, url: &'a str, update: serde_json::Value) -> Result<()>; + + /// Get a list of channels available for the user represented by the `user` object to write. + async fn get_channels(&self, user: &User) -> Result>; + + /// Fetch a feed at `url` and return a an h-feed object containing + /// `limit` posts after a post by url `after`, filtering the content + /// in context of a user specified by `user` (or an anonymous user). + /// + /// Specifically, private posts that don't include the user in the audience + /// will be elided from the feed, and the posts containing location and not + /// specifying post["properties"]["location-visibility"][0] == "public" + /// will have their location data (but not check-in data) stripped. + /// + /// This function is used as an optimization so the client, whatever it is, + /// doesn't have to fetch posts, then realize some of them are private, and + /// fetch more posts. + /// + /// Note for implementors: if you use streams to fetch posts in parallel + /// from the database, preferably make this method use a connection pool + /// to reduce overhead of creating a database connection per post for + /// parallel fetching. + async fn read_feed_with_limit<'a>(&self, url: &'a str, after: &'a Option, limit: usize, user: &'a Option) -> Result>; + + /// Deletes a post from the database irreversibly. 'nuff said. Must be idempotent. + async fn delete_post<'a>(&self, url: &'a str) -> Result<()>; +} + +#[cfg(test)] +mod tests { + use super::{Storage, MicropubChannel}; + use std::{process}; + use std::time::Duration; + use serde_json::json; + + async fn test_backend_basic_operations(backend: Backend) { + let post: serde_json::Value = json!({ + "type": ["h-entry"], + "properties": { + "content": ["Test content"], + "author": ["https://fireburn.ru/"], + "uid": ["https://fireburn.ru/posts/hello"], + "url": ["https://fireburn.ru/posts/hello", "https://fireburn.ru/posts/test"] + } + }); + let key = post["properties"]["uid"][0].as_str().unwrap().to_string(); + let alt_url = post["properties"]["url"][1].as_str().unwrap().to_string(); + + // Reading and writing + backend.put_post(&post).await.unwrap(); + if let Ok(Some(returned_post)) = backend.get_post(&key).await { + assert!(returned_post.is_object()); + assert_eq!(returned_post["type"].as_array().unwrap().len(), post["type"].as_array().unwrap().len()); + assert_eq!(returned_post["type"].as_array().unwrap(), post["type"].as_array().unwrap()); + let props: &serde_json::Map = post["properties"].as_object().unwrap(); + for key in props.keys() { + assert_eq!(returned_post["properties"][key].as_array().unwrap(), post["properties"][key].as_array().unwrap()) + } + } else { panic!("For some reason the backend did not return the post.") } + // Check the alternative URL - it should return the same post + if let Ok(Some(returned_post)) = backend.get_post(&alt_url).await { + assert!(returned_post.is_object()); + assert_eq!(returned_post["type"].as_array().unwrap().len(), post["type"].as_array().unwrap().len()); + assert_eq!(returned_post["type"].as_array().unwrap(), post["type"].as_array().unwrap()); + let props: &serde_json::Map = post["properties"].as_object().unwrap(); + for key in props.keys() { + assert_eq!(returned_post["properties"][key].as_array().unwrap(), post["properties"][key].as_array().unwrap()) + } + } else { panic!("For some reason the backend did not return the post.") } + } + + async fn test_backend_get_channel_list(backend: Backend) { + let feed = json!({ + "type": ["h-feed"], + "properties": { + "name": ["Main Page"], + "author": ["https://fireburn.ru/"], + "uid": ["https://fireburn.ru/feeds/main"] + }, + "children": [] + }); + backend.put_post(&feed).await.unwrap(); + let chans = backend.get_channels(&crate::indieauth::User::new("https://fireburn.ru/", "https://quill.p3k.io/", "create update media")).await.unwrap(); + assert_eq!(chans.len(), 1); + assert_eq!(chans[0], MicropubChannel { uid: "https://fireburn.ru/feeds/main".to_string(), name: "Main Page".to_string() }); + } + + #[async_std::test] + async fn test_memory_storage_basic_operations() { + let backend = super::MemoryStorage::new(); + test_backend_basic_operations(backend).await + } + #[async_std::test] + async fn test_memory_storage_channel_support() { + let backend = super::MemoryStorage::new(); + test_backend_get_channel_list(backend).await + } + + async fn get_redis_backend() -> (tempdir::TempDir, process::Child, super::RedisStorage) { + let tempdir = tempdir::TempDir::new("redis").expect("failed to create tempdir"); + let socket = tempdir.path().join("redis.sock"); + let redis_child = process::Command::new("redis-server") + .current_dir(&tempdir) + .arg("--port").arg("0") + .arg("--unixsocket").arg(&socket) + .stdout(process::Stdio::null()) + .stderr(process::Stdio::null()) + .spawn().expect("Failed to spawn Redis"); + println!("redis+unix:///{}", socket.to_str().unwrap()); + let uri = format!("redis+unix:///{}", socket.to_str().unwrap()); + // There should be a slight delay, we need to wait for Redis to spin up + let client = redis::Client::open(uri.clone()).unwrap(); + let millisecond = Duration::from_millis(1); + let mut retries: usize = 0; + const MAX_RETRIES: usize = 10 * 1000/*ms*/; + while let Err(err) = client.get_connection() { + if err.is_connection_refusal() { + async_std::task::sleep(millisecond).await; + retries += 1; + if retries > MAX_RETRIES { + panic!("Timeout waiting for Redis, last error: {}", err); + } + } else { + panic!("Could not connect: {}", err); + } + } + let backend = super::RedisStorage::new(uri).await.unwrap(); + + return (tempdir, redis_child, backend) + } + + #[async_std::test] + async fn test_redis_storage_basic_operations() { + let (_, mut redis, backend) = get_redis_backend().await; + test_backend_basic_operations(backend).await; + redis.kill().expect("Redis wasn't running"); + } + #[async_std::test] + async fn test_redis_storage_channel_support() { + let (_, mut redis, backend) = get_redis_backend().await; + test_backend_get_channel_list(backend).await; + redis.kill().expect("Redis wasn't running"); + } +} diff --git a/src/database/redis/edit_post.lua b/src/database/redis/edit_post.lua new file mode 100644 index 0000000..a398f8d --- /dev/null +++ b/src/database/redis/edit_post.lua @@ -0,0 +1,93 @@ +local posts = KEYS[1] +local update_desc = cjson.decode(ARGV[2]) +local post = cjson.decode(redis.call("HGET", posts, ARGV[1])) + +local delete_keys = {} +local delete_kvs = {} +local add_keys = {} + +if update_desc.replace ~= nil then + for k, v in pairs(update_desc.replace) do + table.insert(delete_keys, k) + add_keys[k] = v + end +end +if update_desc.delete ~= nil then + if update_desc.delete[0] == nil then + -- Table has string keys. Probably! + for k, v in pairs(update_desc.delete) do + delete_kvs[k] = v + end + else + -- Table has numeric keys. Probably! + for i, v in ipairs(update_desc.delete) do + table.insert(delete_keys, v) + end + end +end +if update_desc.add ~= nil then + for k, v in pairs(update_desc.add) do + add_keys[k] = v + end +end + +for i, v in ipairs(delete_keys) do + post["properties"][v] = nil + -- TODO delete URL links +end + +for k, v in pairs(delete_kvs) do + local index = -1 + if k == "children" then + for j, w in ipairs(post[k]) do + if w == v then + index = j + break + end + end + if index > -1 then + table.remove(post[k], index) + end + else + for j, w in ipairs(post["properties"][k]) do + if w == v then + index = j + break + end + end + if index > -1 then + table.remove(post["properties"][k], index) + -- TODO delete URL links + end + end +end + +for k, v in pairs(add_keys) do + if k == "children" then + if post["children"] == nil then + post["children"] = {} + end + for i, w in ipairs(v) do + table.insert(post["children"], 1, w) + end + else + if post["properties"][k] == nil then + post["properties"][k] = {} + end + for i, w in ipairs(v) do + table.insert(post["properties"][k], w) + end + if k == "url" then + redis.call("HSET", posts, v, cjson.encode({ see_other = post["properties"]["uid"][1] })) + elseif k == "channel" then + local feed = cjson.decode(redis.call("HGET", posts, v)) + table.insert(feed["children"], 1, post["properties"]["uid"][1]) + redis.call("HSET", posts, v, cjson.encode(feed)) + end + end +end + +local encoded = cjson.encode(post) +redis.call("SET", "debug", encoded) +redis.call("HSET", posts, post["properties"]["uid"][1], encoded) +return \ No newline at end of file diff --git a/src/database/redis/mod.rs b/src/database/redis/mod.rs new file mode 100644 index 0000000..2377fac --- /dev/null +++ b/src/database/redis/mod.rs @@ -0,0 +1,304 @@ +use async_trait::async_trait; +use futures_util::FutureExt; +use futures_util::StreamExt; +use futures::stream; +use lazy_static::lazy_static; +use log::error; +use redis; +use redis::AsyncCommands; +use serde_json::json; + +use crate::database::{Storage, Result, StorageError, ErrorKind, MicropubChannel}; +use crate::indieauth::User; + +struct RedisScripts { + edit_post: redis::Script +} + +impl From for StorageError { + fn from(err: redis::RedisError) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::Backend + } + } +} + +lazy_static! { + static ref SCRIPTS: RedisScripts = RedisScripts { + edit_post: redis::Script::new(include_str!("./edit_post.lua")) + }; +} + +#[derive(Clone)] +pub struct RedisStorage { + // TODO: use mobc crate to create a connection pool and reuse connections for efficiency + redis: redis::Client, +} + +fn filter_post<'a>(mut post: serde_json::Value, user: &'a Option) -> Option { + if post["properties"]["deleted"][0].is_string() { + return Some(json!({ + "type": post["type"], + "properties": { + "deleted": post["properties"]["deleted"] + } + })); + } + let empty_vec: Vec = vec![]; + let author = post["properties"]["author"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string()); + let visibility = post["properties"]["visibility"][0].as_str().unwrap_or("public"); + let mut audience = author.chain(post["properties"]["audience"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string())); + if visibility == "private" { + if !audience.any(|i| Some(i) == *user) { + return None + } + } + if post["properties"]["location"].is_array() { + let location_visibility = post["properties"]["location-visibility"][0].as_str().unwrap_or("private"); + let mut author = post["properties"]["author"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string()); + if location_visibility == "private" && !author.any(|i| Some(i) == *user) { + post["properties"].as_object_mut().unwrap().remove("location"); + } + } + Some(post) +} + +#[async_trait] +impl Storage for RedisStorage { + async fn delete_post<'a>(&self, url: &'a str) -> Result<()> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => if let Err(err) = conn.hdel::<&str, &str, bool>("posts", url).await { + return Err(err.into()); + }, + Err(err) => return Err(err.into()) + } + Ok(()) + } + + async fn post_exists(&self, url: &str) -> Result { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => match conn.hexists::<&str, &str, bool>(&"posts", url).await { + Ok(val) => Ok(val), + Err(err) => Err(err.into()) + }, + Err(err) => Err(err.into()) + } + } + + async fn get_post(&self, url: &str) -> Result> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => match conn.hget::<&str, &str, Option>(&"posts", url).await { + Ok(val) => match val { + Some(val) => match serde_json::from_str::(&val) { + Ok(parsed) => if let Some(new_url) = parsed["see_other"].as_str() { + match conn.hget::<&str, &str, Option>(&"posts", new_url).await { + Ok(val) => match val { + Some(val) => match serde_json::from_str::(&val) { + Ok(parsed) => Ok(Some(parsed)), + Err(err) => Err(err.into()) + }, + None => Ok(None) + } + Err(err) => { + Err(err.into()) + } + } + } else { + Ok(Some(parsed)) + }, + Err(err) => Err(err.into()) + }, + None => Ok(None) + }, + Err(err) => Err(err.into()) + }, + Err(err) => Err(err.into()) + } + } + + async fn get_channels(&self, user: &User) -> Result> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => match conn.smembers::>("channels_".to_string() + user.me.as_str()).await { + Ok(channels) => { + Ok(futures_util::future::join_all(channels.iter() + .map(|channel| self.get_post(channel) + .map(|result| result.unwrap()) + .map(|post: Option| { + if let Some(post) = post { + Some(MicropubChannel { + uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), + name: post["properties"]["name"][0].as_str().unwrap().to_string() + }) + } else { None } + }) + ).collect::>()).await.into_iter().filter_map(|chan| chan).collect::>()) + }, + Err(err) => Err(err.into()) + }, + Err(err) => Err(err.into()) + } + } + + async fn put_post<'a>(&self, post: &'a serde_json::Value) -> Result<()> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => { + let key: &str; + match post["properties"]["uid"][0].as_str() { + Some(uid) => key = uid, + None => return Err(StorageError::new(ErrorKind::Other, "post doesn't have a UID")) + } + match conn.hset::<&str, &str, String, ()>(&"posts", key, post.to_string()).await { + Err(err) => return Err(err.into()), + _ => {} + } + if post["properties"]["url"].is_array() { + for url in post["properties"]["url"].as_array().unwrap().iter().map(|i| i.as_str().unwrap().to_string()) { + if &url != key { + match conn.hset::<&str, &str, String, ()>(&"posts", &url, json!({"see_other": key}).to_string()).await { + Err(err) => return Err(err.into()), + _ => {} + } + } + } + } + if post["type"].as_array().unwrap().iter().any(|i| i == "h-feed") { + // This is a feed. Add it to the channels array if it's not already there. + match conn.sadd::("channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(), key).await { + Err(err) => return Err(err.into()), + _ => {}, + } + } + Ok(()) + }, + Err(err) => Err(err.into()) + } + } + + async fn read_feed_with_limit<'a>(&self, url: &'a str, after: &'a Option, limit: usize, user: &'a Option) -> Result> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => { + let mut feed; + match conn.hget::<&str, &str, Option>(&"posts", url).await { + Ok(post) => { + match post { + Some(post) => match serde_json::from_str::(&post) { + Ok(post) => feed = post, + Err(err) => return Err(err.into()) + }, + None => return Ok(None) + } + }, + Err(err) => return Err(err.into()) + } + if feed["see_other"].is_string() { + match conn.hget::<&str, &str, Option>(&"posts", feed["see_other"].as_str().unwrap()).await { + Ok(post) => { + match post { + Some(post) => match serde_json::from_str::(&post) { + Ok(post) => feed = post, + Err(err) => return Err(err.into()) + }, + None => return Ok(None) + } + }, + Err(err) => return Err(err.into()) + } + } + if let Some(post) = filter_post(feed, user) { + feed = post + } else { + return Err(StorageError::new(ErrorKind::PermissionDenied, "specified user cannot access this post")) + } + if feed["children"].is_array() { + let children = feed["children"].as_array().unwrap(); + let posts_iter: Box + Send>; + if let Some(after) = after { + posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string()).skip_while(move |i| i != after).skip(1)); + } else { + posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string())); + } + let posts = stream::iter(posts_iter) + .map(|url| async move { + // Is it rational to use a new connection for every post fetched? + match self.redis.get_async_std_connection().await { + Ok(mut conn) => match conn.hget::<&str, &str, Option>("posts", &url).await { + Ok(post) => match post { + Some(post) => match serde_json::from_str::(&post) { + Ok(post) => Some(post), + Err(err) => { + let err = StorageError::from(err); + error!("{}", err); + panic!("{}", err) + } + }, + // Happens because of a broken link (result of an improper deletion?) + None => None, + }, + Err(err) => { + let err = StorageError::from(err); + error!("{}", err); + panic!("{}", err) + } + }, + Err(err) => { + let err = StorageError::from(err); + error!("{}", err); + panic!("{}", err) + } + } + }) + // TODO: determine the optimal value for this buffer + // It will probably depend on how often can you encounter a private post on the page + // It shouldn't be too large, or we'll start fetching too many posts from the database + // It MUST NOT be larger than the typical page size + .buffered(std::cmp::min(3, limit)) + // Hack to unwrap the Option and sieve out broken links + // Broken links return None, and Stream::filter_map skips all Nones. + .filter_map(|post: Option| async move { post }) + .filter_map(|post| async move { + return filter_post(post, user) + }) + .take(limit); + match std::panic::AssertUnwindSafe(posts.collect::>()).catch_unwind().await { + Ok(posts) => feed["children"] = json!(posts), + Err(err) => return Err(StorageError::new(ErrorKind::Other, "Unknown error encountered while assembling feed, see logs for more info")) + } + } + return Ok(Some(feed)); + } + Err(err) => Err(err.into()) + } + } + + async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => { + if !conn.hexists::<&str, &str, bool>("posts", url).await.unwrap() { + return Err(StorageError::new(ErrorKind::NotFound, "can't edit a non-existent post")) + } + let post: serde_json::Value = serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await.unwrap()).unwrap(); + if let Some(new_url) = post["see_other"].as_str() { + url = new_url + } + if let Err(err) = SCRIPTS.edit_post.key("posts").arg(url).arg(update.to_string()).invoke_async::<_, ()>(&mut conn).await { + return Err(err.into()) + } + }, + Err(err) => return Err(err.into()) + } + Ok(()) + } +} + + +impl RedisStorage { + /// Create a new RedisDatabase that will connect to Redis at `redis_uri` to store data. + pub async fn new(redis_uri: String) -> Result { + match redis::Client::open(redis_uri) { + Ok(client) => Ok(Self { redis: client }), + Err(e) => Err(e.into()) + } + } +} \ No newline at end of file diff --git a/src/edit_post.lua b/src/edit_post.lua deleted file mode 100644 index a398f8d..0000000 --- a/src/edit_post.lua +++ /dev/null @@ -1,93 +0,0 @@ -local posts = KEYS[1] -local update_desc = cjson.decode(ARGV[2]) -local post = cjson.decode(redis.call("HGET", posts, ARGV[1])) - -local delete_keys = {} -local delete_kvs = {} -local add_keys = {} - -if update_desc.replace ~= nil then - for k, v in pairs(update_desc.replace) do - table.insert(delete_keys, k) - add_keys[k] = v - end -end -if update_desc.delete ~= nil then - if update_desc.delete[0] == nil then - -- Table has string keys. Probably! - for k, v in pairs(update_desc.delete) do - delete_kvs[k] = v - end - else - -- Table has numeric keys. Probably! - for i, v in ipairs(update_desc.delete) do - table.insert(delete_keys, v) - end - end -end -if update_desc.add ~= nil then - for k, v in pairs(update_desc.add) do - add_keys[k] = v - end -end - -for i, v in ipairs(delete_keys) do - post["properties"][v] = nil - -- TODO delete URL links -end - -for k, v in pairs(delete_kvs) do - local index = -1 - if k == "children" then - for j, w in ipairs(post[k]) do - if w == v then - index = j - break - end - end - if index > -1 then - table.remove(post[k], index) - end - else - for j, w in ipairs(post["properties"][k]) do - if w == v then - index = j - break - end - end - if index > -1 then - table.remove(post["properties"][k], index) - -- TODO delete URL links - end - end -end - -for k, v in pairs(add_keys) do - if k == "children" then - if post["children"] == nil then - post["children"] = {} - end - for i, w in ipairs(v) do - table.insert(post["children"], 1, w) - end - else - if post["properties"][k] == nil then - post["properties"][k] = {} - end - for i, w in ipairs(v) do - table.insert(post["properties"][k], w) - end - if k == "url" then - redis.call("HSET", posts, v, cjson.encode({ see_other = post["properties"]["uid"][1] })) - elseif k == "channel" then - local feed = cjson.decode(redis.call("HGET", posts, v)) - table.insert(feed["children"], 1, post["properties"]["uid"][1]) - redis.call("HSET", posts, v, cjson.encode(feed)) - end - end -end - -local encoded = cjson.encode(post) -redis.call("SET", "debug", encoded) -redis.call("HSET", posts, post["properties"]["uid"][1], encoded) -return \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 459ad23..219aec6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -102,9 +102,19 @@ mod tests { use crate::database::Storage; use mockito::mock; + // Helpers async fn create_app() -> (database::MemoryStorage, App) { get_app_with_memory_for_testing(surf::Url::parse(&*mockito::server_url()).unwrap()).await } + + async fn post_json(app: &App, json: serde_json::Value) -> surf::Response { + let request = app.post("/micropub") + .header("Authorization", "Bearer test") + .header("Content-Type", "application/json") + .body(json); + return request.send().await.unwrap(); + } + #[async_std::test] async fn test_no_posting_to_others_websites() { let _m = mock("GET", "/") @@ -115,43 +125,31 @@ mod tests { let (db, app) = create_app().await; - let request: surf::RequestBuilder = app.post("/micropub") - .header("Authorization", "Bearer test") - .header("Content-Type", "application/json") - .body(json!({ - "type": ["h-entry"], - "properties": { - "content": ["Fake news about Aaron Parecki!"], - "uid": ["https://aaronparecki.com/posts/fake-news"] - } - })); - let response = request.send().await.unwrap(); + let response = post_json(&app, json!({ + "type": ["h-entry"], + "properties": { + "content": ["Fake news about Aaron Parecki!"], + "uid": ["https://aaronparecki.com/posts/fake-news"] + } + })).await; assert_eq!(response.status(), 403); - let request: surf::RequestBuilder = app.post("/micropub") - .header("Authorization", "Bearer test") - .header("Content-Type", "application/json") - .body(json!({ - "type": ["h-entry"], - "properties": { - "content": ["More fake news about Aaron Parecki!"], - "url": ["https://aaronparecki.com/posts/more-fake-news"] - } - })); - let response = request.send().await.unwrap(); + let response = post_json(&app, json!({ + "type": ["h-entry"], + "properties": { + "content": ["More fake news about Aaron Parecki!"], + "url": ["https://aaronparecki.com/posts/more-fake-news"] + } + })).await; assert_eq!(response.status(), 403); - let request: surf::RequestBuilder = app.post("/micropub") - .header("Authorization", "Bearer test") - .header("Content-Type", "application/json") - .body(json!({ - "type": ["h-entry"], - "properties": { - "content": ["Sneaky advertisement designed to creep into someone else's feed! Buy whatever I'm promoting!"], - "channel": ["https://aaronparecki.com/feeds/main"] - } - })); - let response = request.send().await.unwrap(); + let response = post_json(&app, json!({ + "type": ["h-entry"], + "properties": { + "content": ["Sneaky advertisement designed to creep into someone else's feed! Buy whatever I'm promoting!"], + "channel": ["https://aaronparecki.com/feeds/main"] + } + })).await; assert_eq!(response.status(), 403); } @@ -229,16 +227,12 @@ mod tests { let (storage, app) = create_app().await; - let request: surf::RequestBuilder = app.post("/micropub") - .header("Authorization", "Bearer test") - .header("Content-Type", "application/json") - .body(json!({ - "type": ["h-entry"], - "properties": { - "content": ["This is content!"] - } - })); - let mut response: surf::Response = request.send().await.unwrap(); + let mut response = post_json(&app, json!({ + "type": ["h-entry"], + "properties": { + "content": ["This is content!"] + } + })).await; println!("{:#}", response.body_json::().await.unwrap()); assert!(response.status() == 201 || response.status() == 202); let uid = response.header("Location").unwrap().last().to_string(); @@ -248,20 +242,14 @@ mod tests { let feed = storage.get_post("https://fireburn.ru/feeds/main").await.unwrap().unwrap(); assert_eq!(feed["children"].as_array().unwrap().len(), 1); assert_eq!(feed["children"][0].as_str().unwrap(), uid); - - let request: surf::RequestBuilder = app.post("/micropub") - .header("Authorization", "Bearer test") - .header("Content-Type", "application/json") - .body(json!({ - "type": ["h-entry"], - "properties": { - "content": ["#moar content for you!"] - } - })); - let first_uid = uid; - - let mut response: surf::Response = request.send().await.unwrap(); + // Test creation of a second post + let mut response = post_json(&app, json!({ + "type": ["h-entry"], + "properties": { + "content": ["#moar content for you!"] + } + })).await; println!("{:#}", response.body_json::().await.unwrap()); assert!(response.status() == 201 || response.status() == 202); let uid = response.header("Location").unwrap().last().to_string(); diff --git a/src/micropub/post.rs b/src/micropub/post.rs index 38b205b..680f242 100644 --- a/src/micropub/post.rs +++ b/src/micropub/post.rs @@ -1,7 +1,6 @@ use core::iter::Iterator; use std::str::FromStr; use std::convert::TryInto; -use async_std::sync::RwLockUpgradableReadGuard; use chrono::prelude::*; use http_types::Mime; use tide::prelude::json; -- cgit 1.4.1