diff options
author | Vika <vika@fireburn.ru> | 2022-05-24 17:18:30 +0300 |
---|---|---|
committer | Vika <vika@fireburn.ru> | 2022-05-24 17:18:30 +0300 |
commit | 5610a5f0bf1a9df02bd3d5b55e2cdebef2440360 (patch) | |
tree | 8394bcf1dcc204043d7adeb8dde2e2746977606e /kittybox-rs/src/database | |
parent | 2f93873122b47e42f7ee1c38f1f04d052a63599c (diff) | |
download | kittybox-5610a5f0bf1a9df02bd3d5b55e2cdebef2440360.tar.zst |
flake.nix: reorganize
- Kittybox's source code is moved to a subfolder - This improves build caching by Nix since it doesn't take changes to other files into account - Package and test definitions were spun into separate files - This makes my flake.nix much easier to navigate - This also makes it somewhat possible to use without flakes (but it is still not easy, so use flakes!) - Some attributes were moved in compliance with Nix 2.8's changes to flake schema
Diffstat (limited to 'kittybox-rs/src/database')
-rw-r--r-- | kittybox-rs/src/database/file/mod.rs | 619 | ||||
-rw-r--r-- | kittybox-rs/src/database/memory.rs | 200 | ||||
-rw-r--r-- | kittybox-rs/src/database/mod.rs | 539 | ||||
-rw-r--r-- | kittybox-rs/src/database/redis/edit_post.lua | 93 | ||||
-rw-r--r-- | kittybox-rs/src/database/redis/mod.rs | 392 |
5 files changed, 1843 insertions, 0 deletions
diff --git a/kittybox-rs/src/database/file/mod.rs b/kittybox-rs/src/database/file/mod.rs new file mode 100644 index 0000000..1e7aa96 --- /dev/null +++ b/kittybox-rs/src/database/file/mod.rs @@ -0,0 +1,619 @@ +//#![warn(clippy::unwrap_used)] +use crate::database::{filter_post, ErrorKind, Result, Storage, StorageError, Settings}; +use std::io::ErrorKind as IOErrorKind; +use tokio::fs::{File, OpenOptions}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::task::spawn_blocking; +use async_trait::async_trait; +use futures::{stream, StreamExt, TryStreamExt}; +use log::debug; +use serde_json::json; +use std::collections::HashMap; +use std::path::{Path, PathBuf}; + +impl From<std::io::Error> for StorageError { + fn from(source: std::io::Error) -> Self { + Self::with_source( + match source.kind() { + IOErrorKind::NotFound => ErrorKind::NotFound, + IOErrorKind::AlreadyExists => ErrorKind::Conflict, + _ => ErrorKind::Backend, + }, + "file I/O error", + Box::new(source), + ) + } +} + +impl From<tokio::time::error::Elapsed> for StorageError { + fn from(source: tokio::time::error::Elapsed) -> Self { + Self::with_source( + ErrorKind::Backend, + "timeout on I/O operation", + Box::new(source) + ) + } +} + +// Copied from https://stackoverflow.com/questions/39340924 +// This routine is adapted from the *old* Path's `path_relative_from` +// function, which works differently from the new `relative_from` function. +// In particular, this handles the case on unix where both paths are +// absolute but with only the root as the common directory. +fn path_relative_from(path: &Path, base: &Path) -> Option<PathBuf> { + use std::path::Component; + + if path.is_absolute() != base.is_absolute() { + if path.is_absolute() { + Some(PathBuf::from(path)) + } else { + None + } + } else { + let mut ita = path.components(); + let mut itb = base.components(); + let mut comps: Vec<Component> = vec![]; + loop { + match (ita.next(), itb.next()) { + (None, None) => break, + (Some(a), None) => { + comps.push(a); + comps.extend(ita.by_ref()); + break; + } + (None, _) => comps.push(Component::ParentDir), + (Some(a), Some(b)) if comps.is_empty() && a == b => (), + (Some(a), Some(b)) if b == Component::CurDir => comps.push(a), + (Some(_), Some(b)) if b == Component::ParentDir => return None, + (Some(a), Some(_)) => { + comps.push(Component::ParentDir); + for _ in itb { + comps.push(Component::ParentDir); + } + comps.push(a); + comps.extend(ita.by_ref()); + break; + } + } + } + Some(comps.iter().map(|c| c.as_os_str()).collect()) + } +} + +#[allow(clippy::unwrap_used, clippy::expect_used)] +#[cfg(test)] +mod tests { + #[test] + fn test_relative_path_resolving() { + let path1 = std::path::Path::new("/home/vika/Projects/kittybox"); + let path2 = std::path::Path::new("/home/vika/Projects/nixpkgs"); + let relative_path = super::path_relative_from(path2, path1).unwrap(); + + assert_eq!(relative_path, std::path::Path::new("../nixpkgs")) + } +} + +// TODO: Check that the path ACTUALLY IS INSIDE THE ROOT FOLDER +// This could be checked by completely resolving the path +// and checking if it has a common prefix +fn url_to_path(root: &Path, url: &str) -> PathBuf { + let path = url_to_relative_path(url).to_logical_path(root); + if !path.starts_with(root) { + // TODO: handle more gracefully + panic!("Security error: {:?} is not a prefix of {:?}", path, root) + } else { + path + } +} + +fn url_to_relative_path(url: &str) -> relative_path::RelativePathBuf { + let url = warp::http::Uri::try_from(url).expect("Couldn't parse a URL"); + let mut path = relative_path::RelativePathBuf::new(); + path.push(url.authority().unwrap().to_string() + url.path() + ".json"); + + path +} + +fn modify_post(post: &serde_json::Value, update: &serde_json::Value) -> Result<serde_json::Value> { + let mut add_keys: HashMap<String, Vec<serde_json::Value>> = HashMap::new(); + let mut remove_keys: Vec<String> = vec![]; + let mut remove_values: HashMap<String, Vec<serde_json::Value>> = HashMap::new(); + let mut post = post.clone(); + + if let Some(delete) = update["delete"].as_array() { + remove_keys.extend( + delete + .iter() + .filter_map(|v| v.as_str()) + .map(|v| v.to_string()), + ); + } else if let Some(delete) = update["delete"].as_object() { + for (k, v) in delete { + if let Some(v) = v.as_array() { + remove_values + .entry(k.to_string()) + .or_default() + .extend(v.clone()); + } else { + return Err(StorageError::new( + ErrorKind::BadRequest, + "Malformed update object", + )); + } + } + } + if let Some(add) = update["add"].as_object() { + for (k, v) in add { + if let Some(v) = v.as_array() { + add_keys.insert(k.to_string(), v.clone()); + } else { + return Err(StorageError::new( + ErrorKind::BadRequest, + "Malformed update object", + )); + } + } + } + if let Some(replace) = update["replace"].as_object() { + for (k, v) in replace { + remove_keys.push(k.to_string()); + if let Some(v) = v.as_array() { + add_keys.insert(k.to_string(), v.clone()); + } else { + return Err(StorageError::new(ErrorKind::BadRequest, "Malformed update object")); + } + } + } + + if let Some(props) = post["properties"].as_object_mut() { + for k in remove_keys { + props.remove(&k); + } + } + for (k, v) in remove_values { + let k = &k; + let props = if k == "children" { + &mut post + } else { + &mut post["properties"] + }; + v.iter().for_each(|v| { + if let Some(vec) = props[k].as_array_mut() { + if let Some(index) = vec.iter().position(|w| w == v) { + vec.remove(index); + } + } + }); + } + for (k, v) in add_keys { + let props = if k == "children" { + &mut post + } else { + &mut post["properties"] + }; + let k = &k; + if let Some(prop) = props[k].as_array_mut() { + if k == "children" { + v.into_iter() + .rev() + .for_each(|v| prop.insert(0, v)); + } else { + prop.extend(v.into_iter()); + } + } else { + post["properties"][k] = serde_json::Value::Array(v) + } + } + Ok(post) +} + +#[derive(Clone, Debug)] +/// A backend using a folder with JSON files as a backing store. +/// Uses symbolic links to represent a many-to-one mapping of URLs to a post. +pub struct FileStorage { + root_dir: PathBuf, +} + +impl FileStorage { + /// Create a new storage wrapping a folder specified by root_dir. + pub async fn new(root_dir: PathBuf) -> Result<Self> { + // TODO check if the dir is writable + Ok(Self { root_dir }) + } +} + +async fn hydrate_author<S: Storage>( + feed: &mut serde_json::Value, + user: &'_ Option<String>, + storage: &S, +) { + let url = feed["properties"]["uid"][0] + .as_str() + .expect("MF2 value should have a UID set! Check if you used normalize_mf2 before recording the post!"); + if let Some(author) = feed["properties"]["author"].as_array().cloned() { + if !feed["type"] + .as_array() + .expect("MF2 value should have a type set!") + .iter() + .any(|i| i == "h-card") + { + let author_list: Vec<serde_json::Value> = stream::iter(author.iter()) + .then(|i| async move { + if let Some(i) = i.as_str() { + match storage.get_post(i).await { + Ok(post) => match post { + Some(post) => match filter_post(post, user) { + Some(author) => author, + None => json!(i), + }, + None => json!(i), + }, + Err(e) => { + log::error!("Error while hydrating post {}: {}", url, e); + json!(i) + } + } + } else { + i.clone() + } + }) + .collect::<Vec<_>>() + .await; + if let Some(props) = feed["properties"].as_object_mut() { + props["author"] = json!(author_list); + } else { + feed["properties"] = json!({"author": author_list}); + } + } + } +} + +#[async_trait] +impl Storage for FileStorage { + async fn post_exists(&self, url: &str) -> Result<bool> { + let path = url_to_path(&self.root_dir, url); + debug!("Checking if {:?} exists...", path); + /*let result = match tokio::fs::metadata(path).await { + Ok(metadata) => { + Ok(true) + }, + Err(err) => { + if err.kind() == IOErrorKind::NotFound { + Ok(false) + } else { + Err(err.into()) + } + } + };*/ + #[allow(clippy::unwrap_used)] // JoinHandle captures panics, this closure shouldn't panic + Ok(spawn_blocking(move || path.is_file()).await.unwrap()) + } + + async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> { + let path = url_to_path(&self.root_dir, url); + // TODO: check that the path actually belongs to the dir of user who requested it + // it's not like you CAN access someone else's private posts with it + // so it's not exactly a security issue, but it's still not good + debug!("Opening {:?}", path); + + match File::open(&path).await { + Ok(mut file) => { + let mut content = String::new(); + // Typechecks because OS magic acts on references + // to FDs as if they were behind a mutex + AsyncReadExt::read_to_string(&mut file, &mut content).await?; + debug!("Read {} bytes successfully from {:?}", content.as_bytes().len(), &path); + Ok(Some(serde_json::from_str(&content)?)) + }, + Err(err) => { + if err.kind() == IOErrorKind::NotFound { + Ok(None) + } else { + Err(err.into()) + } + } + } + } + + async fn put_post(&self, post: &'_ serde_json::Value, user: &'_ str) -> Result<()> { + let key = post["properties"]["uid"][0] + .as_str() + .expect("Tried to save a post without UID"); + let path = url_to_path(&self.root_dir, key); + let tempfile = (&path).with_extension("tmp"); + debug!("Creating {:?}", path); + + let parent = path.parent().expect("Parent for this directory should always exist").to_owned(); + if !parent.is_dir() { + tokio::fs::create_dir_all(parent).await?; + } + + let mut file = tokio::fs::OpenOptions::new() + .write(true) + .create_new(true) + .open(&tempfile).await?; + + file.write_all(post.to_string().as_bytes()).await?; + file.flush().await?; + drop(file); + tokio::fs::rename(&tempfile, &path).await?; + + if let Some(urls) = post["properties"]["url"].as_array() { + for url in urls + .iter() + .map(|i| i.as_str().unwrap()) + { + if url != key && url.starts_with(user) { + let link = url_to_path(&self.root_dir, url); + debug!("Creating a symlink at {:?}", link); + let orig = path.clone(); + // We're supposed to have a parent here. + let basedir = link.parent().ok_or_else(|| { + StorageError::new( + ErrorKind::Backend, + "Failed to calculate parent directory when creating a symlink", + ) + })?; + let relative = path_relative_from(&orig, basedir).unwrap(); + println!("{:?} - {:?} = {:?}", &orig, &basedir, &relative); + tokio::fs::symlink(relative, link).await?; + } + } + } + + if post["type"] + .as_array() + .unwrap() + .iter() + .any(|s| s.as_str() == Some("h-feed")) + { + println!("Adding to channel list..."); + // Add the h-feed to the channel list + let mut path = relative_path::RelativePathBuf::new(); + path.push(warp::http::Uri::try_from(user.to_string()).unwrap().authority().unwrap().to_string()); + path.push("channels"); + + let path = path.to_path(&self.root_dir); + let tempfilename = (&path).with_extension("tmp"); + let channel_name = post["properties"]["name"][0] + .as_str() + .map(|s| s.to_string()) + .unwrap_or_else(String::default); + let key = key.to_string(); + + let mut tempfile = OpenOptions::new() + .write(true) + .create_new(true) + .open(&tempfilename).await?; + let mut file = OpenOptions::new() + .read(true) + .write(true) + .truncate(false) + .create(true) + .open(&path).await?; + + let mut content = String::new(); + file.read_to_string(&mut content).await?; + drop(file); + let mut channels: Vec<super::MicropubChannel> = if !content.is_empty() { + serde_json::from_str(&content)? + } else { + Vec::default() + }; + + channels.push(super::MicropubChannel { + uid: key.to_string(), + name: channel_name, + }); + + tempfile.write_all(serde_json::to_string(&channels)?.as_bytes()).await?; + tempfile.flush().await?; + drop(tempfile); + tokio::fs::rename(tempfilename, path).await?; + } + Ok(()) + } + + async fn update_post(&self, url: &'_ str, update: serde_json::Value) -> Result<()> { + let path = url_to_path(&self.root_dir, url); + let tempfilename = path.with_extension("tmp"); + #[allow(unused_variables)] + let (old_json, new_json) = { + let mut temp = OpenOptions::new() + .write(true) + .create_new(true) + .open(&tempfilename) + .await?; + let mut file = OpenOptions::new() + .read(true) + .open(&path) + .await?; + + let mut content = String::new(); + file.read_to_string(&mut content).await?; + let json: serde_json::Value = serde_json::from_str(&content)?; + drop(file); + // Apply the editing algorithms + let new_json = modify_post(&json, &update)?; + + temp.write_all(new_json.to_string().as_bytes()).await?; + temp.flush().await?; + drop(temp); + tokio::fs::rename(tempfilename, path).await?; + + (json, new_json) + }; + // TODO check if URLs changed between old and new JSON + Ok(()) + } + + async fn get_channels(&self, user: &'_ str) -> Result<Vec<super::MicropubChannel>> { + let mut path = relative_path::RelativePathBuf::new(); + path.push(warp::http::Uri::try_from(user.to_string()).unwrap().authority().unwrap().to_string()); + path.push("channels"); + + let path = path.to_path(&self.root_dir); + match File::open(&path).await { + Ok(mut f) => { + let mut content = String::new(); + f.read_to_string(&mut content).await?; + // This should not happen, but if it does, handle it gracefully + if content.is_empty() { + return Ok(vec![]); + } + let channels: Vec<super::MicropubChannel> = serde_json::from_str(&content)?; + Ok(channels) + } + Err(e) => { + if e.kind() == IOErrorKind::NotFound { + Ok(vec![]) + } else { + Err(e.into()) + } + } + } + } + + async fn read_feed_with_limit( + &self, + url: &'_ str, + after: &'_ Option<String>, + limit: usize, + user: &'_ Option<String>, + ) -> Result<Option<serde_json::Value>> { + if let Some(feed) = self.get_post(url).await? { + if let Some(mut feed) = filter_post(feed, user) { + if feed["children"].is_array() { + // This code contains several clones. It looks + // like the borrow checker thinks it is preventing + // me from doing something incredibly stupid. The + // borrow checker may or may not be right. + let children = feed["children"].as_array().unwrap().clone(); + let mut posts_iter = children + .into_iter() + .map(|s: serde_json::Value| s.as_str().unwrap().to_string()); + // Note: we can't actually use skip_while here because we end up emitting `after`. + // This imperative snippet consumes after instead of emitting it, allowing the + // stream of posts to return only those items that truly come *after* + if let Some(after) = after { + for s in posts_iter.by_ref() { + if &s == after { + break + } + } + }; + let posts = stream::iter(posts_iter) + .map(|url: String| async move { self.get_post(&url).await }) + .buffered(std::cmp::min(3, limit)) + // Hack to unwrap the Option and sieve out broken links + // Broken links return None, and Stream::filter_map skips Nones. + .try_filter_map(|post: Option<serde_json::Value>| async move { Ok(post) }) + .try_filter_map(|post| async move { Ok(filter_post(post, user)) }) + .and_then(|mut post| async move { + hydrate_author(&mut post, user, self).await; + Ok(post) + }) + .take(limit); + + match posts.try_collect::<Vec<serde_json::Value>>().await { + Ok(posts) => feed["children"] = serde_json::json!(posts), + Err(err) => { + return Err(StorageError::with_source( + ErrorKind::Other, + "Feed assembly error", + Box::new(err), + )); + } + } + } + hydrate_author(&mut feed, user, self).await; + Ok(Some(feed)) + } else { + Err(StorageError::new( + ErrorKind::PermissionDenied, + "specified user cannot access this post", + )) + } + } else { + Ok(None) + } + } + + async fn delete_post(&self, url: &'_ str) -> Result<()> { + let path = url_to_path(&self.root_dir, url); + if let Err(e) = tokio::fs::remove_file(path).await { + Err(e.into()) + } else { + // TODO check for dangling references in the channel list + Ok(()) + } + } + + async fn get_setting(&self, setting: Settings, user: &'_ str) -> Result<String> { + log::debug!("User for getting settings: {}", user); + let url = warp::http::Uri::try_from(user).expect("Couldn't parse a URL"); + let mut path = relative_path::RelativePathBuf::new(); + path.push(url.authority().unwrap().to_string()); + path.push("settings"); + + let path = path.to_path(&self.root_dir); + log::debug!("Getting settings from {:?}", &path); + let setting = setting.to_string(); + let mut file = File::open(path).await?; + let mut content = String::new(); + file.read_to_string(&mut content).await?; + + let settings: HashMap<String, String> = serde_json::from_str(&content)?; + // XXX consider returning string slices instead of cloning a string every time + // it might come with a performance hit and/or memory usage inflation + settings + .get(&setting) + .cloned() + .ok_or_else(|| StorageError::new(ErrorKind::Backend, "Setting not set")) + } + + async fn set_setting(&self, setting: Settings, user: &'_ str, value: &'_ str) -> Result<()> { + let url = warp::http::Uri::try_from(user).expect("Couldn't parse a URL"); + let mut path = relative_path::RelativePathBuf::new(); + path.push(url.authority().unwrap().to_string()); + path.push("settings"); + + let path = path.to_path(&self.root_dir); + let temppath = path.with_extension("tmp"); + + let parent = path.parent().unwrap().to_owned(); + if !spawn_blocking(move || parent.is_dir()).await.unwrap() { + tokio::fs::create_dir_all(path.parent().unwrap()).await?; + } + + let (setting, value) = (setting.to_string(), value.to_string()); + + let mut tempfile = OpenOptions::new() + .write(true) + .create_new(true) + .open(&temppath) + .await?; + + let mut settings: HashMap<String, String> = match File::open(&path).await { + Ok(mut f) => { + let mut content = String::new(); + f.read_to_string(&mut content).await?; + if content.is_empty() { + HashMap::default() + } else { + serde_json::from_str(&content)? + } + } + Err(err) => if err.kind() == IOErrorKind::NotFound { + HashMap::default() + } else { + return Err(err.into()) + } + }; + settings.insert(setting, value); + tempfile.write_all(serde_json::to_string(&settings)?.as_bytes()).await?; + drop(tempfile); + tokio::fs::rename(temppath, path).await?; + Ok(()) + } +} diff --git a/kittybox-rs/src/database/memory.rs b/kittybox-rs/src/database/memory.rs new file mode 100644 index 0000000..786466c --- /dev/null +++ b/kittybox-rs/src/database/memory.rs @@ -0,0 +1,200 @@ +#![allow(clippy::todo)] +use async_trait::async_trait; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use futures_util::FutureExt; +use serde_json::json; + +use crate::database::{Storage, Result, StorageError, ErrorKind, MicropubChannel, Settings}; + +#[derive(Clone, Debug)] +pub struct MemoryStorage { + pub mapping: Arc<RwLock<HashMap<String, serde_json::Value>>>, + pub channels: Arc<RwLock<HashMap<String, Vec<String>>>> +} + +#[async_trait] +impl Storage for MemoryStorage { + async fn post_exists(&self, url: &str) -> Result<bool> { + return Ok(self.mapping.read().await.contains_key(url)) + } + + async fn get_post(&self, url: &str) ->Result<Option<serde_json::Value>> { + let mapping = self.mapping.read().await; + match mapping.get(url) { + Some(val) => { + if let Some(new_url) = val["see_other"].as_str() { + match mapping.get(new_url) { + Some(val) => Ok(Some(val.clone())), + None => { + drop(mapping); + self.mapping.write().await.remove(url); + Ok(None) + } + } + } else { + Ok(Some(val.clone())) + } + }, + _ => Ok(None) + } + } + + async fn put_post(&self, post: &'_ serde_json::Value, _user: &'_ str) -> Result<()> { + let mapping = &mut self.mapping.write().await; + let key: &str = match post["properties"]["uid"][0].as_str() { + Some(uid) => uid, + None => return Err(StorageError::new(ErrorKind::Other, "post doesn't have a UID")) + }; + mapping.insert(key.to_string(), post.clone()); + if post["properties"]["url"].is_array() { + for url in post["properties"]["url"].as_array().unwrap().iter().map(|i| i.as_str().unwrap().to_string()) { + if url != key { + mapping.insert(url, json!({"see_other": key})); + } + } + } + if post["type"].as_array().unwrap().iter().any(|i| i == "h-feed") { + // This is a feed. Add it to the channels array if it's not already there. + println!("{:#}", post); + self.channels.write().await.entry(post["properties"]["author"][0].as_str().unwrap().to_string()).or_insert_with(Vec::new).push(key.to_string()) + } + Ok(()) + } + + async fn update_post(&self, url: &'_ str, update: serde_json::Value) -> Result<()> { + let mut add_keys: HashMap<String, serde_json::Value> = HashMap::new(); + let mut remove_keys: Vec<String> = vec![]; + let mut remove_values: HashMap<String, Vec<serde_json::Value>> = HashMap::new(); + + if let Some(delete) = update["delete"].as_array() { + remove_keys.extend(delete.iter().filter_map(|v| v.as_str()).map(|v| v.to_string())); + } else if let Some(delete) = update["delete"].as_object() { + for (k, v) in delete { + if let Some(v) = v.as_array() { + remove_values.entry(k.to_string()).or_default().extend(v.clone()); + } else { + return Err(StorageError::new(ErrorKind::BadRequest, "Malformed update object")); + } + } + } + if let Some(add) = update["add"].as_object() { + for (k, v) in add { + if v.is_array() { + add_keys.insert(k.to_string(), v.clone()); + } else { + return Err(StorageError::new(ErrorKind::BadRequest, "Malformed update object")); + } + } + } + if let Some(replace) = update["replace"].as_object() { + for (k, v) in replace { + remove_keys.push(k.to_string()); + add_keys.insert(k.to_string(), v.clone()); + } + } + let mut mapping = self.mapping.write().await; + if let Some(mut post) = mapping.get(url) { + if let Some(url) = post["see_other"].as_str() { + if let Some(new_post) = mapping.get(url) { + post = new_post + } else { + return Err(StorageError::new(ErrorKind::NotFound, "The post you have requested is not found in the database.")); + } + } + let mut post = post.clone(); + for k in remove_keys { + post["properties"].as_object_mut().unwrap().remove(&k); + } + for (k, v) in remove_values { + let k = &k; + let props = if k == "children" { + &mut post + } else { + &mut post["properties"] + }; + v.iter().for_each(|v| { + if let Some(vec) = props[k].as_array_mut() { + if let Some(index) = vec.iter().position(|w| w == v) { + vec.remove(index); + } + } + }); + } + for (k, v) in add_keys { + let props = if k == "children" { + &mut post + } else { + &mut post["properties"] + }; + let k = &k; + if let Some(prop) = props[k].as_array_mut() { + if k == "children" { + v.as_array().unwrap().iter().cloned().rev().for_each(|v| prop.insert(0, v)); + } else { + prop.extend(v.as_array().unwrap().iter().cloned()); + } + } else { + post["properties"][k] = v + } + } + mapping.insert(post["properties"]["uid"][0].as_str().unwrap().to_string(), post); + } else { + return Err(StorageError::new(ErrorKind::NotFound, "The designated post wasn't found in the database.")); + } + Ok(()) + } + + async fn get_channels(&self, user: &'_ str) -> Result<Vec<MicropubChannel>> { + match self.channels.read().await.get(user) { + Some(channels) => Ok(futures_util::future::join_all(channels.iter() + .map(|channel| self.get_post(channel) + .map(|result| result.unwrap()) + .map(|post: Option<serde_json::Value>| { + post.map(|post| MicropubChannel { + uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), + name: post["properties"]["name"][0].as_str().unwrap().to_string() + }) + }) + ).collect::<Vec<_>>()).await.into_iter().flatten().collect::<Vec<_>>()), + None => Ok(vec![]) + } + + } + + #[allow(unused_variables)] + async fn read_feed_with_limit(&self, url: &'_ str, after: &'_ Option<String>, limit: usize, user: &'_ Option<String>) -> Result<Option<serde_json::Value>> { + todo!() + } + + async fn delete_post(&self, url: &'_ str) -> Result<()> { + self.mapping.write().await.remove(url); + Ok(()) + } + + #[allow(unused_variables)] + async fn get_setting(&self, setting: Settings, user: &'_ str) -> Result<String> { + todo!() + } + + #[allow(unused_variables)] + async fn set_setting(&self, setting: Settings, user: &'_ str, value: &'_ str) -> Result<()> { + todo!() + } +} + +impl Default for MemoryStorage { + fn default() -> Self { + Self::new() + } +} + +impl MemoryStorage { + pub fn new() -> Self { + Self { + mapping: Arc::new(RwLock::new(HashMap::new())), + channels: Arc::new(RwLock::new(HashMap::new())) + } + } +} diff --git a/kittybox-rs/src/database/mod.rs b/kittybox-rs/src/database/mod.rs new file mode 100644 index 0000000..6bf5409 --- /dev/null +++ b/kittybox-rs/src/database/mod.rs @@ -0,0 +1,539 @@ +#![warn(missing_docs)] +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +mod file; +pub use crate::database::file::FileStorage; +#[cfg(test)] +mod memory; +#[cfg(test)] +pub use crate::database::memory::MemoryStorage; + +pub use kittybox_util::MicropubChannel; + +/// Enum representing different errors that might occur during the database query. +#[derive(Debug, Clone, Copy)] +pub enum ErrorKind { + /// Backend error (e.g. database connection error) + Backend, + /// Error due to insufficient contextual permissions for the query + PermissionDenied, + /// Error due to the database being unable to parse JSON returned from the backing storage. + /// Usually indicative of someone fiddling with the database manually instead of using proper tools. + JsonParsing, + /// - ErrorKind::NotFound - equivalent to a 404 error. Note, some requests return an Option, + /// in which case None is also equivalent to a 404. + NotFound, + /// The user's query or request to the database was malformed. Used whenever the database processes + /// the user's query directly, such as when editing posts inside of the database (e.g. Redis backend) + BadRequest, + /// the user's query collided with an in-flight request and needs to be retried + Conflict, + /// - ErrorKind::Other - when something so weird happens that it becomes undescribable. + Other, +} + +/// Enum representing settings that might be stored in the site's database. +#[derive(Deserialize, Serialize, Debug, Clone, Copy)] +#[serde(rename_all = "snake_case")] +pub enum Settings { + /// The name of the website -- displayed in the header and the browser titlebar. + SiteName, +} + +impl std::string::ToString for Settings { + fn to_string(&self) -> String { + serde_variant::to_variant_name(self).unwrap().to_string() + } +} + +/// Error signalled from the database. +#[derive(Debug)] +pub struct StorageError { + msg: String, + source: Option<Box<dyn std::error::Error + Send + Sync>>, + kind: ErrorKind, +} + +impl warp::reject::Reject for StorageError {} + +impl std::error::Error for StorageError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.source + .as_ref() + .map(|e| e.as_ref() as &dyn std::error::Error) + } +} +impl From<serde_json::Error> for StorageError { + fn from(err: serde_json::Error) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::JsonParsing, + } + } +} +impl std::fmt::Display for StorageError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match match self.kind { + ErrorKind::Backend => write!(f, "backend error: "), + ErrorKind::JsonParsing => write!(f, "error while parsing JSON: "), + ErrorKind::PermissionDenied => write!(f, "permission denied: "), + ErrorKind::NotFound => write!(f, "not found: "), + ErrorKind::BadRequest => write!(f, "bad request: "), + ErrorKind::Conflict => write!(f, "conflict with an in-flight request or existing data: "), + ErrorKind::Other => write!(f, "generic storage layer error: "), + } { + Ok(_) => write!(f, "{}", self.msg), + Err(err) => Err(err), + } + } +} +impl serde::Serialize for StorageError { + fn serialize<S: serde::Serializer>( + &self, + serializer: S, + ) -> std::result::Result<S::Ok, S::Error> { + serializer.serialize_str(&self.to_string()) + } +} +impl StorageError { + /// Create a new StorageError of an ErrorKind with a message. + fn new(kind: ErrorKind, msg: &str) -> Self { + Self { + msg: msg.to_string(), + source: None, + kind, + } + } + /// Create a StorageError using another arbitrary Error as a source. + fn with_source( + kind: ErrorKind, + msg: &str, + source: Box<dyn std::error::Error + Send + Sync>, + ) -> Self { + Self { + msg: msg.to_string(), + source: Some(source), + kind, + } + } + /// Get the kind of an error. + pub fn kind(&self) -> ErrorKind { + self.kind + } + /// Get the message as a string slice. + pub fn msg(&self) -> &str { + &self.msg + } +} + +/// A special Result type for the Micropub backing storage. +pub type Result<T> = std::result::Result<T, StorageError>; + +/// Filter the post according to the value of `user`. +/// +/// Anonymous users cannot view private posts and protected locations; +/// Logged-in users can only view private posts targeted at them; +/// Logged-in users can't view private location data +pub fn filter_post( + mut post: serde_json::Value, + user: &'_ Option<String>, +) -> Option<serde_json::Value> { + if post["properties"]["deleted"][0].is_string() { + return Some(serde_json::json!({ + "type": post["type"], + "properties": { + "deleted": post["properties"]["deleted"] + } + })); + } + let empty_vec: Vec<serde_json::Value> = vec![]; + let author = post["properties"]["author"] + .as_array() + .unwrap_or(&empty_vec) + .iter() + .map(|i| i.as_str().unwrap().to_string()); + let visibility = post["properties"]["visibility"][0] + .as_str() + .unwrap_or("public"); + let mut audience = author.chain( + post["properties"]["audience"] + .as_array() + .unwrap_or(&empty_vec) + .iter() + .map(|i| i.as_str().unwrap().to_string()), + ); + if (visibility == "private" && !audience.any(|i| Some(i) == *user)) + || (visibility == "protected" && user.is_none()) + { + return None; + } + if post["properties"]["location"].is_array() { + let location_visibility = post["properties"]["location-visibility"][0] + .as_str() + .unwrap_or("private"); + let mut author = post["properties"]["author"] + .as_array() + .unwrap_or(&empty_vec) + .iter() + .map(|i| i.as_str().unwrap().to_string()); + if (location_visibility == "private" && !author.any(|i| Some(i) == *user)) + || (location_visibility == "protected" && user.is_none()) + { + post["properties"] + .as_object_mut() + .unwrap() + .remove("location"); + } + } + Some(post) +} + +/// A storage backend for the Micropub server. +/// +/// Implementations should note that all methods listed on this trait MUST be fully atomic +/// or lock the database so that write conflicts or reading half-written data should not occur. +#[async_trait] +pub trait Storage: std::fmt::Debug + Clone + Send + Sync { + /// Check if a post exists in the database. + async fn post_exists(&self, url: &str) -> Result<bool>; + + /// Load a post from the database in MF2-JSON format, deserialized from JSON. + async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>>; + + /// Save a post to the database as an MF2-JSON structure. + /// + /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined. + async fn put_post(&self, post: &'_ serde_json::Value, user: &'_ str) -> Result<()>; + + /// Modify a post using an update object as defined in the Micropub spec. + /// + /// Note to implementors: the update operation MUST be atomic and + /// SHOULD lock the database to prevent two clients overwriting + /// each other's changes or simply corrupting something. Rejecting + /// is allowed in case of concurrent updates if waiting for a lock + /// cannot be done. + async fn update_post(&self, url: &'_ str, update: serde_json::Value) -> Result<()>; + + /// Get a list of channels available for the user represented by the URL `user` to write to. + async fn get_channels(&self, user: &'_ str) -> Result<Vec<MicropubChannel>>; + + /// Fetch a feed at `url` and return a an h-feed object containing + /// `limit` posts after a post by url `after`, filtering the content + /// in context of a user specified by `user` (or an anonymous user). + /// + /// Specifically, private posts that don't include the user in the audience + /// will be elided from the feed, and the posts containing location and not + /// specifying post["properties"]["location-visibility"][0] == "public" + /// will have their location data (but not check-in data) stripped. + /// + /// This function is used as an optimization so the client, whatever it is, + /// doesn't have to fetch posts, then realize some of them are private, and + /// fetch more posts. + /// + /// Note for implementors: if you use streams to fetch posts in parallel + /// from the database, preferably make this method use a connection pool + /// to reduce overhead of creating a database connection per post for + /// parallel fetching. + async fn read_feed_with_limit( + &self, + url: &'_ str, + after: &'_ Option<String>, + limit: usize, + user: &'_ Option<String>, + ) -> Result<Option<serde_json::Value>>; + + /// Deletes a post from the database irreversibly. 'nuff said. Must be idempotent. + async fn delete_post(&self, url: &'_ str) -> Result<()>; + + /// Gets a setting from the setting store and passes the result. + async fn get_setting(&self, setting: Settings, user: &'_ str) -> Result<String>; + + /// Commits a setting to the setting store. + async fn set_setting(&self, setting: Settings, user: &'_ str, value: &'_ str) -> Result<()>; +} + +#[cfg(test)] +mod tests { + use super::{MicropubChannel, Storage}; + use serde_json::json; + + async fn test_basic_operations<Backend: Storage>(backend: Backend) { + let post: serde_json::Value = json!({ + "type": ["h-entry"], + "properties": { + "content": ["Test content"], + "author": ["https://fireburn.ru/"], + "uid": ["https://fireburn.ru/posts/hello"], + "url": ["https://fireburn.ru/posts/hello", "https://fireburn.ru/posts/test"] + } + }); + let key = post["properties"]["uid"][0].as_str().unwrap().to_string(); + let alt_url = post["properties"]["url"][1].as_str().unwrap().to_string(); + + // Reading and writing + backend + .put_post(&post, "https://fireburn.ru/") + .await + .unwrap(); + if let Some(returned_post) = backend.get_post(&key).await.unwrap() { + assert!(returned_post.is_object()); + assert_eq!( + returned_post["type"].as_array().unwrap().len(), + post["type"].as_array().unwrap().len() + ); + assert_eq!( + returned_post["type"].as_array().unwrap(), + post["type"].as_array().unwrap() + ); + let props: &serde_json::Map<String, serde_json::Value> = + post["properties"].as_object().unwrap(); + for key in props.keys() { + assert_eq!( + returned_post["properties"][key].as_array().unwrap(), + post["properties"][key].as_array().unwrap() + ) + } + } else { + panic!("For some reason the backend did not return the post.") + } + // Check the alternative URL - it should return the same post + if let Ok(Some(returned_post)) = backend.get_post(&alt_url).await { + assert!(returned_post.is_object()); + assert_eq!( + returned_post["type"].as_array().unwrap().len(), + post["type"].as_array().unwrap().len() + ); + assert_eq!( + returned_post["type"].as_array().unwrap(), + post["type"].as_array().unwrap() + ); + let props: &serde_json::Map<String, serde_json::Value> = + post["properties"].as_object().unwrap(); + for key in props.keys() { + assert_eq!( + returned_post["properties"][key].as_array().unwrap(), + post["properties"][key].as_array().unwrap() + ) + } + } else { + panic!("For some reason the backend did not return the post.") + } + } + + /// Note: this is merely a smoke check and is in no way comprehensive. + // TODO updates for feeds must update children using special logic + async fn test_update<Backend: Storage>(backend: Backend) { + let post: serde_json::Value = json!({ + "type": ["h-entry"], + "properties": { + "content": ["Test content"], + "author": ["https://fireburn.ru/"], + "uid": ["https://fireburn.ru/posts/hello"], + "url": ["https://fireburn.ru/posts/hello", "https://fireburn.ru/posts/test"] + } + }); + let key = post["properties"]["uid"][0].as_str().unwrap().to_string(); + + // Reading and writing + backend + .put_post(&post, "https://fireburn.ru/") + .await + .unwrap(); + + backend + .update_post( + &key, + json!({ + "url": &key, + "add": { + "category": ["testing"], + }, + "replace": { + "content": ["Different test content"] + } + }), + ) + .await + .unwrap(); + + match backend.get_post(&key).await { + Ok(Some(returned_post)) => { + assert!(returned_post.is_object()); + assert_eq!( + returned_post["type"].as_array().unwrap().len(), + post["type"].as_array().unwrap().len() + ); + assert_eq!( + returned_post["type"].as_array().unwrap(), + post["type"].as_array().unwrap() + ); + assert_eq!( + returned_post["properties"]["content"][0].as_str().unwrap(), + "Different test content" + ); + assert_eq!( + returned_post["properties"]["category"].as_array().unwrap(), + &vec![json!("testing")] + ); + }, + something_else => { + something_else.expect("Shouldn't error").expect("Should have the post"); + } + } + } + + async fn test_get_channel_list<Backend: Storage>(backend: Backend) { + let feed = json!({ + "type": ["h-feed"], + "properties": { + "name": ["Main Page"], + "author": ["https://fireburn.ru/"], + "uid": ["https://fireburn.ru/feeds/main"] + }, + "children": [] + }); + backend + .put_post(&feed, "https://fireburn.ru/") + .await + .unwrap(); + let chans = backend.get_channels("https://fireburn.ru/").await.unwrap(); + assert_eq!(chans.len(), 1); + assert_eq!( + chans[0], + MicropubChannel { + uid: "https://fireburn.ru/feeds/main".to_string(), + name: "Main Page".to_string() + } + ); + } + + async fn test_settings<Backend: Storage>(backend: Backend) { + backend + .set_setting(crate::database::Settings::SiteName, "https://fireburn.ru/", "Vika's Hideout") + .await + .unwrap(); + assert_eq!( + backend + .get_setting(crate::database::Settings::SiteName, "https://fireburn.ru/") + .await + .unwrap(), + "Vika's Hideout" + ); + } + + fn gen_random_post(domain: &str) -> serde_json::Value { + use faker_rand::lorem::{Paragraphs, Word}; + + let uid = format!( + "https://{domain}/posts/{}-{}-{}", + rand::random::<Word>(), rand::random::<Word>(), rand::random::<Word>() + ); + + let post = json!({ + "type": ["h-entry"], + "properties": { + "content": [rand::random::<Paragraphs>().to_string()], + "uid": [&uid], + "url": [&uid] + } + }); + + post + } + + async fn test_feed_pagination<Backend: Storage>(backend: Backend) { + let posts = std::iter::from_fn(|| Some(gen_random_post("fireburn.ru"))) + .take(20) + .collect::<Vec<serde_json::Value>>(); + + let feed = json!({ + "type": ["h-feed"], + "properties": { + "name": ["Main Page"], + "author": ["https://fireburn.ru/"], + "uid": ["https://fireburn.ru/feeds/main"] + }, + "children": posts.iter() + .filter_map(|json| json["properties"]["uid"][0].as_str()) + .collect::<Vec<&str>>() + }); + let key = feed["properties"]["uid"][0].as_str().unwrap(); + + backend + .put_post(&feed, "https://fireburn.ru/") + .await + .unwrap(); + println!("---"); + for (i, post) in posts.iter().enumerate() { + backend.put_post(post, "https://fireburn.ru/").await.unwrap(); + println!("posts[{}] = {}", i, post["properties"]["uid"][0]); + } + println!("---"); + let limit: usize = 10; + let result = backend.read_feed_with_limit(key, &None, limit, &None) + .await + .unwrap() + .unwrap(); + for (i, post) in result["children"].as_array().unwrap().iter().enumerate() { + println!("feed[0][{}] = {}", i, post["properties"]["uid"][0]); + } + println!("---"); + assert_eq!(result["children"].as_array().unwrap()[0..10], posts[0..10]); + + let result2 = backend.read_feed_with_limit( + key, + &result["children"] + .as_array() + .unwrap() + .last() + .unwrap() + ["properties"]["uid"][0] + .as_str() + .map(|i| i.to_owned()), + limit, &None + ).await.unwrap().unwrap(); + for (i, post) in result2["children"].as_array().unwrap().iter().enumerate() { + println!("feed[1][{}] = {}", i, post["properties"]["uid"][0]); + } + println!("---"); + assert_eq!(result2["children"].as_array().unwrap()[0..10], posts[10..20]); + + // Regression test for #4 + let nonsense_after = Some("1010101010".to_owned()); + let result3 = tokio::time::timeout(tokio::time::Duration::from_secs(10), async move { + backend.read_feed_with_limit( + key, &nonsense_after, limit, &None + ).await.unwrap().unwrap() + }).await.expect("Operation should not hang: see https://gitlab.com/kittybox/kittybox/-/issues/4"); + assert!(result3["children"].as_array().unwrap().is_empty()); + } + + /// Automatically generates a test suite for + macro_rules! test_all { + ($func_name:ident, $mod_name:ident) => { + mod $mod_name { + $func_name!(test_basic_operations); + $func_name!(test_get_channel_list); + $func_name!(test_settings); + $func_name!(test_update); + $func_name!(test_feed_pagination); + } + } + } + macro_rules! file_test { + ($func_name:ident) => { + #[tokio::test] + async fn $func_name () { + test_logger::ensure_env_logger_initialized(); + let tempdir = tempdir::TempDir::new("file").expect("Failed to create tempdir"); + let backend = super::super::FileStorage::new(tempdir.into_path()).await.unwrap(); + super::$func_name(backend).await + } + }; + } + + test_all!(file_test, file); + +} diff --git a/kittybox-rs/src/database/redis/edit_post.lua b/kittybox-rs/src/database/redis/edit_post.lua new file mode 100644 index 0000000..a398f8d --- /dev/null +++ b/kittybox-rs/src/database/redis/edit_post.lua @@ -0,0 +1,93 @@ +local posts = KEYS[1] +local update_desc = cjson.decode(ARGV[2]) +local post = cjson.decode(redis.call("HGET", posts, ARGV[1])) + +local delete_keys = {} +local delete_kvs = {} +local add_keys = {} + +if update_desc.replace ~= nil then + for k, v in pairs(update_desc.replace) do + table.insert(delete_keys, k) + add_keys[k] = v + end +end +if update_desc.delete ~= nil then + if update_desc.delete[0] == nil then + -- Table has string keys. Probably! + for k, v in pairs(update_desc.delete) do + delete_kvs[k] = v + end + else + -- Table has numeric keys. Probably! + for i, v in ipairs(update_desc.delete) do + table.insert(delete_keys, v) + end + end +end +if update_desc.add ~= nil then + for k, v in pairs(update_desc.add) do + add_keys[k] = v + end +end + +for i, v in ipairs(delete_keys) do + post["properties"][v] = nil + -- TODO delete URL links +end + +for k, v in pairs(delete_kvs) do + local index = -1 + if k == "children" then + for j, w in ipairs(post[k]) do + if w == v then + index = j + break + end + end + if index > -1 then + table.remove(post[k], index) + end + else + for j, w in ipairs(post["properties"][k]) do + if w == v then + index = j + break + end + end + if index > -1 then + table.remove(post["properties"][k], index) + -- TODO delete URL links + end + end +end + +for k, v in pairs(add_keys) do + if k == "children" then + if post["children"] == nil then + post["children"] = {} + end + for i, w in ipairs(v) do + table.insert(post["children"], 1, w) + end + else + if post["properties"][k] == nil then + post["properties"][k] = {} + end + for i, w in ipairs(v) do + table.insert(post["properties"][k], w) + end + if k == "url" then + redis.call("HSET", posts, v, cjson.encode({ see_other = post["properties"]["uid"][1] })) + elseif k == "channel" then + local feed = cjson.decode(redis.call("HGET", posts, v)) + table.insert(feed["children"], 1, post["properties"]["uid"][1]) + redis.call("HSET", posts, v, cjson.encode(feed)) + end + end +end + +local encoded = cjson.encode(post) +redis.call("SET", "debug", encoded) +redis.call("HSET", posts, post["properties"]["uid"][1], encoded) +return \ No newline at end of file diff --git a/kittybox-rs/src/database/redis/mod.rs b/kittybox-rs/src/database/redis/mod.rs new file mode 100644 index 0000000..eeaa3f2 --- /dev/null +++ b/kittybox-rs/src/database/redis/mod.rs @@ -0,0 +1,392 @@ +use async_trait::async_trait; +use futures::stream; +use futures_util::FutureExt; +use futures_util::StreamExt; +use futures_util::TryStream; +use futures_util::TryStreamExt; +use lazy_static::lazy_static; +use log::error; +use mobc::Pool; +use mobc_redis::redis; +use mobc_redis::redis::AsyncCommands; +use mobc_redis::RedisConnectionManager; +use serde_json::json; +use std::time::Duration; + +use crate::database::{ErrorKind, MicropubChannel, Result, Storage, StorageError, filter_post}; +use crate::indieauth::User; + +struct RedisScripts { + edit_post: redis::Script, +} + +impl From<mobc_redis::redis::RedisError> for StorageError { + fn from(err: mobc_redis::redis::RedisError) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::Backend, + } + } +} +impl From<mobc::Error<mobc_redis::redis::RedisError>> for StorageError { + fn from(err: mobc::Error<mobc_redis::redis::RedisError>) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::Backend, + } + } +} + +lazy_static! { + static ref SCRIPTS: RedisScripts = RedisScripts { + edit_post: redis::Script::new(include_str!("./edit_post.lua")) + }; +} + +#[derive(Clone)] +pub struct RedisStorage { + // note to future Vika: + // mobc::Pool is actually a fancy name for an Arc + // around a shared connection pool with a manager + // which makes it safe to implement [`Clone`] and + // not worry about new pools being suddenly made + // + // stop worrying and start coding, you dum-dum + redis: mobc::Pool<RedisConnectionManager>, +} + +#[async_trait] +impl Storage for RedisStorage { + async fn get_setting<'a>(&self, setting: &'a str, user: &'a str) -> Result<String> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + Ok(conn + .hget::<String, &str, String>(format!("settings_{}", user), setting) + .await?) + } + + async fn set_setting<'a>(&self, setting: &'a str, user: &'a str, value: &'a str) -> Result<()> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + Ok(conn + .hset::<String, &str, &str, ()>(format!("settings_{}", user), setting, value) + .await?) + } + + async fn delete_post<'a>(&self, url: &'a str) -> Result<()> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + Ok(conn.hdel::<&str, &str, ()>("posts", url).await?) + } + + async fn post_exists(&self, url: &str) -> Result<bool> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + Ok(conn.hexists::<&str, &str, bool>("posts", url).await?) + } + + async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + match conn + .hget::<&str, &str, Option<String>>("posts", url) + .await? + { + Some(val) => { + let parsed = serde_json::from_str::<serde_json::Value>(&val)?; + if let Some(new_url) = parsed["see_other"].as_str() { + match conn + .hget::<&str, &str, Option<String>>("posts", new_url) + .await? + { + Some(val) => Ok(Some(serde_json::from_str::<serde_json::Value>(&val)?)), + None => Ok(None), + } + } else { + Ok(Some(parsed)) + } + } + None => Ok(None), + } + } + + async fn get_channels(&self, user: &User) -> Result<Vec<MicropubChannel>> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + let channels = conn + .smembers::<String, Vec<String>>("channels_".to_string() + user.me.as_str()) + .await?; + // TODO: use streams here instead of this weird thing... how did I even write this?! + Ok(futures_util::future::join_all( + channels + .iter() + .map(|channel| { + self.get_post(channel).map(|result| result.unwrap()).map( + |post: Option<serde_json::Value>| { + post.map(|post| MicropubChannel { + uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), + name: post["properties"]["name"][0].as_str().unwrap().to_string(), + }) + }, + ) + }) + .collect::<Vec<_>>(), + ) + .await + .into_iter() + .flatten() + .collect::<Vec<_>>()) + } + + async fn put_post<'a>(&self, post: &'a serde_json::Value, user: &'a str) -> Result<()> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + let key: &str; + match post["properties"]["uid"][0].as_str() { + Some(uid) => key = uid, + None => { + return Err(StorageError::new( + ErrorKind::BadRequest, + "post doesn't have a UID", + )) + } + } + conn.hset::<&str, &str, String, ()>("posts", key, post.to_string()) + .await?; + if post["properties"]["url"].is_array() { + for url in post["properties"]["url"] + .as_array() + .unwrap() + .iter() + .map(|i| i.as_str().unwrap().to_string()) + { + if url != key && url.starts_with(user) { + conn.hset::<&str, &str, String, ()>( + "posts", + &url, + json!({ "see_other": key }).to_string(), + ) + .await?; + } + } + } + if post["type"] + .as_array() + .unwrap() + .iter() + .any(|i| i == "h-feed") + { + // This is a feed. Add it to the channels array if it's not already there. + conn.sadd::<String, &str, ()>( + "channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(), + key, + ) + .await? + } + Ok(()) + } + + async fn read_feed_with_limit<'a>( + &self, + url: &'a str, + after: &'a Option<String>, + limit: usize, + user: &'a Option<String>, + ) -> Result<Option<serde_json::Value>> { + let mut conn = self.redis.get().await?; + let mut feed; + match conn + .hget::<&str, &str, Option<String>>("posts", url) + .await + .map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))? + { + Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?, + None => return Ok(None), + } + if feed["see_other"].is_string() { + match conn + .hget::<&str, &str, Option<String>>("posts", feed["see_other"].as_str().unwrap()) + .await? + { + Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?, + None => return Ok(None), + } + } + if let Some(post) = filter_post(feed, user) { + feed = post + } else { + return Err(StorageError::new( + ErrorKind::PermissionDenied, + "specified user cannot access this post", + )); + } + if feed["children"].is_array() { + let children = feed["children"].as_array().unwrap(); + let mut posts_iter = children.iter().map(|i| i.as_str().unwrap().to_string()); + if after.is_some() { + loop { + let i = posts_iter.next(); + if &i == after { + break; + } + } + } + async fn fetch_post_for_feed(url: String) -> Option<serde_json::Value> { + return Some(serde_json::json!({})); + } + let posts = stream::iter(posts_iter) + .map(|url: String| async move { + return Ok(fetch_post_for_feed(url).await); + /*match self.redis.get().await { + Ok(mut conn) => { + match conn.hget::<&str, &str, Option<String>>("posts", &url).await { + Ok(post) => match post { + Some(post) => { + Ok(Some(serde_json::from_str(&post)?)) + } + // Happens because of a broken link (result of an improper deletion?) + None => Ok(None), + }, + Err(err) => Err(StorageError::with_source(ErrorKind::Backend, "Error executing a Redis command", Box::new(err))) + } + } + Err(err) => Err(StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(err))) + }*/ + }) + // TODO: determine the optimal value for this buffer + // It will probably depend on how often can you encounter a private post on the page + // It shouldn't be too large, or we'll start fetching too many posts from the database + // It MUST NOT be larger than the typical page size + // It MUST NOT be a significant amount of the connection pool size + //.buffered(std::cmp::min(3, limit)) + // Hack to unwrap the Option and sieve out broken links + // Broken links return None, and Stream::filter_map skips all Nones. + // I wonder if one can use try_flatten() here somehow akin to iters + .try_filter_map(|post| async move { Ok(post) }) + .try_filter_map(|post| async move { + Ok(filter_post(post, user)) + }) + .take(limit); + match posts.try_collect::<Vec<serde_json::Value>>().await { + Ok(posts) => feed["children"] = json!(posts), + Err(err) => { + let e = StorageError::with_source( + ErrorKind::Other, + "An error was encountered while processing the feed", + Box::new(err) + ); + error!("Error while assembling feed: {}", e); + return Err(e); + } + } + } + return Ok(Some(feed)); + } + + async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + if !conn + .hexists::<&str, &str, bool>("posts", url) + .await + .unwrap() + { + return Err(StorageError::new( + ErrorKind::NotFound, + "can't edit a non-existent post", + )); + } + let post: serde_json::Value = + serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await?)?; + if let Some(new_url) = post["see_other"].as_str() { + url = new_url + } + Ok(SCRIPTS + .edit_post + .key("posts") + .arg(url) + .arg(update.to_string()) + .invoke_async::<_, ()>(&mut conn as &mut redis::aio::Connection) + .await?) + } +} + +impl RedisStorage { + /// Create a new RedisDatabase that will connect to Redis at `redis_uri` to store data. + pub async fn new(redis_uri: String) -> Result<Self> { + match redis::Client::open(redis_uri) { + Ok(client) => Ok(Self { + redis: Pool::builder() + .max_open(20) + .max_idle(5) + .get_timeout(Some(Duration::from_secs(3))) + .max_lifetime(Some(Duration::from_secs(120))) + .build(RedisConnectionManager::new(client)), + }), + Err(e) => Err(e.into()), + } + } + + pub async fn conn(&self) -> Result<mobc::Connection<mobc_redis::RedisConnectionManager>> { + self.redis.get().await.map_err(|e| StorageError::with_source( + ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e) + )) + } +} + +#[cfg(test)] +pub mod tests { + use mobc_redis::redis; + use std::process; + use std::time::Duration; + + pub struct RedisInstance { + // We just need to hold on to it so it won't get dropped and remove the socket + _tempdir: tempdir::TempDir, + uri: String, + child: std::process::Child, + } + impl Drop for RedisInstance { + fn drop(&mut self) { + self.child.kill().expect("Failed to kill the child!"); + } + } + impl RedisInstance { + pub fn uri(&self) -> &str { + &self.uri + } + } + + pub async fn get_redis_instance() -> RedisInstance { + let tempdir = tempdir::TempDir::new("redis").expect("failed to create tempdir"); + let socket = tempdir.path().join("redis.sock"); + let redis_child = process::Command::new("redis-server") + .current_dir(&tempdir) + .arg("--port") + .arg("0") + .arg("--unixsocket") + .arg(&socket) + .stdout(process::Stdio::null()) + .stderr(process::Stdio::null()) + .spawn() + .expect("Failed to spawn Redis"); + println!("redis+unix:///{}", socket.to_str().unwrap()); + let uri = format!("redis+unix:///{}", socket.to_str().unwrap()); + // There should be a slight delay, we need to wait for Redis to spin up + let client = redis::Client::open(uri.clone()).unwrap(); + let millisecond = Duration::from_millis(1); + let mut retries: usize = 0; + const MAX_RETRIES: usize = 60 * 1000/*ms*/; + while let Err(err) = client.get_connection() { + if err.is_connection_refusal() { + async_std::task::sleep(millisecond).await; + retries += 1; + if retries > MAX_RETRIES { + panic!("Timeout waiting for Redis, last error: {}", err); + } + } else { + panic!("Could not connect: {}", err); + } + } + + RedisInstance { + uri, + child: redis_child, + _tempdir: tempdir, + } + } +} |