about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/database/file/mod.rs103
-rw-r--r--src/database/mod.rs54
2 files changed, 154 insertions, 3 deletions
diff --git a/src/database/file/mod.rs b/src/database/file/mod.rs
index 36300fc..ff5ad13 100644
--- a/src/database/file/mod.rs
+++ b/src/database/file/mod.rs
@@ -1,4 +1,3 @@
-//pub mod async_file_ext;
 use async_std::fs::{File, OpenOptions};
 use async_std::io::{ErrorKind as IOErrorKind, BufReader};
 use async_std::io::prelude::*;
@@ -8,6 +7,7 @@ use crate::database::{ErrorKind, Result, Storage, StorageError};
 use fd_lock::RwLock;
 use log::debug;
 use std::path::{Path, PathBuf};
+use std::collections::HashMap;
 
 impl From<std::io::Error> for StorageError {
     fn from(source: std::io::Error) -> Self {
@@ -35,6 +35,79 @@ fn url_to_path(root: &Path, url: &str) -> PathBuf {
     path
 }
 
+fn modify_post(post: &serde_json::Value, update: &serde_json::Value) -> Result<serde_json::Value> {
+    let mut add_keys: HashMap<String, serde_json::Value> = HashMap::new();
+    let mut remove_keys: Vec<String> = vec![];
+    let mut remove_values: HashMap<String, Vec<serde_json::Value>> = HashMap::new();
+    let mut post = post.clone();
+
+    if let Some(delete) = update["delete"].as_array() {
+        remove_keys.extend(delete.iter().filter_map(|v| v.as_str()).map(|v| v.to_string()));
+    } else if let Some(delete) = update["delete"].as_object() {
+        for (k, v) in delete {
+            if let Some(v) = v.as_array() {
+                remove_values.entry(k.to_string()).or_default().extend(v.clone());
+            } else {
+                return Err(StorageError::new(ErrorKind::BadRequest, "Malformed update object"));
+            }
+        }
+    }
+    if let Some(add) = update["add"].as_object() {
+        for (k, v) in add {
+            if v.is_array() {
+                add_keys.insert(k.to_string(), v.clone());
+            } else {
+                return Err(StorageError::new(ErrorKind::BadRequest, "Malformed update object"));
+            }
+        }
+    }
+    if let Some(replace) = update["replace"].as_object() {
+        for (k, v) in replace {
+            remove_keys.push(k.to_string());
+            add_keys.insert(k.to_string(), v.clone());
+        }
+    }
+
+    for k in remove_keys {
+        post["properties"].as_object_mut().unwrap().remove(&k);
+    }
+    for (k, v) in remove_values {
+        let k = &k;
+        let props;
+        if k == "children" {
+            props = &mut post;
+        } else {
+            props = &mut post["properties"];
+        }
+        v.iter().for_each(|v| {
+            if let Some(vec) = props[k].as_array_mut() {
+                if let Some(index) = vec.iter().position(|w| w == v) {
+                    vec.remove(index);
+                }
+            }
+        });
+    }
+    for (k, v) in add_keys {
+        let props;
+        if k == "children" {
+            props = &mut post;
+        } else {
+            props = &mut post["properties"];
+        }
+        let k = &k;
+        if let Some(prop) = props[k].as_array_mut() {
+            if k == "children" {
+                v.as_array().unwrap().iter().cloned().rev().for_each(|v| prop.insert(0, v));
+            } else {
+                prop.extend(v.as_array().unwrap().iter().cloned());
+            }
+        } else {
+            post["properties"][k] = v
+        }
+    }
+   Ok(post)
+}
+
 #[derive(Clone)]
 pub struct FileStorage {
     root_dir: PathBuf,
@@ -113,7 +186,8 @@ impl Storage for FileStorage {
         let mut lock = get_lockable_file(f).await;
         let mut guard = lock.write()?;
 
-        (*guard).write(post.to_string().as_bytes()).await?;
+        (*guard).write_all(post.to_string().as_bytes()).await?;
+        (*guard).flush().await?;
         drop(guard);
 
         if post["properties"]["url"].is_array() {
@@ -139,7 +213,30 @@ impl Storage for FileStorage {
     }
 
     async fn update_post<'a>(&self, url: &'a str, update: serde_json::Value) -> Result<()> {
-        todo!()
+        let path = url_to_path(&self.root_dir, url);
+        let f = OpenOptions::new()
+            .write(true)
+            .read(true)
+            .truncate(false)
+            .open(&path)
+            .await?;
+
+        let mut lock = get_lockable_file(f).await;
+        let mut guard = lock.write()?;
+
+        let mut content = String::new();
+        guard.read_to_string(&mut content).await?;
+        let json: serde_json::Value = serde_json::from_str(&content)?;
+        // Apply the editing algorithms
+        let new_json = modify_post(&json, &update)?;
+
+        (*guard).set_len(0).await?;
+        (*guard).seek(std::io::SeekFrom::Start(0)).await?;
+        (*guard).write_all(new_json.to_string().as_bytes()).await?;
+        (*guard).flush().await?;
+        drop(guard);
+        // TODO check if URLs changed between old and new JSON
+        Ok(())
     }
 
     async fn get_channels(
diff --git a/src/database/mod.rs b/src/database/mod.rs
index 58f0a35..e6873b0 100644
--- a/src/database/mod.rs
+++ b/src/database/mod.rs
@@ -252,6 +252,58 @@ mod tests {
         }
     }
 
+    /// Note: this is merely a smoke check and is in no way comprehensive.
+    async fn test_backend_update<Backend: Storage>(backend: Backend) {
+        let post: serde_json::Value = json!({
+            "type": ["h-entry"],
+            "properties": {
+                "content": ["Test content"],
+                "author": ["https://fireburn.ru/"],
+                "uid": ["https://fireburn.ru/posts/hello"],
+                "url": ["https://fireburn.ru/posts/hello", "https://fireburn.ru/posts/test"]
+            }
+        });
+        let key = post["properties"]["uid"][0].as_str().unwrap().to_string();
+
+        // Reading and writing
+        backend
+            .put_post(&post, "https://fireburn.ru/")
+            .await
+            .unwrap();
+
+        backend.update_post(&key, json!({
+            "url": &key,
+            "add": {
+                "category": ["testing"],
+            },
+            "replace": {
+                "content": ["Different test content"]
+            }
+        })).await.unwrap();
+
+        if let Some(returned_post) = backend.get_post(&key).await.unwrap() {
+            assert!(returned_post.is_object());
+            assert_eq!(
+                returned_post["type"].as_array().unwrap().len(),
+                post["type"].as_array().unwrap().len()
+            );
+            assert_eq!(
+                returned_post["type"].as_array().unwrap(),
+                post["type"].as_array().unwrap()
+            );
+            assert_eq!(
+                returned_post["properties"]["content"][0].as_str().unwrap(),
+                "Different test content"
+            );
+            assert_eq!(
+                returned_post["properties"]["category"].as_array().unwrap(),
+                &vec![json!("testing")]
+            );
+        } else {
+            panic!("For some reason the backend did not return the post.")
+        }
+    }
+
     async fn test_backend_get_channel_list<Backend: Storage>(backend: Backend) {
         let feed = json!({
             "type": ["h-feed"],
@@ -330,7 +382,9 @@ mod tests {
     redis_test!(test_backend_basic_operations);
     redis_test!(test_backend_get_channel_list);
     redis_test!(test_backend_settings);
+    redis_test!(test_backend_update);
     file_test!(test_backend_basic_operations);
     file_test!(test_backend_get_channel_list);
     file_test!(test_backend_settings);
+    file_test!(test_backend_update);
 }