diff options
58 files changed, 3272 insertions, 2548 deletions
diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000..0644048 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +5cf86bb7849f2b78711a5576bba15299613fe148 # cargo fmt +1e815637e3e15c7eb81b45b51b40253f3ec57ebb # kittybox-html: cargo fmt \ No newline at end of file diff --git a/.gitignore b/.gitignore index 6df432e..b660440 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ -/kittybox-rs/target -.direnv +/target +/.direnv result-* result dump.rdb @@ -10,9 +10,8 @@ dump.rdb *~ *.log *.log.json -/kittybox-rs/test-dir -/kittybox-rs/media-store -/kittybox-rs/auth-store -/kittybox-rs/fonts/* -/kittybox-rs/companion-lite/dist + +/media-store +/auth-store +/companion-lite/dist /token.txt diff --git a/.zed/settings.json b/.zed/settings.json new file mode 100644 index 0000000..09cfd7b --- /dev/null +++ b/.zed/settings.json @@ -0,0 +1,8 @@ +// Folder-specific settings +// +// For a full list of overridable settings, and general information on folder-specific settings, +// see the documentation: https://zed.dev/docs/configuring-zed#settings-files +{ + "format_on_save": "on", + "languages": { "Rust": { "format_on_save": "language_server" } } +} diff --git a/Cargo.lock b/Cargo.lock index bc54f91..50946fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1415,21 +1415,6 @@ dependencies = [ ] [[package]] -name = "html" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "944d7db81871c611549302f3014418fedbcfbc46902f97e6a1c4f53e785903d2" -dependencies = [ - "html-sys", -] - -[[package]] -name = "html-sys" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13eca55667a5657dd1b86db77c5fe2d1810e3f9413e9555a2c4c461733dd2573" - -[[package]] name = "html5ever" version = "0.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1945,31 +1930,8 @@ dependencies = [ ] [[package]] -name = "kittybox-html" -version = "0.2.0" -dependencies = [ - "axum", - "chrono", - "ellipse", - "faker_rand", - "html", - "http", - "include_dir", - "kittybox-indieauth", - "kittybox-util", - "libflate", - "microformats", - "rand", - "serde_json", - "thiserror 2.0.9", - "time", - "url", - "walkdir", -] - -[[package]] name = "kittybox-indieauth" -version = "0.2.0" +version = "0.3.3" dependencies = [ "axum-core", "data-encoding", diff --git a/Cargo.toml b/Cargo.toml index bf14ded..141016e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,11 @@ unexpected_cfgs = { level = "warn", check-cfg = ['cfg(tokio_unstable)'] } [features] default = ["rustls", "postgres"] webauthn = ["openssl", "dep:webauthn"] -openssl = ["reqwest/native-tls", "reqwest/native-tls-alpn", "sqlx/tls-native-tls"] +openssl = [ + "reqwest/native-tls", + "reqwest/native-tls-alpn", + "sqlx/tls-native-tls", +] rustls = ["reqwest/rustls-tls-webpki-roots", "sqlx/tls-rustls"] cli = ["dep:clap", "dep:anyhow"] postgres = ["sqlx", "kittybox-util/sqlx"] @@ -55,7 +59,7 @@ path = "examples/sql.rs" required-features = ["sqlparser"] [workspace] -members = [".", "./util", "./templates", "./indieauth", "./templates-neo", "./tower-watchdog"] +members = [".", "./util", "./templates", "./indieauth", "./tower-watchdog"] default-members = [".", "./util", "./templates", "./indieauth"] [workspace.dependencies] @@ -68,7 +72,6 @@ ellipse = "0.2.0" faker_rand = "0.1.1" futures = "0.3.31" futures-util = "0.3.31" -html = "0.6.3" http = "1.2" include_dir = "0.7.4" libflate = "2.1.0" @@ -107,7 +110,7 @@ features = ["fs", "axum"] version = "0.1.0" path = "./templates" [dependencies.kittybox-indieauth] -version = "0.2.0" +version = "0.3.0" path = "./indieauth" features = ["axum"] @@ -124,7 +127,11 @@ wiremock = "0.6.2" anyhow = { version = "1.0.95", optional = true } argon2 = { version = "0.5.3", features = ["std"] } axum = { workspace = true, features = ["multipart", "json", "form", "macros"] } -axum-extra = { version = "0.10.0", features = ["cookie", "cookie-signed", "typed-header"] } +axum-extra = { version = "0.10.0", features = [ + "cookie", + "cookie-signed", + "typed-header", +] } bytes = "1.9.0" chrono = { workspace = true } clap = { workspace = true, features = ["derive"], optional = true } @@ -133,7 +140,9 @@ either = "1.13.0" futures = { workspace = true } futures-util = { workspace = true } html5ever = "=0.27.0" -http-cache-reqwest = { version = "0.15.0", default-features = false, features = ["manager-moka"] } +http-cache-reqwest = { version = "0.15.0", default-features = false, features = [ + "manager-moka", +] } hyper = "1.5.2" lazy_static = "1.5.0" listenfd = "1.0.1" @@ -143,27 +152,51 @@ mime = "0.3.17" newbase60 = "0.1.4" prometheus = { version = "0.13.4", features = ["process"] } rand = { workspace = true } -redis = { version = "0.27.6", features = ["aio", "tokio-comp"], optional = true } +redis = { version = "0.27.6", features = [ + "aio", + "tokio-comp", +], optional = true } relative-path = "1.9.3" -reqwest = { version = "0.12.12", default-features = false, features = ["gzip", "brotli", "json", "stream"] } +reqwest = { version = "0.12.12", default-features = false, features = [ + "gzip", + "brotli", + "json", + "stream", +] } reqwest-middleware = "0.4.0" serde = { workspace = true } serde_json = { workspace = true } serde_urlencoded = { workspace = true } serde_variant = { workspace = true } sha2 = { workspace = true } -sqlparser = { version = "0.53.0", features = ["serde", "serde_json"], optional = true } -sqlx = { workspace = true, features = ["uuid", "chrono", "postgres", "runtime-tokio"], optional = true } +sqlparser = { version = "0.53.0", features = [ + "serde", + "serde_json", +], optional = true } +sqlx = { workspace = true, features = [ + "uuid", + "chrono", + "postgres", + "runtime-tokio", +], optional = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["full", "tracing"] } tokio-stream = { workspace = true, features = ["time", "net"] } tokio-util = { workspace = true, features = ["io-util"] } tower = { workspace = true, features = ["tracing"] } -tower-http = { version = "0.6.2", features = ["trace", "cors", "catch-panic", "sensitive-headers", "set-header"] } +tower-http = { version = "0.6.2", features = [ + "trace", + "cors", + "catch-panic", + "sensitive-headers", + "set-header", +] } tracing = { workspace = true, features = [] } tracing-log = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter", "json"] } tracing-tree = { workspace = true } url = { workspace = true } uuid = { workspace = true, features = ["v4"] } -webauthn = { version = "0.5.0", package = "webauthn-rs", features = ["danger-allow-state-serialisation"], optional = true } +webauthn = { version = "0.5.0", package = "webauthn-rs", features = [ + "danger-allow-state-serialisation", +], optional = true } diff --git a/build.rs b/build.rs index 05eca7a..5db39e0 100644 --- a/build.rs +++ b/build.rs @@ -22,9 +22,6 @@ fn main() { } let companion_in = std::path::Path::new("companion-lite"); for file in ["index.html", "style.css"] { - std::fs::copy( - companion_in.join(file), - &companion_out.join(file) - ).unwrap(); + std::fs::copy(companion_in.join(file), companion_out.join(file)).unwrap(); } } diff --git a/examples/password-hasher.rs b/examples/password-hasher.rs index 92de7f7..3c88a40 100644 --- a/examples/password-hasher.rs +++ b/examples/password-hasher.rs @@ -1,6 +1,9 @@ use std::io::Write; -use argon2::{Argon2, password_hash::{rand_core::OsRng, PasswordHasher, PasswordHash, PasswordVerifier, SaltString}}; +use argon2::{ + password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, + Argon2, +}; fn main() -> std::io::Result<()> { eprint!("Type a password: "); @@ -15,19 +18,19 @@ fn main() -> std::io::Result<()> { let salt = SaltString::generate(&mut OsRng); let argon2 = Argon2::default(); //eprintln!("{}", password.trim()); - let password_hash = argon2.hash_password(password.trim().as_bytes(), &salt) + let password_hash = argon2 + .hash_password(password.trim().as_bytes(), &salt) .expect("Hashing a password should not error out") .serialize(); println!("{}", password_hash.as_str()); assert!(Argon2::default() - .verify_password( - password.trim().as_bytes(), - &PasswordHash::new(password_hash.as_str()) - .expect("Password hash should be valid") - ).is_ok() - ); + .verify_password( + password.trim().as_bytes(), + &PasswordHash::new(password_hash.as_str()).expect("Password hash should be valid") + ) + .is_ok()); Ok(()) } diff --git a/examples/sql.rs b/examples/sql.rs index 5d552da..59db1eb 100644 --- a/examples/sql.rs +++ b/examples/sql.rs @@ -15,8 +15,11 @@ fn sanitize(expr: &Expr) -> Result<(), Error> { match expr { Expr::Identifier(_) => Ok(()), Expr::CompoundIdentifier(_) => Ok(()), - Expr::JsonAccess { left, operator: _, right } => sanitize(left) - .and(sanitize(right)), + Expr::JsonAccess { + left, + operator: _, + right, + } => sanitize(left).and(sanitize(right)), Expr::CompositeAccess { expr, key: _ } => sanitize(expr), Expr::IsFalse(subexpr) => sanitize(subexpr), Expr::IsNotFalse(subexpr) => sanitize(subexpr), @@ -26,84 +29,180 @@ fn sanitize(expr: &Expr) -> Result<(), Error> { Expr::IsNotNull(subexpr) => sanitize(subexpr), Expr::IsUnknown(subexpr) => sanitize(subexpr), Expr::IsNotUnknown(subexpr) => sanitize(subexpr), - Expr::IsDistinctFrom(left, right) => sanitize(left) - .and(sanitize(right)), - Expr::IsNotDistinctFrom(left, right) => sanitize(left) - .and(sanitize(right)), - Expr::InList { expr, list, negated: _ } => sanitize(expr) - .and(list.iter().try_for_each(sanitize)), - Expr::InSubquery { expr: _, subquery, negated: _ } => Err(Error::SubqueryDetected(subquery.as_ref())), - Expr::InUnnest { expr, array_expr, negated: _ } => sanitize(expr).and(sanitize(array_expr)), - Expr::Between { expr, negated: _, low, high } => sanitize(expr) - .and(sanitize(low)) - .and(sanitize(high)), - Expr::BinaryOp { left, op: _, right } => sanitize(left) - .and(sanitize(right)), - Expr::Like { negated: _, expr, pattern, escape_char: _ } => sanitize(expr) - .and(sanitize(pattern)), - Expr::ILike { negated: _, expr, pattern, escape_char: _ } => sanitize(expr).and(sanitize(pattern)), - Expr::SimilarTo { negated: _, expr, pattern, escape_char: _ } => sanitize(expr).and(sanitize(pattern)), - Expr::RLike { negated: _, expr, pattern, regexp: _ } => sanitize(expr).and(sanitize(pattern)), - Expr::AnyOp { left, compare_op: _, right } => sanitize(left).and(sanitize(right)), - Expr::AllOp { left, compare_op: _, right } => sanitize(left).and(sanitize(right)), + Expr::IsDistinctFrom(left, right) => sanitize(left).and(sanitize(right)), + Expr::IsNotDistinctFrom(left, right) => sanitize(left).and(sanitize(right)), + Expr::InList { + expr, + list, + negated: _, + } => sanitize(expr).and(list.iter().try_for_each(sanitize)), + Expr::InSubquery { + expr: _, + subquery, + negated: _, + } => Err(Error::SubqueryDetected(subquery.as_ref())), + Expr::InUnnest { + expr, + array_expr, + negated: _, + } => sanitize(expr).and(sanitize(array_expr)), + Expr::Between { + expr, + negated: _, + low, + high, + } => sanitize(expr).and(sanitize(low)).and(sanitize(high)), + Expr::BinaryOp { left, op: _, right } => sanitize(left).and(sanitize(right)), + Expr::Like { + negated: _, + expr, + pattern, + escape_char: _, + } => sanitize(expr).and(sanitize(pattern)), + Expr::ILike { + negated: _, + expr, + pattern, + escape_char: _, + } => sanitize(expr).and(sanitize(pattern)), + Expr::SimilarTo { + negated: _, + expr, + pattern, + escape_char: _, + } => sanitize(expr).and(sanitize(pattern)), + Expr::RLike { + negated: _, + expr, + pattern, + regexp: _, + } => sanitize(expr).and(sanitize(pattern)), + Expr::AnyOp { + left, + compare_op: _, + right, + } => sanitize(left).and(sanitize(right)), + Expr::AllOp { + left, + compare_op: _, + right, + } => sanitize(left).and(sanitize(right)), Expr::UnaryOp { op: _, expr } => sanitize(expr), - Expr::Convert { expr, data_type: _, charset: _, target_before_value: _ } => sanitize(expr), - Expr::Cast { expr, data_type: _, format: _ } => sanitize(expr), - Expr::TryCast { expr, data_type: _, format: _ } => sanitize(expr), - Expr::SafeCast { expr, data_type: _, format: _ } => sanitize(expr), - Expr::AtTimeZone { timestamp, time_zone: _ } => sanitize(timestamp), + Expr::Convert { + expr, + data_type: _, + charset: _, + target_before_value: _, + } => sanitize(expr), + Expr::Cast { + expr, + data_type: _, + format: _, + } => sanitize(expr), + Expr::TryCast { + expr, + data_type: _, + format: _, + } => sanitize(expr), + Expr::SafeCast { + expr, + data_type: _, + format: _, + } => sanitize(expr), + Expr::AtTimeZone { + timestamp, + time_zone: _, + } => sanitize(timestamp), Expr::Extract { field: _, expr } => sanitize(expr), Expr::Ceil { expr, field: _ } => sanitize(expr), Expr::Floor { expr, field: _ } => sanitize(expr), Expr::Position { expr, r#in } => sanitize(expr).and(sanitize(r#in)), - Expr::Substring { expr, substring_from, substring_for, special: _ } => sanitize(expr) + Expr::Substring { + expr, + substring_from, + substring_for, + special: _, + } => sanitize(expr) .and(substring_from.as_deref().map(sanitize).unwrap_or(Ok(()))) .and(substring_for.as_deref().map(sanitize).unwrap_or(Ok(()))), - Expr::Trim { expr, trim_where: _, trim_what, trim_characters } => sanitize(expr) + Expr::Trim { + expr, + trim_where: _, + trim_what, + trim_characters, + } => sanitize(expr) .and(trim_what.as_deref().map(sanitize).unwrap_or(Ok(()))) .and( trim_characters .as_ref() .map(|v| v.iter()) .map(|mut iter| iter.try_for_each(sanitize)) - .unwrap_or(Ok(())) + .unwrap_or(Ok(())), ), - Expr::Overlay { expr, overlay_what, overlay_from, overlay_for } => sanitize(expr) + Expr::Overlay { + expr, + overlay_what, + overlay_from, + overlay_for, + } => sanitize(expr) .and(sanitize(overlay_what)) .and(sanitize(overlay_from)) .and(overlay_for.as_deref().map(sanitize).unwrap_or(Ok(()))), Expr::Collate { expr, collation: _ } => sanitize(expr), Expr::Nested(subexpr) => sanitize(subexpr), Expr::Value(_) => Ok(()), - Expr::IntroducedString { introducer: _, value: _ } => Ok(()), - Expr::TypedString { data_type: _, value: _ } => Ok(()), - Expr::MapAccess { column, keys } => sanitize(column).and(keys.iter().try_for_each(sanitize)), + Expr::IntroducedString { + introducer: _, + value: _, + } => Ok(()), + Expr::TypedString { + data_type: _, + value: _, + } => Ok(()), + Expr::MapAccess { column, keys } => { + sanitize(column).and(keys.iter().try_for_each(sanitize)) + } Expr::Function(func) => Err(Error::FunctionCallDetected(func)), - Expr::AggregateExpressionWithFilter { expr, filter } => sanitize(expr).and(sanitize(filter)), - Expr::Case { operand, conditions, results, else_result } => conditions.iter() + Expr::AggregateExpressionWithFilter { expr, filter } => { + sanitize(expr).and(sanitize(filter)) + } + Expr::Case { + operand, + conditions, + results, + else_result, + } => conditions + .iter() .chain(results) .chain(operand.iter().map(std::borrow::Borrow::borrow)) .chain(else_result.iter().map(std::borrow::Borrow::borrow)) .try_for_each(sanitize), - Expr::Exists { subquery, negated: _ } => Err(Error::SubqueryDetected(subquery)), + Expr::Exists { + subquery, + negated: _, + } => Err(Error::SubqueryDetected(subquery)), Expr::Subquery(subquery) => Err(Error::SubqueryDetected(subquery.as_ref())), Expr::ArraySubquery(subquery) => Err(Error::SubqueryDetected(subquery.as_ref())), Expr::ListAgg(agg) => sanitize(&agg.expr), Expr::ArrayAgg(agg) => sanitize(&agg.expr), - Expr::GroupingSets(sets) => sets.iter() + Expr::GroupingSets(sets) => sets + .iter() .map(|i| i.iter()) .try_for_each(|mut si| si.try_for_each(sanitize)), - Expr::Cube(cube) => cube.iter() + Expr::Cube(cube) => cube + .iter() .map(|i| i.iter()) .try_for_each(|mut si| si.try_for_each(sanitize)), - Expr::Rollup(rollup) => rollup.iter() + Expr::Rollup(rollup) => rollup + .iter() .map(|i| i.iter()) .try_for_each(|mut si| si.try_for_each(sanitize)), Expr::Tuple(tuple) => tuple.iter().try_for_each(sanitize), Expr::Struct { values, fields: _ } => values.iter().try_for_each(sanitize), Expr::Named { expr, name: _ } => sanitize(expr), - Expr::ArrayIndex { obj, indexes } => sanitize(obj) - .and(indexes.iter().try_for_each(sanitize)), + Expr::ArrayIndex { obj, indexes } => { + sanitize(obj).and(indexes.iter().try_for_each(sanitize)) + } Expr::Array(array) => array.elem.iter().try_for_each(sanitize), Expr::Interval(interval) => sanitize(&interval.value), Expr::MatchAgainst { .. } => Ok(()), @@ -115,7 +214,8 @@ fn sanitize(expr: &Expr) -> Result<(), Error> { fn main() -> Result<(), sqlparser::parser::ParserError> { let query = std::env::args().skip(1).take(1).next().unwrap(); - static DIALECT: sqlparser::dialect::PostgreSqlDialect = sqlparser::dialect::PostgreSqlDialect {}; + static DIALECT: sqlparser::dialect::PostgreSqlDialect = + sqlparser::dialect::PostgreSqlDialect {}; let parser: sqlparser::parser::Parser<'static> = sqlparser::parser::Parser::new(&DIALECT); let expr = parser.try_with_sql(&query)?.parse_expr()?; @@ -125,6 +225,6 @@ fn main() -> Result<(), sqlparser::parser::ParserError> { eprintln!("{}", err); } } - + Ok(()) } diff --git a/flake.lock b/flake.lock index bf7a454..5667933 100644 --- a/flake.lock +++ b/flake.lock @@ -2,11 +2,11 @@ "nodes": { "crane": { "locked": { - "lastModified": 1734808813, - "narHash": "sha256-3aH/0Y6ajIlfy7j52FGZ+s4icVX0oHhqBzRdlOeztqg=", + "lastModified": 1745022865, + "narHash": "sha256-tXL4qUlyYZEGOHUKUWjmmcvJjjLQ+4U38lPWSc8Cgdo=", "owner": "ipetkov", "repo": "crane", - "rev": "72e2d02dbac80c8c86bf6bf3e785536acf8ee926", + "rev": "25ca4c50039d91ad88cc0b8feacb9ad7f748dedf", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index f820e51..81eee80 100644 --- a/flake.nix +++ b/flake.nix @@ -20,14 +20,14 @@ overlays = [ rust-overlay.overlays.default ]; localSystem = { inherit system; }; }; + # NOTE: `pkgs` here must match `pkgs` used for `callPackage` to ensure + # cross-compilation works. Crane sets the requisite variables automatically. crane' = crane.mkLib pkgs; cargoToml = builtins.fromTOML (builtins.readFile ./Cargo.toml); crane-msrv' = crane'.overrideToolchain (p: p.rust-bin.stable."${cargoToml.package.rust-version}".default); kittybox = pkgs.callPackage ./kittybox.nix { - # TODO: this may break cross-compilation. It may be better to - # inject it as an overlay. However, I am unsure whether Crane - # can recognize it's being passed a cross-compilation set. + # NOTE: See above re: cross-compilation. crane = crane'; nixosTests = { diff --git a/indieauth/Cargo.toml b/indieauth/Cargo.toml index 3bc3864..9213a51 100644 --- a/indieauth/Cargo.toml +++ b/indieauth/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "kittybox-indieauth" -version = "0.2.0" +version = "0.3.3" edition = "2021" [features] @@ -9,7 +9,7 @@ axum = ["dep:axum-core", "dep:serde_json", "dep:http"] [dev-dependencies] serde_json = { workspace = true } # A JSON serialization file format -serde_urlencoded = { workspace = true } # `x-www-form-urlencoded` meets Serde +serde_urlencoded = { workspace = true } # `x-www-form-urlencoded` meets Serde [dependencies] axum-core = { workspace = true, optional = true } diff --git a/indieauth/src/lib.rs b/indieauth/src/lib.rs index b3ec098..b10fd0e 100644 --- a/indieauth/src/lib.rs +++ b/indieauth/src/lib.rs @@ -20,13 +20,13 @@ //! [`axum`]: https://github.com/tokio-rs/axum use std::borrow::Cow; -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; use url::Url; mod scopes; pub use self::scopes::{Scope, Scopes}; mod pkce; -pub use self::pkce::{PKCEMethod, PKCEVerifier, PKCEChallenge}; +pub use self::pkce::{PKCEChallenge, PKCEMethod, PKCEVerifier}; // Re-export rand crate just to be sure. pub use rand; @@ -34,23 +34,24 @@ pub use rand; /// Authentication methods supported by the introspection endpoint. /// Note that authentication at the introspection endpoint is /// mandatory. -#[derive(Copy, Clone, Debug, Serialize, Deserialize)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] #[non_exhaustive] pub enum IntrospectionEndpointAuthMethod { /// `Authorization` header with a `Bearer` token. + #[serde(rename_all = "PascalCase")] Bearer, /// A token passed as part of a POST request. - #[serde(rename = "snake_case")] + #[serde(rename_all = "snake_case")] ClientSecretPost, /// Username and password passed using HTTP Basic authentication. - #[serde(rename = "snake_case")] + #[serde(rename_all = "snake_case")] ClientSecretBasic, /// TLS client auth with a certificate signed by a valid CA. - #[serde(rename = "snake_case")] + #[serde(rename_all = "snake_case")] TlsClientAuth, /// TLS client auth with a self-signed certificate. - #[serde(rename = "snake_case")] - SelfSignedTlsClientAuth + #[serde(rename_all = "snake_case")] + SelfSignedTlsClientAuth, } /// Authentication methods supported by the revocation endpoint. @@ -60,17 +61,17 @@ pub enum IntrospectionEndpointAuthMethod { /// authentication is neccesary to protect tokens. A well-intentioned /// person discovering a leaked token could quickly revoke it without /// disturbing anyone. -#[derive(Copy, Clone, Debug, Serialize, Deserialize)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] #[non_exhaustive] pub enum RevocationEndpointAuthMethod { /// No authentication is required to access an endpoint declaring /// this value. - None + None, } /// The response types supported by the authorization endpoint. -#[derive(Copy, Clone, Debug, Serialize, Deserialize)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum ResponseType { /// An authorization code will be issued if this response type is @@ -82,7 +83,7 @@ pub enum ResponseType { /// This response type requires a valid access token. /// /// [AutoAuth spec]: https://github.com/sknebel/AutoAuth/blob/master/AutoAuth.md#allowing-external-clients-to-obtain-tokens - ExternalToken + ExternalToken, } // TODO serde_variant impl ResponseType { @@ -100,7 +101,7 @@ impl ResponseType { /// This type is strictly for usage in the [`Metadata`] response. For /// grant requests and responses, see [`GrantRequest`] and /// [`GrantResponse`]. -#[derive(Copy, Clone, Debug, Serialize, Deserialize)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum GrantType { /// The authorization code grant, allowing to exchange an @@ -110,7 +111,7 @@ pub enum GrantType { /// The refresh token grant, allowing to exchange a refresh token /// for a fresh access token and a new refresh token, to /// facilitate long-term access. - RefreshToken + RefreshToken, } /// OAuth 2.0 Authorization Server Metadata in application to the IndieAuth protocol. @@ -222,7 +223,7 @@ pub struct Metadata { /// registration. #[serde(skip_serializing_if = "ref_identity")] #[serde(default = "Default::default")] - pub client_id_metadata_document_supported: bool + pub client_id_metadata_document_supported: bool, } impl std::fmt::Debug for Metadata { @@ -232,31 +233,59 @@ impl std::fmt::Debug for Metadata { .field("authorization_endpoint", &self.issuer.as_str()) .field("token_endpoint", &self.issuer.as_str()) .field("introspection_endpoint", &self.issuer.as_str()) - .field("introspection_endpoint_auth_methods_supported", &self.introspection_endpoint_auth_methods_supported) - .field("revocation_endpoint", &self.revocation_endpoint.as_ref().map(Url::as_str)) - .field("revocation_endpoint_auth_methods_supported", &self.revocation_endpoint_auth_methods_supported) + .field( + "introspection_endpoint_auth_methods_supported", + &self.introspection_endpoint_auth_methods_supported, + ) + .field( + "revocation_endpoint", + &self.revocation_endpoint.as_ref().map(Url::as_str), + ) + .field( + "revocation_endpoint_auth_methods_supported", + &self.revocation_endpoint_auth_methods_supported, + ) .field("scopes_supported", &self.scopes_supported) .field("response_types_supported", &self.response_types_supported) .field("grant_types_supported", &self.grant_types_supported) - .field("service_documentation", &self.service_documentation.as_ref().map(Url::as_str)) - .field("code_challenge_methods_supported", &self.code_challenge_methods_supported) - .field("authorization_response_iss_parameter_supported", &self.authorization_response_iss_parameter_supported) - .field("userinfo_endpoint", &self.userinfo_endpoint.as_ref().map(Url::as_str)) - .field("client_id_metadata_document_supported", &self.client_id_metadata_document_supported) + .field( + "service_documentation", + &self.service_documentation.as_ref().map(Url::as_str), + ) + .field( + "code_challenge_methods_supported", + &self.code_challenge_methods_supported, + ) + .field( + "authorization_response_iss_parameter_supported", + &self.authorization_response_iss_parameter_supported, + ) + .field( + "userinfo_endpoint", + &self.userinfo_endpoint.as_ref().map(Url::as_str), + ) + .field( + "client_id_metadata_document_supported", + &self.client_id_metadata_document_supported, + ) .finish() } } -fn ref_identity(v: &bool) -> bool { *v } +fn ref_identity(v: &bool) -> bool { + *v +} #[cfg(feature = "axum")] impl axum_core::response::IntoResponse for Metadata { fn into_response(self) -> axum_core::response::Response { use http::StatusCode; - (StatusCode::OK, - [("Content-Type", "application/json")], - serde_json::to_vec(&self).unwrap()) + ( + StatusCode::OK, + [("Content-Type", "application/json")], + serde_json::to_vec(&self).unwrap(), + ) .into_response() } } @@ -308,24 +337,37 @@ pub struct ClientMetadata { pub software_version: Option<Cow<'static, str>>, /// URI for the homepage of this client's owners #[serde(skip_serializing_if = "Option::is_none")] - pub homepage_uri: Option<Url> + pub homepage_uri: Option<Url>, +} + +/// Error that occurs when creating [`ClientMetadata`] with mismatched `client_id` and `client_uri`. +#[derive(Debug)] +pub struct ClientIdMismatch; + +impl std::error::Error for ClientIdMismatch {} +impl std::fmt::Display for ClientIdMismatch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "client_id must be a prefix of client_uri") + } } impl ClientMetadata { - /// Create a new [`ClientMetadata`] with all the optional fields - /// omitted. + /// Create a new [`ClientMetadata`] with all the optional fields omitted. /// /// # Errors /// - /// Returns `()` if the `client_uri` is not a prefix of - /// `client_id` as required by the IndieAuth spec. - pub fn new(client_id: url::Url, client_uri: url::Url) -> Result<Self, ()> { - if client_id.as_str().as_bytes()[..client_uri.as_str().len()] != *client_uri.as_str().as_bytes() { - return Err(()); + /// Returns `()` if the `client_uri` is not a prefix of `client_id` as required by the IndieAuth + /// spec. + pub fn new(client_id: url::Url, client_uri: url::Url) -> Result<Self, ClientIdMismatch> { + if client_id.as_str().as_bytes()[..client_uri.as_str().len()] + != *client_uri.as_str().as_bytes() + { + return Err(ClientIdMismatch); } Ok(Self { - client_id, client_uri, + client_id, + client_uri, client_name: None, logo_uri: None, redirect_uris: None, @@ -355,14 +397,15 @@ impl axum_core::response::IntoResponse for ClientMetadata { fn into_response(self) -> axum_core::response::Response { use http::StatusCode; - (StatusCode::OK, - [("Content-Type", "application/json")], - serde_json::to_vec(&self).unwrap()) + ( + StatusCode::OK, + [("Content-Type", "application/json")], + serde_json::to_vec(&self).unwrap(), + ) .into_response() } } - /// User profile to be returned from the userinfo endpoint and when /// the `profile` scope was requested. #[derive(Clone, Debug, Serialize, Deserialize)] @@ -379,7 +422,7 @@ pub struct Profile { /// User's email, if they've chosen to reveal it. This is guarded /// by the `email` scope. #[serde(skip_serializing_if = "Option::is_none")] - pub email: Option<String> + pub email: Option<String>, } #[cfg(feature = "axum")] @@ -387,9 +430,11 @@ impl axum_core::response::IntoResponse for Profile { fn into_response(self) -> axum_core::response::Response { use http::StatusCode; - (StatusCode::OK, - [("Content-Type", "application/json")], - serde_json::to_vec(&self).unwrap()) + ( + StatusCode::OK, + [("Content-Type", "application/json")], + serde_json::to_vec(&self).unwrap(), + ) .into_response() } } @@ -414,13 +459,13 @@ impl State { /// Generate a random state string of 128 bytes in length, using /// the provided random number generator. pub fn from_rng(rng: &mut (impl rand::CryptoRng + rand::Rng)) -> Self { - use rand::{Rng, distributions::Alphanumeric}; + use rand::{distributions::Alphanumeric, Rng}; - let bytes = rng.sample_iter(&Alphanumeric) + let bytes = rng + .sample_iter(&Alphanumeric) .take(128) .collect::<Vec<u8>>(); Self(String::from_utf8(bytes).unwrap()) - } } impl AsRef<str> for State { @@ -503,21 +548,23 @@ impl AuthorizationRequest { ("response_type", Cow::Borrowed(self.response_type.as_str())), ("client_id", Cow::Borrowed(self.client_id.as_str())), ("redirect_uri", Cow::Borrowed(self.redirect_uri.as_str())), - ("code_challenge", Cow::Borrowed(self.code_challenge.as_str())), - ("code_challenge_method", Cow::Borrowed(self.code_challenge.method().as_str())), - ("state", Cow::Borrowed(self.state.as_ref())) + ( + "code_challenge", + Cow::Borrowed(self.code_challenge.as_str()), + ), + ( + "code_challenge_method", + Cow::Borrowed(self.code_challenge.method().as_str()), + ), + ("state", Cow::Borrowed(self.state.as_ref())), ]; if let Some(ref scope) = self.scope { - v.push( - ("scope", Cow::Owned(scope.to_string())) - ); + v.push(("scope", Cow::Owned(scope.to_string()))); } if let Some(ref me) = self.me { - v.push( - ("me", Cow::Borrowed(me.as_str())) - ); + v.push(("me", Cow::Borrowed(me.as_str()))); } v @@ -550,17 +597,22 @@ pub struct AutoAuthRequest { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AutoAuthCallbackData { state: State, - callback_url: Url + callback_url: Url, } #[inline(always)] -fn deserialize_secs<'de, D: serde::de::Deserializer<'de>>(d: D) -> Result<std::time::Duration, D::Error> { +fn deserialize_secs<'de, D: serde::de::Deserializer<'de>>( + d: D, +) -> Result<std::time::Duration, D::Error> { use serde::Deserialize; Ok(std::time::Duration::from_secs(u64::deserialize(d)?)) } #[inline(always)] -fn serialize_secs<S: serde::ser::Serializer>(d: &std::time::Duration, s: S) -> Result<S::Ok, S::Error> { +fn serialize_secs<S: serde::ser::Serializer>( + d: &std::time::Duration, + s: S, +) -> Result<S::Ok, S::Error> { s.serialize_u64(std::time::Duration::as_secs(d)) } @@ -570,7 +622,7 @@ pub struct AutoAuthPollingResponse { request_id: State, #[serde(serialize_with = "serialize_secs")] #[serde(deserialize_with = "deserialize_secs")] - interval: std::time::Duration + interval: std::time::Duration, } /// The authorization response that must be appended to the @@ -602,10 +654,9 @@ pub struct AuthorizationResponse { /// authorization server. /// /// [oauth2-iss]: https://www.ietf.org/archive/id/draft-ietf-oauth-iss-auth-resp-02.html - pub iss: Url + pub iss: Url, } - /// A special grant request that is used in the AutoAuth ceremony. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct AutoAuthCodeGrant { @@ -628,7 +679,7 @@ pub struct AutoAuthCodeGrant { callback_url: Url, /// The user's URL. Will be used to confirm the authorization /// endpoint's authority. - me: Url + me: Url, } /// A grant request that continues the IndieAuth ceremony. @@ -647,7 +698,7 @@ pub enum GrantRequest { redirect_uri: Url, /// The PKCE code verifier that was used to create the code /// challenge. - code_verifier: PKCEVerifier + code_verifier: PKCEVerifier, }, /// Use a refresh token to get a fresh access token and a new /// matching refresh token. @@ -662,8 +713,8 @@ pub enum GrantRequest { /// /// This cannot be used to gain new scopes -- you need to /// start over if you need new scopes from the user. - scope: Option<Scopes> - } + scope: Option<Scopes>, + }, } /// Token type, as described in [RFC6749][]. @@ -677,7 +728,7 @@ pub enum TokenType { /// IndieAuth uses. /// /// [RFC6750]: https://www.rfc-editor.org/rfc/rfc6750 - Bearer + Bearer, } /// The response to a successful [`GrantRequest`]. @@ -714,14 +765,14 @@ pub enum GrantResponse { profile: Option<Profile>, /// The refresh token, if it was issued. #[serde(skip_serializing_if = "Option::is_none")] - refresh_token: Option<String> + refresh_token: Option<String>, }, /// A profile URL response, that only contains the profile URL and /// the profile, if it was requested. /// /// This is suitable for confirming the identity of the user, but /// no more than that. - ProfileUrl(ProfileUrl) + ProfileUrl(ProfileUrl), } /// The contents of a profile URL response. @@ -731,7 +782,7 @@ pub struct ProfileUrl { pub me: Url, /// The user's profile information, if it was requested. #[serde(skip_serializing_if = "Option::is_none")] - pub profile: Option<Profile> + pub profile: Option<Profile>, } #[cfg(feature = "axum")] @@ -739,12 +790,15 @@ impl axum_core::response::IntoResponse for GrantResponse { fn into_response(self) -> axum_core::response::Response { use http::StatusCode; - (StatusCode::OK, - [("Content-Type", "application/json"), - ("Cache-Control", "no-store"), - ("Pragma", "no-cache") - ], - serde_json::to_vec(&self).unwrap()) + ( + StatusCode::OK, + [ + ("Content-Type", "application/json"), + ("Cache-Control", "no-store"), + ("Pragma", "no-cache"), + ], + serde_json::to_vec(&self).unwrap(), + ) .into_response() } } @@ -758,7 +812,7 @@ impl axum_core::response::IntoResponse for GrantResponse { pub enum RequestMaybeAuthorizationEndpoint { Authorization(AuthorizationRequest), Grant(GrantRequest), - AutoAuth(AutoAuthCodeGrant) + AutoAuth(AutoAuthCodeGrant), } /// A token introspection request that can be handled by the token @@ -770,7 +824,7 @@ pub enum RequestMaybeAuthorizationEndpoint { #[derive(Debug, Serialize, Deserialize)] pub struct TokenIntrospectionRequest { /// The token for which data was requested. - pub token: String + pub token: String, } /// Data for a token that will be returned by the introspection @@ -792,7 +846,7 @@ pub struct TokenData { /// The issue date, represented in the same format as the /// [`exp`][TokenData::exp] field. #[serde(skip_serializing_if = "Option::is_none")] - pub iat: Option<u64> + pub iat: Option<u64>, } impl TokenData { @@ -801,24 +855,25 @@ impl TokenData { use std::time::{Duration, SystemTime, UNIX_EPOCH}; self.exp - .map(|exp| SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or(Duration::ZERO) - .as_secs() >= exp) + .map(|exp| { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(Duration::ZERO) + .as_secs() + >= exp + }) .unwrap_or_default() } /// Return a timestamp at which the token is not considered valid anymore. pub fn expires_at(&self) -> Option<std::time::SystemTime> { - self.exp.map(|time| { - std::time::UNIX_EPOCH + std::time::Duration::from_secs(time) - }) + self.exp + .map(|time| std::time::UNIX_EPOCH + std::time::Duration::from_secs(time)) } /// Return a timestamp describing when the token was issued. pub fn issued_at(&self) -> Option<std::time::SystemTime> { - self.iat.map(|time| { - std::time::UNIX_EPOCH + std::time::Duration::from_secs(time) - }) + self.iat + .map(|time| std::time::UNIX_EPOCH + std::time::Duration::from_secs(time)) } /// Check if a certain scope is allowed for this token. @@ -841,18 +896,24 @@ pub struct TokenIntrospectionResponse { active: bool, #[serde(flatten)] #[serde(skip_serializing_if = "Option::is_none")] - data: Option<TokenData> + data: Option<TokenData>, } // These wrappers and impls should take care of making use of this // type as painless as possible. impl TokenIntrospectionResponse { /// Indicate that this token is not valid. pub fn inactive() -> Self { - Self { active: false, data: None } + Self { + active: false, + data: None, + } } /// Indicate that this token is valid, and provide data about it. pub fn active(data: TokenData) -> Self { - Self { active: true, data: Some(data) } + Self { + active: true, + data: Some(data), + } } /// Check if the endpoint reports this token as valid. pub fn is_active(&self) -> bool { @@ -862,7 +923,7 @@ impl TokenIntrospectionResponse { /// Get data contained in the response, if the token is valid. pub fn data(&self) -> Option<&TokenData> { if !self.active { - return None + return None; } self.data.as_ref() } @@ -874,7 +935,10 @@ impl Default for TokenIntrospectionResponse { } impl From<Option<TokenData>> for TokenIntrospectionResponse { fn from(data: Option<TokenData>) -> Self { - Self { active: data.is_some(), data } + Self { + active: data.is_some(), + data, + } } } impl From<TokenIntrospectionResponse> for Option<TokenData> { @@ -888,9 +952,11 @@ impl axum_core::response::IntoResponse for TokenIntrospectionResponse { fn into_response(self) -> axum_core::response::Response { use http::StatusCode; - (StatusCode::OK, - [("Content-Type", "application/json")], - serde_json::to_vec(&self).unwrap()) + ( + StatusCode::OK, + [("Content-Type", "application/json")], + serde_json::to_vec(&self).unwrap(), + ) .into_response() } } @@ -900,7 +966,7 @@ impl axum_core::response::IntoResponse for TokenIntrospectionResponse { #[derive(Debug, Serialize, Deserialize)] pub struct TokenRevocationRequest { /// The token that needs to be revoked in case it is valid. - pub token: String + pub token: String, } /// Types of errors that a resource server (IndieAuth consumer) can @@ -961,7 +1027,6 @@ pub enum ErrorKind { /// AutoAuth/OAuth2 Device Flow: Access was denied by the /// authorization endpoint. AccessDenied, - } // TODO consider relying on serde_variant for these conversions impl AsRef<str> for ErrorKind { @@ -997,13 +1062,15 @@ pub struct Error { pub msg: Option<String>, /// An URL to documentation describing what went wrong and how to /// fix it. - pub error_uri: Option<url::Url> + pub error_uri: Option<url::Url>, } impl From<ErrorKind> for Error { fn from(kind: ErrorKind) -> Error { Error { - kind, msg: None, error_uri: None + kind, + msg: None, + error_uri: None, } } } @@ -1029,9 +1096,11 @@ impl axum_core::response::IntoResponse for self::Error { fn into_response(self) -> axum_core::response::Response { use http::StatusCode; - (StatusCode::BAD_REQUEST, - [("Content-Type", "application/json")], - serde_json::to_vec(&self).unwrap()) + ( + StatusCode::BAD_REQUEST, + [("Content-Type", "application/json")], + serde_json::to_vec(&self).unwrap(), + ) .into_response() } } @@ -1044,17 +1113,23 @@ mod tests { fn test_serialize_deserialize_grant_request() { let authorization_code: GrantRequest = GrantRequest::AuthorizationCode { client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), - redirect_uri: "https://kittybox.fireburn.ru/.kittybox/login/redirect".parse().unwrap(), + redirect_uri: "https://kittybox.fireburn.ru/.kittybox/login/redirect" + .parse() + .unwrap(), code_verifier: PKCEVerifier("helloworld".to_string()), - code: "hithere".to_owned() + code: "hithere".to_owned(), }; let serialized = serde_urlencoded::to_string([ ("grant_type", "authorization_code"), ("code", "hithere"), ("client_id", "https://kittybox.fireburn.ru/"), - ("redirect_uri", "https://kittybox.fireburn.ru/.kittybox/login/redirect"), + ( + "redirect_uri", + "https://kittybox.fireburn.ru/.kittybox/login/redirect", + ), ("code_verifier", "helloworld"), - ]).unwrap(); + ]) + .unwrap(); let deserialized = serde_urlencoded::from_str(&serialized).unwrap(); diff --git a/indieauth/src/pkce.rs b/indieauth/src/pkce.rs index 8dcf9b1..6233016 100644 --- a/indieauth/src/pkce.rs +++ b/indieauth/src/pkce.rs @@ -1,6 +1,6 @@ -use serde::{Serialize, Deserialize}; -use sha2::{Sha256, Digest}; use data_encoding::BASE64URL; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; /// Methods to use for PKCE challenges. #[derive(PartialEq, Eq, Copy, Clone, Debug, Serialize, Deserialize, Default)] @@ -10,7 +10,7 @@ pub enum PKCEMethod { S256, /// Plain string by itself. Please don't use this. #[serde(rename = "snake_case")] - Plain + Plain, } impl PKCEMethod { @@ -18,7 +18,7 @@ impl PKCEMethod { pub fn as_str(&self) -> &'static str { match self { PKCEMethod::S256 => "S256", - PKCEMethod::Plain => "plain" + PKCEMethod::Plain => "plain", } } } @@ -57,7 +57,7 @@ impl PKCEVerifier { /// Generate a new PKCE string of 128 bytes in length, using /// the provided random number generator. pub fn from_rng(rng: &mut (impl rand::CryptoRng + rand::Rng)) -> Self { - use rand::{Rng, distributions::Alphanumeric}; + use rand::{distributions::Alphanumeric, Rng}; let bytes = rng .sample_iter(&Alphanumeric) @@ -65,7 +65,6 @@ impl PKCEVerifier { .collect::<Vec<u8>>(); Self(String::from_utf8(bytes).unwrap()) } - } /// A PKCE challenge as described in [RFC7636]. @@ -75,7 +74,7 @@ impl PKCEVerifier { pub struct PKCEChallenge { code_challenge: String, #[serde(rename = "code_challenge_method")] - method: PKCEMethod + method: PKCEMethod, } impl PKCEChallenge { @@ -92,10 +91,10 @@ impl PKCEChallenge { challenge.retain(|c| c != '='); challenge - }, + } PKCEMethod::Plain => code_verifier.to_string(), }, - method + method, } } @@ -130,17 +129,21 @@ impl PKCEChallenge { #[cfg(test)] mod tests { - use super::{PKCEMethod, PKCEVerifier, PKCEChallenge}; + use super::{PKCEChallenge, PKCEMethod, PKCEVerifier}; #[test] /// A snapshot test generated using [Aaron Parecki's PKCE /// tools](https://example-app.com/pkce) that checks for a /// conforming challenge. fn test_pkce_challenge_verification() { - let verifier = PKCEVerifier("ec03310e4e90f7bc988af05384060c3c1afeae4bb4d0f648c5c06b63".to_owned()); + let verifier = + PKCEVerifier("ec03310e4e90f7bc988af05384060c3c1afeae4bb4d0f648c5c06b63".to_owned()); let challenge = PKCEChallenge::new(&verifier, PKCEMethod::S256); - assert_eq!(challenge.as_str(), "aB8OG20Rh8UoQ9gFhI0YvPkx4dDW2MBspBKGXL6j6Wg"); + assert_eq!( + challenge.as_str(), + "aB8OG20Rh8UoQ9gFhI0YvPkx4dDW2MBspBKGXL6j6Wg" + ); } } diff --git a/indieauth/src/scopes.rs b/indieauth/src/scopes.rs index 02ee8dc..7333b5b 100644 --- a/indieauth/src/scopes.rs +++ b/indieauth/src/scopes.rs @@ -1,12 +1,8 @@ use std::str::FromStr; use serde::{ - Serialize, Serializer, - Deserialize, - de::{ - Deserializer, Visitor, - Error as DeserializeError - } + de::{Deserializer, Error as DeserializeError, Visitor}, + Deserialize, Serialize, Serializer, }; /// Various scopes that can be requested through IndieAuth. @@ -36,7 +32,7 @@ pub enum Scope { /// Allows to receive email in the profile information. Email, /// Custom scope not included above. - Custom(String) + Custom(String), } impl Scope { /// Create a custom scope from a string slice. @@ -61,25 +57,25 @@ impl AsRef<str> for Scope { Channels => "channels", Profile => "profile", Email => "email", - Custom(s) => s.as_ref() + Custom(s) => s.as_ref(), } } } impl From<&str> for Scope { fn from(scope: &str) -> Self { match scope { - "create" => Scope::Create, - "update" => Scope::Update, - "delete" => Scope::Delete, - "media" => Scope::Media, - "read" => Scope::Read, - "follow" => Scope::Follow, - "mute" => Scope::Mute, - "block" => Scope::Block, + "create" => Scope::Create, + "update" => Scope::Update, + "delete" => Scope::Delete, + "media" => Scope::Media, + "read" => Scope::Read, + "follow" => Scope::Follow, + "mute" => Scope::Mute, + "block" => Scope::Block, "channels" => Scope::Channels, - "profile" => Scope::Profile, - "email" => Scope::Email, - other => Scope::custom(other) + "profile" => Scope::Profile, + "email" => Scope::Email, + other => Scope::custom(other), } } } @@ -106,7 +102,8 @@ impl Scopes { } /// Ensure all of the requested scopes are in the list. pub fn has_all(&self, scopes: &[Scope]) -> bool { - scopes.iter() + scopes + .iter() .map(|s1| self.iter().any(|s2| s1 == s2)) .all(|s| s) } @@ -114,6 +111,19 @@ impl Scopes { pub fn iter(&self) -> std::slice::Iter<'_, Scope> { self.0.iter() } + + /// Count scopes requested by the application. + pub fn len(&self) -> usize { + self.0.len() + } + + /// See if the application requested any scopes. + /// + /// Some older applications forget to request scopes. This may be used to force a default scope. + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } } impl AsRef<[Scope]> for Scopes { fn as_ref(&self) -> &[Scope] { @@ -123,8 +133,7 @@ impl AsRef<[Scope]> for Scopes { impl std::fmt::Display for Scopes { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut iter = self.0.iter() - .peekable(); + let mut iter = self.0.iter().peekable(); while let Some(scope) = iter.next() { f.write_str(scope.as_ref())?; if iter.peek().is_some() { @@ -139,20 +148,24 @@ impl FromStr for Scopes { type Err = std::convert::Infallible; fn from_str(value: &str) -> Result<Self, Self::Err> { - Ok(Self(value.split_ascii_whitespace() + Ok(Self( + value + .split_ascii_whitespace() .map(Scope::from) - .collect::<Vec<Scope>>())) + .collect::<Vec<Scope>>(), + )) } } impl Serialize for Scopes { fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where - S: Serializer + S: Serializer, { serializer.serialize_str(&self.to_string()) } } struct ScopeVisitor; +#[allow(clippy::needless_lifetimes, reason = "serde idiom")] impl<'de> Visitor<'de> for ScopeVisitor { type Value = Scopes; @@ -162,16 +175,15 @@ impl<'de> Visitor<'de> for ScopeVisitor { fn visit_str<E>(self, value: &str) -> Result<Self::Value, E> where - E: DeserializeError + E: DeserializeError, { Ok(Scopes::from_str(value).unwrap()) } } impl<'de> Deserialize<'de> for Scopes { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> where - D: Deserializer<'de> + D: Deserializer<'de>, { deserializer.deserialize_str(ScopeVisitor) } @@ -184,29 +196,31 @@ mod tests { #[test] fn test_serde_vec_scope() { let scopes = vec![ - Scope::Create, Scope::Update, Scope::Delete, + Scope::Create, + Scope::Update, + Scope::Delete, Scope::Media, - Scope::custom("kittybox_internal_access") + Scope::custom("kittybox_internal_access"), ]; - let scope_serialized = serde_json::to_value( - Scopes::new(scopes.clone()) - ).unwrap(); + let scope_serialized = serde_json::to_value(Scopes::new(scopes.clone())).unwrap(); let scope_str = scope_serialized.as_str().unwrap(); - assert_eq!(scope_str, "create update delete media kittybox_internal_access"); + assert_eq!( + scope_str, + "create update delete media kittybox_internal_access" + ); - assert!(serde_json::from_value::<Scopes>(scope_serialized).unwrap().has_all(&scopes)) + assert!(serde_json::from_value::<Scopes>(scope_serialized) + .unwrap() + .has_all(&scopes)) } #[test] fn test_scope_has_all() { - let scopes = Scopes(vec![ - Scope::Create, Scope::Update, Scope::custom("draft") - ]); + let scopes = Scopes(vec![Scope::Create, Scope::Update, Scope::custom("draft")]); assert!(scopes.has_all(&[Scope::Create, Scope::custom("draft")])); assert!(!scopes.has_all(&[Scope::Read, Scope::custom("kittybox_internal_access")])); } - } diff --git a/kittybox.1 b/kittybox.1 new file mode 100644 index 0000000..ecda3e0 --- /dev/null +++ b/kittybox.1 @@ -0,0 +1,107 @@ +.TH KITTYBOX 1 "" https://kittybox.fireburn.ru/ +.SH NAME +kittybox \- a CMS using IndieWeb technologies +.SH SYNOPSIS +.SY kittybox +.YS + +.SH DESCRIPTION +.P +.B kittybox +is a full-featured CMS for a personal website which is able to use various +storage backends to store website content. + +It is most suitable for a personal blog, though it probably is capable of being +used for other purposes. + +.SH ENVIRONMENT +.PP +.I $BACKEND_URI +.RS 4 +The URI of the main storage backend to use. Available backends are: +.TP +.EX +.B "postgres://<connection URI>" +.EE +Store website content in a Postgres database. Takes a connection URI. +.TP +.EX +.B "file://<path to a folder>" +.EE +Store website content in a folder on the local filesystem. + +.B NOTE: +This backend is not actively maintained and may not work as expected. +It does not implement some advanced features and will probably not receive +updates often. + +.RE +.PP +.I $AUTHSTORE_URI +.RS 4 +The URI of the authentication backend to use. +This backend is responsible for storing access tokens and short-lived +authorization codes. +Available backends are: +.TP +.EX +.B "file://<path to a folder>" +.EE +Store authentication data in a folder on the filesystem. + +.RE +.PP +.I $BLOBSTORE_URI +.RS 4 +The URI of the media store backend to use. +This backend manages file uploads and storage of post attachments. + +Available backends are: +.TP +.EX +.B "file://<path to a folder>" +.EE +Store file uploads in a content-addressed storage based on a folder. File +contents are hashed using SHA-256, and the hash is used to construct the path. +A small piece of metadata is stored next to the file in JSON format. + +.RE +.PP +.I $JOB_QUEUE_URI +.RS 4 +The URI of the job queue backend to use. +This backend is responsible for some background tasks, like receiving and +validating Webmentions. +Available backends are: +.TP 4 +.EX +.B "postgres://<connection URI>" +.EE +Use PostgreSQL as a job queue. +This works better than one would expect. + +.RE +.PP +.I $COOKIE_KEY +.RS 4 +A key for signing session cookies. +This needs to be kept secret. +.RE + +.SH STANDARDS + +Aaron Parecki, W3C, +.UR https://www.w3.org/TR/micropub/ +.I Micropub +.UE "," +23 May 2017. W3C Recommendation. + +Aaron Parecki, IndieWeb community, +.UR https://indieauth.spec.indieweb.org +.I IndieAuth +.UE "," +11 July 2024. Living Standard. + +.SH SEE ALSO + +.MR postgres 1 diff --git a/kittybox.nix b/kittybox.nix index b078c93..1b591f3 100644 --- a/kittybox.nix +++ b/kittybox.nix @@ -1,4 +1,4 @@ -{ crane, lib, nodePackages +{ crane, lib, installShellFiles, nodePackages , useWebAuthn ? false, openssl, zlib, pkg-config, protobuf , usePostgres ? true, postgresql, postgresqlTestHook , nixosTests }: @@ -9,7 +9,7 @@ assert usePostgres -> postgresql != null && postgresqlTestHook != null; let featureMatrix = features: lib.concatStringsSep " " (lib.attrNames (lib.filterAttrs (k: v: v) features)); - suffixes = [ ".sql" ".ts" ".css" ".html" ".json" ".woff2" ]; + suffixes = [ ".sql" ".ts" ".css" ".html" ".json" ".woff2" ".1" ]; suffixFilter = suffixes: name: type: let base = baseNameOf (toString name); in type == "directory" || lib.any (ext: lib.hasSuffix ext base) suffixes; @@ -34,7 +34,8 @@ let cargoExtraArgs = cargoFeatures; buildInputs = lib.optional useWebAuthn openssl; - nativeBuildInputs = [ nodePackages.typescript ] ++ (lib.optional useWebAuthn pkg-config); + nativeBuildInputs = [ nodePackages.typescript installShellFiles ] + ++ (lib.optional useWebAuthn pkg-config); meta = with lib.meta; { maintainers = with lib.maintainers; [ vikanezrimaya ]; @@ -50,13 +51,17 @@ in crane.buildPackage (args' // { nativeCheckInputs = lib.optionals usePostgres [ postgresql postgresqlTestHook ]; - + # Tests create arbitrary databases; we need to be prepared for that postgresqlTestUserOptions = "LOGIN SUPERUSER"; postgresqlTestSetupPost = '' export DATABASE_URL="postgres://localhost?host=$PGHOST&user=$PGUSER&dbname=$PGDATABASE" ''; + postInstall = '' + installManPage ./kittybox.1 + ''; + passthru = { tests = nixosTests; hasPostgres = usePostgres; diff --git a/src/bin/kittybox-check-webmention.rs b/src/bin/kittybox-check-webmention.rs index b43980e..a9e5957 100644 --- a/src/bin/kittybox-check-webmention.rs +++ b/src/bin/kittybox-check-webmention.rs @@ -7,7 +7,7 @@ enum Error { #[error("reqwest error: {0}")] Http(#[from] reqwest::Error), #[error("webmention check error: {0}")] - Webmention(#[from] WebmentionError) + Webmention(#[from] WebmentionError), } #[derive(Parser, Debug)] @@ -21,7 +21,7 @@ struct Args { #[clap(value_parser)] url: url::Url, #[clap(value_parser)] - link: url::Url + link: url::Url, } #[tokio::main] @@ -30,10 +30,11 @@ async fn main() -> Result<(), Error> { let http: reqwest::Client = { #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION") - )); + let mut builder = reqwest::Client::builder().user_agent(concat!( + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION") + )); builder.build().unwrap() }; diff --git a/src/bin/kittybox-indieauth-helper.rs b/src/bin/kittybox-indieauth-helper.rs index f4ad679..0725aac 100644 --- a/src/bin/kittybox-indieauth-helper.rs +++ b/src/bin/kittybox-indieauth-helper.rs @@ -1,13 +1,11 @@ +use clap::Parser; use futures::{FutureExt, TryFutureExt}; use kittybox_indieauth::{ - AuthorizationRequest, PKCEVerifier, - PKCEChallenge, PKCEMethod, GrantRequest, Scope, - AuthorizationResponse, GrantResponse, - Error as IndieauthError + AuthorizationRequest, AuthorizationResponse, Error as IndieauthError, GrantRequest, + GrantResponse, PKCEChallenge, PKCEMethod, PKCEVerifier, Scope, }; -use clap::Parser; -use tokio::net::TcpListener; use std::{borrow::Cow, future::IntoFuture, io::Write}; +use tokio::net::TcpListener; const DEFAULT_CLIENT_ID: &str = "https://kittybox.fireburn.ru/indieauth-helper.html"; const DEFAULT_REDIRECT_URI: &str = "http://localhost:60000/callback"; @@ -21,7 +19,7 @@ enum Error { #[error("url parsing error: {0}")] UrlParse(#[from] url::ParseError), #[error("indieauth flow error: {0}")] - IndieAuth(#[from] IndieauthError) + IndieAuth(#[from] IndieauthError), } #[derive(Parser, Debug)] @@ -46,20 +44,20 @@ struct Args { client_id: url::Url, /// Redirect URI to declare. Note: This will break the flow, use only for testing UI. #[clap(long, value_parser)] - redirect_uri: Option<url::Url> + redirect_uri: Option<url::Url>, } - #[tokio::main] async fn main() -> Result<(), Error> { let args = Args::parse(); let http: reqwest::Client = { #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION") - )); + let mut builder = reqwest::Client::builder().user_agent(concat!( + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION") + )); // This only works on debug builds. Don't get any funny thoughts. #[cfg(debug_assertions)] if std::env::var("KITTYBOX_DANGER_INSECURE_TLS") @@ -71,12 +69,14 @@ async fn main() -> Result<(), Error> { builder.build().unwrap() }; - let redirect_uri: url::Url = args.redirect_uri + let redirect_uri: url::Url = args + .redirect_uri .clone() .unwrap_or_else(|| DEFAULT_REDIRECT_URI.parse().unwrap()); eprintln!("Checking .well-known for metadata..."); - let metadata = http.get(args.me.join("/.well-known/oauth-authorization-server")?) + let metadata = http + .get(args.me.join("/.well-known/oauth-authorization-server")?) .header("Accept", "application/json") .send() .await? @@ -92,7 +92,7 @@ async fn main() -> Result<(), Error> { state: kittybox_indieauth::State::new(), code_challenge: PKCEChallenge::new(&verifier, PKCEMethod::default()), scope: Some(kittybox_indieauth::Scopes::new(args.scope)), - me: Some(args.me) + me: Some(args.me), }; let indieauth_url = { @@ -103,12 +103,18 @@ async fn main() -> Result<(), Error> { url }; - eprintln!("Please visit the following URL in your browser:\n\n {}\n", indieauth_url.as_str()); + eprintln!( + "Please visit the following URL in your browser:\n\n {}\n", + indieauth_url.as_str() + ); #[cfg(target_os = "linux")] - match std::process::Command::new("xdg-open").arg(indieauth_url.as_str()).spawn() { + match std::process::Command::new("xdg-open") + .arg(indieauth_url.as_str()) + .spawn() + { Ok(child) => drop(child), - Err(err) => eprintln!("Couldn't xdg-open: {}", err) + Err(err) => eprintln!("Couldn't xdg-open: {}", err), } if args.redirect_uri.is_some() { @@ -123,32 +129,38 @@ async fn main() -> Result<(), Error> { let tx = std::sync::Arc::new(tokio::sync::Mutex::new(Some(tx))); - let router = axum::Router::new() - .route("/callback", axum::routing::get( + let router = axum::Router::new().route( + "/callback", + axum::routing::get( move |Query(response): Query<AuthorizationResponse>| async move { if let Some(tx) = tx.lock_owned().await.take() { tx.send(response).unwrap(); - (axum::http::StatusCode::OK, - [("Content-Type", "text/plain")], - "Thank you! This window can now be closed.") + ( + axum::http::StatusCode::OK, + [("Content-Type", "text/plain")], + "Thank you! This window can now be closed.", + ) .into_response() } else { - (axum::http::StatusCode::BAD_REQUEST, - [("Content-Type", "text/plain")], - "Oops. The callback was already received. Did you click twice?") + ( + axum::http::StatusCode::BAD_REQUEST, + [("Content-Type", "text/plain")], + "Oops. The callback was already received. Did you click twice?", + ) .into_response() } - } - )); + }, + ), + ); - use std::net::{SocketAddr, IpAddr, Ipv4Addr}; + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; let server = axum::serve( - TcpListener::bind( - SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST),60000) - ).await.unwrap(), - router.into_make_service() + TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 60000)) + .await + .unwrap(), + router.into_make_service(), ); tokio::task::spawn(server.into_future()) @@ -175,12 +187,13 @@ async fn main() -> Result<(), Error> { #[cfg(not(debug_assertions))] std::process::exit(1); } - let response: Result<GrantResponse, IndieauthError> = http.post(metadata.token_endpoint) + let response: Result<GrantResponse, IndieauthError> = http + .post(metadata.token_endpoint) .form(&GrantRequest::AuthorizationCode { code: authorization_response.code, client_id: args.client_id, redirect_uri, - code_verifier: verifier + code_verifier: verifier, }) .header("Accept", "application/json") .send() @@ -201,9 +214,14 @@ async fn main() -> Result<(), Error> { refresh_token, scope, .. - } = response? { - eprintln!("Congratulations, {}, access token is ready! {}", - profile.as_ref().and_then(|p| p.name.as_deref()).unwrap_or(me.as_str()), + } = response? + { + eprintln!( + "Congratulations, {}, access token is ready! {}", + profile + .as_ref() + .and_then(|p| p.name.as_deref()) + .unwrap_or(me.as_str()), if let Some(exp) = expires_in { Cow::Owned(format!("It expires in {exp} seconds.")) } else { diff --git a/src/bin/kittybox-mf2.rs b/src/bin/kittybox-mf2.rs index 0cd89b4..b6f4999 100644 --- a/src/bin/kittybox-mf2.rs +++ b/src/bin/kittybox-mf2.rs @@ -37,8 +37,9 @@ async fn main() -> Result<(), Error> { .with_indent_lines(true) .with_verbose_exit(true), #[cfg(not(debug_assertions))] - tracing_subscriber::fmt::layer().json() - .with_ansi(std::io::IsTerminal::is_terminal(&std::io::stdout().lock())) + tracing_subscriber::fmt::layer() + .json() + .with_ansi(std::io::IsTerminal::is_terminal(&std::io::stdout().lock())), ); tracing_registry.init(); @@ -46,10 +47,11 @@ async fn main() -> Result<(), Error> { let http: reqwest::Client = { #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION") - )); + let mut builder = reqwest::Client::builder().user_agent(concat!( + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION") + )); builder.build().unwrap() }; diff --git a/src/companion.rs b/src/companion.rs new file mode 100644 index 0000000..debc266 --- /dev/null +++ b/src/companion.rs @@ -0,0 +1,84 @@ +use axum::{ + extract::{Extension, Path}, + response::{IntoResponse, Response}, +}; +use std::{collections::HashMap, sync::Arc}; + +#[derive(Debug, Clone, Copy)] +struct Resource { + data: &'static [u8], + mime: &'static str, +} + +impl IntoResponse for &Resource { + fn into_response(self) -> Response { + ( + axum::http::StatusCode::OK, + [("Content-Type", self.mime)], + self.data, + ) + .into_response() + } +} + +// TODO replace with the "phf" crate someday +type ResourceTable = Arc<HashMap<&'static str, Resource>>; + +#[tracing::instrument] +async fn map_to_static( + Path(name): Path<String>, + Extension(resources): Extension<ResourceTable>, +) -> Response { + tracing::debug!("Searching for {} in the resource table...", name); + match resources.get(name.as_str()) { + Some(res) => res.into_response(), + None => { + #[cfg(debug_assertions)] + tracing::error!("Not found"); + + ( + axum::http::StatusCode::NOT_FOUND, + [("Content-Type", "text/plain")], + "Not found. Sorry.".as_bytes(), + ) + .into_response() + } + } +} + +pub fn router<St: Clone + Send + Sync + 'static>() -> axum::Router<St> { + let resources: ResourceTable = { + let mut map = HashMap::new(); + + macro_rules! register_resource { + ($map:ident, $prefix:expr, ($filename:literal, $mime:literal)) => {{ + $map.insert($filename, Resource { + data: include_bytes!(concat!($prefix, $filename)), + mime: $mime + }) + }}; + ($map:ident, $prefix:expr, ($filename:literal, $mime:literal), $( ($f:literal, $m:literal) ),+) => {{ + register_resource!($map, $prefix, ($filename, $mime)); + register_resource!($map, $prefix, $(($f, $m)),+); + }}; + } + + register_resource! { + map, + concat!(env!("OUT_DIR"), "/", "companion", "/"), + ("index.html", "text/html; charset=\"utf-8\""), + ("main.js", "text/javascript"), + ("micropub_api.js", "text/javascript"), + ("indieauth.js", "text/javascript"), + ("base64.js", "text/javascript"), + ("style.css", "text/css") + }; + + Arc::new(map) + }; + + axum::Router::new().route( + "/{filename}", + axum::routing::get(map_to_static).layer(Extension(resources)), + ) +} diff --git a/src/database/file/mod.rs b/src/database/file/mod.rs index db9bb22..5c93beb 100644 --- a/src/database/file/mod.rs +++ b/src/database/file/mod.rs @@ -1,6 +1,6 @@ //#![warn(clippy::unwrap_used)] -use crate::database::{ErrorKind, Result, settings, Storage, StorageError}; -use crate::micropub::{MicropubUpdate, MicropubPropertyDeletion}; +use crate::database::{settings, ErrorKind, Result, Storage, StorageError}; +use crate::micropub::{MicropubPropertyDeletion, MicropubUpdate}; use futures::{stream, StreamExt, TryStreamExt}; use kittybox_util::MentionType; use serde_json::json; @@ -247,7 +247,9 @@ async fn hydrate_author<S: Storage>( impl Storage for FileStorage { async fn new(url: &'_ url::Url) -> Result<Self> { // TODO: sanity check - Ok(Self { root_dir: PathBuf::from(url.path()) }) + Ok(Self { + root_dir: PathBuf::from(url.path()), + }) } #[tracing::instrument(skip(self))] async fn categories(&self, url: &str) -> Result<Vec<String>> { @@ -259,7 +261,7 @@ impl Storage for FileStorage { // perform well. Err(std::io::Error::new( std::io::ErrorKind::Unsupported, - "?q=category queries are not implemented due to resource constraints" + "?q=category queries are not implemented due to resource constraints", ))? } @@ -340,7 +342,10 @@ impl Storage for FileStorage { file.sync_all().await?; drop(file); tokio::fs::rename(&tempfile, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + tokio::fs::File::open(path.parent().unwrap()) + .await? + .sync_all() + .await?; if let Some(urls) = post["properties"]["url"].as_array() { for url in urls.iter().map(|i| i.as_str().unwrap()) { @@ -350,8 +355,8 @@ impl Storage for FileStorage { "{}{}", url.host_str().unwrap(), url.port() - .map(|port| format!(":{}", port)) - .unwrap_or_default() + .map(|port| format!(":{}", port)) + .unwrap_or_default() ) }; if url != key && url_domain == user.authority() { @@ -410,26 +415,24 @@ impl Storage for FileStorage { .create(false) .open(&path) .await - { - Err(err) if err.kind() == std::io::ErrorKind::NotFound => { - Vec::default() - } - Err(err) => { - // Propagate the error upwards - return Err(err.into()); - } - Ok(mut file) => { - let mut content = String::new(); - file.read_to_string(&mut content).await?; - drop(file); - - if !content.is_empty() { - serde_json::from_str(&content)? - } else { - Vec::default() - } - } - } + { + Err(err) if err.kind() == std::io::ErrorKind::NotFound => Vec::default(), + Err(err) => { + // Propagate the error upwards + return Err(err.into()); + } + Ok(mut file) => { + let mut content = String::new(); + file.read_to_string(&mut content).await?; + drop(file); + + if !content.is_empty() { + serde_json::from_str(&content)? + } else { + Vec::default() + } + } + } }; channels.push(super::MicropubChannel { @@ -444,7 +447,10 @@ impl Storage for FileStorage { tempfile.sync_all().await?; drop(tempfile); tokio::fs::rename(tempfilename, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + tokio::fs::File::open(path.parent().unwrap()) + .await? + .sync_all() + .await?; } Ok(()) } @@ -476,7 +482,10 @@ impl Storage for FileStorage { temp.sync_all().await?; drop(temp); tokio::fs::rename(tempfilename, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + tokio::fs::File::open(path.parent().unwrap()) + .await? + .sync_all() + .await?; (json, new_json) }; @@ -486,7 +495,9 @@ impl Storage for FileStorage { #[tracing::instrument(skip(self, f), fields(f = std::any::type_name::<F>()))] async fn update_with<F: FnOnce(&mut serde_json::Value) + Send>( - &self, url: &str, f: F + &self, + url: &str, + f: F, ) -> Result<(serde_json::Value, serde_json::Value)> { todo!("update_with is not yet implemented due to special requirements of the file backend") } @@ -526,25 +537,25 @@ impl Storage for FileStorage { url: &'_ str, cursor: Option<&'_ str>, limit: usize, - user: Option<&url::Url> + user: Option<&url::Url>, ) -> Result<Option<(serde_json::Value, Option<String>)>> { #[allow(deprecated)] - Ok(self.read_feed_with_limit( - url, - cursor, - limit, - user - ).await? + Ok(self + .read_feed_with_limit(url, cursor, limit, user) + .await? .map(|feed| { - tracing::debug!("Feed: {:#}", serde_json::Value::Array( - feed["children"] - .as_array() - .map(|v| v.as_slice()) - .unwrap_or_default() - .iter() - .map(|mf2| mf2["properties"]["uid"][0].clone()) - .collect::<Vec<_>>() - )); + tracing::debug!( + "Feed: {:#}", + serde_json::Value::Array( + feed["children"] + .as_array() + .map(|v| v.as_slice()) + .unwrap_or_default() + .iter() + .map(|mf2| mf2["properties"]["uid"][0].clone()) + .collect::<Vec<_>>() + ) + ); let cursor: Option<String> = feed["children"] .as_array() .map(|v| v.as_slice()) @@ -553,8 +564,7 @@ impl Storage for FileStorage { .map(|v| v["properties"]["uid"][0].as_str().unwrap().to_owned()); tracing::debug!("Extracted the cursor: {:?}", cursor); (feed, cursor) - }) - ) + })) } #[tracing::instrument(skip(self))] @@ -574,9 +584,12 @@ impl Storage for FileStorage { let children: Vec<serde_json::Value> = match feed["children"].take() { serde_json::Value::Array(children) => children, // We've already checked it's an array - _ => unreachable!() + _ => unreachable!(), }; - tracing::debug!("Full children array: {:#}", serde_json::Value::Array(children.clone())); + tracing::debug!( + "Full children array: {:#}", + serde_json::Value::Array(children.clone()) + ); let mut posts_iter = children .into_iter() .map(|s: serde_json::Value| s.as_str().unwrap().to_string()); @@ -589,7 +602,7 @@ impl Storage for FileStorage { // incredibly long feeds. if let Some(after) = after { tokio::task::block_in_place(|| { - for s in posts_iter.by_ref() { + for s in posts_iter.by_ref() { if s == after { break; } @@ -603,6 +616,7 @@ impl Storage for FileStorage { // Broken links return None, and Stream::filter_map skips Nones. .try_filter_map(|post: Option<serde_json::Value>| async move { Ok(post) }) .and_then(|mut post| async move { + // XXX: N+1 problem, potential sanitization issues hydrate_author(&mut post, user, self).await; Ok(post) }) @@ -654,12 +668,19 @@ impl Storage for FileStorage { let settings: HashMap<&str, serde_json::Value> = serde_json::from_str(&content)?; match settings.get(S::ID) { Some(value) => Ok(serde_json::from_value::<S>(value.clone())?), - None => Err(StorageError::from_static(ErrorKind::Backend, "Setting not set")) + None => Err(StorageError::from_static( + ErrorKind::Backend, + "Setting not set", + )), } } #[tracing::instrument(skip(self))] - async fn set_setting<S: settings::Setting>(&self, user: &url::Url, value: S::Data) -> Result<()> { + async fn set_setting<S: settings::Setting>( + &self, + user: &url::Url, + value: S::Data, + ) -> Result<()> { let mut path = relative_path::RelativePathBuf::new(); path.push(user.authority()); path.push("settings"); @@ -703,20 +724,28 @@ impl Storage for FileStorage { tempfile.sync_all().await?; drop(tempfile); tokio::fs::rename(temppath, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + tokio::fs::File::open(path.parent().unwrap()) + .await? + .sync_all() + .await?; Ok(()) } #[tracing::instrument(skip(self))] - async fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> Result<()> { + async fn add_or_update_webmention( + &self, + target: &str, + mention_type: MentionType, + mention: serde_json::Value, + ) -> Result<()> { let path = url_to_path(&self.root_dir, target); let tempfilename = path.with_extension("tmp"); let mut temp = OpenOptions::new() - .write(true) - .create_new(true) - .open(&tempfilename) - .await?; + .write(true) + .create_new(true) + .open(&tempfilename) + .await?; let mut file = OpenOptions::new().read(true).open(&path).await?; let mut post: serde_json::Value = { @@ -751,13 +780,20 @@ impl Storage for FileStorage { temp.sync_all().await?; drop(temp); tokio::fs::rename(tempfilename, &path).await?; - tokio::fs::File::open(path.parent().unwrap()).await?.sync_all().await?; + tokio::fs::File::open(path.parent().unwrap()) + .await? + .sync_all() + .await?; Ok(()) } - async fn all_posts<'this>(&'this self, user: &url::Url) -> Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this> { + async fn all_posts<'this>( + &'this self, + user: &url::Url, + ) -> Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this> { todo!(); + #[allow(unreachable_code)] Ok(futures::stream::empty()) // for type inference } } diff --git a/src/database/memory.rs b/src/database/memory.rs index 412deef..75f04de 100644 --- a/src/database/memory.rs +++ b/src/database/memory.rs @@ -1,13 +1,14 @@ -#![allow(clippy::todo)] +#![allow(clippy::todo, missing_docs)] use futures_util::FutureExt; use serde_json::json; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; -use crate::database::{ErrorKind, MicropubChannel, Result, settings, Storage, StorageError}; +use crate::database::{settings, ErrorKind, MicropubChannel, Result, Storage, StorageError}; #[derive(Clone, Debug, Default)] +/// A simple in-memory store for testing purposes. pub struct MemoryStorage { pub mapping: Arc<RwLock<HashMap<String, serde_json::Value>>>, pub channels: Arc<RwLock<HashMap<url::Url, Vec<String>>>>, @@ -89,9 +90,16 @@ impl Storage for MemoryStorage { Ok(()) } - async fn update_post(&self, url: &'_ str, update: crate::micropub::MicropubUpdate) -> Result<()> { + async fn update_post( + &self, + url: &'_ str, + update: crate::micropub::MicropubUpdate, + ) -> Result<()> { let mut guard = self.mapping.write().await; - let mut post = guard.get_mut(url).ok_or(StorageError::from_static(ErrorKind::NotFound, "The specified post wasn't found in the database."))?; + let mut post = guard.get_mut(url).ok_or(StorageError::from_static( + ErrorKind::NotFound, + "The specified post wasn't found in the database.", + ))?; use crate::micropub::MicropubPropertyDeletion; @@ -207,7 +215,7 @@ impl Storage for MemoryStorage { url: &'_ str, cursor: Option<&'_ str>, limit: usize, - user: Option<&url::Url> + user: Option<&url::Url>, ) -> Result<Option<(serde_json::Value, Option<String>)>> { todo!() } @@ -223,25 +231,39 @@ impl Storage for MemoryStorage { } #[allow(unused_variables)] - async fn set_setting<S: settings::Setting>(&self, user: &url::Url, value: S::Data) -> Result<()> { + async fn set_setting<S: settings::Setting>( + &self, + user: &url::Url, + value: S::Data, + ) -> Result<()> { todo!() } #[allow(unused_variables)] - async fn add_or_update_webmention(&self, target: &str, mention_type: kittybox_util::MentionType, mention: serde_json::Value) -> Result<()> { + async fn add_or_update_webmention( + &self, + target: &str, + mention_type: kittybox_util::MentionType, + mention: serde_json::Value, + ) -> Result<()> { todo!() } #[allow(unused_variables)] async fn update_with<F: FnOnce(&mut serde_json::Value) + Send>( - &self, url: &str, f: F + &self, + url: &str, + f: F, ) -> Result<(serde_json::Value, serde_json::Value)> { todo!() } - async fn all_posts<'this>(&'this self, user: &url::Url) -> Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this> { + async fn all_posts<'this>( + &'this self, + _user: &url::Url, + ) -> Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this> { todo!(); + #[allow(unreachable_code)] Ok(futures::stream::pending()) } - } diff --git a/src/database/mod.rs b/src/database/mod.rs index 4390ae7..fb6f43c 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -43,71 +43,7 @@ pub enum ErrorKind { } /// Settings that can be stored in the database. -pub mod settings { - mod private { - pub trait Sealed {} - } - - /// A trait for various settings that should be contained here. - /// - /// **Note**: this trait is sealed to prevent external - /// implementations, as it wouldn't make sense to add new settings - /// that aren't used by Kittybox itself. - pub trait Setting: private::Sealed + std::fmt::Debug + Default + Clone + serde::Serialize + serde::de::DeserializeOwned + /*From<Settings> +*/ Send + Sync + 'static { - /// The data that the setting carries. - type Data: std::fmt::Debug + Send + Sync; - /// The string ID for the setting, usable as an identifier in the database. - const ID: &'static str; - - /// Unwrap the setting type, returning owned data contained within. - fn into_inner(self) -> Self::Data; - /// Create a new instance of this type containing certain data. - fn new(data: Self::Data) -> Self; - } - - /// A website's title, shown in the header. - #[derive(Debug, serde::Deserialize, serde::Serialize, Clone, PartialEq, Eq)] - pub struct SiteName(pub(crate) String); - impl Default for SiteName { - fn default() -> Self { - Self("Kittybox".to_string()) - } - } - impl AsRef<str> for SiteName { - fn as_ref(&self) -> &str { - self.0.as_str() - } - } - impl private::Sealed for SiteName {} - impl Setting for SiteName { - type Data = String; - const ID: &'static str = "site_name"; - - fn into_inner(self) -> String { - self.0 - } - fn new(data: Self::Data) -> Self { - Self(data) - } - } - - /// Participation status in the IndieWeb Webring: https://🕸💍.ws/dashboard - #[derive(Debug, Default, serde::Deserialize, serde::Serialize, Clone, Copy, PartialEq, Eq)] - pub struct Webring(bool); - impl private::Sealed for Webring {} - impl Setting for Webring { - type Data = bool; - const ID: &'static str = "webring"; - - fn into_inner(self) -> Self::Data { - self.0 - } - - fn new(data: Self::Data) -> Self { - Self(data) - } - } -} +pub mod settings; /// Error signalled from the database. #[derive(Debug)] @@ -177,7 +113,7 @@ impl StorageError { Self { msg: Cow::Borrowed(msg), source: None, - kind + kind, } } /// Create a StorageError using another arbitrary Error as a source. @@ -219,27 +155,34 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { fn post_exists(&self, url: &str) -> impl Future<Output = Result<bool>> + Send; /// Load a post from the database in MF2-JSON format, deserialized from JSON. - fn get_post(&self, url: &str) -> impl Future<Output = Result<Option<serde_json::Value>>> + Send; + fn get_post(&self, url: &str) + -> impl Future<Output = Result<Option<serde_json::Value>>> + Send; /// Save a post to the database as an MF2-JSON structure. /// /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined. - fn put_post(&self, post: &serde_json::Value, user: &url::Url) -> impl Future<Output = Result<()>> + Send; + fn put_post( + &self, + post: &serde_json::Value, + user: &url::Url, + ) -> impl Future<Output = Result<()>> + Send; /// Add post to feed. Some database implementations might have optimized ways to do this. #[tracing::instrument(skip(self))] fn add_to_feed(&self, feed: &str, post: &str) -> impl Future<Output = Result<()>> + Send { tracing::debug!("Inserting {} into {} using `update_post`", post, feed); - self.update_post(feed, serde_json::from_value( - serde_json::json!({"add": {"children": [post]}})).unwrap() + self.update_post( + feed, + serde_json::from_value(serde_json::json!({"add": {"children": [post]}})).unwrap(), ) } /// Remove post from feed. Some database implementations might have optimized ways to do this. #[tracing::instrument(skip(self))] fn remove_from_feed(&self, feed: &str, post: &str) -> impl Future<Output = Result<()>> + Send { tracing::debug!("Removing {} into {} using `update_post`", post, feed); - self.update_post(feed, serde_json::from_value( - serde_json::json!({"delete": {"children": [post]}})).unwrap() + self.update_post( + feed, + serde_json::from_value(serde_json::json!({"delete": {"children": [post]}})).unwrap(), ) } @@ -254,7 +197,11 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { /// /// Default implementation calls [`Storage::update_with`] and uses /// [`update.apply`][MicropubUpdate::apply] to update the post. - fn update_post(&self, url: &str, update: MicropubUpdate) -> impl Future<Output = Result<()>> + Send { + fn update_post( + &self, + url: &str, + update: MicropubUpdate, + ) -> impl Future<Output = Result<()>> + Send { let fut = self.update_with(url, |post| { update.apply(post); }); @@ -274,12 +221,17 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { /// /// Returns old post and the new post after editing. fn update_with<F: FnOnce(&mut serde_json::Value) + Send>( - &self, url: &str, f: F + &self, + url: &str, + f: F, ) -> impl Future<Output = Result<(serde_json::Value, serde_json::Value)>> + Send; /// Get a list of channels available for the user represented by /// the `user` domain to write to. - fn get_channels(&self, user: &url::Url) -> impl Future<Output = Result<Vec<MicropubChannel>>> + Send; + fn get_channels( + &self, + user: &url::Url, + ) -> impl Future<Output = Result<Vec<MicropubChannel>>> + Send; /// Fetch a feed at `url` and return an h-feed object containing /// `limit` posts after a post by url `after`, filtering the content @@ -329,7 +281,7 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { url: &'_ str, cursor: Option<&'_ str>, limit: usize, - user: Option<&url::Url> + user: Option<&url::Url>, ) -> impl Future<Output = Result<Option<(serde_json::Value, Option<String>)>>> + Send; /// Deletes a post from the database irreversibly. Must be idempotent. @@ -339,7 +291,11 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { fn get_setting<S: Setting>(&self, user: &url::Url) -> impl Future<Output = Result<S>> + Send; /// Commits a setting to the setting store. - fn set_setting<S: Setting>(&self, user: &url::Url, value: S::Data) -> impl Future<Output = Result<()>> + Send; + fn set_setting<S: Setting>( + &self, + user: &url::Url, + value: S::Data, + ) -> impl Future<Output = Result<()>> + Send; /// Add (or update) a webmention on a certian post. /// @@ -355,11 +311,19 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { /// /// Besides, it may even allow for nice tricks like storing the /// webmentions separately and rehydrating them on feed reads. - fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> impl Future<Output = Result<()>> + Send; + fn add_or_update_webmention( + &self, + target: &str, + mention_type: MentionType, + mention: serde_json::Value, + ) -> impl Future<Output = Result<()>> + Send; /// Return a stream of all posts ever made by a certain user, in /// reverse-chronological order. - fn all_posts<'this>(&'this self, user: &url::Url) -> impl Future<Output = Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this>> + Send; + fn all_posts<'this>( + &'this self, + user: &url::Url, + ) -> impl Future<Output = Result<impl futures::Stream<Item = serde_json::Value> + Send + 'this>> + Send; } #[cfg(test)] @@ -464,7 +428,8 @@ mod tests { "replace": { "content": ["Different test content"] } - })).unwrap(), + })) + .unwrap(), ) .await .unwrap(); @@ -511,7 +476,10 @@ mod tests { .put_post(&feed, &"https://fireburn.ru/".parse().unwrap()) .await .unwrap(); - let chans = backend.get_channels(&"https://fireburn.ru/".parse().unwrap()).await.unwrap(); + let chans = backend + .get_channels(&"https://fireburn.ru/".parse().unwrap()) + .await + .unwrap(); assert_eq!(chans.len(), 1); assert_eq!( chans[0], @@ -526,16 +494,16 @@ mod tests { backend .set_setting::<settings::SiteName>( &"https://fireburn.ru/".parse().unwrap(), - "Vika's Hideout".to_owned() + "Vika's Hideout".to_owned(), ) .await .unwrap(); assert_eq!( backend - .get_setting::<settings::SiteName>(&"https://fireburn.ru/".parse().unwrap()) - .await - .unwrap() - .as_ref(), + .get_setting::<settings::SiteName>(&"https://fireburn.ru/".parse().unwrap()) + .await + .unwrap() + .as_ref(), "Vika's Hideout" ); } @@ -597,11 +565,9 @@ mod tests { async fn test_feed_pagination<Backend: Storage>(backend: Backend) { let posts = { - let mut posts = std::iter::from_fn( - || Some(gen_random_post("fireburn.ru")) - ) - .take(40) - .collect::<Vec<serde_json::Value>>(); + let mut posts = std::iter::from_fn(|| Some(gen_random_post("fireburn.ru"))) + .take(40) + .collect::<Vec<serde_json::Value>>(); // Reverse the array so it's in reverse-chronological order posts.reverse(); @@ -629,7 +595,10 @@ mod tests { .put_post(post, &"https://fireburn.ru/".parse().unwrap()) .await .unwrap(); - backend.add_to_feed(key, post["properties"]["uid"][0].as_str().unwrap()).await.unwrap(); + backend + .add_to_feed(key, post["properties"]["uid"][0].as_str().unwrap()) + .await + .unwrap(); } let limit: usize = 10; @@ -648,23 +617,16 @@ mod tests { .unwrap() .iter() .map(|post| post["properties"]["uid"][0].as_str().unwrap()) - .collect::<Vec<_>>() - [0..10], + .collect::<Vec<_>>()[0..10], posts .iter() .map(|post| post["properties"]["uid"][0].as_str().unwrap()) - .collect::<Vec<_>>() - [0..10] + .collect::<Vec<_>>()[0..10] ); tracing::debug!("Continuing with cursor: {:?}", cursor); let (result2, cursor2) = backend - .read_feed_with_cursor( - key, - cursor.as_deref(), - limit, - None, - ) + .read_feed_with_cursor(key, cursor.as_deref(), limit, None) .await .unwrap() .unwrap(); @@ -676,12 +638,7 @@ mod tests { tracing::debug!("Continuing with cursor: {:?}", cursor); let (result3, cursor3) = backend - .read_feed_with_cursor( - key, - cursor2.as_deref(), - limit, - None, - ) + .read_feed_with_cursor(key, cursor2.as_deref(), limit, None) .await .unwrap() .unwrap(); @@ -693,12 +650,7 @@ mod tests { tracing::debug!("Continuing with cursor: {:?}", cursor); let (result4, _) = backend - .read_feed_with_cursor( - key, - cursor3.as_deref(), - limit, - None, - ) + .read_feed_with_cursor(key, cursor3.as_deref(), limit, None) .await .unwrap() .unwrap(); @@ -725,24 +677,43 @@ mod tests { async fn test_webmention_addition<Backend: Storage>(db: Backend) { let post = gen_random_post("fireburn.ru"); - db.put_post(&post, &"https://fireburn.ru/".parse().unwrap()).await.unwrap(); + db.put_post(&post, &"https://fireburn.ru/".parse().unwrap()) + .await + .unwrap(); const TYPE: MentionType = MentionType::Reply; let target = post["properties"]["uid"][0].as_str().unwrap(); let mut reply = gen_random_mention("aaronparecki.com", TYPE, target); - let (read_post, _) = db.read_feed_with_cursor(target, None, 20, None).await.unwrap().unwrap(); + let (read_post, _) = db + .read_feed_with_cursor(target, None, 20, None) + .await + .unwrap() + .unwrap(); assert_eq!(post, read_post); - db.add_or_update_webmention(target, TYPE, reply.clone()).await.unwrap(); + db.add_or_update_webmention(target, TYPE, reply.clone()) + .await + .unwrap(); - let (read_post, _) = db.read_feed_with_cursor(target, None, 20, None).await.unwrap().unwrap(); + let (read_post, _) = db + .read_feed_with_cursor(target, None, 20, None) + .await + .unwrap() + .unwrap(); assert_eq!(read_post["properties"]["comment"][0], reply); - reply["properties"]["content"][0] = json!(rand::random::<faker_rand::lorem::Paragraphs>().to_string()); + reply["properties"]["content"][0] = + json!(rand::random::<faker_rand::lorem::Paragraphs>().to_string()); - db.add_or_update_webmention(target, TYPE, reply.clone()).await.unwrap(); - let (read_post, _) = db.read_feed_with_cursor(target, None, 20, None).await.unwrap().unwrap(); + db.add_or_update_webmention(target, TYPE, reply.clone()) + .await + .unwrap(); + let (read_post, _) = db + .read_feed_with_cursor(target, None, 20, None) + .await + .unwrap() + .unwrap(); assert_eq!(read_post["properties"]["comment"][0], reply); } @@ -752,16 +723,20 @@ mod tests { let post = { let mut post = gen_random_post("fireburn.ru"); let urls = post["properties"]["url"].as_array_mut().unwrap(); - urls.push(serde_json::Value::String( - PERMALINK.to_owned() - )); + urls.push(serde_json::Value::String(PERMALINK.to_owned())); post }; - db.put_post(&post, &"https://fireburn.ru/".parse().unwrap()).await.unwrap(); + db.put_post(&post, &"https://fireburn.ru/".parse().unwrap()) + .await + .unwrap(); for i in post["properties"]["url"].as_array().unwrap() { - let (read_post, _) = db.read_feed_with_cursor(i.as_str().unwrap(), None, 20, None).await.unwrap().unwrap(); + let (read_post, _) = db + .read_feed_with_cursor(i.as_str().unwrap(), None, 20, None) + .await + .unwrap() + .unwrap(); assert_eq!(read_post, post); } } @@ -786,7 +761,7 @@ mod tests { async fn $func_name() { let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); let backend = super::super::FileStorage { - root_dir: tempdir.path().to_path_buf() + root_dir: tempdir.path().to_path_buf(), }; super::$func_name(backend).await } @@ -800,7 +775,7 @@ mod tests { #[tracing_test::traced_test] async fn $func_name( pool_opts: sqlx::postgres::PgPoolOptions, - connect_opts: sqlx::postgres::PgConnectOptions + connect_opts: sqlx::postgres::PgConnectOptions, ) -> Result<(), sqlx::Error> { let db = { //use sqlx::ConnectOptions; diff --git a/src/database/postgres/mod.rs b/src/database/postgres/mod.rs index af19fea..ec67efa 100644 --- a/src/database/postgres/mod.rs +++ b/src/database/postgres/mod.rs @@ -5,7 +5,7 @@ use kittybox_util::{micropub::Channel as MicropubChannel, MentionType}; use sqlx::{ConnectOptions, Executor, PgPool}; use super::settings::Setting; -use super::{Storage, Result, StorageError, ErrorKind}; +use super::{ErrorKind, Result, Storage, StorageError}; static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!(); @@ -14,7 +14,7 @@ impl From<sqlx::Error> for StorageError { Self::with_source( super::ErrorKind::Backend, Cow::Owned(format!("sqlx error: {}", &value)), - Box::new(value) + Box::new(value), ) } } @@ -24,7 +24,7 @@ impl From<sqlx::migrate::MigrateError> for StorageError { Self::with_source( super::ErrorKind::Backend, Cow::Owned(format!("sqlx migration error: {}", &value)), - Box::new(value) + Box::new(value), ) } } @@ -32,14 +32,15 @@ impl From<sqlx::migrate::MigrateError> for StorageError { /// Micropub storage that uses a PostgreSQL database. #[derive(Debug, Clone)] pub struct PostgresStorage { - db: PgPool + db: PgPool, } impl PostgresStorage { /// Construct a [`PostgresStorage`] from a [`sqlx::PgPool`], /// running appropriate migrations. pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result<Self> { - db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox")).await?; + db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox")) + .await?; MIGRATOR.run(&db).await?; Ok(Self { db }) } @@ -50,19 +51,22 @@ impl Storage for PostgresStorage { /// migrations on the database. async fn new(url: &'_ url::Url) -> Result<Self> { tracing::debug!("Postgres URL: {url}"); - let options = sqlx::postgres::PgConnectOptions::from_url(url)? - .options([("search_path", "kittybox")]); + let options = + sqlx::postgres::PgConnectOptions::from_url(url)?.options([("search_path", "kittybox")]); Self::from_pool( sqlx::postgres::PgPoolOptions::new() .max_connections(50) .connect_with(options) - .await? - ).await - + .await?, + ) + .await } - async fn all_posts<'this>(&'this self, user: &url::Url) -> Result<impl Stream<Item = serde_json::Value> + Send + 'this> { + async fn all_posts<'this>( + &'this self, + user: &url::Url, + ) -> Result<impl Stream<Item = serde_json::Value> + Send + 'this> { let authority = user.authority().to_owned(); Ok( sqlx::query_scalar::<_, serde_json::Value>("SELECT mf2 FROM kittybox.mf2_json WHERE owner = $1 ORDER BY (mf2_json.mf2 #>> '{properties,published,0}') DESC") @@ -74,18 +78,20 @@ impl Storage for PostgresStorage { #[tracing::instrument(skip(self))] async fn categories(&self, url: &str) -> Result<Vec<String>> { - sqlx::query_scalar::<_, String>(" + sqlx::query_scalar::<_, String>( + " SELECT jsonb_array_elements(mf2['properties']['category']) AS category FROM kittybox.mf2_json WHERE jsonb_typeof(mf2['properties']['category']) = 'array' AND uid LIKE ($1 + '%') GROUP BY category ORDER BY count(*) DESC -") - .bind(url) - .fetch_all(&self.db) - .await - .map_err(|err| err.into()) +", + ) + .bind(url) + .fetch_all(&self.db) + .await + .map_err(|err| err.into()) } #[tracing::instrument(skip(self))] async fn post_exists(&self, url: &str) -> Result<bool> { @@ -98,13 +104,14 @@ WHERE #[tracing::instrument(skip(self))] async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> { - sqlx::query_as::<_, (serde_json::Value,)>("SELECT mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1") - .bind(url) - .fetch_optional(&self.db) - .await - .map(|v| v.map(|v| v.0)) - .map_err(|err| err.into()) - + sqlx::query_as::<_, (serde_json::Value,)>( + "SELECT mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1", + ) + .bind(url) + .fetch_optional(&self.db) + .await + .map(|v| v.map(|v| v.0)) + .map_err(|err| err.into()) } #[tracing::instrument(skip(self))] @@ -122,13 +129,15 @@ WHERE #[tracing::instrument(skip(self))] async fn add_to_feed(&self, feed: &'_ str, post: &'_ str) -> Result<()> { tracing::debug!("Inserting {} into {}", post, feed); - sqlx::query("INSERT INTO kittybox.children (parent, child) VALUES ($1, $2) ON CONFLICT DO NOTHING") - .bind(feed) - .bind(post) - .execute(&self.db) - .await - .map(|_| ()) - .map_err(Into::into) + sqlx::query( + "INSERT INTO kittybox.children (parent, child) VALUES ($1, $2) ON CONFLICT DO NOTHING", + ) + .bind(feed) + .bind(post) + .execute(&self.db) + .await + .map(|_| ()) + .map_err(Into::into) } #[tracing::instrument(skip(self))] @@ -143,7 +152,12 @@ WHERE } #[tracing::instrument(skip(self))] - async fn add_or_update_webmention(&self, target: &str, mention_type: MentionType, mention: serde_json::Value) -> Result<()> { + async fn add_or_update_webmention( + &self, + target: &str, + mention_type: MentionType, + mention: serde_json::Value, + ) -> Result<()> { let mut txn = self.db.begin().await?; let (uid, mut post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE") @@ -190,7 +204,9 @@ WHERE #[tracing::instrument(skip(self), fields(f = std::any::type_name::<F>()))] async fn update_with<F: FnOnce(&mut serde_json::Value) + Send>( - &self, url: &str, f: F + &self, + url: &str, + f: F, ) -> Result<(serde_json::Value, serde_json::Value)> { tracing::debug!("Updating post {}", url); let mut txn = self.db.begin().await?; @@ -250,12 +266,12 @@ WHERE url: &'_ str, cursor: Option<&'_ str>, limit: usize, - user: Option<&url::Url> + user: Option<&url::Url>, ) -> Result<Option<(serde_json::Value, Option<String>)>> { let mut txn = self.db.begin().await?; sqlx::query("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ ONLY") - .execute(&mut *txn) - .await?; + .execute(&mut *txn) + .await?; tracing::debug!("Started txn: {:?}", txn); let mut feed = match sqlx::query_scalar::<_, serde_json::Value>(" SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 @@ -273,11 +289,17 @@ SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json WHERE uid = $1 OR mf2 // The second query is very long and will probably be extremely // expensive. It's best to skip it on types where it doesn't make sense // (Kittybox doesn't support rendering children on non-feeds) - if !feed["type"].as_array().unwrap().iter().any(|t| *t == serde_json::json!("h-feed")) { + if !feed["type"] + .as_array() + .unwrap() + .iter() + .any(|t| *t == serde_json::json!("h-feed")) + { return Ok(Some((feed, None))); } - feed["children"] = sqlx::query_scalar::<_, serde_json::Value>(" + feed["children"] = sqlx::query_scalar::<_, serde_json::Value>( + " SELECT kittybox.hydrate_author(mf2) FROM kittybox.mf2_json INNER JOIN kittybox.children ON mf2_json.uid = children.child @@ -302,17 +324,19 @@ WHERE ) AND ($4 IS NULL OR ((mf2_json.mf2 #>> '{properties,published,0}') < $4)) ORDER BY (mf2_json.mf2 #>> '{properties,published,0}') DESC -LIMIT $2" +LIMIT $2", ) - .bind(url) - .bind(limit as i64) - .bind(user.map(url::Url::as_str)) - .bind(cursor) - .fetch_all(&mut *txn) - .await - .map(serde_json::Value::Array)?; - - let new_cursor = feed["children"].as_array().unwrap() + .bind(url) + .bind(limit as i64) + .bind(user.map(url::Url::as_str)) + .bind(cursor) + .fetch_all(&mut *txn) + .await + .map(serde_json::Value::Array)?; + + let new_cursor = feed["children"] + .as_array() + .unwrap() .last() .map(|v| v["properties"]["published"][0].as_str().unwrap().to_owned()); @@ -335,7 +359,7 @@ LIMIT $2" .await { Ok((value,)) => Ok(serde_json::from_value(value)?), - Err(err) => Err(err.into()) + Err(err) => Err(err.into()), } } diff --git a/src/database/settings.rs b/src/database/settings.rs new file mode 100644 index 0000000..792a155 --- /dev/null +++ b/src/database/settings.rs @@ -0,0 +1,63 @@ +mod private { + pub trait Sealed {} +} + +/// A trait for various settings that should be contained here. +/// +/// **Note**: this trait is sealed to prevent external +/// implementations, as it wouldn't make sense to add new settings +/// that aren't used by Kittybox itself. +pub trait Setting: private::Sealed + std::fmt::Debug + Default + Clone + serde::Serialize + serde::de::DeserializeOwned + /*From<Settings> +*/ Send + Sync + 'static { + /// The data that the setting carries. + type Data: std::fmt::Debug + Send + Sync; + /// The string ID for the setting, usable as an identifier in the database. + const ID: &'static str; + + /// Unwrap the setting type, returning owned data contained within. + fn into_inner(self) -> Self::Data; + /// Create a new instance of this type containing certain data. + fn new(data: Self::Data) -> Self; +} + +/// A website's title, shown in the header. +#[derive(Debug, serde::Deserialize, serde::Serialize, Clone, PartialEq, Eq)] +pub struct SiteName(pub(crate) String); +impl Default for SiteName { + fn default() -> Self { + Self("Kittybox".to_string()) + } +} +impl AsRef<str> for SiteName { + fn as_ref(&self) -> &str { + self.0.as_str() + } +} +impl private::Sealed for SiteName {} +impl Setting for SiteName { + type Data = String; + const ID: &'static str = "site_name"; + + fn into_inner(self) -> String { + self.0 + } + fn new(data: Self::Data) -> Self { + Self(data) + } +} + +/// Participation status in the IndieWeb Webring: https://🕸💍.ws/dashboard +#[derive(Debug, Default, serde::Deserialize, serde::Serialize, Clone, Copy, PartialEq, Eq)] +pub struct Webring(bool); +impl private::Sealed for Webring {} +impl Setting for Webring { + type Data = bool; + const ID: &'static str = "webring"; + + fn into_inner(self) -> Self::Data { + self.0 + } + + fn new(data: Self::Data) -> Self { + Self(data) + } +} diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 9ba1a69..94b8aa7 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -12,12 +12,10 @@ use tracing::{debug, error}; //pub mod login; pub mod onboarding; +pub use kittybox_frontend_renderer::assets::statics; use kittybox_frontend_renderer::{ - Entry, Feed, VCard, - ErrorPage, Template, MainPage, - POSTS_PER_PAGE + Entry, ErrorPage, Feed, MainPage, Template, VCard, POSTS_PER_PAGE, }; -pub use kittybox_frontend_renderer::assets::statics; #[derive(Debug, Deserialize)] pub struct QueryParams { @@ -106,7 +104,7 @@ pub fn filter_post( .map(|i| -> &str { match i { serde_json::Value::String(ref author) => author.as_str(), - mf2 => mf2["properties"]["uid"][0].as_str().unwrap() + mf2 => mf2["properties"]["uid"][0].as_str().unwrap(), } }) .map(|i| i.parse().unwrap()) @@ -116,11 +114,13 @@ pub fn filter_post( .unwrap_or("public"); let audience = { let mut audience = author_list.clone(); - audience.extend(post["properties"]["audience"] - .as_array() - .unwrap_or(&empty_vec) - .iter() - .map(|i| i.as_str().unwrap().parse().unwrap())); + audience.extend( + post["properties"]["audience"] + .as_array() + .unwrap_or(&empty_vec) + .iter() + .map(|i| i.as_str().unwrap().parse().unwrap()), + ); audience }; @@ -134,7 +134,10 @@ pub fn filter_post( let location_visibility = post["properties"]["location-visibility"][0] .as_str() .unwrap_or("private"); - tracing::debug!("Post contains location, location privacy = {}", location_visibility); + tracing::debug!( + "Post contains location, location privacy = {}", + location_visibility + ); let mut author = post["properties"]["author"] .as_array() .unwrap_or(&empty_vec) @@ -155,16 +158,18 @@ pub fn filter_post( post["properties"]["author"] = serde_json::Value::Array( children .into_iter() - .filter_map(|post| if post.is_string() { - Some(post) - } else { - filter_post(post, user) + .filter_map(|post| { + if post.is_string() { + Some(post) + } else { + filter_post(post, user) + } }) - .collect::<Vec<serde_json::Value>>() + .collect::<Vec<serde_json::Value>>(), ); - }, - serde_json::Value::Null => {}, - other => post["properties"]["author"] = other + } + serde_json::Value::Null => {} + other => post["properties"]["author"] = other, } match post["children"].take() { @@ -173,11 +178,11 @@ pub fn filter_post( children .into_iter() .filter_map(|post| filter_post(post, user)) - .collect::<Vec<serde_json::Value>>() + .collect::<Vec<serde_json::Value>>(), ); - }, - serde_json::Value::Null => {}, - other => post["children"] = other + } + serde_json::Value::Null => {} + other => post["children"] = other, } Some(post) } @@ -209,7 +214,7 @@ async fn get_post_from_database<S: Storage>( )) } } - } + }, None => Err(FrontendError::with_code( StatusCode::NOT_FOUND, "Post not found in the database", @@ -240,7 +245,7 @@ pub async fn homepage<D: Storage>( Host(host): Host, Query(query): Query<QueryParams>, State(db): State<D>, - session: Option<crate::Session> + session: Option<crate::Session>, ) -> impl IntoResponse { // This is stupid, but there is no other way. let hcard_url: url::Url = format!("https://{}/", host).parse().unwrap(); @@ -252,7 +257,7 @@ pub async fn homepage<D: Storage>( ); headers.insert( axum::http::header::X_CONTENT_TYPE_OPTIONS, - axum::http::HeaderValue::from_static("nosniff") + axum::http::HeaderValue::from_static("nosniff"), ); let user = session.as_deref().map(|s| &s.me); @@ -268,18 +273,16 @@ pub async fn homepage<D: Storage>( // btw is it more efficient to fetch these in parallel? let (blogname, webring, channels) = tokio::join!( db.get_setting::<crate::database::settings::SiteName>(&hcard_url) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_setting::<crate::database::settings::Webring>(&hcard_url) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_channels(&hcard_url).map(|i| i.unwrap_or_default()) ); if user.is_some() { headers.insert( axum::http::header::CACHE_CONTROL, - axum::http::HeaderValue::from_static("private") + axum::http::HeaderValue::from_static("private"), ); } // Render the homepage @@ -295,12 +298,13 @@ pub async fn homepage<D: Storage>( feed: &hfeed, card: &hcard, cursor: cursor.as_deref(), - webring: crate::database::settings::Setting::into_inner(webring) + webring: crate::database::settings::Setting::into_inner(webring), } .to_string(), } .to_string(), - ).into_response() + ) + .into_response() } Err(err) => { if err.code == StatusCode::NOT_FOUND { @@ -310,19 +314,20 @@ pub async fn homepage<D: Storage>( StatusCode::FOUND, [(axum::http::header::LOCATION, "/.kittybox/onboarding")], String::default(), - ).into_response() + ) + .into_response() } else { error!("Error while fetching h-card and/or h-feed: {}", err); // Return the error let (blogname, channels) = tokio::join!( db.get_setting::<crate::database::settings::SiteName>(&hcard_url) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_channels(&hcard_url).map(|i| i.unwrap_or_default()) ); ( - err.code(), headers, + err.code(), + headers, Template { title: blogname.as_ref(), blog_name: blogname.as_ref(), @@ -335,7 +340,8 @@ pub async fn homepage<D: Storage>( .to_string(), } .to_string(), - ).into_response() + ) + .into_response() } } } @@ -351,17 +357,13 @@ pub async fn catchall<D: Storage>( ) -> impl IntoResponse { let user: Option<&url::Url> = session.as_deref().map(|p| &p.me); let host = url::Url::parse(&format!("https://{}/", host)).unwrap(); - let path = host - .clone() - .join(uri.path()) - .unwrap(); + let path = host.clone().join(uri.path()).unwrap(); match get_post_from_database(&db, path.as_str(), query.after, user).await { Ok((post, cursor)) => { let (blogname, channels) = tokio::join!( db.get_setting::<crate::database::settings::SiteName>(&host) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_channels(&host).map(|i| i.unwrap_or_default()) ); let mut headers = axum::http::HeaderMap::new(); @@ -371,12 +373,12 @@ pub async fn catchall<D: Storage>( ); headers.insert( axum::http::header::X_CONTENT_TYPE_OPTIONS, - axum::http::HeaderValue::from_static("nosniff") + axum::http::HeaderValue::from_static("nosniff"), ); if user.is_some() { headers.insert( axum::http::header::CACHE_CONTROL, - axum::http::HeaderValue::from_static("private") + axum::http::HeaderValue::from_static("private"), ); } @@ -384,19 +386,20 @@ pub async fn catchall<D: Storage>( let last_modified = post["properties"]["updated"] .as_array() .and_then(|v| v.last()) - .or_else(|| post["properties"]["published"] - .as_array() - .and_then(|v| v.last()) - ) + .or_else(|| { + post["properties"]["published"] + .as_array() + .and_then(|v| v.last()) + }) .and_then(serde_json::Value::as_str) - .and_then(|dt| chrono::DateTime::<chrono::FixedOffset>::parse_from_rfc3339(dt).ok()); + .and_then(|dt| { + chrono::DateTime::<chrono::FixedOffset>::parse_from_rfc3339(dt).ok() + }); if let Some(last_modified) = last_modified { - headers.typed_insert( - axum_extra::headers::LastModified::from( - std::time::SystemTime::from(last_modified) - ) - ); + headers.typed_insert(axum_extra::headers::LastModified::from( + std::time::SystemTime::from(last_modified), + )); } } @@ -410,8 +413,16 @@ pub async fn catchall<D: Storage>( feeds: channels, user: session.as_deref(), content: match post.pointer("/type/0").and_then(|i| i.as_str()) { - Some("h-entry") => Entry { post: &post, from_feed: false, }.to_string(), - Some("h-feed") => Feed { feed: &post, cursor: cursor.as_deref() }.to_string(), + Some("h-entry") => Entry { + post: &post, + from_feed: false, + } + .to_string(), + Some("h-feed") => Feed { + feed: &post, + cursor: cursor.as_deref(), + } + .to_string(), Some("h-card") => VCard { card: &post }.to_string(), unknown => { unimplemented!("Template for MF2-JSON type {:?}", unknown) @@ -419,13 +430,13 @@ pub async fn catchall<D: Storage>( }, } .to_string(), - ).into_response() + ) + .into_response() } Err(err) => { let (blogname, channels) = tokio::join!( db.get_setting::<crate::database::settings::SiteName>(&host) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_channels(&host).map(|i| i.unwrap_or_default()) ); ( @@ -446,7 +457,8 @@ pub async fn catchall<D: Storage>( .to_string(), } .to_string(), - ).into_response() + ) + .into_response() } } } diff --git a/src/frontend/onboarding.rs b/src/frontend/onboarding.rs index 4588157..3b53911 100644 --- a/src/frontend/onboarding.rs +++ b/src/frontend/onboarding.rs @@ -10,7 +10,7 @@ use axum::{ use axum_extra::extract::Host; use kittybox_frontend_renderer::{ErrorPage, OnboardingPage, Template}; use serde::Deserialize; -use tokio::{task::JoinSet, sync::Mutex}; +use tokio::{sync::Mutex, task::JoinSet}; use tracing::{debug, error}; use super::FrontendError; @@ -64,7 +64,8 @@ async fn onboard<D: Storage + 'static>( me: user_uid.clone(), client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), scope: kittybox_indieauth::Scopes::new(vec![kittybox_indieauth::Scope::Create]), - iat: None, exp: None + iat: None, + exp: None, }; tracing::debug!("User data: {:?}", user); @@ -84,7 +85,7 @@ async fn onboard<D: Storage + 'static>( .await .map_err(FrontendError::from)?; - let (_, hcard) = { + let crate::micropub::util::NormalizedPost { id: _, post: hcard } = { let mut hcard = data.user; hcard["properties"]["uid"] = serde_json::json!([&user_uid]); crate::micropub::normalize_mf2(hcard, &user) @@ -99,19 +100,21 @@ async fn onboard<D: Storage + 'static>( continue; }; debug!("Creating feed {} with slug {}", &feed.name, &feed.slug); - let (_, feed) = crate::micropub::normalize_mf2( - serde_json::json!({ - "type": ["h-feed"], - "properties": {"name": [feed.name], "mp-slug": [feed.slug]} - }), - &user, - ); + let crate::micropub::util::NormalizedPost { id: _, post: feed } = + crate::micropub::normalize_mf2( + serde_json::json!({ + "type": ["h-feed"], + "properties": {"name": [feed.name], "mp-slug": [feed.slug]} + }), + &user, + ); db.put_post(&feed, &user.me) .await .map_err(FrontendError::from)?; } - let (uid, post) = crate::micropub::normalize_mf2(data.first_post, &user); + let crate::micropub::util::NormalizedPost { id: uid, post } = + crate::micropub::normalize_mf2(data.first_post, &user); tracing::debug!("Posting first post {}...", uid); crate::micropub::_post(&user, uid, post, db, http, jobset) .await @@ -169,6 +172,5 @@ where reqwest_middleware::ClientWithMiddleware: FromRef<St>, St: Clone + Send + Sync + 'static, { - axum::routing::get(get) - .post(post::<S>) + axum::routing::get(get).post(post::<S>) } diff --git a/src/indieauth/backend.rs b/src/indieauth/backend.rs index b913256..9215adf 100644 --- a/src/indieauth/backend.rs +++ b/src/indieauth/backend.rs @@ -1,9 +1,7 @@ -use std::future::Future; -use std::collections::HashMap; -use kittybox_indieauth::{ - AuthorizationRequest, TokenData -}; +use kittybox_indieauth::{AuthorizationRequest, TokenData}; pub use kittybox_util::auth::EnrolledCredential; +use std::collections::HashMap; +use std::future::Future; type Result<T> = std::io::Result<T>; @@ -20,33 +18,72 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { /// Note for implementors: the [`AuthorizationRequest::me`] value /// is guaranteed to be [`Some(url::Url)`][Option::Some] and can /// be trusted to be correct and non-malicious. - fn create_code(&self, data: AuthorizationRequest) -> impl Future<Output = Result<String>> + Send; + fn create_code( + &self, + data: AuthorizationRequest, + ) -> impl Future<Output = Result<String>> + Send; /// Retreive an authorization request using the one-time /// code. Implementations must sanitize the `code` field to /// prevent exploits, and must check if the code should still be /// valid at this point in time (validity interval is left up to /// the implementation, but is recommended to be no more than 10 /// minutes). - fn get_code(&self, code: &str) -> impl Future<Output = Result<Option<AuthorizationRequest>>> + Send; + fn get_code( + &self, + code: &str, + ) -> impl Future<Output = Result<Option<AuthorizationRequest>>> + Send; // Token management. fn create_token(&self, data: TokenData) -> impl Future<Output = Result<String>> + Send; - fn get_token(&self, website: &url::Url, token: &str) -> impl Future<Output = Result<Option<TokenData>>> + Send; - fn list_tokens(&self, website: &url::Url) -> impl Future<Output = Result<HashMap<String, TokenData>>> + Send; - fn revoke_token(&self, website: &url::Url, token: &str) -> impl Future<Output = Result<()>> + Send; + fn get_token( + &self, + website: &url::Url, + token: &str, + ) -> impl Future<Output = Result<Option<TokenData>>> + Send; + fn list_tokens( + &self, + website: &url::Url, + ) -> impl Future<Output = Result<HashMap<String, TokenData>>> + Send; + fn revoke_token( + &self, + website: &url::Url, + token: &str, + ) -> impl Future<Output = Result<()>> + Send; // Refresh token management. fn create_refresh_token(&self, data: TokenData) -> impl Future<Output = Result<String>> + Send; - fn get_refresh_token(&self, website: &url::Url, token: &str) -> impl Future<Output = Result<Option<TokenData>>> + Send; - fn list_refresh_tokens(&self, website: &url::Url) -> impl Future<Output = Result<HashMap<String, TokenData>>> + Send; - fn revoke_refresh_token(&self, website: &url::Url, token: &str) -> impl Future<Output = Result<()>> + Send; + fn get_refresh_token( + &self, + website: &url::Url, + token: &str, + ) -> impl Future<Output = Result<Option<TokenData>>> + Send; + fn list_refresh_tokens( + &self, + website: &url::Url, + ) -> impl Future<Output = Result<HashMap<String, TokenData>>> + Send; + fn revoke_refresh_token( + &self, + website: &url::Url, + token: &str, + ) -> impl Future<Output = Result<()>> + Send; // Password management. /// Verify a password. #[must_use] - fn verify_password(&self, website: &url::Url, password: String) -> impl Future<Output = Result<bool>> + Send; + fn verify_password( + &self, + website: &url::Url, + password: String, + ) -> impl Future<Output = Result<bool>> + Send; /// Enroll a password credential for a user. Only one password /// credential must exist for a given user. - fn enroll_password(&self, website: &url::Url, password: String) -> impl Future<Output = Result<()>> + Send; + fn enroll_password( + &self, + website: &url::Url, + password: String, + ) -> impl Future<Output = Result<()>> + Send; /// List currently enrolled credential types for a given user. - fn list_user_credential_types(&self, website: &url::Url) -> impl Future<Output = Result<Vec<EnrolledCredential>>> + Send; + fn list_user_credential_types( + &self, + website: &url::Url, + ) -> impl Future<Output = Result<Vec<EnrolledCredential>>> + Send; // WebAuthn credential management. #[cfg(feature = "webauthn")] /// Enroll a WebAuthn authenticator public key for this user. @@ -56,10 +93,17 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { /// This function can also be used to overwrite a passkey with an /// updated version after using /// [webauthn::prelude::Passkey::update_credential()]. - fn enroll_webauthn(&self, website: &url::Url, credential: webauthn::prelude::Passkey) -> impl Future<Output = Result<()>> + Send; + fn enroll_webauthn( + &self, + website: &url::Url, + credential: webauthn::prelude::Passkey, + ) -> impl Future<Output = Result<()>> + Send; #[cfg(feature = "webauthn")] /// List currently enrolled WebAuthn authenticators for a given user. - fn list_webauthn_pubkeys(&self, website: &url::Url) -> impl Future<Output = Result<Vec<webauthn::prelude::Passkey>>> + Send; + fn list_webauthn_pubkeys( + &self, + website: &url::Url, + ) -> impl Future<Output = Result<Vec<webauthn::prelude::Passkey>>> + Send; #[cfg(feature = "webauthn")] /// Persist registration challenge state for a little while so it /// can be used later. @@ -69,7 +113,7 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { fn persist_registration_challenge( &self, website: &url::Url, - state: webauthn::prelude::PasskeyRegistration + state: webauthn::prelude::PasskeyRegistration, ) -> impl Future<Output = Result<String>> + Send; #[cfg(feature = "webauthn")] /// Retrieve a persisted registration challenge. @@ -78,7 +122,7 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { fn retrieve_registration_challenge( &self, website: &url::Url, - challenge_id: &str + challenge_id: &str, ) -> impl Future<Output = Result<webauthn::prelude::PasskeyRegistration>> + Send; #[cfg(feature = "webauthn")] /// Persist authentication challenge state for a little while so @@ -92,7 +136,7 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { fn persist_authentication_challenge( &self, website: &url::Url, - state: webauthn::prelude::PasskeyAuthentication + state: webauthn::prelude::PasskeyAuthentication, ) -> impl Future<Output = Result<String>> + Send; #[cfg(feature = "webauthn")] /// Retrieve a persisted authentication challenge. @@ -101,7 +145,6 @@ pub trait AuthBackend: Clone + Send + Sync + 'static { fn retrieve_authentication_challenge( &self, website: &url::Url, - challenge_id: &str + challenge_id: &str, ) -> impl Future<Output = Result<webauthn::prelude::PasskeyAuthentication>> + Send; - } diff --git a/src/indieauth/backend/fs.rs b/src/indieauth/backend/fs.rs index f74fbbc..26466fe 100644 --- a/src/indieauth/backend/fs.rs +++ b/src/indieauth/backend/fs.rs @@ -1,13 +1,16 @@ -use std::{path::PathBuf, collections::HashMap, borrow::Cow, time::{SystemTime, Duration}}; - -use super::{AuthBackend, Result, EnrolledCredential}; -use kittybox_indieauth::{ - AuthorizationRequest, TokenData +use std::{ + borrow::Cow, + collections::HashMap, + path::PathBuf, + time::{Duration, SystemTime}, }; + +use super::{AuthBackend, EnrolledCredential, Result}; +use kittybox_indieauth::{AuthorizationRequest, TokenData}; use serde::de::DeserializeOwned; -use tokio::{task::spawn_blocking, io::AsyncReadExt}; +use tokio::{io::AsyncReadExt, task::spawn_blocking}; #[cfg(feature = "webauthn")] -use webauthn::prelude::{Passkey, PasskeyRegistration, PasskeyAuthentication}; +use webauthn::prelude::{Passkey, PasskeyAuthentication, PasskeyRegistration}; const CODE_LENGTH: usize = 16; const TOKEN_LENGTH: usize = 128; @@ -29,7 +32,8 @@ impl FileBackend { } else { let mut s = String::with_capacity(filename.len()); - filename.chars() + filename + .chars() .filter(|c| c.is_alphanumeric()) .for_each(|c| s.push(c)); @@ -38,41 +42,41 @@ impl FileBackend { } #[inline] - async fn serialize_to_file<T: 'static + serde::ser::Serialize + Send, B: Into<Option<&'static str>>>( + async fn serialize_to_file< + T: 'static + serde::ser::Serialize + Send, + B: Into<Option<&'static str>>, + >( &self, dir: &str, basename: B, length: usize, - data: T + data: T, ) -> Result<String> { let basename = basename.into(); let has_ext = basename.is_some(); - let (filename, mut file) = kittybox_util::fs::mktemp( - self.path.join(dir), basename, length - ) + let (filename, mut file) = kittybox_util::fs::mktemp(self.path.join(dir), basename, length) .await .map(|(name, file)| (name, file.try_into_std().unwrap()))?; spawn_blocking(move || serde_json::to_writer(&mut file, &data)) .await - .unwrap_or_else(|e| panic!( - "Panic while serializing {}: {}", - std::any::type_name::<T>(), - e - )) + .unwrap_or_else(|e| { + panic!( + "Panic while serializing {}: {}", + std::any::type_name::<T>(), + e + ) + }) .map(move |_| { (if has_ext { - filename - .extension() - + filename.extension() } else { - filename - .file_name() + filename.file_name() }) - .unwrap() - .to_str() - .unwrap() - .to_owned() + .unwrap() + .to_str() + .unwrap() + .to_owned() }) .map_err(|err| err.into()) } @@ -86,17 +90,15 @@ impl FileBackend { ) -> Result<Option<(PathBuf, SystemTime, T)>> where T: serde::de::DeserializeOwned + Send, - B: Into<Option<&'static str>> + B: Into<Option<&'static str>>, { let basename = basename.into(); - let path = self.path - .join(dir) - .join(format!( - "{}{}{}", - basename.unwrap_or(""), - if basename.is_none() { "" } else { "." }, - FileBackend::sanitize_for_path(filename) - )); + let path = self.path.join(dir).join(format!( + "{}{}{}", + basename.unwrap_or(""), + if basename.is_none() { "" } else { "." }, + FileBackend::sanitize_for_path(filename) + )); let data = match tokio::fs::File::open(&path).await { Ok(mut file) => { @@ -106,13 +108,15 @@ impl FileBackend { match serde_json::from_slice::<'_, T>(buf.as_slice()) { Ok(data) => data, - Err(err) => return Err(err.into()) + Err(err) => return Err(err.into()), + } + } + Err(err) => { + if err.kind() == std::io::ErrorKind::NotFound { + return Ok(None); + } else { + return Err(err); } - }, - Err(err) => if err.kind() == std::io::ErrorKind::NotFound { - return Ok(None) - } else { - return Err(err) } }; @@ -125,7 +129,8 @@ impl FileBackend { #[tracing::instrument] fn url_to_dir(url: &url::Url) -> String { let host = url.host_str().unwrap(); - let port = url.port() + let port = url + .port() .map(|port| Cow::Owned(format!(":{}", port))) .unwrap_or(Cow::Borrowed("")); @@ -135,23 +140,26 @@ impl FileBackend { async fn list_files<'dir, 'this: 'dir, T: DeserializeOwned + Send>( &'this self, dir: &'dir str, - prefix: &'static str + prefix: &'static str, ) -> Result<HashMap<String, T>> { let dir = self.path.join(dir); let mut hashmap = HashMap::new(); let mut readdir = match tokio::fs::read_dir(dir).await { Ok(readdir) => readdir, - Err(err) => if err.kind() == std::io::ErrorKind::NotFound { - // empty hashmap - return Ok(hashmap); - } else { - return Err(err); + Err(err) => { + if err.kind() == std::io::ErrorKind::NotFound { + // empty hashmap + return Ok(hashmap); + } else { + return Err(err); + } } }; while let Some(entry) = readdir.next_entry().await? { // safe to unwrap; filenames are alphanumeric - let filename = entry.file_name() + let filename = entry + .file_name() .into_string() .expect("token filenames should be alphanumeric!"); if let Some(token) = filename.strip_prefix(&format!("{}.", prefix)) { @@ -166,16 +174,19 @@ impl FileBackend { Err(err) => { tracing::error!( "Error decoding token data from file {}: {}", - entry.path().display(), err + entry.path().display(), + err ); continue; } }; - }, - Err(err) => if err.kind() == std::io::ErrorKind::NotFound { - continue - } else { - return Err(err) + } + Err(err) => { + if err.kind() == std::io::ErrorKind::NotFound { + continue; + } else { + return Err(err); + } } } } @@ -194,19 +205,27 @@ impl AuthBackend for FileBackend { path = base.join(&format!(".{}", path.path())).unwrap(); } - tracing::debug!("Initializing File auth backend: {} -> {}", orig_path, path.path()); + tracing::debug!( + "Initializing File auth backend: {} -> {}", + orig_path, + path.path() + ); Ok(Self { - path: std::path::PathBuf::from(path.path()) + path: std::path::PathBuf::from(path.path()), }) } // Authorization code management. async fn create_code(&self, data: AuthorizationRequest) -> Result<String> { - self.serialize_to_file("codes", None, CODE_LENGTH, data).await + self.serialize_to_file("codes", None, CODE_LENGTH, data) + .await } async fn get_code(&self, code: &str) -> Result<Option<AuthorizationRequest>> { - match self.deserialize_from_file("codes", None, FileBackend::sanitize_for_path(code).as_ref()).await? { + match self + .deserialize_from_file("codes", None, FileBackend::sanitize_for_path(code).as_ref()) + .await? + { Some((path, ctime, data)) => { if let Err(err) = tokio::fs::remove_file(path).await { tracing::error!("Failed to clean up authorization code: {}", err); @@ -217,23 +236,28 @@ impl AuthBackend for FileBackend { } else { Ok(Some(data)) } - }, - None => Ok(None) + } + None => Ok(None), } } // Token management. async fn create_token(&self, data: TokenData) -> Result<String> { let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me)); - self.serialize_to_file(&dir, "access", TOKEN_LENGTH, data).await + self.serialize_to_file(&dir, "access", TOKEN_LENGTH, data) + .await } async fn get_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> { let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); - match self.deserialize_from_file::<TokenData, _>( - &dir, "access", - FileBackend::sanitize_for_path(token).as_ref() - ).await? { + match self + .deserialize_from_file::<TokenData, _>( + &dir, + "access", + FileBackend::sanitize_for_path(token).as_ref(), + ) + .await? + { Some((path, _, token)) => { if token.expired() { if let Err(err) = tokio::fs::remove_file(path).await { @@ -243,8 +267,8 @@ impl AuthBackend for FileBackend { } else { Ok(Some(token)) } - }, - None => Ok(None) + } + None => Ok(None), } } @@ -258,25 +282,36 @@ impl AuthBackend for FileBackend { self.path .join(FileBackend::url_to_dir(website)) .join("tokens") - .join(format!("access.{}", FileBackend::sanitize_for_path(token))) - ).await { + .join(format!("access.{}", FileBackend::sanitize_for_path(token))), + ) + .await + { Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()), - result => result + result => result, } } // Refresh token management. async fn create_refresh_token(&self, data: TokenData) -> Result<String> { let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me)); - self.serialize_to_file(&dir, "refresh", TOKEN_LENGTH, data).await + self.serialize_to_file(&dir, "refresh", TOKEN_LENGTH, data) + .await } - async fn get_refresh_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> { + async fn get_refresh_token( + &self, + website: &url::Url, + token: &str, + ) -> Result<Option<TokenData>> { let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); - match self.deserialize_from_file::<TokenData, _>( - &dir, "refresh", - FileBackend::sanitize_for_path(token).as_ref() - ).await? { + match self + .deserialize_from_file::<TokenData, _>( + &dir, + "refresh", + FileBackend::sanitize_for_path(token).as_ref(), + ) + .await? + { Some((path, _, token)) => { if token.expired() { if let Err(err) = tokio::fs::remove_file(path).await { @@ -286,8 +321,8 @@ impl AuthBackend for FileBackend { } else { Ok(Some(token)) } - }, - None => Ok(None) + } + None => Ok(None), } } @@ -301,57 +336,80 @@ impl AuthBackend for FileBackend { self.path .join(FileBackend::url_to_dir(website)) .join("tokens") - .join(format!("refresh.{}", FileBackend::sanitize_for_path(token))) - ).await { + .join(format!("refresh.{}", FileBackend::sanitize_for_path(token))), + ) + .await + { Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()), - result => result + result => result, } } // Password management. #[tracing::instrument(skip(password))] async fn verify_password(&self, website: &url::Url, password: String) -> Result<bool> { - use argon2::{Argon2, password_hash::{PasswordHash, PasswordVerifier}}; + use argon2::{ + password_hash::{PasswordHash, PasswordVerifier}, + Argon2, + }; - let password_filename = self.path + let password_filename = self + .path .join(FileBackend::url_to_dir(website)) .join("password"); - tracing::debug!("Reading password for {} from {}", website, password_filename.display()); + tracing::debug!( + "Reading password for {} from {}", + website, + password_filename.display() + ); match tokio::fs::read_to_string(password_filename).await { Ok(password_hash) => { let parsed_hash = { let hash = password_hash.trim(); - #[cfg(debug_assertions)] tracing::debug!("Password hash: {}", hash); - PasswordHash::new(hash) - .expect("Password hash should be valid!") + #[cfg(debug_assertions)] + tracing::debug!("Password hash: {}", hash); + PasswordHash::new(hash).expect("Password hash should be valid!") }; - Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok()) - }, - Err(err) => if err.kind() == std::io::ErrorKind::NotFound { - Ok(false) - } else { - Err(err) + Ok(Argon2::default() + .verify_password(password.as_bytes(), &parsed_hash) + .is_ok()) + } + Err(err) => { + if err.kind() == std::io::ErrorKind::NotFound { + Ok(false) + } else { + Err(err) + } } } } #[tracing::instrument(skip(password))] async fn enroll_password(&self, website: &url::Url, password: String) -> Result<()> { - use argon2::{Argon2, password_hash::{rand_core::OsRng, PasswordHasher, SaltString}}; + use argon2::{ + password_hash::{rand_core::OsRng, PasswordHasher, SaltString}, + Argon2, + }; - let password_filename = self.path + let password_filename = self + .path .join(FileBackend::url_to_dir(website)) .join("password"); let salt = SaltString::generate(&mut OsRng); let argon2 = Argon2::default(); - let password_hash = argon2.hash_password(password.as_bytes(), &salt) + let password_hash = argon2 + .hash_password(password.as_bytes(), &salt) .expect("Hashing a password should not error out") .to_string(); - tracing::debug!("Enrolling password for {} at {}", website, password_filename.display()); + tracing::debug!( + "Enrolling password for {} at {}", + website, + password_filename.display() + ); tokio::fs::write(password_filename, password_hash.as_bytes()).await } @@ -371,7 +429,7 @@ impl AuthBackend for FileBackend { async fn persist_registration_challenge( &self, website: &url::Url, - state: PasskeyRegistration + state: PasskeyRegistration, ) -> Result<String> { todo!() } @@ -380,7 +438,7 @@ impl AuthBackend for FileBackend { async fn retrieve_registration_challenge( &self, website: &url::Url, - challenge_id: &str + challenge_id: &str, ) -> Result<PasskeyRegistration> { todo!() } @@ -389,7 +447,7 @@ impl AuthBackend for FileBackend { async fn persist_authentication_challenge( &self, website: &url::Url, - state: PasskeyAuthentication + state: PasskeyAuthentication, ) -> Result<String> { todo!() } @@ -398,24 +456,28 @@ impl AuthBackend for FileBackend { async fn retrieve_authentication_challenge( &self, website: &url::Url, - challenge_id: &str + challenge_id: &str, ) -> Result<PasskeyAuthentication> { todo!() } #[tracing::instrument(skip(self))] - async fn list_user_credential_types(&self, website: &url::Url) -> Result<Vec<EnrolledCredential>> { + async fn list_user_credential_types( + &self, + website: &url::Url, + ) -> Result<Vec<EnrolledCredential>> { let mut creds = vec![]; - let password_file = self.path + let password_file = self + .path .join(FileBackend::url_to_dir(website)) .join("password"); tracing::debug!("Password file for {}: {}", website, password_file.display()); - match tokio::fs::metadata(password_file) - .await - { + match tokio::fs::metadata(password_file).await { Ok(_) => creds.push(EnrolledCredential::Password), - Err(err) => if err.kind() != std::io::ErrorKind::NotFound { - return Err(err) + Err(err) => { + if err.kind() != std::io::ErrorKind::NotFound { + return Err(err); + } } } diff --git a/src/indieauth/mod.rs b/src/indieauth/mod.rs index 00ae393..5cdbf05 100644 --- a/src/indieauth/mod.rs +++ b/src/indieauth/mod.rs @@ -1,18 +1,29 @@ -use std::marker::PhantomData; -use microformats::types::Class; -use tracing::error; -use serde::Deserialize; +use crate::database::Storage; use axum::{ - extract::{Form, FromRef, Json, Query, State}, http::StatusCode, response::{Html, IntoResponse, Response} + extract::{Form, FromRef, Json, Query, State}, + http::StatusCode, + response::{Html, IntoResponse, Response}, }; #[cfg_attr(not(feature = "webauthn"), allow(unused_imports))] -use axum_extra::extract::{Host, cookie::{CookieJar, Cookie}}; -use axum_extra::{headers::{authorization::Bearer, Authorization, ContentType, HeaderMapExt}, TypedHeader}; -use crate::database::Storage; +use axum_extra::extract::{ + cookie::{Cookie, CookieJar}, + Host, +}; +use axum_extra::{ + headers::{authorization::Bearer, Authorization, ContentType, HeaderMapExt}, + TypedHeader, +}; use kittybox_indieauth::{ - AuthorizationRequest, AuthorizationResponse, ClientMetadata, Error, ErrorKind, GrantRequest, GrantResponse, GrantType, IntrospectionEndpointAuthMethod, Metadata, PKCEMethod, Profile, ProfileUrl, ResponseType, RevocationEndpointAuthMethod, Scope, Scopes, TokenData, TokenIntrospectionRequest, TokenIntrospectionResponse, TokenRevocationRequest + AuthorizationRequest, AuthorizationResponse, ClientMetadata, Error, ErrorKind, GrantRequest, + GrantResponse, GrantType, IntrospectionEndpointAuthMethod, Metadata, PKCEMethod, Profile, + ProfileUrl, ResponseType, RevocationEndpointAuthMethod, Scope, Scopes, TokenData, + TokenIntrospectionRequest, TokenIntrospectionResponse, TokenRevocationRequest, }; +use microformats::types::Class; +use serde::Deserialize; +use std::marker::PhantomData; use std::str::FromStr; +use tracing::error; pub mod backend; #[cfg(feature = "webauthn")] @@ -41,35 +52,42 @@ impl<A: AuthBackend> std::ops::Deref for User<A> { pub enum IndieAuthResourceError { InvalidRequest, Unauthorized, - InvalidToken + InvalidToken, } impl axum::response::IntoResponse for IndieAuthResourceError { fn into_response(self) -> axum::response::Response { use IndieAuthResourceError::*; match self { - Unauthorized => ( - StatusCode::UNAUTHORIZED, - [("WWW-Authenticate", "Bearer")] - ).into_response(), + Unauthorized => { + (StatusCode::UNAUTHORIZED, [("WWW-Authenticate", "Bearer")]).into_response() + } InvalidRequest => ( StatusCode::BAD_REQUEST, - Json(&serde_json::json!({"error": "invalid_request"})) - ).into_response(), + Json(&serde_json::json!({"error": "invalid_request"})), + ) + .into_response(), InvalidToken => ( StatusCode::UNAUTHORIZED, [("WWW-Authenticate", "Bearer, error=\"invalid_token\"")], - Json(&serde_json::json!({"error": "not_authorized"})) - ).into_response() + Json(&serde_json::json!({"error": "not_authorized"})), + ) + .into_response(), } } } -impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::extract::OptionalFromRequestParts<St> for User<A> { +impl<A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> + axum::extract::OptionalFromRequestParts<St> for User<A> +{ type Rejection = <Self as axum::extract::FromRequestParts<St>>::Rejection; - async fn from_request_parts(req: &mut axum::http::request::Parts, state: &St) -> Result<Option<Self>, Self::Rejection> { - let res = <Self as axum::extract::FromRequestParts<St>>::from_request_parts(req, state).await; + async fn from_request_parts( + req: &mut axum::http::request::Parts, + state: &St, + ) -> Result<Option<Self>, Self::Rejection> { + let res = + <Self as axum::extract::FromRequestParts<St>>::from_request_parts(req, state).await; match res { Ok(user) => Ok(Some(user)), @@ -79,14 +97,19 @@ impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::ext } } -impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::extract::FromRequestParts<St> for User<A> { +impl<A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> + axum::extract::FromRequestParts<St> for User<A> +{ type Rejection = IndieAuthResourceError; - async fn from_request_parts(req: &mut axum::http::request::Parts, state: &St) -> Result<Self, Self::Rejection> { + async fn from_request_parts( + req: &mut axum::http::request::Parts, + state: &St, + ) -> Result<Self, Self::Rejection> { let TypedHeader(Authorization(token)) = TypedHeader::<Authorization<Bearer>>::from_request_parts(req, state) - .await - .map_err(|_| IndieAuthResourceError::Unauthorized)?; + .await + .map_err(|_| IndieAuthResourceError::Unauthorized)?; let auth = A::from_ref(state); @@ -94,10 +117,7 @@ impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::ext .await .map_err(|_| IndieAuthResourceError::InvalidRequest)?; - auth.get_token( - &format!("https://{host}/").parse().unwrap(), - token.token() - ) + auth.get_token(&format!("https://{host}/").parse().unwrap(), token.token()) .await .unwrap() .ok_or(IndieAuthResourceError::InvalidToken) @@ -105,9 +125,7 @@ impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::ext } } -pub async fn metadata( - Host(host): Host -) -> Metadata { +pub async fn metadata(Host(host): Host) -> Metadata { let issuer: url::Url = format!("https://{}/", host).parse().unwrap(); let indieauth: url::Url = issuer.join("/.kittybox/indieauth/").unwrap(); @@ -117,18 +135,16 @@ pub async fn metadata( token_endpoint: indieauth.join("token").unwrap(), introspection_endpoint: indieauth.join("token_status").unwrap(), introspection_endpoint_auth_methods_supported: Some(vec![ - IntrospectionEndpointAuthMethod::Bearer + IntrospectionEndpointAuthMethod::Bearer, ]), revocation_endpoint: Some(indieauth.join("revoke_token").unwrap()), - revocation_endpoint_auth_methods_supported: Some(vec![ - RevocationEndpointAuthMethod::None - ]), + revocation_endpoint_auth_methods_supported: Some(vec![RevocationEndpointAuthMethod::None]), scopes_supported: Some(vec![ Scope::Create, Scope::Update, Scope::Delete, Scope::Media, - Scope::Profile + Scope::Profile, ]), response_types_supported: Some(vec![ResponseType::Code]), grant_types_supported: Some(vec![GrantType::AuthorizationCode, GrantType::RefreshToken]), @@ -142,30 +158,42 @@ pub async fn metadata( async fn authorization_endpoint_get<A: AuthBackend, D: Storage + 'static>( Host(host): Host, - Query(request): Query<AuthorizationRequest>, + Query(mut request): Query<AuthorizationRequest>, State(db): State<D>, State(http): State<reqwest_middleware::ClientWithMiddleware>, - State(auth): State<A> + State(auth): State<A>, ) -> Response { let me: url::Url = format!("https://{host}/").parse().unwrap(); // XXX: attempt fetching OAuth application metadata - let h_app: ClientMetadata = if request.client_id.domain().unwrap() == "localhost" && me.domain().unwrap() != "localhost" { + let h_app: ClientMetadata = if request.client_id.domain().unwrap() == "localhost" + && me.domain().unwrap() != "localhost" + { // If client is localhost, but we aren't localhost, generate synthetic metadata. tracing::warn!("Client is localhost, not fetching metadata"); - let mut metadata = ClientMetadata::new(request.client_id.clone(), request.client_id.clone()).unwrap(); + let mut metadata = + ClientMetadata::new(request.client_id.clone(), request.client_id.clone()).unwrap(); metadata.client_name = Some("Your locally hosted app".to_string()); metadata } else { tracing::debug!("Sending request to {} to fetch metadata", request.client_id); - let metadata_request = http.get(request.client_id.clone()) + let metadata_request = http + .get(request.client_id.clone()) .header("Accept", "application/json, text/html"); - match metadata_request.send().await - .and_then(|res| res.error_for_status() - .map_err(reqwest_middleware::Error::Reqwest)) - { - Ok(response) if response.headers().typed_get::<ContentType>().to_owned().map(mime::Mime::from).map(|m| m.type_() == "text" && m.subtype() == "html").unwrap_or(false) => { + match metadata_request.send().await.and_then(|res| { + res.error_for_status() + .map_err(reqwest_middleware::Error::Reqwest) + }) { + Ok(response) + if response + .headers() + .typed_get::<ContentType>() + .to_owned() + .map(mime::Mime::from) + .map(|m| m.type_() == "text" && m.subtype() == "html") + .unwrap_or(false) => + { let url = response.url().clone(); let text = response.text().await.unwrap(); tracing::debug!("Received {} bytes in response", text.len()); @@ -173,76 +201,97 @@ async fn authorization_endpoint_get<A: AuthBackend, D: Storage + 'static>( Ok(mf2) => { if let Some(relation) = mf2.rels.items.get(&request.redirect_uri) { if !relation.rels.iter().any(|i| i == "redirect_uri") { - return (StatusCode::BAD_REQUEST, - [("Content-Type", "text/plain")], - "The redirect_uri provided was declared as \ - something other than redirect_uri.") - .into_response() + return ( + StatusCode::BAD_REQUEST, + [("Content-Type", "text/plain")], + "The redirect_uri provided was declared as \ + something other than redirect_uri.", + ) + .into_response(); } } else if request.redirect_uri.origin() != request.client_id.origin() { - return (StatusCode::BAD_REQUEST, - [("Content-Type", "text/plain")], - "The redirect_uri didn't match the origin \ - and wasn't explicitly allowed. You were being tricked.") - .into_response() + return ( + StatusCode::BAD_REQUEST, + [("Content-Type", "text/plain")], + "The redirect_uri didn't match the origin \ + and wasn't explicitly allowed. You were being tricked.", + ) + .into_response(); } - - if let Some(app) = mf2.items + // Should we attempt to create synthetic metadata from an h-card? + // + // This would increase compatibility with personal websites. + if let Some(app) = mf2 + .items .iter() - .find(|&i| i.r#type.iter() - .any(|i| { + .find(|&i| { + i.r#type.iter().any(|i| { *i == Class::from_str("h-app").unwrap() || *i == Class::from_str("h-x-app").unwrap() }) - ) + }) .cloned() { // Create a synthetic metadata document. Be forgiving. let mut metadata = ClientMetadata::new( request.client_id.clone(), - app.properties.get("url") + app.properties + .get("url") .and_then(|v| v.first()) .and_then(|i| match i { - microformats::types::PropertyValue::Url(url) => Some(url.clone()), - _ => None + microformats::types::PropertyValue::Url(url) => { + Some(url.clone()) + } + _ => None, }) - .unwrap_or_else(|| request.client_id.clone()) - ).unwrap(); + .unwrap_or_else(|| request.client_id.clone()), + ) + .unwrap(); - metadata.client_name = app.properties.get("name") + metadata.client_name = app + .properties + .get("name") .and_then(|v| v.first()) .and_then(|i| match i { - microformats::types::PropertyValue::Plain(name) => Some(name.to_owned()), - _ => None + microformats::types::PropertyValue::Plain(name) => { + Some(name.to_owned()) + } + _ => None, }); metadata.redirect_uris = mf2.rels.by_rels().remove("redirect_uri"); metadata } else { - return (StatusCode::BAD_REQUEST, [("Content-Type", "text/plain")], "No h-app or JSON application metadata found.").into_response() + return ( + StatusCode::BAD_REQUEST, + [("Content-Type", "text/plain")], + "No h-app or JSON application metadata found.", + ) + .into_response(); } - }, + } Err(err) => { tracing::error!("Error parsing application metadata: {}", err); return ( StatusCode::BAD_REQUEST, [("Content-Type", "text/plain")], - "Parsing h-app metadata failed.").into_response() + "Parsing h-app metadata failed.", + ) + .into_response(); } } - }, + } Ok(response) => match response.json::<ClientMetadata>().await { - Ok(client_metadata) => { - client_metadata - }, + Ok(client_metadata) => client_metadata, Err(err) => { tracing::error!("Error parsing JSON application metadata: {}", err); return ( StatusCode::BAD_REQUEST, [("Content-Type", "text/plain")], - format!("Parsing OAuth2 JSON app metadata failed: {}", err) - ).into_response() + format!("Parsing OAuth2 JSON app metadata failed: {}", err), + ) + .into_response(); } }, Err(err) => { @@ -250,27 +299,44 @@ async fn authorization_endpoint_get<A: AuthBackend, D: Storage + 'static>( return ( StatusCode::BAD_REQUEST, [("Content-Type", "text/plain")], - format!("Fetching app metadata failed: {}", err) - ).into_response() + format!("Fetching app metadata failed: {}", err), + ) + .into_response(); } } }; tracing::debug!("Application metadata: {:#?}", h_app); - Html(kittybox_frontend_renderer::Template { - title: "Confirm sign-in via IndieAuth", - blog_name: "Kittybox", - feeds: vec![], - user: None, - content: kittybox_frontend_renderer::AuthorizationRequestPage { - request, - credentials: auth.list_user_credential_types(&me).await.unwrap(), - user: db.get_post(me.as_str()).await.unwrap().unwrap(), - app: h_app - }.to_string(), - }.to_string()) - .into_response() + // Sanity check: some older applications don't ask for scopes when they're supposed to. + // + // Give them the profile scope at least? + if request + .scope + .as_ref() + .map(|scope: &Scopes| scope.is_empty()) + .unwrap_or(true) + { + request.scope.replace(Scopes::new(vec![Scope::Profile])); + } + + Html( + kittybox_frontend_renderer::Template { + title: "Confirm sign-in via IndieAuth", + blog_name: "Kittybox", + feeds: vec![], + user: None, + content: kittybox_frontend_renderer::AuthorizationRequestPage { + request, + credentials: auth.list_user_credential_types(&me).await.unwrap(), + user: db.get_post(me.as_str()).await.unwrap().unwrap(), + app: h_app, + } + .to_string(), + } + .to_string(), + ) + .into_response() } #[derive(Deserialize, Debug)] @@ -278,7 +344,7 @@ async fn authorization_endpoint_get<A: AuthBackend, D: Storage + 'static>( enum Credential { Password(String), #[cfg(feature = "webauthn")] - WebAuthn(::webauthn::prelude::PublicKeyCredential) + WebAuthn(::webauthn::prelude::PublicKeyCredential), } // The IndieAuth standard doesn't prescribe a format for confirming @@ -291,7 +357,7 @@ enum Credential { #[derive(Deserialize, Debug)] struct AuthorizationConfirmation { authorization_method: Credential, - request: AuthorizationRequest + request: AuthorizationRequest, } #[tracing::instrument(skip(auth, credential))] @@ -299,18 +365,14 @@ async fn verify_credential<A: AuthBackend>( auth: &A, website: &url::Url, credential: Credential, - #[cfg_attr(not(feature = "webauthn"), allow(unused_variables))] - challenge_id: Option<&str> + #[cfg_attr(not(feature = "webauthn"), allow(unused_variables))] challenge_id: Option<&str>, ) -> std::io::Result<bool> { match credential { Credential::Password(password) => auth.verify_password(website, password).await, #[cfg(feature = "webauthn")] - Credential::WebAuthn(credential) => webauthn::verify( - auth, - website, - credential, - challenge_id.unwrap() - ).await + Credential::WebAuthn(credential) => { + webauthn::verify(auth, website, credential, challenge_id.unwrap()).await + } } } @@ -323,7 +385,8 @@ async fn authorization_endpoint_confirm<A: AuthBackend>( ) -> Response { tracing::debug!("Received authorization confirmation from user"); #[cfg(feature = "webauthn")] - let challenge_id = cookies.get(webauthn::CHALLENGE_ID_COOKIE) + let challenge_id = cookies + .get(webauthn::CHALLENGE_ID_COOKIE) .map(|cookie| cookie.value()); #[cfg(not(feature = "webauthn"))] let challenge_id = None; @@ -331,14 +394,16 @@ async fn authorization_endpoint_confirm<A: AuthBackend>( let website = format!("https://{}/", host).parse().unwrap(); let AuthorizationConfirmation { authorization_method: credential, - request: mut auth + request: mut auth, } = confirmation; match verify_credential(&backend, &website, credential, challenge_id).await { - Ok(verified) => if !verified { - error!("User failed verification, bailing out."); - return StatusCode::UNAUTHORIZED.into_response(); - }, + Ok(verified) => { + if !verified { + error!("User failed verification, bailing out."); + return StatusCode::UNAUTHORIZED.into_response(); + } + } Err(err) => { error!("Error while verifying credential: {}", err); return StatusCode::INTERNAL_SERVER_ERROR.into_response(); @@ -365,9 +430,14 @@ async fn authorization_endpoint_confirm<A: AuthBackend>( let location = { let mut uri = redirect_uri; - uri.set_query(Some(&serde_urlencoded::to_string( - AuthorizationResponse { code, state, iss: website } - ).unwrap())); + uri.set_query(Some( + &serde_urlencoded::to_string(AuthorizationResponse { + code, + state, + iss: website, + }) + .unwrap(), + )); uri }; @@ -375,10 +445,11 @@ async fn authorization_endpoint_confirm<A: AuthBackend>( // DO NOT SET `StatusCode::FOUND` here! `fetch()` cannot read from // redirects, it can only follow them or choose to receive an // opaque response instead that is completely useless - (StatusCode::NO_CONTENT, - [("Location", location.as_str())], - #[cfg(feature = "webauthn")] - cookies.remove(Cookie::from(webauthn::CHALLENGE_ID_COOKIE)) + ( + StatusCode::NO_CONTENT, + [("Location", location.as_str())], + #[cfg(feature = "webauthn")] + cookies.remove(Cookie::from(webauthn::CHALLENGE_ID_COOKIE)), ) .into_response() } @@ -396,15 +467,18 @@ async fn authorization_endpoint_post<A: AuthBackend, D: Storage + 'static>( code, client_id, redirect_uri, - code_verifier + code_verifier, } => { let request: AuthorizationRequest = match backend.get_code(&code).await { Ok(Some(request)) => request, - Ok(None) => return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("The provided authorization code is invalid.".to_string()), - error_uri: None - }.into_response(), + Ok(None) => { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("The provided authorization code is invalid.".to_string()), + error_uri: None, + } + .into_response() + } Err(err) => { tracing::error!("Error retrieving auth request: {}", err); return StatusCode::INTERNAL_SERVER_ERROR.into_response(); @@ -414,51 +488,66 @@ async fn authorization_endpoint_post<A: AuthBackend, D: Storage + 'static>( return Error { kind: ErrorKind::InvalidGrant, msg: Some("This authorization code isn't yours.".to_string()), - error_uri: None - }.into_response() + error_uri: None, + } + .into_response(); } if redirect_uri != request.redirect_uri { return Error { kind: ErrorKind::InvalidGrant, - msg: Some("This redirect_uri doesn't match the one the code has been sent to.".to_string()), - error_uri: None - }.into_response() + msg: Some( + "This redirect_uri doesn't match the one the code has been sent to." + .to_string(), + ), + error_uri: None, + } + .into_response(); } if !request.code_challenge.verify(code_verifier) { return Error { kind: ErrorKind::InvalidGrant, msg: Some("The PKCE challenge failed.".to_string()), // are RFCs considered human-readable? 😝 - error_uri: "https://datatracker.ietf.org/doc/html/rfc7636#section-4.6".parse().ok() - }.into_response() + error_uri: "https://datatracker.ietf.org/doc/html/rfc7636#section-4.6" + .parse() + .ok(), + } + .into_response(); } let me: url::Url = format!("https://{}/", host).parse().unwrap(); if request.me.unwrap() != me { return Error { kind: ErrorKind::InvalidGrant, msg: Some("This authorization endpoint does not serve this user.".to_string()), - error_uri: None - }.into_response() + error_uri: None, + } + .into_response(); } - let profile = if request.scope.as_ref() - .map(|s| s.has(&Scope::Profile)) - .unwrap_or_default() + let profile = if request + .scope + .as_ref() + .map(|s| s.has(&Scope::Profile)) + .unwrap_or_default() { match get_profile( db, me.as_str(), - request.scope.as_ref() + request + .scope + .as_ref() .map(|s| s.has(&Scope::Email)) - .unwrap_or_default() - ).await { + .unwrap_or_default(), + ) + .await + { Ok(profile) => { tracing::debug!("Retrieved profile: {:?}", profile); profile - }, + } Err(err) => { tracing::error!("Error retrieving profile from database: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } } } else { @@ -466,12 +555,15 @@ async fn authorization_endpoint_post<A: AuthBackend, D: Storage + 'static>( }; GrantResponse::ProfileUrl(ProfileUrl { me, profile }).into_response() - }, + } _ => Error { kind: ErrorKind::InvalidGrant, msg: Some("The provided grant_type is unusable on this endpoint.".to_string()), - error_uri: "https://indieauth.spec.indieweb.org/#redeeming-the-authorization-code".parse().ok() - }.into_response() + error_uri: "https://indieauth.spec.indieweb.org/#redeeming-the-authorization-code" + .parse() + .ok(), + } + .into_response(), } } @@ -485,36 +577,40 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( #[inline] fn prepare_access_token(me: url::Url, client_id: url::Url, scope: Scopes) -> TokenData { TokenData { - me, client_id, scope, + me, + client_id, + scope, exp: (std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - + std::time::Duration::from_secs(ACCESS_TOKEN_VALIDITY)) - .as_secs() - .into(), + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + + std::time::Duration::from_secs(ACCESS_TOKEN_VALIDITY)) + .as_secs() + .into(), iat: std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_secs() - .into() + .into(), } } #[inline] fn prepare_refresh_token(me: url::Url, client_id: url::Url, scope: Scopes) -> TokenData { TokenData { - me, client_id, scope, + me, + client_id, + scope, exp: (std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - + std::time::Duration::from_secs(REFRESH_TOKEN_VALIDITY)) - .as_secs() - .into(), + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + + std::time::Duration::from_secs(REFRESH_TOKEN_VALIDITY)) + .as_secs() + .into(), iat: std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_secs() - .into() + .into(), } } @@ -525,15 +621,18 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( code, client_id, redirect_uri, - code_verifier + code_verifier, } => { let request: AuthorizationRequest = match backend.get_code(&code).await { Ok(Some(request)) => request, - Ok(None) => return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("The provided authorization code is invalid.".to_string()), - error_uri: None - }.into_response(), + Ok(None) => { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("The provided authorization code is invalid.".to_string()), + error_uri: None, + } + .into_response() + } Err(err) => { tracing::error!("Error retrieving auth request: {}", err); return StatusCode::INTERNAL_SERVER_ERROR.into_response(); @@ -542,33 +641,46 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( tracing::debug!("Retrieved authorization request: {:?}", request); - let scope = if let Some(scope) = request.scope { scope } else { + let scope = if let Some(scope) = request.scope { + scope + } else { return Error { kind: ErrorKind::InvalidScope, msg: Some("Tokens cannot be issued if no scopes are requested.".to_string()), - error_uri: "https://indieauth.spec.indieweb.org/#access-token-response".parse().ok() - }.into_response(); + error_uri: "https://indieauth.spec.indieweb.org/#access-token-response" + .parse() + .ok(), + } + .into_response(); }; if client_id != request.client_id { return Error { kind: ErrorKind::InvalidGrant, msg: Some("This authorization code isn't yours.".to_string()), - error_uri: None - }.into_response() + error_uri: None, + } + .into_response(); } if redirect_uri != request.redirect_uri { return Error { kind: ErrorKind::InvalidGrant, - msg: Some("This redirect_uri doesn't match the one the code has been sent to.".to_string()), - error_uri: None - }.into_response() + msg: Some( + "This redirect_uri doesn't match the one the code has been sent to." + .to_string(), + ), + error_uri: None, + } + .into_response(); } if !request.code_challenge.verify(code_verifier) { return Error { kind: ErrorKind::InvalidGrant, msg: Some("The PKCE challenge failed.".to_string()), - error_uri: "https://datatracker.ietf.org/doc/html/rfc7636#section-4.6".parse().ok() - }.into_response(); + error_uri: "https://datatracker.ietf.org/doc/html/rfc7636#section-4.6" + .parse() + .ok(), + } + .into_response(); } // Note: we can trust the `request.me` value, since we set @@ -577,30 +689,32 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( return Error { kind: ErrorKind::InvalidGrant, msg: Some("This authorization endpoint does not serve this user.".to_string()), - error_uri: None - }.into_response() + error_uri: None, + } + .into_response(); } let profile = if dbg!(scope.has(&Scope::Profile)) { - match get_profile( - db, - me.as_str(), - scope.has(&Scope::Email) - ).await { + match get_profile(db, me.as_str(), scope.has(&Scope::Email)).await { Ok(profile) => dbg!(profile), Err(err) => { tracing::error!("Error retrieving profile from database: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } } } else { None }; - let access_token = match backend.create_token( - prepare_access_token(me.clone(), client_id.clone(), scope.clone()) - ).await { + let access_token = match backend + .create_token(prepare_access_token( + me.clone(), + client_id.clone(), + scope.clone(), + )) + .await + { Ok(token) => token, Err(err) => { tracing::error!("Error creating access token: {}", err); @@ -608,9 +722,10 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( } }; // TODO: only create refresh token if user allows it - let refresh_token = match backend.create_refresh_token( - prepare_refresh_token(me.clone(), client_id, scope.clone()) - ).await { + let refresh_token = match backend + .create_refresh_token(prepare_refresh_token(me.clone(), client_id, scope.clone())) + .await + { Ok(token) => token, Err(err) => { tracing::error!("Error creating refresh token: {}", err); @@ -626,24 +741,28 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( scope: Some(scope), expires_in: Some(ACCESS_TOKEN_VALIDITY), refresh_token: Some(refresh_token), - state: None - }.into_response() - }, + state: None, + } + .into_response() + } GrantRequest::RefreshToken { refresh_token, client_id, - scope + scope, } => { let data = match backend.get_refresh_token(&me, &refresh_token).await { Ok(Some(token)) => token, - Ok(None) => return Error { - kind: ErrorKind::InvalidGrant, - msg: Some("This refresh token is not valid.".to_string()), - error_uri: None - }.into_response(), + Ok(None) => { + return Error { + kind: ErrorKind::InvalidGrant, + msg: Some("This refresh token is not valid.".to_string()), + error_uri: None, + } + .into_response() + } Err(err) => { tracing::error!("Error retrieving refresh token: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } }; @@ -651,17 +770,22 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( return Error { kind: ErrorKind::InvalidGrant, msg: Some("This refresh token is not yours.".to_string()), - error_uri: None - }.into_response(); + error_uri: None, + } + .into_response(); } let scope = if let Some(scope) = scope { if !data.scope.has_all(scope.as_ref()) { return Error { kind: ErrorKind::InvalidScope, - msg: Some("You can't request additional scopes through the refresh token grant.".to_string()), - error_uri: None - }.into_response(); + msg: Some( + "You can't request additional scopes through the refresh token grant." + .to_string(), + ), + error_uri: None, + } + .into_response(); } scope @@ -670,27 +794,27 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( data.scope }; - let profile = if scope.has(&Scope::Profile) { - match get_profile( - db, - data.me.as_str(), - scope.has(&Scope::Email) - ).await { + match get_profile(db, data.me.as_str(), scope.has(&Scope::Email)).await { Ok(profile) => profile, Err(err) => { tracing::error!("Error retrieving profile from database: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } } } else { None }; - let access_token = match backend.create_token( - prepare_access_token(data.me.clone(), client_id.clone(), scope.clone()) - ).await { + let access_token = match backend + .create_token(prepare_access_token( + data.me.clone(), + client_id.clone(), + scope.clone(), + )) + .await + { Ok(token) => token, Err(err) => { tracing::error!("Error creating access token: {}", err); @@ -699,9 +823,14 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( }; let old_refresh_token = refresh_token; - let refresh_token = match backend.create_refresh_token( - prepare_refresh_token(data.me.clone(), client_id, scope.clone()) - ).await { + let refresh_token = match backend + .create_refresh_token(prepare_refresh_token( + data.me.clone(), + client_id, + scope.clone(), + )) + .await + { Ok(token) => token, Err(err) => { tracing::error!("Error creating refresh token: {}", err); @@ -721,8 +850,9 @@ async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>( scope: Some(scope), expires_in: Some(ACCESS_TOKEN_VALIDITY), refresh_token: Some(refresh_token), - state: None - }.into_response() + state: None, + } + .into_response() } } } @@ -740,26 +870,39 @@ async fn introspection_endpoint_post<A: AuthBackend>( // Check authentication first match backend.get_token(&me, auth_token.token()).await { - Ok(Some(token)) => if !token.scope.has(&Scope::custom(KITTYBOX_TOKEN_STATUS)) { - return (StatusCode::UNAUTHORIZED, Json(json!({ - "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope - }))).into_response(); - }, - Ok(None) => return (StatusCode::UNAUTHORIZED, Json(json!({ - "error": kittybox_indieauth::ResourceErrorKind::InvalidToken - }))).into_response(), + Ok(Some(token)) => { + if !token.scope.has(&Scope::custom(KITTYBOX_TOKEN_STATUS)) { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({ + "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope + })), + ) + .into_response(); + } + } + Ok(None) => { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({ + "error": kittybox_indieauth::ResourceErrorKind::InvalidToken + })), + ) + .into_response() + } Err(err) => { tracing::error!("Error retrieving token data for introspection: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } } - let response: TokenIntrospectionResponse = match backend.get_token(&me, &token_request.token).await { - Ok(maybe_data) => maybe_data.into(), - Err(err) => { - tracing::error!("Error retrieving token data: {}", err); - return StatusCode::INTERNAL_SERVER_ERROR.into_response() - } - }; + let response: TokenIntrospectionResponse = + match backend.get_token(&me, &token_request.token).await { + Ok(maybe_data) => maybe_data.into(), + Err(err) => { + tracing::error!("Error retrieving token data: {}", err); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + }; response.into_response() } @@ -787,7 +930,7 @@ async fn revocation_endpoint_post<A: AuthBackend>( async fn get_profile<D: Storage + 'static>( db: D, url: &str, - email: bool + email: bool, ) -> crate::database::Result<Option<Profile>> { fn get_first(v: serde_json::Value) -> Option<String> { match v { @@ -796,10 +939,10 @@ async fn get_profile<D: Storage + 'static>( match a.pop() { Some(serde_json::Value::String(s)) => Some(s), Some(serde_json::Value::Object(mut o)) => o.remove("value").and_then(get_first), - _ => None + _ => None, } - }, - _ => None + } + _ => None, } } @@ -807,15 +950,26 @@ async fn get_profile<D: Storage + 'static>( // Ruthlessly manually destructure the MF2 document to save memory let mut properties = match mf2.as_object_mut().unwrap().remove("properties") { Some(serde_json::Value::Object(props)) => props, - _ => unreachable!() + _ => unreachable!(), }; drop(mf2); let name = properties.remove("name").and_then(get_first); - let url = properties.remove("uid").and_then(get_first).and_then(|u| u.parse().ok()); - let photo = properties.remove("photo").and_then(get_first).and_then(|u| u.parse().ok()); + let url = properties + .remove("uid") + .and_then(get_first) + .and_then(|u| u.parse().ok()); + let photo = properties + .remove("photo") + .and_then(get_first) + .and_then(|u| u.parse().ok()); let email = properties.remove("name").and_then(get_first); - Profile { name, url, photo, email } + Profile { + name, + url, + photo, + email, + } })) } @@ -823,7 +977,7 @@ async fn userinfo_endpoint_get<A: AuthBackend, D: Storage + 'static>( Host(host): Host, TypedHeader(Authorization(auth_token)): TypedHeader<Authorization<Bearer>>, State(backend): State<A>, - State(db): State<D> + State(db): State<D>, ) -> Response { use serde_json::json; @@ -832,14 +986,22 @@ async fn userinfo_endpoint_get<A: AuthBackend, D: Storage + 'static>( match backend.get_token(&me, auth_token.token()).await { Ok(Some(token)) => { if token.expired() { - return (StatusCode::UNAUTHORIZED, Json(json!({ - "error": kittybox_indieauth::ResourceErrorKind::InvalidToken - }))).into_response(); + return ( + StatusCode::UNAUTHORIZED, + Json(json!({ + "error": kittybox_indieauth::ResourceErrorKind::InvalidToken + })), + ) + .into_response(); } if !token.scope.has(&Scope::Profile) { - return (StatusCode::UNAUTHORIZED, Json(json!({ - "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope - }))).into_response(); + return ( + StatusCode::UNAUTHORIZED, + Json(json!({ + "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope + })), + ) + .into_response(); } match get_profile(db, me.as_str(), token.scope.has(&Scope::Email)).await { @@ -847,17 +1009,19 @@ async fn userinfo_endpoint_get<A: AuthBackend, D: Storage + 'static>( Ok(None) => Json(json!({ // We do this because ResourceErrorKind is IndieAuth errors only "error": "invalid_request" - })).into_response(), + })) + .into_response(), Err(err) => { tracing::error!("Error retrieving profile from database: {}", err); StatusCode::INTERNAL_SERVER_ERROR.into_response() } } - }, + } Ok(None) => Json(json!({ "error": kittybox_indieauth::ResourceErrorKind::InvalidToken - })).into_response(), + })) + .into_response(), Err(err) => { tracing::error!("Error reading token: {}", err); @@ -871,57 +1035,51 @@ where S: Storage + FromRef<St> + 'static, A: AuthBackend + FromRef<St>, reqwest_middleware::ClientWithMiddleware: FromRef<St>, - St: Clone + Send + Sync + 'static + St: Clone + Send + Sync + 'static, { - use axum::routing::{Router, get, post}; + use axum::routing::{get, post, Router}; Router::new() .nest( "/.kittybox/indieauth", Router::new() - .route("/metadata", - get(metadata)) + .route("/metadata", get(metadata)) .route( "/auth", get(authorization_endpoint_get::<A, S>) - .post(authorization_endpoint_post::<A, S>)) - .route( - "/auth/confirm", - post(authorization_endpoint_confirm::<A>)) - .route( - "/token", - post(token_endpoint_post::<A, S>)) - .route( - "/token_status", - post(introspection_endpoint_post::<A>)) - .route( - "/revoke_token", - post(revocation_endpoint_post::<A>)) + .post(authorization_endpoint_post::<A, S>), + ) + .route("/auth/confirm", post(authorization_endpoint_confirm::<A>)) + .route("/token", post(token_endpoint_post::<A, S>)) + .route("/token_status", post(introspection_endpoint_post::<A>)) + .route("/revoke_token", post(revocation_endpoint_post::<A>)) + .route("/userinfo", get(userinfo_endpoint_get::<A, S>)) .route( - "/userinfo", - get(userinfo_endpoint_get::<A, S>)) - - .route("/webauthn/pre_register", - get( - #[cfg(feature = "webauthn")] webauthn::webauthn_pre_register::<A, S>, - #[cfg(not(feature = "webauthn"))] || std::future::ready(axum::http::StatusCode::NOT_FOUND) - ) + "/webauthn/pre_register", + get( + #[cfg(feature = "webauthn")] + webauthn::webauthn_pre_register::<A, S>, + #[cfg(not(feature = "webauthn"))] + || std::future::ready(axum::http::StatusCode::NOT_FOUND), + ), ) - .layer(tower_http::cors::CorsLayer::new() - .allow_methods([ - axum::http::Method::GET, - axum::http::Method::POST - ]) - .allow_origin(tower_http::cors::Any)) + .layer( + tower_http::cors::CorsLayer::new() + .allow_methods([axum::http::Method::GET, axum::http::Method::POST]) + .allow_origin(tower_http::cors::Any), + ), ) .route( "/.well-known/oauth-authorization-server", - get(|| std::future::ready( - (StatusCode::FOUND, - [("Location", - "/.kittybox/indieauth/metadata")] - ).into_response() - )) + get(|| { + std::future::ready( + ( + StatusCode::FOUND, + [("Location", "/.kittybox/indieauth/metadata")], + ) + .into_response(), + ) + }), ) } @@ -929,9 +1087,10 @@ where mod tests { #[test] fn test_deserialize_authorization_confirmation() { - use super::{Credential, AuthorizationConfirmation}; + use super::{AuthorizationConfirmation, Credential}; - let confirmation = serde_json::from_str::<AuthorizationConfirmation>(r#"{ + let confirmation = serde_json::from_str::<AuthorizationConfirmation>( + r#"{ "request":{ "response_type": "code", "client_id": "https://quill.p3k.io/", @@ -942,12 +1101,14 @@ mod tests { "scope": "create+media" }, "authorization_method": "swordfish" - }"#).unwrap(); + }"#, + ) + .unwrap(); match confirmation.authorization_method { Credential::Password(password) => assert_eq!(password.as_str(), "swordfish"), #[allow(unreachable_patterns)] - other => panic!("Incorrect credential: {:?}", other) + other => panic!("Incorrect credential: {:?}", other), } assert_eq!(confirmation.request.state.as_ref(), "10101010"); } diff --git a/src/indieauth/webauthn.rs b/src/indieauth/webauthn.rs index 0757e72..80d210c 100644 --- a/src/indieauth/webauthn.rs +++ b/src/indieauth/webauthn.rs @@ -1,10 +1,17 @@ use axum::{ extract::Json, + http::StatusCode, response::{IntoResponse, Response}, - http::StatusCode, Extension + Extension, +}; +use axum_extra::extract::{ + cookie::{Cookie, CookieJar}, + Host, +}; +use axum_extra::{ + headers::{authorization::Bearer, Authorization}, + TypedHeader, }; -use axum_extra::extract::{Host, cookie::{CookieJar, Cookie}}; -use axum_extra::{TypedHeader, headers::{authorization::Bearer, Authorization}}; use super::backend::AuthBackend; use crate::database::Storage; @@ -12,40 +19,33 @@ use crate::database::Storage; pub(crate) const CHALLENGE_ID_COOKIE: &str = "kittybox_webauthn_challenge_id"; macro_rules! bail { - ($msg:literal, $err:expr) => { - { - ::tracing::error!($msg, $err); - return ::axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response() - } - } + ($msg:literal, $err:expr) => {{ + ::tracing::error!($msg, $err); + return ::axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response(); + }}; } pub async fn webauthn_pre_register<A: AuthBackend, D: Storage + 'static>( Host(host): Host, Extension(db): Extension<D>, Extension(auth): Extension<A>, - cookies: CookieJar + cookies: CookieJar, ) -> Response { let uid = format!("https://{}/", host.clone()); let uid_url: url::Url = uid.parse().unwrap(); // This will not find an h-card in onboarding! let display_name = match db.get_post(&uid).await { Ok(hcard) => match hcard { - Some(mut hcard) => { - match hcard["properties"]["uid"][0].take() { - serde_json::Value::String(name) => name, - _ => String::default() - } + Some(mut hcard) => match hcard["properties"]["uid"][0].take() { + serde_json::Value::String(name) => name, + _ => String::default(), }, - None => String::default() + None => String::default(), }, - Err(err) => bail!("Error retrieving h-card: {}", err) + Err(err) => bail!("Error retrieving h-card: {}", err), }; - let webauthn = webauthn::WebauthnBuilder::new( - &host, - &uid_url - ) + let webauthn = webauthn::WebauthnBuilder::new(&host, &uid_url) .unwrap() .rp_name("Kittybox") .build() @@ -58,10 +58,10 @@ pub async fn webauthn_pre_register<A: AuthBackend, D: Storage + 'static>( webauthn::prelude::Uuid::nil(), &uid, &display_name, - Some(vec![]) + Some(vec![]), ) { Ok((challenge, state)) => (challenge, state), - Err(err) => bail!("Error generating WebAuthn registration data: {}", err) + Err(err) => bail!("Error generating WebAuthn registration data: {}", err), }; match auth.persist_registration_challenge(&uid_url, state).await { @@ -69,11 +69,12 @@ pub async fn webauthn_pre_register<A: AuthBackend, D: Storage + 'static>( cookies.add( Cookie::build((CHALLENGE_ID_COOKIE, challenge_id)) .secure(true) - .finish() + .finish(), ), - Json(challenge) - ).into_response(), - Err(err) => bail!("Failed to persist WebAuthn challenge: {}", err) + Json(challenge), + ) + .into_response(), + Err(err) => bail!("Failed to persist WebAuthn challenge: {}", err), } } @@ -82,39 +83,36 @@ pub async fn webauthn_register<A: AuthBackend>( Json(credential): Json<webauthn::prelude::RegisterPublicKeyCredential>, // TODO determine if we can use a cookie maybe? user_credential: Option<TypedHeader<Authorization<Bearer>>>, - Extension(auth): Extension<A> + Extension(auth): Extension<A>, ) -> Response { let uid = format!("https://{}/", host.clone()); let uid_url: url::Url = uid.parse().unwrap(); let pubkeys = match auth.list_webauthn_pubkeys(&uid_url).await { Ok(pubkeys) => pubkeys, - Err(err) => bail!("Error enumerating existing WebAuthn credentials: {}", err) + Err(err) => bail!("Error enumerating existing WebAuthn credentials: {}", err), }; if !pubkeys.is_empty() { if let Some(TypedHeader(Authorization(token))) = user_credential { // TODO check validity of the credential } else { - return StatusCode::UNAUTHORIZED.into_response() + return StatusCode::UNAUTHORIZED.into_response(); } } - return StatusCode::OK.into_response() + return StatusCode::OK.into_response(); } pub(crate) async fn verify<A: AuthBackend>( auth: &A, website: &url::Url, credential: webauthn::prelude::PublicKeyCredential, - challenge_id: &str + challenge_id: &str, ) -> std::io::Result<bool> { let host = website.host_str().unwrap(); - let webauthn = webauthn::WebauthnBuilder::new( - host, - website - ) + let webauthn = webauthn::WebauthnBuilder::new(host, website) .unwrap() .rp_name("Kittybox") .build() @@ -122,12 +120,14 @@ pub(crate) async fn verify<A: AuthBackend>( match webauthn.finish_passkey_authentication( &credential, - &auth.retrieve_authentication_challenge(&website, challenge_id).await? + &auth + .retrieve_authentication_challenge(&website, challenge_id) + .await?, ) { Err(err) => { tracing::error!("WebAuthn error: {}", err); Ok(false) - }, + } Ok(authentication_result) => { let counter = authentication_result.counter(); let cred_id = authentication_result.cred_id(); diff --git a/src/lib.rs b/src/lib.rs index 6d8e784..0df5e5d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,22 +4,28 @@ use std::sync::Arc; use axum::extract::{FromRef, FromRequestParts, OptionalFromRequestParts}; -use axum_extra::extract::{cookie::{Cookie, Key}, SignedCookieJar}; +use axum_extra::extract::{ + cookie::{Cookie, Key}, + SignedCookieJar, +}; use database::{FileStorage, PostgresStorage, Storage}; use indieauth::backend::{AuthBackend, FileBackend as FileAuthBackend}; use kittybox_util::queue::JobQueue; -use media::storage::{MediaStore, file::FileStore as FileMediaStore}; -use tokio::{sync::{Mutex, RwLock}, task::JoinSet}; +use media::storage::{file::FileStore as FileMediaStore, MediaStore}; +use tokio::{ + sync::{Mutex, RwLock}, + task::JoinSet, +}; use webmentions::queue::PostgresJobQueue; /// Database abstraction layer for Kittybox, allowing the CMS to work with any kind of database. pub mod database; pub mod frontend; +pub mod indieauth; +pub mod login; pub mod media; pub mod micropub; -pub mod indieauth; pub mod webmentions; -pub mod login; //pub mod admin; const OAUTH2_SOFTWARE_ID: &str = "6f2eee84-c22c-4c9e-b900-10d4e97273c8"; @@ -27,10 +33,10 @@ const OAUTH2_SOFTWARE_ID: &str = "6f2eee84-c22c-4c9e-b900-10d4e97273c8"; #[derive(Clone)] pub struct AppState<A, S, M, Q> where -A: AuthBackend + Sized + 'static, -S: Storage + Sized + 'static, -M: MediaStore + Sized + 'static, -Q: JobQueue<webmentions::Webmention> + Sized + A: AuthBackend + Sized + 'static, + S: Storage + Sized + 'static, + M: MediaStore + Sized + 'static, + Q: JobQueue<webmentions::Webmention> + Sized, { pub auth_backend: A, pub storage: S, @@ -39,7 +45,7 @@ Q: JobQueue<webmentions::Webmention> + Sized pub http: reqwest_middleware::ClientWithMiddleware, pub background_jobs: Arc<Mutex<JoinSet<()>>>, pub cookie_key: Key, - pub session_store: SessionStore + pub session_store: SessionStore, } pub type SessionStore = Arc<RwLock<std::collections::HashMap<uuid::Uuid, Session>>>; @@ -60,7 +66,11 @@ pub struct NoSessionError; impl axum::response::IntoResponse for NoSessionError { fn into_response(self) -> axum::response::Response { // TODO: prettier error message - (axum::http::StatusCode::UNAUTHORIZED, "You are not logged in, but this page requires a session.").into_response() + ( + axum::http::StatusCode::UNAUTHORIZED, + "You are not logged in, but this page requires a session.", + ) + .into_response() } } @@ -72,11 +82,17 @@ where { type Rejection = std::convert::Infallible; - async fn from_request_parts(parts: &mut axum::http::request::Parts, state: &S) -> Result<Option<Self>, Self::Rejection> { - let jar = SignedCookieJar::<Key>::from_request_parts(parts, state).await.unwrap(); + async fn from_request_parts( + parts: &mut axum::http::request::Parts, + state: &S, + ) -> Result<Option<Self>, Self::Rejection> { + let jar = SignedCookieJar::<Key>::from_request_parts(parts, state) + .await + .unwrap(); let session_store = SessionStore::from_ref(state).read_owned().await; - Ok(jar.get("session_id") + Ok(jar + .get("session_id") .as_ref() .map(Cookie::value_trimmed) .and_then(|id| uuid::Uuid::parse_str(id).ok()) @@ -103,7 +119,10 @@ where // have to repeat this magic invocation. impl<S, M, Q> FromRef<AppState<Self, S, M, Q>> for FileAuthBackend -where S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + S: Storage, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<Self, S, M, Q>) -> Self { input.auth_backend.clone() @@ -111,7 +130,10 @@ where S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> } impl<A, M, Q> FromRef<AppState<A, Self, M, Q>> for PostgresStorage -where A: AuthBackend, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, Self, M, Q>) -> Self { input.storage.clone() @@ -119,7 +141,10 @@ where A: AuthBackend, M: MediaStore, Q: JobQueue<webmentions::Webmention> } impl<A, M, Q> FromRef<AppState<A, Self, M, Q>> for FileStorage -where A: AuthBackend, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, Self, M, Q>) -> Self { input.storage.clone() @@ -128,7 +153,10 @@ where A: AuthBackend, M: MediaStore, Q: JobQueue<webmentions::Webmention> impl<A, S, Q> FromRef<AppState<A, S, Self, Q>> for FileMediaStore // where A: AuthBackend, S: Storage -where A: AuthBackend, S: Storage, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + S: Storage, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, S, Self, Q>) -> Self { input.media_store.clone() @@ -136,7 +164,11 @@ where A: AuthBackend, S: Storage, Q: JobQueue<webmentions::Webmention> } impl<A, S, M, Q> FromRef<AppState<A, S, M, Q>> for Key -where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + S: Storage, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, S, M, Q>) -> Self { input.cookie_key.clone() @@ -144,7 +176,11 @@ where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmen } impl<A, S, M, Q> FromRef<AppState<A, S, M, Q>> for reqwest_middleware::ClientWithMiddleware -where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + S: Storage, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, S, M, Q>) -> Self { input.http.clone() @@ -152,7 +188,11 @@ where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmen } impl<A, S, M, Q> FromRef<AppState<A, S, M, Q>> for Arc<Mutex<JoinSet<()>>> -where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + S: Storage, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, S, M, Q>) -> Self { input.background_jobs.clone() @@ -161,7 +201,10 @@ where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmen #[cfg(feature = "sqlx")] impl<A, S, M> FromRef<AppState<A, S, M, Self>> for PostgresJobQueue<webmentions::Webmention> -where A: AuthBackend, S: Storage, M: MediaStore +where + A: AuthBackend, + S: Storage, + M: MediaStore, { fn from_ref(input: &AppState<A, S, M, Self>) -> Self { input.job_queue.clone() @@ -169,127 +212,58 @@ where A: AuthBackend, S: Storage, M: MediaStore } impl<A, S, M, Q> FromRef<AppState<A, S, M, Q>> for SessionStore -where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +where + A: AuthBackend, + S: Storage, + M: MediaStore, + Q: JobQueue<webmentions::Webmention>, { fn from_ref(input: &AppState<A, S, M, Q>) -> Self { input.session_store.clone() } } -pub mod companion { - use std::{collections::HashMap, sync::Arc}; - use axum::{ - extract::{Extension, Path}, - response::{IntoResponse, Response} - }; - - #[derive(Debug, Clone, Copy)] - struct Resource { - data: &'static [u8], - mime: &'static str - } - - impl IntoResponse for &Resource { - fn into_response(self) -> Response { - (axum::http::StatusCode::OK, - [("Content-Type", self.mime)], - self.data).into_response() - } - } - - // TODO replace with the "phf" crate someday - type ResourceTable = Arc<HashMap<&'static str, Resource>>; - - #[tracing::instrument] - async fn map_to_static( - Path(name): Path<String>, - Extension(resources): Extension<ResourceTable> - ) -> Response { - tracing::debug!("Searching for {} in the resource table...", name); - match resources.get(name.as_str()) { - Some(res) => res.into_response(), - None => { - #[cfg(debug_assertions)] tracing::error!("Not found"); - - (axum::http::StatusCode::NOT_FOUND, - [("Content-Type", "text/plain")], - "Not found. Sorry.".as_bytes()).into_response() - } - } - } - - pub fn router<St: Clone + Send + Sync + 'static>() -> axum::Router<St> { - let resources: ResourceTable = { - let mut map = HashMap::new(); - - macro_rules! register_resource { - ($map:ident, $prefix:expr, ($filename:literal, $mime:literal)) => {{ - $map.insert($filename, Resource { - data: include_bytes!(concat!($prefix, $filename)), - mime: $mime - }) - }}; - ($map:ident, $prefix:expr, ($filename:literal, $mime:literal), $( ($f:literal, $m:literal) ),+) => {{ - register_resource!($map, $prefix, ($filename, $mime)); - register_resource!($map, $prefix, $(($f, $m)),+); - }}; - } - - register_resource! { - map, - concat!(env!("OUT_DIR"), "/", "companion", "/"), - ("index.html", "text/html; charset=\"utf-8\""), - ("main.js", "text/javascript"), - ("micropub_api.js", "text/javascript"), - ("indieauth.js", "text/javascript"), - ("base64.js", "text/javascript"), - ("style.css", "text/css") - }; - - Arc::new(map) - }; - - axum::Router::new() - .route( - "/{filename}", - axum::routing::get(map_to_static) - .layer(Extension(resources)) - ) - } -} +pub mod companion; async fn teapot_route() -> impl axum::response::IntoResponse { use axum::http::{header, StatusCode}; - (StatusCode::IM_A_TEAPOT, [(header::CONTENT_TYPE, "text/plain")], "Sorry, can't brew coffee yet!") + ( + StatusCode::IM_A_TEAPOT, + [(header::CONTENT_TYPE, "text/plain")], + "Sorry, can't brew coffee yet!", + ) } async fn health_check<D>( axum::extract::State(data): axum::extract::State<D>, ) -> impl axum::response::IntoResponse where - D: crate::database::Storage + D: crate::database::Storage, { (axum::http::StatusCode::OK, std::borrow::Cow::Borrowed("OK")) } pub async fn compose_kittybox<St, A, S, M, Q>() -> axum::Router<St> where -A: AuthBackend + 'static + FromRef<St>, -S: Storage + 'static + FromRef<St>, -M: MediaStore + 'static + FromRef<St>, -Q: kittybox_util::queue::JobQueue<crate::webmentions::Webmention> + FromRef<St>, -reqwest_middleware::ClientWithMiddleware: FromRef<St>, -Arc<Mutex<JoinSet<()>>>: FromRef<St>, -crate::SessionStore: FromRef<St>, -axum_extra::extract::cookie::Key: FromRef<St>, -St: Clone + Send + Sync + 'static + A: AuthBackend + 'static + FromRef<St>, + S: Storage + 'static + FromRef<St>, + M: MediaStore + 'static + FromRef<St>, + Q: kittybox_util::queue::JobQueue<crate::webmentions::Webmention> + FromRef<St>, + reqwest_middleware::ClientWithMiddleware: FromRef<St>, + Arc<Mutex<JoinSet<()>>>: FromRef<St>, + crate::SessionStore: FromRef<St>, + axum_extra::extract::cookie::Key: FromRef<St>, + St: Clone + Send + Sync + 'static, { use axum::routing::get; axum::Router::new() .route("/", get(crate::frontend::homepage::<S>)) .fallback(get(crate::frontend::catchall::<S>)) .route("/.kittybox/micropub", crate::micropub::router::<A, S, St>()) - .route("/.kittybox/onboarding", crate::frontend::onboarding::router::<St, S>()) + .route( + "/.kittybox/onboarding", + crate::frontend::onboarding::router::<St, S>(), + ) .nest("/.kittybox/media", crate::media::router::<St, A, M>()) .merge(crate::indieauth::router::<St, A, S>()) .merge(crate::webmentions::router::<St, Q>()) @@ -297,21 +271,38 @@ St: Clone + Send + Sync + 'static .nest("/.kittybox/login", crate::login::router::<St, S>()) .route( "/.kittybox/static/{*path}", - axum::routing::get(crate::frontend::statics) + axum::routing::get(crate::frontend::statics), ) .route("/.kittybox/coffee", get(teapot_route)) - .nest("/.kittybox/micropub/client", crate::companion::router::<St>()) + .nest( + "/.kittybox/micropub/client", + crate::companion::router::<St>(), + ) .layer(tower_http::trace::TraceLayer::new_for_http()) .layer(tower_http::catch_panic::CatchPanicLayer::new()) - .layer(tower_http::sensitive_headers::SetSensitiveHeadersLayer::new([ - axum::http::header::AUTHORIZATION, - axum::http::header::COOKIE, - axum::http::header::SET_COOKIE, - ])) + .layer( + tower_http::sensitive_headers::SetSensitiveHeadersLayer::new([ + axum::http::header::AUTHORIZATION, + axum::http::header::COOKIE, + axum::http::header::SET_COOKIE, + ]), + ) .layer(tower_http::set_header::SetResponseHeaderLayer::appending( axum::http::header::CONTENT_SECURITY_POLICY, - axum::http::HeaderValue::from_static( - "default-src 'self'; img-src https:; script-src 'self'; style-src 'self'; base-uri 'none'; object-src 'none'" - ) + axum::http::HeaderValue::from_static(concat!( + "default-src 'none';", // Do not allow unknown things we didn't foresee. + "img-src https:;", // Allow hotlinking images from anywhere. + "form-action 'self';", // Only allow sending forms back to us. + "media-src 'self';", // Only allow embedding media from us. + "script-src 'self';", // Only run scripts we serve. + "font-src 'self';", // Only use fonts we serve. + "style-src 'self';", // Only use styles we serve. + "base-uri 'none';", // Do not allow to change the base URI. + "object-src 'none';", // Do not allow to embed objects (Flash/ActiveX). + "connect-src 'self';", // Allow sending data back to us. (WHY IS THIS A THING OMG) + // Allow embedding the Bandcamp player for jam posts. + // TODO: perhaps make this policy customizable?… + "frame-src 'self' https://bandcamp.com/EmbeddedPlayer/;" + )), )) } diff --git a/src/login.rs b/src/login.rs index eaa787c..3038d9c 100644 --- a/src/login.rs +++ b/src/login.rs @@ -1,10 +1,25 @@ use std::{borrow::Cow, str::FromStr}; +use axum::{ + extract::{FromRef, Query, State}, + http::HeaderValue, + response::IntoResponse, + Form, +}; +use axum_extra::{ + extract::{ + cookie::{self, Cookie}, + Host, SignedCookieJar, + }, + headers::HeaderMapExt, + TypedHeader, +}; use futures_util::FutureExt; -use axum::{extract::{FromRef, Query, State}, http::HeaderValue, response::IntoResponse, Form}; -use axum_extra::{extract::{Host, cookie::{self, Cookie}, SignedCookieJar}, headers::HeaderMapExt, TypedHeader}; -use hyper::{header::{CACHE_CONTROL, LOCATION}, StatusCode}; -use kittybox_frontend_renderer::{Template, LoginPage, LogoutPage}; +use hyper::{ + header::{CACHE_CONTROL, LOCATION}, + StatusCode, +}; +use kittybox_frontend_renderer::{LoginPage, LogoutPage, Template}; use kittybox_indieauth::{AuthorizationResponse, Error, GrantType, PKCEVerifier, Scope, Scopes}; use sha2::Digest; @@ -13,14 +28,13 @@ use crate::database::Storage; /// Show a login page. async fn get<S: Storage + Send + Sync + 'static>( State(db): State<S>, - Host(host): Host + Host(host): Host, ) -> impl axum::response::IntoResponse { let hcard_url: url::Url = format!("https://{}/", host).parse().unwrap(); let (blogname, channels) = tokio::join!( db.get_setting::<crate::database::settings::SiteName>(&hcard_url) - .map(Result::unwrap_or_default), - + .map(Result::unwrap_or_default), db.get_channels(&hcard_url).map(|i| i.unwrap_or_default()) ); ( @@ -34,14 +48,15 @@ async fn get<S: Storage + Send + Sync + 'static>( blog_name: blogname.as_ref(), feeds: channels, user: None, - content: LoginPage {}.to_string() - }.to_string() + content: LoginPage {}.to_string(), + } + .to_string(), ) } #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] struct LoginForm { - url: url::Url + url: url::Url, } /// Accept login and start the IndieAuth dance. @@ -60,10 +75,12 @@ async fn post( .expires(None) .secure(true) .http_only(true) - .build() + .build(), ); - let client_id: url::Url = format!("https://{}/.kittybox/login/client_metadata", host).parse().unwrap(); + let client_id: url::Url = format!("https://{}/.kittybox/login/client_metadata", host) + .parse() + .unwrap(); let redirect_uri = { let mut uri = client_id.clone(); uri.set_path("/.kittybox/login/finish"); @@ -71,11 +88,15 @@ async fn post( }; let indieauth_state = kittybox_indieauth::AuthorizationRequest { response_type: kittybox_indieauth::ResponseType::Code, - client_id, redirect_uri, + client_id, + redirect_uri, state: kittybox_indieauth::State::new(), - code_challenge: kittybox_indieauth::PKCEChallenge::new(&code_verifier, kittybox_indieauth::PKCEMethod::S256), + code_challenge: kittybox_indieauth::PKCEChallenge::new( + &code_verifier, + kittybox_indieauth::PKCEMethod::S256, + ), scope: Some(Scopes::new(vec![Scope::Profile])), - me: Some(form.url.clone()) + me: Some(form.url.clone()), }; // Fetch the user's homepage, determine their authorization endpoint @@ -89,8 +110,9 @@ async fn post( tracing::error!("Error fetching homepage: {:?}", err); return ( StatusCode::BAD_REQUEST, - format!("couldn't fetch your homepage: {}", err) - ).into_response() + format!("couldn't fetch your homepage: {}", err), + ) + .into_response(); } }; @@ -106,22 +128,27 @@ async fn post( // .collect::<Vec<axum_extra::headers::Link>>(); // // todo!("parse Link: headers") - + let body = match response.text().await { Ok(body) => match microformats::from_html(&body, form.url) { Ok(mf2) => mf2, - Err(err) => return ( - StatusCode::BAD_REQUEST, - format!("error while parsing your homepage with mf2: {}", err) - ).into_response() + Err(err) => { + return ( + StatusCode::BAD_REQUEST, + format!("error while parsing your homepage with mf2: {}", err), + ) + .into_response() + } }, - Err(err) => return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("error while fetching your homepage: {}", err) - ).into_response() + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("error while fetching your homepage: {}", err), + ) + .into_response() + } }; - let mut iss: Option<url::Url> = None; let mut authorization_endpoint = match body .rels @@ -139,10 +166,22 @@ async fn post( Ok(metadata) => { iss = Some(metadata.issuer); metadata.authorization_endpoint - }, - Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("couldn't parse your oauth2 metadata: {}", err)).into_response() + } + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("couldn't parse your oauth2 metadata: {}", err), + ) + .into_response() + } }, - Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("couldn't fetch your oauth2 metadata: {}", err)).into_response() + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("couldn't fetch your oauth2 metadata: {}", err), + ) + .into_response() + } }, None => match body .rels @@ -151,13 +190,17 @@ async fn post( .map(|v| v.as_slice()) .unwrap_or_default() .first() - .cloned() { - Some(authorization_endpoint) => authorization_endpoint, - None => return ( + .cloned() + { + Some(authorization_endpoint) => authorization_endpoint, + None => { + return ( StatusCode::BAD_REQUEST, - "no authorization endpoint was found on your homepage." - ).into_response() + "no authorization endpoint was found on your homepage.", + ) + .into_response() } + }, }; cookies = cookies.add( @@ -166,7 +209,7 @@ async fn post( .expires(None) .secure(true) .http_only(true) - .build() + .build(), ); if let Some(iss) = iss { @@ -176,7 +219,7 @@ async fn post( .expires(None) .secure(true) .http_only(true) - .build() + .build(), ); } @@ -186,7 +229,7 @@ async fn post( .expires(None) .secure(true) .http_only(true) - .build() + .build(), ); authorization_endpoint @@ -194,9 +237,12 @@ async fn post( .extend_pairs(indieauth_state.as_query_pairs().iter()); tracing::debug!("Forwarding user to {}", authorization_endpoint); - (StatusCode::FOUND, [ - ("Location", authorization_endpoint.to_string()), - ], cookies).into_response() + ( + StatusCode::FOUND, + [("Location", authorization_endpoint.to_string())], + cookies, + ) + .into_response() } /// Accept the return of the IndieAuth dance. Set a cookie for the @@ -208,7 +254,9 @@ async fn callback( State(http): State<reqwest_middleware::ClientWithMiddleware>, State(session_store): State<crate::SessionStore>, ) -> axum::response::Response { - let client_id: url::Url = format!("https://{}/.kittybox/login/client_metadata", host).parse().unwrap(); + let client_id: url::Url = format!("https://{}/.kittybox/login/client_metadata", host) + .parse() + .unwrap(); let redirect_uri = { let mut uri = client_id.clone(); uri.set_path("/.kittybox/login/finish"); @@ -218,7 +266,8 @@ async fn callback( let me: url::Url = cookie_jar.get("me").unwrap().value().parse().unwrap(); let code_verifier: PKCEVerifier = cookie_jar.get("code_verifier").unwrap().value().into(); - let authorization_endpoint: url::Url = cookie_jar.get("authorization_endpoint") + let authorization_endpoint: url::Url = cookie_jar + .get("authorization_endpoint") .and_then(|v| v.value().parse().ok()) .unwrap(); match cookie_jar.get("iss").and_then(|c| c.value().parse().ok()) { @@ -232,24 +281,59 @@ async fn callback( code: response.code, client_id, redirect_uri, - code_verifier, + code_verifier, }; - tracing::debug!("POSTing {:?} to authorization endpoint {}", grant_request, authorization_endpoint); - let res = match http.post(authorization_endpoint) + tracing::debug!( + "POSTing {:?} to authorization endpoint {}", + grant_request, + authorization_endpoint + ); + let res = match http + .post(authorization_endpoint) .form(&grant_request) .header(reqwest::header::ACCEPT, "application/json") .send() .await { - Ok(res) if res.status().is_success() => match res.json::<kittybox_indieauth::GrantResponse>().await { - Ok(grant) => grant, - Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, [(CACHE_CONTROL, "no-store")], format!("error parsing authorization endpoint response: {}", err)).into_response() - }, + Ok(res) if res.status().is_success() => { + match res.json::<kittybox_indieauth::GrantResponse>().await { + Ok(grant) => grant, + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + [(CACHE_CONTROL, "no-store")], + format!("error parsing authorization endpoint response: {}", err), + ) + .into_response() + } + } + } Ok(res) => match res.json::<Error>().await { - Ok(err) => return (StatusCode::BAD_REQUEST, [(CACHE_CONTROL, "no-store")], err.to_string()).into_response(), - Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, [(CACHE_CONTROL, "no-store")], format!("error parsing indieauth error: {}", err)).into_response() + Ok(err) => { + return ( + StatusCode::BAD_REQUEST, + [(CACHE_CONTROL, "no-store")], + err.to_string(), + ) + .into_response() + } + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + [(CACHE_CONTROL, "no-store")], + format!("error parsing indieauth error: {}", err), + ) + .into_response() + } + }, + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + [(CACHE_CONTROL, "no-store")], + format!("error redeeming authorization code: {}", err), + ) + .into_response() } - Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, [(CACHE_CONTROL, "no-store")], format!("error redeeming authorization code: {}", err)).into_response() }; let profile = match res { @@ -265,19 +349,28 @@ async fn callback( let uuid = uuid::Uuid::new_v4(); session_store.write().await.insert(uuid, session); let cookies = cookie_jar - .add(Cookie::build(("session_id", uuid.to_string())) - .expires(None) - .secure(true) - .http_only(true) - .path("/") - .build() + .add( + Cookie::build(("session_id", uuid.to_string())) + .expires(None) + .secure(true) + .http_only(true) + .path("/") + .build(), ) .remove("authorization_endpoint") .remove("me") .remove("iss") .remove("code_verifier"); - (StatusCode::FOUND, [(LOCATION, HeaderValue::from_static("/")), (CACHE_CONTROL, HeaderValue::from_static("no-store"))], dbg!(cookies)).into_response() + ( + StatusCode::FOUND, + [ + (LOCATION, HeaderValue::from_static("/")), + (CACHE_CONTROL, HeaderValue::from_static("no-store")), + ], + dbg!(cookies), + ) + .into_response() } /// Show the form necessary for logout. If JS is enabled, @@ -288,32 +381,42 @@ async fn callback( /// stupid enough to execute JS and send a POST request though, that's /// on the crawler. async fn logout_page() -> impl axum::response::IntoResponse { - (StatusCode::OK, [("Content-Type", "text/html")], Template { - title: "Signing out...", - blog_name: "Kittybox", - feeds: vec![], - user: None, - content: LogoutPage {}.to_string() - }.to_string()) + ( + StatusCode::OK, + [("Content-Type", "text/html")], + Template { + title: "Signing out...", + blog_name: "Kittybox", + feeds: vec![], + user: None, + content: LogoutPage {}.to_string(), + } + .to_string(), + ) } /// Erase the necessary cookies for login and invalidate the session. async fn logout( mut cookies: SignedCookieJar, - State(session_store): State<crate::SessionStore> -) -> (StatusCode, [(&'static str, &'static str); 1], SignedCookieJar) { - if let Some(id) = cookies.get("session_id") + State(session_store): State<crate::SessionStore>, +) -> ( + StatusCode, + [(&'static str, &'static str); 1], + SignedCookieJar, +) { + if let Some(id) = cookies + .get("session_id") .and_then(|c| uuid::Uuid::parse_str(c.value_trimmed()).ok()) { session_store.write().await.remove(&id); } - cookies = cookies.remove("me") + cookies = cookies + .remove("me") .remove("iss") .remove("authorization_endpoint") .remove("code_verifier") .remove("session_id"); - (StatusCode::FOUND, [("Location", "/")], cookies) } @@ -343,7 +446,7 @@ async fn client_metadata<S: Storage + Send + Sync + 'static>( }; if let Some(cached) = cached { if cached.precondition_passes(&etag) { - return StatusCode::NOT_MODIFIED.into_response() + return StatusCode::NOT_MODIFIED.into_response(); } } let client_uri: url::Url = format!("https://{}/", host).parse().unwrap(); @@ -356,7 +459,13 @@ async fn client_metadata<S: Storage + Send + Sync + 'static>( let mut metadata = kittybox_indieauth::ClientMetadata::new(client_id, client_uri).unwrap(); - metadata.client_name = Some(storage.get_setting::<crate::database::settings::SiteName>(&metadata.client_uri).await.unwrap_or_default().0); + metadata.client_name = Some( + storage + .get_setting::<crate::database::settings::SiteName>(&metadata.client_uri) + .await + .unwrap_or_default() + .0, + ); metadata.grant_types = Some(vec![GrantType::AuthorizationCode]); // We don't request anything more than the profile scope. metadata.scope = Some(Scopes::new(vec![Scope::Profile])); @@ -368,15 +477,18 @@ async fn client_metadata<S: Storage + Send + Sync + 'static>( // identity providers, or json to match newest spec let mut response = metadata.into_response(); // Indicate to upstream caches this endpoint does different things depending on the Accept: header. - response.headers_mut().append("Vary", HeaderValue::from_static("Accept")); + response + .headers_mut() + .append("Vary", HeaderValue::from_static("Accept")); // Cache this metadata for an hour. - response.headers_mut().append("Cache-Control", HeaderValue::from_static("max-age=600")); + response + .headers_mut() + .append("Cache-Control", HeaderValue::from_static("max-age=600")); response.headers_mut().typed_insert(etag); response } - /// Produce a router for all of the above. pub fn router<St, S>() -> axum::routing::Router<St> where diff --git a/src/main.rs b/src/main.rs index bd3684e..984745a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,11 @@ -use kittybox::{database::Storage, indieauth::backend::AuthBackend, media::storage::MediaStore, webmentions::Webmention, compose_kittybox}; -use tokio::{sync::Mutex, task::JoinSet}; +use kittybox::{ + compose_kittybox, database::Storage, indieauth::backend::AuthBackend, + media::storage::MediaStore, webmentions::Webmention, +}; use std::{env, future::IntoFuture, sync::Arc}; +use tokio::{sync::Mutex, task::JoinSet}; use tracing::error; - #[tokio::main] async fn main() { use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Registry}; @@ -17,32 +19,28 @@ async fn main() { .with_indent_lines(true) .with_verbose_exit(true), #[cfg(not(debug_assertions))] - tracing_subscriber::fmt::layer().json() - .with_ansi(std::io::IsTerminal::is_terminal(&std::io::stdout().lock())) + tracing_subscriber::fmt::layer() + .json() + .with_ansi(std::io::IsTerminal::is_terminal(&std::io::stdout().lock())), ); // In debug builds, also log to JSON, but to file. #[cfg(debug_assertions)] - let tracing_registry = tracing_registry.with( - tracing_subscriber::fmt::layer() - .json() - .with_writer({ - let instant = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap(); - move || std::fs::OpenOptions::new() + let tracing_registry = + tracing_registry.with(tracing_subscriber::fmt::layer().json().with_writer({ + let instant = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap(); + move || { + std::fs::OpenOptions::new() .append(true) .create(true) - .open( - format!( - "{}.log.json", - instant - .as_secs_f64() - .to_string() - .replace('.', "_") - ) - ).unwrap() - }) - ); + .open(format!( + "{}.log.json", + instant.as_secs_f64().to_string().replace('.', "_") + )) + .unwrap() + } + })); tracing_registry.init(); tracing::info!("Starting the kittybox server..."); @@ -79,12 +77,15 @@ async fn main() { }); // TODO: load from environment - let cookie_key = axum_extra::extract::cookie::Key::from(&env::var("COOKIE_KEY") - .as_deref() - .map(|s| data_encoding::BASE64_MIME_PERMISSIVE.decode(s.as_bytes()) - .expect("Invalid cookie key: must be base64 encoded") - ) - .unwrap() + let cookie_key = axum_extra::extract::cookie::Key::from( + &env::var("COOKIE_KEY") + .as_deref() + .map(|s| { + data_encoding::BASE64_MIME_PERMISSIVE + .decode(s.as_bytes()) + .expect("Invalid cookie key: must be base64 encoded") + }) + .unwrap(), ); let cancellation_token = tokio_util::sync::CancellationToken::new(); @@ -93,12 +94,11 @@ async fn main() { let http: reqwest_middleware::ClientWithMiddleware = { #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), - "/", - env!("CARGO_PKG_VERSION") - )); + let mut builder = reqwest::Client::builder().user_agent(concat!( + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION") + )); if let Ok(certs) = std::env::var("KITTYBOX_CUSTOM_PKI_ROOTS") { // TODO: add a root certificate if there's an environment variable pointing at it for path in certs.split(':') { @@ -108,21 +108,19 @@ async fn main() { tracing::error!("TLS root certificate {} not found, skipping...", path); continue; } - Err(err) => panic!("Error loading TLS certificates: {}", err) + Err(err) => panic!("Error loading TLS certificates: {}", err), }; if metadata.is_dir() { let mut dir = tokio::fs::read_dir(path).await.unwrap(); while let Ok(Some(file)) = dir.next_entry().await { let pem = tokio::fs::read(file.path()).await.unwrap(); - builder = builder.add_root_certificate( - reqwest::Certificate::from_pem(&pem).unwrap() - ); + builder = builder + .add_root_certificate(reqwest::Certificate::from_pem(&pem).unwrap()); } } else { let pem = tokio::fs::read(path).await.unwrap(); - builder = builder.add_root_certificate( - reqwest::Certificate::from_pem(&pem).unwrap() - ); + builder = + builder.add_root_certificate(reqwest::Certificate::from_pem(&pem).unwrap()); } } } @@ -151,7 +149,7 @@ async fn main() { let job_queue_type = job_queue_uri.scheme(); macro_rules! compose_kittybox { - ($auth:ty, $store:ty, $media:ty, $queue:ty) => { { + ($auth:ty, $store:ty, $media:ty, $queue:ty) => {{ type AuthBackend = $auth; type Storage = $store; type MediaStore = $media; @@ -193,36 +191,43 @@ async fn main() { }; type St = kittybox::AppState<AuthBackend, Storage, MediaStore, JobQueue>; - let stateful_router = compose_kittybox::<St, AuthBackend, Storage, MediaStore, JobQueue>().await; - let task = kittybox::webmentions::supervised_webmentions_task::<St, Storage, JobQueue>(&state, cancellation_token.clone()); + let stateful_router = + compose_kittybox::<St, AuthBackend, Storage, MediaStore, JobQueue>().await; + let task = kittybox::webmentions::supervised_webmentions_task::<St, Storage, JobQueue>( + &state, + cancellation_token.clone(), + ); let router = stateful_router.with_state(state); (router, task) - } } + }}; } - let (router, webmentions_task): (axum::Router<()>, kittybox::webmentions::SupervisedTask) = match (authstore_type, backend_type, blobstore_type, job_queue_type) { - ("file", "file", "file", "postgres") => { - compose_kittybox!( - kittybox::indieauth::backend::fs::FileBackend, - kittybox::database::FileStorage, - kittybox::media::storage::file::FileStore, - kittybox::webmentions::queue::PostgresJobQueue<Webmention> - ) - }, - ("file", "postgres", "file", "postgres") => { - compose_kittybox!( - kittybox::indieauth::backend::fs::FileBackend, - kittybox::database::PostgresStorage, - kittybox::media::storage::file::FileStore, - kittybox::webmentions::queue::PostgresJobQueue<Webmention> - ) - }, - (_, _, _, _) => { - // TODO: refine this error. - panic!("Invalid type for AUTH_STORE_URI, BACKEND_URI, BLOBSTORE_URI or JOB_QUEUE_URI"); - } - }; + let (router, webmentions_task): (axum::Router<()>, kittybox::webmentions::SupervisedTask) = + match (authstore_type, backend_type, blobstore_type, job_queue_type) { + ("file", "file", "file", "postgres") => { + compose_kittybox!( + kittybox::indieauth::backend::fs::FileBackend, + kittybox::database::FileStorage, + kittybox::media::storage::file::FileStore, + kittybox::webmentions::queue::PostgresJobQueue<Webmention> + ) + } + ("file", "postgres", "file", "postgres") => { + compose_kittybox!( + kittybox::indieauth::backend::fs::FileBackend, + kittybox::database::PostgresStorage, + kittybox::media::storage::file::FileStore, + kittybox::webmentions::queue::PostgresJobQueue<Webmention> + ) + } + (_, _, _, _) => { + // TODO: refine this error. + panic!( + "Invalid type for AUTH_STORE_URI, BACKEND_URI, BLOBSTORE_URI or JOB_QUEUE_URI" + ); + } + }; let mut servers: Vec<axum::serve::Serve<_, _, _>> = vec![]; @@ -238,7 +243,7 @@ async fn main() { // .serve(router.clone().into_make_service()) axum::serve( tokio::net::TcpListener::from_std(tcp).unwrap(), - router.clone() + router.clone(), ) }; @@ -246,8 +251,8 @@ async fn main() { for i in 0..(listenfd.len()) { match listenfd.take_tcp_listener(i) { Ok(Some(tcp)) => servers.push(build_hyper(tcp)), - Ok(None) => {}, - Err(err) => tracing::error!("Error binding to socket in fd {}: {}", i, err) + Ok(None) => {} + Err(err) => tracing::error!("Error binding to socket in fd {}: {}", i, err), } } // TODO this requires the `hyperlocal` crate @@ -302,24 +307,35 @@ async fn main() { // to get rid of an extra reference to `jobset` drop(router); // Polling streams mutates them - let mut servers_futures = Box::pin(servers.into_iter() - .map( - #[cfg(not(tokio_unstable))] |server| tokio::task::spawn( - server.with_graceful_shutdown(cancellation_token.clone().cancelled_owned()) - .into_future() - ), - #[cfg(tokio_unstable)] |server| { - tokio::task::Builder::new() - .name(format!("Kittybox HTTP acceptor: {:?}", server).as_str()) - .spawn( - server.with_graceful_shutdown( - cancellation_token.clone().cancelled_owned() - ).into_future() + let mut servers_futures = Box::pin( + servers + .into_iter() + .map( + #[cfg(not(tokio_unstable))] + |server| { + tokio::task::spawn( + server + .with_graceful_shutdown(cancellation_token.clone().cancelled_owned()) + .into_future(), ) - .unwrap() - } - ) - .collect::<futures_util::stream::FuturesUnordered<tokio::task::JoinHandle<Result<(), std::io::Error>>>>() + }, + #[cfg(tokio_unstable)] + |server| { + tokio::task::Builder::new() + .name(format!("Kittybox HTTP acceptor: {:?}", server).as_str()) + .spawn( + server + .with_graceful_shutdown( + cancellation_token.clone().cancelled_owned(), + ) + .into_future(), + ) + .unwrap() + }, + ) + .collect::<futures_util::stream::FuturesUnordered< + tokio::task::JoinHandle<Result<(), std::io::Error>>, + >>(), ); #[cfg(not(unix))] @@ -329,10 +345,10 @@ async fn main() { use tokio::signal::unix::{signal, SignalKind}; async move { - let mut interrupt = signal(SignalKind::interrupt()) - .expect("Failed to set up SIGINT handler"); - let mut terminate = signal(SignalKind::terminate()) - .expect("Failed to setup SIGTERM handler"); + let mut interrupt = + signal(SignalKind::interrupt()).expect("Failed to set up SIGINT handler"); + let mut terminate = + signal(SignalKind::terminate()).expect("Failed to setup SIGTERM handler"); tokio::select! { _ = terminate.recv() => {}, diff --git a/src/media/mod.rs b/src/media/mod.rs index 6f263b6..7e52414 100644 --- a/src/media/mod.rs +++ b/src/media/mod.rs @@ -1,22 +1,23 @@ +use crate::indieauth::{backend::AuthBackend, User}; use axum::{ - extract::{multipart::Multipart, FromRef, Path, State}, response::{IntoResponse, Response} + extract::{multipart::Multipart, FromRef, Path, State}, + response::{IntoResponse, Response}, }; -use axum_extra::headers::{ContentLength, HeaderMapExt, HeaderValue, IfNoneMatch}; use axum_extra::extract::Host; +use axum_extra::headers::{ContentLength, HeaderMapExt, HeaderValue, IfNoneMatch}; use axum_extra::TypedHeader; -use kittybox_util::micropub::{Error as MicropubError, ErrorKind as ErrorType}; use kittybox_indieauth::Scope; -use crate::indieauth::{backend::AuthBackend, User}; +use kittybox_util::micropub::{Error as MicropubError, ErrorKind as ErrorType}; pub mod storage; -use storage::{MediaStore, MediaStoreError, Metadata, ErrorKind}; pub use storage::file::FileStore; +use storage::{ErrorKind, MediaStore, MediaStoreError, Metadata}; impl From<MediaStoreError> for MicropubError { fn from(err: MediaStoreError) -> Self { Self::new( ErrorType::InternalServerError, - format!("media store error: {}", err) + format!("media store error: {}", err), ) } } @@ -25,13 +26,14 @@ impl From<MediaStoreError> for MicropubError { pub(crate) async fn upload<S: MediaStore, A: AuthBackend>( State(blobstore): State<S>, user: User<A>, - mut upload: Multipart + mut upload: Multipart, ) -> Response { if !user.check_scope(&Scope::Media) { return MicropubError::from_static( ErrorType::NotAuthorized, - "Interacting with the media storage requires the \"media\" scope." - ).into_response(); + "Interacting with the media storage requires the \"media\" scope.", + ) + .into_response(); } let host = user.me.authority(); let field = match upload.next_field().await { @@ -39,27 +41,31 @@ pub(crate) async fn upload<S: MediaStore, A: AuthBackend>( Ok(None) => { return MicropubError::from_static( ErrorType::InvalidRequest, - "Send multipart/form-data with one field named file" - ).into_response(); - }, + "Send multipart/form-data with one field named file", + ) + .into_response(); + } Err(err) => { return MicropubError::new( ErrorType::InternalServerError, - format!("Error while parsing multipart/form-data: {}", err) - ).into_response(); - }, + format!("Error while parsing multipart/form-data: {}", err), + ) + .into_response(); + } }; let metadata: Metadata = (&field).into(); match blobstore.write_streaming(host, metadata, field).await { Ok(filename) => IntoResponse::into_response(( axum::http::StatusCode::CREATED, - [ - ("Location", user.me.join( - &format!(".kittybox/media/uploads/{}", filename) - ).unwrap().as_str()) - ] + [( + "Location", + user.me + .join(&format!(".kittybox/media/uploads/{}", filename)) + .unwrap() + .as_str(), + )], )), - Err(err) => MicropubError::from(err).into_response() + Err(err) => MicropubError::from(err).into_response(), } } @@ -68,7 +74,7 @@ pub(crate) async fn serve<S: MediaStore>( Host(host): Host, Path(path): Path<String>, if_none_match: Option<TypedHeader<IfNoneMatch>>, - State(blobstore): State<S> + State(blobstore): State<S>, ) -> Response { use axum::http::StatusCode; tracing::debug!("Searching for file..."); @@ -77,7 +83,9 @@ pub(crate) async fn serve<S: MediaStore>( tracing::debug!("Metadata: {:?}", metadata); let etag = if let Some(etag) = metadata.etag { - let etag = format!("\"{}\"", etag).parse::<axum_extra::headers::ETag>().unwrap(); + let etag = format!("\"{}\"", etag) + .parse::<axum_extra::headers::ETag>() + .unwrap(); if let Some(TypedHeader(if_none_match)) = if_none_match { tracing::debug!("If-None-Match: {:?}", if_none_match); @@ -85,12 +93,14 @@ pub(crate) async fn serve<S: MediaStore>( // returns 304 when it doesn't match because it // only matches when file is different if !if_none_match.precondition_passes(&etag) { - return StatusCode::NOT_MODIFIED.into_response() + return StatusCode::NOT_MODIFIED.into_response(); } } Some(etag) - } else { None }; + } else { + None + }; let mut r = Response::builder(); { @@ -98,14 +108,16 @@ pub(crate) async fn serve<S: MediaStore>( headers.insert( "Content-Type", HeaderValue::from_str( - metadata.content_type + metadata + .content_type .as_deref() - .unwrap_or("application/octet-stream") - ).unwrap() + .unwrap_or("application/octet-stream"), + ) + .unwrap(), ); headers.insert( axum::http::header::X_CONTENT_TYPE_OPTIONS, - axum::http::HeaderValue::from_static("nosniff") + axum::http::HeaderValue::from_static("nosniff"), ); if let Some(length) = metadata.length { headers.typed_insert(ContentLength(length.get().try_into().unwrap())); @@ -117,23 +129,22 @@ pub(crate) async fn serve<S: MediaStore>( r.body(axum::body::Body::from_stream(stream)) .unwrap() .into_response() - }, + } Err(err) => match err.kind() { - ErrorKind::NotFound => { - IntoResponse::into_response(StatusCode::NOT_FOUND) - }, + ErrorKind::NotFound => IntoResponse::into_response(StatusCode::NOT_FOUND), _ => { tracing::error!("{}", err); IntoResponse::into_response(StatusCode::INTERNAL_SERVER_ERROR) } - } + }, } } -pub fn router<St, A, M>() -> axum::Router<St> where +pub fn router<St, A, M>() -> axum::Router<St> +where A: AuthBackend + FromRef<St>, M: MediaStore + FromRef<St>, - St: Clone + Send + Sync + 'static + St: Clone + Send + Sync + 'static, { axum::Router::new() .route("/", axum::routing::post(upload::<M, A>)) diff --git a/src/media/storage/file.rs b/src/media/storage/file.rs index e432945..5198a4c 100644 --- a/src/media/storage/file.rs +++ b/src/media/storage/file.rs @@ -1,12 +1,12 @@ -use super::{Metadata, ErrorKind, MediaStore, MediaStoreError, Result}; -use std::{path::PathBuf, fmt::Debug}; -use tokio::fs::OpenOptions; -use tokio::io::{BufReader, BufWriter, AsyncWriteExt, AsyncSeekExt}; +use super::{ErrorKind, MediaStore, MediaStoreError, Metadata, Result}; +use futures::FutureExt; use futures::{StreamExt, TryStreamExt}; +use sha2::Digest; use std::ops::{Bound, Neg}; use std::pin::Pin; -use sha2::Digest; -use futures::FutureExt; +use std::{fmt::Debug, path::PathBuf}; +use tokio::fs::OpenOptions; +use tokio::io::{AsyncSeekExt, AsyncWriteExt, BufReader, BufWriter}; use tracing::{debug, error}; const BUF_CAPACITY: usize = 16 * 1024; @@ -22,7 +22,7 @@ impl From<tokio::io::Error> for MediaStoreError { msg: format!("file I/O error: {}", source), kind: match source.kind() { std::io::ErrorKind::NotFound => ErrorKind::NotFound, - _ => ErrorKind::Backend + _ => ErrorKind::Backend, }, source: Some(Box::new(source)), } @@ -40,7 +40,9 @@ impl FileStore { impl MediaStore for FileStore { async fn new(url: &'_ url::Url) -> Result<Self> { - Ok(Self { base: url.path().into() }) + Ok(Self { + base: url.path().into(), + }) } #[tracing::instrument(skip(self, content))] @@ -51,10 +53,17 @@ impl MediaStore for FileStore { mut content: T, ) -> Result<String> where - T: tokio_stream::Stream<Item = std::result::Result<bytes::Bytes, axum::extract::multipart::MultipartError>> + Unpin + Send + Debug + T: tokio_stream::Stream< + Item = std::result::Result<bytes::Bytes, axum::extract::multipart::MultipartError>, + > + Unpin + + Send + + Debug, { let (tempfilepath, mut tempfile) = self.mktemp().await?; - debug!("Temporary file opened for storing pending upload: {}", tempfilepath.display()); + debug!( + "Temporary file opened for storing pending upload: {}", + tempfilepath.display() + ); let mut hasher = sha2::Sha256::new(); let mut length: usize = 0; @@ -62,7 +71,7 @@ impl MediaStore for FileStore { let chunk = chunk.map_err(|err| MediaStoreError { kind: ErrorKind::Backend, source: Some(Box::new(err)), - msg: "Failed to read a data chunk".to_owned() + msg: "Failed to read a data chunk".to_owned(), })?; debug!("Read {} bytes from the stream", chunk.len()); length += chunk.len(); @@ -70,9 +79,7 @@ impl MediaStore for FileStore { { let chunk = chunk.clone(); let tempfile = &mut tempfile; - async move { - tempfile.write_all(&chunk).await - } + async move { tempfile.write_all(&chunk).await } }, { let chunk = chunk.clone(); @@ -80,7 +87,8 @@ impl MediaStore for FileStore { hasher.update(&*chunk); hasher - }).map(|r| r.unwrap()) + }) + .map(|r| r.unwrap()) } ); if let Err(err) = write_result { @@ -90,7 +98,9 @@ impl MediaStore for FileStore { // though temporary files might take up space on the hard drive // We'll clean them when maintenance time comes #[allow(unused_must_use)] - { tokio::fs::remove_file(tempfilepath).await; } + { + tokio::fs::remove_file(tempfilepath).await; + } return Err(err.into()); } hasher = _hasher; @@ -113,10 +123,17 @@ impl MediaStore for FileStore { let filepath = self.base.join(domain_str.as_str()).join(&filename); let metafilename = filename.clone() + ".json"; let metapath = self.base.join(domain_str.as_str()).join(&metafilename); - let metatemppath = self.base.join(domain_str.as_str()).join(metafilename + ".tmp"); + let metatemppath = self + .base + .join(domain_str.as_str()) + .join(metafilename + ".tmp"); metadata.length = std::num::NonZeroUsize::new(length); metadata.etag = Some(hash); - debug!("File path: {}, metadata: {}", filepath.display(), metapath.display()); + debug!( + "File path: {}, metadata: {}", + filepath.display(), + metapath.display() + ); { let parent = filepath.parent().unwrap(); tokio::fs::create_dir_all(parent).await?; @@ -126,39 +143,44 @@ impl MediaStore for FileStore { .write(true) .open(&metatemppath) .await?; - meta.write_all(&serde_json::to_vec(&metadata).unwrap()).await?; + meta.write_all(&serde_json::to_vec(&metadata).unwrap()) + .await?; tokio::fs::rename(tempfilepath, filepath).await?; tokio::fs::rename(metatemppath, metapath).await?; Ok(filename) } #[tracing::instrument(skip(self))] + #[allow(clippy::type_complexity)] async fn read_streaming( &self, domain: &str, filename: &str, - ) -> Result<(Metadata, Pin<Box<dyn tokio_stream::Stream<Item = std::io::Result<bytes::Bytes>> + Send>>)> { + ) -> Result<( + Metadata, + Pin<Box<dyn tokio_stream::Stream<Item = std::io::Result<bytes::Bytes>> + Send>>, + )> { debug!("Domain: {}, filename: {}", domain, filename); let path = self.base.join(domain).join(filename); debug!("Path: {}", path.display()); - let file = OpenOptions::new() - .read(true) - .open(path) - .await?; + let file = OpenOptions::new().read(true).open(path).await?; let meta = self.metadata(domain, filename).await?; - Ok((meta, Box::pin( - tokio_util::io::ReaderStream::new( - // TODO: determine if BufReader provides benefit here - // From the logs it looks like we're reading 4KiB at a time - // Buffering file contents seems to double download speed - // How to benchmark this? - BufReader::with_capacity(BUF_CAPACITY, file) - ) - // Sprinkle some salt in form of protective log wrapping - .inspect_ok(|chunk| debug!("Read {} bytes from file", chunk.len())) - ))) + Ok(( + meta, + Box::pin( + tokio_util::io::ReaderStream::new( + // TODO: determine if BufReader provides benefit here + // From the logs it looks like we're reading 4KiB at a time + // Buffering file contents seems to double download speed + // How to benchmark this? + BufReader::with_capacity(BUF_CAPACITY, file), + ) + // Sprinkle some salt in form of protective log wrapping + .inspect_ok(|chunk| debug!("Read {} bytes from file", chunk.len())), + ), + )) } #[tracing::instrument(skip(self))] @@ -166,12 +188,13 @@ impl MediaStore for FileStore { let metapath = self.base.join(domain).join(format!("{}.json", filename)); debug!("Metadata path: {}", metapath.display()); - let meta = serde_json::from_slice(&tokio::fs::read(metapath).await?) - .map_err(|err| MediaStoreError { + let meta = serde_json::from_slice(&tokio::fs::read(metapath).await?).map_err(|err| { + MediaStoreError { kind: ErrorKind::Json, msg: format!("{}", err), - source: Some(Box::new(err)) - })?; + source: Some(Box::new(err)), + } + })?; Ok(meta) } @@ -181,16 +204,14 @@ impl MediaStore for FileStore { &self, domain: &str, filename: &str, - range: (Bound<u64>, Bound<u64>) - ) -> Result<Pin<Box<dyn tokio_stream::Stream<Item = std::io::Result<bytes::Bytes>> + Send>>> { + range: (Bound<u64>, Bound<u64>), + ) -> Result<Pin<Box<dyn tokio_stream::Stream<Item = std::io::Result<bytes::Bytes>> + Send>>> + { let path = self.base.join(format!("{}/{}", domain, filename)); let metapath = self.base.join(format!("{}/{}.json", domain, filename)); debug!("Path: {}, metadata: {}", path.display(), metapath.display()); - let mut file = OpenOptions::new() - .read(true) - .open(path) - .await?; + let mut file = OpenOptions::new().read(true).open(path).await?; let start = match range { (Bound::Included(bound), _) => { @@ -201,45 +222,52 @@ impl MediaStore for FileStore { (Bound::Unbounded, Bound::Included(bound)) => { // Seek to the end minus the bounded bytes debug!("Seeking {} bytes back from the end...", bound); - file.seek(std::io::SeekFrom::End(i64::try_from(bound).unwrap().neg())).await? - }, + file.seek(std::io::SeekFrom::End(i64::try_from(bound).unwrap().neg())) + .await? + } (Bound::Unbounded, Bound::Unbounded) => 0, - (_, Bound::Excluded(_)) => unreachable!() + (_, Bound::Excluded(_)) => unreachable!(), }; - let stream = Box::pin(tokio_util::io::ReaderStream::new(BufReader::with_capacity(BUF_CAPACITY, file))) - .map_ok({ - let mut bytes_read = 0usize; - let len = match range { - (_, Bound::Unbounded) => None, - (Bound::Unbounded, Bound::Included(bound)) => Some(bound), - (_, Bound::Included(bound)) => Some(bound + 1 - start), - (_, Bound::Excluded(_)) => unreachable!() - }; - move |chunk| { - debug!("Read {} bytes from file, {} in this chunk", bytes_read, chunk.len()); - bytes_read += chunk.len(); - if let Some(len) = len.map(|len| len.try_into().unwrap()) { - if bytes_read > len { - if bytes_read - len > chunk.len() { - return None - } - debug!("Truncating last {} bytes", bytes_read - len); - return Some(chunk.slice(..chunk.len() - (bytes_read - len))) + let stream = Box::pin(tokio_util::io::ReaderStream::new(BufReader::with_capacity( + BUF_CAPACITY, + file, + ))) + .map_ok({ + let mut bytes_read = 0usize; + let len = match range { + (_, Bound::Unbounded) => None, + (Bound::Unbounded, Bound::Included(bound)) => Some(bound), + (_, Bound::Included(bound)) => Some(bound + 1 - start), + (_, Bound::Excluded(_)) => unreachable!(), + }; + move |chunk| { + debug!( + "Read {} bytes from file, {} in this chunk", + bytes_read, + chunk.len() + ); + bytes_read += chunk.len(); + if let Some(len) = len.map(|len| len.try_into().unwrap()) { + if bytes_read > len { + if bytes_read - len > chunk.len() { + return None; } + debug!("Truncating last {} bytes", bytes_read - len); + return Some(chunk.slice(..chunk.len() - (bytes_read - len))); } - - Some(chunk) } - }) - .try_take_while(|x| std::future::ready(Ok(x.is_some()))) - // Will never panic, because the moment the stream yields - // a None, it is considered exhausted. - .map_ok(|x| x.unwrap()); - return Ok(Box::pin(stream)) - } + Some(chunk) + } + }) + .try_take_while(|x| std::future::ready(Ok(x.is_some()))) + // Will never panic, because the moment the stream yields + // a None, it is considered exhausted. + .map_ok(|x| x.unwrap()); + return Ok(Box::pin(stream)); + } async fn delete(&self, domain: &str, filename: &str) -> Result<()> { let path = self.base.join(format!("{}/{}", domain, filename)); @@ -250,7 +278,7 @@ impl MediaStore for FileStore { #[cfg(test)] mod tests { - use super::{Metadata, FileStore, MediaStore}; + use super::{FileStore, MediaStore, Metadata}; use std::ops::Bound; use tokio::io::AsyncReadExt; @@ -258,10 +286,15 @@ mod tests { #[tracing_test::traced_test] async fn test_ranges() { let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); - let store = FileStore { base: tempdir.path().to_path_buf() }; + let store = FileStore { + base: tempdir.path().to_path_buf(), + }; let file: &[u8] = include_bytes!("./file.rs"); - let stream = tokio_stream::iter(file.chunks(100).map(|i| Ok(bytes::Bytes::copy_from_slice(i)))); + let stream = tokio_stream::iter( + file.chunks(100) + .map(|i| Ok(bytes::Bytes::copy_from_slice(i))), + ); let metadata = Metadata { filename: Some("file.rs".to_string()), content_type: Some("text/plain".to_string()), @@ -270,28 +303,30 @@ mod tests { }; // write through the interface - let filename = store.write_streaming( - "fireburn.ru", - metadata, stream - ).await.unwrap(); + let filename = store + .write_streaming("fireburn.ru", metadata, stream) + .await + .unwrap(); tracing::debug!("Writing complete."); // Ensure the file is there - let content = tokio::fs::read( - tempdir.path() - .join("fireburn.ru") - .join(&filename) - ).await.unwrap(); + let content = tokio::fs::read(tempdir.path().join("fireburn.ru").join(&filename)) + .await + .unwrap(); assert_eq!(content, file); tracing::debug!("Reading range from the start..."); // try to read range let range = { - let stream = store.stream_range( - "fireburn.ru", &filename, - (Bound::Included(0), Bound::Included(299)) - ).await.unwrap(); + let stream = store + .stream_range( + "fireburn.ru", + &filename, + (Bound::Included(0), Bound::Included(299)), + ) + .await + .unwrap(); let mut reader = tokio_util::io::StreamReader::new(stream); @@ -307,10 +342,14 @@ mod tests { tracing::debug!("Reading range from the middle..."); let range = { - let stream = store.stream_range( - "fireburn.ru", &filename, - (Bound::Included(150), Bound::Included(449)) - ).await.unwrap(); + let stream = store + .stream_range( + "fireburn.ru", + &filename, + (Bound::Included(150), Bound::Included(449)), + ) + .await + .unwrap(); let mut reader = tokio_util::io::StreamReader::new(stream); @@ -325,13 +364,17 @@ mod tests { tracing::debug!("Reading range from the end..."); let range = { - let stream = store.stream_range( - "fireburn.ru", &filename, - // Note: the `headers` crate parses bounds in a - // non-standard way, where unbounded start actually - // means getting things from the end... - (Bound::Unbounded, Bound::Included(300)) - ).await.unwrap(); + let stream = store + .stream_range( + "fireburn.ru", + &filename, + // Note: the `headers` crate parses bounds in a + // non-standard way, where unbounded start actually + // means getting things from the end... + (Bound::Unbounded, Bound::Included(300)), + ) + .await + .unwrap(); let mut reader = tokio_util::io::StreamReader::new(stream); @@ -342,15 +385,19 @@ mod tests { }; assert_eq!(range.len(), 300); - assert_eq!(range.as_slice(), &file[file.len()-300..file.len()]); + assert_eq!(range.as_slice(), &file[file.len() - 300..file.len()]); tracing::debug!("Reading the whole file..."); // try to read range let range = { - let stream = store.stream_range( - "fireburn.ru", &("/".to_string() + &filename), - (Bound::Unbounded, Bound::Unbounded) - ).await.unwrap(); + let stream = store + .stream_range( + "fireburn.ru", + &("/".to_string() + &filename), + (Bound::Unbounded, Bound::Unbounded), + ) + .await + .unwrap(); let mut reader = tokio_util::io::StreamReader::new(stream); @@ -364,15 +411,19 @@ mod tests { assert_eq!(range.as_slice(), file); } - #[tokio::test] #[tracing_test::traced_test] async fn test_streaming_read_write() { let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); - let store = FileStore { base: tempdir.path().to_path_buf() }; + let store = FileStore { + base: tempdir.path().to_path_buf(), + }; let file: &[u8] = include_bytes!("./file.rs"); - let stream = tokio_stream::iter(file.chunks(100).map(|i| Ok(bytes::Bytes::copy_from_slice(i)))); + let stream = tokio_stream::iter( + file.chunks(100) + .map(|i| Ok(bytes::Bytes::copy_from_slice(i))), + ); let metadata = Metadata { filename: Some("style.css".to_string()), content_type: Some("text/css".to_string()), @@ -381,27 +432,32 @@ mod tests { }; // write through the interface - let filename = store.write_streaming( - "fireburn.ru", - metadata, stream - ).await.unwrap(); - println!("{}, {}", filename, tempdir.path() - .join("fireburn.ru") - .join(&filename) - .display()); - let content = tokio::fs::read( - tempdir.path() - .join("fireburn.ru") - .join(&filename) - ).await.unwrap(); + let filename = store + .write_streaming("fireburn.ru", metadata, stream) + .await + .unwrap(); + println!( + "{}, {}", + filename, + tempdir.path().join("fireburn.ru").join(&filename).display() + ); + let content = tokio::fs::read(tempdir.path().join("fireburn.ru").join(&filename)) + .await + .unwrap(); assert_eq!(content, file); // check internal metadata format - let meta: Metadata = serde_json::from_slice(&tokio::fs::read( - tempdir.path() - .join("fireburn.ru") - .join(filename.clone() + ".json") - ).await.unwrap()).unwrap(); + let meta: Metadata = serde_json::from_slice( + &tokio::fs::read( + tempdir + .path() + .join("fireburn.ru") + .join(filename.clone() + ".json"), + ) + .await + .unwrap(), + ) + .unwrap(); assert_eq!(meta.content_type.as_deref(), Some("text/css")); assert_eq!(meta.filename.as_deref(), Some("style.css")); assert_eq!(meta.length.map(|i| i.get()), Some(file.len())); @@ -409,10 +465,10 @@ mod tests { // read back the data using the interface let (metadata, read_back) = { - let (metadata, stream) = store.read_streaming( - "fireburn.ru", - &filename - ).await.unwrap(); + let (metadata, stream) = store + .read_streaming("fireburn.ru", &filename) + .await + .unwrap(); let mut reader = tokio_util::io::StreamReader::new(stream); let mut buf = Vec::default(); @@ -426,6 +482,5 @@ mod tests { assert_eq!(meta.filename.as_deref(), Some("style.css")); assert_eq!(meta.length.map(|i| i.get()), Some(file.len())); assert!(meta.etag.is_some()); - } } diff --git a/src/media/storage/mod.rs b/src/media/storage/mod.rs index 551b61e..5658071 100644 --- a/src/media/storage/mod.rs +++ b/src/media/storage/mod.rs @@ -1,12 +1,12 @@ use axum::extract::multipart::Field; -use tokio_stream::Stream; use bytes::Bytes; use serde::{Deserialize, Serialize}; +use std::fmt::Debug; use std::future::Future; +use std::num::NonZeroUsize; use std::ops::Bound; use std::pin::Pin; -use std::fmt::Debug; -use std::num::NonZeroUsize; +use tokio_stream::Stream; pub mod file; @@ -24,17 +24,14 @@ pub struct Metadata { impl From<&Field<'_>> for Metadata { fn from(field: &Field<'_>) -> Self { Self { - content_type: field.content_type() - .map(|i| i.to_owned()), - filename: field.file_name() - .map(|i| i.to_owned()), + content_type: field.content_type().map(|i| i.to_owned()), + filename: field.file_name().map(|i| i.to_owned()), length: None, etag: None, } } } - #[derive(Debug, Clone, Copy)] pub enum ErrorKind { Backend, @@ -95,87 +92,116 @@ pub trait MediaStore: 'static + Send + Sync + Clone { content: T, ) -> impl Future<Output = Result<String>> + Send where - T: tokio_stream::Stream<Item = std::result::Result<bytes::Bytes, axum::extract::multipart::MultipartError>> + Unpin + Send + Debug; + T: tokio_stream::Stream< + Item = std::result::Result<bytes::Bytes, axum::extract::multipart::MultipartError>, + > + Unpin + + Send + + Debug; + #[allow(clippy::type_complexity)] fn read_streaming( &self, domain: &str, filename: &str, - ) -> impl Future<Output = Result< - (Metadata, Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>) - >> + Send; + ) -> impl Future< + Output = Result<( + Metadata, + Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>, + )>, + > + Send; fn stream_range( &self, domain: &str, filename: &str, - range: (Bound<u64>, Bound<u64>) - ) -> impl Future<Output = Result<Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>>> + Send { async move { - use futures::stream::TryStreamExt; - use tracing::debug; - let (metadata, mut stream) = self.read_streaming(domain, filename).await?; - let length = metadata.length.unwrap().get(); - - use Bound::*; - let (start, end): (usize, usize) = match range { - (Unbounded, Unbounded) => return Ok(stream), - (Included(start), Unbounded) => (start.try_into().unwrap(), length - 1), - (Unbounded, Included(end)) => (length - usize::try_from(end).unwrap(), length - 1), - (Included(start), Included(end)) => (start.try_into().unwrap(), end.try_into().unwrap()), - (_, _) => unreachable!() - }; - - stream = Box::pin( - stream.map_ok({ - let mut bytes_skipped = 0usize; - let mut bytes_read = 0usize; - - move |chunk| { - debug!("Skipped {}/{} bytes, chunk len {}", bytes_skipped, start, chunk.len()); - let chunk = if bytes_skipped < start { - let need_to_skip = start - bytes_skipped; - if chunk.len() < need_to_skip { - return None - } - debug!("Skipping {} bytes", need_to_skip); - bytes_skipped += need_to_skip; - - chunk.slice(need_to_skip..) - } else { - chunk - }; - - debug!("Read {} bytes from file, {} in this chunk", bytes_read, chunk.len()); - bytes_read += chunk.len(); - - if bytes_read > length { - if bytes_read - length > chunk.len() { - return None - } - debug!("Truncating last {} bytes", bytes_read - length); - return Some(chunk.slice(..chunk.len() - (bytes_read - length))) - } - - Some(chunk) + range: (Bound<u64>, Bound<u64>), + ) -> impl Future<Output = Result<Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>>> + Send + { + async move { + use futures::stream::TryStreamExt; + use tracing::debug; + let (metadata, mut stream) = self.read_streaming(domain, filename).await?; + let length = metadata.length.unwrap().get(); + + use Bound::*; + let (start, end): (usize, usize) = match range { + (Unbounded, Unbounded) => return Ok(stream), + (Included(start), Unbounded) => (start.try_into().unwrap(), length - 1), + (Unbounded, Included(end)) => (length - usize::try_from(end).unwrap(), length - 1), + (Included(start), Included(end)) => { + (start.try_into().unwrap(), end.try_into().unwrap()) } - }) - .try_skip_while(|x| std::future::ready(Ok(x.is_none()))) - .try_take_while(|x| std::future::ready(Ok(x.is_some()))) - .map_ok(|x| x.unwrap()) - ); + (_, _) => unreachable!(), + }; + + stream = Box::pin( + stream + .map_ok({ + let mut bytes_skipped = 0usize; + let mut bytes_read = 0usize; + + move |chunk| { + debug!( + "Skipped {}/{} bytes, chunk len {}", + bytes_skipped, + start, + chunk.len() + ); + let chunk = if bytes_skipped < start { + let need_to_skip = start - bytes_skipped; + if chunk.len() < need_to_skip { + return None; + } + debug!("Skipping {} bytes", need_to_skip); + bytes_skipped += need_to_skip; + + chunk.slice(need_to_skip..) + } else { + chunk + }; + + debug!( + "Read {} bytes from file, {} in this chunk", + bytes_read, + chunk.len() + ); + bytes_read += chunk.len(); + + if bytes_read > length { + if bytes_read - length > chunk.len() { + return None; + } + debug!("Truncating last {} bytes", bytes_read - length); + return Some(chunk.slice(..chunk.len() - (bytes_read - length))); + } + + Some(chunk) + } + }) + .try_skip_while(|x| std::future::ready(Ok(x.is_none()))) + .try_take_while(|x| std::future::ready(Ok(x.is_some()))) + .map_ok(|x| x.unwrap()), + ); - Ok(stream) - } } + Ok(stream) + } + } /// Read metadata for a file. /// /// The default implementation uses the `read_streaming` method /// and drops the stream containing file content. - fn metadata(&self, domain: &str, filename: &str) -> impl Future<Output = Result<Metadata>> + Send { async move { - self.read_streaming(domain, filename) - .await - .map(|(meta, _)| meta) - } } + fn metadata( + &self, + domain: &str, + filename: &str, + ) -> impl Future<Output = Result<Metadata>> + Send { + async move { + self.read_streaming(domain, filename) + .await + .map(|(meta, _)| meta) + } + } fn delete(&self, domain: &str, filename: &str) -> impl Future<Output = Result<()>> + Send; } diff --git a/src/micropub/mod.rs b/src/micropub/mod.rs index 719fbf0..5e11033 100644 --- a/src/micropub/mod.rs +++ b/src/micropub/mod.rs @@ -1,25 +1,26 @@ use std::collections::HashMap; -use url::Url; use std::sync::Arc; +use url::Url; +use util::NormalizedPost; use crate::database::{MicropubChannel, Storage, StorageError}; use crate::indieauth::backend::AuthBackend; use crate::indieauth::User; use crate::micropub::util::form_to_mf2_json; -use axum::extract::{FromRef, Query, State}; use axum::body::Body as BodyStream; +use axum::extract::{FromRef, Query, State}; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; use axum_extra::extract::Host; use axum_extra::headers::ContentType; -use axum::response::{IntoResponse, Response}; use axum_extra::TypedHeader; -use axum::http::StatusCode; +use kittybox_indieauth::{Scope, TokenData}; +use kittybox_util::micropub::{Error as MicropubError, ErrorKind, QueryType}; use serde::{Deserialize, Serialize}; use serde_json::json; use tokio::sync::Mutex; use tokio::task::JoinSet; use tracing::{debug, error, info, warn}; -use kittybox_indieauth::{Scope, TokenData}; -use kittybox_util::micropub::{Error as MicropubError, ErrorKind, QueryType}; #[derive(Serialize, Deserialize, Debug)] pub struct MicropubQuery { @@ -34,12 +35,12 @@ impl From<StorageError> for MicropubError { crate::database::ErrorKind::NotFound => ErrorKind::NotFound, _ => ErrorKind::InternalServerError, }, - format!("backend error: {}", err) + format!("backend error: {}", err), ) } } -mod util; +pub(crate) mod util; pub(crate) use util::normalize_mf2; #[derive(Debug)] @@ -58,7 +59,8 @@ fn populate_reply_context( array .iter() .map(|i| { - let mut item = i.as_str() + let mut item = i + .as_str() .and_then(|i| i.parse::<Url>().ok()) .and_then(|url| ctxs.get(&url)) .and_then(|ctx| ctx.mf2["items"].get(0)) @@ -68,7 +70,12 @@ fn populate_reply_context( if item.is_object() && (i != &item) { if let Some(props) = item["properties"].as_object_mut() { // Fixup the item: if it lacks a URL, add one. - if !props.get("url").and_then(serde_json::Value::as_array).map(|a| !a.is_empty()).unwrap_or(false) { + if !props + .get("url") + .and_then(serde_json::Value::as_array) + .map(|a| !a.is_empty()) + .unwrap_or(false) + { props.insert("url".to_owned(), json!([i.as_str()])); } } @@ -144,11 +151,14 @@ async fn background_processing<D: 'static + Storage>( .get("webmention") .and_then(|i| i.first().cloned()); - dbg!(Some((url.clone(), FetchedPostContext { - url, - mf2: serde_json::to_value(mf2).unwrap(), - webmention - }))) + dbg!(Some(( + url.clone(), + FetchedPostContext { + url, + mf2: serde_json::to_value(mf2).unwrap(), + webmention + } + ))) }) .collect::<HashMap<Url, FetchedPostContext>>() .await @@ -160,7 +170,11 @@ async fn background_processing<D: 'static + Storage>( }; for prop in context_props { if let Some(json) = populate_reply_context(&mf2, prop, &post_contexts) { - update.replace.as_mut().unwrap().insert(prop.to_owned(), json); + update + .replace + .as_mut() + .unwrap() + .insert(prop.to_owned(), json); } } if !update.replace.as_ref().unwrap().is_empty() { @@ -249,7 +263,7 @@ pub(crate) async fn _post<D: 'static + Storage>( if !user.check_scope(&Scope::Create) { return Err(MicropubError::from_static( ErrorKind::InvalidScope, - "Not enough privileges - try acquiring the \"create\" scope." + "Not enough privileges - try acquiring the \"create\" scope.", )); } @@ -263,7 +277,7 @@ pub(crate) async fn _post<D: 'static + Storage>( { return Err(MicropubError::from_static( ErrorKind::Forbidden, - "You're posting to a website that's not yours." + "You're posting to a website that's not yours.", )); } @@ -271,7 +285,7 @@ pub(crate) async fn _post<D: 'static + Storage>( if db.post_exists(&uid).await? { return Err(MicropubError::from_static( ErrorKind::AlreadyExists, - "UID clash was detected, operation aborted." + "UID clash was detected, operation aborted.", )); } // Save the post @@ -308,13 +322,18 @@ pub(crate) async fn _post<D: 'static + Storage>( } } - let reply = - IntoResponse::into_response((StatusCode::ACCEPTED, [("Location", uid.as_str())])); + let reply = IntoResponse::into_response((StatusCode::ACCEPTED, [("Location", uid.as_str())])); #[cfg(not(tokio_unstable))] - let _ = jobset.lock().await.spawn(background_processing(db, mf2, http)); + let _ = jobset + .lock() + .await + .spawn(background_processing(db, mf2, http)); #[cfg(tokio_unstable)] - let _ = jobset.lock().await.build_task() + let _ = jobset + .lock() + .await + .build_task() .name(format!("Kittybox background processing for post {}", uid.as_str()).as_str()) .spawn(background_processing(db, mf2, http)); @@ -332,7 +351,7 @@ enum ActionType { #[serde(untagged)] pub enum MicropubPropertyDeletion { Properties(Vec<String>), - Values(HashMap<String, Vec<serde_json::Value>>) + Values(HashMap<String, Vec<serde_json::Value>>), } #[derive(Serialize, Deserialize)] struct MicropubFormAction { @@ -346,7 +365,7 @@ pub struct MicropubAction { url: String, #[serde(flatten)] #[serde(skip_serializing_if = "Option::is_none")] - update: Option<MicropubUpdate> + update: Option<MicropubUpdate>, } #[derive(Serialize, Deserialize, Debug, Default)] @@ -361,39 +380,43 @@ pub struct MicropubUpdate { impl MicropubUpdate { pub fn check_validity(&self) -> Result<(), MicropubError> { if let Some(add) = &self.add { - if add.iter().map(|(k, _)| k.as_str()).any(|k| { - k.to_lowercase().as_str() == "uid" - }) { + if add + .iter() + .map(|(k, _)| k.as_str()) + .any(|k| k.to_lowercase().as_str() == "uid") + { return Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Update cannot modify the post UID" + "Update cannot modify the post UID", )); } } if let Some(replace) = &self.replace { - if replace.iter().map(|(k, _)| k.as_str()).any(|k| { - k.to_lowercase().as_str() == "uid" - }) { + if replace + .iter() + .map(|(k, _)| k.as_str()) + .any(|k| k.to_lowercase().as_str() == "uid") + { return Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Update cannot modify the post UID" + "Update cannot modify the post UID", )); } } let iter = match &self.delete { Some(MicropubPropertyDeletion::Properties(keys)) => { Some(Box::new(keys.iter().map(|k| k.as_str())) as Box<dyn Iterator<Item = &str>>) - }, + } Some(MicropubPropertyDeletion::Values(map)) => { Some(Box::new(map.iter().map(|(k, _)| k.as_str())) as Box<dyn Iterator<Item = &str>>) - }, + } None => None, }; if let Some(mut iter) = iter { if iter.any(|k| k.to_lowercase().as_str() == "uid") { return Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Update cannot modify the post UID" + "Update cannot modify the post UID", )); } } @@ -411,8 +434,9 @@ impl MicropubUpdate { } else if let Some(MicropubPropertyDeletion::Values(ref delete)) = self.delete { if let Some(props) = post["properties"].as_object_mut() { for (key, values) in delete { - if let Some(prop) = props.get_mut(key).and_then(serde_json::Value::as_array_mut) { - prop.retain(|v| { values.iter().all(|i| i != v) }) + if let Some(prop) = props.get_mut(key).and_then(serde_json::Value::as_array_mut) + { + prop.retain(|v| values.iter().all(|i| i != v)) } } } @@ -427,7 +451,10 @@ impl MicropubUpdate { if let Some(add) = self.add { if let Some(props) = post["properties"].as_object_mut() { for (key, value) in add { - if let Some(prop) = props.get_mut(&key).and_then(serde_json::Value::as_array_mut) { + if let Some(prop) = props + .get_mut(&key) + .and_then(serde_json::Value::as_array_mut) + { prop.extend_from_slice(value.as_slice()); } else { props.insert(key, serde_json::Value::Array(value)); @@ -444,7 +471,7 @@ impl From<MicropubFormAction> for MicropubAction { Self { action: a.action, url: a.url, - update: None + update: None, } } } @@ -457,10 +484,12 @@ async fn post_action<D: Storage, A: AuthBackend>( ) -> Result<(), MicropubError> { let uri = match action.url.parse::<hyper::Uri>() { Ok(uri) => uri, - Err(err) => return Err(MicropubError::new( - ErrorKind::InvalidRequest, - format!("url parsing error: {}", err) - )) + Err(err) => { + return Err(MicropubError::new( + ErrorKind::InvalidRequest, + format!("url parsing error: {}", err), + )) + } }; if uri.authority().unwrap() @@ -474,7 +503,7 @@ async fn post_action<D: Storage, A: AuthBackend>( { return Err(MicropubError::from_static( ErrorKind::Forbidden, - "Don't tamper with others' posts!" + "Don't tamper with others' posts!", )); } @@ -483,7 +512,7 @@ async fn post_action<D: Storage, A: AuthBackend>( if !user.check_scope(&Scope::Delete) { return Err(MicropubError::from_static( ErrorKind::InvalidScope, - "You need a \"delete\" scope for this." + "You need a \"delete\" scope for this.", )); } @@ -493,7 +522,7 @@ async fn post_action<D: Storage, A: AuthBackend>( if !user.check_scope(&Scope::Update) { return Err(MicropubError::from_static( ErrorKind::InvalidScope, - "You need an \"update\" scope for this." + "You need an \"update\" scope for this.", )); } @@ -502,7 +531,7 @@ async fn post_action<D: Storage, A: AuthBackend>( } else { return Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Update request is not set." + "Update request is not set.", )); }; @@ -554,7 +583,7 @@ async fn dispatch_body( } else { Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Invalid JSON object passed." + "Invalid JSON object passed.", )) } } else if content_type == ContentType::form_url_encoded() { @@ -565,7 +594,7 @@ async fn dispatch_body( } else { Err(MicropubError::from_static( ErrorKind::InvalidRequest, - "Invalid form-encoded data. Try h=entry&content=Hello!" + "Invalid form-encoded data. Try h=entry&content=Hello!", )) } } else { @@ -591,8 +620,8 @@ pub(crate) async fn post<D: Storage + 'static, A: AuthBackend>( Err(err) => err.into_response(), }, Ok(PostBody::MF2(mf2)) => { - let (uid, mf2) = normalize_mf2(mf2, &user); - match _post(&user, uid, mf2, db, http, jobset).await { + let NormalizedPost { id, post } = normalize_mf2(mf2, &user); + match _post(&user, id, post, db, http, jobset).await { Ok(response) => response, Err(err) => err.into_response(), } @@ -604,7 +633,10 @@ pub(crate) async fn post<D: Storage + 'static, A: AuthBackend>( #[tracing::instrument(skip(db))] pub(crate) async fn query<D: Storage, A: AuthBackend>( State(db): State<D>, - query: Result<Query<MicropubQuery>, <Query<MicropubQuery> as axum::extract::FromRequestParts<()>>::Rejection>, + query: Result< + Query<MicropubQuery>, + <Query<MicropubQuery> as axum::extract::FromRequestParts<()>>::Rejection, + >, Host(host): Host, user: User<A>, ) -> axum::response::Response { @@ -615,8 +647,9 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( } else { return MicropubError::from_static( ErrorKind::InvalidRequest, - "Invalid query provided. Try ?q=config to see what you can do." - ).into_response(); + "Invalid query provided. Try ?q=config to see what you can do.", + ) + .into_response(); }; if axum::http::Uri::try_from(user.me.as_str()) @@ -629,7 +662,7 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( ErrorKind::NotAuthorized, "This website doesn't belong to you.", ) - .into_response(); + .into_response(); } // TODO: consider replacing by `user.me.authority()`? @@ -643,7 +676,7 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( ErrorKind::InternalServerError, format!("Error fetching channels: {}", err), ) - .into_response() + .into_response() } }; @@ -653,35 +686,36 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( QueryType::Config, QueryType::Channel, QueryType::SyndicateTo, - QueryType::Category + QueryType::Category, ], channels: Some(channels), syndicate_to: None, media_endpoint: Some(user.me.join("/.kittybox/media").unwrap()), other: { let mut map = std::collections::HashMap::new(); - map.insert("kittybox_authority".to_string(), serde_json::Value::String(user.me.to_string())); + map.insert( + "kittybox_authority".to_string(), + serde_json::Value::String(user.me.to_string()), + ); map - } + }, }) - .into_response() + .into_response() } QueryType::Source => { match query.url { - Some(url) => { - match db.get_post(&url).await { - Ok(some) => match some { - Some(post) => axum::response::Json(&post).into_response(), - None => MicropubError::from_static( - ErrorKind::NotFound, - "The specified MF2 object was not found in database.", - ) - .into_response(), - }, - Err(err) => MicropubError::from(err).into_response(), - } - } + Some(url) => match db.get_post(&url).await { + Ok(some) => match some { + Some(post) => axum::response::Json(&post).into_response(), + None => MicropubError::from_static( + ErrorKind::NotFound, + "The specified MF2 object was not found in database.", + ) + .into_response(), + }, + Err(err) => MicropubError::from(err).into_response(), + }, None => { // Here, one should probably attempt to query at least the main feed and collect posts // Using a pre-made query function can't be done because it does unneeded filtering @@ -690,7 +724,7 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( ErrorKind::InvalidRequest, "Querying for post list is not implemented yet.", ) - .into_response() + .into_response() } } } @@ -700,46 +734,45 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( ErrorKind::InternalServerError, format!("error fetching channels: backend error: {}", err), ) - .into_response(), + .into_response(), }, QueryType::SyndicateTo => { axum::response::Json(json!({ "syndicate-to": [] })).into_response() - }, + } QueryType::Category => { let categories = match db.categories(user_domain).await { Ok(categories) => categories, Err(err) => { return MicropubError::new( ErrorKind::InternalServerError, - format!("error fetching categories: backend error: {}", err) - ).into_response() + format!("error fetching categories: backend error: {}", err), + ) + .into_response() } }; axum::response::Json(json!({ "categories": categories })).into_response() - }, - QueryType::Unknown(q) => return MicropubError::new( - ErrorKind::InvalidRequest, - format!("Invalid query: {}", q) - ).into_response(), + } + QueryType::Unknown(q) => { + return MicropubError::new(ErrorKind::InvalidRequest, format!("Invalid query: {}", q)) + .into_response() + } } } - pub fn router<A, S, St: Send + Sync + Clone + 'static>() -> axum::routing::MethodRouter<St> where S: Storage + FromRef<St> + 'static, A: AuthBackend + FromRef<St>, reqwest_middleware::ClientWithMiddleware: FromRef<St>, - Arc<Mutex<JoinSet<()>>>: FromRef<St> + Arc<Mutex<JoinSet<()>>>: FromRef<St>, { axum::routing::get(query::<S, A>) .post(post::<S, A>) - .layer::<_, _>(tower_http::cors::CorsLayer::new() - .allow_methods([ - axum::http::Method::GET, - axum::http::Method::POST, - ]) - .allow_origin(tower_http::cors::Any)) + .layer::<_, _>( + tower_http::cors::CorsLayer::new() + .allow_methods([axum::http::Method::GET, axum::http::Method::POST]) + .allow_origin(tower_http::cors::Any), + ) } #[cfg(test)] @@ -764,16 +797,19 @@ impl MicropubQuery { mod tests { use std::sync::Arc; - use crate::{database::Storage, micropub::MicropubError}; + use crate::{ + database::Storage, + micropub::{util::NormalizedPost, MicropubError}, + }; use bytes::Bytes; use futures::StreamExt; use serde_json::json; use tokio::sync::Mutex; use super::FetchedPostContext; - use kittybox_indieauth::{Scopes, Scope, TokenData}; use axum::extract::State; use axum_extra::extract::Host; + use kittybox_indieauth::{Scope, Scopes, TokenData}; #[test] fn test_populate_reply_context() { @@ -800,16 +836,27 @@ mod tests { } }); let fetched_ctx_url: url::Url = "https://fireburn.ru/posts/example".parse().unwrap(); - let reply_contexts = vec![(fetched_ctx_url.clone(), FetchedPostContext { - url: fetched_ctx_url.clone(), - mf2: json!({ "items": [test_ctx] }), - webmention: None, - })].into_iter().collect(); + let reply_contexts = vec![( + fetched_ctx_url.clone(), + FetchedPostContext { + url: fetched_ctx_url.clone(), + mf2: json!({ "items": [test_ctx] }), + webmention: None, + }, + )] + .into_iter() + .collect(); let like_of = super::populate_reply_context(&mf2, "like-of", &reply_contexts).unwrap(); - assert_eq!(like_of[0]["properties"]["content"], test_ctx["properties"]["content"]); - assert_eq!(like_of[0]["properties"]["url"][0].as_str().unwrap(), reply_contexts[&fetched_ctx_url].url.as_str()); + assert_eq!( + like_of[0]["properties"]["content"], + test_ctx["properties"]["content"] + ); + assert_eq!( + like_of[0]["properties"]["url"][0].as_str().unwrap(), + reply_contexts[&fetched_ctx_url].url.as_str() + ); assert_eq!(like_of[1], already_expanded_reply_ctx); assert_eq!(like_of[2], "https://fireburn.ru/posts/non-existent"); @@ -829,20 +876,21 @@ mod tests { me: "https://localhost:8080/".parse().unwrap(), client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), scope: Scopes::new(vec![Scope::Profile]), - iat: None, exp: None + iat: None, + exp: None, }; - let (uid, mf2) = super::normalize_mf2(post, &user); + let NormalizedPost { id, post } = super::normalize_mf2(post, &user); let err = super::_post( - &user, uid, mf2, db.clone(), - reqwest_middleware::ClientWithMiddleware::new( - reqwest::Client::new(), - Box::default() - ), - Arc::new(Mutex::new(tokio::task::JoinSet::new())) + &user, + id, + post, + db.clone(), + reqwest_middleware::ClientWithMiddleware::new(reqwest::Client::new(), Box::default()), + Arc::new(Mutex::new(tokio::task::JoinSet::new())), ) - .await - .unwrap_err(); + .await + .unwrap_err(); assert_eq!(err.error, super::ErrorKind::InvalidScope); @@ -865,21 +913,27 @@ mod tests { let user = TokenData { me: "https://aaronparecki.com/".parse().unwrap(), client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), - scope: Scopes::new(vec![Scope::Profile, Scope::Create, Scope::Update, Scope::Media]), - iat: None, exp: None + scope: Scopes::new(vec![ + Scope::Profile, + Scope::Create, + Scope::Update, + Scope::Media, + ]), + iat: None, + exp: None, }; - let (uid, mf2) = super::normalize_mf2(post, &user); + let NormalizedPost { id, post } = super::normalize_mf2(post, &user); let err = super::_post( - &user, uid, mf2, db.clone(), - reqwest_middleware::ClientWithMiddleware::new( - reqwest::Client::new(), - Box::default() - ), - Arc::new(Mutex::new(tokio::task::JoinSet::new())) + &user, + id, + post, + db.clone(), + reqwest_middleware::ClientWithMiddleware::new(reqwest::Client::new(), Box::default()), + Arc::new(Mutex::new(tokio::task::JoinSet::new())), ) - .await - .unwrap_err(); + .await + .unwrap_err(); assert_eq!(err.error, super::ErrorKind::Forbidden); @@ -901,20 +955,21 @@ mod tests { me: "https://localhost:8080/".parse().unwrap(), client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), scope: Scopes::new(vec![Scope::Profile, Scope::Create]), - iat: None, exp: None + iat: None, + exp: None, }; - let (uid, mf2) = super::normalize_mf2(post, &user); + let NormalizedPost { id, post } = super::normalize_mf2(post, &user); let res = super::_post( - &user, uid, mf2, db.clone(), - reqwest_middleware::ClientWithMiddleware::new( - reqwest::Client::new(), - Box::default() - ), - Arc::new(Mutex::new(tokio::task::JoinSet::new())) + &user, + id, + post, + db.clone(), + reqwest_middleware::ClientWithMiddleware::new(reqwest::Client::new(), Box::default()), + Arc::new(Mutex::new(tokio::task::JoinSet::new())), ) - .await - .unwrap(); + .await + .unwrap(); assert!(res.headers().contains_key("Location")); let location = res.headers().get("Location").unwrap(); @@ -937,10 +992,17 @@ mod tests { TokenData { me: "https://fireburn.ru/".parse().unwrap(), client_id: "https://kittybox.fireburn.ru/".parse().unwrap(), - scope: Scopes::new(vec![Scope::Profile, Scope::Create, Scope::Update, Scope::Media]), - iat: None, exp: None - }, std::marker::PhantomData - ) + scope: Scopes::new(vec![ + Scope::Profile, + Scope::Create, + Scope::Update, + Scope::Media, + ]), + iat: None, + exp: None, + }, + std::marker::PhantomData, + ), ) .await; @@ -953,7 +1015,10 @@ mod tests { .into_iter() .map(Result::unwrap) .by_ref() - .fold(Vec::new(), |mut a, i| { a.extend(i); a}); + .fold(Vec::new(), |mut a, i| { + a.extend(i); + a + }); let json: MicropubError = serde_json::from_slice(&body as &[u8]).unwrap(); assert_eq!(json.error, super::ErrorKind::NotAuthorized); } diff --git a/src/micropub/util.rs b/src/micropub/util.rs index 19f4953..8c5d5e9 100644 --- a/src/micropub/util.rs +++ b/src/micropub/util.rs @@ -1,7 +1,7 @@ use crate::database::Storage; -use kittybox_indieauth::TokenData; use chrono::prelude::*; use core::iter::Iterator; +use kittybox_indieauth::TokenData; use newbase60::num_to_sxg; use serde_json::json; use std::convert::TryInto; @@ -33,7 +33,12 @@ fn reset_dt(post: &mut serde_json::Value) -> DateTime<FixedOffset> { chrono::DateTime::from(curtime) } -pub fn normalize_mf2(mut body: serde_json::Value, user: &TokenData) -> (String, serde_json::Value) { +pub struct NormalizedPost { + pub id: String, + pub post: serde_json::Value, +} + +pub fn normalize_mf2(mut body: serde_json::Value, user: &TokenData) -> NormalizedPost { // Normalize the MF2 object here. let me = &user.me; let folder = get_folder_from_type(body["type"][0].as_str().unwrap()); @@ -137,12 +142,12 @@ pub fn normalize_mf2(mut body: serde_json::Value, user: &TokenData) -> (String, } // If there is no explicit channels, and the post is not marked as "unlisted", // post it to one of the default channels that makes sense for the post type. - if body["properties"]["channel"][0].as_str().is_none() && (!body["properties"]["visibility"] - .as_array() - .map(|v| v.contains( - &serde_json::Value::String("unlisted".to_owned()) - )).unwrap_or(false) - ) { + if body["properties"]["channel"][0].as_str().is_none() + && (!body["properties"]["visibility"] + .as_array() + .map(|v| v.contains(&serde_json::Value::String("unlisted".to_owned()))) + .unwrap_or(false)) + { match body["type"][0].as_str() { Some("h-entry") => { // Set the channel to the main channel... @@ -176,10 +181,10 @@ pub fn normalize_mf2(mut body: serde_json::Value, user: &TokenData) -> (String, } // TODO: maybe highlight #hashtags? // Find other processing to do and insert it here - return ( - body["properties"]["uid"][0].as_str().unwrap().to_string(), - body, - ); + NormalizedPost { + id: body["properties"]["uid"][0].as_str().unwrap().to_string(), + post: body, + } } pub(crate) fn form_to_mf2_json(form: Vec<(String, String)>) -> serde_json::Value { @@ -219,7 +224,7 @@ pub(crate) async fn create_feed( _ => panic!("Tried to create an unknown default feed!"), }; - let (_, feed) = normalize_mf2( + let NormalizedPost { id: _, post: feed } = normalize_mf2( json!({ "type": ["h-feed"], "properties": { @@ -244,7 +249,7 @@ mod tests { client_id: "https://quill.p3k.io/".parse().unwrap(), scope: kittybox_indieauth::Scopes::new(vec![kittybox_indieauth::Scope::Create]), exp: Some(u64::MAX), - iat: Some(0) + iat: Some(0), } } @@ -274,12 +279,15 @@ mod tests { } }); - let (uid, normalized) = normalize_mf2( - mf2.clone(), - &token_data() - ); + let NormalizedPost { + id: _, + post: normalized, + } = normalize_mf2(mf2.clone(), &token_data()); assert!( - normalized["properties"]["channel"].as_array().unwrap_or(&vec![]).is_empty(), + normalized["properties"]["channel"] + .as_array() + .unwrap_or(&vec![]) + .is_empty(), "Returned post was added to a channel despite the `unlisted` visibility" ); } @@ -295,16 +303,16 @@ mod tests { } }); - let (uid, normalized) = normalize_mf2( - mf2.clone(), - &token_data(), - ); + let NormalizedPost { + id, + post: normalized, + } = normalize_mf2(mf2.clone(), &token_data()); assert_eq!( normalized["properties"]["uid"][0], mf2["properties"]["uid"][0], "UID was replaced" ); assert_eq!( - normalized["properties"]["uid"][0], uid, + normalized["properties"]["uid"][0], id, "Returned post location doesn't match UID" ); } @@ -320,10 +328,10 @@ mod tests { } }); - let (_, normalized) = normalize_mf2( - mf2.clone(), - &token_data(), - ); + let NormalizedPost { + id: _, + post: normalized, + } = normalize_mf2(mf2.clone(), &token_data()); assert_eq!( normalized["properties"]["channel"], @@ -342,10 +350,10 @@ mod tests { } }); - let (_, normalized) = normalize_mf2( - mf2.clone(), - &token_data(), - ); + let NormalizedPost { + id: _, + post: normalized, + } = normalize_mf2(mf2.clone(), &token_data()); assert_eq!( normalized["properties"]["channel"][0], @@ -362,10 +370,7 @@ mod tests { } }); - let (uid, post) = normalize_mf2( - mf2, - &token_data(), - ); + let NormalizedPost { id, post } = normalize_mf2(mf2, &token_data()); assert_eq!( post["properties"]["published"] .as_array() @@ -392,11 +397,11 @@ mod tests { "Post doesn't have a single UID" ); assert_eq!( - post["properties"]["uid"][0], uid, + post["properties"]["uid"][0], id, "UID of a post and its supposed location don't match" ); assert!( - uid.starts_with("https://fireburn.ru/posts/"), + id.starts_with("https://fireburn.ru/posts/"), "The post namespace is incorrect" ); assert_eq!( @@ -427,10 +432,7 @@ mod tests { }, }); - let (_, post) = normalize_mf2( - mf2, - &token_data(), - ); + let NormalizedPost { id: _, post } = normalize_mf2(mf2, &token_data()); assert!( post["properties"]["url"] .as_array() @@ -456,12 +458,9 @@ mod tests { } }); - let (uid, post) = normalize_mf2( - mf2, - &token_data(), - ); + let NormalizedPost { id, post } = normalize_mf2(mf2, &token_data()); assert_eq!( - post["properties"]["uid"][0], uid, + post["properties"]["uid"][0], id, "UID of a post and its supposed location don't match" ); assert_eq!(post["properties"]["author"][0], "https://fireburn.ru/"); diff --git a/src/webmentions/check.rs b/src/webmentions/check.rs index 683cc6b..380f4db 100644 --- a/src/webmentions/check.rs +++ b/src/webmentions/check.rs @@ -1,7 +1,7 @@ -use std::rc::Rc; -use microformats::types::PropertyValue; use html5ever::{self, tendril::TendrilSink}; use kittybox_util::MentionType; +use microformats::types::PropertyValue; +use std::rc::Rc; // TODO: replace. mod rcdom; @@ -17,7 +17,11 @@ pub enum Error { } #[tracing::instrument] -pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url::Url, link: &url::Url) -> Result<Option<(MentionType, serde_json::Value)>, Error> { +pub fn check_mention( + document: impl AsRef<str> + std::fmt::Debug, + base_url: &url::Url, + link: &url::Url, +) -> Result<Option<(MentionType, serde_json::Value)>, Error> { tracing::debug!("Parsing MF2 markup..."); // First, check the document for MF2 markup let document = microformats::from_html(document.as_ref(), base_url.clone())?; @@ -29,8 +33,10 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url tracing::debug!("Processing item: {:?}", item); for (prop, interaction_type) in [ - ("in-reply-to", MentionType::Reply), ("like-of", MentionType::Like), - ("bookmark-of", MentionType::Bookmark), ("repost-of", MentionType::Repost) + ("in-reply-to", MentionType::Reply), + ("like-of", MentionType::Like), + ("bookmark-of", MentionType::Bookmark), + ("repost-of", MentionType::Repost), ] { if let Some(propvals) = item.properties.get(prop) { tracing::debug!("Has a u-{} property", prop); @@ -38,7 +44,10 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url if let PropertyValue::Url(url) = val { if url == link { tracing::debug!("URL matches! Webmention is valid"); - return Ok(Some((interaction_type, serde_json::to_value(item).unwrap()))) + return Ok(Some(( + interaction_type, + serde_json::to_value(item).unwrap(), + ))); } } } @@ -46,7 +55,9 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url } // Process `content` tracing::debug!("Processing e-content..."); - if let Some(PropertyValue::Fragment(content)) = item.properties.get("content") + if let Some(PropertyValue::Fragment(content)) = item + .properties + .get("content") .map(Vec::as_slice) .unwrap_or_default() .first() @@ -65,7 +76,8 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url // iteration of the loop. // // Empty list means all nodes were processed. - let mut unprocessed_nodes: Vec<Rc<rcdom::Node>> = root.children.borrow().iter().cloned().collect(); + let mut unprocessed_nodes: Vec<Rc<rcdom::Node>> = + root.children.borrow().iter().cloned().collect(); while !unprocessed_nodes.is_empty() { // "Take" the list out of its memory slot, replace it with an empty list let nodes = std::mem::take(&mut unprocessed_nodes); @@ -74,15 +86,23 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url // Add children nodes to the list for the next iteration unprocessed_nodes.extend(node.children.borrow().iter().cloned()); - if let rcdom::NodeData::Element { ref name, ref attrs, .. } = node.data { + if let rcdom::NodeData::Element { + ref name, + ref attrs, + .. + } = node.data + { // If it's not `<a>`, skip it - if name.local != *"a" { continue; } + if name.local != *"a" { + continue; + } let mut is_mention: bool = false; for attr in attrs.borrow().iter() { if attr.name.local == *"rel" { // Don't count `rel="nofollow"` links — a web crawler should ignore them // and so for purposes of driving visitors they are useless - if attr.value + if attr + .value .as_ref() .split([',', ' ']) .any(|v| v == "nofollow") @@ -92,7 +112,9 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url } } // if it's not `<a href="...">`, skip it - if attr.name.local != *"href" { continue; } + if attr.name.local != *"href" { + continue; + } // Be forgiving in parsing URLs, and resolve them against the base URL if let Ok(url) = base_url.join(attr.value.as_ref()) { if &url == link { @@ -101,12 +123,14 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url } } if is_mention { - return Ok(Some((MentionType::Mention, serde_json::to_value(item).unwrap()))); + return Ok(Some(( + MentionType::Mention, + serde_json::to_value(item).unwrap(), + ))); } } } } - } } diff --git a/src/webmentions/mod.rs b/src/webmentions/mod.rs index 91b274b..57f9a57 100644 --- a/src/webmentions/mod.rs +++ b/src/webmentions/mod.rs @@ -1,9 +1,14 @@ -use axum::{extract::{FromRef, State}, response::{IntoResponse, Response}, routing::post, Form}; use axum::http::StatusCode; +use axum::{ + extract::{FromRef, State}, + response::{IntoResponse, Response}, + routing::post, + Form, +}; use tracing::error; -use crate::database::{Storage, StorageError}; use self::queue::JobQueue; +use crate::database::{Storage, StorageError}; pub mod queue; #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -24,40 +29,46 @@ async fn accept_webmention<Q: JobQueue<Webmention>>( Form(webmention): Form<Webmention>, ) -> Response { if let Err(err) = webmention.source.parse::<url::Url>() { - return (StatusCode::BAD_REQUEST, err.to_string()).into_response() + return (StatusCode::BAD_REQUEST, err.to_string()).into_response(); } if let Err(err) = webmention.target.parse::<url::Url>() { - return (StatusCode::BAD_REQUEST, err.to_string()).into_response() + return (StatusCode::BAD_REQUEST, err.to_string()).into_response(); } match queue.put(&webmention).await { Ok(_id) => StatusCode::ACCEPTED.into_response(), - Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, [ - ("Content-Type", "text/plain") - ], err.to_string()).into_response() + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + [("Content-Type", "text/plain")], + err.to_string(), + ) + .into_response(), } } -pub fn router<St: Clone + Send + Sync + 'static, Q: JobQueue<Webmention> + FromRef<St>>() -> axum::Router<St> { - axum::Router::new() - .route("/.kittybox/webmention", post(accept_webmention::<Q>)) +pub fn router<St: Clone + Send + Sync + 'static, Q: JobQueue<Webmention> + FromRef<St>>( +) -> axum::Router<St> { + axum::Router::new().route("/.kittybox/webmention", post(accept_webmention::<Q>)) } #[derive(thiserror::Error, Debug)] pub enum SupervisorError { #[error("the task was explicitly cancelled")] - Cancelled + Cancelled, } -pub type SupervisedTask = tokio::task::JoinHandle<Result<std::convert::Infallible, SupervisorError>>; +pub type SupervisedTask = + tokio::task::JoinHandle<Result<std::convert::Infallible, SupervisorError>>; -pub fn supervisor<E, A, F>(mut f: F, cancellation_token: tokio_util::sync::CancellationToken) -> SupervisedTask +pub fn supervisor<E, A, F>( + mut f: F, + cancellation_token: tokio_util::sync::CancellationToken, +) -> SupervisedTask where E: std::error::Error + std::fmt::Debug + Send + 'static, A: std::future::Future<Output = Result<std::convert::Infallible, E>> + Send + 'static, - F: FnMut() -> A + Send + 'static + F: FnMut() -> A + Send + 'static, { - let supervisor_future = async move { loop { // Don't spawn the task if we are already cancelled, but @@ -65,7 +76,7 @@ where // crashed and we immediately received a cancellation // request after noticing the crashed task) if cancellation_token.is_cancelled() { - return Err(SupervisorError::Cancelled) + return Err(SupervisorError::Cancelled); } let task = tokio::task::spawn(f()); tokio::select! { @@ -87,7 +98,13 @@ where return tokio::task::spawn(supervisor_future); #[cfg(tokio_unstable)] return tokio::task::Builder::new() - .name(format!("supervisor for background task {}", std::any::type_name::<A>()).as_str()) + .name( + format!( + "supervisor for background task {}", + std::any::type_name::<A>() + ) + .as_str(), + ) .spawn(supervisor_future) .unwrap(); } @@ -99,39 +116,55 @@ enum Error<Q: std::error::Error + std::fmt::Debug + Send + 'static> { #[error("queue error: {0}")] Queue(#[from] Q), #[error("storage error: {0}")] - Storage(StorageError) + Storage(StorageError), } -async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 'static>(queue: Q, db: S, http: reqwest_middleware::ClientWithMiddleware) -> Result<std::convert::Infallible, Error<Q::Error>> { - use futures_util::StreamExt; +async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 'static>( + queue: Q, + db: S, + http: reqwest_middleware::ClientWithMiddleware, +) -> Result<std::convert::Infallible, Error<Q::Error>> { use self::queue::Job; + use futures_util::StreamExt; let mut stream = queue.into_stream().await?; while let Some(item) = stream.next().await.transpose()? { let job = item.job(); let (source, target) = ( job.source.parse::<url::Url>().unwrap(), - job.target.parse::<url::Url>().unwrap() + job.target.parse::<url::Url>().unwrap(), ); let (code, text) = match http.get(source.clone()).send().await { Ok(response) => { let code = response.status(); - if ![StatusCode::OK, StatusCode::GONE].iter().any(|i| i == &code) { - error!("error processing webmention: webpage fetch returned {}", code); + if ![StatusCode::OK, StatusCode::GONE] + .iter() + .any(|i| i == &code) + { + error!( + "error processing webmention: webpage fetch returned {}", + code + ); continue; } match response.text().await { Ok(text) => (code, text), Err(err) => { - error!("error processing webmention: error fetching webpage text: {}", err); - continue + error!( + "error processing webmention: error fetching webpage text: {}", + err + ); + continue; } } } Err(err) => { - error!("error processing webmention: error requesting webpage: {}", err); - continue + error!( + "error processing webmention: error requesting webpage: {}", + err + ); + continue; } }; @@ -150,7 +183,10 @@ async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 's continue; } Err(err) => { - error!("error processing webmention: error checking webmention: {}", err); + error!( + "error processing webmention: error checking webmention: {}", + err + ); continue; } }; @@ -158,31 +194,47 @@ async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 's { mention["type"] = serde_json::json!(["h-cite"]); - if !mention["properties"].as_object().unwrap().contains_key("uid") { - let url = mention["properties"]["url"][0].as_str().unwrap_or_else(|| target.as_str()).to_owned(); + if !mention["properties"] + .as_object() + .unwrap() + .contains_key("uid") + { + let url = mention["properties"]["url"][0] + .as_str() + .unwrap_or_else(|| target.as_str()) + .to_owned(); let props = mention["properties"].as_object_mut().unwrap(); - props.insert("uid".to_owned(), serde_json::Value::Array( - vec![serde_json::Value::String(url)]) + props.insert( + "uid".to_owned(), + serde_json::Value::Array(vec![serde_json::Value::String(url)]), ); } } - db.add_or_update_webmention(target.as_str(), mention_type, mention).await.map_err(Error::<Q::Error>::Storage)?; + db.add_or_update_webmention(target.as_str(), mention_type, mention) + .await + .map_err(Error::<Q::Error>::Storage)?; } } unreachable!() } -pub fn supervised_webmentions_task<St: Send + Sync + 'static, S: Storage + FromRef<St> + 'static, Q: JobQueue<Webmention> + FromRef<St> + 'static>( +pub fn supervised_webmentions_task< + St: Send + Sync + 'static, + S: Storage + FromRef<St> + 'static, + Q: JobQueue<Webmention> + FromRef<St> + 'static, +>( state: &St, - cancellation_token: tokio_util::sync::CancellationToken + cancellation_token: tokio_util::sync::CancellationToken, ) -> SupervisedTask -where reqwest_middleware::ClientWithMiddleware: FromRef<St> +where + reqwest_middleware::ClientWithMiddleware: FromRef<St>, { let queue = Q::from_ref(state); let storage = S::from_ref(state); let http = reqwest_middleware::ClientWithMiddleware::from_ref(state); - supervisor::<Error<Q::Error>, _, _>(move || process_webmentions_from_queue( - queue.clone(), storage.clone(), http.clone() - ), cancellation_token) + supervisor::<Error<Q::Error>, _, _>( + move || process_webmentions_from_queue(queue.clone(), storage.clone(), http.clone()), + cancellation_token, + ) } diff --git a/src/webmentions/queue.rs b/src/webmentions/queue.rs index 52bcdfa..a33de1a 100644 --- a/src/webmentions/queue.rs +++ b/src/webmentions/queue.rs @@ -6,7 +6,7 @@ use super::Webmention; static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("./migrations/webmention"); -pub use kittybox_util::queue::{JobQueue, JobItem, Job, JobStream}; +pub use kittybox_util::queue::{Job, JobItem, JobQueue, JobStream}; pub trait PostgresJobItem: JobItem + sqlx::FromRow<'static, sqlx::postgres::PgRow> { const DATABASE_NAME: &'static str; @@ -17,7 +17,7 @@ pub trait PostgresJobItem: JobItem + sqlx::FromRow<'static, sqlx::postgres::PgRo struct PostgresJobRow<T: PostgresJobItem> { id: Uuid, #[sqlx(flatten)] - job: T + job: T, } #[derive(Debug)] @@ -29,7 +29,6 @@ pub struct PostgresJob<T: PostgresJobItem> { runtime_handle: tokio::runtime::Handle, } - impl<T: PostgresJobItem> Drop for PostgresJob<T> { // This is an emulation of "async drop" — the struct retains a // runtime handle, which it uses to block on a future that does @@ -87,7 +86,9 @@ impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJob<Webmention> { fn job(&self) -> &Webmention { &self.job } - async fn done(mut self) -> Result<(), <PostgresJobQueue<Webmention> as JobQueue<Webmention>>::Error> { + async fn done( + mut self, + ) -> Result<(), <PostgresJobQueue<Webmention> as JobQueue<Webmention>>::Error> { tracing::debug!("Deleting {} from the job queue", self.id); sqlx::query("DELETE FROM kittybox_webmention.incoming_webmention_queue WHERE id = $1") .bind(self.id) @@ -100,13 +101,13 @@ impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJob<Webmention> { pub struct PostgresJobQueue<T> { db: sqlx::PgPool, - _phantom: std::marker::PhantomData<T> + _phantom: std::marker::PhantomData<T>, } impl<T> Clone for PostgresJobQueue<T> { fn clone(&self) -> Self { Self { db: self.db.clone(), - _phantom: std::marker::PhantomData + _phantom: std::marker::PhantomData, } } } @@ -120,15 +121,21 @@ impl PostgresJobQueue<Webmention> { sqlx::postgres::PgPoolOptions::new() .max_connections(50) .connect_with(options) - .await? - ).await - + .await?, + ) + .await } pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result<Self, sqlx::Error> { - db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox_webmention")).await?; + db.execute(sqlx::query( + "CREATE SCHEMA IF NOT EXISTS kittybox_webmention", + )) + .await?; MIGRATOR.run(&db).await?; - Ok(Self { db, _phantom: std::marker::PhantomData }) + Ok(Self { + db, + _phantom: std::marker::PhantomData, + }) } } @@ -180,13 +187,14 @@ impl JobQueue<Webmention> for PostgresJobQueue<Webmention> { Some(item) => return Ok(Some((item, ()))), None => { listener.lock().await.recv().await?; - continue + continue; } } } } } - }).boxed(); + }) + .boxed(); Ok(stream) } @@ -196,7 +204,7 @@ impl JobQueue<Webmention> for PostgresJobQueue<Webmention> { mod tests { use std::sync::Arc; - use super::{Webmention, PostgresJobQueue, Job, JobQueue, MIGRATOR}; + use super::{Job, JobQueue, PostgresJobQueue, Webmention, MIGRATOR}; use futures_util::StreamExt; #[sqlx::test(migrator = "MIGRATOR")] @@ -204,7 +212,7 @@ mod tests { async fn test_webmention_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> { let test_webmention = Webmention { source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(), - target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned() + target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned(), }; let queue = PostgresJobQueue::<Webmention>::from_pool(pool).await?; @@ -236,7 +244,7 @@ mod tests { match queue.get_one().await? { Some(item) => panic!("Unexpected item {:?} returned from job queue!", item), - None => Ok(()) + None => Ok(()), } } @@ -245,7 +253,7 @@ mod tests { async fn test_no_hangups_in_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> { let test_webmention = Webmention { source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(), - target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned() + target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned(), }; let queue = PostgresJobQueue::<Webmention>::from_pool(pool.clone()).await?; @@ -272,18 +280,18 @@ mod tests { } }); } - tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()).await.unwrap_err(); + tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()) + .await + .unwrap_err(); - let future = tokio::task::spawn( - tokio::time::timeout( - std::time::Duration::from_secs(10), async move { - stream.next().await.unwrap().unwrap() - } - ) - ); + let future = tokio::task::spawn(tokio::time::timeout( + std::time::Duration::from_secs(10), + async move { stream.next().await.unwrap().unwrap() }, + )); // Let the other task drop the guard it is holding barrier.wait().await; - let mut guard = future.await + let mut guard = future + .await .expect("Timeout on fetching item") .expect("Job queue error"); assert_eq!(guard.job(), &test_webmention); diff --git a/templates-neo/Cargo.toml b/templates-neo/Cargo.toml deleted file mode 100644 index 0be4dd2..0000000 --- a/templates-neo/Cargo.toml +++ /dev/null @@ -1,34 +0,0 @@ -[package] -name = "kittybox-html" -version = "0.2.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[build-dependencies] -libflate = { workspace = true } -walkdir = { workspace = true } - -[dev-dependencies] -faker_rand = { workspace = true } -rand = { workspace = true } - -[dependencies] -axum = { workspace = true } -chrono = { workspace = true } -ellipse = { workspace = true } -html = { workspace = true } -http = { workspace = true } -include_dir = { workspace = true } -microformats = { workspace = true } -serde_json = { workspace = true } -thiserror = { workspace = true } -time = { workspace = true, features = ["formatting"] } -url = { workspace = true } - -[dependencies.kittybox-util] -version = "0.3.0" -path = "../util" -[dependencies.kittybox-indieauth] -version = "0.2.0" -path = "../indieauth" diff --git a/templates-neo/src/lib.rs b/templates-neo/src/lib.rs deleted file mode 100644 index 1ae9e03..0000000 --- a/templates-neo/src/lib.rs +++ /dev/null @@ -1,2 +0,0 @@ -#![recursion_limit = "512"] -pub mod mf2; diff --git a/templates-neo/src/main.rs b/templates-neo/src/main.rs deleted file mode 100644 index d374e3f..0000000 --- a/templates-neo/src/main.rs +++ /dev/null @@ -1,18 +0,0 @@ -#![recursion_limit = "512"] -use std::io::Write; - -use kittybox_html::mf2::Entry; - -fn main() { - let mf2 = serde_json::from_reader::<_, microformats::types::Item>(std::io::stdin()).unwrap(); - let entry = Entry::try_from(mf2).unwrap(); - - let mut article = html::content::Article::builder(); - entry.build(&mut article); - - let mut stdout = std::io::stdout().lock(); - stdout - .write_all(article.build().to_string().as_bytes()) - .unwrap(); - stdout.write_all(b"\n").unwrap(); -} diff --git a/templates-neo/src/mf2.rs b/templates-neo/src/mf2.rs deleted file mode 100644 index 3cf453f..0000000 --- a/templates-neo/src/mf2.rs +++ /dev/null @@ -1,467 +0,0 @@ -use std::{borrow::Cow, collections::HashMap}; - -use html::{ - content::builders::{ArticleBuilder, SectionBuilder}, - inline_text::Anchor, - media::builders, -}; -use microformats::types::{ - temporal::Value as Temporal, Class, Fragment, Item, KnownClass, PropertyValue, -}; - -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error("wrong mf2 class, expected {expected:?}, got {got:?}")] - WrongClass { - expected: Vec<KnownClass>, - got: Vec<Class>, - }, - #[error("entry lacks `uid` property")] - NoUid, - #[error("unexpected type of property value: expected {expected}, got {got:?}")] - WrongValueType { - expected: &'static str, - got: PropertyValue, - }, - #[error("missing property: {0}")] - MissingProperty(&'static str), -} - -pub enum Image { - Plain(url::Url), - Accessible { src: url::Url, alt: String }, -} - -impl Image { - pub fn build( - self, - img: &mut html::media::builders::ImageBuilder, - ) -> &mut html::media::builders::ImageBuilder { - match self { - Image::Plain(url) => img.src(String::from(url)), - Image::Accessible { src, alt } => img.src(String::from(src)).alt(alt), - } - } -} - -pub struct Card { - uid: url::Url, - urls: Vec<url::Url>, - name: String, - note: Option<String>, - photo: Image, - pronouns: Vec<String>, -} - -impl TryFrom<Item> for Card { - type Error = Error; - - fn try_from(card: Item) -> Result<Self, Self::Error> { - if card.r#type.as_slice() != [Class::Known(KnownClass::Card)] { - return Err(Error::WrongClass { - expected: vec![KnownClass::Card], - got: card.r#type, - }); - } - - let mut props = card.properties; - let uid = { - let uids = props.remove("uid").ok_or(Error::NoUid)?; - if let Some(PropertyValue::Url(uid)) = uids.into_iter().take(1).next() { - uid - } else { - return Err(Error::NoUid); - } - }; - - Ok(Self { - uid, - urls: props - .remove("url") - .unwrap_or_default() - .into_iter() - .filter_map(|v| { - if let PropertyValue::Url(url) = v { - Some(url) - } else { - None - } - }) - .collect(), - name: props - .remove("name") - .unwrap_or_default() - .into_iter() - .next() - .ok_or(Error::MissingProperty("name")) - .and_then(|v| match v { - PropertyValue::Plain(plain) => Ok(plain), - other => Err(Error::WrongValueType { - expected: "string", - got: other, - }), - })?, - note: props - .remove("note") - .unwrap_or_default() - .into_iter() - .next() - .map(|v| match v { - PropertyValue::Plain(plain) => Ok(plain), - other => Err(Error::WrongValueType { - expected: "string", - got: other, - }), - }) - .transpose()?, - photo: props - .remove("photo") - .unwrap_or_default() - .into_iter() - .next() - .ok_or(Error::MissingProperty("photo")) - .and_then(|v| match v { - PropertyValue::Url(url) => Ok(Image::Plain(url)), - PropertyValue::Image(image) => match image.alt { - Some(alt) => Ok(Image::Accessible { - src: image.value, - alt, - }), - None => Ok(Image::Plain(image.value)) - }, - other => Err(Error::WrongValueType { - expected: "string", - got: other, - }), - })?, - pronouns: props - .remove("pronoun") - .unwrap_or_default() - .into_iter() - .map(|v| match v { - PropertyValue::Plain(plain) => Ok(plain), - other => Err(Error::WrongValueType { - expected: "string", - got: other, - }), - }) - .collect::<Result<Vec<String>, _>>()?, - }) - } -} - -impl Card { - pub fn build_section( - self, - section: &mut html::content::builders::SectionBuilder, - ) -> &mut html::content::builders::SectionBuilder { - section.class("mini-h-card").anchor(|a| { - a.class("larger u-author") - .href(String::from(self.uid)) - .image(move |img| self.photo.build(img).loading("lazy")) - .text(self.name) - }) - } - - pub fn build( - self, - article: &mut html::content::builders::ArticleBuilder, - ) -> &mut html::content::builders::ArticleBuilder { - let urls: Vec<_> = self.urls.into_iter().filter(|u| *u != self.uid).collect(); - - article - .class("h-card") - .image(move |builder| self.photo.build(builder)) - .heading_1(move |builder| { - builder.anchor(|builder| { - builder - .class("u-url u-uid p-name") - .href(String::from(self.uid)) - .text(self.name) - }) - }); - - if !self.pronouns.is_empty() { - article.span(move |span| { - span.text("("); - self.pronouns.into_iter().for_each(|p| { - span.text(p); - }); - span.text(")") - }); - } - - if let Some(note) = self.note { - article.paragraph(move |p| p.class("p-note").text(note)); - } - - if !urls.is_empty() { - article.paragraph(|p| p.text("Can be found elsewhere at:")); - article.unordered_list(move |ul| { - for url in urls { - let url = String::from(url); - ul.list_item(move |li| { - li.push(Anchor::builder().href(url.clone()).text(url).build()) - }); - } - - ul - }); - } - - article - } -} - -impl TryFrom<PropertyValue> for Card { - type Error = Error; - - fn try_from(v: PropertyValue) -> Result<Self, Self::Error> { - match v { - PropertyValue::Item(item) => item.try_into(), - other => Err(Error::WrongValueType { - expected: "h-card", - got: other, - }), - } - } -} - -pub struct Cite { - uid: url::Url, - url: Vec<url::Url>, - in_reply_to: Option<Vec<Citation>>, - author: Card, - published: Option<time::OffsetDateTime>, - content: Content, -} - -impl TryFrom<Item> for Cite { - type Error = Error; - - fn try_from(cite: Item) -> Result<Self, Self::Error> { - if cite.r#type.as_slice() != [Class::Known(KnownClass::Cite)] { - return Err(Error::WrongClass { - expected: vec![KnownClass::Cite], - got: cite.r#type, - }); - } - - todo!() - } -} - -pub enum Citation { - Brief(url::Url), - Full(Cite), -} - -impl TryFrom<PropertyValue> for Citation { - type Error = Error; - fn try_from(v: PropertyValue) -> Result<Self, Self::Error> { - match v { - PropertyValue::Url(url) => Ok(Self::Brief(url)), - PropertyValue::Item(item) => Ok(Self::Full(item.try_into()?)), - other => Err(Error::WrongValueType { - expected: "url or h-cite", - got: other, - }), - } - } -} - -pub struct Content(Fragment); - -impl From<Content> for html::content::Main { - fn from(content: Content) -> Self { - let mut builder = Self::builder(); - builder.class("e-content").text(content.0.html); - if let Some(lang) = content.0.lang { - builder.lang(Cow::Owned(lang)); - } - builder.build() - } -} - -pub struct Entry { - uid: url::Url, - url: Vec<url::Url>, - in_reply_to: Option<Citation>, - author: Card, - category: Vec<String>, - syndication: Vec<url::Url>, - published: time::OffsetDateTime, - content: Content, -} - -impl TryFrom<Item> for Entry { - type Error = Error; - fn try_from(entry: Item) -> Result<Self, Self::Error> { - if entry.r#type.as_slice() != [Class::Known(KnownClass::Entry)] { - return Err(Error::WrongClass { - expected: vec![KnownClass::Entry], - got: entry.r#type, - }); - } - - let mut props = entry.properties; - let uid = { - let uids = props.remove("uid").ok_or(Error::NoUid)?; - if let Some(PropertyValue::Url(uid)) = uids.into_iter().take(1).next() { - uid - } else { - return Err(Error::NoUid); - } - }; - Ok(Entry { - uid, - url: props - .remove("url") - .unwrap_or_default() - .into_iter() - .filter_map(|v| { - if let PropertyValue::Url(url) = v { - Some(url) - } else { - None - } - }) - .collect(), - in_reply_to: props - .remove("in-reply-to") - .unwrap_or_default() - .into_iter() - .next() - .map(|v| v.try_into()) - .transpose()?, - author: props - .remove("author") - .unwrap_or_default() - .into_iter() - .next() - .map(|v| v.try_into()) - .transpose()? - .ok_or(Error::MissingProperty("author"))?, - category: props - .remove("category") - .unwrap_or_default() - .into_iter() - .map(|v| match v { - PropertyValue::Plain(string) => Ok(string), - other => Err(Error::WrongValueType { - expected: "string", - got: other, - }), - }) - .collect::<Result<Vec<_>, _>>()?, - syndication: props - .remove("syndication") - .unwrap_or_default() - .into_iter() - .map(|v| match v { - PropertyValue::Url(url) => Ok(url), - other => Err(Error::WrongValueType { - expected: "link", - got: other, - }), - }) - .collect::<Result<Vec<_>, _>>()?, - published: props - .remove("published") - .unwrap_or_default() - .into_iter() - .next() - .map( - |v| -> Result<time::OffsetDateTime, Error> { - match v { - PropertyValue::Temporal(Temporal::Timestamp(ref dt)) => { - // This is incredibly sketchy. - let (date, time, offset) = ( - dt.date.to_owned().ok_or_else(|| Error::WrongValueType { - expected: "timestamp (date, time, offset)", - got: v.clone() - })?.data, - dt.time.to_owned().ok_or_else(|| Error::WrongValueType { - expected: "timestamp (date, time, offset)", - got: v.clone() - })?.data, - dt.offset.to_owned().ok_or_else(|| Error::WrongValueType { - expected: "timestamp (date, time, offset)", - got: v.clone() - })?.data, - ); - - Ok(date.with_time(time).assume_offset(offset)) - } - other => Err(Error::WrongValueType { - expected: "timestamp", - got: other, - }), - } - }, - ) - .ok_or(Error::MissingProperty("published"))??, - content: props - .remove("content") - .unwrap_or_default() - .into_iter() - .next() - .ok_or(Error::MissingProperty("content")) - .and_then(|v| match v { - PropertyValue::Fragment(fragment) => Ok(Content(fragment)), - other => Err(Error::WrongValueType { - expected: "html", - got: other, - }), - })?, - }) - } -} - -impl Entry { - pub fn build(self, article: &mut ArticleBuilder) -> &mut ArticleBuilder { - article - .class("h-entry") - .header(|header| { - header - .class("metadata") - .section(|section| self.author.build_section(section)) - .section(|section| { - section - .division(|div| { - div.anchor(|a| { - a.class("u-url u-uid").href(String::from(self.uid)).push( - html::inline_text::Time::builder() - .text( - self.published - .format(&time::format_description::well_known::Rfc2822) - .unwrap() - ) - .date_time(self.published.format(&time::format_description::well_known::Rfc3339).unwrap()) - .build(), - ) - }) - }) - .division(|div| { - div.text("Tagged").unordered_list(|ul| { - for category in self.category { - ul.list_item(|li| li.class("p-category").text(category)); - } - - ul - }) - }) - }) - }) - .main(|main| { - if let Some(lang) = self.content.0.lang { - main.lang(lang); - } - - // XXX .text() and .push() are completely equivalent - // since .text() does no escaping - main.push(self.content.0.html) - }) - .footer(|footer| footer) - } -} diff --git a/templates/Cargo.toml b/templates/Cargo.toml index 19855e6..ca56dfe 100644 --- a/templates/Cargo.toml +++ b/templates/Cargo.toml @@ -28,5 +28,5 @@ serde_json = { workspace = true } version = "0.3.0" path = "../util" [dependencies.kittybox-indieauth] -version = "0.2.0" +version = "0.3.0" path = "../indieauth" diff --git a/templates/assets/style.css b/templates/assets/style.css index 6139288..97483d4 100644 --- a/templates/assets/style.css +++ b/templates/assets/style.css @@ -175,6 +175,7 @@ article.h-entry, article.h-feed, article.h-card, article.h-event { } .webinteractions > ul.counters > li > .icon { font-size: 1.5em; + font-family: emoji; } .webinteractions > ul.counters > li { display: inline-flex; @@ -300,11 +301,13 @@ body > a#skip-to-content:focus { white-space: nowrap; width: 1px; } - +/* Extras: styles to demarcate output generated by machine learning models + * (No, LLMs and diffusion image generation models are not artificial intelligence) + */ figure.llm-quote { background: #ddd; border-left: 0.5em solid black; - border-image: repeating-linear-gradient(45deg, #000000, #000000 0.75em, #FFFF00 0.75em, #FFFF00 1.5em) 8; + border-image: repeating-linear-gradient(45deg, #000000, #333333 0.75em, #DDDD00 0.75em, #FFFF00 1.5em) 8; padding: 0.5em; padding-left: 0.75em; margin-left: 3em; @@ -319,3 +322,7 @@ figure.llm-quote > figcaption { background-color: #242424; } } +img.diffusion-model-output { + border-left: 0.5em solid black; + border-image: repeating-linear-gradient(45deg, #000000, #333333 0.75em, #DDDD00 0.75em, #FFFF00 1.5em) 8; +} diff --git a/templates/build.rs b/templates/build.rs index 5a62855..057666b 100644 --- a/templates/build.rs +++ b/templates/build.rs @@ -22,8 +22,7 @@ fn main() -> Result<(), std::io::Error> { println!("cargo:rerun-if-changed=assets/"); let assets_path = std::path::Path::new("assets"); - let mut assets = WalkDir::new(assets_path) - .into_iter(); + let mut assets = WalkDir::new(assets_path).into_iter(); while let Some(Ok(entry)) = assets.next() { eprintln!("Processing {}", entry.path().display()); let out_path = out_dir.join(entry.path().strip_prefix(assets_path).unwrap()); @@ -31,11 +30,15 @@ fn main() -> Result<(), std::io::Error> { eprintln!("Creating directory {}", &out_path.display()); if let Err(err) = std::fs::create_dir(&out_path) { if err.kind() != std::io::ErrorKind::AlreadyExists { - return Err(err) + return Err(err); } } } else { - eprintln!("Copying {} to {}", entry.path().display(), out_path.display()); + eprintln!( + "Copying {} to {}", + entry.path().display(), + out_path.display() + ); std::fs::copy(entry.path(), &out_path)?; } } @@ -43,16 +46,11 @@ fn main() -> Result<(), std::io::Error> { let walker = WalkDir::new(&out_dir) .into_iter() .map(Result::unwrap) - .filter(|e| { - e.file_type().is_file() && e.path().extension().unwrap() != "gz" - }); + .filter(|e| e.file_type().is_file() && e.path().extension().unwrap() != "gz"); for entry in walker { let normal_path = entry.path(); let gzip_path = normal_path.with_extension({ - let mut extension = normal_path - .extension() - .unwrap() - .to_owned(); + let mut extension = normal_path.extension().unwrap().to_owned(); extension.push(OsStr::new(".gz")); extension }); diff --git a/templates/src/assets.rs b/templates/src/assets.rs new file mode 100644 index 0000000..493c14d --- /dev/null +++ b/templates/src/assets.rs @@ -0,0 +1,47 @@ +use axum::extract::Path; +use axum::http::header::{CACHE_CONTROL, CONTENT_ENCODING, CONTENT_TYPE, X_CONTENT_TYPE_OPTIONS}; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; + +const ASSETS: include_dir::Dir<'static> = include_dir::include_dir!("$OUT_DIR/"); +const CACHE_FOR_A_DAY: &str = "max-age=86400"; +const GZIP: &str = "gzip"; + +pub async fn statics(Path(path): Path<String>) -> Response { + let content_type: &'static str = if path.ends_with(".js") { + "application/javascript" + } else if path.ends_with(".css") { + "text/css" + } else if path.ends_with(".html") { + "text/html; charset=\"utf-8\"" + } else { + "application/octet-stream" + }; + + match ASSETS.get_file(path.clone() + ".gz") { + Some(file) => ( + StatusCode::OK, + [ + (CONTENT_TYPE, content_type), + (CONTENT_ENCODING, GZIP), + (CACHE_CONTROL, CACHE_FOR_A_DAY), + (X_CONTENT_TYPE_OPTIONS, "nosniff"), + ], + file.contents(), + ) + .into_response(), + None => match ASSETS.get_file(path) { + Some(file) => ( + StatusCode::OK, + [ + (CONTENT_TYPE, content_type), + (CACHE_CONTROL, CACHE_FOR_A_DAY), + (X_CONTENT_TYPE_OPTIONS, "nosniff"), + ], + file.contents(), + ) + .into_response(), + None => StatusCode::NOT_FOUND.into_response(), + }, + } +} diff --git a/templates/src/lib.rs b/templates/src/lib.rs index d9fe86b..0f9f7c6 100644 --- a/templates/src/lib.rs +++ b/templates/src/lib.rs @@ -7,55 +7,9 @@ pub use indieauth::AuthorizationRequestPage; mod login; pub use login::{LoginPage, LogoutPage}; mod mf2; -pub use mf2::{Entry, VCard, Feed, Food, POSTS_PER_PAGE}; - +pub use mf2::{Entry, Feed, Food, VCard, POSTS_PER_PAGE}; pub mod admin; - -pub mod assets { - use axum::response::{IntoResponse, Response}; - use axum::extract::Path; - use axum::http::StatusCode; - use axum::http::header::{CONTENT_TYPE, CONTENT_ENCODING, CACHE_CONTROL, X_CONTENT_TYPE_OPTIONS}; - - const ASSETS: include_dir::Dir<'static> = include_dir::include_dir!("$OUT_DIR/"); - const CACHE_FOR_A_DAY: &str = "max-age=86400"; - const GZIP: &str = "gzip"; - - pub async fn statics( - Path(path): Path<String> - ) -> Response { - let content_type: &'static str = if path.ends_with(".js") { - "application/javascript" - } else if path.ends_with(".css") { - "text/css" - } else if path.ends_with(".html") { - "text/html; charset=\"utf-8\"" - } else { - "application/octet-stream" - }; - - match ASSETS.get_file(path.clone() + ".gz") { - Some(file) => (StatusCode::OK, - [ - (CONTENT_TYPE, content_type), - (CONTENT_ENCODING, GZIP), - (CACHE_CONTROL, CACHE_FOR_A_DAY), - (X_CONTENT_TYPE_OPTIONS, "nosniff") - ], - file.contents()).into_response(), - None => match ASSETS.get_file(path) { - Some(file) => (StatusCode::OK, - [ - (CONTENT_TYPE, content_type), - (CACHE_CONTROL, CACHE_FOR_A_DAY), - (X_CONTENT_TYPE_OPTIONS, "nosniff") - ], - file.contents()).into_response(), - None => StatusCode::NOT_FOUND.into_response() - } - } - } -} +pub mod assets; #[cfg(test)] mod tests { @@ -107,11 +61,11 @@ mod tests { let dt = time::OffsetDateTime::now_utc() .to_offset( time::UtcOffset::from_hms( - rand::distributions::Uniform::new(-11, 12) - .sample(&mut rand::thread_rng()), + rand::distributions::Uniform::new(-11, 12).sample(&mut rand::thread_rng()), if rand::random::<bool>() { 0 } else { 30 }, - 0 - ).unwrap() + 0, + ) + .unwrap(), ) .format(&time::format_description::well_known::Rfc3339) .unwrap(); @@ -218,14 +172,15 @@ mod tests { // potentially with an offset? let offset = item.as_offset().unwrap().data; let date = item.as_date().unwrap().data; - let time = item.as_time().unwrap().data; + let time = item.as_time().unwrap().data; let dt = date.with_time(time).assume_offset(offset); let expected = time::OffsetDateTime::parse( mf2["properties"]["published"][0].as_str().unwrap(), - &time::format_description::well_known::Rfc3339 - ).unwrap(); - + &time::format_description::well_known::Rfc3339, + ) + .unwrap(); + assert_eq!(dt, expected); } else { unreachable!() @@ -235,7 +190,8 @@ mod tests { fn check_e_content(mf2: &serde_json::Value, item: &Item) { assert!(item.properties.contains_key("content")); - if let Some(PropertyValue::Fragment(content)) = item.properties.get("content").and_then(|v| v.first()) + if let Some(PropertyValue::Fragment(content)) = + item.properties.get("content").and_then(|v| v.first()) { assert_eq!( content.html, @@ -250,7 +206,11 @@ mod tests { fn test_note() { let mf2 = gen_random_post(&rand::random::<Domain>().to_string(), PostType::Note); - let html = crate::mf2::Entry { post: &mf2, from_feed: false, }.to_string(); + let html = crate::mf2::Entry { + post: &mf2, + from_feed: false, + } + .to_string(); let url: Url = mf2 .pointer("/properties/uid/0") @@ -259,7 +219,12 @@ mod tests { .unwrap(); let parsed: Document = microformats::from_html(&html, url.clone()).unwrap(); - if let Some(item) = parsed.into_iter().find(|i| i.properties.get("url").unwrap().contains(&PropertyValue::Url(url.clone()))) { + if let Some(item) = parsed.into_iter().find(|i| { + i.properties + .get("url") + .unwrap() + .contains(&PropertyValue::Url(url.clone())) + }) { let props = &item.properties; check_e_content(&mf2, &item); @@ -281,7 +246,11 @@ mod tests { #[test] fn test_article() { let mf2 = gen_random_post(&rand::random::<Domain>().to_string(), PostType::Article); - let html = crate::mf2::Entry { post: &mf2, from_feed: false, }.to_string(); + let html = crate::mf2::Entry { + post: &mf2, + from_feed: false, + } + .to_string(); let url: Url = mf2 .pointer("/properties/uid/0") .and_then(|i| i.as_str()) @@ -289,8 +258,12 @@ mod tests { .unwrap(); let parsed: Document = microformats::from_html(&html, url.clone()).unwrap(); - if let Some(item) = parsed.into_iter().find(|i| i.properties.get("url").unwrap().contains(&PropertyValue::Url(url.clone()))) { - + if let Some(item) = parsed.into_iter().find(|i| { + i.properties + .get("url") + .unwrap() + .contains(&PropertyValue::Url(url.clone())) + }) { check_e_content(&mf2, &item); check_dt_published(&mf2, &item); assert!(item.properties.contains_key("uid")); @@ -302,7 +275,9 @@ mod tests { .iter() .any(|i| i == item.properties.get("uid").and_then(|v| v.first()).unwrap())); assert!(item.properties.contains_key("name")); - if let Some(PropertyValue::Plain(name)) = item.properties.get("name").and_then(|v| v.first()) { + if let Some(PropertyValue::Plain(name)) = + item.properties.get("name").and_then(|v| v.first()) + { assert_eq!( name, mf2.pointer("/properties/name/0") @@ -338,7 +313,11 @@ mod tests { .and_then(|i| i.as_str()) .and_then(|u| u.parse().ok()) .unwrap(); - let html = crate::mf2::Entry { post: &mf2, from_feed: false, }.to_string(); + let html = crate::mf2::Entry { + post: &mf2, + from_feed: false, + } + .to_string(); let parsed: Document = microformats::from_html(&html, url.clone()).unwrap(); if let Some(item) = parsed.items.first() { diff --git a/templates/src/mf2.rs b/templates/src/mf2.rs index 787d3ed..aaac80f 100644 --- a/templates/src/mf2.rs +++ b/templates/src/mf2.rs @@ -1,3 +1,7 @@ +#![expect( + clippy::needless_lifetimes, + reason = "bug: Clippy doesn't realize the `markup` crate requires explicit lifetimes due to its idiosyncracies" +)] use ellipse::Ellipse; pub static POSTS_PER_PAGE: usize = 20; diff --git a/templates/src/templates.rs b/templates/src/templates.rs index 9b29fce..5772b4d 100644 --- a/templates/src/templates.rs +++ b/templates/src/templates.rs @@ -1,6 +1,7 @@ +#![allow(clippy::needless_lifetimes)] +use crate::{Feed, VCard}; use http::StatusCode; use kittybox_util::micropub::Channel; -use crate::{Feed, VCard}; markup::define! { Template<'a>(title: &'a str, blog_name: &'a str, feeds: Vec<Channel>, user: Option<&'a kittybox_indieauth::ProfileUrl>, content: String) { diff --git a/tower-watchdog/src/lib.rs b/tower-watchdog/src/lib.rs index 9a5c609..e0be313 100644 --- a/tower-watchdog/src/lib.rs +++ b/tower-watchdog/src/lib.rs @@ -27,22 +27,45 @@ impl<S> tower_layer::Layer<S> for WatchdogLayer { fn layer(&self, inner: S) -> Self::Service { Self::Service { pet: self.pet.clone(), - inner + inner, } } } pub struct WatchdogService<S> { pet: watchdog::Pet, - inner: S + inner: S, } -impl<S: tower_service::Service<Request> + Clone + 'static, Request: std::fmt::Debug + 'static> tower_service::Service<Request> for WatchdogService<S> { +impl<S: tower_service::Service<Request> + Clone + 'static, Request: std::fmt::Debug + 'static> + tower_service::Service<Request> for WatchdogService<S> +{ type Response = S::Response; type Error = S::Error; - type Future = std::pin::Pin<Box<futures::future::Then<std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), tokio::sync::mpsc::error::SendError<()>>> + Send>>, std::pin::Pin<Box<S::Future>>, Box<dyn FnOnce(Result<(), tokio::sync::mpsc::error::SendError<()>>) -> std::pin::Pin<Box<S::Future>>>>>>; + type Future = std::pin::Pin< + Box< + futures::future::Then< + std::pin::Pin< + Box< + dyn std::future::Future< + Output = Result<(), tokio::sync::mpsc::error::SendError<()>>, + > + Send, + >, + >, + std::pin::Pin<Box<S::Future>>, + Box< + dyn FnOnce( + Result<(), tokio::sync::mpsc::error::SendError<()>>, + ) -> std::pin::Pin<Box<S::Future>>, + >, + >, + >, + >; - fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> { + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Result<(), Self::Error>> { self.inner.poll_ready(cx) } @@ -57,7 +80,11 @@ impl<S: tower_service::Service<Request> + Clone + 'static, Request: std::fmt::De std::mem::swap(&mut self.inner, &mut inner); let pet = self.pet.clone(); - Box::pin(pet.pet_owned().boxed().then(Box::new(move |_| Box::pin(inner.call(request))))) + Box::pin( + pet.pet_owned() + .boxed() + .then(Box::new(move |_| Box::pin(inner.call(request)))), + ) } } @@ -84,7 +111,10 @@ mod tests { for i in 100..=1_000 { if i != 1000 { assert!(mock.poll_ready().is_ready()); - let request = Box::pin(tokio::time::sleep(std::time::Duration::from_millis(i)).then(|()| mock.call(()))); + let request = Box::pin( + tokio::time::sleep(std::time::Duration::from_millis(i)) + .then(|()| mock.call(())), + ); tokio::select! { _ = &mut watchdog_future => panic!("Watchdog called earlier than response!"), _ = request => {}, @@ -94,7 +124,10 @@ mod tests { // We use `+ 1` here, because the watchdog behavior is // subject to a data race if a request arrives in the // same tick. - let request = Box::pin(tokio::time::sleep(std::time::Duration::from_millis(i + 1)).then(|()| mock.call(()))); + let request = Box::pin( + tokio::time::sleep(std::time::Duration::from_millis(i + 1)) + .then(|()| mock.call(())), + ); tokio::select! { _ = &mut watchdog_future => { }, diff --git a/util/src/fs.rs b/util/src/fs.rs index 6a7a5b4..ea9dadd 100644 --- a/util/src/fs.rs +++ b/util/src/fs.rs @@ -1,6 +1,6 @@ +use rand::{distributions::Alphanumeric, Rng}; use std::io::{self, Result}; use std::path::{Path, PathBuf}; -use rand::{Rng, distributions::Alphanumeric}; use tokio::fs; /// Create a temporary file named `temp.[a-zA-Z0-9]{length}` in @@ -20,7 +20,7 @@ use tokio::fs; pub async fn mktemp<T, B>(dir: T, basename: B, length: usize) -> Result<(PathBuf, fs::File)> where T: AsRef<Path>, - B: Into<Option<&'static str>> + B: Into<Option<&'static str>>, { let dir = dir.as_ref(); let basename = basename.into().unwrap_or(""); @@ -33,9 +33,9 @@ where if basename.is_empty() { "" } else { "." }, { let string = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(length) - .collect::<Vec<u8>>(); + .sample_iter(&Alphanumeric) + .take(length) + .collect::<Vec<u8>>(); String::from_utf8(string).unwrap() } )); @@ -49,8 +49,8 @@ where Ok(file) => return Ok((filename, file)), Err(err) => match err.kind() { io::ErrorKind::AlreadyExists => continue, - _ => return Err(err) - } + _ => return Err(err), + }, } } } diff --git a/util/src/lib.rs b/util/src/lib.rs index cb5f666..0c5df49 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -17,7 +17,7 @@ pub enum MentionType { Bookmark, /// A plain link without MF2 annotations. #[default] - Mention + Mention, } /// Common data-types useful in creating smart authentication systems. @@ -29,7 +29,7 @@ pub mod auth { /// used to recover from a lost passkey. Password, /// Denotes availability of one or more passkeys. - WebAuthn + WebAuthn, } } diff --git a/util/src/micropub.rs b/util/src/micropub.rs index 9d2c525..1f8008b 100644 --- a/util/src/micropub.rs +++ b/util/src/micropub.rs @@ -21,7 +21,7 @@ pub enum QueryType { Category, /// Unsupported query type // TODO: make this take a lifetime parameter for zero-copy deserialization if possible? - Unknown(std::borrow::Cow<'static, str>) + Unknown(std::borrow::Cow<'static, str>), } /// Data structure representing a Micropub channel in the ?q=channels output. @@ -42,7 +42,7 @@ pub struct SyndicationDestination { /// The syndication destination's UID, opaque to the client. pub uid: String, /// A human-friendly name. - pub name: String + pub name: String, } fn default_q_list() -> Vec<QueryType> { @@ -67,7 +67,7 @@ pub struct Config { pub media_endpoint: Option<url::Url>, /// Other unspecified keys, sometimes implementation-defined. #[serde(flatten)] - pub other: HashMap<String, serde_json::Value> + pub other: HashMap<String, serde_json::Value>, } #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] @@ -145,14 +145,17 @@ impl Error { pub const fn from_static(error: ErrorKind, error_description: &'static str) -> Self { Self { error, - error_description: Some(std::borrow::Cow::Borrowed(error_description)) + error_description: Some(std::borrow::Cow::Borrowed(error_description)), } } } impl From<ErrorKind> for Error { fn from(error: ErrorKind) -> Self { - Self { error, error_description: None } + Self { + error, + error_description: None, + } } } @@ -190,4 +193,3 @@ impl axum_core::response::IntoResponse for Error { )) } } - diff --git a/util/src/queue.rs b/util/src/queue.rs index edbec86..b32fdc5 100644 --- a/util/src/queue.rs +++ b/util/src/queue.rs @@ -1,5 +1,5 @@ -use std::future::Future; use futures_util::Stream; +use std::future::Future; use std::pin::Pin; use uuid::Uuid; @@ -44,7 +44,9 @@ pub trait JobQueue<T: JobItem>: Send + Sync + Sized + Clone + 'static { /// /// Note that one item may be returned several times if it is not /// marked as done. - fn into_stream(self) -> impl Future<Output = Result<JobStream<Self::Job, Self::Error>, Self::Error>> + Send; + fn into_stream( + self, + ) -> impl Future<Output = Result<JobStream<Self::Job, Self::Error>, Self::Error>> + Send; } /// A job description yielded from a job queue. |