#![warn(missing_docs)] use async_trait::async_trait; use serde::{Serialize,Deserialize}; #[cfg(test)] mod memory; #[cfg(test)] pub(crate) use crate::database::memory::MemoryStorage; use crate::indieauth::User; mod redis; pub use crate::database::redis::RedisStorage; #[derive(Serialize, Deserialize, PartialEq, Debug)] pub struct MicropubChannel { pub uid: String, pub name: String } #[derive(Debug, Clone, Copy)] pub enum ErrorKind { Backend, PermissionDenied, JSONParsing, NotFound, BadRequest, Other } #[derive(Debug)] pub struct StorageError { pub msg: String, source: Option>, pub kind: ErrorKind } unsafe impl Send for StorageError {} unsafe impl Sync for StorageError {} impl std::error::Error for StorageError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { self.source.as_ref().map(|e| e.as_ref()) } } impl From for StorageError { fn from(err: serde_json::Error) -> Self { Self { msg: format!("{}", err), source: Some(Box::new(err)), kind: ErrorKind::JSONParsing } } } impl std::fmt::Display for StorageError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match match self.kind { ErrorKind::Backend => write!(f, "backend error: "), ErrorKind::JSONParsing => write!(f, "error while parsing JSON: "), ErrorKind::PermissionDenied => write!(f, "permission denied: "), ErrorKind::NotFound => write!(f, "not found: "), ErrorKind::BadRequest => write!(f, "bad request: "), ErrorKind::Other => write!(f, "generic storage layer error: ") } { Ok(_) => write!(f, "{}", self.msg), Err(err) => Err(err) } } } impl serde::Serialize for StorageError { fn serialize(&self, serializer: S) -> std::result::Result { serializer.serialize_str(&self.to_string()) } } impl StorageError { /// Create a new StorageError of an ErrorKind with a message. fn new(kind: ErrorKind, msg: &str) -> Self { return StorageError { msg: msg.to_string(), source: None, kind } } /// Get the kind of an error. pub fn kind(&self) -> ErrorKind { self.kind } } /// A special Result type for the Micropub backing storage. pub type Result = std::result::Result; /// A storage backend for the Micropub server. /// /// Implementations should note that all methods listed on this trait MUST be fully atomic /// or lock the database so that write conflicts or reading half-written data should not occur. #[async_trait] pub trait Storage: Clone + Send + Sync { /// Check if a post exists in the database. async fn post_exists(&self, url: &str) -> Result; /// Load a post from the database in MF2-JSON format, deserialized from JSON. async fn get_post(&self, url: &str) -> Result>; /// Save a post to the database as an MF2-JSON structure. /// /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined. async fn put_post<'a>(&self, post: &'a serde_json::Value) -> Result<()>; /*/// Save a post and add it to the relevant feeds listed in `post["properties"]["channel"]`. /// /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined /// and `post["properties"]["channel"]` defined, even if it's empty. async fn put_and_index_post<'a>(&mut self, post: &'a serde_json::Value) -> Result<()>;*/ /// Modify a post using an update object as defined in the Micropub spec. /// /// Note to implementors: the update operation MUST be atomic OR MUST lock the database /// to prevent two clients overwriting each other's changes. /// /// You can assume concurrent updates will never contradict each other, since that will be dumb. /// The last update always wins. async fn update_post<'a>(&self, url: &'a str, update: serde_json::Value) -> Result<()>; /// Get a list of channels available for the user represented by the `user` object to write. async fn get_channels(&self, user: &User) -> Result>; /// Fetch a feed at `url` and return a an h-feed object containing /// `limit` posts after a post by url `after`, filtering the content /// in context of a user specified by `user` (or an anonymous user). /// /// Specifically, private posts that don't include the user in the audience /// will be elided from the feed, and the posts containing location and not /// specifying post["properties"]["location-visibility"][0] == "public" /// will have their location data (but not check-in data) stripped. /// /// This function is used as an optimization so the client, whatever it is, /// doesn't have to fetch posts, then realize some of them are private, and /// fetch more posts. /// /// Note for implementors: if you use streams to fetch posts in parallel /// from the database, preferably make this method use a connection pool /// to reduce overhead of creating a database connection per post for /// parallel fetching. async fn read_feed_with_limit<'a>(&self, url: &'a str, after: &'a Option, limit: usize, user: &'a Option) -> Result>; /// Deletes a post from the database irreversibly. 'nuff said. Must be idempotent. async fn delete_post<'a>(&self, url: &'a str) -> Result<()>; } #[cfg(test)] mod tests { use super::{Storage, MicropubChannel}; use std::{process}; use std::time::Duration; use serde_json::json; async fn test_backend_basic_operations(backend: Backend) { let post: serde_json::Value = json!({ "type": ["h-entry"], "properties": { "content": ["Test content"], "author": ["https://fireburn.ru/"], "uid": ["https://fireburn.ru/posts/hello"], "url": ["https://fireburn.ru/posts/hello", "https://fireburn.ru/posts/test"] } }); let key = post["properties"]["uid"][0].as_str().unwrap().to_string(); let alt_url = post["properties"]["url"][1].as_str().unwrap().to_string(); // Reading and writing backend.put_post(&post).await.unwrap(); if let Ok(Some(returned_post)) = backend.get_post(&key).await { assert!(returned_post.is_object()); assert_eq!(returned_post["type"].as_array().unwrap().len(), post["type"].as_array().unwrap().len()); assert_eq!(returned_post["type"].as_array().unwrap(), post["type"].as_array().unwrap()); let props: &serde_json::Map = post["properties"].as_object().unwrap(); for key in props.keys() { assert_eq!(returned_post["properties"][key].as_array().unwrap(), post["properties"][key].as_array().unwrap()) } } else { panic!("For some reason the backend did not return the post.") } // Check the alternative URL - it should return the same post if let Ok(Some(returned_post)) = backend.get_post(&alt_url).await { assert!(returned_post.is_object()); assert_eq!(returned_post["type"].as_array().unwrap().len(), post["type"].as_array().unwrap().len()); assert_eq!(returned_post["type"].as_array().unwrap(), post["type"].as_array().unwrap()); let props: &serde_json::Map = post["properties"].as_object().unwrap(); for key in props.keys() { assert_eq!(returned_post["properties"][key].as_array().unwrap(), post["properties"][key].as_array().unwrap()) } } else { panic!("For some reason the backend did not return the post.") } } async fn test_backend_get_channel_list(backend: Backend) { let feed = json!({ "type": ["h-feed"], "properties": { "name": ["Main Page"], "author": ["https://fireburn.ru/"], "uid": ["https://fireburn.ru/feeds/main"] }, "children": [] }); backend.put_post(&feed).await.unwrap(); let chans = backend.get_channels(&crate::indieauth::User::new("https://fireburn.ru/", "https://quill.p3k.io/", "create update media")).await.unwrap(); assert_eq!(chans.len(), 1); assert_eq!(chans[0], MicropubChannel { uid: "https://fireburn.ru/feeds/main".to_string(), name: "Main Page".to_string() }); } #[async_std::test] async fn test_memory_storage_basic_operations() { let backend = super::MemoryStorage::new(); test_backend_basic_operations(backend).await } #[async_std::test] async fn test_memory_storage_channel_support() { let backend = super::MemoryStorage::new(); test_backend_get_channel_list(backend).await } async fn get_redis_backend() -> (tempdir::TempDir, process::Child, super::RedisStorage) { let tempdir = tempdir::TempDir::new("redis").expect("failed to create tempdir"); let socket = tempdir.path().join("redis.sock"); let redis_child = process::Command::new("redis-server") .current_dir(&tempdir) .arg("--port").arg("0") .arg("--unixsocket").arg(&socket) .stdout(process::Stdio::null()) .stderr(process::Stdio::null()) .spawn().expect("Failed to spawn Redis"); println!("redis+unix:///{}", socket.to_str().unwrap()); let uri = format!("redis+unix:///{}", socket.to_str().unwrap()); // There should be a slight delay, we need to wait for Redis to spin up let client = redis::Client::open(uri.clone()).unwrap(); let millisecond = Duration::from_millis(1); let mut retries: usize = 0; const MAX_RETRIES: usize = 10 * 1000/*ms*/; while let Err(err) = client.get_connection() { if err.is_connection_refusal() { async_std::task::sleep(millisecond).await; retries += 1; if retries > MAX_RETRIES { panic!("Timeout waiting for Redis, last error: {}", err); } } else { panic!("Could not connect: {}", err); } } let backend = super::RedisStorage::new(uri).await.unwrap(); return (tempdir, redis_child, backend) } #[async_std::test] async fn test_redis_storage_basic_operations() { let (_, mut redis, backend) = get_redis_backend().await; test_backend_basic_operations(backend).await; redis.kill().expect("Redis wasn't running"); } #[async_std::test] async fn test_redis_storage_channel_support() { let (_, mut redis, backend) = get_redis_backend().await; test_backend_get_channel_list(backend).await; redis.kill().expect("Redis wasn't running"); } }