diff options
author | Vika <vika@fireburn.ru> | 2025-04-09 23:31:02 +0300 |
---|---|---|
committer | Vika <vika@fireburn.ru> | 2025-04-09 23:31:57 +0300 |
commit | 8826d9446e6c492db2243b9921e59ce496027bef (patch) | |
tree | 63738aa9001cb73b11cb0e974e93129bcdf1adbb /src | |
parent | 519cadfbb298f50cbf819dde757037ab56e2863e (diff) | |
download | kittybox-8826d9446e6c492db2243b9921e59ce496027bef.tar.zst |
cargo fmt
Change-Id: I80e81ebba3f0cdf8c094451c9fe3ee4126b8c888
Diffstat (limited to 'src')
-rw-r--r-- | src/bin/kittybox-check-webmention.rs | 13 | ||||
-rw-r--r-- | src/bin/kittybox-indieauth-helper.rs | 96 | ||||
-rw-r--r-- | src/bin/kittybox-mf2.rs | 14 | ||||
-rw-r--r-- | src/database/file/mod.rs | 159 | ||||
-rw-r--r-- | src/database/memory.rs | 41 | ||||
-rw-r--r-- | src/database/mod.rs | 165 | ||||
-rw-r--r-- | src/database/postgres/mod.rs | 122 | ||||
-rw-r--r-- | src/frontend/mod.rs | 136 | ||||
-rw-r--r-- | src/frontend/onboarding.rs | 26 | ||||
-rw-r--r-- | src/indieauth/backend.rs | 89 | ||||
-rw-r--r-- | src/indieauth/backend/fs.rs | 282 | ||||
-rw-r--r-- | src/indieauth/mod.rs | 727 | ||||
-rw-r--r-- | src/indieauth/webauthn.rs | 76 | ||||
-rw-r--r-- | src/lib.rs | 211 | ||||
-rw-r--r-- | src/login.rs | 266 | ||||
-rw-r--r-- | src/main.rs | 202 | ||||
-rw-r--r-- | src/media/mod.rs | 83 | ||||
-rw-r--r-- | src/media/storage/file.rs | 322 | ||||
-rw-r--r-- | src/media/storage/mod.rs | 169 | ||||
-rw-r--r-- | src/micropub/mod.rs | 326 | ||||
-rw-r--r-- | src/micropub/util.rs | 70 | ||||
-rw-r--r-- | src/webmentions/check.rs | 52 | ||||
-rw-r--r-- | src/webmentions/mod.rs | 130 | ||||
-rw-r--r-- | src/webmentions/queue.rs | 60 |
24 files changed, 2332 insertions, 1505 deletions
diff --git a/src/bin/kittybox-check-webmention.rs b/src/bin/kittybox-check-webmention.rs index b43980e..a9e5957 100644 --- a/src/bin/kittybox-check-webmention.rs +++ b/src/bin/kittybox-check-webmention.rs @@ -7,7 +7,7 @@ enum Error { #[error("reqwest error: {0}")] Http(#[from] reqwest::Error), #[error("webmention check error: {0}")] - Webmention(#[from] WebmentionError) + Webmention(#[from] WebmentionError), } #[derive(Parser, Debug)] @@ -21,7 +21,7 @@ struct Args { #[clap(value_parser)] url: url::Url, #[clap(value_parser)] - link: url::Url + link: url::Url, } #[tokio::main] @@ -30,10 +30,11 @@ async fn main() -> Result<(), Error> { let http: reqwest::Client = { #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION") - )); + let mut builder = reqwest::Client::builder().user_agent(concat!( + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION") + )); builder.build().unwrap() }; diff --git a/src/bin/kittybox-indieauth-helper.rs b/src/bin/kittybox-indieauth-helper.rs index f4ad679..0725aac 100644 --- a/src/bin/kittybox-indieauth-helper.rs +++ b/src/bin/kittybox-indieauth-helper.rs @@ -1,13 +1,11 @@ +use clap::Parser; use futures::{FutureExt, TryFutureExt}; use kittybox_indieauth::{ - AuthorizationRequest, PKCEVerifier, - PKCEChallenge, PKCEMethod, GrantRequest, Scope, - AuthorizationResponse, GrantResponse, - Error as IndieauthError + AuthorizationRequest, AuthorizationResponse, Error as IndieauthError, GrantRequest, + GrantResponse, PKCEChallenge, PKCEMethod, PKCEVerifier, Scope, }; -use clap::Parser; -use tokio::net::TcpListener; use std::{borrow::Cow, future::IntoFuture, io::Write}; +use tokio::net::TcpListener; const DEFAULT_CLIENT_ID: &str = "https://kittybox.fireburn.ru/indieauth-helper.html"; const DEFAULT_REDIRECT_URI: &str = "http://localhost:60000/callback"; @@ -21,7 +19,7 @@ enum Error { #[error("url parsing error: {0}")] UrlParse(#[from] url::ParseError), #[error("indieauth flow error: {0}")] - IndieAuth(#[from] IndieauthError) + IndieAuth(#[from] IndieauthError), } #[derive(Parser, Debug)] @@ -46,20 +44,20 @@ struct Args { client_id: url::Url, /// Redirect URI to declare. Note: This will break the flow, use only for testing UI. #[clap(long, value_parser)] - redirect_uri: Option<url::Url> + redirect_uri: Option<url::Url>, } - #[tokio::main] async fn main() -> Result<(), Error> { let args = Args::parse(); let http: reqwest::Client = { #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION") - )); + let mut builder = reqwest::Client::builder().user_agent(concat!( + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION") + )); // This only works on debug builds. Don't get any funny thoughts. #[cfg(debug_assertions)] if std::env::var("KITTYBOX_DANGER_INSECURE_TLS") @@ -71,12 +69,14 @@ async fn main() -> Result<(), Error> { builder.build().unwrap() }; - let redirect_uri: url::Url = args.redirect_uri + let redirect_uri: url::Url = args + .redirect_uri .clone() .unwrap_or_else(|| DEFAULT_REDIRECT_URI.parse().unwrap()); eprintln!("Checking .well-known for metadata..."); - let metadata = http.get(args.me.join("/.well-known/oauth-authorization-server")?) + let metadata = http + .get(args.me.join("/.well-known/oauth-authorization-server")?) .header("Accept", "application/json") .send() .await? @@ -92,7 +92,7 @@ async fn main() -> Result<(), Error> { state: kittybox_indieauth::State::new(), code_challenge: PKCEChallenge::new(&verifier, PKCEMethod::default()), scope: Some(kittybox_indieauth::Scopes::new(args.scope)), - me: Some(args.me) + me: Some(args.me), }; let indieauth_url = { @@ -103,12 +103,18 @@ async fn main() -> Result<(), Error> { url }; - eprintln!("Please visit the following URL in your browser:\n\n {}\n", indieauth_url.as_str()); + eprintln!( + "Please visit the following URL in your browser:\n\n {}\n", + indieauth_url.as_str() + ); #[cfg(target_os = "linux")] - match std::process::Command::new("xdg-open").arg(indieauth_url.as_str()).spawn() { + match std::process::Command::new("xdg-open") + .arg(indieauth_url.as_str()) + .spawn() + { Ok(child) => drop(child), - Err(err) => eprintln!("Couldn't xdg-open: {}", err) + Err(err) => eprintln!("Couldn't xdg-open: {}", err), } if args.redirect_uri.is_some() { @@ -123,32 +129,38 @@ async fn main() -> Result<(), Error> { let tx = std::sync::Arc::new(tokio::sync::Mutex::new(Some(tx))); - let router = axum::Router::new() - .route("/callback", axum::routing::get( + let router = axum::Router::new().route( + "/callback", + axum::routing::get( move |Query(response): Query<AuthorizationResponse>| async move { if let Some(tx) = tx.lock_owned().await.take() { tx.send(response).unwrap(); - (axum::http::StatusCode::OK, - [("Content-Type", "text/plain")], - "Thank you! This window can now be closed.") + ( + axum::http::StatusCode::OK, + [("Content-Type", "text/plain")], + "Thank you! This window can now be closed.", + ) .into_response() } else { - (axum::http::StatusCode::BAD_REQUEST, - [("Content-Type", "text/plain")], - "Oops. The callback was already received. Did you click twice?") + ( + axum::http::StatusCode::BAD_REQUEST, + [("Content-Type", "text/plain")], + "Oops. The callback was already received. Did you click twice?", + ) .into_response() } - } - )); + }, + ), + ); - use std::net::{SocketAddr, IpAddr, Ipv4Addr}; + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; let server = axum::serve( - TcpListener::bind( - SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST),60000) - ).await.unwrap(), - router.into_make_service() + TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 60000)) + .await + .unwrap(), + router.into_make_service(), ); tokio::task::spawn(server.into_future()) @@ -175,12 +187,13 @@ async fn main() -> Result<(), Error> { #[cfg(not(debug_assertions))] std::process::exit(1); } - let response: Result<GrantResponse, IndieauthError> = http.post(metadata.token_endpoint) + let response: Result<GrantResponse, IndieauthError> = http + .post(metadata.token_endpoint) .form(&GrantRequest::AuthorizationCode { code: authorization_response.code, client_id: args.client_id, redirect_uri, - code_verifier: verifier + code_verifier: verifier, }) .header("Accept", "application/json") .send() @@ -201,9 +214,14 @@ async fn main() -> Result<(), Error> { refresh_token, scope, .. - } = response? { - eprintln!("Congratulations, {}, access token is ready! {}", - profile.as_ref().and_then(|p| p.name.as_deref()).unwrap_or(me.as_str()), + } = response? + { + eprintln!( + "Congratulations, {}, access token is ready! {}", + profile + .as_ref() + .and_then(|p| p.name.as_deref()) + .unwrap_or(me.as_str()), if let Some(exp) = expires_in { Cow::Owned(format!("It expires in {exp} seconds.")) } else { diff --git a/src/bin/kittybox-mf2.rs b/src/bin/kittybox-mf2.rs index 0cd89b4..b6f4999 100644 --- a/src/bin/kittybox-mf2.rs +++ b/src/bin/kittybox-mf2.rs @@ -37,8 +37,9 @@ async fn main() -> Result<(), Error> { .with_indent_lines(true) .with_verbose_exit(true), #[cfg(not(debug_assertions))] - tracing_subscriber::fmt::layer().json() - .with_ansi(std::io::IsTerminal::is_terminal(&std::io::stdout().lock())) + tracing_subscriber::fmt::layer() + .json() + .with_ansi(std::io::IsTerminal::is_terminal(&std::io::stdout().lock())), ); tracing_registry.init(); @@ -46,10 +47,11 @@ async fn main() -> Result<(), Error> { let http: reqwest::Client = { #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION") - )); + let mut builder = reqwest::Client::builder().user_agent(concat!( + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION") + )); builder.build().unwrap() }; diff --git a/src/database/file/mod.rs b/src/database/file/mod.rs index b9f27b2..5c93beb 100644 --- a/src/database/file/mod.rs +++ b/src/database/file/mod.rs @@ -1,6 +1,6 @@ //#![warn(clippy::unwrap_used)] -use crate::database::{ErrorKind, Result, settings, Storage, StorageError}; -use crate::micropub::{MicropubUpdate, MicropubPropertyDeletion}; +use crate::database::{settings, ErrorKind, Result, Storage, StorageError}; +use crate::micropub::{MicropubPropertyDeletion, MicropubUpdate}; use futures::{stream, StreamExt, TryStreamExt}; use kittybox_util::MentionType; use serde_json::json; @@ -247,7 +247,9 @@ async fn hydrate_author<S: Storage>( impl Storage for FileStorage { async fn new(url: &'_ url::Url) -> Result<Self> { // TODO: sanity check - Ok(Self { root_dir: PathBuf::from(url.path()) }) + Ok(Self { + root_dir: PathBuf::from(url.path()), + }) } #[tracing::instrument(skip(self))] async fn categories(&self, url: &str) -> Result<Vec<String>> { @@ -259,7 +261,7 @@ impl Storage for FileStorage { // perform well. Err(std::io::Error::new( std::io::ErrorKind::Unsupported, - "?q=category queries are not implemented due to resource constraints" + "?q=category queries are not implemented due to resource constraints", ))? } @@ -340,7 +342,10 @@ impl Storage for FileStorage { file.sync_all().await?; drop(file); tokio::fs::rename(&tempfile, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + tokio::fs::File::open(path.parent().unwrap()) + .await? + .sync_all() + .await?; if let Some(urls) = post["properties"]["url"].as_array() { for url in urls.iter().map(|i| i.as_str().unwrap()) { @@ -350,8 +355,8 @@ impl Storage for FileStorage { "{}{}", url.host_str().unwrap(), url.port() - .map(|port| format!(":{}", port)) - .unwrap_or_default() + .map(|port| format!(":{}", port)) + .unwrap_or_default() ) }; if url != key && url_domain == user.authority() { @@ -410,26 +415,24 @@ impl Storage for FileStorage { .create(false) .open(&path) .await - { - Err(err) if err.kind() == std::io::ErrorKind::NotFound => { - Vec::default() - } - Err(err) => { - // Propagate the error upwards - return Err(err.into()); - } - Ok(mut file) => { - let mut content = String::new(); - file.read_to_string(&mut content).await?; - drop(file); - - if !content.is_empty() { - serde_json::from_str(&content)? - } else { - Vec::default() - } - } - } + { + Err(err) if err.kind() == std::io::ErrorKind::NotFound => Vec::default(), + Err(err) => { + // Propagate the error upwards + return Err(err.into()); + } + Ok(mut file) => { + let mut content = String::new(); + file.read_to_string(&mut content).await?; + drop(file); + + if !content.is_empty() { + serde_json::from_str(&content)? + } else { + Vec::default() + } + } + } }; channels.push(super::MicropubChannel { @@ -444,7 +447,10 @@ impl Storage for FileStorage { tempfile.sync_all().await?; drop(tempfile); tokio::fs::rename(tempfilename, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + tokio::fs::File::open(path.parent().unwrap()) + .await? + .sync_all() + .await?; } Ok(()) } @@ -476,7 +482,10 @@ impl Storage for FileStorage { temp.sync_all().await?; drop(temp); tokio::fs::rename(tempfilename, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + tokio::fs::File::open(path.parent().unwrap()) + .await? + .sync_all() + .await?; (json, new_json) }; @@ -486,7 +495,9 @@ impl Storage for FileStorage { #[tracing::instrument(skip(self, f), fields(f = std::any::type_name::<F>()))] async fn update_with<F: FnOnce(&mut serde_json::Value) + Send>( - &self, url: &str, f: F + &self, + url: &str, + f: F, ) -> Result<(serde_json::Value, serde_json::Value)> { todo!("update_with is not yet implemented due to special requirements of the file backend") } @@ -526,25 +537,25 @@ impl Storage for FileStorage { url: &'_ str, cursor: Option<&'_ str>, limit: usize, - user: Option<&url::Url> + user: Option<&url::Url>, ) -> Result<Option<(serde_json::Value, Option<String>)>> { #[allow(deprecated)] - Ok(self.read_feed_with_limit( - url, - cursor, - limit, - user - ).await? + Ok(self + .read_feed_with_limit(url, cursor, limit, user) + .await? .map(|feed| { - tracing::debug!("Feed: {:#}", serde_json::Value::Array( - feed["children"] - .as_array() - .map(|v| v.as_slice()) - .unwrap_or_default() - .iter() - .map(|mf2| mf2["properties"]["uid"][0].clone()) - .collect::<Vec<_>>() - )); + tracing::debug!( + "Feed: {:#}", + serde_json::Value::Array( + feed["children"] + .as_array() + .map(|v| v.as_slice()) + .unwrap_or_default() + .iter() + .map(|mf2| mf2["properties"]["uid"][0].clone()) + .collect::<Vec<_>>() + ) + ); let cursor: Option<String> = feed["children"] .as_array() .map(|v| v.as_slice()) @@ -553,8 +564,7 @@ impl Storage for FileStorage { .map(|v| v["properties"]["uid"][0].as_str().unwrap().to_owned()); tracing::debug!("Extracted the cursor: {:?}", cursor); (feed, cursor) - }) - ) + })) } #[tracing::instrument(skip(self))] @@ -574,9 +584,12 @@ impl Storage for FileStorage { let children: Vec<serde_json::Value> = match feed["children"].take() { serde_json::Value::Array(children) => children, // We've already checked it's an array - _ => unreachable!() + _ => unreachable!(), }; - tracing::debug!("Full children array: {:#}", serde_json::Value::Array(children.clone())); + tracing::debug!( + "Full children array: {:#}", + serde_json::Value::Array(children.clone()) + ); let mut posts_iter = children .into_iter() .map(|s: serde_json::Value| s.as_str().unwrap().to_string()); @@ -589,7 +602,7 @@ impl Storage for FileStorage { // incredibly long feeds. if let Some(after) = after { tokio::task::block_in_place(|| { - for s in posts_iter.by_ref() { + for s in posts_iter.by_ref() { if s == after { break; } @@ -655,12 +668,19 @@ impl Storage for FileStorage { let settings: HashMap<&str, serde_json::Value> = serde_json::from_str(&content)?; match settings.get(S::ID) { Some(value) => Ok(serde_json::from_value::<S>(value.clone())?), - None => Err(StorageError::from_static(ErrorKind::Backend, "Setting not set")) + None => Err(StorageError::from_static( + ErrorKind::Backend, + "Setting not set", + )), } } #[tracing::instrument(skip(self))] - async fn set_setting<S: settings::Setting>(&self, user: &url::Url, value: S::Data) -> Result<()> { + async fn set_setting<S: settings::Setting>( + &self, + user: &url::Url, + value: S::Data, + ) -> Result<()> { let mut path = relative_path::RelativePathBuf::new(); path.push(user.authority()); path.push("settings"); @@ -704,20 +724,28 @@ impl Storage for FileStorage { tempfile.sync_all().await?; drop(tempfile); tokio::fs::rename(temppath, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + tokio::fs::File::open(path.parent().unwrap()) + .await? + .sync_all() + .await?; Ok(()) } #[tracing::instrument(skip(self))] - async fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> Result<()> { + async fn add_or_update_webmention( + &self, + target: &str, + mention_type: MentionType, + mention: serde_json::Value, + ) -> Result<()> { let path = url_to_path(&self.root_dir, target); let tempfilename = path.with_extension("tmp"); let mut temp = OpenOptions::new() - .write(true) - .create_new(true) - .open(&tempfilename) - .await?; + .write(true) + .create_new(true) + .open(&tempfilename) + .await?; let mut file = OpenOptions::new().read(true).open(&path).await?; let mut post: serde_json::Value = { @@ -752,13 +780,20 @@ impl Storage for FileStorage { temp.sync_all().await?; drop(temp); tokio::fs::rename(tempfilename, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + tokio::fs::File::open(path.parent().unwrap()) + .await? + .sync_all() + .await?; Ok(()) } - async fn all_posts<'this>(&'this self, user: &url::Url) -> Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this> { + async fn all_posts<'this>( + &'this self, + user: &url::Url, + ) -> Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this> { todo!(); - #[allow(unreachable_code)] Ok(futures::stream::empty()) // for type inference + #[allow(unreachable_code)] + Ok(futures::stream::empty()) // for type inference } } diff --git a/src/database/memory.rs b/src/database/memory.rs index c2ceb85..75f04de 100644 --- a/src/database/memory.rs +++ b/src/database/memory.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; -use crate::database::{ErrorKind, MicropubChannel, Result, settings, Storage, StorageError}; +use crate::database::{settings, ErrorKind, MicropubChannel, Result, Storage, StorageError}; #[derive(Clone, Debug, Default)] /// A simple in-memory store for testing purposes. @@ -90,9 +90,16 @@ impl Storage for MemoryStorage { Ok(()) } - async fn update_post(&self, url: &'_ str, update: crate::micropub::MicropubUpdate) -> Result<()> { + async fn update_post( + &self, + url: &'_ str, + update: crate::micropub::MicropubUpdate, + ) -> Result<()> { let mut guard = self.mapping.write().await; - let mut post = guard.get_mut(url).ok_or(StorageError::from_static(ErrorKind::NotFound, "The specified post wasn't found in the database."))?; + let mut post = guard.get_mut(url).ok_or(StorageError::from_static( + ErrorKind::NotFound, + "The specified post wasn't found in the database.", + ))?; use crate::micropub::MicropubPropertyDeletion; @@ -208,7 +215,7 @@ impl Storage for MemoryStorage { url: &'_ str, cursor: Option<&'_ str>, limit: usize, - user: Option<&url::Url> + user: Option<&url::Url>, ) -> Result<Option<(serde_json::Value, Option<String>)>> { todo!() } @@ -224,25 +231,39 @@ impl Storage for MemoryStorage { } #[allow(unused_variables)] - async fn set_setting<S: settings::Setting>(&self, user: &url::Url, value: S::Data) -> Result<()> { + async fn set_setting<S: settings::Setting>( + &self, + user: &url::Url, + value: S::Data, + ) -> Result<()> { todo!() } #[allow(unused_variables)] - async fn add_or_update_webmention(&self, target: &str, mention_type: kittybox_util::MentionType, mention: serde_json::Value) -> Result<()> { + async fn add_or_update_webmention( + &self, + target: &str, + mention_type: kittybox_util::MentionType, + mention: serde_json::Value, + ) -> Result<()> { todo!() } #[allow(unused_variables)] async fn update_with<F: FnOnce(&mut serde_json::Value) + Send>( - &self, url: &str, f: F + &self, + url: &str, + f: F, ) -> Result<(serde_json::Value, serde_json::Value)> { todo!() } - async fn all_posts<'this>(&'this self, _user: &url::Url) -> Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this> { + async fn all_posts<'this>( + &'this self, + _user: &url::Url, + ) -> Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this> { todo!(); - #[allow(unreachable_code)] Ok(futures::stream::pending()) + #[allow(unreachable_code)] + Ok(futures::stream::pending()) } - } diff --git a/src/database/mod.rs b/src/database/mod.rs index 4390ae7..de51c2c 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -177,7 +177,7 @@ impl StorageError { Self { msg: Cow::Borrowed(msg), source: None, - kind + kind, } } /// Create a StorageError using another arbitrary Error as a source. @@ -219,27 +219,34 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { fn post_exists(&self, url: &str) -> impl Future<Output = Result<bool>> + Send; /// Load a post from the database in MF2-JSON format, deserialized from JSON. - fn get_post(&self, url: &str) -> impl Future<Output = Result<Option<serde_json::Value>>> + Send; + fn get_post(&self, url: &str) + -> impl Future<Output = Result<Option<serde_json::Value>>> + Send; /// Save a post to the database as an MF2-JSON structure. /// /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined. - fn put_post(&self, post: &serde_json::Value, user: &url::Url) -> impl Future<Output = Result<()>> + Send; + fn put_post( + &self, + post: &serde_json::Value, + user: &url::Url, + ) -> impl Future<Output = Result<()>> + Send; /// Add post to feed. Some database implementations might have optimized ways to do this. #[tracing::instrument(skip(self))] fn add_to_feed(&self, feed: &str, post: &str) -> impl Future<Output = Result<()>> + Send { tracing::debug!("Inserting {} into {} using `update_post`", post, feed); - self.update_post(feed, serde_json::from_value( - serde_json::json!({"add": {"children": [post]}})).unwrap() + self.update_post( + feed, + serde_json::from_value(serde_json::json!({"add": {"children": [post]}})).unwrap(), ) } /// Remove post from feed. Some database implementations might have optimized ways to do this. #[tracing::instrument(skip(self))] fn remove_from_feed(&self, feed: &str, post: &str) -> impl Future<Output = Result<()>> + Send { tracing::debug!("Removing {} into {} using `update_post`", post, feed); - self.update_post(feed, serde_json::from_value( - serde_json::json!({"delete": {"children": [post]}})).unwrap() + self.update_post( + feed, + serde_json::from_value(serde_json::json!({"delete": {"children": [post]}})).unwrap(), ) } @@ -254,7 +261,11 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { /// /// Default implementation calls [`Storage::update_with`] and uses /// [`update.apply`][MicropubUpdate::apply] to update the post. - fn update_post(&self, url: &str, update: MicropubUpdate) -> impl Future<Output = Result<()>> + Send { + fn update_post( + &self, + url: &str, + update: MicropubUpdate, + ) -> impl Future<Output = Result<()>> + Send { let fut = self.update_with(url, |post| { update.apply(post); }); @@ -274,12 +285,17 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { /// /// Returns old post and the new post after editing. fn update_with<F: FnOnce(&mut serde_json::Value) + Send>( - &self, url: &str, f: F + &self, + url: &str, + f: F, ) -> impl Future<Output = Result<(serde_json::Value, serde_json::Value)>> + Send; /// Get a list of channels available for the user represented by /// the `user` domain to write to. - fn get_channels(&self, user: &url::Url) -> impl Future<Output = Result<Vec<MicropubChannel>>> + Send; + fn get_channels( + &self, + user: &url::Url, + ) -> impl Future<Output = Result<Vec<MicropubChannel>>> + Send; /// Fetch a feed at `url` and return an h-feed object containing /// `limit` posts after a post by url `after`, filtering the content @@ -329,7 +345,7 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { url: &'_ str, cursor: Option<&'_ str>, limit: usize, - user: Option<&url::Url> + user: Option<&url::Url>, ) -> impl Future<Output = Result<Option<(serde_json::Value, Option<String>)>>> + Send; /// Deletes a post from the database irreversibly. Must be idempotent. @@ -339,7 +355,11 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { fn get_setting<S: Setting>(&self, user: &url::Url) -> impl Future<Output = Result<S>> + Send; /// Commits a setting to the setting store. - fn set_setting<S: Setting>(&self, user: &url::Url, value: S::Data) -> impl Future<Output = Result<()>> + Send; + fn set_setting<S: Setting>( + &self, + user: &url::Url, + value: S::Data, + ) -> impl Future<Output = Result<()>> + Send; /// Add (or update) a webmention on a certian post. /// @@ -355,11 +375,19 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { /// /// Besides, it may even allow for nice tricks like storing the /// webmentions separately and rehydrating them on feed reads. - fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> impl Future<Output = Result<()>> + Send; + fn add_or_update_webmention( + &self, + target: &str, + mention_type: MentionType, + mention: serde_json::Value, + ) -> impl Future<Output = Result<()>> + Send; /// Return a stream of all posts ever made by a certain user, in /// reverse-chronological order. - fn all_posts<'this>(&'this self, user: &url::Url) -> impl Future<Output = Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this>> + Send; + fn all_posts<'this>( + &'this self, + user: &url::Url, + ) -> impl Future<Output = Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this>> + Send; } #[cfg(test)] @@ -464,7 +492,8 @@ mod tests { "replace": { "content": ["Different test content"] } - })).unwrap(), + })) + .unwrap(), ) .await .unwrap(); @@ -511,7 +540,10 @@ mod tests { .put_post(&feed, &"https://fireburn.ru/".parse().unwrap()) .await .unwrap(); - let chans = backend.get_channels(&"https://fireburn.ru/".parse().unwrap()).await.unwrap(); + let chans = backend + .get_channels(&"https://fireburn.ru/".parse().unwrap()) + .await + .unwrap(); assert_eq!(chans.len(), 1); assert_eq!( chans[0], @@ -526,16 +558,16 @@ mod tests { backend .set_setting::<settings::SiteName>( &"https://fireburn.ru/".parse().unwrap(), - "Vika's Hideout".to_owned() + "Vika's Hideout".to_owned(), ) .await .unwrap(); assert_eq!( backend - .get_setting::<settings::SiteName>(&"https://fireburn.ru/".parse().unwrap()) - .await - .unwrap() - .as_ref(), + .get_setting::<settings::SiteName>(&"https://fireburn.ru/".parse().unwrap()) + .await + .unwrap() + .as_ref(), "Vika's Hideout" ); } @@ -597,11 +629,9 @@ mod tests { async fn test_feed_pagination<Backend: Storage>(backend: Backend) { let posts = { - let mut posts = std::iter::from_fn( - || Some(gen_random_post("fireburn.ru")) - ) - .take(40) - .collect::<Vec<serde_json::Value>>(); + let mut posts = std::iter::from_fn(|| Some(gen_random_post("fireburn.ru"))) + .take(40) + .collect::<Vec<serde_json::Value>>(); // Reverse the array so it's in reverse-chronological order posts.reverse(); @@ -629,7 +659,10 @@ mod tests { .put_post(post, &"https://fireburn.ru/".parse().unwrap()) .await .unwrap(); - backend.add_to_feed(key, post["properties"]["uid"][0].as_str().unwrap()).await.unwrap(); + backend + .add_to_feed(key, post["properties"]["uid"][0].as_str().unwrap()) + .await + .unwrap(); } let limit: usize = 10; @@ -648,23 +681,16 @@ mod tests { .unwrap() .iter() .map(|post| post["properties"]["uid"][0].as_str().unwrap()) - .collect::<Vec<_>>() - [0..10], + .collect::<Vec<_>>()[0..10], posts .iter() .map(|post| post["properties"]["uid"][0].as_str().unwrap()) - .collect::<Vec<_>>() - [0..10] + .collect::<Vec<_>>()[0..10] ); tracing::debug!("Continuing with cursor: {:?}", cursor); let (result2, cursor2) = backend - .read_feed_with_cursor( - key, - cursor.as_deref(), - limit, - None, - ) + .read_feed_with_cursor(key, cursor.as_deref(), limit, None) .await .unwrap() .unwrap(); @@ -676,12 +702,7 @@ mod tests { tracing::debug!("Continuing with cursor: {:?}", cursor); let (result3, cursor3) = backend - .read_feed_with_cursor( - key, - cursor2.as_deref(), - limit, - None, - ) + .read_feed_with_cursor(key, cursor2.as_deref(), limit, None) .await .unwrap() .unwrap(); @@ -693,12 +714,7 @@ mod tests { tracing::debug!("Continuing with cursor: {:?}", cursor); let (result4, _) = backend - .read_feed_with_cursor( - key, - cursor3.as_deref(), - limit, - None, - ) + .read_feed_with_cursor(key, cursor3.as_deref(), limit, None) .await .unwrap() .unwrap(); @@ -725,24 +741,43 @@ mod tests { async fn test_webmention_addition<Backend: Storage>(db: Backend) { let post = gen_random_post("fireburn.ru"); - db.put_post(&post, &"https://fireburn.ru/".parse().unwrap()).await.unwrap(); + db.put_post(&post, &"https://fireburn.ru/".parse().unwrap()) + .await + .unwrap(); const TYPE: MentionType = MentionType::Reply; let target = post["properties"]["uid"][0].as_str().unwrap(); let mut reply = gen_random_mention("aaronparecki.com", TYPE, target); - let (read_post, _) = db.read_feed_with_cursor(target, None, 20, None).await.unwrap().unwrap(); + let (read_post, _) = db + .read_feed_with_cursor(target, None, 20, None) + .await + .unwrap() + .unwrap(); assert_eq!(post, read_post); - db.add_or_update_webmention(target, TYPE, reply.clone()).await.unwrap(); + db.add_or_update_webmention(target, TYPE, reply.clone()) + .await + .unwrap(); - let (read_post, _) = db.read_feed_with_cursor(target, None, 20, None).await.unwrap().unwrap(); + let (read_post, _) = db + .read_feed_with_cursor(target, None, 20, None) + .await + .unwrap() + .unwrap(); assert_eq!(read_post["properties"]["comment"][0], reply); - reply["properties"]["content"][0] = json!(rand::random::<faker_rand::lorem::Paragraphs>().to_string()); + reply["properties"]["content"][0] = + json!(rand::random::<faker_rand::lorem::Paragraphs>().to_string()); - db.add_or_update_webmention(target, TYPE, reply.clone()).await.unwrap(); - let (read_post, _) = db.read_feed_with_cursor(target, None, 20, None).await.unwrap().unwrap(); + db.add_or_update_webmention(target, TYPE, reply.clone()) + .await + .unwrap(); + let (read_post, _) = db + .read_feed_with_cursor(target, None, 20, None) + .await + .unwrap() + .unwrap(); assert_eq!(read_post["properties"]["comment"][0], reply); } @@ -752,16 +787,20 @@ mod tests { let post = { let mut post = gen_random_post("fireburn.ru"); let urls = post["properties"]["url"].as_array_mut().unwrap(); - urls.push(serde_json::Value::String( - PERMALINK.to_owned() - )); + urls.push(serde_json::Value::String(PERMALINK.to_owned())); post }; - db.put_post(&post, &"https://fireburn.ru/".parse().unwrap()).await.unwrap(); + db.put_post(&post, &"https://fireburn.ru/".parse().unwrap()) + .await + .unwrap(); for i in post["properties"]["url"].as_array().unwrap() { - let (read_post, _) = db.read_feed_with_cursor(i.as_str().unwrap(), None, 20, None).await.unwrap().unwrap(); + let (read_post, _) = db + .read_feed_with_cursor(i.as_str().unwrap(), None, 20, None) + .await + .unwrap() + .unwrap(); assert_eq!(read_post, post); } } @@ -786,7 +825,7 @@ mod tests { async fn $func_name() { let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); let backend = super::super::FileStorage { - root_dir: tempdir.path().to_path_buf() + root_dir: tempdir.path().to_path_buf(), }; super::$func_name(backend).await } @@ -800,7 +839,7 @@ mod tests { #[tracing_test::traced_test] async fn $func_name( pool_opts: sqlx::postgres::PgPoolOptions, - connect_opts: sqlx::postgres::PgConnectOptions + connect_opts: sqlx::postgres::PgConnectOptions, ) -> Result<(), sqlx::Error> { let db = { //use sqlx::ConnectOptions; diff --git a/src/database/postgres/mod.rs b/src/database/postgres/mod.rs index af19fea..ec67efa 100644 --- a/src/database/postgres/mod.rs +++ b/src/database/postgres/mod.rs @@ -5,7 +5,7 @@ use kittybox_util::{micropub::Channel as MicropubChannel, MentionType}; use sqlx::{ConnectOptions, Executor, PgPool}; use super::settings::Setting; -use super::{Storage, Result, StorageError, ErrorKind}; +use super::{ErrorKind, Result, Storage, StorageError}; static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!(); @@ -14,7 +14,7 @@ impl From<sqlx::Error> for StorageError { Self::with_source( super::ErrorKind::Backend, Cow::Owned(format!("sqlx error: {}", &value)), - Box::new(value) + Box::new(value), ) } } @@ -24,7 +24,7 @@ impl From<sqlx::migrate::MigrateError> for StorageError { Self::with_source( super::ErrorKind::Backend, Cow::Owned(format!("sqlx migration error: {}", &value)), - Box::new(value) + Box::new(value), ) } } @@ -32,14 +32,15 @@ impl From<sqlx::migrate::MigrateError> for StorageError { /// Micropub storage that uses a PostgreSQL database. #[derive(Debug, Clone)] pub struct PostgresStorage { - db: PgPool + db: PgPool, } impl PostgresStorage { /// Construct a [`PostgresStorage`] from a [`sqlx::PgPool`], /// running appropriate migrations. pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result<Self> { - db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox")).await?; + db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox")) + .await?; MIGRATOR.run(&db).await?; Ok(Self { db }) } @@ -50,19 +51,22 @@ impl Storage for PostgresStorage { /// migrations on the database. async fn new(url: &'_ url::Url) -> Result<Self> { tracing::debug!("Postgres URL: {url}"); - let options = sqlx::postgres::PgConnectOptions::from_url(url)? - .options([("search_path", "kittybox")]); + let options = + sqlx::postgres::PgConnectOptions::from_url(url)?.options([("search_path", "kittybox")]); Self::from_pool( sqlx::postgres::PgPoolOptions::new() .max_connections(50) .connect_with(options) - .await? - ).await - + .await?, + ) + .await } - async fn all_posts<'this>(&'this self, user: &url::Url) -> Result<impl Stream<Item = serde_json::Value> + Send + 'this> { + async fn all_posts<'this>( + &'this self, + user: &url::Url, + ) -> Result<impl Stream<Item = serde_json::Value> + Send + 'this> { let authority = user.authority().to_owned(); Ok( sqlx::query_scalar::<_, serde_json::Value>("SELECT mf2 FROM kittybox.mf2_json WHERE owner = $1 ORDER BY (mf2_json.mf2 #>> '{properties,published,0}') DESC") @@ -74,18 +78,20 @@ impl Storage for PostgresStorage { #[tracing::instrument(skip(self))] async fn categories(&self, url: &str) -> Result<Vec<String>> { - sqlx::query_scalar::<_, String>(" + sqlx::query_scalar::<_, String>( + " SELECT jsonb_array_elements(mf2['properties']['category']) AS category FROM kittybox.mf2_json WHERE jsonb_typeof(mf2['properties']['category']) = 'array' AND uid LIKE ($1 + '%') GROUP BY category ORDER BY count(*) DESC -") - .bind(url) - .fetch_all(&self.db) - .await - .map_err(|err| err.into()) +", + ) + .bind(url) + .fetch_all(&self.db) + .await + .map_err(|err| err.into()) } #[tracing::instrument(skip(self))] async fn post_exists(&self, url: &str) -> Result<bool> { @@ -98,13 +104,14 @@ WHERE #[tracing::instrument(skip(self))] async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> { - sqlx::query_as::<_, (serde_json::Value,)>("SELECT mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1") - .bind(url) - .fetch_optional(&self.db) - .await - .map(|v| v.map(|v| v.0)) - .map_err(|err| err.into()) - + sqlx::query_as::<_, (serde_json::Value,)>( + "SELECT mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1", + ) + .bind(url) + .fetch_optional(&self.db) + .await + .map(|v| v.map(|v| v.0)) + .map_err(|err| err.into()) } #[tracing::instrument(skip(self))] @@ -122,13 +129,15 @@ WHERE #[tracing::instrument(skip(self))] async fn add_to_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> { tracing::debug!("Inserting {} into {}", post, feed); - sqlx::query("INSERT INTO kittybox.children (parent, child) VALUES ($1, $2) ON CONFLICT DO NOTHING") - .bind(feed) - .bind(post) - .execute(&self.db) - .await - .map(|_| ()) - .map_err(Into::into) + sqlx::query( + "INSERT INTO kittybox.children (parent, child) VALUES ($1, $2) ON CONFLICT DO NOTHING", + ) + .bind(feed) + .bind(post) + .execute(&self.db) + .await + .map(|_| ()) + .map_err(Into::into) } #[tracing::instrument(skip(self))] @@ -143,7 +152,12 @@ WHERE } #[tracing::instrument(skip(self))] - async fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> Result<()> { + async fn add_or_update_webmention( + &self, + target: &str, + mention_type: MentionType, + mention: serde_json::Value, + ) -> Result<()> { let mut txn = self.db.begin().await?; let (uid, mut post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE") @@ -190,7 +204,9 @@ WHERE #[tracing::instrument(skip(self), fields(f = std::any::type_name::<F>()))] async fn update_with<F: FnOnce(&mut serde_json::Value) + Send>( - &self, url: &str, f: F + &self, + url: &str, + f: F, ) -> Result<(serde_json::Value, serde_json::Value)> { tracing::debug!("Updating post {}", url); let mut txn = self.db.begin().await?; @@ -250,12 +266,12 @@ WHERE url: &'_ str, cursor: Option<&'_ str>, limit: usize, - user: Option<&url::Url> + user: Option<&url::Url>, ) -> Result<Option<(serde_json::Value, Option<String>)>> { let mut txn = self.db.begin().await?; sqlx::query("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ ONLY") - .execute(&mut *txn) - .await?; + .execute(&mut *txn) + .await?; tracing::debug!("Started txn: {:?}", txn); let mut feed = match sqlx::query_scalar::<_, serde_json::Value>(" SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 @@ -273,11 +289,17 @@ SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json WHERE uid = $1 OR mf2 // The second query is very long and will probably be extremely // expensive. It's best to skip it on types where it doesn't make sense // (Kittybox doesn't support rendering children on non-feeds) - if !feed["type"].as_array().unwrap().iter().any(|t| *t == serde_json::json!("h-feed")) { + if !feed["type"] + .as_array() + .unwrap() + .iter() + .any(|t| *t == serde_json::json!("h-feed")) + { return Ok(Some((feed, None))); } - feed["children"] = sqlx::query_scalar::<_, serde_json::Value>(" + feed["children"] = sqlx::query_scalar::<_, serde_json::Value>( + " SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json INNER JOIN kittybox.children ON mf2_json.uid = children.child @@ -302,17 +324,19 @@ WHERE ) AND ($4 IS NULL OR ((mf2_json.mf2 #>> '{properties,published,0}') < $4)) ORDER BY (mf2_json.mf2 #>> '{properties,published,0}') DESC -LIMIT $2" +LIMIT $2", ) - .bind(url) - .bind(limit as i64) - .bind(user.map(url::Url::as_str)) - .bind(cursor) - .fetch_all(&mut *txn) - .await - .map(serde_json::Value::Array)?; - - let new_cursor = feed["children"].as_array().unwrap() + .bind(url) + .bind(limit as i64) + .bind(user.map(url::Url::as_str)) + .bind(cursor) + .fetch_all(&mut *txn) + .await + .map(serde_json::Value::Array)?; + + let new_cursor = feed["children"] + .as_array() + .unwrap() .last() .map(|v| v["properties"]["published"][0].as_str().unwrap().to_owned()); @@ -335,7 +359,7 @@ LIMIT $2" .await { Ok((value,)) => Ok(serde_json::from_value(value)?), - Err(err) => Err(err.into()) + Err(err) => Err(err.into()), } } diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 9ba1a69..94b8aa7 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -12,12 +12,10 @@ use tracing::{debug, error}; //pub mod login; pub mod onboarding; +pub use kittybox_frontend_renderer::assets::statics; use kittybox_frontend_renderer::{ - Entry, Feed, VCard, - ErrorPage, Template, MainPage, - POSTS_PER_PAGE + Entry, ErrorPage, Feed, MainPage, Template, VCard, POSTS_PER_PAGE, }; -pub use kittybox_frontend_renderer::assets::statics; #[derive(Debug, Deserialize)] pub struct QueryParams { @@ -106,7 +104,7 @@ pub fn filter_post( .map(|i| -> &str { match i { serde_json::Value::String(ref author) => author.as_str(), - mf2 => mf2["properties"]["uid"][0].as_str().unwrap() + mf2 => mf2["properties"]["uid"][0].as_str().unwrap(), } }) .map(|i| i.parse().unwrap()) @@ -116,11 +114,13 @@ pub fn filter_post( .unwrap_or("public"); let audience = { let mut audience = author_list.clone(); - audience.extend(post["properties"]["audience"] - .as_array() - .unwrap_or(&empty_vec) - .iter() - .map(|i| i.as_str().unwrap().parse().unwrap())); + audience.extend( + post["properties"]["audience"] + .as_array() + .unwrap_or(&empty_vec) + .iter() + .map(|i| i.as_str().unwrap().parse().unwrap()), + ); audience }; @@ -134,7 +134,10 @@ pub fn filter_post( let location_visibility = post["properties"]["location-visibility"][0] .as_str() .unwrap_or("private"); - tracing::debug!("Post contains location, location privacy = {}", location_visibility); + tracing::debug!( + "Post contains location, location privacy = {}", + location_visibility + ); let mut author = post["properties"]["author"] .as_array() .unwrap_or(&empty_vec) @@ -155,16 +158,18 @@ pub fn filter_post( post["properties"]["author"] = serde_json::Value::Array( children .into_iter() - .filter_map(|post| if post.is_string() { - Some(post) - } else { - filter_post(post, user) + .filter_map(|post| { + if post.is_string() { + Some(post) + } else { + filter_post(post, user) + } }) - .collect::<Vec<serde_json::Value>>() + .collect::<Vec<serde_json::Value>>(), ); - }, - serde_json::Value::Null => {}, - other => post["properties"]["author"] = other + } + serde_json::Value::Null => {} + other => post["properties"]["author"] = other, } match post["children"].take() { @@ -173,11 +178,11 @@ pub fn filter_post( children .into_iter() .filter_map(|post| filter_post(post, user)) - .collect::<Vec<serde_json::Value>>() + .collect::<Vec<serde_json::Value>>(), ); - }, - serde_json::Value::Null => {}, - other => post["children"] = other + } + serde_json::Value::Null => {} + other => post["children"] = other, } Some(post) } @@ -209,7 +214,7 @@ async fn get_post_from_database<S: Storage>( )) } } - } + }, None => Err(FrontendError::with_code( StatusCode::NOT_FOUND, "Post not found in the database", @@ -240,7 +245,7 @@ pub async fn homepage<D: Storage>( Host(host): Host, Query(query): Query<QueryParams>, State(db): State<D>, - session: Option<crate::Session> + session: Option<crate::Session>, ) -> impl IntoResponse { // This is stupid, but there is no other way. let hcard_url: url::Url = format!("https://{}/", host).parse().unwrap(); @@ -252,7 +257,7 @@ pub async fn homepage<D: Storage>( ); headers.insert( axum::http::header::X_CONTENT_TYPE_OPTIONS, - axum::http::HeaderValue::from_static("nosniff") + axum::http::HeaderValue::from_static("nosniff"), ); let user = session.as_deref().map(|s| &s.me); @@ -268,18 +273,16 @@ pub async fn homepage<D: Storage>( // btw is it more efficient to fetch these in parallel? let (blogname, webring, channels) = tokio::join!( db.get_setting::<crate::database::settings::SiteName>(&hcard_url) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_setting::<crate::database::settings::Webring>(&hcard_url) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_channels(&hcard_url).map(|i| i.unwrap_or_default()) ); if user.is_some() { headers.insert( axum::http::header::CACHE_CONTROL, - axum::http::HeaderValue::from_static("private") + axum::http::HeaderValue::from_static("private"), ); } // Render the homepage @@ -295,12 +298,13 @@ pub async fn homepage<D: Storage>( feed: &hfeed, card: &hcard, cursor: cursor.as_deref(), - webring: crate::database::settings::Setting::into_inner(webring) + webring: crate::database::settings::Setting::into_inner(webring), } .to_string(), } .to_string(), - ).into_response() + ) + .into_response() } Err(err) => { if err.code == StatusCode::NOT_FOUND { @@ -310,19 +314,20 @@ pub async fn homepage<D: Storage>( StatusCode::FOUND, [(axum::http::header::LOCATION, "/.kittybox/onboarding")], String::default(), - ).into_response() + ) + .into_response() } else { error!("Error while fetching h-card and/or h-feed: {}", err); // Return the error let (blogname, channels) = tokio::join!( db.get_setting::<crate::database::settings::SiteName>(&hcard_url) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_channels(&hcard_url).map(|i| i.unwrap_or_default()) ); ( - err.code(), headers, + err.code(), + headers, Template { title: blogname.as_ref(), blog_name: blogname.as_ref(), @@ -335,7 +340,8 @@ pub async fn homepage<D: Storage>( .to_string(), } .to_string(), - ).into_response() + ) + .into_response() } } } @@ -351,17 +357,13 @@ pub async fn catchall<D: Storage>( ) -> impl IntoResponse { let user: Option<&url::Url> = session.as_deref().map(|p| &p.me); let host = url::Url::parse(&format!("https://{}/", host)).unwrap(); - let path = host - .clone() - .join(uri.path()) - .unwrap(); + let path = host.clone().join(uri.path()).unwrap(); match get_post_from_database(&db, path.as_str(), query.after, user).await { Ok((post, cursor)) => { let (blogname, channels) = tokio::join!( db.get_setting::<crate::database::settings::SiteName>(&host) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_channels(&host).map(|i| i.unwrap_or_default()) ); let mut headers = axum::http::HeaderMap::new(); @@ -371,12 +373,12 @@ pub async fn catchall<D: Storage>( ); headers.insert( axum::http::header::X_CONTENT_TYPE_OPTIONS, - axum::http::HeaderValue::from_static("nosniff") + axum::http::HeaderValue::from_static("nosniff"), ); if user.is_some() { headers.insert( axum::http::header::CACHE_CONTROL, - axum::http::HeaderValue::from_static("private") + axum::http::HeaderValue::from_static("private"), ); } @@ -384,19 +386,20 @@ pub async fn catchall<D: Storage>( let last_modified = post["properties"]["updated"] .as_array() .and_then(|v| v.last()) - .or_else(|| post["properties"]["published"] - .as_array() - .and_then(|v| v.last()) - ) + .or_else(|| { + post["properties"]["published"] + .as_array() + .and_then(|v| v.last()) + }) .and_then(serde_json::Value::as_str) - .and_then(|dt| chrono::DateTime::<chrono::FixedOffset>::parse_from_rfc3339(dt).ok()); + .and_then(|dt| { + chrono::DateTime::<chrono::FixedOffset>::parse_from_rfc3339(dt).ok() + }); if let Some(last_modified) = last_modified { - headers.typed_insert( - axum_extra::headers::LastModified::from( - std::time::SystemTime::from(last_modified) - ) - ); + headers.typed_insert(axum_extra::headers::LastModified::from( + std::time::SystemTime::from(last_modified), + )); } } @@ -410,8 +413,16 @@ pub async fn catchall<D: Storage>( feeds: channels, user: session.as_deref(), content: match post.pointer("/type/0").and_then(|i| i.as_str()) { - Some("h-entry") => Entry { post: &post, from_feed: false, }.to_string(), - Some("h-feed") => Feed { feed: &post, cursor: cursor.as_deref() }.to_string(), + Some("h-entry") => Entry { + post: &post, + from_feed: false, + } + .to_string(), + Some("h-feed") => Feed { + feed: &post, + cursor: cursor.as_deref(), + } + .to_string(), Some("h-card") => VCard { card: &post }.to_string(), unknown => { unimplemented!("Template for MF2-JSON type {:?}", unknown) @@ -419,13 +430,13 @@ pub async fn catchall<D: Storage>( }, } .to_string(), - ).into_response() + ) + .into_response() } Err(err) => { let (blogname, channels) = tokio::join!( db.get_setting::<crate::database::settings::SiteName>(&host) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_channels(&host).map(|i| i.unwrap_or_default()) ); ( @@ -446,7 +457,8 @@ pub async fn catchall<D: Storage>( .to_string(), } .to_string(), - ).into_response() + ) + .into_response() } } } diff --git a/src/frontend/onboarding.rs b/src/frontend/onboarding.rs index bf313cf..3b53911 100644 --- a/src/frontend/onboarding.rs +++ b/src/frontend/onboarding.rs @@ -10,7 +10,7 @@ use axum::{ use axum_extra::extract::Host; use kittybox_frontend_renderer::{ErrorPage, OnboardingPage, Template}; use serde::Deserialize; -use tokio::{task::JoinSet, sync::Mutex}; +use tokio::{sync::Mutex, task::JoinSet}; use tracing::{debug, error}; use super::FrontendError; @@ -64,7 +64,8 @@ async fn onboard<D: Storage + 'static>( me: user_uid.clone(), client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), scope: kittybox_indieauth::Scopes::new(vec![kittybox_indieauth::Scope::Create]), - iat: None, exp: None + iat: None, + exp: None, }; tracing::debug!("User data: {:?}", user); @@ -99,19 +100,21 @@ async fn onboard<D: Storage + 'static>( continue; }; debug!("Creating feed {} with slug {}", &feed.name, &feed.slug); - let crate::micropub::util::NormalizedPost { id: _, post: feed } = crate::micropub::normalize_mf2( - serde_json::json!({ - "type": ["h-feed"], - "properties": {"name": [feed.name], "mp-slug": [feed.slug]} - }), - &user, - ); + let crate::micropub::util::NormalizedPost { id: _, post: feed } = + crate::micropub::normalize_mf2( + serde_json::json!({ + "type": ["h-feed"], + "properties": {"name": [feed.name], "mp-slug": [feed.slug]} + }), + &user, + ); db.put_post(&feed, &user.me) .await .map_err(FrontendError::from)?; } - let crate::micropub::util::NormalizedPost { id: uid, post } = crate::micropub::normalize_mf2(data.first_post, &user); + let crate::micropub::util::NormalizedPost { id: uid, post } = + crate::micropub::normalize_mf2(data.first_post, &user); tracing::debug!("Posting first post {}...", uid); crate::micropub::_post(&user, uid, post, db, http, jobset) .await @@ -169,6 +172,5 @@ where reqwest_middleware::ClientWithMiddleware: FromRef<St>, St: Clone + Send + Sync + 'static, { - axum::routing::get(get) - .post(post::<S>) + axum::routing::get(get).post(post::<S>) } diff --git a/src/indieauth/backend.rs b/src/indieauth/backend.rs index b913256..9215adf 100644 --- a/src/indieauth/backend.rs +++ b/src/indieauth/backend.rs @@ -1,9 +1,7 @@ -use std::future::Future; -use std::collections::HashMap; -use kittybox_indieauth::{ - AuthorizationRequest, TokenData -}; +use kittybox_indieauth::{AuthorizationRequest, TokenData}; pub use kittybox_util::auth::EnrolledCredential; +use std::collections::HashMap; +use std::future::Future; type Result<T> = std::io::Result<T>; @@ -20,33 +18,72 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { /// Note for implementors: the [`AuthorizationRequest::me`] value /// is guaranteed to be [`Some(url::Url)`][Option::Some] and can /// be trusted to be correct and non-malicious. - fn create_code(&self, data: AuthorizationRequest) -> impl Future<Output = Result<String>> + Send; + fn create_code( + &self, + data: AuthorizationRequest, + ) -> impl Future<Output = Result<String>> + Send; /// Retreive an authorization request using the one-time /// code. Implementations must sanitize the `code` field to /// prevent exploits, and must check if the code should still be /// valid at this point in time (validity interval is left up to /// the implementation, but is recommended to be no more than 10 /// minutes). - fn get_code(&self, code: &str) -> impl Future<Output = Result<Option<AuthorizationRequest>>> + Send; + fn get_code( + &self, + code: &str, + ) -> impl Future<Output = Result<Option<AuthorizationRequest>>> + Send; // Token management. fn create_token(&self, data: TokenData) -> impl Future<Output = Result<String>> + Send; - fn get_token(&self, website: &url::Url, token: &str) -> impl Future<Output = Result<Option<TokenData>>> + Send; - fn list_tokens(&self, website: &url::Url) -> impl Future<Output = Result<HashMap<String, TokenData>>> + Send; - fn revoke_token(&self, website: &url::Url, token: &str) -> impl Future<Output = Result<()>> + Send; + fn get_token( + &self, + website: &url::Url, + token: &str, + ) -> impl Future<Output = Result<Option<TokenData>>> + Send; + fn list_tokens( + &self, + website: &url::Url, + ) -> impl Future<Output = Result<HashMap<String, TokenData>>> + Send; + fn revoke_token( + &self, + website: &url::Url, + token: &str, + ) -> impl Future<Output = Result<()>> + Send; // Refresh token management. fn create_refresh_token(&self, data: TokenData) -> impl Future<Output = Result<String>> + Send; - fn get_refresh_token(&self, website: &url::Url, token: &str) -> impl Future<Output = Result<Option<TokenData>>> + Send; - fn list_refresh_tokens(&self, website: &url::Url) -> impl Future<Output = Result<HashMap<String, TokenData>>> + Send; - fn revoke_refresh_token(&self, website: &url::Url, token: &str) -> impl Future<Output = Result<()>> + Send; + fn get_refresh_token( + &self, + website: &url::Url, + token: &str, + ) -> impl Future<Output = Result<Option<TokenData>>> + Send; + fn list_refresh_tokens( + &self, + website: &url::Url, + ) -> impl Future<Output = Result<HashMap<String, TokenData>>> + Send; + fn revoke_refresh_token( + &self, + website: &url::Url, + token: &str, + ) -> impl Future<Output = Result<()>> + Send; // Password management. /// Verify a password. #[must_use] - fn verify_password(&self, website: &url::Url, password: String) -> impl Future<Output = Result<bool>> + Send; + fn verify_password( + &self, + website: &url::Url, + password: String, + ) -> impl Future<Output = Result<bool>> + Send; /// Enroll a password credential for a user. Only one password /// credential must exist for a given user. - fn enroll_password(&self, website: &url::Url, password: String) -> impl Future<Output = Result<()>> + Send; + fn enroll_password( + &self, + website: &url::Url, + password: String, + ) -> impl Future<Output = Result<()>> + Send; /// List currently enrolled credential types for a given user. - fn list_user_credential_types(&self, website: &url::Url) -> impl Future<Output = Result<Vec<EnrolledCredential>>> + Send; + fn list_user_credential_types( + &self, + website: &url::Url, + ) -> impl Future<Output = Result<Vec<EnrolledCredential>>> + Send; // WebAuthn credential management. #[cfg(feature = "webauthn")] /// Enroll a WebAuthn authenticator public key for this user. @@ -56,10 +93,17 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { /// This function can also be used to overwrite a passkey with an /// updated version after using /// [webauthn::prelude::Passkey::update_credential()]. - fn enroll_webauthn(&self, website: &url::Url, credential: webauthn::prelude::Passkey) -> impl Future<Output = Result<()>> + Send; + fn enroll_webauthn( + &self, + website: &url::Url, + credential: webauthn::prelude::Passkey, + ) -> impl Future<Output = Result<()>> + Send; #[cfg(feature = "webauthn")] /// List currently enrolled WebAuthn authenticators for a given user. - fn list_webauthn_pubkeys(&self, website: &url::Url) -> impl Future<Output = Result<Vec<webauthn::prelude::Passkey>>> + Send; + fn list_webauthn_pubkeys( + &self, + website: &url::Url, + ) -> impl Future<Output = Result<Vec<webauthn::prelude::Passkey>>> + Send; #[cfg(feature = "webauthn")] /// Persist registration challenge state for a little while so it /// can be used later. @@ -69,7 +113,7 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { fn persist_registration_challenge( &self, website: &url::Url, - state: webauthn::prelude::PasskeyRegistration + state: webauthn::prelude::PasskeyRegistration, ) -> impl Future<Output = Result<String>> + Send; #[cfg(feature = "webauthn")] /// Retrieve a persisted registration challenge. @@ -78,7 +122,7 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { fn retrieve_registration_challenge( &self, website: &url::Url, - challenge_id: &str + challenge_id: &str, ) -> impl Future<Output = Result<webauthn::prelude::PasskeyRegistration>> + Send; #[cfg(feature = "webauthn")] /// Persist authentication challenge state for a little while so @@ -92,7 +136,7 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { fn persist_authentication_challenge( &self, website: &url::Url, - state: webauthn::prelude::PasskeyAuthentication + state: webauthn::prelude::PasskeyAuthentication, ) -> impl Future<Output = Result<String>> + Send; #[cfg(feature = "webauthn")] /// Retrieve a persisted authentication challenge. @@ -101,7 +145,6 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { fn retrieve_authentication_challenge( &self, website: &url::Url, - challenge_id: &str + challenge_id: &str, ) -> impl Future<Output = Result<webauthn::prelude::PasskeyAuthentication>> + Send; - } diff --git a/src/indieauth/backend/fs.rs b/src/indieauth/backend/fs.rs index f74fbbc..26466fe 100644 --- a/src/indieauth/backend/fs.rs +++ b/src/indieauth/backend/fs.rs @@ -1,13 +1,16 @@ -use std::{path::PathBuf, collections::HashMap, borrow::Cow, time::{SystemTime, Duration}}; - -use super::{AuthBackend, Result, EnrolledCredential}; -use kittybox_indieauth::{ - AuthorizationRequest, TokenData +use std::{ + borrow::Cow, + collections::HashMap, + path::PathBuf, + time::{Duration, SystemTime}, }; + +use super::{AuthBackend, EnrolledCredential, Result}; +use kittybox_indieauth::{AuthorizationRequest, TokenData}; use serde::de::DeserializeOwned; -use tokio::{task::spawn_blocking, io::AsyncReadExt}; +use tokio::{io::AsyncReadExt, task::spawn_blocking}; #[cfg(feature = "webauthn")] -use webauthn::prelude::{Passkey, PasskeyRegistration, PasskeyAuthentication}; +use webauthn::prelude::{Passkey, PasskeyAuthentication, PasskeyRegistration}; const CODE_LENGTH: usize = 16; const TOKEN_LENGTH: usize = 128; @@ -29,7 +32,8 @@ impl FileBackend { } else { let mut s = String::with_capacity(filename.len()); - filename.chars() + filename + .chars() .filter(|c| c.is_alphanumeric()) .for_each(|c| s.push(c)); @@ -38,41 +42,41 @@ impl FileBackend { } #[inline] - async fn serialize_to_file<T: 'static + serde::ser::Serialize + Send, B: Into<Option<&'static str>>>( + async fn serialize_to_file< + T: 'static + serde::ser::Serialize + Send, + B: Into<Option<&'static str>>, + >( &self, dir: &str, basename: B, length: usize, - data: T + data: T, ) -> Result<String> { let basename = basename.into(); let has_ext = basename.is_some(); - let (filename, mut file) = kittybox_util::fs::mktemp( - self.path.join(dir), basename, length - ) + let (filename, mut file) = kittybox_util::fs::mktemp(self.path.join(dir), basename, length) .await .map(|(name, file)| (name, file.try_into_std().unwrap()))?; spawn_blocking(move || serde_json::to_writer(&mut file, &data)) .await - .unwrap_or_else(|e| panic!( - "Panic while serializing {}: {}", - std::any::type_name::<T>(), - e - )) + .unwrap_or_else(|e| { + panic!( + "Panic while serializing {}: {}", + std::any::type_name::<T>(), + e + ) + }) .map(move |_| { (if has_ext { - filename - .extension() - + filename.extension() } else { - filename - .file_name() + filename.file_name() }) - .unwrap() - .to_str() - .unwrap() - .to_owned() + .unwrap() + .to_str() + .unwrap() + .to_owned() }) .map_err(|err| err.into()) } @@ -86,17 +90,15 @@ impl FileBackend { ) -> Result<Option<(PathBuf, SystemTime, T)>> where T: serde::de::DeserializeOwned + Send, - B: Into<Option<&'static str>> + B: Into<Option<&'static str>>, { let basename = basename.into(); - let path = self.path - .join(dir) - .join(format!( - "{}{}{}", - basename.unwrap_or(""), - if basename.is_none() { "" } else { "." }, - FileBackend::sanitize_for_path(filename) - )); + let path = self.path.join(dir).join(format!( + "{}{}{}", + basename.unwrap_or(""), + if basename.is_none() { "" } else { "." }, + FileBackend::sanitize_for_path(filename) + )); let data = match tokio::fs::File::open(&path).await { Ok(mut file) => { @@ -106,13 +108,15 @@ impl FileBackend { match serde_json::from_slice::<'_, T>(buf.as_slice()) { Ok(data) => data, - Err(err) => return Err(err.into()) + Err(err) => return Err(err.into()), + } + } + Err(err) => { + if err.kind() == std::io::ErrorKind::NotFound { + return Ok(None); + } else { + return Err(err); } - }, - Err(err) => if err.kind() == std::io::ErrorKind::NotFound { - return Ok(None) - } else { - return Err(err) } }; @@ -125,7 +129,8 @@ impl FileBackend { #[tracing::instrument] fn url_to_dir(url: &url::Url) -> String { let host = url.host_str().unwrap(); - let port = url.port() + let port = url + .port() .map(|port| Cow::Owned(format!(":{}", port))) .unwrap_or(Cow::Borrowed("")); @@ -135,23 +140,26 @@ impl FileBackend { async fn list_files<'dir, 'this: 'dir, T: DeserializeOwned + Send>( &'this self, dir: &'dir str, - prefix: &'static str + prefix: &'static str, ) -> Result<HashMap<String, T>> { let dir = self.path.join(dir); let mut hashmap = HashMap::new(); let mut readdir = match tokio::fs::read_dir(dir).await { Ok(readdir) => readdir, - Err(err) => if err.kind() == std::io::ErrorKind::NotFound { - // empty hashmap - return Ok(hashmap); - } else { - return Err(err); + Err(err) => { + if err.kind() == std::io::ErrorKind::NotFound { + // empty hashmap + return Ok(hashmap); + } else { + return Err(err); + } } }; while let Some(entry) = readdir.next_entry().await? { // safe to unwrap; filenames are alphanumeric - let filename = entry.file_name() + let filename = entry + .file_name() .into_string() .expect("token filenames should be alphanumeric!"); if let Some(token) = filename.strip_prefix(&format!("{}.", prefix)) { @@ -166,16 +174,19 @@ impl FileBackend { Err(err) => { tracing::error!( "Error decoding token data from file {}: {}", - entry.path().display(), err + entry.path().display(), + err ); continue; } }; - }, - Err(err) => if err.kind() == std::io::ErrorKind::NotFound { - continue - } else { - return Err(err) + } + Err(err) => { + if err.kind() == std::io::ErrorKind::NotFound { + continue; + } else { + return Err(err); + } } } } @@ -194,19 +205,27 @@ impl AuthBackend for FileBackend { path = base.join(&format!(".{}", path.path())).unwrap(); } - tracing::debug!("Initializing File auth backend: {} -> {}", orig_path, path.path()); + tracing::debug!( + "Initializing File auth backend: {} -> {}", + orig_path, + path.path() + ); Ok(Self { - path: std::path::PathBuf::from(path.path()) + path: std::path::PathBuf::from(path.path()), }) } // Authorization code management. async fn create_code(&self, data: AuthorizationRequest) -> Result<String> { - self.serialize_to_file("codes", None, CODE_LENGTH, data).await + self.serialize_to_file("codes", None, CODE_LENGTH, data) + .await } async fn get_code(&self, code: &str) -> Result<Option<AuthorizationRequest>> { - match self.deserialize_from_file("codes", None, FileBackend::sanitize_for_path(code).as_ref()).await? { + match self + .deserialize_from_file("codes", None, FileBackend::sanitize_for_path(code).as_ref()) + .await? + { Some((path, ctime, data)) => { if let Err(err) = tokio::fs::remove_file(path).await { tracing::error!("Failed to clean up authorization code: {}", err); @@ -217,23 +236,28 @@ impl AuthBackend for FileBackend { } else { Ok(Some(data)) } - }, - None => Ok(None) + } + None => Ok(None), } } // Token management. async fn create_token(&self, data: TokenData) -> Result<String> { let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me)); - self.serialize_to_file(&dir, "access", TOKEN_LENGTH, data).await + self.serialize_to_file(&dir, "access", TOKEN_LENGTH, data) + .await } async fn get_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> { let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); - match self.deserialize_from_file::<TokenData, _>( - &dir, "access", - FileBackend::sanitize_for_path(token).as_ref() - ).await? { + match self + .deserialize_from_file::<TokenData, _>( + &dir, + "access", + FileBackend::sanitize_for_path(token).as_ref(), + ) + .await? + { Some((path, _, token)) => { if token.expired() { if let Err(err) = tokio::fs::remove_file(path).await { @@ -243,8 +267,8 @@ impl AuthBackend for FileBackend { } else { Ok(Some(token)) } - }, - None => Ok(None) + } + None => Ok(None), } } @@ -258,25 +282,36 @@ impl AuthBackend for FileBackend { self.path .join(FileBackend::url_to_dir(website)) .join("tokens") - .join(format!("access.{}", FileBackend::sanitize_for_path(token))) - ).await { + .join(format!("access.{}", FileBackend::sanitize_for_path(token))), + ) + .await + { Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()), - result => result + result => result, } } // Refresh token management. async fn create_refresh_token(&self, data: TokenData) -> Result<String> { let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me)); - self.serialize_to_file(&dir, "refresh", TOKEN_LENGTH, data).await + self.serialize_to_file(&dir, "refresh", TOKEN_LENGTH, data) + .await } - async fn get_refresh_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> { + async fn get_refresh_token( + &self, + website: &url::Url, + token: &str, + ) -> Result<Option<TokenData>> { let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); - match self.deserialize_from_file::<TokenData, _>( - &dir, "refresh", - FileBackend::sanitize_for_path(token).as_ref() - ).await? { + match self + .deserialize_from_file::<TokenData, _>( + &dir, + "refresh", + FileBackend::sanitize_for_path(token).as_ref(), + ) + .await? + { Some((path, _, token)) => { if token.expired() { if let Err(err) = tokio::fs::remove_file(path).await { @@ -286,8 +321,8 @@ impl AuthBackend for FileBackend { } else { Ok(Some(token)) } - }, - None => Ok(None) + } + None => Ok(None), } } @@ -301,57 +336,80 @@ impl AuthBackend for FileBackend { self.path .join(FileBackend::url_to_dir(website)) .join("tokens") - .join(format!("refresh.{}", FileBackend::sanitize_for_path(token))) - ).await { + .join(format!("refresh.{}", FileBackend::sanitize_for_path(token))), + ) + .await + { Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()), - result => result + result => result, } } // Password management. #[tracing::instrument(skip(password))] async fn verify_password(&self, website: &url::Url, password: String) -> Result<bool> { - use argon2::{Argon2, password_hash::{PasswordHash, PasswordVerifier}}; + use argon2::{ + password_hash::{PasswordHash, PasswordVerifier}, + Argon2, + }; - let password_filename = self.path + let password_filename = self + .path .join(FileBackend::url_to_dir(website)) .join("password"); - tracing::debug!("Reading password for {} from {}", website, password_filename.display()); + tracing::debug!( + "Reading password for {} from {}", + website, + password_filename.display() + ); match tokio::fs::read_to_string(password_filename).await { Ok(password_hash) => { let parsed_hash = { let hash = password_hash.trim(); - #[cfg(debug_assertions)] tracing::debug!("Password hash: {}", hash); - PasswordHash::new(hash) - .expect("Password hash should be valid!") + #[cfg(debug_assertions)] + tracing::debug!("Password hash: {}", hash); + PasswordHash::new(hash).expect("Password hash should be valid!") }; - Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok()) - }, - Err(err) => if err.kind() == std::io::ErrorKind::NotFound { - Ok(false) - } else { - Err(err) + Ok(Argon2::default() + .verify_password(password.as_bytes(), &parsed_hash) + .is_ok()) + } + Err(err) => { + if err.kind() == std::io::ErrorKind::NotFound { + Ok(false) + } else { + Err(err) + } } } } #[tracing::instrument(skip(password))] async fn enroll_password(&self, website: &url::Url, password: String) -> Result<()> { - use argon2::{Argon2, password_hash::{rand_core::OsRng, PasswordHasher, SaltString}}; + use argon2::{ + password_hash::{rand_core::OsRng, PasswordHasher, SaltString}, + Argon2, + }; - let password_filename = self.path + let password_filename = self + .path .join(FileBackend::url_to_dir(website)) .join("password"); let salt = SaltString::generate(&mut OsRng); let argon2 = Argon2::default(); - let password_hash = argon2.hash_password(password.as_bytes(), &salt) + let password_hash = argon2 + .hash_password(password.as_bytes(), &salt) .expect("Hashing a password should not error out") .to_string(); - tracing::debug!("Enrolling password for {} at {}", website, password_filename.display()); + tracing::debug!( + "Enrolling password for {} at {}", + website, + password_filename.display() + ); tokio::fs::write(password_filename, password_hash.as_bytes()).await } @@ -371,7 +429,7 @@ impl AuthBackend for FileBackend { async fn persist_registration_challenge( &self, website: &url::Url, - state: PasskeyRegistration + state: PasskeyRegistration, ) -> Result<String> { todo!() } @@ -380,7 +438,7 @@ impl AuthBackend for FileBackend { async fn retrieve_registration_challenge( &self, website: &url::Url, - challenge_id: &str + challenge_id: &str, ) -> Result<PasskeyRegistration> { todo!() } @@ -389,7 +447,7 @@ impl AuthBackend for FileBackend { async fn persist_authentication_challenge( &self, website: &url::Url, - state: PasskeyAuthentication + state: PasskeyAuthentication, ) -> Result<String> { todo!() } @@ -398,24 +456,28 @@ impl AuthBackend for FileBackend { async fn retrieve_authentication_challenge( &self, website: &url::Url, - challenge_id: &str + challenge_id: &str, ) -> Result<PasskeyAuthentication> { todo!() } #[tracing::instrument(skip(self))] - async fn list_user_credential_types(&self, website: &url::Url) -> Result<Vec<EnrolledCredential>> { + async fn list_user_credential_types( + &self, + website: &url::Url, + ) -> Result<Vec<EnrolledCredential>> { let mut creds = vec![]; - let password_file = self.path + let password_file = self + .path .join(FileBackend::url_to_dir(website)) .join("password"); tracing::debug!("Password file for {}: {}", website, password_file.display()); - match tokio::fs::metadata(password_file) - .await - { + match tokio::fs::metadata(password_file).await { Ok(_) => creds.push(EnrolledCredential::Password), - Err(err) => if err.kind() != std::io::ErrorKind::NotFound { - return Err(err) + Err(err) => { + if err.kind() != std::io::ErrorKind::NotFound { + return Err(err); + } } } diff --git a/src/indieauth/mod.rs b/src/indieauth/mod.rs index 00ae393..2f90a19 100644 --- a/src/indieauth/mod.rs +++ b/src/indieauth/mod.rs @@ -1,18 +1,29 @@ -use std::marker::PhantomData; -use microformats::types::Class; -use tracing::error; -use serde::Deserialize; +use crate::database::Storage; use axum::{ - extract::{Form, FromRef, Json, Query, State}, http::StatusCode, response::{Html, IntoResponse, Response} + extract::{Form, FromRef, Json, Query, State}, + http::StatusCode, + response::{Html, IntoResponse, Response}, }; #[cfg_attr(not(feature = "webauthn"), allow(unused_imports))] -use axum_extra::extract::{Host, cookie::{CookieJar, Cookie}}; -use axum_extra::{headers::{authorization::Bearer, Authorization, ContentType, HeaderMapExt}, TypedHeader}; -use crate::database::Storage; +use axum_extra::extract::{ + cookie::{Cookie, CookieJar}, + Host, +}; +use axum_extra::{ + headers::{authorization::Bearer, Authorization, ContentType, HeaderMapExt}, + TypedHeader, +}; use kittybox_indieauth::{ - AuthorizationRequest, AuthorizationResponse, ClientMetadata, Error, ErrorKind, GrantRequest, GrantResponse, GrantType, IntrospectionEndpointAuthMethod, Metadata, PKCEMethod, Profile, ProfileUrl, ResponseType, RevocationEndpointAuthMethod, Scope, Scopes, TokenData, TokenIntrospectionRequest, TokenIntrospectionResponse, TokenRevocationRequest + AuthorizationRequest, AuthorizationResponse, ClientMetadata, Error, ErrorKind, GrantRequest, + GrantResponse, GrantType, IntrospectionEndpointAuthMethod, Metadata, PKCEMethod, Profile, + ProfileUrl, ResponseType, RevocationEndpointAuthMethod, Scope, Scopes, TokenData, + TokenIntrospectionRequest, TokenIntrospectionResponse, TokenRevocationRequest, }; +use microformats::types::Class; +use serde::Deserialize; +use std::marker::PhantomData; use std::str::FromStr; +use tracing::error; pub mod backend; #[cfg(feature = "webauthn")] @@ -41,35 +52,42 @@ impl<A: AuthBackend> std::ops::Deref for User<A> { pub enum IndieAuthResourceError { InvalidRequest, Unauthorized, - InvalidToken + InvalidToken, } impl axum::response::IntoResponse for IndieAuthResourceError { fn into_response(self) -> axum::response::Response { use IndieAuthResourceError::*; match self { - Unauthorized => ( - StatusCode::UNAUTHORIZED, - [("WWW-Authenticate", "Bearer")] - ).into_response(), + Unauthorized => { + (StatusCode::UNAUTHORIZED, [("WWW-Authenticate", "Bearer")]).into_response() + } InvalidRequest => ( StatusCode::BAD_REQUEST, - Json(&serde_json::json!({"error": "invalid_request"})) - ).into_response(), + Json(&serde_json::json!({"error": "invalid_request"})), + ) + .into_response(), InvalidToken => ( StatusCode::UNAUTHORIZED, [("WWW-Authenticate", "Bearer, error=\"invalid_token\"")], - Json(&serde_json::json!({"error": "not_authorized"})) - ).into_response() + Json(&serde_json::json!({"error": "not_authorized"})), + ) + .into_response(), } } } -impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::extract::OptionalFromRequestParts<St> for User<A> { +impl<A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> + axum::extract::OptionalFromRequestParts<St> for User<A> +{ type Rejection = <Self as axum::extract::FromRequestParts<St>>::Rejection; - async fn from_request_parts(req: &mut axum::http::request::Parts, state: &St) -> Result<Option<Self>, Self::Rejection> { - let res = <Self as axum::extract::FromRequestParts<St>>::from_request_parts(req, state).await; + async fn from_request_parts( + req: &mut axum::http::request::Parts, + state: &St, + ) -> Result<Option<Self>, Self::Rejection> { + let res = + <Self as axum::extract::FromRequestParts<St>>::from_request_parts(req, state).await; match res { Ok(user) => Ok(Some(user)), @@ -79,14 +97,19 @@ impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::ext } } -impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::extract::FromRequestParts<St> for User<A> { +impl<A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> + axum::extract::FromRequestParts<St> for User<A> +{ type Rejection = IndieAuthResourceError; - async fn from_request_parts(req: &mut axum::http::request::Parts, state: &St) -> Result<Self, Self::Rejection> { + async fn from_request_parts( + req: &mut axum::http::request::Parts, + state: &St, + ) -> Result<Self, Self::Rejection> { let TypedHeader(Authorization(token)) = TypedHeader::<Authorization<Bearer>>::from_request_parts(req, state) - .await - .map_err(|_| IndieAuthResourceError::Unauthorized)?; + .await + .map_err(|_| IndieAuthResourceError::Unauthorized)?; let auth = A::from_ref(state); @@ -94,10 +117,7 @@ impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::ext .await .map_err(|_| IndieAuthResourceError::InvalidRequest)?; - auth.get_token( - &format!("https://{host}/").parse().unwrap(), - token.token() - ) + auth.get_token(&format!("https://{host}/").parse().unwrap(), token.token()) .await .unwrap() .ok_or(IndieAuthResourceError::InvalidToken) @@ -105,9 +125,7 @@ impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::ext } } -pub async fn metadata( - Host(host): Host -) -> Metadata { +pub async fn metadata(Host(host): Host) -> Metadata { let issuer: url::Url = format!("https://{}/", host).parse().unwrap(); let indieauth: url::Url = issuer.join("/.kittybox/indieauth/").unwrap(); @@ -117,18 +135,16 @@ pub async fn metadata( token_endpoint: indieauth.join("token").unwrap(), introspection_endpoint: indieauth.join("token_status").unwrap(), introspection_endpoint_auth_methods_supported: Some(vec![ - IntrospectionEndpointAuthMethod::Bearer + IntrospectionEndpointAuthMethod::Bearer, ]), revocation_endpoint: Some(indieauth.join("revoke_token").unwrap()), - revocation_endpoint_auth_methods_supported: Some(vec![ - RevocationEndpointAuthMethod::None - ]), + revocation_endpoint_auth_methods_supported: Some(vec![RevocationEndpointAuthMethod::None]), scopes_supported: Some(vec![ Scope::Create, Scope::Update, Scope::Delete, Scope::Media, - Scope::Profile + Scope::Profile, ]), response_types_supported: Some(vec![ResponseType::Code]), grant_types_supported: Some(vec![GrantType::AuthorizationCode, GrantType::RefreshToken]), @@ -145,27 +161,39 @@ async fn authorization_endpoint_get<A: AuthBackend, D: Storage + 'static>( Query(request): Query<AuthorizationRequest>, State(db): State<D>, State(http): State<reqwest_middleware::ClientWithMiddleware>, - State(auth): State<A> + State(auth): State<A>, ) -> Response { let me: url::Url = format!("https://{host}/").parse().unwrap(); // XXX: attempt fetching OAuth application metadata - let h_app: ClientMetadata = if request.client_id.domain().unwrap() == "localhost" && me.domain().unwrap() != "localhost" { + let h_app: ClientMetadata = if request.client_id.domain().unwrap() == "localhost" + && me.domain().unwrap() != "localhost" + { // If client is localhost, but we aren't localhost, generate synthetic metadata. tracing::warn!("Client is localhost, not fetching metadata"); - let mut metadata = ClientMetadata::new(request.client_id.clone(), request.client_id.clone()).unwrap(); + let mut metadata = + ClientMetadata::new(request.client_id.clone(), request.client_id.clone()).unwrap(); metadata.client_name = Some("Your locally hosted app".to_string()); metadata } else { tracing::debug!("Sending request to {} to fetch metadata", request.client_id); - let metadata_request = http.get(request.client_id.clone()) + let metadata_request = http + .get(request.client_id.clone()) .header("Accept", "application/json, text/html"); - match metadata_request.send().await - .and_then(|res| res.error_for_status() - .map_err(reqwest_middleware::Error::Reqwest)) - { - Ok(response) if response.headers().typed_get::<ContentType>().to_owned().map(mime::Mime::from).map(|m| m.type_() == "text" && m.subtype() == "html").unwrap_or(false) => { + match metadata_request.send().await.and_then(|res| { + res.error_for_status() + .map_err(reqwest_middleware::Error::Reqwest) + }) { + Ok(response) + if response + .headers() + .typed_get::<ContentType>() + .to_owned() + .map(mime::Mime::from) + .map(|m| m.type_() == "text" && m.subtype() == "html") + .unwrap_or(false) => + { let url = response.url().clone(); let text = response.text().await.unwrap(); tracing::debug!("Received {} bytes in response", text.len()); @@ -173,76 +201,95 @@ async fn authorization_endpoint_get<A: AuthBackend, D: Storage + 'static>( Ok(mf2) => { if let Some(relation) = mf2.rels.items.get(&request.redirect_uri) { if !relation.rels.iter().any(|i| i == "redirect_uri") { - return (StatusCode::BAD_REQUEST, - [("Content-Type", "text/plain")], - "The redirect_uri provided was declared as \ - something other than redirect_uri.") - .into_response() + return ( + StatusCode::BAD_REQUEST, + [("Content-Type", "text/plain")], + "The redirect_uri provided was declared as \ + something other than redirect_uri.", + ) + .into_response(); } } else if request.redirect_uri.origin() != request.client_id.origin() { - return (StatusCode::BAD_REQUEST, - [("Content-Type", "text/plain")], - "The redirect_uri didn't match the origin \ - and wasn't explicitly allowed. You were being tricked.") - .into_response() + return ( + StatusCode::BAD_REQUEST, + [("Content-Type", "text/plain")], + "The redirect_uri didn't match the origin \ + and wasn't explicitly allowed. You were being tricked.", + ) + .into_response(); } - if let Some(app) = mf2.items + if let Some(app) = mf2 + .items .iter() - .find(|&i| i.r#type.iter() - .any(|i| { + .find(|&i| { + i.r#type.iter().any(|i| { *i == Class::from_str("h-app").unwrap() || *i == Class::from_str("h-x-app").unwrap() }) - ) + }) .cloned() { // Create a synthetic metadata document. Be forgiving. let mut metadata = ClientMetadata::new( request.client_id.clone(), - app.properties.get("url") + app.properties + .get("url") .and_then(|v| v.first()) .and_then(|i| match i { - microformats::types::PropertyValue::Url(url) => Some(url.clone()), - _ => None + microformats::types::PropertyValue::Url(url) => { + Some(url.clone()) + } + _ => None, }) - .unwrap_or_else(|| request.client_id.clone()) - ).unwrap(); + .unwrap_or_else(|| request.client_id.clone()), + ) + .unwrap(); - metadata.client_name = app.properties.get("name") + metadata.client_name = app + .properties + .get("name") .and_then(|v| v.first()) .and_then(|i| match i { - microformats::types::PropertyValue::Plain(name) => Some(name.to_owned()), - _ => None + microformats::types::PropertyValue::Plain(name) => { + Some(name.to_owned()) + } + _ => None, }); metadata.redirect_uris = mf2.rels.by_rels().remove("redirect_uri"); metadata } else { - return (StatusCode::BAD_REQUEST, [("Content-Type", "text/plain")], "No h-app or JSON application metadata found.").into_response() + return ( + StatusCode::BAD_REQUEST, + [("Content-Type", "text/plain")], + "No h-app or JSON application metadata found.", + ) + .into_response(); } - }, + } Err(err) => { tracing::error!("Error parsing application metadata: {}", err); return ( StatusCode::BAD_REQUEST, [("Content-Type", "text/plain")], - "Parsing h-app metadata failed.").into_response() + "Parsing h-app metadata failed.", + ) + .into_response(); } } - }, + } Ok(response) => match response.json::<ClientMetadata>().await { - Ok(client_metadata) => { - client_metadata - }, + Ok(client_metadata) => client_metadata, Err(err) => { tracing::error!("Error parsing JSON application metadata: {}", err); return ( StatusCode::BAD_REQUEST, [("Content-Type", "text/plain")], - format!("Parsing OAuth2 JSON app metadata failed: {}", err) - ).into_response() + format!("Parsing OAuth2 JSON app metadata failed: {}", err), + ) + .into_response(); } }, Err(err) => { @@ -250,27 +297,32 @@ async fn authorization_endpoint_get<A: AuthBackend, D: Storage + 'static>( return ( StatusCode::BAD_REQUEST, [("Content-Type", "text/plain")], - format!("Fetching app metadata failed: {}", err) - ).into_response() + format!("Fetching app metadata failed: {}", err), + ) + .into_response(); } } }; tracing::debug!("Application metadata: {:#?}", h_app); - Html(kittybox_frontend_renderer::Template { - title: "Confirm sign-in via IndieAuth", - blog_name: "Kittybox", - feeds: vec![], - user: None, - content: kittybox_frontend_renderer::AuthorizationRequestPage { - request, - credentials: auth.list_user_credential_types(&me).await.unwrap(), - user: db.get_post(me.as_str()).await.unwrap().unwrap(), - app: h_app - }.to_string(), - }.to_string()) - .into_response() + Html( + kittybox_frontend_renderer::Template { + title: "Confirm sign-in via IndieAuth", + blog_name: "Kittybox", + feeds: vec![], + user: None, + content: kittybox_frontend_renderer::AuthorizationRequestPage { + request, + credentials: auth.list_user_credential_types(&me).await.unwrap(), + user: db.get_post(me.as_str()).await.unwrap().unwrap(), + app: h_app, + } + .to_string(), + } + .to_string(), + ) + .into_response() } #[derive(Deserialize, Debug)] @@ -278,7 +330,7 @@ async fn authorization_endpoint_get<A: AuthBackend, D: Storage + 'static>( enum Credential { Password(String), #[cfg(feature = "webauthn")] - WebAuthn(::webauthn::prelude::PublicKeyCredential) + WebAuthn(::webauthn::prelude::PublicKeyCredential), } // The IndieAuth standard doesn't prescribe a format for confirming @@ -291,7 +343,7 @@ enum Credential { #[derive(Deserialize, Debug)] struct AuthorizationConfirmation { authorization_method: Credential, - request: AuthorizationRequest + request: AuthorizationRequest, } #[tracing::instrument(skip(auth, credential))] @@ -299,18 +351,14 @@ async fn verify_credential<A: AuthBackend>( auth: &A, website: &url::Url, credential: Credential, - #[cfg_attr(not(feature = "webauthn"), allow(unused_variables))] - challenge_id: Option<&str> + #[cfg_attr(not(feature = "webauthn"), allow(unused_variables))] challenge_id: Option<&str>, ) -> std::io::Result<bool> { match credential { Credential::Password(password) => auth.verify_password(website, password).await, #[cfg(feature = "webauthn")] - Credential::WebAuthn(credential) => webauthn::verify( - auth, - website, - credential, - challenge_id.unwrap() - ).await + Credential::WebAuthn(credential) => { + webauthn::verify(auth, website, credential, challenge_id.unwrap()).await + } } } @@ -323,7 +371,8 @@ async fn authorization_endpoint_confirm<A: AuthBackend>( ) -> Response { tracing::debug!("Received authorization confirmation from user"); #[cfg(feature = "webauthn")] - let challenge_id = cookies.get(webauthn::CHALLENGE_ID_COOKIE) + let challenge_id = cookies + .get(webauthn::CHALLENGE_ID_COOKIE) .map(|cookie| cookie.value()); #[cfg(not(feature = "webauthn"))] let challenge_id = None; @@ -331,14 +380,16 @@ async fn authorization_endpoint_confirm<A: AuthBackend>( let website = format!("https://{}/", host).parse().unwrap(); let AuthorizationConfirmation { authorization_method: credential, - request: mut auth + request: mut auth, } = confirmation; match verify_credential(&backend, &website, credential, challenge_id).await { - Ok(verified) => if !verified { - error!("User failed verification, bailing out."); - return StatusCode::UNAUTHORIZED.into_response(); - }, + Ok(verified) => { + if !verified { + error!("User failed verification, bailing out."); + return StatusCode::UNAUTHORIZED.into_response(); + } + } Err(err) => { error!("Error while verifying credential: {}", err); return StatusCode::INTERNAL_SERVER_ERROR.into_response(); @@ -365,9 +416,14 @@ async fn authorization_endpoint_confirm<A: AuthBackend>( let location = { let mut uri = redirect_uri; - uri.set_query(Some(&serde_urlencoded::to_string( - AuthorizationResponse { code, state, iss: website } - ).unwrap())); + uri.set_query(Some( + &serde_urlencoded::to_string(AuthorizationResponse { + code, + state, + iss: website, + }) + .unwrap(), + )); uri }; @@ -375,10 +431,11 @@ async fn authorization_endpoint_confirm<A: AuthBackend>( // DO NOT SET `StatusCode::FOUND` here! `fetch()` cannot read from // redirects, it can only follow them or choose to receive an // opaque response instead that is completely useless - (StatusCode::NO_CONTENT, - [("Location", location.as_str())], - #[cfg(feature = "webauthn")] - cookies.remove(Cookie::from(webauthn::CHALLENGE_ID_COOKIE)) + ( + StatusCode::NO_CONTENT, + [("Location", location.as_str())], + #[cfg(feature = "webauthn")] + cookies.remove(Cookie::from(webauthn::CHALLENGE_ID_COOKIE)), ) .into_response() } @@ -396,15 +453,18 @@ async fn authorization_endpoint_post<A: AuthBackend, D: Storage + 'static>( code, client_id, redirect_uri, - code_verifier + code_verifier, } => { let request: AuthorizationRequest = match backend.get_code(&code).await { Ok(Some(request)) => request, - Ok(None) => return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("The provided authorization code is invalid.".to_string()), - error_uri: None - }.into_response(), + Ok(None) => { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("The provided authorization code is invalid.".to_string()), + error_uri: None, + } + .into_response() + } Err(err) => { tracing::error!("Error retrieving auth request: {}", err); return StatusCode::INTERNAL_SERVER_ERROR.into_response(); @@ -414,51 +474,66 @@ async fn authorization_endpoint_post<A: AuthBackend, D: Storage + 'static>( return Error { kind: ErrorKind::InvalidGrant, msg: Some("This authorization code isn't yours.".to_string()), - error_uri: None - }.into_response() + error_uri: None, + } + .into_response(); } if redirect_uri != request.redirect_uri { return Error { kind: ErrorKind::InvalidGrant, - msg: Some("This redirect_uri doesn't match the one the code has been sent to.".to_string()), - error_uri: None - }.into_response() + msg: Some( + "This redirect_uri doesn't match the one the code has been sent to." + .to_string(), + ), + error_uri: None, + } + .into_response(); } if !request.code_challenge.verify(code_verifier) { return Error { kind: ErrorKind::InvalidGrant, msg: Some("The PKCE challenge failed.".to_string()), // are RFCs considered human-readable? 😝 - error_uri: "https://datatracker.ietf.org/doc/html/rfc7636#section-4.6".parse().ok() - }.into_response() + error_uri: "https://datatracker.ietf.org/doc/html/rfc7636#section-4.6" + .parse() + .ok(), + } + .into_response(); } let me: url::Url = format!("https://{}/", host).parse().unwrap(); if request.me.unwrap() != me { return Error { kind: ErrorKind::InvalidGrant, msg: Some("This authorization endpoint does not serve this user.".to_string()), - error_uri: None - }.into_response() + error_uri: None, + } + .into_response(); } - let profile = if request.scope.as_ref() - .map(|s| s.has(&Scope::Profile)) - .unwrap_or_default() + let profile = if request + .scope + .as_ref() + .map(|s| s.has(&Scope::Profile)) + .unwrap_or_default() { match get_profile( db, me.as_str(), - request.scope.as_ref() + request + .scope + .as_ref() .map(|s| s.has(&Scope::Email)) - .unwrap_or_default() - ).await { + .unwrap_or_default(), + ) + .await + { Ok(profile) => { tracing::debug!("Retrieved profile: {:?}", profile); profile - }, + } Err(err) => { tracing::error!("Error retrieving profile from database: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } } } else { @@ -466,12 +541,15 @@ async fn authorization_endpoint_post<A: AuthBackend, D: Storage + 'static>( }; GrantResponse::ProfileUrl(ProfileUrl { me, profile }).into_response() - }, + } _ => Error { kind: ErrorKind::InvalidGrant, msg: Some("The provided grant_type is unusable on this endpoint.".to_string()), - error_uri: "https://indieauth.spec.indieweb.org/#redeeming-the-authorization-code".parse().ok() - }.into_response() + error_uri: "https://indieauth.spec.indieweb.org/#redeeming-the-authorization-code" + .parse() + .ok(), + } + .into_response(), } } @@ -485,36 +563,40 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( #[inline] fn prepare_access_token(me: url::Url, client_id: url::Url, scope: Scopes) -> TokenData { TokenData { - me, client_id, scope, + me, + client_id, + scope, exp: (std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - + std::time::Duration::from_secs(ACCESS_TOKEN_VALIDITY)) - .as_secs() - .into(), + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + + std::time::Duration::from_secs(ACCESS_TOKEN_VALIDITY)) + .as_secs() + .into(), iat: std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_secs() - .into() + .into(), } } #[inline] fn prepare_refresh_token(me: url::Url, client_id: url::Url, scope: Scopes) -> TokenData { TokenData { - me, client_id, scope, + me, + client_id, + scope, exp: (std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - + std::time::Duration::from_secs(REFRESH_TOKEN_VALIDITY)) - .as_secs() - .into(), + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + + std::time::Duration::from_secs(REFRESH_TOKEN_VALIDITY)) + .as_secs() + .into(), iat: std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_secs() - .into() + .into(), } } @@ -525,15 +607,18 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( code, client_id, redirect_uri, - code_verifier + code_verifier, } => { let request: AuthorizationRequest = match backend.get_code(&code).await { Ok(Some(request)) => request, - Ok(None) => return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("The provided authorization code is invalid.".to_string()), - error_uri: None - }.into_response(), + Ok(None) => { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("The provided authorization code is invalid.".to_string()), + error_uri: None, + } + .into_response() + } Err(err) => { tracing::error!("Error retrieving auth request: {}", err); return StatusCode::INTERNAL_SERVER_ERROR.into_response(); @@ -542,33 +627,46 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( tracing::debug!("Retrieved authorization request: {:?}", request); - let scope = if let Some(scope) = request.scope { scope } else { + let scope = if let Some(scope) = request.scope { + scope + } else { return Error { kind: ErrorKind::InvalidScope, msg: Some("Tokens cannot be issued if no scopes are requested.".to_string()), - error_uri: "https://indieauth.spec.indieweb.org/#access-token-response".parse().ok() - }.into_response(); + error_uri: "https://indieauth.spec.indieweb.org/#access-token-response" + .parse() + .ok(), + } + .into_response(); }; if client_id != request.client_id { return Error { kind: ErrorKind::InvalidGrant, msg: Some("This authorization code isn't yours.".to_string()), - error_uri: None - }.into_response() + error_uri: None, + } + .into_response(); } if redirect_uri != request.redirect_uri { return Error { kind: ErrorKind::InvalidGrant, - msg: Some("This redirect_uri doesn't match the one the code has been sent to.".to_string()), - error_uri: None - }.into_response() + msg: Some( + "This redirect_uri doesn't match the one the code has been sent to." + .to_string(), + ), + error_uri: None, + } + .into_response(); } if !request.code_challenge.verify(code_verifier) { return Error { kind: ErrorKind::InvalidGrant, msg: Some("The PKCE challenge failed.".to_string()), - error_uri: "https://datatracker.ietf.org/doc/html/rfc7636#section-4.6".parse().ok() - }.into_response(); + error_uri: "https://datatracker.ietf.org/doc/html/rfc7636#section-4.6" + .parse() + .ok(), + } + .into_response(); } // Note: we can trust the `request.me` value, since we set @@ -577,30 +675,32 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( return Error { kind: ErrorKind::InvalidGrant, msg: Some("This authorization endpoint does not serve this user.".to_string()), - error_uri: None - }.into_response() + error_uri: None, + } + .into_response(); } let profile = if dbg!(scope.has(&Scope::Profile)) { - match get_profile( - db, - me.as_str(), - scope.has(&Scope::Email) - ).await { + match get_profile(db, me.as_str(), scope.has(&Scope::Email)).await { Ok(profile) => dbg!(profile), Err(err) => { tracing::error!("Error retrieving profile from database: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } } } else { None }; - let access_token = match backend.create_token( - prepare_access_token(me.clone(), client_id.clone(), scope.clone()) - ).await { + let access_token = match backend + .create_token(prepare_access_token( + me.clone(), + client_id.clone(), + scope.clone(), + )) + .await + { Ok(token) => token, Err(err) => { tracing::error!("Error creating access token: {}", err); @@ -608,9 +708,10 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( } }; // TODO: only create refresh token if user allows it - let refresh_token = match backend.create_refresh_token( - prepare_refresh_token(me.clone(), client_id, scope.clone()) - ).await { + let refresh_token = match backend + .create_refresh_token(prepare_refresh_token(me.clone(), client_id, scope.clone())) + .await + { Ok(token) => token, Err(err) => { tracing::error!("Error creating refresh token: {}", err); @@ -626,24 +727,28 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( scope: Some(scope), expires_in: Some(ACCESS_TOKEN_VALIDITY), refresh_token: Some(refresh_token), - state: None - }.into_response() - }, + state: None, + } + .into_response() + } GrantRequest::RefreshToken { refresh_token, client_id, - scope + scope, } => { let data = match backend.get_refresh_token(&me, &refresh_token).await { Ok(Some(token)) => token, - Ok(None) => return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("This refresh token is not valid.".to_string()), - error_uri: None - }.into_response(), + Ok(None) => { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("This refresh token is not valid.".to_string()), + error_uri: None, + } + .into_response() + } Err(err) => { tracing::error!("Error retrieving refresh token: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } }; @@ -651,17 +756,22 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( return Error { kind: ErrorKind::InvalidGrant, msg: Some("This refresh token is not yours.".to_string()), - error_uri: None - }.into_response(); + error_uri: None, + } + .into_response(); } let scope = if let Some(scope) = scope { if !data.scope.has_all(scope.as_ref()) { return Error { kind: ErrorKind::InvalidScope, - msg: Some("You can't request additional scopes through the refresh token grant.".to_string()), - error_uri: None - }.into_response(); + msg: Some( + "You can't request additional scopes through the refresh token grant." + .to_string(), + ), + error_uri: None, + } + .into_response(); } scope @@ -670,27 +780,27 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( data.scope }; - let profile = if scope.has(&Scope::Profile) { - match get_profile( - db, - data.me.as_str(), - scope.has(&Scope::Email) - ).await { + match get_profile(db, data.me.as_str(), scope.has(&Scope::Email)).await { Ok(profile) => profile, Err(err) => { tracing::error!("Error retrieving profile from database: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } } } else { None }; - let access_token = match backend.create_token( - prepare_access_token(data.me.clone(), client_id.clone(), scope.clone()) - ).await { + let access_token = match backend + .create_token(prepare_access_token( + data.me.clone(), + client_id.clone(), + scope.clone(), + )) + .await + { Ok(token) => token, Err(err) => { tracing::error!("Error creating access token: {}", err); @@ -699,9 +809,14 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( }; let old_refresh_token = refresh_token; - let refresh_token = match backend.create_refresh_token( - prepare_refresh_token(data.me.clone(), client_id, scope.clone()) - ).await { + let refresh_token = match backend + .create_refresh_token(prepare_refresh_token( + data.me.clone(), + client_id, + scope.clone(), + )) + .await + { Ok(token) => token, Err(err) => { tracing::error!("Error creating refresh token: {}", err); @@ -721,8 +836,9 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( scope: Some(scope), expires_in: Some(ACCESS_TOKEN_VALIDITY), refresh_token: Some(refresh_token), - state: None - }.into_response() + state: None, + } + .into_response() } } } @@ -740,26 +856,39 @@ async fn introspection_endpoint_post<A: AuthBackend>( // Check authentication first match backend.get_token(&me, auth_token.token()).await { - Ok(Some(token)) => if !token.scope.has(&Scope::custom(KITTYBOX_TOKEN_STATUS)) { - return (StatusCode::UNAUTHORIZED, Json(json!({ - "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope - }))).into_response(); - }, - Ok(None) => return (StatusCode::UNAUTHORIZED, Json(json!({ - "error": kittybox_indieauth::ResourceErrorKind::InvalidToken - }))).into_response(), + Ok(Some(token)) => { + if !token.scope.has(&Scope::custom(KITTYBOX_TOKEN_STATUS)) { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({ + "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope + })), + ) + .into_response(); + } + } + Ok(None) => { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({ + "error": kittybox_indieauth::ResourceErrorKind::InvalidToken + })), + ) + .into_response() + } Err(err) => { tracing::error!("Error retrieving token data for introspection: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } } - let response: TokenIntrospectionResponse = match backend.get_token(&me, &token_request.token).await { - Ok(maybe_data) => maybe_data.into(), - Err(err) => { - tracing::error!("Error retrieving token data: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() - } - }; + let response: TokenIntrospectionResponse = + match backend.get_token(&me, &token_request.token).await { + Ok(maybe_data) => maybe_data.into(), + Err(err) => { + tracing::error!("Error retrieving token data: {}", err); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + }; response.into_response() } @@ -787,7 +916,7 @@ async fn revocation_endpoint_post<A: AuthBackend>( async fn get_profile<D: Storage + 'static>( db: D, url: &str, - email: bool + email: bool, ) -> crate::database::Result<Option<Profile>> { fn get_first(v: serde_json::Value) -> Option<String> { match v { @@ -796,10 +925,10 @@ async fn get_profile<D: Storage + 'static>( match a.pop() { Some(serde_json::Value::String(s)) => Some(s), Some(serde_json::Value::Object(mut o)) => o.remove("value").and_then(get_first), - _ => None + _ => None, } - }, - _ => None + } + _ => None, } } @@ -807,15 +936,26 @@ async fn get_profile<D: Storage + 'static>( // Ruthlessly manually destructure the MF2 document to save memory let mut properties = match mf2.as_object_mut().unwrap().remove("properties") { Some(serde_json::Value::Object(props)) => props, - _ => unreachable!() + _ => unreachable!(), }; drop(mf2); let name = properties.remove("name").and_then(get_first); - let url = properties.remove("uid").and_then(get_first).and_then(|u| u.parse().ok()); - let photo = properties.remove("photo").and_then(get_first).and_then(|u| u.parse().ok()); + let url = properties + .remove("uid") + .and_then(get_first) + .and_then(|u| u.parse().ok()); + let photo = properties + .remove("photo") + .and_then(get_first) + .and_then(|u| u.parse().ok()); let email = properties.remove("name").and_then(get_first); - Profile { name, url, photo, email } + Profile { + name, + url, + photo, + email, + } })) } @@ -823,7 +963,7 @@ async fn userinfo_endpoint_get<A: AuthBackend, D: Storage + 'static>( Host(host): Host, TypedHeader(Authorization(auth_token)): TypedHeader<Authorization<Bearer>>, State(backend): State<A>, - State(db): State<D> + State(db): State<D>, ) -> Response { use serde_json::json; @@ -832,14 +972,22 @@ async fn userinfo_endpoint_get<A: AuthBackend, D: Storage + 'static>( match backend.get_token(&me, auth_token.token()).await { Ok(Some(token)) => { if token.expired() { - return (StatusCode::UNAUTHORIZED, Json(json!({ - "error": kittybox_indieauth::ResourceErrorKind::InvalidToken - }))).into_response(); + return ( + StatusCode::UNAUTHORIZED, + Json(json!({ + "error": kittybox_indieauth::ResourceErrorKind::InvalidToken + })), + ) + .into_response(); } if !token.scope.has(&Scope::Profile) { - return (StatusCode::UNAUTHORIZED, Json(json!({ - "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope - }))).into_response(); + return ( + StatusCode::UNAUTHORIZED, + Json(json!({ + "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope + })), + ) + .into_response(); } match get_profile(db, me.as_str(), token.scope.has(&Scope::Email)).await { @@ -847,17 +995,19 @@ async fn userinfo_endpoint_get<A: AuthBackend, D: Storage + 'static>( Ok(None) => Json(json!({ // We do this because ResourceErrorKind is IndieAuth errors only "error": "invalid_request" - })).into_response(), + })) + .into_response(), Err(err) => { tracing::error!("Error retrieving profile from database: {}", err); StatusCode::INTERNAL_SERVER_ERROR.into_response() } } - }, + } Ok(None) => Json(json!({ "error": kittybox_indieauth::ResourceErrorKind::InvalidToken - })).into_response(), + })) + .into_response(), Err(err) => { tracing::error!("Error reading token: {}", err); @@ -871,57 +1021,51 @@ where S: Storage + FromRef<St> + 'static, A: AuthBackend + FromRef<St>, reqwest_middleware::ClientWithMiddleware: FromRef<St>, - St: Clone + Send + Sync + 'static + St: Clone + Send + Sync + 'static, { - use axum::routing::{Router, get, post}; + use axum::routing::{get, post, Router}; Router::new() .nest( "/.kittybox/indieauth", Router::new() - .route("/metadata", - get(metadata)) + .route("/metadata", get(metadata)) .route( "/auth", get(authorization_endpoint_get::<A, S>) - .post(authorization_endpoint_post::<A, S>)) - .route( - "/auth/confirm", - post(authorization_endpoint_confirm::<A>)) - .route( - "/token", - post(token_endpoint_post::<A, S>)) - .route( - "/token_status", - post(introspection_endpoint_post::<A>)) - .route( - "/revoke_token", - post(revocation_endpoint_post::<A>)) + .post(authorization_endpoint_post::<A, S>), + ) + .route("/auth/confirm", post(authorization_endpoint_confirm::<A>)) + .route("/token", post(token_endpoint_post::<A, S>)) + .route("/token_status", post(introspection_endpoint_post::<A>)) + .route("/revoke_token", post(revocation_endpoint_post::<A>)) + .route("/userinfo", get(userinfo_endpoint_get::<A, S>)) .route( - "/userinfo", - get(userinfo_endpoint_get::<A, S>)) - - .route("/webauthn/pre_register", - get( - #[cfg(feature = "webauthn")] webauthn::webauthn_pre_register::<A, S>, - #[cfg(not(feature = "webauthn"))] || std::future::ready(axum::http::StatusCode::NOT_FOUND) - ) + "/webauthn/pre_register", + get( + #[cfg(feature = "webauthn")] + webauthn::webauthn_pre_register::<A, S>, + #[cfg(not(feature = "webauthn"))] + || std::future::ready(axum::http::StatusCode::NOT_FOUND), + ), ) - .layer(tower_http::cors::CorsLayer::new() - .allow_methods([ - axum::http::Method::GET, - axum::http::Method::POST - ]) - .allow_origin(tower_http::cors::Any)) + .layer( + tower_http::cors::CorsLayer::new() + .allow_methods([axum::http::Method::GET, axum::http::Method::POST]) + .allow_origin(tower_http::cors::Any), + ), ) .route( "/.well-known/oauth-authorization-server", - get(|| std::future::ready( - (StatusCode::FOUND, - [("Location", - "/.kittybox/indieauth/metadata")] - ).into_response() - )) + get(|| { + std::future::ready( + ( + StatusCode::FOUND, + [("Location", "/.kittybox/indieauth/metadata")], + ) + .into_response(), + ) + }), ) } @@ -929,9 +1073,10 @@ where mod tests { #[test] fn test_deserialize_authorization_confirmation() { - use super::{Credential, AuthorizationConfirmation}; + use super::{AuthorizationConfirmation, Credential}; - let confirmation = serde_json::from_str::<AuthorizationConfirmation>(r#"{ + let confirmation = serde_json::from_str::<AuthorizationConfirmation>( + r#"{ "request":{ "response_type": "code", "client_id": "https://quill.p3k.io/", @@ -942,12 +1087,14 @@ mod tests { "scope": "create+media" }, "authorization_method": "swordfish" - }"#).unwrap(); + }"#, + ) + .unwrap(); match confirmation.authorization_method { Credential::Password(password) => assert_eq!(password.as_str(), "swordfish"), #[allow(unreachable_patterns)] - other => panic!("Incorrect credential: {:?}", other) + other => panic!("Incorrect credential: {:?}", other), } assert_eq!(confirmation.request.state.as_ref(), "10101010"); } diff --git a/src/indieauth/webauthn.rs b/src/indieauth/webauthn.rs index 0757e72..80d210c 100644 --- a/src/indieauth/webauthn.rs +++ b/src/indieauth/webauthn.rs @@ -1,10 +1,17 @@ use axum::{ extract::Json, + http::StatusCode, response::{IntoResponse, Response}, - http::StatusCode, Extension + Extension, +}; +use axum_extra::extract::{ + cookie::{Cookie, CookieJar}, + Host, +}; +use axum_extra::{ + headers::{authorization::Bearer, Authorization}, + TypedHeader, }; -use axum_extra::extract::{Host, cookie::{CookieJar, Cookie}}; -use axum_extra::{TypedHeader, headers::{authorization::Bearer, Authorization}}; use super::backend::AuthBackend; use crate::database::Storage; @@ -12,40 +19,33 @@ use crate::database::Storage; pub(crate) const CHALLENGE_ID_COOKIE: &str = "kittybox_webauthn_challenge_id"; macro_rules! bail { - ($msg:literal, $err:expr) => { - { - ::tracing::error!($msg, $err); - return ::axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response() - } - } + ($msg:literal, $err:expr) => {{ + ::tracing::error!($msg, $err); + return ::axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response(); + }}; } pub async fn webauthn_pre_register<A: AuthBackend, D: Storage + 'static>( Host(host): Host, Extension(db): Extension<D>, Extension(auth): Extension<A>, - cookies: CookieJar + cookies: CookieJar, ) -> Response { let uid = format!("https://{}/", host.clone()); let uid_url: url::Url = uid.parse().unwrap(); // This will not find an h-card in onboarding! let display_name = match db.get_post(&uid).await { Ok(hcard) => match hcard { - Some(mut hcard) => { - match hcard["properties"]["uid"][0].take() { - serde_json::Value::String(name) => name, - _ => String::default() - } + Some(mut hcard) => match hcard["properties"]["uid"][0].take() { + serde_json::Value::String(name) => name, + _ => String::default(), }, - None => String::default() + None => String::default(), }, - Err(err) => bail!("Error retrieving h-card: {}", err) + Err(err) => bail!("Error retrieving h-card: {}", err), }; - let webauthn = webauthn::WebauthnBuilder::new( - &host, - &uid_url - ) + let webauthn = webauthn::WebauthnBuilder::new(&host, &uid_url) .unwrap() .rp_name("Kittybox") .build() @@ -58,10 +58,10 @@ pub async fn webauthn_pre_register<A: AuthBackend, D: Storage + 'static>( webauthn::prelude::Uuid::nil(), &uid, &display_name, - Some(vec![]) + Some(vec![]), ) { Ok((challenge, state)) => (challenge, state), - Err(err) => bail!("Error generating WebAuthn registration data: {}", err) + Err(err) => bail!("Error generating WebAuthn registration data: {}", err), }; match auth.persist_registration_challenge(&uid_url, state).await { @@ -69,11 +69,12 @@ pub async fn webauthn_pre_register<A: AuthBackend, D: Storage + 'static>( cookies.add( Cookie::build((CHALLENGE_ID_COOKIE, challenge_id)) .secure(true) - .finish() + .finish(), ), - Json(challenge) - ).into_response(), - Err(err) => bail!("Failed to persist WebAuthn challenge: {}", err) + Json(challenge), + ) + .into_response(), + Err(err) => bail!("Failed to persist WebAuthn challenge: {}", err), } } @@ -82,39 +83,36 @@ pub async fn webauthn_register<A: AuthBackend>( Json(credential): Json<webauthn::prelude::RegisterPublicKeyCredential>, // TODO determine if we can use a cookie maybe? user_credential: Option<TypedHeader<Authorization<Bearer>>>, - Extension(auth): Extension<A> + Extension(auth): Extension<A>, ) -> Response { let uid = format!("https://{}/", host.clone()); let uid_url: url::Url = uid.parse().unwrap(); let pubkeys = match auth.list_webauthn_pubkeys(&uid_url).await { Ok(pubkeys) => pubkeys, - Err(err) => bail!("Error enumerating existing WebAuthn credentials: {}", err) + Err(err) => bail!("Error enumerating existing WebAuthn credentials: {}", err), }; if !pubkeys.is_empty() { if let Some(TypedHeader(Authorization(token))) = user_credential { // TODO check validity of the credential } else { - return StatusCode::UNAUTHORIZED.into_response() + return StatusCode::UNAUTHORIZED.into_response(); } } - return StatusCode::OK.into_response() + return StatusCode::OK.into_response(); } pub(crate) async fn verify<A: AuthBackend>( auth: &A, website: &url::Url, credential: webauthn::prelude::PublicKeyCredential, - challenge_id: &str + challenge_id: &str, ) -> std::io::Result<bool> { let host = website.host_str().unwrap(); - let webauthn = webauthn::WebauthnBuilder::new( - host, - website - ) + let webauthn = webauthn::WebauthnBuilder::new(host, website) .unwrap() .rp_name("Kittybox") .build() @@ -122,12 +120,14 @@ pub(crate) async fn verify<A: AuthBackend>( match webauthn.finish_passkey_authentication( &credential, - &auth.retrieve_authentication_challenge(&website, challenge_id).await? + &auth + .retrieve_authentication_challenge(&website, challenge_id) + .await?, ) { Err(err) => { tracing::error!("WebAuthn error: {}", err); Ok(false) - }, + } Ok(authentication_result) => { let counter = authentication_result.counter(); let cred_id = authentication_result.cred_id(); diff --git a/src/lib.rs b/src/lib.rs index 4aeaca5..a52db4c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,22 +4,28 @@ use std::sync::Arc; use axum::extract::{FromRef, FromRequestParts, OptionalFromRequestParts}; -use axum_extra::extract::{cookie::{Cookie, Key}, SignedCookieJar}; +use axum_extra::extract::{ + cookie::{Cookie, Key}, + SignedCookieJar, +}; use database::{FileStorage, PostgresStorage, Storage}; use indieauth::backend::{AuthBackend, FileBackend as FileAuthBackend}; use kittybox_util::queue::JobQueue; -use media::storage::{MediaStore, file::FileStore as FileMediaStore}; -use tokio::{sync::{Mutex, RwLock}, task::JoinSet}; +use media::storage::{file::FileStore as FileMediaStore, MediaStore}; +use tokio::{ + sync::{Mutex, RwLock}, + task::JoinSet, +}; use webmentions::queue::PostgresJobQueue; /// Database abstraction layer for Kittybox, allowing the CMS to work with any kind of database. pub mod database; pub mod frontend; +pub mod indieauth; +pub mod login; pub mod media; pub mod micropub; -pub mod indieauth; pub mod webmentions; -pub mod login; //pub mod admin; const OAUTH2_SOFTWARE_ID: &str = "6f2eee84-c22c-4c9e-b900-10d4e97273c8"; @@ -27,10 +33,10 @@ const OAUTH2_SOFTWARE_ID: &str = "6f2eee84-c22c-4c9e-b900-10d4e97273c8"; #[derive(Clone)] pub struct AppState<A, S, M, Q> where -A: AuthBackend + Sized + 'static, -S: Storage + Sized + 'static, -M: MediaStore + Sized + 'static, -Q: JobQueue<webmentions::Webmention> + Sized + A: AuthBackend + Sized + 'static, + S: Storage + Sized + 'static, + M: MediaStore + Sized + 'static, + Q: JobQueue<webmentions::Webmention> + Sized, { pub auth_backend: A, pub storage: S, @@ -39,7 +45,7 @@ Q: JobQueue<webmentions::Webmention> + Sized pub http: reqwest_middleware::ClientWithMiddleware, pub background_jobs: Arc<Mutex<JoinSet<()>>>, pub cookie_key: Key, - pub session_store: SessionStore + pub session_store: SessionStore, } pub type SessionStore = Arc<RwLock<std::collections::HashMap<uuid::Uuid, Session>>>; @@ -60,7 +66,11 @@ pub struct NoSessionError; impl axum::response::IntoResponse for NoSessionError { fn into_response(self) -> axum::response::Response { // TODO: prettier error message - (axum::http::StatusCode::UNAUTHORIZED, "You are not logged in, but this page requires a session.").into_response() + ( + axum::http::StatusCode::UNAUTHORIZED, + "You are not logged in, but this page requires a session.", + ) + .into_response() } } @@ -72,11 +82,17 @@ where { type Rejection = std::convert::Infallible; - async fn from_request_parts(parts: &mut axum::http::request::Parts, state: &S) -> Result<Option<Self>, Self::Rejection> { - let jar = SignedCookieJar::<Key>::from_request_parts(parts, state).await.unwrap(); + async fn from_request_parts( + parts: &mut axum::http::request::Parts, + state: &S, + ) -> Result<Option<Self>, Self::Rejection> { + let jar = SignedCookieJar::<Key>::from_request_parts(parts, state) + .await + .unwrap(); let session_store = SessionStore::from_ref(state).read_owned().await; - Ok(jar.get("session_id") + Ok(jar + .get("session_id") .as_ref() .map(Cookie::value_trimmed) .and_then(|id| uuid::Uuid::parse_str(id).ok()) @@ -103,7 +119,10 @@ where // have to repeat this magic invocation. impl<S, M, Q> FromRef<AppState<Self, S, M, Q>> for FileAuthBackend -where S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + S: Storage, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<Self, S, M, Q>) -> Self { input.auth_backend.clone() @@ -111,7 +130,10 @@ where S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> } impl<A, M, Q> FromRef<AppState<A, Self, M, Q>> for PostgresStorage -where A: AuthBackend, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, Self, M, Q>) -> Self { input.storage.clone() @@ -119,7 +141,10 @@ where A: AuthBackend, M: MediaStore, Q: JobQueue<webmentions::Webmention> } impl<A, M, Q> FromRef<AppState<A, Self, M, Q>> for FileStorage -where A: AuthBackend, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, Self, M, Q>) -> Self { input.storage.clone() @@ -128,7 +153,10 @@ where A: AuthBackend, M: MediaStore, Q: JobQueue<webmentions::Webmention> impl<A, S, Q> FromRef<AppState<A, S, Self, Q>> for FileMediaStore // where A: AuthBackend, S: Storage -where A: AuthBackend, S: Storage, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + S: Storage, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, S, Self, Q>) -> Self { input.media_store.clone() @@ -136,7 +164,11 @@ where A: AuthBackend, S: Storage, Q: JobQueue<webmentions::Webmention> } impl<A, S, M, Q> FromRef<AppState<A, S, M, Q>> for Key -where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + S: Storage, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, S, M, Q>) -> Self { input.cookie_key.clone() @@ -144,7 +176,11 @@ where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmen } impl<A, S, M, Q> FromRef<AppState<A, S, M, Q>> for reqwest_middleware::ClientWithMiddleware -where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + S: Storage, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, S, M, Q>) -> Self { input.http.clone() @@ -152,7 +188,11 @@ where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmen } impl<A, S, M, Q> FromRef<AppState<A, S, M, Q>> for Arc<Mutex<JoinSet<()>>> -where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + S: Storage, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, S, M, Q>) -> Self { input.background_jobs.clone() @@ -161,7 +201,10 @@ where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmen #[cfg(feature = "sqlx")] impl<A, S, M> FromRef<AppState<A, S, M, Self>> for PostgresJobQueue<webmentions::Webmention> -where A: AuthBackend, S: Storage, M: MediaStore +where + A: AuthBackend, + S: Storage, + M: MediaStore, { fn from_ref(input: &AppState<A, S, M, Self>) -> Self { input.job_queue.clone() @@ -169,7 +212,11 @@ where A: AuthBackend, S: Storage, M: MediaStore } impl<A, S, M, Q> FromRef<AppState<A, S, M, Q>> for SessionStore -where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + S: Storage, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, S, M, Q>) -> Self { input.session_store.clone() @@ -177,23 +224,26 @@ where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmen } pub mod companion { - use std::{collections::HashMap, sync::Arc}; use axum::{ extract::{Extension, Path}, - response::{IntoResponse, Response} + response::{IntoResponse, Response}, }; + use std::{collections::HashMap, sync::Arc}; #[derive(Debug, Clone, Copy)] struct Resource { data: &'static [u8], - mime: &'static str + mime: &'static str, } impl IntoResponse for &Resource { fn into_response(self) -> Response { - (axum::http::StatusCode::OK, - [("Content-Type", self.mime)], - self.data).into_response() + ( + axum::http::StatusCode::OK, + [("Content-Type", self.mime)], + self.data, + ) + .into_response() } } @@ -203,17 +253,21 @@ pub mod companion { #[tracing::instrument] async fn map_to_static( Path(name): Path<String>, - Extension(resources): Extension<ResourceTable> + Extension(resources): Extension<ResourceTable>, ) -> Response { tracing::debug!("Searching for {} in the resource table...", name); match resources.get(name.as_str()) { Some(res) => res.into_response(), None => { - #[cfg(debug_assertions)] tracing::error!("Not found"); + #[cfg(debug_assertions)] + tracing::error!("Not found"); - (axum::http::StatusCode::NOT_FOUND, - [("Content-Type", "text/plain")], - "Not found. Sorry.".as_bytes()).into_response() + ( + axum::http::StatusCode::NOT_FOUND, + [("Content-Type", "text/plain")], + "Not found. Sorry.".as_bytes(), + ) + .into_response() } } } @@ -249,47 +303,52 @@ pub mod companion { Arc::new(map) }; - axum::Router::new() - .route( - "/{filename}", - axum::routing::get(map_to_static) - .layer(Extension(resources)) - ) + axum::Router::new().route( + "/{filename}", + axum::routing::get(map_to_static).layer(Extension(resources)), + ) } } async fn teapot_route() -> impl axum::response::IntoResponse { use axum::http::{header, StatusCode}; - (StatusCode::IM_A_TEAPOT, [(header::CONTENT_TYPE, "text/plain")], "Sorry, can't brew coffee yet!") + ( + StatusCode::IM_A_TEAPOT, + [(header::CONTENT_TYPE, "text/plain")], + "Sorry, can't brew coffee yet!", + ) } async fn health_check<D>( axum::extract::State(data): axum::extract::State<D>, ) -> impl axum::response::IntoResponse where - D: crate::database::Storage + D: crate::database::Storage, { (axum::http::StatusCode::OK, std::borrow::Cow::Borrowed("OK")) } pub async fn compose_kittybox<St, A, S, M, Q>() -> axum::Router<St> where -A: AuthBackend + 'static + FromRef<St>, -S: Storage + 'static + FromRef<St>, -M: MediaStore + 'static + FromRef<St>, -Q: kittybox_util::queue::JobQueue<crate::webmentions::Webmention> + FromRef<St>, -reqwest_middleware::ClientWithMiddleware: FromRef<St>, -Arc<Mutex<JoinSet<()>>>: FromRef<St>, -crate::SessionStore: FromRef<St>, -axum_extra::extract::cookie::Key: FromRef<St>, -St: Clone + Send + Sync + 'static + A: AuthBackend + 'static + FromRef<St>, + S: Storage + 'static + FromRef<St>, + M: MediaStore + 'static + FromRef<St>, + Q: kittybox_util::queue::JobQueue<crate::webmentions::Webmention> + FromRef<St>, + reqwest_middleware::ClientWithMiddleware: FromRef<St>, + Arc<Mutex<JoinSet<()>>>: FromRef<St>, + crate::SessionStore: FromRef<St>, + axum_extra::extract::cookie::Key: FromRef<St>, + St: Clone + Send + Sync + 'static, { use axum::routing::get; axum::Router::new() .route("/", get(crate::frontend::homepage::<S>)) .fallback(get(crate::frontend::catchall::<S>)) .route("/.kittybox/micropub", crate::micropub::router::<A, S, St>()) - .route("/.kittybox/onboarding", crate::frontend::onboarding::router::<St, S>()) + .route( + "/.kittybox/onboarding", + crate::frontend::onboarding::router::<St, S>(), + ) .nest("/.kittybox/media", crate::media::router::<St, A, M>()) .merge(crate::indieauth::router::<St, A, S>()) .merge(crate::webmentions::router::<St, Q>()) @@ -297,34 +356,36 @@ St: Clone + Send + Sync + 'static .nest("/.kittybox/login", crate::login::router::<St, S>()) .route( "/.kittybox/static/{*path}", - axum::routing::get(crate::frontend::statics) + axum::routing::get(crate::frontend::statics), ) .route("/.kittybox/coffee", get(teapot_route)) - .nest("/.kittybox/micropub/client", crate::companion::router::<St>()) + .nest( + "/.kittybox/micropub/client", + crate::companion::router::<St>(), + ) .layer(tower_http::trace::TraceLayer::new_for_http()) .layer(tower_http::catch_panic::CatchPanicLayer::new()) - .layer(tower_http::sensitive_headers::SetSensitiveHeadersLayer::new([ - axum::http::header::AUTHORIZATION, - axum::http::header::COOKIE, - axum::http::header::SET_COOKIE, - ])) + .layer( + tower_http::sensitive_headers::SetSensitiveHeadersLayer::new([ + axum::http::header::AUTHORIZATION, + axum::http::header::COOKIE, + axum::http::header::SET_COOKIE, + ]), + ) .layer(tower_http::set_header::SetResponseHeaderLayer::appending( axum::http::header::CONTENT_SECURITY_POLICY, - axum::http::HeaderValue::from_static( - concat!( - "default-src 'none';", // Do not allow unknown things we didn't foresee. - "img-src https:;", // Allow hotlinking images from anywhere. - "form-action 'self';", // Only allow sending forms back to us. - "media-src 'self';", // Only allow embedding media from us. - "script-src 'self';", // Only run scripts we serve. - "style-src 'self';", // Only use styles we serve. - "base-uri 'none';", // Do not allow to change the base URI. - "object-src 'none';", // Do not allow to embed objects (Flash/ActiveX). - - // Allow embedding the Bandcamp player for jam posts. - // TODO: perhaps make this policy customizable?… - "frame-src 'self' https://bandcamp.com/EmbeddedPlayer/;" - ) - ) + axum::http::HeaderValue::from_static(concat!( + "default-src 'none';", // Do not allow unknown things we didn't foresee. + "img-src https:;", // Allow hotlinking images from anywhere. + "form-action 'self';", // Only allow sending forms back to us. + "media-src 'self';", // Only allow embedding media from us. + "script-src 'self';", // Only run scripts we serve. + "style-src 'self';", // Only use styles we serve. + "base-uri 'none';", // Do not allow to change the base URI. + "object-src 'none';", // Do not allow to embed objects (Flash/ActiveX). + // Allow embedding the Bandcamp player for jam posts. + // TODO: perhaps make this policy customizable?… + "frame-src 'self' https://bandcamp.com/EmbeddedPlayer/;" + )), )) } diff --git a/src/login.rs b/src/login.rs index eaa787c..3038d9c 100644 --- a/src/login.rs +++ b/src/login.rs @@ -1,10 +1,25 @@ use std::{borrow::Cow, str::FromStr}; +use axum::{ + extract::{FromRef, Query, State}, + http::HeaderValue, + response::IntoResponse, + Form, +}; +use axum_extra::{ + extract::{ + cookie::{self, Cookie}, + Host, SignedCookieJar, + }, + headers::HeaderMapExt, + TypedHeader, +}; use futures_util::FutureExt; -use axum::{extract::{FromRef, Query, State}, http::HeaderValue, response::IntoResponse, Form}; -use axum_extra::{extract::{Host, cookie::{self, Cookie}, SignedCookieJar}, headers::HeaderMapExt, TypedHeader}; -use hyper::{header::{CACHE_CONTROL, LOCATION}, StatusCode}; -use kittybox_frontend_renderer::{Template, LoginPage, LogoutPage}; +use hyper::{ + header::{CACHE_CONTROL, LOCATION}, + StatusCode, +}; +use kittybox_frontend_renderer::{LoginPage, LogoutPage, Template}; use kittybox_indieauth::{AuthorizationResponse, Error, GrantType, PKCEVerifier, Scope, Scopes}; use sha2::Digest; @@ -13,14 +28,13 @@ use crate::database::Storage; /// Show a login page. async fn get<S: Storage + Send + Sync + 'static>( State(db): State<S>, - Host(host): Host + Host(host): Host, ) -> impl axum::response::IntoResponse { let hcard_url: url::Url = format!("https://{}/", host).parse().unwrap(); let (blogname, channels) = tokio::join!( db.get_setting::<crate::database::settings::SiteName>(&hcard_url) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_channels(&hcard_url).map(|i| i.unwrap_or_default()) ); ( @@ -34,14 +48,15 @@ async fn get<S: Storage + Send + Sync + 'static>( blog_name: blogname.as_ref(), feeds: channels, user: None, - content: LoginPage {}.to_string() - }.to_string() + content: LoginPage {}.to_string(), + } + .to_string(), ) } #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] struct LoginForm { - url: url::Url + url: url::Url, } /// Accept login and start the IndieAuth dance. @@ -60,10 +75,12 @@ async fn post( .expires(None) .secure(true) .http_only(true) - .build() + .build(), ); - let client_id: url::Url = format!("https://{}/.kittybox/login/client_metadata", host).parse().unwrap(); + let client_id: url::Url = format!("https://{}/.kittybox/login/client_metadata", host) + .parse() + .unwrap(); let redirect_uri = { let mut uri = client_id.clone(); uri.set_path("/.kittybox/login/finish"); @@ -71,11 +88,15 @@ async fn post( }; let indieauth_state = kittybox_indieauth::AuthorizationRequest { response_type: kittybox_indieauth::ResponseType::Code, - client_id, redirect_uri, + client_id, + redirect_uri, state: kittybox_indieauth::State::new(), - code_challenge: kittybox_indieauth::PKCEChallenge::new(&code_verifier, kittybox_indieauth::PKCEMethod::S256), + code_challenge: kittybox_indieauth::PKCEChallenge::new( + &code_verifier, + kittybox_indieauth::PKCEMethod::S256, + ), scope: Some(Scopes::new(vec![Scope::Profile])), - me: Some(form.url.clone()) + me: Some(form.url.clone()), }; // Fetch the user's homepage, determine their authorization endpoint @@ -89,8 +110,9 @@ async fn post( tracing::error!("Error fetching homepage: {:?}", err); return ( StatusCode::BAD_REQUEST, - format!("couldn't fetch your homepage: {}", err) - ).into_response() + format!("couldn't fetch your homepage: {}", err), + ) + .into_response(); } }; @@ -106,22 +128,27 @@ async fn post( // .collect::<Vec<axum_extra::headers::Link>>(); // // todo!("parse Link: headers") - + let body = match response.text().await { Ok(body) => match microformats::from_html(&body, form.url) { Ok(mf2) => mf2, - Err(err) => return ( - StatusCode::BAD_REQUEST, - format!("error while parsing your homepage with mf2: {}", err) - ).into_response() + Err(err) => { + return ( + StatusCode::BAD_REQUEST, + format!("error while parsing your homepage with mf2: {}", err), + ) + .into_response() + } }, - Err(err) => return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("error while fetching your homepage: {}", err) - ).into_response() + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("error while fetching your homepage: {}", err), + ) + .into_response() + } }; - let mut iss: Option<url::Url> = None; let mut authorization_endpoint = match body .rels @@ -139,10 +166,22 @@ async fn post( Ok(metadata) => { iss = Some(metadata.issuer); metadata.authorization_endpoint - }, - Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("couldn't parse your oauth2 metadata: {}", err)).into_response() + } + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("couldn't parse your oauth2 metadata: {}", err), + ) + .into_response() + } }, - Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("couldn't fetch your oauth2 metadata: {}", err)).into_response() + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("couldn't fetch your oauth2 metadata: {}", err), + ) + .into_response() + } }, None => match body .rels @@ -151,13 +190,17 @@ async fn post( .map(|v| v.as_slice()) .unwrap_or_default() .first() - .cloned() { - Some(authorization_endpoint) => authorization_endpoint, - None => return ( + .cloned() + { + Some(authorization_endpoint) => authorization_endpoint, + None => { + return ( StatusCode::BAD_REQUEST, - "no authorization endpoint was found on your homepage." - ).into_response() + "no authorization endpoint was found on your homepage.", + ) + .into_response() } + }, }; cookies = cookies.add( @@ -166,7 +209,7 @@ async fn post( .expires(None) .secure(true) .http_only(true) - .build() + .build(), ); if let Some(iss) = iss { @@ -176,7 +219,7 @@ async fn post( .expires(None) .secure(true) .http_only(true) - .build() + .build(), ); } @@ -186,7 +229,7 @@ async fn post( .expires(None) .secure(true) .http_only(true) - .build() + .build(), ); authorization_endpoint @@ -194,9 +237,12 @@ async fn post( .extend_pairs(indieauth_state.as_query_pairs().iter()); tracing::debug!("Forwarding user to {}", authorization_endpoint); - (StatusCode::FOUND, [ - ("Location", authorization_endpoint.to_string()), - ], cookies).into_response() + ( + StatusCode::FOUND, + [("Location", authorization_endpoint.to_string())], + cookies, + ) + .into_response() } /// Accept the return of the IndieAuth dance. Set a cookie for the @@ -208,7 +254,9 @@ async fn callback( State(http): State<reqwest_middleware::ClientWithMiddleware>, State(session_store): State<crate::SessionStore>, ) -> axum::response::Response { - let client_id: url::Url = format!("https://{}/.kittybox/login/client_metadata", host).parse().unwrap(); + let client_id: url::Url = format!("https://{}/.kittybox/login/client_metadata", host) + .parse() + .unwrap(); let redirect_uri = { let mut uri = client_id.clone(); uri.set_path("/.kittybox/login/finish"); @@ -218,7 +266,8 @@ async fn callback( let me: url::Url = cookie_jar.get("me").unwrap().value().parse().unwrap(); let code_verifier: PKCEVerifier = cookie_jar.get("code_verifier").unwrap().value().into(); - let authorization_endpoint: url::Url = cookie_jar.get("authorization_endpoint") + let authorization_endpoint: url::Url = cookie_jar + .get("authorization_endpoint") .and_then(|v| v.value().parse().ok()) .unwrap(); match cookie_jar.get("iss").and_then(|c| c.value().parse().ok()) { @@ -232,24 +281,59 @@ async fn callback( code: response.code, client_id, redirect_uri, - code_verifier, + code_verifier, }; - tracing::debug!("POSTing {:?} to authorization endpoint {}", grant_request, authorization_endpoint); - let res = match http.post(authorization_endpoint) + tracing::debug!( + "POSTing {:?} to authorization endpoint {}", + grant_request, + authorization_endpoint + ); + let res = match http + .post(authorization_endpoint) .form(&grant_request) .header(reqwest::header::ACCEPT, "application/json") .send() .await { - Ok(res) if res.status().is_success() => match res.json::<kittybox_indieauth::GrantResponse>().await { - Ok(grant) => grant, - Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, [(CACHE_CONTROL, "no-store")], format!("error parsing authorization endpoint response: {}", err)).into_response() - }, + Ok(res) if res.status().is_success() => { + match res.json::<kittybox_indieauth::GrantResponse>().await { + Ok(grant) => grant, + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + [(CACHE_CONTROL, "no-store")], + format!("error parsing authorization endpoint response: {}", err), + ) + .into_response() + } + } + } Ok(res) => match res.json::<Error>().await { - Ok(err) => return (StatusCode::BAD_REQUEST, [(CACHE_CONTROL, "no-store")], err.to_string()).into_response(), - Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, [(CACHE_CONTROL, "no-store")], format!("error parsing indieauth error: {}", err)).into_response() + Ok(err) => { + return ( + StatusCode::BAD_REQUEST, + [(CACHE_CONTROL, "no-store")], + err.to_string(), + ) + .into_response() + } + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + [(CACHE_CONTROL, "no-store")], + format!("error parsing indieauth error: {}", err), + ) + .into_response() + } + }, + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + [(CACHE_CONTROL, "no-store")], + format!("error redeeming authorization code: {}", err), + ) + .into_response() } - Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, [(CACHE_CONTROL, "no-store")], format!("error redeeming authorization code: {}", err)).into_response() }; let profile = match res { @@ -265,19 +349,28 @@ async fn callback( let uuid = uuid::Uuid::new_v4(); session_store.write().await.insert(uuid, session); let cookies = cookie_jar - .add(Cookie::build(("session_id", uuid.to_string())) - .expires(None) - .secure(true) - .http_only(true) - .path("/") - .build() + .add( + Cookie::build(("session_id", uuid.to_string())) + .expires(None) + .secure(true) + .http_only(true) + .path("/") + .build(), ) .remove("authorization_endpoint") .remove("me") .remove("iss") .remove("code_verifier"); - (StatusCode::FOUND, [(LOCATION, HeaderValue::from_static("/")), (CACHE_CONTROL, HeaderValue::from_static("no-store"))], dbg!(cookies)).into_response() + ( + StatusCode::FOUND, + [ + (LOCATION, HeaderValue::from_static("/")), + (CACHE_CONTROL, HeaderValue::from_static("no-store")), + ], + dbg!(cookies), + ) + .into_response() } /// Show the form necessary for logout. If JS is enabled, @@ -288,32 +381,42 @@ async fn callback( /// stupid enough to execute JS and send a POST request though, that's /// on the crawler. async fn logout_page() -> impl axum::response::IntoResponse { - (StatusCode::OK, [("Content-Type", "text/html")], Template { - title: "Signing out...", - blog_name: "Kittybox", - feeds: vec![], - user: None, - content: LogoutPage {}.to_string() - }.to_string()) + ( + StatusCode::OK, + [("Content-Type", "text/html")], + Template { + title: "Signing out...", + blog_name: "Kittybox", + feeds: vec![], + user: None, + content: LogoutPage {}.to_string(), + } + .to_string(), + ) } /// Erase the necessary cookies for login and invalidate the session. async fn logout( mut cookies: SignedCookieJar, - State(session_store): State<crate::SessionStore> -) -> (StatusCode, [(&'static str, &'static str); 1], SignedCookieJar) { - if let Some(id) = cookies.get("session_id") + State(session_store): State<crate::SessionStore>, +) -> ( + StatusCode, + [(&'static str, &'static str); 1], + SignedCookieJar, +) { + if let Some(id) = cookies + .get("session_id") .and_then(|c| uuid::Uuid::parse_str(c.value_trimmed()).ok()) { session_store.write().await.remove(&id); } - cookies = cookies.remove("me") + cookies = cookies + .remove("me") .remove("iss") .remove("authorization_endpoint") .remove("code_verifier") .remove("session_id"); - (StatusCode::FOUND, [("Location", "/")], cookies) } @@ -343,7 +446,7 @@ async fn client_metadata<S: Storage + Send + Sync + 'static>( }; if let Some(cached) = cached { if cached.precondition_passes(&etag) { - return StatusCode::NOT_MODIFIED.into_response() + return StatusCode::NOT_MODIFIED.into_response(); } } let client_uri: url::Url = format!("https://{}/", host).parse().unwrap(); @@ -356,7 +459,13 @@ async fn client_metadata<S: Storage + Send + Sync + 'static>( let mut metadata = kittybox_indieauth::ClientMetadata::new(client_id, client_uri).unwrap(); - metadata.client_name = Some(storage.get_setting::<crate::database::settings::SiteName>(&metadata.client_uri).await.unwrap_or_default().0); + metadata.client_name = Some( + storage + .get_setting::<crate::database::settings::SiteName>(&metadata.client_uri) + .await + .unwrap_or_default() + .0, + ); metadata.grant_types = Some(vec![GrantType::AuthorizationCode]); // We don't request anything more than the profile scope. metadata.scope = Some(Scopes::new(vec![Scope::Profile])); @@ -368,15 +477,18 @@ async fn client_metadata<S: Storage + Send + Sync + 'static>( // identity providers, or json to match newest spec let mut response = metadata.into_response(); // Indicate to upstream caches this endpoint does different things depending on the Accept: header. - response.headers_mut().append("Vary", HeaderValue::from_static("Accept")); + response + .headers_mut() + .append("Vary", HeaderValue::from_static("Accept")); // Cache this metadata for an hour. - response.headers_mut().append("Cache-Control", HeaderValue::from_static("max-age=600")); + response + .headers_mut() + .append("Cache-Control", HeaderValue::from_static("max-age=600")); response.headers_mut().typed_insert(etag); response } - /// Produce a router for all of the above. pub fn router<St, S>() -> axum::routing::Router<St> where diff --git a/src/main.rs b/src/main.rs index bd3684e..984745a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,11 @@ -use kittybox::{database::Storage, indieauth::backend::AuthBackend, media::storage::MediaStore, webmentions::Webmention, compose_kittybox}; -use tokio::{sync::Mutex, task::JoinSet}; +use kittybox::{ + compose_kittybox, database::Storage, indieauth::backend::AuthBackend, + media::storage::MediaStore, webmentions::Webmention, +}; use std::{env, future::IntoFuture, sync::Arc}; +use tokio::{sync::Mutex, task::JoinSet}; use tracing::error; - #[tokio::main] async fn main() { use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Registry}; @@ -17,32 +19,28 @@ async fn main() { .with_indent_lines(true) .with_verbose_exit(true), #[cfg(not(debug_assertions))] - tracing_subscriber::fmt::layer().json() - .with_ansi(std::io::IsTerminal::is_terminal(&std::io::stdout().lock())) + tracing_subscriber::fmt::layer() + .json() + .with_ansi(std::io::IsTerminal::is_terminal(&std::io::stdout().lock())), ); // In debug builds, also log to JSON, but to file. #[cfg(debug_assertions)] - let tracing_registry = tracing_registry.with( - tracing_subscriber::fmt::layer() - .json() - .with_writer({ - let instant = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap(); - move || std::fs::OpenOptions::new() + let tracing_registry = + tracing_registry.with(tracing_subscriber::fmt::layer().json().with_writer({ + let instant = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap(); + move || { + std::fs::OpenOptions::new() .append(true) .create(true) - .open( - format!( - "{}.log.json", - instant - .as_secs_f64() - .to_string() - .replace('.', "_") - ) - ).unwrap() - }) - ); + .open(format!( + "{}.log.json", + instant.as_secs_f64().to_string().replace('.', "_") + )) + .unwrap() + } + })); tracing_registry.init(); tracing::info!("Starting the kittybox server..."); @@ -79,12 +77,15 @@ async fn main() { }); // TODO: load from environment - let cookie_key = axum_extra::extract::cookie::Key::from(&env::var("COOKIE_KEY") - .as_deref() - .map(|s| data_encoding::BASE64_MIME_PERMISSIVE.decode(s.as_bytes()) - .expect("Invalid cookie key: must be base64 encoded") - ) - .unwrap() + let cookie_key = axum_extra::extract::cookie::Key::from( + &env::var("COOKIE_KEY") + .as_deref() + .map(|s| { + data_encoding::BASE64_MIME_PERMISSIVE + .decode(s.as_bytes()) + .expect("Invalid cookie key: must be base64 encoded") + }) + .unwrap(), ); let cancellation_token = tokio_util::sync::CancellationToken::new(); @@ -93,12 +94,11 @@ async fn main() { let http: reqwest_middleware::ClientWithMiddleware = { #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), - "/", - env!("CARGO_PKG_VERSION") - )); + let mut builder = reqwest::Client::builder().user_agent(concat!( + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION") + )); if let Ok(certs) = std::env::var("KITTYBOX_CUSTOM_PKI_ROOTS") { // TODO: add a root certificate if there's an environment variable pointing at it for path in certs.split(':') { @@ -108,21 +108,19 @@ async fn main() { tracing::error!("TLS root certificate {} not found, skipping...", path); continue; } - Err(err) => panic!("Error loading TLS certificates: {}", err) + Err(err) => panic!("Error loading TLS certificates: {}", err), }; if metadata.is_dir() { let mut dir = tokio::fs::read_dir(path).await.unwrap(); while let Ok(Some(file)) = dir.next_entry().await { let pem = tokio::fs::read(file.path()).await.unwrap(); - builder = builder.add_root_certificate( - reqwest::Certificate::from_pem(&pem).unwrap() - ); + builder = builder + .add_root_certificate(reqwest::Certificate::from_pem(&pem).unwrap()); } } else { let pem = tokio::fs::read(path).await.unwrap(); - builder = builder.add_root_certificate( - reqwest::Certificate::from_pem(&pem).unwrap() - ); + builder = + builder.add_root_certificate(reqwest::Certificate::from_pem(&pem).unwrap()); } } } @@ -151,7 +149,7 @@ async fn main() { let job_queue_type = job_queue_uri.scheme(); macro_rules! compose_kittybox { - ($auth:ty, $store:ty, $media:ty, $queue:ty) => { { + ($auth:ty, $store:ty, $media:ty, $queue:ty) => {{ type AuthBackend = $auth; type Storage = $store; type MediaStore = $media; @@ -193,36 +191,43 @@ async fn main() { }; type St = kittybox::AppState<AuthBackend, Storage, MediaStore, JobQueue>; - let stateful_router = compose_kittybox::<St, AuthBackend, Storage, MediaStore, JobQueue>().await; - let task = kittybox::webmentions::supervised_webmentions_task::<St, Storage, JobQueue>(&state, cancellation_token.clone()); + let stateful_router = + compose_kittybox::<St, AuthBackend, Storage, MediaStore, JobQueue>().await; + let task = kittybox::webmentions::supervised_webmentions_task::<St, Storage, JobQueue>( + &state, + cancellation_token.clone(), + ); let router = stateful_router.with_state(state); (router, task) - } } + }}; } - let (router, webmentions_task): (axum::Router<()>, kittybox::webmentions::SupervisedTask) = match (authstore_type, backend_type, blobstore_type, job_queue_type) { - ("file", "file", "file", "postgres") => { - compose_kittybox!( - kittybox::indieauth::backend::fs::FileBackend, - kittybox::database::FileStorage, - kittybox::media::storage::file::FileStore, - kittybox::webmentions::queue::PostgresJobQueue<Webmention> - ) - }, - ("file", "postgres", "file", "postgres") => { - compose_kittybox!( - kittybox::indieauth::backend::fs::FileBackend, - kittybox::database::PostgresStorage, - kittybox::media::storage::file::FileStore, - kittybox::webmentions::queue::PostgresJobQueue<Webmention> - ) - }, - (_, _, _, _) => { - // TODO: refine this error. - panic!("Invalid type for AUTH_STORE_URI, BACKEND_URI, BLOBSTORE_URI or JOB_QUEUE_URI"); - } - }; + let (router, webmentions_task): (axum::Router<()>, kittybox::webmentions::SupervisedTask) = + match (authstore_type, backend_type, blobstore_type, job_queue_type) { + ("file", "file", "file", "postgres") => { + compose_kittybox!( + kittybox::indieauth::backend::fs::FileBackend, + kittybox::database::FileStorage, + kittybox::media::storage::file::FileStore, + kittybox::webmentions::queue::PostgresJobQueue<Webmention> + ) + } + ("file", "postgres", "file", "postgres") => { + compose_kittybox!( + kittybox::indieauth::backend::fs::FileBackend, + kittybox::database::PostgresStorage, + kittybox::media::storage::file::FileStore, + kittybox::webmentions::queue::PostgresJobQueue<Webmention> + ) + } + (_, _, _, _) => { + // TODO: refine this error. + panic!( + "Invalid type for AUTH_STORE_URI, BACKEND_URI, BLOBSTORE_URI or JOB_QUEUE_URI" + ); + } + }; let mut servers: Vec<axum::serve::Serve<_, _, _>> = vec![]; @@ -238,7 +243,7 @@ async fn main() { // .serve(router.clone().into_make_service()) axum::serve( tokio::net::TcpListener::from_std(tcp).unwrap(), - router.clone() + router.clone(), ) }; @@ -246,8 +251,8 @@ async fn main() { for i in 0..(listenfd.len()) { match listenfd.take_tcp_listener(i) { Ok(Some(tcp)) => servers.push(build_hyper(tcp)), - Ok(None) => {}, - Err(err) => tracing::error!("Error binding to socket in fd {}: {}", i, err) + Ok(None) => {} + Err(err) => tracing::error!("Error binding to socket in fd {}: {}", i, err), } } // TODO this requires the `hyperlocal` crate @@ -302,24 +307,35 @@ async fn main() { // to get rid of an extra reference to `jobset` drop(router); // Polling streams mutates them - let mut servers_futures = Box::pin(servers.into_iter() - .map( - #[cfg(not(tokio_unstable))] |server| tokio::task::spawn( - server.with_graceful_shutdown(cancellation_token.clone().cancelled_owned()) - .into_future() - ), - #[cfg(tokio_unstable)] |server| { - tokio::task::Builder::new() - .name(format!("Kittybox HTTP acceptor: {:?}", server).as_str()) - .spawn( - server.with_graceful_shutdown( - cancellation_token.clone().cancelled_owned() - ).into_future() + let mut servers_futures = Box::pin( + servers + .into_iter() + .map( + #[cfg(not(tokio_unstable))] + |server| { + tokio::task::spawn( + server + .with_graceful_shutdown(cancellation_token.clone().cancelled_owned()) + .into_future(), ) - .unwrap() - } - ) - .collect::<futures_util::stream::FuturesUnordered<tokio::task::JoinHandle<Result<(), std::io::Error>>>>() + }, + #[cfg(tokio_unstable)] + |server| { + tokio::task::Builder::new() + .name(format!("Kittybox HTTP acceptor: {:?}", server).as_str()) + .spawn( + server + .with_graceful_shutdown( + cancellation_token.clone().cancelled_owned(), + ) + .into_future(), + ) + .unwrap() + }, + ) + .collect::<futures_util::stream::FuturesUnordered< + tokio::task::JoinHandle<Result<(), std::io::Error>>, + >>(), ); #[cfg(not(unix))] @@ -329,10 +345,10 @@ async fn main() { use tokio::signal::unix::{signal, SignalKind}; async move { - let mut interrupt = signal(SignalKind::interrupt()) - .expect("Failed to set up SIGINT handler"); - let mut terminate = signal(SignalKind::terminate()) - .expect("Failed to setup SIGTERM handler"); + let mut interrupt = + signal(SignalKind::interrupt()).expect("Failed to set up SIGINT handler"); + let mut terminate = + signal(SignalKind::terminate()).expect("Failed to setup SIGTERM handler"); tokio::select! { _ = terminate.recv() => {}, diff --git a/src/media/mod.rs b/src/media/mod.rs index 6f263b6..7e52414 100644 --- a/src/media/mod.rs +++ b/src/media/mod.rs @@ -1,22 +1,23 @@ +use crate::indieauth::{backend::AuthBackend, User}; use axum::{ - extract::{multipart::Multipart, FromRef, Path, State}, response::{IntoResponse, Response} + extract::{multipart::Multipart, FromRef, Path, State}, + response::{IntoResponse, Response}, }; -use axum_extra::headers::{ContentLength, HeaderMapExt, HeaderValue, IfNoneMatch}; use axum_extra::extract::Host; +use axum_extra::headers::{ContentLength, HeaderMapExt, HeaderValue, IfNoneMatch}; use axum_extra::TypedHeader; -use kittybox_util::micropub::{Error as MicropubError, ErrorKind as ErrorType}; use kittybox_indieauth::Scope; -use crate::indieauth::{backend::AuthBackend, User}; +use kittybox_util::micropub::{Error as MicropubError, ErrorKind as ErrorType}; pub mod storage; -use storage::{MediaStore, MediaStoreError, Metadata, ErrorKind}; pub use storage::file::FileStore; +use storage::{ErrorKind, MediaStore, MediaStoreError, Metadata}; impl From<MediaStoreError> for MicropubError { fn from(err: MediaStoreError) -> Self { Self::new( ErrorType::InternalServerError, - format!("media store error: {}", err) + format!("media store error: {}", err), ) } } @@ -25,13 +26,14 @@ impl From<MediaStoreError> for MicropubError { pub(crate) async fn upload<S: MediaStore, A: AuthBackend>( State(blobstore): State<S>, user: User<A>, - mut upload: Multipart + mut upload: Multipart, ) -> Response { if !user.check_scope(&Scope::Media) { return MicropubError::from_static( ErrorType::NotAuthorized, - "Interacting with the media storage requires the \"media\" scope." - ).into_response(); + "Interacting with the media storage requires the \"media\" scope.", + ) + .into_response(); } let host = user.me.authority(); let field = match upload.next_field().await { @@ -39,27 +41,31 @@ pub(crate) async fn upload<S: MediaStore, A: AuthBackend>( Ok(None) => { return MicropubError::from_static( ErrorType::InvalidRequest, - "Send multipart/form-data with one field named file" - ).into_response(); - }, + "Send multipart/form-data with one field named file", + ) + .into_response(); + } Err(err) => { return MicropubError::new( ErrorType::InternalServerError, - format!("Error while parsing multipart/form-data: {}", err) - ).into_response(); - }, + format!("Error while parsing multipart/form-data: {}", err), + ) + .into_response(); + } }; let metadata: Metadata = (&field).into(); match blobstore.write_streaming(host, metadata, field).await { Ok(filename) => IntoResponse::into_response(( axum::http::StatusCode::CREATED, - [ - ("Location", user.me.join( - &format!(".kittybox/media/uploads/{}", filename) - ).unwrap().as_str()) - ] + [( + "Location", + user.me + .join(&format!(".kittybox/media/uploads/{}", filename)) + .unwrap() + .as_str(), + )], )), - Err(err) => MicropubError::from(err).into_response() + Err(err) => MicropubError::from(err).into_response(), } } @@ -68,7 +74,7 @@ pub(crate) async fn serve<S: MediaStore>( Host(host): Host, Path(path): Path<String>, if_none_match: Option<TypedHeader<IfNoneMatch>>, - State(blobstore): State<S> + State(blobstore): State<S>, ) -> Response { use axum::http::StatusCode; tracing::debug!("Searching for file..."); @@ -77,7 +83,9 @@ pub(crate) async fn serve<S: MediaStore>( tracing::debug!("Metadata: {:?}", metadata); let etag = if let Some(etag) = metadata.etag { - let etag = format!("\"{}\"", etag).parse::<axum_extra::headers::ETag>().unwrap(); + let etag = format!("\"{}\"", etag) + .parse::<axum_extra::headers::ETag>() + .unwrap(); if let Some(TypedHeader(if_none_match)) = if_none_match { tracing::debug!("If-None-Match: {:?}", if_none_match); @@ -85,12 +93,14 @@ pub(crate) async fn serve<S: MediaStore>( // returns 304 when it doesn't match because it // only matches when file is different if !if_none_match.precondition_passes(&etag) { - return StatusCode::NOT_MODIFIED.into_response() + return StatusCode::NOT_MODIFIED.into_response(); } } Some(etag) - } else { None }; + } else { + None + }; let mut r = Response::builder(); { @@ -98,14 +108,16 @@ pub(crate) async fn serve<S: MediaStore>( headers.insert( "Content-Type", HeaderValue::from_str( - metadata.content_type + metadata + .content_type .as_deref() - .unwrap_or("application/octet-stream") - ).unwrap() + .unwrap_or("application/octet-stream"), + ) + .unwrap(), ); headers.insert( axum::http::header::X_CONTENT_TYPE_OPTIONS, - axum::http::HeaderValue::from_static("nosniff") + axum::http::HeaderValue::from_static("nosniff"), ); if let Some(length) = metadata.length { headers.typed_insert(ContentLength(length.get().try_into().unwrap())); @@ -117,23 +129,22 @@ pub(crate) async fn serve<S: MediaStore>( r.body(axum::body::Body::from_stream(stream)) .unwrap() .into_response() - }, + } Err(err) => match err.kind() { - ErrorKind::NotFound => { - IntoResponse::into_response(StatusCode::NOT_FOUND) - }, + ErrorKind::NotFound => IntoResponse::into_response(StatusCode::NOT_FOUND), _ => { tracing::error!("{}", err); IntoResponse::into_response(StatusCode::INTERNAL_SERVER_ERROR) } - } + }, } } -pub fn router<St, A, M>() -> axum::Router<St> where +pub fn router<St, A, M>() -> axum::Router<St> +where A: AuthBackend + FromRef<St>, M: MediaStore + FromRef<St>, - St: Clone + Send + Sync + 'static + St: Clone + Send + Sync + 'static, { axum::Router::new() .route("/", axum::routing::post(upload::<M, A>)) diff --git a/src/media/storage/file.rs b/src/media/storage/file.rs index 4cd0ece..5198a4c 100644 --- a/src/media/storage/file.rs +++ b/src/media/storage/file.rs @@ -1,12 +1,12 @@ -use super::{Metadata, ErrorKind, MediaStore, MediaStoreError, Result}; -use std::{path::PathBuf, fmt::Debug}; -use tokio::fs::OpenOptions; -use tokio::io::{BufReader, BufWriter, AsyncWriteExt, AsyncSeekExt}; +use super::{ErrorKind, MediaStore, MediaStoreError, Metadata, Result}; +use futures::FutureExt; use futures::{StreamExt, TryStreamExt}; +use sha2::Digest; use std::ops::{Bound, Neg}; use std::pin::Pin; -use sha2::Digest; -use futures::FutureExt; +use std::{fmt::Debug, path::PathBuf}; +use tokio::fs::OpenOptions; +use tokio::io::{AsyncSeekExt, AsyncWriteExt, BufReader, BufWriter}; use tracing::{debug, error}; const BUF_CAPACITY: usize = 16 * 1024; @@ -22,7 +22,7 @@ impl From<tokio::io::Error> for MediaStoreError { msg: format!("file I/O error: {}", source), kind: match source.kind() { std::io::ErrorKind::NotFound => ErrorKind::NotFound, - _ => ErrorKind::Backend + _ => ErrorKind::Backend, }, source: Some(Box::new(source)), } @@ -40,7 +40,9 @@ impl FileStore { impl MediaStore for FileStore { async fn new(url: &'_ url::Url) -> Result<Self> { - Ok(Self { base: url.path().into() }) + Ok(Self { + base: url.path().into(), + }) } #[tracing::instrument(skip(self, content))] @@ -51,10 +53,17 @@ impl MediaStore for FileStore { mut content: T, ) -> Result<String> where - T: tokio_stream::Stream<Item = std::result::Result<bytes::Bytes, axum::extract::multipart::MultipartError>> + Unpin + Send + Debug + T: tokio_stream::Stream< + Item = std::result::Result<bytes::Bytes, axum::extract::multipart::MultipartError>, + > + Unpin + + Send + + Debug, { let (tempfilepath, mut tempfile) = self.mktemp().await?; - debug!("Temporary file opened for storing pending upload: {}", tempfilepath.display()); + debug!( + "Temporary file opened for storing pending upload: {}", + tempfilepath.display() + ); let mut hasher = sha2::Sha256::new(); let mut length: usize = 0; @@ -62,7 +71,7 @@ impl MediaStore for FileStore { let chunk = chunk.map_err(|err| MediaStoreError { kind: ErrorKind::Backend, source: Some(Box::new(err)), - msg: "Failed to read a data chunk".to_owned() + msg: "Failed to read a data chunk".to_owned(), })?; debug!("Read {} bytes from the stream", chunk.len()); length += chunk.len(); @@ -70,9 +79,7 @@ impl MediaStore for FileStore { { let chunk = chunk.clone(); let tempfile = &mut tempfile; - async move { - tempfile.write_all(&chunk).await - } + async move { tempfile.write_all(&chunk).await } }, { let chunk = chunk.clone(); @@ -80,7 +87,8 @@ impl MediaStore for FileStore { hasher.update(&*chunk); hasher - }).map(|r| r.unwrap()) + }) + .map(|r| r.unwrap()) } ); if let Err(err) = write_result { @@ -90,7 +98,9 @@ impl MediaStore for FileStore { // though temporary files might take up space on the hard drive // We'll clean them when maintenance time comes #[allow(unused_must_use)] - { tokio::fs::remove_file(tempfilepath).await; } + { + tokio::fs::remove_file(tempfilepath).await; + } return Err(err.into()); } hasher = _hasher; @@ -113,10 +123,17 @@ impl MediaStore for FileStore { let filepath = self.base.join(domain_str.as_str()).join(&filename); let metafilename = filename.clone() + ".json"; let metapath = self.base.join(domain_str.as_str()).join(&metafilename); - let metatemppath = self.base.join(domain_str.as_str()).join(metafilename + ".tmp"); + let metatemppath = self + .base + .join(domain_str.as_str()) + .join(metafilename + ".tmp"); metadata.length = std::num::NonZeroUsize::new(length); metadata.etag = Some(hash); - debug!("File path: {}, metadata: {}", filepath.display(), metapath.display()); + debug!( + "File path: {}, metadata: {}", + filepath.display(), + metapath.display() + ); { let parent = filepath.parent().unwrap(); tokio::fs::create_dir_all(parent).await?; @@ -126,7 +143,8 @@ impl MediaStore for FileStore { .write(true) .open(&metatemppath) .await?; - meta.write_all(&serde_json::to_vec(&metadata).unwrap()).await?; + meta.write_all(&serde_json::to_vec(&metadata).unwrap()) + .await?; tokio::fs::rename(tempfilepath, filepath).await?; tokio::fs::rename(metatemppath, metapath).await?; Ok(filename) @@ -138,28 +156,31 @@ impl MediaStore for FileStore { &self, domain: &str, filename: &str, - ) -> Result<(Metadata, Pin<Box<dyn tokio_stream::Stream<Item = std::io::Result<bytes::Bytes>> + Send>>)> { + ) -> Result<( + Metadata, + Pin<Box<dyn tokio_stream::Stream<Item = std::io::Result<bytes::Bytes>> + Send>>, + )> { debug!("Domain: {}, filename: {}", domain, filename); let path = self.base.join(domain).join(filename); debug!("Path: {}", path.display()); - let file = OpenOptions::new() - .read(true) - .open(path) - .await?; + let file = OpenOptions::new().read(true).open(path).await?; let meta = self.metadata(domain, filename).await?; - Ok((meta, Box::pin( - tokio_util::io::ReaderStream::new( - // TODO: determine if BufReader provides benefit here - // From the logs it looks like we're reading 4KiB at a time - // Buffering file contents seems to double download speed - // How to benchmark this? - BufReader::with_capacity(BUF_CAPACITY, file) - ) - // Sprinkle some salt in form of protective log wrapping - .inspect_ok(|chunk| debug!("Read {} bytes from file", chunk.len())) - ))) + Ok(( + meta, + Box::pin( + tokio_util::io::ReaderStream::new( + // TODO: determine if BufReader provides benefit here + // From the logs it looks like we're reading 4KiB at a time + // Buffering file contents seems to double download speed + // How to benchmark this? + BufReader::with_capacity(BUF_CAPACITY, file), + ) + // Sprinkle some salt in form of protective log wrapping + .inspect_ok(|chunk| debug!("Read {} bytes from file", chunk.len())), + ), + )) } #[tracing::instrument(skip(self))] @@ -167,12 +188,13 @@ impl MediaStore for FileStore { let metapath = self.base.join(domain).join(format!("{}.json", filename)); debug!("Metadata path: {}", metapath.display()); - let meta = serde_json::from_slice(&tokio::fs::read(metapath).await?) - .map_err(|err| MediaStoreError { + let meta = serde_json::from_slice(&tokio::fs::read(metapath).await?).map_err(|err| { + MediaStoreError { kind: ErrorKind::Json, msg: format!("{}", err), - source: Some(Box::new(err)) - })?; + source: Some(Box::new(err)), + } + })?; Ok(meta) } @@ -182,16 +204,14 @@ impl MediaStore for FileStore { &self, domain: &str, filename: &str, - range: (Bound<u64>, Bound<u64>) - ) -> Result<Pin<Box<dyn tokio_stream::Stream<Item = std::io::Result<bytes::Bytes>> + Send>>> { + range: (Bound<u64>, Bound<u64>), + ) -> Result<Pin<Box<dyn tokio_stream::Stream<Item = std::io::Result<bytes::Bytes>> + Send>>> + { let path = self.base.join(format!("{}/{}", domain, filename)); let metapath = self.base.join(format!("{}/{}.json", domain, filename)); debug!("Path: {}, metadata: {}", path.display(), metapath.display()); - let mut file = OpenOptions::new() - .read(true) - .open(path) - .await?; + let mut file = OpenOptions::new().read(true).open(path).await?; let start = match range { (Bound::Included(bound), _) => { @@ -202,45 +222,52 @@ impl MediaStore for FileStore { (Bound::Unbounded, Bound::Included(bound)) => { // Seek to the end minus the bounded bytes debug!("Seeking {} bytes back from the end...", bound); - file.seek(std::io::SeekFrom::End(i64::try_from(bound).unwrap().neg())).await? - }, + file.seek(std::io::SeekFrom::End(i64::try_from(bound).unwrap().neg())) + .await? + } (Bound::Unbounded, Bound::Unbounded) => 0, - (_, Bound::Excluded(_)) => unreachable!() + (_, Bound::Excluded(_)) => unreachable!(), }; - let stream = Box::pin(tokio_util::io::ReaderStream::new(BufReader::with_capacity(BUF_CAPACITY, file))) - .map_ok({ - let mut bytes_read = 0usize; - let len = match range { - (_, Bound::Unbounded) => None, - (Bound::Unbounded, Bound::Included(bound)) => Some(bound), - (_, Bound::Included(bound)) => Some(bound + 1 - start), - (_, Bound::Excluded(_)) => unreachable!() - }; - move |chunk| { - debug!("Read {} bytes from file, {} in this chunk", bytes_read, chunk.len()); - bytes_read += chunk.len(); - if let Some(len) = len.map(|len| len.try_into().unwrap()) { - if bytes_read > len { - if bytes_read - len > chunk.len() { - return None - } - debug!("Truncating last {} bytes", bytes_read - len); - return Some(chunk.slice(..chunk.len() - (bytes_read - len))) + let stream = Box::pin(tokio_util::io::ReaderStream::new(BufReader::with_capacity( + BUF_CAPACITY, + file, + ))) + .map_ok({ + let mut bytes_read = 0usize; + let len = match range { + (_, Bound::Unbounded) => None, + (Bound::Unbounded, Bound::Included(bound)) => Some(bound), + (_, Bound::Included(bound)) => Some(bound + 1 - start), + (_, Bound::Excluded(_)) => unreachable!(), + }; + move |chunk| { + debug!( + "Read {} bytes from file, {} in this chunk", + bytes_read, + chunk.len() + ); + bytes_read += chunk.len(); + if let Some(len) = len.map(|len| len.try_into().unwrap()) { + if bytes_read > len { + if bytes_read - len > chunk.len() { + return None; } + debug!("Truncating last {} bytes", bytes_read - len); + return Some(chunk.slice(..chunk.len() - (bytes_read - len))); } - - Some(chunk) } - }) - .try_take_while(|x| std::future::ready(Ok(x.is_some()))) - // Will never panic, because the moment the stream yields - // a None, it is considered exhausted. - .map_ok(|x| x.unwrap()); - return Ok(Box::pin(stream)) - } + Some(chunk) + } + }) + .try_take_while(|x| std::future::ready(Ok(x.is_some()))) + // Will never panic, because the moment the stream yields + // a None, it is considered exhausted. + .map_ok(|x| x.unwrap()); + return Ok(Box::pin(stream)); + } async fn delete(&self, domain: &str, filename: &str) -> Result<()> { let path = self.base.join(format!("{}/{}", domain, filename)); @@ -251,7 +278,7 @@ impl MediaStore for FileStore { #[cfg(test)] mod tests { - use super::{Metadata, FileStore, MediaStore}; + use super::{FileStore, MediaStore, Metadata}; use std::ops::Bound; use tokio::io::AsyncReadExt; @@ -259,10 +286,15 @@ mod tests { #[tracing_test::traced_test] async fn test_ranges() { let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); - let store = FileStore { base: tempdir.path().to_path_buf() }; + let store = FileStore { + base: tempdir.path().to_path_buf(), + }; let file: &[u8] = include_bytes!("./file.rs"); - let stream = tokio_stream::iter(file.chunks(100).map(|i| Ok(bytes::Bytes::copy_from_slice(i)))); + let stream = tokio_stream::iter( + file.chunks(100) + .map(|i| Ok(bytes::Bytes::copy_from_slice(i))), + ); let metadata = Metadata { filename: Some("file.rs".to_string()), content_type: Some("text/plain".to_string()), @@ -271,28 +303,30 @@ mod tests { }; // write through the interface - let filename = store.write_streaming( - "fireburn.ru", - metadata, stream - ).await.unwrap(); + let filename = store + .write_streaming("fireburn.ru", metadata, stream) + .await + .unwrap(); tracing::debug!("Writing complete."); // Ensure the file is there - let content = tokio::fs::read( - tempdir.path() - .join("fireburn.ru") - .join(&filename) - ).await.unwrap(); + let content = tokio::fs::read(tempdir.path().join("fireburn.ru").join(&filename)) + .await + .unwrap(); assert_eq!(content, file); tracing::debug!("Reading range from the start..."); // try to read range let range = { - let stream = store.stream_range( - "fireburn.ru", &filename, - (Bound::Included(0), Bound::Included(299)) - ).await.unwrap(); + let stream = store + .stream_range( + "fireburn.ru", + &filename, + (Bound::Included(0), Bound::Included(299)), + ) + .await + .unwrap(); let mut reader = tokio_util::io::StreamReader::new(stream); @@ -308,10 +342,14 @@ mod tests { tracing::debug!("Reading range from the middle..."); let range = { - let stream = store.stream_range( - "fireburn.ru", &filename, - (Bound::Included(150), Bound::Included(449)) - ).await.unwrap(); + let stream = store + .stream_range( + "fireburn.ru", + &filename, + (Bound::Included(150), Bound::Included(449)), + ) + .await + .unwrap(); let mut reader = tokio_util::io::StreamReader::new(stream); @@ -326,13 +364,17 @@ mod tests { tracing::debug!("Reading range from the end..."); let range = { - let stream = store.stream_range( - "fireburn.ru", &filename, - // Note: the `headers` crate parses bounds in a - // non-standard way, where unbounded start actually - // means getting things from the end... - (Bound::Unbounded, Bound::Included(300)) - ).await.unwrap(); + let stream = store + .stream_range( + "fireburn.ru", + &filename, + // Note: the `headers` crate parses bounds in a + // non-standard way, where unbounded start actually + // means getting things from the end... + (Bound::Unbounded, Bound::Included(300)), + ) + .await + .unwrap(); let mut reader = tokio_util::io::StreamReader::new(stream); @@ -343,15 +385,19 @@ mod tests { }; assert_eq!(range.len(), 300); - assert_eq!(range.as_slice(), &file[file.len()-300..file.len()]); + assert_eq!(range.as_slice(), &file[file.len() - 300..file.len()]); tracing::debug!("Reading the whole file..."); // try to read range let range = { - let stream = store.stream_range( - "fireburn.ru", &("/".to_string() + &filename), - (Bound::Unbounded, Bound::Unbounded) - ).await.unwrap(); + let stream = store + .stream_range( + "fireburn.ru", + &("/".to_string() + &filename), + (Bound::Unbounded, Bound::Unbounded), + ) + .await + .unwrap(); let mut reader = tokio_util::io::StreamReader::new(stream); @@ -365,15 +411,19 @@ mod tests { assert_eq!(range.as_slice(), file); } - #[tokio::test] #[tracing_test::traced_test] async fn test_streaming_read_write() { let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); - let store = FileStore { base: tempdir.path().to_path_buf() }; + let store = FileStore { + base: tempdir.path().to_path_buf(), + }; let file: &[u8] = include_bytes!("./file.rs"); - let stream = tokio_stream::iter(file.chunks(100).map(|i| Ok(bytes::Bytes::copy_from_slice(i)))); + let stream = tokio_stream::iter( + file.chunks(100) + .map(|i| Ok(bytes::Bytes::copy_from_slice(i))), + ); let metadata = Metadata { filename: Some("style.css".to_string()), content_type: Some("text/css".to_string()), @@ -382,27 +432,32 @@ mod tests { }; // write through the interface - let filename = store.write_streaming( - "fireburn.ru", - metadata, stream - ).await.unwrap(); - println!("{}, {}", filename, tempdir.path() - .join("fireburn.ru") - .join(&filename) - .display()); - let content = tokio::fs::read( - tempdir.path() - .join("fireburn.ru") - .join(&filename) - ).await.unwrap(); + let filename = store + .write_streaming("fireburn.ru", metadata, stream) + .await + .unwrap(); + println!( + "{}, {}", + filename, + tempdir.path().join("fireburn.ru").join(&filename).display() + ); + let content = tokio::fs::read(tempdir.path().join("fireburn.ru").join(&filename)) + .await + .unwrap(); assert_eq!(content, file); // check internal metadata format - let meta: Metadata = serde_json::from_slice(&tokio::fs::read( - tempdir.path() - .join("fireburn.ru") - .join(filename.clone() + ".json") - ).await.unwrap()).unwrap(); + let meta: Metadata = serde_json::from_slice( + &tokio::fs::read( + tempdir + .path() + .join("fireburn.ru") + .join(filename.clone() + ".json"), + ) + .await + .unwrap(), + ) + .unwrap(); assert_eq!(meta.content_type.as_deref(), Some("text/css")); assert_eq!(meta.filename.as_deref(), Some("style.css")); assert_eq!(meta.length.map(|i| i.get()), Some(file.len())); @@ -410,10 +465,10 @@ mod tests { // read back the data using the interface let (metadata, read_back) = { - let (metadata, stream) = store.read_streaming( - "fireburn.ru", - &filename - ).await.unwrap(); + let (metadata, stream) = store + .read_streaming("fireburn.ru", &filename) + .await + .unwrap(); let mut reader = tokio_util::io::StreamReader::new(stream); let mut buf = Vec::default(); @@ -427,6 +482,5 @@ mod tests { assert_eq!(meta.filename.as_deref(), Some("style.css")); assert_eq!(meta.length.map(|i| i.get()), Some(file.len())); assert!(meta.etag.is_some()); - } } diff --git a/src/media/storage/mod.rs b/src/media/storage/mod.rs index 3583247..5658071 100644 --- a/src/media/storage/mod.rs +++ b/src/media/storage/mod.rs @@ -1,12 +1,12 @@ use axum::extract::multipart::Field; -use tokio_stream::Stream; use bytes::Bytes; use serde::{Deserialize, Serialize}; +use std::fmt::Debug; use std::future::Future; +use std::num::NonZeroUsize; use std::ops::Bound; use std::pin::Pin; -use std::fmt::Debug; -use std::num::NonZeroUsize; +use tokio_stream::Stream; pub mod file; @@ -24,17 +24,14 @@ pub struct Metadata { impl From<&Field<'_>> for Metadata { fn from(field: &Field<'_>) -> Self { Self { - content_type: field.content_type() - .map(|i| i.to_owned()), - filename: field.file_name() - .map(|i| i.to_owned()), + content_type: field.content_type().map(|i| i.to_owned()), + filename: field.file_name().map(|i| i.to_owned()), length: None, etag: None, } } } - #[derive(Debug, Clone, Copy)] pub enum ErrorKind { Backend, @@ -95,88 +92,116 @@ pub trait MediaStore: 'static + Send + Sync + Clone { content: T, ) -> impl Future<Output = Result<String>> + Send where - T: tokio_stream::Stream<Item = std::result::Result<bytes::Bytes, axum::extract::multipart::MultipartError>> + Unpin + Send + Debug; + T: tokio_stream::Stream< + Item = std::result::Result<bytes::Bytes, axum::extract::multipart::MultipartError>, + > + Unpin + + Send + + Debug; #[allow(clippy::type_complexity)] fn read_streaming( &self, domain: &str, filename: &str, - ) -> impl Future<Output = Result< - (Metadata, Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>) - >> + Send; + ) -> impl Future< + Output = Result<( + Metadata, + Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>, + )>, + > + Send; fn stream_range( &self, domain: &str, filename: &str, - range: (Bound<u64>, Bound<u64>) - ) -> impl Future<Output = Result<Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>>> + Send { async move { - use futures::stream::TryStreamExt; - use tracing::debug; - let (metadata, mut stream) = self.read_streaming(domain, filename).await?; - let length = metadata.length.unwrap().get(); - - use Bound::*; - let (start, end): (usize, usize) = match range { - (Unbounded, Unbounded) => return Ok(stream), - (Included(start), Unbounded) => (start.try_into().unwrap(), length - 1), - (Unbounded, Included(end)) => (length - usize::try_from(end).unwrap(), length - 1), - (Included(start), Included(end)) => (start.try_into().unwrap(), end.try_into().unwrap()), - (_, _) => unreachable!() - }; - - stream = Box::pin( - stream.map_ok({ - let mut bytes_skipped = 0usize; - let mut bytes_read = 0usize; - - move |chunk| { - debug!("Skipped {}/{} bytes, chunk len {}", bytes_skipped, start, chunk.len()); - let chunk = if bytes_skipped < start { - let need_to_skip = start - bytes_skipped; - if chunk.len() < need_to_skip { - return None - } - debug!("Skipping {} bytes", need_to_skip); - bytes_skipped += need_to_skip; - - chunk.slice(need_to_skip..) - } else { - chunk - }; - - debug!("Read {} bytes from file, {} in this chunk", bytes_read, chunk.len()); - bytes_read += chunk.len(); - - if bytes_read > length { - if bytes_read - length > chunk.len() { - return None - } - debug!("Truncating last {} bytes", bytes_read - length); - return Some(chunk.slice(..chunk.len() - (bytes_read - length))) - } - - Some(chunk) + range: (Bound<u64>, Bound<u64>), + ) -> impl Future<Output = Result<Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>>> + Send + { + async move { + use futures::stream::TryStreamExt; + use tracing::debug; + let (metadata, mut stream) = self.read_streaming(domain, filename).await?; + let length = metadata.length.unwrap().get(); + + use Bound::*; + let (start, end): (usize, usize) = match range { + (Unbounded, Unbounded) => return Ok(stream), + (Included(start), Unbounded) => (start.try_into().unwrap(), length - 1), + (Unbounded, Included(end)) => (length - usize::try_from(end).unwrap(), length - 1), + (Included(start), Included(end)) => { + (start.try_into().unwrap(), end.try_into().unwrap()) } - }) - .try_skip_while(|x| std::future::ready(Ok(x.is_none()))) - .try_take_while(|x| std::future::ready(Ok(x.is_some()))) - .map_ok(|x| x.unwrap()) - ); + (_, _) => unreachable!(), + }; + + stream = Box::pin( + stream + .map_ok({ + let mut bytes_skipped = 0usize; + let mut bytes_read = 0usize; + + move |chunk| { + debug!( + "Skipped {}/{} bytes, chunk len {}", + bytes_skipped, + start, + chunk.len() + ); + let chunk = if bytes_skipped < start { + let need_to_skip = start - bytes_skipped; + if chunk.len() < need_to_skip { + return None; + } + debug!("Skipping {} bytes", need_to_skip); + bytes_skipped += need_to_skip; + + chunk.slice(need_to_skip..) + } else { + chunk + }; + + debug!( + "Read {} bytes from file, {} in this chunk", + bytes_read, + chunk.len() + ); + bytes_read += chunk.len(); + + if bytes_read > length { + if bytes_read - length > chunk.len() { + return None; + } + debug!("Truncating last {} bytes", bytes_read - length); + return Some(chunk.slice(..chunk.len() - (bytes_read - length))); + } + + Some(chunk) + } + }) + .try_skip_while(|x| std::future::ready(Ok(x.is_none()))) + .try_take_while(|x| std::future::ready(Ok(x.is_some()))) + .map_ok(|x| x.unwrap()), + ); - Ok(stream) - } } + Ok(stream) + } + } /// Read metadata for a file. /// /// The default implementation uses the `read_streaming` method /// and drops the stream containing file content. - fn metadata(&self, domain: &str, filename: &str) -> impl Future<Output = Result<Metadata>> + Send { async move { - self.read_streaming(domain, filename) - .await - .map(|(meta, _)| meta) - } } + fn metadata( + &self, + domain: &str, + filename: &str, + ) -> impl Future<Output = Result<Metadata>> + Send { + async move { + self.read_streaming(domain, filename) + .await + .map(|(meta, _)| meta) + } + } fn delete(&self, domain: &str, filename: &str) -> impl Future<Output = Result<()>> + Send; } diff --git a/src/micropub/mod.rs b/src/micropub/mod.rs index 8505ae5..5e11033 100644 --- a/src/micropub/mod.rs +++ b/src/micropub/mod.rs @@ -1,26 +1,26 @@ use std::collections::HashMap; +use std::sync::Arc; use url::Url; use util::NormalizedPost; -use std::sync::Arc; use crate::database::{MicropubChannel, Storage, StorageError}; use crate::indieauth::backend::AuthBackend; use crate::indieauth::User; use crate::micropub::util::form_to_mf2_json; -use axum::extract::{FromRef, Query, State}; use axum::body::Body as BodyStream; +use axum::extract::{FromRef, Query, State}; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; use axum_extra::extract::Host; use axum_extra::headers::ContentType; -use axum::response::{IntoResponse, Response}; use axum_extra::TypedHeader; -use axum::http::StatusCode; +use kittybox_indieauth::{Scope, TokenData}; +use kittybox_util::micropub::{Error as MicropubError, ErrorKind, QueryType}; use serde::{Deserialize, Serialize}; use serde_json::json; use tokio::sync::Mutex; use tokio::task::JoinSet; use tracing::{debug, error, info, warn}; -use kittybox_indieauth::{Scope, TokenData}; -use kittybox_util::micropub::{Error as MicropubError, ErrorKind, QueryType}; #[derive(Serialize, Deserialize, Debug)] pub struct MicropubQuery { @@ -35,7 +35,7 @@ impl From<StorageError> for MicropubError { crate::database::ErrorKind::NotFound => ErrorKind::NotFound, _ => ErrorKind::InternalServerError, }, - format!("backend error: {}", err) + format!("backend error: {}", err), ) } } @@ -59,7 +59,8 @@ fn populate_reply_context( array .iter() .map(|i| { - let mut item = i.as_str() + let mut item = i + .as_str() .and_then(|i| i.parse::<Url>().ok()) .and_then(|url| ctxs.get(&url)) .and_then(|ctx| ctx.mf2["items"].get(0)) @@ -69,7 +70,12 @@ fn populate_reply_context( if item.is_object() && (i != &item) { if let Some(props) = item["properties"].as_object_mut() { // Fixup the item: if it lacks a URL, add one. - if !props.get("url").and_then(serde_json::Value::as_array).map(|a| !a.is_empty()).unwrap_or(false) { + if !props + .get("url") + .and_then(serde_json::Value::as_array) + .map(|a| !a.is_empty()) + .unwrap_or(false) + { props.insert("url".to_owned(), json!([i.as_str()])); } } @@ -145,11 +151,14 @@ async fn background_processing<D: 'static + Storage>( .get("webmention") .and_then(|i| i.first().cloned()); - dbg!(Some((url.clone(), FetchedPostContext { - url, - mf2: serde_json::to_value(mf2).unwrap(), - webmention - }))) + dbg!(Some(( + url.clone(), + FetchedPostContext { + url, + mf2: serde_json::to_value(mf2).unwrap(), + webmention + } + ))) }) .collect::<HashMap<Url, FetchedPostContext>>() .await @@ -161,7 +170,11 @@ async fn background_processing<D: 'static + Storage>( }; for prop in context_props { if let Some(json) = populate_reply_context(&mf2, prop, &post_contexts) { - update.replace.as_mut().unwrap().insert(prop.to_owned(), json); + update + .replace + .as_mut() + .unwrap() + .insert(prop.to_owned(), json); } } if !update.replace.as_ref().unwrap().is_empty() { @@ -250,7 +263,7 @@ pub(crate) async fn _post<D: 'static + Storage>( if !user.check_scope(&Scope::Create) { return Err(MicropubError::from_static( ErrorKind::InvalidScope, - "Not enough privileges - try acquiring the \"create\" scope." + "Not enough privileges - try acquiring the \"create\" scope.", )); } @@ -264,7 +277,7 @@ pub(crate) async fn _post<D: 'static + Storage>( { return Err(MicropubError::from_static( ErrorKind::Forbidden, - "You're posting to a website that's not yours." + "You're posting to a website that's not yours.", )); } @@ -272,7 +285,7 @@ pub(crate) async fn _post<D: 'static + Storage>( if db.post_exists(&uid).await? { return Err(MicropubError::from_static( ErrorKind::AlreadyExists, - "UID clash was detected, operation aborted." + "UID clash was detected, operation aborted.", )); } // Save the post @@ -309,13 +322,18 @@ pub(crate) async fn _post<D: 'static + Storage>( } } - let reply = - IntoResponse::into_response((StatusCode::ACCEPTED, [("Location", uid.as_str())])); + let reply = IntoResponse::into_response((StatusCode::ACCEPTED, [("Location", uid.as_str())])); #[cfg(not(tokio_unstable))] - let _ = jobset.lock().await.spawn(background_processing(db, mf2, http)); + let _ = jobset + .lock() + .await + .spawn(background_processing(db, mf2, http)); #[cfg(tokio_unstable)] - let _ = jobset.lock().await.build_task() + let _ = jobset + .lock() + .await + .build_task() .name(format!("Kittybox background processing for post {}", uid.as_str()).as_str()) .spawn(background_processing(db, mf2, http)); @@ -333,7 +351,7 @@ enum ActionType { #[serde(untagged)] pub enum MicropubPropertyDeletion { Properties(Vec<String>), - Values(HashMap<String, Vec<serde_json::Value>>) + Values(HashMap<String, Vec<serde_json::Value>>), } #[derive(Serialize, Deserialize)] struct MicropubFormAction { @@ -347,7 +365,7 @@ pub struct MicropubAction { url: String, #[serde(flatten)] #[serde(skip_serializing_if = "Option::is_none")] - update: Option<MicropubUpdate> + update: Option<MicropubUpdate>, } #[derive(Serialize, Deserialize, Debug, Default)] @@ -362,39 +380,43 @@ pub struct MicropubUpdate { impl MicropubUpdate { pub fn check_validity(&self) -> Result<(), MicropubError> { if let Some(add) = &self.add { - if add.iter().map(|(k, _)| k.as_str()).any(|k| { - k.to_lowercase().as_str() == "uid" - }) { + if add + .iter() + .map(|(k, _)| k.as_str()) + .any(|k| k.to_lowercase().as_str() == "uid") + { return Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Update cannot modify the post UID" + "Update cannot modify the post UID", )); } } if let Some(replace) = &self.replace { - if replace.iter().map(|(k, _)| k.as_str()).any(|k| { - k.to_lowercase().as_str() == "uid" - }) { + if replace + .iter() + .map(|(k, _)| k.as_str()) + .any(|k| k.to_lowercase().as_str() == "uid") + { return Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Update cannot modify the post UID" + "Update cannot modify the post UID", )); } } let iter = match &self.delete { Some(MicropubPropertyDeletion::Properties(keys)) => { Some(Box::new(keys.iter().map(|k| k.as_str())) as Box<dyn Iterator<Item = &str>>) - }, + } Some(MicropubPropertyDeletion::Values(map)) => { Some(Box::new(map.iter().map(|(k, _)| k.as_str())) as Box<dyn Iterator<Item = &str>>) - }, + } None => None, }; if let Some(mut iter) = iter { if iter.any(|k| k.to_lowercase().as_str() == "uid") { return Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Update cannot modify the post UID" + "Update cannot modify the post UID", )); } } @@ -412,8 +434,9 @@ impl MicropubUpdate { } else if let Some(MicropubPropertyDeletion::Values(ref delete)) = self.delete { if let Some(props) = post["properties"].as_object_mut() { for (key, values) in delete { - if let Some(prop) = props.get_mut(key).and_then(serde_json::Value::as_array_mut) { - prop.retain(|v| { values.iter().all(|i| i != v) }) + if let Some(prop) = props.get_mut(key).and_then(serde_json::Value::as_array_mut) + { + prop.retain(|v| values.iter().all(|i| i != v)) } } } @@ -428,7 +451,10 @@ impl MicropubUpdate { if let Some(add) = self.add { if let Some(props) = post["properties"].as_object_mut() { for (key, value) in add { - if let Some(prop) = props.get_mut(&key).and_then(serde_json::Value::as_array_mut) { + if let Some(prop) = props + .get_mut(&key) + .and_then(serde_json::Value::as_array_mut) + { prop.extend_from_slice(value.as_slice()); } else { props.insert(key, serde_json::Value::Array(value)); @@ -445,7 +471,7 @@ impl From<MicropubFormAction> for MicropubAction { Self { action: a.action, url: a.url, - update: None + update: None, } } } @@ -458,10 +484,12 @@ async fn post_action<D: Storage, A: AuthBackend>( ) -> Result<(), MicropubError> { let uri = match action.url.parse::<hyper::Uri>() { Ok(uri) => uri, - Err(err) => return Err(MicropubError::new( - ErrorKind::InvalidRequest, - format!("url parsing error: {}", err) - )) + Err(err) => { + return Err(MicropubError::new( + ErrorKind::InvalidRequest, + format!("url parsing error: {}", err), + )) + } }; if uri.authority().unwrap() @@ -475,7 +503,7 @@ async fn post_action<D: Storage, A: AuthBackend>( { return Err(MicropubError::from_static( ErrorKind::Forbidden, - "Don't tamper with others' posts!" + "Don't tamper with others' posts!", )); } @@ -484,7 +512,7 @@ async fn post_action<D: Storage, A: AuthBackend>( if !user.check_scope(&Scope::Delete) { return Err(MicropubError::from_static( ErrorKind::InvalidScope, - "You need a \"delete\" scope for this." + "You need a \"delete\" scope for this.", )); } @@ -494,7 +522,7 @@ async fn post_action<D: Storage, A: AuthBackend>( if !user.check_scope(&Scope::Update) { return Err(MicropubError::from_static( ErrorKind::InvalidScope, - "You need an \"update\" scope for this." + "You need an \"update\" scope for this.", )); } @@ -503,7 +531,7 @@ async fn post_action<D: Storage, A: AuthBackend>( } else { return Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Update request is not set." + "Update request is not set.", )); }; @@ -555,7 +583,7 @@ async fn dispatch_body( } else { Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Invalid JSON object passed." + "Invalid JSON object passed.", )) } } else if content_type == ContentType::form_url_encoded() { @@ -566,7 +594,7 @@ async fn dispatch_body( } else { Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Invalid form-encoded data. Try h=entry&content=Hello!" + "Invalid form-encoded data. Try h=entry&content=Hello!", )) } } else { @@ -605,7 +633,10 @@ pub(crate) async fn post<D: Storage + 'static, A: AuthBackend>( #[tracing::instrument(skip(db))] pub(crate) async fn query<D: Storage, A: AuthBackend>( State(db): State<D>, - query: Result<Query<MicropubQuery>, <Query<MicropubQuery> as axum::extract::FromRequestParts<()>>::Rejection>, + query: Result< + Query<MicropubQuery>, + <Query<MicropubQuery> as axum::extract::FromRequestParts<()>>::Rejection, + >, Host(host): Host, user: User<A>, ) -> axum::response::Response { @@ -616,8 +647,9 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( } else { return MicropubError::from_static( ErrorKind::InvalidRequest, - "Invalid query provided. Try ?q=config to see what you can do." - ).into_response(); + "Invalid query provided. Try ?q=config to see what you can do.", + ) + .into_response(); }; if axum::http::Uri::try_from(user.me.as_str()) @@ -630,7 +662,7 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( ErrorKind::NotAuthorized, "This website doesn't belong to you.", ) - .into_response(); + .into_response(); } // TODO: consider replacing by `user.me.authority()`? @@ -644,7 +676,7 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( ErrorKind::InternalServerError, format!("Error fetching channels: {}", err), ) - .into_response() + .into_response() } }; @@ -654,35 +686,36 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( QueryType::Config, QueryType::Channel, QueryType::SyndicateTo, - QueryType::Category + QueryType::Category, ], channels: Some(channels), syndicate_to: None, media_endpoint: Some(user.me.join("/.kittybox/media").unwrap()), other: { let mut map = std::collections::HashMap::new(); - map.insert("kittybox_authority".to_string(), serde_json::Value::String(user.me.to_string())); + map.insert( + "kittybox_authority".to_string(), + serde_json::Value::String(user.me.to_string()), + ); map - } + }, }) - .into_response() + .into_response() } QueryType::Source => { match query.url { - Some(url) => { - match db.get_post(&url).await { - Ok(some) => match some { - Some(post) => axum::response::Json(&post).into_response(), - None => MicropubError::from_static( - ErrorKind::NotFound, - "The specified MF2 object was not found in database.", - ) - .into_response(), - }, - Err(err) => MicropubError::from(err).into_response(), - } - } + Some(url) => match db.get_post(&url).await { + Ok(some) => match some { + Some(post) => axum::response::Json(&post).into_response(), + None => MicropubError::from_static( + ErrorKind::NotFound, + "The specified MF2 object was not found in database.", + ) + .into_response(), + }, + Err(err) => MicropubError::from(err).into_response(), + }, None => { // Here, one should probably attempt to query at least the main feed and collect posts // Using a pre-made query function can't be done because it does unneeded filtering @@ -691,7 +724,7 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( ErrorKind::InvalidRequest, "Querying for post list is not implemented yet.", ) - .into_response() + .into_response() } } } @@ -701,46 +734,45 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( ErrorKind::InternalServerError, format!("error fetching channels: backend error: {}", err), ) - .into_response(), + .into_response(), }, QueryType::SyndicateTo => { axum::response::Json(json!({ "syndicate-to": [] })).into_response() - }, + } QueryType::Category => { let categories = match db.categories(user_domain).await { Ok(categories) => categories, Err(err) => { return MicropubError::new( ErrorKind::InternalServerError, - format!("error fetching categories: backend error: {}", err) - ).into_response() + format!("error fetching categories: backend error: {}", err), + ) + .into_response() } }; axum::response::Json(json!({ "categories": categories })).into_response() - }, - QueryType::Unknown(q) => return MicropubError::new( - ErrorKind::InvalidRequest, - format!("Invalid query: {}", q) - ).into_response(), + } + QueryType::Unknown(q) => { + return MicropubError::new(ErrorKind::InvalidRequest, format!("Invalid query: {}", q)) + .into_response() + } } } - pub fn router<A, S, St: Send + Sync + Clone + 'static>() -> axum::routing::MethodRouter<St> where S: Storage + FromRef<St> + 'static, A: AuthBackend + FromRef<St>, reqwest_middleware::ClientWithMiddleware: FromRef<St>, - Arc<Mutex<JoinSet<()>>>: FromRef<St> + Arc<Mutex<JoinSet<()>>>: FromRef<St>, { axum::routing::get(query::<S, A>) .post(post::<S, A>) - .layer::<_, _>(tower_http::cors::CorsLayer::new() - .allow_methods([ - axum::http::Method::GET, - axum::http::Method::POST, - ]) - .allow_origin(tower_http::cors::Any)) + .layer::<_, _>( + tower_http::cors::CorsLayer::new() + .allow_methods([axum::http::Method::GET, axum::http::Method::POST]) + .allow_origin(tower_http::cors::Any), + ) } #[cfg(test)] @@ -765,16 +797,19 @@ impl MicropubQuery { mod tests { use std::sync::Arc; - use crate::{database::Storage, micropub::{util::NormalizedPost, MicropubError}}; + use crate::{ + database::Storage, + micropub::{util::NormalizedPost, MicropubError}, + }; use bytes::Bytes; use futures::StreamExt; use serde_json::json; use tokio::sync::Mutex; use super::FetchedPostContext; - use kittybox_indieauth::{Scopes, Scope, TokenData}; use axum::extract::State; use axum_extra::extract::Host; + use kittybox_indieauth::{Scope, Scopes, TokenData}; #[test] fn test_populate_reply_context() { @@ -801,16 +836,27 @@ mod tests { } }); let fetched_ctx_url: url::Url = "https://fireburn.ru/posts/example".parse().unwrap(); - let reply_contexts = vec![(fetched_ctx_url.clone(), FetchedPostContext { - url: fetched_ctx_url.clone(), - mf2: json!({ "items": [test_ctx] }), - webmention: None, - })].into_iter().collect(); + let reply_contexts = vec![( + fetched_ctx_url.clone(), + FetchedPostContext { + url: fetched_ctx_url.clone(), + mf2: json!({ "items": [test_ctx] }), + webmention: None, + }, + )] + .into_iter() + .collect(); let like_of = super::populate_reply_context(&mf2, "like-of", &reply_contexts).unwrap(); - assert_eq!(like_of[0]["properties"]["content"], test_ctx["properties"]["content"]); - assert_eq!(like_of[0]["properties"]["url"][0].as_str().unwrap(), reply_contexts[&fetched_ctx_url].url.as_str()); + assert_eq!( + like_of[0]["properties"]["content"], + test_ctx["properties"]["content"] + ); + assert_eq!( + like_of[0]["properties"]["url"][0].as_str().unwrap(), + reply_contexts[&fetched_ctx_url].url.as_str() + ); assert_eq!(like_of[1], already_expanded_reply_ctx); assert_eq!(like_of[2], "https://fireburn.ru/posts/non-existent"); @@ -830,20 +876,21 @@ mod tests { me: "https://localhost:8080/".parse().unwrap(), client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), scope: Scopes::new(vec![Scope::Profile]), - iat: None, exp: None + iat: None, + exp: None, }; let NormalizedPost { id, post } = super::normalize_mf2(post, &user); let err = super::_post( - &user, id, post, db.clone(), - reqwest_middleware::ClientWithMiddleware::new( - reqwest::Client::new(), - Box::default() - ), - Arc::new(Mutex::new(tokio::task::JoinSet::new())) + &user, + id, + post, + db.clone(), + reqwest_middleware::ClientWithMiddleware::new(reqwest::Client::new(), Box::default()), + Arc::new(Mutex::new(tokio::task::JoinSet::new())), ) - .await - .unwrap_err(); + .await + .unwrap_err(); assert_eq!(err.error, super::ErrorKind::InvalidScope); @@ -866,21 +913,27 @@ mod tests { let user = TokenData { me: "https://aaronparecki.com/".parse().unwrap(), client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), - scope: Scopes::new(vec![Scope::Profile, Scope::Create, Scope::Update, Scope::Media]), - iat: None, exp: None + scope: Scopes::new(vec![ + Scope::Profile, + Scope::Create, + Scope::Update, + Scope::Media, + ]), + iat: None, + exp: None, }; let NormalizedPost { id, post } = super::normalize_mf2(post, &user); let err = super::_post( - &user, id, post, db.clone(), - reqwest_middleware::ClientWithMiddleware::new( - reqwest::Client::new(), - Box::default() - ), - Arc::new(Mutex::new(tokio::task::JoinSet::new())) + &user, + id, + post, + db.clone(), + reqwest_middleware::ClientWithMiddleware::new(reqwest::Client::new(), Box::default()), + Arc::new(Mutex::new(tokio::task::JoinSet::new())), ) - .await - .unwrap_err(); + .await + .unwrap_err(); assert_eq!(err.error, super::ErrorKind::Forbidden); @@ -902,20 +955,21 @@ mod tests { me: "https://localhost:8080/".parse().unwrap(), client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), scope: Scopes::new(vec![Scope::Profile, Scope::Create]), - iat: None, exp: None + iat: None, + exp: None, }; let NormalizedPost { id, post } = super::normalize_mf2(post, &user); let res = super::_post( - &user, id, post, db.clone(), - reqwest_middleware::ClientWithMiddleware::new( - reqwest::Client::new(), - Box::default() - ), - Arc::new(Mutex::new(tokio::task::JoinSet::new())) + &user, + id, + post, + db.clone(), + reqwest_middleware::ClientWithMiddleware::new(reqwest::Client::new(), Box::default()), + Arc::new(Mutex::new(tokio::task::JoinSet::new())), ) - .await - .unwrap(); + .await + .unwrap(); assert!(res.headers().contains_key("Location")); let location = res.headers().get("Location").unwrap(); @@ -938,10 +992,17 @@ mod tests { TokenData { me: "https://fireburn.ru/".parse().unwrap(), client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), - scope: Scopes::new(vec![Scope::Profile, Scope::Create, Scope::Update, Scope::Media]), - iat: None, exp: None - }, std::marker::PhantomData - ) + scope: Scopes::new(vec![ + Scope::Profile, + Scope::Create, + Scope::Update, + Scope::Media, + ]), + iat: None, + exp: None, + }, + std::marker::PhantomData, + ), ) .await; @@ -954,7 +1015,10 @@ mod tests { .into_iter() .map(Result::unwrap) .by_ref() - .fold(Vec::new(), |mut a, i| { a.extend(i); a}); + .fold(Vec::new(), |mut a, i| { + a.extend(i); + a + }); let json: MicropubError = serde_json::from_slice(&body as &[u8]).unwrap(); assert_eq!(json.error, super::ErrorKind::NotAuthorized); } diff --git a/src/micropub/util.rs b/src/micropub/util.rs index 99aec8e..8c5d5e9 100644 --- a/src/micropub/util.rs +++ b/src/micropub/util.rs @@ -1,7 +1,7 @@ use crate::database::Storage; -use kittybox_indieauth::TokenData; use chrono::prelude::*; use core::iter::Iterator; +use kittybox_indieauth::TokenData; use newbase60::num_to_sxg; use serde_json::json; use std::convert::TryInto; @@ -35,7 +35,7 @@ fn reset_dt(post: &mut serde_json::Value) -> DateTime<FixedOffset> { pub struct NormalizedPost { pub id: String, - pub post: serde_json::Value + pub post: serde_json::Value, } pub fn normalize_mf2(mut body: serde_json::Value, user: &TokenData) -> NormalizedPost { @@ -142,12 +142,12 @@ pub fn normalize_mf2(mut body: serde_json::Value, user: &TokenData) -> Normalize } // If there is no explicit channels, and the post is not marked as "unlisted", // post it to one of the default channels that makes sense for the post type. - if body["properties"]["channel"][0].as_str().is_none() && (!body["properties"]["visibility"] - .as_array() - .map(|v| v.contains( - &serde_json::Value::String("unlisted".to_owned()) - )).unwrap_or(false) - ) { + if body["properties"]["channel"][0].as_str().is_none() + && (!body["properties"]["visibility"] + .as_array() + .map(|v| v.contains(&serde_json::Value::String("unlisted".to_owned()))) + .unwrap_or(false)) + { match body["type"][0].as_str() { Some("h-entry") => { // Set the channel to the main channel... @@ -249,7 +249,7 @@ mod tests { client_id: "https://quill.p3k.io/".parse().unwrap(), scope: kittybox_indieauth::Scopes::new(vec![kittybox_indieauth::Scope::Create]), exp: Some(u64::MAX), - iat: Some(0) + iat: Some(0), } } @@ -279,12 +279,15 @@ mod tests { } }); - let NormalizedPost { id: _, post: normalized } = normalize_mf2( - mf2.clone(), - &token_data() - ); + let NormalizedPost { + id: _, + post: normalized, + } = normalize_mf2(mf2.clone(), &token_data()); assert!( - normalized["properties"]["channel"].as_array().unwrap_or(&vec![]).is_empty(), + normalized["properties"]["channel"] + .as_array() + .unwrap_or(&vec![]) + .is_empty(), "Returned post was added to a channel despite the `unlisted` visibility" ); } @@ -300,10 +303,10 @@ mod tests { } }); - let NormalizedPost { id, post: normalized } = normalize_mf2( - mf2.clone(), - &token_data(), - ); + let NormalizedPost { + id, + post: normalized, + } = normalize_mf2(mf2.clone(), &token_data()); assert_eq!( normalized["properties"]["uid"][0], mf2["properties"]["uid"][0], "UID was replaced" @@ -325,10 +328,10 @@ mod tests { } }); - let NormalizedPost { id: _, post: normalized } = normalize_mf2( - mf2.clone(), - &token_data(), - ); + let NormalizedPost { + id: _, + post: normalized, + } = normalize_mf2(mf2.clone(), &token_data()); assert_eq!( normalized["properties"]["channel"], @@ -347,10 +350,10 @@ mod tests { } }); - let NormalizedPost { id: _, post: normalized } = normalize_mf2( - mf2.clone(), - &token_data(), - ); + let NormalizedPost { + id: _, + post: normalized, + } = normalize_mf2(mf2.clone(), &token_data()); assert_eq!( normalized["properties"]["channel"][0], @@ -367,10 +370,7 @@ mod tests { } }); - let NormalizedPost { id, post } = normalize_mf2( - mf2, - &token_data(), - ); + let NormalizedPost { id, post } = normalize_mf2(mf2, &token_data()); assert_eq!( post["properties"]["published"] .as_array() @@ -432,10 +432,7 @@ mod tests { }, }); - let NormalizedPost { id: _, post } = normalize_mf2( - mf2, - &token_data(), - ); + let NormalizedPost { id: _, post } = normalize_mf2(mf2, &token_data()); assert!( post["properties"]["url"] .as_array() @@ -461,10 +458,7 @@ mod tests { } }); - let NormalizedPost { id, post } = normalize_mf2( - mf2, - &token_data(), - ); + let NormalizedPost { id, post } = normalize_mf2(mf2, &token_data()); assert_eq!( post["properties"]["uid"][0], id, "UID of a post and its supposed location don't match" diff --git a/src/webmentions/check.rs b/src/webmentions/check.rs index 683cc6b..380f4db 100644 --- a/src/webmentions/check.rs +++ b/src/webmentions/check.rs @@ -1,7 +1,7 @@ -use std::rc::Rc; -use microformats::types::PropertyValue; use html5ever::{self, tendril::TendrilSink}; use kittybox_util::MentionType; +use microformats::types::PropertyValue; +use std::rc::Rc; // TODO: replace. mod rcdom; @@ -17,7 +17,11 @@ pub enum Error { } #[tracing::instrument] -pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url::Url, link: &url::Url) -> Result<Option<(MentionType, serde_json::Value)>, Error> { +pub fn check_mention( + document: impl AsRef<str> + std::fmt::Debug, + base_url: &url::Url, + link: &url::Url, +) -> Result<Option<(MentionType, serde_json::Value)>, Error> { tracing::debug!("Parsing MF2 markup..."); // First, check the document for MF2 markup let document = microformats::from_html(document.as_ref(), base_url.clone())?; @@ -29,8 +33,10 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url tracing::debug!("Processing item: {:?}", item); for (prop, interaction_type) in [ - ("in-reply-to", MentionType::Reply), ("like-of", MentionType::Like), - ("bookmark-of", MentionType::Bookmark), ("repost-of", MentionType::Repost) + ("in-reply-to", MentionType::Reply), + ("like-of", MentionType::Like), + ("bookmark-of", MentionType::Bookmark), + ("repost-of", MentionType::Repost), ] { if let Some(propvals) = item.properties.get(prop) { tracing::debug!("Has a u-{} property", prop); @@ -38,7 +44,10 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url if let PropertyValue::Url(url) = val { if url == link { tracing::debug!("URL matches! Webmention is valid"); - return Ok(Some((interaction_type, serde_json::to_value(item).unwrap()))) + return Ok(Some(( + interaction_type, + serde_json::to_value(item).unwrap(), + ))); } } } @@ -46,7 +55,9 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url } // Process `content` tracing::debug!("Processing e-content..."); - if let Some(PropertyValue::Fragment(content)) = item.properties.get("content") + if let Some(PropertyValue::Fragment(content)) = item + .properties + .get("content") .map(Vec::as_slice) .unwrap_or_default() .first() @@ -65,7 +76,8 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url // iteration of the loop. // // Empty list means all nodes were processed. - let mut unprocessed_nodes: Vec<Rc<rcdom::Node>> = root.children.borrow().iter().cloned().collect(); + let mut unprocessed_nodes: Vec<Rc<rcdom::Node>> = + root.children.borrow().iter().cloned().collect(); while !unprocessed_nodes.is_empty() { // "Take" the list out of its memory slot, replace it with an empty list let nodes = std::mem::take(&mut unprocessed_nodes); @@ -74,15 +86,23 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url // Add children nodes to the list for the next iteration unprocessed_nodes.extend(node.children.borrow().iter().cloned()); - if let rcdom::NodeData::Element { ref name, ref attrs, .. } = node.data { + if let rcdom::NodeData::Element { + ref name, + ref attrs, + .. + } = node.data + { // If it's not `<a>`, skip it - if name.local != *"a" { continue; } + if name.local != *"a" { + continue; + } let mut is_mention: bool = false; for attr in attrs.borrow().iter() { if attr.name.local == *"rel" { // Don't count `rel="nofollow"` links — a web crawler should ignore them // and so for purposes of driving visitors they are useless - if attr.value + if attr + .value .as_ref() .split([',', ' ']) .any(|v| v == "nofollow") @@ -92,7 +112,9 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url } } // if it's not `<a href="...">`, skip it - if attr.name.local != *"href" { continue; } + if attr.name.local != *"href" { + continue; + } // Be forgiving in parsing URLs, and resolve them against the base URL if let Ok(url) = base_url.join(attr.value.as_ref()) { if &url == link { @@ -101,12 +123,14 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url } } if is_mention { - return Ok(Some((MentionType::Mention, serde_json::to_value(item).unwrap()))); + return Ok(Some(( + MentionType::Mention, + serde_json::to_value(item).unwrap(), + ))); } } } } - } } diff --git a/src/webmentions/mod.rs b/src/webmentions/mod.rs index 91b274b..57f9a57 100644 --- a/src/webmentions/mod.rs +++ b/src/webmentions/mod.rs @@ -1,9 +1,14 @@ -use axum::{extract::{FromRef, State}, response::{IntoResponse, Response}, routing::post, Form}; use axum::http::StatusCode; +use axum::{ + extract::{FromRef, State}, + response::{IntoResponse, Response}, + routing::post, + Form, +}; use tracing::error; -use crate::database::{Storage, StorageError}; use self::queue::JobQueue; +use crate::database::{Storage, StorageError}; pub mod queue; #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -24,40 +29,46 @@ async fn accept_webmention<Q: JobQueue<Webmention>>( Form(webmention): Form<Webmention>, ) -> Response { if let Err(err) = webmention.source.parse::<url::Url>() { - return (StatusCode::BAD_REQUEST, err.to_string()).into_response() + return (StatusCode::BAD_REQUEST, err.to_string()).into_response(); } if let Err(err) = webmention.target.parse::<url::Url>() { - return (StatusCode::BAD_REQUEST, err.to_string()).into_response() + return (StatusCode::BAD_REQUEST, err.to_string()).into_response(); } match queue.put(&webmention).await { Ok(_id) => StatusCode::ACCEPTED.into_response(), - Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, [ - ("Content-Type", "text/plain") - ], err.to_string()).into_response() + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + [("Content-Type", "text/plain")], + err.to_string(), + ) + .into_response(), } } -pub fn router<St: Clone + Send + Sync + 'static, Q: JobQueue<Webmention> + FromRef<St>>() -> axum::Router<St> { - axum::Router::new() - .route("/.kittybox/webmention", post(accept_webmention::<Q>)) +pub fn router<St: Clone + Send + Sync + 'static, Q: JobQueue<Webmention> + FromRef<St>>( +) -> axum::Router<St> { + axum::Router::new().route("/.kittybox/webmention", post(accept_webmention::<Q>)) } #[derive(thiserror::Error, Debug)] pub enum SupervisorError { #[error("the task was explicitly cancelled")] - Cancelled + Cancelled, } -pub type SupervisedTask = tokio::task::JoinHandle<Result<std::convert::Infallible, SupervisorError>>; +pub type SupervisedTask = + tokio::task::JoinHandle<Result<std::convert::Infallible, SupervisorError>>; -pub fn supervisor<E, A, F>(mut f: F, cancellation_token: tokio_util::sync::CancellationToken) -> SupervisedTask +pub fn supervisor<E, A, F>( + mut f: F, + cancellation_token: tokio_util::sync::CancellationToken, +) -> SupervisedTask where E: std::error::Error + std::fmt::Debug + Send + 'static, A: std::future::Future<Output = Result<std::convert::Infallible, E>> + Send + 'static, - F: FnMut() -> A + Send + 'static + F: FnMut() -> A + Send + 'static, { - let supervisor_future = async move { loop { // Don't spawn the task if we are already cancelled, but @@ -65,7 +76,7 @@ where // crashed and we immediately received a cancellation // request after noticing the crashed task) if cancellation_token.is_cancelled() { - return Err(SupervisorError::Cancelled) + return Err(SupervisorError::Cancelled); } let task = tokio::task::spawn(f()); tokio::select! { @@ -87,7 +98,13 @@ where return tokio::task::spawn(supervisor_future); #[cfg(tokio_unstable)] return tokio::task::Builder::new() - .name(format!("supervisor for background task {}", std::any::type_name::<A>()).as_str()) + .name( + format!( + "supervisor for background task {}", + std::any::type_name::<A>() + ) + .as_str(), + ) .spawn(supervisor_future) .unwrap(); } @@ -99,39 +116,55 @@ enum Error<Q: std::error::Error + std::fmt::Debug + Send + 'static> { #[error("queue error: {0}")] Queue(#[from] Q), #[error("storage error: {0}")] - Storage(StorageError) + Storage(StorageError), } -async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 'static>(queue: Q, db: S, http: reqwest_middleware::ClientWithMiddleware) -> Result<std::convert::Infallible, Error<Q::Error>> { - use futures_util::StreamExt; +async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 'static>( + queue: Q, + db: S, + http: reqwest_middleware::ClientWithMiddleware, +) -> Result<std::convert::Infallible, Error<Q::Error>> { use self::queue::Job; + use futures_util::StreamExt; let mut stream = queue.into_stream().await?; while let Some(item) = stream.next().await.transpose()? { let job = item.job(); let (source, target) = ( job.source.parse::<url::Url>().unwrap(), - job.target.parse::<url::Url>().unwrap() + job.target.parse::<url::Url>().unwrap(), ); let (code, text) = match http.get(source.clone()).send().await { Ok(response) => { let code = response.status(); - if ![StatusCode::OK, StatusCode::GONE].iter().any(|i| i == &code) { - error!("error processing webmention: webpage fetch returned {}", code); + if ![StatusCode::OK, StatusCode::GONE] + .iter() + .any(|i| i == &code) + { + error!( + "error processing webmention: webpage fetch returned {}", + code + ); continue; } match response.text().await { Ok(text) => (code, text), Err(err) => { - error!("error processing webmention: error fetching webpage text: {}", err); - continue + error!( + "error processing webmention: error fetching webpage text: {}", + err + ); + continue; } } } Err(err) => { - error!("error processing webmention: error requesting webpage: {}", err); - continue + error!( + "error processing webmention: error requesting webpage: {}", + err + ); + continue; } }; @@ -150,7 +183,10 @@ async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 's continue; } Err(err) => { - error!("error processing webmention: error checking webmention: {}", err); + error!( + "error processing webmention: error checking webmention: {}", + err + ); continue; } }; @@ -158,31 +194,47 @@ async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 's { mention["type"] = serde_json::json!(["h-cite"]); - if !mention["properties"].as_object().unwrap().contains_key("uid") { - let url = mention["properties"]["url"][0].as_str().unwrap_or_else(|| target.as_str()).to_owned(); + if !mention["properties"] + .as_object() + .unwrap() + .contains_key("uid") + { + let url = mention["properties"]["url"][0] + .as_str() + .unwrap_or_else(|| target.as_str()) + .to_owned(); let props = mention["properties"].as_object_mut().unwrap(); - props.insert("uid".to_owned(), serde_json::Value::Array( - vec![serde_json::Value::String(url)]) + props.insert( + "uid".to_owned(), + serde_json::Value::Array(vec![serde_json::Value::String(url)]), ); } } - db.add_or_update_webmention(target.as_str(), mention_type, mention).await.map_err(Error::<Q::Error>::Storage)?; + db.add_or_update_webmention(target.as_str(), mention_type, mention) + .await + .map_err(Error::<Q::Error>::Storage)?; } } unreachable!() } -pub fn supervised_webmentions_task<St: Send + Sync + 'static, S: Storage + FromRef<St> + 'static, Q: JobQueue<Webmention> + FromRef<St> + 'static>( +pub fn supervised_webmentions_task< + St: Send + Sync + 'static, + S: Storage + FromRef<St> + 'static, + Q: JobQueue<Webmention> + FromRef<St> + 'static, +>( state: &St, - cancellation_token: tokio_util::sync::CancellationToken + cancellation_token: tokio_util::sync::CancellationToken, ) -> SupervisedTask -where reqwest_middleware::ClientWithMiddleware: FromRef<St> +where + reqwest_middleware::ClientWithMiddleware: FromRef<St>, { let queue = Q::from_ref(state); let storage = S::from_ref(state); let http = reqwest_middleware::ClientWithMiddleware::from_ref(state); - supervisor::<Error<Q::Error>, _, _>(move || process_webmentions_from_queue( - queue.clone(), storage.clone(), http.clone() - ), cancellation_token) + supervisor::<Error<Q::Error>, _, _>( + move || process_webmentions_from_queue(queue.clone(), storage.clone(), http.clone()), + cancellation_token, + ) } diff --git a/src/webmentions/queue.rs b/src/webmentions/queue.rs index 52bcdfa..a33de1a 100644 --- a/src/webmentions/queue.rs +++ b/src/webmentions/queue.rs @@ -6,7 +6,7 @@ use super::Webmention; static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("./migrations/webmention"); -pub use kittybox_util::queue::{JobQueue, JobItem, Job, JobStream}; +pub use kittybox_util::queue::{Job, JobItem, JobQueue, JobStream}; pub trait PostgresJobItem: JobItem + sqlx::FromRow<'static, sqlx::postgres::PgRow> { const DATABASE_NAME: &'static str; @@ -17,7 +17,7 @@ pub trait PostgresJobItem: JobItem + sqlx::FromRow<'static, sqlx::postgres::PgRo struct PostgresJobRow<T: PostgresJobItem> { id: Uuid, #[sqlx(flatten)] - job: T + job: T, } #[derive(Debug)] @@ -29,7 +29,6 @@ pub struct PostgresJob<T: PostgresJobItem> { runtime_handle: tokio::runtime::Handle, } - impl<T: PostgresJobItem> Drop for PostgresJob<T> { // This is an emulation of "async drop" — the struct retains a // runtime handle, which it uses to block on a future that does @@ -87,7 +86,9 @@ impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJob<Webmention> { fn job(&self) -> &Webmention { &self.job } - async fn done(mut self) -> Result<(), <PostgresJobQueue<Webmention> as JobQueue<Webmention>>::Error> { + async fn done( + mut self, + ) -> Result<(), <PostgresJobQueue<Webmention> as JobQueue<Webmention>>::Error> { tracing::debug!("Deleting {} from the job queue", self.id); sqlx::query("DELETE FROM kittybox_webmention.incoming_webmention_queue WHERE id = $1") .bind(self.id) @@ -100,13 +101,13 @@ impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJob<Webmention> { pub struct PostgresJobQueue<T> { db: sqlx::PgPool, - _phantom: std::marker::PhantomData<T> + _phantom: std::marker::PhantomData<T>, } impl<T> Clone for PostgresJobQueue<T> { fn clone(&self) -> Self { Self { db: self.db.clone(), - _phantom: std::marker::PhantomData + _phantom: std::marker::PhantomData, } } } @@ -120,15 +121,21 @@ impl PostgresJobQueue<Webmention> { sqlx::postgres::PgPoolOptions::new() .max_connections(50) .connect_with(options) - .await? - ).await - + .await?, + ) + .await } pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result<Self, sqlx::Error> { - db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox_webmention")).await?; + db.execute(sqlx::query( + "CREATE SCHEMA IF NOT EXISTS kittybox_webmention", + )) + .await?; MIGRATOR.run(&db).await?; - Ok(Self { db, _phantom: std::marker::PhantomData }) + Ok(Self { + db, + _phantom: std::marker::PhantomData, + }) } } @@ -180,13 +187,14 @@ impl JobQueue<Webmention> for PostgresJobQueue<Webmention> { Some(item) => return Ok(Some((item, ()))), None => { listener.lock().await.recv().await?; - continue + continue; } } } } } - }).boxed(); + }) + .boxed(); Ok(stream) } @@ -196,7 +204,7 @@ impl JobQueue<Webmention> for PostgresJobQueue<Webmention> { mod tests { use std::sync::Arc; - use super::{Webmention, PostgresJobQueue, Job, JobQueue, MIGRATOR}; + use super::{Job, JobQueue, PostgresJobQueue, Webmention, MIGRATOR}; use futures_util::StreamExt; #[sqlx::test(migrator = "MIGRATOR")] @@ -204,7 +212,7 @@ mod tests { async fn test_webmention_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> { let test_webmention = Webmention { source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(), - target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned() + target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned(), }; let queue = PostgresJobQueue::<Webmention>::from_pool(pool).await?; @@ -236,7 +244,7 @@ mod tests { match queue.get_one().await? { Some(item) => panic!("Unexpected item {:?} returned from job queue!", item), - None => Ok(()) + None => Ok(()), } } @@ -245,7 +253,7 @@ mod tests { async fn test_no_hangups_in_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> { let test_webmention = Webmention { source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(), - target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned() + target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned(), }; let queue = PostgresJobQueue::<Webmention>::from_pool(pool.clone()).await?; @@ -272,18 +280,18 @@ mod tests { } }); } - tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()).await.unwrap_err(); + tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()) + .await + .unwrap_err(); - let future = tokio::task::spawn( - tokio::time::timeout( - std::time::Duration::from_secs(10), async move { - stream.next().await.unwrap().unwrap() - } - ) - ); + let future = tokio::task::spawn(tokio::time::timeout( + std::time::Duration::from_secs(10), + async move { stream.next().await.unwrap().unwrap() }, + )); // Let the other task drop the guard it is holding barrier.wait().await; - let mut guard = future.await + let mut guard = future + .await .expect("Timeout on fetching item") .expect("Job queue error"); assert_eq!(guard.job(), &test_webmention); |