diff options
-rw-r--r-- | src/database/memory.rs | 196 | ||||
-rw-r--r-- | src/database/mod.rs | 2 |
2 files changed, 198 insertions, 0 deletions
diff --git a/src/database/memory.rs b/src/database/memory.rs new file mode 100644 index 0000000..c83bc8c --- /dev/null +++ b/src/database/memory.rs @@ -0,0 +1,196 @@ +use async_trait::async_trait; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use futures_util::FutureExt; +use serde_json::json; + +use crate::database::{Storage, Result, StorageError, ErrorKind, MicropubChannel}; +use crate::indieauth::User; + +#[derive(Clone, Debug)] +pub struct MemoryStorage { + pub mapping: Arc<RwLock<HashMap<String, serde_json::Value>>>, + pub channels: Arc<RwLock<HashMap<String, Vec<String>>>> +} + +#[async_trait] +impl Storage for MemoryStorage { + async fn post_exists(&self, url: &str) -> Result<bool> { + return Ok(self.mapping.read().await.contains_key(url)) + } + + async fn get_post(&self, url: &str) ->Result<Option<serde_json::Value>> { + let mapping = self.mapping.read().await; + match mapping.get(url) { + Some(val) => { + if let Some(new_url) = val["see_other"].as_str() { + match mapping.get(new_url) { + Some(val) => Ok(Some(val.clone())), + None => { + drop(mapping); + self.mapping.write().await.remove(url); + Ok(None) + } + } + } else { + Ok(Some(val.clone())) + } + }, + _ => Ok(None) + } + } + + async fn put_post(&self, post: &'_ serde_json::Value, user: &'_ str) -> Result<()> { + let mapping = &mut self.mapping.write().await; + let key: &str; + match post["properties"]["uid"][0].as_str() { + Some(uid) => key = uid, + None => return Err(StorageError::new(ErrorKind::Other, "post doesn't have a UID")) + } + mapping.insert(key.to_string(), post.clone()); + if post["properties"]["url"].is_array() { + for url in post["properties"]["url"].as_array().unwrap().iter().map(|i| i.as_str().unwrap().to_string()) { + if &url != key { + mapping.insert(url, json!({"see_other": key})); + } + } + } + if post["type"].as_array().unwrap().iter().any(|i| i == "h-feed") { + // This is a feed. Add it to the channels array if it's not already there. + println!("{:#}", post); + self.channels.write().await.entry(post["properties"]["author"][0].as_str().unwrap().to_string()).or_insert(vec![]).push(key.to_string()) + } + Ok(()) + } + + async fn update_post(&self, url: &'_ str, update: serde_json::Value) -> Result<()> { + let mut add_keys: HashMap<String, serde_json::Value> = HashMap::new(); + let mut remove_keys: Vec<String> = vec![]; + let mut remove_values: HashMap<String, Vec<serde_json::Value>> = HashMap::new(); + + if let Some(delete) = update["delete"].as_array() { + remove_keys.extend(delete.iter().filter_map(|v| v.as_str()).map(|v| v.to_string())); + } else if let Some(delete) = update["delete"].as_object() { + for (k, v) in delete { + if let Some(v) = v.as_array() { + remove_values.entry(k.to_string()).or_default().extend(v.clone()); + } else { + return Err(StorageError::new(ErrorKind::BadRequest, "Malformed update object")); + } + } + } + if let Some(add) = update["add"].as_object() { + for (k, v) in add { + if v.is_array() { + add_keys.insert(k.to_string(), v.clone()); + } else { + return Err(StorageError::new(ErrorKind::BadRequest, "Malformed update object")); + } + } + } + if let Some(replace) = update["replace"].as_object() { + for (k, v) in replace { + remove_keys.push(k.to_string()); + add_keys.insert(k.to_string(), v.clone()); + } + } + let mut mapping = self.mapping.write().await; + if let Some(mut post) = mapping.get(url) { + if let Some(url) = post["see_other"].as_str() { + if let Some(new_post) = mapping.get(url) { + post = new_post + } else { + return Err(StorageError::new(ErrorKind::NotFound, "The post you have requested is not found in the database.")); + } + } + let mut post = post.clone(); + for k in remove_keys { + post["properties"].as_object_mut().unwrap().remove(&k); + } + for (k, v) in remove_values { + let k = &k; + let props; + if k == "children" { + props = &mut post; + } else { + props = &mut post["properties"]; + } + v.iter().for_each(|v| { + if let Some(vec) = props[k].as_array_mut() { + if let Some(index) = vec.iter().position(|w| w == v) { + vec.remove(index); + } + } + }); + } + for (k, v) in add_keys { + let props; + if k == "children" { + props = &mut post; + } else { + props = &mut post["properties"]; + } + let k = &k; + if let Some(prop) = props[k].as_array_mut() { + if k == "children" { + v.as_array().unwrap().iter().cloned().rev().for_each(|v| prop.insert(0, v)); + } else { + prop.extend(v.as_array().unwrap().iter().cloned()); + } + } else { + post["properties"][k] = v + } + } + mapping.insert(post["properties"]["uid"][0].as_str().unwrap().to_string(), post); + } else { + return Err(StorageError::new(ErrorKind::NotFound, "The designated post wasn't found in the database.")); + } + Ok(()) + } + + async fn get_channels(&self, user: &'_ str) -> Result<Vec<MicropubChannel>> { + match self.channels.read().await.get(user) { + Some(channels) => Ok(futures_util::future::join_all(channels.iter() + .map(|channel| self.get_post(channel) + .map(|result| result.unwrap()) + .map(|post: Option<serde_json::Value>| { + if let Some(post) = post { + Some(MicropubChannel { + uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), + name: post["properties"]["name"][0].as_str().unwrap().to_string() + }) + } else { None } + }) + ).collect::<Vec<_>>()).await.into_iter().filter_map(|chan| chan).collect::<Vec<_>>()), + None => Ok(vec![]) + } + + } + + async fn read_feed_with_limit(&self, url: &'_ str, after: &'_ Option<String>, limit: usize, user: &'_ Option<String>) -> Result<Option<serde_json::Value>> { + todo!() + } + + async fn delete_post(&self, url: &'_ str) -> Result<()> { + self.mapping.write().await.remove(url); + Ok(()) + } + + async fn get_setting(&self, setting: &'_ str, user: &'_ str) -> Result<String> { + todo!() + } + + async fn set_setting(&self, setting: &'_ str, user: &'_ str, value: &'_ str) -> Result<()> { + todo!() + } +} + +impl MemoryStorage { + pub fn new() -> Self { + Self { + mapping: Arc::new(RwLock::new(HashMap::new())), + channels: Arc::new(RwLock::new(HashMap::new())) + } + } +} diff --git a/src/database/mod.rs b/src/database/mod.rs index 6fdb9b1..836d6c3 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -4,6 +4,8 @@ use serde::{Deserialize, Serialize}; mod file; pub use crate::database::file::FileStorage; +mod memory; +pub(crate) use crate::database::memory::MemoryStorage; /// Data structure representing a Micropub channel in the ?q=channels output. #[derive(Serialize, Deserialize, PartialEq, Debug)] |