diff options
Diffstat (limited to 'src/database.rs')
-rw-r--r-- | src/database.rs | 625 |
1 files changed, 625 insertions, 0 deletions
diff --git a/src/database.rs b/src/database.rs new file mode 100644 index 0000000..3a9ac04 --- /dev/null +++ b/src/database.rs @@ -0,0 +1,625 @@ +#![allow(unused_variables)] +use async_trait::async_trait; +use lazy_static::lazy_static; +#[cfg(test)] +use std::collections::HashMap; +#[cfg(test)] +use std::sync::Arc; +#[cfg(test)] +use async_std::sync::RwLock; +use log::error; +use futures::stream; +use futures_util::FutureExt; +use futures_util::StreamExt; +use serde::{Serialize,Deserialize}; +use serde_json::json; +use redis; +use redis::AsyncCommands; + +use crate::indieauth::User; + +#[derive(Serialize, Deserialize, PartialEq, Debug)] +pub struct MicropubChannel { + pub uid: String, + pub name: String +} + +#[derive(Debug, Clone, Copy)] +pub enum ErrorKind { + Backend, + PermissionDenied, + JSONParsing, + NotFound, + Other +} + +// TODO get rid of your own errors and use a crate? +#[derive(Debug)] +pub struct StorageError { + msg: String, + source: Option<Box<dyn std::error::Error>>, + kind: ErrorKind +} +unsafe impl Send for StorageError {} +unsafe impl Sync for StorageError {} +impl std::error::Error for StorageError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.source.as_ref().map(|e| e.as_ref()) + } +} +impl From<redis::RedisError> for StorageError { + fn from(err: redis::RedisError) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::Backend + } + } +} +impl From<serde_json::Error> for StorageError { + fn from(err: serde_json::Error) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::JSONParsing + } + } +} +impl std::fmt::Display for StorageError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match match self.kind { + ErrorKind::Backend => write!(f, "backend error: "), + ErrorKind::JSONParsing => write!(f, "error while parsing JSON: "), + ErrorKind::PermissionDenied => write!(f, "permission denied: "), + ErrorKind::NotFound => write!(f, "not found: "), + ErrorKind::Other => write!(f, "generic storage layer error: ") + } { + Ok(_) => write!(f, "{}", self.msg), + Err(err) => Err(err) + } + } +} +impl serde::Serialize for StorageError { + fn serialize<S: serde::Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> { + serializer.serialize_str(&self.to_string()) + } +} +impl StorageError { + fn new(kind: ErrorKind, msg: &str) -> Self { + return StorageError { + msg: msg.to_string(), + source: None, + kind + } + } + pub fn kind(&self) -> ErrorKind { + self.kind + } +} + +pub type Result<T> = std::result::Result<T, StorageError>; + +/// A storage backend for the Micropub server. +/// +/// Implementations should note that all methods listed on this trait MUST be fully atomic +/// or lock the database so that write conflicts or reading half-written data should not occur. +#[async_trait] +pub trait Storage: Clone + Send + Sync { + /// Check if a post exists in the database. + async fn post_exists(&self, url: &str) -> Result<bool>; + + /// Load a post from the database in MF2-JSON format, deserialized from JSON. + async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>>; + + /// Save a post to the database as an MF2-JSON structure. + /// + /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined. + async fn put_post<'a>(&self, post: &'a serde_json::Value) -> Result<()>; + + /*/// Save a post and add it to the relevant feeds listed in `post["properties"]["channel"]`. + /// + /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined + /// and `post["properties"]["channel"]` defined, even if it's empty. + async fn put_and_index_post<'a>(&mut self, post: &'a serde_json::Value) -> Result<()>;*/ + + /// Modify a post using an update object as defined in the Micropub spec. + /// + /// Note to implementors: the update operation MUST be atomic OR MUST lock the database + /// to prevent two clients overwriting each other's changes. + /// + /// You can assume concurrent updates will never contradict each other, since that will be dumb. + /// The last update always wins. + async fn update_post<'a>(&self, url: &'a str, update: serde_json::Value) -> Result<()>; + + /// Get a list of channels available for the user represented by the `user` object to write. + async fn get_channels(&self, user: &User) -> Result<Vec<MicropubChannel>>; + + /// Fetch a feed at `url` and return a an h-feed object containing + /// `limit` posts after a post by url `after`, filtering the content + /// in context of a user specified by `user` (or an anonymous user). + /// + /// Specifically, private posts that don't include the user in the audience + /// will be elided from the feed, and the posts containing location and not + /// specifying post["properties"]["location-visibility"][0] == "public" + /// will have their location data (but not check-in data) stripped. + /// + /// This function is used as an optimization so the client, whatever it is, + /// doesn't have to fetch posts, then realize some of them are private, and + /// fetch more posts. + /// + /// Note for implementors: if you use streams to fetch posts in parallel + /// from the database, preferably make this method use a connection pool + /// to reduce overhead of creating a database connection per post for + /// parallel fetching. + async fn read_feed_with_limit<'a>(&self, url: &'a str, after: &'a Option<String>, limit: usize, user: &'a Option<String>) -> Result<Option<serde_json::Value>>; + + /// Deletes a post from the database irreversibly. 'nuff said. Must be idempotent. + async fn delete_post<'a>(&self, url: &'a str) -> Result<()>; +} + +#[cfg(test)] +#[derive(Clone)] +pub struct MemoryStorage { + pub mapping: Arc<RwLock<HashMap<String, serde_json::Value>>>, + pub channels: Arc<RwLock<HashMap<String, Vec<String>>>> +} + +#[cfg(test)] +#[async_trait] +impl Storage for MemoryStorage { + async fn read_feed_with_limit<'a>(&self, url: &'a str, after: &'a Option<String>, limit: usize, user: &'a Option<String>) -> Result<Option<serde_json::Value>> { + todo!() + } + + async fn update_post<'a>(&self, url: &'a str, update: serde_json::Value) -> Result<()> { + todo!() + } + + async fn delete_post<'a>(&self, url: &'a str) -> Result<()> { + self.mapping.write().await.remove(url); + Ok(()) + } + + async fn post_exists(&self, url: &str) -> Result<bool> { + return Ok(self.mapping.read().await.contains_key(url)) + } + + async fn get_post(&self, url: &str) ->Result<Option<serde_json::Value>> { + let mapping = self.mapping.read().await; + match mapping.get(url) { + Some(val) => { + if let Some(new_url) = val["see_other"].as_str() { + match mapping.get(new_url) { + Some(val) => Ok(Some(val.clone())), + None => { + drop(mapping); + self.mapping.write().await.remove(url); + Ok(None) + } + } + } else { + Ok(Some(val.clone())) + } + }, + _ => Ok(None) + } + } + + async fn get_channels(&self, user: &User) -> Result<Vec<MicropubChannel>> { + match self.channels.read().await.get(&user.me.to_string()) { + Some(channels) => Ok(futures_util::future::join_all(channels.iter() + .map(|channel| self.get_post(channel) + .map(|result| result.unwrap()) + .map(|post: Option<serde_json::Value>| { + if let Some(post) = post { + Some(MicropubChannel { + uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), + name: post["properties"]["name"][0].as_str().unwrap().to_string() + }) + } else { None } + }) + ).collect::<Vec<_>>()).await.into_iter().filter_map(|chan| chan).collect::<Vec<_>>()), + None => Ok(vec![]) + } + + } + + async fn put_post<'a>(&self, post: &'a serde_json::Value) -> Result<()> { + let mapping = &mut self.mapping.write().await; + let key: &str; + match post["properties"]["uid"][0].as_str() { + Some(uid) => key = uid, + None => return Err(StorageError::new(ErrorKind::Other, "post doesn't have a UID")) + } + mapping.insert(key.to_string(), post.clone()); + if post["properties"]["url"].is_array() { + for url in post["properties"]["url"].as_array().unwrap().iter().map(|i| i.as_str().unwrap().to_string()) { + if &url != key { + mapping.insert(url, json!({"see_other": key})); + } + } + } + if post["type"].as_array().unwrap().iter().any(|i| i == "h-feed") { + // This is a feed. Add it to the channels array if it's not already there. + println!("{:#}", post); + self.channels.write().await.entry(post["properties"]["author"][0].as_str().unwrap().to_string()).or_insert(vec![]).push(key.to_string()) + } + Ok(()) + } +} +#[cfg(test)] +impl MemoryStorage { + pub fn new() -> Self { + Self { + mapping: Arc::new(RwLock::new(HashMap::new())), + channels: Arc::new(RwLock::new(HashMap::new())) + } + } +} + +struct RedisScripts { + edit_post: redis::Script +} + +lazy_static! { + static ref SCRIPTS: RedisScripts = RedisScripts { + edit_post: redis::Script::new(include_str!("./edit_post.lua")) + }; +} + +#[derive(Clone)] +pub struct RedisStorage { + // TODO: use mobc crate to create a connection pool and reuse connections for efficiency + redis: redis::Client, +} + +fn filter_post<'a>(mut post: serde_json::Value, user: &'a Option<String>) -> Option<serde_json::Value> { + if post["properties"]["deleted"][0].is_string() { + return Some(json!({ + "type": post["type"], + "properties": { + "deleted": post["properties"]["deleted"] + } + })); + } + let empty_vec: Vec<serde_json::Value> = vec![]; + let author = post["properties"]["author"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string()); + let visibility = post["properties"]["visibility"][0].as_str().unwrap_or("public"); + let mut audience = author.chain(post["properties"]["audience"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string())); + if visibility == "private" { + if !audience.any(|i| Some(i) == *user) { + return None + } + } + if post["properties"]["location"].is_array() { + let location_visibility = post["properties"]["location-visibility"][0].as_str().unwrap_or("private"); + let mut author = post["properties"]["author"].as_array().unwrap_or(&empty_vec).iter().map(|i| i.as_str().unwrap().to_string()); + if location_visibility == "private" && !author.any(|i| Some(i) == *user) { + post["properties"].as_object_mut().unwrap().remove("location"); + } + } + Some(post) +} + +#[async_trait] +impl Storage for RedisStorage { + async fn delete_post<'a>(&self, url: &'a str) -> Result<()> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => if let Err(err) = conn.hdel::<&str, &str, bool>("posts", url).await { + return Err(err.into()); + }, + Err(err) => return Err(err.into()) + } + Ok(()) + } + + async fn post_exists(&self, url: &str) -> Result<bool> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => match conn.hexists::<&str, &str, bool>(&"posts", url).await { + Ok(val) => Ok(val), + Err(err) => Err(err.into()) + }, + Err(err) => Err(err.into()) + } + } + + async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => match conn.hget::<&str, &str, Option<String>>(&"posts", url).await { + Ok(val) => match val { + Some(val) => match serde_json::from_str::<serde_json::Value>(&val) { + Ok(parsed) => if let Some(new_url) = parsed["see_other"].as_str() { + match conn.hget::<&str, &str, Option<String>>(&"posts", new_url).await { + Ok(val) => match val { + Some(val) => match serde_json::from_str::<serde_json::Value>(&val) { + Ok(parsed) => Ok(Some(parsed)), + Err(err) => Err(err.into()) + }, + None => Ok(None) + } + Err(err) => { + Ok(None) + } + } + } else { + Ok(Some(parsed)) + }, + Err(err) => Err(err.into()) + }, + None => Ok(None) + }, + Err(err) => Err(err.into()) + }, + Err(err) => Err(err.into()) + } + } + + async fn get_channels(&self, user: &User) -> Result<Vec<MicropubChannel>> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => match conn.smembers::<String, Vec<String>>("channels_".to_string() + user.me.as_str()).await { + Ok(channels) => { + Ok(futures_util::future::join_all(channels.iter() + .map(|channel| self.get_post(channel) + .map(|result| result.unwrap()) + .map(|post: Option<serde_json::Value>| { + if let Some(post) = post { + Some(MicropubChannel { + uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), + name: post["properties"]["name"][0].as_str().unwrap().to_string() + }) + } else { None } + }) + ).collect::<Vec<_>>()).await.into_iter().filter_map(|chan| chan).collect::<Vec<_>>()) + }, + Err(err) => Err(err.into()) + }, + Err(err) => Err(err.into()) + } + } + + async fn put_post<'a>(&self, post: &'a serde_json::Value) -> Result<()> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => { + let key: &str; + match post["properties"]["uid"][0].as_str() { + Some(uid) => key = uid, + None => return Err(StorageError::new(ErrorKind::Other, "post doesn't have a UID")) + } + match conn.hset::<&str, &str, String, ()>(&"posts", key, post.to_string()).await { + Err(err) => return Err(err.into()), + _ => {} + } + if post["properties"]["url"].is_array() { + for url in post["properties"]["url"].as_array().unwrap().iter().map(|i| i.as_str().unwrap().to_string()) { + if &url != key { + match conn.hset::<&str, &str, String, ()>(&"posts", &url, json!({"see_other": key}).to_string()).await { + Err(err) => return Err(err.into()), + _ => {} + } + } + } + } + if post["type"].as_array().unwrap().iter().any(|i| i == "h-feed") { + // This is a feed. Add it to the channels array if it's not already there. + match conn.sadd::<String, &str, ()>("channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(), key).await { + Err(err) => return Err(err.into()), + _ => {}, + } + } + Ok(()) + }, + Err(err) => Err(err.into()) + } + } + + async fn read_feed_with_limit<'a>(&self, url: &'a str, after: &'a Option<String>, limit: usize, user: &'a Option<String>) -> Result<Option<serde_json::Value>> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => { + let mut feed; + match conn.hget::<&str, &str, Option<String>>(&"posts", url).await { + Ok(post) => { + match post { + Some(post) => match serde_json::from_str::<serde_json::Value>(&post) { + Ok(post) => feed = post, + Err(err) => return Err(err.into()) + }, + None => return Ok(None) + } + }, + Err(err) => return Err(err.into()) + } + if feed["see_other"].is_string() { + match conn.hget::<&str, &str, Option<String>>(&"posts", feed["see_other"].as_str().unwrap()).await { + Ok(post) => { + match post { + Some(post) => match serde_json::from_str::<serde_json::Value>(&post) { + Ok(post) => feed = post, + Err(err) => return Err(err.into()) + }, + None => return Ok(None) + } + }, + Err(err) => return Err(err.into()) + } + } + if let Some(post) = filter_post(feed, user) { + feed = post + } else { + return Err(StorageError::new(ErrorKind::PermissionDenied, "specified user cannot access this post")) + } + if feed["children"].is_array() { + let children = feed["children"].as_array().unwrap(); + let posts_iter: Box<dyn std::iter::Iterator<Item = String> + Send>; + if let Some(after) = after { + posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string()).skip_while(move |i| i != after).skip(1)); + } else { + posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string())); + } + let posts = stream::iter(posts_iter) + .map(|url| async move { + // Is it rational to use a new connection for every post fetched? + match self.redis.get_async_std_connection().await { + Ok(mut conn) => match conn.hget::<&str, &str, Option<String>>("posts", &url).await { + Ok(post) => match post { + Some(post) => match serde_json::from_str::<serde_json::Value>(&post) { + Ok(post) => Some(post), + Err(err) => { + let err = StorageError::from(err); + error!("{}", err); + panic!("{}", err) + } + }, + // Happens because of a broken link (result of an improper deletion?) + None => None, + }, + Err(err) => { + let err = StorageError::from(err); + error!("{}", err); + panic!("{}", err) + } + }, + Err(err) => { + let err = StorageError::from(err); + error!("{}", err); + panic!("{}", err) + } + } + }) + // TODO: determine the optimal value for this buffer + // It will probably depend on how often can you encounter a private post on the page + // It shouldn't be too large, or we'll start fetching too many posts from the database + // It MUST NOT be larger than the typical page size + .buffered(std::cmp::min(3, limit)) + // Hack to unwrap the Option and sieve out broken links + // Broken links return None, and Stream::filter_map skips all Nones. + .filter_map(|post: Option<serde_json::Value>| async move { post }) + .filter_map(|post| async move { + return filter_post(post, user) + }) + .take(limit); + match std::panic::AssertUnwindSafe(posts.collect::<Vec<serde_json::Value>>()).catch_unwind().await { + Ok(posts) => feed["children"] = json!(posts), + Err(err) => return Err(StorageError::new(ErrorKind::Other, "Unknown error encountered while assembling feed, see logs for more info")) + } + } + return Ok(Some(feed)); + } + Err(err) => Err(err.into()) + } + } + + async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> { + match self.redis.get_async_std_connection().await { + Ok(mut conn) => { + if !conn.hexists::<&str, &str, bool>("posts", url).await.unwrap() { + return Err(StorageError::new(ErrorKind::NotFound, "can't edit a non-existent post")) + } + let post: serde_json::Value = serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await.unwrap()).unwrap(); + if let Some(new_url) = post["see_other"].as_str() { + url = new_url + } + if let Err(err) = SCRIPTS.edit_post.key("posts").arg(url).arg(update.to_string()).invoke_async::<_, ()>(&mut conn).await { + return Err(err.into()) + } + }, + Err(err) => return Err(err.into()) + } + Ok(()) + } +} + + +impl RedisStorage { + /// Create a new RedisDatabase that will connect to Redis at `redis_uri` to store data. + pub async fn new(redis_uri: String) -> Result<Self> { + match redis::Client::open(redis_uri) { + Ok(client) => Ok(Self { redis: client }), + Err(e) => Err(e.into()) + } + } +} + +#[cfg(test)] +mod tests { + use super::{Storage, MicropubChannel}; + use serde_json::json; + + async fn test_backend_basic_operations<Backend: Storage>(backend: Backend) { + let post: serde_json::Value = json!({ + "type": ["h-entry"], + "properties": { + "content": ["Test content"], + "author": ["https://fireburn.ru/"], + "uid": ["https://fireburn.ru/posts/hello"], + "url": ["https://fireburn.ru/posts/hello", "https://fireburn.ru/posts/test"] + } + }); + let key = post["properties"]["uid"][0].as_str().unwrap().to_string(); + let alt_url = post["properties"]["url"][1].as_str().unwrap().to_string(); + + // Reading and writing + backend.put_post(&post).await.unwrap(); + if let Ok(Some(returned_post)) = backend.get_post(&key).await { + assert!(returned_post.is_object()); + assert_eq!(returned_post["type"].as_array().unwrap().len(), post["type"].as_array().unwrap().len()); + assert_eq!(returned_post["type"].as_array().unwrap(), post["type"].as_array().unwrap()); + let props: &serde_json::Map<String, serde_json::Value> = post["properties"].as_object().unwrap(); + for key in props.keys() { + assert_eq!(returned_post["properties"][key].as_array().unwrap(), post["properties"][key].as_array().unwrap()) + } + } else { panic!("For some reason the backend did not return the post.") } + // Check the alternative URL - it should return the same post + if let Ok(Some(returned_post)) = backend.get_post(&alt_url).await { + assert!(returned_post.is_object()); + assert_eq!(returned_post["type"].as_array().unwrap().len(), post["type"].as_array().unwrap().len()); + assert_eq!(returned_post["type"].as_array().unwrap(), post["type"].as_array().unwrap()); + let props: &serde_json::Map<String, serde_json::Value> = post["properties"].as_object().unwrap(); + for key in props.keys() { + assert_eq!(returned_post["properties"][key].as_array().unwrap(), post["properties"][key].as_array().unwrap()) + } + } else { panic!("For some reason the backend did not return the post.") } + } + + async fn test_backend_channel_support<Backend: Storage>(backend: Backend) { + let feed = json!({ + "type": ["h-feed"], + "properties": { + "name": ["Main Page"], + "author": ["https://fireburn.ru/"], + "uid": ["https://fireburn.ru/feeds/main"] + }, + "children": [] + }); + let key = feed["properties"]["uid"][0].as_str().unwrap().to_string(); + backend.put_post(&feed).await.unwrap(); + let chans = backend.get_channels(&crate::indieauth::User::new("https://fireburn.ru/", "https://quill.p3k.io/", "create update media")).await.unwrap(); + assert_eq!(chans.len(), 1); + assert_eq!(chans[0], MicropubChannel { uid: "https://fireburn.ru/feeds/main".to_string(), name: "Main Page".to_string() }); + } + + #[async_std::test] + async fn test_memory_storage_basic_operations() { + let backend = super::MemoryStorage::new(); + test_backend_basic_operations(backend).await + } + #[async_std::test] + async fn test_memory_storage_channel_support() { + let backend = super::MemoryStorage::new(); + test_backend_channel_support(backend).await + } + + #[async_std::test] + #[ignore] + async fn test_redis_storage_basic_operations() { + let backend = super::RedisStorage::new(std::env::var("REDIS_URI").unwrap_or("redis://127.0.0.1:6379/0".to_string())).await.unwrap(); + redis::cmd("FLUSHDB").query_async::<_, ()>(&mut backend.redis.get_async_std_connection().await.unwrap()).await.unwrap(); + test_backend_basic_operations(backend).await + } + #[async_std::test] + #[ignore] + async fn test_redis_storage_channel_support() { + let backend = super::RedisStorage::new(std::env::var("REDIS_URI").unwrap_or("redis://127.0.0.1:6379/1".to_string())).await.unwrap(); + redis::cmd("FLUSHDB").query_async::<_, ()>(&mut backend.redis.get_async_std_connection().await.unwrap()).await.unwrap(); + test_backend_channel_support(backend).await + } +} |