about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/database/memory.rs196
-rw-r--r--src/database/mod.rs2
2 files changed, 198 insertions, 0 deletions
diff --git a/src/database/memory.rs b/src/database/memory.rs
new file mode 100644
index 0000000..c83bc8c
--- /dev/null
+++ b/src/database/memory.rs
@@ -0,0 +1,196 @@
+use async_trait::async_trait;
+use std::collections::HashMap;
+use std::sync::Arc;
+use tokio::sync::RwLock;
+use futures_util::FutureExt;
+use serde_json::json;
+
+use crate::database::{Storage, Result, StorageError, ErrorKind, MicropubChannel};
+use crate::indieauth::User;
+
+#[derive(Clone, Debug)]
+pub struct MemoryStorage {
+    pub mapping: Arc<RwLock<HashMap<String, serde_json::Value>>>,
+    pub channels: Arc<RwLock<HashMap<String, Vec<String>>>>
+}
+
+#[async_trait]
+impl Storage for MemoryStorage {
+    async fn post_exists(&self, url: &str) -> Result<bool> {
+        return Ok(self.mapping.read().await.contains_key(url))
+    }
+
+    async fn get_post(&self, url: &str) ->Result<Option<serde_json::Value>> {
+        let mapping = self.mapping.read().await;
+        match mapping.get(url) {
+            Some(val) => {
+                if let Some(new_url) = val["see_other"].as_str() {
+                    match mapping.get(new_url) {
+                        Some(val) => Ok(Some(val.clone())),
+                        None => {
+                            drop(mapping);
+                            self.mapping.write().await.remove(url);
+                            Ok(None)
+                        }
+                    }
+                } else {
+                    Ok(Some(val.clone()))
+                }
+            },
+            _ => Ok(None)
+        }
+    }
+
+    async fn put_post(&self, post: &'_ serde_json::Value, user: &'_ str) -> Result<()> {
+        let mapping = &mut self.mapping.write().await;
+        let key: &str;
+        match post["properties"]["uid"][0].as_str() {
+            Some(uid) => key = uid,
+            None => return Err(StorageError::new(ErrorKind::Other, "post doesn't have a UID"))
+        }
+        mapping.insert(key.to_string(), post.clone());
+        if post["properties"]["url"].is_array() {
+            for url in post["properties"]["url"].as_array().unwrap().iter().map(|i| i.as_str().unwrap().to_string()) {
+                if &url != key {
+                    mapping.insert(url, json!({"see_other": key}));
+                }
+            }
+        }
+        if post["type"].as_array().unwrap().iter().any(|i| i == "h-feed") {
+            // This is a feed. Add it to the channels array if it's not already there.
+            println!("{:#}", post);
+            self.channels.write().await.entry(post["properties"]["author"][0].as_str().unwrap().to_string()).or_insert(vec![]).push(key.to_string())
+        }
+        Ok(())
+    }
+
+    async fn update_post(&self, url: &'_ str, update: serde_json::Value) -> Result<()> {
+        let mut add_keys: HashMap<String, serde_json::Value> = HashMap::new();
+        let mut remove_keys: Vec<String> = vec![];
+        let mut remove_values: HashMap<String, Vec<serde_json::Value>> = HashMap::new();
+
+        if let Some(delete) = update["delete"].as_array() {
+            remove_keys.extend(delete.iter().filter_map(|v| v.as_str()).map(|v| v.to_string()));
+        } else if let Some(delete) = update["delete"].as_object() {
+            for (k, v) in delete {
+                if let Some(v) = v.as_array() {
+                    remove_values.entry(k.to_string()).or_default().extend(v.clone());
+                } else {
+                    return Err(StorageError::new(ErrorKind::BadRequest, "Malformed update object"));
+                }
+            }
+        }
+        if let Some(add) = update["add"].as_object() {
+            for (k, v) in add {
+                if v.is_array() {
+                    add_keys.insert(k.to_string(), v.clone());
+                } else {
+                    return Err(StorageError::new(ErrorKind::BadRequest, "Malformed update object"));
+                }
+            }
+        }
+        if let Some(replace) = update["replace"].as_object() {
+            for (k, v) in replace {
+                remove_keys.push(k.to_string());
+                add_keys.insert(k.to_string(), v.clone());
+            }
+        }
+        let mut mapping = self.mapping.write().await;
+        if let Some(mut post) = mapping.get(url) {
+            if let Some(url) = post["see_other"].as_str() {
+                if let Some(new_post) = mapping.get(url) {
+                    post = new_post
+                } else {
+                    return Err(StorageError::new(ErrorKind::NotFound, "The post you have requested is not found in the database."));
+                }
+            }
+            let mut post = post.clone();
+            for k in remove_keys {
+                post["properties"].as_object_mut().unwrap().remove(&k);
+            }
+            for (k, v) in remove_values {
+                let k = &k;
+                let props;
+                if k == "children" {
+                    props = &mut post;
+                } else {
+                    props = &mut post["properties"];
+                }
+                v.iter().for_each(|v| {
+                    if let Some(vec) = props[k].as_array_mut() {
+                        if let Some(index) = vec.iter().position(|w| w == v) {
+                            vec.remove(index);
+                        }
+                    }
+                });
+            }
+            for (k, v) in add_keys {
+                let props;
+                if k == "children" {
+                    props = &mut post;
+                } else {
+                    props = &mut post["properties"];
+                }
+                let k = &k;
+                if let Some(prop) = props[k].as_array_mut() {
+                    if k == "children" {
+                        v.as_array().unwrap().iter().cloned().rev().for_each(|v| prop.insert(0, v));
+                    } else {
+                        prop.extend(v.as_array().unwrap().iter().cloned());
+                    }
+                } else {
+                    post["properties"][k] = v
+                }
+            }
+            mapping.insert(post["properties"]["uid"][0].as_str().unwrap().to_string(), post);
+        } else {
+            return Err(StorageError::new(ErrorKind::NotFound, "The designated post wasn't found in the database."));
+        }
+        Ok(())
+    }
+
+    async fn get_channels(&self, user: &'_ str) -> Result<Vec<MicropubChannel>> {
+        match self.channels.read().await.get(user) {
+            Some(channels) => Ok(futures_util::future::join_all(channels.iter()
+                .map(|channel| self.get_post(channel)
+                    .map(|result| result.unwrap())
+                    .map(|post: Option<serde_json::Value>| {
+                        if let Some(post) = post {
+                            Some(MicropubChannel {
+                                uid: post["properties"]["uid"][0].as_str().unwrap().to_string(),
+                                name: post["properties"]["name"][0].as_str().unwrap().to_string()
+                            })
+                        } else { None }
+                    })
+                ).collect::<Vec<_>>()).await.into_iter().filter_map(|chan| chan).collect::<Vec<_>>()),
+            None => Ok(vec![])
+        }
+        
+    }
+
+    async fn read_feed_with_limit(&self, url: &'_ str, after: &'_ Option<String>, limit: usize, user: &'_ Option<String>) -> Result<Option<serde_json::Value>> {
+        todo!()
+    }
+
+    async fn delete_post(&self, url: &'_ str) -> Result<()> {
+        self.mapping.write().await.remove(url);
+        Ok(())
+    }
+
+    async fn get_setting(&self, setting: &'_ str, user: &'_ str) -> Result<String> {
+        todo!()
+    }
+
+    async fn set_setting(&self, setting: &'_ str, user: &'_ str, value: &'_ str) -> Result<()> {
+        todo!()
+    }
+}
+
+impl MemoryStorage {
+    pub fn new() -> Self {
+        Self {
+            mapping: Arc::new(RwLock::new(HashMap::new())),
+            channels: Arc::new(RwLock::new(HashMap::new()))
+        }
+    }
+}
diff --git a/src/database/mod.rs b/src/database/mod.rs
index 6fdb9b1..836d6c3 100644
--- a/src/database/mod.rs
+++ b/src/database/mod.rs
@@ -4,6 +4,8 @@ use serde::{Deserialize, Serialize};
 
 mod file;
 pub use crate::database::file::FileStorage;
+mod memory;
+pub(crate) use crate::database::memory::MemoryStorage;
 
 /// Data structure representing a Micropub channel in the ?q=channels output.
 #[derive(Serialize, Deserialize, PartialEq, Debug)]