diff options
40 files changed, 2893 insertions, 1791 deletions
diff --git a/.zed/settings.json b/.zed/settings.json new file mode 100644 index 0000000..09cfd7b --- /dev/null +++ b/.zed/settings.json @@ -0,0 +1,8 @@ +// Folder-specific settings +// +// For a full list of overridable settings, and general information on folder-specific settings, +// see the documentation: https://zed.dev/docs/configuring-zed#settings-files +{ + "format_on_save": "on", + "languages": { "Rust": { "format_on_save": "language_server" } } +} diff --git a/Cargo.toml b/Cargo.toml index 3e085db..141016e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,11 @@ unexpected_cfgs = { level = "warn", check-cfg = ['cfg(tokio_unstable)'] } [features] default = ["rustls", "postgres"] webauthn = ["openssl", "dep:webauthn"] -openssl = ["reqwest/native-tls", "reqwest/native-tls-alpn", "sqlx/tls-native-tls"] +openssl = [ + "reqwest/native-tls", + "reqwest/native-tls-alpn", + "sqlx/tls-native-tls", +] rustls = ["reqwest/rustls-tls-webpki-roots", "sqlx/tls-rustls"] cli = ["dep:clap", "dep:anyhow"] postgres = ["sqlx", "kittybox-util/sqlx"] @@ -123,7 +127,11 @@ wiremock = "0.6.2" anyhow = { version = "1.0.95", optional = true } argon2 = { version = "0.5.3", features = ["std"] } axum = { workspace = true, features = ["multipart", "json", "form", "macros"] } -axum-extra = { version = "0.10.0", features = ["cookie", "cookie-signed", "typed-header"] } +axum-extra = { version = "0.10.0", features = [ + "cookie", + "cookie-signed", + "typed-header", +] } bytes = "1.9.0" chrono = { workspace = true } clap = { workspace = true, features = ["derive"], optional = true } @@ -132,7 +140,9 @@ either = "1.13.0" futures = { workspace = true } futures-util = { workspace = true } html5ever = "=0.27.0" -http-cache-reqwest = { version = "0.15.0", default-features = false, features = ["manager-moka"] } +http-cache-reqwest = { version = "0.15.0", default-features = false, features = [ + "manager-moka", +] } hyper = "1.5.2" lazy_static = "1.5.0" listenfd = "1.0.1" @@ -142,27 +152,51 @@ mime = "0.3.17" newbase60 = "0.1.4" prometheus = { version = "0.13.4", features = ["process"] } rand = { workspace = true } -redis = { version = "0.27.6", features = ["aio", "tokio-comp"], optional = true } +redis = { version = "0.27.6", features = [ + "aio", + "tokio-comp", +], optional = true } relative-path = "1.9.3" -reqwest = { version = "0.12.12", default-features = false, features = ["gzip", "brotli", "json", "stream"] } +reqwest = { version = "0.12.12", default-features = false, features = [ + "gzip", + "brotli", + "json", + "stream", +] } reqwest-middleware = "0.4.0" serde = { workspace = true } serde_json = { workspace = true } serde_urlencoded = { workspace = true } serde_variant = { workspace = true } sha2 = { workspace = true } -sqlparser = { version = "0.53.0", features = ["serde", "serde_json"], optional = true } -sqlx = { workspace = true, features = ["uuid", "chrono", "postgres", "runtime-tokio"], optional = true } +sqlparser = { version = "0.53.0", features = [ + "serde", + "serde_json", +], optional = true } +sqlx = { workspace = true, features = [ + "uuid", + "chrono", + "postgres", + "runtime-tokio", +], optional = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["full", "tracing"] } tokio-stream = { workspace = true, features = ["time", "net"] } tokio-util = { workspace = true, features = ["io-util"] } tower = { workspace = true, features = ["tracing"] } -tower-http = { version = "0.6.2", features = ["trace", "cors", "catch-panic", "sensitive-headers", "set-header"] } +tower-http = { version = "0.6.2", features = [ + "trace", + "cors", + "catch-panic", + "sensitive-headers", + "set-header", +] } tracing = { workspace = true, features = [] } tracing-log = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter", "json"] } tracing-tree = { workspace = true } url = { workspace = true } uuid = { workspace = true, features = ["v4"] } -webauthn = { version = "0.5.0", package = "webauthn-rs", features = ["danger-allow-state-serialisation"], optional = true } +webauthn = { version = "0.5.0", package = "webauthn-rs", features = [ + "danger-allow-state-serialisation", +], optional = true } diff --git a/build.rs b/build.rs index dcfa332..5db39e0 100644 --- a/build.rs +++ b/build.rs @@ -22,9 +22,6 @@ fn main() { } let companion_in = std::path::Path::new("companion-lite"); for file in ["index.html", "style.css"] { - std::fs::copy( - companion_in.join(file), - companion_out.join(file) - ).unwrap(); + std::fs::copy(companion_in.join(file), companion_out.join(file)).unwrap(); } } diff --git a/examples/password-hasher.rs b/examples/password-hasher.rs index 92de7f7..3c88a40 100644 --- a/examples/password-hasher.rs +++ b/examples/password-hasher.rs @@ -1,6 +1,9 @@ use std::io::Write; -use argon2::{Argon2, password_hash::{rand_core::OsRng, PasswordHasher, PasswordHash, PasswordVerifier, SaltString}}; +use argon2::{ + password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, + Argon2, +}; fn main() -> std::io::Result<()> { eprint!("Type a password: "); @@ -15,19 +18,19 @@ fn main() -> std::io::Result<()> { let salt = SaltString::generate(&mut OsRng); let argon2 = Argon2::default(); //eprintln!("{}", password.trim()); - let password_hash = argon2.hash_password(password.trim().as_bytes(), &salt) + let password_hash = argon2 + .hash_password(password.trim().as_bytes(), &salt) .expect("Hashing a password should not error out") .serialize(); println!("{}", password_hash.as_str()); assert!(Argon2::default() - .verify_password( - password.trim().as_bytes(), - &PasswordHash::new(password_hash.as_str()) - .expect("Password hash should be valid") - ).is_ok() - ); + .verify_password( + password.trim().as_bytes(), + &PasswordHash::new(password_hash.as_str()).expect("Password hash should be valid") + ) + .is_ok()); Ok(()) } diff --git a/examples/sql.rs b/examples/sql.rs index 5d552da..59db1eb 100644 --- a/examples/sql.rs +++ b/examples/sql.rs @@ -15,8 +15,11 @@ fn sanitize(expr: &Expr) -> Result<(), Error> { match expr { Expr::Identifier(_) => Ok(()), Expr::CompoundIdentifier(_) => Ok(()), - Expr::JsonAccess { left, operator: _, right } => sanitize(left) - .and(sanitize(right)), + Expr::JsonAccess { + left, + operator: _, + right, + } => sanitize(left).and(sanitize(right)), Expr::CompositeAccess { expr, key: _ } => sanitize(expr), Expr::IsFalse(subexpr) => sanitize(subexpr), Expr::IsNotFalse(subexpr) => sanitize(subexpr), @@ -26,84 +29,180 @@ fn sanitize(expr: &Expr) -> Result<(), Error> { Expr::IsNotNull(subexpr) => sanitize(subexpr), Expr::IsUnknown(subexpr) => sanitize(subexpr), Expr::IsNotUnknown(subexpr) => sanitize(subexpr), - Expr::IsDistinctFrom(left, right) => sanitize(left) - .and(sanitize(right)), - Expr::IsNotDistinctFrom(left, right) => sanitize(left) - .and(sanitize(right)), - Expr::InList { expr, list, negated: _ } => sanitize(expr) - .and(list.iter().try_for_each(sanitize)), - Expr::InSubquery { expr: _, subquery, negated: _ } => Err(Error::SubqueryDetected(subquery.as_ref())), - Expr::InUnnest { expr, array_expr, negated: _ } => sanitize(expr).and(sanitize(array_expr)), - Expr::Between { expr, negated: _, low, high } => sanitize(expr) - .and(sanitize(low)) - .and(sanitize(high)), - Expr::BinaryOp { left, op: _, right } => sanitize(left) - .and(sanitize(right)), - Expr::Like { negated: _, expr, pattern, escape_char: _ } => sanitize(expr) - .and(sanitize(pattern)), - Expr::ILike { negated: _, expr, pattern, escape_char: _ } => sanitize(expr).and(sanitize(pattern)), - Expr::SimilarTo { negated: _, expr, pattern, escape_char: _ } => sanitize(expr).and(sanitize(pattern)), - Expr::RLike { negated: _, expr, pattern, regexp: _ } => sanitize(expr).and(sanitize(pattern)), - Expr::AnyOp { left, compare_op: _, right } => sanitize(left).and(sanitize(right)), - Expr::AllOp { left, compare_op: _, right } => sanitize(left).and(sanitize(right)), + Expr::IsDistinctFrom(left, right) => sanitize(left).and(sanitize(right)), + Expr::IsNotDistinctFrom(left, right) => sanitize(left).and(sanitize(right)), + Expr::InList { + expr, + list, + negated: _, + } => sanitize(expr).and(list.iter().try_for_each(sanitize)), + Expr::InSubquery { + expr: _, + subquery, + negated: _, + } => Err(Error::SubqueryDetected(subquery.as_ref())), + Expr::InUnnest { + expr, + array_expr, + negated: _, + } => sanitize(expr).and(sanitize(array_expr)), + Expr::Between { + expr, + negated: _, + low, + high, + } => sanitize(expr).and(sanitize(low)).and(sanitize(high)), + Expr::BinaryOp { left, op: _, right } => sanitize(left).and(sanitize(right)), + Expr::Like { + negated: _, + expr, + pattern, + escape_char: _, + } => sanitize(expr).and(sanitize(pattern)), + Expr::ILike { + negated: _, + expr, + pattern, + escape_char: _, + } => sanitize(expr).and(sanitize(pattern)), + Expr::SimilarTo { + negated: _, + expr, + pattern, + escape_char: _, + } => sanitize(expr).and(sanitize(pattern)), + Expr::RLike { + negated: _, + expr, + pattern, + regexp: _, + } => sanitize(expr).and(sanitize(pattern)), + Expr::AnyOp { + left, + compare_op: _, + right, + } => sanitize(left).and(sanitize(right)), + Expr::AllOp { + left, + compare_op: _, + right, + } => sanitize(left).and(sanitize(right)), Expr::UnaryOp { op: _, expr } => sanitize(expr), - Expr::Convert { expr, data_type: _, charset: _, target_before_value: _ } => sanitize(expr), - Expr::Cast { expr, data_type: _, format: _ } => sanitize(expr), - Expr::TryCast { expr, data_type: _, format: _ } => sanitize(expr), - Expr::SafeCast { expr, data_type: _, format: _ } => sanitize(expr), - Expr::AtTimeZone { timestamp, time_zone: _ } => sanitize(timestamp), + Expr::Convert { + expr, + data_type: _, + charset: _, + target_before_value: _, + } => sanitize(expr), + Expr::Cast { + expr, + data_type: _, + format: _, + } => sanitize(expr), + Expr::TryCast { + expr, + data_type: _, + format: _, + } => sanitize(expr), + Expr::SafeCast { + expr, + data_type: _, + format: _, + } => sanitize(expr), + Expr::AtTimeZone { + timestamp, + time_zone: _, + } => sanitize(timestamp), Expr::Extract { field: _, expr } => sanitize(expr), Expr::Ceil { expr, field: _ } => sanitize(expr), Expr::Floor { expr, field: _ } => sanitize(expr), Expr::Position { expr, r#in } => sanitize(expr).and(sanitize(r#in)), - Expr::Substring { expr, substring_from, substring_for, special: _ } => sanitize(expr) + Expr::Substring { + expr, + substring_from, + substring_for, + special: _, + } => sanitize(expr) .and(substring_from.as_deref().map(sanitize).unwrap_or(Ok(()))) .and(substring_for.as_deref().map(sanitize).unwrap_or(Ok(()))), - Expr::Trim { expr, trim_where: _, trim_what, trim_characters } => sanitize(expr) + Expr::Trim { + expr, + trim_where: _, + trim_what, + trim_characters, + } => sanitize(expr) .and(trim_what.as_deref().map(sanitize).unwrap_or(Ok(()))) .and( trim_characters .as_ref() .map(|v| v.iter()) .map(|mut iter| iter.try_for_each(sanitize)) - .unwrap_or(Ok(())) + .unwrap_or(Ok(())), ), - Expr::Overlay { expr, overlay_what, overlay_from, overlay_for } => sanitize(expr) + Expr::Overlay { + expr, + overlay_what, + overlay_from, + overlay_for, + } => sanitize(expr) .and(sanitize(overlay_what)) .and(sanitize(overlay_from)) .and(overlay_for.as_deref().map(sanitize).unwrap_or(Ok(()))), Expr::Collate { expr, collation: _ } => sanitize(expr), Expr::Nested(subexpr) => sanitize(subexpr), Expr::Value(_) => Ok(()), - Expr::IntroducedString { introducer: _, value: _ } => Ok(()), - Expr::TypedString { data_type: _, value: _ } => Ok(()), - Expr::MapAccess { column, keys } => sanitize(column).and(keys.iter().try_for_each(sanitize)), + Expr::IntroducedString { + introducer: _, + value: _, + } => Ok(()), + Expr::TypedString { + data_type: _, + value: _, + } => Ok(()), + Expr::MapAccess { column, keys } => { + sanitize(column).and(keys.iter().try_for_each(sanitize)) + } Expr::Function(func) => Err(Error::FunctionCallDetected(func)), - Expr::AggregateExpressionWithFilter { expr, filter } => sanitize(expr).and(sanitize(filter)), - Expr::Case { operand, conditions, results, else_result } => conditions.iter() + Expr::AggregateExpressionWithFilter { expr, filter } => { + sanitize(expr).and(sanitize(filter)) + } + Expr::Case { + operand, + conditions, + results, + else_result, + } => conditions + .iter() .chain(results) .chain(operand.iter().map(std::borrow::Borrow::borrow)) .chain(else_result.iter().map(std::borrow::Borrow::borrow)) .try_for_each(sanitize), - Expr::Exists { subquery, negated: _ } => Err(Error::SubqueryDetected(subquery)), + Expr::Exists { + subquery, + negated: _, + } => Err(Error::SubqueryDetected(subquery)), Expr::Subquery(subquery) => Err(Error::SubqueryDetected(subquery.as_ref())), Expr::ArraySubquery(subquery) => Err(Error::SubqueryDetected(subquery.as_ref())), Expr::ListAgg(agg) => sanitize(&agg.expr), Expr::ArrayAgg(agg) => sanitize(&agg.expr), - Expr::GroupingSets(sets) => sets.iter() + Expr::GroupingSets(sets) => sets + .iter() .map(|i| i.iter()) .try_for_each(|mut si| si.try_for_each(sanitize)), - Expr::Cube(cube) => cube.iter() + Expr::Cube(cube) => cube + .iter() .map(|i| i.iter()) .try_for_each(|mut si| si.try_for_each(sanitize)), - Expr::Rollup(rollup) => rollup.iter() + Expr::Rollup(rollup) => rollup + .iter() .map(|i| i.iter()) .try_for_each(|mut si| si.try_for_each(sanitize)), Expr::Tuple(tuple) => tuple.iter().try_for_each(sanitize), Expr::Struct { values, fields: _ } => values.iter().try_for_each(sanitize), Expr::Named { expr, name: _ } => sanitize(expr), - Expr::ArrayIndex { obj, indexes } => sanitize(obj) - .and(indexes.iter().try_for_each(sanitize)), + Expr::ArrayIndex { obj, indexes } => { + sanitize(obj).and(indexes.iter().try_for_each(sanitize)) + } Expr::Array(array) => array.elem.iter().try_for_each(sanitize), Expr::Interval(interval) => sanitize(&interval.value), Expr::MatchAgainst { .. } => Ok(()), @@ -115,7 +214,8 @@ fn sanitize(expr: &Expr) -> Result<(), Error> { fn main() -> Result<(), sqlparser::parser::ParserError> { let query = std::env::args().skip(1).take(1).next().unwrap(); - static DIALECT: sqlparser::dialect::PostgreSqlDialect = sqlparser::dialect::PostgreSqlDialect {}; + static DIALECT: sqlparser::dialect::PostgreSqlDialect = + sqlparser::dialect::PostgreSqlDialect {}; let parser: sqlparser::parser::Parser<'static> = sqlparser::parser::Parser::new(&DIALECT); let expr = parser.try_with_sql(&query)?.parse_expr()?; @@ -125,6 +225,6 @@ fn main() -> Result<(), sqlparser::parser::ParserError> { eprintln!("{}", err); } } - + Ok(()) } diff --git a/indieauth/src/lib.rs b/indieauth/src/lib.rs index 459d943..1582318 100644 --- a/indieauth/src/lib.rs +++ b/indieauth/src/lib.rs @@ -20,13 +20,13 @@ //! [`axum`]: https://github.com/tokio-rs/axum use std::borrow::Cow; -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; use url::Url; mod scopes; pub use self::scopes::{Scope, Scopes}; mod pkce; -pub use self::pkce::{PKCEMethod, PKCEVerifier, PKCEChallenge}; +pub use self::pkce::{PKCEChallenge, PKCEMethod, PKCEVerifier}; // Re-export rand crate just to be sure. pub use rand; @@ -48,7 +48,7 @@ pub enum IntrospectionEndpointAuthMethod { /// TLS client auth with a certificate signed by a valid CA. TlsClientAuth, /// TLS client auth with a self-signed certificate. - SelfSignedTlsClientAuth + SelfSignedTlsClientAuth, } /// Authentication methods supported by the revocation endpoint. @@ -64,7 +64,7 @@ pub enum IntrospectionEndpointAuthMethod { pub enum RevocationEndpointAuthMethod { /// No authentication is required to access an endpoint declaring /// this value. - None + None, } /// The response types supported by the authorization endpoint. @@ -80,7 +80,7 @@ pub enum ResponseType { /// This response type requires a valid access token. /// /// [AutoAuth spec]: https://github.com/sknebel/AutoAuth/blob/master/AutoAuth.md#allowing-external-clients-to-obtain-tokens - ExternalToken + ExternalToken, } // TODO serde_variant impl ResponseType { @@ -108,7 +108,7 @@ pub enum GrantType { /// The refresh token grant, allowing to exchange a refresh token /// for a fresh access token and a new refresh token, to /// facilitate long-term access. - RefreshToken + RefreshToken, } /// OAuth 2.0 Authorization Server Metadata in application to the IndieAuth protocol. @@ -220,7 +220,7 @@ pub struct Metadata { /// registration. #[serde(skip_serializing_if = "ref_identity")] #[serde(default = "Default::default")] - pub client_id_metadata_document_supported: bool + pub client_id_metadata_document_supported: bool, } impl std::fmt::Debug for Metadata { @@ -230,31 +230,59 @@ impl std::fmt::Debug for Metadata { .field("authorization_endpoint", &self.issuer.as_str()) .field("token_endpoint", &self.issuer.as_str()) .field("introspection_endpoint", &self.issuer.as_str()) - .field("introspection_endpoint_auth_methods_supported", &self.introspection_endpoint_auth_methods_supported) - .field("revocation_endpoint", &self.revocation_endpoint.as_ref().map(Url::as_str)) - .field("revocation_endpoint_auth_methods_supported", &self.revocation_endpoint_auth_methods_supported) + .field( + "introspection_endpoint_auth_methods_supported", + &self.introspection_endpoint_auth_methods_supported, + ) + .field( + "revocation_endpoint", + &self.revocation_endpoint.as_ref().map(Url::as_str), + ) + .field( + "revocation_endpoint_auth_methods_supported", + &self.revocation_endpoint_auth_methods_supported, + ) .field("scopes_supported", &self.scopes_supported) .field("response_types_supported", &self.response_types_supported) .field("grant_types_supported", &self.grant_types_supported) - .field("service_documentation", &self.service_documentation.as_ref().map(Url::as_str)) - .field("code_challenge_methods_supported", &self.code_challenge_methods_supported) - .field("authorization_response_iss_parameter_supported", &self.authorization_response_iss_parameter_supported) - .field("userinfo_endpoint", &self.userinfo_endpoint.as_ref().map(Url::as_str)) - .field("client_id_metadata_document_supported", &self.client_id_metadata_document_supported) + .field( + "service_documentation", + &self.service_documentation.as_ref().map(Url::as_str), + ) + .field( + "code_challenge_methods_supported", + &self.code_challenge_methods_supported, + ) + .field( + "authorization_response_iss_parameter_supported", + &self.authorization_response_iss_parameter_supported, + ) + .field( + "userinfo_endpoint", + &self.userinfo_endpoint.as_ref().map(Url::as_str), + ) + .field( + "client_id_metadata_document_supported", + &self.client_id_metadata_document_supported, + ) .finish() } } -fn ref_identity(v: &bool) -> bool { *v } +fn ref_identity(v: &bool) -> bool { + *v +} #[cfg(feature = "axum")] impl axum_core::response::IntoResponse for Metadata { fn into_response(self) -> axum_core::response::Response { use http::StatusCode; - (StatusCode::OK, - [("Content-Type", "application/json")], - serde_json::to_vec(&self).unwrap()) + ( + StatusCode::OK, + [("Content-Type", "application/json")], + serde_json::to_vec(&self).unwrap(), + ) .into_response() } } @@ -306,7 +334,7 @@ pub struct ClientMetadata { pub software_version: Option<Cow<'static, str>>, /// URI for the homepage of this client's owners #[serde(skip_serializing_if = "Option::is_none")] - pub homepage_uri: Option<Url> + pub homepage_uri: Option<Url>, } /// Error that occurs when creating [`ClientMetadata`] with mismatched `client_id` and `client_uri`. @@ -328,12 +356,15 @@ impl ClientMetadata { /// Returns `()` if the `client_uri` is not a prefix of `client_id` as required by the IndieAuth /// spec. pub fn new(client_id: url::Url, client_uri: url::Url) -> Result<Self, ClientIdMismatch> { - if client_id.as_str().as_bytes()[..client_uri.as_str().len()] != *client_uri.as_str().as_bytes() { + if client_id.as_str().as_bytes()[..client_uri.as_str().len()] + != *client_uri.as_str().as_bytes() + { return Err(ClientIdMismatch); } Ok(Self { - client_id, client_uri, + client_id, + client_uri, client_name: None, logo_uri: None, redirect_uris: None, @@ -363,14 +394,15 @@ impl axum_core::response::IntoResponse for ClientMetadata { fn into_response(self) -> axum_core::response::Response { use http::StatusCode; - (StatusCode::OK, - [("Content-Type", "application/json")], - serde_json::to_vec(&self).unwrap()) + ( + StatusCode::OK, + [("Content-Type", "application/json")], + serde_json::to_vec(&self).unwrap(), + ) .into_response() } } - /// User profile to be returned from the userinfo endpoint and when /// the `profile` scope was requested. #[derive(Clone, Debug, Serialize, Deserialize)] @@ -387,7 +419,7 @@ pub struct Profile { /// User's email, if they've chosen to reveal it. This is guarded /// by the `email` scope. #[serde(skip_serializing_if = "Option::is_none")] - pub email: Option<String> + pub email: Option<String>, } #[cfg(feature = "axum")] @@ -395,9 +427,11 @@ impl axum_core::response::IntoResponse for Profile { fn into_response(self) -> axum_core::response::Response { use http::StatusCode; - (StatusCode::OK, - [("Content-Type", "application/json")], - serde_json::to_vec(&self).unwrap()) + ( + StatusCode::OK, + [("Content-Type", "application/json")], + serde_json::to_vec(&self).unwrap(), + ) .into_response() } } @@ -422,13 +456,13 @@ impl State { /// Generate a random state string of 128 bytes in length, using /// the provided random number generator. pub fn from_rng(rng: &mut (impl rand::CryptoRng + rand::Rng)) -> Self { - use rand::{Rng, distributions::Alphanumeric}; + use rand::{distributions::Alphanumeric, Rng}; - let bytes = rng.sample_iter(&Alphanumeric) + let bytes = rng + .sample_iter(&Alphanumeric) .take(128) .collect::<Vec<u8>>(); Self(String::from_utf8(bytes).unwrap()) - } } impl AsRef<str> for State { @@ -511,21 +545,23 @@ impl AuthorizationRequest { ("response_type", Cow::Borrowed(self.response_type.as_str())), ("client_id", Cow::Borrowed(self.client_id.as_str())), ("redirect_uri", Cow::Borrowed(self.redirect_uri.as_str())), - ("code_challenge", Cow::Borrowed(self.code_challenge.as_str())), - ("code_challenge_method", Cow::Borrowed(self.code_challenge.method().as_str())), - ("state", Cow::Borrowed(self.state.as_ref())) + ( + "code_challenge", + Cow::Borrowed(self.code_challenge.as_str()), + ), + ( + "code_challenge_method", + Cow::Borrowed(self.code_challenge.method().as_str()), + ), + ("state", Cow::Borrowed(self.state.as_ref())), ]; if let Some(ref scope) = self.scope { - v.push( - ("scope", Cow::Owned(scope.to_string())) - ); + v.push(("scope", Cow::Owned(scope.to_string()))); } if let Some(ref me) = self.me { - v.push( - ("me", Cow::Borrowed(me.as_str())) - ); + v.push(("me", Cow::Borrowed(me.as_str()))); } v @@ -558,17 +594,22 @@ pub struct AutoAuthRequest { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AutoAuthCallbackData { state: State, - callback_url: Url + callback_url: Url, } #[inline(always)] -fn deserialize_secs<'de, D: serde::de::Deserializer<'de>>(d: D) -> Result<std::time::Duration, D::Error> { +fn deserialize_secs<'de, D: serde::de::Deserializer<'de>>( + d: D, +) -> Result<std::time::Duration, D::Error> { use serde::Deserialize; Ok(std::time::Duration::from_secs(u64::deserialize(d)?)) } #[inline(always)] -fn serialize_secs<S: serde::ser::Serializer>(d: &std::time::Duration, s: S) -> Result<S::Ok, S::Error> { +fn serialize_secs<S: serde::ser::Serializer>( + d: &std::time::Duration, + s: S, +) -> Result<S::Ok, S::Error> { s.serialize_u64(std::time::Duration::as_secs(d)) } @@ -578,7 +619,7 @@ pub struct AutoAuthPollingResponse { request_id: State, #[serde(serialize_with = "serialize_secs")] #[serde(deserialize_with = "deserialize_secs")] - interval: std::time::Duration + interval: std::time::Duration, } /// The authorization response that must be appended to the @@ -610,10 +651,9 @@ pub struct AuthorizationResponse { /// authorization server. /// /// [oauth2-iss]: https://www.ietf.org/archive/id/draft-ietf-oauth-iss-auth-resp-02.html - pub iss: Url + pub iss: Url, } - /// A special grant request that is used in the AutoAuth ceremony. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct AutoAuthCodeGrant { @@ -636,7 +676,7 @@ pub struct AutoAuthCodeGrant { callback_url: Url, /// The user's URL. Will be used to confirm the authorization /// endpoint's authority. - me: Url + me: Url, } /// A grant request that continues the IndieAuth ceremony. @@ -655,7 +695,7 @@ pub enum GrantRequest { redirect_uri: Url, /// The PKCE code verifier that was used to create the code /// challenge. - code_verifier: PKCEVerifier + code_verifier: PKCEVerifier, }, /// Use a refresh token to get a fresh access token and a new /// matching refresh token. @@ -670,8 +710,8 @@ pub enum GrantRequest { /// /// This cannot be used to gain new scopes -- you need to /// start over if you need new scopes from the user. - scope: Option<Scopes> - } + scope: Option<Scopes>, + }, } /// Token type, as described in [RFC6749][]. @@ -685,7 +725,7 @@ pub enum TokenType { /// IndieAuth uses. /// /// [RFC6750]: https://www.rfc-editor.org/rfc/rfc6750 - Bearer + Bearer, } /// The response to a successful [`GrantRequest`]. @@ -722,14 +762,14 @@ pub enum GrantResponse { profile: Option<Profile>, /// The refresh token, if it was issued. #[serde(skip_serializing_if = "Option::is_none")] - refresh_token: Option<String> + refresh_token: Option<String>, }, /// A profile URL response, that only contains the profile URL and /// the profile, if it was requested. /// /// This is suitable for confirming the identity of the user, but /// no more than that. - ProfileUrl(ProfileUrl) + ProfileUrl(ProfileUrl), } /// The contents of a profile URL response. @@ -739,7 +779,7 @@ pub struct ProfileUrl { pub me: Url, /// The user's profile information, if it was requested. #[serde(skip_serializing_if = "Option::is_none")] - pub profile: Option<Profile> + pub profile: Option<Profile>, } #[cfg(feature = "axum")] @@ -747,12 +787,15 @@ impl axum_core::response::IntoResponse for GrantResponse { fn into_response(self) -> axum_core::response::Response { use http::StatusCode; - (StatusCode::OK, - [("Content-Type", "application/json"), - ("Cache-Control", "no-store"), - ("Pragma", "no-cache") - ], - serde_json::to_vec(&self).unwrap()) + ( + StatusCode::OK, + [ + ("Content-Type", "application/json"), + ("Cache-Control", "no-store"), + ("Pragma", "no-cache"), + ], + serde_json::to_vec(&self).unwrap(), + ) .into_response() } } @@ -766,7 +809,7 @@ impl axum_core::response::IntoResponse for GrantResponse { pub enum RequestMaybeAuthorizationEndpoint { Authorization(AuthorizationRequest), Grant(GrantRequest), - AutoAuth(AutoAuthCodeGrant) + AutoAuth(AutoAuthCodeGrant), } /// A token introspection request that can be handled by the token @@ -778,7 +821,7 @@ pub enum RequestMaybeAuthorizationEndpoint { #[derive(Debug, Serialize, Deserialize)] pub struct TokenIntrospectionRequest { /// The token for which data was requested. - pub token: String + pub token: String, } /// Data for a token that will be returned by the introspection @@ -800,7 +843,7 @@ pub struct TokenData { /// The issue date, represented in the same format as the /// [`exp`][TokenData::exp] field. #[serde(skip_serializing_if = "Option::is_none")] - pub iat: Option<u64> + pub iat: Option<u64>, } impl TokenData { @@ -809,24 +852,25 @@ impl TokenData { use std::time::{Duration, SystemTime, UNIX_EPOCH}; self.exp - .map(|exp| SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or(Duration::ZERO) - .as_secs() >= exp) + .map(|exp| { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(Duration::ZERO) + .as_secs() + >= exp + }) .unwrap_or_default() } /// Return a timestamp at which the token is not considered valid anymore. pub fn expires_at(&self) -> Option<std::time::SystemTime> { - self.exp.map(|time| { - std::time::UNIX_EPOCH + std::time::Duration::from_secs(time) - }) + self.exp + .map(|time| std::time::UNIX_EPOCH + std::time::Duration::from_secs(time)) } /// Return a timestamp describing when the token was issued. pub fn issued_at(&self) -> Option<std::time::SystemTime> { - self.iat.map(|time| { - std::time::UNIX_EPOCH + std::time::Duration::from_secs(time) - }) + self.iat + .map(|time| std::time::UNIX_EPOCH + std::time::Duration::from_secs(time)) } /// Check if a certain scope is allowed for this token. @@ -849,18 +893,24 @@ pub struct TokenIntrospectionResponse { active: bool, #[serde(flatten)] #[serde(skip_serializing_if = "Option::is_none")] - data: Option<TokenData> + data: Option<TokenData>, } // These wrappers and impls should take care of making use of this // type as painless as possible. impl TokenIntrospectionResponse { /// Indicate that this token is not valid. pub fn inactive() -> Self { - Self { active: false, data: None } + Self { + active: false, + data: None, + } } /// Indicate that this token is valid, and provide data about it. pub fn active(data: TokenData) -> Self { - Self { active: true, data: Some(data) } + Self { + active: true, + data: Some(data), + } } /// Check if the endpoint reports this token as valid. pub fn is_active(&self) -> bool { @@ -870,7 +920,7 @@ impl TokenIntrospectionResponse { /// Get data contained in the response, if the token is valid. pub fn data(&self) -> Option<&TokenData> { if !self.active { - return None + return None; } self.data.as_ref() } @@ -882,7 +932,10 @@ impl Default for TokenIntrospectionResponse { } impl From<Option<TokenData>> for TokenIntrospectionResponse { fn from(data: Option<TokenData>) -> Self { - Self { active: data.is_some(), data } + Self { + active: data.is_some(), + data, + } } } impl From<TokenIntrospectionResponse> for Option<TokenData> { @@ -896,9 +949,11 @@ impl axum_core::response::IntoResponse for TokenIntrospectionResponse { fn into_response(self) -> axum_core::response::Response { use http::StatusCode; - (StatusCode::OK, - [("Content-Type", "application/json")], - serde_json::to_vec(&self).unwrap()) + ( + StatusCode::OK, + [("Content-Type", "application/json")], + serde_json::to_vec(&self).unwrap(), + ) .into_response() } } @@ -908,7 +963,7 @@ impl axum_core::response::IntoResponse for TokenIntrospectionResponse { #[derive(Debug, Serialize, Deserialize)] pub struct TokenRevocationRequest { /// The token that needs to be revoked in case it is valid. - pub token: String + pub token: String, } /// Types of errors that a resource server (IndieAuth consumer) can @@ -969,7 +1024,6 @@ pub enum ErrorKind { /// AutoAuth/OAuth2 Device Flow: Access was denied by the /// authorization endpoint. AccessDenied, - } // TODO consider relying on serde_variant for these conversions impl AsRef<str> for ErrorKind { @@ -1005,13 +1059,15 @@ pub struct Error { pub msg: Option<String>, /// An URL to documentation describing what went wrong and how to /// fix it. - pub error_uri: Option<url::Url> + pub error_uri: Option<url::Url>, } impl From<ErrorKind> for Error { fn from(kind: ErrorKind) -> Error { Error { - kind, msg: None, error_uri: None + kind, + msg: None, + error_uri: None, } } } @@ -1037,9 +1093,11 @@ impl axum_core::response::IntoResponse for self::Error { fn into_response(self) -> axum_core::response::Response { use http::StatusCode; - (StatusCode::BAD_REQUEST, - [("Content-Type", "application/json")], - serde_json::to_vec(&self).unwrap()) + ( + StatusCode::BAD_REQUEST, + [("Content-Type", "application/json")], + serde_json::to_vec(&self).unwrap(), + ) .into_response() } } @@ -1052,17 +1110,23 @@ mod tests { fn test_serialize_deserialize_grant_request() { let authorization_code: GrantRequest = GrantRequest::AuthorizationCode { client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), - redirect_uri: "https://kittybox.fireburn.ru/.kittybox/login/redirect".parse().unwrap(), + redirect_uri: "https://kittybox.fireburn.ru/.kittybox/login/redirect" + .parse() + .unwrap(), code_verifier: PKCEVerifier("helloworld".to_string()), - code: "hithere".to_owned() + code: "hithere".to_owned(), }; let serialized = serde_urlencoded::to_string([ ("grant_type", "authorization_code"), ("code", "hithere"), ("client_id", "https://kittybox.fireburn.ru/"), - ("redirect_uri", "https://kittybox.fireburn.ru/.kittybox/login/redirect"), + ( + "redirect_uri", + "https://kittybox.fireburn.ru/.kittybox/login/redirect", + ), ("code_verifier", "helloworld"), - ]).unwrap(); + ]) + .unwrap(); let deserialized = serde_urlencoded::from_str(&serialized).unwrap(); diff --git a/indieauth/src/pkce.rs b/indieauth/src/pkce.rs index 8dcf9b1..6233016 100644 --- a/indieauth/src/pkce.rs +++ b/indieauth/src/pkce.rs @@ -1,6 +1,6 @@ -use serde::{Serialize, Deserialize}; -use sha2::{Sha256, Digest}; use data_encoding::BASE64URL; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; /// Methods to use for PKCE challenges. #[derive(PartialEq, Eq, Copy, Clone, Debug, Serialize, Deserialize, Default)] @@ -10,7 +10,7 @@ pub enum PKCEMethod { S256, /// Plain string by itself. Please don't use this. #[serde(rename = "snake_case")] - Plain + Plain, } impl PKCEMethod { @@ -18,7 +18,7 @@ impl PKCEMethod { pub fn as_str(&self) -> &'static str { match self { PKCEMethod::S256 => "S256", - PKCEMethod::Plain => "plain" + PKCEMethod::Plain => "plain", } } } @@ -57,7 +57,7 @@ impl PKCEVerifier { /// Generate a new PKCE string of 128 bytes in length, using /// the provided random number generator. pub fn from_rng(rng: &mut (impl rand::CryptoRng + rand::Rng)) -> Self { - use rand::{Rng, distributions::Alphanumeric}; + use rand::{distributions::Alphanumeric, Rng}; let bytes = rng .sample_iter(&Alphanumeric) @@ -65,7 +65,6 @@ impl PKCEVerifier { .collect::<Vec<u8>>(); Self(String::from_utf8(bytes).unwrap()) } - } /// A PKCE challenge as described in [RFC7636]. @@ -75,7 +74,7 @@ impl PKCEVerifier { pub struct PKCEChallenge { code_challenge: String, #[serde(rename = "code_challenge_method")] - method: PKCEMethod + method: PKCEMethod, } impl PKCEChallenge { @@ -92,10 +91,10 @@ impl PKCEChallenge { challenge.retain(|c| c != '='); challenge - }, + } PKCEMethod::Plain => code_verifier.to_string(), }, - method + method, } } @@ -130,17 +129,21 @@ impl PKCEChallenge { #[cfg(test)] mod tests { - use super::{PKCEMethod, PKCEVerifier, PKCEChallenge}; + use super::{PKCEChallenge, PKCEMethod, PKCEVerifier}; #[test] /// A snapshot test generated using [Aaron Parecki's PKCE /// tools](https://example-app.com/pkce) that checks for a /// conforming challenge. fn test_pkce_challenge_verification() { - let verifier = PKCEVerifier("ec03310e4e90f7bc988af05384060c3c1afeae4bb4d0f648c5c06b63".to_owned()); + let verifier = + PKCEVerifier("ec03310e4e90f7bc988af05384060c3c1afeae4bb4d0f648c5c06b63".to_owned()); let challenge = PKCEChallenge::new(&verifier, PKCEMethod::S256); - assert_eq!(challenge.as_str(), "aB8OG20Rh8UoQ9gFhI0YvPkx4dDW2MBspBKGXL6j6Wg"); + assert_eq!( + challenge.as_str(), + "aB8OG20Rh8UoQ9gFhI0YvPkx4dDW2MBspBKGXL6j6Wg" + ); } } diff --git a/indieauth/src/scopes.rs b/indieauth/src/scopes.rs index 1157996..295b0c8 100644 --- a/indieauth/src/scopes.rs +++ b/indieauth/src/scopes.rs @@ -1,12 +1,8 @@ use std::str::FromStr; use serde::{ - Serialize, Serializer, - Deserialize, - de::{ - Deserializer, Visitor, - Error as DeserializeError - } + de::{Deserializer, Error as DeserializeError, Visitor}, + Deserialize, Serialize, Serializer, }; /// Various scopes that can be requested through IndieAuth. @@ -36,7 +32,7 @@ pub enum Scope { /// Allows to receive email in the profile information. Email, /// Custom scope not included above. - Custom(String) + Custom(String), } impl Scope { /// Create a custom scope from a string slice. @@ -61,25 +57,25 @@ impl AsRef<str> for Scope { Channels => "channels", Profile => "profile", Email => "email", - Custom(s) => s.as_ref() + Custom(s) => s.as_ref(), } } } impl From<&str> for Scope { fn from(scope: &str) -> Self { match scope { - "create" => Scope::Create, - "update" => Scope::Update, - "delete" => Scope::Delete, - "media" => Scope::Media, - "read" => Scope::Read, - "follow" => Scope::Follow, - "mute" => Scope::Mute, - "block" => Scope::Block, + "create" => Scope::Create, + "update" => Scope::Update, + "delete" => Scope::Delete, + "media" => Scope::Media, + "read" => Scope::Read, + "follow" => Scope::Follow, + "mute" => Scope::Mute, + "block" => Scope::Block, "channels" => Scope::Channels, - "profile" => Scope::Profile, - "email" => Scope::Email, - other => Scope::custom(other) + "profile" => Scope::Profile, + "email" => Scope::Email, + other => Scope::custom(other), } } } @@ -106,7 +102,8 @@ impl Scopes { } /// Ensure all of the requested scopes are in the list. pub fn has_all(&self, scopes: &[Scope]) -> bool { - scopes.iter() + scopes + .iter() .map(|s1| self.iter().any(|s2| s1 == s2)) .all(|s| s) } @@ -123,8 +120,7 @@ impl AsRef<[Scope]> for Scopes { impl std::fmt::Display for Scopes { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut iter = self.0.iter() - .peekable(); + let mut iter = self.0.iter().peekable(); while let Some(scope) = iter.next() { f.write_str(scope.as_ref())?; if iter.peek().is_some() { @@ -139,15 +135,18 @@ impl FromStr for Scopes { type Err = std::convert::Infallible; fn from_str(value: &str) -> Result<Self, Self::Err> { - Ok(Self(value.split_ascii_whitespace() + Ok(Self( + value + .split_ascii_whitespace() .map(Scope::from) - .collect::<Vec<Scope>>())) + .collect::<Vec<Scope>>(), + )) } } impl Serialize for Scopes { fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where - S: Serializer + S: Serializer, { serializer.serialize_str(&self.to_string()) } @@ -163,16 +162,15 @@ impl<'de> Visitor<'de> for ScopeVisitor { fn visit_str<E>(self, value: &str) -> Result<Self::Value, E> where - E: DeserializeError + E: DeserializeError, { Ok(Scopes::from_str(value).unwrap()) } } impl<'de> Deserialize<'de> for Scopes { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> where - D: Deserializer<'de> + D: Deserializer<'de>, { deserializer.deserialize_str(ScopeVisitor) } @@ -185,29 +183,31 @@ mod tests { #[test] fn test_serde_vec_scope() { let scopes = vec![ - Scope::Create, Scope::Update, Scope::Delete, + Scope::Create, + Scope::Update, + Scope::Delete, Scope::Media, - Scope::custom("kittybox_internal_access") + Scope::custom("kittybox_internal_access"), ]; - let scope_serialized = serde_json::to_value( - Scopes::new(scopes.clone()) - ).unwrap(); + let scope_serialized = serde_json::to_value(Scopes::new(scopes.clone())).unwrap(); let scope_str = scope_serialized.as_str().unwrap(); - assert_eq!(scope_str, "create update delete media kittybox_internal_access"); + assert_eq!( + scope_str, + "create update delete media kittybox_internal_access" + ); - assert!(serde_json::from_value::<Scopes>(scope_serialized).unwrap().has_all(&scopes)) + assert!(serde_json::from_value::<Scopes>(scope_serialized) + .unwrap() + .has_all(&scopes)) } #[test] fn test_scope_has_all() { - let scopes = Scopes(vec![ - Scope::Create, Scope::Update, Scope::custom("draft") - ]); + let scopes = Scopes(vec![Scope::Create, Scope::Update, Scope::custom("draft")]); assert!(scopes.has_all(&[Scope::Create, Scope::custom("draft")])); assert!(!scopes.has_all(&[Scope::Read, Scope::custom("kittybox_internal_access")])); } - } diff --git a/src/bin/kittybox-check-webmention.rs b/src/bin/kittybox-check-webmention.rs index b43980e..a9e5957 100644 --- a/src/bin/kittybox-check-webmention.rs +++ b/src/bin/kittybox-check-webmention.rs @@ -7,7 +7,7 @@ enum Error { #[error("reqwest error: {0}")] Http(#[from] reqwest::Error), #[error("webmention check error: {0}")] - Webmention(#[from] WebmentionError) + Webmention(#[from] WebmentionError), } #[derive(Parser, Debug)] @@ -21,7 +21,7 @@ struct Args { #[clap(value_parser)] url: url::Url, #[clap(value_parser)] - link: url::Url + link: url::Url, } #[tokio::main] @@ -30,10 +30,11 @@ async fn main() -> Result<(), Error> { let http: reqwest::Client = { #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION") - )); + let mut builder = reqwest::Client::builder().user_agent(concat!( + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION") + )); builder.build().unwrap() }; diff --git a/src/bin/kittybox-indieauth-helper.rs b/src/bin/kittybox-indieauth-helper.rs index f4ad679..0725aac 100644 --- a/src/bin/kittybox-indieauth-helper.rs +++ b/src/bin/kittybox-indieauth-helper.rs @@ -1,13 +1,11 @@ +use clap::Parser; use futures::{FutureExt, TryFutureExt}; use kittybox_indieauth::{ - AuthorizationRequest, PKCEVerifier, - PKCEChallenge, PKCEMethod, GrantRequest, Scope, - AuthorizationResponse, GrantResponse, - Error as IndieauthError + AuthorizationRequest, AuthorizationResponse, Error as IndieauthError, GrantRequest, + GrantResponse, PKCEChallenge, PKCEMethod, PKCEVerifier, Scope, }; -use clap::Parser; -use tokio::net::TcpListener; use std::{borrow::Cow, future::IntoFuture, io::Write}; +use tokio::net::TcpListener; const DEFAULT_CLIENT_ID: &str = "https://kittybox.fireburn.ru/indieauth-helper.html"; const DEFAULT_REDIRECT_URI: &str = "http://localhost:60000/callback"; @@ -21,7 +19,7 @@ enum Error { #[error("url parsing error: {0}")] UrlParse(#[from] url::ParseError), #[error("indieauth flow error: {0}")] - IndieAuth(#[from] IndieauthError) + IndieAuth(#[from] IndieauthError), } #[derive(Parser, Debug)] @@ -46,20 +44,20 @@ struct Args { client_id: url::Url, /// Redirect URI to declare. Note: This will break the flow, use only for testing UI. #[clap(long, value_parser)] - redirect_uri: Option<url::Url> + redirect_uri: Option<url::Url>, } - #[tokio::main] async fn main() -> Result<(), Error> { let args = Args::parse(); let http: reqwest::Client = { #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION") - )); + let mut builder = reqwest::Client::builder().user_agent(concat!( + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION") + )); // This only works on debug builds. Don't get any funny thoughts. #[cfg(debug_assertions)] if std::env::var("KITTYBOX_DANGER_INSECURE_TLS") @@ -71,12 +69,14 @@ async fn main() -> Result<(), Error> { builder.build().unwrap() }; - let redirect_uri: url::Url = args.redirect_uri + let redirect_uri: url::Url = args + .redirect_uri .clone() .unwrap_or_else(|| DEFAULT_REDIRECT_URI.parse().unwrap()); eprintln!("Checking .well-known for metadata..."); - let metadata = http.get(args.me.join("/.well-known/oauth-authorization-server")?) + let metadata = http + .get(args.me.join("/.well-known/oauth-authorization-server")?) .header("Accept", "application/json") .send() .await? @@ -92,7 +92,7 @@ async fn main() -> Result<(), Error> { state: kittybox_indieauth::State::new(), code_challenge: PKCEChallenge::new(&verifier, PKCEMethod::default()), scope: Some(kittybox_indieauth::Scopes::new(args.scope)), - me: Some(args.me) + me: Some(args.me), }; let indieauth_url = { @@ -103,12 +103,18 @@ async fn main() -> Result<(), Error> { url }; - eprintln!("Please visit the following URL in your browser:\n\n {}\n", indieauth_url.as_str()); + eprintln!( + "Please visit the following URL in your browser:\n\n {}\n", + indieauth_url.as_str() + ); #[cfg(target_os = "linux")] - match std::process::Command::new("xdg-open").arg(indieauth_url.as_str()).spawn() { + match std::process::Command::new("xdg-open") + .arg(indieauth_url.as_str()) + .spawn() + { Ok(child) => drop(child), - Err(err) => eprintln!("Couldn't xdg-open: {}", err) + Err(err) => eprintln!("Couldn't xdg-open: {}", err), } if args.redirect_uri.is_some() { @@ -123,32 +129,38 @@ async fn main() -> Result<(), Error> { let tx = std::sync::Arc::new(tokio::sync::Mutex::new(Some(tx))); - let router = axum::Router::new() - .route("/callback", axum::routing::get( + let router = axum::Router::new().route( + "/callback", + axum::routing::get( move |Query(response): Query<AuthorizationResponse>| async move { if let Some(tx) = tx.lock_owned().await.take() { tx.send(response).unwrap(); - (axum::http::StatusCode::OK, - [("Content-Type", "text/plain")], - "Thank you! This window can now be closed.") + ( + axum::http::StatusCode::OK, + [("Content-Type", "text/plain")], + "Thank you! This window can now be closed.", + ) .into_response() } else { - (axum::http::StatusCode::BAD_REQUEST, - [("Content-Type", "text/plain")], - "Oops. The callback was already received. Did you click twice?") + ( + axum::http::StatusCode::BAD_REQUEST, + [("Content-Type", "text/plain")], + "Oops. The callback was already received. Did you click twice?", + ) .into_response() } - } - )); + }, + ), + ); - use std::net::{SocketAddr, IpAddr, Ipv4Addr}; + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; let server = axum::serve( - TcpListener::bind( - SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST),60000) - ).await.unwrap(), - router.into_make_service() + TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 60000)) + .await + .unwrap(), + router.into_make_service(), ); tokio::task::spawn(server.into_future()) @@ -175,12 +187,13 @@ async fn main() -> Result<(), Error> { #[cfg(not(debug_assertions))] std::process::exit(1); } - let response: Result<GrantResponse, IndieauthError> = http.post(metadata.token_endpoint) + let response: Result<GrantResponse, IndieauthError> = http + .post(metadata.token_endpoint) .form(&GrantRequest::AuthorizationCode { code: authorization_response.code, client_id: args.client_id, redirect_uri, - code_verifier: verifier + code_verifier: verifier, }) .header("Accept", "application/json") .send() @@ -201,9 +214,14 @@ async fn main() -> Result<(), Error> { refresh_token, scope, .. - } = response? { - eprintln!("Congratulations, {}, access token is ready! {}", - profile.as_ref().and_then(|p| p.name.as_deref()).unwrap_or(me.as_str()), + } = response? + { + eprintln!( + "Congratulations, {}, access token is ready! {}", + profile + .as_ref() + .and_then(|p| p.name.as_deref()) + .unwrap_or(me.as_str()), if let Some(exp) = expires_in { Cow::Owned(format!("It expires in {exp} seconds.")) } else { diff --git a/src/bin/kittybox-mf2.rs b/src/bin/kittybox-mf2.rs index 0cd89b4..b6f4999 100644 --- a/src/bin/kittybox-mf2.rs +++ b/src/bin/kittybox-mf2.rs @@ -37,8 +37,9 @@ async fn main() -> Result<(), Error> { .with_indent_lines(true) .with_verbose_exit(true), #[cfg(not(debug_assertions))] - tracing_subscriber::fmt::layer().json() - .with_ansi(std::io::IsTerminal::is_terminal(&std::io::stdout().lock())) + tracing_subscriber::fmt::layer() + .json() + .with_ansi(std::io::IsTerminal::is_terminal(&std::io::stdout().lock())), ); tracing_registry.init(); @@ -46,10 +47,11 @@ async fn main() -> Result<(), Error> { let http: reqwest::Client = { #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION") - )); + let mut builder = reqwest::Client::builder().user_agent(concat!( + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION") + )); builder.build().unwrap() }; diff --git a/src/database/file/mod.rs b/src/database/file/mod.rs index b9f27b2..5c93beb 100644 --- a/src/database/file/mod.rs +++ b/src/database/file/mod.rs @@ -1,6 +1,6 @@ //#![warn(clippy::unwrap_used)] -use crate::database::{ErrorKind, Result, settings, Storage, StorageError}; -use crate::micropub::{MicropubUpdate, MicropubPropertyDeletion}; +use crate::database::{settings, ErrorKind, Result, Storage, StorageError}; +use crate::micropub::{MicropubPropertyDeletion, MicropubUpdate}; use futures::{stream, StreamExt, TryStreamExt}; use kittybox_util::MentionType; use serde_json::json; @@ -247,7 +247,9 @@ async fn hydrate_author<S: Storage>( impl Storage for FileStorage { async fn new(url: &'_ url::Url) -> Result<Self> { // TODO: sanity check - Ok(Self { root_dir: PathBuf::from(url.path()) }) + Ok(Self { + root_dir: PathBuf::from(url.path()), + }) } #[tracing::instrument(skip(self))] async fn categories(&self, url: &str) -> Result<Vec<String>> { @@ -259,7 +261,7 @@ impl Storage for FileStorage { // perform well. Err(std::io::Error::new( std::io::ErrorKind::Unsupported, - "?q=category queries are not implemented due to resource constraints" + "?q=category queries are not implemented due to resource constraints", ))? } @@ -340,7 +342,10 @@ impl Storage for FileStorage { file.sync_all().await?; drop(file); tokio::fs::rename(&tempfile, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + tokio::fs::File::open(path.parent().unwrap()) + .await? + .sync_all() + .await?; if let Some(urls) = post["properties"]["url"].as_array() { for url in urls.iter().map(|i| i.as_str().unwrap()) { @@ -350,8 +355,8 @@ impl Storage for FileStorage { "{}{}", url.host_str().unwrap(), url.port() - .map(|port| format!(":{}", port)) - .unwrap_or_default() + .map(|port| format!(":{}", port)) + .unwrap_or_default() ) }; if url != key && url_domain == user.authority() { @@ -410,26 +415,24 @@ impl Storage for FileStorage { .create(false) .open(&path) .await - { - Err(err) if err.kind() == std::io::ErrorKind::NotFound => { - Vec::default() - } - Err(err) => { - // Propagate the error upwards - return Err(err.into()); - } - Ok(mut file) => { - let mut content = String::new(); - file.read_to_string(&mut content).await?; - drop(file); - - if !content.is_empty() { - serde_json::from_str(&content)? - } else { - Vec::default() - } - } - } + { + Err(err) if err.kind() == std::io::ErrorKind::NotFound => Vec::default(), + Err(err) => { + // Propagate the error upwards + return Err(err.into()); + } + Ok(mut file) => { + let mut content = String::new(); + file.read_to_string(&mut content).await?; + drop(file); + + if !content.is_empty() { + serde_json::from_str(&content)? + } else { + Vec::default() + } + } + } }; channels.push(super::MicropubChannel { @@ -444,7 +447,10 @@ impl Storage for FileStorage { tempfile.sync_all().await?; drop(tempfile); tokio::fs::rename(tempfilename, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + tokio::fs::File::open(path.parent().unwrap()) + .await? + .sync_all() + .await?; } Ok(()) } @@ -476,7 +482,10 @@ impl Storage for FileStorage { temp.sync_all().await?; drop(temp); tokio::fs::rename(tempfilename, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + tokio::fs::File::open(path.parent().unwrap()) + .await? + .sync_all() + .await?; (json, new_json) }; @@ -486,7 +495,9 @@ impl Storage for FileStorage { #[tracing::instrument(skip(self, f), fields(f = std::any::type_name::<F>()))] async fn update_with<F: FnOnce(&mut serde_json::Value) + Send>( - &self, url: &str, f: F + &self, + url: &str, + f: F, ) -> Result<(serde_json::Value, serde_json::Value)> { todo!("update_with is not yet implemented due to special requirements of the file backend") } @@ -526,25 +537,25 @@ impl Storage for FileStorage { url: &'_ str, cursor: Option<&'_ str>, limit: usize, - user: Option<&url::Url> + user: Option<&url::Url>, ) -> Result<Option<(serde_json::Value, Option<String>)>> { #[allow(deprecated)] - Ok(self.read_feed_with_limit( - url, - cursor, - limit, - user - ).await? + Ok(self + .read_feed_with_limit(url, cursor, limit, user) + .await? .map(|feed| { - tracing::debug!("Feed: {:#}", serde_json::Value::Array( - feed["children"] - .as_array() - .map(|v| v.as_slice()) - .unwrap_or_default() - .iter() - .map(|mf2| mf2["properties"]["uid"][0].clone()) - .collect::<Vec<_>>() - )); + tracing::debug!( + "Feed: {:#}", + serde_json::Value::Array( + feed["children"] + .as_array() + .map(|v| v.as_slice()) + .unwrap_or_default() + .iter() + .map(|mf2| mf2["properties"]["uid"][0].clone()) + .collect::<Vec<_>>() + ) + ); let cursor: Option<String> = feed["children"] .as_array() .map(|v| v.as_slice()) @@ -553,8 +564,7 @@ impl Storage for FileStorage { .map(|v| v["properties"]["uid"][0].as_str().unwrap().to_owned()); tracing::debug!("Extracted the cursor: {:?}", cursor); (feed, cursor) - }) - ) + })) } #[tracing::instrument(skip(self))] @@ -574,9 +584,12 @@ impl Storage for FileStorage { let children: Vec<serde_json::Value> = match feed["children"].take() { serde_json::Value::Array(children) => children, // We've already checked it's an array - _ => unreachable!() + _ => unreachable!(), }; - tracing::debug!("Full children array: {:#}", serde_json::Value::Array(children.clone())); + tracing::debug!( + "Full children array: {:#}", + serde_json::Value::Array(children.clone()) + ); let mut posts_iter = children .into_iter() .map(|s: serde_json::Value| s.as_str().unwrap().to_string()); @@ -589,7 +602,7 @@ impl Storage for FileStorage { // incredibly long feeds. if let Some(after) = after { tokio::task::block_in_place(|| { - for s in posts_iter.by_ref() { + for s in posts_iter.by_ref() { if s == after { break; } @@ -655,12 +668,19 @@ impl Storage for FileStorage { let settings: HashMap<&str, serde_json::Value> = serde_json::from_str(&content)?; match settings.get(S::ID) { Some(value) => Ok(serde_json::from_value::<S>(value.clone())?), - None => Err(StorageError::from_static(ErrorKind::Backend, "Setting not set")) + None => Err(StorageError::from_static( + ErrorKind::Backend, + "Setting not set", + )), } } #[tracing::instrument(skip(self))] - async fn set_setting<S: settings::Setting>(&self, user: &url::Url, value: S::Data) -> Result<()> { + async fn set_setting<S: settings::Setting>( + &self, + user: &url::Url, + value: S::Data, + ) -> Result<()> { let mut path = relative_path::RelativePathBuf::new(); path.push(user.authority()); path.push("settings"); @@ -704,20 +724,28 @@ impl Storage for FileStorage { tempfile.sync_all().await?; drop(tempfile); tokio::fs::rename(temppath, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + tokio::fs::File::open(path.parent().unwrap()) + .await? + .sync_all() + .await?; Ok(()) } #[tracing::instrument(skip(self))] - async fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> Result<()> { + async fn add_or_update_webmention( + &self, + target: &str, + mention_type: MentionType, + mention: serde_json::Value, + ) -> Result<()> { let path = url_to_path(&self.root_dir, target); let tempfilename = path.with_extension("tmp"); let mut temp = OpenOptions::new() - .write(true) - .create_new(true) - .open(&tempfilename) - .await?; + .write(true) + .create_new(true) + .open(&tempfilename) + .await?; let mut file = OpenOptions::new().read(true).open(&path).await?; let mut post: serde_json::Value = { @@ -752,13 +780,20 @@ impl Storage for FileStorage { temp.sync_all().await?; drop(temp); tokio::fs::rename(tempfilename, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + tokio::fs::File::open(path.parent().unwrap()) + .await? + .sync_all() + .await?; Ok(()) } - async fn all_posts<'this>(&'this self, user: &url::Url) -> Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this> { + async fn all_posts<'this>( + &'this self, + user: &url::Url, + ) -> Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this> { todo!(); - #[allow(unreachable_code)] Ok(futures::stream::empty()) // for type inference + #[allow(unreachable_code)] + Ok(futures::stream::empty()) // for type inference } } diff --git a/src/database/memory.rs b/src/database/memory.rs index c2ceb85..75f04de 100644 --- a/src/database/memory.rs +++ b/src/database/memory.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; -use crate::database::{ErrorKind, MicropubChannel, Result, settings, Storage, StorageError}; +use crate::database::{settings, ErrorKind, MicropubChannel, Result, Storage, StorageError}; #[derive(Clone, Debug, Default)] /// A simple in-memory store for testing purposes. @@ -90,9 +90,16 @@ impl Storage for MemoryStorage { Ok(()) } - async fn update_post(&self, url: &'_ str, update: crate::micropub::MicropubUpdate) -> Result<()> { + async fn update_post( + &self, + url: &'_ str, + update: crate::micropub::MicropubUpdate, + ) -> Result<()> { let mut guard = self.mapping.write().await; - let mut post = guard.get_mut(url).ok_or(StorageError::from_static(ErrorKind::NotFound, "The specified post wasn't found in the database."))?; + let mut post = guard.get_mut(url).ok_or(StorageError::from_static( + ErrorKind::NotFound, + "The specified post wasn't found in the database.", + ))?; use crate::micropub::MicropubPropertyDeletion; @@ -208,7 +215,7 @@ impl Storage for MemoryStorage { url: &'_ str, cursor: Option<&'_ str>, limit: usize, - user: Option<&url::Url> + user: Option<&url::Url>, ) -> Result<Option<(serde_json::Value, Option<String>)>> { todo!() } @@ -224,25 +231,39 @@ impl Storage for MemoryStorage { } #[allow(unused_variables)] - async fn set_setting<S: settings::Setting>(&self, user: &url::Url, value: S::Data) -> Result<()> { + async fn set_setting<S: settings::Setting>( + &self, + user: &url::Url, + value: S::Data, + ) -> Result<()> { todo!() } #[allow(unused_variables)] - async fn add_or_update_webmention(&self, target: &str, mention_type: kittybox_util::MentionType, mention: serde_json::Value) -> Result<()> { + async fn add_or_update_webmention( + &self, + target: &str, + mention_type: kittybox_util::MentionType, + mention: serde_json::Value, + ) -> Result<()> { todo!() } #[allow(unused_variables)] async fn update_with<F: FnOnce(&mut serde_json::Value) + Send>( - &self, url: &str, f: F + &self, + url: &str, + f: F, ) -> Result<(serde_json::Value, serde_json::Value)> { todo!() } - async fn all_posts<'this>(&'this self, _user: &url::Url) -> Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this> { + async fn all_posts<'this>( + &'this self, + _user: &url::Url, + ) -> Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this> { todo!(); - #[allow(unreachable_code)] Ok(futures::stream::pending()) + #[allow(unreachable_code)] + Ok(futures::stream::pending()) } - } diff --git a/src/database/mod.rs b/src/database/mod.rs index 4390ae7..de51c2c 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -177,7 +177,7 @@ impl StorageError { Self { msg: Cow::Borrowed(msg), source: None, - kind + kind, } } /// Create a StorageError using another arbitrary Error as a source. @@ -219,27 +219,34 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { fn post_exists(&self, url: &str) -> impl Future<Output = Result<bool>> + Send; /// Load a post from the database in MF2-JSON format, deserialized from JSON. - fn get_post(&self, url: &str) -> impl Future<Output = Result<Option<serde_json::Value>>> + Send; + fn get_post(&self, url: &str) + -> impl Future<Output = Result<Option<serde_json::Value>>> + Send; /// Save a post to the database as an MF2-JSON structure. /// /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined. - fn put_post(&self, post: &serde_json::Value, user: &url::Url) -> impl Future<Output = Result<()>> + Send; + fn put_post( + &self, + post: &serde_json::Value, + user: &url::Url, + ) -> impl Future<Output = Result<()>> + Send; /// Add post to feed. Some database implementations might have optimized ways to do this. #[tracing::instrument(skip(self))] fn add_to_feed(&self, feed: &str, post: &str) -> impl Future<Output = Result<()>> + Send { tracing::debug!("Inserting {} into {} using `update_post`", post, feed); - self.update_post(feed, serde_json::from_value( - serde_json::json!({"add": {"children": [post]}})).unwrap() + self.update_post( + feed, + serde_json::from_value(serde_json::json!({"add": {"children": [post]}})).unwrap(), ) } /// Remove post from feed. Some database implementations might have optimized ways to do this. #[tracing::instrument(skip(self))] fn remove_from_feed(&self, feed: &str, post: &str) -> impl Future<Output = Result<()>> + Send { tracing::debug!("Removing {} into {} using `update_post`", post, feed); - self.update_post(feed, serde_json::from_value( - serde_json::json!({"delete": {"children": [post]}})).unwrap() + self.update_post( + feed, + serde_json::from_value(serde_json::json!({"delete": {"children": [post]}})).unwrap(), ) } @@ -254,7 +261,11 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { /// /// Default implementation calls [`Storage::update_with`] and uses /// [`update.apply`][MicropubUpdate::apply] to update the post. - fn update_post(&self, url: &str, update: MicropubUpdate) -> impl Future<Output = Result<()>> + Send { + fn update_post( + &self, + url: &str, + update: MicropubUpdate, + ) -> impl Future<Output = Result<()>> + Send { let fut = self.update_with(url, |post| { update.apply(post); }); @@ -274,12 +285,17 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { /// /// Returns old post and the new post after editing. fn update_with<F: FnOnce(&mut serde_json::Value) + Send>( - &self, url: &str, f: F + &self, + url: &str, + f: F, ) -> impl Future<Output = Result<(serde_json::Value, serde_json::Value)>> + Send; /// Get a list of channels available for the user represented by /// the `user` domain to write to. - fn get_channels(&self, user: &url::Url) -> impl Future<Output = Result<Vec<MicropubChannel>>> + Send; + fn get_channels( + &self, + user: &url::Url, + ) -> impl Future<Output = Result<Vec<MicropubChannel>>> + Send; /// Fetch a feed at `url` and return an h-feed object containing /// `limit` posts after a post by url `after`, filtering the content @@ -329,7 +345,7 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { url: &'_ str, cursor: Option<&'_ str>, limit: usize, - user: Option<&url::Url> + user: Option<&url::Url>, ) -> impl Future<Output = Result<Option<(serde_json::Value, Option<String>)>>> + Send; /// Deletes a post from the database irreversibly. Must be idempotent. @@ -339,7 +355,11 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { fn get_setting<S: Setting>(&self, user: &url::Url) -> impl Future<Output = Result<S>> + Send; /// Commits a setting to the setting store. - fn set_setting<S: Setting>(&self, user: &url::Url, value: S::Data) -> impl Future<Output = Result<()>> + Send; + fn set_setting<S: Setting>( + &self, + user: &url::Url, + value: S::Data, + ) -> impl Future<Output = Result<()>> + Send; /// Add (or update) a webmention on a certian post. /// @@ -355,11 +375,19 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { /// /// Besides, it may even allow for nice tricks like storing the /// webmentions separately and rehydrating them on feed reads. - fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> impl Future<Output = Result<()>> + Send; + fn add_or_update_webmention( + &self, + target: &str, + mention_type: MentionType, + mention: serde_json::Value, + ) -> impl Future<Output = Result<()>> + Send; /// Return a stream of all posts ever made by a certain user, in /// reverse-chronological order. - fn all_posts<'this>(&'this self, user: &url::Url) -> impl Future<Output = Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this>> + Send; + fn all_posts<'this>( + &'this self, + user: &url::Url, + ) -> impl Future<Output = Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this>> + Send; } #[cfg(test)] @@ -464,7 +492,8 @@ mod tests { "replace": { "content": ["Different test content"] } - })).unwrap(), + })) + .unwrap(), ) .await .unwrap(); @@ -511,7 +540,10 @@ mod tests { .put_post(&feed, &"https://fireburn.ru/".parse().unwrap()) .await .unwrap(); - let chans = backend.get_channels(&"https://fireburn.ru/".parse().unwrap()).await.unwrap(); + let chans = backend + .get_channels(&"https://fireburn.ru/".parse().unwrap()) + .await + .unwrap(); assert_eq!(chans.len(), 1); assert_eq!( chans[0], @@ -526,16 +558,16 @@ mod tests { backend .set_setting::<settings::SiteName>( &"https://fireburn.ru/".parse().unwrap(), - "Vika's Hideout".to_owned() + "Vika's Hideout".to_owned(), ) .await .unwrap(); assert_eq!( backend - .get_setting::<settings::SiteName>(&"https://fireburn.ru/".parse().unwrap()) - .await - .unwrap() - .as_ref(), + .get_setting::<settings::SiteName>(&"https://fireburn.ru/".parse().unwrap()) + .await + .unwrap() + .as_ref(), "Vika's Hideout" ); } @@ -597,11 +629,9 @@ mod tests { async fn test_feed_pagination<Backend: Storage>(backend: Backend) { let posts = { - let mut posts = std::iter::from_fn( - || Some(gen_random_post("fireburn.ru")) - ) - .take(40) - .collect::<Vec<serde_json::Value>>(); + let mut posts = std::iter::from_fn(|| Some(gen_random_post("fireburn.ru"))) + .take(40) + .collect::<Vec<serde_json::Value>>(); // Reverse the array so it's in reverse-chronological order posts.reverse(); @@ -629,7 +659,10 @@ mod tests { .put_post(post, &"https://fireburn.ru/".parse().unwrap()) .await .unwrap(); - backend.add_to_feed(key, post["properties"]["uid"][0].as_str().unwrap()).await.unwrap(); + backend + .add_to_feed(key, post["properties"]["uid"][0].as_str().unwrap()) + .await + .unwrap(); } let limit: usize = 10; @@ -648,23 +681,16 @@ mod tests { .unwrap() .iter() .map(|post| post["properties"]["uid"][0].as_str().unwrap()) - .collect::<Vec<_>>() - [0..10], + .collect::<Vec<_>>()[0..10], posts .iter() .map(|post| post["properties"]["uid"][0].as_str().unwrap()) - .collect::<Vec<_>>() - [0..10] + .collect::<Vec<_>>()[0..10] ); tracing::debug!("Continuing with cursor: {:?}", cursor); let (result2, cursor2) = backend - .read_feed_with_cursor( - key, - cursor.as_deref(), - limit, - None, - ) + .read_feed_with_cursor(key, cursor.as_deref(), limit, None) .await .unwrap() .unwrap(); @@ -676,12 +702,7 @@ mod tests { tracing::debug!("Continuing with cursor: {:?}", cursor); let (result3, cursor3) = backend - .read_feed_with_cursor( - key, - cursor2.as_deref(), - limit, - None, - ) + .read_feed_with_cursor(key, cursor2.as_deref(), limit, None) .await .unwrap() .unwrap(); @@ -693,12 +714,7 @@ mod tests { tracing::debug!("Continuing with cursor: {:?}", cursor); let (result4, _) = backend - .read_feed_with_cursor( - key, - cursor3.as_deref(), - limit, - None, - ) + .read_feed_with_cursor(key, cursor3.as_deref(), limit, None) .await .unwrap() .unwrap(); @@ -725,24 +741,43 @@ mod tests { async fn test_webmention_addition<Backend: Storage>(db: Backend) { let post = gen_random_post("fireburn.ru"); - db.put_post(&post, &"https://fireburn.ru/".parse().unwrap()).await.unwrap(); + db.put_post(&post, &"https://fireburn.ru/".parse().unwrap()) + .await + .unwrap(); const TYPE: MentionType = MentionType::Reply; let target = post["properties"]["uid"][0].as_str().unwrap(); let mut reply = gen_random_mention("aaronparecki.com", TYPE, target); - let (read_post, _) = db.read_feed_with_cursor(target, None, 20, None).await.unwrap().unwrap(); + let (read_post, _) = db + .read_feed_with_cursor(target, None, 20, None) + .await + .unwrap() + .unwrap(); assert_eq!(post, read_post); - db.add_or_update_webmention(target, TYPE, reply.clone()).await.unwrap(); + db.add_or_update_webmention(target, TYPE, reply.clone()) + .await + .unwrap(); - let (read_post, _) = db.read_feed_with_cursor(target, None, 20, None).await.unwrap().unwrap(); + let (read_post, _) = db + .read_feed_with_cursor(target, None, 20, None) + .await + .unwrap() + .unwrap(); assert_eq!(read_post["properties"]["comment"][0], reply); - reply["properties"]["content"][0] = json!(rand::random::<faker_rand::lorem::Paragraphs>().to_string()); + reply["properties"]["content"][0] = + json!(rand::random::<faker_rand::lorem::Paragraphs>().to_string()); - db.add_or_update_webmention(target, TYPE, reply.clone()).await.unwrap(); - let (read_post, _) = db.read_feed_with_cursor(target, None, 20, None).await.unwrap().unwrap(); + db.add_or_update_webmention(target, TYPE, reply.clone()) + .await + .unwrap(); + let (read_post, _) = db + .read_feed_with_cursor(target, None, 20, None) + .await + .unwrap() + .unwrap(); assert_eq!(read_post["properties"]["comment"][0], reply); } @@ -752,16 +787,20 @@ mod tests { let post = { let mut post = gen_random_post("fireburn.ru"); let urls = post["properties"]["url"].as_array_mut().unwrap(); - urls.push(serde_json::Value::String( - PERMALINK.to_owned() - )); + urls.push(serde_json::Value::String(PERMALINK.to_owned())); post }; - db.put_post(&post, &"https://fireburn.ru/".parse().unwrap()).await.unwrap(); + db.put_post(&post, &"https://fireburn.ru/".parse().unwrap()) + .await + .unwrap(); for i in post["properties"]["url"].as_array().unwrap() { - let (read_post, _) = db.read_feed_with_cursor(i.as_str().unwrap(), None, 20, None).await.unwrap().unwrap(); + let (read_post, _) = db + .read_feed_with_cursor(i.as_str().unwrap(), None, 20, None) + .await + .unwrap() + .unwrap(); assert_eq!(read_post, post); } } @@ -786,7 +825,7 @@ mod tests { async fn $func_name() { let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); let backend = super::super::FileStorage { - root_dir: tempdir.path().to_path_buf() + root_dir: tempdir.path().to_path_buf(), }; super::$func_name(backend).await } @@ -800,7 +839,7 @@ mod tests { #[tracing_test::traced_test] async fn $func_name( pool_opts: sqlx::postgres::PgPoolOptions, - connect_opts: sqlx::postgres::PgConnectOptions + connect_opts: sqlx::postgres::PgConnectOptions, ) -> Result<(), sqlx::Error> { let db = { //use sqlx::ConnectOptions; diff --git a/src/database/postgres/mod.rs b/src/database/postgres/mod.rs index af19fea..ec67efa 100644 --- a/src/database/postgres/mod.rs +++ b/src/database/postgres/mod.rs @@ -5,7 +5,7 @@ use kittybox_util::{micropub::Channel as MicropubChannel, MentionType}; use sqlx::{ConnectOptions, Executor, PgPool}; use super::settings::Setting; -use super::{Storage, Result, StorageError, ErrorKind}; +use super::{ErrorKind, Result, Storage, StorageError}; static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!(); @@ -14,7 +14,7 @@ impl From<sqlx::Error> for StorageError { Self::with_source( super::ErrorKind::Backend, Cow::Owned(format!("sqlx error: {}", &value)), - Box::new(value) + Box::new(value), ) } } @@ -24,7 +24,7 @@ impl From<sqlx::migrate::MigrateError> for StorageError { Self::with_source( super::ErrorKind::Backend, Cow::Owned(format!("sqlx migration error: {}", &value)), - Box::new(value) + Box::new(value), ) } } @@ -32,14 +32,15 @@ impl From<sqlx::migrate::MigrateError> for StorageError { /// Micropub storage that uses a PostgreSQL database. #[derive(Debug, Clone)] pub struct PostgresStorage { - db: PgPool + db: PgPool, } impl PostgresStorage { /// Construct a [`PostgresStorage`] from a [`sqlx::PgPool`], /// running appropriate migrations. pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result<Self> { - db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox")).await?; + db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox")) + .await?; MIGRATOR.run(&db).await?; Ok(Self { db }) } @@ -50,19 +51,22 @@ impl Storage for PostgresStorage { /// migrations on the database. async fn new(url: &'_ url::Url) -> Result<Self> { tracing::debug!("Postgres URL: {url}"); - let options = sqlx::postgres::PgConnectOptions::from_url(url)? - .options([("search_path", "kittybox")]); + let options = + sqlx::postgres::PgConnectOptions::from_url(url)?.options([("search_path", "kittybox")]); Self::from_pool( sqlx::postgres::PgPoolOptions::new() .max_connections(50) .connect_with(options) - .await? - ).await - + .await?, + ) + .await } - async fn all_posts<'this>(&'this self, user: &url::Url) -> Result<impl Stream<Item = serde_json::Value> + Send + 'this> { + async fn all_posts<'this>( + &'this self, + user: &url::Url, + ) -> Result<impl Stream<Item = serde_json::Value> + Send + 'this> { let authority = user.authority().to_owned(); Ok( sqlx::query_scalar::<_, serde_json::Value>("SELECT mf2 FROM kittybox.mf2_json WHERE owner = $1 ORDER BY (mf2_json.mf2 #>> '{properties,published,0}') DESC") @@ -74,18 +78,20 @@ impl Storage for PostgresStorage { #[tracing::instrument(skip(self))] async fn categories(&self, url: &str) -> Result<Vec<String>> { - sqlx::query_scalar::<_, String>(" + sqlx::query_scalar::<_, String>( + " SELECT jsonb_array_elements(mf2['properties']['category']) AS category FROM kittybox.mf2_json WHERE jsonb_typeof(mf2['properties']['category']) = 'array' AND uid LIKE ($1 + '%') GROUP BY category ORDER BY count(*) DESC -") - .bind(url) - .fetch_all(&self.db) - .await - .map_err(|err| err.into()) +", + ) + .bind(url) + .fetch_all(&self.db) + .await + .map_err(|err| err.into()) } #[tracing::instrument(skip(self))] async fn post_exists(&self, url: &str) -> Result<bool> { @@ -98,13 +104,14 @@ WHERE #[tracing::instrument(skip(self))] async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> { - sqlx::query_as::<_, (serde_json::Value,)>("SELECT mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1") - .bind(url) - .fetch_optional(&self.db) - .await - .map(|v| v.map(|v| v.0)) - .map_err(|err| err.into()) - + sqlx::query_as::<_, (serde_json::Value,)>( + "SELECT mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1", + ) + .bind(url) + .fetch_optional(&self.db) + .await + .map(|v| v.map(|v| v.0)) + .map_err(|err| err.into()) } #[tracing::instrument(skip(self))] @@ -122,13 +129,15 @@ WHERE #[tracing::instrument(skip(self))] async fn add_to_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> { tracing::debug!("Inserting {} into {}", post, feed); - sqlx::query("INSERT INTO kittybox.children (parent, child) VALUES ($1, $2) ON CONFLICT DO NOTHING") - .bind(feed) - .bind(post) - .execute(&self.db) - .await - .map(|_| ()) - .map_err(Into::into) + sqlx::query( + "INSERT INTO kittybox.children (parent, child) VALUES ($1, $2) ON CONFLICT DO NOTHING", + ) + .bind(feed) + .bind(post) + .execute(&self.db) + .await + .map(|_| ()) + .map_err(Into::into) } #[tracing::instrument(skip(self))] @@ -143,7 +152,12 @@ WHERE } #[tracing::instrument(skip(self))] - async fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> Result<()> { + async fn add_or_update_webmention( + &self, + target: &str, + mention_type: MentionType, + mention: serde_json::Value, + ) -> Result<()> { let mut txn = self.db.begin().await?; let (uid, mut post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE") @@ -190,7 +204,9 @@ WHERE #[tracing::instrument(skip(self), fields(f = std::any::type_name::<F>()))] async fn update_with<F: FnOnce(&mut serde_json::Value) + Send>( - &self, url: &str, f: F + &self, + url: &str, + f: F, ) -> Result<(serde_json::Value, serde_json::Value)> { tracing::debug!("Updating post {}", url); let mut txn = self.db.begin().await?; @@ -250,12 +266,12 @@ WHERE url: &'_ str, cursor: Option<&'_ str>, limit: usize, - user: Option<&url::Url> + user: Option<&url::Url>, ) -> Result<Option<(serde_json::Value, Option<String>)>> { let mut txn = self.db.begin().await?; sqlx::query("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ ONLY") - .execute(&mut *txn) - .await?; + .execute(&mut *txn) + .await?; tracing::debug!("Started txn: {:?}", txn); let mut feed = match sqlx::query_scalar::<_, serde_json::Value>(" SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 @@ -273,11 +289,17 @@ SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json WHERE uid = $1 OR mf2 // The second query is very long and will probably be extremely // expensive. It's best to skip it on types where it doesn't make sense // (Kittybox doesn't support rendering children on non-feeds) - if !feed["type"].as_array().unwrap().iter().any(|t| *t == serde_json::json!("h-feed")) { + if !feed["type"] + .as_array() + .unwrap() + .iter() + .any(|t| *t == serde_json::json!("h-feed")) + { return Ok(Some((feed, None))); } - feed["children"] = sqlx::query_scalar::<_, serde_json::Value>(" + feed["children"] = sqlx::query_scalar::<_, serde_json::Value>( + " SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json INNER JOIN kittybox.children ON mf2_json.uid = children.child @@ -302,17 +324,19 @@ WHERE ) AND ($4 IS NULL OR ((mf2_json.mf2 #>> '{properties,published,0}') < $4)) ORDER BY (mf2_json.mf2 #>> '{properties,published,0}') DESC -LIMIT $2" +LIMIT $2", ) - .bind(url) - .bind(limit as i64) - .bind(user.map(url::Url::as_str)) - .bind(cursor) - .fetch_all(&mut *txn) - .await - .map(serde_json::Value::Array)?; - - let new_cursor = feed["children"].as_array().unwrap() + .bind(url) + .bind(limit as i64) + .bind(user.map(url::Url::as_str)) + .bind(cursor) + .fetch_all(&mut *txn) + .await + .map(serde_json::Value::Array)?; + + let new_cursor = feed["children"] + .as_array() + .unwrap() .last() .map(|v| v["properties"]["published"][0].as_str().unwrap().to_owned()); @@ -335,7 +359,7 @@ LIMIT $2" .await { Ok((value,)) => Ok(serde_json::from_value(value)?), - Err(err) => Err(err.into()) + Err(err) => Err(err.into()), } } diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 9ba1a69..94b8aa7 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -12,12 +12,10 @@ use tracing::{debug, error}; //pub mod login; pub mod onboarding; +pub use kittybox_frontend_renderer::assets::statics; use kittybox_frontend_renderer::{ - Entry, Feed, VCard, - ErrorPage, Template, MainPage, - POSTS_PER_PAGE + Entry, ErrorPage, Feed, MainPage, Template, VCard, POSTS_PER_PAGE, }; -pub use kittybox_frontend_renderer::assets::statics; #[derive(Debug, Deserialize)] pub struct QueryParams { @@ -106,7 +104,7 @@ pub fn filter_post( .map(|i| -> &str { match i { serde_json::Value::String(ref author) => author.as_str(), - mf2 => mf2["properties"]["uid"][0].as_str().unwrap() + mf2 => mf2["properties"]["uid"][0].as_str().unwrap(), } }) .map(|i| i.parse().unwrap()) @@ -116,11 +114,13 @@ pub fn filter_post( .unwrap_or("public"); let audience = { let mut audience = author_list.clone(); - audience.extend(post["properties"]["audience"] - .as_array() - .unwrap_or(&empty_vec) - .iter() - .map(|i| i.as_str().unwrap().parse().unwrap())); + audience.extend( + post["properties"]["audience"] + .as_array() + .unwrap_or(&empty_vec) + .iter() + .map(|i| i.as_str().unwrap().parse().unwrap()), + ); audience }; @@ -134,7 +134,10 @@ pub fn filter_post( let location_visibility = post["properties"]["location-visibility"][0] .as_str() .unwrap_or("private"); - tracing::debug!("Post contains location, location privacy = {}", location_visibility); + tracing::debug!( + "Post contains location, location privacy = {}", + location_visibility + ); let mut author = post["properties"]["author"] .as_array() .unwrap_or(&empty_vec) @@ -155,16 +158,18 @@ pub fn filter_post( post["properties"]["author"] = serde_json::Value::Array( children .into_iter() - .filter_map(|post| if post.is_string() { - Some(post) - } else { - filter_post(post, user) + .filter_map(|post| { + if post.is_string() { + Some(post) + } else { + filter_post(post, user) + } }) - .collect::<Vec<serde_json::Value>>() + .collect::<Vec<serde_json::Value>>(), ); - }, - serde_json::Value::Null => {}, - other => post["properties"]["author"] = other + } + serde_json::Value::Null => {} + other => post["properties"]["author"] = other, } match post["children"].take() { @@ -173,11 +178,11 @@ pub fn filter_post( children .into_iter() .filter_map(|post| filter_post(post, user)) - .collect::<Vec<serde_json::Value>>() + .collect::<Vec<serde_json::Value>>(), ); - }, - serde_json::Value::Null => {}, - other => post["children"] = other + } + serde_json::Value::Null => {} + other => post["children"] = other, } Some(post) } @@ -209,7 +214,7 @@ async fn get_post_from_database<S: Storage>( )) } } - } + }, None => Err(FrontendError::with_code( StatusCode::NOT_FOUND, "Post not found in the database", @@ -240,7 +245,7 @@ pub async fn homepage<D: Storage>( Host(host): Host, Query(query): Query<QueryParams>, State(db): State<D>, - session: Option<crate::Session> + session: Option<crate::Session>, ) -> impl IntoResponse { // This is stupid, but there is no other way. let hcard_url: url::Url = format!("https://{}/", host).parse().unwrap(); @@ -252,7 +257,7 @@ pub async fn homepage<D: Storage>( ); headers.insert( axum::http::header::X_CONTENT_TYPE_OPTIONS, - axum::http::HeaderValue::from_static("nosniff") + axum::http::HeaderValue::from_static("nosniff"), ); let user = session.as_deref().map(|s| &s.me); @@ -268,18 +273,16 @@ pub async fn homepage<D: Storage>( // btw is it more efficient to fetch these in parallel? let (blogname, webring, channels) = tokio::join!( db.get_setting::<crate::database::settings::SiteName>(&hcard_url) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_setting::<crate::database::settings::Webring>(&hcard_url) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_channels(&hcard_url).map(|i| i.unwrap_or_default()) ); if user.is_some() { headers.insert( axum::http::header::CACHE_CONTROL, - axum::http::HeaderValue::from_static("private") + axum::http::HeaderValue::from_static("private"), ); } // Render the homepage @@ -295,12 +298,13 @@ pub async fn homepage<D: Storage>( feed: &hfeed, card: &hcard, cursor: cursor.as_deref(), - webring: crate::database::settings::Setting::into_inner(webring) + webring: crate::database::settings::Setting::into_inner(webring), } .to_string(), } .to_string(), - ).into_response() + ) + .into_response() } Err(err) => { if err.code == StatusCode::NOT_FOUND { @@ -310,19 +314,20 @@ pub async fn homepage<D: Storage>( StatusCode::FOUND, [(axum::http::header::LOCATION, "/.kittybox/onboarding")], String::default(), - ).into_response() + ) + .into_response() } else { error!("Error while fetching h-card and/or h-feed: {}", err); // Return the error let (blogname, channels) = tokio::join!( db.get_setting::<crate::database::settings::SiteName>(&hcard_url) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_channels(&hcard_url).map(|i| i.unwrap_or_default()) ); ( - err.code(), headers, + err.code(), + headers, Template { title: blogname.as_ref(), blog_name: blogname.as_ref(), @@ -335,7 +340,8 @@ pub async fn homepage<D: Storage>( .to_string(), } .to_string(), - ).into_response() + ) + .into_response() } } } @@ -351,17 +357,13 @@ pub async fn catchall<D: Storage>( ) -> impl IntoResponse { let user: Option<&url::Url> = session.as_deref().map(|p| &p.me); let host = url::Url::parse(&format!("https://{}/", host)).unwrap(); - let path = host - .clone() - .join(uri.path()) - .unwrap(); + let path = host.clone().join(uri.path()).unwrap(); match get_post_from_database(&db, path.as_str(), query.after, user).await { Ok((post, cursor)) => { let (blogname, channels) = tokio::join!( db.get_setting::<crate::database::settings::SiteName>(&host) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_channels(&host).map(|i| i.unwrap_or_default()) ); let mut headers = axum::http::HeaderMap::new(); @@ -371,12 +373,12 @@ pub async fn catchall<D: Storage>( ); headers.insert( axum::http::header::X_CONTENT_TYPE_OPTIONS, - axum::http::HeaderValue::from_static("nosniff") + axum::http::HeaderValue::from_static("nosniff"), ); if user.is_some() { headers.insert( axum::http::header::CACHE_CONTROL, - axum::http::HeaderValue::from_static("private") + axum::http::HeaderValue::from_static("private"), ); } @@ -384,19 +386,20 @@ pub async fn catchall<D: Storage>( let last_modified = post["properties"]["updated"] .as_array() .and_then(|v| v.last()) - .or_else(|| post["properties"]["published"] - .as_array() - .and_then(|v| v.last()) - ) + .or_else(|| { + post["properties"]["published"] + .as_array() + .and_then(|v| v.last()) + }) .and_then(serde_json::Value::as_str) - .and_then(|dt| chrono::DateTime::<chrono::FixedOffset>::parse_from_rfc3339(dt).ok()); + .and_then(|dt| { + chrono::DateTime::<chrono::FixedOffset>::parse_from_rfc3339(dt).ok() + }); if let Some(last_modified) = last_modified { - headers.typed_insert( - axum_extra::headers::LastModified::from( - std::time::SystemTime::from(last_modified) - ) - ); + headers.typed_insert(axum_extra::headers::LastModified::from( + std::time::SystemTime::from(last_modified), + )); } } @@ -410,8 +413,16 @@ pub async fn catchall<D: Storage>( feeds: channels, user: session.as_deref(), content: match post.pointer("/type/0").and_then(|i| i.as_str()) { - Some("h-entry") => Entry { post: &post, from_feed: false, }.to_string(), - Some("h-feed") => Feed { feed: &post, cursor: cursor.as_deref() }.to_string(), + Some("h-entry") => Entry { + post: &post, + from_feed: false, + } + .to_string(), + Some("h-feed") => Feed { + feed: &post, + cursor: cursor.as_deref(), + } + .to_string(), Some("h-card") => VCard { card: &post }.to_string(), unknown => { unimplemented!("Template for MF2-JSON type {:?}", unknown) @@ -419,13 +430,13 @@ pub async fn catchall<D: Storage>( }, } .to_string(), - ).into_response() + ) + .into_response() } Err(err) => { let (blogname, channels) = tokio::join!( db.get_setting::<crate::database::settings::SiteName>(&host) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_channels(&host).map(|i| i.unwrap_or_default()) ); ( @@ -446,7 +457,8 @@ pub async fn catchall<D: Storage>( .to_string(), } .to_string(), - ).into_response() + ) + .into_response() } } } diff --git a/src/frontend/onboarding.rs b/src/frontend/onboarding.rs index bf313cf..3b53911 100644 --- a/src/frontend/onboarding.rs +++ b/src/frontend/onboarding.rs @@ -10,7 +10,7 @@ use axum::{ use axum_extra::extract::Host; use kittybox_frontend_renderer::{ErrorPage, OnboardingPage, Template}; use serde::Deserialize; -use tokio::{task::JoinSet, sync::Mutex}; +use tokio::{sync::Mutex, task::JoinSet}; use tracing::{debug, error}; use super::FrontendError; @@ -64,7 +64,8 @@ async fn onboard<D: Storage + 'static>( me: user_uid.clone(), client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), scope: kittybox_indieauth::Scopes::new(vec![kittybox_indieauth::Scope::Create]), - iat: None, exp: None + iat: None, + exp: None, }; tracing::debug!("User data: {:?}", user); @@ -99,19 +100,21 @@ async fn onboard<D: Storage + 'static>( continue; }; debug!("Creating feed {} with slug {}", &feed.name, &feed.slug); - let crate::micropub::util::NormalizedPost { id: _, post: feed } = crate::micropub::normalize_mf2( - serde_json::json!({ - "type": ["h-feed"], - "properties": {"name": [feed.name], "mp-slug": [feed.slug]} - }), - &user, - ); + let crate::micropub::util::NormalizedPost { id: _, post: feed } = + crate::micropub::normalize_mf2( + serde_json::json!({ + "type": ["h-feed"], + "properties": {"name": [feed.name], "mp-slug": [feed.slug]} + }), + &user, + ); db.put_post(&feed, &user.me) .await .map_err(FrontendError::from)?; } - let crate::micropub::util::NormalizedPost { id: uid, post } = crate::micropub::normalize_mf2(data.first_post, &user); + let crate::micropub::util::NormalizedPost { id: uid, post } = + crate::micropub::normalize_mf2(data.first_post, &user); tracing::debug!("Posting first post {}...", uid); crate::micropub::_post(&user, uid, post, db, http, jobset) .await @@ -169,6 +172,5 @@ where reqwest_middleware::ClientWithMiddleware: FromRef<St>, St: Clone + Send + Sync + 'static, { - axum::routing::get(get) - .post(post::<S>) + axum::routing::get(get).post(post::<S>) } diff --git a/src/indieauth/backend.rs b/src/indieauth/backend.rs index b913256..9215adf 100644 --- a/src/indieauth/backend.rs +++ b/src/indieauth/backend.rs @@ -1,9 +1,7 @@ -use std::future::Future; -use std::collections::HashMap; -use kittybox_indieauth::{ - AuthorizationRequest, TokenData -}; +use kittybox_indieauth::{AuthorizationRequest, TokenData}; pub use kittybox_util::auth::EnrolledCredential; +use std::collections::HashMap; +use std::future::Future; type Result<T> = std::io::Result<T>; @@ -20,33 +18,72 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { /// Note for implementors: the [`AuthorizationRequest::me`] value /// is guaranteed to be [`Some(url::Url)`][Option::Some] and can /// be trusted to be correct and non-malicious. - fn create_code(&self, data: AuthorizationRequest) -> impl Future<Output = Result<String>> + Send; + fn create_code( + &self, + data: AuthorizationRequest, + ) -> impl Future<Output = Result<String>> + Send; /// Retreive an authorization request using the one-time /// code. Implementations must sanitize the `code` field to /// prevent exploits, and must check if the code should still be /// valid at this point in time (validity interval is left up to /// the implementation, but is recommended to be no more than 10 /// minutes). - fn get_code(&self, code: &str) -> impl Future<Output = Result<Option<AuthorizationRequest>>> + Send; + fn get_code( + &self, + code: &str, + ) -> impl Future<Output = Result<Option<AuthorizationRequest>>> + Send; // Token management. fn create_token(&self, data: TokenData) -> impl Future<Output = Result<String>> + Send; - fn get_token(&self, website: &url::Url, token: &str) -> impl Future<Output = Result<Option<TokenData>>> + Send; - fn list_tokens(&self, website: &url::Url) -> impl Future<Output = Result<HashMap<String, TokenData>>> + Send; - fn revoke_token(&self, website: &url::Url, token: &str) -> impl Future<Output = Result<()>> + Send; + fn get_token( + &self, + website: &url::Url, + token: &str, + ) -> impl Future<Output = Result<Option<TokenData>>> + Send; + fn list_tokens( + &self, + website: &url::Url, + ) -> impl Future<Output = Result<HashMap<String, TokenData>>> + Send; + fn revoke_token( + &self, + website: &url::Url, + token: &str, + ) -> impl Future<Output = Result<()>> + Send; // Refresh token management. fn create_refresh_token(&self, data: TokenData) -> impl Future<Output = Result<String>> + Send; - fn get_refresh_token(&self, website: &url::Url, token: &str) -> impl Future<Output = Result<Option<TokenData>>> + Send; - fn list_refresh_tokens(&self, website: &url::Url) -> impl Future<Output = Result<HashMap<String, TokenData>>> + Send; - fn revoke_refresh_token(&self, website: &url::Url, token: &str) -> impl Future<Output = Result<()>> + Send; + fn get_refresh_token( + &self, + website: &url::Url, + token: &str, + ) -> impl Future<Output = Result<Option<TokenData>>> + Send; + fn list_refresh_tokens( + &self, + website: &url::Url, + ) -> impl Future<Output = Result<HashMap<String, TokenData>>> + Send; + fn revoke_refresh_token( + &self, + website: &url::Url, + token: &str, + ) -> impl Future<Output = Result<()>> + Send; // Password management. /// Verify a password. #[must_use] - fn verify_password(&self, website: &url::Url, password: String) -> impl Future<Output = Result<bool>> + Send; + fn verify_password( + &self, + website: &url::Url, + password: String, + ) -> impl Future<Output = Result<bool>> + Send; /// Enroll a password credential for a user. Only one password /// credential must exist for a given user. - fn enroll_password(&self, website: &url::Url, password: String) -> impl Future<Output = Result<()>> + Send; + fn enroll_password( + &self, + website: &url::Url, + password: String, + ) -> impl Future<Output = Result<()>> + Send; /// List currently enrolled credential types for a given user. - fn list_user_credential_types(&self, website: &url::Url) -> impl Future<Output = Result<Vec<EnrolledCredential>>> + Send; + fn list_user_credential_types( + &self, + website: &url::Url, + ) -> impl Future<Output = Result<Vec<EnrolledCredential>>> + Send; // WebAuthn credential management. #[cfg(feature = "webauthn")] /// Enroll a WebAuthn authenticator public key for this user. @@ -56,10 +93,17 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { /// This function can also be used to overwrite a passkey with an /// updated version after using /// [webauthn::prelude::Passkey::update_credential()]. - fn enroll_webauthn(&self, website: &url::Url, credential: webauthn::prelude::Passkey) -> impl Future<Output = Result<()>> + Send; + fn enroll_webauthn( + &self, + website: &url::Url, + credential: webauthn::prelude::Passkey, + ) -> impl Future<Output = Result<()>> + Send; #[cfg(feature = "webauthn")] /// List currently enrolled WebAuthn authenticators for a given user. - fn list_webauthn_pubkeys(&self, website: &url::Url) -> impl Future<Output = Result<Vec<webauthn::prelude::Passkey>>> + Send; + fn list_webauthn_pubkeys( + &self, + website: &url::Url, + ) -> impl Future<Output = Result<Vec<webauthn::prelude::Passkey>>> + Send; #[cfg(feature = "webauthn")] /// Persist registration challenge state for a little while so it /// can be used later. @@ -69,7 +113,7 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { fn persist_registration_challenge( &self, website: &url::Url, - state: webauthn::prelude::PasskeyRegistration + state: webauthn::prelude::PasskeyRegistration, ) -> impl Future<Output = Result<String>> + Send; #[cfg(feature = "webauthn")] /// Retrieve a persisted registration challenge. @@ -78,7 +122,7 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { fn retrieve_registration_challenge( &self, website: &url::Url, - challenge_id: &str + challenge_id: &str, ) -> impl Future<Output = Result<webauthn::prelude::PasskeyRegistration>> + Send; #[cfg(feature = "webauthn")] /// Persist authentication challenge state for a little while so @@ -92,7 +136,7 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { fn persist_authentication_challenge( &self, website: &url::Url, - state: webauthn::prelude::PasskeyAuthentication + state: webauthn::prelude::PasskeyAuthentication, ) -> impl Future<Output = Result<String>> + Send; #[cfg(feature = "webauthn")] /// Retrieve a persisted authentication challenge. @@ -101,7 +145,6 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { fn retrieve_authentication_challenge( &self, website: &url::Url, - challenge_id: &str + challenge_id: &str, ) -> impl Future<Output = Result<webauthn::prelude::PasskeyAuthentication>> + Send; - } diff --git a/src/indieauth/backend/fs.rs b/src/indieauth/backend/fs.rs index f74fbbc..26466fe 100644 --- a/src/indieauth/backend/fs.rs +++ b/src/indieauth/backend/fs.rs @@ -1,13 +1,16 @@ -use std::{path::PathBuf, collections::HashMap, borrow::Cow, time::{SystemTime, Duration}}; - -use super::{AuthBackend, Result, EnrolledCredential}; -use kittybox_indieauth::{ - AuthorizationRequest, TokenData +use std::{ + borrow::Cow, + collections::HashMap, + path::PathBuf, + time::{Duration, SystemTime}, }; + +use super::{AuthBackend, EnrolledCredential, Result}; +use kittybox_indieauth::{AuthorizationRequest, TokenData}; use serde::de::DeserializeOwned; -use tokio::{task::spawn_blocking, io::AsyncReadExt}; +use tokio::{io::AsyncReadExt, task::spawn_blocking}; #[cfg(feature = "webauthn")] -use webauthn::prelude::{Passkey, PasskeyRegistration, PasskeyAuthentication}; +use webauthn::prelude::{Passkey, PasskeyAuthentication, PasskeyRegistration}; const CODE_LENGTH: usize = 16; const TOKEN_LENGTH: usize = 128; @@ -29,7 +32,8 @@ impl FileBackend { } else { let mut s = String::with_capacity(filename.len()); - filename.chars() + filename + .chars() .filter(|c| c.is_alphanumeric()) .for_each(|c| s.push(c)); @@ -38,41 +42,41 @@ impl FileBackend { } #[inline] - async fn serialize_to_file<T: 'static + serde::ser::Serialize + Send, B: Into<Option<&'static str>>>( + async fn serialize_to_file< + T: 'static + serde::ser::Serialize + Send, + B: Into<Option<&'static str>>, + >( &self, dir: &str, basename: B, length: usize, - data: T + data: T, ) -> Result<String> { let basename = basename.into(); let has_ext = basename.is_some(); - let (filename, mut file) = kittybox_util::fs::mktemp( - self.path.join(dir), basename, length - ) + let (filename, mut file) = kittybox_util::fs::mktemp(self.path.join(dir), basename, length) .await .map(|(name, file)| (name, file.try_into_std().unwrap()))?; spawn_blocking(move || serde_json::to_writer(&mut file, &data)) .await - .unwrap_or_else(|e| panic!( - "Panic while serializing {}: {}", - std::any::type_name::<T>(), - e - )) + .unwrap_or_else(|e| { + panic!( + "Panic while serializing {}: {}", + std::any::type_name::<T>(), + e + ) + }) .map(move |_| { (if has_ext { - filename - .extension() - + filename.extension() } else { - filename - .file_name() + filename.file_name() }) - .unwrap() - .to_str() - .unwrap() - .to_owned() + .unwrap() + .to_str() + .unwrap() + .to_owned() }) .map_err(|err| err.into()) } @@ -86,17 +90,15 @@ impl FileBackend { ) -> Result<Option<(PathBuf, SystemTime, T)>> where T: serde::de::DeserializeOwned + Send, - B: Into<Option<&'static str>> + B: Into<Option<&'static str>>, { let basename = basename.into(); - let path = self.path - .join(dir) - .join(format!( - "{}{}{}", - basename.unwrap_or(""), - if basename.is_none() { "" } else { "." }, - FileBackend::sanitize_for_path(filename) - )); + let path = self.path.join(dir).join(format!( + "{}{}{}", + basename.unwrap_or(""), + if basename.is_none() { "" } else { "." }, + FileBackend::sanitize_for_path(filename) + )); let data = match tokio::fs::File::open(&path).await { Ok(mut file) => { @@ -106,13 +108,15 @@ impl FileBackend { match serde_json::from_slice::<'_, T>(buf.as_slice()) { Ok(data) => data, - Err(err) => return Err(err.into()) + Err(err) => return Err(err.into()), + } + } + Err(err) => { + if err.kind() == std::io::ErrorKind::NotFound { + return Ok(None); + } else { + return Err(err); } - }, - Err(err) => if err.kind() == std::io::ErrorKind::NotFound { - return Ok(None) - } else { - return Err(err) } }; @@ -125,7 +129,8 @@ impl FileBackend { #[tracing::instrument] fn url_to_dir(url: &url::Url) -> String { let host = url.host_str().unwrap(); - let port = url.port() + let port = url + .port() .map(|port| Cow::Owned(format!(":{}", port))) .unwrap_or(Cow::Borrowed("")); @@ -135,23 +140,26 @@ impl FileBackend { async fn list_files<'dir, 'this: 'dir, T: DeserializeOwned + Send>( &'this self, dir: &'dir str, - prefix: &'static str + prefix: &'static str, ) -> Result<HashMap<String, T>> { let dir = self.path.join(dir); let mut hashmap = HashMap::new(); let mut readdir = match tokio::fs::read_dir(dir).await { Ok(readdir) => readdir, - Err(err) => if err.kind() == std::io::ErrorKind::NotFound { - // empty hashmap - return Ok(hashmap); - } else { - return Err(err); + Err(err) => { + if err.kind() == std::io::ErrorKind::NotFound { + // empty hashmap + return Ok(hashmap); + } else { + return Err(err); + } } }; while let Some(entry) = readdir.next_entry().await? { // safe to unwrap; filenames are alphanumeric - let filename = entry.file_name() + let filename = entry + .file_name() .into_string() .expect("token filenames should be alphanumeric!"); if let Some(token) = filename.strip_prefix(&format!("{}.", prefix)) { @@ -166,16 +174,19 @@ impl FileBackend { Err(err) => { tracing::error!( "Error decoding token data from file {}: {}", - entry.path().display(), err + entry.path().display(), + err ); continue; } }; - }, - Err(err) => if err.kind() == std::io::ErrorKind::NotFound { - continue - } else { - return Err(err) + } + Err(err) => { + if err.kind() == std::io::ErrorKind::NotFound { + continue; + } else { + return Err(err); + } } } } @@ -194,19 +205,27 @@ impl AuthBackend for FileBackend { path = base.join(&format!(".{}", path.path())).unwrap(); } - tracing::debug!("Initializing File auth backend: {} -> {}", orig_path, path.path()); + tracing::debug!( + "Initializing File auth backend: {} -> {}", + orig_path, + path.path() + ); Ok(Self { - path: std::path::PathBuf::from(path.path()) + path: std::path::PathBuf::from(path.path()), }) } // Authorization code management. async fn create_code(&self, data: AuthorizationRequest) -> Result<String> { - self.serialize_to_file("codes", None, CODE_LENGTH, data).await + self.serialize_to_file("codes", None, CODE_LENGTH, data) + .await } async fn get_code(&self, code: &str) -> Result<Option<AuthorizationRequest>> { - match self.deserialize_from_file("codes", None, FileBackend::sanitize_for_path(code).as_ref()).await? { + match self + .deserialize_from_file("codes", None, FileBackend::sanitize_for_path(code).as_ref()) + .await? + { Some((path, ctime, data)) => { if let Err(err) = tokio::fs::remove_file(path).await { tracing::error!("Failed to clean up authorization code: {}", err); @@ -217,23 +236,28 @@ impl AuthBackend for FileBackend { } else { Ok(Some(data)) } - }, - None => Ok(None) + } + None => Ok(None), } } // Token management. async fn create_token(&self, data: TokenData) -> Result<String> { let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me)); - self.serialize_to_file(&dir, "access", TOKEN_LENGTH, data).await + self.serialize_to_file(&dir, "access", TOKEN_LENGTH, data) + .await } async fn get_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> { let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); - match self.deserialize_from_file::<TokenData, _>( - &dir, "access", - FileBackend::sanitize_for_path(token).as_ref() - ).await? { + match self + .deserialize_from_file::<TokenData, _>( + &dir, + "access", + FileBackend::sanitize_for_path(token).as_ref(), + ) + .await? + { Some((path, _, token)) => { if token.expired() { if let Err(err) = tokio::fs::remove_file(path).await { @@ -243,8 +267,8 @@ impl AuthBackend for FileBackend { } else { Ok(Some(token)) } - }, - None => Ok(None) + } + None => Ok(None), } } @@ -258,25 +282,36 @@ impl AuthBackend for FileBackend { self.path .join(FileBackend::url_to_dir(website)) .join("tokens") - .join(format!("access.{}", FileBackend::sanitize_for_path(token))) - ).await { + .join(format!("access.{}", FileBackend::sanitize_for_path(token))), + ) + .await + { Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()), - result => result + result => result, } } // Refresh token management. async fn create_refresh_token(&self, data: TokenData) -> Result<String> { let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me)); - self.serialize_to_file(&dir, "refresh", TOKEN_LENGTH, data).await + self.serialize_to_file(&dir, "refresh", TOKEN_LENGTH, data) + .await } - async fn get_refresh_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> { + async fn get_refresh_token( + &self, + website: &url::Url, + token: &str, + ) -> Result<Option<TokenData>> { let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); - match self.deserialize_from_file::<TokenData, _>( - &dir, "refresh", - FileBackend::sanitize_for_path(token).as_ref() - ).await? { + match self + .deserialize_from_file::<TokenData, _>( + &dir, + "refresh", + FileBackend::sanitize_for_path(token).as_ref(), + ) + .await? + { Some((path, _, token)) => { if token.expired() { if let Err(err) = tokio::fs::remove_file(path).await { @@ -286,8 +321,8 @@ impl AuthBackend for FileBackend { } else { Ok(Some(token)) } - }, - None => Ok(None) + } + None => Ok(None), } } @@ -301,57 +336,80 @@ impl AuthBackend for FileBackend { self.path .join(FileBackend::url_to_dir(website)) .join("tokens") - .join(format!("refresh.{}", FileBackend::sanitize_for_path(token))) - ).await { + .join(format!("refresh.{}", FileBackend::sanitize_for_path(token))), + ) + .await + { Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()), - result => result + result => result, } } // Password management. #[tracing::instrument(skip(password))] async fn verify_password(&self, website: &url::Url, password: String) -> Result<bool> { - use argon2::{Argon2, password_hash::{PasswordHash, PasswordVerifier}}; + use argon2::{ + password_hash::{PasswordHash, PasswordVerifier}, + Argon2, + }; - let password_filename = self.path + let password_filename = self + .path .join(FileBackend::url_to_dir(website)) .join("password"); - tracing::debug!("Reading password for {} from {}", website, password_filename.display()); + tracing::debug!( + "Reading password for {} from {}", + website, + password_filename.display() + ); match tokio::fs::read_to_string(password_filename).await { Ok(password_hash) => { let parsed_hash = { let hash = password_hash.trim(); - #[cfg(debug_assertions)] tracing::debug!("Password hash: {}", hash); - PasswordHash::new(hash) - .expect("Password hash should be valid!") + #[cfg(debug_assertions)] + tracing::debug!("Password hash: {}", hash); + PasswordHash::new(hash).expect("Password hash should be valid!") }; - Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok()) - }, - Err(err) => if err.kind() == std::io::ErrorKind::NotFound { - Ok(false) - } else { - Err(err) + Ok(Argon2::default() + .verify_password(password.as_bytes(), &parsed_hash) + .is_ok()) + } + Err(err) => { + if err.kind() == std::io::ErrorKind::NotFound { + Ok(false) + } else { + Err(err) + } } } } #[tracing::instrument(skip(password))] async fn enroll_password(&self, website: &url::Url, password: String) -> Result<()> { - use argon2::{Argon2, password_hash::{rand_core::OsRng, PasswordHasher, SaltString}}; + use argon2::{ + password_hash::{rand_core::OsRng, PasswordHasher, SaltString}, + Argon2, + }; - let password_filename = self.path + let password_filename = self + .path .join(FileBackend::url_to_dir(website)) .join("password"); let salt = SaltString::generate(&mut OsRng); let argon2 = Argon2::default(); - let password_hash = argon2.hash_password(password.as_bytes(), &salt) + let password_hash = argon2 + .hash_password(password.as_bytes(), &salt) .expect("Hashing a password should not error out") .to_string(); - tracing::debug!("Enrolling password for {} at {}", website, password_filename.display()); + tracing::debug!( + "Enrolling password for {} at {}", + website, + password_filename.display() + ); tokio::fs::write(password_filename, password_hash.as_bytes()).await } @@ -371,7 +429,7 @@ impl AuthBackend for FileBackend { async fn persist_registration_challenge( &self, website: &url::Url, - state: PasskeyRegistration + state: PasskeyRegistration, ) -> Result<String> { todo!() } @@ -380,7 +438,7 @@ impl AuthBackend for FileBackend { async fn retrieve_registration_challenge( &self, website: &url::Url, - challenge_id: &str + challenge_id: &str, ) -> Result<PasskeyRegistration> { todo!() } @@ -389,7 +447,7 @@ impl AuthBackend for FileBackend { async fn persist_authentication_challenge( &self, website: &url::Url, - state: PasskeyAuthentication + state: PasskeyAuthentication, ) -> Result<String> { todo!() } @@ -398,24 +456,28 @@ impl AuthBackend for FileBackend { async fn retrieve_authentication_challenge( &self, website: &url::Url, - challenge_id: &str + challenge_id: &str, ) -> Result<PasskeyAuthentication> { todo!() } #[tracing::instrument(skip(self))] - async fn list_user_credential_types(&self, website: &url::Url) -> Result<Vec<EnrolledCredential>> { + async fn list_user_credential_types( + &self, + website: &url::Url, + ) -> Result<Vec<EnrolledCredential>> { let mut creds = vec![]; - let password_file = self.path + let password_file = self + .path .join(FileBackend::url_to_dir(website)) .join("password"); tracing::debug!("Password file for {}: {}", website, password_file.display()); - match tokio::fs::metadata(password_file) - .await - { + match tokio::fs::metadata(password_file).await { Ok(_) => creds.push(EnrolledCredential::Password), - Err(err) => if err.kind() != std::io::ErrorKind::NotFound { - return Err(err) + Err(err) => { + if err.kind() != std::io::ErrorKind::NotFound { + return Err(err); + } } } diff --git a/src/indieauth/mod.rs b/src/indieauth/mod.rs index 00ae393..2f90a19 100644 --- a/src/indieauth/mod.rs +++ b/src/indieauth/mod.rs @@ -1,18 +1,29 @@ -use std::marker::PhantomData; -use microformats::types::Class; -use tracing::error; -use serde::Deserialize; +use crate::database::Storage; use axum::{ - extract::{Form, FromRef, Json, Query, State}, http::StatusCode, response::{Html, IntoResponse, Response} + extract::{Form, FromRef, Json, Query, State}, + http::StatusCode, + response::{Html, IntoResponse, Response}, }; #[cfg_attr(not(feature = "webauthn"), allow(unused_imports))] -use axum_extra::extract::{Host, cookie::{CookieJar, Cookie}}; -use axum_extra::{headers::{authorization::Bearer, Authorization, ContentType, HeaderMapExt}, TypedHeader}; -use crate::database::Storage; +use axum_extra::extract::{ + cookie::{Cookie, CookieJar}, + Host, +}; +use axum_extra::{ + headers::{authorization::Bearer, Authorization, ContentType, HeaderMapExt}, + TypedHeader, +}; use kittybox_indieauth::{ - AuthorizationRequest, AuthorizationResponse, ClientMetadata, Error, ErrorKind, GrantRequest, GrantResponse, GrantType, IntrospectionEndpointAuthMethod, Metadata, PKCEMethod, Profile, ProfileUrl, ResponseType, RevocationEndpointAuthMethod, Scope, Scopes, TokenData, TokenIntrospectionRequest, TokenIntrospectionResponse, TokenRevocationRequest + AuthorizationRequest, AuthorizationResponse, ClientMetadata, Error, ErrorKind, GrantRequest, + GrantResponse, GrantType, IntrospectionEndpointAuthMethod, Metadata, PKCEMethod, Profile, + ProfileUrl, ResponseType, RevocationEndpointAuthMethod, Scope, Scopes, TokenData, + TokenIntrospectionRequest, TokenIntrospectionResponse, TokenRevocationRequest, }; +use microformats::types::Class; +use serde::Deserialize; +use std::marker::PhantomData; use std::str::FromStr; +use tracing::error; pub mod backend; #[cfg(feature = "webauthn")] @@ -41,35 +52,42 @@ impl<A: AuthBackend> std::ops::Deref for User<A> { pub enum IndieAuthResourceError { InvalidRequest, Unauthorized, - InvalidToken + InvalidToken, } impl axum::response::IntoResponse for IndieAuthResourceError { fn into_response(self) -> axum::response::Response { use IndieAuthResourceError::*; match self { - Unauthorized => ( - StatusCode::UNAUTHORIZED, - [("WWW-Authenticate", "Bearer")] - ).into_response(), + Unauthorized => { + (StatusCode::UNAUTHORIZED, [("WWW-Authenticate", "Bearer")]).into_response() + } InvalidRequest => ( StatusCode::BAD_REQUEST, - Json(&serde_json::json!({"error": "invalid_request"})) - ).into_response(), + Json(&serde_json::json!({"error": "invalid_request"})), + ) + .into_response(), InvalidToken => ( StatusCode::UNAUTHORIZED, [("WWW-Authenticate", "Bearer, error=\"invalid_token\"")], - Json(&serde_json::json!({"error": "not_authorized"})) - ).into_response() + Json(&serde_json::json!({"error": "not_authorized"})), + ) + .into_response(), } } } -impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::extract::OptionalFromRequestParts<St> for User<A> { +impl<A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> + axum::extract::OptionalFromRequestParts<St> for User<A> +{ type Rejection = <Self as axum::extract::FromRequestParts<St>>::Rejection; - async fn from_request_parts(req: &mut axum::http::request::Parts, state: &St) -> Result<Option<Self>, Self::Rejection> { - let res = <Self as axum::extract::FromRequestParts<St>>::from_request_parts(req, state).await; + async fn from_request_parts( + req: &mut axum::http::request::Parts, + state: &St, + ) -> Result<Option<Self>, Self::Rejection> { + let res = + <Self as axum::extract::FromRequestParts<St>>::from_request_parts(req, state).await; match res { Ok(user) => Ok(Some(user)), @@ -79,14 +97,19 @@ impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::ext } } -impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::extract::FromRequestParts<St> for User<A> { +impl<A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> + axum::extract::FromRequestParts<St> for User<A> +{ type Rejection = IndieAuthResourceError; - async fn from_request_parts(req: &mut axum::http::request::Parts, state: &St) -> Result<Self, Self::Rejection> { + async fn from_request_parts( + req: &mut axum::http::request::Parts, + state: &St, + ) -> Result<Self, Self::Rejection> { let TypedHeader(Authorization(token)) = TypedHeader::<Authorization<Bearer>>::from_request_parts(req, state) - .await - .map_err(|_| IndieAuthResourceError::Unauthorized)?; + .await + .map_err(|_| IndieAuthResourceError::Unauthorized)?; let auth = A::from_ref(state); @@ -94,10 +117,7 @@ impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::ext .await .map_err(|_| IndieAuthResourceError::InvalidRequest)?; - auth.get_token( - &format!("https://{host}/").parse().unwrap(), - token.token() - ) + auth.get_token(&format!("https://{host}/").parse().unwrap(), token.token()) .await .unwrap() .ok_or(IndieAuthResourceError::InvalidToken) @@ -105,9 +125,7 @@ impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::ext } } -pub async fn metadata( - Host(host): Host -) -> Metadata { +pub async fn metadata(Host(host): Host) -> Metadata { let issuer: url::Url = format!("https://{}/", host).parse().unwrap(); let indieauth: url::Url = issuer.join("/.kittybox/indieauth/").unwrap(); @@ -117,18 +135,16 @@ pub async fn metadata( token_endpoint: indieauth.join("token").unwrap(), introspection_endpoint: indieauth.join("token_status").unwrap(), introspection_endpoint_auth_methods_supported: Some(vec![ - IntrospectionEndpointAuthMethod::Bearer + IntrospectionEndpointAuthMethod::Bearer, ]), revocation_endpoint: Some(indieauth.join("revoke_token").unwrap()), - revocation_endpoint_auth_methods_supported: Some(vec![ - RevocationEndpointAuthMethod::None - ]), + revocation_endpoint_auth_methods_supported: Some(vec![RevocationEndpointAuthMethod::None]), scopes_supported: Some(vec![ Scope::Create, Scope::Update, Scope::Delete, Scope::Media, - Scope::Profile + Scope::Profile, ]), response_types_supported: Some(vec![ResponseType::Code]), grant_types_supported: Some(vec![GrantType::AuthorizationCode, GrantType::RefreshToken]), @@ -145,27 +161,39 @@ async fn authorization_endpoint_get<A: AuthBackend, D: Storage + 'static>( Query(request): Query<AuthorizationRequest>, State(db): State<D>, State(http): State<reqwest_middleware::ClientWithMiddleware>, - State(auth): State<A> + State(auth): State<A>, ) -> Response { let me: url::Url = format!("https://{host}/").parse().unwrap(); // XXX: attempt fetching OAuth application metadata - let h_app: ClientMetadata = if request.client_id.domain().unwrap() == "localhost" && me.domain().unwrap() != "localhost" { + let h_app: ClientMetadata = if request.client_id.domain().unwrap() == "localhost" + && me.domain().unwrap() != "localhost" + { // If client is localhost, but we aren't localhost, generate synthetic metadata. tracing::warn!("Client is localhost, not fetching metadata"); - let mut metadata = ClientMetadata::new(request.client_id.clone(), request.client_id.clone()).unwrap(); + let mut metadata = + ClientMetadata::new(request.client_id.clone(), request.client_id.clone()).unwrap(); metadata.client_name = Some("Your locally hosted app".to_string()); metadata } else { tracing::debug!("Sending request to {} to fetch metadata", request.client_id); - let metadata_request = http.get(request.client_id.clone()) + let metadata_request = http + .get(request.client_id.clone()) .header("Accept", "application/json, text/html"); - match metadata_request.send().await - .and_then(|res| res.error_for_status() - .map_err(reqwest_middleware::Error::Reqwest)) - { - Ok(response) if response.headers().typed_get::<ContentType>().to_owned().map(mime::Mime::from).map(|m| m.type_() == "text" && m.subtype() == "html").unwrap_or(false) => { + match metadata_request.send().await.and_then(|res| { + res.error_for_status() + .map_err(reqwest_middleware::Error::Reqwest) + }) { + Ok(response) + if response + .headers() + .typed_get::<ContentType>() + .to_owned() + .map(mime::Mime::from) + .map(|m| m.type_() == "text" && m.subtype() == "html") + .unwrap_or(false) => + { let url = response.url().clone(); let text = response.text().await.unwrap(); tracing::debug!("Received {} bytes in response", text.len()); @@ -173,76 +201,95 @@ async fn authorization_endpoint_get<A: AuthBackend, D: Storage + 'static>( Ok(mf2) => { if let Some(relation) = mf2.rels.items.get(&request.redirect_uri) { if !relation.rels.iter().any(|i| i == "redirect_uri") { - return (StatusCode::BAD_REQUEST, - [("Content-Type", "text/plain")], - "The redirect_uri provided was declared as \ - something other than redirect_uri.") - .into_response() + return ( + StatusCode::BAD_REQUEST, + [("Content-Type", "text/plain")], + "The redirect_uri provided was declared as \ + something other than redirect_uri.", + ) + .into_response(); } } else if request.redirect_uri.origin() != request.client_id.origin() { - return (StatusCode::BAD_REQUEST, - [("Content-Type", "text/plain")], - "The redirect_uri didn't match the origin \ - and wasn't explicitly allowed. You were being tricked.") - .into_response() + return ( + StatusCode::BAD_REQUEST, + [("Content-Type", "text/plain")], + "The redirect_uri didn't match the origin \ + and wasn't explicitly allowed. You were being tricked.", + ) + .into_response(); } - if let Some(app) = mf2.items + if let Some(app) = mf2 + .items .iter() - .find(|&i| i.r#type.iter() - .any(|i| { + .find(|&i| { + i.r#type.iter().any(|i| { *i == Class::from_str("h-app").unwrap() || *i == Class::from_str("h-x-app").unwrap() }) - ) + }) .cloned() { // Create a synthetic metadata document. Be forgiving. let mut metadata = ClientMetadata::new( request.client_id.clone(), - app.properties.get("url") + app.properties + .get("url") .and_then(|v| v.first()) .and_then(|i| match i { - microformats::types::PropertyValue::Url(url) => Some(url.clone()), - _ => None + microformats::types::PropertyValue::Url(url) => { + Some(url.clone()) + } + _ => None, }) - .unwrap_or_else(|| request.client_id.clone()) - ).unwrap(); + .unwrap_or_else(|| request.client_id.clone()), + ) + .unwrap(); - metadata.client_name = app.properties.get("name") + metadata.client_name = app + .properties + .get("name") .and_then(|v| v.first()) .and_then(|i| match i { - microformats::types::PropertyValue::Plain(name) => Some(name.to_owned()), - _ => None + microformats::types::PropertyValue::Plain(name) => { + Some(name.to_owned()) + } + _ => None, }); metadata.redirect_uris = mf2.rels.by_rels().remove("redirect_uri"); metadata } else { - return (StatusCode::BAD_REQUEST, [("Content-Type", "text/plain")], "No h-app or JSON application metadata found.").into_response() + return ( + StatusCode::BAD_REQUEST, + [("Content-Type", "text/plain")], + "No h-app or JSON application metadata found.", + ) + .into_response(); } - }, + } Err(err) => { tracing::error!("Error parsing application metadata: {}", err); return ( StatusCode::BAD_REQUEST, [("Content-Type", "text/plain")], - "Parsing h-app metadata failed.").into_response() + "Parsing h-app metadata failed.", + ) + .into_response(); } } - }, + } Ok(response) => match response.json::<ClientMetadata>().await { - Ok(client_metadata) => { - client_metadata - }, + Ok(client_metadata) => client_metadata, Err(err) => { tracing::error!("Error parsing JSON application metadata: {}", err); return ( StatusCode::BAD_REQUEST, [("Content-Type", "text/plain")], - format!("Parsing OAuth2 JSON app metadata failed: {}", err) - ).into_response() + format!("Parsing OAuth2 JSON app metadata failed: {}", err), + ) + .into_response(); } }, Err(err) => { @@ -250,27 +297,32 @@ async fn authorization_endpoint_get<A: AuthBackend, D: Storage + 'static>( return ( StatusCode::BAD_REQUEST, [("Content-Type", "text/plain")], - format!("Fetching app metadata failed: {}", err) - ).into_response() + format!("Fetching app metadata failed: {}", err), + ) + .into_response(); } } }; tracing::debug!("Application metadata: {:#?}", h_app); - Html(kittybox_frontend_renderer::Template { - title: "Confirm sign-in via IndieAuth", - blog_name: "Kittybox", - feeds: vec![], - user: None, - content: kittybox_frontend_renderer::AuthorizationRequestPage { - request, - credentials: auth.list_user_credential_types(&me).await.unwrap(), - user: db.get_post(me.as_str()).await.unwrap().unwrap(), - app: h_app - }.to_string(), - }.to_string()) - .into_response() + Html( + kittybox_frontend_renderer::Template { + title: "Confirm sign-in via IndieAuth", + blog_name: "Kittybox", + feeds: vec![], + user: None, + content: kittybox_frontend_renderer::AuthorizationRequestPage { + request, + credentials: auth.list_user_credential_types(&me).await.unwrap(), + user: db.get_post(me.as_str()).await.unwrap().unwrap(), + app: h_app, + } + .to_string(), + } + .to_string(), + ) + .into_response() } #[derive(Deserialize, Debug)] @@ -278,7 +330,7 @@ async fn authorization_endpoint_get<A: AuthBackend, D: Storage + 'static>( enum Credential { Password(String), #[cfg(feature = "webauthn")] - WebAuthn(::webauthn::prelude::PublicKeyCredential) + WebAuthn(::webauthn::prelude::PublicKeyCredential), } // The IndieAuth standard doesn't prescribe a format for confirming @@ -291,7 +343,7 @@ enum Credential { #[derive(Deserialize, Debug)] struct AuthorizationConfirmation { authorization_method: Credential, - request: AuthorizationRequest + request: AuthorizationRequest, } #[tracing::instrument(skip(auth, credential))] @@ -299,18 +351,14 @@ async fn verify_credential<A: AuthBackend>( auth: &A, website: &url::Url, credential: Credential, - #[cfg_attr(not(feature = "webauthn"), allow(unused_variables))] - challenge_id: Option<&str> + #[cfg_attr(not(feature = "webauthn"), allow(unused_variables))] challenge_id: Option<&str>, ) -> std::io::Result<bool> { match credential { Credential::Password(password) => auth.verify_password(website, password).await, #[cfg(feature = "webauthn")] - Credential::WebAuthn(credential) => webauthn::verify( - auth, - website, - credential, - challenge_id.unwrap() - ).await + Credential::WebAuthn(credential) => { + webauthn::verify(auth, website, credential, challenge_id.unwrap()).await + } } } @@ -323,7 +371,8 @@ async fn authorization_endpoint_confirm<A: AuthBackend>( ) -> Response { tracing::debug!("Received authorization confirmation from user"); #[cfg(feature = "webauthn")] - let challenge_id = cookies.get(webauthn::CHALLENGE_ID_COOKIE) + let challenge_id = cookies + .get(webauthn::CHALLENGE_ID_COOKIE) .map(|cookie| cookie.value()); #[cfg(not(feature = "webauthn"))] let challenge_id = None; @@ -331,14 +380,16 @@ async fn authorization_endpoint_confirm<A: AuthBackend>( let website = format!("https://{}/", host).parse().unwrap(); let AuthorizationConfirmation { authorization_method: credential, - request: mut auth + request: mut auth, } = confirmation; match verify_credential(&backend, &website, credential, challenge_id).await { - Ok(verified) => if !verified { - error!("User failed verification, bailing out."); - return StatusCode::UNAUTHORIZED.into_response(); - }, + Ok(verified) => { + if !verified { + error!("User failed verification, bailing out."); + return StatusCode::UNAUTHORIZED.into_response(); + } + } Err(err) => { error!("Error while verifying credential: {}", err); return StatusCode::INTERNAL_SERVER_ERROR.into_response(); @@ -365,9 +416,14 @@ async fn authorization_endpoint_confirm<A: AuthBackend>( let location = { let mut uri = redirect_uri; - uri.set_query(Some(&serde_urlencoded::to_string( - AuthorizationResponse { code, state, iss: website } - ).unwrap())); + uri.set_query(Some( + &serde_urlencoded::to_string(AuthorizationResponse { + code, + state, + iss: website, + }) + .unwrap(), + )); uri }; @@ -375,10 +431,11 @@ async fn authorization_endpoint_confirm<A: AuthBackend>( // DO NOT SET `StatusCode::FOUND` here! `fetch()` cannot read from // redirects, it can only follow them or choose to receive an // opaque response instead that is completely useless - (StatusCode::NO_CONTENT, - [("Location", location.as_str())], - #[cfg(feature = "webauthn")] - cookies.remove(Cookie::from(webauthn::CHALLENGE_ID_COOKIE)) + ( + StatusCode::NO_CONTENT, + [("Location", location.as_str())], + #[cfg(feature = "webauthn")] + cookies.remove(Cookie::from(webauthn::CHALLENGE_ID_COOKIE)), ) .into_response() } @@ -396,15 +453,18 @@ async fn authorization_endpoint_post<A: AuthBackend, D: Storage + 'static>( code, client_id, redirect_uri, - code_verifier + code_verifier, } => { let request: AuthorizationRequest = match backend.get_code(&code).await { Ok(Some(request)) => request, - Ok(None) => return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("The provided authorization code is invalid.".to_string()), - error_uri: None - }.into_response(), + Ok(None) => { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("The provided authorization code is invalid.".to_string()), + error_uri: None, + } + .into_response() + } Err(err) => { tracing::error!("Error retrieving auth request: {}", err); return StatusCode::INTERNAL_SERVER_ERROR.into_response(); @@ -414,51 +474,66 @@ async fn authorization_endpoint_post<A: AuthBackend, D: Storage + 'static>( return Error { kind: ErrorKind::InvalidGrant, msg: Some("This authorization code isn't yours.".to_string()), - error_uri: None - }.into_response() + error_uri: None, + } + .into_response(); } if redirect_uri != request.redirect_uri { return Error { kind: ErrorKind::InvalidGrant, - msg: Some("This redirect_uri doesn't match the one the code has been sent to.".to_string()), - error_uri: None - }.into_response() + msg: Some( + "This redirect_uri doesn't match the one the code has been sent to." + .to_string(), + ), + error_uri: None, + } + .into_response(); } if !request.code_challenge.verify(code_verifier) { return Error { kind: ErrorKind::InvalidGrant, msg: Some("The PKCE challenge failed.".to_string()), // are RFCs considered human-readable? 😝 - error_uri: "https://datatracker.ietf.org/doc/html/rfc7636#section-4.6".parse().ok() - }.into_response() + error_uri: "https://datatracker.ietf.org/doc/html/rfc7636#section-4.6" + .parse() + .ok(), + } + .into_response(); } let me: url::Url = format!("https://{}/", host).parse().unwrap(); if request.me.unwrap() != me { return Error { kind: ErrorKind::InvalidGrant, msg: Some("This authorization endpoint does not serve this user.".to_string()), - error_uri: None - }.into_response() + error_uri: None, + } + .into_response(); } - let profile = if request.scope.as_ref() - .map(|s| s.has(&Scope::Profile)) - .unwrap_or_default() + let profile = if request + .scope + .as_ref() + .map(|s| s.has(&Scope::Profile)) + .unwrap_or_default() { match get_profile( db, me.as_str(), - request.scope.as_ref() + request + .scope + .as_ref() .map(|s| s.has(&Scope::Email)) - .unwrap_or_default() - ).await { + .unwrap_or_default(), + ) + .await + { Ok(profile) => { tracing::debug!("Retrieved profile: {:?}", profile); profile - }, + } Err(err) => { tracing::error!("Error retrieving profile from database: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } } } else { @@ -466,12 +541,15 @@ async fn authorization_endpoint_post<A: AuthBackend, D: Storage + 'static>( }; GrantResponse::ProfileUrl(ProfileUrl { me, profile }).into_response() - }, + } _ => Error { kind: ErrorKind::InvalidGrant, msg: Some("The provided grant_type is unusable on this endpoint.".to_string()), - error_uri: "https://indieauth.spec.indieweb.org/#redeeming-the-authorization-code".parse().ok() - }.into_response() + error_uri: "https://indieauth.spec.indieweb.org/#redeeming-the-authorization-code" + .parse() + .ok(), + } + .into_response(), } } @@ -485,36 +563,40 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( #[inline] fn prepare_access_token(me: url::Url, client_id: url::Url, scope: Scopes) -> TokenData { TokenData { - me, client_id, scope, + me, + client_id, + scope, exp: (std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - + std::time::Duration::from_secs(ACCESS_TOKEN_VALIDITY)) - .as_secs() - .into(), + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + + std::time::Duration::from_secs(ACCESS_TOKEN_VALIDITY)) + .as_secs() + .into(), iat: std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_secs() - .into() + .into(), } } #[inline] fn prepare_refresh_token(me: url::Url, client_id: url::Url, scope: Scopes) -> TokenData { TokenData { - me, client_id, scope, + me, + client_id, + scope, exp: (std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - + std::time::Duration::from_secs(REFRESH_TOKEN_VALIDITY)) - .as_secs() - .into(), + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + + std::time::Duration::from_secs(REFRESH_TOKEN_VALIDITY)) + .as_secs() + .into(), iat: std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_secs() - .into() + .into(), } } @@ -525,15 +607,18 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( code, client_id, redirect_uri, - code_verifier + code_verifier, } => { let request: AuthorizationRequest = match backend.get_code(&code).await { Ok(Some(request)) => request, - Ok(None) => return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("The provided authorization code is invalid.".to_string()), - error_uri: None - }.into_response(), + Ok(None) => { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("The provided authorization code is invalid.".to_string()), + error_uri: None, + } + .into_response() + } Err(err) => { tracing::error!("Error retrieving auth request: {}", err); return StatusCode::INTERNAL_SERVER_ERROR.into_response(); @@ -542,33 +627,46 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( tracing::debug!("Retrieved authorization request: {:?}", request); - let scope = if let Some(scope) = request.scope { scope } else { + let scope = if let Some(scope) = request.scope { + scope + } else { return Error { kind: ErrorKind::InvalidScope, msg: Some("Tokens cannot be issued if no scopes are requested.".to_string()), - error_uri: "https://indieauth.spec.indieweb.org/#access-token-response".parse().ok() - }.into_response(); + error_uri: "https://indieauth.spec.indieweb.org/#access-token-response" + .parse() + .ok(), + } + .into_response(); }; if client_id != request.client_id { return Error { kind: ErrorKind::InvalidGrant, msg: Some("This authorization code isn't yours.".to_string()), - error_uri: None - }.into_response() + error_uri: None, + } + .into_response(); } if redirect_uri != request.redirect_uri { return Error { kind: ErrorKind::InvalidGrant, - msg: Some("This redirect_uri doesn't match the one the code has been sent to.".to_string()), - error_uri: None - }.into_response() + msg: Some( + "This redirect_uri doesn't match the one the code has been sent to." + .to_string(), + ), + error_uri: None, + } + .into_response(); } if !request.code_challenge.verify(code_verifier) { return Error { kind: ErrorKind::InvalidGrant, msg: Some("The PKCE challenge failed.".to_string()), - error_uri: "https://datatracker.ietf.org/doc/html/rfc7636#section-4.6".parse().ok() - }.into_response(); + error_uri: "https://datatracker.ietf.org/doc/html/rfc7636#section-4.6" + .parse() + .ok(), + } + .into_response(); } // Note: we can trust the `request.me` value, since we set @@ -577,30 +675,32 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( return Error { kind: ErrorKind::InvalidGrant, msg: Some("This authorization endpoint does not serve this user.".to_string()), - error_uri: None - }.into_response() + error_uri: None, + } + .into_response(); } let profile = if dbg!(scope.has(&Scope::Profile)) { - match get_profile( - db, - me.as_str(), - scope.has(&Scope::Email) - ).await { + match get_profile(db, me.as_str(), scope.has(&Scope::Email)).await { Ok(profile) => dbg!(profile), Err(err) => { tracing::error!("Error retrieving profile from database: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } } } else { None }; - let access_token = match backend.create_token( - prepare_access_token(me.clone(), client_id.clone(), scope.clone()) - ).await { + let access_token = match backend + .create_token(prepare_access_token( + me.clone(), + client_id.clone(), + scope.clone(), + )) + .await + { Ok(token) => token, Err(err) => { tracing::error!("Error creating access token: {}", err); @@ -608,9 +708,10 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( } }; // TODO: only create refresh token if user allows it - let refresh_token = match backend.create_refresh_token( - prepare_refresh_token(me.clone(), client_id, scope.clone()) - ).await { + let refresh_token = match backend + .create_refresh_token(prepare_refresh_token(me.clone(), client_id, scope.clone())) + .await + { Ok(token) => token, Err(err) => { tracing::error!("Error creating refresh token: {}", err); @@ -626,24 +727,28 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( scope: Some(scope), expires_in: Some(ACCESS_TOKEN_VALIDITY), refresh_token: Some(refresh_token), - state: None - }.into_response() - }, + state: None, + } + .into_response() + } GrantRequest::RefreshToken { refresh_token, client_id, - scope + scope, } => { let data = match backend.get_refresh_token(&me, &refresh_token).await { Ok(Some(token)) => token, - Ok(None) => return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("This refresh token is not valid.".to_string()), - error_uri: None - }.into_response(), + Ok(None) => { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("This refresh token is not valid.".to_string()), + error_uri: None, + } + .into_response() + } Err(err) => { tracing::error!("Error retrieving refresh token: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } }; @@ -651,17 +756,22 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( return Error { kind: ErrorKind::InvalidGrant, msg: Some("This refresh token is not yours.".to_string()), - error_uri: None - }.into_response(); + error_uri: None, + } + .into_response(); } let scope = if let Some(scope) = scope { if !data.scope.has_all(scope.as_ref()) { return Error { kind: ErrorKind::InvalidScope, - msg: Some("You can't request additional scopes through the refresh token grant.".to_string()), - error_uri: None - }.into_response(); + msg: Some( + "You can't request additional scopes through the refresh token grant." + .to_string(), + ), + error_uri: None, + } + .into_response(); } scope @@ -670,27 +780,27 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( data.scope }; - let profile = if scope.has(&Scope::Profile) { - match get_profile( - db, - data.me.as_str(), - scope.has(&Scope::Email) - ).await { + match get_profile(db, data.me.as_str(), scope.has(&Scope::Email)).await { Ok(profile) => profile, Err(err) => { tracing::error!("Error retrieving profile from database: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } } } else { None }; - let access_token = match backend.create_token( - prepare_access_token(data.me.clone(), client_id.clone(), scope.clone()) - ).await { + let access_token = match backend + .create_token(prepare_access_token( + data.me.clone(), + client_id.clone(), + scope.clone(), + )) + .await + { Ok(token) => token, Err(err) => { tracing::error!("Error creating access token: {}", err); @@ -699,9 +809,14 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( }; let old_refresh_token = refresh_token; - let refresh_token = match backend.create_refresh_token( - prepare_refresh_token(data.me.clone(), client_id, scope.clone()) - ).await { + let refresh_token = match backend + .create_refresh_token(prepare_refresh_token( + data.me.clone(), + client_id, + scope.clone(), + )) + .await + { Ok(token) => token, Err(err) => { tracing::error!("Error creating refresh token: {}", err); @@ -721,8 +836,9 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( scope: Some(scope), expires_in: Some(ACCESS_TOKEN_VALIDITY), refresh_token: Some(refresh_token), - state: None - }.into_response() + state: None, + } + .into_response() } } } @@ -740,26 +856,39 @@ async fn introspection_endpoint_post<A: AuthBackend>( // Check authentication first match backend.get_token(&me, auth_token.token()).await { - Ok(Some(token)) => if !token.scope.has(&Scope::custom(KITTYBOX_TOKEN_STATUS)) { - return (StatusCode::UNAUTHORIZED, Json(json!({ - "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope - }))).into_response(); - }, - Ok(None) => return (StatusCode::UNAUTHORIZED, Json(json!({ - "error": kittybox_indieauth::ResourceErrorKind::InvalidToken - }))).into_response(), + Ok(Some(token)) => { + if !token.scope.has(&Scope::custom(KITTYBOX_TOKEN_STATUS)) { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({ + "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope + })), + ) + .into_response(); + } + } + Ok(None) => { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({ + "error": kittybox_indieauth::ResourceErrorKind::InvalidToken + })), + ) + .into_response() + } Err(err) => { tracing::error!("Error retrieving token data for introspection: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } } - let response: TokenIntrospectionResponse = match backend.get_token(&me, &token_request.token).await { - Ok(maybe_data) => maybe_data.into(), - Err(err) => { - tracing::error!("Error retrieving token data: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() - } - }; + let response: TokenIntrospectionResponse = + match backend.get_token(&me, &token_request.token).await { + Ok(maybe_data) => maybe_data.into(), + Err(err) => { + tracing::error!("Error retrieving token data: {}", err); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + }; response.into_response() } @@ -787,7 +916,7 @@ async fn revocation_endpoint_post<A: AuthBackend>( async fn get_profile<D: Storage + 'static>( db: D, url: &str, - email: bool + email: bool, ) -> crate::database::Result<Option<Profile>> { fn get_first(v: serde_json::Value) -> Option<String> { match v { @@ -796,10 +925,10 @@ async fn get_profile<D: Storage + 'static>( match a.pop() { Some(serde_json::Value::String(s)) => Some(s), Some(serde_json::Value::Object(mut o)) => o.remove("value").and_then(get_first), - _ => None + _ => None, } - }, - _ => None + } + _ => None, } } @@ -807,15 +936,26 @@ async fn get_profile<D: Storage + 'static>( // Ruthlessly manually destructure the MF2 document to save memory let mut properties = match mf2.as_object_mut().unwrap().remove("properties") { Some(serde_json::Value::Object(props)) => props, - _ => unreachable!() + _ => unreachable!(), }; drop(mf2); let name = properties.remove("name").and_then(get_first); - let url = properties.remove("uid").and_then(get_first).and_then(|u| u.parse().ok()); - let photo = properties.remove("photo").and_then(get_first).and_then(|u| u.parse().ok()); + let url = properties + .remove("uid") + .and_then(get_first) + .and_then(|u| u.parse().ok()); + let photo = properties + .remove("photo") + .and_then(get_first) + .and_then(|u| u.parse().ok()); let email = properties.remove("name").and_then(get_first); - Profile { name, url, photo, email } + Profile { + name, + url, + photo, + email, + } })) } @@ -823,7 +963,7 @@ async fn userinfo_endpoint_get<A: AuthBackend, D: Storage + 'static>( Host(host): Host, TypedHeader(Authorization(auth_token)): TypedHeader<Authorization<Bearer>>, State(backend): State<A>, - State(db): State<D> + State(db): State<D>, ) -> Response { use serde_json::json; @@ -832,14 +972,22 @@ async fn userinfo_endpoint_get<A: AuthBackend, D: Storage + 'static>( match backend.get_token(&me, auth_token.token()).await { Ok(Some(token)) => { if token.expired() { - return (StatusCode::UNAUTHORIZED, Json(json!({ - "error": kittybox_indieauth::ResourceErrorKind::InvalidToken - }))).into_response(); + return ( + StatusCode::UNAUTHORIZED, + Json(json!({ + "error": kittybox_indieauth::ResourceErrorKind::InvalidToken + })), + ) + .into_response(); } if !token.scope.has(&Scope::Profile) { - return (StatusCode::UNAUTHORIZED, Json(json!({ - "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope - }))).into_response(); + return ( + StatusCode::UNAUTHORIZED, + Json(json!({ + "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope + })), + ) + .into_response(); } match get_profile(db, me.as_str(), token.scope.has(&Scope::Email)).await { @@ -847,17 +995,19 @@ async fn userinfo_endpoint_get<A: AuthBackend, D: Storage + 'static>( Ok(None) => Json(json!({ // We do this because ResourceErrorKind is IndieAuth errors only "error": "invalid_request" - })).into_response(), + })) + .into_response(), Err(err) => { tracing::error!("Error retrieving profile from database: {}", err); StatusCode::INTERNAL_SERVER_ERROR.into_response() } } - }, + } Ok(None) => Json(json!({ "error": kittybox_indieauth::ResourceErrorKind::InvalidToken - })).into_response(), + })) + .into_response(), Err(err) => { tracing::error!("Error reading token: {}", err); @@ -871,57 +1021,51 @@ where S: Storage + FromRef<St> + 'static, A: AuthBackend + FromRef<St>, reqwest_middleware::ClientWithMiddleware: FromRef<St>, - St: Clone + Send + Sync + 'static + St: Clone + Send + Sync + 'static, { - use axum::routing::{Router, get, post}; + use axum::routing::{get, post, Router}; Router::new() .nest( "/.kittybox/indieauth", Router::new() - .route("/metadata", - get(metadata)) + .route("/metadata", get(metadata)) .route( "/auth", get(authorization_endpoint_get::<A, S>) - .post(authorization_endpoint_post::<A, S>)) - .route( - "/auth/confirm", - post(authorization_endpoint_confirm::<A>)) - .route( - "/token", - post(token_endpoint_post::<A, S>)) - .route( - "/token_status", - post(introspection_endpoint_post::<A>)) - .route( - "/revoke_token", - post(revocation_endpoint_post::<A>)) + .post(authorization_endpoint_post::<A, S>), + ) + .route("/auth/confirm", post(authorization_endpoint_confirm::<A>)) + .route("/token", post(token_endpoint_post::<A, S>)) + .route("/token_status", post(introspection_endpoint_post::<A>)) + .route("/revoke_token", post(revocation_endpoint_post::<A>)) + .route("/userinfo", get(userinfo_endpoint_get::<A, S>)) .route( - "/userinfo", - get(userinfo_endpoint_get::<A, S>)) - - .route("/webauthn/pre_register", - get( - #[cfg(feature = "webauthn")] webauthn::webauthn_pre_register::<A, S>, - #[cfg(not(feature = "webauthn"))] || std::future::ready(axum::http::StatusCode::NOT_FOUND) - ) + "/webauthn/pre_register", + get( + #[cfg(feature = "webauthn")] + webauthn::webauthn_pre_register::<A, S>, + #[cfg(not(feature = "webauthn"))] + || std::future::ready(axum::http::StatusCode::NOT_FOUND), + ), ) - .layer(tower_http::cors::CorsLayer::new() - .allow_methods([ - axum::http::Method::GET, - axum::http::Method::POST - ]) - .allow_origin(tower_http::cors::Any)) + .layer( + tower_http::cors::CorsLayer::new() + .allow_methods([axum::http::Method::GET, axum::http::Method::POST]) + .allow_origin(tower_http::cors::Any), + ), ) .route( "/.well-known/oauth-authorization-server", - get(|| std::future::ready( - (StatusCode::FOUND, - [("Location", - "/.kittybox/indieauth/metadata")] - ).into_response() - )) + get(|| { + std::future::ready( + ( + StatusCode::FOUND, + [("Location", "/.kittybox/indieauth/metadata")], + ) + .into_response(), + ) + }), ) } @@ -929,9 +1073,10 @@ where mod tests { #[test] fn test_deserialize_authorization_confirmation() { - use super::{Credential, AuthorizationConfirmation}; + use super::{AuthorizationConfirmation, Credential}; - let confirmation = serde_json::from_str::<AuthorizationConfirmation>(r#"{ + let confirmation = serde_json::from_str::<AuthorizationConfirmation>( + r#"{ "request":{ "response_type": "code", "client_id": "https://quill.p3k.io/", @@ -942,12 +1087,14 @@ mod tests { "scope": "create+media" }, "authorization_method": "swordfish" - }"#).unwrap(); + }"#, + ) + .unwrap(); match confirmation.authorization_method { Credential::Password(password) => assert_eq!(password.as_str(), "swordfish"), #[allow(unreachable_patterns)] - other => panic!("Incorrect credential: {:?}", other) + other => panic!("Incorrect credential: {:?}", other), } assert_eq!(confirmation.request.state.as_ref(), "10101010"); } diff --git a/src/indieauth/webauthn.rs b/src/indieauth/webauthn.rs index 0757e72..80d210c 100644 --- a/src/indieauth/webauthn.rs +++ b/src/indieauth/webauthn.rs @@ -1,10 +1,17 @@ use axum::{ extract::Json, + http::StatusCode, response::{IntoResponse, Response}, - http::StatusCode, Extension + Extension, +}; +use axum_extra::extract::{ + cookie::{Cookie, CookieJar}, + Host, +}; +use axum_extra::{ + headers::{authorization::Bearer, Authorization}, + TypedHeader, }; -use axum_extra::extract::{Host, cookie::{CookieJar, Cookie}}; -use axum_extra::{TypedHeader, headers::{authorization::Bearer, Authorization}}; use super::backend::AuthBackend; use crate::database::Storage; @@ -12,40 +19,33 @@ use crate::database::Storage; pub(crate) const CHALLENGE_ID_COOKIE: &str = "kittybox_webauthn_challenge_id"; macro_rules! bail { - ($msg:literal, $err:expr) => { - { - ::tracing::error!($msg, $err); - return ::axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response() - } - } + ($msg:literal, $err:expr) => {{ + ::tracing::error!($msg, $err); + return ::axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response(); + }}; } pub async fn webauthn_pre_register<A: AuthBackend, D: Storage + 'static>( Host(host): Host, Extension(db): Extension<D>, Extension(auth): Extension<A>, - cookies: CookieJar + cookies: CookieJar, ) -> Response { let uid = format!("https://{}/", host.clone()); let uid_url: url::Url = uid.parse().unwrap(); // This will not find an h-card in onboarding! let display_name = match db.get_post(&uid).await { Ok(hcard) => match hcard { - Some(mut hcard) => { - match hcard["properties"]["uid"][0].take() { - serde_json::Value::String(name) => name, - _ => String::default() - } + Some(mut hcard) => match hcard["properties"]["uid"][0].take() { + serde_json::Value::String(name) => name, + _ => String::default(), }, - None => String::default() + None => String::default(), }, - Err(err) => bail!("Error retrieving h-card: {}", err) + Err(err) => bail!("Error retrieving h-card: {}", err), }; - let webauthn = webauthn::WebauthnBuilder::new( - &host, - &uid_url - ) + let webauthn = webauthn::WebauthnBuilder::new(&host, &uid_url) .unwrap() .rp_name("Kittybox") .build() @@ -58,10 +58,10 @@ pub async fn webauthn_pre_register<A: AuthBackend, D: Storage + 'static>( webauthn::prelude::Uuid::nil(), &uid, &display_name, - Some(vec![]) + Some(vec![]), ) { Ok((challenge, state)) => (challenge, state), - Err(err) => bail!("Error generating WebAuthn registration data: {}", err) + Err(err) => bail!("Error generating WebAuthn registration data: {}", err), }; match auth.persist_registration_challenge(&uid_url, state).await { @@ -69,11 +69,12 @@ pub async fn webauthn_pre_register<A: AuthBackend, D: Storage + 'static>( cookies.add( Cookie::build((CHALLENGE_ID_COOKIE, challenge_id)) .secure(true) - .finish() + .finish(), ), - Json(challenge) - ).into_response(), - Err(err) => bail!("Failed to persist WebAuthn challenge: {}", err) + Json(challenge), + ) + .into_response(), + Err(err) => bail!("Failed to persist WebAuthn challenge: {}", err), } } @@ -82,39 +83,36 @@ pub async fn webauthn_register<A: AuthBackend>( Json(credential): Json<webauthn::prelude::RegisterPublicKeyCredential>, // TODO determine if we can use a cookie maybe? user_credential: Option<TypedHeader<Authorization<Bearer>>>, - Extension(auth): Extension<A> + Extension(auth): Extension<A>, ) -> Response { let uid = format!("https://{}/", host.clone()); let uid_url: url::Url = uid.parse().unwrap(); let pubkeys = match auth.list_webauthn_pubkeys(&uid_url).await { Ok(pubkeys) => pubkeys, - Err(err) => bail!("Error enumerating existing WebAuthn credentials: {}", err) + Err(err) => bail!("Error enumerating existing WebAuthn credentials: {}", err), }; if !pubkeys.is_empty() { if let Some(TypedHeader(Authorization(token))) = user_credential { // TODO check validity of the credential } else { - return StatusCode::UNAUTHORIZED.into_response() + return StatusCode::UNAUTHORIZED.into_response(); } } - return StatusCode::OK.into_response() + return StatusCode::OK.into_response(); } pub(crate) async fn verify<A: AuthBackend>( auth: &A, website: &url::Url, credential: webauthn::prelude::PublicKeyCredential, - challenge_id: &str + challenge_id: &str, ) -> std::io::Result<bool> { let host = website.host_str().unwrap(); - let webauthn = webauthn::WebauthnBuilder::new( - host, - website - ) + let webauthn = webauthn::WebauthnBuilder::new(host, website) .unwrap() .rp_name("Kittybox") .build() @@ -122,12 +120,14 @@ pub(crate) async fn verify<A: AuthBackend>( match webauthn.finish_passkey_authentication( &credential, - &auth.retrieve_authentication_challenge(&website, challenge_id).await? + &auth + .retrieve_authentication_challenge(&website, challenge_id) + .await?, ) { Err(err) => { tracing::error!("WebAuthn error: {}", err); Ok(false) - }, + } Ok(authentication_result) => { let counter = authentication_result.counter(); let cred_id = authentication_result.cred_id(); diff --git a/src/lib.rs b/src/lib.rs index 4aeaca5..a52db4c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,22 +4,28 @@ use std::sync::Arc; use axum::extract::{FromRef, FromRequestParts, OptionalFromRequestParts}; -use axum_extra::extract::{cookie::{Cookie, Key}, SignedCookieJar}; +use axum_extra::extract::{ + cookie::{Cookie, Key}, + SignedCookieJar, +}; use database::{FileStorage, PostgresStorage, Storage}; use indieauth::backend::{AuthBackend, FileBackend as FileAuthBackend}; use kittybox_util::queue::JobQueue; -use media::storage::{MediaStore, file::FileStore as FileMediaStore}; -use tokio::{sync::{Mutex, RwLock}, task::JoinSet}; +use media::storage::{file::FileStore as FileMediaStore, MediaStore}; +use tokio::{ + sync::{Mutex, RwLock}, + task::JoinSet, +}; use webmentions::queue::PostgresJobQueue; /// Database abstraction layer for Kittybox, allowing the CMS to work with any kind of database. pub mod database; pub mod frontend; +pub mod indieauth; +pub mod login; pub mod media; pub mod micropub; -pub mod indieauth; pub mod webmentions; -pub mod login; //pub mod admin; const OAUTH2_SOFTWARE_ID: &str = "6f2eee84-c22c-4c9e-b900-10d4e97273c8"; @@ -27,10 +33,10 @@ const OAUTH2_SOFTWARE_ID: &str = "6f2eee84-c22c-4c9e-b900-10d4e97273c8"; #[derive(Clone)] pub struct AppState<A, S, M, Q> where -A: AuthBackend + Sized + 'static, -S: Storage + Sized + 'static, -M: MediaStore + Sized + 'static, -Q: JobQueue<webmentions::Webmention> + Sized + A: AuthBackend + Sized + 'static, + S: Storage + Sized + 'static, + M: MediaStore + Sized + 'static, + Q: JobQueue<webmentions::Webmention> + Sized, { pub auth_backend: A, pub storage: S, @@ -39,7 +45,7 @@ Q: JobQueue<webmentions::Webmention> + Sized pub http: reqwest_middleware::ClientWithMiddleware, pub background_jobs: Arc<Mutex<JoinSet<()>>>, pub cookie_key: Key, - pub session_store: SessionStore + pub session_store: SessionStore, } pub type SessionStore = Arc<RwLock<std::collections::HashMap<uuid::Uuid, Session>>>; @@ -60,7 +66,11 @@ pub struct NoSessionError; impl axum::response::IntoResponse for NoSessionError { fn into_response(self) -> axum::response::Response { // TODO: prettier error message - (axum::http::StatusCode::UNAUTHORIZED, "You are not logged in, but this page requires a session.").into_response() + ( + axum::http::StatusCode::UNAUTHORIZED, + "You are not logged in, but this page requires a session.", + ) + .into_response() } } @@ -72,11 +82,17 @@ where { type Rejection = std::convert::Infallible; - async fn from_request_parts(parts: &mut axum::http::request::Parts, state: &S) -> Result<Option<Self>, Self::Rejection> { - let jar = SignedCookieJar::<Key>::from_request_parts(parts, state).await.unwrap(); + async fn from_request_parts( + parts: &mut axum::http::request::Parts, + state: &S, + ) -> Result<Option<Self>, Self::Rejection> { + let jar = SignedCookieJar::<Key>::from_request_parts(parts, state) + .await + .unwrap(); let session_store = SessionStore::from_ref(state).read_owned().await; - Ok(jar.get("session_id") + Ok(jar + .get("session_id") .as_ref() .map(Cookie::value_trimmed) .and_then(|id| uuid::Uuid::parse_str(id).ok()) @@ -103,7 +119,10 @@ where // have to repeat this magic invocation. impl<S, M, Q> FromRef<AppState<Self, S, M, Q>> for FileAuthBackend -where S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + S: Storage, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<Self, S, M, Q>) -> Self { input.auth_backend.clone() @@ -111,7 +130,10 @@ where S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> } impl<A, M, Q> FromRef<AppState<A, Self, M, Q>> for PostgresStorage -where A: AuthBackend, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, Self, M, Q>) -> Self { input.storage.clone() @@ -119,7 +141,10 @@ where A: AuthBackend, M: MediaStore, Q: JobQueue<webmentions::Webmention> } impl<A, M, Q> FromRef<AppState<A, Self, M, Q>> for FileStorage -where A: AuthBackend, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, Self, M, Q>) -> Self { input.storage.clone() @@ -128,7 +153,10 @@ where A: AuthBackend, M: MediaStore, Q: JobQueue<webmentions::Webmention> impl<A, S, Q> FromRef<AppState<A, S, Self, Q>> for FileMediaStore // where A: AuthBackend, S: Storage -where A: AuthBackend, S: Storage, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + S: Storage, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, S, Self, Q>) -> Self { input.media_store.clone() @@ -136,7 +164,11 @@ where A: AuthBackend, S: Storage, Q: JobQueue<webmentions::Webmention> } impl<A, S, M, Q> FromRef<AppState<A, S, M, Q>> for Key -where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + S: Storage, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, S, M, Q>) -> Self { input.cookie_key.clone() @@ -144,7 +176,11 @@ where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmen } impl<A, S, M, Q> FromRef<AppState<A, S, M, Q>> for reqwest_middleware::ClientWithMiddleware -where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + S: Storage, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, S, M, Q>) -> Self { input.http.clone() @@ -152,7 +188,11 @@ where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmen } impl<A, S, M, Q> FromRef<AppState<A, S, M, Q>> for Arc<Mutex<JoinSet<()>>> -where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + S: Storage, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, S, M, Q>) -> Self { input.background_jobs.clone() @@ -161,7 +201,10 @@ where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmen #[cfg(feature = "sqlx")] impl<A, S, M> FromRef<AppState<A, S, M, Self>> for PostgresJobQueue<webmentions::Webmention> -where A: AuthBackend, S: Storage, M: MediaStore +where + A: AuthBackend, + S: Storage, + M: MediaStore, { fn from_ref(input: &AppState<A, S, M, Self>) -> Self { input.job_queue.clone() @@ -169,7 +212,11 @@ where A: AuthBackend, S: Storage, M: MediaStore } impl<A, S, M, Q> FromRef<AppState<A, S, M, Q>> for SessionStore -where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + S: Storage, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, S, M, Q>) -> Self { input.session_store.clone() @@ -177,23 +224,26 @@ where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmen } pub mod companion { - use std::{collections::HashMap, sync::Arc}; use axum::{ extract::{Extension, Path}, - response::{IntoResponse, Response} + response::{IntoResponse, Response}, }; + use std::{collections::HashMap, sync::Arc}; #[derive(Debug, Clone, Copy)] struct Resource { data: &'static [u8], - mime: &'static str + mime: &'static str, } impl IntoResponse for &Resource { fn into_response(self) -> Response { - (axum::http::StatusCode::OK, - [("Content-Type", self.mime)], - self.data).into_response() + ( + axum::http::StatusCode::OK, + [("Content-Type", self.mime)], + self.data, + ) + .into_response() } } @@ -203,17 +253,21 @@ pub mod companion { #[tracing::instrument] async fn map_to_static( Path(name): Path<String>, - Extension(resources): Extension<ResourceTable> + Extension(resources): Extension<ResourceTable>, ) -> Response { tracing::debug!("Searching for {} in the resource table...", name); match resources.get(name.as_str()) { Some(res) => res.into_response(), None => { - #[cfg(debug_assertions)] tracing::error!("Not found"); + #[cfg(debug_assertions)] + tracing::error!("Not found"); - (axum::http::StatusCode::NOT_FOUND, - [("Content-Type", "text/plain")], - "Not found. Sorry.".as_bytes()).into_response() + ( + axum::http::StatusCode::NOT_FOUND, + [("Content-Type", "text/plain")], + "Not found. Sorry.".as_bytes(), + ) + .into_response() } } } @@ -249,47 +303,52 @@ pub mod companion { Arc::new(map) }; - axum::Router::new() - .route( - "/{filename}", - axum::routing::get(map_to_static) - .layer(Extension(resources)) - ) + axum::Router::new().route( + "/{filename}", + axum::routing::get(map_to_static).layer(Extension(resources)), + ) } } async fn teapot_route() -> impl axum::response::IntoResponse { use axum::http::{header, StatusCode}; - (StatusCode::IM_A_TEAPOT, [(header::CONTENT_TYPE, "text/plain")], "Sorry, can't brew coffee yet!") + ( + StatusCode::IM_A_TEAPOT, + [(header::CONTENT_TYPE, "text/plain")], + "Sorry, can't brew coffee yet!", + ) } async fn health_check<D>( axum::extract::State(data): axum::extract::State<D>, ) -> impl axum::response::IntoResponse where - D: crate::database::Storage + D: crate::database::Storage, { (axum::http::StatusCode::OK, std::borrow::Cow::Borrowed("OK")) } pub async fn compose_kittybox<St, A, S, M, Q>() -> axum::Router<St> where -A: AuthBackend + 'static + FromRef<St>, -S: Storage + 'static + FromRef<St>, -M: MediaStore + 'static + FromRef<St>, -Q: kittybox_util::queue::JobQueue<crate::webmentions::Webmention> + FromRef<St>, -reqwest_middleware::ClientWithMiddleware: FromRef<St>, -Arc<Mutex<JoinSet<()>>>: FromRef<St>, -crate::SessionStore: FromRef<St>, -axum_extra::extract::cookie::Key: FromRef<St>, -St: Clone + Send + Sync + 'static + A: AuthBackend + 'static + FromRef<St>, + S: Storage + 'static + FromRef<St>, + M: MediaStore + 'static + FromRef<St>, + Q: kittybox_util::queue::JobQueue<crate::webmentions::Webmention> + FromRef<St>, + reqwest_middleware::ClientWithMiddleware: FromRef<St>, + Arc<Mutex<JoinSet<()>>>: FromRef<St>, + crate::SessionStore: FromRef<St>, + axum_extra::extract::cookie::Key: FromRef<St>, + St: Clone + Send + Sync + 'static, { use axum::routing::get; axum::Router::new() .route("/", get(crate::frontend::homepage::<S>)) .fallback(get(crate::frontend::catchall::<S>)) .route("/.kittybox/micropub", crate::micropub::router::<A, S, St>()) - .route("/.kittybox/onboarding", crate::frontend::onboarding::router::<St, S>()) + .route( + "/.kittybox/onboarding", + crate::frontend::onboarding::router::<St, S>(), + ) .nest("/.kittybox/media", crate::media::router::<St, A, M>()) .merge(crate::indieauth::router::<St, A, S>()) .merge(crate::webmentions::router::<St, Q>()) @@ -297,34 +356,36 @@ St: Clone + Send + Sync + 'static .nest("/.kittybox/login", crate::login::router::<St, S>()) .route( "/.kittybox/static/{*path}", - axum::routing::get(crate::frontend::statics) + axum::routing::get(crate::frontend::statics), ) .route("/.kittybox/coffee", get(teapot_route)) - .nest("/.kittybox/micropub/client", crate::companion::router::<St>()) + .nest( + "/.kittybox/micropub/client", + crate::companion::router::<St>(), + ) .layer(tower_http::trace::TraceLayer::new_for_http()) .layer(tower_http::catch_panic::CatchPanicLayer::new()) - .layer(tower_http::sensitive_headers::SetSensitiveHeadersLayer::new([ - axum::http::header::AUTHORIZATION, - axum::http::header::COOKIE, - axum::http::header::SET_COOKIE, - ])) + .layer( + tower_http::sensitive_headers::SetSensitiveHeadersLayer::new([ + axum::http::header::AUTHORIZATION, + axum::http::header::COOKIE, + axum::http::header::SET_COOKIE, + ]), + ) .layer(tower_http::set_header::SetResponseHeaderLayer::appending( axum::http::header::CONTENT_SECURITY_POLICY, - axum::http::HeaderValue::from_static( - concat!( - "default-src 'none';", // Do not allow unknown things we didn't foresee. - "img-src https:;", // Allow hotlinking images from anywhere. - "form-action 'self';", // Only allow sending forms back to us. - "media-src 'self';", // Only allow embedding media from us. - "script-src 'self';", // Only run scripts we serve. - "style-src 'self';", // Only use styles we serve. - "base-uri 'none';", // Do not allow to change the base URI. - "object-src 'none';", // Do not allow to embed objects (Flash/ActiveX). - - // Allow embedding the Bandcamp player for jam posts. - // TODO: perhaps make this policy customizable?… - "frame-src 'self' https://bandcamp.com/EmbeddedPlayer/;" - ) - ) + axum::http::HeaderValue::from_static(concat!( + "default-src 'none';", // Do not allow unknown things we didn't foresee. + "img-src https:;", // Allow hotlinking images from anywhere. + "form-action 'self';", // Only allow sending forms back to us. + "media-src 'self';", // Only allow embedding media from us. + "script-src 'self';", // Only run scripts we serve. + "style-src 'self';", // Only use styles we serve. + "base-uri 'none';", // Do not allow to change the base URI. + "object-src 'none';", // Do not allow to embed objects (Flash/ActiveX). + // Allow embedding the Bandcamp player for jam posts. + // TODO: perhaps make this policy customizable?… + "frame-src 'self' https://bandcamp.com/EmbeddedPlayer/;" + )), )) } diff --git a/src/login.rs b/src/login.rs index eaa787c..3038d9c 100644 --- a/src/login.rs +++ b/src/login.rs @@ -1,10 +1,25 @@ use std::{borrow::Cow, str::FromStr}; +use axum::{ + extract::{FromRef, Query, State}, + http::HeaderValue, + response::IntoResponse, + Form, +}; +use axum_extra::{ + extract::{ + cookie::{self, Cookie}, + Host, SignedCookieJar, + }, + headers::HeaderMapExt, + TypedHeader, +}; use futures_util::FutureExt; -use axum::{extract::{FromRef, Query, State}, http::HeaderValue, response::IntoResponse, Form}; -use axum_extra::{extract::{Host, cookie::{self, Cookie}, SignedCookieJar}, headers::HeaderMapExt, TypedHeader}; -use hyper::{header::{CACHE_CONTROL, LOCATION}, StatusCode}; -use kittybox_frontend_renderer::{Template, LoginPage, LogoutPage}; +use hyper::{ + header::{CACHE_CONTROL, LOCATION}, + StatusCode, +}; +use kittybox_frontend_renderer::{LoginPage, LogoutPage, Template}; use kittybox_indieauth::{AuthorizationResponse, Error, GrantType, PKCEVerifier, Scope, Scopes}; use sha2::Digest; @@ -13,14 +28,13 @@ use crate::database::Storage; /// Show a login page. async fn get<S: Storage + Send + Sync + 'static>( State(db): State<S>, - Host(host): Host + Host(host): Host, ) -> impl axum::response::IntoResponse { let hcard_url: url::Url = format!("https://{}/", host).parse().unwrap(); let (blogname, channels) = tokio::join!( db.get_setting::<crate::database::settings::SiteName>(&hcard_url) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_channels(&hcard_url).map(|i| i.unwrap_or_default()) ); ( @@ -34,14 +48,15 @@ async fn get<S: Storage + Send + Sync + 'static>( blog_name: blogname.as_ref(), feeds: channels, user: None, - content: LoginPage {}.to_string() - }.to_string() + content: LoginPage {}.to_string(), + } + .to_string(), ) } #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] struct LoginForm { - url: url::Url + url: url::Url, } /// Accept login and start the IndieAuth dance. @@ -60,10 +75,12 @@ async fn post( .expires(None) .secure(true) .http_only(true) - .build() + .build(), ); - let client_id: url::Url = format!("https://{}/.kittybox/login/client_metadata", host).parse().unwrap(); + let client_id: url::Url = format!("https://{}/.kittybox/login/client_metadata", host) + .parse() + .unwrap(); let redirect_uri = { let mut uri = client_id.clone(); uri.set_path("/.kittybox/login/finish"); @@ -71,11 +88,15 @@ async fn post( }; let indieauth_state = kittybox_indieauth::AuthorizationRequest { response_type: kittybox_indieauth::ResponseType::Code, - client_id, redirect_uri, + client_id, + redirect_uri, state: kittybox_indieauth::State::new(), - code_challenge: kittybox_indieauth::PKCEChallenge::new(&code_verifier, kittybox_indieauth::PKCEMethod::S256), + code_challenge: kittybox_indieauth::PKCEChallenge::new( + &code_verifier, + kittybox_indieauth::PKCEMethod::S256, + ), scope: Some(Scopes::new(vec![Scope::Profile])), - me: Some(form.url.clone()) + me: Some(form.url.clone()), }; // Fetch the user's homepage, determine their authorization endpoint @@ -89,8 +110,9 @@ async fn post( tracing::error!("Error fetching homepage: {:?}", err); return ( StatusCode::BAD_REQUEST, - format!("couldn't fetch your homepage: {}", err) - ).into_response() + format!("couldn't fetch your homepage: {}", err), + ) + .into_response(); } }; @@ -106,22 +128,27 @@ async fn post( // .collect::<Vec<axum_extra::headers::Link>>(); // // todo!("parse Link: headers") - + let body = match response.text().await { Ok(body) => match microformats::from_html(&body, form.url) { Ok(mf2) => mf2, - Err(err) => return ( - StatusCode::BAD_REQUEST, - format!("error while parsing your homepage with mf2: {}", err) - ).into_response() + Err(err) => { + return ( + StatusCode::BAD_REQUEST, + format!("error while parsing your homepage with mf2: {}", err), + ) + .into_response() + } }, - Err(err) => return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("error while fetching your homepage: {}", err) - ).into_response() + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("error while fetching your homepage: {}", err), + ) + .into_response() + } }; - let mut iss: Option<url::Url> = None; let mut authorization_endpoint = match body .rels @@ -139,10 +166,22 @@ async fn post( Ok(metadata) => { iss = Some(metadata.issuer); metadata.authorization_endpoint - }, - Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("couldn't parse your oauth2 metadata: {}", err)).into_response() + } + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("couldn't parse your oauth2 metadata: {}", err), + ) + .into_response() + } }, - Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("couldn't fetch your oauth2 metadata: {}", err)).into_response() + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("couldn't fetch your oauth2 metadata: {}", err), + ) + .into_response() + } }, None => match body .rels @@ -151,13 +190,17 @@ async fn post( .map(|v| v.as_slice()) .unwrap_or_default() .first() - .cloned() { - Some(authorization_endpoint) => authorization_endpoint, - None => return ( + .cloned() + { + Some(authorization_endpoint) => authorization_endpoint, + None => { + return ( StatusCode::BAD_REQUEST, - "no authorization endpoint was found on your homepage." - ).into_response() + "no authorization endpoint was found on your homepage.", + ) + .into_response() } + }, }; cookies = cookies.add( @@ -166,7 +209,7 @@ async fn post( .expires(None) .secure(true) .http_only(true) - .build() + .build(), ); if let Some(iss) = iss { @@ -176,7 +219,7 @@ async fn post( .expires(None) .secure(true) .http_only(true) - .build() + .build(), ); } @@ -186,7 +229,7 @@ async fn post( .expires(None) .secure(true) .http_only(true) - .build() + .build(), ); authorization_endpoint @@ -194,9 +237,12 @@ async fn post( .extend_pairs(indieauth_state.as_query_pairs().iter()); tracing::debug!("Forwarding user to {}", authorization_endpoint); - (StatusCode::FOUND, [ - ("Location", authorization_endpoint.to_string()), - ], cookies).into_response() + ( + StatusCode::FOUND, + [("Location", authorization_endpoint.to_string())], + cookies, + ) + .into_response() } /// Accept the return of the IndieAuth dance. Set a cookie for the @@ -208,7 +254,9 @@ async fn callback( State(http): State<reqwest_middleware::ClientWithMiddleware>, State(session_store): State<crate::SessionStore>, ) -> axum::response::Response { - let client_id: url::Url = format!("https://{}/.kittybox/login/client_metadata", host).parse().unwrap(); + let client_id: url::Url = format!("https://{}/.kittybox/login/client_metadata", host) + .parse() + .unwrap(); let redirect_uri = { let mut uri = client_id.clone(); uri.set_path("/.kittybox/login/finish"); @@ -218,7 +266,8 @@ async fn callback( let me: url::Url = cookie_jar.get("me").unwrap().value().parse().unwrap(); let code_verifier: PKCEVerifier = cookie_jar.get("code_verifier").unwrap().value().into(); - let authorization_endpoint: url::Url = cookie_jar.get("authorization_endpoint") + let authorization_endpoint: url::Url = cookie_jar + .get("authorization_endpoint") .and_then(|v| v.value().parse().ok()) .unwrap(); match cookie_jar.get("iss").and_then(|c| c.value().parse().ok()) { @@ -232,24 +281,59 @@ async fn callback( code: response.code, client_id, redirect_uri, - code_verifier, + code_verifier, }; - tracing::debug!("POSTing {:?} to authorization endpoint {}", grant_request, authorization_endpoint); - let res = match http.post(authorization_endpoint) + tracing::debug!( + "POSTing {:?} to authorization endpoint {}", + grant_request, + authorization_endpoint + ); + let res = match http + .post(authorization_endpoint) .form(&grant_request) .header(reqwest::header::ACCEPT, "application/json") .send() .await { - Ok(res) if res.status().is_success() => match res.json::<kittybox_indieauth::GrantResponse>().await { - Ok(grant) => grant, - Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, [(CACHE_CONTROL, "no-store")], format!("error parsing authorization endpoint response: {}", err)).into_response() - }, + Ok(res) if res.status().is_success() => { + match res.json::<kittybox_indieauth::GrantResponse>().await { + Ok(grant) => grant, + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + [(CACHE_CONTROL, "no-store")], + format!("error parsing authorization endpoint response: {}", err), + ) + .into_response() + } + } + } Ok(res) => match res.json::<Error>().await { - Ok(err) => return (StatusCode::BAD_REQUEST, [(CACHE_CONTROL, "no-store")], err.to_string()).into_response(), - Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, [(CACHE_CONTROL, "no-store")], format!("error parsing indieauth error: {}", err)).into_response() + Ok(err) => { + return ( + StatusCode::BAD_REQUEST, + [(CACHE_CONTROL, "no-store")], + err.to_string(), + ) + .into_response() + } + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + [(CACHE_CONTROL, "no-store")], + format!("error parsing indieauth error: {}", err), + ) + .into_response() + } + }, + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + [(CACHE_CONTROL, "no-store")], + format!("error redeeming authorization code: {}", err), + ) + .into_response() } - Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, [(CACHE_CONTROL, "no-store")], format!("error redeeming authorization code: {}", err)).into_response() }; let profile = match res { @@ -265,19 +349,28 @@ async fn callback( let uuid = uuid::Uuid::new_v4(); session_store.write().await.insert(uuid, session); let cookies = cookie_jar - .add(Cookie::build(("session_id", uuid.to_string())) - .expires(None) - .secure(true) - .http_only(true) - .path("/") - .build() + .add( + Cookie::build(("session_id", uuid.to_string())) + .expires(None) + .secure(true) + .http_only(true) + .path("/") + .build(), ) .remove("authorization_endpoint") .remove("me") .remove("iss") .remove("code_verifier"); - (StatusCode::FOUND, [(LOCATION, HeaderValue::from_static("/")), (CACHE_CONTROL, HeaderValue::from_static("no-store"))], dbg!(cookies)).into_response() + ( + StatusCode::FOUND, + [ + (LOCATION, HeaderValue::from_static("/")), + (CACHE_CONTROL, HeaderValue::from_static("no-store")), + ], + dbg!(cookies), + ) + .into_response() } /// Show the form necessary for logout. If JS is enabled, @@ -288,32 +381,42 @@ async fn callback( /// stupid enough to execute JS and send a POST request though, that's /// on the crawler. async fn logout_page() -> impl axum::response::IntoResponse { - (StatusCode::OK, [("Content-Type", "text/html")], Template { - title: "Signing out...", - blog_name: "Kittybox", - feeds: vec![], - user: None, - content: LogoutPage {}.to_string() - }.to_string()) + ( + StatusCode::OK, + [("Content-Type", "text/html")], + Template { + title: "Signing out...", + blog_name: "Kittybox", + feeds: vec![], + user: None, + content: LogoutPage {}.to_string(), + } + .to_string(), + ) } /// Erase the necessary cookies for login and invalidate the session. async fn logout( mut cookies: SignedCookieJar, - State(session_store): State<crate::SessionStore> -) -> (StatusCode, [(&'static str, &'static str); 1], SignedCookieJar) { - if let Some(id) = cookies.get("session_id") + State(session_store): State<crate::SessionStore>, +) -> ( + StatusCode, + [(&'static str, &'static str); 1], + SignedCookieJar, +) { + if let Some(id) = cookies + .get("session_id") .and_then(|c| uuid::Uuid::parse_str(c.value_trimmed()).ok()) { session_store.write().await.remove(&id); } - cookies = cookies.remove("me") + cookies = cookies + .remove("me") .remove("iss") .remove("authorization_endpoint") .remove("code_verifier") .remove("session_id"); - (StatusCode::FOUND, [("Location", "/")], cookies) } @@ -343,7 +446,7 @@ async fn client_metadata<S: Storage + Send + Sync + 'static>( }; if let Some(cached) = cached { if cached.precondition_passes(&etag) { - return StatusCode::NOT_MODIFIED.into_response() + return StatusCode::NOT_MODIFIED.into_response(); } } let client_uri: url::Url = format!("https://{}/", host).parse().unwrap(); @@ -356,7 +459,13 @@ async fn client_metadata<S: Storage + Send + Sync + 'static>( let mut metadata = kittybox_indieauth::ClientMetadata::new(client_id, client_uri).unwrap(); - metadata.client_name = Some(storage.get_setting::<crate::database::settings::SiteName>(&metadata.client_uri).await.unwrap_or_default().0); + metadata.client_name = Some( + storage + .get_setting::<crate::database::settings::SiteName>(&metadata.client_uri) + .await + .unwrap_or_default() + .0, + ); metadata.grant_types = Some(vec![GrantType::AuthorizationCode]); // We don't request anything more than the profile scope. metadata.scope = Some(Scopes::new(vec![Scope::Profile])); @@ -368,15 +477,18 @@ async fn client_metadata<S: Storage + Send + Sync + 'static>( // identity providers, or json to match newest spec let mut response = metadata.into_response(); // Indicate to upstream caches this endpoint does different things depending on the Accept: header. - response.headers_mut().append("Vary", HeaderValue::from_static("Accept")); + response + .headers_mut() + .append("Vary", HeaderValue::from_static("Accept")); // Cache this metadata for an hour. - response.headers_mut().append("Cache-Control", HeaderValue::from_static("max-age=600")); + response + .headers_mut() + .append("Cache-Control", HeaderValue::from_static("max-age=600")); response.headers_mut().typed_insert(etag); response } - /// Produce a router for all of the above. pub fn router<St, S>() -> axum::routing::Router<St> where diff --git a/src/main.rs b/src/main.rs index bd3684e..984745a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,11 @@ -use kittybox::{database::Storage, indieauth::backend::AuthBackend, media::storage::MediaStore, webmentions::Webmention, compose_kittybox}; -use tokio::{sync::Mutex, task::JoinSet}; +use kittybox::{ + compose_kittybox, database::Storage, indieauth::backend::AuthBackend, + media::storage::MediaStore, webmentions::Webmention, +}; use std::{env, future::IntoFuture, sync::Arc}; +use tokio::{sync::Mutex, task::JoinSet}; use tracing::error; - #[tokio::main] async fn main() { use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Registry}; @@ -17,32 +19,28 @@ async fn main() { .with_indent_lines(true) .with_verbose_exit(true), #[cfg(not(debug_assertions))] - tracing_subscriber::fmt::layer().json() - .with_ansi(std::io::IsTerminal::is_terminal(&std::io::stdout().lock())) + tracing_subscriber::fmt::layer() + .json() + .with_ansi(std::io::IsTerminal::is_terminal(&std::io::stdout().lock())), ); // In debug builds, also log to JSON, but to file. #[cfg(debug_assertions)] - let tracing_registry = tracing_registry.with( - tracing_subscriber::fmt::layer() - .json() - .with_writer({ - let instant = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap(); - move || std::fs::OpenOptions::new() + let tracing_registry = + tracing_registry.with(tracing_subscriber::fmt::layer().json().with_writer({ + let instant = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap(); + move || { + std::fs::OpenOptions::new() .append(true) .create(true) - .open( - format!( - "{}.log.json", - instant - .as_secs_f64() - .to_string() - .replace('.', "_") - ) - ).unwrap() - }) - ); + .open(format!( + "{}.log.json", + instant.as_secs_f64().to_string().replace('.', "_") + )) + .unwrap() + } + })); tracing_registry.init(); tracing::info!("Starting the kittybox server..."); @@ -79,12 +77,15 @@ async fn main() { }); // TODO: load from environment - let cookie_key = axum_extra::extract::cookie::Key::from(&env::var("COOKIE_KEY") - .as_deref() - .map(|s| data_encoding::BASE64_MIME_PERMISSIVE.decode(s.as_bytes()) - .expect("Invalid cookie key: must be base64 encoded") - ) - .unwrap() + let cookie_key = axum_extra::extract::cookie::Key::from( + &env::var("COOKIE_KEY") + .as_deref() + .map(|s| { + data_encoding::BASE64_MIME_PERMISSIVE + .decode(s.as_bytes()) + .expect("Invalid cookie key: must be base64 encoded") + }) + .unwrap(), ); let cancellation_token = tokio_util::sync::CancellationToken::new(); @@ -93,12 +94,11 @@ async fn main() { let http: reqwest_middleware::ClientWithMiddleware = { #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), - "/", - env!("CARGO_PKG_VERSION") - )); + let mut builder = reqwest::Client::builder().user_agent(concat!( + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION") + )); if let Ok(certs) = std::env::var("KITTYBOX_CUSTOM_PKI_ROOTS") { // TODO: add a root certificate if there's an environment variable pointing at it for path in certs.split(':') { @@ -108,21 +108,19 @@ async fn main() { tracing::error!("TLS root certificate {} not found, skipping...", path); continue; } - Err(err) => panic!("Error loading TLS certificates: {}", err) + Err(err) => panic!("Error loading TLS certificates: {}", err), }; if metadata.is_dir() { let mut dir = tokio::fs::read_dir(path).await.unwrap(); while let Ok(Some(file)) = dir.next_entry().await { let pem = tokio::fs::read(file.path()).await.unwrap(); - builder = builder.add_root_certificate( - reqwest::Certificate::from_pem(&pem).unwrap() - ); + builder = builder + .add_root_certificate(reqwest::Certificate::from_pem(&pem).unwrap()); } } else { let pem = tokio::fs::read(path).await.unwrap(); - builder = builder.add_root_certificate( - reqwest::Certificate::from_pem(&pem).unwrap() - ); + builder = + builder.add_root_certificate(reqwest::Certificate::from_pem(&pem).unwrap()); } } } @@ -151,7 +149,7 @@ async fn main() { let job_queue_type = job_queue_uri.scheme(); macro_rules! compose_kittybox { - ($auth:ty, $store:ty, $media:ty, $queue:ty) => { { + ($auth:ty, $store:ty, $media:ty, $queue:ty) => {{ type AuthBackend = $auth; type Storage = $store; type MediaStore = $media; @@ -193,36 +191,43 @@ async fn main() { }; type St = kittybox::AppState<AuthBackend, Storage, MediaStore, JobQueue>; - let stateful_router = compose_kittybox::<St, AuthBackend, Storage, MediaStore, JobQueue>().await; - let task = kittybox::webmentions::supervised_webmentions_task::<St, Storage, JobQueue>(&state, cancellation_token.clone()); + let stateful_router = + compose_kittybox::<St, AuthBackend, Storage, MediaStore, JobQueue>().await; + let task = kittybox::webmentions::supervised_webmentions_task::<St, Storage, JobQueue>( + &state, + cancellation_token.clone(), + ); let router = stateful_router.with_state(state); (router, task) - } } + }}; } - let (router, webmentions_task): (axum::Router<()>, kittybox::webmentions::SupervisedTask) = match (authstore_type, backend_type, blobstore_type, job_queue_type) { - ("file", "file", "file", "postgres") => { - compose_kittybox!( - kittybox::indieauth::backend::fs::FileBackend, - kittybox::database::FileStorage, - kittybox::media::storage::file::FileStore, - kittybox::webmentions::queue::PostgresJobQueue<Webmention> - ) - }, - ("file", "postgres", "file", "postgres") => { - compose_kittybox!( - kittybox::indieauth::backend::fs::FileBackend, - kittybox::database::PostgresStorage, - kittybox::media::storage::file::FileStore, - kittybox::webmentions::queue::PostgresJobQueue<Webmention> - ) - }, - (_, _, _, _) => { - // TODO: refine this error. - panic!("Invalid type for AUTH_STORE_URI, BACKEND_URI, BLOBSTORE_URI or JOB_QUEUE_URI"); - } - }; + let (router, webmentions_task): (axum::Router<()>, kittybox::webmentions::SupervisedTask) = + match (authstore_type, backend_type, blobstore_type, job_queue_type) { + ("file", "file", "file", "postgres") => { + compose_kittybox!( + kittybox::indieauth::backend::fs::FileBackend, + kittybox::database::FileStorage, + kittybox::media::storage::file::FileStore, + kittybox::webmentions::queue::PostgresJobQueue<Webmention> + ) + } + ("file", "postgres", "file", "postgres") => { + compose_kittybox!( + kittybox::indieauth::backend::fs::FileBackend, + kittybox::database::PostgresStorage, + kittybox::media::storage::file::FileStore, + kittybox::webmentions::queue::PostgresJobQueue<Webmention> + ) + } + (_, _, _, _) => { + // TODO: refine this error. + panic!( + "Invalid type for AUTH_STORE_URI, BACKEND_URI, BLOBSTORE_URI or JOB_QUEUE_URI" + ); + } + }; let mut servers: Vec<axum::serve::Serve<_, _, _>> = vec![]; @@ -238,7 +243,7 @@ async fn main() { // .serve(router.clone().into_make_service()) axum::serve( tokio::net::TcpListener::from_std(tcp).unwrap(), - router.clone() + router.clone(), ) }; @@ -246,8 +251,8 @@ async fn main() { for i in 0..(listenfd.len()) { match listenfd.take_tcp_listener(i) { Ok(Some(tcp)) => servers.push(build_hyper(tcp)), - Ok(None) => {}, - Err(err) => tracing::error!("Error binding to socket in fd {}: {}", i, err) + Ok(None) => {} + Err(err) => tracing::error!("Error binding to socket in fd {}: {}", i, err), } } // TODO this requires the `hyperlocal` crate @@ -302,24 +307,35 @@ async fn main() { // to get rid of an extra reference to `jobset` drop(router); // Polling streams mutates them - let mut servers_futures = Box::pin(servers.into_iter() - .map( - #[cfg(not(tokio_unstable))] |server| tokio::task::spawn( - server.with_graceful_shutdown(cancellation_token.clone().cancelled_owned()) - .into_future() - ), - #[cfg(tokio_unstable)] |server| { - tokio::task::Builder::new() - .name(format!("Kittybox HTTP acceptor: {:?}", server).as_str()) - .spawn( - server.with_graceful_shutdown( - cancellation_token.clone().cancelled_owned() - ).into_future() + let mut servers_futures = Box::pin( + servers + .into_iter() + .map( + #[cfg(not(tokio_unstable))] + |server| { + tokio::task::spawn( + server + .with_graceful_shutdown(cancellation_token.clone().cancelled_owned()) + .into_future(), ) - .unwrap() - } - ) - .collect::<futures_util::stream::FuturesUnordered<tokio::task::JoinHandle<Result<(), std::io::Error>>>>() + }, + #[cfg(tokio_unstable)] + |server| { + tokio::task::Builder::new() + .name(format!("Kittybox HTTP acceptor: {:?}", server).as_str()) + .spawn( + server + .with_graceful_shutdown( + cancellation_token.clone().cancelled_owned(), + ) + .into_future(), + ) + .unwrap() + }, + ) + .collect::<futures_util::stream::FuturesUnordered< + tokio::task::JoinHandle<Result<(), std::io::Error>>, + >>(), ); #[cfg(not(unix))] @@ -329,10 +345,10 @@ async fn main() { use tokio::signal::unix::{signal, SignalKind}; async move { - let mut interrupt = signal(SignalKind::interrupt()) - .expect("Failed to set up SIGINT handler"); - let mut terminate = signal(SignalKind::terminate()) - .expect("Failed to setup SIGTERM handler"); + let mut interrupt = + signal(SignalKind::interrupt()).expect("Failed to set up SIGINT handler"); + let mut terminate = + signal(SignalKind::terminate()).expect("Failed to setup SIGTERM handler"); tokio::select! { _ = terminate.recv() => {}, diff --git a/src/media/mod.rs b/src/media/mod.rs index 6f263b6..7e52414 100644 --- a/src/media/mod.rs +++ b/src/media/mod.rs @@ -1,22 +1,23 @@ +use crate::indieauth::{backend::AuthBackend, User}; use axum::{ - extract::{multipart::Multipart, FromRef, Path, State}, response::{IntoResponse, Response} + extract::{multipart::Multipart, FromRef, Path, State}, + response::{IntoResponse, Response}, }; -use axum_extra::headers::{ContentLength, HeaderMapExt, HeaderValue, IfNoneMatch}; use axum_extra::extract::Host; +use axum_extra::headers::{ContentLength, HeaderMapExt, HeaderValue, IfNoneMatch}; use axum_extra::TypedHeader; -use kittybox_util::micropub::{Error as MicropubError, ErrorKind as ErrorType}; use kittybox_indieauth::Scope; -use crate::indieauth::{backend::AuthBackend, User}; +use kittybox_util::micropub::{Error as MicropubError, ErrorKind as ErrorType}; pub mod storage; -use storage::{MediaStore, MediaStoreError, Metadata, ErrorKind}; pub use storage::file::FileStore; +use storage::{ErrorKind, MediaStore, MediaStoreError, Metadata}; impl From<MediaStoreError> for MicropubError { fn from(err: MediaStoreError) -> Self { Self::new( ErrorType::InternalServerError, - format!("media store error: {}", err) + format!("media store error: {}", err), ) } } @@ -25,13 +26,14 @@ impl From<MediaStoreError> for MicropubError { pub(crate) async fn upload<S: MediaStore, A: AuthBackend>( State(blobstore): State<S>, user: User<A>, - mut upload: Multipart + mut upload: Multipart, ) -> Response { if !user.check_scope(&Scope::Media) { return MicropubError::from_static( ErrorType::NotAuthorized, - "Interacting with the media storage requires the \"media\" scope." - ).into_response(); + "Interacting with the media storage requires the \"media\" scope.", + ) + .into_response(); } let host = user.me.authority(); let field = match upload.next_field().await { @@ -39,27 +41,31 @@ pub(crate) async fn upload<S: MediaStore, A: AuthBackend>( Ok(None) => { return MicropubError::from_static( ErrorType::InvalidRequest, - "Send multipart/form-data with one field named file" - ).into_response(); - }, + "Send multipart/form-data with one field named file", + ) + .into_response(); + } Err(err) => { return MicropubError::new( ErrorType::InternalServerError, - format!("Error while parsing multipart/form-data: {}", err) - ).into_response(); - }, + format!("Error while parsing multipart/form-data: {}", err), + ) + .into_response(); + } }; let metadata: Metadata = (&field).into(); match blobstore.write_streaming(host, metadata, field).await { Ok(filename) => IntoResponse::into_response(( axum::http::StatusCode::CREATED, - [ - ("Location", user.me.join( - &format!(".kittybox/media/uploads/{}", filename) - ).unwrap().as_str()) - ] + [( + "Location", + user.me + .join(&format!(".kittybox/media/uploads/{}", filename)) + .unwrap() + .as_str(), + )], )), - Err(err) => MicropubError::from(err).into_response() + Err(err) => MicropubError::from(err).into_response(), } } @@ -68,7 +74,7 @@ pub(crate) async fn serve<S: MediaStore>( Host(host): Host, Path(path): Path<String>, if_none_match: Option<TypedHeader<IfNoneMatch>>, - State(blobstore): State<S> + State(blobstore): State<S>, ) -> Response { use axum::http::StatusCode; tracing::debug!("Searching for file..."); @@ -77,7 +83,9 @@ pub(crate) async fn serve<S: MediaStore>( tracing::debug!("Metadata: {:?}", metadata); let etag = if let Some(etag) = metadata.etag { - let etag = format!("\"{}\"", etag).parse::<axum_extra::headers::ETag>().unwrap(); + let etag = format!("\"{}\"", etag) + .parse::<axum_extra::headers::ETag>() + .unwrap(); if let Some(TypedHeader(if_none_match)) = if_none_match { tracing::debug!("If-None-Match: {:?}", if_none_match); @@ -85,12 +93,14 @@ pub(crate) async fn serve<S: MediaStore>( // returns 304 when it doesn't match because it // only matches when file is different if !if_none_match.precondition_passes(&etag) { - return StatusCode::NOT_MODIFIED.into_response() + return StatusCode::NOT_MODIFIED.into_response(); } } Some(etag) - } else { None }; + } else { + None + }; let mut r = Response::builder(); { @@ -98,14 +108,16 @@ pub(crate) async fn serve<S: MediaStore>( headers.insert( "Content-Type", HeaderValue::from_str( - metadata.content_type + metadata + .content_type .as_deref() - .unwrap_or("application/octet-stream") - ).unwrap() + .unwrap_or("application/octet-stream"), + ) + .unwrap(), ); headers.insert( axum::http::header::X_CONTENT_TYPE_OPTIONS, - axum::http::HeaderValue::from_static("nosniff") + axum::http::HeaderValue::from_static("nosniff"), ); if let Some(length) = metadata.length { headers.typed_insert(ContentLength(length.get().try_into().unwrap())); @@ -117,23 +129,22 @@ pub(crate) async fn serve<S: MediaStore>( r.body(axum::body::Body::from_stream(stream)) .unwrap() .into_response() - }, + } Err(err) => match err.kind() { - ErrorKind::NotFound => { - IntoResponse::into_response(StatusCode::NOT_FOUND) - }, + ErrorKind::NotFound => IntoResponse::into_response(StatusCode::NOT_FOUND), _ => { tracing::error!("{}", err); IntoResponse::into_response(StatusCode::INTERNAL_SERVER_ERROR) } - } + }, } } -pub fn router<St, A, M>() -> axum::Router<St> where +pub fn router<St, A, M>() -> axum::Router<St> +where A: AuthBackend + FromRef<St>, M: MediaStore + FromRef<St>, - St: Clone + Send + Sync + 'static + St: Clone + Send + Sync + 'static, { axum::Router::new() .route("/", axum::routing::post(upload::<M, A>)) diff --git a/src/media/storage/file.rs b/src/media/storage/file.rs index 4cd0ece..5198a4c 100644 --- a/src/media/storage/file.rs +++ b/src/media/storage/file.rs @@ -1,12 +1,12 @@ -use super::{Metadata, ErrorKind, MediaStore, MediaStoreError, Result}; -use std::{path::PathBuf, fmt::Debug}; -use tokio::fs::OpenOptions; -use tokio::io::{BufReader, BufWriter, AsyncWriteExt, AsyncSeekExt}; +use super::{ErrorKind, MediaStore, MediaStoreError, Metadata, Result}; +use futures::FutureExt; use futures::{StreamExt, TryStreamExt}; +use sha2::Digest; use std::ops::{Bound, Neg}; use std::pin::Pin; -use sha2::Digest; -use futures::FutureExt; +use std::{fmt::Debug, path::PathBuf}; +use tokio::fs::OpenOptions; +use tokio::io::{AsyncSeekExt, AsyncWriteExt, BufReader, BufWriter}; use tracing::{debug, error}; const BUF_CAPACITY: usize = 16 * 1024; @@ -22,7 +22,7 @@ impl From<tokio::io::Error> for MediaStoreError { msg: format!("file I/O error: {}", source), kind: match source.kind() { std::io::ErrorKind::NotFound => ErrorKind::NotFound, - _ => ErrorKind::Backend + _ => ErrorKind::Backend, }, source: Some(Box::new(source)), } @@ -40,7 +40,9 @@ impl FileStore { impl MediaStore for FileStore { async fn new(url: &'_ url::Url) -> Result<Self> { - Ok(Self { base: url.path().into() }) + Ok(Self { + base: url.path().into(), + }) } #[tracing::instrument(skip(self, content))] @@ -51,10 +53,17 @@ impl MediaStore for FileStore { mut content: T, ) -> Result<String> where - T: tokio_stream::Stream<Item = std::result::Result<bytes::Bytes, axum::extract::multipart::MultipartError>> + Unpin + Send + Debug + T: tokio_stream::Stream< + Item = std::result::Result<bytes::Bytes, axum::extract::multipart::MultipartError>, + > + Unpin + + Send + + Debug, { let (tempfilepath, mut tempfile) = self.mktemp().await?; - debug!("Temporary file opened for storing pending upload: {}", tempfilepath.display()); + debug!( + "Temporary file opened for storing pending upload: {}", + tempfilepath.display() + ); let mut hasher = sha2::Sha256::new(); let mut length: usize = 0; @@ -62,7 +71,7 @@ impl MediaStore for FileStore { let chunk = chunk.map_err(|err| MediaStoreError { kind: ErrorKind::Backend, source: Some(Box::new(err)), - msg: "Failed to read a data chunk".to_owned() + msg: "Failed to read a data chunk".to_owned(), })?; debug!("Read {} bytes from the stream", chunk.len()); length += chunk.len(); @@ -70,9 +79,7 @@ impl MediaStore for FileStore { { let chunk = chunk.clone(); let tempfile = &mut tempfile; - async move { - tempfile.write_all(&chunk).await - } + async move { tempfile.write_all(&chunk).await } }, { let chunk = chunk.clone(); @@ -80,7 +87,8 @@ impl MediaStore for FileStore { hasher.update(&*chunk); hasher - }).map(|r| r.unwrap()) + }) + .map(|r| r.unwrap()) } ); if let Err(err) = write_result { @@ -90,7 +98,9 @@ impl MediaStore for FileStore { // though temporary files might take up space on the hard drive // We'll clean them when maintenance time comes #[allow(unused_must_use)] - { tokio::fs::remove_file(tempfilepath).await; } + { + tokio::fs::remove_file(tempfilepath).await; + } return Err(err.into()); } hasher = _hasher; @@ -113,10 +123,17 @@ impl MediaStore for FileStore { let filepath = self.base.join(domain_str.as_str()).join(&filename); let metafilename = filename.clone() + ".json"; let metapath = self.base.join(domain_str.as_str()).join(&metafilename); - let metatemppath = self.base.join(domain_str.as_str()).join(metafilename + ".tmp"); + let metatemppath = self + .base + .join(domain_str.as_str()) + .join(metafilename + ".tmp"); metadata.length = std::num::NonZeroUsize::new(length); metadata.etag = Some(hash); - debug!("File path: {}, metadata: {}", filepath.display(), metapath.display()); + debug!( + "File path: {}, metadata: {}", + filepath.display(), + metapath.display() + ); { let parent = filepath.parent().unwrap(); tokio::fs::create_dir_all(parent).await?; @@ -126,7 +143,8 @@ impl MediaStore for FileStore { .write(true) .open(&metatemppath) .await?; - meta.write_all(&serde_json::to_vec(&metadata).unwrap()).await?; + meta.write_all(&serde_json::to_vec(&metadata).unwrap()) + .await?; tokio::fs::rename(tempfilepath, filepath).await?; tokio::fs::rename(metatemppath, metapath).await?; Ok(filename) @@ -138,28 +156,31 @@ impl MediaStore for FileStore { &self, domain: &str, filename: &str, - ) -> Result<(Metadata, Pin<Box<dyn tokio_stream::Stream<Item = std::io::Result<bytes::Bytes>> + Send>>)> { + ) -> Result<( + Metadata, + Pin<Box<dyn tokio_stream::Stream<Item = std::io::Result<bytes::Bytes>> + Send>>, + )> { debug!("Domain: {}, filename: {}", domain, filename); let path = self.base.join(domain).join(filename); debug!("Path: {}", path.display()); - let file = OpenOptions::new() - .read(true) - .open(path) - .await?; + let file = OpenOptions::new().read(true).open(path).await?; let meta = self.metadata(domain, filename).await?; - Ok((meta, Box::pin( - tokio_util::io::ReaderStream::new( - // TODO: determine if BufReader provides benefit here - // From the logs it looks like we're reading 4KiB at a time - // Buffering file contents seems to double download speed - // How to benchmark this? - BufReader::with_capacity(BUF_CAPACITY, file) - ) - // Sprinkle some salt in form of protective log wrapping - .inspect_ok(|chunk| debug!("Read {} bytes from file", chunk.len())) - ))) + Ok(( + meta, + Box::pin( + tokio_util::io::ReaderStream::new( + // TODO: determine if BufReader provides benefit here + // From the logs it looks like we're reading 4KiB at a time + // Buffering file contents seems to double download speed + // How to benchmark this? + BufReader::with_capacity(BUF_CAPACITY, file), + ) + // Sprinkle some salt in form of protective log wrapping + .inspect_ok(|chunk| debug!("Read {} bytes from file", chunk.len())), + ), + )) } #[tracing::instrument(skip(self))] @@ -167,12 +188,13 @@ impl MediaStore for FileStore { let metapath = self.base.join(domain).join(format!("{}.json", filename)); debug!("Metadata path: {}", metapath.display()); - let meta = serde_json::from_slice(&tokio::fs::read(metapath).await?) - .map_err(|err| MediaStoreError { + let meta = serde_json::from_slice(&tokio::fs::read(metapath).await?).map_err(|err| { + MediaStoreError { kind: ErrorKind::Json, msg: format!("{}", err), - source: Some(Box::new(err)) - })?; + source: Some(Box::new(err)), + } + })?; Ok(meta) } @@ -182,16 +204,14 @@ impl MediaStore for FileStore { &self, domain: &str, filename: &str, - range: (Bound<u64>, Bound<u64>) - ) -> Result<Pin<Box<dyn tokio_stream::Stream<Item = std::io::Result<bytes::Bytes>> + Send>>> { + range: (Bound<u64>, Bound<u64>), + ) -> Result<Pin<Box<dyn tokio_stream::Stream<Item = std::io::Result<bytes::Bytes>> + Send>>> + { let path = self.base.join(format!("{}/{}", domain, filename)); let metapath = self.base.join(format!("{}/{}.json", domain, filename)); debug!("Path: {}, metadata: {}", path.display(), metapath.display()); - let mut file = OpenOptions::new() - .read(true) - .open(path) - .await?; + let mut file = OpenOptions::new().read(true).open(path).await?; let start = match range { (Bound::Included(bound), _) => { @@ -202,45 +222,52 @@ impl MediaStore for FileStore { (Bound::Unbounded, Bound::Included(bound)) => { // Seek to the end minus the bounded bytes debug!("Seeking {} bytes back from the end...", bound); - file.seek(std::io::SeekFrom::End(i64::try_from(bound).unwrap().neg())).await? - }, + file.seek(std::io::SeekFrom::End(i64::try_from(bound).unwrap().neg())) + .await? + } (Bound::Unbounded, Bound::Unbounded) => 0, - (_, Bound::Excluded(_)) => unreachable!() + (_, Bound::Excluded(_)) => unreachable!(), }; - let stream = Box::pin(tokio_util::io::ReaderStream::new(BufReader::with_capacity(BUF_CAPACITY, file))) - .map_ok({ - let mut bytes_read = 0usize; - let len = match range { - (_, Bound::Unbounded) => None, - (Bound::Unbounded, Bound::Included(bound)) => Some(bound), - (_, Bound::Included(bound)) => Some(bound + 1 - start), - (_, Bound::Excluded(_)) => unreachable!() - }; - move |chunk| { - debug!("Read {} bytes from file, {} in this chunk", bytes_read, chunk.len()); - bytes_read += chunk.len(); - if let Some(len) = len.map(|len| len.try_into().unwrap()) { - if bytes_read > len { - if bytes_read - len > chunk.len() { - return None - } - debug!("Truncating last {} bytes", bytes_read - len); - return Some(chunk.slice(..chunk.len() - (bytes_read - len))) + let stream = Box::pin(tokio_util::io::ReaderStream::new(BufReader::with_capacity( + BUF_CAPACITY, + file, + ))) + .map_ok({ + let mut bytes_read = 0usize; + let len = match range { + (_, Bound::Unbounded) => None, + (Bound::Unbounded, Bound::Included(bound)) => Some(bound), + (_, Bound::Included(bound)) => Some(bound + 1 - start), + (_, Bound::Excluded(_)) => unreachable!(), + }; + move |chunk| { + debug!( + "Read {} bytes from file, {} in this chunk", + bytes_read, + chunk.len() + ); + bytes_read += chunk.len(); + if let Some(len) = len.map(|len| len.try_into().unwrap()) { + if bytes_read > len { + if bytes_read - len > chunk.len() { + return None; } + debug!("Truncating last {} bytes", bytes_read - len); + return Some(chunk.slice(..chunk.len() - (bytes_read - len))); } - - Some(chunk) } - }) - .try_take_while(|x| std::future::ready(Ok(x.is_some()))) - // Will never panic, because the moment the stream yields - // a None, it is considered exhausted. - .map_ok(|x| x.unwrap()); - return Ok(Box::pin(stream)) - } + Some(chunk) + } + }) + .try_take_while(|x| std::future::ready(Ok(x.is_some()))) + // Will never panic, because the moment the stream yields + // a None, it is considered exhausted. + .map_ok(|x| x.unwrap()); + return Ok(Box::pin(stream)); + } async fn delete(&self, domain: &str, filename: &str) -> Result<()> { let path = self.base.join(format!("{}/{}", domain, filename)); @@ -251,7 +278,7 @@ impl MediaStore for FileStore { #[cfg(test)] mod tests { - use super::{Metadata, FileStore, MediaStore}; + use super::{FileStore, MediaStore, Metadata}; use std::ops::Bound; use tokio::io::AsyncReadExt; @@ -259,10 +286,15 @@ mod tests { #[tracing_test::traced_test] async fn test_ranges() { let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); - let store = FileStore { base: tempdir.path().to_path_buf() }; + let store = FileStore { + base: tempdir.path().to_path_buf(), + }; let file: &[u8] = include_bytes!("./file.rs"); - let stream = tokio_stream::iter(file.chunks(100).map(|i| Ok(bytes::Bytes::copy_from_slice(i)))); + let stream = tokio_stream::iter( + file.chunks(100) + .map(|i| Ok(bytes::Bytes::copy_from_slice(i))), + ); let metadata = Metadata { filename: Some("file.rs".to_string()), content_type: Some("text/plain".to_string()), @@ -271,28 +303,30 @@ mod tests { }; // write through the interface - let filename = store.write_streaming( - "fireburn.ru", - metadata, stream - ).await.unwrap(); + let filename = store + .write_streaming("fireburn.ru", metadata, stream) + .await + .unwrap(); tracing::debug!("Writing complete."); // Ensure the file is there - let content = tokio::fs::read( - tempdir.path() - .join("fireburn.ru") - .join(&filename) - ).await.unwrap(); + let content = tokio::fs::read(tempdir.path().join("fireburn.ru").join(&filename)) + .await + .unwrap(); assert_eq!(content, file); tracing::debug!("Reading range from the start..."); // try to read range let range = { - let stream = store.stream_range( - "fireburn.ru", &filename, - (Bound::Included(0), Bound::Included(299)) - ).await.unwrap(); + let stream = store + .stream_range( + "fireburn.ru", + &filename, + (Bound::Included(0), Bound::Included(299)), + ) + .await + .unwrap(); let mut reader = tokio_util::io::StreamReader::new(stream); @@ -308,10 +342,14 @@ mod tests { tracing::debug!("Reading range from the middle..."); let range = { - let stream = store.stream_range( - "fireburn.ru", &filename, - (Bound::Included(150), Bound::Included(449)) - ).await.unwrap(); + let stream = store + .stream_range( + "fireburn.ru", + &filename, + (Bound::Included(150), Bound::Included(449)), + ) + .await + .unwrap(); let mut reader = tokio_util::io::StreamReader::new(stream); @@ -326,13 +364,17 @@ mod tests { tracing::debug!("Reading range from the end..."); let range = { - let stream = store.stream_range( - "fireburn.ru", &filename, - // Note: the `headers` crate parses bounds in a - // non-standard way, where unbounded start actually - // means getting things from the end... - (Bound::Unbounded, Bound::Included(300)) - ).await.unwrap(); + let stream = store + .stream_range( + "fireburn.ru", + &filename, + // Note: the `headers` crate parses bounds in a + // non-standard way, where unbounded start actually + // means getting things from the end... + (Bound::Unbounded, Bound::Included(300)), + ) + .await + .unwrap(); let mut reader = tokio_util::io::StreamReader::new(stream); @@ -343,15 +385,19 @@ mod tests { }; assert_eq!(range.len(), 300); - assert_eq!(range.as_slice(), &file[file.len()-300..file.len()]); + assert_eq!(range.as_slice(), &file[file.len() - 300..file.len()]); tracing::debug!("Reading the whole file..."); // try to read range let range = { - let stream = store.stream_range( - "fireburn.ru", &("/".to_string() + &filename), - (Bound::Unbounded, Bound::Unbounded) - ).await.unwrap(); + let stream = store + .stream_range( + "fireburn.ru", + &("/".to_string() + &filename), + (Bound::Unbounded, Bound::Unbounded), + ) + .await + .unwrap(); let mut reader = tokio_util::io::StreamReader::new(stream); @@ -365,15 +411,19 @@ mod tests { assert_eq!(range.as_slice(), file); } - #[tokio::test] #[tracing_test::traced_test] async fn test_streaming_read_write() { let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); - let store = FileStore { base: tempdir.path().to_path_buf() }; + let store = FileStore { + base: tempdir.path().to_path_buf(), + }; let file: &[u8] = include_bytes!("./file.rs"); - let stream = tokio_stream::iter(file.chunks(100).map(|i| Ok(bytes::Bytes::copy_from_slice(i)))); + let stream = tokio_stream::iter( + file.chunks(100) + .map(|i| Ok(bytes::Bytes::copy_from_slice(i))), + ); let metadata = Metadata { filename: Some("style.css".to_string()), content_type: Some("text/css".to_string()), @@ -382,27 +432,32 @@ mod tests { }; // write through the interface - let filename = store.write_streaming( - "fireburn.ru", - metadata, stream - ).await.unwrap(); - println!("{}, {}", filename, tempdir.path() - .join("fireburn.ru") - .join(&filename) - .display()); - let content = tokio::fs::read( - tempdir.path() - .join("fireburn.ru") - .join(&filename) - ).await.unwrap(); + let filename = store + .write_streaming("fireburn.ru", metadata, stream) + .await + .unwrap(); + println!( + "{}, {}", + filename, + tempdir.path().join("fireburn.ru").join(&filename).display() + ); + let content = tokio::fs::read(tempdir.path().join("fireburn.ru").join(&filename)) + .await + .unwrap(); assert_eq!(content, file); // check internal metadata format - let meta: Metadata = serde_json::from_slice(&tokio::fs::read( - tempdir.path() - .join("fireburn.ru") - .join(filename.clone() + ".json") - ).await.unwrap()).unwrap(); + let meta: Metadata = serde_json::from_slice( + &tokio::fs::read( + tempdir + .path() + .join("fireburn.ru") + .join(filename.clone() + ".json"), + ) + .await + .unwrap(), + ) + .unwrap(); assert_eq!(meta.content_type.as_deref(), Some("text/css")); assert_eq!(meta.filename.as_deref(), Some("style.css")); assert_eq!(meta.length.map(|i| i.get()), Some(file.len())); @@ -410,10 +465,10 @@ mod tests { // read back the data using the interface let (metadata, read_back) = { - let (metadata, stream) = store.read_streaming( - "fireburn.ru", - &filename - ).await.unwrap(); + let (metadata, stream) = store + .read_streaming("fireburn.ru", &filename) + .await + .unwrap(); let mut reader = tokio_util::io::StreamReader::new(stream); let mut buf = Vec::default(); @@ -427,6 +482,5 @@ mod tests { assert_eq!(meta.filename.as_deref(), Some("style.css")); assert_eq!(meta.length.map(|i| i.get()), Some(file.len())); assert!(meta.etag.is_some()); - } } diff --git a/src/media/storage/mod.rs b/src/media/storage/mod.rs index 3583247..5658071 100644 --- a/src/media/storage/mod.rs +++ b/src/media/storage/mod.rs @@ -1,12 +1,12 @@ use axum::extract::multipart::Field; -use tokio_stream::Stream; use bytes::Bytes; use serde::{Deserialize, Serialize}; +use std::fmt::Debug; use std::future::Future; +use std::num::NonZeroUsize; use std::ops::Bound; use std::pin::Pin; -use std::fmt::Debug; -use std::num::NonZeroUsize; +use tokio_stream::Stream; pub mod file; @@ -24,17 +24,14 @@ pub struct Metadata { impl From<&Field<'_>> for Metadata { fn from(field: &Field<'_>) -> Self { Self { - content_type: field.content_type() - .map(|i| i.to_owned()), - filename: field.file_name() - .map(|i| i.to_owned()), + content_type: field.content_type().map(|i| i.to_owned()), + filename: field.file_name().map(|i| i.to_owned()), length: None, etag: None, } } } - #[derive(Debug, Clone, Copy)] pub enum ErrorKind { Backend, @@ -95,88 +92,116 @@ pub trait MediaStore: 'static + Send + Sync + Clone { content: T, ) -> impl Future<Output = Result<String>> + Send where - T: tokio_stream::Stream<Item = std::result::Result<bytes::Bytes, axum::extract::multipart::MultipartError>> + Unpin + Send + Debug; + T: tokio_stream::Stream< + Item = std::result::Result<bytes::Bytes, axum::extract::multipart::MultipartError>, + > + Unpin + + Send + + Debug; #[allow(clippy::type_complexity)] fn read_streaming( &self, domain: &str, filename: &str, - ) -> impl Future<Output = Result< - (Metadata, Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>) - >> + Send; + ) -> impl Future< + Output = Result<( + Metadata, + Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>, + )>, + > + Send; fn stream_range( &self, domain: &str, filename: &str, - range: (Bound<u64>, Bound<u64>) - ) -> impl Future<Output = Result<Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>>> + Send { async move { - use futures::stream::TryStreamExt; - use tracing::debug; - let (metadata, mut stream) = self.read_streaming(domain, filename).await?; - let length = metadata.length.unwrap().get(); - - use Bound::*; - let (start, end): (usize, usize) = match range { - (Unbounded, Unbounded) => return Ok(stream), - (Included(start), Unbounded) => (start.try_into().unwrap(), length - 1), - (Unbounded, Included(end)) => (length - usize::try_from(end).unwrap(), length - 1), - (Included(start), Included(end)) => (start.try_into().unwrap(), end.try_into().unwrap()), - (_, _) => unreachable!() - }; - - stream = Box::pin( - stream.map_ok({ - let mut bytes_skipped = 0usize; - let mut bytes_read = 0usize; - - move |chunk| { - debug!("Skipped {}/{} bytes, chunk len {}", bytes_skipped, start, chunk.len()); - let chunk = if bytes_skipped < start { - let need_to_skip = start - bytes_skipped; - if chunk.len() < need_to_skip { - return None - } - debug!("Skipping {} bytes", need_to_skip); - bytes_skipped += need_to_skip; - - chunk.slice(need_to_skip..) - } else { - chunk - }; - - debug!("Read {} bytes from file, {} in this chunk", bytes_read, chunk.len()); - bytes_read += chunk.len(); - - if bytes_read > length { - if bytes_read - length > chunk.len() { - return None - } - debug!("Truncating last {} bytes", bytes_read - length); - return Some(chunk.slice(..chunk.len() - (bytes_read - length))) - } - - Some(chunk) + range: (Bound<u64>, Bound<u64>), + ) -> impl Future<Output = Result<Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>>> + Send + { + async move { + use futures::stream::TryStreamExt; + use tracing::debug; + let (metadata, mut stream) = self.read_streaming(domain, filename).await?; + let length = metadata.length.unwrap().get(); + + use Bound::*; + let (start, end): (usize, usize) = match range { + (Unbounded, Unbounded) => return Ok(stream), + (Included(start), Unbounded) => (start.try_into().unwrap(), length - 1), + (Unbounded, Included(end)) => (length - usize::try_from(end).unwrap(), length - 1), + (Included(start), Included(end)) => { + (start.try_into().unwrap(), end.try_into().unwrap()) } - }) - .try_skip_while(|x| std::future::ready(Ok(x.is_none()))) - .try_take_while(|x| std::future::ready(Ok(x.is_some()))) - .map_ok(|x| x.unwrap()) - ); + (_, _) => unreachable!(), + }; + + stream = Box::pin( + stream + .map_ok({ + let mut bytes_skipped = 0usize; + let mut bytes_read = 0usize; + + move |chunk| { + debug!( + "Skipped {}/{} bytes, chunk len {}", + bytes_skipped, + start, + chunk.len() + ); + let chunk = if bytes_skipped < start { + let need_to_skip = start - bytes_skipped; + if chunk.len() < need_to_skip { + return None; + } + debug!("Skipping {} bytes", need_to_skip); + bytes_skipped += need_to_skip; + + chunk.slice(need_to_skip..) + } else { + chunk + }; + + debug!( + "Read {} bytes from file, {} in this chunk", + bytes_read, + chunk.len() + ); + bytes_read += chunk.len(); + + if bytes_read > length { + if bytes_read - length > chunk.len() { + return None; + } + debug!("Truncating last {} bytes", bytes_read - length); + return Some(chunk.slice(..chunk.len() - (bytes_read - length))); + } + + Some(chunk) + } + }) + .try_skip_while(|x| std::future::ready(Ok(x.is_none()))) + .try_take_while(|x| std::future::ready(Ok(x.is_some()))) + .map_ok(|x| x.unwrap()), + ); - Ok(stream) - } } + Ok(stream) + } + } /// Read metadata for a file. /// /// The default implementation uses the `read_streaming` method /// and drops the stream containing file content. - fn metadata(&self, domain: &str, filename: &str) -> impl Future<Output = Result<Metadata>> + Send { async move { - self.read_streaming(domain, filename) - .await - .map(|(meta, _)| meta) - } } + fn metadata( + &self, + domain: &str, + filename: &str, + ) -> impl Future<Output = Result<Metadata>> + Send { + async move { + self.read_streaming(domain, filename) + .await + .map(|(meta, _)| meta) + } + } fn delete(&self, domain: &str, filename: &str) -> impl Future<Output = Result<()>> + Send; } diff --git a/src/micropub/mod.rs b/src/micropub/mod.rs index 8505ae5..5e11033 100644 --- a/src/micropub/mod.rs +++ b/src/micropub/mod.rs @@ -1,26 +1,26 @@ use std::collections::HashMap; +use std::sync::Arc; use url::Url; use util::NormalizedPost; -use std::sync::Arc; use crate::database::{MicropubChannel, Storage, StorageError}; use crate::indieauth::backend::AuthBackend; use crate::indieauth::User; use crate::micropub::util::form_to_mf2_json; -use axum::extract::{FromRef, Query, State}; use axum::body::Body as BodyStream; +use axum::extract::{FromRef, Query, State}; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; use axum_extra::extract::Host; use axum_extra::headers::ContentType; -use axum::response::{IntoResponse, Response}; use axum_extra::TypedHeader; -use axum::http::StatusCode; +use kittybox_indieauth::{Scope, TokenData}; +use kittybox_util::micropub::{Error as MicropubError, ErrorKind, QueryType}; use serde::{Deserialize, Serialize}; use serde_json::json; use tokio::sync::Mutex; use tokio::task::JoinSet; use tracing::{debug, error, info, warn}; -use kittybox_indieauth::{Scope, TokenData}; -use kittybox_util::micropub::{Error as MicropubError, ErrorKind, QueryType}; #[derive(Serialize, Deserialize, Debug)] pub struct MicropubQuery { @@ -35,7 +35,7 @@ impl From<StorageError> for MicropubError { crate::database::ErrorKind::NotFound => ErrorKind::NotFound, _ => ErrorKind::InternalServerError, }, - format!("backend error: {}", err) + format!("backend error: {}", err), ) } } @@ -59,7 +59,8 @@ fn populate_reply_context( array .iter() .map(|i| { - let mut item = i.as_str() + let mut item = i + .as_str() .and_then(|i| i.parse::<Url>().ok()) .and_then(|url| ctxs.get(&url)) .and_then(|ctx| ctx.mf2["items"].get(0)) @@ -69,7 +70,12 @@ fn populate_reply_context( if item.is_object() && (i != &item) { if let Some(props) = item["properties"].as_object_mut() { // Fixup the item: if it lacks a URL, add one. - if !props.get("url").and_then(serde_json::Value::as_array).map(|a| !a.is_empty()).unwrap_or(false) { + if !props + .get("url") + .and_then(serde_json::Value::as_array) + .map(|a| !a.is_empty()) + .unwrap_or(false) + { props.insert("url".to_owned(), json!([i.as_str()])); } } @@ -145,11 +151,14 @@ async fn background_processing<D: 'static + Storage>( .get("webmention") .and_then(|i| i.first().cloned()); - dbg!(Some((url.clone(), FetchedPostContext { - url, - mf2: serde_json::to_value(mf2).unwrap(), - webmention - }))) + dbg!(Some(( + url.clone(), + FetchedPostContext { + url, + mf2: serde_json::to_value(mf2).unwrap(), + webmention + } + ))) }) .collect::<HashMap<Url, FetchedPostContext>>() .await @@ -161,7 +170,11 @@ async fn background_processing<D: 'static + Storage>( }; for prop in context_props { if let Some(json) = populate_reply_context(&mf2, prop, &post_contexts) { - update.replace.as_mut().unwrap().insert(prop.to_owned(), json); + update + .replace + .as_mut() + .unwrap() + .insert(prop.to_owned(), json); } } if !update.replace.as_ref().unwrap().is_empty() { @@ -250,7 +263,7 @@ pub(crate) async fn _post<D: 'static + Storage>( if !user.check_scope(&Scope::Create) { return Err(MicropubError::from_static( ErrorKind::InvalidScope, - "Not enough privileges - try acquiring the \"create\" scope." + "Not enough privileges - try acquiring the \"create\" scope.", )); } @@ -264,7 +277,7 @@ pub(crate) async fn _post<D: 'static + Storage>( { return Err(MicropubError::from_static( ErrorKind::Forbidden, - "You're posting to a website that's not yours." + "You're posting to a website that's not yours.", )); } @@ -272,7 +285,7 @@ pub(crate) async fn _post<D: 'static + Storage>( if db.post_exists(&uid).await? { return Err(MicropubError::from_static( ErrorKind::AlreadyExists, - "UID clash was detected, operation aborted." + "UID clash was detected, operation aborted.", )); } // Save the post @@ -309,13 +322,18 @@ pub(crate) async fn _post<D: 'static + Storage>( } } - let reply = - IntoResponse::into_response((StatusCode::ACCEPTED, [("Location", uid.as_str())])); + let reply = IntoResponse::into_response((StatusCode::ACCEPTED, [("Location", uid.as_str())])); #[cfg(not(tokio_unstable))] - let _ = jobset.lock().await.spawn(background_processing(db, mf2, http)); + let _ = jobset + .lock() + .await + .spawn(background_processing(db, mf2, http)); #[cfg(tokio_unstable)] - let _ = jobset.lock().await.build_task() + let _ = jobset + .lock() + .await + .build_task() .name(format!("Kittybox background processing for post {}", uid.as_str()).as_str()) .spawn(background_processing(db, mf2, http)); @@ -333,7 +351,7 @@ enum ActionType { #[serde(untagged)] pub enum MicropubPropertyDeletion { Properties(Vec<String>), - Values(HashMap<String, Vec<serde_json::Value>>) + Values(HashMap<String, Vec<serde_json::Value>>), } #[derive(Serialize, Deserialize)] struct MicropubFormAction { @@ -347,7 +365,7 @@ pub struct MicropubAction { url: String, #[serde(flatten)] #[serde(skip_serializing_if = "Option::is_none")] - update: Option<MicropubUpdate> + update: Option<MicropubUpdate>, } #[derive(Serialize, Deserialize, Debug, Default)] @@ -362,39 +380,43 @@ pub struct MicropubUpdate { impl MicropubUpdate { pub fn check_validity(&self) -> Result<(), MicropubError> { if let Some(add) = &self.add { - if add.iter().map(|(k, _)| k.as_str()).any(|k| { - k.to_lowercase().as_str() == "uid" - }) { + if add + .iter() + .map(|(k, _)| k.as_str()) + .any(|k| k.to_lowercase().as_str() == "uid") + { return Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Update cannot modify the post UID" + "Update cannot modify the post UID", )); } } if let Some(replace) = &self.replace { - if replace.iter().map(|(k, _)| k.as_str()).any(|k| { - k.to_lowercase().as_str() == "uid" - }) { + if replace + .iter() + .map(|(k, _)| k.as_str()) + .any(|k| k.to_lowercase().as_str() == "uid") + { return Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Update cannot modify the post UID" + "Update cannot modify the post UID", )); } } let iter = match &self.delete { Some(MicropubPropertyDeletion::Properties(keys)) => { Some(Box::new(keys.iter().map(|k| k.as_str())) as Box<dyn Iterator<Item = &str>>) - }, + } Some(MicropubPropertyDeletion::Values(map)) => { Some(Box::new(map.iter().map(|(k, _)| k.as_str())) as Box<dyn Iterator<Item = &str>>) - }, + } None => None, }; if let Some(mut iter) = iter { if iter.any(|k| k.to_lowercase().as_str() == "uid") { return Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Update cannot modify the post UID" + "Update cannot modify the post UID", )); } } @@ -412,8 +434,9 @@ impl MicropubUpdate { } else if let Some(MicropubPropertyDeletion::Values(ref delete)) = self.delete { if let Some(props) = post["properties"].as_object_mut() { for (key, values) in delete { - if let Some(prop) = props.get_mut(key).and_then(serde_json::Value::as_array_mut) { - prop.retain(|v| { values.iter().all(|i| i != v) }) + if let Some(prop) = props.get_mut(key).and_then(serde_json::Value::as_array_mut) + { + prop.retain(|v| values.iter().all(|i| i != v)) } } } @@ -428,7 +451,10 @@ impl MicropubUpdate { if let Some(add) = self.add { if let Some(props) = post["properties"].as_object_mut() { for (key, value) in add { - if let Some(prop) = props.get_mut(&key).and_then(serde_json::Value::as_array_mut) { + if let Some(prop) = props + .get_mut(&key) + .and_then(serde_json::Value::as_array_mut) + { prop.extend_from_slice(value.as_slice()); } else { props.insert(key, serde_json::Value::Array(value)); @@ -445,7 +471,7 @@ impl From<MicropubFormAction> for MicropubAction { Self { action: a.action, url: a.url, - update: None + update: None, } } } @@ -458,10 +484,12 @@ async fn post_action<D: Storage, A: AuthBackend>( ) -> Result<(), MicropubError> { let uri = match action.url.parse::<hyper::Uri>() { Ok(uri) => uri, - Err(err) => return Err(MicropubError::new( - ErrorKind::InvalidRequest, - format!("url parsing error: {}", err) - )) + Err(err) => { + return Err(MicropubError::new( + ErrorKind::InvalidRequest, + format!("url parsing error: {}", err), + )) + } }; if uri.authority().unwrap() @@ -475,7 +503,7 @@ async fn post_action<D: Storage, A: AuthBackend>( { return Err(MicropubError::from_static( ErrorKind::Forbidden, - "Don't tamper with others' posts!" + "Don't tamper with others' posts!", )); } @@ -484,7 +512,7 @@ async fn post_action<D: Storage, A: AuthBackend>( if !user.check_scope(&Scope::Delete) { return Err(MicropubError::from_static( ErrorKind::InvalidScope, - "You need a \"delete\" scope for this." + "You need a \"delete\" scope for this.", )); } @@ -494,7 +522,7 @@ async fn post_action<D: Storage, A: AuthBackend>( if !user.check_scope(&Scope::Update) { return Err(MicropubError::from_static( ErrorKind::InvalidScope, - "You need an \"update\" scope for this." + "You need an \"update\" scope for this.", )); } @@ -503,7 +531,7 @@ async fn post_action<D: Storage, A: AuthBackend>( } else { return Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Update request is not set." + "Update request is not set.", )); }; @@ -555,7 +583,7 @@ async fn dispatch_body( } else { Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Invalid JSON object passed." + "Invalid JSON object passed.", )) } } else if content_type == ContentType::form_url_encoded() { @@ -566,7 +594,7 @@ async fn dispatch_body( } else { Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Invalid form-encoded data. Try h=entry&content=Hello!" + "Invalid form-encoded data. Try h=entry&content=Hello!", )) } } else { @@ -605,7 +633,10 @@ pub(crate) async fn post<D: Storage + 'static, A: AuthBackend>( #[tracing::instrument(skip(db))] pub(crate) async fn query<D: Storage, A: AuthBackend>( State(db): State<D>, - query: Result<Query<MicropubQuery>, <Query<MicropubQuery> as axum::extract::FromRequestParts<()>>::Rejection>, + query: Result< + Query<MicropubQuery>, + <Query<MicropubQuery> as axum::extract::FromRequestParts<()>>::Rejection, + >, Host(host): Host, user: User<A>, ) -> axum::response::Response { @@ -616,8 +647,9 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( } else { return MicropubError::from_static( ErrorKind::InvalidRequest, - "Invalid query provided. Try ?q=config to see what you can do." - ).into_response(); + "Invalid query provided. Try ?q=config to see what you can do.", + ) + .into_response(); }; if axum::http::Uri::try_from(user.me.as_str()) @@ -630,7 +662,7 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( ErrorKind::NotAuthorized, "This website doesn't belong to you.", ) - .into_response(); + .into_response(); } // TODO: consider replacing by `user.me.authority()`? @@ -644,7 +676,7 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( ErrorKind::InternalServerError, format!("Error fetching channels: {}", err), ) - .into_response() + .into_response() } }; @@ -654,35 +686,36 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( QueryType::Config, QueryType::Channel, QueryType::SyndicateTo, - QueryType::Category + QueryType::Category, ], channels: Some(channels), syndicate_to: None, media_endpoint: Some(user.me.join("/.kittybox/media").unwrap()), other: { let mut map = std::collections::HashMap::new(); - map.insert("kittybox_authority".to_string(), serde_json::Value::String(user.me.to_string())); + map.insert( + "kittybox_authority".to_string(), + serde_json::Value::String(user.me.to_string()), + ); map - } + }, }) - .into_response() + .into_response() } QueryType::Source => { match query.url { - Some(url) => { - match db.get_post(&url).await { - Ok(some) => match some { - Some(post) => axum::response::Json(&post).into_response(), - None => MicropubError::from_static( - ErrorKind::NotFound, - "The specified MF2 object was not found in database.", - ) - .into_response(), - }, - Err(err) => MicropubError::from(err).into_response(), - } - } + Some(url) => match db.get_post(&url).await { + Ok(some) => match some { + Some(post) => axum::response::Json(&post).into_response(), + None => MicropubError::from_static( + ErrorKind::NotFound, + "The specified MF2 object was not found in database.", + ) + .into_response(), + }, + Err(err) => MicropubError::from(err).into_response(), + }, None => { // Here, one should probably attempt to query at least the main feed and collect posts // Using a pre-made query function can't be done because it does unneeded filtering @@ -691,7 +724,7 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( ErrorKind::InvalidRequest, "Querying for post list is not implemented yet.", ) - .into_response() + .into_response() } } } @@ -701,46 +734,45 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( ErrorKind::InternalServerError, format!("error fetching channels: backend error: {}", err), ) - .into_response(), + .into_response(), }, QueryType::SyndicateTo => { axum::response::Json(json!({ "syndicate-to": [] })).into_response() - }, + } QueryType::Category => { let categories = match db.categories(user_domain).await { Ok(categories) => categories, Err(err) => { return MicropubError::new( ErrorKind::InternalServerError, - format!("error fetching categories: backend error: {}", err) - ).into_response() + format!("error fetching categories: backend error: {}", err), + ) + .into_response() } }; axum::response::Json(json!({ "categories": categories })).into_response() - }, - QueryType::Unknown(q) => return MicropubError::new( - ErrorKind::InvalidRequest, - format!("Invalid query: {}", q) - ).into_response(), + } + QueryType::Unknown(q) => { + return MicropubError::new(ErrorKind::InvalidRequest, format!("Invalid query: {}", q)) + .into_response() + } } } - pub fn router<A, S, St: Send + Sync + Clone + 'static>() -> axum::routing::MethodRouter<St> where S: Storage + FromRef<St> + 'static, A: AuthBackend + FromRef<St>, reqwest_middleware::ClientWithMiddleware: FromRef<St>, - Arc<Mutex<JoinSet<()>>>: FromRef<St> + Arc<Mutex<JoinSet<()>>>: FromRef<St>, { axum::routing::get(query::<S, A>) .post(post::<S, A>) - .layer::<_, _>(tower_http::cors::CorsLayer::new() - .allow_methods([ - axum::http::Method::GET, - axum::http::Method::POST, - ]) - .allow_origin(tower_http::cors::Any)) + .layer::<_, _>( + tower_http::cors::CorsLayer::new() + .allow_methods([axum::http::Method::GET, axum::http::Method::POST]) + .allow_origin(tower_http::cors::Any), + ) } #[cfg(test)] @@ -765,16 +797,19 @@ impl MicropubQuery { mod tests { use std::sync::Arc; - use crate::{database::Storage, micropub::{util::NormalizedPost, MicropubError}}; + use crate::{ + database::Storage, + micropub::{util::NormalizedPost, MicropubError}, + }; use bytes::Bytes; use futures::StreamExt; use serde_json::json; use tokio::sync::Mutex; use super::FetchedPostContext; - use kittybox_indieauth::{Scopes, Scope, TokenData}; use axum::extract::State; use axum_extra::extract::Host; + use kittybox_indieauth::{Scope, Scopes, TokenData}; #[test] fn test_populate_reply_context() { @@ -801,16 +836,27 @@ mod tests { } }); let fetched_ctx_url: url::Url = "https://fireburn.ru/posts/example".parse().unwrap(); - let reply_contexts = vec![(fetched_ctx_url.clone(), FetchedPostContext { - url: fetched_ctx_url.clone(), - mf2: json!({ "items": [test_ctx] }), - webmention: None, - })].into_iter().collect(); + let reply_contexts = vec![( + fetched_ctx_url.clone(), + FetchedPostContext { + url: fetched_ctx_url.clone(), + mf2: json!({ "items": [test_ctx] }), + webmention: None, + }, + )] + .into_iter() + .collect(); let like_of = super::populate_reply_context(&mf2, "like-of", &reply_contexts).unwrap(); - assert_eq!(like_of[0]["properties"]["content"], test_ctx["properties"]["content"]); - assert_eq!(like_of[0]["properties"]["url"][0].as_str().unwrap(), reply_contexts[&fetched_ctx_url].url.as_str()); + assert_eq!( + like_of[0]["properties"]["content"], + test_ctx["properties"]["content"] + ); + assert_eq!( + like_of[0]["properties"]["url"][0].as_str().unwrap(), + reply_contexts[&fetched_ctx_url].url.as_str() + ); assert_eq!(like_of[1], already_expanded_reply_ctx); assert_eq!(like_of[2], "https://fireburn.ru/posts/non-existent"); @@ -830,20 +876,21 @@ mod tests { me: "https://localhost:8080/".parse().unwrap(), client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), scope: Scopes::new(vec![Scope::Profile]), - iat: None, exp: None + iat: None, + exp: None, }; let NormalizedPost { id, post } = super::normalize_mf2(post, &user); let err = super::_post( - &user, id, post, db.clone(), - reqwest_middleware::ClientWithMiddleware::new( - reqwest::Client::new(), - Box::default() - ), - Arc::new(Mutex::new(tokio::task::JoinSet::new())) + &user, + id, + post, + db.clone(), + reqwest_middleware::ClientWithMiddleware::new(reqwest::Client::new(), Box::default()), + Arc::new(Mutex::new(tokio::task::JoinSet::new())), ) - .await - .unwrap_err(); + .await + .unwrap_err(); assert_eq!(err.error, super::ErrorKind::InvalidScope); @@ -866,21 +913,27 @@ mod tests { let user = TokenData { me: "https://aaronparecki.com/".parse().unwrap(), client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), - scope: Scopes::new(vec![Scope::Profile, Scope::Create, Scope::Update, Scope::Media]), - iat: None, exp: None + scope: Scopes::new(vec![ + Scope::Profile, + Scope::Create, + Scope::Update, + Scope::Media, + ]), + iat: None, + exp: None, }; let NormalizedPost { id, post } = super::normalize_mf2(post, &user); let err = super::_post( - &user, id, post, db.clone(), - reqwest_middleware::ClientWithMiddleware::new( - reqwest::Client::new(), - Box::default() - ), - Arc::new(Mutex::new(tokio::task::JoinSet::new())) + &user, + id, + post, + db.clone(), + reqwest_middleware::ClientWithMiddleware::new(reqwest::Client::new(), Box::default()), + Arc::new(Mutex::new(tokio::task::JoinSet::new())), ) - .await - .unwrap_err(); + .await + .unwrap_err(); assert_eq!(err.error, super::ErrorKind::Forbidden); @@ -902,20 +955,21 @@ mod tests { me: "https://localhost:8080/".parse().unwrap(), client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), scope: Scopes::new(vec![Scope::Profile, Scope::Create]), - iat: None, exp: None + iat: None, + exp: None, }; let NormalizedPost { id, post } = super::normalize_mf2(post, &user); let res = super::_post( - &user, id, post, db.clone(), - reqwest_middleware::ClientWithMiddleware::new( - reqwest::Client::new(), - Box::default() - ), - Arc::new(Mutex::new(tokio::task::JoinSet::new())) + &user, + id, + post, + db.clone(), + reqwest_middleware::ClientWithMiddleware::new(reqwest::Client::new(), Box::default()), + Arc::new(Mutex::new(tokio::task::JoinSet::new())), ) - .await - .unwrap(); + .await + .unwrap(); assert!(res.headers().contains_key("Location")); let location = res.headers().get("Location").unwrap(); @@ -938,10 +992,17 @@ mod tests { TokenData { me: "https://fireburn.ru/".parse().unwrap(), client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), - scope: Scopes::new(vec![Scope::Profile, Scope::Create, Scope::Update, Scope::Media]), - iat: None, exp: None - }, std::marker::PhantomData - ) + scope: Scopes::new(vec![ + Scope::Profile, + Scope::Create, + Scope::Update, + Scope::Media, + ]), + iat: None, + exp: None, + }, + std::marker::PhantomData, + ), ) .await; @@ -954,7 +1015,10 @@ mod tests { .into_iter() .map(Result::unwrap) .by_ref() - .fold(Vec::new(), |mut a, i| { a.extend(i); a}); + .fold(Vec::new(), |mut a, i| { + a.extend(i); + a + }); let json: MicropubError = serde_json::from_slice(&body as &[u8]).unwrap(); assert_eq!(json.error, super::ErrorKind::NotAuthorized); } diff --git a/src/micropub/util.rs b/src/micropub/util.rs index 99aec8e..8c5d5e9 100644 --- a/src/micropub/util.rs +++ b/src/micropub/util.rs @@ -1,7 +1,7 @@ use crate::database::Storage; -use kittybox_indieauth::TokenData; use chrono::prelude::*; use core::iter::Iterator; +use kittybox_indieauth::TokenData; use newbase60::num_to_sxg; use serde_json::json; use std::convert::TryInto; @@ -35,7 +35,7 @@ fn reset_dt(post: &mut serde_json::Value) -> DateTime<FixedOffset> { pub struct NormalizedPost { pub id: String, - pub post: serde_json::Value + pub post: serde_json::Value, } pub fn normalize_mf2(mut body: serde_json::Value, user: &TokenData) -> NormalizedPost { @@ -142,12 +142,12 @@ pub fn normalize_mf2(mut body: serde_json::Value, user: &TokenData) -> Normalize } // If there is no explicit channels, and the post is not marked as "unlisted", // post it to one of the default channels that makes sense for the post type. - if body["properties"]["channel"][0].as_str().is_none() && (!body["properties"]["visibility"] - .as_array() - .map(|v| v.contains( - &serde_json::Value::String("unlisted".to_owned()) - )).unwrap_or(false) - ) { + if body["properties"]["channel"][0].as_str().is_none() + && (!body["properties"]["visibility"] + .as_array() + .map(|v| v.contains(&serde_json::Value::String("unlisted".to_owned()))) + .unwrap_or(false)) + { match body["type"][0].as_str() { Some("h-entry") => { // Set the channel to the main channel... @@ -249,7 +249,7 @@ mod tests { client_id: "https://quill.p3k.io/".parse().unwrap(), scope: kittybox_indieauth::Scopes::new(vec![kittybox_indieauth::Scope::Create]), exp: Some(u64::MAX), - iat: Some(0) + iat: Some(0), } } @@ -279,12 +279,15 @@ mod tests { } }); - let NormalizedPost { id: _, post: normalized } = normalize_mf2( - mf2.clone(), - &token_data() - ); + let NormalizedPost { + id: _, + post: normalized, + } = normalize_mf2(mf2.clone(), &token_data()); assert!( - normalized["properties"]["channel"].as_array().unwrap_or(&vec![]).is_empty(), + normalized["properties"]["channel"] + .as_array() + .unwrap_or(&vec![]) + .is_empty(), "Returned post was added to a channel despite the `unlisted` visibility" ); } @@ -300,10 +303,10 @@ mod tests { } }); - let NormalizedPost { id, post: normalized } = normalize_mf2( - mf2.clone(), - &token_data(), - ); + let NormalizedPost { + id, + post: normalized, + } = normalize_mf2(mf2.clone(), &token_data()); assert_eq!( normalized["properties"]["uid"][0], mf2["properties"]["uid"][0], "UID was replaced" @@ -325,10 +328,10 @@ mod tests { } }); - let NormalizedPost { id: _, post: normalized } = normalize_mf2( - mf2.clone(), - &token_data(), - ); + let NormalizedPost { + id: _, + post: normalized, + } = normalize_mf2(mf2.clone(), &token_data()); assert_eq!( normalized["properties"]["channel"], @@ -347,10 +350,10 @@ mod tests { } }); - let NormalizedPost { id: _, post: normalized } = normalize_mf2( - mf2.clone(), - &token_data(), - ); + let NormalizedPost { + id: _, + post: normalized, + } = normalize_mf2(mf2.clone(), &token_data()); assert_eq!( normalized["properties"]["channel"][0], @@ -367,10 +370,7 @@ mod tests { } }); - let NormalizedPost { id, post } = normalize_mf2( - mf2, - &token_data(), - ); + let NormalizedPost { id, post } = normalize_mf2(mf2, &token_data()); assert_eq!( post["properties"]["published"] .as_array() @@ -432,10 +432,7 @@ mod tests { }, }); - let NormalizedPost { id: _, post } = normalize_mf2( - mf2, - &token_data(), - ); + let NormalizedPost { id: _, post } = normalize_mf2(mf2, &token_data()); assert!( post["properties"]["url"] .as_array() @@ -461,10 +458,7 @@ mod tests { } }); - let NormalizedPost { id, post } = normalize_mf2( - mf2, - &token_data(), - ); + let NormalizedPost { id, post } = normalize_mf2(mf2, &token_data()); assert_eq!( post["properties"]["uid"][0], id, "UID of a post and its supposed location don't match" diff --git a/src/webmentions/check.rs b/src/webmentions/check.rs index 683cc6b..380f4db 100644 --- a/src/webmentions/check.rs +++ b/src/webmentions/check.rs @@ -1,7 +1,7 @@ -use std::rc::Rc; -use microformats::types::PropertyValue; use html5ever::{self, tendril::TendrilSink}; use kittybox_util::MentionType; +use microformats::types::PropertyValue; +use std::rc::Rc; // TODO: replace. mod rcdom; @@ -17,7 +17,11 @@ pub enum Error { } #[tracing::instrument] -pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url::Url, link: &url::Url) -> Result<Option<(MentionType, serde_json::Value)>, Error> { +pub fn check_mention( + document: impl AsRef<str> + std::fmt::Debug, + base_url: &url::Url, + link: &url::Url, +) -> Result<Option<(MentionType, serde_json::Value)>, Error> { tracing::debug!("Parsing MF2 markup..."); // First, check the document for MF2 markup let document = microformats::from_html(document.as_ref(), base_url.clone())?; @@ -29,8 +33,10 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url tracing::debug!("Processing item: {:?}", item); for (prop, interaction_type) in [ - ("in-reply-to", MentionType::Reply), ("like-of", MentionType::Like), - ("bookmark-of", MentionType::Bookmark), ("repost-of", MentionType::Repost) + ("in-reply-to", MentionType::Reply), + ("like-of", MentionType::Like), + ("bookmark-of", MentionType::Bookmark), + ("repost-of", MentionType::Repost), ] { if let Some(propvals) = item.properties.get(prop) { tracing::debug!("Has a u-{} property", prop); @@ -38,7 +44,10 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url if let PropertyValue::Url(url) = val { if url == link { tracing::debug!("URL matches! Webmention is valid"); - return Ok(Some((interaction_type, serde_json::to_value(item).unwrap()))) + return Ok(Some(( + interaction_type, + serde_json::to_value(item).unwrap(), + ))); } } } @@ -46,7 +55,9 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url } // Process `content` tracing::debug!("Processing e-content..."); - if let Some(PropertyValue::Fragment(content)) = item.properties.get("content") + if let Some(PropertyValue::Fragment(content)) = item + .properties + .get("content") .map(Vec::as_slice) .unwrap_or_default() .first() @@ -65,7 +76,8 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url // iteration of the loop. // // Empty list means all nodes were processed. - let mut unprocessed_nodes: Vec<Rc<rcdom::Node>> = root.children.borrow().iter().cloned().collect(); + let mut unprocessed_nodes: Vec<Rc<rcdom::Node>> = + root.children.borrow().iter().cloned().collect(); while !unprocessed_nodes.is_empty() { // "Take" the list out of its memory slot, replace it with an empty list let nodes = std::mem::take(&mut unprocessed_nodes); @@ -74,15 +86,23 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url // Add children nodes to the list for the next iteration unprocessed_nodes.extend(node.children.borrow().iter().cloned()); - if let rcdom::NodeData::Element { ref name, ref attrs, .. } = node.data { + if let rcdom::NodeData::Element { + ref name, + ref attrs, + .. + } = node.data + { // If it's not `<a>`, skip it - if name.local != *"a" { continue; } + if name.local != *"a" { + continue; + } let mut is_mention: bool = false; for attr in attrs.borrow().iter() { if attr.name.local == *"rel" { // Don't count `rel="nofollow"` links — a web crawler should ignore them // and so for purposes of driving visitors they are useless - if attr.value + if attr + .value .as_ref() .split([',', ' ']) .any(|v| v == "nofollow") @@ -92,7 +112,9 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url } } // if it's not `<a href="...">`, skip it - if attr.name.local != *"href" { continue; } + if attr.name.local != *"href" { + continue; + } // Be forgiving in parsing URLs, and resolve them against the base URL if let Ok(url) = base_url.join(attr.value.as_ref()) { if &url == link { @@ -101,12 +123,14 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url } } if is_mention { - return Ok(Some((MentionType::Mention, serde_json::to_value(item).unwrap()))); + return Ok(Some(( + MentionType::Mention, + serde_json::to_value(item).unwrap(), + ))); } } } } - } } diff --git a/src/webmentions/mod.rs b/src/webmentions/mod.rs index 91b274b..57f9a57 100644 --- a/src/webmentions/mod.rs +++ b/src/webmentions/mod.rs @@ -1,9 +1,14 @@ -use axum::{extract::{FromRef, State}, response::{IntoResponse, Response}, routing::post, Form}; use axum::http::StatusCode; +use axum::{ + extract::{FromRef, State}, + response::{IntoResponse, Response}, + routing::post, + Form, +}; use tracing::error; -use crate::database::{Storage, StorageError}; use self::queue::JobQueue; +use crate::database::{Storage, StorageError}; pub mod queue; #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -24,40 +29,46 @@ async fn accept_webmention<Q: JobQueue<Webmention>>( Form(webmention): Form<Webmention>, ) -> Response { if let Err(err) = webmention.source.parse::<url::Url>() { - return (StatusCode::BAD_REQUEST, err.to_string()).into_response() + return (StatusCode::BAD_REQUEST, err.to_string()).into_response(); } if let Err(err) = webmention.target.parse::<url::Url>() { - return (StatusCode::BAD_REQUEST, err.to_string()).into_response() + return (StatusCode::BAD_REQUEST, err.to_string()).into_response(); } match queue.put(&webmention).await { Ok(_id) => StatusCode::ACCEPTED.into_response(), - Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, [ - ("Content-Type", "text/plain") - ], err.to_string()).into_response() + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + [("Content-Type", "text/plain")], + err.to_string(), + ) + .into_response(), } } -pub fn router<St: Clone + Send + Sync + 'static, Q: JobQueue<Webmention> + FromRef<St>>() -> axum::Router<St> { - axum::Router::new() - .route("/.kittybox/webmention", post(accept_webmention::<Q>)) +pub fn router<St: Clone + Send + Sync + 'static, Q: JobQueue<Webmention> + FromRef<St>>( +) -> axum::Router<St> { + axum::Router::new().route("/.kittybox/webmention", post(accept_webmention::<Q>)) } #[derive(thiserror::Error, Debug)] pub enum SupervisorError { #[error("the task was explicitly cancelled")] - Cancelled + Cancelled, } -pub type SupervisedTask = tokio::task::JoinHandle<Result<std::convert::Infallible, SupervisorError>>; +pub type SupervisedTask = + tokio::task::JoinHandle<Result<std::convert::Infallible, SupervisorError>>; -pub fn supervisor<E, A, F>(mut f: F, cancellation_token: tokio_util::sync::CancellationToken) -> SupervisedTask +pub fn supervisor<E, A, F>( + mut f: F, + cancellation_token: tokio_util::sync::CancellationToken, +) -> SupervisedTask where E: std::error::Error + std::fmt::Debug + Send + 'static, A: std::future::Future<Output = Result<std::convert::Infallible, E>> + Send + 'static, - F: FnMut() -> A + Send + 'static + F: FnMut() -> A + Send + 'static, { - let supervisor_future = async move { loop { // Don't spawn the task if we are already cancelled, but @@ -65,7 +76,7 @@ where // crashed and we immediately received a cancellation // request after noticing the crashed task) if cancellation_token.is_cancelled() { - return Err(SupervisorError::Cancelled) + return Err(SupervisorError::Cancelled); } let task = tokio::task::spawn(f()); tokio::select! { @@ -87,7 +98,13 @@ where return tokio::task::spawn(supervisor_future); #[cfg(tokio_unstable)] return tokio::task::Builder::new() - .name(format!("supervisor for background task {}", std::any::type_name::<A>()).as_str()) + .name( + format!( + "supervisor for background task {}", + std::any::type_name::<A>() + ) + .as_str(), + ) .spawn(supervisor_future) .unwrap(); } @@ -99,39 +116,55 @@ enum Error<Q: std::error::Error + std::fmt::Debug + Send + 'static> { #[error("queue error: {0}")] Queue(#[from] Q), #[error("storage error: {0}")] - Storage(StorageError) + Storage(StorageError), } -async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 'static>(queue: Q, db: S, http: reqwest_middleware::ClientWithMiddleware) -> Result<std::convert::Infallible, Error<Q::Error>> { - use futures_util::StreamExt; +async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 'static>( + queue: Q, + db: S, + http: reqwest_middleware::ClientWithMiddleware, +) -> Result<std::convert::Infallible, Error<Q::Error>> { use self::queue::Job; + use futures_util::StreamExt; let mut stream = queue.into_stream().await?; while let Some(item) = stream.next().await.transpose()? { let job = item.job(); let (source, target) = ( job.source.parse::<url::Url>().unwrap(), - job.target.parse::<url::Url>().unwrap() + job.target.parse::<url::Url>().unwrap(), ); let (code, text) = match http.get(source.clone()).send().await { Ok(response) => { let code = response.status(); - if ![StatusCode::OK, StatusCode::GONE].iter().any(|i| i == &code) { - error!("error processing webmention: webpage fetch returned {}", code); + if ![StatusCode::OK, StatusCode::GONE] + .iter() + .any(|i| i == &code) + { + error!( + "error processing webmention: webpage fetch returned {}", + code + ); continue; } match response.text().await { Ok(text) => (code, text), Err(err) => { - error!("error processing webmention: error fetching webpage text: {}", err); - continue + error!( + "error processing webmention: error fetching webpage text: {}", + err + ); + continue; } } } Err(err) => { - error!("error processing webmention: error requesting webpage: {}", err); - continue + error!( + "error processing webmention: error requesting webpage: {}", + err + ); + continue; } }; @@ -150,7 +183,10 @@ async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 's continue; } Err(err) => { - error!("error processing webmention: error checking webmention: {}", err); + error!( + "error processing webmention: error checking webmention: {}", + err + ); continue; } }; @@ -158,31 +194,47 @@ async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 's { mention["type"] = serde_json::json!(["h-cite"]); - if !mention["properties"].as_object().unwrap().contains_key("uid") { - let url = mention["properties"]["url"][0].as_str().unwrap_or_else(|| target.as_str()).to_owned(); + if !mention["properties"] + .as_object() + .unwrap() + .contains_key("uid") + { + let url = mention["properties"]["url"][0] + .as_str() + .unwrap_or_else(|| target.as_str()) + .to_owned(); let props = mention["properties"].as_object_mut().unwrap(); - props.insert("uid".to_owned(), serde_json::Value::Array( - vec![serde_json::Value::String(url)]) + props.insert( + "uid".to_owned(), + serde_json::Value::Array(vec![serde_json::Value::String(url)]), ); } } - db.add_or_update_webmention(target.as_str(), mention_type, mention).await.map_err(Error::<Q::Error>::Storage)?; + db.add_or_update_webmention(target.as_str(), mention_type, mention) + .await + .map_err(Error::<Q::Error>::Storage)?; } } unreachable!() } -pub fn supervised_webmentions_task<St: Send + Sync + 'static, S: Storage + FromRef<St> + 'static, Q: JobQueue<Webmention> + FromRef<St> + 'static>( +pub fn supervised_webmentions_task< + St: Send + Sync + 'static, + S: Storage + FromRef<St> + 'static, + Q: JobQueue<Webmention> + FromRef<St> + 'static, +>( state: &St, - cancellation_token: tokio_util::sync::CancellationToken + cancellation_token: tokio_util::sync::CancellationToken, ) -> SupervisedTask -where reqwest_middleware::ClientWithMiddleware: FromRef<St> +where + reqwest_middleware::ClientWithMiddleware: FromRef<St>, { let queue = Q::from_ref(state); let storage = S::from_ref(state); let http = reqwest_middleware::ClientWithMiddleware::from_ref(state); - supervisor::<Error<Q::Error>, _, _>(move || process_webmentions_from_queue( - queue.clone(), storage.clone(), http.clone() - ), cancellation_token) + supervisor::<Error<Q::Error>, _, _>( + move || process_webmentions_from_queue(queue.clone(), storage.clone(), http.clone()), + cancellation_token, + ) } diff --git a/src/webmentions/queue.rs b/src/webmentions/queue.rs index 52bcdfa..a33de1a 100644 --- a/src/webmentions/queue.rs +++ b/src/webmentions/queue.rs @@ -6,7 +6,7 @@ use super::Webmention; static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("./migrations/webmention"); -pub use kittybox_util::queue::{JobQueue, JobItem, Job, JobStream}; +pub use kittybox_util::queue::{Job, JobItem, JobQueue, JobStream}; pub trait PostgresJobItem: JobItem + sqlx::FromRow<'static, sqlx::postgres::PgRow> { const DATABASE_NAME: &'static str; @@ -17,7 +17,7 @@ pub trait PostgresJobItem: JobItem + sqlx::FromRow<'static, sqlx::postgres::PgRo struct PostgresJobRow<T: PostgresJobItem> { id: Uuid, #[sqlx(flatten)] - job: T + job: T, } #[derive(Debug)] @@ -29,7 +29,6 @@ pub struct PostgresJob<T: PostgresJobItem> { runtime_handle: tokio::runtime::Handle, } - impl<T: PostgresJobItem> Drop for PostgresJob<T> { // This is an emulation of "async drop" — the struct retains a // runtime handle, which it uses to block on a future that does @@ -87,7 +86,9 @@ impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJob<Webmention> { fn job(&self) -> &Webmention { &self.job } - async fn done(mut self) -> Result<(), <PostgresJobQueue<Webmention> as JobQueue<Webmention>>::Error> { + async fn done( + mut self, + ) -> Result<(), <PostgresJobQueue<Webmention> as JobQueue<Webmention>>::Error> { tracing::debug!("Deleting {} from the job queue", self.id); sqlx::query("DELETE FROM kittybox_webmention.incoming_webmention_queue WHERE id = $1") .bind(self.id) @@ -100,13 +101,13 @@ impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJob<Webmention> { pub struct PostgresJobQueue<T> { db: sqlx::PgPool, - _phantom: std::marker::PhantomData<T> + _phantom: std::marker::PhantomData<T>, } impl<T> Clone for PostgresJobQueue<T> { fn clone(&self) -> Self { Self { db: self.db.clone(), - _phantom: std::marker::PhantomData + _phantom: std::marker::PhantomData, } } } @@ -120,15 +121,21 @@ impl PostgresJobQueue<Webmention> { sqlx::postgres::PgPoolOptions::new() .max_connections(50) .connect_with(options) - .await? - ).await - + .await?, + ) + .await } pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result<Self, sqlx::Error> { - db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox_webmention")).await?; + db.execute(sqlx::query( + "CREATE SCHEMA IF NOT EXISTS kittybox_webmention", + )) + .await?; MIGRATOR.run(&db).await?; - Ok(Self { db, _phantom: std::marker::PhantomData }) + Ok(Self { + db, + _phantom: std::marker::PhantomData, + }) } } @@ -180,13 +187,14 @@ impl JobQueue<Webmention> for PostgresJobQueue<Webmention> { Some(item) => return Ok(Some((item, ()))), None => { listener.lock().await.recv().await?; - continue + continue; } } } } } - }).boxed(); + }) + .boxed(); Ok(stream) } @@ -196,7 +204,7 @@ impl JobQueue<Webmention> for PostgresJobQueue<Webmention> { mod tests { use std::sync::Arc; - use super::{Webmention, PostgresJobQueue, Job, JobQueue, MIGRATOR}; + use super::{Job, JobQueue, PostgresJobQueue, Webmention, MIGRATOR}; use futures_util::StreamExt; #[sqlx::test(migrator = "MIGRATOR")] @@ -204,7 +212,7 @@ mod tests { async fn test_webmention_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> { let test_webmention = Webmention { source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(), - target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned() + target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned(), }; let queue = PostgresJobQueue::<Webmention>::from_pool(pool).await?; @@ -236,7 +244,7 @@ mod tests { match queue.get_one().await? { Some(item) => panic!("Unexpected item {:?} returned from job queue!", item), - None => Ok(()) + None => Ok(()), } } @@ -245,7 +253,7 @@ mod tests { async fn test_no_hangups_in_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> { let test_webmention = Webmention { source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(), - target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned() + target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned(), }; let queue = PostgresJobQueue::<Webmention>::from_pool(pool.clone()).await?; @@ -272,18 +280,18 @@ mod tests { } }); } - tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()).await.unwrap_err(); + tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()) + .await + .unwrap_err(); - let future = tokio::task::spawn( - tokio::time::timeout( - std::time::Duration::from_secs(10), async move { - stream.next().await.unwrap().unwrap() - } - ) - ); + let future = tokio::task::spawn(tokio::time::timeout( + std::time::Duration::from_secs(10), + async move { stream.next().await.unwrap().unwrap() }, + )); // Let the other task drop the guard it is holding barrier.wait().await; - let mut guard = future.await + let mut guard = future + .await .expect("Timeout on fetching item") .expect("Job queue error"); assert_eq!(guard.job(), &test_webmention); diff --git a/templates/build.rs b/templates/build.rs index 5a62855..057666b 100644 --- a/templates/build.rs +++ b/templates/build.rs @@ -22,8 +22,7 @@ fn main() -> Result<(), std::io::Error> { println!("cargo:rerun-if-changed=assets/"); let assets_path = std::path::Path::new("assets"); - let mut assets = WalkDir::new(assets_path) - .into_iter(); + let mut assets = WalkDir::new(assets_path).into_iter(); while let Some(Ok(entry)) = assets.next() { eprintln!("Processing {}", entry.path().display()); let out_path = out_dir.join(entry.path().strip_prefix(assets_path).unwrap()); @@ -31,11 +30,15 @@ fn main() -> Result<(), std::io::Error> { eprintln!("Creating directory {}", &out_path.display()); if let Err(err) = std::fs::create_dir(&out_path) { if err.kind() != std::io::ErrorKind::AlreadyExists { - return Err(err) + return Err(err); } } } else { - eprintln!("Copying {} to {}", entry.path().display(), out_path.display()); + eprintln!( + "Copying {} to {}", + entry.path().display(), + out_path.display() + ); std::fs::copy(entry.path(), &out_path)?; } } @@ -43,16 +46,11 @@ fn main() -> Result<(), std::io::Error> { let walker = WalkDir::new(&out_dir) .into_iter() .map(Result::unwrap) - .filter(|e| { - e.file_type().is_file() && e.path().extension().unwrap() != "gz" - }); + .filter(|e| e.file_type().is_file() && e.path().extension().unwrap() != "gz"); for entry in walker { let normal_path = entry.path(); let gzip_path = normal_path.with_extension({ - let mut extension = normal_path - .extension() - .unwrap() - .to_owned(); + let mut extension = normal_path.extension().unwrap().to_owned(); extension.push(OsStr::new(".gz")); extension }); diff --git a/templates/src/lib.rs b/templates/src/lib.rs index d9fe86b..fde0dab 100644 --- a/templates/src/lib.rs +++ b/templates/src/lib.rs @@ -7,23 +7,23 @@ pub use indieauth::AuthorizationRequestPage; mod login; pub use login::{LoginPage, LogoutPage}; mod mf2; -pub use mf2::{Entry, VCard, Feed, Food, POSTS_PER_PAGE}; +pub use mf2::{Entry, Feed, Food, VCard, POSTS_PER_PAGE}; pub mod admin; pub mod assets { - use axum::response::{IntoResponse, Response}; use axum::extract::Path; + use axum::http::header::{ + CACHE_CONTROL, CONTENT_ENCODING, CONTENT_TYPE, X_CONTENT_TYPE_OPTIONS, + }; use axum::http::StatusCode; - use axum::http::header::{CONTENT_TYPE, CONTENT_ENCODING, CACHE_CONTROL, X_CONTENT_TYPE_OPTIONS}; + use axum::response::{IntoResponse, Response}; const ASSETS: include_dir::Dir<'static> = include_dir::include_dir!("$OUT_DIR/"); const CACHE_FOR_A_DAY: &str = "max-age=86400"; const GZIP: &str = "gzip"; - pub async fn statics( - Path(path): Path<String> - ) -> Response { + pub async fn statics(Path(path): Path<String>) -> Response { let content_type: &'static str = if path.ends_with(".js") { "application/javascript" } else if path.ends_with(".css") { @@ -35,24 +35,30 @@ pub mod assets { }; match ASSETS.get_file(path.clone() + ".gz") { - Some(file) => (StatusCode::OK, - [ - (CONTENT_TYPE, content_type), - (CONTENT_ENCODING, GZIP), - (CACHE_CONTROL, CACHE_FOR_A_DAY), - (X_CONTENT_TYPE_OPTIONS, "nosniff") - ], - file.contents()).into_response(), + Some(file) => ( + StatusCode::OK, + [ + (CONTENT_TYPE, content_type), + (CONTENT_ENCODING, GZIP), + (CACHE_CONTROL, CACHE_FOR_A_DAY), + (X_CONTENT_TYPE_OPTIONS, "nosniff"), + ], + file.contents(), + ) + .into_response(), None => match ASSETS.get_file(path) { - Some(file) => (StatusCode::OK, - [ - (CONTENT_TYPE, content_type), - (CACHE_CONTROL, CACHE_FOR_A_DAY), - (X_CONTENT_TYPE_OPTIONS, "nosniff") - ], - file.contents()).into_response(), - None => StatusCode::NOT_FOUND.into_response() - } + Some(file) => ( + StatusCode::OK, + [ + (CONTENT_TYPE, content_type), + (CACHE_CONTROL, CACHE_FOR_A_DAY), + (X_CONTENT_TYPE_OPTIONS, "nosniff"), + ], + file.contents(), + ) + .into_response(), + None => StatusCode::NOT_FOUND.into_response(), + }, } } } @@ -107,11 +113,11 @@ mod tests { let dt = time::OffsetDateTime::now_utc() .to_offset( time::UtcOffset::from_hms( - rand::distributions::Uniform::new(-11, 12) - .sample(&mut rand::thread_rng()), + rand::distributions::Uniform::new(-11, 12).sample(&mut rand::thread_rng()), if rand::random::<bool>() { 0 } else { 30 }, - 0 - ).unwrap() + 0, + ) + .unwrap(), ) .format(&time::format_description::well_known::Rfc3339) .unwrap(); @@ -218,14 +224,15 @@ mod tests { // potentially with an offset? let offset = item.as_offset().unwrap().data; let date = item.as_date().unwrap().data; - let time = item.as_time().unwrap().data; + let time = item.as_time().unwrap().data; let dt = date.with_time(time).assume_offset(offset); let expected = time::OffsetDateTime::parse( mf2["properties"]["published"][0].as_str().unwrap(), - &time::format_description::well_known::Rfc3339 - ).unwrap(); - + &time::format_description::well_known::Rfc3339, + ) + .unwrap(); + assert_eq!(dt, expected); } else { unreachable!() @@ -235,7 +242,8 @@ mod tests { fn check_e_content(mf2: &serde_json::Value, item: &Item) { assert!(item.properties.contains_key("content")); - if let Some(PropertyValue::Fragment(content)) = item.properties.get("content").and_then(|v| v.first()) + if let Some(PropertyValue::Fragment(content)) = + item.properties.get("content").and_then(|v| v.first()) { assert_eq!( content.html, @@ -250,7 +258,11 @@ mod tests { fn test_note() { let mf2 = gen_random_post(&rand::random::<Domain>().to_string(), PostType::Note); - let html = crate::mf2::Entry { post: &mf2, from_feed: false, }.to_string(); + let html = crate::mf2::Entry { + post: &mf2, + from_feed: false, + } + .to_string(); let url: Url = mf2 .pointer("/properties/uid/0") @@ -259,7 +271,12 @@ mod tests { .unwrap(); let parsed: Document = microformats::from_html(&html, url.clone()).unwrap(); - if let Some(item) = parsed.into_iter().find(|i| i.properties.get("url").unwrap().contains(&PropertyValue::Url(url.clone()))) { + if let Some(item) = parsed.into_iter().find(|i| { + i.properties + .get("url") + .unwrap() + .contains(&PropertyValue::Url(url.clone())) + }) { let props = &item.properties; check_e_content(&mf2, &item); @@ -281,7 +298,11 @@ mod tests { #[test] fn test_article() { let mf2 = gen_random_post(&rand::random::<Domain>().to_string(), PostType::Article); - let html = crate::mf2::Entry { post: &mf2, from_feed: false, }.to_string(); + let html = crate::mf2::Entry { + post: &mf2, + from_feed: false, + } + .to_string(); let url: Url = mf2 .pointer("/properties/uid/0") .and_then(|i| i.as_str()) @@ -289,8 +310,12 @@ mod tests { .unwrap(); let parsed: Document = microformats::from_html(&html, url.clone()).unwrap(); - if let Some(item) = parsed.into_iter().find(|i| i.properties.get("url").unwrap().contains(&PropertyValue::Url(url.clone()))) { - + if let Some(item) = parsed.into_iter().find(|i| { + i.properties + .get("url") + .unwrap() + .contains(&PropertyValue::Url(url.clone())) + }) { check_e_content(&mf2, &item); check_dt_published(&mf2, &item); assert!(item.properties.contains_key("uid")); @@ -302,7 +327,9 @@ mod tests { .iter() .any(|i| i == item.properties.get("uid").and_then(|v| v.first()).unwrap())); assert!(item.properties.contains_key("name")); - if let Some(PropertyValue::Plain(name)) = item.properties.get("name").and_then(|v| v.first()) { + if let Some(PropertyValue::Plain(name)) = + item.properties.get("name").and_then(|v| v.first()) + { assert_eq!( name, mf2.pointer("/properties/name/0") @@ -338,7 +365,11 @@ mod tests { .and_then(|i| i.as_str()) .and_then(|u| u.parse().ok()) .unwrap(); - let html = crate::mf2::Entry { post: &mf2, from_feed: false, }.to_string(); + let html = crate::mf2::Entry { + post: &mf2, + from_feed: false, + } + .to_string(); let parsed: Document = microformats::from_html(&html, url.clone()).unwrap(); if let Some(item) = parsed.items.first() { diff --git a/templates/src/templates.rs b/templates/src/templates.rs index d2734f8..5772b4d 100644 --- a/templates/src/templates.rs +++ b/templates/src/templates.rs @@ -1,7 +1,7 @@ #![allow(clippy::needless_lifetimes)] +use crate::{Feed, VCard}; use http::StatusCode; use kittybox_util::micropub::Channel; -use crate::{Feed, VCard}; markup::define! { Template<'a>(title: &'a str, blog_name: &'a str, feeds: Vec<Channel>, user: Option<&'a kittybox_indieauth::ProfileUrl>, content: String) { diff --git a/tower-watchdog/src/lib.rs b/tower-watchdog/src/lib.rs index 9a5c609..e0be313 100644 --- a/tower-watchdog/src/lib.rs +++ b/tower-watchdog/src/lib.rs @@ -27,22 +27,45 @@ impl<S> tower_layer::Layer<S> for WatchdogLayer { fn layer(&self, inner: S) -> Self::Service { Self::Service { pet: self.pet.clone(), - inner + inner, } } } pub struct WatchdogService<S> { pet: watchdog::Pet, - inner: S + inner: S, } -impl<S: tower_service::Service<Request> + Clone + 'static, Request: std::fmt::Debug + 'static> tower_service::Service<Request> for WatchdogService<S> { +impl<S: tower_service::Service<Request> + Clone + 'static, Request: std::fmt::Debug + 'static> + tower_service::Service<Request> for WatchdogService<S> +{ type Response = S::Response; type Error = S::Error; - type Future = std::pin::Pin<Box<futures::future::Then<std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), tokio::sync::mpsc::error::SendError<()>>> + Send>>, std::pin::Pin<Box<S::Future>>, Box<dyn FnOnce(Result<(), tokio::sync::mpsc::error::SendError<()>>) -> std::pin::Pin<Box<S::Future>>>>>>; + type Future = std::pin::Pin< + Box< + futures::future::Then< + std::pin::Pin< + Box< + dyn std::future::Future< + Output = Result<(), tokio::sync::mpsc::error::SendError<()>>, + > + Send, + >, + >, + std::pin::Pin<Box<S::Future>>, + Box< + dyn FnOnce( + Result<(), tokio::sync::mpsc::error::SendError<()>>, + ) -> std::pin::Pin<Box<S::Future>>, + >, + >, + >, + >; - fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> { + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Result<(), Self::Error>> { self.inner.poll_ready(cx) } @@ -57,7 +80,11 @@ impl<S: tower_service::Service<Request> + Clone + 'static, Request: std::fmt::De std::mem::swap(&mut self.inner, &mut inner); let pet = self.pet.clone(); - Box::pin(pet.pet_owned().boxed().then(Box::new(move |_| Box::pin(inner.call(request))))) + Box::pin( + pet.pet_owned() + .boxed() + .then(Box::new(move |_| Box::pin(inner.call(request)))), + ) } } @@ -84,7 +111,10 @@ mod tests { for i in 100..=1_000 { if i != 1000 { assert!(mock.poll_ready().is_ready()); - let request = Box::pin(tokio::time::sleep(std::time::Duration::from_millis(i)).then(|()| mock.call(()))); + let request = Box::pin( + tokio::time::sleep(std::time::Duration::from_millis(i)) + .then(|()| mock.call(())), + ); tokio::select! { _ = &mut watchdog_future => panic!("Watchdog called earlier than response!"), _ = request => {}, @@ -94,7 +124,10 @@ mod tests { // We use `+ 1` here, because the watchdog behavior is // subject to a data race if a request arrives in the // same tick. - let request = Box::pin(tokio::time::sleep(std::time::Duration::from_millis(i + 1)).then(|()| mock.call(()))); + let request = Box::pin( + tokio::time::sleep(std::time::Duration::from_millis(i + 1)) + .then(|()| mock.call(())), + ); tokio::select! { _ = &mut watchdog_future => { }, diff --git a/util/src/fs.rs b/util/src/fs.rs index 6a7a5b4..ea9dadd 100644 --- a/util/src/fs.rs +++ b/util/src/fs.rs @@ -1,6 +1,6 @@ +use rand::{distributions::Alphanumeric, Rng}; use std::io::{self, Result}; use std::path::{Path, PathBuf}; -use rand::{Rng, distributions::Alphanumeric}; use tokio::fs; /// Create a temporary file named `temp.[a-zA-Z0-9]{length}` in @@ -20,7 +20,7 @@ use tokio::fs; pub async fn mktemp<T, B>(dir: T, basename: B, length: usize) -> Result<(PathBuf, fs::File)> where T: AsRef<Path>, - B: Into<Option<&'static str>> + B: Into<Option<&'static str>>, { let dir = dir.as_ref(); let basename = basename.into().unwrap_or(""); @@ -33,9 +33,9 @@ where if basename.is_empty() { "" } else { "." }, { let string = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(length) - .collect::<Vec<u8>>(); + .sample_iter(&Alphanumeric) + .take(length) + .collect::<Vec<u8>>(); String::from_utf8(string).unwrap() } )); @@ -49,8 +49,8 @@ where Ok(file) => return Ok((filename, file)), Err(err) => match err.kind() { io::ErrorKind::AlreadyExists => continue, - _ => return Err(err) - } + _ => return Err(err), + }, } } } diff --git a/util/src/lib.rs b/util/src/lib.rs index cb5f666..0c5df49 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -17,7 +17,7 @@ pub enum MentionType { Bookmark, /// A plain link without MF2 annotations. #[default] - Mention + Mention, } /// Common data-types useful in creating smart authentication systems. @@ -29,7 +29,7 @@ pub mod auth { /// used to recover from a lost passkey. Password, /// Denotes availability of one or more passkeys. - WebAuthn + WebAuthn, } } diff --git a/util/src/micropub.rs b/util/src/micropub.rs index 9d2c525..1f8008b 100644 --- a/util/src/micropub.rs +++ b/util/src/micropub.rs @@ -21,7 +21,7 @@ pub enum QueryType { Category, /// Unsupported query type // TODO: make this take a lifetime parameter for zero-copy deserialization if possible? - Unknown(std::borrow::Cow<'static, str>) + Unknown(std::borrow::Cow<'static, str>), } /// Data structure representing a Micropub channel in the ?q=channels output. @@ -42,7 +42,7 @@ pub struct SyndicationDestination { /// The syndication destination's UID, opaque to the client. pub uid: String, /// A human-friendly name. - pub name: String + pub name: String, } fn default_q_list() -> Vec<QueryType> { @@ -67,7 +67,7 @@ pub struct Config { pub media_endpoint: Option<url::Url>, /// Other unspecified keys, sometimes implementation-defined. #[serde(flatten)] - pub other: HashMap<String, serde_json::Value> + pub other: HashMap<String, serde_json::Value>, } #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] @@ -145,14 +145,17 @@ impl Error { pub const fn from_static(error: ErrorKind, error_description: &'static str) -> Self { Self { error, - error_description: Some(std::borrow::Cow::Borrowed(error_description)) + error_description: Some(std::borrow::Cow::Borrowed(error_description)), } } } impl From<ErrorKind> for Error { fn from(error: ErrorKind) -> Self { - Self { error, error_description: None } + Self { + error, + error_description: None, + } } } @@ -190,4 +193,3 @@ impl axum_core::response::IntoResponse for Error { )) } } - diff --git a/util/src/queue.rs b/util/src/queue.rs index edbec86..b32fdc5 100644 --- a/util/src/queue.rs +++ b/util/src/queue.rs @@ -1,5 +1,5 @@ -use std::future::Future; use futures_util::Stream; +use std::future::Future; use std::pin::Pin; use uuid::Uuid; @@ -44,7 +44,9 @@ pub trait JobQueue<T: JobItem>: Send + Sync + Sized + Clone + 'static { /// /// Note that one item may be returned several times if it is not /// marked as done. - fn into_stream(self) -> impl Future<Output = Result<JobStream<Self::Job, Self::Error>, Self::Error>> + Send; + fn into_stream( + self, + ) -> impl Future<Output = Result<JobStream<Self::Job, Self::Error>, Self::Error>> + Send; } /// A job description yielded from a job queue. |