From 4ca0c24b1989fcd12c453d428af70f58456f7651 Mon Sep 17 00:00:00 2001 From: Vika Date: Thu, 1 Aug 2024 20:01:12 +0300 Subject: Migrate from axum::Extension to axum::extract::State This somehow allowed me to shrink the construction phase of Kittybox by a huge amount of code. --- src/admin/mod.rs | 31 ++-- src/frontend/mod.rs | 7 +- src/frontend/onboarding.rs | 23 ++- src/indieauth/mod.rs | 45 +++-- src/lib.rs | 118 ++++++++++++- src/main.rs | 423 ++++++++++++++++++--------------------------- src/media/mod.rs | 26 ++- src/micropub/mod.rs | 38 ++-- src/webmentions/mod.rs | 39 ++--- 9 files changed, 378 insertions(+), 372 deletions(-) diff --git a/src/admin/mod.rs b/src/admin/mod.rs index abc4515..7f8c9cf 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -10,9 +10,9 @@ // prevent collisions with future well-known scopes). use std::collections::HashSet; -use axum::extract::Host; +use axum::extract::{State, Host}; use axum::response::{Response, IntoResponse}; -use axum::{Extension, Form}; +use axum::Form; use axum_extra::extract::CookieJar; use hyper::StatusCode; @@ -36,19 +36,19 @@ static SESSION_STORE: std::sync::LazyLock( Host(host): Host, - Extension(db): Extension, + State(db): State, Form(NameChange { name }): Form ) -> Result<(), StorageError> { db.set_setting::(&host, name).await } -async fn get_name(Host(host): Host, Extension(db): Extension) -> Result { +async fn get_name(Host(host): Host, State(db): State) -> Result { db.get_setting::(&host).await.map(|name| name.as_ref().to_owned()) } async fn change_password( Host(host): Host, - Extension(auth): Extension, + State(auth): State, Form(PasswordChange { old_password, new_password }): Form ) -> StatusCode { let website = url::Url::parse(&format!("https://{host}/")).unwrap(); @@ -84,8 +84,8 @@ impl axum::response::IntoResponse for StorageError { async fn dashboard( Host(host): Host, cookies: CookieJar, - Extension(db): Extension, - Extension(auth): Extension + State(db): State, + State(auth): State ) -> axum::response::Response { let page = kittybox_frontend_renderer::admin::AdminHome {}; @@ -94,21 +94,26 @@ async fn dashboard( } -pub fn router(db: D, auth: A) -> axum::Router { +pub fn router() -> axum::Router +where + A: AuthBackend + FromRef + 'static, + S: Storage + FromRef + 'static, + M: MediaStore + FromRef + 'static, + Q: crate::webmentions::JobQueue + FromRef + 'static, + axum_extra::extract::cookie::Key: FromRef +{ axum::Router::new() .nest("/.kittybox/admin", axum::Router::new() // routes go here .route( "/", - axum::routing::get(dashboard::) + axum::routing::get(dashboard::) ) .route( "/api/settings/name", - axum::routing::post(set_name::) - .get(get_name::) + axum::routing::post(set_name::) + .get(get_name::) ) .route("/api/settings/password", axum::routing::post(change_password::)) - .layer(Extension(db)) - .layer(Extension(auth)) ) } diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 0292171..42e8754 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -1,9 +1,8 @@ use crate::database::{Storage, StorageError}; use axum::{ - extract::{Host, Path, Query}, + extract::{Host, Query, State}, http::{StatusCode, Uri}, response::IntoResponse, - Extension, }; use futures_util::FutureExt; use serde::Deserialize; @@ -239,7 +238,7 @@ async fn get_post_from_database( pub async fn homepage( Host(host): Host, Query(query): Query, - Extension(db): Extension, + State(db): State, ) -> impl IntoResponse { let user = None; // TODO authentication // This is stupid, but there is no other way. @@ -333,7 +332,7 @@ pub async fn homepage( #[tracing::instrument(skip(db))] pub async fn catchall( - Extension(db): Extension, + State(db): State, Host(host): Host, Query(query): Query, uri: Uri, diff --git a/src/frontend/onboarding.rs b/src/frontend/onboarding.rs index faf8cdd..9f3f36b 100644 --- a/src/frontend/onboarding.rs +++ b/src/frontend/onboarding.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use crate::database::{settings, Storage}; use axum::{ - extract::{Extension, Host}, + extract::{FromRef, Host, State}, http::StatusCode, response::{Html, IntoResponse}, Json, @@ -131,10 +131,10 @@ async fn onboard( } pub async fn post( - Extension(db): Extension, + State(db): State, Host(host): Host, - Extension(http): Extension, - Extension(jobset): Extension>>>, + State(http): State, + State(jobset): State>>>, Json(data): Json, ) -> axum::response::Response { let user_uid = format!("https://{}/", host.as_str()); @@ -168,14 +168,13 @@ pub async fn post( } } -pub fn router( - database: S, - http: reqwest::Client, - jobset: Arc>>, -) -> axum::routing::MethodRouter { +pub fn router() -> axum::routing::MethodRouter +where + S: Storage + FromRef + 'static, + Arc>>: FromRef, + reqwest::Client: FromRef, + St: Clone + Send + Sync + 'static, +{ axum::routing::get(get) .post(post::) - .layer::<_, _, std::convert::Infallible>(axum::Extension(database)) - .layer::<_, _, std::convert::Infallible>(axum::Extension(http)) - .layer(axum::Extension(jobset)) } diff --git a/src/indieauth/mod.rs b/src/indieauth/mod.rs index 26879bb..2550df0 100644 --- a/src/indieauth/mod.rs +++ b/src/indieauth/mod.rs @@ -1,12 +1,10 @@ use std::marker::PhantomData; +use microformats::types::Class; use tracing::error; use serde::Deserialize; use axum::{ - extract::{Query, Json, Host, Form}, - response::{Html, IntoResponse, Response}, - http::StatusCode, TypedHeader, headers::{Authorization, authorization::Bearer}, - Extension + extract::{Form, FromRef, Host, Json, Query, State}, headers::{authorization::Bearer, Authorization}, http::StatusCode, response::{Html, IntoResponse, Response}, Extension, TypedHeader }; #[cfg_attr(not(feature = "webauthn"), allow(unused_imports))] use axum_extra::extract::cookie::{CookieJar, Cookie}; @@ -73,18 +71,16 @@ impl axum::response::IntoResponse for IndieAuthResourceError { } #[async_trait::async_trait] -impl axum::extract::FromRequestParts for User { +impl , St: Clone + Send + Sync + 'static> axum::extract::FromRequestParts for User { type Rejection = IndieAuthResourceError; - async fn from_request_parts(req: &mut axum::http::request::Parts, state: &S) -> Result { + async fn from_request_parts(req: &mut axum::http::request::Parts, state: &St) -> Result { let TypedHeader(Authorization(token)) = TypedHeader::>::from_request_parts(req, state) .await .map_err(|_| IndieAuthResourceError::Unauthorized)?; - let axum::Extension(auth) = axum::Extension::::from_request_parts(req, state) - .await - .unwrap(); + let auth = A::from_ref(state); let Host(host) = Host::from_request_parts(req, state) .await @@ -146,9 +142,9 @@ pub async fn metadata( async fn authorization_endpoint_get( Host(host): Host, Query(request): Query, - Extension(db): Extension, - Extension(http): Extension, - Extension(auth): Extension + State(db): State, + State(http): State, + State(auth): State ) -> Response { let me = format!("https://{host}/").parse().unwrap(); let h_app = { @@ -802,8 +798,13 @@ async fn userinfo_endpoint_get( } } -#[must_use] -pub fn router(backend: A, db: D, http: reqwest::Client) -> axum::Router { +pub fn router() -> axum::Router +where + S: Storage + FromRef + 'static, + A: AuthBackend + FromRef, + reqwest::Client: FromRef, + St: Clone + Send + Sync + 'static +{ use axum::routing::{Router, get, post}; Router::new() @@ -814,14 +815,14 @@ pub fn router(backend: A, db: D, http: req get(metadata)) .route( "/auth", - get(authorization_endpoint_get::) - .post(authorization_endpoint_post::)) + get(authorization_endpoint_get::) + .post(authorization_endpoint_post::)) .route( "/auth/confirm", post(authorization_endpoint_confirm::)) .route( "/token", - post(token_endpoint_post::)) + post(token_endpoint_post::)) .route( "/token_status", post(introspection_endpoint_post::)) @@ -830,11 +831,11 @@ pub fn router(backend: A, db: D, http: req post(revocation_endpoint_post::)) .route( "/userinfo", - get(userinfo_endpoint_get::)) + get(userinfo_endpoint_get::)) .route("/webauthn/pre_register", get( - #[cfg(feature = "webauthn")] webauthn::webauthn_pre_register::, + #[cfg(feature = "webauthn")] webauthn::webauthn_pre_register::, #[cfg(not(feature = "webauthn"))] || std::future::ready(axum::http::StatusCode::NOT_FOUND) ) ) @@ -844,12 +845,6 @@ pub fn router(backend: A, db: D, http: req axum::http::Method::POST ]) .allow_origin(tower_http::cors::Any)) - .layer(Extension(backend)) - // I don't really like the fact that I have to use the whole database - // If I could, I would've designed a separate trait for getting profiles - // And made databases implement it, for example - .layer(Extension(db)) - .layer(Extension(http)) ) .route( "/.well-known/oauth-authorization-server", diff --git a/src/lib.rs b/src/lib.rs index 04b3298..1cc01c2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,17 @@ #![forbid(unsafe_code)] #![warn(clippy::todo)] +use std::sync::Arc; + +use axum::extract::FromRef; +use axum_extra::extract::cookie::Key; +use database::{FileStorage, PostgresStorage, Storage}; +use indieauth::backend::{AuthBackend, FileBackend as FileAuthBackend}; +use kittybox_util::queue::{JobItem, JobQueue}; +use media::storage::{MediaStore, file::FileStore as FileMediaStore}; +use tokio::{sync::Mutex, task::JoinSet}; +use webmentions::queue::{PostgresJobItem, PostgresJobQueue}; + /// Database abstraction layer for Kittybox, allowing the CMS to work with any kind of database. pub mod database; pub mod frontend; @@ -9,6 +20,110 @@ pub mod micropub; pub mod indieauth; pub mod webmentions; pub mod login; +//pub mod admin; + +#[derive(Clone)] +pub struct AppState +where +A: AuthBackend + Sized + 'static, +S: Storage + Sized + 'static, +M: MediaStore + Sized + 'static, +Q: JobQueue + Sized +{ + pub auth_backend: A, + pub storage: S, + pub media_store: M, + pub job_queue: Q, + pub http: reqwest::Client, + pub background_jobs: Arc>>, + pub cookie_key: Key +} + +// This is really regrettable, but I can't write: +// +// ```compile-error +// impl FromRef> for A +// where A: AuthBackend, S: Storage, M: MediaStore { +// fn from_ref(input: &AppState) -> A { +// input.auth_backend.clone() +// } +// } +// ``` +// +// ...because of the orphan rule. +// +// I wonder if this would stifle external implementations. I think it +// shouldn't, because my AppState type is generic, and since the +// target type is local, the orphan rule will not kick in. You just +// have to repeat this magic invocation. + +impl FromRef> for FileAuthBackend +// where S: Storage, M: MediaStore +where S: Storage, M: MediaStore, Q: JobQueue +{ + fn from_ref(input: &AppState) -> Self { + input.auth_backend.clone() + } +} + +impl FromRef> for PostgresStorage +// where A: AuthBackend, M: MediaStore +where A: AuthBackend, M: MediaStore, Q: JobQueue +{ + fn from_ref(input: &AppState) -> Self { + input.storage.clone() + } +} + +impl FromRef> for FileStorage +where A: AuthBackend, M: MediaStore, Q: JobQueue +{ + fn from_ref(input: &AppState) -> Self { + input.storage.clone() + } +} + +impl FromRef> for FileMediaStore +// where A: AuthBackend, S: Storage +where A: AuthBackend, S: Storage, Q: JobQueue +{ + fn from_ref(input: &AppState) -> Self { + input.media_store.clone() + } +} + +impl FromRef> for Key +where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue +{ + fn from_ref(input: &AppState) -> Self { + input.cookie_key.clone() + } +} + +impl FromRef> for reqwest::Client +where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue +{ + fn from_ref(input: &AppState) -> Self { + input.http.clone() + } +} + +impl FromRef> for Arc>> +where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue +{ + fn from_ref(input: &AppState) -> Self { + input.background_jobs.clone() + } +} + +#[cfg(feature = "sqlx")] +impl FromRef> for PostgresJobQueue +where A: AuthBackend, S: Storage, M: MediaStore +{ + fn from_ref(input: &AppState) -> Self { + input.job_queue.clone() + } +} pub mod companion { use std::{collections::HashMap, sync::Arc}; @@ -52,8 +167,7 @@ pub mod companion { } } - #[must_use] - pub fn router() -> axum::Router { + pub fn router() -> axum::Router { let resources: ResourceTable = { let mut map = HashMap::new(); diff --git a/src/main.rs b/src/main.rs index 9e541b9..d10822b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,266 +1,58 @@ -use kittybox::database::FileStorage; +use axum::extract::FromRef; +use kittybox::{database::Storage, indieauth::backend::AuthBackend, media::storage::MediaStore, webmentions::Webmention}; +use tokio::{sync::Mutex, task::JoinSet}; use std::{env, time::Duration, sync::Arc}; use tracing::error; -fn init_media(auth_backend: A, blobstore_uri: &str) -> axum::Router { - match blobstore_uri.split_once(':').unwrap().0 { - "file" => { - let folder = std::path::PathBuf::from( - blobstore_uri.strip_prefix("file://").unwrap() - ); - let blobstore = kittybox::media::storage::file::FileStore::new(folder); - kittybox::media::router::<_, _>(blobstore, auth_backend) - }, - other => unimplemented!("Unsupported backend: {other}") - } +async fn teapot_route() -> impl axum::response::IntoResponse { + use axum::http::{header, StatusCode}; + (StatusCode::IM_A_TEAPOT, [(header::CONTENT_TYPE, "text/plain")], "Sorry, can't brew coffee yet!") } -async fn compose_kittybox_with_auth( - http: reqwest::Client, - auth_backend: A, - backend_uri: &str, - blobstore_uri: &str, - job_queue_uri: &str, - jobset: &Arc>>, - cancellation_token: &tokio_util::sync::CancellationToken -) -> (axum::Router, kittybox::webmentions::SupervisedTask) -where A: kittybox::indieauth::backend::AuthBackend +async fn health_check( + axum::extract::State(data): axum::extract::State, +) -> impl axum::response::IntoResponse +where + D: kittybox::database::Storage { - match backend_uri.split_once(':').unwrap().0 { - "file" => { - let database = { - let folder = backend_uri.strip_prefix("file://").unwrap(); - let path = std::path::PathBuf::from(folder); - - match kittybox::database::FileStorage::new(path).await { - Ok(db) => db, - Err(err) => { - error!("Error creating database: {:?}", err); - std::process::exit(1); - } - } - }; - - // Technically, if we don't construct the micropub router, - // we could use some wrapper that makes the database - // read-only. - // - // This would allow to exclude all code to write to the - // database and separate reader and writer processes of - // Kittybox to improve security. - let homepage: axum::routing::MethodRouter<_> = axum::routing::get( - kittybox::frontend::homepage:: - ) - .layer(axum::Extension(database.clone())); - let fallback = axum::routing::get( - kittybox::frontend::catchall:: - ) - .layer(axum::Extension(database.clone())); - - let micropub = kittybox::micropub::router( - database.clone(), - http.clone(), - auth_backend.clone(), - Arc::clone(jobset) - ); - let onboarding = kittybox::frontend::onboarding::router( - database.clone(), http.clone(), Arc::clone(jobset) - ); - - - let (webmention, task) = kittybox::webmentions::router( - kittybox::webmentions::queue::PostgresJobQueue::new(job_queue_uri).await.unwrap(), - database.clone(), - http.clone(), - cancellation_token.clone() - ); - - let router = axum::Router::new() - .route("/", homepage) - .fallback(fallback) - .route("/.kittybox/micropub", micropub) - .route("/.kittybox/onboarding", onboarding) - .nest("/.kittybox/media", init_media(auth_backend.clone(), blobstore_uri)) - .merge(kittybox::indieauth::router(auth_backend.clone(), database.clone(), http.clone())) - .merge(webmention) - .route( - "/.kittybox/health", - axum::routing::get(health_check::) - .layer(axum::Extension(database)) - ); - - (router, task) - }, - "redis" => unimplemented!("Redis backend is not supported."), - #[cfg(feature = "postgres")] - "postgres" => { - use kittybox::database::PostgresStorage; - - let database = { - match PostgresStorage::new(backend_uri).await { - Ok(db) => db, - Err(err) => { - error!("Error creating database: {:?}", err); - std::process::exit(1); - } - } - }; - - // Technically, if we don't construct the micropub router, - // we could use some wrapper that makes the database - // read-only. - // - // This would allow to exclude all code to write to the - // database and separate reader and writer processes of - // Kittybox to improve security. - let homepage: axum::routing::MethodRouter<_> = axum::routing::get( - kittybox::frontend::homepage:: - ) - .layer(axum::Extension(database.clone())); - let fallback = axum::routing::get( - kittybox::frontend::catchall:: - ) - .layer(axum::Extension(database.clone())); - - let micropub = kittybox::micropub::router( - database.clone(), - http.clone(), - auth_backend.clone(), - Arc::clone(jobset) - ); - let onboarding = kittybox::frontend::onboarding::router( - database.clone(), http.clone(), Arc::clone(jobset) - ); - - let (webmention, task) = kittybox::webmentions::router( - kittybox::webmentions::queue::PostgresJobQueue::new(job_queue_uri).await.unwrap(), - database.clone(), - http.clone(), - cancellation_token.clone() - ); - - let router = axum::Router::new() - .route("/", homepage) - .fallback(fallback) - .route("/.kittybox/micropub", micropub) - .route("/.kittybox/onboarding", onboarding) - .nest("/.kittybox/media", init_media(auth_backend.clone(), blobstore_uri)) - .merge(kittybox::indieauth::router(auth_backend.clone(), database.clone(), http.clone())) - .merge(webmention) - .route( - "/.kittybox/health", - axum::routing::get(health_check::) - .layer(axum::Extension(database)) - ); - - (router, task) - }, - other => unimplemented!("Unsupported backend: {other}") - } + (axum::http::StatusCode::OK, std::borrow::Cow::Borrowed("OK")) } -async fn compose_kittybox( - backend_uri: &str, - blobstore_uri: &str, - authstore_uri: &str, - job_queue_uri: &str, - jobset: &Arc>>, - cancellation_token: &tokio_util::sync::CancellationToken -) -> (axum::Router, kittybox::webmentions::SupervisedTask) { - let http: reqwest::Client = { - #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), - "/", - env!("CARGO_PKG_VERSION") - )); - if let Ok(certs) = std::env::var("KITTYBOX_CUSTOM_PKI_ROOTS") { - // TODO: add a root certificate if there's an environment variable pointing at it - for path in certs.split(':') { - let metadata = match tokio::fs::metadata(path).await { - Ok(metadata) => metadata, - Err(err) if err.kind() == std::io::ErrorKind::NotFound => { - tracing::error!("TLS root certificate {} not found, skipping...", path); - continue; - } - Err(err) => panic!("Error loading TLS certificates: {}", err) - }; - if metadata.is_dir() { - let mut dir = tokio::fs::read_dir(path).await.unwrap(); - while let Ok(Some(file)) = dir.next_entry().await { - let pem = tokio::fs::read(file.path()).await.unwrap(); - builder = builder.add_root_certificate( - reqwest::Certificate::from_pem(&pem).unwrap() - ); - } - } else { - let pem = tokio::fs::read(path).await.unwrap(); - builder = builder.add_root_certificate( - reqwest::Certificate::from_pem(&pem).unwrap() - ); - } - } - } - - builder.build().unwrap() - }; - - let (router, task) = match authstore_uri.split_once(':').unwrap().0 { - "file" => { - let auth_backend = { - let folder = authstore_uri - .strip_prefix("file://") - .unwrap(); - kittybox::indieauth::backend::fs::FileBackend::new(folder) - }; - - compose_kittybox_with_auth(http, auth_backend, backend_uri, blobstore_uri, job_queue_uri, jobset, cancellation_token).await - } - other => unimplemented!("Unsupported backend: {other}") - }; - - // TODO: load from environment - let cookie_key = axum_extra::extract::cookie::Key::generate(); - let router = router +async fn compose_stateful_kittybox() -> axum::Router +where +A: AuthBackend + 'static + FromRef, +S: Storage + 'static + FromRef, +M: MediaStore + 'static + FromRef, +Q: kittybox_util::queue::JobQueue + FromRef, +reqwest::Client: FromRef, +Arc>>: FromRef, +St: Clone + Send + Sync + 'static +{ + use axum::routing::get; + axum::Router::new() + .route("/", get(kittybox::frontend::homepage::)) + .fallback(get(kittybox::frontend::catchall::)) + .route("/.kittybox/micropub", kittybox::micropub::router::()) + .route("/.kittybox/onboarding", kittybox::frontend::onboarding::router::()) + .nest("/.kittybox/media", kittybox::media::router::()) + .merge(kittybox::indieauth::router::()) + .merge(kittybox::webmentions::router::()) + .route("/.kittybox/health", get(health_check::)) .route( "/.kittybox/static/:path", axum::routing::get(kittybox::frontend::statics) ) - .route("/.kittybox/coffee", teapot_route()) - .nest("/.kittybox/micropub/client", kittybox::companion::router()) - .nest("/.kittybox/login", kittybox::login::router(cookie_key)) + .route("/.kittybox/coffee", get(teapot_route)) + .nest("/.kittybox/micropub/client", kittybox::companion::router::()) .layer(tower_http::trace::TraceLayer::new_for_http()) .layer(tower_http::catch_panic::CatchPanicLayer::new()) .layer(tower_http::sensitive_headers::SetSensitiveHeadersLayer::new([ axum::http::header::AUTHORIZATION, axum::http::header::COOKIE, axum::http::header::SET_COOKIE, - ])); - - (router, task) -} - -fn teapot_route() -> axum::routing::MethodRouter { - axum::routing::get(|| async { - use axum::http::{header, StatusCode}; - (StatusCode::IM_A_TEAPOT, [(header::CONTENT_TYPE, "text/plain")], "Sorry, can't brew coffee yet!") - }) -} - -async fn health_check( - //axum::Extension(auth): axum::Extension, - //axum::Extension(blob): axum::Extension, - axum::Extension(data): axum::Extension, -) -> impl axum::response::IntoResponse -where - //A: kittybox::indieauth::backend::AuthBackend, - //B: kittybox::media::storage::MediaStore, - D: kittybox::database::Storage -{ - (axum::http::StatusCode::OK, std::borrow::Cow::Borrowed("OK")) + ])) } #[tokio::main] @@ -306,40 +98,158 @@ async fn main() { tracing::info!("Starting the kittybox server..."); - let backend_uri: String = env::var("BACKEND_URI") + let backend_uri: url::Url = env::var("BACKEND_URI") + .as_deref() + .map(|s| url::Url::parse(s).expect("BACKEND_URI malformed")) .unwrap_or_else(|_| { error!("BACKEND_URI is not set, cannot find a database"); std::process::exit(1); }); - let blobstore_uri: String = env::var("BLOBSTORE_URI") + let blobstore_uri: url::Url = env::var("BLOBSTORE_URI") + .as_deref() + .map(|s| url::Url::parse(s).expect("BLOBSTORE_URI malformed")) .unwrap_or_else(|_| { error!("BLOBSTORE_URI is not set, can't find media store"); std::process::exit(1); }); - let authstore_uri: String = env::var("AUTH_STORE_URI") + let authstore_uri: url::Url = env::var("AUTH_STORE_URI") + .as_deref() + .map(|s| url::Url::parse(s).expect("AUTH_STORE_URI malformed")) .unwrap_or_else(|_| { error!("AUTH_STORE_URI is not set, can't find authentication store"); std::process::exit(1); }); - let job_queue_uri: String = env::var("JOB_QUEUE_URI") + let job_queue_uri: url::Url = env::var("JOB_QUEUE_URI") + .as_deref() + .map(|s| url::Url::parse(s).expect("JOB_QUEUE_URI malformed")) .unwrap_or_else(|_| { error!("JOB_QUEUE_URI is not set, can't find job queue"); std::process::exit(1); }); + // TODO: load from environment + let cookie_key = axum_extra::extract::cookie::Key::generate(); + let cancellation_token = tokio_util::sync::CancellationToken::new(); - let jobset = Arc::new(tokio::sync::Mutex::new(tokio::task::JoinSet::new())); + let jobset = Arc::new(tokio::sync::Mutex::new(tokio::task::JoinSet::<()>::new())); + let http: reqwest::Client = { + #[allow(unused_mut)] + let mut builder = reqwest::Client::builder() + .user_agent(concat!( + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION") + )); + if let Ok(certs) = std::env::var("KITTYBOX_CUSTOM_PKI_ROOTS") { + // TODO: add a root certificate if there's an environment variable pointing at it + for path in certs.split(':') { + let metadata = match tokio::fs::metadata(path).await { + Ok(metadata) => metadata, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => { + tracing::error!("TLS root certificate {} not found, skipping...", path); + continue; + } + Err(err) => panic!("Error loading TLS certificates: {}", err) + }; + if metadata.is_dir() { + let mut dir = tokio::fs::read_dir(path).await.unwrap(); + while let Ok(Some(file)) = dir.next_entry().await { + let pem = tokio::fs::read(file.path()).await.unwrap(); + builder = builder.add_root_certificate( + reqwest::Certificate::from_pem(&pem).unwrap() + ); + } + } else { + let pem = tokio::fs::read(path).await.unwrap(); + builder = builder.add_root_certificate( + reqwest::Certificate::from_pem(&pem).unwrap() + ); + } + } + } + + builder.build().unwrap() + }; + + let backend_type = backend_uri.scheme(); + let blobstore_type = blobstore_uri.scheme(); + let authstore_type = authstore_uri.scheme(); + let job_queue_type = job_queue_uri.scheme(); + + macro_rules! compose_kittybox { + ($auth:ty, $store:ty, $media:ty, $queue:ty) => { { + type AuthBackend = $auth; + type Storage = $store; + type MediaStore = $media; + type JobQueue = $queue; + + let state = kittybox::AppState { + auth_backend: match AuthBackend::new(&authstore_uri).await { + Ok(auth) => auth, + Err(err) => { + error!("Error creating auth backend: {:?}", err); + std::process::exit(1); + } + }, + storage: match Storage::new(&backend_uri).await { + Ok(db) => db, + Err(err) => { + error!("Error creating database: {:?}", err); + std::process::exit(1); + } + }, + media_store: match MediaStore::new(&blobstore_uri).await { + Ok(media) => media, + Err(err) => { + error!("Error creating media store: {:?}", err); + std::process::exit(1); + } + }, + job_queue: match JobQueue::new(&job_queue_uri).await { + Ok(queue) => queue, + Err(err) => { + error!("Error creating webmention job queue: {:?}", err); + std::process::exit(1); + } + }, + http, + background_jobs: jobset.clone(), + cookie_key + }; + + type St = kittybox::AppState; + let stateful_router = compose_stateful_kittybox::().await; + let task = kittybox::webmentions::supervised_webmentions_task::(&state, cancellation_token.clone()); + let router = stateful_router.with_state(state); - let (router, webmentions_task) = compose_kittybox( - backend_uri.as_str(), - blobstore_uri.as_str(), - authstore_uri.as_str(), - job_queue_uri.as_str(), - &jobset, - &cancellation_token - ).await; + (router, task) + } } + } + + let (router, webmentions_task): (axum::Router<()>, kittybox::webmentions::SupervisedTask) = match (authstore_type, backend_type, blobstore_type, job_queue_type) { + ("file", "file", "file", "postgres") => { + compose_kittybox!( + kittybox::indieauth::backend::fs::FileBackend, + kittybox::database::FileStorage, + kittybox::media::storage::file::FileStore, + kittybox::webmentions::queue::PostgresJobQueue + ) + }, + ("file", "postgres", "file", "postgres") => { + compose_kittybox!( + kittybox::indieauth::backend::fs::FileBackend, + kittybox::database::PostgresStorage, + kittybox::media::storage::file::FileStore, + kittybox::webmentions::queue::PostgresJobQueue + ) + }, + (_, _, _, _) => { + // TODO: refine this error. + panic!("Invalid type for AUTH_STORE_URI, BACKEND_URI, BLOBSTORE_URI or JOB_QUEUE_URI"); + } + }; let mut servers: Vec> = vec![]; @@ -494,5 +404,4 @@ async fn main() { while (jobset.join_next().await).is_some() {} tracing::info!("Shutdown complete, exiting."); std::process::exit(exitcode); - } diff --git a/src/media/mod.rs b/src/media/mod.rs index 71f875e..47f456a 100644 --- a/src/media/mod.rs +++ b/src/media/mod.rs @@ -1,14 +1,9 @@ -use std::convert::TryFrom; - use axum::{ - extract::{Extension, Host, multipart::Multipart, Path}, - response::{IntoResponse, Response}, - headers::{Header, HeaderValue, IfNoneMatch, HeaderMapExt}, - TypedHeader, + extract::{multipart::Multipart, FromRef, Host, Path, State}, headers::{HeaderMapExt, HeaderValue, IfNoneMatch}, response::{IntoResponse, Response}, TypedHeader }; use kittybox_util::error::{MicropubError, ErrorType}; use kittybox_indieauth::Scope; -use crate::indieauth::{User, backend::AuthBackend}; +use crate::indieauth::{backend::AuthBackend, User}; pub mod storage; use storage::{MediaStore, MediaStoreError, Metadata, ErrorKind}; @@ -25,7 +20,7 @@ impl From for MicropubError { #[tracing::instrument(skip(blobstore))] pub(crate) async fn upload( - Extension(blobstore): Extension, + State(blobstore): State, user: User, mut upload: Multipart ) -> Response { @@ -70,7 +65,7 @@ pub(crate) async fn serve( Host(host): Host, Path(path): Path, if_none_match: Option>, - Extension(blobstore): Extension + State(blobstore): State ) -> Response { use axum::http::StatusCode; tracing::debug!("Searching for file..."); @@ -131,11 +126,12 @@ pub(crate) async fn serve( } } -#[must_use] -pub fn router(blobstore: S, auth: A) -> axum::Router { +pub fn router() -> axum::Router where + A: AuthBackend + FromRef, + M: MediaStore + FromRef, + St: Clone + Send + Sync + 'static +{ axum::Router::new() - .route("/", axum::routing::post(upload::)) - .route("/uploads/*file", axum::routing::get(serve::)) - .layer(axum::Extension(blobstore)) - .layer(axum::Extension(auth)) + .route("/", axum::routing::post(upload::)) + .route("/uploads/*file", axum::routing::get(serve::)) } diff --git a/src/micropub/mod.rs b/src/micropub/mod.rs index 624c239..fc5dd10 100644 --- a/src/micropub/mod.rs +++ b/src/micropub/mod.rs @@ -5,12 +5,13 @@ use std::sync::Arc; use crate::database::{MicropubChannel, Storage, StorageError}; use crate::indieauth::backend::AuthBackend; use crate::indieauth::User; +use crate::media::storage::MediaStore; use crate::micropub::util::form_to_mf2_json; -use axum::extract::{BodyStream, Query, Host}; +use axum::extract::{BodyStream, FromRef, Host, Query, State}; use axum::headers::ContentType; use axum::response::{IntoResponse, Response}; use axum::TypedHeader; -use axum::{http::StatusCode, Extension}; +use axum::http::StatusCode; use serde::{Deserialize, Serialize}; use serde_json::json; use tokio::sync::Mutex; @@ -515,9 +516,9 @@ async fn dispatch_body( #[tracing::instrument(skip(db, http))] pub(crate) async fn post( - Extension(db): Extension, - Extension(http): Extension, - Extension(jobset): Extension>>>, + State(db): State, + State(http): State, + State(jobset): State>>>, TypedHeader(content_type): TypedHeader, user: User, body: BodyStream, @@ -540,7 +541,7 @@ pub(crate) async fn post( #[tracing::instrument(skip(db))] pub(crate) async fn query( - Extension(db): Extension, + State(db): State, query: Option>, Host(host): Host, user: User, @@ -662,16 +663,13 @@ pub(crate) async fn query( } } -#[must_use] -pub fn router( - storage: S, - http: reqwest::Client, - auth: A, - jobset: Arc>> -) -> axum::routing::MethodRouter + +pub fn router() -> axum::routing::MethodRouter where - S: Storage + 'static, - A: AuthBackend + S: Storage + FromRef + 'static, + A: AuthBackend + FromRef, + reqwest::Client: FromRef, + Arc>>: FromRef { axum::routing::get(query::) .post(post::) @@ -680,11 +678,7 @@ where axum::http::Method::GET, axum::http::Method::POST, ]) - .allow_origin(tower_http::cors::Any)) - .layer::<_, _, std::convert::Infallible>(axum::Extension(storage)) - .layer::<_, _, std::convert::Infallible>(axum::Extension(http)) - .layer::<_, _, std::convert::Infallible>(axum::Extension(auth)) - .layer::<_, _, std::convert::Infallible>(axum::Extension(jobset)) + .allow_origin(tower_http::cors::Any)) } #[cfg(test)] @@ -716,7 +710,7 @@ mod tests { use super::FetchedPostContext; use kittybox_indieauth::{Scopes, Scope, TokenData}; - use axum::extract::Host; + use axum::extract::{Host, State}; #[test] fn test_populate_reply_context() { @@ -850,7 +844,7 @@ mod tests { #[tokio::test] async fn test_query_foreign_url() { let mut res = super::query( - axum::Extension(crate::database::MemoryStorage::default()), + State(crate::database::MemoryStorage::default()), Some(axum::extract::Query(super::MicropubQuery::source( "https://aaronparecki.com/feeds/main", ))), diff --git a/src/webmentions/mod.rs b/src/webmentions/mod.rs index 3e9b094..d5a617e 100644 --- a/src/webmentions/mod.rs +++ b/src/webmentions/mod.rs @@ -1,4 +1,4 @@ -use axum::{Form, response::{IntoResponse, Response}, Extension}; +use axum::{extract::{FromRef, State}, response::{IntoResponse, Response}, routing::post, Form}; use axum::http::StatusCode; use tracing::error; @@ -20,7 +20,7 @@ impl queue::PostgresJobItem for Webmention { } async fn accept_webmention>( - Extension(queue): Extension, + State(queue): State, Form(webmention): Form, ) -> Response { if let Err(err) = webmention.source.parse::() { @@ -31,27 +31,16 @@ async fn accept_webmention>( } match queue.put(&webmention).await { - Ok(id) => StatusCode::ACCEPTED.into_response(), + Ok(_id) => StatusCode::ACCEPTED.into_response(), Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, [ ("Content-Type", "text/plain") ], err.to_string()).into_response() } } -pub fn router, S: Storage + 'static>( - queue: Q, db: S, http: reqwest::Client, - cancellation_token: tokio_util::sync::CancellationToken -) -> (axum::Router, SupervisedTask) { - // Automatically spawn a background task to handle webmentions - let bgtask_handle = supervised_webmentions_task(queue.clone(), db, http, cancellation_token); - - let router = axum::Router::new() - .route("/.kittybox/webmention", - axum::routing::post(accept_webmention::) - ) - .layer(Extension(queue)); - - (router, bgtask_handle) +pub fn router + FromRef>() -> axum::Router { + axum::Router::new() + .route("/.kittybox/webmention", post(accept_webmention::)) } #[derive(thiserror::Error, Debug)] @@ -184,10 +173,16 @@ async fn process_webmentions_from_queue, S: Storage + 's unreachable!() } -fn supervised_webmentions_task, S: Storage + 'static>( - queue: Q, db: S, - http: reqwest::Client, +pub fn supervised_webmentions_task + 'static, Q: JobQueue + FromRef + 'static>( + state: &St, cancellation_token: tokio_util::sync::CancellationToken -) -> SupervisedTask { - supervisor::, _, _>(move || process_webmentions_from_queue(queue.clone(), db.clone(), http.clone()), cancellation_token) +) -> SupervisedTask +where reqwest::Client: FromRef +{ + let queue = Q::from_ref(state); + let storage = S::from_ref(state); + let http = reqwest::Client::from_ref(state); + supervisor::, _, _>(move || process_webmentions_from_queue( + queue.clone(), storage.clone(), http.clone() + ), cancellation_token) } -- cgit 1.4.1