diff options
author | Vika <vika@fireburn.ru> | 2023-07-29 21:59:56 +0300 |
---|---|---|
committer | Vika <vika@fireburn.ru> | 2023-07-29 21:59:56 +0300 |
commit | 0617663b249f9ca488e5de652108b17d67fbaf45 (patch) | |
tree | 11564b6c8fa37bf9203a0a4cc1c4e9cc088cb1a5 /kittybox-rs/src | |
parent | 26c2b79f6a6380ae3224e9309b9f3352f5717bd7 (diff) | |
download | kittybox-0617663b249f9ca488e5de652108b17d67fbaf45.tar.zst |
Moved the entire Kittybox tree into the root
Diffstat (limited to 'kittybox-rs/src')
31 files changed, 0 insertions, 9450 deletions
diff --git a/kittybox-rs/src/bin/kittybox-check-webmention.rs b/kittybox-rs/src/bin/kittybox-check-webmention.rs deleted file mode 100644 index f02032c..0000000 --- a/kittybox-rs/src/bin/kittybox-check-webmention.rs +++ /dev/null @@ -1,152 +0,0 @@ -use std::cell::{RefCell, Ref}; -use std::rc::Rc; - -use clap::Parser; -use microformats::types::PropertyValue; -use microformats::html5ever; -use microformats::html5ever::tendril::TendrilSink; - -#[derive(thiserror::Error, Debug)] -enum Error { - #[error("http request error: {0}")] - Http(#[from] reqwest::Error), - #[error("microformats error: {0}")] - Microformats(#[from] microformats::Error), - #[error("json error: {0}")] - Json(#[from] serde_json::Error), - #[error("url parse error: {0}")] - UrlParse(#[from] url::ParseError), -} - -use kittybox_util::MentionType; - -fn check_mention(document: impl AsRef<str>, base_url: &url::Url, link: &url::Url) -> Result<Option<MentionType>, Error> { - // First, check the document for MF2 markup - let document = microformats::from_html(document.as_ref(), base_url.clone())?; - - // Get an iterator of all items - let items_iter = document.items.iter() - .map(AsRef::as_ref) - .map(RefCell::borrow); - - for item in items_iter { - let props = item.properties.borrow(); - for (prop, interaction_type) in [ - ("in-reply-to", MentionType::Reply), ("like-of", MentionType::Like), - ("bookmark-of", MentionType::Bookmark), ("repost-of", MentionType::Repost) - ] { - if let Some(propvals) = props.get(prop) { - for val in propvals { - if let PropertyValue::Url(url) = val { - if url == link { - return Ok(Some(interaction_type)) - } - } - } - } - } - // Process `content` - if let Some(PropertyValue::Fragment(content)) = props.get("content") - .map(Vec::as_slice) - .unwrap_or_default() - .first() - { - let root = html5ever::parse_document(html5ever::rcdom::RcDom::default(), Default::default()) - .from_utf8() - .one(content.html.to_owned().as_bytes()) - .document; - - // This is a trick to unwrap recursion into a loop - // - // A list of unprocessed node is made. Then, in each - // iteration, the list is "taken" and replaced with an - // empty list, which is populated with nodes for the next - // iteration of the loop. - // - // Empty list means all nodes were processed. - let mut unprocessed_nodes: Vec<Rc<html5ever::rcdom::Node>> = root.children.borrow().iter().cloned().collect(); - while unprocessed_nodes.len() > 0 { - // "Take" the list out of its memory slot, replace it with an empty list - let nodes = std::mem::take(&mut unprocessed_nodes); - 'nodes_loop: for node in nodes.into_iter() { - // Add children nodes to the list for the next iteration - unprocessed_nodes.extend(node.children.borrow().iter().cloned()); - - if let html5ever::rcdom::NodeData::Element { ref name, ref attrs, .. } = node.data { - // If it's not `<a>`, skip it - if name.local != *"a" { continue; } - let mut is_mention: bool = false; - for attr in attrs.borrow().iter() { - if attr.name.local == *"rel" { - // Don't count `rel="nofollow"` links — a web crawler should ignore them - // and so for purposes of driving visitors they are useless - if attr.value - .as_ref() - .split([',', ' ']) - .any(|v| v == "nofollow") - { - // Skip the entire node. - continue 'nodes_loop; - } - } - // if it's not `<a href="...">`, skip it - if attr.name.local != *"href" { continue; } - // Be forgiving in parsing URLs, and resolve them against the base URL - if let Ok(url) = base_url.join(attr.value.as_ref()) { - if &url == link { - is_mention = true; - } - } - } - if is_mention { - return Ok(Some(MentionType::Mention)); - } - } - } - } - - } - } - - Ok(None) -} - -#[derive(Parser, Debug)] -#[clap( - name = "kittybox-check-webmention", - author = "Vika <vika@fireburn.ru>", - version = env!("CARGO_PKG_VERSION"), - about = "Verify an incoming webmention" -)] -struct Args { - #[clap(value_parser)] - url: url::Url, - #[clap(value_parser)] - link: url::Url -} - -#[tokio::main] -async fn main() -> Result<(), self::Error> { - let args = Args::parse(); - - let http: reqwest::Client = { - #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION") - )); - - builder.build().unwrap() - }; - - let response = http.get(args.url.clone()).send().await?; - let text = response.text().await?; - - if let Some(mention_type) = check_mention(text, &args.url, &args.link)? { - println!("{:?}", mention_type); - - Ok(()) - } else { - std::process::exit(1) - } -} diff --git a/kittybox-rs/src/bin/kittybox-indieauth-helper.rs b/kittybox-rs/src/bin/kittybox-indieauth-helper.rs deleted file mode 100644 index 3377ec3..0000000 --- a/kittybox-rs/src/bin/kittybox-indieauth-helper.rs +++ /dev/null @@ -1,233 +0,0 @@ -use kittybox_indieauth::{ - AuthorizationRequest, PKCEVerifier, - PKCEChallenge, PKCEMethod, GrantRequest, Scope, - AuthorizationResponse, TokenData, GrantResponse -}; -use clap::Parser; -use std::{borrow::Cow, io::Write}; - -const DEFAULT_CLIENT_ID: &str = "https://kittybox.fireburn.ru/indieauth-helper.html"; -const DEFAULT_REDIRECT_URI: &str = "http://localhost:60000/callback"; - -#[derive(Debug, thiserror::Error)] -enum Error { - #[error("i/o error: {0}")] - IO(#[from] std::io::Error), - #[error("http request error: {0}")] - HTTP(#[from] reqwest::Error), - #[error("urlencoded encoding error: {0}")] - UrlencodedEncoding(#[from] serde_urlencoded::ser::Error), - #[error("url parsing error: {0}")] - UrlParse(#[from] url::ParseError), - #[error("indieauth flow error: {0}")] - IndieAuth(Cow<'static, str>) -} - -#[derive(Parser, Debug)] -#[clap( - name = "kittybox-indieauth-helper", - author = "Vika <vika@fireburn.ru>", - version = env!("CARGO_PKG_VERSION"), - about = "Retrieve an IndieAuth token for debugging", - long_about = None -)] -struct Args { - /// Profile URL to use for initiating IndieAuth metadata discovery. - #[clap(value_parser)] - me: url::Url, - /// Scopes to request for the token. - /// - /// All IndieAuth scopes are supported, including arbitrary custom scopes. - #[clap(short, long)] - scope: Vec<Scope>, - /// Client ID to use when requesting a token. - #[clap(short, long, value_parser, default_value = DEFAULT_CLIENT_ID)] - client_id: url::Url, - /// Redirect URI to declare. Note: This will break the flow, use only for testing UI. - #[clap(long, value_parser)] - redirect_uri: Option<url::Url> -} - -fn append_query_string<T: serde::Serialize>( - url: &url::Url, - query: T -) -> Result<url::Url, Error> { - let mut new_url = url.clone(); - let mut query = serde_urlencoded::to_string(query)?; - if let Some(old_query) = url.query() { - query.push('&'); - query.push_str(old_query); - } - new_url.set_query(Some(&query)); - - Ok(new_url) -} - -#[tokio::main] -async fn main() -> Result<(), Error> { - let args = Args::parse(); - - let http: reqwest::Client = { - #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION") - )); - - builder.build().unwrap() - }; - - let redirect_uri: url::Url = args.redirect_uri - .clone() - .unwrap_or_else(|| DEFAULT_REDIRECT_URI.parse().unwrap()); - - eprintln!("Checking .well-known for metadata..."); - let metadata = http.get(args.me.join("/.well-known/oauth-authorization-server")?) - .header("Accept", "application/json") - .send() - .await? - .json::<kittybox_indieauth::Metadata>() - .await?; - - let verifier = PKCEVerifier::new(); - - let authorization_request = AuthorizationRequest { - response_type: kittybox_indieauth::ResponseType::Code, - client_id: args.client_id.clone(), - redirect_uri: redirect_uri.clone(), - state: kittybox_indieauth::State::new(), - code_challenge: PKCEChallenge::new(&verifier, PKCEMethod::default()), - scope: Some(kittybox_indieauth::Scopes::new(args.scope)), - me: Some(args.me) - }; - - let indieauth_url = append_query_string( - &metadata.authorization_endpoint, - authorization_request - )?; - - eprintln!("Please visit the following URL in your browser:\n\n {}\n", indieauth_url.as_str()); - - if args.redirect_uri.is_some() { - eprintln!("Custom redirect URI specified, won't be able to catch authorization response."); - std::process::exit(0); - } - - // Prepare a callback - let (tx, rx) = tokio::sync::oneshot::channel::<AuthorizationResponse>(); - let server = { - use axum::{routing::get, extract::Query, response::IntoResponse}; - - let tx = std::sync::Arc::new(tokio::sync::Mutex::new(Some(tx))); - - let router = axum::Router::new() - .route("/callback", axum::routing::get( - move |query: Option<Query<AuthorizationResponse>>| async move { - if let Some(Query(response)) = query { - if let Some(tx) = tx.lock_owned().await.take() { - tx.send(response).unwrap(); - - (axum::http::StatusCode::OK, - [("Content-Type", "text/plain")], - "Thank you! This window can now be closed.") - .into_response() - } else { - (axum::http::StatusCode::BAD_REQUEST, - [("Content-Type", "text/plain")], - "Oops. The callback was already received. Did you click twice?") - .into_response() - } - } else { - axum::http::StatusCode::BAD_REQUEST.into_response() - } - } - )); - - use std::net::{SocketAddr, IpAddr, Ipv4Addr}; - - let server = hyper::server::Server::bind( - &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST),60000) - ) - .serve(router.into_make_service()); - - tokio::task::spawn(server) - }; - - let authorization_response = rx.await.unwrap(); - - // Clean up after the server - tokio::task::spawn(async move { - // Wait for the server to settle -- it might need to send its response - tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - // Abort the future -- this should kill the server - server.abort(); - }); - - eprintln!("Got authorization response: {:#?}", authorization_response); - eprint!("Checking issuer field..."); - std::io::stderr().lock().flush()?; - - if dbg!(authorization_response.iss.as_str()) == dbg!(metadata.issuer.as_str()) { - eprintln!(" Done"); - } else { - eprintln!(" Failed"); - #[cfg(not(debug_assertions))] - std::process::exit(1); - } - let grant_response: GrantResponse = http.post(metadata.token_endpoint) - .form(&GrantRequest::AuthorizationCode { - code: authorization_response.code, - client_id: args.client_id, - redirect_uri, - code_verifier: verifier - }) - .header("Accept", "application/json") - .send() - .await? - .json() - .await?; - - if let GrantResponse::AccessToken { - me, - profile, - access_token, - expires_in, - refresh_token, - token_type, - scope - } = grant_response { - eprintln!("Congratulations, {}, access token is ready! {}", - me.as_str(), - if let Some(exp) = expires_in { - format!("It expires in {exp} seconds.") - } else { - format!("It seems to have unlimited duration.") - } - ); - println!("{}", access_token); - if let Some(refresh_token) = refresh_token { - eprintln!("Save this refresh token, it will come in handy:"); - println!("{}", refresh_token); - }; - - if let Some(profile) = profile { - eprintln!("\nThe token endpoint returned some profile information:"); - if let Some(name) = profile.name { - eprintln!(" - Name: {name}") - } - if let Some(url) = profile.url { - eprintln!(" - URL: {url}") - } - if let Some(photo) = profile.photo { - eprintln!(" - Photo: {photo}") - } - if let Some(email) = profile.email { - eprintln!(" - Email: {email}") - } - } - - Ok(()) - } else { - return Err(Error::IndieAuth(Cow::Borrowed("IndieAuth token endpoint did not return an access token grant."))); - } -} diff --git a/kittybox-rs/src/bin/kittybox-mf2.rs b/kittybox-rs/src/bin/kittybox-mf2.rs deleted file mode 100644 index 4366cb8..0000000 --- a/kittybox-rs/src/bin/kittybox-mf2.rs +++ /dev/null @@ -1,49 +0,0 @@ -use clap::Parser; - -#[derive(Parser, Debug)] -#[clap( - name = "kittybox-mf2", - author = "Vika <vika@fireburn.ru>", - version = env!("CARGO_PKG_VERSION"), - about = "Fetch HTML and turn it into MF2-JSON" -)] -struct Args { - #[clap(value_parser)] - url: url::Url, -} - -#[derive(thiserror::Error, Debug)] -enum Error { - #[error("http request error: {0}")] - Http(#[from] reqwest::Error), - #[error("microformats error: {0}")] - Microformats(#[from] microformats::Error), - #[error("json error: {0}")] - Json(#[from] serde_json::Error), - #[error("url parse error: {0}")] - UrlParse(#[from] url::ParseError), -} - -#[tokio::main] -async fn main() -> Result<(), Error> { - let args = Args::parse(); - - let http: reqwest::Client = { - #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION") - )); - - builder.build().unwrap() - }; - - let response = http.get(args.url.clone()).send().await?; - let text = response.text().await?; - - let mf2 = microformats::from_html(text.as_ref(), args.url)?; - - println!("{}", serde_json::to_string_pretty(&mf2)?); - - Ok(()) -} diff --git a/kittybox-rs/src/bin/kittybox_bulk_import.rs b/kittybox-rs/src/bin/kittybox_bulk_import.rs deleted file mode 100644 index 7e1f6af..0000000 --- a/kittybox-rs/src/bin/kittybox_bulk_import.rs +++ /dev/null @@ -1,66 +0,0 @@ -use anyhow::{anyhow, bail, Context, Result}; -use std::fs::File; -use std::io; - -#[async_std::main] -async fn main() -> Result<()> { - let args = std::env::args().collect::<Vec<String>>(); - if args.iter().skip(1).any(|s| s == "--help") { - println!("Usage: {} <url> [file]", args[0]); - println!("\nIf launched with no arguments, reads from stdin."); - println!( - "\nUse KITTYBOX_AUTH_TOKEN environment variable to authorize to the Micropub endpoint." - ); - std::process::exit(0); - } - - let token = std::env::var("KITTYBOX_AUTH_TOKEN") - .map_err(|_| anyhow!("No auth token found! Use KITTYBOX_AUTH_TOKEN env variable."))?; - let data: Vec<serde_json::Value> = (if args.len() == 2 || (args.len() == 3 && args[2] == "-") { - serde_json::from_reader(io::stdin()) - } else if args.len() == 3 { - serde_json::from_reader(File::open(&args[2]).with_context(|| "Error opening input file")?) - } else { - bail!("See `{} --help` for usage.", args[0]); - }) - .with_context(|| "Error while loading the input file")?; - - let url = surf::Url::parse(&args[1])?; - let client = surf::Client::new(); - - let iter = data.into_iter(); - - for post in iter { - println!( - "Processing {}...", - post["properties"]["url"][0] - .as_str() - .or_else(|| post["properties"]["published"][0] - .as_str() - .or_else(|| post["properties"]["name"][0] - .as_str() - .or(Some("<unidentified post>")))) - .unwrap() - ); - match client - .post(&url) - .body(surf::http::Body::from_string(serde_json::to_string(&post)?)) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", &token)) - .send() - .await - { - Ok(mut response) => { - if response.status() == 201 || response.status() == 202 { - println!("Posted at {}", response.header("location").unwrap().last()); - } else { - println!("Error: {:?}", response.body_string().await); - } - } - Err(err) => { - println!("{}", err); - } - } - } - Ok(()) -} diff --git a/kittybox-rs/src/bin/kittybox_database_converter.rs b/kittybox-rs/src/bin/kittybox_database_converter.rs deleted file mode 100644 index bc355c9..0000000 --- a/kittybox-rs/src/bin/kittybox_database_converter.rs +++ /dev/null @@ -1,106 +0,0 @@ -use anyhow::{anyhow, Context}; -use kittybox::database::FileStorage; -use kittybox::database::Storage; -use redis::{self, AsyncCommands}; -use std::collections::HashMap; - -/// Convert from a Redis storage to a new storage new_storage. -async fn convert_from_redis<S: Storage>(from: String, new_storage: S) -> anyhow::Result<()> { - let db = redis::Client::open(from).context("Failed to open the Redis connection")?; - - let mut conn = db - .get_async_std_connection() - .await - .context("Failed to connect to Redis")?; - - // Rebinding to convince the borrow checker we're not smuggling stuff outta scope - let storage = &new_storage; - - let mut stream = conn.hscan::<_, String>("posts").await?; - - while let Some(key) = stream.next_item().await { - let value = serde_json::from_str::<serde_json::Value>( - &stream - .next_item() - .await - .ok_or(anyhow!("Failed to find a corresponding value for the key"))?, - )?; - - println!("{}, {:?}", key, value); - - if value["see_other"].is_string() { - continue; - } - - let user = &(url::Url::parse(value["properties"]["uid"][0].as_str().unwrap()) - .unwrap() - .origin() - .ascii_serialization() - .clone() - + "/"); - if let Err(err) = storage.clone().put_post(&value, user).await { - eprintln!("Error saving post: {}", err); - } - } - - let mut stream: redis::AsyncIter<String> = conn.scan_match("settings_*").await?; - while let Some(key) = stream.next_item().await { - let mut conn = db - .get_async_std_connection() - .await - .context("Failed to connect to Redis")?; - let user = key.strip_prefix("settings_").unwrap(); - match conn - .hgetall::<&str, HashMap<String, String>>(&key) - .await - .context(format!("Failed getting settings from key {}", key)) - { - Ok(settings) => { - for (k, v) in settings.iter() { - if let Err(e) = storage - .set_setting(k, user, v) - .await - .with_context(|| format!("Failed setting {} for {}", k, user)) - { - eprintln!("{}", e); - } - } - } - Err(e) => { - eprintln!("{}", e); - } - } - } - - Ok(()) -} - -#[async_std::main] -async fn main() -> anyhow::Result<()> { - let mut args = std::env::args(); - args.next(); // skip argv[0] - let old_uri = args - .next() - .ok_or_else(|| anyhow!("No import source is provided."))?; - let new_uri = args - .next() - .ok_or_else(|| anyhow!("No import destination is provided."))?; - - let storage = if new_uri.starts_with("file:") { - let folder = new_uri.strip_prefix("file://").unwrap(); - let path = std::path::PathBuf::from(folder); - Box::new( - FileStorage::new(path) - .await - .context("Failed to construct the file storage")?, - ) - } else { - anyhow::bail!("Cannot construct the storage abstraction for destination storage. Check the storage type?"); - }; - - if old_uri.starts_with("redis") { - convert_from_redis(old_uri, *storage).await? - } - - Ok(()) -} diff --git a/kittybox-rs/src/database/file/mod.rs b/kittybox-rs/src/database/file/mod.rs deleted file mode 100644 index 27d3da1..0000000 --- a/kittybox-rs/src/database/file/mod.rs +++ /dev/null @@ -1,733 +0,0 @@ -//#![warn(clippy::unwrap_used)] -use crate::database::{ErrorKind, Result, settings, Storage, StorageError}; -use crate::micropub::{MicropubUpdate, MicropubPropertyDeletion}; -use async_trait::async_trait; -use futures::{stream, StreamExt, TryStreamExt}; -use kittybox_util::MentionType; -use serde_json::json; -use std::borrow::Cow; -use std::collections::HashMap; -use std::io::ErrorKind as IOErrorKind; -use std::path::{Path, PathBuf}; -use tokio::fs::{File, OpenOptions}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::task::spawn_blocking; -use tracing::{debug, error}; - -impl From<std::io::Error> for StorageError { - fn from(source: std::io::Error) -> Self { - Self::with_source( - match source.kind() { - IOErrorKind::NotFound => ErrorKind::NotFound, - IOErrorKind::AlreadyExists => ErrorKind::Conflict, - _ => ErrorKind::Backend, - }, - Cow::Owned(format!("file I/O error: {}", &source)), - Box::new(source), - ) - } -} - -impl From<tokio::time::error::Elapsed> for StorageError { - fn from(source: tokio::time::error::Elapsed) -> Self { - Self::with_source( - ErrorKind::Backend, - Cow::Borrowed("timeout on I/O operation"), - Box::new(source), - ) - } -} - -// Copied from https://stackoverflow.com/questions/39340924 -// This routine is adapted from the *old* Path's `path_relative_from` -// function, which works differently from the new `relative_from` function. -// In particular, this handles the case on unix where both paths are -// absolute but with only the root as the common directory. -fn path_relative_from(path: &Path, base: &Path) -> Option<PathBuf> { - use std::path::Component; - - if path.is_absolute() != base.is_absolute() { - if path.is_absolute() { - Some(PathBuf::from(path)) - } else { - None - } - } else { - let mut ita = path.components(); - let mut itb = base.components(); - let mut comps: Vec<Component> = vec![]; - loop { - match (ita.next(), itb.next()) { - (None, None) => break, - (Some(a), None) => { - comps.push(a); - comps.extend(ita.by_ref()); - break; - } - (None, _) => comps.push(Component::ParentDir), - (Some(a), Some(b)) if comps.is_empty() && a == b => (), - (Some(a), Some(b)) if b == Component::CurDir => comps.push(a), - (Some(_), Some(b)) if b == Component::ParentDir => return None, - (Some(a), Some(_)) => { - comps.push(Component::ParentDir); - for _ in itb { - comps.push(Component::ParentDir); - } - comps.push(a); - comps.extend(ita.by_ref()); - break; - } - } - } - Some(comps.iter().map(|c| c.as_os_str()).collect()) - } -} - -#[allow(clippy::unwrap_used, clippy::expect_used)] -#[cfg(test)] -mod tests { - #[test] - fn test_relative_path_resolving() { - let path1 = std::path::Path::new("/home/vika/Projects/kittybox"); - let path2 = std::path::Path::new("/home/vika/Projects/nixpkgs"); - let relative_path = super::path_relative_from(path2, path1).unwrap(); - - assert_eq!(relative_path, std::path::Path::new("../nixpkgs")) - } -} - -// TODO: Check that the path ACTUALLY IS INSIDE THE ROOT FOLDER -// This could be checked by completely resolving the path -// and checking if it has a common prefix -fn url_to_path(root: &Path, url: &str) -> PathBuf { - let path = url_to_relative_path(url).to_logical_path(root); - if !path.starts_with(root) { - // TODO: handle more gracefully - panic!("Security error: {:?} is not a prefix of {:?}", path, root) - } else { - path - } -} - -fn url_to_relative_path(url: &str) -> relative_path::RelativePathBuf { - let url = url::Url::try_from(url).expect("Couldn't parse a URL"); - let mut path = relative_path::RelativePathBuf::new(); - let user_domain = format!( - "{}{}", - url.host_str().unwrap(), - url.port() - .map(|port| format!(":{}", port)) - .unwrap_or_default() - ); - path.push(user_domain + url.path() + ".json"); - - path -} - -fn modify_post(post: &serde_json::Value, update: MicropubUpdate) -> Result<serde_json::Value> { - let mut post = post.clone(); - - let mut add_keys: HashMap<String, Vec<serde_json::Value>> = HashMap::new(); - let mut remove_keys: Vec<String> = vec![]; - let mut remove_values: HashMap<String, Vec<serde_json::Value>> = HashMap::new(); - - if let Some(MicropubPropertyDeletion::Properties(delete)) = update.delete { - remove_keys.extend(delete.iter().cloned()); - } else if let Some(MicropubPropertyDeletion::Values(delete)) = update.delete { - for (k, v) in delete { - remove_values - .entry(k.to_string()) - .or_default() - .extend(v.clone()); - } - } - if let Some(add) = update.add { - for (k, v) in add { - add_keys.insert(k.to_string(), v.clone()); - } - } - if let Some(replace) = update.replace { - for (k, v) in replace { - remove_keys.push(k.to_string()); - add_keys.insert(k.to_string(), v.clone()); - } - } - - if let Some(props) = post["properties"].as_object_mut() { - for k in remove_keys { - props.remove(&k); - } - } - for (k, v) in remove_values { - let k = &k; - let props = if k == "children" { - &mut post - } else { - &mut post["properties"] - }; - v.iter().for_each(|v| { - if let Some(vec) = props[k].as_array_mut() { - if let Some(index) = vec.iter().position(|w| w == v) { - vec.remove(index); - } - } - }); - } - for (k, v) in add_keys { - tracing::debug!("Adding k/v to post: {} => {:?}", k, v); - let props = if k == "children" { - &mut post - } else { - &mut post["properties"] - }; - if let Some(prop) = props[&k].as_array_mut() { - if k == "children" { - v.into_iter().rev().for_each(|v| prop.insert(0, v)); - } else { - prop.extend(v.into_iter()); - } - } else { - props[&k] = serde_json::Value::Array(v) - } - } - Ok(post) -} - -#[derive(Clone, Debug)] -/// A backend using a folder with JSON files as a backing store. -/// Uses symbolic links to represent a many-to-one mapping of URLs to a post. -pub struct FileStorage { - root_dir: PathBuf, -} - -impl FileStorage { - /// Create a new storage wrapping a folder specified by root_dir. - pub async fn new(root_dir: PathBuf) -> Result<Self> { - // TODO check if the dir is writable - Ok(Self { root_dir }) - } -} - -async fn hydrate_author<S: Storage>( - feed: &mut serde_json::Value, - user: &'_ Option<String>, - storage: &S, -) { - let url = feed["properties"]["uid"][0] - .as_str() - .expect("MF2 value should have a UID set! Check if you used normalize_mf2 before recording the post!"); - if let Some(author) = feed["properties"]["author"].as_array().cloned() { - if !feed["type"] - .as_array() - .expect("MF2 value should have a type set!") - .iter() - .any(|i| i == "h-card") - { - let author_list: Vec<serde_json::Value> = stream::iter(author.iter()) - .then(|i| async move { - if let Some(i) = i.as_str() { - match storage.get_post(i).await { - Ok(post) => match post { - Some(post) => post, - None => json!(i), - }, - Err(e) => { - error!("Error while hydrating post {}: {}", url, e); - json!(i) - } - } - } else { - i.clone() - } - }) - .collect::<Vec<_>>() - .await; - if let Some(props) = feed["properties"].as_object_mut() { - props["author"] = json!(author_list); - } else { - feed["properties"] = json!({ "author": author_list }); - } - } - } -} - -#[async_trait] -impl Storage for FileStorage { - #[tracing::instrument(skip(self))] - async fn post_exists(&self, url: &str) -> Result<bool> { - let path = url_to_path(&self.root_dir, url); - debug!("Checking if {:?} exists...", path); - /*let result = match tokio::fs::metadata(path).await { - Ok(metadata) => { - Ok(true) - }, - Err(err) => { - if err.kind() == IOErrorKind::NotFound { - Ok(false) - } else { - Err(err.into()) - } - } - };*/ - #[allow(clippy::unwrap_used)] // JoinHandle captures panics, this closure shouldn't panic - Ok(spawn_blocking(move || path.is_file()).await.unwrap()) - } - - #[tracing::instrument(skip(self))] - async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> { - let path = url_to_path(&self.root_dir, url); - // TODO: check that the path actually belongs to the dir of user who requested it - // it's not like you CAN access someone else's private posts with it - // so it's not exactly a security issue, but it's still not good - debug!("Opening {:?}", path); - - match File::open(&path).await { - Ok(mut file) => { - let mut content = String::new(); - // Typechecks because OS magic acts on references - // to FDs as if they were behind a mutex - AsyncReadExt::read_to_string(&mut file, &mut content).await?; - debug!( - "Read {} bytes successfully from {:?}", - content.as_bytes().len(), - &path - ); - Ok(Some(serde_json::from_str(&content)?)) - } - Err(err) => { - if err.kind() == IOErrorKind::NotFound { - Ok(None) - } else { - Err(err.into()) - } - } - } - } - - #[tracing::instrument(skip(self))] - async fn put_post(&self, post: &'_ serde_json::Value, user: &'_ str) -> Result<()> { - let key = post["properties"]["uid"][0] - .as_str() - .expect("Tried to save a post without UID"); - let path = url_to_path(&self.root_dir, key); - let tempfile = (&path).with_extension("tmp"); - debug!("Creating {:?}", path); - - let parent = path - .parent() - .expect("Parent for this directory should always exist") - .to_owned(); - tokio::fs::create_dir_all(&parent).await?; - - let mut file = tokio::fs::OpenOptions::new() - .write(true) - .create_new(true) - .open(&tempfile) - .await?; - - file.write_all(post.to_string().as_bytes()).await?; - file.flush().await?; - file.sync_all().await?; - drop(file); - tokio::fs::rename(&tempfile, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; - - if let Some(urls) = post["properties"]["url"].as_array() { - for url in urls.iter().map(|i| i.as_str().unwrap()) { - let url_domain = { - let url = url::Url::parse(url).unwrap(); - format!( - "{}{}", - url.host_str().unwrap(), - url.port() - .map(|port| format!(":{}", port)) - .unwrap_or_default() - ) - }; - if url != key && url_domain == user { - let link = url_to_path(&self.root_dir, url); - debug!("Creating a symlink at {:?}", link); - let orig = path.clone(); - // We're supposed to have a parent here. - let basedir = link.parent().ok_or_else(|| { - StorageError::from_static( - ErrorKind::Backend, - "Failed to calculate parent directory when creating a symlink", - ) - })?; - let relative = path_relative_from(&orig, basedir).unwrap(); - println!("{:?} - {:?} = {:?}", &orig, &basedir, &relative); - tokio::fs::symlink(relative, link).await?; - } - } - } - - if post["type"] - .as_array() - .unwrap() - .iter() - .any(|s| s.as_str() == Some("h-feed")) - { - tracing::debug!("Adding to channel list..."); - // Add the h-feed to the channel list - let path = { - let mut path = relative_path::RelativePathBuf::new(); - path.push(user); - path.push("channels"); - - path.to_path(&self.root_dir) - }; - tokio::fs::create_dir_all(path.parent().unwrap()).await?; - tracing::debug!("Channels file path: {}", path.display()); - let tempfilename = path.with_extension("tmp"); - let channel_name = post["properties"]["name"][0] - .as_str() - .map(|s| s.to_string()) - .unwrap_or_else(String::default); - let key = key.to_string(); - tracing::debug!("Opening temporary file to modify chnanels..."); - let mut tempfile = OpenOptions::new() - .write(true) - .create_new(true) - .open(&tempfilename) - .await?; - tracing::debug!("Opening real channel file..."); - let mut channels: Vec<super::MicropubChannel> = { - match OpenOptions::new() - .read(true) - .write(false) - .truncate(false) - .create(false) - .open(&path) - .await - { - Err(err) if err.kind() == std::io::ErrorKind::NotFound => { - Vec::default() - } - Err(err) => { - // Propagate the error upwards - return Err(err.into()); - } - Ok(mut file) => { - let mut content = String::new(); - file.read_to_string(&mut content).await?; - drop(file); - - if !content.is_empty() { - serde_json::from_str(&content)? - } else { - Vec::default() - } - } - } - }; - - channels.push(super::MicropubChannel { - uid: key.to_string(), - name: channel_name, - }); - - tempfile - .write_all(serde_json::to_string(&channels)?.as_bytes()) - .await?; - tempfile.flush().await?; - tempfile.sync_all().await?; - drop(tempfile); - tokio::fs::rename(tempfilename, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; - } - Ok(()) - } - - #[tracing::instrument(skip(self))] - async fn update_post(&self, url: &str, update: MicropubUpdate) -> Result<()> { - let path = url_to_path(&self.root_dir, url); - let tempfilename = path.with_extension("tmp"); - #[allow(unused_variables)] - let (old_json, new_json) = { - let mut temp = OpenOptions::new() - .write(true) - .create_new(true) - .open(&tempfilename) - .await?; - let mut file = OpenOptions::new().read(true).open(&path).await?; - - let mut content = String::new(); - file.read_to_string(&mut content).await?; - let json: serde_json::Value = serde_json::from_str(&content)?; - drop(file); - // Apply the editing algorithms - let new_json = modify_post(&json, update)?; - - temp.write_all(new_json.to_string().as_bytes()).await?; - temp.flush().await?; - temp.sync_all().await?; - drop(temp); - tokio::fs::rename(tempfilename, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; - - (json, new_json) - }; - // TODO check if URLs changed between old and new JSON - Ok(()) - } - - #[tracing::instrument(skip(self))] - async fn get_channels(&self, user: &'_ str) -> Result<Vec<super::MicropubChannel>> { - let mut path = relative_path::RelativePathBuf::new(); - path.push(user); - path.push("channels"); - - let path = path.to_path(&self.root_dir); - tracing::debug!("Channels file path: {}", path.display()); - - match File::open(&path).await { - Ok(mut f) => { - let mut content = String::new(); - f.read_to_string(&mut content).await?; - // This should not happen, but if it does, handle it gracefully - if content.is_empty() { - return Ok(vec![]); - } - let channels: Vec<super::MicropubChannel> = serde_json::from_str(&content)?; - Ok(channels) - } - Err(e) => { - if e.kind() == IOErrorKind::NotFound { - Ok(vec![]) - } else { - Err(e.into()) - } - } - } - } - - async fn read_feed_with_cursor( - &self, - url: &'_ str, - cursor: Option<&'_ str>, - limit: usize, - user: Option<&'_ str> - ) -> Result<Option<(serde_json::Value, Option<String>)>> { - Ok(self.read_feed_with_limit( - url, - &cursor.map(|v| v.to_owned()), - limit, - &user.map(|v| v.to_owned()) - ).await? - .map(|feed| { - tracing::debug!("Feed: {:#}", serde_json::Value::Array( - feed["children"] - .as_array() - .map(|v| v.as_slice()) - .unwrap_or_default() - .iter() - .map(|mf2| mf2["properties"]["uid"][0].clone()) - .collect::<Vec<_>>() - )); - let cursor: Option<String> = feed["children"] - .as_array() - .map(|v| v.as_slice()) - .unwrap_or_default() - .last() - .map(|v| v["properties"]["uid"][0].as_str().unwrap().to_owned()); - tracing::debug!("Extracted the cursor: {:?}", cursor); - (feed, cursor) - }) - ) - } - - #[tracing::instrument(skip(self))] - async fn read_feed_with_limit( - &self, - url: &'_ str, - after: &'_ Option<String>, - limit: usize, - user: &'_ Option<String>, - ) -> Result<Option<serde_json::Value>> { - if let Some(mut feed) = self.get_post(url).await? { - if feed["children"].is_array() { - // Take this out of the MF2-JSON document to save memory - // - // This uses a clever match with enum destructuring - // to extract the underlying Vec without cloning it - let children: Vec<serde_json::Value> = match feed["children"].take() { - serde_json::Value::Array(children) => children, - // We've already checked it's an array - _ => unreachable!() - }; - tracing::debug!("Full children array: {:#}", serde_json::Value::Array(children.clone())); - let mut posts_iter = children - .into_iter() - .map(|s: serde_json::Value| s.as_str().unwrap().to_string()); - // Note: we can't actually use `skip_while` here because we end up emitting `after`. - // This imperative snippet consumes after instead of emitting it, allowing the - // stream of posts to return only those items that truly come *after* that one. - // If I would implement an Iter combinator like this, I would call it `skip_until` - if let Some(after) = after { - for s in posts_iter.by_ref() { - if &s == after { - break; - } - } - }; - let posts = stream::iter(posts_iter) - .map(|url: String| async move { self.get_post(&url).await }) - .buffered(std::cmp::min(3, limit)) - // Hack to unwrap the Option and sieve out broken links - // Broken links return None, and Stream::filter_map skips Nones. - .try_filter_map(|post: Option<serde_json::Value>| async move { Ok(post) }) - .and_then(|mut post| async move { - hydrate_author(&mut post, user, self).await; - Ok(post) - }) - .take(limit); - - match posts.try_collect::<Vec<serde_json::Value>>().await { - Ok(posts) => feed["children"] = serde_json::json!(posts), - Err(err) => { - return Err(StorageError::with_source( - ErrorKind::Other, - Cow::Owned(format!("Feed assembly error: {}", &err)), - Box::new(err), - )); - } - } - } - hydrate_author(&mut feed, user, self).await; - Ok(Some(feed)) - } else { - Ok(None) - } - } - - #[tracing::instrument(skip(self))] - async fn delete_post(&self, url: &'_ str) -> Result<()> { - let path = url_to_path(&self.root_dir, url); - if let Err(e) = tokio::fs::remove_file(path).await { - Err(e.into()) - } else { - // TODO check for dangling references in the channel list - Ok(()) - } - } - - #[tracing::instrument(skip(self))] - async fn get_setting<S: settings::Setting<'a>, 'a>(&self, user: &'_ str) -> Result<S> { - debug!("User for getting settings: {}", user); - let mut path = relative_path::RelativePathBuf::new(); - path.push(user); - path.push("settings"); - - let path = path.to_path(&self.root_dir); - debug!("Getting settings from {:?}", &path); - - let mut file = File::open(path).await?; - let mut content = String::new(); - file.read_to_string(&mut content).await?; - - let settings: HashMap<&str, serde_json::Value> = serde_json::from_str(&content)?; - match settings.get(S::ID) { - Some(value) => Ok(serde_json::from_value::<S>(value.clone())?), - None => Err(StorageError::from_static(ErrorKind::Backend, "Setting not set")) - } - } - - #[tracing::instrument(skip(self))] - async fn set_setting<S: settings::Setting<'a> + 'a, 'a>(&self, user: &'a str, value: S::Data) -> Result<()> { - let mut path = relative_path::RelativePathBuf::new(); - path.push(user); - path.push("settings"); - - let path = path.to_path(&self.root_dir); - let temppath = path.with_extension("tmp"); - - let parent = path.parent().unwrap().to_owned(); - tokio::fs::create_dir_all(&parent).await?; - - let mut tempfile = OpenOptions::new() - .write(true) - .create_new(true) - .open(&temppath) - .await?; - - let mut settings: HashMap<String, serde_json::Value> = match File::open(&path).await { - Ok(mut f) => { - let mut content = String::new(); - f.read_to_string(&mut content).await?; - if content.is_empty() { - Default::default() - } else { - serde_json::from_str(&content)? - } - } - Err(err) => { - if err.kind() == IOErrorKind::NotFound { - Default::default() - } else { - return Err(err.into()); - } - } - }; - settings.insert(S::ID.to_owned(), serde_json::to_value(S::new(value))?); - - tempfile - .write_all(serde_json::to_string(&settings)?.as_bytes()) - .await?; - tempfile.flush().await?; - tempfile.sync_all().await?; - drop(tempfile); - tokio::fs::rename(temppath, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; - Ok(()) - } - - #[tracing::instrument(skip(self))] - async fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> Result<()> { - let path = url_to_path(&self.root_dir, target); - let tempfilename = path.with_extension("tmp"); - - let mut temp = OpenOptions::new() - .write(true) - .create_new(true) - .open(&tempfilename) - .await?; - let mut file = OpenOptions::new().read(true).open(&path).await?; - - let mut post: serde_json::Value = { - let mut content = String::new(); - file.read_to_string(&mut content).await?; - drop(file); - - serde_json::from_str(&content)? - }; - - let key: &'static str = match mention_type { - MentionType::Reply => "comment", - MentionType::Like => "like", - MentionType::Repost => "repost", - MentionType::Bookmark => "bookmark", - MentionType::Mention => "mention", - }; - let mention_uid = mention["properties"]["uid"][0].clone(); - if let Some(values) = post["properties"][key].as_array_mut() { - for value in values.iter_mut() { - if value["properties"]["uid"][0] == mention_uid { - *value = mention; - break; - } - } - } else { - post["properties"][key] = serde_json::Value::Array(vec![mention]); - } - - temp.write_all(post.to_string().as_bytes()).await?; - temp.flush().await?; - temp.sync_all().await?; - drop(temp); - tokio::fs::rename(tempfilename, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; - - Ok(()) - } -} diff --git a/kittybox-rs/src/database/memory.rs b/kittybox-rs/src/database/memory.rs deleted file mode 100644 index 6339e7a..0000000 --- a/kittybox-rs/src/database/memory.rs +++ /dev/null @@ -1,249 +0,0 @@ -#![allow(clippy::todo)] -use async_trait::async_trait; -use futures_util::FutureExt; -use serde_json::json; -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::RwLock; - -use crate::database::{ErrorKind, MicropubChannel, Result, settings, Storage, StorageError}; - -#[derive(Clone, Debug)] -pub struct MemoryStorage { - pub mapping: Arc<RwLock<HashMap<String, serde_json::Value>>>, - pub channels: Arc<RwLock<HashMap<String, Vec<String>>>>, -} - -#[async_trait] -impl Storage for MemoryStorage { - async fn post_exists(&self, url: &str) -> Result<bool> { - return Ok(self.mapping.read().await.contains_key(url)); - } - - async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> { - let mapping = self.mapping.read().await; - match mapping.get(url) { - Some(val) => { - if let Some(new_url) = val["see_other"].as_str() { - match mapping.get(new_url) { - Some(val) => Ok(Some(val.clone())), - None => { - drop(mapping); - self.mapping.write().await.remove(url); - Ok(None) - } - } - } else { - Ok(Some(val.clone())) - } - } - _ => Ok(None), - } - } - - async fn put_post(&self, post: &'_ serde_json::Value, _user: &'_ str) -> Result<()> { - let mapping = &mut self.mapping.write().await; - let key: &str = match post["properties"]["uid"][0].as_str() { - Some(uid) => uid, - None => { - return Err(StorageError::from_static( - ErrorKind::Other, - "post doesn't have a UID", - )) - } - }; - mapping.insert(key.to_string(), post.clone()); - if post["properties"]["url"].is_array() { - for url in post["properties"]["url"] - .as_array() - .unwrap() - .iter() - .map(|i| i.as_str().unwrap().to_string()) - { - if url != key { - mapping.insert(url, json!({ "see_other": key })); - } - } - } - if post["type"] - .as_array() - .unwrap() - .iter() - .any(|i| i == "h-feed") - { - // This is a feed. Add it to the channels array if it's not already there. - println!("{:#}", post); - self.channels - .write() - .await - .entry( - post["properties"]["author"][0] - .as_str() - .unwrap() - .to_string(), - ) - .or_insert_with(Vec::new) - .push(key.to_string()) - } - Ok(()) - } - - async fn update_post(&self, url: &'_ str, update: crate::micropub::MicropubUpdate) -> Result<()> { - let mut guard = self.mapping.write().await; - let mut post = guard.get_mut(url).ok_or(StorageError::from_static(ErrorKind::NotFound, "The specified post wasn't found in the database."))?; - - use crate::micropub::MicropubPropertyDeletion; - - let mut add_keys: HashMap<String, Vec<serde_json::Value>> = HashMap::new(); - let mut remove_keys: Vec<String> = vec![]; - let mut remove_values: HashMap<String, Vec<serde_json::Value>> = HashMap::new(); - - if let Some(MicropubPropertyDeletion::Properties(delete)) = update.delete { - remove_keys.extend(delete.iter().cloned()); - } else if let Some(MicropubPropertyDeletion::Values(delete)) = update.delete { - for (k, v) in delete { - remove_values - .entry(k.to_string()) - .or_default() - .extend(v.clone()); - } - } - if let Some(add) = update.add { - for (k, v) in add { - add_keys.insert(k.to_string(), v.clone()); - } - } - if let Some(replace) = update.replace { - for (k, v) in replace { - remove_keys.push(k.to_string()); - add_keys.insert(k.to_string(), v.clone()); - } - } - - if let Some(props) = post["properties"].as_object_mut() { - for k in remove_keys { - props.remove(&k); - } - } - for (k, v) in remove_values { - let k = &k; - let props = if k == "children" { - &mut post - } else { - &mut post["properties"] - }; - v.iter().for_each(|v| { - if let Some(vec) = props[k].as_array_mut() { - if let Some(index) = vec.iter().position(|w| w == v) { - vec.remove(index); - } - } - }); - } - for (k, v) in add_keys { - tracing::debug!("Adding k/v to post: {} => {:?}", k, v); - let props = if k == "children" { - &mut post - } else { - &mut post["properties"] - }; - if let Some(prop) = props[&k].as_array_mut() { - if k == "children" { - v.into_iter().rev().for_each(|v| prop.insert(0, v)); - } else { - prop.extend(v.into_iter()); - } - } else { - props[&k] = serde_json::Value::Array(v) - } - } - - Ok(()) - } - - async fn get_channels(&self, user: &'_ str) -> Result<Vec<MicropubChannel>> { - match self.channels.read().await.get(user) { - Some(channels) => Ok(futures_util::future::join_all( - channels - .iter() - .map(|channel| { - self.get_post(channel).map(|result| result.unwrap()).map( - |post: Option<serde_json::Value>| { - post.map(|post| MicropubChannel { - uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), - name: post["properties"]["name"][0] - .as_str() - .unwrap() - .to_string(), - }) - }, - ) - }) - .collect::<Vec<_>>(), - ) - .await - .into_iter() - .flatten() - .collect::<Vec<_>>()), - None => Ok(vec![]), - } - } - - #[allow(unused_variables)] - async fn read_feed_with_limit( - &self, - url: &'_ str, - after: &'_ Option<String>, - limit: usize, - user: &'_ Option<String>, - ) -> Result<Option<serde_json::Value>> { - todo!() - } - - #[allow(unused_variables)] - async fn read_feed_with_cursor( - &self, - url: &'_ str, - cursor: Option<&'_ str>, - limit: usize, - user: Option<&'_ str> - ) -> Result<Option<(serde_json::Value, Option<String>)>> { - todo!() - } - - async fn delete_post(&self, url: &'_ str) -> Result<()> { - self.mapping.write().await.remove(url); - Ok(()) - } - - #[allow(unused_variables)] - async fn get_setting<S: settings::Setting<'a>, 'a>(&'_ self, user: &'_ str) -> Result<S> { - todo!() - } - - #[allow(unused_variables)] - async fn set_setting<S: settings::Setting<'a> + 'a, 'a>(&self, user: &'a str, value: S::Data) -> Result<()> { - todo!() - } - - #[allow(unused_variables)] - async fn add_or_update_webmention(&self, target: &str, mention_type: kittybox_util::MentionType, mention: serde_json::Value) -> Result<()> { - todo!() - } - -} - -impl Default for MemoryStorage { - fn default() -> Self { - Self::new() - } -} - -impl MemoryStorage { - pub fn new() -> Self { - Self { - mapping: Arc::new(RwLock::new(HashMap::new())), - channels: Arc::new(RwLock::new(HashMap::new())), - } - } -} diff --git a/kittybox-rs/src/database/mod.rs b/kittybox-rs/src/database/mod.rs deleted file mode 100644 index b4b70b2..0000000 --- a/kittybox-rs/src/database/mod.rs +++ /dev/null @@ -1,793 +0,0 @@ -#![warn(missing_docs)] -use std::borrow::Cow; - -use async_trait::async_trait; -use kittybox_util::MentionType; - -mod file; -pub use crate::database::file::FileStorage; -use crate::micropub::MicropubUpdate; -#[cfg(feature = "postgres")] -mod postgres; -#[cfg(feature = "postgres")] -pub use postgres::PostgresStorage; - -#[cfg(test)] -mod memory; -#[cfg(test)] -pub use crate::database::memory::MemoryStorage; - -pub use kittybox_util::MicropubChannel; - -use self::settings::Setting; - -/// Enum representing different errors that might occur during the database query. -#[derive(Debug, Clone, Copy)] -pub enum ErrorKind { - /// Backend error (e.g. database connection error) - Backend, - /// Error due to insufficient contextual permissions for the query - PermissionDenied, - /// Error due to the database being unable to parse JSON returned from the backing storage. - /// Usually indicative of someone fiddling with the database manually instead of using proper tools. - JsonParsing, - /// - ErrorKind::NotFound - equivalent to a 404 error. Note, some requests return an Option, - /// in which case None is also equivalent to a 404. - NotFound, - /// The user's query or request to the database was malformed. Used whenever the database processes - /// the user's query directly, such as when editing posts inside of the database (e.g. Redis backend) - BadRequest, - /// the user's query collided with an in-flight request and needs to be retried - Conflict, - /// - ErrorKind::Other - when something so weird happens that it becomes undescribable. - Other, -} - -/// Settings that can be stored in the database. -pub mod settings { - mod private { - pub trait Sealed {} - } - - /// A trait for various settings that should be contained here. - /// - /// **Note**: this trait is sealed to prevent external - /// implementations, as it wouldn't make sense to add new settings - /// that aren't used by Kittybox itself. - pub trait Setting<'de>: private::Sealed + std::fmt::Debug + Default + Clone + serde::Serialize + serde::de::DeserializeOwned + /*From<Settings> +*/ Send + Sync { - type Data: std::fmt::Debug + Send + Sync; - const ID: &'static str; - - /// Unwrap the setting type, returning owned data contained within. - fn into_inner(self) -> Self::Data; - /// Create a new instance of this type containing certain data. - fn new(data: Self::Data) -> Self; - } - - /// A website's title, shown in the header. - #[derive(Debug, serde::Deserialize, serde::Serialize, Clone, PartialEq, Eq)] - pub struct SiteName(String); - impl Default for SiteName { - fn default() -> Self { - Self("Kittybox".to_string()) - } - } - impl AsRef<str> for SiteName { - fn as_ref(&self) -> &str { - self.0.as_str() - } - } - impl private::Sealed for SiteName {} - impl Setting<'_> for SiteName { - type Data = String; - const ID: &'static str = "site_name"; - - fn into_inner(self) -> String { - self.0 - } - fn new(data: Self::Data) -> Self { - Self(data) - } - } - impl SiteName { - fn from_str(data: &str) -> Self { - Self(data.to_owned()) - } - } - - /// Participation status in the IndieWeb Webring: https://🕸💍.ws/dashboard - #[derive(Debug, Default, serde::Deserialize, serde::Serialize, Clone, Copy, PartialEq, Eq)] - pub struct Webring(bool); - impl private::Sealed for Webring {} - impl Setting<'_> for Webring { - type Data = bool; - const ID: &'static str = "webring"; - - fn into_inner(self) -> Self::Data { - self.0 - } - - fn new(data: Self::Data) -> Self { - Self(data) - } - } -} - -/// Error signalled from the database. -#[derive(Debug)] -pub struct StorageError { - msg: std::borrow::Cow<'static, str>, - source: Option<Box<dyn std::error::Error + Send + Sync>>, - kind: ErrorKind, -} - -impl std::error::Error for StorageError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.source - .as_ref() - .map(|e| e.as_ref() as &dyn std::error::Error) - } -} -impl From<serde_json::Error> for StorageError { - fn from(err: serde_json::Error) -> Self { - Self { - msg: std::borrow::Cow::Owned(format!("{}", err)), - source: Some(Box::new(err)), - kind: ErrorKind::JsonParsing, - } - } -} -impl std::fmt::Display for StorageError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}: {}", - match self.kind { - ErrorKind::Backend => "backend error", - ErrorKind::JsonParsing => "JSON parsing error", - ErrorKind::PermissionDenied => "permission denied", - ErrorKind::NotFound => "not found", - ErrorKind::BadRequest => "bad request", - ErrorKind::Conflict => "conflict with an in-flight request or existing data", - ErrorKind::Other => "generic storage layer error", - }, - self.msg - ) - } -} -impl serde::Serialize for StorageError { - fn serialize<S: serde::Serializer>( - &self, - serializer: S, - ) -> std::result::Result<S::Ok, S::Error> { - serializer.serialize_str(&self.to_string()) - } -} -impl StorageError { - /// Create a new StorageError of an ErrorKind with a message. - pub fn new(kind: ErrorKind, msg: String) -> Self { - Self { - msg: Cow::Owned(msg), - source: None, - kind, - } - } - /// Create a new StorageError of an ErrorKind with a message from - /// a static string. - /// - /// This saves an allocation for a new string and is the preferred - /// way in case the error message doesn't change. - pub fn from_static(kind: ErrorKind, msg: &'static str) -> Self { - Self { - msg: Cow::Borrowed(msg), - source: None, - kind - } - } - /// Create a StorageError using another arbitrary Error as a source. - pub fn with_source( - kind: ErrorKind, - msg: std::borrow::Cow<'static, str>, - source: Box<dyn std::error::Error + Send + Sync>, - ) -> Self { - Self { - msg, - source: Some(source), - kind, - } - } - /// Get the kind of an error. - pub fn kind(&self) -> ErrorKind { - self.kind - } - /// Get the message as a string slice. - pub fn msg(&self) -> &str { - &self.msg - } -} - -/// A special Result type for the Micropub backing storage. -pub type Result<T> = std::result::Result<T, StorageError>; - -/// A storage backend for the Micropub server. -/// -/// Implementations should note that all methods listed on this trait MUST be fully atomic -/// or lock the database so that write conflicts or reading half-written data should not occur. -#[async_trait] -pub trait Storage: std::fmt::Debug + Clone + Send + Sync { - /// Check if a post exists in the database. - async fn post_exists(&self, url: &str) -> Result<bool>; - - /// Load a post from the database in MF2-JSON format, deserialized from JSON. - async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>>; - - /// Save a post to the database as an MF2-JSON structure. - /// - /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined. - async fn put_post(&self, post: &'_ serde_json::Value, user: &'_ str) -> Result<()>; - - /// Add post to feed. Some database implementations might have optimized ways to do this. - #[tracing::instrument(skip(self))] - async fn add_to_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> { - tracing::debug!("Inserting {} into {} using `update_post`", post, feed); - self.update_post(feed, serde_json::from_value( - serde_json::json!({"add": {"children": [post]}})).unwrap() - ).await - } - /// Remove post from feed. Some database implementations might have optimized ways to do this. - #[tracing::instrument(skip(self))] - async fn remove_from_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> { - tracing::debug!("Removing {} into {} using `update_post`", post, feed); - self.update_post(feed, serde_json::from_value( - serde_json::json!({"delete": {"children": [post]}})).unwrap() - ).await - } - - /// Modify a post using an update object as defined in the - /// Micropub spec. - /// - /// Note to implementors: the update operation MUST be atomic and - /// SHOULD lock the database to prevent two clients overwriting - /// each other's changes or simply corrupting something. Rejecting - /// is allowed in case of concurrent updates if waiting for a lock - /// cannot be done. - async fn update_post(&self, url: &str, update: MicropubUpdate) -> Result<()>; - - /// Get a list of channels available for the user represented by - /// the `user` domain to write to. - async fn get_channels(&self, user: &'_ str) -> Result<Vec<MicropubChannel>>; - - /// Fetch a feed at `url` and return an h-feed object containing - /// `limit` posts after a post by url `after`, filtering the content - /// in context of a user specified by `user` (or an anonymous user). - /// - /// This method MUST hydrate the `author` property with an h-card - /// from the database by replacing URLs with corresponding h-cards. - /// - /// When encountering posts which the `user` is not authorized to - /// access, this method MUST elide such posts (as an optimization - /// for the frontend) and not return them, but still return up to - /// `limit` posts (to not reveal the hidden posts' presence). - /// - /// Note for implementors: if you use streams to fetch posts in - /// parallel from the database, preferably make this method use a - /// connection pool to reduce overhead of creating a database - /// connection per post for parallel fetching. - async fn read_feed_with_limit( - &self, - url: &'_ str, - after: &'_ Option<String>, - limit: usize, - user: &'_ Option<String>, - ) -> Result<Option<serde_json::Value>>; - - /// Fetch a feed at `url` and return an h-feed object containing - /// `limit` posts after a `cursor` (filtering the content in - /// context of a user specified by `user`, or an anonymous user), - /// as well as a new cursor to paginate with. - /// - /// This method MUST hydrate the `author` property with an h-card - /// from the database by replacing URLs with corresponding h-cards. - /// - /// When encountering posts which the `user` is not authorized to - /// access, this method MUST elide such posts (as an optimization - /// for the frontend) and not return them, but still return an - /// amount of posts as close to `limit` as possible (to avoid - /// revealing the existence of the hidden post). - /// - /// Note for implementors: if you use streams to fetch posts in - /// parallel from the database, preferably make this method use a - /// connection pool to reduce overhead of creating a database - /// connection per post for parallel fetching. - async fn read_feed_with_cursor( - &self, - url: &'_ str, - cursor: Option<&'_ str>, - limit: usize, - user: Option<&'_ str> - ) -> Result<Option<(serde_json::Value, Option<String>)>>; - - /// Deletes a post from the database irreversibly. Must be idempotent. - async fn delete_post(&self, url: &'_ str) -> Result<()>; - - /// Gets a setting from the setting store and passes the result. - async fn get_setting<S: Setting<'a>, 'a>(&'_ self, user: &'_ str) -> Result<S>; - - /// Commits a setting to the setting store. - async fn set_setting<S: Setting<'a> + 'a, 'a>(&self, user: &'a str, value: S::Data) -> Result<()>; - - /// Add (or update) a webmention on a certian post. - /// - /// The MF2 object describing the webmention content will always - /// be of type `h-cite`, and the `uid` property on the object will - /// always be set. - /// - /// The rationale for this function is as follows: webmentions - /// might be duplicated, and we need to deduplicate them first. As - /// we lack support for transactions and locking posts on the - /// database, the only way is to implement the operation on the - /// database itself. - /// - /// Besides, it may even allow for nice tricks like storing the - /// webmentions separately and rehydrating them on feed reads. - async fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> Result<()>; -} - -#[cfg(test)] -mod tests { - use super::settings; - - use super::{MicropubChannel, Storage}; - use kittybox_util::MentionType; - use serde_json::json; - - async fn test_basic_operations<Backend: Storage>(backend: Backend) { - let post: serde_json::Value = json!({ - "type": ["h-entry"], - "properties": { - "content": ["Test content"], - "author": ["https://fireburn.ru/"], - "uid": ["https://fireburn.ru/posts/hello"], - "url": ["https://fireburn.ru/posts/hello", "https://fireburn.ru/posts/test"] - } - }); - let key = post["properties"]["uid"][0].as_str().unwrap().to_string(); - let alt_url = post["properties"]["url"][1].as_str().unwrap().to_string(); - - // Reading and writing - backend - .put_post(&post, "fireburn.ru") - .await - .unwrap(); - if let Some(returned_post) = backend.get_post(&key).await.unwrap() { - assert!(returned_post.is_object()); - assert_eq!( - returned_post["type"].as_array().unwrap().len(), - post["type"].as_array().unwrap().len() - ); - assert_eq!( - returned_post["type"].as_array().unwrap(), - post["type"].as_array().unwrap() - ); - let props: &serde_json::Map<String, serde_json::Value> = - post["properties"].as_object().unwrap(); - for key in props.keys() { - assert_eq!( - returned_post["properties"][key].as_array().unwrap(), - post["properties"][key].as_array().unwrap() - ) - } - } else { - panic!("For some reason the backend did not return the post.") - } - // Check the alternative URL - it should return the same post - if let Ok(Some(returned_post)) = backend.get_post(&alt_url).await { - assert!(returned_post.is_object()); - assert_eq!( - returned_post["type"].as_array().unwrap().len(), - post["type"].as_array().unwrap().len() - ); - assert_eq!( - returned_post["type"].as_array().unwrap(), - post["type"].as_array().unwrap() - ); - let props: &serde_json::Map<String, serde_json::Value> = - post["properties"].as_object().unwrap(); - for key in props.keys() { - assert_eq!( - returned_post["properties"][key].as_array().unwrap(), - post["properties"][key].as_array().unwrap() - ) - } - } else { - panic!("For some reason the backend did not return the post.") - } - } - - /// Note: this is merely a smoke check and is in no way comprehensive. - // TODO updates for feeds must update children using special logic - async fn test_update<Backend: Storage>(backend: Backend) { - let post: serde_json::Value = json!({ - "type": ["h-entry"], - "properties": { - "content": ["Test content"], - "author": ["https://fireburn.ru/"], - "uid": ["https://fireburn.ru/posts/hello"], - "url": ["https://fireburn.ru/posts/hello", "https://fireburn.ru/posts/test"] - } - }); - let key = post["properties"]["uid"][0].as_str().unwrap().to_string(); - - // Reading and writing - backend - .put_post(&post, "fireburn.ru") - .await - .unwrap(); - - backend - .update_post( - &key, - serde_json::from_value(json!({ - "url": &key, - "add": { - "category": ["testing"], - }, - "replace": { - "content": ["Different test content"] - } - })).unwrap(), - ) - .await - .unwrap(); - - match backend.get_post(&key).await { - Ok(Some(returned_post)) => { - assert!(returned_post.is_object()); - assert_eq!( - returned_post["type"].as_array().unwrap().len(), - post["type"].as_array().unwrap().len() - ); - assert_eq!( - returned_post["type"].as_array().unwrap(), - post["type"].as_array().unwrap() - ); - assert_eq!( - returned_post["properties"]["content"][0].as_str().unwrap(), - "Different test content" - ); - assert_eq!( - returned_post["properties"]["category"].as_array().unwrap(), - &vec![json!("testing")] - ); - } - something_else => { - something_else - .expect("Shouldn't error") - .expect("Should have the post"); - } - } - } - - async fn test_get_channel_list<Backend: Storage>(backend: Backend) { - let feed = json!({ - "type": ["h-feed"], - "properties": { - "name": ["Main Page"], - "author": ["https://fireburn.ru/"], - "uid": ["https://fireburn.ru/feeds/main"] - }, - "children": [] - }); - backend - .put_post(&feed, "fireburn.ru") - .await - .unwrap(); - let chans = backend.get_channels("fireburn.ru").await.unwrap(); - assert_eq!(chans.len(), 1); - assert_eq!( - chans[0], - MicropubChannel { - uid: "https://fireburn.ru/feeds/main".to_string(), - name: "Main Page".to_string() - } - ); - } - - async fn test_settings<Backend: Storage>(backend: Backend) { - backend - .set_setting::<settings::SiteName>( - "https://fireburn.ru/", - "Vika's Hideout".to_owned() - ) - .await - .unwrap(); - assert_eq!( - backend - .get_setting::<settings::SiteName>("https://fireburn.ru/") - .await - .unwrap() - .as_ref(), - "Vika's Hideout" - ); - } - - fn gen_random_post(domain: &str) -> serde_json::Value { - use faker_rand::lorem::{Paragraphs, Word}; - - let uid = format!( - "https://{domain}/posts/{}-{}-{}", - rand::random::<Word>(), - rand::random::<Word>(), - rand::random::<Word>() - ); - - let time = chrono::Local::now().to_rfc3339(); - let post = json!({ - "type": ["h-entry"], - "properties": { - "content": [rand::random::<Paragraphs>().to_string()], - "uid": [&uid], - "url": [&uid], - "published": [&time] - } - }); - - post - } - - fn gen_random_mention(domain: &str, mention_type: MentionType, url: &str) -> serde_json::Value { - use faker_rand::lorem::{Paragraphs, Word}; - - let uid = format!( - "https://{domain}/posts/{}-{}-{}", - rand::random::<Word>(), - rand::random::<Word>(), - rand::random::<Word>() - ); - - let time = chrono::Local::now().to_rfc3339(); - let post = json!({ - "type": ["h-cite"], - "properties": { - "content": [rand::random::<Paragraphs>().to_string()], - "uid": [&uid], - "url": [&uid], - "published": [&time], - (match mention_type { - MentionType::Reply => "in-reply-to", - MentionType::Like => "like-of", - MentionType::Repost => "repost-of", - MentionType::Bookmark => "bookmark-of", - MentionType::Mention => unimplemented!(), - }): [url] - } - }); - - post - } - - async fn test_feed_pagination<Backend: Storage>(backend: Backend) { - let posts = { - let mut posts = std::iter::from_fn( - || Some(gen_random_post("fireburn.ru")) - ) - .take(40) - .collect::<Vec<serde_json::Value>>(); - - // Reverse the array so it's in reverse-chronological order - posts.reverse(); - - posts - }; - - let feed = json!({ - "type": ["h-feed"], - "properties": { - "name": ["Main Page"], - "author": ["https://fireburn.ru/"], - "uid": ["https://fireburn.ru/feeds/main"] - }, - }); - let key = feed["properties"]["uid"][0].as_str().unwrap(); - - backend - .put_post(&feed, "fireburn.ru") - .await - .unwrap(); - - for (i, post) in posts.iter().rev().enumerate() { - backend - .put_post(post, "fireburn.ru") - .await - .unwrap(); - backend.add_to_feed(key, post["properties"]["uid"][0].as_str().unwrap()).await.unwrap(); - } - - let limit: usize = 10; - - tracing::debug!("Starting feed reading..."); - let (result, cursor) = backend - .read_feed_with_cursor(key, None, limit, None) - .await - .unwrap() - .unwrap(); - - assert_eq!(result["children"].as_array().unwrap().len(), limit); - assert_eq!( - result["children"] - .as_array() - .unwrap() - .iter() - .map(|post| post["properties"]["uid"][0].as_str().unwrap()) - .collect::<Vec<_>>() - [0..10], - posts - .iter() - .map(|post| post["properties"]["uid"][0].as_str().unwrap()) - .collect::<Vec<_>>() - [0..10] - ); - - tracing::debug!("Continuing with cursor: {:?}", cursor); - let (result2, cursor2) = backend - .read_feed_with_cursor( - key, - cursor.as_deref(), - limit, - None, - ) - .await - .unwrap() - .unwrap(); - - assert_eq!( - result2["children"].as_array().unwrap()[0..10], - posts[10..20] - ); - - tracing::debug!("Continuing with cursor: {:?}", cursor); - let (result3, cursor3) = backend - .read_feed_with_cursor( - key, - cursor2.as_deref(), - limit, - None, - ) - .await - .unwrap() - .unwrap(); - - assert_eq!( - result3["children"].as_array().unwrap()[0..10], - posts[20..30] - ); - - tracing::debug!("Continuing with cursor: {:?}", cursor); - let (result4, _) = backend - .read_feed_with_cursor( - key, - cursor3.as_deref(), - limit, - None, - ) - .await - .unwrap() - .unwrap(); - - assert_eq!( - result4["children"].as_array().unwrap()[0..10], - posts[30..40] - ); - - // Regression test for #4 - // - // Results for a bogus cursor are undefined, so we aren't - // checking them. But the function at least shouldn't hang. - let nonsense_after = Some("1010101010"); - let _ = tokio::time::timeout(tokio::time::Duration::from_secs(10), async move { - backend - .read_feed_with_cursor(key, nonsense_after, limit, None) - .await - }) - .await - .expect("Operation should not hang: see https://gitlab.com/kittybox/kittybox/-/issues/4"); - } - - async fn test_webmention_addition<Backend: Storage>(db: Backend) { - let post = gen_random_post("fireburn.ru"); - - db.put_post(&post, "fireburn.ru").await.unwrap(); - const TYPE: MentionType = MentionType::Reply; - - let target = post["properties"]["uid"][0].as_str().unwrap(); - let mut reply = gen_random_mention("aaronparecki.com", TYPE, target); - - let (read_post, _) = db.read_feed_with_cursor(target, None, 20, None).await.unwrap().unwrap(); - assert_eq!(post, read_post); - - db.add_or_update_webmention(target, TYPE, reply.clone()).await.unwrap(); - - let (read_post, _) = db.read_feed_with_cursor(target, None, 20, None).await.unwrap().unwrap(); - assert_eq!(read_post["properties"]["comment"][0], reply); - - reply["properties"]["content"][0] = json!(rand::random::<faker_rand::lorem::Paragraphs>().to_string()); - - db.add_or_update_webmention(target, TYPE, reply.clone()).await.unwrap(); - let (read_post, _) = db.read_feed_with_cursor(target, None, 20, None).await.unwrap().unwrap(); - assert_eq!(read_post["properties"]["comment"][0], reply); - } - - async fn test_pretty_permalinks<Backend: Storage>(db: Backend) { - const PERMALINK: &str = "https://fireburn.ru/posts/pretty-permalink"; - - let post = { - let mut post = gen_random_post("fireburn.ru"); - let urls = post["properties"]["url"].as_array_mut().unwrap(); - urls.push(serde_json::Value::String( - PERMALINK.to_owned() - )); - - post - }; - db.put_post(&post, "fireburn.ru").await.unwrap(); - - for i in post["properties"]["url"].as_array().unwrap() { - let (read_post, _) = db.read_feed_with_cursor(i.as_str().unwrap(), None, 20, None).await.unwrap().unwrap(); - assert_eq!(read_post, post); - } - } - /// Automatically generates a test suite for - macro_rules! test_all { - ($func_name:ident, $mod_name:ident) => { - mod $mod_name { - $func_name!(test_basic_operations); - $func_name!(test_get_channel_list); - $func_name!(test_settings); - $func_name!(test_update); - $func_name!(test_feed_pagination); - $func_name!(test_webmention_addition); - $func_name!(test_pretty_permalinks); - } - }; - } - macro_rules! file_test { - ($func_name:ident) => { - #[tokio::test] - #[tracing_test::traced_test] - async fn $func_name() { - let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); - let backend = super::super::FileStorage::new( - tempdir.path().to_path_buf() - ) - .await - .unwrap(); - super::$func_name(backend).await - } - }; - } - - macro_rules! postgres_test { - ($func_name:ident) => { - #[cfg(feature = "sqlx")] - #[sqlx::test] - #[tracing_test::traced_test] - async fn $func_name( - pool_opts: sqlx::postgres::PgPoolOptions, - connect_opts: sqlx::postgres::PgConnectOptions - ) -> Result<(), sqlx::Error> { - let db = { - //use sqlx::ConnectOptions; - //connect_opts.log_statements(log::LevelFilter::Debug); - - pool_opts.connect_with(connect_opts).await? - }; - let backend = super::super::PostgresStorage::from_pool(db).await.unwrap(); - - Ok(super::$func_name(backend).await) - } - }; - } - - test_all!(file_test, file); - test_all!(postgres_test, postgres); -} diff --git a/kittybox-rs/src/database/postgres/mod.rs b/kittybox-rs/src/database/postgres/mod.rs deleted file mode 100644 index 9176d12..0000000 --- a/kittybox-rs/src/database/postgres/mod.rs +++ /dev/null @@ -1,416 +0,0 @@ -#![allow(unused_variables)] -use std::borrow::Cow; -use std::str::FromStr; - -use kittybox_util::{MicropubChannel, MentionType}; -use sqlx::{PgPool, Executor}; -use crate::micropub::{MicropubUpdate, MicropubPropertyDeletion}; - -use super::settings::Setting; -use super::{Storage, Result, StorageError, ErrorKind}; - -static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!(); - -impl From<sqlx::Error> for StorageError { - fn from(value: sqlx::Error) -> Self { - Self::with_source( - super::ErrorKind::Backend, - Cow::Owned(format!("sqlx error: {}", &value)), - Box::new(value) - ) - } -} - -impl From<sqlx::migrate::MigrateError> for StorageError { - fn from(value: sqlx::migrate::MigrateError) -> Self { - Self::with_source( - super::ErrorKind::Backend, - Cow::Owned(format!("sqlx migration error: {}", &value)), - Box::new(value) - ) - } -} - -#[derive(Debug, Clone)] -pub struct PostgresStorage { - db: PgPool -} - -impl PostgresStorage { - /// Construct a new [`PostgresStorage`] from an URI string and run - /// migrations on the database. - /// - /// If `PGPASS_FILE` environment variable is defined, read the - /// password from the file at the specified path. If, instead, - /// the `PGPASS` environment variable is present, read the - /// password from it. - pub async fn new(uri: &str) -> Result<Self> { - tracing::debug!("Postgres URL: {uri}"); - let mut options = sqlx::postgres::PgConnectOptions::from_str(uri)? - .options([("search_path", "kittybox")]); - if let Ok(password_file) = std::env::var("PGPASS_FILE") { - let password = tokio::fs::read_to_string(password_file).await.unwrap(); - options = options.password(&password); - } else if let Ok(password) = std::env::var("PGPASS") { - options = options.password(&password) - } - Self::from_pool( - sqlx::postgres::PgPoolOptions::new() - .max_connections(50) - .connect_with(options) - .await? - ).await - - } - - /// Construct a [`PostgresStorage`] from a [`sqlx::PgPool`], - /// running appropriate migrations. - pub async fn from_pool(db: sqlx::PgPool) -> Result<Self> { - db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox")).await?; - MIGRATOR.run(&db).await?; - Ok(Self { db }) - } -} - -#[async_trait::async_trait] -impl Storage for PostgresStorage { - #[tracing::instrument(skip(self))] - async fn post_exists(&self, url: &str) -> Result<bool> { - sqlx::query_as::<_, (bool,)>("SELECT exists(SELECT 1 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1)") - .bind(url) - .fetch_one(&self.db) - .await - .map(|v| v.0) - .map_err(|err| err.into()) - } - - #[tracing::instrument(skip(self))] - async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> { - sqlx::query_as::<_, (serde_json::Value,)>("SELECT mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1") - .bind(url) - .fetch_optional(&self.db) - .await - .map(|v| v.map(|v| v.0)) - .map_err(|err| err.into()) - - } - - #[tracing::instrument(skip(self))] - async fn put_post(&self, post: &'_ serde_json::Value, user: &'_ str) -> Result<()> { - tracing::debug!("New post: {}", post); - sqlx::query("INSERT INTO kittybox.mf2_json (uid, mf2, owner) VALUES ($1 #>> '{properties,uid,0}', $1, $2)") - .bind(post) - .bind(user) - .execute(&self.db) - .await - .map(|_| ()) - .map_err(Into::into) - } - - #[tracing::instrument(skip(self))] - async fn add_to_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> { - tracing::debug!("Inserting {} into {}", post, feed); - sqlx::query("INSERT INTO kittybox.children (parent, child) VALUES ($1, $2) ON CONFLICT DO NOTHING") - .bind(feed) - .bind(post) - .execute(&self.db) - .await - .map(|_| ()) - .map_err(Into::into) - } - - #[tracing::instrument(skip(self))] - async fn remove_from_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> { - sqlx::query("DELETE FROM kittybox.children WHERE parent = $1 AND child = $2") - .bind(feed) - .bind(post) - .execute(&self.db) - .await - .map_err(Into::into) - .map(|_| ()) - } - - #[tracing::instrument(skip(self))] - async fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> Result<()> { - let mut txn = self.db.begin().await?; - - let (uid, mut post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE") - .bind(target) - .fetch_optional(&mut *txn) - .await? - .ok_or(StorageError::from_static( - ErrorKind::NotFound, - "The specified post wasn't found in the database." - ))?; - - tracing::debug!("Loaded post for target {} with uid {}", target, uid); - - let key: &'static str = match mention_type { - MentionType::Reply => "comment", - MentionType::Like => "like", - MentionType::Repost => "repost", - MentionType::Bookmark => "bookmark", - MentionType::Mention => "mention", - }; - - tracing::debug!("Mention type -> key: {}", key); - - let mention_uid = mention["properties"]["uid"][0].clone(); - if let Some(values) = post["properties"][key].as_array_mut() { - for value in values.iter_mut() { - if value["properties"]["uid"][0] == mention_uid { - *value = mention; - break; - } - } - } else { - post["properties"][key] = serde_json::Value::Array(vec![mention]); - } - - sqlx::query("UPDATE kittybox.mf2_json SET mf2 = $2 WHERE uid = $1") - .bind(uid) - .bind(post) - .execute(&mut *txn) - .await?; - - txn.commit().await.map_err(Into::into) - } - #[tracing::instrument(skip(self))] - async fn update_post(&self, url: &'_ str, update: MicropubUpdate) -> Result<()> { - tracing::debug!("Updating post {}", url); - let mut txn = self.db.begin().await?; - let (uid, mut post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE") - .bind(url) - .fetch_optional(&mut *txn) - .await? - .ok_or(StorageError::from_static( - ErrorKind::NotFound, - "The specified post wasn't found in the database." - ))?; - - if let Some(MicropubPropertyDeletion::Properties(ref delete)) = update.delete { - if let Some(props) = post["properties"].as_object_mut() { - for key in delete { - props.remove(key); - } - } - } else if let Some(MicropubPropertyDeletion::Values(ref delete)) = update.delete { - if let Some(props) = post["properties"].as_object_mut() { - for (key, values) in delete { - if let Some(prop) = props.get_mut(key).and_then(serde_json::Value::as_array_mut) { - prop.retain(|v| { values.iter().all(|i| i != v) }) - } - } - } - } - if let Some(replace) = update.replace { - if let Some(props) = post["properties"].as_object_mut() { - for (key, value) in replace { - props.insert(key, serde_json::Value::Array(value)); - } - } - } - if let Some(add) = update.add { - if let Some(props) = post["properties"].as_object_mut() { - for (key, value) in add { - if let Some(prop) = props.get_mut(&key).and_then(serde_json::Value::as_array_mut) { - prop.extend_from_slice(value.as_slice()); - } else { - props.insert(key, serde_json::Value::Array(value)); - } - } - } - } - - sqlx::query("UPDATE kittybox.mf2_json SET mf2 = $2 WHERE uid = $1") - .bind(uid) - .bind(post) - .execute(&mut *txn) - .await?; - - txn.commit().await.map_err(Into::into) - } - - #[tracing::instrument(skip(self))] - async fn get_channels(&self, user: &'_ str) -> Result<Vec<MicropubChannel>> { - /*sqlx::query_as::<_, MicropubChannel>("SELECT name, uid FROM kittybox.channels WHERE owner = $1") - .bind(user) - .fetch_all(&self.db) - .await - .map_err(|err| err.into())*/ - sqlx::query_as::<_, MicropubChannel>(r#"SELECT mf2 #>> '{properties,name,0}' as name, uid FROM kittybox.mf2_json WHERE '["h-feed"]'::jsonb @> mf2['type'] AND owner = $1"#) - .bind(user) - .fetch_all(&self.db) - .await - .map_err(|err| err.into()) - } - - #[tracing::instrument(skip(self))] - async fn read_feed_with_limit( - &self, - url: &'_ str, - after: &'_ Option<String>, - limit: usize, - user: &'_ Option<String>, - ) -> Result<Option<serde_json::Value>> { - let mut feed = match sqlx::query_as::<_, (serde_json::Value,)>(" -SELECT jsonb_set( - mf2, - '{properties,author,0}', - (SELECT mf2 FROM kittybox.mf2_json - WHERE uid = mf2 #>> '{properties,author,0}') -) FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 -") - .bind(url) - .fetch_optional(&self.db) - .await? - .map(|v| v.0) - { - Some(feed) => feed, - None => return Ok(None) - }; - - let posts: Vec<String> = { - let mut posts_iter = feed["children"] - .as_array() - .cloned() - .unwrap_or_default() - .into_iter() - .map(|s| s.as_str().unwrap().to_string()); - if let Some(after) = after { - for s in posts_iter.by_ref() { - if &s == after { - break; - } - } - }; - - posts_iter.take(limit).collect::<Vec<_>>() - }; - feed["children"] = serde_json::Value::Array( - sqlx::query_as::<_, (serde_json::Value,)>(" -SELECT jsonb_set( - mf2, - '{properties,author,0}', - (SELECT mf2 FROM kittybox.mf2_json - WHERE uid = mf2 #>> '{properties,author,0}') -) FROM kittybox.mf2_json -WHERE uid = ANY($1) -ORDER BY mf2 #>> '{properties,published,0}' DESC -") - .bind(&posts[..]) - .fetch_all(&self.db) - .await? - .into_iter() - .map(|v| v.0) - .collect::<Vec<_>>() - ); - - Ok(Some(feed)) - - } - - #[tracing::instrument(skip(self))] - async fn read_feed_with_cursor( - &self, - url: &'_ str, - cursor: Option<&'_ str>, - limit: usize, - user: Option<&'_ str> - ) -> Result<Option<(serde_json::Value, Option<String>)>> { - let mut txn = self.db.begin().await?; - sqlx::query("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ ONLY") - .execute(&mut *txn) - .await?; - tracing::debug!("Started txn: {:?}", txn); - let mut feed = match sqlx::query_scalar::<_, serde_json::Value>(" -SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 -") - .bind(url) - .fetch_optional(&mut *txn) - .await? - { - Some(feed) => feed, - None => return Ok(None) - }; - - // Don't query for children if this isn't a feed. - // - // The second query is very long and will probably be extremely - // expensive. It's best to skip it on types where it doesn't make sense - // (Kittybox doesn't support rendering children on non-feeds) - if !feed["type"].as_array().unwrap().iter().any(|t| *t == serde_json::json!("h-feed")) { - return Ok(Some((feed, None))); - } - - feed["children"] = sqlx::query_scalar::<_, serde_json::Value>(" -SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json -INNER JOIN kittybox.children -ON mf2_json.uid = children.child -WHERE - children.parent = $1 - AND ( - ( - (mf2 #>> '{properties,visibility,0}') = 'public' - OR - NOT (mf2['properties'] ? 'visibility') - ) - OR - ( - $3 != null AND ( - mf2['properties']['audience'] ? $3 - OR mf2['properties']['author'] ? $3 - ) - ) - ) - AND ($4 IS NULL OR ((mf2_json.mf2 #>> '{properties,published,0}') < $4)) -ORDER BY (mf2_json.mf2 #>> '{properties,published,0}') DESC -LIMIT $2" - ) - .bind(url) - .bind(limit as i64) - .bind(user) - .bind(cursor) - .fetch_all(&mut *txn) - .await - .map(serde_json::Value::Array)?; - - let new_cursor = feed["children"].as_array().unwrap() - .last() - .map(|v| v["properties"]["published"][0].as_str().unwrap().to_owned()); - - txn.commit().await?; - - Ok(Some((feed, new_cursor))) - } - - #[tracing::instrument(skip(self))] - async fn delete_post(&self, url: &'_ str) -> Result<()> { - todo!() - } - - #[tracing::instrument(skip(self))] - async fn get_setting<S: Setting<'a>, 'a>(&'_ self, user: &'_ str) -> Result<S> { - match sqlx::query_as::<_, (serde_json::Value,)>("SELECT kittybox.get_setting($1, $2)") - .bind(user) - .bind(S::ID) - .fetch_one(&self.db) - .await - { - Ok((value,)) => Ok(serde_json::from_value(value)?), - Err(err) => Err(err.into()) - } - } - - #[tracing::instrument(skip(self))] - async fn set_setting<S: Setting<'a> + 'a, 'a>(&self, user: &'a str, value: S::Data) -> Result<()> { - sqlx::query("SELECT kittybox.set_setting($1, $2, $3)") - .bind(user) - .bind(S::ID) - .bind(serde_json::to_value(S::new(value)).unwrap()) - .execute(&self.db) - .await - .map_err(Into::into) - .map(|_| ()) - } -} diff --git a/kittybox-rs/src/database/redis/edit_post.lua b/kittybox-rs/src/database/redis/edit_post.lua deleted file mode 100644 index a398f8d..0000000 --- a/kittybox-rs/src/database/redis/edit_post.lua +++ /dev/null @@ -1,93 +0,0 @@ -local posts = KEYS[1] -local update_desc = cjson.decode(ARGV[2]) -local post = cjson.decode(redis.call("HGET", posts, ARGV[1])) - -local delete_keys = {} -local delete_kvs = {} -local add_keys = {} - -if update_desc.replace ~= nil then - for k, v in pairs(update_desc.replace) do - table.insert(delete_keys, k) - add_keys[k] = v - end -end -if update_desc.delete ~= nil then - if update_desc.delete[0] == nil then - -- Table has string keys. Probably! - for k, v in pairs(update_desc.delete) do - delete_kvs[k] = v - end - else - -- Table has numeric keys. Probably! - for i, v in ipairs(update_desc.delete) do - table.insert(delete_keys, v) - end - end -end -if update_desc.add ~= nil then - for k, v in pairs(update_desc.add) do - add_keys[k] = v - end -end - -for i, v in ipairs(delete_keys) do - post["properties"][v] = nil - -- TODO delete URL links -end - -for k, v in pairs(delete_kvs) do - local index = -1 - if k == "children" then - for j, w in ipairs(post[k]) do - if w == v then - index = j - break - end - end - if index > -1 then - table.remove(post[k], index) - end - else - for j, w in ipairs(post["properties"][k]) do - if w == v then - index = j - break - end - end - if index > -1 then - table.remove(post["properties"][k], index) - -- TODO delete URL links - end - end -end - -for k, v in pairs(add_keys) do - if k == "children" then - if post["children"] == nil then - post["children"] = {} - end - for i, w in ipairs(v) do - table.insert(post["children"], 1, w) - end - else - if post["properties"][k] == nil then - post["properties"][k] = {} - end - for i, w in ipairs(v) do - table.insert(post["properties"][k], w) - end - if k == "url" then - redis.call("HSET", posts, v, cjson.encode({ see_other = post["properties"]["uid"][1] })) - elseif k == "channel" then - local feed = cjson.decode(redis.call("HGET", posts, v)) - table.insert(feed["children"], 1, post["properties"]["uid"][1]) - redis.call("HSET", posts, v, cjson.encode(feed)) - end - end -end - -local encoded = cjson.encode(post) -redis.call("SET", "debug", encoded) -redis.call("HSET", posts, post["properties"]["uid"][1], encoded) -return \ No newline at end of file diff --git a/kittybox-rs/src/database/redis/mod.rs b/kittybox-rs/src/database/redis/mod.rs deleted file mode 100644 index 39ee852..0000000 --- a/kittybox-rs/src/database/redis/mod.rs +++ /dev/null @@ -1,398 +0,0 @@ -use async_trait::async_trait; -use futures::stream; -use futures_util::FutureExt; -use futures_util::StreamExt; -use futures_util::TryStream; -use futures_util::TryStreamExt; -use lazy_static::lazy_static; -use log::error; -use mobc::Pool; -use mobc_redis::redis; -use mobc_redis::redis::AsyncCommands; -use mobc_redis::RedisConnectionManager; -use serde_json::json; -use std::time::Duration; - -use crate::database::{ErrorKind, MicropubChannel, Result, Storage, StorageError, filter_post}; -use crate::indieauth::User; - -struct RedisScripts { - edit_post: redis::Script, -} - -impl From<mobc_redis::redis::RedisError> for StorageError { - fn from(err: mobc_redis::redis::RedisError) -> Self { - Self { - msg: format!("{}", err), - source: Some(Box::new(err)), - kind: ErrorKind::Backend, - } - } -} -impl From<mobc::Error<mobc_redis::redis::RedisError>> for StorageError { - fn from(err: mobc::Error<mobc_redis::redis::RedisError>) -> Self { - Self { - msg: format!("{}", err), - source: Some(Box::new(err)), - kind: ErrorKind::Backend, - } - } -} - -lazy_static! { - static ref SCRIPTS: RedisScripts = RedisScripts { - edit_post: redis::Script::new(include_str!("./edit_post.lua")) - }; -} -/*#[cfg(feature(lazy_cell))] -static SCRIPTS_CELL: std::cell::LazyCell<RedisScripts> = std::cell::LazyCell::new(|| { - RedisScripts { - edit_post: redis::Script::new(include_str!("./edit_post.lua")) - } -});*/ - -#[derive(Clone)] -pub struct RedisStorage { - // note to future Vika: - // mobc::Pool is actually a fancy name for an Arc - // around a shared connection pool with a manager - // which makes it safe to implement [`Clone`] and - // not worry about new pools being suddenly made - // - // stop worrying and start coding, you dum-dum - redis: mobc::Pool<RedisConnectionManager>, -} - -#[async_trait] -impl Storage for RedisStorage { - async fn get_setting<'a>(&self, setting: &'a str, user: &'a str) -> Result<String> { - let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; - Ok(conn - .hget::<String, &str, String>(format!("settings_{}", user), setting) - .await?) - } - - async fn set_setting<'a>(&self, setting: &'a str, user: &'a str, value: &'a str) -> Result<()> { - let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; - Ok(conn - .hset::<String, &str, &str, ()>(format!("settings_{}", user), setting, value) - .await?) - } - - async fn delete_post<'a>(&self, url: &'a str) -> Result<()> { - let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; - Ok(conn.hdel::<&str, &str, ()>("posts", url).await?) - } - - async fn post_exists(&self, url: &str) -> Result<bool> { - let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; - Ok(conn.hexists::<&str, &str, bool>("posts", url).await?) - } - - async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> { - let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; - match conn - .hget::<&str, &str, Option<String>>("posts", url) - .await? - { - Some(val) => { - let parsed = serde_json::from_str::<serde_json::Value>(&val)?; - if let Some(new_url) = parsed["see_other"].as_str() { - match conn - .hget::<&str, &str, Option<String>>("posts", new_url) - .await? - { - Some(val) => Ok(Some(serde_json::from_str::<serde_json::Value>(&val)?)), - None => Ok(None), - } - } else { - Ok(Some(parsed)) - } - } - None => Ok(None), - } - } - - async fn get_channels(&self, user: &User) -> Result<Vec<MicropubChannel>> { - let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; - let channels = conn - .smembers::<String, Vec<String>>("channels_".to_string() + user.me.as_str()) - .await?; - // TODO: use streams here instead of this weird thing... how did I even write this?! - Ok(futures_util::future::join_all( - channels - .iter() - .map(|channel| { - self.get_post(channel).map(|result| result.unwrap()).map( - |post: Option<serde_json::Value>| { - post.map(|post| MicropubChannel { - uid: post["properties"]["uid"][0].as_str().unwrap().to_string(), - name: post["properties"]["name"][0].as_str().unwrap().to_string(), - }) - }, - ) - }) - .collect::<Vec<_>>(), - ) - .await - .into_iter() - .flatten() - .collect::<Vec<_>>()) - } - - async fn put_post<'a>(&self, post: &'a serde_json::Value, user: &'a str) -> Result<()> { - let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; - let key: &str; - match post["properties"]["uid"][0].as_str() { - Some(uid) => key = uid, - None => { - return Err(StorageError::new( - ErrorKind::BadRequest, - "post doesn't have a UID", - )) - } - } - conn.hset::<&str, &str, String, ()>("posts", key, post.to_string()) - .await?; - if post["properties"]["url"].is_array() { - for url in post["properties"]["url"] - .as_array() - .unwrap() - .iter() - .map(|i| i.as_str().unwrap().to_string()) - { - if url != key && url.starts_with(user) { - conn.hset::<&str, &str, String, ()>( - "posts", - &url, - json!({ "see_other": key }).to_string(), - ) - .await?; - } - } - } - if post["type"] - .as_array() - .unwrap() - .iter() - .any(|i| i == "h-feed") - { - // This is a feed. Add it to the channels array if it's not already there. - conn.sadd::<String, &str, ()>( - "channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(), - key, - ) - .await? - } - Ok(()) - } - - async fn read_feed_with_limit<'a>( - &self, - url: &'a str, - after: &'a Option<String>, - limit: usize, - user: &'a Option<String>, - ) -> Result<Option<serde_json::Value>> { - let mut conn = self.redis.get().await?; - let mut feed; - match conn - .hget::<&str, &str, Option<String>>("posts", url) - .await - .map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))? - { - Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?, - None => return Ok(None), - } - if feed["see_other"].is_string() { - match conn - .hget::<&str, &str, Option<String>>("posts", feed["see_other"].as_str().unwrap()) - .await? - { - Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?, - None => return Ok(None), - } - } - if let Some(post) = filter_post(feed, user) { - feed = post - } else { - return Err(StorageError::new( - ErrorKind::PermissionDenied, - "specified user cannot access this post", - )); - } - if feed["children"].is_array() { - let children = feed["children"].as_array().unwrap(); - let mut posts_iter = children.iter().map(|i| i.as_str().unwrap().to_string()); - if after.is_some() { - loop { - let i = posts_iter.next(); - if &i == after { - break; - } - } - } - async fn fetch_post_for_feed(url: String) -> Option<serde_json::Value> { - return Some(serde_json::json!({})); - } - let posts = stream::iter(posts_iter) - .map(|url: String| async move { - return Ok(fetch_post_for_feed(url).await); - /*match self.redis.get().await { - Ok(mut conn) => { - match conn.hget::<&str, &str, Option<String>>("posts", &url).await { - Ok(post) => match post { - Some(post) => { - Ok(Some(serde_json::from_str(&post)?)) - } - // Happens because of a broken link (result of an improper deletion?) - None => Ok(None), - }, - Err(err) => Err(StorageError::with_source(ErrorKind::Backend, "Error executing a Redis command", Box::new(err))) - } - } - Err(err) => Err(StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(err))) - }*/ - }) - // TODO: determine the optimal value for this buffer - // It will probably depend on how often can you encounter a private post on the page - // It shouldn't be too large, or we'll start fetching too many posts from the database - // It MUST NOT be larger than the typical page size - // It MUST NOT be a significant amount of the connection pool size - //.buffered(std::cmp::min(3, limit)) - // Hack to unwrap the Option and sieve out broken links - // Broken links return None, and Stream::filter_map skips all Nones. - // I wonder if one can use try_flatten() here somehow akin to iters - .try_filter_map(|post| async move { Ok(post) }) - .try_filter_map(|post| async move { - Ok(filter_post(post, user)) - }) - .take(limit); - match posts.try_collect::<Vec<serde_json::Value>>().await { - Ok(posts) => feed["children"] = json!(posts), - Err(err) => { - let e = StorageError::with_source( - ErrorKind::Other, - "An error was encountered while processing the feed", - Box::new(err) - ); - error!("Error while assembling feed: {}", e); - return Err(e); - } - } - } - return Ok(Some(feed)); - } - - async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> { - let mut conn = self.redis.get().await.map_err(|e| StorageError::with_source(ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e)))?; - if !conn - .hexists::<&str, &str, bool>("posts", url) - .await - .unwrap() - { - return Err(StorageError::new( - ErrorKind::NotFound, - "can't edit a non-existent post", - )); - } - let post: serde_json::Value = - serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await?)?; - if let Some(new_url) = post["see_other"].as_str() { - url = new_url - } - Ok(SCRIPTS - .edit_post - .key("posts") - .arg(url) - .arg(update.to_string()) - .invoke_async::<_, ()>(&mut conn as &mut redis::aio::Connection) - .await?) - } -} - -impl RedisStorage { - /// Create a new RedisDatabase that will connect to Redis at `redis_uri` to store data. - pub async fn new(redis_uri: String) -> Result<Self> { - match redis::Client::open(redis_uri) { - Ok(client) => Ok(Self { - redis: Pool::builder() - .max_open(20) - .max_idle(5) - .get_timeout(Some(Duration::from_secs(3))) - .max_lifetime(Some(Duration::from_secs(120))) - .build(RedisConnectionManager::new(client)), - }), - Err(e) => Err(e.into()), - } - } - - pub async fn conn(&self) -> Result<mobc::Connection<mobc_redis::RedisConnectionManager>> { - self.redis.get().await.map_err(|e| StorageError::with_source( - ErrorKind::Backend, "Error getting a connection from the pool", Box::new(e) - )) - } -} - -#[cfg(test)] -pub mod tests { - use mobc_redis::redis; - use std::process; - use std::time::Duration; - - pub struct RedisInstance { - // We just need to hold on to it so it won't get dropped and remove the socket - _tempdir: tempdir::TempDir, - uri: String, - child: std::process::Child, - } - impl Drop for RedisInstance { - fn drop(&mut self) { - self.child.kill().expect("Failed to kill the child!"); - } - } - impl RedisInstance { - pub fn uri(&self) -> &str { - &self.uri - } - } - - pub async fn get_redis_instance() -> RedisInstance { - let tempdir = tempdir::TempDir::new("redis").expect("failed to create tempdir"); - let socket = tempdir.path().join("redis.sock"); - let redis_child = process::Command::new("redis-server") - .current_dir(&tempdir) - .arg("--port") - .arg("0") - .arg("--unixsocket") - .arg(&socket) - .stdout(process::Stdio::null()) - .stderr(process::Stdio::null()) - .spawn() - .expect("Failed to spawn Redis"); - println!("redis+unix:///{}", socket.to_str().unwrap()); - let uri = format!("redis+unix:///{}", socket.to_str().unwrap()); - // There should be a slight delay, we need to wait for Redis to spin up - let client = redis::Client::open(uri.clone()).unwrap(); - let millisecond = Duration::from_millis(1); - let mut retries: usize = 0; - const MAX_RETRIES: usize = 60 * 1000/*ms*/; - while let Err(err) = client.get_connection() { - if err.is_connection_refusal() { - async_std::task::sleep(millisecond).await; - retries += 1; - if retries > MAX_RETRIES { - panic!("Timeout waiting for Redis, last error: {}", err); - } - } else { - panic!("Could not connect: {}", err); - } - } - - RedisInstance { - uri, - child: redis_child, - _tempdir: tempdir, - } - } -} diff --git a/kittybox-rs/src/frontend/login.rs b/kittybox-rs/src/frontend/login.rs deleted file mode 100644 index c693899..0000000 --- a/kittybox-rs/src/frontend/login.rs +++ /dev/null @@ -1,333 +0,0 @@ -use http_types::Mime; -use log::{debug, error}; -use rand::Rng; -use serde::{Deserialize, Serialize}; -use sha2::{Digest, Sha256}; -use std::convert::TryInto; -use std::str::FromStr; - -use crate::frontend::templates::Template; -use crate::frontend::{FrontendError, IndiewebEndpoints}; -use crate::{database::Storage, ApplicationState}; -use kittybox_frontend_renderer::LoginPage; - -pub async fn form<S: Storage>(req: Request<ApplicationState<S>>) -> Result { - let owner = req.url().origin().ascii_serialization() + "/"; - let storage = &req.state().storage; - let authorization_endpoint = req.state().authorization_endpoint.to_string(); - let token_endpoint = req.state().token_endpoint.to_string(); - let blog_name = storage - .get_setting("site_name", &owner) - .await - .unwrap_or_else(|_| "Kitty Box!".to_string()); - let feeds = storage.get_channels(&owner).await.unwrap_or_default(); - - Ok(Response::builder(200) - .body( - Template { - title: "Sign in with IndieAuth", - blog_name: &blog_name, - endpoints: IndiewebEndpoints { - authorization_endpoint, - token_endpoint, - webmention: None, - microsub: None, - }, - feeds, - user: req.session().get("user"), - content: LoginPage {}.to_string(), - } - .to_string(), - ) - .content_type("text/html; charset=utf-8") - .build()) -} - -#[derive(Serialize, Deserialize)] -struct LoginForm { - url: String, -} - -#[derive(Serialize, Deserialize)] -struct IndieAuthClientState { - /// A random value to protect from CSRF attacks. - nonce: String, - /// The user's initial "me" value. - me: String, - /// Authorization endpoint used. - authorization_endpoint: String, -} - -#[derive(Serialize, Deserialize)] -struct IndieAuthRequestParams { - response_type: String, // can only have "code". TODO make an enum - client_id: String, // always a URL. TODO consider making a URL - redirect_uri: surf::Url, // callback URI for IndieAuth - state: String, // CSRF protection, should include randomness and be passed through - code_challenge: String, // base64-encoded PKCE challenge - code_challenge_method: String, // usually "S256". TODO make an enum - scope: Option<String>, // oAuth2 scopes to grant, - me: surf::Url, // User's entered profile URL -} - -/// Handle login requests. Find the IndieAuth authorization endpoint and redirect to it. -pub async fn handler<S: Storage>(mut req: Request<ApplicationState<S>>) -> Result { - let content_type = req.content_type(); - if content_type.is_none() { - return Err(FrontendError::with_code(400, "Use the login form, Luke.").into()); - } - if content_type.unwrap() != Mime::from_str("application/x-www-form-urlencoded").unwrap() { - return Err( - FrontendError::with_code(400, "Login form results must be a urlencoded form").into(), - ); - } - - let form = req.body_form::<LoginForm>().await?; // FIXME check if it returns 400 or 500 on error - let homepage_uri = surf::Url::parse(&form.url)?; - let http = &req.state().http_client; - - let mut fetch_response = http.get(&homepage_uri).send().await?; - if fetch_response.status() != 200 { - return Err(FrontendError::with_code( - 500, - "Error fetching your authorization endpoint. Check if your website's okay.", - ) - .into()); - } - - let mut authorization_endpoint: Option<surf::Url> = None; - if let Some(links) = fetch_response.header("Link") { - // NOTE: this is the same Link header parser used in src/micropub/post.rs:459. - // One should refactor it to a function to use independently and improve later - for link in links.iter().flat_map(|i| i.as_str().split(',')) { - debug!("Trying to match {} as authorization_endpoint", link); - let mut split_link = link.split(';'); - - match split_link.next() { - Some(uri) => { - if let Some(uri) = uri.strip_prefix('<').and_then(|uri| uri.strip_suffix('>')) { - debug!("uri: {}", uri); - for prop in split_link { - debug!("prop: {}", prop); - let lowercased = prop.to_ascii_lowercase(); - let trimmed = lowercased.trim(); - if trimmed == "rel=\"authorization_endpoint\"" - || trimmed == "rel=authorization_endpoint" - { - if let Ok(endpoint) = homepage_uri.join(uri) { - debug!( - "Found authorization endpoint {} for user {}", - endpoint, - homepage_uri.as_str() - ); - authorization_endpoint = Some(endpoint); - break; - } - } - } - } - } - None => continue, - } - } - } - // If the authorization_endpoint is still not found after the Link parsing gauntlet, - // bring out the big guns and parse HTML to find it. - if authorization_endpoint.is_none() { - let body = fetch_response.body_string().await?; - let pattern = - easy_scraper::Pattern::new(r#"<link rel="authorization_endpoint" href="{{url}}">"#) - .expect("Cannot parse the pattern for authorization_endpoint"); - let matches = pattern.matches(&body); - debug!("Matches for authorization_endpoint in HTML: {:?}", matches); - if !matches.is_empty() { - if let Ok(endpoint) = homepage_uri.join(&matches[0]["url"]) { - debug!( - "Found authorization endpoint {} for user {}", - endpoint, - homepage_uri.as_str() - ); - authorization_endpoint = Some(endpoint) - } - } - }; - // If even after this the authorization endpoint is still not found, bail out. - if authorization_endpoint.is_none() { - error!( - "Couldn't find authorization_endpoint for {}", - homepage_uri.as_str() - ); - return Err(FrontendError::with_code( - 400, - "Your website doesn't support the IndieAuth protocol.", - ) - .into()); - } - let mut authorization_endpoint: surf::Url = authorization_endpoint.unwrap(); - let mut rng = rand::thread_rng(); - let state: String = data_encoding::BASE64URL.encode( - serde_urlencoded::to_string(IndieAuthClientState { - nonce: (0..8) - .map(|_| { - let idx = rng.gen_range(0..INDIEAUTH_PKCE_CHARSET.len()); - INDIEAUTH_PKCE_CHARSET[idx] as char - }) - .collect(), - me: homepage_uri.to_string(), - authorization_endpoint: authorization_endpoint.to_string(), - })? - .as_bytes(), - ); - // PKCE code generation - let code_verifier: String = (0..128) - .map(|_| { - let idx = rng.gen_range(0..INDIEAUTH_PKCE_CHARSET.len()); - INDIEAUTH_PKCE_CHARSET[idx] as char - }) - .collect(); - let mut hasher = Sha256::new(); - hasher.update(code_verifier.as_bytes()); - let code_challenge: String = data_encoding::BASE64URL.encode(&hasher.finalize()); - - authorization_endpoint.set_query(Some(&serde_urlencoded::to_string( - IndieAuthRequestParams { - response_type: "code".to_string(), - client_id: req.url().origin().ascii_serialization(), - redirect_uri: req.url().join("login/callback")?, - state: state.clone(), - code_challenge, - code_challenge_method: "S256".to_string(), - scope: Some("profile".to_string()), - me: homepage_uri, - }, - )?)); - - let cookies = vec![ - format!( - r#"indieauth_state="{}"; Same-Site: None; Secure; Max-Age: 600"#, - state - ), - format!( - r#"indieauth_code_verifier="{}"; Same-Site: None; Secure; Max-Age: 600"#, - code_verifier - ), - ]; - - let cookie_header = cookies - .iter() - .map(|i| -> http_types::headers::HeaderValue { (i as &str).try_into().unwrap() }) - .collect::<Vec<_>>(); - - Ok(Response::builder(302) - .header("Location", authorization_endpoint.to_string()) - .header("Set-Cookie", &*cookie_header) - .build()) -} - -const INDIEAUTH_PKCE_CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\ - abcdefghijklmnopqrstuvwxyz\ - 1234567890-._~"; - -#[derive(Deserialize)] -struct IndieAuthCallbackResponse { - code: Option<String>, - error: Option<String>, - error_description: Option<String>, - #[allow(dead_code)] - error_uri: Option<String>, - // This needs to be further decoded to receive state back and will always be present - state: String, -} - -impl IndieAuthCallbackResponse { - fn is_successful(&self) -> bool { - self.code.is_some() - } -} - -#[derive(Serialize, Deserialize)] -struct IndieAuthCodeRedeem { - grant_type: String, - code: String, - client_id: String, - redirect_uri: String, - code_verifier: String, -} - -#[derive(Serialize, Deserialize)] -struct IndieWebProfile { - name: Option<String>, - url: Option<String>, - email: Option<String>, - photo: Option<String>, -} - -#[derive(Serialize, Deserialize)] -struct IndieAuthResponse { - me: String, - scope: Option<String>, - access_token: Option<String>, - token_type: Option<String>, - profile: Option<IndieWebProfile>, -} - -/// Handle IndieAuth parameters, fetch the final h-card and redirect the user to the homepage. -pub async fn callback<S: Storage>(mut req: Request<ApplicationState<S>>) -> Result { - let params: IndieAuthCallbackResponse = req.query()?; - let http: &surf::Client = &req.state().http_client; - let origin = req.url().origin().ascii_serialization(); - - if req.cookie("indieauth_state").unwrap().value() != params.state { - return Err(FrontendError::with_code(400, "The state doesn't match. A possible CSRF attack was prevented. Please try again later.").into()); - } - let state: IndieAuthClientState = - serde_urlencoded::from_bytes(&data_encoding::BASE64URL.decode(params.state.as_bytes())?)?; - - if !params.is_successful() { - return Err(FrontendError::with_code( - 400, - &format!( - "The authorization endpoint indicated a following error: {:?}: {:?}", - ¶ms.error, ¶ms.error_description - ), - ) - .into()); - } - - let authorization_endpoint = surf::Url::parse(&state.authorization_endpoint).unwrap(); - let mut code_response = http - .post(authorization_endpoint) - .body_string(serde_urlencoded::to_string(IndieAuthCodeRedeem { - grant_type: "authorization_code".to_string(), - code: params.code.unwrap().to_string(), - client_id: origin.to_string(), - redirect_uri: origin + "/login/callback", - code_verifier: req - .cookie("indieauth_code_verifier") - .unwrap() - .value() - .to_string(), - })?) - .header("Content-Type", "application/x-www-form-urlencoded") - .header("Accept", "application/json") - .send() - .await?; - - if code_response.status() != 200 { - return Err(FrontendError::with_code( - code_response.status(), - &format!( - "Authorization endpoint returned an error when redeeming the code: {}", - code_response.body_string().await? - ), - ) - .into()); - } - - let json: IndieAuthResponse = code_response.body_json().await?; - let session = req.session_mut(); - session.insert("user", &json.me)?; - - // TODO redirect to the page user came from - Ok(Response::builder(302).header("Location", "/").build()) -} diff --git a/kittybox-rs/src/frontend/mod.rs b/kittybox-rs/src/frontend/mod.rs deleted file mode 100644 index 7a43532..0000000 --- a/kittybox-rs/src/frontend/mod.rs +++ /dev/null @@ -1,404 +0,0 @@ -use crate::database::{Storage, StorageError}; -use axum::{ - extract::{Host, Path, Query}, - http::{StatusCode, Uri}, - response::IntoResponse, - Extension, -}; -use futures_util::FutureExt; -use serde::Deserialize; -use std::convert::TryInto; -use tracing::{debug, error}; -//pub mod login; -pub mod onboarding; - -use kittybox_frontend_renderer::{ - Entry, Feed, VCard, - ErrorPage, Template, MainPage, - POSTS_PER_PAGE -}; -pub use kittybox_frontend_renderer::assets::statics; - -#[derive(Debug, Deserialize)] -pub struct QueryParams { - after: Option<String>, -} - -#[derive(Debug)] -struct FrontendError { - msg: String, - source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>, - code: StatusCode, -} - -impl FrontendError { - pub fn with_code<C>(code: C, msg: &str) -> Self - where - C: TryInto<StatusCode>, - { - Self { - msg: msg.to_string(), - source: None, - code: code.try_into().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), - } - } - pub fn msg(&self) -> &str { - &self.msg - } - pub fn code(&self) -> StatusCode { - self.code - } -} - -impl From<StorageError> for FrontendError { - fn from(err: StorageError) -> Self { - Self { - msg: "Database error".to_string(), - source: Some(Box::new(err)), - code: StatusCode::INTERNAL_SERVER_ERROR, - } - } -} - -impl std::error::Error for FrontendError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.source - .as_ref() - .map(|e| e.as_ref() as &(dyn std::error::Error + 'static)) - } -} - -impl std::fmt::Display for FrontendError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.msg)?; - if let Some(err) = std::error::Error::source(&self) { - write!(f, ": {}", err)?; - } - - Ok(()) - } -} - -/// Filter the post according to the value of `user`. -/// -/// Anonymous users cannot view private posts and protected locations; -/// Logged-in users can only view private posts targeted at them; -/// Logged-in users can't view private location data -#[tracing::instrument(skip(post), fields(post = %post))] -pub fn filter_post( - mut post: serde_json::Value, - user: Option<&str>, -) -> Option<serde_json::Value> { - if post["properties"]["deleted"][0].is_string() { - tracing::debug!("Deleted post; returning tombstone instead"); - return Some(serde_json::json!({ - "type": post["type"], - "properties": { - "deleted": post["properties"]["deleted"] - } - })); - } - let empty_vec: Vec<serde_json::Value> = vec![]; - let author_list = post["properties"]["author"] - .as_array() - .unwrap_or(&empty_vec) - .iter() - .map(|i| -> &str { - match i { - serde_json::Value::String(ref author) => author.as_str(), - mf2 => mf2["properties"]["uid"][0].as_str().unwrap() - } - }).collect::<Vec<&str>>(); - let visibility = post["properties"]["visibility"][0] - .as_str() - .unwrap_or("public"); - let audience = { - let mut audience = author_list.clone(); - audience.extend(post["properties"]["audience"] - .as_array() - .unwrap_or(&empty_vec) - .iter() - .map(|i| i.as_str().unwrap())); - - audience - }; - tracing::debug!("post audience = {:?}", audience); - if (visibility == "private" && !audience.iter().any(|i| Some(*i) == user)) - || (visibility == "protected" && user.is_none()) - { - return None; - } - if post["properties"]["location"].is_array() { - let location_visibility = post["properties"]["location-visibility"][0] - .as_str() - .unwrap_or("private"); - tracing::debug!("Post contains location, location privacy = {}", location_visibility); - let mut author = post["properties"]["author"] - .as_array() - .unwrap_or(&empty_vec) - .iter() - .map(|i| i.as_str().unwrap()); - if (location_visibility == "private" && !author.any(|i| Some(i) == user)) - || (location_visibility == "protected" && user.is_none()) - { - post["properties"] - .as_object_mut() - .unwrap() - .remove("location"); - } - } - - match post["properties"]["author"].take() { - serde_json::Value::Array(children) => { - post["properties"]["author"] = serde_json::Value::Array( - children - .into_iter() - .filter_map(|post| if post.is_string() { - Some(post) - } else { - filter_post(post, user) - }) - .collect::<Vec<serde_json::Value>>() - ); - }, - serde_json::Value::Null => {}, - other => post["properties"]["author"] = other - } - - match post["children"].take() { - serde_json::Value::Array(children) => { - post["children"] = serde_json::Value::Array( - children - .into_iter() - .filter_map(|post| filter_post(post, user)) - .collect::<Vec<serde_json::Value>>() - ); - }, - serde_json::Value::Null => {}, - other => post["children"] = other - } - Some(post) -} - -async fn get_post_from_database<S: Storage>( - db: &S, - url: &str, - after: Option<String>, - user: &Option<String>, -) -> std::result::Result<(serde_json::Value, Option<String>), FrontendError> { - match db - .read_feed_with_cursor(url, after.as_deref(), POSTS_PER_PAGE, user.as_deref()) - .await - { - Ok(result) => match result { - Some((post, cursor)) => match filter_post(post, user.as_deref()) { - Some(post) => Ok((post, cursor)), - None => { - // TODO: Authentication - if user.is_some() { - Err(FrontendError::with_code( - StatusCode::FORBIDDEN, - "User authenticated AND forbidden to access this resource", - )) - } else { - Err(FrontendError::with_code( - StatusCode::UNAUTHORIZED, - "User needs to authenticate themselves", - )) - } - } - } - None => Err(FrontendError::with_code( - StatusCode::NOT_FOUND, - "Post not found in the database", - )), - }, - Err(err) => match err.kind() { - crate::database::ErrorKind::PermissionDenied => { - // TODO: Authentication - if user.is_some() { - Err(FrontendError::with_code( - StatusCode::FORBIDDEN, - "User authenticated AND forbidden to access this resource", - )) - } else { - Err(FrontendError::with_code( - StatusCode::UNAUTHORIZED, - "User needs to authenticate themselves", - )) - } - } - _ => Err(err.into()), - }, - } -} - -#[tracing::instrument(skip(db))] -pub async fn homepage<D: Storage>( - Host(host): Host, - Query(query): Query<QueryParams>, - Extension(db): Extension<D>, -) -> impl IntoResponse { - let user = None; // TODO authentication - let path = format!("https://{}/", host); - let feed_path = format!("https://{}/feeds/main", host); - - match tokio::try_join!( - get_post_from_database(&db, &path, None, &user), - get_post_from_database(&db, &feed_path, query.after, &user) - ) { - Ok(((hcard, _), (hfeed, cursor))) => { - // Here, we know those operations can't really fail - // (or it'll be a transient failure that will show up on - // other requests anyway if it's serious...) - // - // btw is it more efficient to fetch these in parallel? - let (blogname, webring, channels) = tokio::join!( - db.get_setting::<crate::database::settings::SiteName>(&host) - .map(Result::unwrap_or_default), - - db.get_setting::<crate::database::settings::Webring>(&host) - .map(Result::unwrap_or_default), - - db.get_channels(&host).map(|i| i.unwrap_or_default()) - ); - // Render the homepage - ( - StatusCode::OK, - [( - axum::http::header::CONTENT_TYPE, - r#"text/html; charset="utf-8""#, - )], - Template { - title: blogname.as_ref(), - blog_name: blogname.as_ref(), - feeds: channels, - user, - content: MainPage { - feed: &hfeed, - card: &hcard, - cursor: cursor.as_deref(), - webring: crate::database::settings::Setting::into_inner(webring) - } - .to_string(), - } - .to_string(), - ) - } - Err(err) => { - if err.code == StatusCode::NOT_FOUND { - debug!("Transferring to onboarding..."); - // Transfer to onboarding - ( - StatusCode::FOUND, - [(axum::http::header::LOCATION, "/.kittybox/onboarding")], - String::default(), - ) - } else { - error!("Error while fetching h-card and/or h-feed: {}", err); - // Return the error - let (blogname, channels) = tokio::join!( - db.get_setting::<crate::database::settings::SiteName>(&host) - .map(Result::unwrap_or_default), - - db.get_channels(&host).map(|i| i.unwrap_or_default()) - ); - - ( - err.code(), - [( - axum::http::header::CONTENT_TYPE, - r#"text/html; charset="utf-8""#, - )], - Template { - title: blogname.as_ref(), - blog_name: blogname.as_ref(), - feeds: channels, - user, - content: ErrorPage { - code: err.code(), - msg: Some(err.msg().to_string()), - } - .to_string(), - } - .to_string(), - ) - } - } - } -} - -#[tracing::instrument(skip(db))] -pub async fn catchall<D: Storage>( - Extension(db): Extension<D>, - Host(host): Host, - Query(query): Query<QueryParams>, - uri: Uri, -) -> impl IntoResponse { - let user = None; // TODO authentication - let path = url::Url::parse(&format!("https://{}/", host)) - .unwrap() - .join(uri.path()) - .unwrap(); - - match get_post_from_database(&db, path.as_str(), query.after, &user).await { - Ok((post, cursor)) => { - let (blogname, channels) = tokio::join!( - db.get_setting::<crate::database::settings::SiteName>(&host) - .map(Result::unwrap_or_default), - - db.get_channels(&host).map(|i| i.unwrap_or_default()) - ); - // Render the homepage - ( - StatusCode::OK, - [( - axum::http::header::CONTENT_TYPE, - r#"text/html; charset="utf-8""#, - )], - Template { - title: blogname.as_ref(), - blog_name: blogname.as_ref(), - feeds: channels, - user, - content: match post.pointer("/type/0").and_then(|i| i.as_str()) { - Some("h-entry") => Entry { post: &post }.to_string(), - Some("h-feed") => Feed { feed: &post, cursor: cursor.as_deref() }.to_string(), - Some("h-card") => VCard { card: &post }.to_string(), - unknown => { - unimplemented!("Template for MF2-JSON type {:?}", unknown) - } - }, - } - .to_string(), - ) - } - Err(err) => { - let (blogname, channels) = tokio::join!( - db.get_setting::<crate::database::settings::SiteName>(&host) - .map(Result::unwrap_or_default), - - db.get_channels(&host).map(|i| i.unwrap_or_default()) - ); - ( - err.code(), - [( - axum::http::header::CONTENT_TYPE, - r#"text/html; charset="utf-8""#, - )], - Template { - title: blogname.as_ref(), - blog_name: blogname.as_ref(), - feeds: channels, - user, - content: ErrorPage { - code: err.code(), - msg: Some(err.msg().to_owned()), - } - .to_string(), - } - .to_string(), - ) - } - } -} diff --git a/kittybox-rs/src/frontend/onboarding.rs b/kittybox-rs/src/frontend/onboarding.rs deleted file mode 100644 index e44e866..0000000 --- a/kittybox-rs/src/frontend/onboarding.rs +++ /dev/null @@ -1,181 +0,0 @@ -use std::sync::Arc; - -use crate::database::{settings, Storage}; -use axum::{ - extract::{Extension, Host}, - http::StatusCode, - response::{Html, IntoResponse}, - Json, -}; -use kittybox_frontend_renderer::{ErrorPage, OnboardingPage, Template}; -use serde::Deserialize; -use tokio::{task::JoinSet, sync::Mutex}; -use tracing::{debug, error}; - -use super::FrontendError; - -pub async fn get() -> Html<String> { - Html( - Template { - title: "Kittybox - Onboarding", - blog_name: "Kittybox", - feeds: vec![], - user: None, - content: OnboardingPage {}.to_string(), - } - .to_string(), - ) -} - -#[derive(Deserialize, Debug)] -struct OnboardingFeed { - slug: String, - name: String, -} - -#[derive(Deserialize, Debug)] -pub struct OnboardingData { - user: serde_json::Value, - first_post: serde_json::Value, - #[serde(default = "OnboardingData::default_blog_name")] - blog_name: String, - feeds: Vec<OnboardingFeed>, -} - -impl OnboardingData { - fn default_blog_name() -> String { - "Kitty Box!".to_owned() - } -} - -#[tracing::instrument(skip(db, http))] -async fn onboard<D: Storage + 'static>( - db: D, - user_uid: url::Url, - data: OnboardingData, - http: reqwest::Client, - jobset: Arc<Mutex<JoinSet<()>>>, -) -> Result<(), FrontendError> { - // Create a user to pass to the backend - // At this point the site belongs to nobody, so it is safe to do - tracing::debug!("Creating user..."); - let user = kittybox_indieauth::TokenData { - me: user_uid.clone(), - client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), - scope: kittybox_indieauth::Scopes::new(vec![kittybox_indieauth::Scope::Create]), - iat: None, exp: None - }; - tracing::debug!("User data: {:?}", user); - - if data.user["type"][0] != "h-card" || data.first_post["type"][0] != "h-entry" { - return Err(FrontendError::with_code( - StatusCode::BAD_REQUEST, - "user and first_post should be an h-card and an h-entry", - )); - } - - tracing::debug!("Setting settings..."); - let user_domain = format!( - "{}{}", - user.me.host_str().unwrap(), - user.me.port() - .map(|port| format!(":{}", port)) - .unwrap_or_default() - ); - db.set_setting::<settings::SiteName>(&user_domain, data.blog_name.to_owned()) - .await - .map_err(FrontendError::from)?; - - db.set_setting::<settings::Webring>(&user_domain, false) - .await - .map_err(FrontendError::from)?; - - let (_, hcard) = { - let mut hcard = data.user; - hcard["properties"]["uid"] = serde_json::json!([&user_uid]); - crate::micropub::normalize_mf2(hcard, &user) - }; - db.put_post(&hcard, user_domain.as_str()) - .await - .map_err(FrontendError::from)?; - - debug!("Creating feeds..."); - for feed in data.feeds { - if feed.name.is_empty() || feed.slug.is_empty() { - continue; - }; - debug!("Creating feed {} with slug {}", &feed.name, &feed.slug); - let (_, feed) = crate::micropub::normalize_mf2( - serde_json::json!({ - "type": ["h-feed"], - "properties": {"name": [feed.name], "mp-slug": [feed.slug]} - }), - &user, - ); - - db.put_post(&feed, user_uid.as_str()) - .await - .map_err(FrontendError::from)?; - } - let (uid, post) = crate::micropub::normalize_mf2(data.first_post, &user); - tracing::debug!("Posting first post {}...", uid); - crate::micropub::_post(&user, uid, post, db, http, jobset) - .await - .map_err(|e| FrontendError { - msg: "Error while posting the first post".to_string(), - source: Some(Box::new(e)), - code: StatusCode::INTERNAL_SERVER_ERROR, - })?; - - Ok(()) -} - -pub async fn post<D: Storage + 'static>( - Extension(db): Extension<D>, - Host(host): Host, - Extension(http): Extension<reqwest::Client>, - Extension(jobset): Extension<Arc<Mutex<JoinSet<()>>>>, - Json(data): Json<OnboardingData>, -) -> axum::response::Response { - let user_uid = format!("https://{}/", host.as_str()); - - if db.post_exists(&user_uid).await.unwrap() { - IntoResponse::into_response((StatusCode::FOUND, [("Location", "/")])) - } else { - match onboard(db, user_uid.parse().unwrap(), data, http, jobset).await { - Ok(()) => IntoResponse::into_response((StatusCode::FOUND, [("Location", "/")])), - Err(err) => { - error!("Onboarding error: {}", err); - IntoResponse::into_response(( - err.code(), - Html( - Template { - title: "Kittybox - Onboarding", - blog_name: "Kittybox", - feeds: vec![], - user: None, - content: ErrorPage { - code: err.code(), - msg: Some(err.msg().to_string()), - } - .to_string(), - } - .to_string(), - ), - )) - } - } - } -} - -pub fn router<S: Storage + 'static>( - database: S, - http: reqwest::Client, - jobset: Arc<Mutex<JoinSet<()>>>, -) -> axum::routing::MethodRouter { - axum::routing::get(get) - .post(post::<S>) - .layer::<_, _, std::convert::Infallible>(axum::Extension(database)) - .layer::<_, _, std::convert::Infallible>(axum::Extension(http)) - .layer(axum::Extension(jobset)) -} diff --git a/kittybox-rs/src/indieauth/backend.rs b/kittybox-rs/src/indieauth/backend.rs deleted file mode 100644 index 534bcfb..0000000 --- a/kittybox-rs/src/indieauth/backend.rs +++ /dev/null @@ -1,105 +0,0 @@ -use std::collections::HashMap; -use kittybox_indieauth::{ - AuthorizationRequest, TokenData -}; -pub use kittybox_util::auth::EnrolledCredential; - -type Result<T> = std::io::Result<T>; - -pub mod fs; -pub use fs::FileBackend; - -#[async_trait::async_trait] -pub trait AuthBackend: Clone + Send + Sync + 'static { - // Authorization code management. - /// Create a one-time OAuth2 authorization code for the passed - /// authorization request, and save it for later retrieval. - /// - /// Note for implementors: the [`AuthorizationRequest::me`] value - /// is guaranteed to be [`Some(url::Url)`][Option::Some] and can - /// be trusted to be correct and non-malicious. - async fn create_code(&self, data: AuthorizationRequest) -> Result<String>; - /// Retreive an authorization request using the one-time - /// code. Implementations must sanitize the `code` field to - /// prevent exploits, and must check if the code should still be - /// valid at this point in time (validity interval is left up to - /// the implementation, but is recommended to be no more than 10 - /// minutes). - async fn get_code(&self, code: &str) -> Result<Option<AuthorizationRequest>>; - // Token management. - async fn create_token(&self, data: TokenData) -> Result<String>; - async fn get_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>>; - async fn list_tokens(&self, website: &url::Url) -> Result<HashMap<String, TokenData>>; - async fn revoke_token(&self, website: &url::Url, token: &str) -> Result<()>; - // Refresh token management. - async fn create_refresh_token(&self, data: TokenData) -> Result<String>; - async fn get_refresh_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>>; - async fn list_refresh_tokens(&self, website: &url::Url) -> Result<HashMap<String, TokenData>>; - async fn revoke_refresh_token(&self, website: &url::Url, token: &str) -> Result<()>; - // Password management. - /// Verify a password. - #[must_use] - async fn verify_password(&self, website: &url::Url, password: String) -> Result<bool>; - /// Enroll a password credential for a user. Only one password - /// credential must exist for a given user. - async fn enroll_password(&self, website: &url::Url, password: String) -> Result<()>; - /// List currently enrolled credential types for a given user. - async fn list_user_credential_types(&self, website: &url::Url) -> Result<Vec<EnrolledCredential>>; - // WebAuthn credential management. - #[cfg(feature = "webauthn")] - /// Enroll a WebAuthn authenticator public key for this user. - /// Multiple public keys may be saved for one user, corresponding - /// to different authenticators used by them. - /// - /// This function can also be used to overwrite a passkey with an - /// updated version after using - /// [webauthn::prelude::Passkey::update_credential()]. - async fn enroll_webauthn(&self, website: &url::Url, credential: webauthn::prelude::Passkey) -> Result<()>; - #[cfg(feature = "webauthn")] - /// List currently enrolled WebAuthn authenticators for a given user. - async fn list_webauthn_pubkeys(&self, website: &url::Url) -> Result<Vec<webauthn::prelude::Passkey>>; - #[cfg(feature = "webauthn")] - /// Persist registration challenge state for a little while so it - /// can be used later. - /// - /// Challenges saved in this manner MUST expire after a little - /// while. 10 minutes is recommended. - async fn persist_registration_challenge( - &self, - website: &url::Url, - state: webauthn::prelude::PasskeyRegistration - ) -> Result<String>; - #[cfg(feature = "webauthn")] - /// Retrieve a persisted registration challenge. - /// - /// The challenge should be deleted after retrieval. - async fn retrieve_registration_challenge( - &self, - website: &url::Url, - challenge_id: &str - ) -> Result<webauthn::prelude::PasskeyRegistration>; - #[cfg(feature = "webauthn")] - /// Persist authentication challenge state for a little while so - /// it can be used later. - /// - /// Challenges saved in this manner MUST expire after a little - /// while. 10 minutes is recommended. - /// - /// To support multiple authentication options, this can return an - /// opaque token that should be set as a cookie. - async fn persist_authentication_challenge( - &self, - website: &url::Url, - state: webauthn::prelude::PasskeyAuthentication - ) -> Result<String>; - #[cfg(feature = "webauthn")] - /// Retrieve a persisted authentication challenge. - /// - /// The challenge should be deleted after retrieval. - async fn retrieve_authentication_challenge( - &self, - website: &url::Url, - challenge_id: &str - ) -> Result<webauthn::prelude::PasskeyAuthentication>; - -} diff --git a/kittybox-rs/src/indieauth/backend/fs.rs b/kittybox-rs/src/indieauth/backend/fs.rs deleted file mode 100644 index 600e901..0000000 --- a/kittybox-rs/src/indieauth/backend/fs.rs +++ /dev/null @@ -1,420 +0,0 @@ -use std::{path::PathBuf, collections::HashMap, borrow::Cow, time::{SystemTime, Duration}}; - -use super::{AuthBackend, Result, EnrolledCredential}; -use async_trait::async_trait; -use kittybox_indieauth::{ - AuthorizationRequest, TokenData -}; -use serde::de::DeserializeOwned; -use tokio::{task::spawn_blocking, io::AsyncReadExt}; -#[cfg(feature = "webauthn")] -use webauthn::prelude::{Passkey, PasskeyRegistration, PasskeyAuthentication}; - -const CODE_LENGTH: usize = 16; -const TOKEN_LENGTH: usize = 128; -const CODE_DURATION: std::time::Duration = std::time::Duration::from_secs(600); - -#[derive(Clone, Debug)] -pub struct FileBackend { - path: PathBuf, -} - -impl FileBackend { - pub fn new<T: Into<PathBuf>>(path: T) -> Self { - Self { - path: path.into() - } - } - - /// Sanitize a filename, leaving only alphanumeric characters. - /// - /// Doesn't allocate a new string unless non-alphanumeric - /// characters are encountered. - fn sanitize_for_path(filename: &'_ str) -> Cow<'_, str> { - if filename.chars().all(char::is_alphanumeric) { - Cow::Borrowed(filename) - } else { - let mut s = String::with_capacity(filename.len()); - - filename.chars() - .filter(|c| c.is_alphanumeric()) - .for_each(|c| s.push(c)); - - Cow::Owned(s) - } - } - - #[inline] - async fn serialize_to_file<T: 'static + serde::ser::Serialize + Send, B: Into<Option<&'static str>>>( - &self, - dir: &str, - basename: B, - length: usize, - data: T - ) -> Result<String> { - let basename = basename.into(); - let has_ext = basename.is_some(); - let (filename, mut file) = kittybox_util::fs::mktemp( - self.path.join(dir), basename, length - ) - .await - .map(|(name, file)| (name, file.try_into_std().unwrap()))?; - - spawn_blocking(move || serde_json::to_writer(&mut file, &data)) - .await - .unwrap_or_else(|e| panic!( - "Panic while serializing {}: {}", - std::any::type_name::<T>(), - e - )) - .map(move |_| { - (if has_ext { - filename - .extension() - - } else { - filename - .file_name() - }) - .unwrap() - .to_str() - .unwrap() - .to_owned() - }) - .map_err(|err| err.into()) - } - - #[inline] - async fn deserialize_from_file<'filename, 'this: 'filename, T, B>( - &'this self, - dir: &'filename str, - basename: B, - filename: &'filename str, - ) -> Result<Option<(PathBuf, SystemTime, T)>> - where - T: serde::de::DeserializeOwned + Send, - B: Into<Option<&'static str>> - { - let basename = basename.into(); - let path = self.path - .join(dir) - .join(format!( - "{}{}{}", - basename.unwrap_or(""), - if basename.is_none() { "" } else { "." }, - FileBackend::sanitize_for_path(filename) - )); - - let data = match tokio::fs::File::open(&path).await { - Ok(mut file) => { - let mut buf = Vec::new(); - - file.read_to_end(&mut buf).await?; - - match serde_json::from_slice::<'_, T>(buf.as_slice()) { - Ok(data) => data, - Err(err) => return Err(err.into()) - } - }, - Err(err) => if err.kind() == std::io::ErrorKind::NotFound { - return Ok(None) - } else { - return Err(err) - } - }; - - let ctime = tokio::fs::metadata(&path).await?.created()?; - - Ok(Some((path, ctime, data))) - } - - #[inline] - fn url_to_dir(url: &url::Url) -> String { - let host = url.host_str().unwrap(); - let port = url.port() - .map(|port| Cow::Owned(format!(":{}", port))) - .unwrap_or(Cow::Borrowed("")); - - format!("{}{}", host, port) - } - - async fn list_files<'dir, 'this: 'dir, T: DeserializeOwned + Send>( - &'this self, - dir: &'dir str, - prefix: &'static str - ) -> Result<HashMap<String, T>> { - let dir = self.path.join(dir); - - let mut hashmap = HashMap::new(); - let mut readdir = match tokio::fs::read_dir(dir).await { - Ok(readdir) => readdir, - Err(err) => if err.kind() == std::io::ErrorKind::NotFound { - // empty hashmap - return Ok(hashmap); - } else { - return Err(err); - } - }; - while let Some(entry) = readdir.next_entry().await? { - // safe to unwrap; filenames are alphanumeric - let filename = entry.file_name() - .into_string() - .expect("token filenames should be alphanumeric!"); - if let Some(token) = filename.strip_prefix(&format!("{}.", prefix)) { - match tokio::fs::File::open(entry.path()).await { - Ok(mut file) => { - let mut buf = Vec::new(); - - file.read_to_end(&mut buf).await?; - - match serde_json::from_slice::<'_, T>(buf.as_slice()) { - Ok(data) => hashmap.insert(token.to_string(), data), - Err(err) => { - tracing::error!( - "Error decoding token data from file {}: {}", - entry.path().display(), err - ); - continue; - } - }; - }, - Err(err) => if err.kind() == std::io::ErrorKind::NotFound { - continue - } else { - return Err(err) - } - } - } - } - - Ok(hashmap) - } -} - -#[async_trait] -impl AuthBackend for FileBackend { - // Authorization code management. - async fn create_code(&self, data: AuthorizationRequest) -> Result<String> { - self.serialize_to_file("codes", None, CODE_LENGTH, data).await - } - - async fn get_code(&self, code: &str) -> Result<Option<AuthorizationRequest>> { - match self.deserialize_from_file("codes", None, FileBackend::sanitize_for_path(code).as_ref()).await? { - Some((path, ctime, data)) => { - if let Err(err) = tokio::fs::remove_file(path).await { - tracing::error!("Failed to clean up authorization code: {}", err); - } - // Err on the safe side in case of clock drift - if ctime.elapsed().unwrap_or(Duration::ZERO) > CODE_DURATION { - Ok(None) - } else { - Ok(Some(data)) - } - }, - None => Ok(None) - } - } - - // Token management. - async fn create_token(&self, data: TokenData) -> Result<String> { - let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me)); - self.serialize_to_file(&dir, "access", TOKEN_LENGTH, data).await - } - - async fn get_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> { - let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); - match self.deserialize_from_file::<TokenData, _>( - &dir, "access", - FileBackend::sanitize_for_path(token).as_ref() - ).await? { - Some((path, _, token)) => { - if token.expired() { - if let Err(err) = tokio::fs::remove_file(path).await { - tracing::error!("Failed to remove expired token: {}", err); - } - Ok(None) - } else { - Ok(Some(token)) - } - }, - None => Ok(None) - } - } - - async fn list_tokens(&self, website: &url::Url) -> Result<HashMap<String, TokenData>> { - let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); - self.list_files(&dir, "access").await - } - - async fn revoke_token(&self, website: &url::Url, token: &str) -> Result<()> { - match tokio::fs::remove_file( - self.path - .join(FileBackend::url_to_dir(website)) - .join("tokens") - .join(format!("access.{}", FileBackend::sanitize_for_path(token))) - ).await { - Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()), - result => result - } - } - - // Refresh token management. - async fn create_refresh_token(&self, data: TokenData) -> Result<String> { - let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me)); - self.serialize_to_file(&dir, "refresh", TOKEN_LENGTH, data).await - } - - async fn get_refresh_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> { - let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); - match self.deserialize_from_file::<TokenData, _>( - &dir, "refresh", - FileBackend::sanitize_for_path(token).as_ref() - ).await? { - Some((path, _, token)) => { - if token.expired() { - if let Err(err) = tokio::fs::remove_file(path).await { - tracing::error!("Failed to remove expired token: {}", err); - } - Ok(None) - } else { - Ok(Some(token)) - } - }, - None => Ok(None) - } - } - - async fn list_refresh_tokens(&self, website: &url::Url) -> Result<HashMap<String, TokenData>> { - let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); - self.list_files(&dir, "refresh").await - } - - async fn revoke_refresh_token(&self, website: &url::Url, token: &str) -> Result<()> { - match tokio::fs::remove_file( - self.path - .join(FileBackend::url_to_dir(website)) - .join("tokens") - .join(format!("refresh.{}", FileBackend::sanitize_for_path(token))) - ).await { - Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()), - result => result - } - } - - // Password management. - #[tracing::instrument(skip(password))] - async fn verify_password(&self, website: &url::Url, password: String) -> Result<bool> { - use argon2::{Argon2, password_hash::{PasswordHash, PasswordVerifier}}; - - let password_filename = self.path - .join(FileBackend::url_to_dir(website)) - .join("password"); - - tracing::debug!("Reading password for {} from {}", website, password_filename.display()); - - match tokio::fs::read_to_string(password_filename).await { - Ok(password_hash) => { - let parsed_hash = { - let hash = password_hash.trim(); - #[cfg(debug_assertions)] tracing::debug!("Password hash: {}", hash); - PasswordHash::new(hash) - .expect("Password hash should be valid!") - }; - Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok()) - }, - Err(err) => if err.kind() == std::io::ErrorKind::NotFound { - Ok(false) - } else { - Err(err) - } - } - } - - #[tracing::instrument(skip(password))] - async fn enroll_password(&self, website: &url::Url, password: String) -> Result<()> { - use argon2::{Argon2, password_hash::{rand_core::OsRng, PasswordHasher, SaltString}}; - - let password_filename = self.path - .join(FileBackend::url_to_dir(website)) - .join("password"); - - let salt = SaltString::generate(&mut OsRng); - let argon2 = Argon2::default(); - let password_hash = argon2.hash_password(password.as_bytes(), &salt) - .expect("Hashing a password should not error out") - .to_string(); - - tracing::debug!("Enrolling password for {} at {}", website, password_filename.display()); - tokio::fs::write(password_filename, password_hash.as_bytes()).await - } - - // WebAuthn credential management. - #[cfg(feature = "webauthn")] - async fn enroll_webauthn(&self, website: &url::Url, credential: Passkey) -> Result<()> { - todo!() - } - - #[cfg(feature = "webauthn")] - async fn list_webauthn_pubkeys(&self, website: &url::Url) -> Result<Vec<Passkey>> { - // TODO stub! - Ok(vec![]) - } - - #[cfg(feature = "webauthn")] - async fn persist_registration_challenge( - &self, - website: &url::Url, - state: PasskeyRegistration - ) -> Result<String> { - todo!() - } - - #[cfg(feature = "webauthn")] - async fn retrieve_registration_challenge( - &self, - website: &url::Url, - challenge_id: &str - ) -> Result<PasskeyRegistration> { - todo!() - } - - #[cfg(feature = "webauthn")] - async fn persist_authentication_challenge( - &self, - website: &url::Url, - state: PasskeyAuthentication - ) -> Result<String> { - todo!() - } - - #[cfg(feature = "webauthn")] - async fn retrieve_authentication_challenge( - &self, - website: &url::Url, - challenge_id: &str - ) -> Result<PasskeyAuthentication> { - todo!() - } - - async fn list_user_credential_types(&self, website: &url::Url) -> Result<Vec<EnrolledCredential>> { - let mut creds = vec![]; - - match tokio::fs::metadata(self.path - .join(FileBackend::url_to_dir(website)) - .join("password")) - .await - { - Ok(_) => creds.push(EnrolledCredential::Password), - Err(err) => if err.kind() != std::io::ErrorKind::NotFound { - return Err(err) - } - } - - #[cfg(feature = "webauthn")] - if !self.list_webauthn_pubkeys(website).await?.is_empty() { - creds.push(EnrolledCredential::WebAuthn); - } - - Ok(creds) - } -} diff --git a/kittybox-rs/src/indieauth/mod.rs b/kittybox-rs/src/indieauth/mod.rs deleted file mode 100644 index 0ad2702..0000000 --- a/kittybox-rs/src/indieauth/mod.rs +++ /dev/null @@ -1,883 +0,0 @@ -use std::marker::PhantomData; - -use tracing::error; -use serde::Deserialize; -use axum::{ - extract::{Query, Json, Host, Form}, - response::{Html, IntoResponse, Response}, - http::StatusCode, TypedHeader, headers::{Authorization, authorization::Bearer}, - Extension -}; -#[cfg_attr(not(feature = "webauthn"), allow(unused_imports))] -use axum_extra::extract::cookie::{CookieJar, Cookie}; -use crate::database::Storage; -use kittybox_indieauth::{ - Metadata, IntrospectionEndpointAuthMethod, RevocationEndpointAuthMethod, - Scope, Scopes, PKCEMethod, Error, ErrorKind, ResponseType, - AuthorizationRequest, AuthorizationResponse, - GrantType, GrantRequest, GrantResponse, Profile, - TokenIntrospectionRequest, TokenIntrospectionResponse, TokenRevocationRequest, TokenData -}; -use std::str::FromStr; -use std::ops::Deref; - -pub mod backend; -#[cfg(feature = "webauthn")] -mod webauthn; -use backend::AuthBackend; - -const ACCESS_TOKEN_VALIDITY: u64 = 7 * 24 * 60 * 60; // 7 days -const REFRESH_TOKEN_VALIDITY: u64 = ACCESS_TOKEN_VALIDITY / 7 * 60; // 60 days -/// Internal scope for accessing the token introspection endpoint. -const KITTYBOX_TOKEN_STATUS: &str = "kittybox:token_status"; - -pub(crate) struct User<A: AuthBackend>(pub(crate) TokenData, pub(crate) PhantomData<A>); -impl<A: AuthBackend> std::fmt::Debug for User<A> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_tuple("User").field(&self.0).finish() - } -} -impl<A: AuthBackend> std::ops::Deref for User<A> { - type Target = TokenData; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -pub enum IndieAuthResourceError { - InvalidRequest, - Unauthorized, - InvalidToken -} -impl axum::response::IntoResponse for IndieAuthResourceError { - fn into_response(self) -> axum::response::Response { - use IndieAuthResourceError::*; - - match self { - Unauthorized => ( - StatusCode::UNAUTHORIZED, - [("WWW-Authenticate", "Bearer")] - ).into_response(), - InvalidRequest => ( - StatusCode::BAD_REQUEST, - Json(&serde_json::json!({"error": "invalid_request"})) - ).into_response(), - InvalidToken => ( - StatusCode::UNAUTHORIZED, - [("WWW-Authenticate", "Bearer, error=\"invalid_token\"")], - Json(&serde_json::json!({"error": "unauthorized"})) - ).into_response() - } - } -} - -#[async_trait::async_trait] -impl <S: Send + Sync, A: AuthBackend> axum::extract::FromRequestParts<S> for User<A> { - type Rejection = IndieAuthResourceError; - - async fn from_request_parts(req: &mut axum::http::request::Parts, state: &S) -> Result<Self, Self::Rejection> { - let TypedHeader(Authorization(token)) = - TypedHeader::<Authorization<Bearer>>::from_request_parts(req, state) - .await - .map_err(|_| IndieAuthResourceError::Unauthorized)?; - - let axum::Extension(auth) = axum::Extension::<A>::from_request_parts(req, state) - .await - .unwrap(); - - let Host(host) = Host::from_request_parts(req, state) - .await - .map_err(|_| IndieAuthResourceError::InvalidRequest)?; - - auth.get_token( - &format!("https://{host}/").parse().unwrap(), - token.token() - ) - .await - .unwrap() - .ok_or(IndieAuthResourceError::InvalidToken) - .map(|t| User(t, PhantomData)) - } -} - -pub async fn metadata( - Host(host): Host -) -> Metadata { - let issuer: url::Url = format!( - "{}://{}/", - if cfg!(debug_assertions) { - "http" - } else { - "https" - }, - host - ).parse().unwrap(); - - let indieauth: url::Url = issuer.join("/.kittybox/indieauth/").unwrap(); - Metadata { - issuer, - authorization_endpoint: indieauth.join("auth").unwrap(), - token_endpoint: indieauth.join("token").unwrap(), - introspection_endpoint: indieauth.join("token_status").unwrap(), - introspection_endpoint_auth_methods_supported: Some(vec![ - IntrospectionEndpointAuthMethod::Bearer - ]), - revocation_endpoint: Some(indieauth.join("revoke_token").unwrap()), - revocation_endpoint_auth_methods_supported: Some(vec![ - RevocationEndpointAuthMethod::None - ]), - scopes_supported: Some(vec![ - Scope::Create, - Scope::Update, - Scope::Delete, - Scope::Media, - Scope::Profile - ]), - response_types_supported: Some(vec![ResponseType::Code]), - grant_types_supported: Some(vec![GrantType::AuthorizationCode, GrantType::RefreshToken]), - service_documentation: None, - code_challenge_methods_supported: vec![PKCEMethod::S256], - authorization_response_iss_parameter_supported: Some(true), - userinfo_endpoint: Some(indieauth.join("userinfo").unwrap()), - } -} - -async fn authorization_endpoint_get<A: AuthBackend, D: Storage + 'static>( - Host(host): Host, - Query(request): Query<AuthorizationRequest>, - Extension(db): Extension<D>, - Extension(http): Extension<reqwest::Client>, - Extension(auth): Extension<A> -) -> Response { - let me = format!("https://{host}/").parse().unwrap(); - let h_app = { - tracing::debug!("Sending request to {} to fetch metadata", request.client_id); - match http.get(request.client_id.clone()).send().await { - Ok(response) => { - let url = response.url().clone(); - let text = response.text().await.unwrap(); - tracing::debug!("Received {} bytes in response", text.len()); - match microformats::from_html(&text, url) { - Ok(mf2) => { - if let Some(relation) = mf2.rels.items.get(&request.redirect_uri) { - if !relation.rels.iter().any(|i| i == "redirect_uri") { - return (StatusCode::BAD_REQUEST, - [("Content-Type", "text/plain")], - "The redirect_uri provided was declared as \ - something other than redirect_uri.") - .into_response() - } - } else if request.redirect_uri.origin() != request.client_id.origin() { - return (StatusCode::BAD_REQUEST, - [("Content-Type", "text/plain")], - "The redirect_uri didn't match the origin \ - and wasn't explicitly allowed. You were being tricked.") - .into_response() - } - - mf2.items.iter() - .cloned() - .find(|i| (**i).borrow().r#type.iter() - .any(|i| *i == microformats::types::Class::from_str("h-app").unwrap() - || *i == microformats::types::Class::from_str("h-x-app").unwrap())) - .map(|i| serde_json::to_value(i.borrow().deref()).unwrap()) - }, - Err(err) => { - tracing::error!("Error parsing application metadata: {}", err); - return (StatusCode::BAD_REQUEST, - [("Content-Type", "text/plain")], - "Parsing application metadata failed.").into_response() - } - } - }, - Err(err) => { - tracing::error!("Error fetching application metadata: {}", err); - return (StatusCode::INTERNAL_SERVER_ERROR, - [("Content-Type", "text/plain")], - "Fetching application metadata failed.").into_response() - } - } - }; - - tracing::debug!("Application metadata: {:#?}", h_app); - - Html(kittybox_frontend_renderer::Template { - title: "Confirm sign-in via IndieAuth", - blog_name: "Kittybox", - feeds: vec![], - user: None, - content: kittybox_frontend_renderer::AuthorizationRequestPage { - request, - credentials: auth.list_user_credential_types(&me).await.unwrap(), - user: db.get_post(me.as_str()).await.unwrap().unwrap(), - app: h_app - }.to_string(), - }.to_string()) - .into_response() -} - -#[derive(Deserialize, Debug)] -#[serde(untagged)] -enum Credential { - Password(String), - #[cfg(feature = "webauthn")] - WebAuthn(::webauthn::prelude::PublicKeyCredential) -} - -#[derive(Deserialize, Debug)] -struct AuthorizationConfirmation { - authorization_method: Credential, - request: AuthorizationRequest -} - -async fn verify_credential<A: AuthBackend>( - auth: &A, - website: &url::Url, - credential: Credential, - #[cfg_attr(not(feature = "webauthn"), allow(unused_variables))] - challenge_id: Option<&str> -) -> std::io::Result<bool> { - match credential { - Credential::Password(password) => auth.verify_password(website, password).await, - #[cfg(feature = "webauthn")] - Credential::WebAuthn(credential) => webauthn::verify( - auth, - website, - credential, - challenge_id.unwrap() - ).await - } -} - -#[tracing::instrument(skip(backend, confirmation))] -async fn authorization_endpoint_confirm<A: AuthBackend>( - Host(host): Host, - Extension(backend): Extension<A>, - cookies: CookieJar, - Json(confirmation): Json<AuthorizationConfirmation>, -) -> Response { - tracing::debug!("Received authorization confirmation from user"); - #[cfg(feature = "webauthn")] - let challenge_id = cookies.get(webauthn::CHALLENGE_ID_COOKIE) - .map(|cookie| cookie.value()); - #[cfg(not(feature = "webauthn"))] - let challenge_id = None; - - let website = format!("https://{}/", host).parse().unwrap(); - let AuthorizationConfirmation { - authorization_method: credential, - request: mut auth - } = confirmation; - match verify_credential(&backend, &website, credential, challenge_id).await { - Ok(verified) => if !verified { - error!("User failed verification, bailing out."); - return StatusCode::UNAUTHORIZED.into_response(); - }, - Err(err) => { - error!("Error while verifying credential: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - } - // Insert the correct `me` value into the request - // - // From this point, the `me` value that hits the backend is - // guaranteed to be authoritative and correct, and can be safely - // unwrapped. - auth.me = Some(website.clone()); - // Cloning these two values, because we can't destructure - // the AuthorizationRequest - we need it for the code - let state = auth.state.clone(); - let redirect_uri = auth.redirect_uri.clone(); - - let code = match backend.create_code(auth).await { - Ok(code) => code, - Err(err) => { - error!("Error creating authorization code: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - }; - - let location = { - let mut uri = redirect_uri; - uri.set_query(Some(&serde_urlencoded::to_string( - AuthorizationResponse { code, state, iss: website } - ).unwrap())); - - uri - }; - - // DO NOT SET `StatusCode::FOUND` here! `fetch()` cannot read from - // redirects, it can only follow them or choose to receive an - // opaque response instead that is completely useless - (StatusCode::NO_CONTENT, - [("Location", location.as_str())], - #[cfg(feature = "webauthn")] - cookies.remove(Cookie::named(webauthn::CHALLENGE_ID_COOKIE)) - ) - .into_response() -} - -#[tracing::instrument(skip(backend, db))] -async fn authorization_endpoint_post<A: AuthBackend, D: Storage + 'static>( - Host(host): Host, - Extension(backend): Extension<A>, - Extension(db): Extension<D>, - Form(grant): Form<GrantRequest>, -) -> Response { - match grant { - GrantRequest::AuthorizationCode { - code, - client_id, - redirect_uri, - code_verifier - } => { - let request: AuthorizationRequest = match backend.get_code(&code).await { - Ok(Some(request)) => request, - Ok(None) => return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("The provided authorization code is invalid.".to_string()), - error_uri: None - }.into_response(), - Err(err) => { - tracing::error!("Error retrieving auth request: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - }; - if client_id != request.client_id { - return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("This authorization code isn't yours.".to_string()), - error_uri: None - }.into_response() - } - if redirect_uri != request.redirect_uri { - return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("This redirect_uri doesn't match the one the code has been sent to.".to_string()), - error_uri: None - }.into_response() - } - if !request.code_challenge.verify(code_verifier) { - return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("The PKCE challenge failed.".to_string()), - // are RFCs considered human-readable? 😝 - error_uri: "https://datatracker.ietf.org/doc/html/rfc7636#section-4.6".parse().ok() - }.into_response() - } - let me: url::Url = format!("https://{}/", host).parse().unwrap(); - if request.me.unwrap() != me { - return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("This authorization endpoint does not serve this user.".to_string()), - error_uri: None - }.into_response() - } - let profile = if request.scope.as_ref() - .map(|s| s.has(&Scope::Profile)) - .unwrap_or_default() - { - match get_profile( - db, - me.as_str(), - request.scope.as_ref() - .map(|s| s.has(&Scope::Email)) - .unwrap_or_default() - ).await { - Ok(profile) => { - tracing::debug!("Retrieved profile: {:?}", profile); - profile - }, - Err(err) => { - tracing::error!("Error retrieving profile from database: {}", err); - - return StatusCode::INTERNAL_SERVER_ERROR.into_response() - } - } - } else { - None - }; - - GrantResponse::ProfileUrl { me, profile }.into_response() - }, - _ => Error { - kind: ErrorKind::InvalidGrant, - msg: Some("The provided grant_type is unusable on this endpoint.".to_string()), - error_uri: "https://indieauth.spec.indieweb.org/#redeeming-the-authorization-code".parse().ok() - }.into_response() - } -} - -#[tracing::instrument(skip(backend, db))] -async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( - Host(host): Host, - Extension(backend): Extension<A>, - Extension(db): Extension<D>, - Form(grant): Form<GrantRequest>, -) -> Response { - #[inline] - fn prepare_access_token(me: url::Url, client_id: url::Url, scope: Scopes) -> TokenData { - TokenData { - me, client_id, scope, - exp: (std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - + std::time::Duration::from_secs(ACCESS_TOKEN_VALIDITY)) - .as_secs() - .into(), - iat: std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs() - .into() - } - } - - #[inline] - fn prepare_refresh_token(me: url::Url, client_id: url::Url, scope: Scopes) -> TokenData { - TokenData { - me, client_id, scope, - exp: (std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - + std::time::Duration::from_secs(REFRESH_TOKEN_VALIDITY)) - .as_secs() - .into(), - iat: std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs() - .into() - } - } - - let me: url::Url = format!("https://{}/", host).parse().unwrap(); - - match grant { - GrantRequest::AuthorizationCode { - code, - client_id, - redirect_uri, - code_verifier - } => { - let request: AuthorizationRequest = match backend.get_code(&code).await { - Ok(Some(request)) => request, - Ok(None) => return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("The provided authorization code is invalid.".to_string()), - error_uri: None - }.into_response(), - Err(err) => { - tracing::error!("Error retrieving auth request: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - }; - - tracing::debug!("Retrieved authorization request: {:?}", request); - - let scope = if let Some(scope) = request.scope { scope } else { - return Error { - kind: ErrorKind::InvalidScope, - msg: Some("Tokens cannot be issued if no scopes are requested.".to_string()), - error_uri: "https://indieauth.spec.indieweb.org/#access-token-response".parse().ok() - }.into_response(); - }; - if client_id != request.client_id { - return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("This authorization code isn't yours.".to_string()), - error_uri: None - }.into_response() - } - if redirect_uri != request.redirect_uri { - return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("This redirect_uri doesn't match the one the code has been sent to.".to_string()), - error_uri: None - }.into_response() - } - if !request.code_challenge.verify(code_verifier) { - return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("The PKCE challenge failed.".to_string()), - error_uri: "https://datatracker.ietf.org/doc/html/rfc7636#section-4.6".parse().ok() - }.into_response(); - } - - // Note: we can trust the `request.me` value, since we set - // it earlier before generating the authorization code - if request.me.unwrap() != me { - return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("This authorization endpoint does not serve this user.".to_string()), - error_uri: None - }.into_response() - } - - let profile = if dbg!(scope.has(&Scope::Profile)) { - match get_profile( - db, - me.as_str(), - scope.has(&Scope::Email) - ).await { - Ok(profile) => dbg!(profile), - Err(err) => { - tracing::error!("Error retrieving profile from database: {}", err); - - return StatusCode::INTERNAL_SERVER_ERROR.into_response() - } - } - } else { - None - }; - - let access_token = match backend.create_token( - prepare_access_token(me.clone(), client_id.clone(), scope.clone()) - ).await { - Ok(token) => token, - Err(err) => { - tracing::error!("Error creating access token: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - }; - // TODO: only create refresh token if user allows it - let refresh_token = match backend.create_refresh_token( - prepare_refresh_token(me.clone(), client_id, scope.clone()) - ).await { - Ok(token) => token, - Err(err) => { - tracing::error!("Error creating refresh token: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - }; - - GrantResponse::AccessToken { - me, - profile, - access_token, - token_type: kittybox_indieauth::TokenType::Bearer, - scope: Some(scope), - expires_in: Some(ACCESS_TOKEN_VALIDITY), - refresh_token: Some(refresh_token) - }.into_response() - }, - GrantRequest::RefreshToken { - refresh_token, - client_id, - scope - } => { - let data = match backend.get_refresh_token(&me, &refresh_token).await { - Ok(Some(token)) => token, - Ok(None) => return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("This refresh token is not valid.".to_string()), - error_uri: None - }.into_response(), - Err(err) => { - tracing::error!("Error retrieving refresh token: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() - } - }; - - if data.client_id != client_id { - return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("This refresh token is not yours.".to_string()), - error_uri: None - }.into_response(); - } - - let scope = if let Some(scope) = scope { - if !data.scope.has_all(scope.as_ref()) { - return Error { - kind: ErrorKind::InvalidScope, - msg: Some("You can't request additional scopes through the refresh token grant.".to_string()), - error_uri: None - }.into_response(); - } - - scope - } else { - // Note: check skipped because of redundancy (comparing a scope list with itself) - data.scope - }; - - - let profile = if scope.has(&Scope::Profile) { - match get_profile( - db, - data.me.as_str(), - scope.has(&Scope::Email) - ).await { - Ok(profile) => profile, - Err(err) => { - tracing::error!("Error retrieving profile from database: {}", err); - - return StatusCode::INTERNAL_SERVER_ERROR.into_response() - } - } - } else { - None - }; - - let access_token = match backend.create_token( - prepare_access_token(data.me.clone(), client_id.clone(), scope.clone()) - ).await { - Ok(token) => token, - Err(err) => { - tracing::error!("Error creating access token: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - }; - - let old_refresh_token = refresh_token; - let refresh_token = match backend.create_refresh_token( - prepare_refresh_token(data.me.clone(), client_id, scope.clone()) - ).await { - Ok(token) => token, - Err(err) => { - tracing::error!("Error creating refresh token: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - }; - if let Err(err) = backend.revoke_refresh_token(&me, &old_refresh_token).await { - tracing::error!("Error revoking refresh token: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - - GrantResponse::AccessToken { - me: data.me, - profile, - access_token, - token_type: kittybox_indieauth::TokenType::Bearer, - scope: Some(scope), - expires_in: Some(ACCESS_TOKEN_VALIDITY), - refresh_token: Some(refresh_token) - }.into_response() - } - } -} - -#[tracing::instrument(skip(backend, token_request))] -async fn introspection_endpoint_post<A: AuthBackend>( - Host(host): Host, - TypedHeader(Authorization(auth_token)): TypedHeader<Authorization<Bearer>>, - Extension(backend): Extension<A>, - Form(token_request): Form<TokenIntrospectionRequest>, -) -> Response { - use serde_json::json; - - let me: url::Url = format!("https://{}/", host).parse().unwrap(); - - // Check authentication first - match backend.get_token(&me, auth_token.token()).await { - Ok(Some(token)) => if !token.scope.has(&Scope::custom(KITTYBOX_TOKEN_STATUS)) { - return (StatusCode::UNAUTHORIZED, Json(json!({ - "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope - }))).into_response(); - }, - Ok(None) => return (StatusCode::UNAUTHORIZED, Json(json!({ - "error": kittybox_indieauth::ResourceErrorKind::InvalidToken - }))).into_response(), - Err(err) => { - tracing::error!("Error retrieving token data for introspection: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() - } - } - let response: TokenIntrospectionResponse = match backend.get_token(&me, &token_request.token).await { - Ok(maybe_data) => maybe_data.into(), - Err(err) => { - tracing::error!("Error retrieving token data: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() - } - }; - - response.into_response() -} - -async fn revocation_endpoint_post<A: AuthBackend>( - Host(host): Host, - Extension(backend): Extension<A>, - Form(revocation): Form<TokenRevocationRequest>, -) -> impl IntoResponse { - let me: url::Url = format!("https://{}/", host).parse().unwrap(); - - if let Err(err) = tokio::try_join!( - backend.revoke_token(&me, &revocation.token), - backend.revoke_refresh_token(&me, &revocation.token) - ) { - tracing::error!("Error revoking token: {}", err); - - StatusCode::INTERNAL_SERVER_ERROR - } else { - StatusCode::OK - } -} - -async fn get_profile<D: Storage + 'static>( - db: D, - url: &str, - email: bool -) -> crate::database::Result<Option<Profile>> { - Ok(db.get_post(url).await?.map(|mut mf2| { - // Ruthlessly manually destructure the MF2 document to save memory - let name = match mf2["properties"]["name"][0].take() { - serde_json::Value::String(s) => Some(s), - _ => None - }; - let url = match mf2["properties"]["uid"][0].take() { - serde_json::Value::String(s) => s.parse().ok(), - _ => None - }; - let photo = match mf2["properties"]["photo"][0].take() { - serde_json::Value::String(s) => s.parse().ok(), - _ => None - }; - let email = if email { - match mf2["properties"]["email"][0].take() { - serde_json::Value::String(s) => Some(s), - _ => None - } - } else { - None - }; - - Profile { name, url, photo, email } - })) -} - -async fn userinfo_endpoint_get<A: AuthBackend, D: Storage + 'static>( - Host(host): Host, - TypedHeader(Authorization(auth_token)): TypedHeader<Authorization<Bearer>>, - Extension(backend): Extension<A>, - Extension(db): Extension<D> -) -> Response { - use serde_json::json; - - let me: url::Url = format!("https://{}/", host).parse().unwrap(); - - match backend.get_token(&me, auth_token.token()).await { - Ok(Some(token)) => { - if token.expired() { - return (StatusCode::UNAUTHORIZED, Json(json!({ - "error": kittybox_indieauth::ResourceErrorKind::InvalidToken - }))).into_response(); - } - if !token.scope.has(&Scope::Profile) { - return (StatusCode::UNAUTHORIZED, Json(json!({ - "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope - }))).into_response(); - } - - match get_profile(db, me.as_str(), token.scope.has(&Scope::Email)).await { - Ok(Some(profile)) => profile.into_response(), - Ok(None) => Json(json!({ - // We do this because ResourceErrorKind is IndieAuth errors only - "error": "invalid_request" - })).into_response(), - Err(err) => { - tracing::error!("Error retrieving profile from database: {}", err); - - StatusCode::INTERNAL_SERVER_ERROR.into_response() - } - } - }, - Ok(None) => Json(json!({ - "error": kittybox_indieauth::ResourceErrorKind::InvalidToken - })).into_response(), - Err(err) => { - tracing::error!("Error reading token: {}", err); - - StatusCode::INTERNAL_SERVER_ERROR.into_response() - } - } -} - -#[must_use] -pub fn router<A: AuthBackend, D: Storage + 'static>(backend: A, db: D, http: reqwest::Client) -> axum::Router { - use axum::routing::{Router, get, post}; - - Router::new() - .nest( - "/.kittybox/indieauth", - Router::new() - .route("/metadata", - get(metadata)) - .route( - "/auth", - get(authorization_endpoint_get::<A, D>) - .post(authorization_endpoint_post::<A, D>)) - .route( - "/auth/confirm", - post(authorization_endpoint_confirm::<A>)) - .route( - "/token", - post(token_endpoint_post::<A, D>)) - .route( - "/token_status", - post(introspection_endpoint_post::<A>)) - .route( - "/revoke_token", - post(revocation_endpoint_post::<A>)) - .route( - "/userinfo", - get(userinfo_endpoint_get::<A, D>)) - - .route("/webauthn/pre_register", - get( - #[cfg(feature = "webauthn")] webauthn::webauthn_pre_register::<A, D>, - #[cfg(not(feature = "webauthn"))] || std::future::ready(axum::http::StatusCode::NOT_FOUND) - ) - ) - .layer(tower_http::cors::CorsLayer::new() - .allow_methods([ - axum::http::Method::GET, - axum::http::Method::POST - ]) - .allow_origin(tower_http::cors::Any)) - .layer(Extension(backend)) - // I don't really like the fact that I have to use the whole database - // If I could, I would've designed a separate trait for getting profiles - // And made databases implement it, for example - .layer(Extension(db)) - .layer(Extension(http)) - ) - .route( - "/.well-known/oauth-authorization-server", - get(|| std::future::ready( - (StatusCode::FOUND, - [("Location", - "/.kittybox/indieauth/metadata")] - ).into_response() - )) - ) -} - -#[cfg(test)] -mod tests { - #[test] - fn test_deserialize_authorization_confirmation() { - use super::{Credential, AuthorizationConfirmation}; - - let confirmation = serde_json::from_str::<AuthorizationConfirmation>(r#"{ - "request":{ - "response_type": "code", - "client_id": "https://quill.p3k.io/", - "redirect_uri": "https://quill.p3k.io/", - "state": "10101010", - "code_challenge": "awooooooooooo", - "code_challenge_method": "S256", - "scope": "create+media" - }, - "authorization_method": "swordfish" - }"#).unwrap(); - - match confirmation.authorization_method { - Credential::Password(password) => assert_eq!(password.as_str(), "swordfish"), - #[allow(unreachable_patterns)] - other => panic!("Incorrect credential: {:?}", other) - } - assert_eq!(confirmation.request.state.as_ref(), "10101010"); - } -} diff --git a/kittybox-rs/src/indieauth/webauthn.rs b/kittybox-rs/src/indieauth/webauthn.rs deleted file mode 100644 index ea3ad3d..0000000 --- a/kittybox-rs/src/indieauth/webauthn.rs +++ /dev/null @@ -1,140 +0,0 @@ -use axum::{ - extract::{Json, Host}, - response::{IntoResponse, Response}, - http::StatusCode, Extension, TypedHeader, headers::{authorization::Bearer, Authorization} -}; -use axum_extra::extract::cookie::{CookieJar, Cookie}; - -use super::backend::AuthBackend; -use crate::database::Storage; - -pub(crate) const CHALLENGE_ID_COOKIE: &str = "kittybox_webauthn_challenge_id"; - -macro_rules! bail { - ($msg:literal, $err:expr) => { - { - ::tracing::error!($msg, $err); - return ::axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response() - } - } -} - -pub async fn webauthn_pre_register<A: AuthBackend, D: Storage + 'static>( - Host(host): Host, - Extension(db): Extension<D>, - Extension(auth): Extension<A>, - cookies: CookieJar -) -> Response { - let uid = format!("https://{}/", host.clone()); - let uid_url: url::Url = uid.parse().unwrap(); - // This will not find an h-card in onboarding! - let display_name = match db.get_post(&uid).await { - Ok(hcard) => match hcard { - Some(mut hcard) => { - match hcard["properties"]["uid"][0].take() { - serde_json::Value::String(name) => name, - _ => String::default() - } - }, - None => String::default() - }, - Err(err) => bail!("Error retrieving h-card: {}", err) - }; - - let webauthn = webauthn::WebauthnBuilder::new( - &host, - &uid_url - ) - .unwrap() - .rp_name("Kittybox") - .build() - .unwrap(); - - let (challenge, state) = match webauthn.start_passkey_registration( - // Note: using a nil uuid here is fine - // Because the user corresponds to a website anyway - // We do not track multiple users - webauthn::prelude::Uuid::nil(), - &uid, - &display_name, - Some(vec![]) - ) { - Ok((challenge, state)) => (challenge, state), - Err(err) => bail!("Error generating WebAuthn registration data: {}", err) - }; - - match auth.persist_registration_challenge(&uid_url, state).await { - Ok(challenge_id) => ( - cookies.add( - Cookie::build(CHALLENGE_ID_COOKIE, challenge_id) - .secure(true) - .finish() - ), - Json(challenge) - ).into_response(), - Err(err) => bail!("Failed to persist WebAuthn challenge: {}", err) - } -} - -pub async fn webauthn_register<A: AuthBackend>( - Host(host): Host, - Json(credential): Json<webauthn::prelude::RegisterPublicKeyCredential>, - // TODO determine if we can use a cookie maybe? - user_credential: Option<TypedHeader<Authorization<Bearer>>>, - Extension(auth): Extension<A> -) -> Response { - let uid = format!("https://{}/", host.clone()); - let uid_url: url::Url = uid.parse().unwrap(); - - let pubkeys = match auth.list_webauthn_pubkeys(&uid_url).await { - Ok(pubkeys) => pubkeys, - Err(err) => bail!("Error enumerating existing WebAuthn credentials: {}", err) - }; - - if !pubkeys.is_empty() { - if let Some(TypedHeader(Authorization(token))) = user_credential { - // TODO check validity of the credential - } else { - return StatusCode::UNAUTHORIZED.into_response() - } - } - - return StatusCode::OK.into_response() -} - -pub(crate) async fn verify<A: AuthBackend>( - auth: &A, - website: &url::Url, - credential: webauthn::prelude::PublicKeyCredential, - challenge_id: &str -) -> std::io::Result<bool> { - let host = website.host_str().unwrap(); - - let webauthn = webauthn::WebauthnBuilder::new( - host, - website - ) - .unwrap() - .rp_name("Kittybox") - .build() - .unwrap(); - - match webauthn.finish_passkey_authentication( - &credential, - &auth.retrieve_authentication_challenge(&website, challenge_id).await? - ) { - Err(err) => { - tracing::error!("WebAuthn error: {}", err); - Ok(false) - }, - Ok(authentication_result) => { - let counter = authentication_result.counter(); - let cred_id = authentication_result.cred_id(); - - if authentication_result.needs_update() { - todo!() - } - Ok(true) - } - } -} diff --git a/kittybox-rs/src/lib.rs b/kittybox-rs/src/lib.rs deleted file mode 100644 index c1bd965..0000000 --- a/kittybox-rs/src/lib.rs +++ /dev/null @@ -1,93 +0,0 @@ -#![forbid(unsafe_code)] -#![warn(clippy::todo)] - -/// Database abstraction layer for Kittybox, allowing the CMS to work with any kind of database. -pub mod database; -pub mod frontend; -pub mod media; -pub mod micropub; -pub mod indieauth; -pub mod webmentions; - -pub mod companion { - use std::{collections::HashMap, sync::Arc}; - use axum::{ - extract::{Extension, Path}, - response::{IntoResponse, Response} - }; - - #[derive(Debug, Clone, Copy)] - struct Resource { - data: &'static [u8], - mime: &'static str - } - - impl IntoResponse for &Resource { - fn into_response(self) -> Response { - (axum::http::StatusCode::OK, - [("Content-Type", self.mime)], - self.data).into_response() - } - } - - // TODO replace with the "phf" crate someday - type ResourceTable = Arc<HashMap<&'static str, Resource>>; - - #[tracing::instrument] - async fn map_to_static( - Path(name): Path<String>, - Extension(resources): Extension<ResourceTable> - ) -> Response { - tracing::debug!("Searching for {} in the resource table...", name); - match resources.get(name.as_str()) { - Some(res) => res.into_response(), - None => { - #[cfg(debug_assertions)] tracing::error!("Not found"); - - (axum::http::StatusCode::NOT_FOUND, - [("Content-Type", "text/plain")], - "Not found. Sorry.".as_bytes()).into_response() - } - } - } - - #[must_use] - pub fn router() -> axum::Router { - let resources: ResourceTable = { - let mut map = HashMap::new(); - - macro_rules! register_resource { - ($map:ident, $prefix:expr, ($filename:literal, $mime:literal)) => {{ - $map.insert($filename, Resource { - data: include_bytes!(concat!($prefix, $filename)), - mime: $mime - }) - }}; - ($map:ident, $prefix:expr, ($filename:literal, $mime:literal), $( ($f:literal, $m:literal) ),+) => {{ - register_resource!($map, $prefix, ($filename, $mime)); - register_resource!($map, $prefix, $(($f, $m)),+); - }}; - } - - register_resource! { - map, - concat!(env!("OUT_DIR"), "/", "companion", "/"), - ("index.html", "text/html; charset=\"utf-8\""), - ("main.js", "text/javascript"), - ("micropub_api.js", "text/javascript"), - ("indieauth.js", "text/javascript"), - ("base64.js", "text/javascript"), - ("style.css", "text/css") - }; - - Arc::new(map) - }; - - axum::Router::new() - .route( - "/:filename", - axum::routing::get(map_to_static) - .layer(Extension(resources)) - ) - } -} diff --git a/kittybox-rs/src/main.rs b/kittybox-rs/src/main.rs deleted file mode 100644 index 6389489..0000000 --- a/kittybox-rs/src/main.rs +++ /dev/null @@ -1,489 +0,0 @@ -use kittybox::database::FileStorage; -use std::{env, time::Duration, sync::Arc}; -use tracing::error; - -fn init_media<A: kittybox::indieauth::backend::AuthBackend>(auth_backend: A, blobstore_uri: &str) -> axum::Router { - match blobstore_uri.split_once(':').unwrap().0 { - "file" => { - let folder = std::path::PathBuf::from( - blobstore_uri.strip_prefix("file://").unwrap() - ); - let blobstore = kittybox::media::storage::file::FileStore::new(folder); - - kittybox::media::router::<_, _>(blobstore, auth_backend) - }, - other => unimplemented!("Unsupported backend: {other}") - } -} - -async fn compose_kittybox_with_auth<A>( - http: reqwest::Client, - auth_backend: A, - backend_uri: &str, - blobstore_uri: &str, - job_queue_uri: &str, - jobset: &Arc<tokio::sync::Mutex<tokio::task::JoinSet<()>>>, - cancellation_token: &tokio_util::sync::CancellationToken -) -> (axum::Router, kittybox::webmentions::SupervisedTask) -where A: kittybox::indieauth::backend::AuthBackend -{ - match backend_uri.split_once(':').unwrap().0 { - "file" => { - let database = { - let folder = backend_uri.strip_prefix("file://").unwrap(); - let path = std::path::PathBuf::from(folder); - - match kittybox::database::FileStorage::new(path).await { - Ok(db) => db, - Err(err) => { - error!("Error creating database: {:?}", err); - std::process::exit(1); - } - } - }; - - // Technically, if we don't construct the micropub router, - // we could use some wrapper that makes the database - // read-only. - // - // This would allow to exclude all code to write to the - // database and separate reader and writer processes of - // Kittybox to improve security. - let homepage: axum::routing::MethodRouter<_> = axum::routing::get( - kittybox::frontend::homepage::<FileStorage> - ) - .layer(axum::Extension(database.clone())); - let fallback = axum::routing::get( - kittybox::frontend::catchall::<FileStorage> - ) - .layer(axum::Extension(database.clone())); - - let micropub = kittybox::micropub::router( - database.clone(), - http.clone(), - auth_backend.clone(), - Arc::clone(jobset) - ); - let onboarding = kittybox::frontend::onboarding::router( - database.clone(), http.clone(), Arc::clone(jobset) - ); - - - let (webmention, task) = kittybox::webmentions::router( - kittybox::webmentions::queue::PostgresJobQueue::new(job_queue_uri).await.unwrap(), - database.clone(), - http.clone(), - cancellation_token.clone() - ); - - let router = axum::Router::new() - .route("/", homepage) - .fallback(fallback) - .route("/.kittybox/micropub", micropub) - .route("/.kittybox/onboarding", onboarding) - .nest("/.kittybox/media", init_media(auth_backend.clone(), blobstore_uri)) - .merge(kittybox::indieauth::router(auth_backend.clone(), database.clone(), http.clone())) - .merge(webmention) - .route( - "/.kittybox/health", - axum::routing::get(health_check::<kittybox::database::FileStorage>) - .layer(axum::Extension(database)) - ); - - (router, task) - }, - "redis" => unimplemented!("Redis backend is not supported."), - #[cfg(feature = "postgres")] - "postgres" => { - use kittybox::database::PostgresStorage; - - let database = { - match PostgresStorage::new(backend_uri).await { - Ok(db) => db, - Err(err) => { - error!("Error creating database: {:?}", err); - std::process::exit(1); - } - } - }; - - // Technically, if we don't construct the micropub router, - // we could use some wrapper that makes the database - // read-only. - // - // This would allow to exclude all code to write to the - // database and separate reader and writer processes of - // Kittybox to improve security. - let homepage: axum::routing::MethodRouter<_> = axum::routing::get( - kittybox::frontend::homepage::<PostgresStorage> - ) - .layer(axum::Extension(database.clone())); - let fallback = axum::routing::get( - kittybox::frontend::catchall::<PostgresStorage> - ) - .layer(axum::Extension(database.clone())); - - let micropub = kittybox::micropub::router( - database.clone(), - http.clone(), - auth_backend.clone(), - Arc::clone(jobset) - ); - let onboarding = kittybox::frontend::onboarding::router( - database.clone(), http.clone(), Arc::clone(jobset) - ); - - let (webmention, task) = kittybox::webmentions::router( - kittybox::webmentions::queue::PostgresJobQueue::new(job_queue_uri).await.unwrap(), - database.clone(), - http.clone(), - cancellation_token.clone() - ); - - let router = axum::Router::new() - .route("/", homepage) - .fallback(fallback) - .route("/.kittybox/micropub", micropub) - .route("/.kittybox/onboarding", onboarding) - .nest("/.kittybox/media", init_media(auth_backend.clone(), blobstore_uri)) - .merge(kittybox::indieauth::router(auth_backend.clone(), database.clone(), http.clone())) - .merge(webmention) - .route( - "/.kittybox/health", - axum::routing::get(health_check::<kittybox::database::PostgresStorage>) - .layer(axum::Extension(database)) - ); - - (router, task) - }, - other => unimplemented!("Unsupported backend: {other}") - } -} - -async fn compose_kittybox( - backend_uri: &str, - blobstore_uri: &str, - authstore_uri: &str, - job_queue_uri: &str, - jobset: &Arc<tokio::sync::Mutex<tokio::task::JoinSet<()>>>, - cancellation_token: &tokio_util::sync::CancellationToken -) -> (axum::Router, kittybox::webmentions::SupervisedTask) { - let http: reqwest::Client = { - #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), - "/", - env!("CARGO_PKG_VERSION") - )); - if let Ok(certs) = std::env::var("KITTYBOX_CUSTOM_PKI_ROOTS") { - // TODO: add a root certificate if there's an environment variable pointing at it - for path in certs.split(':') { - let metadata = match tokio::fs::metadata(path).await { - Ok(metadata) => metadata, - Err(err) if err.kind() == std::io::ErrorKind::NotFound => { - tracing::error!("TLS root certificate {} not found, skipping...", path); - continue; - } - Err(err) => panic!("Error loading TLS certificates: {}", err) - }; - if metadata.is_dir() { - let mut dir = tokio::fs::read_dir(path).await.unwrap(); - while let Ok(Some(file)) = dir.next_entry().await { - let pem = tokio::fs::read(file.path()).await.unwrap(); - builder = builder.add_root_certificate( - reqwest::Certificate::from_pem(&pem).unwrap() - ); - } - } else { - let pem = tokio::fs::read(path).await.unwrap(); - builder = builder.add_root_certificate( - reqwest::Certificate::from_pem(&pem).unwrap() - ); - } - } - } - - builder.build().unwrap() - }; - - let (router, task) = match authstore_uri.split_once(':').unwrap().0 { - "file" => { - let auth_backend = { - let folder = authstore_uri - .strip_prefix("file://") - .unwrap(); - kittybox::indieauth::backend::fs::FileBackend::new(folder) - }; - - compose_kittybox_with_auth(http, auth_backend, backend_uri, blobstore_uri, job_queue_uri, jobset, cancellation_token).await - } - other => unimplemented!("Unsupported backend: {other}") - }; - - let router = router - .route( - "/.kittybox/static/:path", - axum::routing::get(kittybox::frontend::statics) - ) - .route("/.kittybox/coffee", teapot_route()) - .nest("/.kittybox/micropub/client", kittybox::companion::router()) - .layer(tower_http::trace::TraceLayer::new_for_http()) - .layer(tower_http::catch_panic::CatchPanicLayer::new()); - - (router, task) -} - -fn teapot_route() -> axum::routing::MethodRouter { - axum::routing::get(|| async { - use axum::http::{header, StatusCode}; - (StatusCode::IM_A_TEAPOT, [(header::CONTENT_TYPE, "text/plain")], "Sorry, can't brew coffee yet!") - }) -} - -async fn health_check</*A, B, */D>( - //axum::Extension(auth): axum::Extension<A>, - //axum::Extension(blob): axum::Extension<B>, - axum::Extension(data): axum::Extension<D>, -) -> impl axum::response::IntoResponse -where - //A: kittybox::indieauth::backend::AuthBackend, - //B: kittybox::media::storage::MediaStore, - D: kittybox::database::Storage -{ - (axum::http::StatusCode::OK, std::borrow::Cow::Borrowed("OK")) -} - -#[tokio::main] -async fn main() { - use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Registry}; - - let tracing_registry = Registry::default() - .with(EnvFilter::from_default_env()) - .with( - #[cfg(debug_assertions)] - tracing_tree::HierarchicalLayer::new(2) - .with_bracketed_fields(true) - .with_indent_lines(true) - .with_verbose_exit(true), - #[cfg(not(debug_assertions))] - tracing_subscriber::fmt::layer().json() - .with_ansi(std::io::IsTerminal::is_terminal(&std::io::stdout().lock())) - ); - // In debug builds, also log to JSON, but to file. - #[cfg(debug_assertions)] - let tracing_registry = tracing_registry.with( - tracing_subscriber::fmt::layer() - .json() - .with_writer({ - let instant = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap(); - move || std::fs::OpenOptions::new() - .append(true) - .create(true) - .open( - format!( - "{}.log.json", - instant - .as_secs_f64() - .to_string() - .replace('.', "_") - ) - ).unwrap() - }) - ); - tracing_registry.init(); - - tracing::info!("Starting the kittybox server..."); - - let backend_uri: String = env::var("BACKEND_URI") - .unwrap_or_else(|_| { - error!("BACKEND_URI is not set, cannot find a database"); - std::process::exit(1); - }); - let blobstore_uri: String = env::var("BLOBSTORE_URI") - .unwrap_or_else(|_| { - error!("BLOBSTORE_URI is not set, can't find media store"); - std::process::exit(1); - }); - - let authstore_uri: String = env::var("AUTH_STORE_URI") - .unwrap_or_else(|_| { - error!("AUTH_STORE_URI is not set, can't find authentication store"); - std::process::exit(1); - }); - - let job_queue_uri: String = env::var("JOB_QUEUE_URI") - .unwrap_or_else(|_| { - error!("JOB_QUEUE_URI is not set, can't find job queue"); - std::process::exit(1); - }); - - let cancellation_token = tokio_util::sync::CancellationToken::new(); - let jobset = Arc::new(tokio::sync::Mutex::new(tokio::task::JoinSet::new())); - - let (router, webmentions_task) = compose_kittybox( - backend_uri.as_str(), - blobstore_uri.as_str(), - authstore_uri.as_str(), - job_queue_uri.as_str(), - &jobset, - &cancellation_token - ).await; - - let mut servers: Vec<hyper::server::Server<hyper::server::conn::AddrIncoming, _>> = vec![]; - - let build_hyper = |tcp: std::net::TcpListener| { - tracing::info!("Listening on {}", tcp.local_addr().unwrap()); - // Set the socket to non-blocking so tokio can poll it - // properly -- this is the async magic! - tcp.set_nonblocking(true).unwrap(); - - hyper::server::Server::from_tcp(tcp).unwrap() - // Otherwise Chrome keeps connections open for too long - .tcp_keepalive(Some(Duration::from_secs(30 * 60))) - .serve(router.clone().into_make_service()) - }; - - let mut listenfd = listenfd::ListenFd::from_env(); - for i in 0..(listenfd.len()) { - match listenfd.take_tcp_listener(i) { - Ok(Some(tcp)) => servers.push(build_hyper(tcp)), - Ok(None) => {}, - Err(err) => tracing::error!("Error binding to socket in fd {}: {}", i, err) - } - } - // TODO this requires the `hyperlocal` crate - //#[rustfmt::skip] - /*#[cfg(unix)] { - let build_hyper_unix = |unix: std::os::unix::net::UnixListener| { - { - use std::os::linux::net::SocketAddrExt; - - let local_addr = unix.local_addr().unwrap(); - if let Some(pathname) = local_addr.as_pathname() { - tracing::info!("Listening on unix:{}", pathname.display()); - } else if let Some(name) = { - #[cfg(linux)] - local_addr.as_abstract_name(); - #[cfg(not(linux))] - None::<&[u8]> - } { - tracing::info!("Listening on unix:@{}", String::from_utf8_lossy(name)); - } else { - tracing::info!("Listening on unnamed unix socket"); - } - } - unix.set_nonblocking(true).unwrap(); - - hyper::server::Server::builder(unix) - .serve(router.clone().into_make_service()) - }; - for i in 0..(listenfd.len()) { - match listenfd.take_unix_listener(i) { - Ok(Some(unix)) => servers.push(build_hyper_unix(unix)), - Ok(None) => {}, - Err(err) => tracing::error!("Error binding to socket in fd {}: {}", i, err) - } - } - }*/ - if servers.is_empty() { - servers.push(build_hyper({ - let listen_addr = env::var("SERVE_AT") - .ok() - .unwrap_or_else(|| "[::]:8080".to_string()) - .parse::<std::net::SocketAddr>() - .unwrap_or_else(|e| { - error!("Cannot parse SERVE_AT: {}", e); - std::process::exit(1); - }); - - std::net::TcpListener::bind(listen_addr).unwrap() - })) - } - // Drop the remaining copy of the router - // to get rid of an extra reference to `jobset` - drop(router); - // Polling streams mutates them - let mut servers_futures = Box::pin(servers.into_iter() - .map( - #[cfg(not(tokio_unstable))] |server| tokio::task::spawn( - server.with_graceful_shutdown(cancellation_token.clone().cancelled_owned()) - ), - #[cfg(tokio_unstable)] |server| { - tokio::task::Builder::new() - .name(format!("Kittybox HTTP acceptor: {}", server.local_addr()).as_str()) - .spawn( - server.with_graceful_shutdown( - cancellation_token.clone().cancelled_owned() - ) - ) - .unwrap() - } - ) - .collect::<futures_util::stream::FuturesUnordered<tokio::task::JoinHandle<Result<(), hyper::Error>>>>() - ); - - #[cfg(not(unix))] - let shutdown_signal = tokio::signal::ctrl_c(); - #[cfg(unix)] - let shutdown_signal = { - use tokio::signal::unix::{signal, SignalKind}; - - async move { - let mut interrupt = signal(SignalKind::interrupt()) - .expect("Failed to set up SIGINT handler"); - let mut terminate = signal(SignalKind::terminate()) - .expect("Failed to setup SIGTERM handler"); - - tokio::select! { - _ = terminate.recv() => {}, - _ = interrupt.recv() => {}, - } - } - }; - use futures_util::stream::StreamExt; - - let exitcode: i32 = tokio::select! { - // Poll the servers stream for errors. - // If any error out, shut down the entire operation - // - // We do this because there might not be a good way - // to recover from some errors without external help - Some(Err(e)) = servers_futures.next() => { - tracing::error!("Error in HTTP server: {}", e); - tracing::error!("Shutting down because of error."); - cancellation_token.cancel(); - - 1 - } - _ = cancellation_token.cancelled() => { - tracing::info!("Signal caught from watchdog."); - - 0 - } - _ = shutdown_signal => { - tracing::info!("Shutdown requested by signal."); - cancellation_token.cancel(); - - 0 - } - }; - - tracing::info!("Waiting for unfinished background tasks..."); - - let _ = tokio::join!( - webmentions_task, - Box::pin(futures_util::future::join_all( - servers_futures.iter_mut().collect::<Vec<_>>() - )), - ); - let mut jobset: tokio::task::JoinSet<()> = Arc::try_unwrap(jobset) - .expect("Dangling jobset references present") - .into_inner(); - while (jobset.join_next().await).is_some() {} - tracing::info!("Shutdown complete, exiting."); - std::process::exit(exitcode); - -} diff --git a/kittybox-rs/src/media/mod.rs b/kittybox-rs/src/media/mod.rs deleted file mode 100644 index 71f875e..0000000 --- a/kittybox-rs/src/media/mod.rs +++ /dev/null @@ -1,141 +0,0 @@ -use std::convert::TryFrom; - -use axum::{ - extract::{Extension, Host, multipart::Multipart, Path}, - response::{IntoResponse, Response}, - headers::{Header, HeaderValue, IfNoneMatch, HeaderMapExt}, - TypedHeader, -}; -use kittybox_util::error::{MicropubError, ErrorType}; -use kittybox_indieauth::Scope; -use crate::indieauth::{User, backend::AuthBackend}; - -pub mod storage; -use storage::{MediaStore, MediaStoreError, Metadata, ErrorKind}; -pub use storage::file::FileStore; - -impl From<MediaStoreError> for MicropubError { - fn from(err: MediaStoreError) -> Self { - Self { - error: ErrorType::InternalServerError, - error_description: format!("{}", err) - } - } -} - -#[tracing::instrument(skip(blobstore))] -pub(crate) async fn upload<S: MediaStore, A: AuthBackend>( - Extension(blobstore): Extension<S>, - user: User<A>, - mut upload: Multipart -) -> Response { - if !user.check_scope(&Scope::Media) { - return MicropubError { - error: ErrorType::NotAuthorized, - error_description: "Interacting with the media storage requires the \"media\" scope.".to_owned() - }.into_response(); - } - let host = user.me.host().unwrap().to_string() + &user.me.port().map(|i| format!(":{}", i)).unwrap_or_default(); - let field = match upload.next_field().await { - Ok(Some(field)) => field, - Ok(None) => { - return MicropubError { - error: ErrorType::InvalidRequest, - error_description: "Send multipart/form-data with one field named file".to_owned() - }.into_response(); - }, - Err(err) => { - return MicropubError { - error: ErrorType::InternalServerError, - error_description: format!("Error while parsing multipart/form-data: {}", err) - }.into_response(); - }, - }; - let metadata: Metadata = (&field).into(); - match blobstore.write_streaming(&host, metadata, field).await { - Ok(filename) => IntoResponse::into_response(( - axum::http::StatusCode::CREATED, - [ - ("Location", user.me.join( - &format!(".kittybox/media/uploads/{}", filename) - ).unwrap().as_str()) - ] - )), - Err(err) => MicropubError::from(err).into_response() - } -} - -#[tracing::instrument(skip(blobstore))] -pub(crate) async fn serve<S: MediaStore>( - Host(host): Host, - Path(path): Path<String>, - if_none_match: Option<TypedHeader<IfNoneMatch>>, - Extension(blobstore): Extension<S> -) -> Response { - use axum::http::StatusCode; - tracing::debug!("Searching for file..."); - match blobstore.read_streaming(&host, path.as_str()).await { - Ok((metadata, stream)) => { - tracing::debug!("Metadata: {:?}", metadata); - - let etag = if let Some(etag) = metadata.etag { - let etag = format!("\"{}\"", etag).parse::<axum::headers::ETag>().unwrap(); - - if let Some(TypedHeader(if_none_match)) = if_none_match { - tracing::debug!("If-None-Match: {:?}", if_none_match); - // If-None-Match is a negative precondition that - // returns 304 when it doesn't match because it - // only matches when file is different - if !if_none_match.precondition_passes(&etag) { - return StatusCode::NOT_MODIFIED.into_response() - } - } - - Some(etag) - } else { None }; - - let mut r = Response::builder(); - { - let headers = r.headers_mut().unwrap(); - headers.insert( - "Content-Type", - HeaderValue::from_str( - metadata.content_type - .as_deref() - .unwrap_or("application/octet-stream") - ).unwrap() - ); - if let Some(length) = metadata.length { - headers.insert( - "Content-Length", - HeaderValue::from_str(&length.to_string()).unwrap() - ); - } - if let Some(etag) = etag { - headers.typed_insert(etag); - } - } - r.body(axum::body::StreamBody::new(stream)) - .unwrap() - .into_response() - }, - Err(err) => match err.kind() { - ErrorKind::NotFound => { - IntoResponse::into_response(StatusCode::NOT_FOUND) - }, - _ => { - tracing::error!("{}", err); - IntoResponse::into_response(StatusCode::INTERNAL_SERVER_ERROR) - } - } - } -} - -#[must_use] -pub fn router<S: MediaStore, A: AuthBackend>(blobstore: S, auth: A) -> axum::Router { - axum::Router::new() - .route("/", axum::routing::post(upload::<S, A>)) - .route("/uploads/*file", axum::routing::get(serve::<S>)) - .layer(axum::Extension(blobstore)) - .layer(axum::Extension(auth)) -} diff --git a/kittybox-rs/src/media/storage/file.rs b/kittybox-rs/src/media/storage/file.rs deleted file mode 100644 index 0aaaa3b..0000000 --- a/kittybox-rs/src/media/storage/file.rs +++ /dev/null @@ -1,434 +0,0 @@ -use super::{Metadata, ErrorKind, MediaStore, MediaStoreError, Result}; -use async_trait::async_trait; -use std::{path::PathBuf, fmt::Debug}; -use tokio::fs::OpenOptions; -use tokio::io::{BufReader, BufWriter, AsyncWriteExt, AsyncSeekExt}; -use futures::{StreamExt, TryStreamExt}; -use std::ops::{Bound, RangeBounds, Neg}; -use std::pin::Pin; -use sha2::Digest; -use futures::FutureExt; -use tracing::{debug, error}; - -const BUF_CAPACITY: usize = 16 * 1024; - -#[derive(Clone)] -pub struct FileStore { - base: PathBuf, -} - -impl From<tokio::io::Error> for MediaStoreError { - fn from(source: tokio::io::Error) -> Self { - Self { - msg: format!("file I/O error: {}", source), - kind: match source.kind() { - std::io::ErrorKind::NotFound => ErrorKind::NotFound, - _ => ErrorKind::Backend - }, - source: Some(Box::new(source)), - } - } -} - -impl FileStore { - pub fn new<T: Into<PathBuf>>(base: T) -> Self { - Self { base: base.into() } - } - - async fn mktemp(&self) -> Result<(PathBuf, BufWriter<tokio::fs::File>)> { - kittybox_util::fs::mktemp(&self.base, "temp", 16) - .await - .map(|(name, file)| (name, BufWriter::new(file))) - .map_err(Into::into) - } -} - -#[async_trait] -impl MediaStore for FileStore { - - #[tracing::instrument(skip(self, content))] - async fn write_streaming<T>( - &self, - domain: &str, - mut metadata: Metadata, - mut content: T, - ) -> Result<String> - where - T: tokio_stream::Stream<Item = std::result::Result<bytes::Bytes, axum::extract::multipart::MultipartError>> + Unpin + Send + Debug - { - let (tempfilepath, mut tempfile) = self.mktemp().await?; - debug!("Temporary file opened for storing pending upload: {}", tempfilepath.display()); - let mut hasher = sha2::Sha256::new(); - let mut length: usize = 0; - - while let Some(chunk) = content.next().await { - let chunk = chunk.map_err(|err| MediaStoreError { - kind: ErrorKind::Backend, - source: Some(Box::new(err)), - msg: "Failed to read a data chunk".to_owned() - })?; - debug!("Read {} bytes from the stream", chunk.len()); - length += chunk.len(); - let (write_result, _hasher) = tokio::join!( - { - let chunk = chunk.clone(); - let tempfile = &mut tempfile; - async move { - tempfile.write_all(&*chunk).await - } - }, - { - let chunk = chunk.clone(); - tokio::task::spawn_blocking(move || { - hasher.update(&*chunk); - - hasher - }).map(|r| r.unwrap()) - } - ); - if let Err(err) = write_result { - error!("Error while writing pending upload: {}", err); - drop(tempfile); - // this is just cleanup, nothing fails if it fails - // though temporary files might take up space on the hard drive - // We'll clean them when maintenance time comes - #[allow(unused_must_use)] - { tokio::fs::remove_file(tempfilepath).await; } - return Err(err.into()); - } - hasher = _hasher; - } - // Manually flush the buffer and drop the handle to close the file - tempfile.flush().await?; - tempfile.into_inner().sync_all().await?; - - let hash = hasher.finalize(); - debug!("Pending upload hash: {}", hex::encode(&hash)); - let filename = format!( - "{}/{}/{}/{}/{}", - hex::encode([hash[0]]), - hex::encode([hash[1]]), - hex::encode([hash[2]]), - hex::encode([hash[3]]), - hex::encode(&hash[4..32]) - ); - let domain_str = domain.to_string(); - let filepath = self.base.join(domain_str.as_str()).join(&filename); - let metafilename = filename.clone() + ".json"; - let metapath = self.base.join(domain_str.as_str()).join(&metafilename); - let metatemppath = self.base.join(domain_str.as_str()).join(metafilename + ".tmp"); - metadata.length = std::num::NonZeroUsize::new(length); - metadata.etag = Some(hex::encode(&hash)); - debug!("File path: {}, metadata: {}", filepath.display(), metapath.display()); - { - let parent = filepath.parent().unwrap(); - tokio::fs::create_dir_all(parent).await?; - } - let mut meta = OpenOptions::new() - .create_new(true) - .write(true) - .open(&metatemppath) - .await?; - meta.write_all(&serde_json::to_vec(&metadata).unwrap()).await?; - tokio::fs::rename(tempfilepath, filepath).await?; - tokio::fs::rename(metatemppath, metapath).await?; - Ok(filename) - } - - #[tracing::instrument(skip(self))] - async fn read_streaming( - &self, - domain: &str, - filename: &str, - ) -> Result<(Metadata, Pin<Box<dyn tokio_stream::Stream<Item = std::io::Result<bytes::Bytes>> + Send>>)> { - debug!("Domain: {}, filename: {}", domain, filename); - let path = self.base.join(domain).join(filename); - debug!("Path: {}", path.display()); - - let file = OpenOptions::new() - .read(true) - .open(path) - .await?; - let meta = self.metadata(domain, filename).await?; - - Ok((meta, Box::pin( - tokio_util::io::ReaderStream::new( - // TODO: determine if BufReader provides benefit here - // From the logs it looks like we're reading 4KiB at a time - // Buffering file contents seems to double download speed - // How to benchmark this? - BufReader::with_capacity(BUF_CAPACITY, file) - ) - // Sprinkle some salt in form of protective log wrapping - .inspect_ok(|chunk| debug!("Read {} bytes from file", chunk.len())) - ))) - } - - #[tracing::instrument(skip(self))] - async fn metadata(&self, domain: &str, filename: &str) -> Result<Metadata> { - let metapath = self.base.join(domain).join(format!("{}.json", filename)); - debug!("Metadata path: {}", metapath.display()); - - let meta = serde_json::from_slice(&tokio::fs::read(metapath).await?) - .map_err(|err| MediaStoreError { - kind: ErrorKind::Json, - msg: format!("{}", err), - source: Some(Box::new(err)) - })?; - - Ok(meta) - } - - #[tracing::instrument(skip(self))] - async fn stream_range( - &self, - domain: &str, - filename: &str, - range: (Bound<u64>, Bound<u64>) - ) -> Result<Pin<Box<dyn tokio_stream::Stream<Item = std::io::Result<bytes::Bytes>> + Send>>> { - let path = self.base.join(format!("{}/{}", domain, filename)); - let metapath = self.base.join(format!("{}/{}.json", domain, filename)); - debug!("Path: {}, metadata: {}", path.display(), metapath.display()); - - let mut file = OpenOptions::new() - .read(true) - .open(path) - .await?; - - let start = match range { - (Bound::Included(bound), _) => { - debug!("Seeking {} bytes forward...", bound); - file.seek(std::io::SeekFrom::Start(bound)).await? - } - (Bound::Excluded(_), _) => unreachable!(), - (Bound::Unbounded, Bound::Included(bound)) => { - // Seek to the end minus the bounded bytes - debug!("Seeking {} bytes back from the end...", bound); - file.seek(std::io::SeekFrom::End(i64::try_from(bound).unwrap().neg())).await? - }, - (Bound::Unbounded, Bound::Unbounded) => 0, - (_, Bound::Excluded(_)) => unreachable!() - }; - - let stream = Box::pin(tokio_util::io::ReaderStream::new(BufReader::with_capacity(BUF_CAPACITY, file))) - .map_ok({ - let mut bytes_read = 0usize; - let len = match range { - (_, Bound::Unbounded) => None, - (Bound::Unbounded, Bound::Included(bound)) => Some(bound), - (_, Bound::Included(bound)) => Some(bound + 1 - start), - (_, Bound::Excluded(_)) => unreachable!() - }; - move |chunk| { - debug!("Read {} bytes from file, {} in this chunk", bytes_read, chunk.len()); - bytes_read += chunk.len(); - if let Some(len) = len.map(|len| len.try_into().unwrap()) { - if bytes_read > len { - if bytes_read - len > chunk.len() { - return None - } - debug!("Truncating last {} bytes", bytes_read - len); - return Some(chunk.slice(..chunk.len() - (bytes_read - len))) - } - } - - Some(chunk) - } - }) - .try_take_while(|x| std::future::ready(Ok(x.is_some()))) - // Will never panic, because the moment the stream yields - // a None, it is considered exhausted. - .map_ok(|x| x.unwrap()); - - return Ok(Box::pin(stream)) - } - - - async fn delete(&self, domain: &str, filename: &str) -> Result<()> { - let path = self.base.join(format!("{}/{}", domain, filename)); - - Ok(tokio::fs::remove_file(path).await?) - } -} - -#[cfg(test)] -mod tests { - use super::{Metadata, FileStore, MediaStore}; - use std::ops::Bound; - use tokio::io::AsyncReadExt; - - #[tokio::test] - #[tracing_test::traced_test] - async fn test_ranges() { - let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); - let store = FileStore::new(tempdir.path()); - - let file: &[u8] = include_bytes!("./file.rs"); - let stream = tokio_stream::iter(file.chunks(100).map(|i| Ok(bytes::Bytes::copy_from_slice(i)))); - let metadata = Metadata { - filename: Some("file.rs".to_string()), - content_type: Some("text/plain".to_string()), - length: None, - etag: None, - }; - - // write through the interface - let filename = store.write_streaming( - "fireburn.ru", - metadata, stream - ).await.unwrap(); - - tracing::debug!("Writing complete."); - - // Ensure the file is there - let content = tokio::fs::read( - tempdir.path() - .join("fireburn.ru") - .join(&filename) - ).await.unwrap(); - assert_eq!(content, file); - - tracing::debug!("Reading range from the start..."); - // try to read range - let range = { - let stream = store.stream_range( - "fireburn.ru", &filename, - (Bound::Included(0), Bound::Included(299)) - ).await.unwrap(); - - let mut reader = tokio_util::io::StreamReader::new(stream); - - let mut buf = Vec::default(); - reader.read_to_end(&mut buf).await.unwrap(); - - buf - }; - - assert_eq!(range.len(), 300); - assert_eq!(range.as_slice(), &file[..=299]); - - tracing::debug!("Reading range from the middle..."); - - let range = { - let stream = store.stream_range( - "fireburn.ru", &filename, - (Bound::Included(150), Bound::Included(449)) - ).await.unwrap(); - - let mut reader = tokio_util::io::StreamReader::new(stream); - - let mut buf = Vec::default(); - reader.read_to_end(&mut buf).await.unwrap(); - - buf - }; - - assert_eq!(range.len(), 300); - assert_eq!(range.as_slice(), &file[150..=449]); - - tracing::debug!("Reading range from the end..."); - let range = { - let stream = store.stream_range( - "fireburn.ru", &filename, - // Note: the `headers` crate parses bounds in a - // non-standard way, where unbounded start actually - // means getting things from the end... - (Bound::Unbounded, Bound::Included(300)) - ).await.unwrap(); - - let mut reader = tokio_util::io::StreamReader::new(stream); - - let mut buf = Vec::default(); - reader.read_to_end(&mut buf).await.unwrap(); - - buf - }; - - assert_eq!(range.len(), 300); - assert_eq!(range.as_slice(), &file[file.len()-300..file.len()]); - - tracing::debug!("Reading the whole file..."); - // try to read range - let range = { - let stream = store.stream_range( - "fireburn.ru", &("/".to_string() + &filename), - (Bound::Unbounded, Bound::Unbounded) - ).await.unwrap(); - - let mut reader = tokio_util::io::StreamReader::new(stream); - - let mut buf = Vec::default(); - reader.read_to_end(&mut buf).await.unwrap(); - - buf - }; - - assert_eq!(range.len(), file.len()); - assert_eq!(range.as_slice(), file); - } - - - #[tokio::test] - #[tracing_test::traced_test] - async fn test_streaming_read_write() { - let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); - let store = FileStore::new(tempdir.path()); - - let file: &[u8] = include_bytes!("./file.rs"); - let stream = tokio_stream::iter(file.chunks(100).map(|i| Ok(bytes::Bytes::copy_from_slice(i)))); - let metadata = Metadata { - filename: Some("style.css".to_string()), - content_type: Some("text/css".to_string()), - length: None, - etag: None, - }; - - // write through the interface - let filename = store.write_streaming( - "fireburn.ru", - metadata, stream - ).await.unwrap(); - println!("{}, {}", filename, tempdir.path() - .join("fireburn.ru") - .join(&filename) - .display()); - let content = tokio::fs::read( - tempdir.path() - .join("fireburn.ru") - .join(&filename) - ).await.unwrap(); - assert_eq!(content, file); - - // check internal metadata format - let meta: Metadata = serde_json::from_slice(&tokio::fs::read( - tempdir.path() - .join("fireburn.ru") - .join(filename.clone() + ".json") - ).await.unwrap()).unwrap(); - assert_eq!(meta.content_type.as_deref(), Some("text/css")); - assert_eq!(meta.filename.as_deref(), Some("style.css")); - assert_eq!(meta.length.map(|i| i.get()), Some(file.len())); - assert!(meta.etag.is_some()); - - // read back the data using the interface - let (metadata, read_back) = { - let (metadata, stream) = store.read_streaming( - "fireburn.ru", - &filename - ).await.unwrap(); - let mut reader = tokio_util::io::StreamReader::new(stream); - - let mut buf = Vec::default(); - reader.read_to_end(&mut buf).await.unwrap(); - - (metadata, buf) - }; - - assert_eq!(read_back, file); - assert_eq!(metadata.content_type.as_deref(), Some("text/css")); - assert_eq!(meta.filename.as_deref(), Some("style.css")); - assert_eq!(meta.length.map(|i| i.get()), Some(file.len())); - assert!(meta.etag.is_some()); - - } -} diff --git a/kittybox-rs/src/media/storage/mod.rs b/kittybox-rs/src/media/storage/mod.rs deleted file mode 100644 index 020999c..0000000 --- a/kittybox-rs/src/media/storage/mod.rs +++ /dev/null @@ -1,177 +0,0 @@ -use async_trait::async_trait; -use axum::extract::multipart::Field; -use tokio_stream::Stream; -use bytes::Bytes; -use serde::{Deserialize, Serialize}; -use std::ops::Bound; -use std::pin::Pin; -use std::fmt::Debug; -use std::num::NonZeroUsize; - -pub mod file; - -#[derive(Debug, Deserialize, Serialize)] -pub struct Metadata { - /// Content type of the file. If None, the content-type is considered undefined. - pub content_type: Option<String>, - /// The original filename that was passed. - pub filename: Option<String>, - /// The recorded length of the file. - pub length: Option<NonZeroUsize>, - /// The e-tag of a file. Note: it must be a strong e-tag, for example, a hash. - pub etag: Option<String>, -} -impl From<&Field<'_>> for Metadata { - fn from(field: &Field<'_>) -> Self { - Self { - content_type: field.content_type() - .map(|i| i.to_owned()), - filename: field.file_name() - .map(|i| i.to_owned()), - length: None, - etag: None, - } - } -} - - -#[derive(Debug, Clone, Copy)] -pub enum ErrorKind { - Backend, - Permission, - Json, - NotFound, - Other, -} - -#[derive(Debug)] -pub struct MediaStoreError { - kind: ErrorKind, - source: Option<Box<dyn std::error::Error + Send + Sync>>, - msg: String, -} - -impl MediaStoreError { - pub fn kind(&self) -> ErrorKind { - self.kind - } -} - -impl std::error::Error for MediaStoreError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.source - .as_ref() - .map(|i| i.as_ref() as &dyn std::error::Error) - } -} - -impl std::fmt::Display for MediaStoreError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}: {}", - match self.kind { - ErrorKind::Backend => "media storage backend error", - ErrorKind::Permission => "permission denied", - ErrorKind::Json => "failed to parse json", - ErrorKind::NotFound => "blob not found", - ErrorKind::Other => "unknown media storage error", - }, - self.msg - ) - } -} - -pub type Result<T> = std::result::Result<T, MediaStoreError>; - -#[async_trait] -pub trait MediaStore: 'static + Send + Sync + Clone { - async fn write_streaming<T>( - &self, - domain: &str, - metadata: Metadata, - content: T, - ) -> Result<String> - where - T: tokio_stream::Stream<Item = std::result::Result<bytes::Bytes, axum::extract::multipart::MultipartError>> + Unpin + Send + Debug; - - async fn read_streaming( - &self, - domain: &str, - filename: &str, - ) -> Result<(Metadata, Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>)>; - - async fn stream_range( - &self, - domain: &str, - filename: &str, - range: (Bound<u64>, Bound<u64>) - ) -> Result<Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>> { - use futures::stream::TryStreamExt; - use tracing::debug; - let (metadata, mut stream) = self.read_streaming(domain, filename).await?; - let length = metadata.length.unwrap().get(); - - use Bound::*; - let (start, end): (usize, usize) = match range { - (Unbounded, Unbounded) => return Ok(stream), - (Included(start), Unbounded) => (start.try_into().unwrap(), length - 1), - (Unbounded, Included(end)) => (length - usize::try_from(end).unwrap(), length - 1), - (Included(start), Included(end)) => (start.try_into().unwrap(), end.try_into().unwrap()), - (_, _) => unreachable!() - }; - - stream = Box::pin( - stream.map_ok({ - let mut bytes_skipped = 0usize; - let mut bytes_read = 0usize; - - move |chunk| { - debug!("Skipped {}/{} bytes, chunk len {}", bytes_skipped, start, chunk.len()); - let chunk = if bytes_skipped < start { - let need_to_skip = start - bytes_skipped; - if chunk.len() < need_to_skip { - return None - } - debug!("Skipping {} bytes", need_to_skip); - bytes_skipped += need_to_skip; - - chunk.slice(need_to_skip..) - } else { - chunk - }; - - debug!("Read {} bytes from file, {} in this chunk", bytes_read, chunk.len()); - bytes_read += chunk.len(); - - if bytes_read > length { - if bytes_read - length > chunk.len() { - return None - } - debug!("Truncating last {} bytes", bytes_read - length); - return Some(chunk.slice(..chunk.len() - (bytes_read - length))) - } - - Some(chunk) - } - }) - .try_skip_while(|x| std::future::ready(Ok(x.is_none()))) - .try_take_while(|x| std::future::ready(Ok(x.is_some()))) - .map_ok(|x| x.unwrap()) - ); - - return Ok(stream); - } - - /// Read metadata for a file. - /// - /// The default implementation uses the `read_streaming` method - /// and drops the stream containing file content. - async fn metadata(&self, domain: &str, filename: &str) -> Result<Metadata> { - self.read_streaming(domain, filename) - .await - .map(|(meta, stream)| meta) - } - - async fn delete(&self, domain: &str, filename: &str) -> Result<()>; -} diff --git a/kittybox-rs/src/metrics.rs b/kittybox-rs/src/metrics.rs deleted file mode 100644 index e13fcb9..0000000 --- a/kittybox-rs/src/metrics.rs +++ /dev/null @@ -1,21 +0,0 @@ -#![allow(unused_imports, dead_code)] -use async_trait::async_trait; -use lazy_static::lazy_static; -use prometheus::Encoder; -use std::time::{Duration, Instant}; - -// TODO: Vendor in the Metrics struct from warp_prometheus and rework the path matching algorithm - -pub fn metrics(path_includes: Vec<String>) -> warp::log::Log<impl Fn(warp::log::Info) + Clone> { - let metrics = warp_prometheus::Metrics::new(prometheus::default_registry(), &path_includes); - warp::log::custom(move |info| metrics.http_metrics(info)) -} - -pub fn gather() -> Vec<u8> { - let mut buffer: Vec<u8> = vec![]; - let encoder = prometheus::TextEncoder::new(); - let metric_families = prometheus::gather(); - encoder.encode(&metric_families, &mut buffer).unwrap(); - - buffer -} diff --git a/kittybox-rs/src/micropub/get.rs b/kittybox-rs/src/micropub/get.rs deleted file mode 100644 index 718714a..0000000 --- a/kittybox-rs/src/micropub/get.rs +++ /dev/null @@ -1,82 +0,0 @@ -use crate::database::{MicropubChannel, Storage}; -use crate::indieauth::User; -use crate::ApplicationState; -use tide::prelude::{json, Deserialize}; -use tide::{Request, Response, Result}; - -#[derive(Deserialize)] -struct QueryOptions { - q: String, - url: Option<String>, -} - -pub async fn get_handler<Backend>(req: Request<ApplicationState<Backend>>) -> Result -where - Backend: Storage + Send + Sync, -{ - let user = req.ext::<User>().unwrap(); - let backend = &req.state().storage; - let media_endpoint = &req.state().media_endpoint; - let query = req.query::<QueryOptions>().unwrap_or(QueryOptions { - q: "".to_string(), - url: None, - }); - match &*query.q { - "config" => { - let channels: Vec<MicropubChannel>; - match backend.get_channels(user.me.as_str()).await { - Ok(chans) => channels = chans, - Err(err) => return Ok(err.into()) - } - Ok(Response::builder(200).body(json!({ - "q": ["source", "config", "channel"], - "channels": channels, - "media-endpoint": media_endpoint - })).build()) - }, - "channel" => { - let channels: Vec<MicropubChannel>; - match backend.get_channels(user.me.as_str()).await { - Ok(chans) => channels = chans, - Err(err) => return Ok(err.into()) - } - Ok(Response::builder(200).body(json!(channels)).build()) - } - "source" => { - if user.check_scope("create") || user.check_scope("update") || user.check_scope("delete") || user.check_scope("undelete") { - if let Some(url) = query.url { - match backend.get_post(&url).await { - Ok(post) => if let Some(post) = post { - Ok(Response::builder(200).body(post).build()) - } else { - Ok(Response::builder(404).build()) - }, - Err(err) => Ok(err.into()) - } - } else { - Ok(Response::builder(400).body(json!({ - "error": "invalid_request", - "error_description": "Please provide `url`." - })).build()) - } - } else { - Ok(Response::builder(401).body(json!({ - "error": "insufficient_scope", - "error_description": "You don't have the required scopes to proceed.", - "scope": "update" - })).build()) - } - }, - // TODO: ?q=food, ?q=geo, ?q=contacts - // Depends on indexing posts - // Errors - "" => Ok(Response::builder(400).body(json!({ - "error": "invalid_request", - "error_description": "No ?q= parameter specified. Try ?q=config maybe?" - })).build()), - _ => Ok(Response::builder(400).body(json!({ - "error": "invalid_request", - "error_description": "Unsupported ?q= query. Try ?q=config and see the q array for supported values." - })).build()) - } -} diff --git a/kittybox-rs/src/micropub/mod.rs b/kittybox-rs/src/micropub/mod.rs deleted file mode 100644 index 02eee6e..0000000 --- a/kittybox-rs/src/micropub/mod.rs +++ /dev/null @@ -1,846 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; - -use crate::database::{MicropubChannel, Storage, StorageError}; -use crate::indieauth::backend::AuthBackend; -use crate::indieauth::User; -use crate::micropub::util::form_to_mf2_json; -use axum::extract::{BodyStream, Query, Host}; -use axum::headers::ContentType; -use axum::response::{IntoResponse, Response}; -use axum::TypedHeader; -use axum::{http::StatusCode, Extension}; -use serde::{Deserialize, Serialize}; -use serde_json::json; -use tokio::sync::Mutex; -use tokio::task::JoinSet; -use tracing::{debug, error, info, warn}; -use kittybox_indieauth::{Scope, TokenData}; -use kittybox_util::{MicropubError, ErrorType}; - -#[derive(Serialize, Deserialize, Debug, PartialEq)] -#[serde(rename_all = "kebab-case")] -enum QueryType { - Source, - Config, - Channel, - SyndicateTo, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct MicropubQuery { - q: QueryType, - url: Option<String>, -} - -impl From<StorageError> for MicropubError { - fn from(err: StorageError) -> Self { - Self { - error: match err.kind() { - crate::database::ErrorKind::NotFound => ErrorType::NotFound, - _ => ErrorType::InternalServerError, - }, - error_description: format!("Backend error: {}", err), - } - } -} - -mod util; -pub(crate) use util::normalize_mf2; - -#[derive(Debug)] -struct FetchedPostContext { - url: url::Url, - mf2: serde_json::Value, - webmention: Option<url::Url>, -} - -fn populate_reply_context( - mf2: &serde_json::Value, - prop: &str, - ctxs: &[FetchedPostContext], -) -> Option<Vec<serde_json::Value>> { - mf2["properties"][prop].as_array().map(|array| { - array - .iter() - // TODO: This seems to be O(n^2) and I don't like it. - // Switching `ctxs` to a hashmap might speed it up to O(n) - // The key would be the URL/UID - .map(|i| ctxs - .iter() - .find(|ctx| Some(ctx.url.as_str()) == i.as_str()) - .and_then(|ctx| ctx.mf2["items"].get(0)) - .unwrap_or(i)) - .cloned() - .collect::<Vec<serde_json::Value>>() - }) -} - -#[tracing::instrument(skip(db))] -async fn background_processing<D: 'static + Storage>( - db: D, - mf2: serde_json::Value, - http: reqwest::Client, -) -> () { - // TODO: Post-processing the post (aka second write pass) - // - [x] Download rich reply contexts - // - [ ] Syndicate the post if requested, add links to the syndicated copies - // - [ ] Send WebSub notifications to the hub (if we happen to have one) - // - [x] Send webmentions - - use futures_util::StreamExt; - - let uid: &str = mf2["properties"]["uid"][0].as_str().unwrap(); - - let context_props = ["in-reply-to", "like-of", "repost-of", "bookmark-of"]; - let mut context_urls: Vec<url::Url> = vec![]; - for prop in &context_props { - if let Some(array) = mf2["properties"][prop].as_array() { - context_urls.extend( - array - .iter() - .filter_map(|v| v.as_str()) - .filter_map(|v| v.parse::<url::Url>().ok()), - ); - } - } - // TODO parse HTML in e-content and add links found here - context_urls.sort_unstable_by_key(|u| u.to_string()); - context_urls.dedup(); - - // TODO: Make a stream to fetch all these posts and convert them to MF2 - let post_contexts = { - let http = &http; - tokio_stream::iter(context_urls.into_iter()) - .then(move |url: url::Url| http.get(url).send()) - .filter_map(|response| futures::future::ready(response.ok())) - .filter(|response| futures::future::ready(response.status() == 200)) - .filter_map(|response: reqwest::Response| async move { - // 1. We need to preserve the URL - // 2. We need to get the HTML for MF2 processing - // 3. We need to get the webmention endpoint address - // All of that can be done in one go. - let url = response.url().clone(); - // TODO parse link headers - let links = response - .headers() - .get_all(hyper::http::header::LINK) - .iter() - .cloned() - .collect::<Vec<hyper::http::HeaderValue>>(); - let html = response.text().await; - if html.is_err() { - return None; - } - let html = html.unwrap(); - let mf2 = microformats::from_html(&html, url.clone()).unwrap(); - // TODO use first Link: header if available - let webmention: Option<url::Url> = mf2 - .rels - .by_rels() - .get("webmention") - .and_then(|i| i.first().cloned()); - - dbg!(Some(FetchedPostContext { - url, - mf2: serde_json::to_value(mf2).unwrap(), - webmention - })) - }) - .collect::<Vec<FetchedPostContext>>() - .await - }; - - let mut update = MicropubUpdate { - replace: Some(Default::default()), - ..Default::default() - }; - for prop in context_props { - if let Some(json) = populate_reply_context(&mf2, prop, &post_contexts) { - update.replace.as_mut().unwrap().insert(prop.to_owned(), json); - } - } - if !update.replace.as_ref().unwrap().is_empty() { - if let Err(err) = db.update_post(uid, update).await { - error!("Failed to update post with rich reply contexts: {}", err); - } - } - - // At this point we can start syndicating the post. - // Currently we don't really support any syndication endpoints, but still! - /*if let Some(syndicate_to) = mf2["properties"]["mp-syndicate-to"].as_array() { - let http = &http; - tokio_stream::iter(syndicate_to) - .filter_map(|i| futures::future::ready(i.as_str())) - .for_each_concurrent(3, |s: &str| async move { - #[allow(clippy::match_single_binding)] - match s { - _ => { - todo!("Syndicate to generic webmention-aware service {}", s); - } - // TODO special handling for non-webmention-aware services like the birdsite - } - }) - .await; - }*/ - - { - let http = &http; - tokio_stream::iter( - post_contexts - .into_iter() - .filter(|ctx| ctx.webmention.is_some()), - ) - .for_each_concurrent(2, |ctx| async move { - let mut map = std::collections::HashMap::new(); - map.insert("source", uid); - map.insert("target", ctx.url.as_str()); - - match http - .post(ctx.webmention.unwrap().clone()) - .form(&map) - .send() - .await - { - Ok(res) => { - if !res.status().is_success() { - warn!( - "Failed to send a webmention for {}: got HTTP {}", - ctx.url, - res.status() - ); - } else { - info!( - "Sent a webmention to {}, got HTTP {}", - ctx.url, - res.status() - ) - } - } - Err(err) => warn!("Failed to send a webmention for {}: {}", ctx.url, err), - } - }) - .await; - } -} - -// TODO actually save the post to the database and schedule post-processing -pub(crate) async fn _post<D: 'static + Storage>( - user: &TokenData, - uid: String, - mf2: serde_json::Value, - db: D, - http: reqwest::Client, - jobset: Arc<Mutex<JoinSet<()>>>, -) -> Result<Response, MicropubError> { - // Here, we have the following guarantees: - // - The MF2-JSON document is normalized (guaranteed by normalize_mf2) - // - The MF2-JSON document contains a UID - // - The MF2-JSON document's URL list contains its UID - // - The MF2-JSON document's "content" field contains an HTML blob, if present - // - The MF2-JSON document's publishing datetime is present - // - The MF2-JSON document's target channels are set - // - The MF2-JSON document's author is set - - // Security check! Do we have an OAuth2 scope to proceed? - if !user.check_scope(&Scope::Create) { - return Err(MicropubError { - error: ErrorType::InvalidScope, - error_description: "Not enough privileges - try acquiring the \"create\" scope." - .to_owned(), - }); - } - - // Security check #2! Are we posting to our own website? - if !uid.starts_with(user.me.as_str()) - || mf2["properties"]["channel"] - .as_array() - .unwrap_or(&vec![]) - .iter() - .any(|url| !url.as_str().unwrap().starts_with(user.me.as_str())) - { - return Err(MicropubError { - error: ErrorType::Forbidden, - error_description: "You're posting to a website that's not yours.".to_owned(), - }); - } - - // Security check #3! Are we overwriting an existing document? - if db.post_exists(&uid).await? { - return Err(MicropubError { - error: ErrorType::AlreadyExists, - error_description: "UID clash was detected, operation aborted.".to_owned(), - }); - } - let user_domain = format!( - "{}{}", - user.me.host_str().unwrap(), - user.me.port() - .map(|port| format!(":{}", port)) - .unwrap_or_default() - ); - // Save the post - tracing::debug!("Saving post to database..."); - db.put_post(&mf2, &user_domain).await?; - - let mut channels = mf2["properties"]["channel"] - .as_array() - .unwrap() - .iter() - .map(|i| i.as_str().unwrap_or("")) - .filter(|i| !i.is_empty()); - - let default_channel = user - .me - .join(util::DEFAULT_CHANNEL_PATH) - .unwrap() - .to_string(); - let vcards_channel = user - .me - .join(util::CONTACTS_CHANNEL_PATH) - .unwrap() - .to_string(); - let food_channel = user.me.join(util::FOOD_CHANNEL_PATH).unwrap().to_string(); - let default_channels = vec![default_channel, vcards_channel, food_channel]; - - for chan in &mut channels { - debug!("Adding post {} to channel {}", uid, chan); - if db.post_exists(chan).await? { - db.add_to_feed(chan, &uid).await?; - } else if default_channels.iter().any(|i| chan == i) { - util::create_feed(&db, &uid, chan, user).await?; - } else { - warn!("Ignoring non-existent channel: {}", chan); - } - } - - let reply = - IntoResponse::into_response((StatusCode::ACCEPTED, [("Location", uid.as_str())])); - - #[cfg(not(tokio_unstable))] - jobset.lock().await.spawn(background_processing(db, mf2, http)); - #[cfg(tokio_unstable)] - jobset.lock().await.build_task() - .name(format!("Kittybox background processing for post {}", uid.as_str()).as_str()) - .spawn(background_processing(db, mf2, http)); - - Ok(reply) -} - -#[derive(Serialize, Deserialize, Debug)] -#[serde(rename_all = "snake_case")] -enum ActionType { - Delete, - Update, -} - -#[derive(Serialize, Deserialize, Debug)] -#[serde(untagged)] -pub enum MicropubPropertyDeletion { - Properties(Vec<String>), - Values(HashMap<String, Vec<serde_json::Value>>) -} -#[derive(Serialize, Deserialize)] -struct MicropubFormAction { - action: ActionType, - url: String, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct MicropubAction { - action: ActionType, - url: String, - #[serde(flatten)] - #[serde(skip_serializing_if = "Option::is_none")] - update: Option<MicropubUpdate> -} - -#[derive(Serialize, Deserialize, Debug, Default)] -pub struct MicropubUpdate { - #[serde(skip_serializing_if = "Option::is_none")] - pub replace: Option<HashMap<String, Vec<serde_json::Value>>>, - #[serde(skip_serializing_if = "Option::is_none")] - pub add: Option<HashMap<String, Vec<serde_json::Value>>>, - #[serde(skip_serializing_if = "Option::is_none")] - pub delete: Option<MicropubPropertyDeletion>, - -} - -impl From<MicropubFormAction> for MicropubAction { - fn from(a: MicropubFormAction) -> Self { - debug_assert!(matches!(a.action, ActionType::Delete)); - Self { - action: a.action, - url: a.url, - update: None - } - } -} - -#[tracing::instrument(skip(db))] -async fn post_action<D: Storage, A: AuthBackend>( - action: MicropubAction, - db: D, - user: User<A>, -) -> Result<(), MicropubError> { - let uri = if let Ok(uri) = action.url.parse::<hyper::Uri>() { - uri - } else { - return Err(MicropubError { - error: ErrorType::InvalidRequest, - error_description: "Your URL doesn't parse properly.".to_owned(), - }); - }; - - if uri.authority().unwrap() - != user - .me - .as_str() - .parse::<hyper::Uri>() - .unwrap() - .authority() - .unwrap() - { - return Err(MicropubError { - error: ErrorType::Forbidden, - error_description: "Don't tamper with others' posts!".to_owned(), - }); - } - - match action.action { - ActionType::Delete => { - if !user.check_scope(&Scope::Delete) { - return Err(MicropubError { - error: ErrorType::InvalidScope, - error_description: "You need a \"delete\" scope for this.".to_owned(), - }); - } - - db.delete_post(&action.url).await? - } - ActionType::Update => { - if !user.check_scope(&Scope::Update) { - return Err(MicropubError { - error: ErrorType::InvalidScope, - error_description: "You need an \"update\" scope for this.".to_owned(), - }); - } - - db.update_post( - &action.url, - action.update.ok_or(MicropubError { - error: ErrorType::InvalidRequest, - error_description: "Update request is not set.".to_owned(), - })? - ) - .await? - } - } - - Ok(()) -} - -enum PostBody { - Action(MicropubAction), - MF2(serde_json::Value), -} - -#[tracing::instrument] -async fn dispatch_body( - mut body: BodyStream, - content_type: ContentType, -) -> Result<PostBody, MicropubError> { - let body: Vec<u8> = { - debug!("Buffering body..."); - use tokio_stream::StreamExt; - let mut buf = Vec::default(); - - while let Some(chunk) = body.next().await { - buf.extend_from_slice(&chunk.unwrap()) - } - - buf - }; - - debug!("Content-Type: {:?}", content_type); - if content_type == ContentType::json() { - if let Ok(action) = serde_json::from_slice::<MicropubAction>(&body) { - Ok(PostBody::Action(action)) - } else if let Ok(body) = serde_json::from_slice::<serde_json::Value>(&body) { - // quick sanity check - if !body.is_object() || !body["type"].is_array() { - return Err(MicropubError { - error: ErrorType::InvalidRequest, - error_description: "Invalid MF2-JSON detected: `.` should be an object, `.type` should be an array of MF2 types".to_owned() - }); - } - - Ok(PostBody::MF2(body)) - } else { - Err(MicropubError { - error: ErrorType::InvalidRequest, - error_description: "Invalid JSON object passed.".to_owned(), - }) - } - } else if content_type == ContentType::form_url_encoded() { - if let Ok(body) = serde_urlencoded::from_bytes::<MicropubFormAction>(&body) { - Ok(PostBody::Action(body.into())) - } else if let Ok(body) = serde_urlencoded::from_bytes::<Vec<(String, String)>>(&body) { - Ok(PostBody::MF2(form_to_mf2_json(body))) - } else { - Err(MicropubError { - error: ErrorType::InvalidRequest, - error_description: "Invalid form-encoded data. Try h=entry&content=Hello!" - .to_owned(), - }) - } - } else { - Err(MicropubError::new( - ErrorType::UnsupportedMediaType, - "This Content-Type is not recognized. Try application/json instead?", - )) - } -} - -#[tracing::instrument(skip(db, http))] -pub(crate) async fn post<D: Storage + 'static, A: AuthBackend>( - Extension(db): Extension<D>, - Extension(http): Extension<reqwest::Client>, - Extension(jobset): Extension<Arc<Mutex<JoinSet<()>>>>, - TypedHeader(content_type): TypedHeader<ContentType>, - user: User<A>, - body: BodyStream, -) -> axum::response::Response { - match dispatch_body(body, content_type).await { - Ok(PostBody::Action(action)) => match post_action(action, db, user).await { - Ok(()) => Response::default(), - Err(err) => err.into_response(), - }, - Ok(PostBody::MF2(mf2)) => { - let (uid, mf2) = normalize_mf2(mf2, &user); - match _post(&user, uid, mf2, db, http, jobset).await { - Ok(response) => response, - Err(err) => err.into_response(), - } - } - Err(err) => err.into_response(), - } -} - -#[tracing::instrument(skip(db))] -pub(crate) async fn query<D: Storage, A: AuthBackend>( - Extension(db): Extension<D>, - query: Option<Query<MicropubQuery>>, - Host(host): Host, - user: User<A>, -) -> axum::response::Response { - // We handle the invalid query case manually to return a - // MicropubError instead of HTTP 422 - let query = if let Some(Query(query)) = query { - query - } else { - return MicropubError::new( - ErrorType::InvalidRequest, - "Invalid query provided. Try ?q=config to see what you can do." - ).into_response(); - }; - - if axum::http::Uri::try_from(user.me.as_str()) - .unwrap() - .authority() - .unwrap() - != &host - { - return MicropubError::new( - ErrorType::NotAuthorized, - "This website doesn't belong to you.", - ) - .into_response(); - } - - let user_domain = format!( - "{}{}", - user.me.host_str().unwrap(), - user.me.port() - .map(|port| format!(":{}", port)) - .unwrap_or_default() - ); - match query.q { - QueryType::Config => { - let channels: Vec<MicropubChannel> = match db.get_channels(user.me.as_str()).await { - Ok(chans) => chans, - Err(err) => { - return MicropubError::new( - ErrorType::InternalServerError, - &format!("Error fetching channels: {}", err), - ) - .into_response() - } - }; - - axum::response::Json(json!({ - "q": [ - QueryType::Source, - QueryType::Config, - QueryType::Channel, - QueryType::SyndicateTo - ], - "channels": channels, - "_kittybox_authority": user.me.as_str(), - "syndicate-to": [], - "media-endpoint": user.me.join("/.kittybox/media").unwrap().as_str() - })) - .into_response() - } - QueryType::Source => { - match query.url { - Some(url) => { - match db.get_post(&url).await { - Ok(some) => match some { - Some(post) => axum::response::Json(&post).into_response(), - None => MicropubError::new( - ErrorType::NotFound, - "The specified MF2 object was not found in database.", - ) - .into_response(), - }, - Err(err) => MicropubError::new( - ErrorType::InternalServerError, - &format!("Backend error: {}", err), - ) - .into_response(), - } - } - None => { - // Here, one should probably attempt to query at least the main feed and collect posts - // Using a pre-made query function can't be done because it does unneeded filtering - // Don't implement for now, this is optional - MicropubError::new( - ErrorType::InvalidRequest, - "Querying for post list is not implemented yet.", - ) - .into_response() - } - } - } - QueryType::Channel => match db.get_channels(&user_domain).await { - Ok(chans) => axum::response::Json(json!({ "channels": chans })).into_response(), - Err(err) => MicropubError::new( - ErrorType::InternalServerError, - &format!("Error fetching channels: {}", err), - ) - .into_response(), - }, - QueryType::SyndicateTo => { - axum::response::Json(json!({ "syndicate-to": [] })).into_response() - } - } -} - -#[must_use] -pub fn router<S, A>( - storage: S, - http: reqwest::Client, - auth: A, - jobset: Arc<Mutex<JoinSet<()>>> -) -> axum::routing::MethodRouter -where - S: Storage + 'static, - A: AuthBackend -{ - axum::routing::get(query::<S, A>) - .post(post::<S, A>) - .layer::<_, _, std::convert::Infallible>(tower_http::cors::CorsLayer::new() - .allow_methods([ - axum::http::Method::GET, - axum::http::Method::POST, - ]) - .allow_origin(tower_http::cors::Any)) - .layer::<_, _, std::convert::Infallible>(axum::Extension(storage)) - .layer::<_, _, std::convert::Infallible>(axum::Extension(http)) - .layer::<_, _, std::convert::Infallible>(axum::Extension(auth)) - .layer::<_, _, std::convert::Infallible>(axum::Extension(jobset)) -} - -#[cfg(test)] -#[allow(dead_code)] -impl MicropubQuery { - fn config() -> Self { - Self { - q: QueryType::Config, - url: None, - } - } - - fn source(url: &str) -> Self { - Self { - q: QueryType::Source, - url: Some(url.to_owned()), - } - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use crate::{database::Storage, micropub::MicropubError}; - use hyper::body::HttpBody; - use serde_json::json; - use tokio::sync::Mutex; - - use super::FetchedPostContext; - use kittybox_indieauth::{Scopes, Scope, TokenData}; - use axum::extract::Host; - - #[test] - fn test_populate_reply_context() { - let already_expanded_reply_ctx = json!({ - "type": ["h-entry"], - "properties": { - "content": ["Hello world!"] - } - }); - let mf2 = json!({ - "type": ["h-entry"], - "properties": { - "like-of": [ - "https://fireburn.ru/posts/example", - already_expanded_reply_ctx, - "https://fireburn.ru/posts/non-existent" - ] - } - }); - let test_ctx = json!({ - "type": ["h-entry"], - "properties": { - "content": ["This is a post which was reacted to."] - } - }); - let reply_contexts = vec![FetchedPostContext { - url: "https://fireburn.ru/posts/example".parse().unwrap(), - mf2: json!({ "items": [test_ctx] }), - webmention: None, - }]; - - let like_of = super::populate_reply_context(&mf2, "like-of", &reply_contexts).unwrap(); - - assert_eq!(like_of[0], test_ctx); - assert_eq!(like_of[1], already_expanded_reply_ctx); - assert_eq!(like_of[2], "https://fireburn.ru/posts/non-existent"); - } - - #[tokio::test] - async fn test_post_reject_scope() { - let db = crate::database::MemoryStorage::new(); - - let post = json!({ - "type": ["h-entry"], - "properties": { - "content": ["Hello world!"] - } - }); - let user = TokenData { - me: "https://localhost:8080/".parse().unwrap(), - client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), - scope: Scopes::new(vec![Scope::Profile]), - iat: None, exp: None - }; - let (uid, mf2) = super::normalize_mf2(post, &user); - - let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) - .await - .unwrap_err(); - - assert_eq!(err.error, super::ErrorType::InvalidScope); - - let hashmap = db.mapping.read().await; - assert!(hashmap.is_empty()); - } - - #[tokio::test] - async fn test_post_reject_different_user() { - let db = crate::database::MemoryStorage::new(); - - let post = json!({ - "type": ["h-entry"], - "properties": { - "content": ["Hello world!"], - "uid": ["https://fireburn.ru/posts/hello"], - "url": ["https://fireburn.ru/posts/hello"] - } - }); - let user = TokenData { - me: "https://aaronparecki.com/".parse().unwrap(), - client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), - scope: Scopes::new(vec![Scope::Profile, Scope::Create, Scope::Update, Scope::Media]), - iat: None, exp: None - }; - let (uid, mf2) = super::normalize_mf2(post, &user); - - let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) - .await - .unwrap_err(); - - assert_eq!(err.error, super::ErrorType::Forbidden); - - let hashmap = db.mapping.read().await; - assert!(hashmap.is_empty()); - } - - #[tokio::test] - async fn test_post_mf2() { - let db = crate::database::MemoryStorage::new(); - - let post = json!({ - "type": ["h-entry"], - "properties": { - "content": ["Hello world!"] - } - }); - let user = TokenData { - me: "https://localhost:8080/".parse().unwrap(), - client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), - scope: Scopes::new(vec![Scope::Profile, Scope::Create]), - iat: None, exp: None - }; - let (uid, mf2) = super::normalize_mf2(post, &user); - - let res = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) - .await - .unwrap(); - - assert!(res.headers().contains_key("Location")); - let location = res.headers().get("Location").unwrap(); - assert!(db.post_exists(location.to_str().unwrap()).await.unwrap()); - assert!(db - .post_exists("https://localhost:8080/feeds/main") - .await - .unwrap()); - } - - #[tokio::test] - async fn test_query_foreign_url() { - let mut res = super::query( - axum::Extension(crate::database::MemoryStorage::new()), - Some(axum::extract::Query(super::MicropubQuery::source( - "https://aaronparecki.com/feeds/main", - ))), - Host("aaronparecki.com".to_owned()), - crate::indieauth::User::<crate::indieauth::backend::fs::FileBackend>( - TokenData { - me: "https://fireburn.ru/".parse().unwrap(), - client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), - scope: Scopes::new(vec![Scope::Profile, Scope::Create, Scope::Update, Scope::Media]), - iat: None, exp: None - }, std::marker::PhantomData - ) - ) - .await; - - assert_eq!(res.status(), 401); - let body = res.body_mut().data().await.unwrap().unwrap(); - let json: MicropubError = serde_json::from_slice(&body as &[u8]).unwrap(); - assert_eq!(json.error, super::ErrorType::NotAuthorized); - } -} diff --git a/kittybox-rs/src/micropub/util.rs b/kittybox-rs/src/micropub/util.rs deleted file mode 100644 index 940d7c3..0000000 --- a/kittybox-rs/src/micropub/util.rs +++ /dev/null @@ -1,444 +0,0 @@ -use crate::database::Storage; -use kittybox_indieauth::TokenData; -use chrono::prelude::*; -use core::iter::Iterator; -use newbase60::num_to_sxg; -use serde_json::json; -use std::convert::TryInto; - -pub(crate) const DEFAULT_CHANNEL_PATH: &str = "/feeds/main"; -const DEFAULT_CHANNEL_NAME: &str = "Main feed"; -pub(crate) const CONTACTS_CHANNEL_PATH: &str = "/feeds/vcards"; -const CONTACTS_CHANNEL_NAME: &str = "My address book"; -pub(crate) const FOOD_CHANNEL_PATH: &str = "/feeds/food"; -const FOOD_CHANNEL_NAME: &str = "My recipe book"; - -fn get_folder_from_type(post_type: &str) -> String { - (match post_type { - "h-feed" => "feeds/", - "h-card" => "vcards/", - "h-event" => "events/", - "h-food" => "food/", - _ => "posts/", - }) - .to_string() -} - -/// Reset the datetime to a proper datetime. -/// Do not attempt to recover the information. -/// Do not pass GO. Do not collect $200. -fn reset_dt(post: &mut serde_json::Value) -> DateTime<FixedOffset> { - let curtime: DateTime<Local> = Local::now(); - post["properties"]["published"] = json!([curtime.to_rfc3339()]); - chrono::DateTime::from(curtime) -} - -pub fn normalize_mf2(mut body: serde_json::Value, user: &TokenData) -> (String, serde_json::Value) { - // Normalize the MF2 object here. - let me = &user.me; - let folder = get_folder_from_type(body["type"][0].as_str().unwrap()); - let published: DateTime<FixedOffset> = - if let Some(dt) = body["properties"]["published"][0].as_str() { - // Check if the datetime is parsable. - match DateTime::parse_from_rfc3339(dt) { - Ok(dt) => dt, - Err(_) => reset_dt(&mut body), - } - } else { - // Set the datetime. - // Note: this code block duplicates functionality with the above failsafe. - // Consider refactoring it to a helper function? - reset_dt(&mut body) - }; - match body["properties"]["uid"][0].as_str() { - None => { - let uid = serde_json::Value::String( - me.join( - &(folder.clone() - + &num_to_sxg(published.timestamp_millis().try_into().unwrap())), - ) - .unwrap() - .to_string(), - ); - body["properties"]["uid"] = serde_json::Value::Array(vec![uid.clone()]); - match body["properties"]["url"].as_array_mut() { - Some(array) => array.push(uid), - None => body["properties"]["url"] = body["properties"]["uid"].clone(), - } - } - Some(uid_str) => { - let uid = uid_str.to_string(); - match body["properties"]["url"].as_array_mut() { - Some(array) => { - if !array.iter().any(|i| i.as_str().unwrap_or("") == uid) { - array.push(serde_json::Value::String(uid)) - } - } - None => body["properties"]["url"] = body["properties"]["uid"].clone(), - } - } - } - if let Some(slugs) = body["properties"]["mp-slug"].as_array() { - let new_urls = slugs - .iter() - .map(|i| i.as_str().unwrap_or("")) - .filter(|i| i != &"") - .map(|i| me.join(&((&folder).clone() + i)).unwrap().to_string()) - .collect::<Vec<String>>(); - let urls = body["properties"]["url"].as_array_mut().unwrap(); - new_urls.iter().for_each(|i| urls.push(json!(i))); - } - let props = body["properties"].as_object_mut().unwrap(); - props.remove("mp-slug"); - - if body["properties"]["content"][0].is_string() { - // Convert the content to HTML using the `markdown` crate - body["properties"]["content"] = json!([{ - "html": markdown::to_html(body["properties"]["content"][0].as_str().unwrap()), - "value": body["properties"]["content"][0] - }]) - } - // TODO: apply this normalization to editing too - if body["properties"]["mp-channel"].is_array() { - let mut additional_channels = body["properties"]["mp-channel"].as_array().unwrap().clone(); - if let Some(array) = body["properties"]["channel"].as_array_mut() { - array.append(&mut additional_channels); - } else { - body["properties"]["channel"] = json!(additional_channels) - } - body["properties"] - .as_object_mut() - .unwrap() - .remove("mp-channel"); - } else if body["properties"]["mp-channel"].is_string() { - let chan = body["properties"]["mp-channel"] - .as_str() - .unwrap() - .to_owned(); - if let Some(array) = body["properties"]["channel"].as_array_mut() { - array.push(json!(chan)) - } else { - body["properties"]["channel"] = json!([chan]); - } - body["properties"] - .as_object_mut() - .unwrap() - .remove("mp-channel"); - } - if body["properties"]["channel"][0].as_str().is_none() { - match body["type"][0].as_str() { - Some("h-entry") => { - // Set the channel to the main channel... - // TODO find like posts and move them to separate private channel - let default_channel = me.join(DEFAULT_CHANNEL_PATH).unwrap().to_string(); - - body["properties"]["channel"] = json!([default_channel]); - } - Some("h-card") => { - let default_channel = me.join(CONTACTS_CHANNEL_PATH).unwrap().to_string(); - - body["properties"]["channel"] = json!([default_channel]); - } - Some("h-food") => { - let default_channel = me.join(FOOD_CHANNEL_PATH).unwrap().to_string(); - - body["properties"]["channel"] = json!([default_channel]); - } - // TODO h-event - /*"h-event" => { - let default_channel - },*/ - _ => { - body["properties"]["channel"] = json!([]); - } - } - } - body["properties"]["posted-with"] = json!([user.client_id]); - if body["properties"]["author"][0].as_str().is_none() { - body["properties"]["author"] = json!([me.as_str()]) - } - // TODO: maybe highlight #hashtags? - // Find other processing to do and insert it here - return ( - body["properties"]["uid"][0].as_str().unwrap().to_string(), - body, - ); -} - -pub(crate) fn form_to_mf2_json(form: Vec<(String, String)>) -> serde_json::Value { - let mut mf2 = json!({"type": [], "properties": {}}); - for (k, v) in form { - if k == "h" { - mf2["type"] - .as_array_mut() - .unwrap() - .push(json!("h-".to_string() + &v)); - } else if k != "access_token" { - let key = k.strip_suffix("[]").unwrap_or(&k); - match mf2["properties"][key].as_array_mut() { - Some(prop) => prop.push(json!(v)), - None => mf2["properties"][key] = json!([v]), - } - } - } - if mf2["type"].as_array().unwrap().is_empty() { - mf2["type"].as_array_mut().unwrap().push(json!("h-entry")); - } - mf2 -} - -pub(crate) async fn create_feed( - storage: &impl Storage, - uid: &str, - channel: &str, - user: &TokenData, -) -> crate::database::Result<()> { - let path = url::Url::parse(channel).unwrap().path().to_string(); - - let name = match path.as_str() { - DEFAULT_CHANNEL_PATH => DEFAULT_CHANNEL_NAME, - CONTACTS_CHANNEL_PATH => CONTACTS_CHANNEL_NAME, - FOOD_CHANNEL_PATH => FOOD_CHANNEL_NAME, - _ => panic!("Tried to create an unknown default feed!"), - }; - - let (_, feed) = normalize_mf2( - json!({ - "type": ["h-feed"], - "properties": { - "name": [name], - "uid": [channel] - }, - }), - user, - ); - storage.put_post(&feed, user.me.as_str()).await?; - storage.add_to_feed(channel, uid).await -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - fn token_data() -> TokenData { - TokenData { - me: "https://fireburn.ru/".parse().unwrap(), - client_id: "https://quill.p3k.io/".parse().unwrap(), - scope: kittybox_indieauth::Scopes::new(vec![kittybox_indieauth::Scope::Create]), - exp: Some(u64::MAX), - iat: Some(0) - } - } - - #[test] - fn test_form_to_mf2() { - assert_eq!( - super::form_to_mf2_json( - serde_urlencoded::from_str("h=entry&content=something%20interesting").unwrap() - ), - json!({ - "type": ["h-entry"], - "properties": { - "content": ["something interesting"] - } - }) - ) - } - - #[test] - fn test_no_replace_uid() { - let mf2 = json!({ - "type": ["h-card"], - "properties": { - "uid": ["https://fireburn.ru/"], - "name": ["Vika Nezrimaya"], - "note": ["A crazy programmer girl who wants some hugs"] - } - }); - - let (uid, normalized) = normalize_mf2( - mf2.clone(), - &token_data(), - ); - assert_eq!( - normalized["properties"]["uid"][0], mf2["properties"]["uid"][0], - "UID was replaced" - ); - assert_eq!( - normalized["properties"]["uid"][0], uid, - "Returned post location doesn't match UID" - ); - } - - #[test] - fn test_mp_channel() { - let mf2 = json!({ - "type": ["h-entry"], - "properties": { - "uid": ["https://fireburn.ru/posts/test"], - "content": [{"html": "<p>Hello world!</p>"}], - "mp-channel": ["https://fireburn.ru/feeds/test"] - } - }); - - let (_, normalized) = normalize_mf2( - mf2.clone(), - &token_data(), - ); - - assert_eq!( - normalized["properties"]["channel"], - mf2["properties"]["mp-channel"] - ); - } - - #[test] - fn test_mp_channel_as_string() { - let mf2 = json!({ - "type": ["h-entry"], - "properties": { - "uid": ["https://fireburn.ru/posts/test"], - "content": [{"html": "<p>Hello world!</p>"}], - "mp-channel": "https://fireburn.ru/feeds/test" - } - }); - - let (_, normalized) = normalize_mf2( - mf2.clone(), - &token_data(), - ); - - assert_eq!( - normalized["properties"]["channel"][0], - mf2["properties"]["mp-channel"] - ); - } - - #[test] - fn test_normalize_mf2() { - let mf2 = json!({ - "type": ["h-entry"], - "properties": { - "content": ["This is content!"] - } - }); - - let (uid, post) = normalize_mf2( - mf2, - &token_data(), - ); - assert_eq!( - post["properties"]["published"] - .as_array() - .expect("post['published'] is undefined") - .len(), - 1, - "Post doesn't have a published time" - ); - DateTime::parse_from_rfc3339(post["properties"]["published"][0].as_str().unwrap()) - .expect("Couldn't parse date from rfc3339"); - assert!( - !post["properties"]["url"] - .as_array() - .expect("post['url'] is undefined") - .is_empty(), - "Post doesn't have any URLs" - ); - assert_eq!( - post["properties"]["uid"] - .as_array() - .expect("post['uid'] is undefined") - .len(), - 1, - "Post doesn't have a single UID" - ); - assert_eq!( - post["properties"]["uid"][0], uid, - "UID of a post and its supposed location don't match" - ); - assert!( - uid.starts_with("https://fireburn.ru/posts/"), - "The post namespace is incorrect" - ); - assert_eq!( - post["properties"]["content"][0]["html"] - .as_str() - .expect("Post doesn't have a rich content object") - .trim(), - "<p>This is content!</p>", - "Parsed Markdown content doesn't match expected HTML" - ); - assert_eq!( - post["properties"]["channel"][0], "https://fireburn.ru/feeds/main", - "Post isn't posted to the main channel" - ); - assert_eq!( - post["properties"]["author"][0], "https://fireburn.ru/", - "Post author is unknown" - ); - } - - #[test] - fn test_mp_slug() { - let mf2 = json!({ - "type": ["h-entry"], - "properties": { - "content": ["This is content!"], - "mp-slug": ["hello-post"] - }, - }); - - let (_, post) = normalize_mf2( - mf2, - &token_data(), - ); - assert!( - post["properties"]["url"] - .as_array() - .unwrap() - .iter() - .map(|i| i.as_str().unwrap()) - .any(|i| i == "https://fireburn.ru/posts/hello-post"), - "Didn't found an URL pointing to the location expected by the mp-slug semantics" - ); - assert!( - post["properties"]["mp-slug"].as_array().is_none(), - "mp-slug wasn't deleted from the array!" - ) - } - - #[test] - fn test_normalize_feed() { - let mf2 = json!({ - "type": ["h-feed"], - "properties": { - "name": "Main feed", - "mp-slug": ["main"] - } - }); - - let (uid, post) = normalize_mf2( - mf2, - &token_data(), - ); - assert_eq!( - post["properties"]["uid"][0], uid, - "UID of a post and its supposed location don't match" - ); - assert_eq!(post["properties"]["author"][0], "https://fireburn.ru/"); - assert!( - post["properties"]["url"] - .as_array() - .unwrap() - .iter() - .map(|i| i.as_str().unwrap()) - .any(|i| i == "https://fireburn.ru/feeds/main"), - "Didn't found an URL pointing to the location expected by the mp-slug semantics" - ); - assert!( - post["properties"]["mp-slug"].as_array().is_none(), - "mp-slug wasn't deleted from the array!" - ) - } -} diff --git a/kittybox-rs/src/tokenauth.rs b/kittybox-rs/src/tokenauth.rs deleted file mode 100644 index 244a045..0000000 --- a/kittybox-rs/src/tokenauth.rs +++ /dev/null @@ -1,358 +0,0 @@ -use serde::{Deserialize, Serialize}; -use url::Url; - -#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] -pub struct User { - pub me: Url, - pub client_id: Url, - scope: String, -} - -#[derive(Debug, Clone, PartialEq, Copy)] -pub enum ErrorKind { - PermissionDenied, - NotAuthorized, - TokenEndpointError, - JsonParsing, - InvalidHeader, - Other, -} - -#[derive(Deserialize, Serialize, Debug, Clone)] -pub struct TokenEndpointError { - error: String, - error_description: String, -} - -#[derive(Debug)] -pub struct IndieAuthError { - source: Option<Box<dyn std::error::Error + Send + Sync>>, - kind: ErrorKind, - msg: String, -} - -impl std::error::Error for IndieAuthError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.source - .as_ref() - .map(|e| e.as_ref() as &dyn std::error::Error) - } -} - -impl std::fmt::Display for IndieAuthError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}: {}", - match self.kind { - ErrorKind::TokenEndpointError => "token endpoint returned an error: ", - ErrorKind::JsonParsing => "error while parsing token endpoint response: ", - ErrorKind::NotAuthorized => "token endpoint did not recognize the token: ", - ErrorKind::PermissionDenied => "token endpoint rejected the token: ", - ErrorKind::InvalidHeader => "authorization header parsing error: ", - ErrorKind::Other => "token endpoint communication error: ", - }, - self.msg - ) - } -} - -impl From<serde_json::Error> for IndieAuthError { - fn from(err: serde_json::Error) -> Self { - Self { - msg: format!("{}", err), - source: Some(Box::new(err)), - kind: ErrorKind::JsonParsing, - } - } -} - -impl From<reqwest::Error> for IndieAuthError { - fn from(err: reqwest::Error) -> Self { - Self { - msg: format!("{}", err), - source: Some(Box::new(err)), - kind: ErrorKind::Other, - } - } -} - -impl From<axum::extract::rejection::TypedHeaderRejection> for IndieAuthError { - fn from(err: axum::extract::rejection::TypedHeaderRejection) -> Self { - Self { - msg: format!("{:?}", err.reason()), - source: Some(Box::new(err)), - kind: ErrorKind::InvalidHeader, - } - } -} - -impl axum::response::IntoResponse for IndieAuthError { - fn into_response(self) -> axum::response::Response { - let status_code: StatusCode = match self.kind { - ErrorKind::PermissionDenied => StatusCode::FORBIDDEN, - ErrorKind::NotAuthorized => StatusCode::UNAUTHORIZED, - ErrorKind::TokenEndpointError => StatusCode::INTERNAL_SERVER_ERROR, - ErrorKind::JsonParsing => StatusCode::BAD_REQUEST, - ErrorKind::InvalidHeader => StatusCode::UNAUTHORIZED, - ErrorKind::Other => StatusCode::INTERNAL_SERVER_ERROR, - }; - - let body = serde_json::json!({ - "error": match self.kind { - ErrorKind::PermissionDenied => "forbidden", - ErrorKind::NotAuthorized => "unauthorized", - ErrorKind::TokenEndpointError => "token_endpoint_error", - ErrorKind::JsonParsing => "invalid_request", - ErrorKind::InvalidHeader => "unauthorized", - ErrorKind::Other => "unknown_error", - }, - "error_description": self.msg - }); - - (status_code, axum::response::Json(body)).into_response() - } -} - -impl User { - pub fn check_scope(&self, scope: &str) -> bool { - self.scopes().any(|i| i == scope) - } - pub fn scopes(&self) -> std::str::SplitAsciiWhitespace<'_> { - self.scope.split_ascii_whitespace() - } - pub fn new(me: &str, client_id: &str, scope: &str) -> Self { - Self { - me: Url::parse(me).unwrap(), - client_id: Url::parse(client_id).unwrap(), - scope: scope.to_string(), - } - } -} - -use axum::{ - extract::{Extension, FromRequest, RequestParts, TypedHeader}, - headers::{ - authorization::{Bearer, Credentials}, - Authorization, - }, - http::StatusCode, -}; - -// this newtype is required due to axum::Extension retrieving items by type -// it's based on compiler magic matching extensions by their type's hashes -#[derive(Debug, Clone)] -pub struct TokenEndpoint(pub url::Url); - -#[async_trait::async_trait] -impl<B> FromRequest<B> for User -where - B: Send, -{ - type Rejection = IndieAuthError; - - #[cfg_attr( - all(debug_assertions, not(test)), - allow(unreachable_code, unused_variables) - )] - async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { - // Return a fake user if we're running a debug build - // I don't wanna bother with authentication - #[cfg(all(debug_assertions, not(test)))] - return Ok(User::new( - "http://localhost:8080/", - "https://quill.p3k.io/", - "create update delete media", - )); - - let TypedHeader(Authorization(token)) = - TypedHeader::<Authorization<Bearer>>::from_request(req) - .await - .map_err(IndieAuthError::from)?; - - let Extension(TokenEndpoint(token_endpoint)): Extension<TokenEndpoint> = - Extension::from_request(req).await.unwrap(); - - let Extension(http): Extension<reqwest::Client> = - Extension::from_request(req).await.unwrap(); - - match http - .get(token_endpoint) - .header("Authorization", token.encode()) - .header("Accept", "application/json") - .send() - .await - { - Ok(res) => match res.status() { - StatusCode::OK => match res.json::<serde_json::Value>().await { - Ok(json) => match serde_json::from_value::<User>(json.clone()) { - Ok(user) => Ok(user), - Err(err) => { - if let Some(false) = json["active"].as_bool() { - Err(IndieAuthError { - source: None, - kind: ErrorKind::NotAuthorized, - msg: "The token is not active for this user.".to_owned(), - }) - } else { - Err(IndieAuthError::from(err)) - } - } - }, - Err(err) => Err(IndieAuthError::from(err)), - }, - StatusCode::BAD_REQUEST => match res.json::<TokenEndpointError>().await { - Ok(err) => { - if err.error == "unauthorized" { - Err(IndieAuthError { - source: None, - kind: ErrorKind::NotAuthorized, - msg: err.error_description, - }) - } else { - Err(IndieAuthError { - source: None, - kind: ErrorKind::TokenEndpointError, - msg: err.error_description, - }) - } - } - Err(err) => Err(IndieAuthError::from(err)), - }, - _ => Err(IndieAuthError { - source: None, - msg: format!("Token endpoint returned {}", res.status()), - kind: ErrorKind::TokenEndpointError, - }), - }, - Err(err) => Err(IndieAuthError::from(err)), - } - } -} - -#[cfg(test)] -mod tests { - use super::User; - use axum::{ - extract::FromRequest, - http::{Method, Request}, - }; - use wiremock::{MockServer, Mock, ResponseTemplate}; - use wiremock::matchers::{method, path, header}; - - #[test] - fn user_scopes_are_checkable() { - let user = User::new( - "https://fireburn.ru/", - "https://quill.p3k.io/", - "create update media", - ); - - assert!(user.check_scope("create")); - assert!(!user.check_scope("delete")); - } - - #[inline] - fn get_http_client() -> reqwest::Client { - reqwest::Client::new() - } - - fn request<A: Into<Option<&'static str>>>( - auth: A, - endpoint: String, - ) -> Request<()> { - let request = Request::builder().method(Method::GET); - - match auth.into() { - Some(auth) => request.header("Authorization", auth), - None => request, - } - .extension(super::TokenEndpoint(endpoint.parse().unwrap())) - .extension(get_http_client()) - .body(()) - .unwrap() - } - - #[tokio::test] - async fn test_require_token_with_token() { - let server = MockServer::start().await; - - Mock::given(path("/token")) - .and(header("Authorization", "Bearer token")) - .respond_with(ResponseTemplate::new(200) - .set_body_json(User::new( - "https://fireburn.ru/", - "https://quill.p3k.io/", - "create update media", - )) - ) - .mount(&server) - .await; - - let request = request("Bearer token", format!("{}/token", &server.uri())); - let mut parts = axum::extract::RequestParts::new(request); - let user = User::from_request(&mut parts).await.unwrap(); - - assert_eq!(user.me.as_str(), "https://fireburn.ru/") - } - - #[tokio::test] - async fn test_require_token_fake_token() { - let server = MockServer::start().await; - - Mock::given(path("/refuse_token")) - .respond_with(ResponseTemplate::new(200) - .set_body_json(serde_json::json!({"active": false})) - ) - .mount(&server) - .await; - - let request = request("Bearer token", format!("{}/refuse_token", &server.uri())); - let mut parts = axum::extract::RequestParts::new(request); - let err = User::from_request(&mut parts).await.unwrap_err(); - - assert_eq!(err.kind, super::ErrorKind::NotAuthorized) - } - - #[tokio::test] - async fn test_require_token_no_token() { - let server = MockServer::start().await; - - Mock::given(path("/should_never_be_called")) - .respond_with(ResponseTemplate::new(500)) - .expect(0) - .mount(&server) - .await; - - let request = request(None, format!("{}/should_never_be_called", &server.uri())); - let mut parts = axum::extract::RequestParts::new(request); - let err = User::from_request(&mut parts).await.unwrap_err(); - - assert_eq!(err.kind, super::ErrorKind::InvalidHeader); - } - - #[tokio::test] - async fn test_require_token_400_error_unauthorized() { - let server = MockServer::start().await; - - Mock::given(path("/refuse_token_with_400")) - .and(header("Authorization", "Bearer token")) - .respond_with(ResponseTemplate::new(400) - .set_body_json(serde_json::json!({ - "error": "unauthorized", - "error_description": "The token provided was malformed" - })) - ) - .mount(&server) - .await; - - let request = request( - "Bearer token", - format!("{}/refuse_token_with_400", &server.uri()), - ); - let mut parts = axum::extract::RequestParts::new(request); - let err = User::from_request(&mut parts).await.unwrap_err(); - - assert_eq!(err.kind, super::ErrorKind::NotAuthorized); - } -} diff --git a/kittybox-rs/src/webmentions/check.rs b/kittybox-rs/src/webmentions/check.rs deleted file mode 100644 index f7322f7..0000000 --- a/kittybox-rs/src/webmentions/check.rs +++ /dev/null @@ -1,113 +0,0 @@ -use std::{cell::RefCell, rc::Rc}; -use microformats::{types::PropertyValue, html5ever::{self, tendril::TendrilSink}}; -use kittybox_util::MentionType; - -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("microformats error: {0}")] - Microformats(#[from] microformats::Error), - // #[error("json error: {0}")] - // Json(#[from] serde_json::Error), - #[error("url parse error: {0}")] - UrlParse(#[from] url::ParseError), -} - -#[tracing::instrument] -pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url::Url, link: &url::Url) -> Result<Option<(MentionType, serde_json::Value)>, Error> { - tracing::debug!("Parsing MF2 markup..."); - // First, check the document for MF2 markup - let document = microformats::from_html(document.as_ref(), base_url.clone())?; - - // Get an iterator of all items - let items_iter = document.items.iter() - .map(AsRef::as_ref) - .map(RefCell::borrow); - - for item in items_iter { - tracing::debug!("Processing item: {:?}", item); - - let props = item.properties.borrow(); - for (prop, interaction_type) in [ - ("in-reply-to", MentionType::Reply), ("like-of", MentionType::Like), - ("bookmark-of", MentionType::Bookmark), ("repost-of", MentionType::Repost) - ] { - if let Some(propvals) = props.get(prop) { - tracing::debug!("Has a u-{} property", prop); - for val in propvals { - if let PropertyValue::Url(url) = val { - if url == link { - tracing::debug!("URL matches! Webmention is valid"); - return Ok(Some((interaction_type, serde_json::to_value(&*item).unwrap()))) - } - } - } - } - } - // Process `content` - tracing::debug!("Processing e-content..."); - if let Some(PropertyValue::Fragment(content)) = props.get("content") - .map(Vec::as_slice) - .unwrap_or_default() - .first() - { - tracing::debug!("Parsing HTML data..."); - let root = html5ever::parse_document(html5ever::rcdom::RcDom::default(), Default::default()) - .from_utf8() - .one(content.html.to_owned().as_bytes()) - .document; - - // This is a trick to unwrap recursion into a loop - // - // A list of unprocessed node is made. Then, in each - // iteration, the list is "taken" and replaced with an - // empty list, which is populated with nodes for the next - // iteration of the loop. - // - // Empty list means all nodes were processed. - let mut unprocessed_nodes: Vec<Rc<html5ever::rcdom::Node>> = root.children.borrow().iter().cloned().collect(); - while !unprocessed_nodes.is_empty() { - // "Take" the list out of its memory slot, replace it with an empty list - let nodes = std::mem::take(&mut unprocessed_nodes); - tracing::debug!("Processing list of {} nodes", nodes.len()); - 'nodes_loop: for node in nodes.into_iter() { - // Add children nodes to the list for the next iteration - unprocessed_nodes.extend(node.children.borrow().iter().cloned()); - - if let html5ever::rcdom::NodeData::Element { ref name, ref attrs, .. } = node.data { - // If it's not `<a>`, skip it - if name.local != *"a" { continue; } - let mut is_mention: bool = false; - for attr in attrs.borrow().iter() { - if attr.name.local == *"rel" { - // Don't count `rel="nofollow"` links — a web crawler should ignore them - // and so for purposes of driving visitors they are useless - if attr.value - .as_ref() - .split([',', ' ']) - .any(|v| v == "nofollow") - { - // Skip the entire node. - continue 'nodes_loop; - } - } - // if it's not `<a href="...">`, skip it - if attr.name.local != *"href" { continue; } - // Be forgiving in parsing URLs, and resolve them against the base URL - if let Ok(url) = base_url.join(attr.value.as_ref()) { - if &url == link { - is_mention = true; - } - } - } - if is_mention { - return Ok(Some((MentionType::Mention, serde_json::to_value(&*item).unwrap()))); - } - } - } - } - - } - } - - Ok(None) -} diff --git a/kittybox-rs/src/webmentions/mod.rs b/kittybox-rs/src/webmentions/mod.rs deleted file mode 100644 index 95ea870..0000000 --- a/kittybox-rs/src/webmentions/mod.rs +++ /dev/null @@ -1,195 +0,0 @@ -use axum::{Form, response::{IntoResponse, Response}, Extension}; -use axum::http::StatusCode; -use tracing::error; - -use crate::database::{Storage, StorageError}; -use self::queue::JobQueue; -pub mod queue; - -#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -#[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] -pub struct Webmention { - source: String, - target: String, -} - -impl queue::JobItem for Webmention {} -impl queue::PostgresJobItem for Webmention { - const DATABASE_NAME: &'static str = "kittybox_webmention.incoming_webmention_queue"; - const NOTIFICATION_CHANNEL: &'static str = "incoming_webmention"; -} - -async fn accept_webmention<Q: JobQueue<Webmention>>( - Extension(queue): Extension<Q>, - Form(webmention): Form<Webmention>, -) -> Response { - if let Err(err) = webmention.source.parse::<url::Url>() { - return (StatusCode::BAD_REQUEST, err.to_string()).into_response() - } - if let Err(err) = webmention.target.parse::<url::Url>() { - return (StatusCode::BAD_REQUEST, err.to_string()).into_response() - } - - match queue.put(&webmention).await { - Ok(id) => (StatusCode::ACCEPTED, [ - ("Location", format!("/.kittybox/webmention/{id}")) - ]).into_response(), - Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, [ - ("Content-Type", "text/plain") - ], err.to_string()).into_response() - } -} - -pub fn router<Q: JobQueue<Webmention>, S: Storage + 'static>( - queue: Q, db: S, http: reqwest::Client, - cancellation_token: tokio_util::sync::CancellationToken -) -> (axum::Router, SupervisedTask) { - // Automatically spawn a background task to handle webmentions - let bgtask_handle = supervised_webmentions_task(queue.clone(), db, http, cancellation_token); - - let router = axum::Router::new() - .route("/.kittybox/webmention", - axum::routing::post(accept_webmention::<Q>) - ) - .layer(Extension(queue)); - - (router, bgtask_handle) -} - -#[derive(thiserror::Error, Debug)] -pub enum SupervisorError { - #[error("the task was explicitly cancelled")] - Cancelled -} - -pub type SupervisedTask = tokio::task::JoinHandle<Result<std::convert::Infallible, SupervisorError>>; - -pub fn supervisor<E, A, F>(mut f: F, cancellation_token: tokio_util::sync::CancellationToken) -> SupervisedTask -where - E: std::error::Error + std::fmt::Debug + Send + 'static, - A: std::future::Future<Output = Result<std::convert::Infallible, E>> + Send + 'static, - F: FnMut() -> A + Send + 'static -{ - - let supervisor_future = async move { - loop { - // Don't spawn the task if we are already cancelled, but - // have somehow missed it (probably because the task - // crashed and we immediately received a cancellation - // request after noticing the crashed task) - if cancellation_token.is_cancelled() { - return Err(SupervisorError::Cancelled) - } - let task = tokio::task::spawn(f()); - tokio::select! { - _ = cancellation_token.cancelled() => { - tracing::info!("Shutdown of background task {:?} requested.", std::any::type_name::<A>()); - return Err(SupervisorError::Cancelled) - } - task_result = task => match task_result { - Err(e) => tracing::error!("background task {:?} exited unexpectedly: {}", std::any::type_name::<A>(), e), - Ok(Err(e)) => tracing::error!("background task {:?} returned error: {}", std::any::type_name::<A>(), e), - Ok(Ok(_)) => unreachable!("task's Ok is Infallible") - } - } - tracing::debug!("Sleeping for a little while to back-off..."); - tokio::time::sleep(std::time::Duration::from_secs(5)).await; - } - }; - #[cfg(not(tokio_unstable))] - return tokio::task::spawn(supervisor_future); - #[cfg(tokio_unstable)] - return tokio::task::Builder::new() - .name(format!("supervisor for background task {}", std::any::type_name::<A>()).as_str()) - .spawn(supervisor_future) - .unwrap(); -} - -mod check; - -#[derive(thiserror::Error, Debug)] -enum Error<Q: std::error::Error + std::fmt::Debug + Send + 'static> { - #[error("queue error: {0}")] - Queue(#[from] Q), - #[error("storage error: {0}")] - Storage(StorageError) -} - -async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 'static>(queue: Q, db: S, http: reqwest::Client) -> Result<std::convert::Infallible, Error<Q::Error>> { - use futures_util::StreamExt; - use self::queue::Job; - - let mut stream = queue.into_stream().await?; - while let Some(item) = stream.next().await.transpose()? { - let job = item.job(); - let (source, target) = ( - job.source.parse::<url::Url>().unwrap(), - job.target.parse::<url::Url>().unwrap() - ); - - let (code, text) = match http.get(source.clone()).send().await { - Ok(response) => { - let code = response.status(); - if ![StatusCode::OK, StatusCode::GONE].iter().any(|i| i == &code) { - error!("error processing webmention: webpage fetch returned {}", code); - continue; - } - match response.text().await { - Ok(text) => (code, text), - Err(err) => { - error!("error processing webmention: error fetching webpage text: {}", err); - continue - } - } - } - Err(err) => { - error!("error processing webmention: error requesting webpage: {}", err); - continue - } - }; - - if code == StatusCode::GONE { - todo!("removing webmentions is not implemented yet"); - // db.remove_webmention(target.as_str(), source.as_str()).await.map_err(Error::<Q::Error>::Storage)?; - } else { - // Verify webmention - let (mention_type, mut mention) = match tokio::task::block_in_place({ - || check::check_mention(text, &source, &target) - }) { - Ok(Some(mention_type)) => mention_type, - Ok(None) => { - error!("webmention {} -> {} invalid, rejecting", source, target); - item.done().await?; - continue; - } - Err(err) => { - error!("error processing webmention: error checking webmention: {}", err); - continue; - } - }; - - { - mention["type"] = serde_json::json!(["h-cite"]); - - if !mention["properties"].as_object().unwrap().contains_key("uid") { - let url = mention["properties"]["url"][0].as_str().unwrap_or_else(|| target.as_str()).to_owned(); - let props = mention["properties"].as_object_mut().unwrap(); - props.insert("uid".to_owned(), serde_json::Value::Array( - vec![serde_json::Value::String(url)]) - ); - } - } - - db.add_or_update_webmention(target.as_str(), mention_type, mention).await.map_err(Error::<Q::Error>::Storage)?; - } - } - unreachable!() -} - -fn supervised_webmentions_task<Q: JobQueue<Webmention>, S: Storage + 'static>( - queue: Q, db: S, - http: reqwest::Client, - cancellation_token: tokio_util::sync::CancellationToken -) -> SupervisedTask { - supervisor::<Error<Q::Error>, _, _>(move || process_webmentions_from_queue(queue.clone(), db.clone(), http.clone()), cancellation_token) -} diff --git a/kittybox-rs/src/webmentions/queue.rs b/kittybox-rs/src/webmentions/queue.rs deleted file mode 100644 index b811e71..0000000 --- a/kittybox-rs/src/webmentions/queue.rs +++ /dev/null @@ -1,303 +0,0 @@ -use std::{pin::Pin, str::FromStr}; - -use futures_util::{Stream, StreamExt}; -use sqlx::{postgres::PgListener, Executor}; -use uuid::Uuid; - -use super::Webmention; - -static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("./migrations/webmention"); - -pub use kittybox_util::queue::{JobQueue, JobItem, Job}; - -pub trait PostgresJobItem: JobItem + sqlx::FromRow<'static, sqlx::postgres::PgRow> { - const DATABASE_NAME: &'static str; - const NOTIFICATION_CHANNEL: &'static str; -} - -#[derive(sqlx::FromRow)] -struct PostgresJobRow<T: PostgresJobItem> { - id: Uuid, - #[sqlx(flatten)] - job: T -} - -#[derive(Debug)] -pub struct PostgresJob<T: PostgresJobItem> { - id: Uuid, - job: T, - // This will normally always be Some, except on drop - txn: Option<sqlx::Transaction<'static, sqlx::Postgres>>, - runtime_handle: tokio::runtime::Handle, -} - - -impl<T: PostgresJobItem> Drop for PostgresJob<T> { - // This is an emulation of "async drop" — the struct retains a - // runtime handle, which it uses to block on a future that does - // the actual cleanup. - // - // Of course, this is not portable between runtimes, but I don't - // care about that, since Kittybox is designed to work within the - // Tokio ecosystem. - fn drop(&mut self) { - tracing::error!("Job {:?} failed, incrementing attempts...", &self); - if let Some(mut txn) = self.txn.take() { - let id = self.id; - self.runtime_handle.spawn(async move { - tracing::debug!("Constructing query to increment attempts for job {}...", id); - // UPDATE "T::DATABASE_NAME" WHERE id = $1 SET attempts = attempts + 1 - sqlx::query_builder::QueryBuilder::new("UPDATE ") - // This is safe from a SQL injection standpoint, since it is a constant. - .push(T::DATABASE_NAME) - .push(" SET attempts = attempts + 1") - .push(" WHERE id = ") - .push_bind(id) - .build() - .execute(&mut *txn) - .await - .unwrap(); - sqlx::query_builder::QueryBuilder::new("NOTIFY ") - .push(T::NOTIFICATION_CHANNEL) - .build() - .execute(&mut *txn) - .await - .unwrap(); - txn.commit().await.unwrap(); - }); - } - } -} - -#[cfg(test)] -impl<T: PostgresJobItem> PostgresJob<T> { - async fn attempts(&mut self) -> Result<usize, sqlx::Error> { - sqlx::query_builder::QueryBuilder::new("SELECT attempts FROM ") - .push(T::DATABASE_NAME) - .push(" WHERE id = ") - .push_bind(self.id) - .build_query_as::<(i32,)>() - // It's safe to unwrap here, because we "take" the txn only on drop or commit, - // where it's passed by value, not by reference. - .fetch_one(self.txn.as_deref_mut().unwrap()) - .await - .map(|(i,)| i as usize) - } -} - -#[async_trait::async_trait] -impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJob<Webmention> { - fn job(&self) -> &Webmention { - &self.job - } - async fn done(mut self) -> Result<(), <PostgresJobQueue<Webmention> as JobQueue<Webmention>>::Error> { - tracing::debug!("Deleting {} from the job queue", self.id); - sqlx::query("DELETE FROM kittybox_webmention.incoming_webmention_queue WHERE id = $1") - .bind(self.id) - .execute(self.txn.as_deref_mut().unwrap()) - .await?; - - self.txn.take().unwrap().commit().await - } -} - -pub struct PostgresJobQueue<T> { - db: sqlx::PgPool, - _phantom: std::marker::PhantomData<T> -} -impl<T> Clone for PostgresJobQueue<T> { - fn clone(&self) -> Self { - Self { - db: self.db.clone(), - _phantom: std::marker::PhantomData - } - } -} - -impl PostgresJobQueue<Webmention> { - pub async fn new(uri: &str) -> Result<Self, sqlx::Error> { - let mut options = sqlx::postgres::PgConnectOptions::from_str(uri)? - .options([("search_path", "kittybox_webmention")]); - if let Ok(password_file) = std::env::var("PGPASS_FILE") { - let password = tokio::fs::read_to_string(password_file).await.unwrap(); - options = options.password(&password); - } else if let Ok(password) = std::env::var("PGPASS") { - options = options.password(&password) - } - Self::from_pool( - sqlx::postgres::PgPoolOptions::new() - .max_connections(50) - .connect_with(options) - .await? - ).await - - } - - pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result<Self, sqlx::Error> { - db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox_webmention")).await?; - MIGRATOR.run(&db).await?; - Ok(Self { db, _phantom: std::marker::PhantomData }) - } -} - -#[async_trait::async_trait] -impl JobQueue<Webmention> for PostgresJobQueue<Webmention> { - type Job = PostgresJob<Webmention>; - type Error = sqlx::Error; - - async fn get_one(&self) -> Result<Option<Self::Job>, Self::Error> { - let mut txn = self.db.begin().await?; - - match sqlx::query_as::<_, PostgresJobRow<Webmention>>( - "SELECT id, source, target FROM kittybox_webmention.incoming_webmention_queue WHERE attempts < 5 FOR UPDATE SKIP LOCKED LIMIT 1" - ) - .fetch_optional(&mut *txn) - .await? - { - Some(job_row) => { - return Ok(Some(Self::Job { - id: job_row.id, - job: job_row.job, - txn: Some(txn), - runtime_handle: tokio::runtime::Handle::current(), - })) - }, - None => Ok(None) - } - } - - async fn put(&self, item: &Webmention) -> Result<Uuid, Self::Error> { - sqlx::query_scalar::<_, Uuid>("INSERT INTO kittybox_webmention.incoming_webmention_queue (source, target) VALUES ($1, $2) RETURNING id") - .bind(item.source.as_str()) - .bind(item.target.as_str()) - .fetch_one(&self.db) - .await - } - - async fn into_stream(self) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>>, Self::Error> { - let mut listener = PgListener::connect_with(&self.db).await?; - listener.listen("incoming_webmention").await?; - - let stream: Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>> = futures_util::stream::try_unfold((), { - let listener = std::sync::Arc::new(tokio::sync::Mutex::new(listener)); - move |_| { - let queue = self.clone(); - let listener = listener.clone(); - async move { - loop { - match queue.get_one().await? { - Some(item) => return Ok(Some((item, ()))), - None => { - listener.lock().await.recv().await?; - continue - } - } - } - } - } - }).boxed(); - - Ok(stream) - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use super::{Webmention, PostgresJobQueue, Job, JobQueue, MIGRATOR}; - use futures_util::StreamExt; - - #[sqlx::test(migrator = "MIGRATOR")] - #[tracing_test::traced_test] - async fn test_webmention_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> { - let test_webmention = Webmention { - source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(), - target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned() - }; - - let queue = PostgresJobQueue::<Webmention>::from_pool(pool).await?; - tracing::debug!("Putting webmention into queue"); - queue.put(&test_webmention).await?; - { - let mut job_description = queue.get_one().await?.unwrap(); - assert_eq!(job_description.job(), &test_webmention); - assert_eq!(job_description.attempts().await?, 0); - } - tracing::debug!("Creating a stream"); - let mut stream = queue.clone().into_stream().await?; - - { - let mut guard = stream.next().await.transpose()?.unwrap(); - assert_eq!(guard.job(), &test_webmention); - assert_eq!(guard.attempts().await?, 1); - if let Some(item) = queue.get_one().await? { - panic!("Unexpected item {:?} returned from job queue!", item) - }; - } - - { - let mut guard = stream.next().await.transpose()?.unwrap(); - assert_eq!(guard.job(), &test_webmention); - assert_eq!(guard.attempts().await?, 2); - guard.done().await?; - } - - match queue.get_one().await? { - Some(item) => panic!("Unexpected item {:?} returned from job queue!", item), - None => Ok(()) - } - } - - #[sqlx::test(migrator = "MIGRATOR")] - #[tracing_test::traced_test] - async fn test_no_hangups_in_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> { - let test_webmention = Webmention { - source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(), - target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned() - }; - - let queue = PostgresJobQueue::<Webmention>::from_pool(pool.clone()).await?; - tracing::debug!("Putting webmention into queue"); - queue.put(&test_webmention).await?; - tracing::debug!("Creating a stream"); - let mut stream = queue.clone().into_stream().await?; - - // Synchronisation barrier that will be useful later - let barrier = Arc::new(tokio::sync::Barrier::new(2)); - { - // Get one job guard from a queue - let mut guard = stream.next().await.transpose()?.unwrap(); - assert_eq!(guard.job(), &test_webmention); - assert_eq!(guard.attempts().await?, 0); - - tokio::task::spawn({ - let barrier = barrier.clone(); - async move { - // Wait for the signal to drop the guard! - barrier.wait().await; - - drop(guard) - } - }); - } - tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()).await.unwrap_err(); - - let future = tokio::task::spawn( - tokio::time::timeout( - std::time::Duration::from_secs(10), async move { - stream.next().await.unwrap().unwrap() - } - ) - ); - // Let the other task drop the guard it is holding - barrier.wait().await; - let mut guard = future.await - .expect("Timeout on fetching item") - .expect("Job queue error"); - assert_eq!(guard.job(), &test_webmention); - assert_eq!(guard.attempts().await?, 1); - - Ok(()) - } -} |