diff options
author | Vika <vika@fireburn.ru> | 2024-08-01 20:01:12 +0300 |
---|---|---|
committer | Vika <vika@fireburn.ru> | 2024-08-01 20:40:30 +0300 |
commit | 4ca0c24b1989fcd12c453d428af70f58456f7651 (patch) | |
tree | 19f480107cc6491b832a7a2d7198cee48f205b85 | |
parent | 7e8e688e2e58f9c944b941e768ab7b034a348a1f (diff) | |
download | kittybox-4ca0c24b1989fcd12c453d428af70f58456f7651.tar.zst |
Migrate from axum::Extension to axum::extract::State
This somehow allowed me to shrink the construction phase of Kittybox by a huge amount of code.
-rw-r--r-- | src/admin/mod.rs | 31 | ||||
-rw-r--r-- | src/frontend/mod.rs | 7 | ||||
-rw-r--r-- | src/frontend/onboarding.rs | 23 | ||||
-rw-r--r-- | src/indieauth/mod.rs | 45 | ||||
-rw-r--r-- | src/lib.rs | 118 | ||||
-rw-r--r-- | src/main.rs | 423 | ||||
-rw-r--r-- | src/media/mod.rs | 26 | ||||
-rw-r--r-- | src/micropub/mod.rs | 38 | ||||
-rw-r--r-- | src/webmentions/mod.rs | 39 |
9 files changed, 378 insertions, 372 deletions
diff --git a/src/admin/mod.rs b/src/admin/mod.rs index abc4515..7f8c9cf 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -10,9 +10,9 @@ // prevent collisions with future well-known scopes). use std::collections::HashSet; -use axum::extract::Host; +use axum::extract::{State, Host}; use axum::response::{Response, IntoResponse}; -use axum::{Extension, Form}; +use axum::Form; use axum_extra::extract::CookieJar; use hyper::StatusCode; @@ -36,19 +36,19 @@ static SESSION_STORE: std::sync::LazyLock<tokio::sync::RwLock<HashSet<uuid::Uuid async fn set_name<D: Storage + 'static>( Host(host): Host, - Extension(db): Extension<D>, + State(db): State<D>, Form(NameChange { name }): Form<NameChange> ) -> Result<(), StorageError> { db.set_setting::<SiteName>(&host, name).await } -async fn get_name<D: Storage + 'static>(Host(host): Host, Extension(db): Extension<D>) -> Result<String, StorageError> { +async fn get_name<D: Storage + 'static>(Host(host): Host, State(db): State<D>) -> Result<String, StorageError> { db.get_setting::<SiteName>(&host).await.map(|name| name.as_ref().to_owned()) } async fn change_password<A: AuthBackend>( Host(host): Host, - Extension(auth): Extension<A>, + State(auth): State<A>, Form(PasswordChange { old_password, new_password }): Form<PasswordChange> ) -> StatusCode { let website = url::Url::parse(&format!("https://{host}/")).unwrap(); @@ -84,8 +84,8 @@ impl axum::response::IntoResponse for StorageError { async fn dashboard<D: Storage + 'static, A: AuthBackend>( Host(host): Host, cookies: CookieJar, - Extension(db): Extension<D>, - Extension(auth): Extension<A> + State(db): State<D>, + State(auth): State<A> ) -> axum::response::Response { let page = kittybox_frontend_renderer::admin::AdminHome {}; @@ -94,21 +94,26 @@ async fn dashboard<D: Storage + 'static, A: AuthBackend>( } -pub fn router<D: Storage + 'static, A: AuthBackend>(db: D, auth: A) -> axum::Router { +pub fn router<St, A, S, M>() -> axum::Router<St> +where + A: AuthBackend + FromRef<St> + 'static, + S: Storage + FromRef<St> + 'static, + M: MediaStore + FromRef<St> + 'static, + Q: crate::webmentions::JobQueue<crate::webmentions::Webmention> + FromRef<St> + 'static, + axum_extra::extract::cookie::Key: FromRef<St> +{ axum::Router::new() .nest("/.kittybox/admin", axum::Router::new() // routes go here .route( "/", - axum::routing::get(dashboard::<D, A>) + axum::routing::get(dashboard::<S, A>) ) .route( "/api/settings/name", - axum::routing::post(set_name::<D>) - .get(get_name::<D>) + axum::routing::post(set_name::<S>) + .get(get_name::<S>) ) .route("/api/settings/password", axum::routing::post(change_password::<A>)) - .layer(Extension(db)) - .layer(Extension(auth)) ) } diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 0292171..42e8754 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -1,9 +1,8 @@ use crate::database::{Storage, StorageError}; use axum::{ - extract::{Host, Path, Query}, + extract::{Host, Query, State}, http::{StatusCode, Uri}, response::IntoResponse, - Extension, }; use futures_util::FutureExt; use serde::Deserialize; @@ -239,7 +238,7 @@ async fn get_post_from_database<S: Storage>( pub async fn homepage<D: Storage>( Host(host): Host, Query(query): Query<QueryParams>, - Extension(db): Extension<D>, + State(db): State<D>, ) -> impl IntoResponse { let user = None; // TODO authentication // This is stupid, but there is no other way. @@ -333,7 +332,7 @@ pub async fn homepage<D: Storage>( #[tracing::instrument(skip(db))] pub async fn catchall<D: Storage>( - Extension(db): Extension<D>, + State(db): State<D>, Host(host): Host, Query(query): Query<QueryParams>, uri: Uri, diff --git a/src/frontend/onboarding.rs b/src/frontend/onboarding.rs index faf8cdd..9f3f36b 100644 --- a/src/frontend/onboarding.rs +++ b/src/frontend/onboarding.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use crate::database::{settings, Storage}; use axum::{ - extract::{Extension, Host}, + extract::{FromRef, Host, State}, http::StatusCode, response::{Html, IntoResponse}, Json, @@ -131,10 +131,10 @@ async fn onboard<D: Storage + 'static>( } pub async fn post<D: Storage + 'static>( - Extension(db): Extension<D>, + State(db): State<D>, Host(host): Host, - Extension(http): Extension<reqwest::Client>, - Extension(jobset): Extension<Arc<Mutex<JoinSet<()>>>>, + State(http): State<reqwest::Client>, + State(jobset): State<Arc<Mutex<JoinSet<()>>>>, Json(data): Json<OnboardingData>, ) -> axum::response::Response { let user_uid = format!("https://{}/", host.as_str()); @@ -168,14 +168,13 @@ pub async fn post<D: Storage + 'static>( } } -pub fn router<S: Storage + 'static>( - database: S, - http: reqwest::Client, - jobset: Arc<Mutex<JoinSet<()>>>, -) -> axum::routing::MethodRouter { +pub fn router<St, S>() -> axum::routing::MethodRouter<St> +where + S: Storage + FromRef<St> + 'static, + Arc<Mutex<JoinSet<()>>>: FromRef<St>, + reqwest::Client: FromRef<St>, + St: Clone + Send + Sync + 'static, +{ axum::routing::get(get) .post(post::<S>) - .layer::<_, _, std::convert::Infallible>(axum::Extension(database)) - .layer::<_, _, std::convert::Infallible>(axum::Extension(http)) - .layer(axum::Extension(jobset)) } diff --git a/src/indieauth/mod.rs b/src/indieauth/mod.rs index 26879bb..2550df0 100644 --- a/src/indieauth/mod.rs +++ b/src/indieauth/mod.rs @@ -1,12 +1,10 @@ use std::marker::PhantomData; +use microformats::types::Class; use tracing::error; use serde::Deserialize; use axum::{ - extract::{Query, Json, Host, Form}, - response::{Html, IntoResponse, Response}, - http::StatusCode, TypedHeader, headers::{Authorization, authorization::Bearer}, - Extension + extract::{Form, FromRef, Host, Json, Query, State}, headers::{authorization::Bearer, Authorization}, http::StatusCode, response::{Html, IntoResponse, Response}, Extension, TypedHeader }; #[cfg_attr(not(feature = "webauthn"), allow(unused_imports))] use axum_extra::extract::cookie::{CookieJar, Cookie}; @@ -73,18 +71,16 @@ impl axum::response::IntoResponse for IndieAuthResourceError { } #[async_trait::async_trait] -impl <S: Send + Sync, A: AuthBackend> axum::extract::FromRequestParts<S> for User<A> { +impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::extract::FromRequestParts<St> for User<A> { type Rejection = IndieAuthResourceError; - async fn from_request_parts(req: &mut axum::http::request::Parts, state: &S) -> Result<Self, Self::Rejection> { + async fn from_request_parts(req: &mut axum::http::request::Parts, state: &St) -> Result<Self, Self::Rejection> { let TypedHeader(Authorization(token)) = TypedHeader::<Authorization<Bearer>>::from_request_parts(req, state) .await .map_err(|_| IndieAuthResourceError::Unauthorized)?; - let axum::Extension(auth) = axum::Extension::<A>::from_request_parts(req, state) - .await - .unwrap(); + let auth = A::from_ref(state); let Host(host) = Host::from_request_parts(req, state) .await @@ -146,9 +142,9 @@ pub async fn metadata( async fn authorization_endpoint_get<A: AuthBackend, D: Storage + 'static>( Host(host): Host, Query(request): Query<AuthorizationRequest>, - Extension(db): Extension<D>, - Extension(http): Extension<reqwest::Client>, - Extension(auth): Extension<A> + State(db): State<D>, + State(http): State<reqwest::Client>, + State(auth): State<A> ) -> Response { let me = format!("https://{host}/").parse().unwrap(); let h_app = { @@ -802,8 +798,13 @@ async fn userinfo_endpoint_get<A: AuthBackend, D: Storage + 'static>( } } -#[must_use] -pub fn router<A: AuthBackend, D: Storage + 'static>(backend: A, db: D, http: reqwest::Client) -> axum::Router { +pub fn router<St, A, S>() -> axum::Router<St> +where + S: Storage + FromRef<St> + 'static, + A: AuthBackend + FromRef<St>, + reqwest::Client: FromRef<St>, + St: Clone + Send + Sync + 'static +{ use axum::routing::{Router, get, post}; Router::new() @@ -814,14 +815,14 @@ pub fn router<A: AuthBackend, D: Storage + 'static>(backend: A, db: D, http: req get(metadata)) .route( "/auth", - get(authorization_endpoint_get::<A, D>) - .post(authorization_endpoint_post::<A, D>)) + get(authorization_endpoint_get::<A, S>) + .post(authorization_endpoint_post::<A, S>)) .route( "/auth/confirm", post(authorization_endpoint_confirm::<A>)) .route( "/token", - post(token_endpoint_post::<A, D>)) + post(token_endpoint_post::<A, S>)) .route( "/token_status", post(introspection_endpoint_post::<A>)) @@ -830,11 +831,11 @@ pub fn router<A: AuthBackend, D: Storage + 'static>(backend: A, db: D, http: req post(revocation_endpoint_post::<A>)) .route( "/userinfo", - get(userinfo_endpoint_get::<A, D>)) + get(userinfo_endpoint_get::<A, S>)) .route("/webauthn/pre_register", get( - #[cfg(feature = "webauthn")] webauthn::webauthn_pre_register::<A, D>, + #[cfg(feature = "webauthn")] webauthn::webauthn_pre_register::<A, S>, #[cfg(not(feature = "webauthn"))] || std::future::ready(axum::http::StatusCode::NOT_FOUND) ) ) @@ -844,12 +845,6 @@ pub fn router<A: AuthBackend, D: Storage + 'static>(backend: A, db: D, http: req axum::http::Method::POST ]) .allow_origin(tower_http::cors::Any)) - .layer(Extension(backend)) - // I don't really like the fact that I have to use the whole database - // If I could, I would've designed a separate trait for getting profiles - // And made databases implement it, for example - .layer(Extension(db)) - .layer(Extension(http)) ) .route( "/.well-known/oauth-authorization-server", diff --git a/src/lib.rs b/src/lib.rs index 04b3298..1cc01c2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,17 @@ #![forbid(unsafe_code)] #![warn(clippy::todo)] +use std::sync::Arc; + +use axum::extract::FromRef; +use axum_extra::extract::cookie::Key; +use database::{FileStorage, PostgresStorage, Storage}; +use indieauth::backend::{AuthBackend, FileBackend as FileAuthBackend}; +use kittybox_util::queue::{JobItem, JobQueue}; +use media::storage::{MediaStore, file::FileStore as FileMediaStore}; +use tokio::{sync::Mutex, task::JoinSet}; +use webmentions::queue::{PostgresJobItem, PostgresJobQueue}; + /// Database abstraction layer for Kittybox, allowing the CMS to work with any kind of database. pub mod database; pub mod frontend; @@ -9,6 +20,110 @@ pub mod micropub; pub mod indieauth; pub mod webmentions; pub mod login; +//pub mod admin; + +#[derive(Clone)] +pub struct AppState<A, S, M, Q> +where +A: AuthBackend + Sized + 'static, +S: Storage + Sized + 'static, +M: MediaStore + Sized + 'static, +Q: JobQueue<webmentions::Webmention> + Sized +{ + pub auth_backend: A, + pub storage: S, + pub media_store: M, + pub job_queue: Q, + pub http: reqwest::Client, + pub background_jobs: Arc<Mutex<JoinSet<()>>>, + pub cookie_key: Key +} + +// This is really regrettable, but I can't write: +// +// ```compile-error +// impl <A, S, M> FromRef<AppState<A, S, M>> for A +// where A: AuthBackend, S: Storage, M: MediaStore { +// fn from_ref(input: &AppState<A, S, M>) -> A { +// input.auth_backend.clone() +// } +// } +// ``` +// +// ...because of the orphan rule. +// +// I wonder if this would stifle external implementations. I think it +// shouldn't, because my AppState type is generic, and since the +// target type is local, the orphan rule will not kick in. You just +// have to repeat this magic invocation. + +impl<S, M, Q> FromRef<AppState<Self, S, M, Q>> for FileAuthBackend +// where S: Storage, M: MediaStore +where S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +{ + fn from_ref(input: &AppState<Self, S, M, Q>) -> Self { + input.auth_backend.clone() + } +} + +impl<A, M, Q> FromRef<AppState<A, Self, M, Q>> for PostgresStorage +// where A: AuthBackend, M: MediaStore +where A: AuthBackend, M: MediaStore, Q: JobQueue<webmentions::Webmention> +{ + fn from_ref(input: &AppState<A, Self, M, Q>) -> Self { + input.storage.clone() + } +} + +impl<A, M, Q> FromRef<AppState<A, Self, M, Q>> for FileStorage +where A: AuthBackend, M: MediaStore, Q: JobQueue<webmentions::Webmention> +{ + fn from_ref(input: &AppState<A, Self, M, Q>) -> Self { + input.storage.clone() + } +} + +impl<A, S, Q> FromRef<AppState<A, S, Self, Q>> for FileMediaStore +// where A: AuthBackend, S: Storage +where A: AuthBackend, S: Storage, Q: JobQueue<webmentions::Webmention> +{ + fn from_ref(input: &AppState<A, S, Self, Q>) -> Self { + input.media_store.clone() + } +} + +impl<A, S, M, Q> FromRef<AppState<A, S, M, Q>> for Key +where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +{ + fn from_ref(input: &AppState<A, S, M, Q>) -> Self { + input.cookie_key.clone() + } +} + +impl<A, S, M, Q> FromRef<AppState<A, S, M, Q>> for reqwest::Client +where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +{ + fn from_ref(input: &AppState<A, S, M, Q>) -> Self { + input.http.clone() + } +} + +impl<A, S, M, Q> FromRef<AppState<A, S, M, Q>> for Arc<Mutex<JoinSet<()>>> +where A: AuthBackend, S: Storage, M: MediaStore, Q: JobQueue<webmentions::Webmention> +{ + fn from_ref(input: &AppState<A, S, M, Q>) -> Self { + input.background_jobs.clone() + } +} + +#[cfg(feature = "sqlx")] +impl<A, S, M> FromRef<AppState<A, S, M, Self>> for PostgresJobQueue<webmentions::Webmention> +where A: AuthBackend, S: Storage, M: MediaStore +{ + fn from_ref(input: &AppState<A, S, M, Self>) -> Self { + input.job_queue.clone() + } +} pub mod companion { use std::{collections::HashMap, sync::Arc}; @@ -52,8 +167,7 @@ pub mod companion { } } - #[must_use] - pub fn router() -> axum::Router { + pub fn router<St: Clone + Send + Sync + 'static>() -> axum::Router<St> { let resources: ResourceTable = { let mut map = HashMap::new(); diff --git a/src/main.rs b/src/main.rs index 9e541b9..d10822b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,266 +1,58 @@ -use kittybox::database::FileStorage; +use axum::extract::FromRef; +use kittybox::{database::Storage, indieauth::backend::AuthBackend, media::storage::MediaStore, webmentions::Webmention}; +use tokio::{sync::Mutex, task::JoinSet}; use std::{env, time::Duration, sync::Arc}; use tracing::error; -fn init_media<A: kittybox::indieauth::backend::AuthBackend>(auth_backend: A, blobstore_uri: &str) -> axum::Router { - match blobstore_uri.split_once(':').unwrap().0 { - "file" => { - let folder = std::path::PathBuf::from( - blobstore_uri.strip_prefix("file://").unwrap() - ); - let blobstore = kittybox::media::storage::file::FileStore::new(folder); - kittybox::media::router::<_, _>(blobstore, auth_backend) - }, - other => unimplemented!("Unsupported backend: {other}") - } +async fn teapot_route() -> impl axum::response::IntoResponse { + use axum::http::{header, StatusCode}; + (StatusCode::IM_A_TEAPOT, [(header::CONTENT_TYPE, "text/plain")], "Sorry, can't brew coffee yet!") } -async fn compose_kittybox_with_auth<A>( - http: reqwest::Client, - auth_backend: A, - backend_uri: &str, - blobstore_uri: &str, - job_queue_uri: &str, - jobset: &Arc<tokio::sync::Mutex<tokio::task::JoinSet<()>>>, - cancellation_token: &tokio_util::sync::CancellationToken -) -> (axum::Router, kittybox::webmentions::SupervisedTask) -where A: kittybox::indieauth::backend::AuthBackend +async fn health_check<D>( + axum::extract::State(data): axum::extract::State<D>, +) -> impl axum::response::IntoResponse +where + D: kittybox::database::Storage { - match backend_uri.split_once(':').unwrap().0 { - "file" => { - let database = { - let folder = backend_uri.strip_prefix("file://").unwrap(); - let path = std::path::PathBuf::from(folder); - - match kittybox::database::FileStorage::new(path).await { - Ok(db) => db, - Err(err) => { - error!("Error creating database: {:?}", err); - std::process::exit(1); - } - } - }; - - // Technically, if we don't construct the micropub router, - // we could use some wrapper that makes the database - // read-only. - // - // This would allow to exclude all code to write to the - // database and separate reader and writer processes of - // Kittybox to improve security. - let homepage: axum::routing::MethodRouter<_> = axum::routing::get( - kittybox::frontend::homepage::<FileStorage> - ) - .layer(axum::Extension(database.clone())); - let fallback = axum::routing::get( - kittybox::frontend::catchall::<FileStorage> - ) - .layer(axum::Extension(database.clone())); - - let micropub = kittybox::micropub::router( - database.clone(), - http.clone(), - auth_backend.clone(), - Arc::clone(jobset) - ); - let onboarding = kittybox::frontend::onboarding::router( - database.clone(), http.clone(), Arc::clone(jobset) - ); - - - let (webmention, task) = kittybox::webmentions::router( - kittybox::webmentions::queue::PostgresJobQueue::new(job_queue_uri).await.unwrap(), - database.clone(), - http.clone(), - cancellation_token.clone() - ); - - let router = axum::Router::new() - .route("/", homepage) - .fallback(fallback) - .route("/.kittybox/micropub", micropub) - .route("/.kittybox/onboarding", onboarding) - .nest("/.kittybox/media", init_media(auth_backend.clone(), blobstore_uri)) - .merge(kittybox::indieauth::router(auth_backend.clone(), database.clone(), http.clone())) - .merge(webmention) - .route( - "/.kittybox/health", - axum::routing::get(health_check::<kittybox::database::FileStorage>) - .layer(axum::Extension(database)) - ); - - (router, task) - }, - "redis" => unimplemented!("Redis backend is not supported."), - #[cfg(feature = "postgres")] - "postgres" => { - use kittybox::database::PostgresStorage; - - let database = { - match PostgresStorage::new(backend_uri).await { - Ok(db) => db, - Err(err) => { - error!("Error creating database: {:?}", err); - std::process::exit(1); - } - } - }; - - // Technically, if we don't construct the micropub router, - // we could use some wrapper that makes the database - // read-only. - // - // This would allow to exclude all code to write to the - // database and separate reader and writer processes of - // Kittybox to improve security. - let homepage: axum::routing::MethodRouter<_> = axum::routing::get( - kittybox::frontend::homepage::<PostgresStorage> - ) - .layer(axum::Extension(database.clone())); - let fallback = axum::routing::get( - kittybox::frontend::catchall::<PostgresStorage> - ) - .layer(axum::Extension(database.clone())); - - let micropub = kittybox::micropub::router( - database.clone(), - http.clone(), - auth_backend.clone(), - Arc::clone(jobset) - ); - let onboarding = kittybox::frontend::onboarding::router( - database.clone(), http.clone(), Arc::clone(jobset) - ); - - let (webmention, task) = kittybox::webmentions::router( - kittybox::webmentions::queue::PostgresJobQueue::new(job_queue_uri).await.unwrap(), - database.clone(), - http.clone(), - cancellation_token.clone() - ); - - let router = axum::Router::new() - .route("/", homepage) - .fallback(fallback) - .route("/.kittybox/micropub", micropub) - .route("/.kittybox/onboarding", onboarding) - .nest("/.kittybox/media", init_media(auth_backend.clone(), blobstore_uri)) - .merge(kittybox::indieauth::router(auth_backend.clone(), database.clone(), http.clone())) - .merge(webmention) - .route( - "/.kittybox/health", - axum::routing::get(health_check::<kittybox::database::PostgresStorage>) - .layer(axum::Extension(database)) - ); - - (router, task) - }, - other => unimplemented!("Unsupported backend: {other}") - } + (axum::http::StatusCode::OK, std::borrow::Cow::Borrowed("OK")) } -async fn compose_kittybox( - backend_uri: &str, - blobstore_uri: &str, - authstore_uri: &str, - job_queue_uri: &str, - jobset: &Arc<tokio::sync::Mutex<tokio::task::JoinSet<()>>>, - cancellation_token: &tokio_util::sync::CancellationToken -) -> (axum::Router, kittybox::webmentions::SupervisedTask) { - let http: reqwest::Client = { - #[allow(unused_mut)] - let mut builder = reqwest::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), - "/", - env!("CARGO_PKG_VERSION") - )); - if let Ok(certs) = std::env::var("KITTYBOX_CUSTOM_PKI_ROOTS") { - // TODO: add a root certificate if there's an environment variable pointing at it - for path in certs.split(':') { - let metadata = match tokio::fs::metadata(path).await { - Ok(metadata) => metadata, - Err(err) if err.kind() == std::io::ErrorKind::NotFound => { - tracing::error!("TLS root certificate {} not found, skipping...", path); - continue; - } - Err(err) => panic!("Error loading TLS certificates: {}", err) - }; - if metadata.is_dir() { - let mut dir = tokio::fs::read_dir(path).await.unwrap(); - while let Ok(Some(file)) = dir.next_entry().await { - let pem = tokio::fs::read(file.path()).await.unwrap(); - builder = builder.add_root_certificate( - reqwest::Certificate::from_pem(&pem).unwrap() - ); - } - } else { - let pem = tokio::fs::read(path).await.unwrap(); - builder = builder.add_root_certificate( - reqwest::Certificate::from_pem(&pem).unwrap() - ); - } - } - } - - builder.build().unwrap() - }; - - let (router, task) = match authstore_uri.split_once(':').unwrap().0 { - "file" => { - let auth_backend = { - let folder = authstore_uri - .strip_prefix("file://") - .unwrap(); - kittybox::indieauth::backend::fs::FileBackend::new(folder) - }; - - compose_kittybox_with_auth(http, auth_backend, backend_uri, blobstore_uri, job_queue_uri, jobset, cancellation_token).await - } - other => unimplemented!("Unsupported backend: {other}") - }; - - // TODO: load from environment - let cookie_key = axum_extra::extract::cookie::Key::generate(); - let router = router +async fn compose_stateful_kittybox<St, A, S, M, Q>() -> axum::Router<St> +where +A: AuthBackend + 'static + FromRef<St>, +S: Storage + 'static + FromRef<St>, +M: MediaStore + 'static + FromRef<St>, +Q: kittybox_util::queue::JobQueue<kittybox::webmentions::Webmention> + FromRef<St>, +reqwest::Client: FromRef<St>, +Arc<Mutex<JoinSet<()>>>: FromRef<St>, +St: Clone + Send + Sync + 'static +{ + use axum::routing::get; + axum::Router::new() + .route("/", get(kittybox::frontend::homepage::<S>)) + .fallback(get(kittybox::frontend::catchall::<S>)) + .route("/.kittybox/micropub", kittybox::micropub::router::<A, S, St>()) + .route("/.kittybox/onboarding", kittybox::frontend::onboarding::router::<St, S>()) + .nest("/.kittybox/media", kittybox::media::router::<St, A, M>()) + .merge(kittybox::indieauth::router::<St, A, S>()) + .merge(kittybox::webmentions::router::<St, Q>()) + .route("/.kittybox/health", get(health_check::<S>)) .route( "/.kittybox/static/:path", axum::routing::get(kittybox::frontend::statics) ) - .route("/.kittybox/coffee", teapot_route()) - .nest("/.kittybox/micropub/client", kittybox::companion::router()) - .nest("/.kittybox/login", kittybox::login::router(cookie_key)) + .route("/.kittybox/coffee", get(teapot_route)) + .nest("/.kittybox/micropub/client", kittybox::companion::router::<St>()) .layer(tower_http::trace::TraceLayer::new_for_http()) .layer(tower_http::catch_panic::CatchPanicLayer::new()) .layer(tower_http::sensitive_headers::SetSensitiveHeadersLayer::new([ axum::http::header::AUTHORIZATION, axum::http::header::COOKIE, axum::http::header::SET_COOKIE, - ])); - - (router, task) -} - -fn teapot_route() -> axum::routing::MethodRouter { - axum::routing::get(|| async { - use axum::http::{header, StatusCode}; - (StatusCode::IM_A_TEAPOT, [(header::CONTENT_TYPE, "text/plain")], "Sorry, can't brew coffee yet!") - }) -} - -async fn health_check</*A, B, */D>( - //axum::Extension(auth): axum::Extension<A>, - //axum::Extension(blob): axum::Extension<B>, - axum::Extension(data): axum::Extension<D>, -) -> impl axum::response::IntoResponse -where - //A: kittybox::indieauth::backend::AuthBackend, - //B: kittybox::media::storage::MediaStore, - D: kittybox::database::Storage -{ - (axum::http::StatusCode::OK, std::borrow::Cow::Borrowed("OK")) + ])) } #[tokio::main] @@ -306,40 +98,158 @@ async fn main() { tracing::info!("Starting the kittybox server..."); - let backend_uri: String = env::var("BACKEND_URI") + let backend_uri: url::Url = env::var("BACKEND_URI") + .as_deref() + .map(|s| url::Url::parse(s).expect("BACKEND_URI malformed")) .unwrap_or_else(|_| { error!("BACKEND_URI is not set, cannot find a database"); std::process::exit(1); }); - let blobstore_uri: String = env::var("BLOBSTORE_URI") + let blobstore_uri: url::Url = env::var("BLOBSTORE_URI") + .as_deref() + .map(|s| url::Url::parse(s).expect("BLOBSTORE_URI malformed")) .unwrap_or_else(|_| { error!("BLOBSTORE_URI is not set, can't find media store"); std::process::exit(1); }); - let authstore_uri: String = env::var("AUTH_STORE_URI") + let authstore_uri: url::Url = env::var("AUTH_STORE_URI") + .as_deref() + .map(|s| url::Url::parse(s).expect("AUTH_STORE_URI malformed")) .unwrap_or_else(|_| { error!("AUTH_STORE_URI is not set, can't find authentication store"); std::process::exit(1); }); - let job_queue_uri: String = env::var("JOB_QUEUE_URI") + let job_queue_uri: url::Url = env::var("JOB_QUEUE_URI") + .as_deref() + .map(|s| url::Url::parse(s).expect("JOB_QUEUE_URI malformed")) .unwrap_or_else(|_| { error!("JOB_QUEUE_URI is not set, can't find job queue"); std::process::exit(1); }); + // TODO: load from environment + let cookie_key = axum_extra::extract::cookie::Key::generate(); + let cancellation_token = tokio_util::sync::CancellationToken::new(); - let jobset = Arc::new(tokio::sync::Mutex::new(tokio::task::JoinSet::new())); + let jobset = Arc::new(tokio::sync::Mutex::new(tokio::task::JoinSet::<()>::new())); + let http: reqwest::Client = { + #[allow(unused_mut)] + let mut builder = reqwest::Client::builder() + .user_agent(concat!( + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION") + )); + if let Ok(certs) = std::env::var("KITTYBOX_CUSTOM_PKI_ROOTS") { + // TODO: add a root certificate if there's an environment variable pointing at it + for path in certs.split(':') { + let metadata = match tokio::fs::metadata(path).await { + Ok(metadata) => metadata, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => { + tracing::error!("TLS root certificate {} not found, skipping...", path); + continue; + } + Err(err) => panic!("Error loading TLS certificates: {}", err) + }; + if metadata.is_dir() { + let mut dir = tokio::fs::read_dir(path).await.unwrap(); + while let Ok(Some(file)) = dir.next_entry().await { + let pem = tokio::fs::read(file.path()).await.unwrap(); + builder = builder.add_root_certificate( + reqwest::Certificate::from_pem(&pem).unwrap() + ); + } + } else { + let pem = tokio::fs::read(path).await.unwrap(); + builder = builder.add_root_certificate( + reqwest::Certificate::from_pem(&pem).unwrap() + ); + } + } + } + + builder.build().unwrap() + }; + + let backend_type = backend_uri.scheme(); + let blobstore_type = blobstore_uri.scheme(); + let authstore_type = authstore_uri.scheme(); + let job_queue_type = job_queue_uri.scheme(); + + macro_rules! compose_kittybox { + ($auth:ty, $store:ty, $media:ty, $queue:ty) => { { + type AuthBackend = $auth; + type Storage = $store; + type MediaStore = $media; + type JobQueue = $queue; + + let state = kittybox::AppState { + auth_backend: match AuthBackend::new(&authstore_uri).await { + Ok(auth) => auth, + Err(err) => { + error!("Error creating auth backend: {:?}", err); + std::process::exit(1); + } + }, + storage: match Storage::new(&backend_uri).await { + Ok(db) => db, + Err(err) => { + error!("Error creating database: {:?}", err); + std::process::exit(1); + } + }, + media_store: match MediaStore::new(&blobstore_uri).await { + Ok(media) => media, + Err(err) => { + error!("Error creating media store: {:?}", err); + std::process::exit(1); + } + }, + job_queue: match JobQueue::new(&job_queue_uri).await { + Ok(queue) => queue, + Err(err) => { + error!("Error creating webmention job queue: {:?}", err); + std::process::exit(1); + } + }, + http, + background_jobs: jobset.clone(), + cookie_key + }; + + type St = kittybox::AppState<AuthBackend, Storage, MediaStore, JobQueue>; + let stateful_router = compose_stateful_kittybox::<St, AuthBackend, Storage, MediaStore, JobQueue>().await; + let task = kittybox::webmentions::supervised_webmentions_task::<St, Storage, JobQueue>(&state, cancellation_token.clone()); + let router = stateful_router.with_state(state); - let (router, webmentions_task) = compose_kittybox( - backend_uri.as_str(), - blobstore_uri.as_str(), - authstore_uri.as_str(), - job_queue_uri.as_str(), - &jobset, - &cancellation_token - ).await; + (router, task) + } } + } + + let (router, webmentions_task): (axum::Router<()>, kittybox::webmentions::SupervisedTask) = match (authstore_type, backend_type, blobstore_type, job_queue_type) { + ("file", "file", "file", "postgres") => { + compose_kittybox!( + kittybox::indieauth::backend::fs::FileBackend, + kittybox::database::FileStorage, + kittybox::media::storage::file::FileStore, + kittybox::webmentions::queue::PostgresJobQueue<Webmention> + ) + }, + ("file", "postgres", "file", "postgres") => { + compose_kittybox!( + kittybox::indieauth::backend::fs::FileBackend, + kittybox::database::PostgresStorage, + kittybox::media::storage::file::FileStore, + kittybox::webmentions::queue::PostgresJobQueue<Webmention> + ) + }, + (_, _, _, _) => { + // TODO: refine this error. + panic!("Invalid type for AUTH_STORE_URI, BACKEND_URI, BLOBSTORE_URI or JOB_QUEUE_URI"); + } + }; let mut servers: Vec<hyper::server::Server<hyper::server::conn::AddrIncoming, _>> = vec![]; @@ -494,5 +404,4 @@ async fn main() { while (jobset.join_next().await).is_some() {} tracing::info!("Shutdown complete, exiting."); std::process::exit(exitcode); - } diff --git a/src/media/mod.rs b/src/media/mod.rs index 71f875e..47f456a 100644 --- a/src/media/mod.rs +++ b/src/media/mod.rs @@ -1,14 +1,9 @@ -use std::convert::TryFrom; - use axum::{ - extract::{Extension, Host, multipart::Multipart, Path}, - response::{IntoResponse, Response}, - headers::{Header, HeaderValue, IfNoneMatch, HeaderMapExt}, - TypedHeader, + extract::{multipart::Multipart, FromRef, Host, Path, State}, headers::{HeaderMapExt, HeaderValue, IfNoneMatch}, response::{IntoResponse, Response}, TypedHeader }; use kittybox_util::error::{MicropubError, ErrorType}; use kittybox_indieauth::Scope; -use crate::indieauth::{User, backend::AuthBackend}; +use crate::indieauth::{backend::AuthBackend, User}; pub mod storage; use storage::{MediaStore, MediaStoreError, Metadata, ErrorKind}; @@ -25,7 +20,7 @@ impl From<MediaStoreError> for MicropubError { #[tracing::instrument(skip(blobstore))] pub(crate) async fn upload<S: MediaStore, A: AuthBackend>( - Extension(blobstore): Extension<S>, + State(blobstore): State<S>, user: User<A>, mut upload: Multipart ) -> Response { @@ -70,7 +65,7 @@ pub(crate) async fn serve<S: MediaStore>( Host(host): Host, Path(path): Path<String>, if_none_match: Option<TypedHeader<IfNoneMatch>>, - Extension(blobstore): Extension<S> + State(blobstore): State<S> ) -> Response { use axum::http::StatusCode; tracing::debug!("Searching for file..."); @@ -131,11 +126,12 @@ pub(crate) async fn serve<S: MediaStore>( } } -#[must_use] -pub fn router<S: MediaStore, A: AuthBackend>(blobstore: S, auth: A) -> axum::Router { +pub fn router<St, A, M>() -> axum::Router<St> where + A: AuthBackend + FromRef<St>, + M: MediaStore + FromRef<St>, + St: Clone + Send + Sync + 'static +{ axum::Router::new() - .route("/", axum::routing::post(upload::<S, A>)) - .route("/uploads/*file", axum::routing::get(serve::<S>)) - .layer(axum::Extension(blobstore)) - .layer(axum::Extension(auth)) + .route("/", axum::routing::post(upload::<M, A>)) + .route("/uploads/*file", axum::routing::get(serve::<M>)) } diff --git a/src/micropub/mod.rs b/src/micropub/mod.rs index 624c239..fc5dd10 100644 --- a/src/micropub/mod.rs +++ b/src/micropub/mod.rs @@ -5,12 +5,13 @@ use std::sync::Arc; use crate::database::{MicropubChannel, Storage, StorageError}; use crate::indieauth::backend::AuthBackend; use crate::indieauth::User; +use crate::media::storage::MediaStore; use crate::micropub::util::form_to_mf2_json; -use axum::extract::{BodyStream, Query, Host}; +use axum::extract::{BodyStream, FromRef, Host, Query, State}; use axum::headers::ContentType; use axum::response::{IntoResponse, Response}; use axum::TypedHeader; -use axum::{http::StatusCode, Extension}; +use axum::http::StatusCode; use serde::{Deserialize, Serialize}; use serde_json::json; use tokio::sync::Mutex; @@ -515,9 +516,9 @@ async fn dispatch_body( #[tracing::instrument(skip(db, http))] pub(crate) async fn post<D: Storage + 'static, A: AuthBackend>( - Extension(db): Extension<D>, - Extension(http): Extension<reqwest::Client>, - Extension(jobset): Extension<Arc<Mutex<JoinSet<()>>>>, + State(db): State<D>, + State(http): State<reqwest::Client>, + State(jobset): State<Arc<Mutex<JoinSet<()>>>>, TypedHeader(content_type): TypedHeader<ContentType>, user: User<A>, body: BodyStream, @@ -540,7 +541,7 @@ pub(crate) async fn post<D: Storage + 'static, A: AuthBackend>( #[tracing::instrument(skip(db))] pub(crate) async fn query<D: Storage, A: AuthBackend>( - Extension(db): Extension<D>, + State(db): State<D>, query: Option<Query<MicropubQuery>>, Host(host): Host, user: User<A>, @@ -662,16 +663,13 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( } } -#[must_use] -pub fn router<S, A>( - storage: S, - http: reqwest::Client, - auth: A, - jobset: Arc<Mutex<JoinSet<()>>> -) -> axum::routing::MethodRouter + +pub fn router<A, S, St: Send + Sync + Clone + 'static>() -> axum::routing::MethodRouter<St> where - S: Storage + 'static, - A: AuthBackend + S: Storage + FromRef<St> + 'static, + A: AuthBackend + FromRef<St>, + reqwest::Client: FromRef<St>, + Arc<Mutex<JoinSet<()>>>: FromRef<St> { axum::routing::get(query::<S, A>) .post(post::<S, A>) @@ -680,11 +678,7 @@ where axum::http::Method::GET, axum::http::Method::POST, ]) - .allow_origin(tower_http::cors::Any)) - .layer::<_, _, std::convert::Infallible>(axum::Extension(storage)) - .layer::<_, _, std::convert::Infallible>(axum::Extension(http)) - .layer::<_, _, std::convert::Infallible>(axum::Extension(auth)) - .layer::<_, _, std::convert::Infallible>(axum::Extension(jobset)) + .allow_origin(tower_http::cors::Any)) } #[cfg(test)] @@ -716,7 +710,7 @@ mod tests { use super::FetchedPostContext; use kittybox_indieauth::{Scopes, Scope, TokenData}; - use axum::extract::Host; + use axum::extract::{Host, State}; #[test] fn test_populate_reply_context() { @@ -850,7 +844,7 @@ mod tests { #[tokio::test] async fn test_query_foreign_url() { let mut res = super::query( - axum::Extension(crate::database::MemoryStorage::default()), + State(crate::database::MemoryStorage::default()), Some(axum::extract::Query(super::MicropubQuery::source( "https://aaronparecki.com/feeds/main", ))), diff --git a/src/webmentions/mod.rs b/src/webmentions/mod.rs index 3e9b094..d5a617e 100644 --- a/src/webmentions/mod.rs +++ b/src/webmentions/mod.rs @@ -1,4 +1,4 @@ -use axum::{Form, response::{IntoResponse, Response}, Extension}; +use axum::{extract::{FromRef, State}, response::{IntoResponse, Response}, routing::post, Form}; use axum::http::StatusCode; use tracing::error; @@ -20,7 +20,7 @@ impl queue::PostgresJobItem for Webmention { } async fn accept_webmention<Q: JobQueue<Webmention>>( - Extension(queue): Extension<Q>, + State(queue): State<Q>, Form(webmention): Form<Webmention>, ) -> Response { if let Err(err) = webmention.source.parse::<url::Url>() { @@ -31,27 +31,16 @@ async fn accept_webmention<Q: JobQueue<Webmention>>( } match queue.put(&webmention).await { - Ok(id) => StatusCode::ACCEPTED.into_response(), + Ok(_id) => StatusCode::ACCEPTED.into_response(), Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, [ ("Content-Type", "text/plain") ], err.to_string()).into_response() } } -pub fn router<Q: JobQueue<Webmention>, S: Storage + 'static>( - queue: Q, db: S, http: reqwest::Client, - cancellation_token: tokio_util::sync::CancellationToken -) -> (axum::Router, SupervisedTask) { - // Automatically spawn a background task to handle webmentions - let bgtask_handle = supervised_webmentions_task(queue.clone(), db, http, cancellation_token); - - let router = axum::Router::new() - .route("/.kittybox/webmention", - axum::routing::post(accept_webmention::<Q>) - ) - .layer(Extension(queue)); - - (router, bgtask_handle) +pub fn router<St: Clone + Send + Sync + 'static, Q: JobQueue<Webmention> + FromRef<St>>() -> axum::Router<St> { + axum::Router::new() + .route("/.kittybox/webmention", post(accept_webmention::<Q>)) } #[derive(thiserror::Error, Debug)] @@ -184,10 +173,16 @@ async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 's unreachable!() } -fn supervised_webmentions_task<Q: JobQueue<Webmention>, S: Storage + 'static>( - queue: Q, db: S, - http: reqwest::Client, +pub fn supervised_webmentions_task<St: Send + Sync + 'static, S: Storage + FromRef<St> + 'static, Q: JobQueue<Webmention> + FromRef<St> + 'static>( + state: &St, cancellation_token: tokio_util::sync::CancellationToken -) -> SupervisedTask { - supervisor::<Error<Q::Error>, _, _>(move || process_webmentions_from_queue(queue.clone(), db.clone(), http.clone()), cancellation_token) +) -> SupervisedTask +where reqwest::Client: FromRef<St> +{ + let queue = Q::from_ref(state); + let storage = S::from_ref(state); + let http = reqwest::Client::from_ref(state); + supervisor::<Error<Q::Error>, _, _>(move || process_webmentions_from_queue( + queue.clone(), storage.clone(), http.clone() + ), cancellation_token) } |