diff options
author | Vika <vika@fireburn.ru> | 2023-07-29 21:59:56 +0300 |
---|---|---|
committer | Vika <vika@fireburn.ru> | 2023-07-29 21:59:56 +0300 |
commit | 0617663b249f9ca488e5de652108b17d67fbaf45 (patch) | |
tree | 11564b6c8fa37bf9203a0a4cc1c4e9cc088cb1a5 /src | |
parent | 26c2b79f6a6380ae3224e9309b9f3352f5717bd7 (diff) | |
download | kittybox-0617663b249f9ca488e5de652108b17d67fbaf45.tar.zst |
Moved the entire Kittybox tree into the root
Diffstat (limited to 'src')
31 files changed, 9450 insertions, 0 deletions
diff --git a/src/bin/kittybox-check-webmention.rs b/src/bin/kittybox-check-webmention.rs new file mode 100644 index 0000000..f02032c --- /dev/null +++ b/src/bin/kittybox-check-webmention.rs @@ -0,0 +1,152 @@ +use std::cell::{RefCell, Ref}; +use std::rc::Rc; + +use clap::Parser; +use microformats::types::PropertyValue; +use microformats::html5ever; +use microformats::html5ever::tendril::TendrilSink; + +#[derive(thiserror::Error, Debug)] +enum Error { + #[error("http request error: {0}")] + Http(#[from] reqwest::Error), + #[error("microformats error: {0}")] + Microformats(#[from] microformats::Error), + #[error("json error: {0}")] + Json(#[from] serde_json::Error), + #[error("url parse error: {0}")] + UrlParse(#[from] url::ParseError), +} + +use kittybox_util::MentionType; + +fn check_mention(document: impl AsRef<str>, base_url: &url::Url, link: &url::Url) -> Result<Option<MentionType>, Error> { + // First, check the document for MF2 markup + let document = microformats::from_html(document.as_ref(), base_url.clone())?; + + // Get an iterator of all items + let items_iter = document.items.iter() + .map(AsRef::as_ref) + .map(RefCell::borrow); + + for item in items_iter { + let props = item.properties.borrow(); + for (prop, interaction_type) in [ + ("in-reply-to", MentionType::Reply), ("like-of", MentionType::Like), + ("bookmark-of", MentionType::Bookmark), ("repost-of", MentionType::Repost) + ] { + if let Some(propvals) = props.get(prop) { + for val in propvals { + if let PropertyValue::Url(url) = val { + if url == link { + return Ok(Some(interaction_type)) + } + } + } + } + } + // Process `content` + if let Some(PropertyValue::Fragment(content)) = props.get("content") + .map(Vec::as_slice) + .unwrap_or_default() + .first() + { + let root = html5ever::parse_document(html5ever::rcdom::RcDom::default(), Default::default()) + .from_utf8() + .one(content.html.to_owned().as_bytes()) + .document; + + // This is a trick to unwrap recursion into a loop + // + // A list of unprocessed node is made. Then, in each + // iteration, the list is "taken" and replaced with an + // empty list, which is populated with nodes for the next + // iteration of the loop. + // + // Empty list means all nodes were processed. + let mut unprocessed_nodes: Vec<Rc<html5ever::rcdom::Node>> = root.children.borrow().iter().cloned().collect(); + while unprocessed_nodes.len() > 0 { + // "Take" the list out of its memory slot, replace it with an empty list + let nodes = std::mem::take(&mut unprocessed_nodes); + 'nodes_loop: for node in nodes.into_iter() { + // Add children nodes to the list for the next iteration + unprocessed_nodes.extend(node.children.borrow().iter().cloned()); + + if let html5ever::rcdom::NodeData::Element { ref name, ref attrs, .. } = node.data { + // If it's not `<a>`, skip it + if name.local != *"a" { continue; } + let mut is_mention: bool = false; + for attr in attrs.borrow().iter() { + if attr.name.local == *"rel" { + // Don't count `rel="nofollow"` links — a web crawler should ignore them + // and so for purposes of driving visitors they are useless + if attr.value + .as_ref() + .split([',', ' ']) + .any(|v| v == "nofollow") + { + // Skip the entire node. + continue 'nodes_loop; + } + } + // if it's not `<a href="...">`, skip it + if attr.name.local != *"href" { continue; } + // Be forgiving in parsing URLs, and resolve them against the base URL + if let Ok(url) = base_url.join(attr.value.as_ref()) { + if &url == link { + is_mention = true; + } + } + } + if is_mention { + return Ok(Some(MentionType::Mention)); + } + } + } + } + + } + } + + Ok(None) +} + +#[derive(Parser, Debug)] +#[clap( + name = "kittybox-check-webmention", + author = "Vika <vika@fireburn.ru>", + version = env!("CARGO_PKG_VERSION"), + about = "Verify an incoming webmention" +)] +struct Args { + #[clap(value_parser)] + url: url::Url, + #[clap(value_parser)] + link: url::Url +} + +#[tokio::main] +async fn main() -> Result<(), self::Error> { + let args = Args::parse(); + + let http: reqwest::Client = { + #[allow(unused_mut)] + let mut builder = reqwest::Client::builder() + .user_agent(concat!( + env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION") + )); + + builder.build().unwrap() + }; + + let response = http.get(args.url.clone()).send().await?; + let text = response.text().await?; + + if let Some(mention_type) = check_mention(text, &args.url, &args.link)? { + println!("{:?}", mention_type); + + Ok(()) + } else { + std::process::exit(1) + } +} diff --git a/src/bin/kittybox-indieauth-helper.rs b/src/bin/kittybox-indieauth-helper.rs new file mode 100644 index 0000000..3377ec3 --- /dev/null +++ b/src/bin/kittybox-indieauth-helper.rs @@ -0,0 +1,233 @@ +use kittybox_indieauth::{ + AuthorizationRequest, PKCEVerifier, + PKCEChallenge, PKCEMethod, GrantRequest, Scope, + AuthorizationResponse, TokenData, GrantResponse +}; +use clap::Parser; +use std::{borrow::Cow, io::Write}; + +const DEFAULT_CLIENT_ID: &str = "https://kittybox.fireburn.ru/indieauth-helper.html"; +const DEFAULT_REDIRECT_URI: &str = "http://localhost:60000/callback"; + +#[derive(Debug, thiserror::Error)] +enum Error { + #[error("i/o error: {0}")] + IO(#[from] std::io::Error), + #[error("http request error: {0}")] + HTTP(#[from] reqwest::Error), + #[error("urlencoded encoding error: {0}")] + UrlencodedEncoding(#[from] serde_urlencoded::ser::Error), + #[error("url parsing error: {0}")] + UrlParse(#[from] url::ParseError), + #[error("indieauth flow error: {0}")] + IndieAuth(Cow<'static, str>) +} + +#[derive(Parser, Debug)] +#[clap( + name = "kittybox-indieauth-helper", + author = "Vika <vika@fireburn.ru>", + version = env!("CARGO_PKG_VERSION"), + about = "Retrieve an IndieAuth token for debugging", + long_about = None +)] +struct Args { + /// Profile URL to use for initiating IndieAuth metadata discovery. + #[clap(value_parser)] + me: url::Url, + /// Scopes to request for the token. + /// + /// All IndieAuth scopes are supported, including arbitrary custom scopes. + #[clap(short, long)] + scope: Vec<Scope>, + /// Client ID to use when requesting a token. + #[clap(short, long, value_parser, default_value = DEFAULT_CLIENT_ID)] + client_id: url::Url, + /// Redirect URI to declare. Note: This will break the flow, use only for testing UI. + #[clap(long, value_parser)] + redirect_uri: Option<url::Url> +} + +fn append_query_string<T: serde::Serialize>( + url: &url::Url, + query: T +) -> Result<url::Url, Error> { + let mut new_url = url.clone(); + let mut query = serde_urlencoded::to_string(query)?; + if let Some(old_query) = url.query() { + query.push('&'); + query.push_str(old_query); + } + new_url.set_query(Some(&query)); + + Ok(new_url) +} + +#[tokio::main] +async fn main() -> Result<(), Error> { + let args = Args::parse(); + + let http: reqwest::Client = { + #[allow(unused_mut)] + let mut builder = reqwest::Client::builder() + .user_agent(concat!( + env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION") + )); + + builder.build().unwrap() + }; + + let redirect_uri: url::Url = args.redirect_uri + .clone() + .unwrap_or_else(|| DEFAULT_REDIRECT_URI.parse().unwrap()); + + eprintln!("Checking .well-known for metadata..."); + let metadata = http.get(args.me.join("/.well-known/oauth-authorization-server")?) + .header("Accept", "application/json") + .send() + .await? + .json::<kittybox_indieauth::Metadata>() + .await?; + + let verifier = PKCEVerifier::new(); + + let authorization_request = AuthorizationRequest { + response_type: kittybox_indieauth::ResponseType::Code, + client_id: args.client_id.clone(), + redirect_uri: redirect_uri.clone(), + state: kittybox_indieauth::State::new(), + code_challenge: PKCEChallenge::new(&verifier, PKCEMethod::default()), + scope: Some(kittybox_indieauth::Scopes::new(args.scope)), + me: Some(args.me) + }; + + let indieauth_url = append_query_string( + &metadata.authorization_endpoint, + authorization_request + )?; + + eprintln!("Please visit the following URL in your browser:\n\n {}\n", indieauth_url.as_str()); + + if args.redirect_uri.is_some() { + eprintln!("Custom redirect URI specified, won't be able to catch authorization response."); + std::process::exit(0); + } + + // Prepare a callback + let (tx, rx) = tokio::sync::oneshot::channel::<AuthorizationResponse>(); + let server = { + use axum::{routing::get, extract::Query, response::IntoResponse}; + + let tx = std::sync::Arc::new(tokio::sync::Mutex::new(Some(tx))); + + let router = axum::Router::new() + .route("/callback", axum::routing::get( + move |query: Option<Query<AuthorizationResponse>>| async move { + if let Some(Query(response)) = query { + if let Some(tx) = tx.lock_owned().await.take() { + tx.send(response).unwrap(); + + (axum::http::StatusCode::OK, + [("Content-Type", "text/plain")], + "Thank you! This window can now be closed.") + .into_response() + } else { + (axum::http::StatusCode::BAD_REQUEST, + [("Content-Type", "text/plain")], + "Oops. The callback was already received. Did you click twice?") + .into_response() + } + } else { + axum::http::StatusCode::BAD_REQUEST.into_response() + } + } + )); + + use std::net::{SocketAddr, IpAddr, Ipv4Addr}; + + let server = hyper::server::Server::bind( + &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST),60000) + ) + .serve(router.into_make_service()); + + tokio::task::spawn(server) + }; + + let authorization_response = rx.await.unwrap(); + + // Clean up after the server + tokio::task::spawn(async move { + // Wait for the server to settle -- it might need to send its response + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + // Abort the future -- this should kill the server + server.abort(); + }); + + eprintln!("Got authorization response: {:#?}", authorization_response); + eprint!("Checking issuer field..."); + std::io::stderr().lock().flush()?; + + if dbg!(authorization_response.iss.as_str()) == dbg!(metadata.issuer.as_str()) { + eprintln!(" Done"); + } else { + eprintln!(" Failed"); + #[cfg(not(debug_assertions))] + std::process::exit(1); + } + let grant_response: GrantResponse = http.post(metadata.token_endpoint) + .form(&GrantRequest::AuthorizationCode { + code: authorization_response.code, + client_id: args.client_id, + redirect_uri, + code_verifier: verifier + }) + .header("Accept", "application/json") + .send() + .await? + .json() + .await?; + + if let GrantResponse::AccessToken { + me, + profile, + access_token, + expires_in, + refresh_token, + token_type, + scope + } = grant_response { + eprintln!("Congratulations, {}, access token is ready! {}", + me.as_str(), + if let Some(exp) = expires_in { + format!("It expires in {exp} seconds.") + } else { + format!("It seems to have unlimited duration.") + } + ); + println!("{}", access_token); + if let Some(refresh_token) = refresh_token { + eprintln!("Save this refresh token, it will come in handy:"); + println!("{}", refresh_token); + }; + + if let Some(profile) = profile { + eprintln!("\nThe token endpoint returned some profile information:"); + if let Some(name) = profile.name { + eprintln!(" - Name: {name}") + } + if let Some(url) = profile.url { + eprintln!(" - URL: {url}") + } + if let Some(photo) = profile.photo { + eprintln!(" - Photo: {photo}") + } + if let Some(email) = profile.email { + eprintln!(" - Email: {email}") + } + } + + Ok(()) + } else { + return Err(Error::IndieAuth(Cow::Borrowed("IndieAuth token endpoint did not return an access token grant."))); + } +} diff --git a/src/bin/kittybox-mf2.rs b/src/bin/kittybox-mf2.rs new file mode 100644 index 0000000..4366cb8 --- /dev/null +++ b/src/bin/kittybox-mf2.rs @@ -0,0 +1,49 @@ +use clap::Parser; + +#[derive(Parser, Debug)] +#[clap( + name = "kittybox-mf2", + author = "Vika <vika@fireburn.ru>", + version = env!("CARGO_PKG_VERSION"), + about = "Fetch HTML and turn it into MF2-JSON" +)] +struct Args { + #[clap(value_parser)] + url: url::Url, +} + +#[derive(thiserror::Error, Debug)] +enum Error { + #[error("http request error: {0}")] + Http(#[from] reqwest::Error), + #[error("microformats error: {0}")] + Microformats(#[from] microformats::Error), + #[error("json error: {0}")] + Json(#[from] serde_json::Error), + #[error("url parse error: {0}")] + UrlParse(#[from] url::ParseError), +} + +#[tokio::main] +async fn main() -> Result<(), Error> { + let args = Args::parse(); + + let http: reqwest::Client = { + #[allow(unused_mut)] + let mut builder = reqwest::Client::builder() + .user_agent(concat!( + env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION") + )); + + builder.build().unwrap() + }; + + let response = http.get(args.url.clone()).send().await?; + let text = response.text().await?; + + let mf2 = microformats::from_html(text.as_ref(), args.url)?; + + println!("{}", serde_json::to_string_pretty(&mf2)?); + + Ok(()) +} diff --git a/src/bin/kittybox_bulk_import.rs b/src/bin/kittybox_bulk_import.rs new file mode 100644 index 0000000..7e1f6af --- /dev/null +++ b/src/bin/kittybox_bulk_import.rs @@ -0,0 +1,66 @@ +use anyhow::{anyhow, bail, Context, Result}; +use std::fs::File; +use std::io; + +#[async_std::main] +async fn main() -> Result<()> { + let args = std::env::args().collect::<Vec<String>>(); + if args.iter().skip(1).any(|s| s == "--help") { + println!("Usage: {} <url> [file]", args[0]); + println!("\nIf launched with no arguments, reads from stdin."); + println!( + "\nUse KITTYBOX_AUTH_TOKEN environment variable to authorize to the Micropub endpoint." + ); + std::process::exit(0); + } + + let token = std::env::var("KITTYBOX_AUTH_TOKEN") + .map_err(|_| anyhow!("No auth token found! Use KITTYBOX_AUTH_TOKEN env variable."))?; + let data: Vec<serde_json::Value> = (if args.len() == 2 || (args.len() == 3 && args[2] == "-") { + serde_json::from_reader(io::stdin()) + } else if args.len() == 3 { + serde_json::from_reader(File::open(&args[2]).with_context(|| "Error opening input file")?) + } else { + bail!("See `{} --help` for usage.", args[0]); + }) + .with_context(|| "Error while loading the input file")?; + + let url = surf::Url::parse(&args[1])?; + let client = surf::Client::new(); + + let iter = data.into_iter(); + + for post in iter { + println!( + "Processing {}...", + post["properties"]["url"][0] + .as_str() + .or_else(|| post["properties"]["published"][0] + .as_str() + .or_else(|| post["properties"]["name"][0] + .as_str() + .or(Some("<unidentified post>")))) + .unwrap() + ); + match client + .post(&url) + .body(surf::http::Body::from_string(serde_json::to_string(&post)?)) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", &token)) + .send() + .await + { + Ok(mut response) => { + if response.status() == 201 || response.status() == 202 { + println!("Posted at {}", response.header("location").unwrap().last()); + } else { + println!("Error: {:?}", response.body_string().await); + } + } + Err(err) => { + println!("{}", err); + } + } + } + Ok(()) +} diff --git a/src/bin/kittybox_database_converter.rs b/src/bin/kittybox_database_converter.rs new file mode 100644 index 0000000..bc355c9 --- /dev/null +++ b/src/bin/kittybox_database_converter.rs @@ -0,0 +1,106 @@ +use anyhow::{anyhow, Context}; +use kittybox::database::FileStorage; +use kittybox::database::Storage; +use redis::{self, AsyncCommands}; +use std::collections::HashMap; + +/// Convert from a Redis storage to a new storage new_storage. +async fn convert_from_redis<S: Storage>(from: String, new_storage: S) -> anyhow::Result<()> { + let db = redis::Client::open(from).context("Failed to open the Redis connection")?; + + let mut conn = db + .get_async_std_connection() + .await + .context("Failed to connect to Redis")?; + + // Rebinding to convince the borrow checker we're not smuggling stuff outta scope + let storage = &new_storage; + + let mut stream = conn.hscan::<_, String>("posts").await?; + + while let Some(key) = stream.next_item().await { + let value = serde_json::from_str::<serde_json::Value>( + &stream + .next_item() + .await + .ok_or(anyhow!("Failed to find a corresponding value for the key"))?, + )?; + + println!("{}, {:?}", key, value); + + if value["see_other"].is_string() { + continue; + } + + let user = &(url::Url::parse(value["properties"]["uid"][0].as_str().unwrap()) + .unwrap() + .origin() + .ascii_serialization() + .clone() + + "/"); + if let Err(err) = storage.clone().put_post(&value, user).await { + eprintln!("Error saving post: {}", err); + } + } + + let mut stream: redis::AsyncIter<String> = conn.scan_match("settings_*").await?; + while let Some(key) = stream.next_item().await { + let mut conn = db + .get_async_std_connection() + .await + .context("Failed to connect to Redis")?; + let user = key.strip_prefix("settings_").unwrap(); + match conn + .hgetall::<&str, HashMap<String, String>>(&key) + .await + .context(format!("Failed getting settings from key {}", key)) + { + Ok(settings) => { + for (k, v) in settings.iter() { + if let Err(e) = storage + .set_setting(k, user, v) + .await + .with_context(|| format!("Failed setting {} for {}", k, user)) + { + eprintln!("{}", e); + } + } + } + Err(e) => { + eprintln!("{}", e); + } + } + } + + Ok(()) +} + +#[async_std::main] +async fn main() -> anyhow::Result<()> { + let mut args = std::env::args(); + args.next(); // skip argv[0] + let old_uri = args + .next() + .ok_or_else(|| anyhow!("No import source is provided."))?; + let new_uri = args + .next() + .ok_or_else(|| anyhow!("No import destination is provided."))?; + + let storage = if new_uri.starts_with("file:") { + let folder = new_uri.strip_prefix("file://").unwrap(); + let path = std::path::PathBuf::from(folder); + Box::new( + FileStorage::new(path) + .await + .context("Failed to construct the file storage")?, + ) + } else { + anyhow::bail!("Cannot construct the storage abstraction for destination storage. Check the storage type?"); + }; + + if old_uri.starts_with("redis") { + convert_from_redis(old_uri, *storage).await? + } + + Ok(()) +} diff --git a/src/database/file/mod.rs b/src/database/file/mod.rs new file mode 100644 index 0000000..27d3da1 --- /dev/null +++ b/src/database/file/mod.rs @@ -0,0 +1,733 @@ +//#![warn(clippy::unwrap_used)] +use crate::database::{ErrorKind, Result, settings, Storage, StorageError}; +use crate::micropub::{MicropubUpdate, MicropubPropertyDeletion}; +use async_trait::async_trait; +use futures::{stream, StreamExt, TryStreamExt}; +use kittybox_util::MentionType; +use serde_json::json; +use std::borrow::Cow; +use std::collections::HashMap; +use std::io::ErrorKind as IOErrorKind; +use std::path::{Path, PathBuf}; +use tokio::fs::{File, OpenOptions}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::task::spawn_blocking; +use tracing::{debug, error}; + +impl From<std::io::Error> for StorageError { + fn from(source: std::io::Error) -> Self { + Self::with_source( + match source.kind() { + IOErrorKind::NotFound => ErrorKind::NotFound, + IOErrorKind::AlreadyExists => ErrorKind::Conflict, + _ => ErrorKind::Backend, + }, + Cow::Owned(format!("file I/O error: {}", &source)), + Box::new(source), + ) + } +} + +impl From<tokio::time::error::Elapsed> for StorageError { + fn from(source: tokio::time::error::Elapsed) -> Self { + Self::with_source( + ErrorKind::Backend, + Cow::Borrowed("timeout on I/O operation"), + Box::new(source), + ) + } +} + +// Copied from https://stackoverflow.com/questions/39340924 +// This routine is adapted from the *old* Path's `path_relative_from` +// function, which works differently from the new `relative_from` function. +// In particular, this handles the case on unix where both paths are +// absolute but with only the root as the common directory. +fn path_relative_from(path: &Path, base: &Path) -> Option<PathBuf> { + use std::path::Component; + + if path.is_absolute() != base.is_absolute() { + if path.is_absolute() { + Some(PathBuf::from(path)) + } else { + None + } + } else { + let mut ita = path.components(); + let mut itb = base.components(); + let mut comps: Vec<Component> = vec![]; + loop { + match (ita.next(), itb.next()) { + (None, None) => break, + (Some(a), None) => { + comps.push(a); + comps.extend(ita.by_ref()); + break; + } + (None, _) => comps.push(Component::ParentDir), + (Some(a), Some(b)) if comps.is_empty() && a == b => (), + (Some(a), Some(b)) if b == Component::CurDir => comps.push(a), + (Some(_), Some(b)) if b == Component::ParentDir => return None, + (Some(a), Some(_)) => { + comps.push(Component::ParentDir); + for _ in itb { + comps.push(Component::ParentDir); + } + comps.push(a); + comps.extend(ita.by_ref()); + break; + } + } + } + Some(comps.iter().map(|c| c.as_os_str()).collect()) + } +} + +#[allow(clippy::unwrap_used, clippy::expect_used)] +#[cfg(test)] +mod tests { + #[test] + fn test_relative_path_resolving() { + let path1 = std::path::Path::new("/home/vika/Projects/kittybox"); + let path2 = std::path::Path::new("/home/vika/Projects/nixpkgs"); + let relative_path = super::path_relative_from(path2, path1).unwrap(); + + assert_eq!(relative_path, std::path::Path::new("../nixpkgs")) + } +} + +// TODO: Check that the path ACTUALLY IS INSIDE THE ROOT FOLDER +// This could be checked by completely resolving the path +// and checking if it has a common prefix +fn url_to_path(root: &Path, url: &str) -> PathBuf { + let path = url_to_relative_path(url).to_logical_path(root); + if !path.starts_with(root) { + // TODO: handle more gracefully + panic!("Security error: {:?} is not a prefix of {:?}", path, root) + } else { + path + } +} + +fn url_to_relative_path(url: &str) -> relative_path::RelativePathBuf { + let url = url::Url::try_from(url).expect("Couldn't parse a URL"); + let mut path = relative_path::RelativePathBuf::new(); + let user_domain = format!( + "{}{}", + url.host_str().unwrap(), + url.port() + .map(|port| format!(":{}", port)) + .unwrap_or_default() + ); + path.push(user_domain + url.path() + ".json"); + + path +} + +fn modify_post(post: &serde_json::Value, update: MicropubUpdate) -> Result<serde_json::Value> { + let mut post = post.clone(); + + let mut add_keys: HashMap<String, Vec<serde_json::Value>> = HashMap::new(); + let mut remove_keys: Vec<String> = vec![]; + let mut remove_values: HashMap<String, Vec<serde_json::Value>> = HashMap::new(); + + if let Some(MicropubPropertyDeletion::Properties(delete)) = update.delete { + remove_keys.extend(delete.iter().cloned()); + } else if let Some(MicropubPropertyDeletion::Values(delete)) = update.delete { + for (k, v) in delete { + remove_values + .entry(k.to_string()) + .or_default() + .extend(v.clone()); + } + } + if let Some(add) = update.add { + for (k, v) in add { + add_keys.insert(k.to_string(), v.clone()); + } + } + if let Some(replace) = update.replace { + for (k, v) in replace { + remove_keys.push(k.to_string()); + add_keys.insert(k.to_string(), v.clone()); + } + } + + if let Some(props) = post["properties"].as_object_mut() { + for k in remove_keys { + props.remove(&k); + } + } + for (k, v) in remove_values { + let k = &k; + let props = if k == "children" { + &mut post + } else { + &mut post["properties"] + }; + v.iter().for_each(|v| { + if let Some(vec) = props[k].as_array_mut() { + if let Some(index) = vec.iter().position(|w| w == v) { + vec.remove(index); + } + } + }); + } + for (k, v) in add_keys { + tracing::debug!("Adding k/v to post: {} => {:?}", k, v); + let props = if k == "children" { + &mut post + } else { + &mut post["properties"] + }; + if let Some(prop) = props[&k].as_array_mut() { + if k == "children" { + v.into_iter().rev().for_each(|v| prop.insert(0, v)); + } else { + prop.extend(v.into_iter()); + } + } else { + props[&k] = serde_json::Value::Array(v) + } + } + Ok(post) +} + +#[derive(Clone, Debug)] +/// A backend using a folder with JSON files as a backing store. +/// Uses symbolic links to represent a many-to-one mapping of URLs to a post. +pub struct FileStorage { + root_dir: PathBuf, +} + +impl FileStorage { + /// Create a new storage wrapping a folder specified by root_dir. + pub async fn new(root_dir: PathBuf) -> Result<Self> { + // TODO check if the dir is writable + Ok(Self { root_dir }) + } +} + +async fn hydrate_author<S: Storage>( + feed: &mut serde_json::Value, + user: &'_ Option<String>, + storage: &S, +) { + let url = feed["properties"]["uid"][0] + .as_str() + .expect("MF2 value should have a UID set! Check if you used normalize_mf2 before recording the post!"); + if let Some(author) = feed["properties"]["author"].as_array().cloned() { + if !feed["type"] + .as_array() + .expect("MF2 value should have a type set!") + .iter() + .any(|i| i == "h-card") + { + let author_list: Vec<serde_json::Value> = stream::iter(author.iter()) + .then(|i| async move { + if let Some(i) = i.as_str() { + match storage.get_post(i).await { + Ok(post) => match post { + Some(post) => post, + None => json!(i), + }, + Err(e) => { + error!("Error while hydrating post {}: {}", url, e); + json!(i) + } + } + } else { + i.clone() + } + }) + .collect::<Vec<_>>() + .await; + if let Some(props) = feed["properties"].as_object_mut() { + props["author"] = json!(author_list); + } else { + feed["properties"] = json!({ "author": author_list }); + } + } + } +} + +#[async_trait] +impl Storage for FileStorage { + #[tracing::instrument(skip(self))] + async fn post_exists(&self, url: &str) -> Result<bool> { + let path = url_to_path(&self.root_dir, url); + debug!("Checking if {:?} exists...", path); + /*let result = match tokio::fs::metadata(path).await { + Ok(metadata) => { + Ok(true) + }, + Err(err) => { + if err.kind() == IOErrorKind::NotFound { + Ok(false) + } else { + Err(err.into()) + } + } + };*/ + #[allow(clippy::unwrap_used)] // JoinHandle captures panics, this closure shouldn't panic + Ok(spawn_blocking(move || path.is_file()).await.unwrap()) + } + + #[tracing::instrument(skip(self))] + async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> { + let path = url_to_path(&self.root_dir, url); + // TODO: check that the path actually belongs to the dir of user who requested it + // it's not like you CAN access someone else's private posts with it + // so it's not exactly a security issue, but it's still not good + debug!("Opening {:?}", path); + + match File::open(&path).await { + Ok(mut file) => { + let mut content = String::new(); + // Typechecks because OS magic acts on references + // to FDs as if they were behind a mutex + AsyncReadExt::read_to_string(&mut file, &mut content).await?; + debug!( + "Read {} bytes successfully from {:?}", + content.as_bytes().len(), + &path + ); + Ok(Some(serde_json::from_str(&content)?)) + } + Err(err) => { + if err.kind() == IOErrorKind::NotFound { + Ok(None) + } else { + Err(err.into()) + } + } + } + } + + #[tracing::instrument(skip(self))] + async fn put_post(&self, post: &'_ serde_json::Value, user: &'_ str) -> Result<()> { + let key = post["properties"]["uid"][0] + .as_str() + .expect("Tried to save a post without UID"); + let path = url_to_path(&self.root_dir, key); + let tempfile = (&path).with_extension("tmp"); + debug!("Creating {:?}", path); + + let parent = path + .parent() + .expect("Parent for this directory should always exist") + .to_owned(); + tokio::fs::create_dir_all(&parent).await?; + + let mut file = tokio::fs::OpenOptions::new() + .write(true) + .create_new(true) + .open(&tempfile) + .await?; + + file.write_all(post.to_string().as_bytes()).await?; + file.flush().await?; + file.sync_all().await?; + drop(file); + tokio::fs::rename(&tempfile, &path).await?; + tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + + if let Some(urls) = post["properties"]["url"].as_array() { + for url in urls.iter().map(|i| i.as_str().unwrap()) { + let url_domain = { + let url = url::Url::parse(url).unwrap(); + format!( + "{}{}", + url.host_str().unwrap(), + url.port() + .map(|port| format!(":{}", port)) + .unwrap_or_default() + ) + }; + if url != key && url_domain == user { + let link = url_to_path(&self.root_dir, url); + debug!("Creating a symlink at {:?}", link); + let orig = path.clone(); + // We're supposed to have a parent here. + let basedir = link.parent().ok_or_else(|| { + StorageError::from_static( + ErrorKind::Backend, + "Failed to calculate parent directory when creating a symlink", + ) + })?; + let relative = path_relative_from(&orig, basedir).unwrap(); + println!("{:?} - {:?} = {:?}", &orig, &basedir, &relative); + tokio::fs::symlink(relative, link).await?; + } + } + } + + if post["type"] + .as_array() + .unwrap() + .iter() + .any(|s| s.as_str() == Some("h-feed")) + { + tracing::debug!("Adding to channel list..."); + // Add the h-feed to the channel list + let path = { + let mut path = relative_path::RelativePathBuf::new(); + path.push(user); + path.push("channels"); + + path.to_path(&self.root_dir) + }; + tokio::fs::create_dir_all(path.parent().unwrap()).await?; + tracing::debug!("Channels file path: {}", path.display()); + let tempfilename = path.with_extension("tmp"); + let channel_name = post["properties"]["name"][0] + .as_str() + .map(|s| s.to_string()) + .unwrap_or_else(String::default); + let key = key.to_string(); + tracing::debug!("Opening temporary file to modify chnanels..."); + let mut tempfile = OpenOptions::new() + .write(true) + .create_new(true) + .open(&tempfilename) + .await?; + tracing::debug!("Opening real channel file..."); + let mut channels: Vec<super::MicropubChannel> = { + match OpenOptions::new() + .read(true) + .write(false) + .truncate(false) + .create(false) + .open(&path) + .await + { + Err(err) if err.kind() == std::io::ErrorKind::NotFound => { + Vec::default() + } + Err(err) => { + // Propagate the error upwards + return Err(err.into()); + } + Ok(mut file) => { + let mut content = String::new(); + file.read_to_string(&mut content).await?; + drop(file); + + if !content.is_empty() { + serde_json::from_str(&content)? + } else { + Vec::default() + } + } + } + }; + + channels.push(super::MicropubChannel { + uid: key.to_string(), + name: channel_name, + }); + + tempfile + .write_all(serde_json::to_string(&channels)?.as_bytes()) + .await?; + tempfile.flush().await?; + tempfile.sync_all().await?; + drop(tempfile); + tokio::fs::rename(tempfilename, &path).await?; + tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + } + Ok(()) + } + + #[tracing::instrument(skip(self))] + async fn update_post(&self, url: &str, update: MicropubUpdate) -> Result<()> { + let path = url_to_path(&self.root_dir, url); + let tempfilename = path.with_extension("tmp"); + #[allow(unused_variables)] + let (old_json, new_json) = { + let mut temp = OpenOptions::new() + .write(true) + .create_new(true) + .open(&tempfilename) + .await?; + let mut file = OpenOptions::new().read(true).open(&path).await?; + + let mut content = String::new(); + file.read_to_string(&mut content).await?; + let json: serde_json::Value = serde_json::from_str(&content)?; + drop(file); + // Apply the editing algorithms + let new_json = modify_post(&json, update)?; + + temp.write_all(new_json.to_string().as_bytes()).await?; + temp.flush().await?; + temp.sync_all().await?; + drop(temp); + tokio::fs::rename(tempfilename, &path).await?; + tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + + (json, new_json) + }; + // TODO check if URLs changed between old and new JSON + Ok(()) + } + + #[tracing::instrument(skip(self))] + async fn get_channels(&self, user: &'_ str) -> Result<Vec<super::MicropubChannel>> { + let mut path = relative_path::RelativePathBuf::new(); + path.push(user); + path.push("channels"); + + let path = path.to_path(&self.root_dir); + tracing::debug!("Channels file path: {}", path.display()); + + match File::open(&path).await { + Ok(mut f) => { + let mut content = String::new(); + f.read_to_string(&mut content).await?; + // This should not happen, but if it does, handle it gracefully + if content.is_empty() { + return Ok(vec![]); + } + let channels: Vec<super::MicropubChannel> = serde_json::from_str(&content)?; + Ok(channels) + } + Err(e) => { + if e.kind() == IOErrorKind::NotFound { + Ok(vec![]) + } else { + Err(e.into()) + } + } + } + } + + async fn read_feed_with_cursor( + &self, + url: &'_ str, + cursor: Option<&'_ str>, + limit: usize, + user: Option<&'_ str> + ) -> Result<Option<(serde_json::Value, Option<String>)>> { + Ok(self.read_feed_with_limit( + url, + &cursor.map(|v| v.to_owned()), + limit, + &user.map(|v| v.to_owned()) + ).await? + .map(|feed| { + tracing::debug!("Feed: {:#}", serde_json::Value::Array( + feed["children"] + .as_array() + .map(|v| v.as_slice()) + .unwrap_or_default() + .iter() + .map(|mf2| mf2["properties"]["uid"][0].clone()) + .collect::<Vec<_>>() + )); + let cursor: Option<String> = feed["children"] + .as_array() + .map(|v| v.as_slice()) + .unwrap_or_default() + .last() + .map(|v| v["properties"]["uid"][0].as_str().unwrap().to_owned()); + tracing::debug!("Extracted the cursor: {:?}", cursor); + (feed, cursor) + }) + ) + } + + #[tracing::instrument(skip(self))] + async fn read_feed_with_limit( + &self, + url: &'_ str, + after: &'_ Option<String>, + limit: usize, + user: &'_ Option<String>, + ) -> Result<Option<serde_json::Value>> { + if let Some(mut feed) = self.get_post(url).await? { + if feed["children"].is_array() { + // Take this out of the MF2-JSON document to save memory + // + // This uses a clever match with enum destructuring + // to extract the underlying Vec without cloning it + let children: Vec<serde_json::Value> = match feed["children"].take() { + serde_json::Value::Array(children) => children, + // We've already checked it's an array + _ => unreachable!() + }; + tracing::debug!("Full children array: {:#}", serde_json::Value::Array(children.clone())); + let mut posts_iter = children + .into_iter() + .map(|s: serde_json::Value| s.as_str().unwrap().to_string()); + // Note: we can't actually use `skip_while` here because we end up emitting `after`. + // This imperative snippet consumes after instead of emitting it, allowing the + // stream of posts to return only those items that truly come *after* that one. + // If I would implement an Iter combinator like this, I would call it `skip_until` + if let Some(after) = after { + for s in posts_iter.by_ref() { + if &s == after { + break; + } + } + }; + let posts = stream::iter(posts_iter) + .map(|url: String| async move { self.get_post(&url).await }) + .buffered(std::cmp::min(3, limit)) + // Hack to unwrap the Option and sieve out broken links + // Broken links return None, and Stream::filter_map skips Nones. + .try_filter_map(|post: Option<serde_json::Value>| async move { Ok(post) }) + .and_then(|mut post| async move { + hydrate_author(&mut post, user, self).await; + Ok(post) + }) + .take(limit); + + match posts.try_collect::<Vec<serde_json::Value>>().await { + Ok(posts) => feed["children"] = serde_json::json!(posts), + Err(err) => { + return Err(StorageError::with_source( + ErrorKind::Other, + Cow::Owned(format!("Feed assembly error: {}", &err)), + Box::new(err), + )); + } + } + } + hydrate_author(&mut feed, user, self).await; + Ok(Some(feed)) + } else { + Ok(None) + } + } + + #[tracing::instrument(skip(self))] + async fn delete_post(&self, url: &'_ str) -> Result<()> { + let path = url_to_path(&self.root_dir, url); + if let Err(e) = tokio::fs::remove_file(path).await { + Err(e.into()) + } else { + // TODO check for dangling references in the channel list + Ok(()) + } + } + + #[tracing::instrument(skip(self))] + async fn get_setting<S: settings::Setting<'a>, 'a>(&self, user: &'_ str) -> Result<S> { + debug!("User for getting settings: {}", user); + let mut path = relative_path::RelativePathBuf::new(); + path.push(user); + path.push("settings"); + + let path = path.to_path(&self.root_dir); + debug!("Getting settings from {:?}", &path); + + let mut file = File::open(path).await?; + let mut content = String::new(); + file.read_to_string(&mut content).await?; + + let settings: HashMap<&str, serde_json::Value> = serde_json::from_str(&content)?; + match settings.get(S::ID) { + Some(value) => Ok(serde_json::from_value::<S>(value.clone())?), + None => Err(StorageError::from_static(ErrorKind::Backend, "Setting not set")) + } + } + + #[tracing::instrument(skip(self))] + async fn set_setting<S: settings::Setting<'a> + 'a, 'a>(&self, user: &'a str, value: S::Data) -> Result<()> { + let mut path = relative_path::RelativePathBuf::new(); + path.push(user); + path.push("settings"); + + let path = path.to_path(&self.root_dir); + let temppath = path.with_extension("tmp"); + + let parent = path.parent().unwrap().to_owned(); + tokio::fs::create_dir_all(&parent).await?; + + let mut tempfile = OpenOptions::new() + .write(true) + .create_new(true) + .open(&temppath) + .await?; + + let mut settings: HashMap<String, serde_json::Value> = match File::open(&path).await { + Ok(mut f) => { + let mut content = String::new(); + f.read_to_string(&mut content).await?; + if content.is_empty() { + Default::default() + } else { + serde_json::from_str(&content)? + } + } + Err(err) => { + if err.kind() == IOErrorKind::NotFound { + Default::default() + } else { + return Err(err.into()); + } + } + }; + settings.insert(S::ID.to_owned(), serde_json::to_value(S::new(value))?); + + tempfile + .write_all(serde_json::to_string(&settings)?.as_bytes()) + .await?; + tempfile.flush().await?; + tempfile.sync_all().await?; + drop(tempfile); + tokio::fs::rename(temppath, &path).await?; + tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + Ok(()) + } + + #[tracing::instrument(skip(self))] + async fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> Result<()> { + let path = url_to_path(&self.root_dir, target); + let tempfilename = path.with_extension("tmp"); + + let mut temp = OpenOptions::new() + .write(true) + .create_new(true) + .open(&tempfilename) + .await?; + let mut file = OpenOptions::new().read(true).open(&path).await?; + + let mut post: serde_json::Value = { + let mut content = String::new(); + file.read_to_string(&mut content).await?; + drop(file); + + serde_json::from_str(&content)? + }; + + let key: &'static str = match mention_type { + MentionType::Reply => "comment", + MentionType::Like => "like", + MentionType::Repost => "repost", + MentionType::Bookmark => "bookmark", + MentionType::Mention => "mention", + }; + let mention_uid = mention["properties"]["uid"][0].clone(); + if let Some(values) = post["properties"][key].as_array_mut() { + for value in values.iter_mut() { + if value["properties"]["uid"][0] == mention_uid { + *value = mention; + break; + } + } + } else { + post["properties"][key] = serde_json::Value::Array(vec![mention]); + } + + temp.write_all(post.to_string().as_bytes()).await?; + temp.flush().await?; + temp.sync_all().await?; + drop(temp); + tokio::fs::rename(tempfilename, &path).await?; + tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + + Ok(()) + } +} diff --git a/src/database/memory.rs b/src/database/memory.rs new file mode 100644 index 0000000..6339e7a --- /dev/null +++ b/src/database/memory.rs @@ -0,0 +1,249 @@ +#![allow(clippy::todo)] +use async_trait::async_trait; +use futures_util::FutureExt; +use serde_json::json; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; + +use crate::database::{ErrorKind, MicropubChannel, Result, settings, Storage, StorageError}; + +#[derive(Clone, Debug)] +pub struct MemoryStorage { + pub mapping: Arc<RwLock<HashMap<String, serde_json::Value>>>, + pub channels: Arc<RwLock<HashMap<String, Vec<String>>>>, +} + +#[async_trait] +impl Storage for MemoryStorage { + async fn post_exists(&self, url: &str) -> Result<bool> { + return Ok(self.mapping.read().await.contains_key(url)); + } + + async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> { + let mapping = self.mapping.read().await; + match mapping.get(url) { + Some(val) => { + if let Some(new_url) = val["see_other"].as_str() { + match mapping.get(new_url) { + Some(val) => Ok(Some(val.clone())), + None => { + drop(mapping); + self.mapping.write().await.remove(url); + Ok(None) + } + } + } else { + Ok(Some(val.clone())) + } + } + _ => Ok(None), + } + } + + async fn put_post(&self, post: &'_ serde_json::Value, _user: &'_ str) -> Result<()> { + let mapping = &mut self.mapping.write().await; + let key: &str = match post["properties"]["uid"][0].as_str() { + Some(uid) => uid, + None => { + return Err(StorageError::from_static( + ErrorKind::Other, + "post doesn't have a UID", + )) + } + }; + mapping.insert(key.to_string(), post.clone()); + if post["properties"]["url"].is_array() { + for url in post["properties"]["url"] + .as_array() + .unwrap() + .iter() + .map(|i| i.as_str().unwrap().to_string()) + { + if url != key { + mapping.insert(url, json!({ "see_other": key })); + } + } + } + if post["type"] + .as_array() + .unwrap() + .iter() + .any(|i| i == "h-feed") + { + // This is a feed. Add it to the channels array if it's not already there. + println!("{:#}", post); + self.channels + .write() + .await + .entry( + post["properties"]["author"][0] + .as_str() + .unwrap() + .to_string(), + ) + .or_insert_with(Vec::new) + .push(key.to_string()) + } + Ok(()) + } + + async fn update_post(&self, url: &'_ str, update: crate::micropub::MicropubUpdate) -> Result<()> { + let mut guard = self.mapping.write().await; + let mut post = guard.get_mut(url).ok_or(StorageError::from_static(ErrorKind::NotFound, "The specified post wasn't found in the database."))?; + + use crate::micropub::MicropubPropertyDeletion; + + let mut add_keys: HashMap<String, Vec<serde_json::Value>> = HashMap::new(); + let mut remove_keys: Vec<String> = vec![]; + let mut remove_values: HashMap<String, Vec<serde_json::Value>> = HashMap::new(); + + if let Some(MicropubPropertyDeletion::Properties(delete)) = update.delete { + remove_keys.extend(delete.iter().cloned()); + } else if let Some(MicropubPropertyDeletion::Values(delete)) = update.delete { + for (k, v) in delete { + remove_values + .entry(k.to_string()) + .or_default() + .extend(v.clone()); + } + } + if let Some(add) = update.add { + for (k, v) in add { + add_keys.insert(k.to_string(), v.clone()); + } + } + if let Some(replace) = update.replace { + for (k, v) in replace { + remove_keys.push(k.to_string()); + add_keys.insert(k.to_string(), v.clone()); + } + } + + if let Some(props) = post["properties"].as_object_mut() { + for k in remove_keys { + props.remove(&k); + } + } + for (k, v) in remove_values { + let k = &k; + let props = if k == "children" { + &mut post + } else { + &mut post["properties"] + }; + v.iter().for_each(|v| { + if let Some(vec) = props[k].as_array_mut() { + if let Some(index) = vec.iter().position(|w| w == v) { + vec.remove(index); + } + } + }); + } + for (k, v) in add_keys { + tracing::debug!("Adding k/v to post: {} => {:?}", k, v); + let props = if k == "children" { + &mut post + } else { + &mut post["properties"] + }; + if let Some(prop) = props[&k].as_array_mut() { + if k == "children" { + v.into_iter().rev().for_each(|v| prop.insert(0, v)); + } else { + prop.extend(v.into_iter()); + } + } else { + props[&k] = serde_json::Value::Array(v) + } + } + + Ok(()) + } + + async fn get_channels(&self, user: &'_ str) -> Result<Vec<MicropubChannel>> { + match self.channels.read().await.get(user) { + Some(channels) => Ok(futures_util::future::join_all( + channels + .iter() + .map(|channel| { + self.get_post(channel).map(|result| result.unwrap()).map( + |post: Option<serde_json::Value>| { + post.map(|post| MicropubChannel { + uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), + name: post["properties"]["name"][0] + .as_str() + .unwrap() + .to_string(), + }) + }, + ) + }) + .collect::<Vec<_>>(), + ) + .await + .into_iter() + .flatten() + .collect::<Vec<_>>()), + None => Ok(vec![]), + } + } + + #[allow(unused_variables)] + async fn read_feed_with_limit( + &self, + url: &'_ str, + after: &'_ Option<String>, + limit: usize, + user: &'_ Option<String>, + ) -> Result<Option<serde_json::Value>> { + todo!() + } + + #[allow(unused_variables)] + async fn read_feed_with_cursor( + &self, + url: &'_ str, + cursor: Option<&'_ str>, + limit: usize, + user: Option<&'_ str> + ) -> Result<Option<(serde_json::Value, Option<String>)>> { + todo!() + } + + async fn delete_post(&self, url: &'_ str) -> Result<()> { + self.mapping.write().await.remove(url); + Ok(()) + } + + #[allow(unused_variables)] + async fn get_setting<S: settings::Setting<'a>, 'a>(&'_ self, user: &'_ str) -> Result<S> { + todo!() + } + + #[allow(unused_variables)] + async fn set_setting<S: settings::Setting<'a> + 'a, 'a>(&self, user: &'a str, value: S::Data) -> Result<()> { + todo!() + } + + #[allow(unused_variables)] + async fn add_or_update_webmention(&self, target: &str, mention_type: kittybox_util::MentionType, mention: serde_json::Value) -> Result<()> { + todo!() + } + +} + +impl Default for MemoryStorage { + fn default() -> Self { + Self::new() + } +} + +impl MemoryStorage { + pub fn new() -> Self { + Self { + mapping: Arc::new(RwLock::new(HashMap::new())), + channels: Arc::new(RwLock::new(HashMap::new())), + } + } +} diff --git a/src/database/mod.rs b/src/database/mod.rs new file mode 100644 index 0000000..b4b70b2 --- /dev/null +++ b/src/database/mod.rs @@ -0,0 +1,793 @@ +#![warn(missing_docs)] +use std::borrow::Cow; + +use async_trait::async_trait; +use kittybox_util::MentionType; + +mod file; +pub use crate::database::file::FileStorage; +use crate::micropub::MicropubUpdate; +#[cfg(feature = "postgres")] +mod postgres; +#[cfg(feature = "postgres")] +pub use postgres::PostgresStorage; + +#[cfg(test)] +mod memory; +#[cfg(test)] +pub use crate::database::memory::MemoryStorage; + +pub use kittybox_util::MicropubChannel; + +use self::settings::Setting; + +/// Enum representing different errors that might occur during the database query. +#[derive(Debug, Clone, Copy)] +pub enum ErrorKind { + /// Backend error (e.g. database connection error) + Backend, + /// Error due to insufficient contextual permissions for the query + PermissionDenied, + /// Error due to the database being unable to parse JSON returned from the backing storage. + /// Usually indicative of someone fiddling with the database manually instead of using proper tools. + JsonParsing, + /// - ErrorKind::NotFound - equivalent to a 404 error. Note, some requests return an Option, + /// in which case None is also equivalent to a 404. + NotFound, + /// The user's query or request to the database was malformed. Used whenever the database processes + /// the user's query directly, such as when editing posts inside of the database (e.g. Redis backend) + BadRequest, + /// the user's query collided with an in-flight request and needs to be retried + Conflict, + /// - ErrorKind::Other - when something so weird happens that it becomes undescribable. + Other, +} + +/// Settings that can be stored in the database. +pub mod settings { + mod private { + pub trait Sealed {} + } + + /// A trait for various settings that should be contained here. + /// + /// **Note**: this trait is sealed to prevent external + /// implementations, as it wouldn't make sense to add new settings + /// that aren't used by Kittybox itself. + pub trait Setting<'de>: private::Sealed + std::fmt::Debug + Default + Clone + serde::Serialize + serde::de::DeserializeOwned + /*From<Settings> +*/ Send + Sync { + type Data: std::fmt::Debug + Send + Sync; + const ID: &'static str; + + /// Unwrap the setting type, returning owned data contained within. + fn into_inner(self) -> Self::Data; + /// Create a new instance of this type containing certain data. + fn new(data: Self::Data) -> Self; + } + + /// A website's title, shown in the header. + #[derive(Debug, serde::Deserialize, serde::Serialize, Clone, PartialEq, Eq)] + pub struct SiteName(String); + impl Default for SiteName { + fn default() -> Self { + Self("Kittybox".to_string()) + } + } + impl AsRef<str> for SiteName { + fn as_ref(&self) -> &str { + self.0.as_str() + } + } + impl private::Sealed for SiteName {} + impl Setting<'_> for SiteName { + type Data = String; + const ID: &'static str = "site_name"; + + fn into_inner(self) -> String { + self.0 + } + fn new(data: Self::Data) -> Self { + Self(data) + } + } + impl SiteName { + fn from_str(data: &str) -> Self { + Self(data.to_owned()) + } + } + + /// Participation status in the IndieWeb Webring: https://🕸💍.ws/dashboard + #[derive(Debug, Default, serde::Deserialize, serde::Serialize, Clone, Copy, PartialEq, Eq)] + pub struct Webring(bool); + impl private::Sealed for Webring {} + impl Setting<'_> for Webring { + type Data = bool; + const ID: &'static str = "webring"; + + fn into_inner(self) -> Self::Data { + self.0 + } + + fn new(data: Self::Data) -> Self { + Self(data) + } + } +} + +/// Error signalled from the database. +#[derive(Debug)] +pub struct StorageError { + msg: std::borrow::Cow<'static, str>, + source: Option<Box<dyn std::error::Error + Send + Sync>>, + kind: ErrorKind, +} + +impl std::error::Error for StorageError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.source + .as_ref() + .map(|e| e.as_ref() as &dyn std::error::Error) + } +} +impl From<serde_json::Error> for StorageError { + fn from(err: serde_json::Error) -> Self { + Self { + msg: std::borrow::Cow::Owned(format!("{}", err)), + source: Some(Box::new(err)), + kind: ErrorKind::JsonParsing, + } + } +} +impl std::fmt::Display for StorageError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}: {}", + match self.kind { + ErrorKind::Backend => "backend error", + ErrorKind::JsonParsing => "JSON parsing error", + ErrorKind::PermissionDenied => "permission denied", + ErrorKind::NotFound => "not found", + ErrorKind::BadRequest => "bad request", + ErrorKind::Conflict => "conflict with an in-flight request or existing data", + ErrorKind::Other => "generic storage layer error", + }, + self.msg + ) + } +} +impl serde::Serialize for StorageError { + fn serialize<S: serde::Serializer>( + &self, + serializer: S, + ) -> std::result::Result<S::Ok, S::Error> { + serializer.serialize_str(&self.to_string()) + } +} +impl StorageError { + /// Create a new StorageError of an ErrorKind with a message. + pub fn new(kind: ErrorKind, msg: String) -> Self { + Self { + msg: Cow::Owned(msg), + source: None, + kind, + } + } + /// Create a new StorageError of an ErrorKind with a message from + /// a static string. + /// + /// This saves an allocation for a new string and is the preferred + /// way in case the error message doesn't change. + pub fn from_static(kind: ErrorKind, msg: &'static str) -> Self { + Self { + msg: Cow::Borrowed(msg), + source: None, + kind + } + } + /// Create a StorageError using another arbitrary Error as a source. + pub fn with_source( + kind: ErrorKind, + msg: std::borrow::Cow<'static, str>, + source: Box<dyn std::error::Error + Send + Sync>, + ) -> Self { + Self { + msg, + source: Some(source), + kind, + } + } + /// Get the kind of an error. + pub fn kind(&self) -> ErrorKind { + self.kind + } + /// Get the message as a string slice. + pub fn msg(&self) -> &str { + &self.msg + } +} + +/// A special Result type for the Micropub backing storage. +pub type Result<T> = std::result::Result<T, StorageError>; + +/// A storage backend for the Micropub server. +/// +/// Implementations should note that all methods listed on this trait MUST be fully atomic +/// or lock the database so that write conflicts or reading half-written data should not occur. +#[async_trait] +pub trait Storage: std::fmt::Debug + Clone + Send + Sync { + /// Check if a post exists in the database. + async fn post_exists(&self, url: &str) -> Result<bool>; + + /// Load a post from the database in MF2-JSON format, deserialized from JSON. + async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>>; + + /// Save a post to the database as an MF2-JSON structure. + /// + /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined. + async fn put_post(&self, post: &'_ serde_json::Value, user: &'_ str) -> Result<()>; + + /// Add post to feed. Some database implementations might have optimized ways to do this. + #[tracing::instrument(skip(self))] + async fn add_to_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> { + tracing::debug!("Inserting {} into {} using `update_post`", post, feed); + self.update_post(feed, serde_json::from_value( + serde_json::json!({"add": {"children": [post]}})).unwrap() + ).await + } + /// Remove post from feed. Some database implementations might have optimized ways to do this. + #[tracing::instrument(skip(self))] + async fn remove_from_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> { + tracing::debug!("Removing {} into {} using `update_post`", post, feed); + self.update_post(feed, serde_json::from_value( + serde_json::json!({"delete": {"children": [post]}})).unwrap() + ).await + } + + /// Modify a post using an update object as defined in the + /// Micropub spec. + /// + /// Note to implementors: the update operation MUST be atomic and + /// SHOULD lock the database to prevent two clients overwriting + /// each other's changes or simply corrupting something. Rejecting + /// is allowed in case of concurrent updates if waiting for a lock + /// cannot be done. + async fn update_post(&self, url: &str, update: MicropubUpdate) -> Result<()>; + + /// Get a list of channels available for the user represented by + /// the `user` domain to write to. + async fn get_channels(&self, user: &'_ str) -> Result<Vec<MicropubChannel>>; + + /// Fetch a feed at `url` and return an h-feed object containing + /// `limit` posts after a post by url `after`, filtering the content + /// in context of a user specified by `user` (or an anonymous user). + /// + /// This method MUST hydrate the `author` property with an h-card + /// from the database by replacing URLs with corresponding h-cards. + /// + /// When encountering posts which the `user` is not authorized to + /// access, this method MUST elide such posts (as an optimization + /// for the frontend) and not return them, but still return up to + /// `limit` posts (to not reveal the hidden posts' presence). + /// + /// Note for implementors: if you use streams to fetch posts in + /// parallel from the database, preferably make this method use a + /// connection pool to reduce overhead of creating a database + /// connection per post for parallel fetching. + async fn read_feed_with_limit( + &self, + url: &'_ str, + after: &'_ Option<String>, + limit: usize, + user: &'_ Option<String>, + ) -> Result<Option<serde_json::Value>>; + + /// Fetch a feed at `url` and return an h-feed object containing + /// `limit` posts after a `cursor` (filtering the content in + /// context of a user specified by `user`, or an anonymous user), + /// as well as a new cursor to paginate with. + /// + /// This method MUST hydrate the `author` property with an h-card + /// from the database by replacing URLs with corresponding h-cards. + /// + /// When encountering posts which the `user` is not authorized to + /// access, this method MUST elide such posts (as an optimization + /// for the frontend) and not return them, but still return an + /// amount of posts as close to `limit` as possible (to avoid + /// revealing the existence of the hidden post). + /// + /// Note for implementors: if you use streams to fetch posts in + /// parallel from the database, preferably make this method use a + /// connection pool to reduce overhead of creating a database + /// connection per post for parallel fetching. + async fn read_feed_with_cursor( + &self, + url: &'_ str, + cursor: Option<&'_ str>, + limit: usize, + user: Option<&'_ str> + ) -> Result<Option<(serde_json::Value, Option<String>)>>; + + /// Deletes a post from the database irreversibly. Must be idempotent. + async fn delete_post(&self, url: &'_ str) -> Result<()>; + + /// Gets a setting from the setting store and passes the result. + async fn get_setting<S: Setting<'a>, 'a>(&'_ self, user: &'_ str) -> Result<S>; + + /// Commits a setting to the setting store. + async fn set_setting<S: Setting<'a> + 'a, 'a>(&self, user: &'a str, value: S::Data) -> Result<()>; + + /// Add (or update) a webmention on a certian post. + /// + /// The MF2 object describing the webmention content will always + /// be of type `h-cite`, and the `uid` property on the object will + /// always be set. + /// + /// The rationale for this function is as follows: webmentions + /// might be duplicated, and we need to deduplicate them first. As + /// we lack support for transactions and locking posts on the + /// database, the only way is to implement the operation on the + /// database itself. + /// + /// Besides, it may even allow for nice tricks like storing the + /// webmentions separately and rehydrating them on feed reads. + async fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> Result<()>; +} + +#[cfg(test)] +mod tests { + use super::settings; + + use super::{MicropubChannel, Storage}; + use kittybox_util::MentionType; + use serde_json::json; + + async fn test_basic_operations<Backend: Storage>(backend: Backend) { + let post: serde_json::Value = json!({ + "type": ["h-entry"], + "properties": { + "content": ["Test content"], + "author": ["https://fireburn.ru/"], + "uid": ["https://fireburn.ru/posts/hello"], + "url": ["https://fireburn.ru/posts/hello", "https://fireburn.ru/posts/test"] + } + }); + let key = post["properties"]["uid"][0].as_str().unwrap().to_string(); + let alt_url = post["properties"]["url"][1].as_str().unwrap().to_string(); + + // Reading and writing + backend + .put_post(&post, "fireburn.ru") + .await + .unwrap(); + if let Some(returned_post) = backend.get_post(&key).await.unwrap() { + assert!(returned_post.is_object()); + assert_eq!( + returned_post["type"].as_array().unwrap().len(), + post["type"].as_array().unwrap().len() + ); + assert_eq!( + returned_post["type"].as_array().unwrap(), + post["type"].as_array().unwrap() + ); + let props: &serde_json::Map<String, serde_json::Value> = + post["properties"].as_object().unwrap(); + for key in props.keys() { + assert_eq!( + returned_post["properties"][key].as_array().unwrap(), + post["properties"][key].as_array().unwrap() + ) + } + } else { + panic!("For some reason the backend did not return the post.") + } + // Check the alternative URL - it should return the same post + if let Ok(Some(returned_post)) = backend.get_post(&alt_url).await { + assert!(returned_post.is_object()); + assert_eq!( + returned_post["type"].as_array().unwrap().len(), + post["type"].as_array().unwrap().len() + ); + assert_eq!( + returned_post["type"].as_array().unwrap(), + post["type"].as_array().unwrap() + ); + let props: &serde_json::Map<String, serde_json::Value> = + post["properties"].as_object().unwrap(); + for key in props.keys() { + assert_eq!( + returned_post["properties"][key].as_array().unwrap(), + post["properties"][key].as_array().unwrap() + ) + } + } else { + panic!("For some reason the backend did not return the post.") + } + } + + /// Note: this is merely a smoke check and is in no way comprehensive. + // TODO updates for feeds must update children using special logic + async fn test_update<Backend: Storage>(backend: Backend) { + let post: serde_json::Value = json!({ + "type": ["h-entry"], + "properties": { + "content": ["Test content"], + "author": ["https://fireburn.ru/"], + "uid": ["https://fireburn.ru/posts/hello"], + "url": ["https://fireburn.ru/posts/hello", "https://fireburn.ru/posts/test"] + } + }); + let key = post["properties"]["uid"][0].as_str().unwrap().to_string(); + + // Reading and writing + backend + .put_post(&post, "fireburn.ru") + .await + .unwrap(); + + backend + .update_post( + &key, + serde_json::from_value(json!({ + "url": &key, + "add": { + "category": ["testing"], + }, + "replace": { + "content": ["Different test content"] + } + })).unwrap(), + ) + .await + .unwrap(); + + match backend.get_post(&key).await { + Ok(Some(returned_post)) => { + assert!(returned_post.is_object()); + assert_eq!( + returned_post["type"].as_array().unwrap().len(), + post["type"].as_array().unwrap().len() + ); + assert_eq!( + returned_post["type"].as_array().unwrap(), + post["type"].as_array().unwrap() + ); + assert_eq!( + returned_post["properties"]["content"][0].as_str().unwrap(), + "Different test content" + ); + assert_eq!( + returned_post["properties"]["category"].as_array().unwrap(), + &vec![json!("testing")] + ); + } + something_else => { + something_else + .expect("Shouldn't error") + .expect("Should have the post"); + } + } + } + + async fn test_get_channel_list<Backend: Storage>(backend: Backend) { + let feed = json!({ + "type": ["h-feed"], + "properties": { + "name": ["Main Page"], + "author": ["https://fireburn.ru/"], + "uid": ["https://fireburn.ru/feeds/main"] + }, + "children": [] + }); + backend + .put_post(&feed, "fireburn.ru") + .await + .unwrap(); + let chans = backend.get_channels("fireburn.ru").await.unwrap(); + assert_eq!(chans.len(), 1); + assert_eq!( + chans[0], + MicropubChannel { + uid: "https://fireburn.ru/feeds/main".to_string(), + name: "Main Page".to_string() + } + ); + } + + async fn test_settings<Backend: Storage>(backend: Backend) { + backend + .set_setting::<settings::SiteName>( + "https://fireburn.ru/", + "Vika's Hideout".to_owned() + ) + .await + .unwrap(); + assert_eq!( + backend + .get_setting::<settings::SiteName>("https://fireburn.ru/") + .await + .unwrap() + .as_ref(), + "Vika's Hideout" + ); + } + + fn gen_random_post(domain: &str) -> serde_json::Value { + use faker_rand::lorem::{Paragraphs, Word}; + + let uid = format!( + "https://{domain}/posts/{}-{}-{}", + rand::random::<Word>(), + rand::random::<Word>(), + rand::random::<Word>() + ); + + let time = chrono::Local::now().to_rfc3339(); + let post = json!({ + "type": ["h-entry"], + "properties": { + "content": [rand::random::<Paragraphs>().to_string()], + "uid": [&uid], + "url": [&uid], + "published": [&time] + } + }); + + post + } + + fn gen_random_mention(domain: &str, mention_type: MentionType, url: &str) -> serde_json::Value { + use faker_rand::lorem::{Paragraphs, Word}; + + let uid = format!( + "https://{domain}/posts/{}-{}-{}", + rand::random::<Word>(), + rand::random::<Word>(), + rand::random::<Word>() + ); + + let time = chrono::Local::now().to_rfc3339(); + let post = json!({ + "type": ["h-cite"], + "properties": { + "content": [rand::random::<Paragraphs>().to_string()], + "uid": [&uid], + "url": [&uid], + "published": [&time], + (match mention_type { + MentionType::Reply => "in-reply-to", + MentionType::Like => "like-of", + MentionType::Repost => "repost-of", + MentionType::Bookmark => "bookmark-of", + MentionType::Mention => unimplemented!(), + }): [url] + } + }); + + post + } + + async fn test_feed_pagination<Backend: Storage>(backend: Backend) { + let posts = { + let mut posts = std::iter::from_fn( + || Some(gen_random_post("fireburn.ru")) + ) + .take(40) + .collect::<Vec<serde_json::Value>>(); + + // Reverse the array so it's in reverse-chronological order + posts.reverse(); + + posts + }; + + let feed = json!({ + "type": ["h-feed"], + "properties": { + "name": ["Main Page"], + "author": ["https://fireburn.ru/"], + "uid": ["https://fireburn.ru/feeds/main"] + }, + }); + let key = feed["properties"]["uid"][0].as_str().unwrap(); + + backend + .put_post(&feed, "fireburn.ru") + .await + .unwrap(); + + for (i, post) in posts.iter().rev().enumerate() { + backend + .put_post(post, "fireburn.ru") + .await + .unwrap(); + backend.add_to_feed(key, post["properties"]["uid"][0].as_str().unwrap()).await.unwrap(); + } + + let limit: usize = 10; + + tracing::debug!("Starting feed reading..."); + let (result, cursor) = backend + .read_feed_with_cursor(key, None, limit, None) + .await + .unwrap() + .unwrap(); + + assert_eq!(result["children"].as_array().unwrap().len(), limit); + assert_eq!( + result["children"] + .as_array() + .unwrap() + .iter() + .map(|post| post["properties"]["uid"][0].as_str().unwrap()) + .collect::<Vec<_>>() + [0..10], + posts + .iter() + .map(|post| post["properties"]["uid"][0].as_str().unwrap()) + .collect::<Vec<_>>() + [0..10] + ); + + tracing::debug!("Continuing with cursor: {:?}", cursor); + let (result2, cursor2) = backend + .read_feed_with_cursor( + key, + cursor.as_deref(), + limit, + None, + ) + .await + .unwrap() + .unwrap(); + + assert_eq!( + result2["children"].as_array().unwrap()[0..10], + posts[10..20] + ); + + tracing::debug!("Continuing with cursor: {:?}", cursor); + let (result3, cursor3) = backend + .read_feed_with_cursor( + key, + cursor2.as_deref(), + limit, + None, + ) + .await + .unwrap() + .unwrap(); + + assert_eq!( + result3["children"].as_array().unwrap()[0..10], + posts[20..30] + ); + + tracing::debug!("Continuing with cursor: {:?}", cursor); + let (result4, _) = backend + .read_feed_with_cursor( + key, + cursor3.as_deref(), + limit, + None, + ) + .await + .unwrap() + .unwrap(); + + assert_eq!( + result4["children"].as_array().unwrap()[0..10], + posts[30..40] + ); + + // Regression test for #4 + // + // Results for a bogus cursor are undefined, so we aren't + // checking them. But the function at least shouldn't hang. + let nonsense_after = Some("1010101010"); + let _ = tokio::time::timeout(tokio::time::Duration::from_secs(10), async move { + backend + .read_feed_with_cursor(key, nonsense_after, limit, None) + .await + }) + .await + .expect("Operation should not hang: see https://gitlab.com/kittybox/kittybox/-/issues/4"); + } + + async fn test_webmention_addition<Backend: Storage>(db: Backend) { + let post = gen_random_post("fireburn.ru"); + + db.put_post(&post, "fireburn.ru").await.unwrap(); + const TYPE: MentionType = MentionType::Reply; + + let target = post["properties"]["uid"][0].as_str().unwrap(); + let mut reply = gen_random_mention("aaronparecki.com", TYPE, target); + + let (read_post, _) = db.read_feed_with_cursor(target, None, 20, None).await.unwrap().unwrap(); + assert_eq!(post, read_post); + + db.add_or_update_webmention(target, TYPE, reply.clone()).await.unwrap(); + + let (read_post, _) = db.read_feed_with_cursor(target, None, 20, None).await.unwrap().unwrap(); + assert_eq!(read_post["properties"]["comment"][0], reply); + + reply["properties"]["content"][0] = json!(rand::random::<faker_rand::lorem::Paragraphs>().to_string()); + + db.add_or_update_webmention(target, TYPE, reply.clone()).await.unwrap(); + let (read_post, _) = db.read_feed_with_cursor(target, None, 20, None).await.unwrap().unwrap(); + assert_eq!(read_post["properties"]["comment"][0], reply); + } + + async fn test_pretty_permalinks<Backend: Storage>(db: Backend) { + const PERMALINK: &str = "https://fireburn.ru/posts/pretty-permalink"; + + let post = { + let mut post = gen_random_post("fireburn.ru"); + let urls = post["properties"]["url"].as_array_mut().unwrap(); + urls.push(serde_json::Value::String( + PERMALINK.to_owned() + )); + + post + }; + db.put_post(&post, "fireburn.ru").await.unwrap(); + + for i in post["properties"]["url"].as_array().unwrap() { + let (read_post, _) = db.read_feed_with_cursor(i.as_str().unwrap(), None, 20, None).await.unwrap().unwrap(); + assert_eq!(read_post, post); + } + } + /// Automatically generates a test suite for + macro_rules! test_all { + ($func_name:ident, $mod_name:ident) => { + mod $mod_name { + $func_name!(test_basic_operations); + $func_name!(test_get_channel_list); + $func_name!(test_settings); + $func_name!(test_update); + $func_name!(test_feed_pagination); + $func_name!(test_webmention_addition); + $func_name!(test_pretty_permalinks); + } + }; + } + macro_rules! file_test { + ($func_name:ident) => { + #[tokio::test] + #[tracing_test::traced_test] + async fn $func_name() { + let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); + let backend = super::super::FileStorage::new( + tempdir.path().to_path_buf() + ) + .await + .unwrap(); + super::$func_name(backend).await + } + }; + } + + macro_rules! postgres_test { + ($func_name:ident) => { + #[cfg(feature = "sqlx")] + #[sqlx::test] + #[tracing_test::traced_test] + async fn $func_name( + pool_opts: sqlx::postgres::PgPoolOptions, + connect_opts: sqlx::postgres::PgConnectOptions + ) -> Result<(), sqlx::Error> { + let db = { + //use sqlx::ConnectOptions; + //connect_opts.log_statements(log::LevelFilter::Debug); + + pool_opts.connect_with(connect_opts).await? + }; + let backend = super::super::PostgresStorage::from_pool(db).await.unwrap(); + + Ok(super::$func_name(backend).await) + } + }; + } + + test_all!(file_test, file); + test_all!(postgres_test, postgres); +} diff --git a/src/database/postgres/mod.rs b/src/database/postgres/mod.rs new file mode 100644 index 0000000..9176d12 --- /dev/null +++ b/src/database/postgres/mod.rs @@ -0,0 +1,416 @@ +#![allow(unused_variables)] +use std::borrow::Cow; +use std::str::FromStr; + +use kittybox_util::{MicropubChannel, MentionType}; +use sqlx::{PgPool, Executor}; +use crate::micropub::{MicropubUpdate, MicropubPropertyDeletion}; + +use super::settings::Setting; +use super::{Storage, Result, StorageError, ErrorKind}; + +static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!(); + +impl From<sqlx::Error> for StorageError { + fn from(value: sqlx::Error) -> Self { + Self::with_source( + super::ErrorKind::Backend, + Cow::Owned(format!("sqlx error: {}", &value)), + Box::new(value) + ) + } +} + +impl From<sqlx::migrate::MigrateError> for StorageError { + fn from(value: sqlx::migrate::MigrateError) -> Self { + Self::with_source( + super::ErrorKind::Backend, + Cow::Owned(format!("sqlx migration error: {}", &value)), + Box::new(value) + ) + } +} + +#[derive(Debug, Clone)] +pub struct PostgresStorage { + db: PgPool +} + +impl PostgresStorage { + /// Construct a new [`PostgresStorage`] from an URI string and run + /// migrations on the database. + /// + /// If `PGPASS_FILE` environment variable is defined, read the + /// password from the file at the specified path. If, instead, + /// the `PGPASS` environment variable is present, read the + /// password from it. + pub async fn new(uri: &str) -> Result<Self> { + tracing::debug!("Postgres URL: {uri}"); + let mut options = sqlx::postgres::PgConnectOptions::from_str(uri)? + .options([("search_path", "kittybox")]); + if let Ok(password_file) = std::env::var("PGPASS_FILE") { + let password = tokio::fs::read_to_string(password_file).await.unwrap(); + options = options.password(&password); + } else if let Ok(password) = std::env::var("PGPASS") { + options = options.password(&password) + } + Self::from_pool( + sqlx::postgres::PgPoolOptions::new() + .max_connections(50) + .connect_with(options) + .await? + ).await + + } + + /// Construct a [`PostgresStorage`] from a [`sqlx::PgPool`], + /// running appropriate migrations. + pub async fn from_pool(db: sqlx::PgPool) -> Result<Self> { + db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox")).await?; + MIGRATOR.run(&db).await?; + Ok(Self { db }) + } +} + +#[async_trait::async_trait] +impl Storage for PostgresStorage { + #[tracing::instrument(skip(self))] + async fn post_exists(&self, url: &str) -> Result<bool> { + sqlx::query_as::<_, (bool,)>("SELECT exists(SELECT 1 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1)") + .bind(url) + .fetch_one(&self.db) + .await + .map(|v| v.0) + .map_err(|err| err.into()) + } + + #[tracing::instrument(skip(self))] + async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> { + sqlx::query_as::<_, (serde_json::Value,)>("SELECT mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1") + .bind(url) + .fetch_optional(&self.db) + .await + .map(|v| v.map(|v| v.0)) + .map_err(|err| err.into()) + + } + + #[tracing::instrument(skip(self))] + async fn put_post(&self, post: &'_ serde_json::Value, user: &'_ str) -> Result<()> { + tracing::debug!("New post: {}", post); + sqlx::query("INSERT INTO kittybox.mf2_json (uid, mf2, owner) VALUES ($1 #>> '{properties,uid,0}', $1, $2)") + .bind(post) + .bind(user) + .execute(&self.db) + .await + .map(|_| ()) + .map_err(Into::into) + } + + #[tracing::instrument(skip(self))] + async fn add_to_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> { + tracing::debug!("Inserting {} into {}", post, feed); + sqlx::query("INSERT INTO kittybox.children (parent, child) VALUES ($1, $2) ON CONFLICT DO NOTHING") + .bind(feed) + .bind(post) + .execute(&self.db) + .await + .map(|_| ()) + .map_err(Into::into) + } + + #[tracing::instrument(skip(self))] + async fn remove_from_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> { + sqlx::query("DELETE FROM kittybox.children WHERE parent = $1 AND child = $2") + .bind(feed) + .bind(post) + .execute(&self.db) + .await + .map_err(Into::into) + .map(|_| ()) + } + + #[tracing::instrument(skip(self))] + async fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> Result<()> { + let mut txn = self.db.begin().await?; + + let (uid, mut post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE") + .bind(target) + .fetch_optional(&mut *txn) + .await? + .ok_or(StorageError::from_static( + ErrorKind::NotFound, + "The specified post wasn't found in the database." + ))?; + + tracing::debug!("Loaded post for target {} with uid {}", target, uid); + + let key: &'static str = match mention_type { + MentionType::Reply => "comment", + MentionType::Like => "like", + MentionType::Repost => "repost", + MentionType::Bookmark => "bookmark", + MentionType::Mention => "mention", + }; + + tracing::debug!("Mention type -> key: {}", key); + + let mention_uid = mention["properties"]["uid"][0].clone(); + if let Some(values) = post["properties"][key].as_array_mut() { + for value in values.iter_mut() { + if value["properties"]["uid"][0] == mention_uid { + *value = mention; + break; + } + } + } else { + post["properties"][key] = serde_json::Value::Array(vec![mention]); + } + + sqlx::query("UPDATE kittybox.mf2_json SET mf2 = $2 WHERE uid = $1") + .bind(uid) + .bind(post) + .execute(&mut *txn) + .await?; + + txn.commit().await.map_err(Into::into) + } + #[tracing::instrument(skip(self))] + async fn update_post(&self, url: &'_ str, update: MicropubUpdate) -> Result<()> { + tracing::debug!("Updating post {}", url); + let mut txn = self.db.begin().await?; + let (uid, mut post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE") + .bind(url) + .fetch_optional(&mut *txn) + .await? + .ok_or(StorageError::from_static( + ErrorKind::NotFound, + "The specified post wasn't found in the database." + ))?; + + if let Some(MicropubPropertyDeletion::Properties(ref delete)) = update.delete { + if let Some(props) = post["properties"].as_object_mut() { + for key in delete { + props.remove(key); + } + } + } else if let Some(MicropubPropertyDeletion::Values(ref delete)) = update.delete { + if let Some(props) = post["properties"].as_object_mut() { + for (key, values) in delete { + if let Some(prop) = props.get_mut(key).and_then(serde_json::Value::as_array_mut) { + prop.retain(|v| { values.iter().all(|i| i != v) }) + } + } + } + } + if let Some(replace) = update.replace { + if let Some(props) = post["properties"].as_object_mut() { + for (key, value) in replace { + props.insert(key, serde_json::Value::Array(value)); + } + } + } + if let Some(add) = update.add { + if let Some(props) = post["properties"].as_object_mut() { + for (key, value) in add { + if let Some(prop) = props.get_mut(&key).and_then(serde_json::Value::as_array_mut) { + prop.extend_from_slice(value.as_slice()); + } else { + props.insert(key, serde_json::Value::Array(value)); + } + } + } + } + + sqlx::query("UPDATE kittybox.mf2_json SET mf2 = $2 WHERE uid = $1") + .bind(uid) + .bind(post) + .execute(&mut *txn) + .await?; + + txn.commit().await.map_err(Into::into) + } + + #[tracing::instrument(skip(self))] + async fn get_channels(&self, user: &'_ str) -> Result<Vec<MicropubChannel>> { + /*sqlx::query_as::<_, MicropubChannel>("SELECT name, uid FROM kittybox.channels WHERE owner = $1") + .bind(user) + .fetch_all(&self.db) + .await + .map_err(|err| err.into())*/ + sqlx::query_as::<_, MicropubChannel>(r#"SELECT mf2 #>> '{properties,name,0}' as name, uid FROM kittybox.mf2_json WHERE '["h-feed"]'::jsonb @> mf2['type'] AND owner = $1"#) + .bind(user) + .fetch_all(&self.db) + .await + .map_err(|err| err.into()) + } + + #[tracing::instrument(skip(self))] + async fn read_feed_with_limit( + &self, + url: &'_ str, + after: &'_ Option<String>, + limit: usize, + user: &'_ Option<String>, + ) -> Result<Option<serde_json::Value>> { + let mut feed = match sqlx::query_as::<_, (serde_json::Value,)>(" +SELECT jsonb_set( + mf2, + '{properties,author,0}', + (SELECT mf2 FROM kittybox.mf2_json + WHERE uid = mf2 #>> '{properties,author,0}') +) FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 +") + .bind(url) + .fetch_optional(&self.db) + .await? + .map(|v| v.0) + { + Some(feed) => feed, + None => return Ok(None) + }; + + let posts: Vec<String> = { + let mut posts_iter = feed["children"] + .as_array() + .cloned() + .unwrap_or_default() + .into_iter() + .map(|s| s.as_str().unwrap().to_string()); + if let Some(after) = after { + for s in posts_iter.by_ref() { + if &s == after { + break; + } + } + }; + + posts_iter.take(limit).collect::<Vec<_>>() + }; + feed["children"] = serde_json::Value::Array( + sqlx::query_as::<_, (serde_json::Value,)>(" +SELECT jsonb_set( + mf2, + '{properties,author,0}', + (SELECT mf2 FROM kittybox.mf2_json + WHERE uid = mf2 #>> '{properties,author,0}') +) FROM kittybox.mf2_json +WHERE uid = ANY($1) +ORDER BY mf2 #>> '{properties,published,0}' DESC +") + .bind(&posts[..]) + .fetch_all(&self.db) + .await? + .into_iter() + .map(|v| v.0) + .collect::<Vec<_>>() + ); + + Ok(Some(feed)) + + } + + #[tracing::instrument(skip(self))] + async fn read_feed_with_cursor( + &self, + url: &'_ str, + cursor: Option<&'_ str>, + limit: usize, + user: Option<&'_ str> + ) -> Result<Option<(serde_json::Value, Option<String>)>> { + let mut txn = self.db.begin().await?; + sqlx::query("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ ONLY") + .execute(&mut *txn) + .await?; + tracing::debug!("Started txn: {:?}", txn); + let mut feed = match sqlx::query_scalar::<_, serde_json::Value>(" +SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 +") + .bind(url) + .fetch_optional(&mut *txn) + .await? + { + Some(feed) => feed, + None => return Ok(None) + }; + + // Don't query for children if this isn't a feed. + // + // The second query is very long and will probably be extremely + // expensive. It's best to skip it on types where it doesn't make sense + // (Kittybox doesn't support rendering children on non-feeds) + if !feed["type"].as_array().unwrap().iter().any(|t| *t == serde_json::json!("h-feed")) { + return Ok(Some((feed, None))); + } + + feed["children"] = sqlx::query_scalar::<_, serde_json::Value>(" +SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json +INNER JOIN kittybox.children +ON mf2_json.uid = children.child +WHERE + children.parent = $1 + AND ( + ( + (mf2 #>> '{properties,visibility,0}') = 'public' + OR + NOT (mf2['properties'] ? 'visibility') + ) + OR + ( + $3 != null AND ( + mf2['properties']['audience'] ? $3 + OR mf2['properties']['author'] ? $3 + ) + ) + ) + AND ($4 IS NULL OR ((mf2_json.mf2 #>> '{properties,published,0}') < $4)) +ORDER BY (mf2_json.mf2 #>> '{properties,published,0}') DESC +LIMIT $2" + ) + .bind(url) + .bind(limit as i64) + .bind(user) + .bind(cursor) + .fetch_all(&mut *txn) + .await + .map(serde_json::Value::Array)?; + + let new_cursor = feed["children"].as_array().unwrap() + .last() + .map(|v| v["properties"]["published"][0].as_str().unwrap().to_owned()); + + txn.commit().await?; + + Ok(Some((feed, new_cursor))) + } + + #[tracing::instrument(skip(self))] + async fn delete_post(&self, url: &'_ str) -> Result<()> { + todo!() + } + + #[tracing::instrument(skip(self))] + async fn get_setting<S: Setting<'a>, 'a>(&'_ self, user: &'_ str) -> Result<S> { + match sqlx::query_as::<_, (serde_json::Value,)>("SELECT kittybox.get_setting($1, $2)") + .bind(user) + .bind(S::ID) + .fetch_one(&self.db) + .await + { + Ok((value,)) => Ok(serde_json::from_value(value)?), + Err(err) => Err(err.into()) + } + } + + #[tracing::instrument(skip(self))] + async fn set_setting<S: Setting<'a> + 'a, 'a>(&self, user: &'a str, value: S::Data) -> Result<()> { + sqlx::query("SELECT kittybox.set_setting($1, $2, $3)") + .bind(user) + .bind(S::ID) + .bind(serde_json::to_value(S::new(value)).unwrap()) + .execute(&self.db) + .await + .map_err(Into::into) + .map(|_| ()) + } +} diff --git a/src/database/redis/edit_post.lua b/src/database/redis/edit_post.lua new file mode 100644 index 0000000..a398f8d --- /dev/null +++ b/src/database/redis/edit_post.lua @@ -0,0 +1,93 @@ +local posts = KEYS[1] +local update_desc = cjson.decode(ARGV[2]) +local post = cjson.decode(redis.call("HGET", posts, ARGV[1])) + +local delete_keys = {} +local delete_kvs = {} +local add_keys = {} + +if update_desc.replace ~= nil then + for k, v in pairs(update_desc.replace) do + table.insert(delete_keys, k) + add_keys[k] = v + end +end +if update_desc.delete ~= nil then + if update_desc.delete[0] == nil then + -- Table has string keys. Probably! + for k, v in pairs(update_desc.delete) do + delete_kvs[k] = v + end + else + -- Table has numeric keys. Probably! + for i, v in ipairs(update_desc.delete) do + table.insert(delete_keys, v) + end + end +end +if update_desc.add ~= nil then + for k, v in pairs(update_desc.add) do + add_keys[k] = v + end +end + +for i, v in ipairs(delete_keys) do + post["properties"][v] = nil + -- TODO delete URL links +end + +for k, v in pairs(delete_kvs) do + local index = -1 + if k == "children" then + for j, w in ipairs(post[k]) do + if w == v then + index = j + break + end + end + if index > -1 then + table.remove(post[k], index) + end + else + for j, w in ipairs(post["properties"][k]) do + if w == v then + index = j + break + end + end + if index > -1 then + table.remove(post["properties"][k], index) + -- TODO delete URL links + end + end +end + +for k, v in pairs(add_keys) do + if k == "children" then + if post["children"] == nil then + post["children"] = {} + end + for i, w in ipairs(v) do + table.insert(post["children"], 1, w) + end + else + if post["properties"][k] == nil then + post["properties"][k] = {} + end + for i, w in ipairs(v) do + table.insert(post["properties"][k], w) + end + if k == "url" then + redis.call("HSET", posts, v, cjson.encode({ see_other = post["properties"]["uid"][1] })) + elseif k == "channel" then + local feed = cjson.decode(redis.call("HGET", posts, v)) + table.insert(feed["children"], 1, post["properties"]["uid"][1]) + redis.call("HSET", posts, v, cjson.encode(feed)) + end + end +end + +local encoded = cjson.encode(post) +redis.call("SET", "debug", encoded) +redis.call("HSET", posts, post["properties"]["uid"][1], encoded) +return \ No newline at end of file diff --git a/src/database/redis/mod.rs b/src/database/redis/mod.rs new file mode 100644 index 0000000..39ee852 --- /dev/null +++ b/src/database/redis/mod.rs @@ -0,0 +1,398 @@ +use async_trait::async_trait; +use futures::stream; +use futures_util::FutureExt; +use futures_util::StreamExt; +use futures_util::TryStream; +use futures_util::TryStreamExt; +use lazy_static::lazy_static; +use log::error; +use mobc::Pool; +use mobc_redis::redis; +use mobc_redis::redis::AsyncCommands; +use mobc_redis::RedisConnectionManager; +use serde_json::json; +use std::time::Duration; + +use crate::database::{ErrorKind, MicropubChannel, Result, Storage, StorageError, filter_post}; +use crate::indieauth::User; + +struct RedisScripts { + edit_post: redis::Script, +} + +impl From<mobc_redis::redis::RedisError> for StorageError { + fn from(err: mobc_redis::redis::RedisError) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::Backend, + } + } +} +impl From<mobc::Error<mobc_redis::redis::RedisError>> for StorageError { + fn from(err: mobc::Error<mobc_redis::redis::RedisError>) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::Backend, + } + } +} + +lazy_static! { + static ref SCRIPTS: RedisScripts = RedisScripts { + edit_post: redis::Script::new(include_str!("./edit_post.lua")) + }; +} +/*#[cfg(feature(lazy_cell))] +static SCRIPTS_CELL: std::cell::LazyCell<RedisScripts> = std::cell::LazyCell::new(|| { + RedisScripts { + edit_post: redis::Script::new(include_str!("./edit_post.lua")) + } +});*/ + +#[derive(Clone)] +pub struct RedisStorage { + // note to future Vika: + // mobc::Pool is actually a fancy name for an Arc + // around a shared connection pool with a manager + // which makes it safe to implement [`Clone`] and + // not worry about new pools being suddenly made + // + // stop worrying and start coding, you dum-dum + redis: mobc::Pool<RedisConnectionManager>, +} + +#[async_trait] +impl Storage for RedisStorage { + async fn get_setting<'a>(&self, setting: &'a str, user: &'a str) -> Result<String> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + Ok(conn + .hget::<String, &str, String>(format!("settings_{}", user), setting) + .await?) + } + + async fn set_setting<'a>(&self, setting: &'a str, user: &'a str, value: &'a str) -> Result<()> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + Ok(conn + .hset::<String, &str, &str, ()>(format!("settings_{}", user), setting, value) + .await?) + } + + async fn delete_post<'a>(&self, url: &'a str) -> Result<()> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + Ok(conn.hdel::<&str, &str, ()>("posts", url).await?) + } + + async fn post_exists(&self, url: &str) -> Result<bool> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + Ok(conn.hexists::<&str, &str, bool>("posts", url).await?) + } + + async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + match conn + .hget::<&str, &str, Option<String>>("posts", url) + .await? + { + Some(val) => { + let parsed = serde_json::from_str::<serde_json::Value>(&val)?; + if let Some(new_url) = parsed["see_other"].as_str() { + match conn + .hget::<&str, &str, Option<String>>("posts", new_url) + .await? + { + Some(val) => Ok(Some(serde_json::from_str::<serde_json::Value>(&val)?)), + None => Ok(None), + } + } else { + Ok(Some(parsed)) + } + } + None => Ok(None), + } + } + + async fn get_channels(&self, user: &User) -> Result<Vec<MicropubChannel>> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + let channels = conn + .smembers::<String, Vec<String>>("channels_".to_string() + user.me.as_str()) + .await?; + // TODO: use streams here instead of this weird thing... how did I even write this?! + Ok(futures_util::future::join_all( + channels + .iter() + .map(|channel| { + self.get_post(channel).map(|result| result.unwrap()).map( + |post: Option<serde_json::Value>| { + post.map(|post| MicropubChannel { + uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), + name: post["properties"]["name"][0].as_str().unwrap().to_string(), + }) + }, + ) + }) + .collect::<Vec<_>>(), + ) + .await + .into_iter() + .flatten() + .collect::<Vec<_>>()) + } + + async fn put_post<'a>(&self, post: &'a serde_json::Value, user: &'a str) -> Result<()> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + let key: &str; + match post["properties"]["uid"][0].as_str() { + Some(uid) => key = uid, + None => { + return Err(StorageError::new( + ErrorKind::BadRequest, + "post doesn't have a UID", + )) + } + } + conn.hset::<&str, &str, String, ()>("posts", key, post.to_string()) + .await?; + if post["properties"]["url"].is_array() { + for url in post["properties"]["url"] + .as_array() + .unwrap() + .iter() + .map(|i| i.as_str().unwrap().to_string()) + { + if url != key && url.starts_with(user) { + conn.hset::<&str, &str, String, ()>( + "posts", + &url, + json!({ "see_other": key }).to_string(), + ) + .await?; + } + } + } + if post["type"] + .as_array() + .unwrap() + .iter() + .any(|i| i == "h-feed") + { + // This is a feed. Add it to the channels array if it's not already there. + conn.sadd::<String, &str, ()>( + "channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(), + key, + ) + .await? + } + Ok(()) + } + + async fn read_feed_with_limit<'a>( + &self, + url: &'a str, + after: &'a Option<String>, + limit: usize, + user: &'a Option<String>, + ) -> Result<Option<serde_json::Value>> { + let mut conn = self.redis.get().await?; + let mut feed; + match conn + .hget::<&str, &str, Option<String>>("posts", url) + .await + .map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))? + { + Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?, + None => return Ok(None), + } + if feed["see_other"].is_string() { + match conn + .hget::<&str, &str, Option<String>>("posts", feed["see_other"].as_str().unwrap()) + .await? + { + Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?, + None => return Ok(None), + } + } + if let Some(post) = filter_post(feed, user) { + feed = post + } else { + return Err(StorageError::new( + ErrorKind::PermissionDenied, + "specified user cannot access this post", + )); + } + if feed["children"].is_array() { + let children = feed["children"].as_array().unwrap(); + let mut posts_iter = children.iter().map(|i| i.as_str().unwrap().to_string()); + if after.is_some() { + loop { + let i = posts_iter.next(); + if &i == after { + break; + } + } + } + async fn fetch_post_for_feed(url: String) -> Option<serde_json::Value> { + return Some(serde_json::json!({})); + } + let posts = stream::iter(posts_iter) + .map(|url: String| async move { + return Ok(fetch_post_for_feed(url).await); + /*match self.redis.get().await { + Ok(mut conn) => { + match conn.hget::<&str, &str, Option<String>>("posts", &url).await { + Ok(post) => match post { + Some(post) => { + Ok(Some(serde_json::from_str(&post)?)) + } + // Happens because of a broken link (result of an improper deletion?) + None => Ok(None), + }, + Err(err) => Err(StorageError::with_source(ErrorKind::Backend, "Error executing a Redis command", Box::new(err))) + } + } + Err(err) => Err(StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(err))) + }*/ + }) + // TODO: determine the optimal value for this buffer + // It will probably depend on how often can you encounter a private post on the page + // It shouldn't be too large, or we'll start fetching too many posts from the database + // It MUST NOT be larger than the typical page size + // It MUST NOT be a significant amount of the connection pool size + //.buffered(std::cmp::min(3, limit)) + // Hack to unwrap the Option and sieve out broken links + // Broken links return None, and Stream::filter_map skips all Nones. + // I wonder if one can use try_flatten() here somehow akin to iters + .try_filter_map(|post| async move { Ok(post) }) + .try_filter_map(|post| async move { + Ok(filter_post(post, user)) + }) + .take(limit); + match posts.try_collect::<Vec<serde_json::Value>>().await { + Ok(posts) => feed["children"] = json!(posts), + Err(err) => { + let e = StorageError::with_source( + ErrorKind::Other, + "An error was encountered while processing the feed", + Box::new(err) + ); + error!("Error while assembling feed: {}", e); + return Err(e); + } + } + } + return Ok(Some(feed)); + } + + async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> { + let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; + if !conn + .hexists::<&str, &str, bool>("posts", url) + .await + .unwrap() + { + return Err(StorageError::new( + ErrorKind::NotFound, + "can't edit a non-existent post", + )); + } + let post: serde_json::Value = + serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await?)?; + if let Some(new_url) = post["see_other"].as_str() { + url = new_url + } + Ok(SCRIPTS + .edit_post + .key("posts") + .arg(url) + .arg(update.to_string()) + .invoke_async::<_, ()>(&mut conn as &mut redis::aio::Connection) + .await?) + } +} + +impl RedisStorage { + /// Create a new RedisDatabase that will connect to Redis at `redis_uri` to store data. + pub async fn new(redis_uri: String) -> Result<Self> { + match redis::Client::open(redis_uri) { + Ok(client) => Ok(Self { + redis: Pool::builder() + .max_open(20) + .max_idle(5) + .get_timeout(Some(Duration::from_secs(3))) + .max_lifetime(Some(Duration::from_secs(120))) + .build(RedisConnectionManager::new(client)), + }), + Err(e) => Err(e.into()), + } + } + + pub async fn conn(&self) -> Result<mobc::Connection<mobc_redis::RedisConnectionManager>> { + self.redis.get().await.map_err(|e| StorageError::with_source( + ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e) + )) + } +} + +#[cfg(test)] +pub mod tests { + use mobc_redis::redis; + use std::process; + use std::time::Duration; + + pub struct RedisInstance { + // We just need to hold on to it so it won't get dropped and remove the socket + _tempdir: tempdir::TempDir, + uri: String, + child: std::process::Child, + } + impl Drop for RedisInstance { + fn drop(&mut self) { + self.child.kill().expect("Failed to kill the child!"); + } + } + impl RedisInstance { + pub fn uri(&self) -> &str { + &self.uri + } + } + + pub async fn get_redis_instance() -> RedisInstance { + let tempdir = tempdir::TempDir::new("redis").expect("failed to create tempdir"); + let socket = tempdir.path().join("redis.sock"); + let redis_child = process::Command::new("redis-server") + .current_dir(&tempdir) + .arg("--port") + .arg("0") + .arg("--unixsocket") + .arg(&socket) + .stdout(process::Stdio::null()) + .stderr(process::Stdio::null()) + .spawn() + .expect("Failed to spawn Redis"); + println!("redis+unix:///{}", socket.to_str().unwrap()); + let uri = format!("redis+unix:///{}", socket.to_str().unwrap()); + // There should be a slight delay, we need to wait for Redis to spin up + let client = redis::Client::open(uri.clone()).unwrap(); + let millisecond = Duration::from_millis(1); + let mut retries: usize = 0; + const MAX_RETRIES: usize = 60 * 1000/*ms*/; + while let Err(err) = client.get_connection() { + if err.is_connection_refusal() { + async_std::task::sleep(millisecond).await; + retries += 1; + if retries > MAX_RETRIES { + panic!("Timeout waiting for Redis, last error: {}", err); + } + } else { + panic!("Could not connect: {}", err); + } + } + + RedisInstance { + uri, + child: redis_child, + _tempdir: tempdir, + } + } +} diff --git a/src/frontend/login.rs b/src/frontend/login.rs new file mode 100644 index 0000000..c693899 --- /dev/null +++ b/src/frontend/login.rs @@ -0,0 +1,333 @@ +use http_types::Mime; +use log::{debug, error}; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::convert::TryInto; +use std::str::FromStr; + +use crate::frontend::templates::Template; +use crate::frontend::{FrontendError, IndiewebEndpoints}; +use crate::{database::Storage, ApplicationState}; +use kittybox_frontend_renderer::LoginPage; + +pub async fn form<S: Storage>(req: Request<ApplicationState<S>>) -> Result { + let owner = req.url().origin().ascii_serialization() + "/"; + let storage = &req.state().storage; + let authorization_endpoint = req.state().authorization_endpoint.to_string(); + let token_endpoint = req.state().token_endpoint.to_string(); + let blog_name = storage + .get_setting("site_name", &owner) + .await + .unwrap_or_else(|_| "Kitty Box!".to_string()); + let feeds = storage.get_channels(&owner).await.unwrap_or_default(); + + Ok(Response::builder(200) + .body( + Template { + title: "Sign in with IndieAuth", + blog_name: &blog_name, + endpoints: IndiewebEndpoints { + authorization_endpoint, + token_endpoint, + webmention: None, + microsub: None, + }, + feeds, + user: req.session().get("user"), + content: LoginPage {}.to_string(), + } + .to_string(), + ) + .content_type("text/html; charset=utf-8") + .build()) +} + +#[derive(Serialize, Deserialize)] +struct LoginForm { + url: String, +} + +#[derive(Serialize, Deserialize)] +struct IndieAuthClientState { + /// A random value to protect from CSRF attacks. + nonce: String, + /// The user's initial "me" value. + me: String, + /// Authorization endpoint used. + authorization_endpoint: String, +} + +#[derive(Serialize, Deserialize)] +struct IndieAuthRequestParams { + response_type: String, // can only have "code". TODO make an enum + client_id: String, // always a URL. TODO consider making a URL + redirect_uri: surf::Url, // callback URI for IndieAuth + state: String, // CSRF protection, should include randomness and be passed through + code_challenge: String, // base64-encoded PKCE challenge + code_challenge_method: String, // usually "S256". TODO make an enum + scope: Option<String>, // oAuth2 scopes to grant, + me: surf::Url, // User's entered profile URL +} + +/// Handle login requests. Find the IndieAuth authorization endpoint and redirect to it. +pub async fn handler<S: Storage>(mut req: Request<ApplicationState<S>>) -> Result { + let content_type = req.content_type(); + if content_type.is_none() { + return Err(FrontendError::with_code(400, "Use the login form, Luke.").into()); + } + if content_type.unwrap() != Mime::from_str("application/x-www-form-urlencoded").unwrap() { + return Err( + FrontendError::with_code(400, "Login form results must be a urlencoded form").into(), + ); + } + + let form = req.body_form::<LoginForm>().await?; // FIXME check if it returns 400 or 500 on error + let homepage_uri = surf::Url::parse(&form.url)?; + let http = &req.state().http_client; + + let mut fetch_response = http.get(&homepage_uri).send().await?; + if fetch_response.status() != 200 { + return Err(FrontendError::with_code( + 500, + "Error fetching your authorization endpoint. Check if your website's okay.", + ) + .into()); + } + + let mut authorization_endpoint: Option<surf::Url> = None; + if let Some(links) = fetch_response.header("Link") { + // NOTE: this is the same Link header parser used in src/micropub/post.rs:459. + // One should refactor it to a function to use independently and improve later + for link in links.iter().flat_map(|i| i.as_str().split(',')) { + debug!("Trying to match {} as authorization_endpoint", link); + let mut split_link = link.split(';'); + + match split_link.next() { + Some(uri) => { + if let Some(uri) = uri.strip_prefix('<').and_then(|uri| uri.strip_suffix('>')) { + debug!("uri: {}", uri); + for prop in split_link { + debug!("prop: {}", prop); + let lowercased = prop.to_ascii_lowercase(); + let trimmed = lowercased.trim(); + if trimmed == "rel=\"authorization_endpoint\"" + || trimmed == "rel=authorization_endpoint" + { + if let Ok(endpoint) = homepage_uri.join(uri) { + debug!( + "Found authorization endpoint {} for user {}", + endpoint, + homepage_uri.as_str() + ); + authorization_endpoint = Some(endpoint); + break; + } + } + } + } + } + None => continue, + } + } + } + // If the authorization_endpoint is still not found after the Link parsing gauntlet, + // bring out the big guns and parse HTML to find it. + if authorization_endpoint.is_none() { + let body = fetch_response.body_string().await?; + let pattern = + easy_scraper::Pattern::new(r#"<link rel="authorization_endpoint" href="{{url}}">"#) + .expect("Cannot parse the pattern for authorization_endpoint"); + let matches = pattern.matches(&body); + debug!("Matches for authorization_endpoint in HTML: {:?}", matches); + if !matches.is_empty() { + if let Ok(endpoint) = homepage_uri.join(&matches[0]["url"]) { + debug!( + "Found authorization endpoint {} for user {}", + endpoint, + homepage_uri.as_str() + ); + authorization_endpoint = Some(endpoint) + } + } + }; + // If even after this the authorization endpoint is still not found, bail out. + if authorization_endpoint.is_none() { + error!( + "Couldn't find authorization_endpoint for {}", + homepage_uri.as_str() + ); + return Err(FrontendError::with_code( + 400, + "Your website doesn't support the IndieAuth protocol.", + ) + .into()); + } + let mut authorization_endpoint: surf::Url = authorization_endpoint.unwrap(); + let mut rng = rand::thread_rng(); + let state: String = data_encoding::BASE64URL.encode( + serde_urlencoded::to_string(IndieAuthClientState { + nonce: (0..8) + .map(|_| { + let idx = rng.gen_range(0..INDIEAUTH_PKCE_CHARSET.len()); + INDIEAUTH_PKCE_CHARSET[idx] as char + }) + .collect(), + me: homepage_uri.to_string(), + authorization_endpoint: authorization_endpoint.to_string(), + })? + .as_bytes(), + ); + // PKCE code generation + let code_verifier: String = (0..128) + .map(|_| { + let idx = rng.gen_range(0..INDIEAUTH_PKCE_CHARSET.len()); + INDIEAUTH_PKCE_CHARSET[idx] as char + }) + .collect(); + let mut hasher = Sha256::new(); + hasher.update(code_verifier.as_bytes()); + let code_challenge: String = data_encoding::BASE64URL.encode(&hasher.finalize()); + + authorization_endpoint.set_query(Some(&serde_urlencoded::to_string( + IndieAuthRequestParams { + response_type: "code".to_string(), + client_id: req.url().origin().ascii_serialization(), + redirect_uri: req.url().join("login/callback")?, + state: state.clone(), + code_challenge, + code_challenge_method: "S256".to_string(), + scope: Some("profile".to_string()), + me: homepage_uri, + }, + )?)); + + let cookies = vec![ + format!( + r#"indieauth_state="{}"; Same-Site: None; Secure; Max-Age: 600"#, + state + ), + format!( + r#"indieauth_code_verifier="{}"; Same-Site: None; Secure; Max-Age: 600"#, + code_verifier + ), + ]; + + let cookie_header = cookies + .iter() + .map(|i| -> http_types::headers::HeaderValue { (i as &str).try_into().unwrap() }) + .collect::<Vec<_>>(); + + Ok(Response::builder(302) + .header("Location", authorization_endpoint.to_string()) + .header("Set-Cookie", &*cookie_header) + .build()) +} + +const INDIEAUTH_PKCE_CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\ + abcdefghijklmnopqrstuvwxyz\ + 1234567890-._~"; + +#[derive(Deserialize)] +struct IndieAuthCallbackResponse { + code: Option<String>, + error: Option<String>, + error_description: Option<String>, + #[allow(dead_code)] + error_uri: Option<String>, + // This needs to be further decoded to receive state back and will always be present + state: String, +} + +impl IndieAuthCallbackResponse { + fn is_successful(&self) -> bool { + self.code.is_some() + } +} + +#[derive(Serialize, Deserialize)] +struct IndieAuthCodeRedeem { + grant_type: String, + code: String, + client_id: String, + redirect_uri: String, + code_verifier: String, +} + +#[derive(Serialize, Deserialize)] +struct IndieWebProfile { + name: Option<String>, + url: Option<String>, + email: Option<String>, + photo: Option<String>, +} + +#[derive(Serialize, Deserialize)] +struct IndieAuthResponse { + me: String, + scope: Option<String>, + access_token: Option<String>, + token_type: Option<String>, + profile: Option<IndieWebProfile>, +} + +/// Handle IndieAuth parameters, fetch the final h-card and redirect the user to the homepage. +pub async fn callback<S: Storage>(mut req: Request<ApplicationState<S>>) -> Result { + let params: IndieAuthCallbackResponse = req.query()?; + let http: &surf::Client = &req.state().http_client; + let origin = req.url().origin().ascii_serialization(); + + if req.cookie("indieauth_state").unwrap().value() != params.state { + return Err(FrontendError::with_code(400, "The state doesn't match. A possible CSRF attack was prevented. Please try again later.").into()); + } + let state: IndieAuthClientState = + serde_urlencoded::from_bytes(&data_encoding::BASE64URL.decode(params.state.as_bytes())?)?; + + if !params.is_successful() { + return Err(FrontendError::with_code( + 400, + &format!( + "The authorization endpoint indicated a following error: {:?}: {:?}", + ¶ms.error, ¶ms.error_description + ), + ) + .into()); + } + + let authorization_endpoint = surf::Url::parse(&state.authorization_endpoint).unwrap(); + let mut code_response = http + .post(authorization_endpoint) + .body_string(serde_urlencoded::to_string(IndieAuthCodeRedeem { + grant_type: "authorization_code".to_string(), + code: params.code.unwrap().to_string(), + client_id: origin.to_string(), + redirect_uri: origin + "/login/callback", + code_verifier: req + .cookie("indieauth_code_verifier") + .unwrap() + .value() + .to_string(), + })?) + .header("Content-Type", "application/x-www-form-urlencoded") + .header("Accept", "application/json") + .send() + .await?; + + if code_response.status() != 200 { + return Err(FrontendError::with_code( + code_response.status(), + &format!( + "Authorization endpoint returned an error when redeeming the code: {}", + code_response.body_string().await? + ), + ) + .into()); + } + + let json: IndieAuthResponse = code_response.body_json().await?; + let session = req.session_mut(); + session.insert("user", &json.me)?; + + // TODO redirect to the page user came from + Ok(Response::builder(302).header("Location", "/").build()) +} diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs new file mode 100644 index 0000000..7a43532 --- /dev/null +++ b/src/frontend/mod.rs @@ -0,0 +1,404 @@ +use crate::database::{Storage, StorageError}; +use axum::{ + extract::{Host, Path, Query}, + http::{StatusCode, Uri}, + response::IntoResponse, + Extension, +}; +use futures_util::FutureExt; +use serde::Deserialize; +use std::convert::TryInto; +use tracing::{debug, error}; +//pub mod login; +pub mod onboarding; + +use kittybox_frontend_renderer::{ + Entry, Feed, VCard, + ErrorPage, Template, MainPage, + POSTS_PER_PAGE +}; +pub use kittybox_frontend_renderer::assets::statics; + +#[derive(Debug, Deserialize)] +pub struct QueryParams { + after: Option<String>, +} + +#[derive(Debug)] +struct FrontendError { + msg: String, + source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>, + code: StatusCode, +} + +impl FrontendError { + pub fn with_code<C>(code: C, msg: &str) -> Self + where + C: TryInto<StatusCode>, + { + Self { + msg: msg.to_string(), + source: None, + code: code.try_into().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + } + } + pub fn msg(&self) -> &str { + &self.msg + } + pub fn code(&self) -> StatusCode { + self.code + } +} + +impl From<StorageError> for FrontendError { + fn from(err: StorageError) -> Self { + Self { + msg: "Database error".to_string(), + source: Some(Box::new(err)), + code: StatusCode::INTERNAL_SERVER_ERROR, + } + } +} + +impl std::error::Error for FrontendError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.source + .as_ref() + .map(|e| e.as_ref() as &(dyn std::error::Error + 'static)) + } +} + +impl std::fmt::Display for FrontendError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.msg)?; + if let Some(err) = std::error::Error::source(&self) { + write!(f, ": {}", err)?; + } + + Ok(()) + } +} + +/// Filter the post according to the value of `user`. +/// +/// Anonymous users cannot view private posts and protected locations; +/// Logged-in users can only view private posts targeted at them; +/// Logged-in users can't view private location data +#[tracing::instrument(skip(post), fields(post = %post))] +pub fn filter_post( + mut post: serde_json::Value, + user: Option<&str>, +) -> Option<serde_json::Value> { + if post["properties"]["deleted"][0].is_string() { + tracing::debug!("Deleted post; returning tombstone instead"); + return Some(serde_json::json!({ + "type": post["type"], + "properties": { + "deleted": post["properties"]["deleted"] + } + })); + } + let empty_vec: Vec<serde_json::Value> = vec![]; + let author_list = post["properties"]["author"] + .as_array() + .unwrap_or(&empty_vec) + .iter() + .map(|i| -> &str { + match i { + serde_json::Value::String(ref author) => author.as_str(), + mf2 => mf2["properties"]["uid"][0].as_str().unwrap() + } + }).collect::<Vec<&str>>(); + let visibility = post["properties"]["visibility"][0] + .as_str() + .unwrap_or("public"); + let audience = { + let mut audience = author_list.clone(); + audience.extend(post["properties"]["audience"] + .as_array() + .unwrap_or(&empty_vec) + .iter() + .map(|i| i.as_str().unwrap())); + + audience + }; + tracing::debug!("post audience = {:?}", audience); + if (visibility == "private" && !audience.iter().any(|i| Some(*i) == user)) + || (visibility == "protected" && user.is_none()) + { + return None; + } + if post["properties"]["location"].is_array() { + let location_visibility = post["properties"]["location-visibility"][0] + .as_str() + .unwrap_or("private"); + tracing::debug!("Post contains location, location privacy = {}", location_visibility); + let mut author = post["properties"]["author"] + .as_array() + .unwrap_or(&empty_vec) + .iter() + .map(|i| i.as_str().unwrap()); + if (location_visibility == "private" && !author.any(|i| Some(i) == user)) + || (location_visibility == "protected" && user.is_none()) + { + post["properties"] + .as_object_mut() + .unwrap() + .remove("location"); + } + } + + match post["properties"]["author"].take() { + serde_json::Value::Array(children) => { + post["properties"]["author"] = serde_json::Value::Array( + children + .into_iter() + .filter_map(|post| if post.is_string() { + Some(post) + } else { + filter_post(post, user) + }) + .collect::<Vec<serde_json::Value>>() + ); + }, + serde_json::Value::Null => {}, + other => post["properties"]["author"] = other + } + + match post["children"].take() { + serde_json::Value::Array(children) => { + post["children"] = serde_json::Value::Array( + children + .into_iter() + .filter_map(|post| filter_post(post, user)) + .collect::<Vec<serde_json::Value>>() + ); + }, + serde_json::Value::Null => {}, + other => post["children"] = other + } + Some(post) +} + +async fn get_post_from_database<S: Storage>( + db: &S, + url: &str, + after: Option<String>, + user: &Option<String>, +) -> std::result::Result<(serde_json::Value, Option<String>), FrontendError> { + match db + .read_feed_with_cursor(url, after.as_deref(), POSTS_PER_PAGE, user.as_deref()) + .await + { + Ok(result) => match result { + Some((post, cursor)) => match filter_post(post, user.as_deref()) { + Some(post) => Ok((post, cursor)), + None => { + // TODO: Authentication + if user.is_some() { + Err(FrontendError::with_code( + StatusCode::FORBIDDEN, + "User authenticated AND forbidden to access this resource", + )) + } else { + Err(FrontendError::with_code( + StatusCode::UNAUTHORIZED, + "User needs to authenticate themselves", + )) + } + } + } + None => Err(FrontendError::with_code( + StatusCode::NOT_FOUND, + "Post not found in the database", + )), + }, + Err(err) => match err.kind() { + crate::database::ErrorKind::PermissionDenied => { + // TODO: Authentication + if user.is_some() { + Err(FrontendError::with_code( + StatusCode::FORBIDDEN, + "User authenticated AND forbidden to access this resource", + )) + } else { + Err(FrontendError::with_code( + StatusCode::UNAUTHORIZED, + "User needs to authenticate themselves", + )) + } + } + _ => Err(err.into()), + }, + } +} + +#[tracing::instrument(skip(db))] +pub async fn homepage<D: Storage>( + Host(host): Host, + Query(query): Query<QueryParams>, + Extension(db): Extension<D>, +) -> impl IntoResponse { + let user = None; // TODO authentication + let path = format!("https://{}/", host); + let feed_path = format!("https://{}/feeds/main", host); + + match tokio::try_join!( + get_post_from_database(&db, &path, None, &user), + get_post_from_database(&db, &feed_path, query.after, &user) + ) { + Ok(((hcard, _), (hfeed, cursor))) => { + // Here, we know those operations can't really fail + // (or it'll be a transient failure that will show up on + // other requests anyway if it's serious...) + // + // btw is it more efficient to fetch these in parallel? + let (blogname, webring, channels) = tokio::join!( + db.get_setting::<crate::database::settings::SiteName>(&host) + .map(Result::unwrap_or_default), + + db.get_setting::<crate::database::settings::Webring>(&host) + .map(Result::unwrap_or_default), + + db.get_channels(&host).map(|i| i.unwrap_or_default()) + ); + // Render the homepage + ( + StatusCode::OK, + [( + axum::http::header::CONTENT_TYPE, + r#"text/html; charset="utf-8""#, + )], + Template { + title: blogname.as_ref(), + blog_name: blogname.as_ref(), + feeds: channels, + user, + content: MainPage { + feed: &hfeed, + card: &hcard, + cursor: cursor.as_deref(), + webring: crate::database::settings::Setting::into_inner(webring) + } + .to_string(), + } + .to_string(), + ) + } + Err(err) => { + if err.code == StatusCode::NOT_FOUND { + debug!("Transferring to onboarding..."); + // Transfer to onboarding + ( + StatusCode::FOUND, + [(axum::http::header::LOCATION, "/.kittybox/onboarding")], + String::default(), + ) + } else { + error!("Error while fetching h-card and/or h-feed: {}", err); + // Return the error + let (blogname, channels) = tokio::join!( + db.get_setting::<crate::database::settings::SiteName>(&host) + .map(Result::unwrap_or_default), + + db.get_channels(&host).map(|i| i.unwrap_or_default()) + ); + + ( + err.code(), + [( + axum::http::header::CONTENT_TYPE, + r#"text/html; charset="utf-8""#, + )], + Template { + title: blogname.as_ref(), + blog_name: blogname.as_ref(), + feeds: channels, + user, + content: ErrorPage { + code: err.code(), + msg: Some(err.msg().to_string()), + } + .to_string(), + } + .to_string(), + ) + } + } + } +} + +#[tracing::instrument(skip(db))] +pub async fn catchall<D: Storage>( + Extension(db): Extension<D>, + Host(host): Host, + Query(query): Query<QueryParams>, + uri: Uri, +) -> impl IntoResponse { + let user = None; // TODO authentication + let path = url::Url::parse(&format!("https://{}/", host)) + .unwrap() + .join(uri.path()) + .unwrap(); + + match get_post_from_database(&db, path.as_str(), query.after, &user).await { + Ok((post, cursor)) => { + let (blogname, channels) = tokio::join!( + db.get_setting::<crate::database::settings::SiteName>(&host) + .map(Result::unwrap_or_default), + + db.get_channels(&host).map(|i| i.unwrap_or_default()) + ); + // Render the homepage + ( + StatusCode::OK, + [( + axum::http::header::CONTENT_TYPE, + r#"text/html; charset="utf-8""#, + )], + Template { + title: blogname.as_ref(), + blog_name: blogname.as_ref(), + feeds: channels, + user, + content: match post.pointer("/type/0").and_then(|i| i.as_str()) { + Some("h-entry") => Entry { post: &post }.to_string(), + Some("h-feed") => Feed { feed: &post, cursor: cursor.as_deref() }.to_string(), + Some("h-card") => VCard { card: &post }.to_string(), + unknown => { + unimplemented!("Template for MF2-JSON type {:?}", unknown) + } + }, + } + .to_string(), + ) + } + Err(err) => { + let (blogname, channels) = tokio::join!( + db.get_setting::<crate::database::settings::SiteName>(&host) + .map(Result::unwrap_or_default), + + db.get_channels(&host).map(|i| i.unwrap_or_default()) + ); + ( + err.code(), + [( + axum::http::header::CONTENT_TYPE, + r#"text/html; charset="utf-8""#, + )], + Template { + title: blogname.as_ref(), + blog_name: blogname.as_ref(), + feeds: channels, + user, + content: ErrorPage { + code: err.code(), + msg: Some(err.msg().to_owned()), + } + .to_string(), + } + .to_string(), + ) + } + } +} diff --git a/src/frontend/onboarding.rs b/src/frontend/onboarding.rs new file mode 100644 index 0000000..e44e866 --- /dev/null +++ b/src/frontend/onboarding.rs @@ -0,0 +1,181 @@ +use std::sync::Arc; + +use crate::database::{settings, Storage}; +use axum::{ + extract::{Extension, Host}, + http::StatusCode, + response::{Html, IntoResponse}, + Json, +}; +use kittybox_frontend_renderer::{ErrorPage, OnboardingPage, Template}; +use serde::Deserialize; +use tokio::{task::JoinSet, sync::Mutex}; +use tracing::{debug, error}; + +use super::FrontendError; + +pub async fn get() -> Html<String> { + Html( + Template { + title: "Kittybox - Onboarding", + blog_name: "Kittybox", + feeds: vec![], + user: None, + content: OnboardingPage {}.to_string(), + } + .to_string(), + ) +} + +#[derive(Deserialize, Debug)] +struct OnboardingFeed { + slug: String, + name: String, +} + +#[derive(Deserialize, Debug)] +pub struct OnboardingData { + user: serde_json::Value, + first_post: serde_json::Value, + #[serde(default = "OnboardingData::default_blog_name")] + blog_name: String, + feeds: Vec<OnboardingFeed>, +} + +impl OnboardingData { + fn default_blog_name() -> String { + "Kitty Box!".to_owned() + } +} + +#[tracing::instrument(skip(db, http))] +async fn onboard<D: Storage + 'static>( + db: D, + user_uid: url::Url, + data: OnboardingData, + http: reqwest::Client, + jobset: Arc<Mutex<JoinSet<()>>>, +) -> Result<(), FrontendError> { + // Create a user to pass to the backend + // At this point the site belongs to nobody, so it is safe to do + tracing::debug!("Creating user..."); + let user = kittybox_indieauth::TokenData { + me: user_uid.clone(), + client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), + scope: kittybox_indieauth::Scopes::new(vec![kittybox_indieauth::Scope::Create]), + iat: None, exp: None + }; + tracing::debug!("User data: {:?}", user); + + if data.user["type"][0] != "h-card" || data.first_post["type"][0] != "h-entry" { + return Err(FrontendError::with_code( + StatusCode::BAD_REQUEST, + "user and first_post should be an h-card and an h-entry", + )); + } + + tracing::debug!("Setting settings..."); + let user_domain = format!( + "{}{}", + user.me.host_str().unwrap(), + user.me.port() + .map(|port| format!(":{}", port)) + .unwrap_or_default() + ); + db.set_setting::<settings::SiteName>(&user_domain, data.blog_name.to_owned()) + .await + .map_err(FrontendError::from)?; + + db.set_setting::<settings::Webring>(&user_domain, false) + .await + .map_err(FrontendError::from)?; + + let (_, hcard) = { + let mut hcard = data.user; + hcard["properties"]["uid"] = serde_json::json!([&user_uid]); + crate::micropub::normalize_mf2(hcard, &user) + }; + db.put_post(&hcard, user_domain.as_str()) + .await + .map_err(FrontendError::from)?; + + debug!("Creating feeds..."); + for feed in data.feeds { + if feed.name.is_empty() || feed.slug.is_empty() { + continue; + }; + debug!("Creating feed {} with slug {}", &feed.name, &feed.slug); + let (_, feed) = crate::micropub::normalize_mf2( + serde_json::json!({ + "type": ["h-feed"], + "properties": {"name": [feed.name], "mp-slug": [feed.slug]} + }), + &user, + ); + + db.put_post(&feed, user_uid.as_str()) + .await + .map_err(FrontendError::from)?; + } + let (uid, post) = crate::micropub::normalize_mf2(data.first_post, &user); + tracing::debug!("Posting first post {}...", uid); + crate::micropub::_post(&user, uid, post, db, http, jobset) + .await + .map_err(|e| FrontendError { + msg: "Error while posting the first post".to_string(), + source: Some(Box::new(e)), + code: StatusCode::INTERNAL_SERVER_ERROR, + })?; + + Ok(()) +} + +pub async fn post<D: Storage + 'static>( + Extension(db): Extension<D>, + Host(host): Host, + Extension(http): Extension<reqwest::Client>, + Extension(jobset): Extension<Arc<Mutex<JoinSet<()>>>>, + Json(data): Json<OnboardingData>, +) -> axum::response::Response { + let user_uid = format!("https://{}/", host.as_str()); + + if db.post_exists(&user_uid).await.unwrap() { + IntoResponse::into_response((StatusCode::FOUND, [("Location", "/")])) + } else { + match onboard(db, user_uid.parse().unwrap(), data, http, jobset).await { + Ok(()) => IntoResponse::into_response((StatusCode::FOUND, [("Location", "/")])), + Err(err) => { + error!("Onboarding error: {}", err); + IntoResponse::into_response(( + err.code(), + Html( + Template { + title: "Kittybox - Onboarding", + blog_name: "Kittybox", + feeds: vec![], + user: None, + content: ErrorPage { + code: err.code(), + msg: Some(err.msg().to_string()), + } + .to_string(), + } + .to_string(), + ), + )) + } + } + } +} + +pub fn router<S: Storage + 'static>( + database: S, + http: reqwest::Client, + jobset: Arc<Mutex<JoinSet<()>>>, +) -> axum::routing::MethodRouter { + axum::routing::get(get) + .post(post::<S>) + .layer::<_, _, std::convert::Infallible>(axum::Extension(database)) + .layer::<_, _, std::convert::Infallible>(axum::Extension(http)) + .layer(axum::Extension(jobset)) +} diff --git a/src/indieauth/backend.rs b/src/indieauth/backend.rs new file mode 100644 index 0000000..534bcfb --- /dev/null +++ b/src/indieauth/backend.rs @@ -0,0 +1,105 @@ +use std::collections::HashMap; +use kittybox_indieauth::{ + AuthorizationRequest, TokenData +}; +pub use kittybox_util::auth::EnrolledCredential; + +type Result<T> = std::io::Result<T>; + +pub mod fs; +pub use fs::FileBackend; + +#[async_trait::async_trait] +pub trait AuthBackend: Clone + Send + Sync + 'static { + // Authorization code management. + /// Create a one-time OAuth2 authorization code for the passed + /// authorization request, and save it for later retrieval. + /// + /// Note for implementors: the [`AuthorizationRequest::me`] value + /// is guaranteed to be [`Some(url::Url)`][Option::Some] and can + /// be trusted to be correct and non-malicious. + async fn create_code(&self, data: AuthorizationRequest) -> Result<String>; + /// Retreive an authorization request using the one-time + /// code. Implementations must sanitize the `code` field to + /// prevent exploits, and must check if the code should still be + /// valid at this point in time (validity interval is left up to + /// the implementation, but is recommended to be no more than 10 + /// minutes). + async fn get_code(&self, code: &str) -> Result<Option<AuthorizationRequest>>; + // Token management. + async fn create_token(&self, data: TokenData) -> Result<String>; + async fn get_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>>; + async fn list_tokens(&self, website: &url::Url) -> Result<HashMap<String, TokenData>>; + async fn revoke_token(&self, website: &url::Url, token: &str) -> Result<()>; + // Refresh token management. + async fn create_refresh_token(&self, data: TokenData) -> Result<String>; + async fn get_refresh_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>>; + async fn list_refresh_tokens(&self, website: &url::Url) -> Result<HashMap<String, TokenData>>; + async fn revoke_refresh_token(&self, website: &url::Url, token: &str) -> Result<()>; + // Password management. + /// Verify a password. + #[must_use] + async fn verify_password(&self, website: &url::Url, password: String) -> Result<bool>; + /// Enroll a password credential for a user. Only one password + /// credential must exist for a given user. + async fn enroll_password(&self, website: &url::Url, password: String) -> Result<()>; + /// List currently enrolled credential types for a given user. + async fn list_user_credential_types(&self, website: &url::Url) -> Result<Vec<EnrolledCredential>>; + // WebAuthn credential management. + #[cfg(feature = "webauthn")] + /// Enroll a WebAuthn authenticator public key for this user. + /// Multiple public keys may be saved for one user, corresponding + /// to different authenticators used by them. + /// + /// This function can also be used to overwrite a passkey with an + /// updated version after using + /// [webauthn::prelude::Passkey::update_credential()]. + async fn enroll_webauthn(&self, website: &url::Url, credential: webauthn::prelude::Passkey) -> Result<()>; + #[cfg(feature = "webauthn")] + /// List currently enrolled WebAuthn authenticators for a given user. + async fn list_webauthn_pubkeys(&self, website: &url::Url) -> Result<Vec<webauthn::prelude::Passkey>>; + #[cfg(feature = "webauthn")] + /// Persist registration challenge state for a little while so it + /// can be used later. + /// + /// Challenges saved in this manner MUST expire after a little + /// while. 10 minutes is recommended. + async fn persist_registration_challenge( + &self, + website: &url::Url, + state: webauthn::prelude::PasskeyRegistration + ) -> Result<String>; + #[cfg(feature = "webauthn")] + /// Retrieve a persisted registration challenge. + /// + /// The challenge should be deleted after retrieval. + async fn retrieve_registration_challenge( + &self, + website: &url::Url, + challenge_id: &str + ) -> Result<webauthn::prelude::PasskeyRegistration>; + #[cfg(feature = "webauthn")] + /// Persist authentication challenge state for a little while so + /// it can be used later. + /// + /// Challenges saved in this manner MUST expire after a little + /// while. 10 minutes is recommended. + /// + /// To support multiple authentication options, this can return an + /// opaque token that should be set as a cookie. + async fn persist_authentication_challenge( + &self, + website: &url::Url, + state: webauthn::prelude::PasskeyAuthentication + ) -> Result<String>; + #[cfg(feature = "webauthn")] + /// Retrieve a persisted authentication challenge. + /// + /// The challenge should be deleted after retrieval. + async fn retrieve_authentication_challenge( + &self, + website: &url::Url, + challenge_id: &str + ) -> Result<webauthn::prelude::PasskeyAuthentication>; + +} diff --git a/src/indieauth/backend/fs.rs b/src/indieauth/backend/fs.rs new file mode 100644 index 0000000..600e901 --- /dev/null +++ b/src/indieauth/backend/fs.rs @@ -0,0 +1,420 @@ +use std::{path::PathBuf, collections::HashMap, borrow::Cow, time::{SystemTime, Duration}}; + +use super::{AuthBackend, Result, EnrolledCredential}; +use async_trait::async_trait; +use kittybox_indieauth::{ + AuthorizationRequest, TokenData +}; +use serde::de::DeserializeOwned; +use tokio::{task::spawn_blocking, io::AsyncReadExt}; +#[cfg(feature = "webauthn")] +use webauthn::prelude::{Passkey, PasskeyRegistration, PasskeyAuthentication}; + +const CODE_LENGTH: usize = 16; +const TOKEN_LENGTH: usize = 128; +const CODE_DURATION: std::time::Duration = std::time::Duration::from_secs(600); + +#[derive(Clone, Debug)] +pub struct FileBackend { + path: PathBuf, +} + +impl FileBackend { + pub fn new<T: Into<PathBuf>>(path: T) -> Self { + Self { + path: path.into() + } + } + + /// Sanitize a filename, leaving only alphanumeric characters. + /// + /// Doesn't allocate a new string unless non-alphanumeric + /// characters are encountered. + fn sanitize_for_path(filename: &'_ str) -> Cow<'_, str> { + if filename.chars().all(char::is_alphanumeric) { + Cow::Borrowed(filename) + } else { + let mut s = String::with_capacity(filename.len()); + + filename.chars() + .filter(|c| c.is_alphanumeric()) + .for_each(|c| s.push(c)); + + Cow::Owned(s) + } + } + + #[inline] + async fn serialize_to_file<T: 'static + serde::ser::Serialize + Send, B: Into<Option<&'static str>>>( + &self, + dir: &str, + basename: B, + length: usize, + data: T + ) -> Result<String> { + let basename = basename.into(); + let has_ext = basename.is_some(); + let (filename, mut file) = kittybox_util::fs::mktemp( + self.path.join(dir), basename, length + ) + .await + .map(|(name, file)| (name, file.try_into_std().unwrap()))?; + + spawn_blocking(move || serde_json::to_writer(&mut file, &data)) + .await + .unwrap_or_else(|e| panic!( + "Panic while serializing {}: {}", + std::any::type_name::<T>(), + e + )) + .map(move |_| { + (if has_ext { + filename + .extension() + + } else { + filename + .file_name() + }) + .unwrap() + .to_str() + .unwrap() + .to_owned() + }) + .map_err(|err| err.into()) + } + + #[inline] + async fn deserialize_from_file<'filename, 'this: 'filename, T, B>( + &'this self, + dir: &'filename str, + basename: B, + filename: &'filename str, + ) -> Result<Option<(PathBuf, SystemTime, T)>> + where + T: serde::de::DeserializeOwned + Send, + B: Into<Option<&'static str>> + { + let basename = basename.into(); + let path = self.path + .join(dir) + .join(format!( + "{}{}{}", + basename.unwrap_or(""), + if basename.is_none() { "" } else { "." }, + FileBackend::sanitize_for_path(filename) + )); + + let data = match tokio::fs::File::open(&path).await { + Ok(mut file) => { + let mut buf = Vec::new(); + + file.read_to_end(&mut buf).await?; + + match serde_json::from_slice::<'_, T>(buf.as_slice()) { + Ok(data) => data, + Err(err) => return Err(err.into()) + } + }, + Err(err) => if err.kind() == std::io::ErrorKind::NotFound { + return Ok(None) + } else { + return Err(err) + } + }; + + let ctime = tokio::fs::metadata(&path).await?.created()?; + + Ok(Some((path, ctime, data))) + } + + #[inline] + fn url_to_dir(url: &url::Url) -> String { + let host = url.host_str().unwrap(); + let port = url.port() + .map(|port| Cow::Owned(format!(":{}", port))) + .unwrap_or(Cow::Borrowed("")); + + format!("{}{}", host, port) + } + + async fn list_files<'dir, 'this: 'dir, T: DeserializeOwned + Send>( + &'this self, + dir: &'dir str, + prefix: &'static str + ) -> Result<HashMap<String, T>> { + let dir = self.path.join(dir); + + let mut hashmap = HashMap::new(); + let mut readdir = match tokio::fs::read_dir(dir).await { + Ok(readdir) => readdir, + Err(err) => if err.kind() == std::io::ErrorKind::NotFound { + // empty hashmap + return Ok(hashmap); + } else { + return Err(err); + } + }; + while let Some(entry) = readdir.next_entry().await? { + // safe to unwrap; filenames are alphanumeric + let filename = entry.file_name() + .into_string() + .expect("token filenames should be alphanumeric!"); + if let Some(token) = filename.strip_prefix(&format!("{}.", prefix)) { + match tokio::fs::File::open(entry.path()).await { + Ok(mut file) => { + let mut buf = Vec::new(); + + file.read_to_end(&mut buf).await?; + + match serde_json::from_slice::<'_, T>(buf.as_slice()) { + Ok(data) => hashmap.insert(token.to_string(), data), + Err(err) => { + tracing::error!( + "Error decoding token data from file {}: {}", + entry.path().display(), err + ); + continue; + } + }; + }, + Err(err) => if err.kind() == std::io::ErrorKind::NotFound { + continue + } else { + return Err(err) + } + } + } + } + + Ok(hashmap) + } +} + +#[async_trait] +impl AuthBackend for FileBackend { + // Authorization code management. + async fn create_code(&self, data: AuthorizationRequest) -> Result<String> { + self.serialize_to_file("codes", None, CODE_LENGTH, data).await + } + + async fn get_code(&self, code: &str) -> Result<Option<AuthorizationRequest>> { + match self.deserialize_from_file("codes", None, FileBackend::sanitize_for_path(code).as_ref()).await? { + Some((path, ctime, data)) => { + if let Err(err) = tokio::fs::remove_file(path).await { + tracing::error!("Failed to clean up authorization code: {}", err); + } + // Err on the safe side in case of clock drift + if ctime.elapsed().unwrap_or(Duration::ZERO) > CODE_DURATION { + Ok(None) + } else { + Ok(Some(data)) + } + }, + None => Ok(None) + } + } + + // Token management. + async fn create_token(&self, data: TokenData) -> Result<String> { + let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me)); + self.serialize_to_file(&dir, "access", TOKEN_LENGTH, data).await + } + + async fn get_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> { + let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); + match self.deserialize_from_file::<TokenData, _>( + &dir, "access", + FileBackend::sanitize_for_path(token).as_ref() + ).await? { + Some((path, _, token)) => { + if token.expired() { + if let Err(err) = tokio::fs::remove_file(path).await { + tracing::error!("Failed to remove expired token: {}", err); + } + Ok(None) + } else { + Ok(Some(token)) + } + }, + None => Ok(None) + } + } + + async fn list_tokens(&self, website: &url::Url) -> Result<HashMap<String, TokenData>> { + let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); + self.list_files(&dir, "access").await + } + + async fn revoke_token(&self, website: &url::Url, token: &str) -> Result<()> { + match tokio::fs::remove_file( + self.path + .join(FileBackend::url_to_dir(website)) + .join("tokens") + .join(format!("access.{}", FileBackend::sanitize_for_path(token))) + ).await { + Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()), + result => result + } + } + + // Refresh token management. + async fn create_refresh_token(&self, data: TokenData) -> Result<String> { + let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me)); + self.serialize_to_file(&dir, "refresh", TOKEN_LENGTH, data).await + } + + async fn get_refresh_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> { + let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); + match self.deserialize_from_file::<TokenData, _>( + &dir, "refresh", + FileBackend::sanitize_for_path(token).as_ref() + ).await? { + Some((path, _, token)) => { + if token.expired() { + if let Err(err) = tokio::fs::remove_file(path).await { + tracing::error!("Failed to remove expired token: {}", err); + } + Ok(None) + } else { + Ok(Some(token)) + } + }, + None => Ok(None) + } + } + + async fn list_refresh_tokens(&self, website: &url::Url) -> Result<HashMap<String, TokenData>> { + let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); + self.list_files(&dir, "refresh").await + } + + async fn revoke_refresh_token(&self, website: &url::Url, token: &str) -> Result<()> { + match tokio::fs::remove_file( + self.path + .join(FileBackend::url_to_dir(website)) + .join("tokens") + .join(format!("refresh.{}", FileBackend::sanitize_for_path(token))) + ).await { + Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()), + result => result + } + } + + // Password management. + #[tracing::instrument(skip(password))] + async fn verify_password(&self, website: &url::Url, password: String) -> Result<bool> { + use argon2::{Argon2, password_hash::{PasswordHash, PasswordVerifier}}; + + let password_filename = self.path + .join(FileBackend::url_to_dir(website)) + .join("password"); + + tracing::debug!("Reading password for {} from {}", website, password_filename.display()); + + match tokio::fs::read_to_string(password_filename).await { + Ok(password_hash) => { + let parsed_hash = { + let hash = password_hash.trim(); + #[cfg(debug_assertions)] tracing::debug!("Password hash: {}", hash); + PasswordHash::new(hash) + .expect("Password hash should be valid!") + }; + Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok()) + }, + Err(err) => if err.kind() == std::io::ErrorKind::NotFound { + Ok(false) + } else { + Err(err) + } + } + } + + #[tracing::instrument(skip(password))] + async fn enroll_password(&self, website: &url::Url, password: String) -> Result<()> { + use argon2::{Argon2, password_hash::{rand_core::OsRng, PasswordHasher, SaltString}}; + + let password_filename = self.path + .join(FileBackend::url_to_dir(website)) + .join("password"); + + let salt = SaltString::generate(&mut OsRng); + let argon2 = Argon2::default(); + let password_hash = argon2.hash_password(password.as_bytes(), &salt) + .expect("Hashing a password should not error out") + .to_string(); + + tracing::debug!("Enrolling password for {} at {}", website, password_filename.display()); + tokio::fs::write(password_filename, password_hash.as_bytes()).await + } + + // WebAuthn credential management. + #[cfg(feature = "webauthn")] + async fn enroll_webauthn(&self, website: &url::Url, credential: Passkey) -> Result<()> { + todo!() + } + + #[cfg(feature = "webauthn")] + async fn list_webauthn_pubkeys(&self, website: &url::Url) -> Result<Vec<Passkey>> { + // TODO stub! + Ok(vec![]) + } + + #[cfg(feature = "webauthn")] + async fn persist_registration_challenge( + &self, + website: &url::Url, + state: PasskeyRegistration + ) -> Result<String> { + todo!() + } + + #[cfg(feature = "webauthn")] + async fn retrieve_registration_challenge( + &self, + website: &url::Url, + challenge_id: &str + ) -> Result<PasskeyRegistration> { + todo!() + } + + #[cfg(feature = "webauthn")] + async fn persist_authentication_challenge( + &self, + website: &url::Url, + state: PasskeyAuthentication + ) -> Result<String> { + todo!() + } + + #[cfg(feature = "webauthn")] + async fn retrieve_authentication_challenge( + &self, + website: &url::Url, + challenge_id: &str + ) -> Result<PasskeyAuthentication> { + todo!() + } + + async fn list_user_credential_types(&self, website: &url::Url) -> Result<Vec<EnrolledCredential>> { + let mut creds = vec![]; + + match tokio::fs::metadata(self.path + .join(FileBackend::url_to_dir(website)) + .join("password")) + .await + { + Ok(_) => creds.push(EnrolledCredential::Password), + Err(err) => if err.kind() != std::io::ErrorKind::NotFound { + return Err(err) + } + } + + #[cfg(feature = "webauthn")] + if !self.list_webauthn_pubkeys(website).await?.is_empty() { + creds.push(EnrolledCredential::WebAuthn); + } + + Ok(creds) + } +} diff --git a/src/indieauth/mod.rs b/src/indieauth/mod.rs new file mode 100644 index 0000000..0ad2702 --- /dev/null +++ b/src/indieauth/mod.rs @@ -0,0 +1,883 @@ +use std::marker::PhantomData; + +use tracing::error; +use serde::Deserialize; +use axum::{ + extract::{Query, Json, Host, Form}, + response::{Html, IntoResponse, Response}, + http::StatusCode, TypedHeader, headers::{Authorization, authorization::Bearer}, + Extension +}; +#[cfg_attr(not(feature = "webauthn"), allow(unused_imports))] +use axum_extra::extract::cookie::{CookieJar, Cookie}; +use crate::database::Storage; +use kittybox_indieauth::{ + Metadata, IntrospectionEndpointAuthMethod, RevocationEndpointAuthMethod, + Scope, Scopes, PKCEMethod, Error, ErrorKind, ResponseType, + AuthorizationRequest, AuthorizationResponse, + GrantType, GrantRequest, GrantResponse, Profile, + TokenIntrospectionRequest, TokenIntrospectionResponse, TokenRevocationRequest, TokenData +}; +use std::str::FromStr; +use std::ops::Deref; + +pub mod backend; +#[cfg(feature = "webauthn")] +mod webauthn; +use backend::AuthBackend; + +const ACCESS_TOKEN_VALIDITY: u64 = 7 * 24 * 60 * 60; // 7 days +const REFRESH_TOKEN_VALIDITY: u64 = ACCESS_TOKEN_VALIDITY / 7 * 60; // 60 days +/// Internal scope for accessing the token introspection endpoint. +const KITTYBOX_TOKEN_STATUS: &str = "kittybox:token_status"; + +pub(crate) struct User<A: AuthBackend>(pub(crate) TokenData, pub(crate) PhantomData<A>); +impl<A: AuthBackend> std::fmt::Debug for User<A> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("User").field(&self.0).finish() + } +} +impl<A: AuthBackend> std::ops::Deref for User<A> { + type Target = TokenData; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +pub enum IndieAuthResourceError { + InvalidRequest, + Unauthorized, + InvalidToken +} +impl axum::response::IntoResponse for IndieAuthResourceError { + fn into_response(self) -> axum::response::Response { + use IndieAuthResourceError::*; + + match self { + Unauthorized => ( + StatusCode::UNAUTHORIZED, + [("WWW-Authenticate", "Bearer")] + ).into_response(), + InvalidRequest => ( + StatusCode::BAD_REQUEST, + Json(&serde_json::json!({"error": "invalid_request"})) + ).into_response(), + InvalidToken => ( + StatusCode::UNAUTHORIZED, + [("WWW-Authenticate", "Bearer, error=\"invalid_token\"")], + Json(&serde_json::json!({"error": "unauthorized"})) + ).into_response() + } + } +} + +#[async_trait::async_trait] +impl <S: Send + Sync, A: AuthBackend> axum::extract::FromRequestParts<S> for User<A> { + type Rejection = IndieAuthResourceError; + + async fn from_request_parts(req: &mut axum::http::request::Parts, state: &S) -> Result<Self, Self::Rejection> { + let TypedHeader(Authorization(token)) = + TypedHeader::<Authorization<Bearer>>::from_request_parts(req, state) + .await + .map_err(|_| IndieAuthResourceError::Unauthorized)?; + + let axum::Extension(auth) = axum::Extension::<A>::from_request_parts(req, state) + .await + .unwrap(); + + let Host(host) = Host::from_request_parts(req, state) + .await + .map_err(|_| IndieAuthResourceError::InvalidRequest)?; + + auth.get_token( + &format!("https://{host}/").parse().unwrap(), + token.token() + ) + .await + .unwrap() + .ok_or(IndieAuthResourceError::InvalidToken) + .map(|t| User(t, PhantomData)) + } +} + +pub async fn metadata( + Host(host): Host +) -> Metadata { + let issuer: url::Url = format!( + "{}://{}/", + if cfg!(debug_assertions) { + "http" + } else { + "https" + }, + host + ).parse().unwrap(); + + let indieauth: url::Url = issuer.join("/.kittybox/indieauth/").unwrap(); + Metadata { + issuer, + authorization_endpoint: indieauth.join("auth").unwrap(), + token_endpoint: indieauth.join("token").unwrap(), + introspection_endpoint: indieauth.join("token_status").unwrap(), + introspection_endpoint_auth_methods_supported: Some(vec![ + IntrospectionEndpointAuthMethod::Bearer + ]), + revocation_endpoint: Some(indieauth.join("revoke_token").unwrap()), + revocation_endpoint_auth_methods_supported: Some(vec![ + RevocationEndpointAuthMethod::None + ]), + scopes_supported: Some(vec![ + Scope::Create, + Scope::Update, + Scope::Delete, + Scope::Media, + Scope::Profile + ]), + response_types_supported: Some(vec![ResponseType::Code]), + grant_types_supported: Some(vec![GrantType::AuthorizationCode, GrantType::RefreshToken]), + service_documentation: None, + code_challenge_methods_supported: vec![PKCEMethod::S256], + authorization_response_iss_parameter_supported: Some(true), + userinfo_endpoint: Some(indieauth.join("userinfo").unwrap()), + } +} + +async fn authorization_endpoint_get<A: AuthBackend, D: Storage + 'static>( + Host(host): Host, + Query(request): Query<AuthorizationRequest>, + Extension(db): Extension<D>, + Extension(http): Extension<reqwest::Client>, + Extension(auth): Extension<A> +) -> Response { + let me = format!("https://{host}/").parse().unwrap(); + let h_app = { + tracing::debug!("Sending request to {} to fetch metadata", request.client_id); + match http.get(request.client_id.clone()).send().await { + Ok(response) => { + let url = response.url().clone(); + let text = response.text().await.unwrap(); + tracing::debug!("Received {} bytes in response", text.len()); + match microformats::from_html(&text, url) { + Ok(mf2) => { + if let Some(relation) = mf2.rels.items.get(&request.redirect_uri) { + if !relation.rels.iter().any(|i| i == "redirect_uri") { + return (StatusCode::BAD_REQUEST, + [("Content-Type", "text/plain")], + "The redirect_uri provided was declared as \ + something other than redirect_uri.") + .into_response() + } + } else if request.redirect_uri.origin() != request.client_id.origin() { + return (StatusCode::BAD_REQUEST, + [("Content-Type", "text/plain")], + "The redirect_uri didn't match the origin \ + and wasn't explicitly allowed. You were being tricked.") + .into_response() + } + + mf2.items.iter() + .cloned() + .find(|i| (**i).borrow().r#type.iter() + .any(|i| *i == microformats::types::Class::from_str("h-app").unwrap() + || *i == microformats::types::Class::from_str("h-x-app").unwrap())) + .map(|i| serde_json::to_value(i.borrow().deref()).unwrap()) + }, + Err(err) => { + tracing::error!("Error parsing application metadata: {}", err); + return (StatusCode::BAD_REQUEST, + [("Content-Type", "text/plain")], + "Parsing application metadata failed.").into_response() + } + } + }, + Err(err) => { + tracing::error!("Error fetching application metadata: {}", err); + return (StatusCode::INTERNAL_SERVER_ERROR, + [("Content-Type", "text/plain")], + "Fetching application metadata failed.").into_response() + } + } + }; + + tracing::debug!("Application metadata: {:#?}", h_app); + + Html(kittybox_frontend_renderer::Template { + title: "Confirm sign-in via IndieAuth", + blog_name: "Kittybox", + feeds: vec![], + user: None, + content: kittybox_frontend_renderer::AuthorizationRequestPage { + request, + credentials: auth.list_user_credential_types(&me).await.unwrap(), + user: db.get_post(me.as_str()).await.unwrap().unwrap(), + app: h_app + }.to_string(), + }.to_string()) + .into_response() +} + +#[derive(Deserialize, Debug)] +#[serde(untagged)] +enum Credential { + Password(String), + #[cfg(feature = "webauthn")] + WebAuthn(::webauthn::prelude::PublicKeyCredential) +} + +#[derive(Deserialize, Debug)] +struct AuthorizationConfirmation { + authorization_method: Credential, + request: AuthorizationRequest +} + +async fn verify_credential<A: AuthBackend>( + auth: &A, + website: &url::Url, + credential: Credential, + #[cfg_attr(not(feature = "webauthn"), allow(unused_variables))] + challenge_id: Option<&str> +) -> std::io::Result<bool> { + match credential { + Credential::Password(password) => auth.verify_password(website, password).await, + #[cfg(feature = "webauthn")] + Credential::WebAuthn(credential) => webauthn::verify( + auth, + website, + credential, + challenge_id.unwrap() + ).await + } +} + +#[tracing::instrument(skip(backend, confirmation))] +async fn authorization_endpoint_confirm<A: AuthBackend>( + Host(host): Host, + Extension(backend): Extension<A>, + cookies: CookieJar, + Json(confirmation): Json<AuthorizationConfirmation>, +) -> Response { + tracing::debug!("Received authorization confirmation from user"); + #[cfg(feature = "webauthn")] + let challenge_id = cookies.get(webauthn::CHALLENGE_ID_COOKIE) + .map(|cookie| cookie.value()); + #[cfg(not(feature = "webauthn"))] + let challenge_id = None; + + let website = format!("https://{}/", host).parse().unwrap(); + let AuthorizationConfirmation { + authorization_method: credential, + request: mut auth + } = confirmation; + match verify_credential(&backend, &website, credential, challenge_id).await { + Ok(verified) => if !verified { + error!("User failed verification, bailing out."); + return StatusCode::UNAUTHORIZED.into_response(); + }, + Err(err) => { + error!("Error while verifying credential: {}", err); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + } + // Insert the correct `me` value into the request + // + // From this point, the `me` value that hits the backend is + // guaranteed to be authoritative and correct, and can be safely + // unwrapped. + auth.me = Some(website.clone()); + // Cloning these two values, because we can't destructure + // the AuthorizationRequest - we need it for the code + let state = auth.state.clone(); + let redirect_uri = auth.redirect_uri.clone(); + + let code = match backend.create_code(auth).await { + Ok(code) => code, + Err(err) => { + error!("Error creating authorization code: {}", err); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + }; + + let location = { + let mut uri = redirect_uri; + uri.set_query(Some(&serde_urlencoded::to_string( + AuthorizationResponse { code, state, iss: website } + ).unwrap())); + + uri + }; + + // DO NOT SET `StatusCode::FOUND` here! `fetch()` cannot read from + // redirects, it can only follow them or choose to receive an + // opaque response instead that is completely useless + (StatusCode::NO_CONTENT, + [("Location", location.as_str())], + #[cfg(feature = "webauthn")] + cookies.remove(Cookie::named(webauthn::CHALLENGE_ID_COOKIE)) + ) + .into_response() +} + +#[tracing::instrument(skip(backend, db))] +async fn authorization_endpoint_post<A: AuthBackend, D: Storage + 'static>( + Host(host): Host, + Extension(backend): Extension<A>, + Extension(db): Extension<D>, + Form(grant): Form<GrantRequest>, +) -> Response { + match grant { + GrantRequest::AuthorizationCode { + code, + client_id, + redirect_uri, + code_verifier + } => { + let request: AuthorizationRequest = match backend.get_code(&code).await { + Ok(Some(request)) => request, + Ok(None) => return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("The provided authorization code is invalid.".to_string()), + error_uri: None + }.into_response(), + Err(err) => { + tracing::error!("Error retrieving auth request: {}", err); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + }; + if client_id != request.client_id { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("This authorization code isn't yours.".to_string()), + error_uri: None + }.into_response() + } + if redirect_uri != request.redirect_uri { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("This redirect_uri doesn't match the one the code has been sent to.".to_string()), + error_uri: None + }.into_response() + } + if !request.code_challenge.verify(code_verifier) { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("The PKCE challenge failed.".to_string()), + // are RFCs considered human-readable? 😝 + error_uri: "https://datatracker.ietf.org/doc/html/rfc7636#section-4.6".parse().ok() + }.into_response() + } + let me: url::Url = format!("https://{}/", host).parse().unwrap(); + if request.me.unwrap() != me { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("This authorization endpoint does not serve this user.".to_string()), + error_uri: None + }.into_response() + } + let profile = if request.scope.as_ref() + .map(|s| s.has(&Scope::Profile)) + .unwrap_or_default() + { + match get_profile( + db, + me.as_str(), + request.scope.as_ref() + .map(|s| s.has(&Scope::Email)) + .unwrap_or_default() + ).await { + Ok(profile) => { + tracing::debug!("Retrieved profile: {:?}", profile); + profile + }, + Err(err) => { + tracing::error!("Error retrieving profile from database: {}", err); + + return StatusCode::INTERNAL_SERVER_ERROR.into_response() + } + } + } else { + None + }; + + GrantResponse::ProfileUrl { me, profile }.into_response() + }, + _ => Error { + kind: ErrorKind::InvalidGrant, + msg: Some("The provided grant_type is unusable on this endpoint.".to_string()), + error_uri: "https://indieauth.spec.indieweb.org/#redeeming-the-authorization-code".parse().ok() + }.into_response() + } +} + +#[tracing::instrument(skip(backend, db))] +async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( + Host(host): Host, + Extension(backend): Extension<A>, + Extension(db): Extension<D>, + Form(grant): Form<GrantRequest>, +) -> Response { + #[inline] + fn prepare_access_token(me: url::Url, client_id: url::Url, scope: Scopes) -> TokenData { + TokenData { + me, client_id, scope, + exp: (std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + + std::time::Duration::from_secs(ACCESS_TOKEN_VALIDITY)) + .as_secs() + .into(), + iat: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + .into() + } + } + + #[inline] + fn prepare_refresh_token(me: url::Url, client_id: url::Url, scope: Scopes) -> TokenData { + TokenData { + me, client_id, scope, + exp: (std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + + std::time::Duration::from_secs(REFRESH_TOKEN_VALIDITY)) + .as_secs() + .into(), + iat: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + .into() + } + } + + let me: url::Url = format!("https://{}/", host).parse().unwrap(); + + match grant { + GrantRequest::AuthorizationCode { + code, + client_id, + redirect_uri, + code_verifier + } => { + let request: AuthorizationRequest = match backend.get_code(&code).await { + Ok(Some(request)) => request, + Ok(None) => return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("The provided authorization code is invalid.".to_string()), + error_uri: None + }.into_response(), + Err(err) => { + tracing::error!("Error retrieving auth request: {}", err); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + }; + + tracing::debug!("Retrieved authorization request: {:?}", request); + + let scope = if let Some(scope) = request.scope { scope } else { + return Error { + kind: ErrorKind::InvalidScope, + msg: Some("Tokens cannot be issued if no scopes are requested.".to_string()), + error_uri: "https://indieauth.spec.indieweb.org/#access-token-response".parse().ok() + }.into_response(); + }; + if client_id != request.client_id { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("This authorization code isn't yours.".to_string()), + error_uri: None + }.into_response() + } + if redirect_uri != request.redirect_uri { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("This redirect_uri doesn't match the one the code has been sent to.".to_string()), + error_uri: None + }.into_response() + } + if !request.code_challenge.verify(code_verifier) { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("The PKCE challenge failed.".to_string()), + error_uri: "https://datatracker.ietf.org/doc/html/rfc7636#section-4.6".parse().ok() + }.into_response(); + } + + // Note: we can trust the `request.me` value, since we set + // it earlier before generating the authorization code + if request.me.unwrap() != me { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("This authorization endpoint does not serve this user.".to_string()), + error_uri: None + }.into_response() + } + + let profile = if dbg!(scope.has(&Scope::Profile)) { + match get_profile( + db, + me.as_str(), + scope.has(&Scope::Email) + ).await { + Ok(profile) => dbg!(profile), + Err(err) => { + tracing::error!("Error retrieving profile from database: {}", err); + + return StatusCode::INTERNAL_SERVER_ERROR.into_response() + } + } + } else { + None + }; + + let access_token = match backend.create_token( + prepare_access_token(me.clone(), client_id.clone(), scope.clone()) + ).await { + Ok(token) => token, + Err(err) => { + tracing::error!("Error creating access token: {}", err); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + }; + // TODO: only create refresh token if user allows it + let refresh_token = match backend.create_refresh_token( + prepare_refresh_token(me.clone(), client_id, scope.clone()) + ).await { + Ok(token) => token, + Err(err) => { + tracing::error!("Error creating refresh token: {}", err); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + }; + + GrantResponse::AccessToken { + me, + profile, + access_token, + token_type: kittybox_indieauth::TokenType::Bearer, + scope: Some(scope), + expires_in: Some(ACCESS_TOKEN_VALIDITY), + refresh_token: Some(refresh_token) + }.into_response() + }, + GrantRequest::RefreshToken { + refresh_token, + client_id, + scope + } => { + let data = match backend.get_refresh_token(&me, &refresh_token).await { + Ok(Some(token)) => token, + Ok(None) => return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("This refresh token is not valid.".to_string()), + error_uri: None + }.into_response(), + Err(err) => { + tracing::error!("Error retrieving refresh token: {}", err); + return StatusCode::INTERNAL_SERVER_ERROR.into_response() + } + }; + + if data.client_id != client_id { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("This refresh token is not yours.".to_string()), + error_uri: None + }.into_response(); + } + + let scope = if let Some(scope) = scope { + if !data.scope.has_all(scope.as_ref()) { + return Error { + kind: ErrorKind::InvalidScope, + msg: Some("You can't request additional scopes through the refresh token grant.".to_string()), + error_uri: None + }.into_response(); + } + + scope + } else { + // Note: check skipped because of redundancy (comparing a scope list with itself) + data.scope + }; + + + let profile = if scope.has(&Scope::Profile) { + match get_profile( + db, + data.me.as_str(), + scope.has(&Scope::Email) + ).await { + Ok(profile) => profile, + Err(err) => { + tracing::error!("Error retrieving profile from database: {}", err); + + return StatusCode::INTERNAL_SERVER_ERROR.into_response() + } + } + } else { + None + }; + + let access_token = match backend.create_token( + prepare_access_token(data.me.clone(), client_id.clone(), scope.clone()) + ).await { + Ok(token) => token, + Err(err) => { + tracing::error!("Error creating access token: {}", err); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + }; + + let old_refresh_token = refresh_token; + let refresh_token = match backend.create_refresh_token( + prepare_refresh_token(data.me.clone(), client_id, scope.clone()) + ).await { + Ok(token) => token, + Err(err) => { + tracing::error!("Error creating refresh token: {}", err); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + }; + if let Err(err) = backend.revoke_refresh_token(&me, &old_refresh_token).await { + tracing::error!("Error revoking refresh token: {}", err); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + + GrantResponse::AccessToken { + me: data.me, + profile, + access_token, + token_type: kittybox_indieauth::TokenType::Bearer, + scope: Some(scope), + expires_in: Some(ACCESS_TOKEN_VALIDITY), + refresh_token: Some(refresh_token) + }.into_response() + } + } +} + +#[tracing::instrument(skip(backend, token_request))] +async fn introspection_endpoint_post<A: AuthBackend>( + Host(host): Host, + TypedHeader(Authorization(auth_token)): TypedHeader<Authorization<Bearer>>, + Extension(backend): Extension<A>, + Form(token_request): Form<TokenIntrospectionRequest>, +) -> Response { + use serde_json::json; + + let me: url::Url = format!("https://{}/", host).parse().unwrap(); + + // Check authentication first + match backend.get_token(&me, auth_token.token()).await { + Ok(Some(token)) => if !token.scope.has(&Scope::custom(KITTYBOX_TOKEN_STATUS)) { + return (StatusCode::UNAUTHORIZED, Json(json!({ + "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope + }))).into_response(); + }, + Ok(None) => return (StatusCode::UNAUTHORIZED, Json(json!({ + "error": kittybox_indieauth::ResourceErrorKind::InvalidToken + }))).into_response(), + Err(err) => { + tracing::error!("Error retrieving token data for introspection: {}", err); + return StatusCode::INTERNAL_SERVER_ERROR.into_response() + } + } + let response: TokenIntrospectionResponse = match backend.get_token(&me, &token_request.token).await { + Ok(maybe_data) => maybe_data.into(), + Err(err) => { + tracing::error!("Error retrieving token data: {}", err); + return StatusCode::INTERNAL_SERVER_ERROR.into_response() + } + }; + + response.into_response() +} + +async fn revocation_endpoint_post<A: AuthBackend>( + Host(host): Host, + Extension(backend): Extension<A>, + Form(revocation): Form<TokenRevocationRequest>, +) -> impl IntoResponse { + let me: url::Url = format!("https://{}/", host).parse().unwrap(); + + if let Err(err) = tokio::try_join!( + backend.revoke_token(&me, &revocation.token), + backend.revoke_refresh_token(&me, &revocation.token) + ) { + tracing::error!("Error revoking token: {}", err); + + StatusCode::INTERNAL_SERVER_ERROR + } else { + StatusCode::OK + } +} + +async fn get_profile<D: Storage + 'static>( + db: D, + url: &str, + email: bool +) -> crate::database::Result<Option<Profile>> { + Ok(db.get_post(url).await?.map(|mut mf2| { + // Ruthlessly manually destructure the MF2 document to save memory + let name = match mf2["properties"]["name"][0].take() { + serde_json::Value::String(s) => Some(s), + _ => None + }; + let url = match mf2["properties"]["uid"][0].take() { + serde_json::Value::String(s) => s.parse().ok(), + _ => None + }; + let photo = match mf2["properties"]["photo"][0].take() { + serde_json::Value::String(s) => s.parse().ok(), + _ => None + }; + let email = if email { + match mf2["properties"]["email"][0].take() { + serde_json::Value::String(s) => Some(s), + _ => None + } + } else { + None + }; + + Profile { name, url, photo, email } + })) +} + +async fn userinfo_endpoint_get<A: AuthBackend, D: Storage + 'static>( + Host(host): Host, + TypedHeader(Authorization(auth_token)): TypedHeader<Authorization<Bearer>>, + Extension(backend): Extension<A>, + Extension(db): Extension<D> +) -> Response { + use serde_json::json; + + let me: url::Url = format!("https://{}/", host).parse().unwrap(); + + match backend.get_token(&me, auth_token.token()).await { + Ok(Some(token)) => { + if token.expired() { + return (StatusCode::UNAUTHORIZED, Json(json!({ + "error": kittybox_indieauth::ResourceErrorKind::InvalidToken + }))).into_response(); + } + if !token.scope.has(&Scope::Profile) { + return (StatusCode::UNAUTHORIZED, Json(json!({ + "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope + }))).into_response(); + } + + match get_profile(db, me.as_str(), token.scope.has(&Scope::Email)).await { + Ok(Some(profile)) => profile.into_response(), + Ok(None) => Json(json!({ + // We do this because ResourceErrorKind is IndieAuth errors only + "error": "invalid_request" + })).into_response(), + Err(err) => { + tracing::error!("Error retrieving profile from database: {}", err); + + StatusCode::INTERNAL_SERVER_ERROR.into_response() + } + } + }, + Ok(None) => Json(json!({ + "error": kittybox_indieauth::ResourceErrorKind::InvalidToken + })).into_response(), + Err(err) => { + tracing::error!("Error reading token: {}", err); + + StatusCode::INTERNAL_SERVER_ERROR.into_response() + } + } +} + +#[must_use] +pub fn router<A: AuthBackend, D: Storage + 'static>(backend: A, db: D, http: reqwest::Client) -> axum::Router { + use axum::routing::{Router, get, post}; + + Router::new() + .nest( + "/.kittybox/indieauth", + Router::new() + .route("/metadata", + get(metadata)) + .route( + "/auth", + get(authorization_endpoint_get::<A, D>) + .post(authorization_endpoint_post::<A, D>)) + .route( + "/auth/confirm", + post(authorization_endpoint_confirm::<A>)) + .route( + "/token", + post(token_endpoint_post::<A, D>)) + .route( + "/token_status", + post(introspection_endpoint_post::<A>)) + .route( + "/revoke_token", + post(revocation_endpoint_post::<A>)) + .route( + "/userinfo", + get(userinfo_endpoint_get::<A, D>)) + + .route("/webauthn/pre_register", + get( + #[cfg(feature = "webauthn")] webauthn::webauthn_pre_register::<A, D>, + #[cfg(not(feature = "webauthn"))] || std::future::ready(axum::http::StatusCode::NOT_FOUND) + ) + ) + .layer(tower_http::cors::CorsLayer::new() + .allow_methods([ + axum::http::Method::GET, + axum::http::Method::POST + ]) + .allow_origin(tower_http::cors::Any)) + .layer(Extension(backend)) + // I don't really like the fact that I have to use the whole database + // If I could, I would've designed a separate trait for getting profiles + // And made databases implement it, for example + .layer(Extension(db)) + .layer(Extension(http)) + ) + .route( + "/.well-known/oauth-authorization-server", + get(|| std::future::ready( + (StatusCode::FOUND, + [("Location", + "/.kittybox/indieauth/metadata")] + ).into_response() + )) + ) +} + +#[cfg(test)] +mod tests { + #[test] + fn test_deserialize_authorization_confirmation() { + use super::{Credential, AuthorizationConfirmation}; + + let confirmation = serde_json::from_str::<AuthorizationConfirmation>(r#"{ + "request":{ + "response_type": "code", + "client_id": "https://quill.p3k.io/", + "redirect_uri": "https://quill.p3k.io/", + "state": "10101010", + "code_challenge": "awooooooooooo", + "code_challenge_method": "S256", + "scope": "create+media" + }, + "authorization_method": "swordfish" + }"#).unwrap(); + + match confirmation.authorization_method { + Credential::Password(password) => assert_eq!(password.as_str(), "swordfish"), + #[allow(unreachable_patterns)] + other => panic!("Incorrect credential: {:?}", other) + } + assert_eq!(confirmation.request.state.as_ref(), "10101010"); + } +} diff --git a/src/indieauth/webauthn.rs b/src/indieauth/webauthn.rs new file mode 100644 index 0000000..ea3ad3d --- /dev/null +++ b/src/indieauth/webauthn.rs @@ -0,0 +1,140 @@ +use axum::{ + extract::{Json, Host}, + response::{IntoResponse, Response}, + http::StatusCode, Extension, TypedHeader, headers::{authorization::Bearer, Authorization} +}; +use axum_extra::extract::cookie::{CookieJar, Cookie}; + +use super::backend::AuthBackend; +use crate::database::Storage; + +pub(crate) const CHALLENGE_ID_COOKIE: &str = "kittybox_webauthn_challenge_id"; + +macro_rules! bail { + ($msg:literal, $err:expr) => { + { + ::tracing::error!($msg, $err); + return ::axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response() + } + } +} + +pub async fn webauthn_pre_register<A: AuthBackend, D: Storage + 'static>( + Host(host): Host, + Extension(db): Extension<D>, + Extension(auth): Extension<A>, + cookies: CookieJar +) -> Response { + let uid = format!("https://{}/", host.clone()); + let uid_url: url::Url = uid.parse().unwrap(); + // This will not find an h-card in onboarding! + let display_name = match db.get_post(&uid).await { + Ok(hcard) => match hcard { + Some(mut hcard) => { + match hcard["properties"]["uid"][0].take() { + serde_json::Value::String(name) => name, + _ => String::default() + } + }, + None => String::default() + }, + Err(err) => bail!("Error retrieving h-card: {}", err) + }; + + let webauthn = webauthn::WebauthnBuilder::new( + &host, + &uid_url + ) + .unwrap() + .rp_name("Kittybox") + .build() + .unwrap(); + + let (challenge, state) = match webauthn.start_passkey_registration( + // Note: using a nil uuid here is fine + // Because the user corresponds to a website anyway + // We do not track multiple users + webauthn::prelude::Uuid::nil(), + &uid, + &display_name, + Some(vec![]) + ) { + Ok((challenge, state)) => (challenge, state), + Err(err) => bail!("Error generating WebAuthn registration data: {}", err) + }; + + match auth.persist_registration_challenge(&uid_url, state).await { + Ok(challenge_id) => ( + cookies.add( + Cookie::build(CHALLENGE_ID_COOKIE, challenge_id) + .secure(true) + .finish() + ), + Json(challenge) + ).into_response(), + Err(err) => bail!("Failed to persist WebAuthn challenge: {}", err) + } +} + +pub async fn webauthn_register<A: AuthBackend>( + Host(host): Host, + Json(credential): Json<webauthn::prelude::RegisterPublicKeyCredential>, + // TODO determine if we can use a cookie maybe? + user_credential: Option<TypedHeader<Authorization<Bearer>>>, + Extension(auth): Extension<A> +) -> Response { + let uid = format!("https://{}/", host.clone()); + let uid_url: url::Url = uid.parse().unwrap(); + + let pubkeys = match auth.list_webauthn_pubkeys(&uid_url).await { + Ok(pubkeys) => pubkeys, + Err(err) => bail!("Error enumerating existing WebAuthn credentials: {}", err) + }; + + if !pubkeys.is_empty() { + if let Some(TypedHeader(Authorization(token))) = user_credential { + // TODO check validity of the credential + } else { + return StatusCode::UNAUTHORIZED.into_response() + } + } + + return StatusCode::OK.into_response() +} + +pub(crate) async fn verify<A: AuthBackend>( + auth: &A, + website: &url::Url, + credential: webauthn::prelude::PublicKeyCredential, + challenge_id: &str +) -> std::io::Result<bool> { + let host = website.host_str().unwrap(); + + let webauthn = webauthn::WebauthnBuilder::new( + host, + website + ) + .unwrap() + .rp_name("Kittybox") + .build() + .unwrap(); + + match webauthn.finish_passkey_authentication( + &credential, + &auth.retrieve_authentication_challenge(&website, challenge_id).await? + ) { + Err(err) => { + tracing::error!("WebAuthn error: {}", err); + Ok(false) + }, + Ok(authentication_result) => { + let counter = authentication_result.counter(); + let cred_id = authentication_result.cred_id(); + + if authentication_result.needs_update() { + todo!() + } + Ok(true) + } + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..c1bd965 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,93 @@ +#![forbid(unsafe_code)] +#![warn(clippy::todo)] + +/// Database abstraction layer for Kittybox, allowing the CMS to work with any kind of database. +pub mod database; +pub mod frontend; +pub mod media; +pub mod micropub; +pub mod indieauth; +pub mod webmentions; + +pub mod companion { + use std::{collections::HashMap, sync::Arc}; + use axum::{ + extract::{Extension, Path}, + response::{IntoResponse, Response} + }; + + #[derive(Debug, Clone, Copy)] + struct Resource { + data: &'static [u8], + mime: &'static str + } + + impl IntoResponse for &Resource { + fn into_response(self) -> Response { + (axum::http::StatusCode::OK, + [("Content-Type", self.mime)], + self.data).into_response() + } + } + + // TODO replace with the "phf" crate someday + type ResourceTable = Arc<HashMap<&'static str, Resource>>; + + #[tracing::instrument] + async fn map_to_static( + Path(name): Path<String>, + Extension(resources): Extension<ResourceTable> + ) -> Response { + tracing::debug!("Searching for {} in the resource table...", name); + match resources.get(name.as_str()) { + Some(res) => res.into_response(), + None => { + #[cfg(debug_assertions)] tracing::error!("Not found"); + + (axum::http::StatusCode::NOT_FOUND, + [("Content-Type", "text/plain")], + "Not found. Sorry.".as_bytes()).into_response() + } + } + } + + #[must_use] + pub fn router() -> axum::Router { + let resources: ResourceTable = { + let mut map = HashMap::new(); + + macro_rules! register_resource { + ($map:ident, $prefix:expr, ($filename:literal, $mime:literal)) => {{ + $map.insert($filename, Resource { + data: include_bytes!(concat!($prefix, $filename)), + mime: $mime + }) + }}; + ($map:ident, $prefix:expr, ($filename:literal, $mime:literal), $( ($f:literal, $m:literal) ),+) => {{ + register_resource!($map, $prefix, ($filename, $mime)); + register_resource!($map, $prefix, $(($f, $m)),+); + }}; + } + + register_resource! { + map, + concat!(env!("OUT_DIR"), "/", "companion", "/"), + ("index.html", "text/html; charset=\"utf-8\""), + ("main.js", "text/javascript"), + ("micropub_api.js", "text/javascript"), + ("indieauth.js", "text/javascript"), + ("base64.js", "text/javascript"), + ("style.css", "text/css") + }; + + Arc::new(map) + }; + + axum::Router::new() + .route( + "/:filename", + axum::routing::get(map_to_static) + .layer(Extension(resources)) + ) + } +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..6389489 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,489 @@ +use kittybox::database::FileStorage; +use std::{env, time::Duration, sync::Arc}; +use tracing::error; + +fn init_media<A: kittybox::indieauth::backend::AuthBackend>(auth_backend: A, blobstore_uri: &str) -> axum::Router { + match blobstore_uri.split_once(':').unwrap().0 { + "file" => { + let folder = std::path::PathBuf::from( + blobstore_uri.strip_prefix("file://").unwrap() + ); + let blobstore = kittybox::media::storage::file::FileStore::new(folder); + + kittybox::media::router::<_, _>(blobstore, auth_backend) + }, + other => unimplemented!("Unsupported backend: {other}") + } +} + +async fn compose_kittybox_with_auth<A>( + http: reqwest::Client, + auth_backend: A, + backend_uri: &str, + blobstore_uri: &str, + job_queue_uri: &str, + jobset: &Arc<tokio::sync::Mutex<tokio::task::JoinSet<()>>>, + cancellation_token: &tokio_util::sync::CancellationToken +) -> (axum::Router, kittybox::webmentions::SupervisedTask) +where A: kittybox::indieauth::backend::AuthBackend +{ + match backend_uri.split_once(':').unwrap().0 { + "file" => { + let database = { + let folder = backend_uri.strip_prefix("file://").unwrap(); + let path = std::path::PathBuf::from(folder); + + match kittybox::database::FileStorage::new(path).await { + Ok(db) => db, + Err(err) => { + error!("Error creating database: {:?}", err); + std::process::exit(1); + } + } + }; + + // Technically, if we don't construct the micropub router, + // we could use some wrapper that makes the database + // read-only. + // + // This would allow to exclude all code to write to the + // database and separate reader and writer processes of + // Kittybox to improve security. + let homepage: axum::routing::MethodRouter<_> = axum::routing::get( + kittybox::frontend::homepage::<FileStorage> + ) + .layer(axum::Extension(database.clone())); + let fallback = axum::routing::get( + kittybox::frontend::catchall::<FileStorage> + ) + .layer(axum::Extension(database.clone())); + + let micropub = kittybox::micropub::router( + database.clone(), + http.clone(), + auth_backend.clone(), + Arc::clone(jobset) + ); + let onboarding = kittybox::frontend::onboarding::router( + database.clone(), http.clone(), Arc::clone(jobset) + ); + + + let (webmention, task) = kittybox::webmentions::router( + kittybox::webmentions::queue::PostgresJobQueue::new(job_queue_uri).await.unwrap(), + database.clone(), + http.clone(), + cancellation_token.clone() + ); + + let router = axum::Router::new() + .route("/", homepage) + .fallback(fallback) + .route("/.kittybox/micropub", micropub) + .route("/.kittybox/onboarding", onboarding) + .nest("/.kittybox/media", init_media(auth_backend.clone(), blobstore_uri)) + .merge(kittybox::indieauth::router(auth_backend.clone(), database.clone(), http.clone())) + .merge(webmention) + .route( + "/.kittybox/health", + axum::routing::get(health_check::<kittybox::database::FileStorage>) + .layer(axum::Extension(database)) + ); + + (router, task) + }, + "redis" => unimplemented!("Redis backend is not supported."), + #[cfg(feature = "postgres")] + "postgres" => { + use kittybox::database::PostgresStorage; + + let database = { + match PostgresStorage::new(backend_uri).await { + Ok(db) => db, + Err(err) => { + error!("Error creating database: {:?}", err); + std::process::exit(1); + } + } + }; + + // Technically, if we don't construct the micropub router, + // we could use some wrapper that makes the database + // read-only. + // + // This would allow to exclude all code to write to the + // database and separate reader and writer processes of + // Kittybox to improve security. + let homepage: axum::routing::MethodRouter<_> = axum::routing::get( + kittybox::frontend::homepage::<PostgresStorage> + ) + .layer(axum::Extension(database.clone())); + let fallback = axum::routing::get( + kittybox::frontend::catchall::<PostgresStorage> + ) + .layer(axum::Extension(database.clone())); + + let micropub = kittybox::micropub::router( + database.clone(), + http.clone(), + auth_backend.clone(), + Arc::clone(jobset) + ); + let onboarding = kittybox::frontend::onboarding::router( + database.clone(), http.clone(), Arc::clone(jobset) + ); + + let (webmention, task) = kittybox::webmentions::router( + kittybox::webmentions::queue::PostgresJobQueue::new(job_queue_uri).await.unwrap(), + database.clone(), + http.clone(), + cancellation_token.clone() + ); + + let router = axum::Router::new() + .route("/", homepage) + .fallback(fallback) + .route("/.kittybox/micropub", micropub) + .route("/.kittybox/onboarding", onboarding) + .nest("/.kittybox/media", init_media(auth_backend.clone(), blobstore_uri)) + .merge(kittybox::indieauth::router(auth_backend.clone(), database.clone(), http.clone())) + .merge(webmention) + .route( + "/.kittybox/health", + axum::routing::get(health_check::<kittybox::database::PostgresStorage>) + .layer(axum::Extension(database)) + ); + + (router, task) + }, + other => unimplemented!("Unsupported backend: {other}") + } +} + +async fn compose_kittybox( + backend_uri: &str, + blobstore_uri: &str, + authstore_uri: &str, + job_queue_uri: &str, + jobset: &Arc<tokio::sync::Mutex<tokio::task::JoinSet<()>>>, + cancellation_token: &tokio_util::sync::CancellationToken +) -> (axum::Router, kittybox::webmentions::SupervisedTask) { + let http: reqwest::Client = { + #[allow(unused_mut)] + let mut builder = reqwest::Client::builder() + .user_agent(concat!( + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION") + )); + if let Ok(certs) = std::env::var("KITTYBOX_CUSTOM_PKI_ROOTS") { + // TODO: add a root certificate if there's an environment variable pointing at it + for path in certs.split(':') { + let metadata = match tokio::fs::metadata(path).await { + Ok(metadata) => metadata, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => { + tracing::error!("TLS root certificate {} not found, skipping...", path); + continue; + } + Err(err) => panic!("Error loading TLS certificates: {}", err) + }; + if metadata.is_dir() { + let mut dir = tokio::fs::read_dir(path).await.unwrap(); + while let Ok(Some(file)) = dir.next_entry().await { + let pem = tokio::fs::read(file.path()).await.unwrap(); + builder = builder.add_root_certificate( + reqwest::Certificate::from_pem(&pem).unwrap() + ); + } + } else { + let pem = tokio::fs::read(path).await.unwrap(); + builder = builder.add_root_certificate( + reqwest::Certificate::from_pem(&pem).unwrap() + ); + } + } + } + + builder.build().unwrap() + }; + + let (router, task) = match authstore_uri.split_once(':').unwrap().0 { + "file" => { + let auth_backend = { + let folder = authstore_uri + .strip_prefix("file://") + .unwrap(); + kittybox::indieauth::backend::fs::FileBackend::new(folder) + }; + + compose_kittybox_with_auth(http, auth_backend, backend_uri, blobstore_uri, job_queue_uri, jobset, cancellation_token).await + } + other => unimplemented!("Unsupported backend: {other}") + }; + + let router = router + .route( + "/.kittybox/static/:path", + axum::routing::get(kittybox::frontend::statics) + ) + .route("/.kittybox/coffee", teapot_route()) + .nest("/.kittybox/micropub/client", kittybox::companion::router()) + .layer(tower_http::trace::TraceLayer::new_for_http()) + .layer(tower_http::catch_panic::CatchPanicLayer::new()); + + (router, task) +} + +fn teapot_route() -> axum::routing::MethodRouter { + axum::routing::get(|| async { + use axum::http::{header, StatusCode}; + (StatusCode::IM_A_TEAPOT, [(header::CONTENT_TYPE, "text/plain")], "Sorry, can't brew coffee yet!") + }) +} + +async fn health_check</*A, B, */D>( + //axum::Extension(auth): axum::Extension<A>, + //axum::Extension(blob): axum::Extension<B>, + axum::Extension(data): axum::Extension<D>, +) -> impl axum::response::IntoResponse +where + //A: kittybox::indieauth::backend::AuthBackend, + //B: kittybox::media::storage::MediaStore, + D: kittybox::database::Storage +{ + (axum::http::StatusCode::OK, std::borrow::Cow::Borrowed("OK")) +} + +#[tokio::main] +async fn main() { + use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Registry}; + + let tracing_registry = Registry::default() + .with(EnvFilter::from_default_env()) + .with( + #[cfg(debug_assertions)] + tracing_tree::HierarchicalLayer::new(2) + .with_bracketed_fields(true) + .with_indent_lines(true) + .with_verbose_exit(true), + #[cfg(not(debug_assertions))] + tracing_subscriber::fmt::layer().json() + .with_ansi(std::io::IsTerminal::is_terminal(&std::io::stdout().lock())) + ); + // In debug builds, also log to JSON, but to file. + #[cfg(debug_assertions)] + let tracing_registry = tracing_registry.with( + tracing_subscriber::fmt::layer() + .json() + .with_writer({ + let instant = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap(); + move || std::fs::OpenOptions::new() + .append(true) + .create(true) + .open( + format!( + "{}.log.json", + instant + .as_secs_f64() + .to_string() + .replace('.', "_") + ) + ).unwrap() + }) + ); + tracing_registry.init(); + + tracing::info!("Starting the kittybox server..."); + + let backend_uri: String = env::var("BACKEND_URI") + .unwrap_or_else(|_| { + error!("BACKEND_URI is not set, cannot find a database"); + std::process::exit(1); + }); + let blobstore_uri: String = env::var("BLOBSTORE_URI") + .unwrap_or_else(|_| { + error!("BLOBSTORE_URI is not set, can't find media store"); + std::process::exit(1); + }); + + let authstore_uri: String = env::var("AUTH_STORE_URI") + .unwrap_or_else(|_| { + error!("AUTH_STORE_URI is not set, can't find authentication store"); + std::process::exit(1); + }); + + let job_queue_uri: String = env::var("JOB_QUEUE_URI") + .unwrap_or_else(|_| { + error!("JOB_QUEUE_URI is not set, can't find job queue"); + std::process::exit(1); + }); + + let cancellation_token = tokio_util::sync::CancellationToken::new(); + let jobset = Arc::new(tokio::sync::Mutex::new(tokio::task::JoinSet::new())); + + let (router, webmentions_task) = compose_kittybox( + backend_uri.as_str(), + blobstore_uri.as_str(), + authstore_uri.as_str(), + job_queue_uri.as_str(), + &jobset, + &cancellation_token + ).await; + + let mut servers: Vec<hyper::server::Server<hyper::server::conn::AddrIncoming, _>> = vec![]; + + let build_hyper = |tcp: std::net::TcpListener| { + tracing::info!("Listening on {}", tcp.local_addr().unwrap()); + // Set the socket to non-blocking so tokio can poll it + // properly -- this is the async magic! + tcp.set_nonblocking(true).unwrap(); + + hyper::server::Server::from_tcp(tcp).unwrap() + // Otherwise Chrome keeps connections open for too long + .tcp_keepalive(Some(Duration::from_secs(30 * 60))) + .serve(router.clone().into_make_service()) + }; + + let mut listenfd = listenfd::ListenFd::from_env(); + for i in 0..(listenfd.len()) { + match listenfd.take_tcp_listener(i) { + Ok(Some(tcp)) => servers.push(build_hyper(tcp)), + Ok(None) => {}, + Err(err) => tracing::error!("Error binding to socket in fd {}: {}", i, err) + } + } + // TODO this requires the `hyperlocal` crate + //#[rustfmt::skip] + /*#[cfg(unix)] { + let build_hyper_unix = |unix: std::os::unix::net::UnixListener| { + { + use std::os::linux::net::SocketAddrExt; + + let local_addr = unix.local_addr().unwrap(); + if let Some(pathname) = local_addr.as_pathname() { + tracing::info!("Listening on unix:{}", pathname.display()); + } else if let Some(name) = { + #[cfg(linux)] + local_addr.as_abstract_name(); + #[cfg(not(linux))] + None::<&[u8]> + } { + tracing::info!("Listening on unix:@{}", String::from_utf8_lossy(name)); + } else { + tracing::info!("Listening on unnamed unix socket"); + } + } + unix.set_nonblocking(true).unwrap(); + + hyper::server::Server::builder(unix) + .serve(router.clone().into_make_service()) + }; + for i in 0..(listenfd.len()) { + match listenfd.take_unix_listener(i) { + Ok(Some(unix)) => servers.push(build_hyper_unix(unix)), + Ok(None) => {}, + Err(err) => tracing::error!("Error binding to socket in fd {}: {}", i, err) + } + } + }*/ + if servers.is_empty() { + servers.push(build_hyper({ + let listen_addr = env::var("SERVE_AT") + .ok() + .unwrap_or_else(|| "[::]:8080".to_string()) + .parse::<std::net::SocketAddr>() + .unwrap_or_else(|e| { + error!("Cannot parse SERVE_AT: {}", e); + std::process::exit(1); + }); + + std::net::TcpListener::bind(listen_addr).unwrap() + })) + } + // Drop the remaining copy of the router + // to get rid of an extra reference to `jobset` + drop(router); + // Polling streams mutates them + let mut servers_futures = Box::pin(servers.into_iter() + .map( + #[cfg(not(tokio_unstable))] |server| tokio::task::spawn( + server.with_graceful_shutdown(cancellation_token.clone().cancelled_owned()) + ), + #[cfg(tokio_unstable)] |server| { + tokio::task::Builder::new() + .name(format!("Kittybox HTTP acceptor: {}", server.local_addr()).as_str()) + .spawn( + server.with_graceful_shutdown( + cancellation_token.clone().cancelled_owned() + ) + ) + .unwrap() + } + ) + .collect::<futures_util::stream::FuturesUnordered<tokio::task::JoinHandle<Result<(), hyper::Error>>>>() + ); + + #[cfg(not(unix))] + let shutdown_signal = tokio::signal::ctrl_c(); + #[cfg(unix)] + let shutdown_signal = { + use tokio::signal::unix::{signal, SignalKind}; + + async move { + let mut interrupt = signal(SignalKind::interrupt()) + .expect("Failed to set up SIGINT handler"); + let mut terminate = signal(SignalKind::terminate()) + .expect("Failed to setup SIGTERM handler"); + + tokio::select! { + _ = terminate.recv() => {}, + _ = interrupt.recv() => {}, + } + } + }; + use futures_util::stream::StreamExt; + + let exitcode: i32 = tokio::select! { + // Poll the servers stream for errors. + // If any error out, shut down the entire operation + // + // We do this because there might not be a good way + // to recover from some errors without external help + Some(Err(e)) = servers_futures.next() => { + tracing::error!("Error in HTTP server: {}", e); + tracing::error!("Shutting down because of error."); + cancellation_token.cancel(); + + 1 + } + _ = cancellation_token.cancelled() => { + tracing::info!("Signal caught from watchdog."); + + 0 + } + _ = shutdown_signal => { + tracing::info!("Shutdown requested by signal."); + cancellation_token.cancel(); + + 0 + } + }; + + tracing::info!("Waiting for unfinished background tasks..."); + + let _ = tokio::join!( + webmentions_task, + Box::pin(futures_util::future::join_all( + servers_futures.iter_mut().collect::<Vec<_>>() + )), + ); + let mut jobset: tokio::task::JoinSet<()> = Arc::try_unwrap(jobset) + .expect("Dangling jobset references present") + .into_inner(); + while (jobset.join_next().await).is_some() {} + tracing::info!("Shutdown complete, exiting."); + std::process::exit(exitcode); + +} diff --git a/src/media/mod.rs b/src/media/mod.rs new file mode 100644 index 0000000..71f875e --- /dev/null +++ b/src/media/mod.rs @@ -0,0 +1,141 @@ +use std::convert::TryFrom; + +use axum::{ + extract::{Extension, Host, multipart::Multipart, Path}, + response::{IntoResponse, Response}, + headers::{Header, HeaderValue, IfNoneMatch, HeaderMapExt}, + TypedHeader, +}; +use kittybox_util::error::{MicropubError, ErrorType}; +use kittybox_indieauth::Scope; +use crate::indieauth::{User, backend::AuthBackend}; + +pub mod storage; +use storage::{MediaStore, MediaStoreError, Metadata, ErrorKind}; +pub use storage::file::FileStore; + +impl From<MediaStoreError> for MicropubError { + fn from(err: MediaStoreError) -> Self { + Self { + error: ErrorType::InternalServerError, + error_description: format!("{}", err) + } + } +} + +#[tracing::instrument(skip(blobstore))] +pub(crate) async fn upload<S: MediaStore, A: AuthBackend>( + Extension(blobstore): Extension<S>, + user: User<A>, + mut upload: Multipart +) -> Response { + if !user.check_scope(&Scope::Media) { + return MicropubError { + error: ErrorType::NotAuthorized, + error_description: "Interacting with the media storage requires the \"media\" scope.".to_owned() + }.into_response(); + } + let host = user.me.host().unwrap().to_string() + &user.me.port().map(|i| format!(":{}", i)).unwrap_or_default(); + let field = match upload.next_field().await { + Ok(Some(field)) => field, + Ok(None) => { + return MicropubError { + error: ErrorType::InvalidRequest, + error_description: "Send multipart/form-data with one field named file".to_owned() + }.into_response(); + }, + Err(err) => { + return MicropubError { + error: ErrorType::InternalServerError, + error_description: format!("Error while parsing multipart/form-data: {}", err) + }.into_response(); + }, + }; + let metadata: Metadata = (&field).into(); + match blobstore.write_streaming(&host, metadata, field).await { + Ok(filename) => IntoResponse::into_response(( + axum::http::StatusCode::CREATED, + [ + ("Location", user.me.join( + &format!(".kittybox/media/uploads/{}", filename) + ).unwrap().as_str()) + ] + )), + Err(err) => MicropubError::from(err).into_response() + } +} + +#[tracing::instrument(skip(blobstore))] +pub(crate) async fn serve<S: MediaStore>( + Host(host): Host, + Path(path): Path<String>, + if_none_match: Option<TypedHeader<IfNoneMatch>>, + Extension(blobstore): Extension<S> +) -> Response { + use axum::http::StatusCode; + tracing::debug!("Searching for file..."); + match blobstore.read_streaming(&host, path.as_str()).await { + Ok((metadata, stream)) => { + tracing::debug!("Metadata: {:?}", metadata); + + let etag = if let Some(etag) = metadata.etag { + let etag = format!("\"{}\"", etag).parse::<axum::headers::ETag>().unwrap(); + + if let Some(TypedHeader(if_none_match)) = if_none_match { + tracing::debug!("If-None-Match: {:?}", if_none_match); + // If-None-Match is a negative precondition that + // returns 304 when it doesn't match because it + // only matches when file is different + if !if_none_match.precondition_passes(&etag) { + return StatusCode::NOT_MODIFIED.into_response() + } + } + + Some(etag) + } else { None }; + + let mut r = Response::builder(); + { + let headers = r.headers_mut().unwrap(); + headers.insert( + "Content-Type", + HeaderValue::from_str( + metadata.content_type + .as_deref() + .unwrap_or("application/octet-stream") + ).unwrap() + ); + if let Some(length) = metadata.length { + headers.insert( + "Content-Length", + HeaderValue::from_str(&length.to_string()).unwrap() + ); + } + if let Some(etag) = etag { + headers.typed_insert(etag); + } + } + r.body(axum::body::StreamBody::new(stream)) + .unwrap() + .into_response() + }, + Err(err) => match err.kind() { + ErrorKind::NotFound => { + IntoResponse::into_response(StatusCode::NOT_FOUND) + }, + _ => { + tracing::error!("{}", err); + IntoResponse::into_response(StatusCode::INTERNAL_SERVER_ERROR) + } + } + } +} + +#[must_use] +pub fn router<S: MediaStore, A: AuthBackend>(blobstore: S, auth: A) -> axum::Router { + axum::Router::new() + .route("/", axum::routing::post(upload::<S, A>)) + .route("/uploads/*file", axum::routing::get(serve::<S>)) + .layer(axum::Extension(blobstore)) + .layer(axum::Extension(auth)) +} diff --git a/src/media/storage/file.rs b/src/media/storage/file.rs new file mode 100644 index 0000000..0aaaa3b --- /dev/null +++ b/src/media/storage/file.rs @@ -0,0 +1,434 @@ +use super::{Metadata, ErrorKind, MediaStore, MediaStoreError, Result}; +use async_trait::async_trait; +use std::{path::PathBuf, fmt::Debug}; +use tokio::fs::OpenOptions; +use tokio::io::{BufReader, BufWriter, AsyncWriteExt, AsyncSeekExt}; +use futures::{StreamExt, TryStreamExt}; +use std::ops::{Bound, RangeBounds, Neg}; +use std::pin::Pin; +use sha2::Digest; +use futures::FutureExt; +use tracing::{debug, error}; + +const BUF_CAPACITY: usize = 16 * 1024; + +#[derive(Clone)] +pub struct FileStore { + base: PathBuf, +} + +impl From<tokio::io::Error> for MediaStoreError { + fn from(source: tokio::io::Error) -> Self { + Self { + msg: format!("file I/O error: {}", source), + kind: match source.kind() { + std::io::ErrorKind::NotFound => ErrorKind::NotFound, + _ => ErrorKind::Backend + }, + source: Some(Box::new(source)), + } + } +} + +impl FileStore { + pub fn new<T: Into<PathBuf>>(base: T) -> Self { + Self { base: base.into() } + } + + async fn mktemp(&self) -> Result<(PathBuf, BufWriter<tokio::fs::File>)> { + kittybox_util::fs::mktemp(&self.base, "temp", 16) + .await + .map(|(name, file)| (name, BufWriter::new(file))) + .map_err(Into::into) + } +} + +#[async_trait] +impl MediaStore for FileStore { + + #[tracing::instrument(skip(self, content))] + async fn write_streaming<T>( + &self, + domain: &str, + mut metadata: Metadata, + mut content: T, + ) -> Result<String> + where + T: tokio_stream::Stream<Item = std::result::Result<bytes::Bytes, axum::extract::multipart::MultipartError>> + Unpin + Send + Debug + { + let (tempfilepath, mut tempfile) = self.mktemp().await?; + debug!("Temporary file opened for storing pending upload: {}", tempfilepath.display()); + let mut hasher = sha2::Sha256::new(); + let mut length: usize = 0; + + while let Some(chunk) = content.next().await { + let chunk = chunk.map_err(|err| MediaStoreError { + kind: ErrorKind::Backend, + source: Some(Box::new(err)), + msg: "Failed to read a data chunk".to_owned() + })?; + debug!("Read {} bytes from the stream", chunk.len()); + length += chunk.len(); + let (write_result, _hasher) = tokio::join!( + { + let chunk = chunk.clone(); + let tempfile = &mut tempfile; + async move { + tempfile.write_all(&*chunk).await + } + }, + { + let chunk = chunk.clone(); + tokio::task::spawn_blocking(move || { + hasher.update(&*chunk); + + hasher + }).map(|r| r.unwrap()) + } + ); + if let Err(err) = write_result { + error!("Error while writing pending upload: {}", err); + drop(tempfile); + // this is just cleanup, nothing fails if it fails + // though temporary files might take up space on the hard drive + // We'll clean them when maintenance time comes + #[allow(unused_must_use)] + { tokio::fs::remove_file(tempfilepath).await; } + return Err(err.into()); + } + hasher = _hasher; + } + // Manually flush the buffer and drop the handle to close the file + tempfile.flush().await?; + tempfile.into_inner().sync_all().await?; + + let hash = hasher.finalize(); + debug!("Pending upload hash: {}", hex::encode(&hash)); + let filename = format!( + "{}/{}/{}/{}/{}", + hex::encode([hash[0]]), + hex::encode([hash[1]]), + hex::encode([hash[2]]), + hex::encode([hash[3]]), + hex::encode(&hash[4..32]) + ); + let domain_str = domain.to_string(); + let filepath = self.base.join(domain_str.as_str()).join(&filename); + let metafilename = filename.clone() + ".json"; + let metapath = self.base.join(domain_str.as_str()).join(&metafilename); + let metatemppath = self.base.join(domain_str.as_str()).join(metafilename + ".tmp"); + metadata.length = std::num::NonZeroUsize::new(length); + metadata.etag = Some(hex::encode(&hash)); + debug!("File path: {}, metadata: {}", filepath.display(), metapath.display()); + { + let parent = filepath.parent().unwrap(); + tokio::fs::create_dir_all(parent).await?; + } + let mut meta = OpenOptions::new() + .create_new(true) + .write(true) + .open(&metatemppath) + .await?; + meta.write_all(&serde_json::to_vec(&metadata).unwrap()).await?; + tokio::fs::rename(tempfilepath, filepath).await?; + tokio::fs::rename(metatemppath, metapath).await?; + Ok(filename) + } + + #[tracing::instrument(skip(self))] + async fn read_streaming( + &self, + domain: &str, + filename: &str, + ) -> Result<(Metadata, Pin<Box<dyn tokio_stream::Stream<Item = std::io::Result<bytes::Bytes>> + Send>>)> { + debug!("Domain: {}, filename: {}", domain, filename); + let path = self.base.join(domain).join(filename); + debug!("Path: {}", path.display()); + + let file = OpenOptions::new() + .read(true) + .open(path) + .await?; + let meta = self.metadata(domain, filename).await?; + + Ok((meta, Box::pin( + tokio_util::io::ReaderStream::new( + // TODO: determine if BufReader provides benefit here + // From the logs it looks like we're reading 4KiB at a time + // Buffering file contents seems to double download speed + // How to benchmark this? + BufReader::with_capacity(BUF_CAPACITY, file) + ) + // Sprinkle some salt in form of protective log wrapping + .inspect_ok(|chunk| debug!("Read {} bytes from file", chunk.len())) + ))) + } + + #[tracing::instrument(skip(self))] + async fn metadata(&self, domain: &str, filename: &str) -> Result<Metadata> { + let metapath = self.base.join(domain).join(format!("{}.json", filename)); + debug!("Metadata path: {}", metapath.display()); + + let meta = serde_json::from_slice(&tokio::fs::read(metapath).await?) + .map_err(|err| MediaStoreError { + kind: ErrorKind::Json, + msg: format!("{}", err), + source: Some(Box::new(err)) + })?; + + Ok(meta) + } + + #[tracing::instrument(skip(self))] + async fn stream_range( + &self, + domain: &str, + filename: &str, + range: (Bound<u64>, Bound<u64>) + ) -> Result<Pin<Box<dyn tokio_stream::Stream<Item = std::io::Result<bytes::Bytes>> + Send>>> { + let path = self.base.join(format!("{}/{}", domain, filename)); + let metapath = self.base.join(format!("{}/{}.json", domain, filename)); + debug!("Path: {}, metadata: {}", path.display(), metapath.display()); + + let mut file = OpenOptions::new() + .read(true) + .open(path) + .await?; + + let start = match range { + (Bound::Included(bound), _) => { + debug!("Seeking {} bytes forward...", bound); + file.seek(std::io::SeekFrom::Start(bound)).await? + } + (Bound::Excluded(_), _) => unreachable!(), + (Bound::Unbounded, Bound::Included(bound)) => { + // Seek to the end minus the bounded bytes + debug!("Seeking {} bytes back from the end...", bound); + file.seek(std::io::SeekFrom::End(i64::try_from(bound).unwrap().neg())).await? + }, + (Bound::Unbounded, Bound::Unbounded) => 0, + (_, Bound::Excluded(_)) => unreachable!() + }; + + let stream = Box::pin(tokio_util::io::ReaderStream::new(BufReader::with_capacity(BUF_CAPACITY, file))) + .map_ok({ + let mut bytes_read = 0usize; + let len = match range { + (_, Bound::Unbounded) => None, + (Bound::Unbounded, Bound::Included(bound)) => Some(bound), + (_, Bound::Included(bound)) => Some(bound + 1 - start), + (_, Bound::Excluded(_)) => unreachable!() + }; + move |chunk| { + debug!("Read {} bytes from file, {} in this chunk", bytes_read, chunk.len()); + bytes_read += chunk.len(); + if let Some(len) = len.map(|len| len.try_into().unwrap()) { + if bytes_read > len { + if bytes_read - len > chunk.len() { + return None + } + debug!("Truncating last {} bytes", bytes_read - len); + return Some(chunk.slice(..chunk.len() - (bytes_read - len))) + } + } + + Some(chunk) + } + }) + .try_take_while(|x| std::future::ready(Ok(x.is_some()))) + // Will never panic, because the moment the stream yields + // a None, it is considered exhausted. + .map_ok(|x| x.unwrap()); + + return Ok(Box::pin(stream)) + } + + + async fn delete(&self, domain: &str, filename: &str) -> Result<()> { + let path = self.base.join(format!("{}/{}", domain, filename)); + + Ok(tokio::fs::remove_file(path).await?) + } +} + +#[cfg(test)] +mod tests { + use super::{Metadata, FileStore, MediaStore}; + use std::ops::Bound; + use tokio::io::AsyncReadExt; + + #[tokio::test] + #[tracing_test::traced_test] + async fn test_ranges() { + let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); + let store = FileStore::new(tempdir.path()); + + let file: &[u8] = include_bytes!("./file.rs"); + let stream = tokio_stream::iter(file.chunks(100).map(|i| Ok(bytes::Bytes::copy_from_slice(i)))); + let metadata = Metadata { + filename: Some("file.rs".to_string()), + content_type: Some("text/plain".to_string()), + length: None, + etag: None, + }; + + // write through the interface + let filename = store.write_streaming( + "fireburn.ru", + metadata, stream + ).await.unwrap(); + + tracing::debug!("Writing complete."); + + // Ensure the file is there + let content = tokio::fs::read( + tempdir.path() + .join("fireburn.ru") + .join(&filename) + ).await.unwrap(); + assert_eq!(content, file); + + tracing::debug!("Reading range from the start..."); + // try to read range + let range = { + let stream = store.stream_range( + "fireburn.ru", &filename, + (Bound::Included(0), Bound::Included(299)) + ).await.unwrap(); + + let mut reader = tokio_util::io::StreamReader::new(stream); + + let mut buf = Vec::default(); + reader.read_to_end(&mut buf).await.unwrap(); + + buf + }; + + assert_eq!(range.len(), 300); + assert_eq!(range.as_slice(), &file[..=299]); + + tracing::debug!("Reading range from the middle..."); + + let range = { + let stream = store.stream_range( + "fireburn.ru", &filename, + (Bound::Included(150), Bound::Included(449)) + ).await.unwrap(); + + let mut reader = tokio_util::io::StreamReader::new(stream); + + let mut buf = Vec::default(); + reader.read_to_end(&mut buf).await.unwrap(); + + buf + }; + + assert_eq!(range.len(), 300); + assert_eq!(range.as_slice(), &file[150..=449]); + + tracing::debug!("Reading range from the end..."); + let range = { + let stream = store.stream_range( + "fireburn.ru", &filename, + // Note: the `headers` crate parses bounds in a + // non-standard way, where unbounded start actually + // means getting things from the end... + (Bound::Unbounded, Bound::Included(300)) + ).await.unwrap(); + + let mut reader = tokio_util::io::StreamReader::new(stream); + + let mut buf = Vec::default(); + reader.read_to_end(&mut buf).await.unwrap(); + + buf + }; + + assert_eq!(range.len(), 300); + assert_eq!(range.as_slice(), &file[file.len()-300..file.len()]); + + tracing::debug!("Reading the whole file..."); + // try to read range + let range = { + let stream = store.stream_range( + "fireburn.ru", &("/".to_string() + &filename), + (Bound::Unbounded, Bound::Unbounded) + ).await.unwrap(); + + let mut reader = tokio_util::io::StreamReader::new(stream); + + let mut buf = Vec::default(); + reader.read_to_end(&mut buf).await.unwrap(); + + buf + }; + + assert_eq!(range.len(), file.len()); + assert_eq!(range.as_slice(), file); + } + + + #[tokio::test] + #[tracing_test::traced_test] + async fn test_streaming_read_write() { + let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); + let store = FileStore::new(tempdir.path()); + + let file: &[u8] = include_bytes!("./file.rs"); + let stream = tokio_stream::iter(file.chunks(100).map(|i| Ok(bytes::Bytes::copy_from_slice(i)))); + let metadata = Metadata { + filename: Some("style.css".to_string()), + content_type: Some("text/css".to_string()), + length: None, + etag: None, + }; + + // write through the interface + let filename = store.write_streaming( + "fireburn.ru", + metadata, stream + ).await.unwrap(); + println!("{}, {}", filename, tempdir.path() + .join("fireburn.ru") + .join(&filename) + .display()); + let content = tokio::fs::read( + tempdir.path() + .join("fireburn.ru") + .join(&filename) + ).await.unwrap(); + assert_eq!(content, file); + + // check internal metadata format + let meta: Metadata = serde_json::from_slice(&tokio::fs::read( + tempdir.path() + .join("fireburn.ru") + .join(filename.clone() + ".json") + ).await.unwrap()).unwrap(); + assert_eq!(meta.content_type.as_deref(), Some("text/css")); + assert_eq!(meta.filename.as_deref(), Some("style.css")); + assert_eq!(meta.length.map(|i| i.get()), Some(file.len())); + assert!(meta.etag.is_some()); + + // read back the data using the interface + let (metadata, read_back) = { + let (metadata, stream) = store.read_streaming( + "fireburn.ru", + &filename + ).await.unwrap(); + let mut reader = tokio_util::io::StreamReader::new(stream); + + let mut buf = Vec::default(); + reader.read_to_end(&mut buf).await.unwrap(); + + (metadata, buf) + }; + + assert_eq!(read_back, file); + assert_eq!(metadata.content_type.as_deref(), Some("text/css")); + assert_eq!(meta.filename.as_deref(), Some("style.css")); + assert_eq!(meta.length.map(|i| i.get()), Some(file.len())); + assert!(meta.etag.is_some()); + + } +} diff --git a/src/media/storage/mod.rs b/src/media/storage/mod.rs new file mode 100644 index 0000000..020999c --- /dev/null +++ b/src/media/storage/mod.rs @@ -0,0 +1,177 @@ +use async_trait::async_trait; +use axum::extract::multipart::Field; +use tokio_stream::Stream; +use bytes::Bytes; +use serde::{Deserialize, Serialize}; +use std::ops::Bound; +use std::pin::Pin; +use std::fmt::Debug; +use std::num::NonZeroUsize; + +pub mod file; + +#[derive(Debug, Deserialize, Serialize)] +pub struct Metadata { + /// Content type of the file. If None, the content-type is considered undefined. + pub content_type: Option<String>, + /// The original filename that was passed. + pub filename: Option<String>, + /// The recorded length of the file. + pub length: Option<NonZeroUsize>, + /// The e-tag of a file. Note: it must be a strong e-tag, for example, a hash. + pub etag: Option<String>, +} +impl From<&Field<'_>> for Metadata { + fn from(field: &Field<'_>) -> Self { + Self { + content_type: field.content_type() + .map(|i| i.to_owned()), + filename: field.file_name() + .map(|i| i.to_owned()), + length: None, + etag: None, + } + } +} + + +#[derive(Debug, Clone, Copy)] +pub enum ErrorKind { + Backend, + Permission, + Json, + NotFound, + Other, +} + +#[derive(Debug)] +pub struct MediaStoreError { + kind: ErrorKind, + source: Option<Box<dyn std::error::Error + Send + Sync>>, + msg: String, +} + +impl MediaStoreError { + pub fn kind(&self) -> ErrorKind { + self.kind + } +} + +impl std::error::Error for MediaStoreError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.source + .as_ref() + .map(|i| i.as_ref() as &dyn std::error::Error) + } +} + +impl std::fmt::Display for MediaStoreError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}: {}", + match self.kind { + ErrorKind::Backend => "media storage backend error", + ErrorKind::Permission => "permission denied", + ErrorKind::Json => "failed to parse json", + ErrorKind::NotFound => "blob not found", + ErrorKind::Other => "unknown media storage error", + }, + self.msg + ) + } +} + +pub type Result<T> = std::result::Result<T, MediaStoreError>; + +#[async_trait] +pub trait MediaStore: 'static + Send + Sync + Clone { + async fn write_streaming<T>( + &self, + domain: &str, + metadata: Metadata, + content: T, + ) -> Result<String> + where + T: tokio_stream::Stream<Item = std::result::Result<bytes::Bytes, axum::extract::multipart::MultipartError>> + Unpin + Send + Debug; + + async fn read_streaming( + &self, + domain: &str, + filename: &str, + ) -> Result<(Metadata, Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>)>; + + async fn stream_range( + &self, + domain: &str, + filename: &str, + range: (Bound<u64>, Bound<u64>) + ) -> Result<Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>> { + use futures::stream::TryStreamExt; + use tracing::debug; + let (metadata, mut stream) = self.read_streaming(domain, filename).await?; + let length = metadata.length.unwrap().get(); + + use Bound::*; + let (start, end): (usize, usize) = match range { + (Unbounded, Unbounded) => return Ok(stream), + (Included(start), Unbounded) => (start.try_into().unwrap(), length - 1), + (Unbounded, Included(end)) => (length - usize::try_from(end).unwrap(), length - 1), + (Included(start), Included(end)) => (start.try_into().unwrap(), end.try_into().unwrap()), + (_, _) => unreachable!() + }; + + stream = Box::pin( + stream.map_ok({ + let mut bytes_skipped = 0usize; + let mut bytes_read = 0usize; + + move |chunk| { + debug!("Skipped {}/{} bytes, chunk len {}", bytes_skipped, start, chunk.len()); + let chunk = if bytes_skipped < start { + let need_to_skip = start - bytes_skipped; + if chunk.len() < need_to_skip { + return None + } + debug!("Skipping {} bytes", need_to_skip); + bytes_skipped += need_to_skip; + + chunk.slice(need_to_skip..) + } else { + chunk + }; + + debug!("Read {} bytes from file, {} in this chunk", bytes_read, chunk.len()); + bytes_read += chunk.len(); + + if bytes_read > length { + if bytes_read - length > chunk.len() { + return None + } + debug!("Truncating last {} bytes", bytes_read - length); + return Some(chunk.slice(..chunk.len() - (bytes_read - length))) + } + + Some(chunk) + } + }) + .try_skip_while(|x| std::future::ready(Ok(x.is_none()))) + .try_take_while(|x| std::future::ready(Ok(x.is_some()))) + .map_ok(|x| x.unwrap()) + ); + + return Ok(stream); + } + + /// Read metadata for a file. + /// + /// The default implementation uses the `read_streaming` method + /// and drops the stream containing file content. + async fn metadata(&self, domain: &str, filename: &str) -> Result<Metadata> { + self.read_streaming(domain, filename) + .await + .map(|(meta, stream)| meta) + } + + async fn delete(&self, domain: &str, filename: &str) -> Result<()>; +} diff --git a/src/metrics.rs b/src/metrics.rs new file mode 100644 index 0000000..e13fcb9 --- /dev/null +++ b/src/metrics.rs @@ -0,0 +1,21 @@ +#![allow(unused_imports, dead_code)] +use async_trait::async_trait; +use lazy_static::lazy_static; +use prometheus::Encoder; +use std::time::{Duration, Instant}; + +// TODO: Vendor in the Metrics struct from warp_prometheus and rework the path matching algorithm + +pub fn metrics(path_includes: Vec<String>) -> warp::log::Log<impl Fn(warp::log::Info) + Clone> { + let metrics = warp_prometheus::Metrics::new(prometheus::default_registry(), &path_includes); + warp::log::custom(move |info| metrics.http_metrics(info)) +} + +pub fn gather() -> Vec<u8> { + let mut buffer: Vec<u8> = vec![]; + let encoder = prometheus::TextEncoder::new(); + let metric_families = prometheus::gather(); + encoder.encode(&metric_families, &mut buffer).unwrap(); + + buffer +} diff --git a/src/micropub/get.rs b/src/micropub/get.rs new file mode 100644 index 0000000..718714a --- /dev/null +++ b/src/micropub/get.rs @@ -0,0 +1,82 @@ +use crate::database::{MicropubChannel, Storage}; +use crate::indieauth::User; +use crate::ApplicationState; +use tide::prelude::{json, Deserialize}; +use tide::{Request, Response, Result}; + +#[derive(Deserialize)] +struct QueryOptions { + q: String, + url: Option<String>, +} + +pub async fn get_handler<Backend>(req: Request<ApplicationState<Backend>>) -> Result +where + Backend: Storage + Send + Sync, +{ + let user = req.ext::<User>().unwrap(); + let backend = &req.state().storage; + let media_endpoint = &req.state().media_endpoint; + let query = req.query::<QueryOptions>().unwrap_or(QueryOptions { + q: "".to_string(), + url: None, + }); + match &*query.q { + "config" => { + let channels: Vec<MicropubChannel>; + match backend.get_channels(user.me.as_str()).await { + Ok(chans) => channels = chans, + Err(err) => return Ok(err.into()) + } + Ok(Response::builder(200).body(json!({ + "q": ["source", "config", "channel"], + "channels": channels, + "media-endpoint": media_endpoint + })).build()) + }, + "channel" => { + let channels: Vec<MicropubChannel>; + match backend.get_channels(user.me.as_str()).await { + Ok(chans) => channels = chans, + Err(err) => return Ok(err.into()) + } + Ok(Response::builder(200).body(json!(channels)).build()) + } + "source" => { + if user.check_scope("create") || user.check_scope("update") || user.check_scope("delete") || user.check_scope("undelete") { + if let Some(url) = query.url { + match backend.get_post(&url).await { + Ok(post) => if let Some(post) = post { + Ok(Response::builder(200).body(post).build()) + } else { + Ok(Response::builder(404).build()) + }, + Err(err) => Ok(err.into()) + } + } else { + Ok(Response::builder(400).body(json!({ + "error": "invalid_request", + "error_description": "Please provide `url`." + })).build()) + } + } else { + Ok(Response::builder(401).body(json!({ + "error": "insufficient_scope", + "error_description": "You don't have the required scopes to proceed.", + "scope": "update" + })).build()) + } + }, + // TODO: ?q=food, ?q=geo, ?q=contacts + // Depends on indexing posts + // Errors + "" => Ok(Response::builder(400).body(json!({ + "error": "invalid_request", + "error_description": "No ?q= parameter specified. Try ?q=config maybe?" + })).build()), + _ => Ok(Response::builder(400).body(json!({ + "error": "invalid_request", + "error_description": "Unsupported ?q= query. Try ?q=config and see the q array for supported values." + })).build()) + } +} diff --git a/src/micropub/mod.rs b/src/micropub/mod.rs new file mode 100644 index 0000000..02eee6e --- /dev/null +++ b/src/micropub/mod.rs @@ -0,0 +1,846 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use crate::database::{MicropubChannel, Storage, StorageError}; +use crate::indieauth::backend::AuthBackend; +use crate::indieauth::User; +use crate::micropub::util::form_to_mf2_json; +use axum::extract::{BodyStream, Query, Host}; +use axum::headers::ContentType; +use axum::response::{IntoResponse, Response}; +use axum::TypedHeader; +use axum::{http::StatusCode, Extension}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use tokio::sync::Mutex; +use tokio::task::JoinSet; +use tracing::{debug, error, info, warn}; +use kittybox_indieauth::{Scope, TokenData}; +use kittybox_util::{MicropubError, ErrorType}; + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[serde(rename_all = "kebab-case")] +enum QueryType { + Source, + Config, + Channel, + SyndicateTo, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct MicropubQuery { + q: QueryType, + url: Option<String>, +} + +impl From<StorageError> for MicropubError { + fn from(err: StorageError) -> Self { + Self { + error: match err.kind() { + crate::database::ErrorKind::NotFound => ErrorType::NotFound, + _ => ErrorType::InternalServerError, + }, + error_description: format!("Backend error: {}", err), + } + } +} + +mod util; +pub(crate) use util::normalize_mf2; + +#[derive(Debug)] +struct FetchedPostContext { + url: url::Url, + mf2: serde_json::Value, + webmention: Option<url::Url>, +} + +fn populate_reply_context( + mf2: &serde_json::Value, + prop: &str, + ctxs: &[FetchedPostContext], +) -> Option<Vec<serde_json::Value>> { + mf2["properties"][prop].as_array().map(|array| { + array + .iter() + // TODO: This seems to be O(n^2) and I don't like it. + // Switching `ctxs` to a hashmap might speed it up to O(n) + // The key would be the URL/UID + .map(|i| ctxs + .iter() + .find(|ctx| Some(ctx.url.as_str()) == i.as_str()) + .and_then(|ctx| ctx.mf2["items"].get(0)) + .unwrap_or(i)) + .cloned() + .collect::<Vec<serde_json::Value>>() + }) +} + +#[tracing::instrument(skip(db))] +async fn background_processing<D: 'static + Storage>( + db: D, + mf2: serde_json::Value, + http: reqwest::Client, +) -> () { + // TODO: Post-processing the post (aka second write pass) + // - [x] Download rich reply contexts + // - [ ] Syndicate the post if requested, add links to the syndicated copies + // - [ ] Send WebSub notifications to the hub (if we happen to have one) + // - [x] Send webmentions + + use futures_util::StreamExt; + + let uid: &str = mf2["properties"]["uid"][0].as_str().unwrap(); + + let context_props = ["in-reply-to", "like-of", "repost-of", "bookmark-of"]; + let mut context_urls: Vec<url::Url> = vec![]; + for prop in &context_props { + if let Some(array) = mf2["properties"][prop].as_array() { + context_urls.extend( + array + .iter() + .filter_map(|v| v.as_str()) + .filter_map(|v| v.parse::<url::Url>().ok()), + ); + } + } + // TODO parse HTML in e-content and add links found here + context_urls.sort_unstable_by_key(|u| u.to_string()); + context_urls.dedup(); + + // TODO: Make a stream to fetch all these posts and convert them to MF2 + let post_contexts = { + let http = &http; + tokio_stream::iter(context_urls.into_iter()) + .then(move |url: url::Url| http.get(url).send()) + .filter_map(|response| futures::future::ready(response.ok())) + .filter(|response| futures::future::ready(response.status() == 200)) + .filter_map(|response: reqwest::Response| async move { + // 1. We need to preserve the URL + // 2. We need to get the HTML for MF2 processing + // 3. We need to get the webmention endpoint address + // All of that can be done in one go. + let url = response.url().clone(); + // TODO parse link headers + let links = response + .headers() + .get_all(hyper::http::header::LINK) + .iter() + .cloned() + .collect::<Vec<hyper::http::HeaderValue>>(); + let html = response.text().await; + if html.is_err() { + return None; + } + let html = html.unwrap(); + let mf2 = microformats::from_html(&html, url.clone()).unwrap(); + // TODO use first Link: header if available + let webmention: Option<url::Url> = mf2 + .rels + .by_rels() + .get("webmention") + .and_then(|i| i.first().cloned()); + + dbg!(Some(FetchedPostContext { + url, + mf2: serde_json::to_value(mf2).unwrap(), + webmention + })) + }) + .collect::<Vec<FetchedPostContext>>() + .await + }; + + let mut update = MicropubUpdate { + replace: Some(Default::default()), + ..Default::default() + }; + for prop in context_props { + if let Some(json) = populate_reply_context(&mf2, prop, &post_contexts) { + update.replace.as_mut().unwrap().insert(prop.to_owned(), json); + } + } + if !update.replace.as_ref().unwrap().is_empty() { + if let Err(err) = db.update_post(uid, update).await { + error!("Failed to update post with rich reply contexts: {}", err); + } + } + + // At this point we can start syndicating the post. + // Currently we don't really support any syndication endpoints, but still! + /*if let Some(syndicate_to) = mf2["properties"]["mp-syndicate-to"].as_array() { + let http = &http; + tokio_stream::iter(syndicate_to) + .filter_map(|i| futures::future::ready(i.as_str())) + .for_each_concurrent(3, |s: &str| async move { + #[allow(clippy::match_single_binding)] + match s { + _ => { + todo!("Syndicate to generic webmention-aware service {}", s); + } + // TODO special handling for non-webmention-aware services like the birdsite + } + }) + .await; + }*/ + + { + let http = &http; + tokio_stream::iter( + post_contexts + .into_iter() + .filter(|ctx| ctx.webmention.is_some()), + ) + .for_each_concurrent(2, |ctx| async move { + let mut map = std::collections::HashMap::new(); + map.insert("source", uid); + map.insert("target", ctx.url.as_str()); + + match http + .post(ctx.webmention.unwrap().clone()) + .form(&map) + .send() + .await + { + Ok(res) => { + if !res.status().is_success() { + warn!( + "Failed to send a webmention for {}: got HTTP {}", + ctx.url, + res.status() + ); + } else { + info!( + "Sent a webmention to {}, got HTTP {}", + ctx.url, + res.status() + ) + } + } + Err(err) => warn!("Failed to send a webmention for {}: {}", ctx.url, err), + } + }) + .await; + } +} + +// TODO actually save the post to the database and schedule post-processing +pub(crate) async fn _post<D: 'static + Storage>( + user: &TokenData, + uid: String, + mf2: serde_json::Value, + db: D, + http: reqwest::Client, + jobset: Arc<Mutex<JoinSet<()>>>, +) -> Result<Response, MicropubError> { + // Here, we have the following guarantees: + // - The MF2-JSON document is normalized (guaranteed by normalize_mf2) + // - The MF2-JSON document contains a UID + // - The MF2-JSON document's URL list contains its UID + // - The MF2-JSON document's "content" field contains an HTML blob, if present + // - The MF2-JSON document's publishing datetime is present + // - The MF2-JSON document's target channels are set + // - The MF2-JSON document's author is set + + // Security check! Do we have an OAuth2 scope to proceed? + if !user.check_scope(&Scope::Create) { + return Err(MicropubError { + error: ErrorType::InvalidScope, + error_description: "Not enough privileges - try acquiring the \"create\" scope." + .to_owned(), + }); + } + + // Security check #2! Are we posting to our own website? + if !uid.starts_with(user.me.as_str()) + || mf2["properties"]["channel"] + .as_array() + .unwrap_or(&vec![]) + .iter() + .any(|url| !url.as_str().unwrap().starts_with(user.me.as_str())) + { + return Err(MicropubError { + error: ErrorType::Forbidden, + error_description: "You're posting to a website that's not yours.".to_owned(), + }); + } + + // Security check #3! Are we overwriting an existing document? + if db.post_exists(&uid).await? { + return Err(MicropubError { + error: ErrorType::AlreadyExists, + error_description: "UID clash was detected, operation aborted.".to_owned(), + }); + } + let user_domain = format!( + "{}{}", + user.me.host_str().unwrap(), + user.me.port() + .map(|port| format!(":{}", port)) + .unwrap_or_default() + ); + // Save the post + tracing::debug!("Saving post to database..."); + db.put_post(&mf2, &user_domain).await?; + + let mut channels = mf2["properties"]["channel"] + .as_array() + .unwrap() + .iter() + .map(|i| i.as_str().unwrap_or("")) + .filter(|i| !i.is_empty()); + + let default_channel = user + .me + .join(util::DEFAULT_CHANNEL_PATH) + .unwrap() + .to_string(); + let vcards_channel = user + .me + .join(util::CONTACTS_CHANNEL_PATH) + .unwrap() + .to_string(); + let food_channel = user.me.join(util::FOOD_CHANNEL_PATH).unwrap().to_string(); + let default_channels = vec![default_channel, vcards_channel, food_channel]; + + for chan in &mut channels { + debug!("Adding post {} to channel {}", uid, chan); + if db.post_exists(chan).await? { + db.add_to_feed(chan, &uid).await?; + } else if default_channels.iter().any(|i| chan == i) { + util::create_feed(&db, &uid, chan, user).await?; + } else { + warn!("Ignoring non-existent channel: {}", chan); + } + } + + let reply = + IntoResponse::into_response((StatusCode::ACCEPTED, [("Location", uid.as_str())])); + + #[cfg(not(tokio_unstable))] + jobset.lock().await.spawn(background_processing(db, mf2, http)); + #[cfg(tokio_unstable)] + jobset.lock().await.build_task() + .name(format!("Kittybox background processing for post {}", uid.as_str()).as_str()) + .spawn(background_processing(db, mf2, http)); + + Ok(reply) +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "snake_case")] +enum ActionType { + Delete, + Update, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(untagged)] +pub enum MicropubPropertyDeletion { + Properties(Vec<String>), + Values(HashMap<String, Vec<serde_json::Value>>) +} +#[derive(Serialize, Deserialize)] +struct MicropubFormAction { + action: ActionType, + url: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct MicropubAction { + action: ActionType, + url: String, + #[serde(flatten)] + #[serde(skip_serializing_if = "Option::is_none")] + update: Option<MicropubUpdate> +} + +#[derive(Serialize, Deserialize, Debug, Default)] +pub struct MicropubUpdate { + #[serde(skip_serializing_if = "Option::is_none")] + pub replace: Option<HashMap<String, Vec<serde_json::Value>>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub add: Option<HashMap<String, Vec<serde_json::Value>>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub delete: Option<MicropubPropertyDeletion>, + +} + +impl From<MicropubFormAction> for MicropubAction { + fn from(a: MicropubFormAction) -> Self { + debug_assert!(matches!(a.action, ActionType::Delete)); + Self { + action: a.action, + url: a.url, + update: None + } + } +} + +#[tracing::instrument(skip(db))] +async fn post_action<D: Storage, A: AuthBackend>( + action: MicropubAction, + db: D, + user: User<A>, +) -> Result<(), MicropubError> { + let uri = if let Ok(uri) = action.url.parse::<hyper::Uri>() { + uri + } else { + return Err(MicropubError { + error: ErrorType::InvalidRequest, + error_description: "Your URL doesn't parse properly.".to_owned(), + }); + }; + + if uri.authority().unwrap() + != user + .me + .as_str() + .parse::<hyper::Uri>() + .unwrap() + .authority() + .unwrap() + { + return Err(MicropubError { + error: ErrorType::Forbidden, + error_description: "Don't tamper with others' posts!".to_owned(), + }); + } + + match action.action { + ActionType::Delete => { + if !user.check_scope(&Scope::Delete) { + return Err(MicropubError { + error: ErrorType::InvalidScope, + error_description: "You need a \"delete\" scope for this.".to_owned(), + }); + } + + db.delete_post(&action.url).await? + } + ActionType::Update => { + if !user.check_scope(&Scope::Update) { + return Err(MicropubError { + error: ErrorType::InvalidScope, + error_description: "You need an \"update\" scope for this.".to_owned(), + }); + } + + db.update_post( + &action.url, + action.update.ok_or(MicropubError { + error: ErrorType::InvalidRequest, + error_description: "Update request is not set.".to_owned(), + })? + ) + .await? + } + } + + Ok(()) +} + +enum PostBody { + Action(MicropubAction), + MF2(serde_json::Value), +} + +#[tracing::instrument] +async fn dispatch_body( + mut body: BodyStream, + content_type: ContentType, +) -> Result<PostBody, MicropubError> { + let body: Vec<u8> = { + debug!("Buffering body..."); + use tokio_stream::StreamExt; + let mut buf = Vec::default(); + + while let Some(chunk) = body.next().await { + buf.extend_from_slice(&chunk.unwrap()) + } + + buf + }; + + debug!("Content-Type: {:?}", content_type); + if content_type == ContentType::json() { + if let Ok(action) = serde_json::from_slice::<MicropubAction>(&body) { + Ok(PostBody::Action(action)) + } else if let Ok(body) = serde_json::from_slice::<serde_json::Value>(&body) { + // quick sanity check + if !body.is_object() || !body["type"].is_array() { + return Err(MicropubError { + error: ErrorType::InvalidRequest, + error_description: "Invalid MF2-JSON detected: `.` should be an object, `.type` should be an array of MF2 types".to_owned() + }); + } + + Ok(PostBody::MF2(body)) + } else { + Err(MicropubError { + error: ErrorType::InvalidRequest, + error_description: "Invalid JSON object passed.".to_owned(), + }) + } + } else if content_type == ContentType::form_url_encoded() { + if let Ok(body) = serde_urlencoded::from_bytes::<MicropubFormAction>(&body) { + Ok(PostBody::Action(body.into())) + } else if let Ok(body) = serde_urlencoded::from_bytes::<Vec<(String, String)>>(&body) { + Ok(PostBody::MF2(form_to_mf2_json(body))) + } else { + Err(MicropubError { + error: ErrorType::InvalidRequest, + error_description: "Invalid form-encoded data. Try h=entry&content=Hello!" + .to_owned(), + }) + } + } else { + Err(MicropubError::new( + ErrorType::UnsupportedMediaType, + "This Content-Type is not recognized. Try application/json instead?", + )) + } +} + +#[tracing::instrument(skip(db, http))] +pub(crate) async fn post<D: Storage + 'static, A: AuthBackend>( + Extension(db): Extension<D>, + Extension(http): Extension<reqwest::Client>, + Extension(jobset): Extension<Arc<Mutex<JoinSet<()>>>>, + TypedHeader(content_type): TypedHeader<ContentType>, + user: User<A>, + body: BodyStream, +) -> axum::response::Response { + match dispatch_body(body, content_type).await { + Ok(PostBody::Action(action)) => match post_action(action, db, user).await { + Ok(()) => Response::default(), + Err(err) => err.into_response(), + }, + Ok(PostBody::MF2(mf2)) => { + let (uid, mf2) = normalize_mf2(mf2, &user); + match _post(&user, uid, mf2, db, http, jobset).await { + Ok(response) => response, + Err(err) => err.into_response(), + } + } + Err(err) => err.into_response(), + } +} + +#[tracing::instrument(skip(db))] +pub(crate) async fn query<D: Storage, A: AuthBackend>( + Extension(db): Extension<D>, + query: Option<Query<MicropubQuery>>, + Host(host): Host, + user: User<A>, +) -> axum::response::Response { + // We handle the invalid query case manually to return a + // MicropubError instead of HTTP 422 + let query = if let Some(Query(query)) = query { + query + } else { + return MicropubError::new( + ErrorType::InvalidRequest, + "Invalid query provided. Try ?q=config to see what you can do." + ).into_response(); + }; + + if axum::http::Uri::try_from(user.me.as_str()) + .unwrap() + .authority() + .unwrap() + != &host + { + return MicropubError::new( + ErrorType::NotAuthorized, + "This website doesn't belong to you.", + ) + .into_response(); + } + + let user_domain = format!( + "{}{}", + user.me.host_str().unwrap(), + user.me.port() + .map(|port| format!(":{}", port)) + .unwrap_or_default() + ); + match query.q { + QueryType::Config => { + let channels: Vec<MicropubChannel> = match db.get_channels(user.me.as_str()).await { + Ok(chans) => chans, + Err(err) => { + return MicropubError::new( + ErrorType::InternalServerError, + &format!("Error fetching channels: {}", err), + ) + .into_response() + } + }; + + axum::response::Json(json!({ + "q": [ + QueryType::Source, + QueryType::Config, + QueryType::Channel, + QueryType::SyndicateTo + ], + "channels": channels, + "_kittybox_authority": user.me.as_str(), + "syndicate-to": [], + "media-endpoint": user.me.join("/.kittybox/media").unwrap().as_str() + })) + .into_response() + } + QueryType::Source => { + match query.url { + Some(url) => { + match db.get_post(&url).await { + Ok(some) => match some { + Some(post) => axum::response::Json(&post).into_response(), + None => MicropubError::new( + ErrorType::NotFound, + "The specified MF2 object was not found in database.", + ) + .into_response(), + }, + Err(err) => MicropubError::new( + ErrorType::InternalServerError, + &format!("Backend error: {}", err), + ) + .into_response(), + } + } + None => { + // Here, one should probably attempt to query at least the main feed and collect posts + // Using a pre-made query function can't be done because it does unneeded filtering + // Don't implement for now, this is optional + MicropubError::new( + ErrorType::InvalidRequest, + "Querying for post list is not implemented yet.", + ) + .into_response() + } + } + } + QueryType::Channel => match db.get_channels(&user_domain).await { + Ok(chans) => axum::response::Json(json!({ "channels": chans })).into_response(), + Err(err) => MicropubError::new( + ErrorType::InternalServerError, + &format!("Error fetching channels: {}", err), + ) + .into_response(), + }, + QueryType::SyndicateTo => { + axum::response::Json(json!({ "syndicate-to": [] })).into_response() + } + } +} + +#[must_use] +pub fn router<S, A>( + storage: S, + http: reqwest::Client, + auth: A, + jobset: Arc<Mutex<JoinSet<()>>> +) -> axum::routing::MethodRouter +where + S: Storage + 'static, + A: AuthBackend +{ + axum::routing::get(query::<S, A>) + .post(post::<S, A>) + .layer::<_, _, std::convert::Infallible>(tower_http::cors::CorsLayer::new() + .allow_methods([ + axum::http::Method::GET, + axum::http::Method::POST, + ]) + .allow_origin(tower_http::cors::Any)) + .layer::<_, _, std::convert::Infallible>(axum::Extension(storage)) + .layer::<_, _, std::convert::Infallible>(axum::Extension(http)) + .layer::<_, _, std::convert::Infallible>(axum::Extension(auth)) + .layer::<_, _, std::convert::Infallible>(axum::Extension(jobset)) +} + +#[cfg(test)] +#[allow(dead_code)] +impl MicropubQuery { + fn config() -> Self { + Self { + q: QueryType::Config, + url: None, + } + } + + fn source(url: &str) -> Self { + Self { + q: QueryType::Source, + url: Some(url.to_owned()), + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::{database::Storage, micropub::MicropubError}; + use hyper::body::HttpBody; + use serde_json::json; + use tokio::sync::Mutex; + + use super::FetchedPostContext; + use kittybox_indieauth::{Scopes, Scope, TokenData}; + use axum::extract::Host; + + #[test] + fn test_populate_reply_context() { + let already_expanded_reply_ctx = json!({ + "type": ["h-entry"], + "properties": { + "content": ["Hello world!"] + } + }); + let mf2 = json!({ + "type": ["h-entry"], + "properties": { + "like-of": [ + "https://fireburn.ru/posts/example", + already_expanded_reply_ctx, + "https://fireburn.ru/posts/non-existent" + ] + } + }); + let test_ctx = json!({ + "type": ["h-entry"], + "properties": { + "content": ["This is a post which was reacted to."] + } + }); + let reply_contexts = vec![FetchedPostContext { + url: "https://fireburn.ru/posts/example".parse().unwrap(), + mf2: json!({ "items": [test_ctx] }), + webmention: None, + }]; + + let like_of = super::populate_reply_context(&mf2, "like-of", &reply_contexts).unwrap(); + + assert_eq!(like_of[0], test_ctx); + assert_eq!(like_of[1], already_expanded_reply_ctx); + assert_eq!(like_of[2], "https://fireburn.ru/posts/non-existent"); + } + + #[tokio::test] + async fn test_post_reject_scope() { + let db = crate::database::MemoryStorage::new(); + + let post = json!({ + "type": ["h-entry"], + "properties": { + "content": ["Hello world!"] + } + }); + let user = TokenData { + me: "https://localhost:8080/".parse().unwrap(), + client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), + scope: Scopes::new(vec![Scope::Profile]), + iat: None, exp: None + }; + let (uid, mf2) = super::normalize_mf2(post, &user); + + let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) + .await + .unwrap_err(); + + assert_eq!(err.error, super::ErrorType::InvalidScope); + + let hashmap = db.mapping.read().await; + assert!(hashmap.is_empty()); + } + + #[tokio::test] + async fn test_post_reject_different_user() { + let db = crate::database::MemoryStorage::new(); + + let post = json!({ + "type": ["h-entry"], + "properties": { + "content": ["Hello world!"], + "uid": ["https://fireburn.ru/posts/hello"], + "url": ["https://fireburn.ru/posts/hello"] + } + }); + let user = TokenData { + me: "https://aaronparecki.com/".parse().unwrap(), + client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), + scope: Scopes::new(vec![Scope::Profile, Scope::Create, Scope::Update, Scope::Media]), + iat: None, exp: None + }; + let (uid, mf2) = super::normalize_mf2(post, &user); + + let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) + .await + .unwrap_err(); + + assert_eq!(err.error, super::ErrorType::Forbidden); + + let hashmap = db.mapping.read().await; + assert!(hashmap.is_empty()); + } + + #[tokio::test] + async fn test_post_mf2() { + let db = crate::database::MemoryStorage::new(); + + let post = json!({ + "type": ["h-entry"], + "properties": { + "content": ["Hello world!"] + } + }); + let user = TokenData { + me: "https://localhost:8080/".parse().unwrap(), + client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), + scope: Scopes::new(vec![Scope::Profile, Scope::Create]), + iat: None, exp: None + }; + let (uid, mf2) = super::normalize_mf2(post, &user); + + let res = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) + .await + .unwrap(); + + assert!(res.headers().contains_key("Location")); + let location = res.headers().get("Location").unwrap(); + assert!(db.post_exists(location.to_str().unwrap()).await.unwrap()); + assert!(db + .post_exists("https://localhost:8080/feeds/main") + .await + .unwrap()); + } + + #[tokio::test] + async fn test_query_foreign_url() { + let mut res = super::query( + axum::Extension(crate::database::MemoryStorage::new()), + Some(axum::extract::Query(super::MicropubQuery::source( + "https://aaronparecki.com/feeds/main", + ))), + Host("aaronparecki.com".to_owned()), + crate::indieauth::User::<crate::indieauth::backend::fs::FileBackend>( + TokenData { + me: "https://fireburn.ru/".parse().unwrap(), + client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), + scope: Scopes::new(vec![Scope::Profile, Scope::Create, Scope::Update, Scope::Media]), + iat: None, exp: None + }, std::marker::PhantomData + ) + ) + .await; + + assert_eq!(res.status(), 401); + let body = res.body_mut().data().await.unwrap().unwrap(); + let json: MicropubError = serde_json::from_slice(&body as &[u8]).unwrap(); + assert_eq!(json.error, super::ErrorType::NotAuthorized); + } +} diff --git a/src/micropub/util.rs b/src/micropub/util.rs new file mode 100644 index 0000000..940d7c3 --- /dev/null +++ b/src/micropub/util.rs @@ -0,0 +1,444 @@ +use crate::database::Storage; +use kittybox_indieauth::TokenData; +use chrono::prelude::*; +use core::iter::Iterator; +use newbase60::num_to_sxg; +use serde_json::json; +use std::convert::TryInto; + +pub(crate) const DEFAULT_CHANNEL_PATH: &str = "/feeds/main"; +const DEFAULT_CHANNEL_NAME: &str = "Main feed"; +pub(crate) const CONTACTS_CHANNEL_PATH: &str = "/feeds/vcards"; +const CONTACTS_CHANNEL_NAME: &str = "My address book"; +pub(crate) const FOOD_CHANNEL_PATH: &str = "/feeds/food"; +const FOOD_CHANNEL_NAME: &str = "My recipe book"; + +fn get_folder_from_type(post_type: &str) -> String { + (match post_type { + "h-feed" => "feeds/", + "h-card" => "vcards/", + "h-event" => "events/", + "h-food" => "food/", + _ => "posts/", + }) + .to_string() +} + +/// Reset the datetime to a proper datetime. +/// Do not attempt to recover the information. +/// Do not pass GO. Do not collect $200. +fn reset_dt(post: &mut serde_json::Value) -> DateTime<FixedOffset> { + let curtime: DateTime<Local> = Local::now(); + post["properties"]["published"] = json!([curtime.to_rfc3339()]); + chrono::DateTime::from(curtime) +} + +pub fn normalize_mf2(mut body: serde_json::Value, user: &TokenData) -> (String, serde_json::Value) { + // Normalize the MF2 object here. + let me = &user.me; + let folder = get_folder_from_type(body["type"][0].as_str().unwrap()); + let published: DateTime<FixedOffset> = + if let Some(dt) = body["properties"]["published"][0].as_str() { + // Check if the datetime is parsable. + match DateTime::parse_from_rfc3339(dt) { + Ok(dt) => dt, + Err(_) => reset_dt(&mut body), + } + } else { + // Set the datetime. + // Note: this code block duplicates functionality with the above failsafe. + // Consider refactoring it to a helper function? + reset_dt(&mut body) + }; + match body["properties"]["uid"][0].as_str() { + None => { + let uid = serde_json::Value::String( + me.join( + &(folder.clone() + + &num_to_sxg(published.timestamp_millis().try_into().unwrap())), + ) + .unwrap() + .to_string(), + ); + body["properties"]["uid"] = serde_json::Value::Array(vec![uid.clone()]); + match body["properties"]["url"].as_array_mut() { + Some(array) => array.push(uid), + None => body["properties"]["url"] = body["properties"]["uid"].clone(), + } + } + Some(uid_str) => { + let uid = uid_str.to_string(); + match body["properties"]["url"].as_array_mut() { + Some(array) => { + if !array.iter().any(|i| i.as_str().unwrap_or("") == uid) { + array.push(serde_json::Value::String(uid)) + } + } + None => body["properties"]["url"] = body["properties"]["uid"].clone(), + } + } + } + if let Some(slugs) = body["properties"]["mp-slug"].as_array() { + let new_urls = slugs + .iter() + .map(|i| i.as_str().unwrap_or("")) + .filter(|i| i != &"") + .map(|i| me.join(&((&folder).clone() + i)).unwrap().to_string()) + .collect::<Vec<String>>(); + let urls = body["properties"]["url"].as_array_mut().unwrap(); + new_urls.iter().for_each(|i| urls.push(json!(i))); + } + let props = body["properties"].as_object_mut().unwrap(); + props.remove("mp-slug"); + + if body["properties"]["content"][0].is_string() { + // Convert the content to HTML using the `markdown` crate + body["properties"]["content"] = json!([{ + "html": markdown::to_html(body["properties"]["content"][0].as_str().unwrap()), + "value": body["properties"]["content"][0] + }]) + } + // TODO: apply this normalization to editing too + if body["properties"]["mp-channel"].is_array() { + let mut additional_channels = body["properties"]["mp-channel"].as_array().unwrap().clone(); + if let Some(array) = body["properties"]["channel"].as_array_mut() { + array.append(&mut additional_channels); + } else { + body["properties"]["channel"] = json!(additional_channels) + } + body["properties"] + .as_object_mut() + .unwrap() + .remove("mp-channel"); + } else if body["properties"]["mp-channel"].is_string() { + let chan = body["properties"]["mp-channel"] + .as_str() + .unwrap() + .to_owned(); + if let Some(array) = body["properties"]["channel"].as_array_mut() { + array.push(json!(chan)) + } else { + body["properties"]["channel"] = json!([chan]); + } + body["properties"] + .as_object_mut() + .unwrap() + .remove("mp-channel"); + } + if body["properties"]["channel"][0].as_str().is_none() { + match body["type"][0].as_str() { + Some("h-entry") => { + // Set the channel to the main channel... + // TODO find like posts and move them to separate private channel + let default_channel = me.join(DEFAULT_CHANNEL_PATH).unwrap().to_string(); + + body["properties"]["channel"] = json!([default_channel]); + } + Some("h-card") => { + let default_channel = me.join(CONTACTS_CHANNEL_PATH).unwrap().to_string(); + + body["properties"]["channel"] = json!([default_channel]); + } + Some("h-food") => { + let default_channel = me.join(FOOD_CHANNEL_PATH).unwrap().to_string(); + + body["properties"]["channel"] = json!([default_channel]); + } + // TODO h-event + /*"h-event" => { + let default_channel + },*/ + _ => { + body["properties"]["channel"] = json!([]); + } + } + } + body["properties"]["posted-with"] = json!([user.client_id]); + if body["properties"]["author"][0].as_str().is_none() { + body["properties"]["author"] = json!([me.as_str()]) + } + // TODO: maybe highlight #hashtags? + // Find other processing to do and insert it here + return ( + body["properties"]["uid"][0].as_str().unwrap().to_string(), + body, + ); +} + +pub(crate) fn form_to_mf2_json(form: Vec<(String, String)>) -> serde_json::Value { + let mut mf2 = json!({"type": [], "properties": {}}); + for (k, v) in form { + if k == "h" { + mf2["type"] + .as_array_mut() + .unwrap() + .push(json!("h-".to_string() + &v)); + } else if k != "access_token" { + let key = k.strip_suffix("[]").unwrap_or(&k); + match mf2["properties"][key].as_array_mut() { + Some(prop) => prop.push(json!(v)), + None => mf2["properties"][key] = json!([v]), + } + } + } + if mf2["type"].as_array().unwrap().is_empty() { + mf2["type"].as_array_mut().unwrap().push(json!("h-entry")); + } + mf2 +} + +pub(crate) async fn create_feed( + storage: &impl Storage, + uid: &str, + channel: &str, + user: &TokenData, +) -> crate::database::Result<()> { + let path = url::Url::parse(channel).unwrap().path().to_string(); + + let name = match path.as_str() { + DEFAULT_CHANNEL_PATH => DEFAULT_CHANNEL_NAME, + CONTACTS_CHANNEL_PATH => CONTACTS_CHANNEL_NAME, + FOOD_CHANNEL_PATH => FOOD_CHANNEL_NAME, + _ => panic!("Tried to create an unknown default feed!"), + }; + + let (_, feed) = normalize_mf2( + json!({ + "type": ["h-feed"], + "properties": { + "name": [name], + "uid": [channel] + }, + }), + user, + ); + storage.put_post(&feed, user.me.as_str()).await?; + storage.add_to_feed(channel, uid).await +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + fn token_data() -> TokenData { + TokenData { + me: "https://fireburn.ru/".parse().unwrap(), + client_id: "https://quill.p3k.io/".parse().unwrap(), + scope: kittybox_indieauth::Scopes::new(vec![kittybox_indieauth::Scope::Create]), + exp: Some(u64::MAX), + iat: Some(0) + } + } + + #[test] + fn test_form_to_mf2() { + assert_eq!( + super::form_to_mf2_json( + serde_urlencoded::from_str("h=entry&content=something%20interesting").unwrap() + ), + json!({ + "type": ["h-entry"], + "properties": { + "content": ["something interesting"] + } + }) + ) + } + + #[test] + fn test_no_replace_uid() { + let mf2 = json!({ + "type": ["h-card"], + "properties": { + "uid": ["https://fireburn.ru/"], + "name": ["Vika Nezrimaya"], + "note": ["A crazy programmer girl who wants some hugs"] + } + }); + + let (uid, normalized) = normalize_mf2( + mf2.clone(), + &token_data(), + ); + assert_eq!( + normalized["properties"]["uid"][0], mf2["properties"]["uid"][0], + "UID was replaced" + ); + assert_eq!( + normalized["properties"]["uid"][0], uid, + "Returned post location doesn't match UID" + ); + } + + #[test] + fn test_mp_channel() { + let mf2 = json!({ + "type": ["h-entry"], + "properties": { + "uid": ["https://fireburn.ru/posts/test"], + "content": [{"html": "<p>Hello world!</p>"}], + "mp-channel": ["https://fireburn.ru/feeds/test"] + } + }); + + let (_, normalized) = normalize_mf2( + mf2.clone(), + &token_data(), + ); + + assert_eq!( + normalized["properties"]["channel"], + mf2["properties"]["mp-channel"] + ); + } + + #[test] + fn test_mp_channel_as_string() { + let mf2 = json!({ + "type": ["h-entry"], + "properties": { + "uid": ["https://fireburn.ru/posts/test"], + "content": [{"html": "<p>Hello world!</p>"}], + "mp-channel": "https://fireburn.ru/feeds/test" + } + }); + + let (_, normalized) = normalize_mf2( + mf2.clone(), + &token_data(), + ); + + assert_eq!( + normalized["properties"]["channel"][0], + mf2["properties"]["mp-channel"] + ); + } + + #[test] + fn test_normalize_mf2() { + let mf2 = json!({ + "type": ["h-entry"], + "properties": { + "content": ["This is content!"] + } + }); + + let (uid, post) = normalize_mf2( + mf2, + &token_data(), + ); + assert_eq!( + post["properties"]["published"] + .as_array() + .expect("post['published'] is undefined") + .len(), + 1, + "Post doesn't have a published time" + ); + DateTime::parse_from_rfc3339(post["properties"]["published"][0].as_str().unwrap()) + .expect("Couldn't parse date from rfc3339"); + assert!( + !post["properties"]["url"] + .as_array() + .expect("post['url'] is undefined") + .is_empty(), + "Post doesn't have any URLs" + ); + assert_eq!( + post["properties"]["uid"] + .as_array() + .expect("post['uid'] is undefined") + .len(), + 1, + "Post doesn't have a single UID" + ); + assert_eq!( + post["properties"]["uid"][0], uid, + "UID of a post and its supposed location don't match" + ); + assert!( + uid.starts_with("https://fireburn.ru/posts/"), + "The post namespace is incorrect" + ); + assert_eq!( + post["properties"]["content"][0]["html"] + .as_str() + .expect("Post doesn't have a rich content object") + .trim(), + "<p>This is content!</p>", + "Parsed Markdown content doesn't match expected HTML" + ); + assert_eq!( + post["properties"]["channel"][0], "https://fireburn.ru/feeds/main", + "Post isn't posted to the main channel" + ); + assert_eq!( + post["properties"]["author"][0], "https://fireburn.ru/", + "Post author is unknown" + ); + } + + #[test] + fn test_mp_slug() { + let mf2 = json!({ + "type": ["h-entry"], + "properties": { + "content": ["This is content!"], + "mp-slug": ["hello-post"] + }, + }); + + let (_, post) = normalize_mf2( + mf2, + &token_data(), + ); + assert!( + post["properties"]["url"] + .as_array() + .unwrap() + .iter() + .map(|i| i.as_str().unwrap()) + .any(|i| i == "https://fireburn.ru/posts/hello-post"), + "Didn't found an URL pointing to the location expected by the mp-slug semantics" + ); + assert!( + post["properties"]["mp-slug"].as_array().is_none(), + "mp-slug wasn't deleted from the array!" + ) + } + + #[test] + fn test_normalize_feed() { + let mf2 = json!({ + "type": ["h-feed"], + "properties": { + "name": "Main feed", + "mp-slug": ["main"] + } + }); + + let (uid, post) = normalize_mf2( + mf2, + &token_data(), + ); + assert_eq!( + post["properties"]["uid"][0], uid, + "UID of a post and its supposed location don't match" + ); + assert_eq!(post["properties"]["author"][0], "https://fireburn.ru/"); + assert!( + post["properties"]["url"] + .as_array() + .unwrap() + .iter() + .map(|i| i.as_str().unwrap()) + .any(|i| i == "https://fireburn.ru/feeds/main"), + "Didn't found an URL pointing to the location expected by the mp-slug semantics" + ); + assert!( + post["properties"]["mp-slug"].as_array().is_none(), + "mp-slug wasn't deleted from the array!" + ) + } +} diff --git a/src/tokenauth.rs b/src/tokenauth.rs new file mode 100644 index 0000000..244a045 --- /dev/null +++ b/src/tokenauth.rs @@ -0,0 +1,358 @@ +use serde::{Deserialize, Serialize}; +use url::Url; + +#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] +pub struct User { + pub me: Url, + pub client_id: Url, + scope: String, +} + +#[derive(Debug, Clone, PartialEq, Copy)] +pub enum ErrorKind { + PermissionDenied, + NotAuthorized, + TokenEndpointError, + JsonParsing, + InvalidHeader, + Other, +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct TokenEndpointError { + error: String, + error_description: String, +} + +#[derive(Debug)] +pub struct IndieAuthError { + source: Option<Box<dyn std::error::Error + Send + Sync>>, + kind: ErrorKind, + msg: String, +} + +impl std::error::Error for IndieAuthError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.source + .as_ref() + .map(|e| e.as_ref() as &dyn std::error::Error) + } +} + +impl std::fmt::Display for IndieAuthError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}: {}", + match self.kind { + ErrorKind::TokenEndpointError => "token endpoint returned an error: ", + ErrorKind::JsonParsing => "error while parsing token endpoint response: ", + ErrorKind::NotAuthorized => "token endpoint did not recognize the token: ", + ErrorKind::PermissionDenied => "token endpoint rejected the token: ", + ErrorKind::InvalidHeader => "authorization header parsing error: ", + ErrorKind::Other => "token endpoint communication error: ", + }, + self.msg + ) + } +} + +impl From<serde_json::Error> for IndieAuthError { + fn from(err: serde_json::Error) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::JsonParsing, + } + } +} + +impl From<reqwest::Error> for IndieAuthError { + fn from(err: reqwest::Error) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::Other, + } + } +} + +impl From<axum::extract::rejection::TypedHeaderRejection> for IndieAuthError { + fn from(err: axum::extract::rejection::TypedHeaderRejection) -> Self { + Self { + msg: format!("{:?}", err.reason()), + source: Some(Box::new(err)), + kind: ErrorKind::InvalidHeader, + } + } +} + +impl axum::response::IntoResponse for IndieAuthError { + fn into_response(self) -> axum::response::Response { + let status_code: StatusCode = match self.kind { + ErrorKind::PermissionDenied => StatusCode::FORBIDDEN, + ErrorKind::NotAuthorized => StatusCode::UNAUTHORIZED, + ErrorKind::TokenEndpointError => StatusCode::INTERNAL_SERVER_ERROR, + ErrorKind::JsonParsing => StatusCode::BAD_REQUEST, + ErrorKind::InvalidHeader => StatusCode::UNAUTHORIZED, + ErrorKind::Other => StatusCode::INTERNAL_SERVER_ERROR, + }; + + let body = serde_json::json!({ + "error": match self.kind { + ErrorKind::PermissionDenied => "forbidden", + ErrorKind::NotAuthorized => "unauthorized", + ErrorKind::TokenEndpointError => "token_endpoint_error", + ErrorKind::JsonParsing => "invalid_request", + ErrorKind::InvalidHeader => "unauthorized", + ErrorKind::Other => "unknown_error", + }, + "error_description": self.msg + }); + + (status_code, axum::response::Json(body)).into_response() + } +} + +impl User { + pub fn check_scope(&self, scope: &str) -> bool { + self.scopes().any(|i| i == scope) + } + pub fn scopes(&self) -> std::str::SplitAsciiWhitespace<'_> { + self.scope.split_ascii_whitespace() + } + pub fn new(me: &str, client_id: &str, scope: &str) -> Self { + Self { + me: Url::parse(me).unwrap(), + client_id: Url::parse(client_id).unwrap(), + scope: scope.to_string(), + } + } +} + +use axum::{ + extract::{Extension, FromRequest, RequestParts, TypedHeader}, + headers::{ + authorization::{Bearer, Credentials}, + Authorization, + }, + http::StatusCode, +}; + +// this newtype is required due to axum::Extension retrieving items by type +// it's based on compiler magic matching extensions by their type's hashes +#[derive(Debug, Clone)] +pub struct TokenEndpoint(pub url::Url); + +#[async_trait::async_trait] +impl<B> FromRequest<B> for User +where + B: Send, +{ + type Rejection = IndieAuthError; + + #[cfg_attr( + all(debug_assertions, not(test)), + allow(unreachable_code, unused_variables) + )] + async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + // Return a fake user if we're running a debug build + // I don't wanna bother with authentication + #[cfg(all(debug_assertions, not(test)))] + return Ok(User::new( + "http://localhost:8080/", + "https://quill.p3k.io/", + "create update delete media", + )); + + let TypedHeader(Authorization(token)) = + TypedHeader::<Authorization<Bearer>>::from_request(req) + .await + .map_err(IndieAuthError::from)?; + + let Extension(TokenEndpoint(token_endpoint)): Extension<TokenEndpoint> = + Extension::from_request(req).await.unwrap(); + + let Extension(http): Extension<reqwest::Client> = + Extension::from_request(req).await.unwrap(); + + match http + .get(token_endpoint) + .header("Authorization", token.encode()) + .header("Accept", "application/json") + .send() + .await + { + Ok(res) => match res.status() { + StatusCode::OK => match res.json::<serde_json::Value>().await { + Ok(json) => match serde_json::from_value::<User>(json.clone()) { + Ok(user) => Ok(user), + Err(err) => { + if let Some(false) = json["active"].as_bool() { + Err(IndieAuthError { + source: None, + kind: ErrorKind::NotAuthorized, + msg: "The token is not active for this user.".to_owned(), + }) + } else { + Err(IndieAuthError::from(err)) + } + } + }, + Err(err) => Err(IndieAuthError::from(err)), + }, + StatusCode::BAD_REQUEST => match res.json::<TokenEndpointError>().await { + Ok(err) => { + if err.error == "unauthorized" { + Err(IndieAuthError { + source: None, + kind: ErrorKind::NotAuthorized, + msg: err.error_description, + }) + } else { + Err(IndieAuthError { + source: None, + kind: ErrorKind::TokenEndpointError, + msg: err.error_description, + }) + } + } + Err(err) => Err(IndieAuthError::from(err)), + }, + _ => Err(IndieAuthError { + source: None, + msg: format!("Token endpoint returned {}", res.status()), + kind: ErrorKind::TokenEndpointError, + }), + }, + Err(err) => Err(IndieAuthError::from(err)), + } + } +} + +#[cfg(test)] +mod tests { + use super::User; + use axum::{ + extract::FromRequest, + http::{Method, Request}, + }; + use wiremock::{MockServer, Mock, ResponseTemplate}; + use wiremock::matchers::{method, path, header}; + + #[test] + fn user_scopes_are_checkable() { + let user = User::new( + "https://fireburn.ru/", + "https://quill.p3k.io/", + "create update media", + ); + + assert!(user.check_scope("create")); + assert!(!user.check_scope("delete")); + } + + #[inline] + fn get_http_client() -> reqwest::Client { + reqwest::Client::new() + } + + fn request<A: Into<Option<&'static str>>>( + auth: A, + endpoint: String, + ) -> Request<()> { + let request = Request::builder().method(Method::GET); + + match auth.into() { + Some(auth) => request.header("Authorization", auth), + None => request, + } + .extension(super::TokenEndpoint(endpoint.parse().unwrap())) + .extension(get_http_client()) + .body(()) + .unwrap() + } + + #[tokio::test] + async fn test_require_token_with_token() { + let server = MockServer::start().await; + + Mock::given(path("/token")) + .and(header("Authorization", "Bearer token")) + .respond_with(ResponseTemplate::new(200) + .set_body_json(User::new( + "https://fireburn.ru/", + "https://quill.p3k.io/", + "create update media", + )) + ) + .mount(&server) + .await; + + let request = request("Bearer token", format!("{}/token", &server.uri())); + let mut parts = axum::extract::RequestParts::new(request); + let user = User::from_request(&mut parts).await.unwrap(); + + assert_eq!(user.me.as_str(), "https://fireburn.ru/") + } + + #[tokio::test] + async fn test_require_token_fake_token() { + let server = MockServer::start().await; + + Mock::given(path("/refuse_token")) + .respond_with(ResponseTemplate::new(200) + .set_body_json(serde_json::json!({"active": false})) + ) + .mount(&server) + .await; + + let request = request("Bearer token", format!("{}/refuse_token", &server.uri())); + let mut parts = axum::extract::RequestParts::new(request); + let err = User::from_request(&mut parts).await.unwrap_err(); + + assert_eq!(err.kind, super::ErrorKind::NotAuthorized) + } + + #[tokio::test] + async fn test_require_token_no_token() { + let server = MockServer::start().await; + + Mock::given(path("/should_never_be_called")) + .respond_with(ResponseTemplate::new(500)) + .expect(0) + .mount(&server) + .await; + + let request = request(None, format!("{}/should_never_be_called", &server.uri())); + let mut parts = axum::extract::RequestParts::new(request); + let err = User::from_request(&mut parts).await.unwrap_err(); + + assert_eq!(err.kind, super::ErrorKind::InvalidHeader); + } + + #[tokio::test] + async fn test_require_token_400_error_unauthorized() { + let server = MockServer::start().await; + + Mock::given(path("/refuse_token_with_400")) + .and(header("Authorization", "Bearer token")) + .respond_with(ResponseTemplate::new(400) + .set_body_json(serde_json::json!({ + "error": "unauthorized", + "error_description": "The token provided was malformed" + })) + ) + .mount(&server) + .await; + + let request = request( + "Bearer token", + format!("{}/refuse_token_with_400", &server.uri()), + ); + let mut parts = axum::extract::RequestParts::new(request); + let err = User::from_request(&mut parts).await.unwrap_err(); + + assert_eq!(err.kind, super::ErrorKind::NotAuthorized); + } +} diff --git a/src/webmentions/check.rs b/src/webmentions/check.rs new file mode 100644 index 0000000..f7322f7 --- /dev/null +++ b/src/webmentions/check.rs @@ -0,0 +1,113 @@ +use std::{cell::RefCell, rc::Rc}; +use microformats::{types::PropertyValue, html5ever::{self, tendril::TendrilSink}}; +use kittybox_util::MentionType; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("microformats error: {0}")] + Microformats(#[from] microformats::Error), + // #[error("json error: {0}")] + // Json(#[from] serde_json::Error), + #[error("url parse error: {0}")] + UrlParse(#[from] url::ParseError), +} + +#[tracing::instrument] +pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url::Url, link: &url::Url) -> Result<Option<(MentionType, serde_json::Value)>, Error> { + tracing::debug!("Parsing MF2 markup..."); + // First, check the document for MF2 markup + let document = microformats::from_html(document.as_ref(), base_url.clone())?; + + // Get an iterator of all items + let items_iter = document.items.iter() + .map(AsRef::as_ref) + .map(RefCell::borrow); + + for item in items_iter { + tracing::debug!("Processing item: {:?}", item); + + let props = item.properties.borrow(); + for (prop, interaction_type) in [ + ("in-reply-to", MentionType::Reply), ("like-of", MentionType::Like), + ("bookmark-of", MentionType::Bookmark), ("repost-of", MentionType::Repost) + ] { + if let Some(propvals) = props.get(prop) { + tracing::debug!("Has a u-{} property", prop); + for val in propvals { + if let PropertyValue::Url(url) = val { + if url == link { + tracing::debug!("URL matches! Webmention is valid"); + return Ok(Some((interaction_type, serde_json::to_value(&*item).unwrap()))) + } + } + } + } + } + // Process `content` + tracing::debug!("Processing e-content..."); + if let Some(PropertyValue::Fragment(content)) = props.get("content") + .map(Vec::as_slice) + .unwrap_or_default() + .first() + { + tracing::debug!("Parsing HTML data..."); + let root = html5ever::parse_document(html5ever::rcdom::RcDom::default(), Default::default()) + .from_utf8() + .one(content.html.to_owned().as_bytes()) + .document; + + // This is a trick to unwrap recursion into a loop + // + // A list of unprocessed node is made. Then, in each + // iteration, the list is "taken" and replaced with an + // empty list, which is populated with nodes for the next + // iteration of the loop. + // + // Empty list means all nodes were processed. + let mut unprocessed_nodes: Vec<Rc<html5ever::rcdom::Node>> = root.children.borrow().iter().cloned().collect(); + while !unprocessed_nodes.is_empty() { + // "Take" the list out of its memory slot, replace it with an empty list + let nodes = std::mem::take(&mut unprocessed_nodes); + tracing::debug!("Processing list of {} nodes", nodes.len()); + 'nodes_loop: for node in nodes.into_iter() { + // Add children nodes to the list for the next iteration + unprocessed_nodes.extend(node.children.borrow().iter().cloned()); + + if let html5ever::rcdom::NodeData::Element { ref name, ref attrs, .. } = node.data { + // If it's not `<a>`, skip it + if name.local != *"a" { continue; } + let mut is_mention: bool = false; + for attr in attrs.borrow().iter() { + if attr.name.local == *"rel" { + // Don't count `rel="nofollow"` links — a web crawler should ignore them + // and so for purposes of driving visitors they are useless + if attr.value + .as_ref() + .split([',', ' ']) + .any(|v| v == "nofollow") + { + // Skip the entire node. + continue 'nodes_loop; + } + } + // if it's not `<a href="...">`, skip it + if attr.name.local != *"href" { continue; } + // Be forgiving in parsing URLs, and resolve them against the base URL + if let Ok(url) = base_url.join(attr.value.as_ref()) { + if &url == link { + is_mention = true; + } + } + } + if is_mention { + return Ok(Some((MentionType::Mention, serde_json::to_value(&*item).unwrap()))); + } + } + } + } + + } + } + + Ok(None) +} diff --git a/src/webmentions/mod.rs b/src/webmentions/mod.rs new file mode 100644 index 0000000..95ea870 --- /dev/null +++ b/src/webmentions/mod.rs @@ -0,0 +1,195 @@ +use axum::{Form, response::{IntoResponse, Response}, Extension}; +use axum::http::StatusCode; +use tracing::error; + +use crate::database::{Storage, StorageError}; +use self::queue::JobQueue; +pub mod queue; + +#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] +pub struct Webmention { + source: String, + target: String, +} + +impl queue::JobItem for Webmention {} +impl queue::PostgresJobItem for Webmention { + const DATABASE_NAME: &'static str = "kittybox_webmention.incoming_webmention_queue"; + const NOTIFICATION_CHANNEL: &'static str = "incoming_webmention"; +} + +async fn accept_webmention<Q: JobQueue<Webmention>>( + Extension(queue): Extension<Q>, + Form(webmention): Form<Webmention>, +) -> Response { + if let Err(err) = webmention.source.parse::<url::Url>() { + return (StatusCode::BAD_REQUEST, err.to_string()).into_response() + } + if let Err(err) = webmention.target.parse::<url::Url>() { + return (StatusCode::BAD_REQUEST, err.to_string()).into_response() + } + + match queue.put(&webmention).await { + Ok(id) => (StatusCode::ACCEPTED, [ + ("Location", format!("/.kittybox/webmention/{id}")) + ]).into_response(), + Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, [ + ("Content-Type", "text/plain") + ], err.to_string()).into_response() + } +} + +pub fn router<Q: JobQueue<Webmention>, S: Storage + 'static>( + queue: Q, db: S, http: reqwest::Client, + cancellation_token: tokio_util::sync::CancellationToken +) -> (axum::Router, SupervisedTask) { + // Automatically spawn a background task to handle webmentions + let bgtask_handle = supervised_webmentions_task(queue.clone(), db, http, cancellation_token); + + let router = axum::Router::new() + .route("/.kittybox/webmention", + axum::routing::post(accept_webmention::<Q>) + ) + .layer(Extension(queue)); + + (router, bgtask_handle) +} + +#[derive(thiserror::Error, Debug)] +pub enum SupervisorError { + #[error("the task was explicitly cancelled")] + Cancelled +} + +pub type SupervisedTask = tokio::task::JoinHandle<Result<std::convert::Infallible, SupervisorError>>; + +pub fn supervisor<E, A, F>(mut f: F, cancellation_token: tokio_util::sync::CancellationToken) -> SupervisedTask +where + E: std::error::Error + std::fmt::Debug + Send + 'static, + A: std::future::Future<Output = Result<std::convert::Infallible, E>> + Send + 'static, + F: FnMut() -> A + Send + 'static +{ + + let supervisor_future = async move { + loop { + // Don't spawn the task if we are already cancelled, but + // have somehow missed it (probably because the task + // crashed and we immediately received a cancellation + // request after noticing the crashed task) + if cancellation_token.is_cancelled() { + return Err(SupervisorError::Cancelled) + } + let task = tokio::task::spawn(f()); + tokio::select! { + _ = cancellation_token.cancelled() => { + tracing::info!("Shutdown of background task {:?} requested.", std::any::type_name::<A>()); + return Err(SupervisorError::Cancelled) + } + task_result = task => match task_result { + Err(e) => tracing::error!("background task {:?} exited unexpectedly: {}", std::any::type_name::<A>(), e), + Ok(Err(e)) => tracing::error!("background task {:?} returned error: {}", std::any::type_name::<A>(), e), + Ok(Ok(_)) => unreachable!("task's Ok is Infallible") + } + } + tracing::debug!("Sleeping for a little while to back-off..."); + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + } + }; + #[cfg(not(tokio_unstable))] + return tokio::task::spawn(supervisor_future); + #[cfg(tokio_unstable)] + return tokio::task::Builder::new() + .name(format!("supervisor for background task {}", std::any::type_name::<A>()).as_str()) + .spawn(supervisor_future) + .unwrap(); +} + +mod check; + +#[derive(thiserror::Error, Debug)] +enum Error<Q: std::error::Error + std::fmt::Debug + Send + 'static> { + #[error("queue error: {0}")] + Queue(#[from] Q), + #[error("storage error: {0}")] + Storage(StorageError) +} + +async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 'static>(queue: Q, db: S, http: reqwest::Client) -> Result<std::convert::Infallible, Error<Q::Error>> { + use futures_util::StreamExt; + use self::queue::Job; + + let mut stream = queue.into_stream().await?; + while let Some(item) = stream.next().await.transpose()? { + let job = item.job(); + let (source, target) = ( + job.source.parse::<url::Url>().unwrap(), + job.target.parse::<url::Url>().unwrap() + ); + + let (code, text) = match http.get(source.clone()).send().await { + Ok(response) => { + let code = response.status(); + if ![StatusCode::OK, StatusCode::GONE].iter().any(|i| i == &code) { + error!("error processing webmention: webpage fetch returned {}", code); + continue; + } + match response.text().await { + Ok(text) => (code, text), + Err(err) => { + error!("error processing webmention: error fetching webpage text: {}", err); + continue + } + } + } + Err(err) => { + error!("error processing webmention: error requesting webpage: {}", err); + continue + } + }; + + if code == StatusCode::GONE { + todo!("removing webmentions is not implemented yet"); + // db.remove_webmention(target.as_str(), source.as_str()).await.map_err(Error::<Q::Error>::Storage)?; + } else { + // Verify webmention + let (mention_type, mut mention) = match tokio::task::block_in_place({ + || check::check_mention(text, &source, &target) + }) { + Ok(Some(mention_type)) => mention_type, + Ok(None) => { + error!("webmention {} -> {} invalid, rejecting", source, target); + item.done().await?; + continue; + } + Err(err) => { + error!("error processing webmention: error checking webmention: {}", err); + continue; + } + }; + + { + mention["type"] = serde_json::json!(["h-cite"]); + + if !mention["properties"].as_object().unwrap().contains_key("uid") { + let url = mention["properties"]["url"][0].as_str().unwrap_or_else(|| target.as_str()).to_owned(); + let props = mention["properties"].as_object_mut().unwrap(); + props.insert("uid".to_owned(), serde_json::Value::Array( + vec![serde_json::Value::String(url)]) + ); + } + } + + db.add_or_update_webmention(target.as_str(), mention_type, mention).await.map_err(Error::<Q::Error>::Storage)?; + } + } + unreachable!() +} + +fn supervised_webmentions_task<Q: JobQueue<Webmention>, S: Storage + 'static>( + queue: Q, db: S, + http: reqwest::Client, + cancellation_token: tokio_util::sync::CancellationToken +) -> SupervisedTask { + supervisor::<Error<Q::Error>, _, _>(move || process_webmentions_from_queue(queue.clone(), db.clone(), http.clone()), cancellation_token) +} diff --git a/src/webmentions/queue.rs b/src/webmentions/queue.rs new file mode 100644 index 0000000..b811e71 --- /dev/null +++ b/src/webmentions/queue.rs @@ -0,0 +1,303 @@ +use std::{pin::Pin, str::FromStr}; + +use futures_util::{Stream, StreamExt}; +use sqlx::{postgres::PgListener, Executor}; +use uuid::Uuid; + +use super::Webmention; + +static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("./migrations/webmention"); + +pub use kittybox_util::queue::{JobQueue, JobItem, Job}; + +pub trait PostgresJobItem: JobItem + sqlx::FromRow<'static, sqlx::postgres::PgRow> { + const DATABASE_NAME: &'static str; + const NOTIFICATION_CHANNEL: &'static str; +} + +#[derive(sqlx::FromRow)] +struct PostgresJobRow<T: PostgresJobItem> { + id: Uuid, + #[sqlx(flatten)] + job: T +} + +#[derive(Debug)] +pub struct PostgresJob<T: PostgresJobItem> { + id: Uuid, + job: T, + // This will normally always be Some, except on drop + txn: Option<sqlx::Transaction<'static, sqlx::Postgres>>, + runtime_handle: tokio::runtime::Handle, +} + + +impl<T: PostgresJobItem> Drop for PostgresJob<T> { + // This is an emulation of "async drop" — the struct retains a + // runtime handle, which it uses to block on a future that does + // the actual cleanup. + // + // Of course, this is not portable between runtimes, but I don't + // care about that, since Kittybox is designed to work within the + // Tokio ecosystem. + fn drop(&mut self) { + tracing::error!("Job {:?} failed, incrementing attempts...", &self); + if let Some(mut txn) = self.txn.take() { + let id = self.id; + self.runtime_handle.spawn(async move { + tracing::debug!("Constructing query to increment attempts for job {}...", id); + // UPDATE "T::DATABASE_NAME" WHERE id = $1 SET attempts = attempts + 1 + sqlx::query_builder::QueryBuilder::new("UPDATE ") + // This is safe from a SQL injection standpoint, since it is a constant. + .push(T::DATABASE_NAME) + .push(" SET attempts = attempts + 1") + .push(" WHERE id = ") + .push_bind(id) + .build() + .execute(&mut *txn) + .await + .unwrap(); + sqlx::query_builder::QueryBuilder::new("NOTIFY ") + .push(T::NOTIFICATION_CHANNEL) + .build() + .execute(&mut *txn) + .await + .unwrap(); + txn.commit().await.unwrap(); + }); + } + } +} + +#[cfg(test)] +impl<T: PostgresJobItem> PostgresJob<T> { + async fn attempts(&mut self) -> Result<usize, sqlx::Error> { + sqlx::query_builder::QueryBuilder::new("SELECT attempts FROM ") + .push(T::DATABASE_NAME) + .push(" WHERE id = ") + .push_bind(self.id) + .build_query_as::<(i32,)>() + // It's safe to unwrap here, because we "take" the txn only on drop or commit, + // where it's passed by value, not by reference. + .fetch_one(self.txn.as_deref_mut().unwrap()) + .await + .map(|(i,)| i as usize) + } +} + +#[async_trait::async_trait] +impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJob<Webmention> { + fn job(&self) -> &Webmention { + &self.job + } + async fn done(mut self) -> Result<(), <PostgresJobQueue<Webmention> as JobQueue<Webmention>>::Error> { + tracing::debug!("Deleting {} from the job queue", self.id); + sqlx::query("DELETE FROM kittybox_webmention.incoming_webmention_queue WHERE id = $1") + .bind(self.id) + .execute(self.txn.as_deref_mut().unwrap()) + .await?; + + self.txn.take().unwrap().commit().await + } +} + +pub struct PostgresJobQueue<T> { + db: sqlx::PgPool, + _phantom: std::marker::PhantomData<T> +} +impl<T> Clone for PostgresJobQueue<T> { + fn clone(&self) -> Self { + Self { + db: self.db.clone(), + _phantom: std::marker::PhantomData + } + } +} + +impl PostgresJobQueue<Webmention> { + pub async fn new(uri: &str) -> Result<Self, sqlx::Error> { + let mut options = sqlx::postgres::PgConnectOptions::from_str(uri)? + .options([("search_path", "kittybox_webmention")]); + if let Ok(password_file) = std::env::var("PGPASS_FILE") { + let password = tokio::fs::read_to_string(password_file).await.unwrap(); + options = options.password(&password); + } else if let Ok(password) = std::env::var("PGPASS") { + options = options.password(&password) + } + Self::from_pool( + sqlx::postgres::PgPoolOptions::new() + .max_connections(50) + .connect_with(options) + .await? + ).await + + } + + pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result<Self, sqlx::Error> { + db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox_webmention")).await?; + MIGRATOR.run(&db).await?; + Ok(Self { db, _phantom: std::marker::PhantomData }) + } +} + +#[async_trait::async_trait] +impl JobQueue<Webmention> for PostgresJobQueue<Webmention> { + type Job = PostgresJob<Webmention>; + type Error = sqlx::Error; + + async fn get_one(&self) -> Result<Option<Self::Job>, Self::Error> { + let mut txn = self.db.begin().await?; + + match sqlx::query_as::<_, PostgresJobRow<Webmention>>( + "SELECT id, source, target FROM kittybox_webmention.incoming_webmention_queue WHERE attempts < 5 FOR UPDATE SKIP LOCKED LIMIT 1" + ) + .fetch_optional(&mut *txn) + .await? + { + Some(job_row) => { + return Ok(Some(Self::Job { + id: job_row.id, + job: job_row.job, + txn: Some(txn), + runtime_handle: tokio::runtime::Handle::current(), + })) + }, + None => Ok(None) + } + } + + async fn put(&self, item: &Webmention) -> Result<Uuid, Self::Error> { + sqlx::query_scalar::<_, Uuid>("INSERT INTO kittybox_webmention.incoming_webmention_queue (source, target) VALUES ($1, $2) RETURNING id") + .bind(item.source.as_str()) + .bind(item.target.as_str()) + .fetch_one(&self.db) + .await + } + + async fn into_stream(self) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>>, Self::Error> { + let mut listener = PgListener::connect_with(&self.db).await?; + listener.listen("incoming_webmention").await?; + + let stream: Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>> = futures_util::stream::try_unfold((), { + let listener = std::sync::Arc::new(tokio::sync::Mutex::new(listener)); + move |_| { + let queue = self.clone(); + let listener = listener.clone(); + async move { + loop { + match queue.get_one().await? { + Some(item) => return Ok(Some((item, ()))), + None => { + listener.lock().await.recv().await?; + continue + } + } + } + } + } + }).boxed(); + + Ok(stream) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::{Webmention, PostgresJobQueue, Job, JobQueue, MIGRATOR}; + use futures_util::StreamExt; + + #[sqlx::test(migrator = "MIGRATOR")] + #[tracing_test::traced_test] + async fn test_webmention_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> { + let test_webmention = Webmention { + source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(), + target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned() + }; + + let queue = PostgresJobQueue::<Webmention>::from_pool(pool).await?; + tracing::debug!("Putting webmention into queue"); + queue.put(&test_webmention).await?; + { + let mut job_description = queue.get_one().await?.unwrap(); + assert_eq!(job_description.job(), &test_webmention); + assert_eq!(job_description.attempts().await?, 0); + } + tracing::debug!("Creating a stream"); + let mut stream = queue.clone().into_stream().await?; + + { + let mut guard = stream.next().await.transpose()?.unwrap(); + assert_eq!(guard.job(), &test_webmention); + assert_eq!(guard.attempts().await?, 1); + if let Some(item) = queue.get_one().await? { + panic!("Unexpected item {:?} returned from job queue!", item) + }; + } + + { + let mut guard = stream.next().await.transpose()?.unwrap(); + assert_eq!(guard.job(), &test_webmention); + assert_eq!(guard.attempts().await?, 2); + guard.done().await?; + } + + match queue.get_one().await? { + Some(item) => panic!("Unexpected item {:?} returned from job queue!", item), + None => Ok(()) + } + } + + #[sqlx::test(migrator = "MIGRATOR")] + #[tracing_test::traced_test] + async fn test_no_hangups_in_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> { + let test_webmention = Webmention { + source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(), + target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned() + }; + + let queue = PostgresJobQueue::<Webmention>::from_pool(pool.clone()).await?; + tracing::debug!("Putting webmention into queue"); + queue.put(&test_webmention).await?; + tracing::debug!("Creating a stream"); + let mut stream = queue.clone().into_stream().await?; + + // Synchronisation barrier that will be useful later + let barrier = Arc::new(tokio::sync::Barrier::new(2)); + { + // Get one job guard from a queue + let mut guard = stream.next().await.transpose()?.unwrap(); + assert_eq!(guard.job(), &test_webmention); + assert_eq!(guard.attempts().await?, 0); + + tokio::task::spawn({ + let barrier = barrier.clone(); + async move { + // Wait for the signal to drop the guard! + barrier.wait().await; + + drop(guard) + } + }); + } + tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()).await.unwrap_err(); + + let future = tokio::task::spawn( + tokio::time::timeout( + std::time::Duration::from_secs(10), async move { + stream.next().await.unwrap().unwrap() + } + ) + ); + // Let the other task drop the guard it is holding + barrier.wait().await; + let mut guard = future.await + .expect("Timeout on fetching item") + .expect("Job queue error"); + assert_eq!(guard.job(), &test_webmention); + assert_eq!(guard.attempts().await?, 1); + + Ok(()) + } +} |