about summary refs log tree commit diff
path: root/src/database
diff options
context:
space:
mode:
Diffstat (limited to 'src/database')
-rw-r--r--src/database/file/mod.rs173
-rw-r--r--src/database/mod.rs67
2 files changed, 210 insertions, 30 deletions
diff --git a/src/database/file/mod.rs b/src/database/file/mod.rs
new file mode 100644
index 0000000..36300fc
--- /dev/null
+++ b/src/database/file/mod.rs
@@ -0,0 +1,173 @@
+//pub mod async_file_ext;
+use async_std::fs::{File, OpenOptions};
+use async_std::io::{ErrorKind as IOErrorKind, BufReader};
+use async_std::io::prelude::*;
+use async_std::task::spawn_blocking;
+use async_trait::async_trait;
+use crate::database::{ErrorKind, Result, Storage, StorageError};
+use fd_lock::RwLock;
+use log::debug;
+use std::path::{Path, PathBuf};
+
+impl From<std::io::Error> for StorageError {
+    fn from(source: std::io::Error) -> Self {
+        Self::with_source(
+            match source.kind() {
+                IOErrorKind::NotFound => ErrorKind::NotFound,
+                _ => ErrorKind::Backend,
+            },
+            "file I/O error",
+            Box::new(source),
+        )
+    }
+}
+
+async fn get_lockable_file(file: File) -> RwLock<File> {
+    debug!("Trying to create a file lock");
+    spawn_blocking(move || RwLock::new(file)).await
+}
+
+fn url_to_path(root: &Path, url: &str) -> PathBuf {
+    let url = http_types::Url::parse(url).expect("Couldn't parse a URL");
+    let mut path: PathBuf = root.to_owned();
+    path.push(url.origin().ascii_serialization() + &url.path().to_string() + ".json");
+
+    path
+}
+
+#[derive(Clone)]
+pub struct FileStorage {
+    root_dir: PathBuf,
+}
+
+impl FileStorage {
+    pub async fn new(root_dir: PathBuf) -> Result<Self> {
+        // TODO check if the dir is writable
+        Ok(Self { root_dir })
+    }
+}
+
+#[async_trait]
+impl Storage for FileStorage {
+    async fn post_exists(&self, url: &str) -> Result<bool> {
+        let path = url_to_path(&self.root_dir, url);
+        debug!("Checking if {:?} exists...", path);
+        Ok(spawn_blocking(move || path.is_file()).await)
+    }
+
+    async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> {
+        let path = url_to_path(&self.root_dir, url);
+        debug!("Opening {:?}", path);
+        // We have to special-case in here because the function should return Ok(None) on 404
+        match File::open(path).await {
+            Ok(f) => {
+                let lock = get_lockable_file(f).await;
+                let guard = lock.read()?;
+
+                // HOW DOES THIS TYPECHECK?!!!!!!!!
+                // Read::read(&mut self) requires a mutable reference
+                // yet Read is implemented for &File
+                // We can't get a &mut File from &File, can we?
+                // And we need a &mut File to use Read::read_to_string()
+                // Yet if we pass it to a BufReader it works?!!
+                //
+                // I hate magic
+                //
+                // TODO find a way to get rid of BufReader here
+                let mut content = String::new();
+                let mut reader = BufReader::new(&*guard);
+                reader.read_to_string(&mut content).await?;
+                drop(reader);
+                drop(guard);
+                Ok(Some(serde_json::from_str(&content)?))
+            }
+            Err(err) => {
+                if err.kind() == IOErrorKind::NotFound {
+                    Ok(None)
+                } else {
+                    Err(err.into())
+                }
+            }
+        }
+    }
+
+    async fn put_post<'a>(&self, post: &'a serde_json::Value, user: &'a str) -> Result<()> {
+        let key = post["properties"]["uid"][0]
+        .as_str()
+        .expect("Tried to save a post without UID");
+        let path = url_to_path(&self.root_dir, key);
+
+        debug!("Creating {:?}", path);
+
+        let parent = path.parent().unwrap().to_owned();
+        if !spawn_blocking(move || parent.is_dir()).await {
+            async_std::fs::create_dir_all(path.parent().unwrap()).await?;
+        }
+
+        let f = OpenOptions::new()
+            .write(true)
+            .create_new(true)
+            .open(&path)
+            .await?;
+        
+        let mut lock = get_lockable_file(f).await;
+        let mut guard = lock.write()?;
+
+        (*guard).write(post.to_string().as_bytes()).await?;
+        drop(guard);
+
+        if post["properties"]["url"].is_array() {
+            for url in post["properties"]["url"]
+                .as_array()
+                .unwrap()
+                .iter()
+                .map(|i| i.as_str().unwrap())
+            {
+                // TODO consider using the symlink crate
+                // to promote cross-platform compat on Windows
+                // do we even need to support Windows?...
+                if url != key && url.starts_with(user) {
+                    let link = url_to_path(&self.root_dir, url);
+                    debug!("Creating a symlink at {:?}", link);
+                    let orig = path.clone();
+                    spawn_blocking(move || { std::os::unix::fs::symlink(orig, link) }).await?;
+                }
+            }
+        }
+
+        Ok(())
+    }
+
+    async fn update_post<'a>(&self, url: &'a str, update: serde_json::Value) -> Result<()> {
+        todo!()
+    }
+
+    async fn get_channels(
+        &self,
+        user: &crate::indieauth::User,
+    ) -> Result<Vec<super::MicropubChannel>> {
+        todo!()
+    }
+
+    async fn read_feed_with_limit<'a>(
+        &self,
+        url: &'a str,
+        after: &'a Option<String>,
+        limit: usize,
+        user: &'a Option<String>,
+    ) -> Result<Option<serde_json::Value>> {
+        todo!()
+    }
+
+    async fn delete_post<'a>(&self, url: &'a str) -> Result<()> {
+        todo!()
+    }
+
+    async fn get_setting<'a>(&self, setting: &'a str, user: &'a str) -> Result<String> {
+        todo!()
+    }
+
+    async fn set_setting<'a>(&self, setting: &'a str, user: &'a str, value: &'a str) -> Result<()> {
+        todo!()
+    }
+}
diff --git a/src/database/mod.rs b/src/database/mod.rs
index 7b144f8..58f0a35 100644
--- a/src/database/mod.rs
+++ b/src/database/mod.rs
@@ -7,6 +7,8 @@ mod redis;
 pub use crate::database::redis::RedisStorage;
 #[cfg(test)]
 pub use redis::tests::{get_redis_instance, RedisInstance};
+mod file;
+pub use crate::database::file::FileStorage;
 
 #[derive(Serialize, Deserialize, PartialEq, Debug)]
 pub struct MicropubChannel {
@@ -133,12 +135,6 @@ pub trait Storage: Clone + Send + Sync {
     /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined.
     async fn put_post<'a>(&self, post: &'a serde_json::Value, user: &'a str) -> Result<()>;
 
-    /*/// Save a post and add it to the relevant feeds listed in `post["properties"]["channel"]`.
-    ///
-    /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined
-    /// and `post["properties"]["channel"]` defined, even if it's empty.
-    async fn put_and_index_post<'a>(&mut self, post: &'a serde_json::Value) -> Result<()>;*/
-
     /// Modify a post using an update object as defined in the Micropub spec.
     ///
     /// Note to implementors: the update operation MUST be atomic OR MUST lock the database
@@ -191,6 +187,7 @@ mod tests {
     use super::redis::tests::get_redis_instance;
     use super::{MicropubChannel, Storage};
     use serde_json::json;
+    use paste::paste;
 
     async fn test_backend_basic_operations<Backend: Storage>(backend: Backend) {
         let post: serde_json::Value = json!({
@@ -210,7 +207,7 @@ mod tests {
             .put_post(&post, "https://fireburn.ru/")
             .await
             .unwrap();
-        if let Ok(Some(returned_post)) = backend.get_post(&key).await {
+        if let Some(returned_post) = backend.get_post(&key).await.unwrap() {
             assert!(returned_post.is_object());
             assert_eq!(
                 returned_post["type"].as_array().unwrap().len(),
@@ -300,30 +297,40 @@ mod tests {
             "Vika's Hideout"
         );
     }
-
-    #[async_std::test]
-    async fn test_redis_storage_basic_operations() {
-        let redis_instance = get_redis_instance().await;
-        let backend = super::RedisStorage::new(redis_instance.uri().to_string())
-            .await
-            .unwrap();
-        test_backend_basic_operations(backend).await;
-    }
-    #[async_std::test]
-    async fn test_redis_storage_channel_list() {
-        let redis_instance = get_redis_instance().await;
-        let backend = super::RedisStorage::new(redis_instance.uri().to_string())
-            .await
-            .unwrap();
-        test_backend_get_channel_list(backend).await;
+    macro_rules! redis_test {
+        ($func_name:expr) => {
+            paste! {
+                #[async_std::test]
+                async fn [<redis_ $func_name>] () {
+                    test_logger::ensure_env_logger_initialized();
+                    let redis_instance = get_redis_instance().await;
+                    let backend = super::RedisStorage::new(redis_instance.uri().to_string())
+                        .await
+                        .unwrap();
+                    $func_name(backend).await
+                }
+            }
+        }
     }
 
-    #[async_std::test]
-    async fn test_redis_settings() {
-        let redis_instance = get_redis_instance().await;
-        let backend = super::RedisStorage::new(redis_instance.uri().to_string())
-            .await
-            .unwrap();
-        test_backend_settings(backend).await;
+    macro_rules! file_test {
+        ($func_name:expr) => {
+            paste! {
+                #[async_std::test]
+                async fn [<file_ $func_name>] () {
+                    test_logger::ensure_env_logger_initialized();
+                    let tempdir = tempdir::TempDir::new("file").expect("Failed to create tempdir");
+                    let backend = super::FileStorage::new(tempdir.into_path()).await.unwrap();
+                    $func_name(backend).await
+                }
+            }
+        }
     }
+
+    redis_test!(test_backend_basic_operations);
+    redis_test!(test_backend_get_channel_list);
+    redis_test!(test_backend_settings);
+    file_test!(test_backend_basic_operations);
+    file_test!(test_backend_get_channel_list);
+    file_test!(test_backend_settings);
 }