about summary refs log tree commit diff
path: root/src/main.rs
diff options
context:
space:
mode:
authorVika <vika@fireburn.ru>2024-08-01 20:01:12 +0300
committerVika <vika@fireburn.ru>2024-08-01 20:40:30 +0300
commit4ca0c24b1989fcd12c453d428af70f58456f7651 (patch)
tree19f480107cc6491b832a7a2d7198cee48f205b85 /src/main.rs
parent7e8e688e2e58f9c944b941e768ab7b034a348a1f (diff)
Migrate from axum::Extension to axum::extract::State
This somehow allowed me to shrink the construction phase of Kittybox
by a huge amount of code.
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs423
1 files changed, 166 insertions, 257 deletions
diff --git a/src/main.rs b/src/main.rs
index 9e541b9..d10822b 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,266 +1,58 @@
-use kittybox::database::FileStorage;
+use axum::extract::FromRef;
+use kittybox::{database::Storage, indieauth::backend::AuthBackend, media::storage::MediaStore, webmentions::Webmention};
+use tokio::{sync::Mutex, task::JoinSet};
 use std::{env, time::Duration, sync::Arc};
 use tracing::error;
 
-fn init_media<A: kittybox::indieauth::backend::AuthBackend>(auth_backend: A, blobstore_uri: &str) -> axum::Router {
-    match blobstore_uri.split_once(':').unwrap().0 {
-        "file" => {
-            let folder = std::path::PathBuf::from(
-                blobstore_uri.strip_prefix("file://").unwrap()
-            );
-            let blobstore = kittybox::media::storage::file::FileStore::new(folder);
 
-            kittybox::media::router::<_, _>(blobstore, auth_backend)
-        },
-        other => unimplemented!("Unsupported backend: {other}")
-    }
+async fn teapot_route() -> impl axum::response::IntoResponse {
+    use axum::http::{header, StatusCode};
+    (StatusCode::IM_A_TEAPOT, [(header::CONTENT_TYPE, "text/plain")], "Sorry, can't brew coffee yet!")
 }
 
-async fn compose_kittybox_with_auth<A>(
-    http: reqwest::Client,
-    auth_backend: A,
-    backend_uri: &str,
-    blobstore_uri: &str,
-    job_queue_uri: &str,
-    jobset: &Arc<tokio::sync::Mutex<tokio::task::JoinSet<()>>>,
-    cancellation_token: &tokio_util::sync::CancellationToken
-) -> (axum::Router, kittybox::webmentions::SupervisedTask)
-where A: kittybox::indieauth::backend::AuthBackend
+async fn health_check<D>(
+    axum::extract::State(data): axum::extract::State<D>,
+) -> impl axum::response::IntoResponse
+where
+    D: kittybox::database::Storage
 {
-    match backend_uri.split_once(':').unwrap().0 {
-        "file" => {
-            let database = {
-                let folder = backend_uri.strip_prefix("file://").unwrap();
-                let path = std::path::PathBuf::from(folder);
-
-                match kittybox::database::FileStorage::new(path).await {
-                    Ok(db) => db,
-                    Err(err) => {
-                        error!("Error creating database: {:?}", err);
-                        std::process::exit(1);
-                    }
-                }
-            };
-
-            // Technically, if we don't construct the micropub router,
-            // we could use some wrapper that makes the database
-            // read-only.
-            //
-            // This would allow to exclude all code to write to the
-            // database and separate reader and writer processes of
-            // Kittybox to improve security.
-            let homepage: axum::routing::MethodRouter<_> = axum::routing::get(
-                kittybox::frontend::homepage::<FileStorage>
-            )
-                .layer(axum::Extension(database.clone()));
-            let fallback = axum::routing::get(
-                kittybox::frontend::catchall::<FileStorage>
-            )
-                .layer(axum::Extension(database.clone()));
-
-            let micropub = kittybox::micropub::router(
-                database.clone(),
-                http.clone(),
-                auth_backend.clone(),
-                Arc::clone(jobset)
-            );
-            let onboarding = kittybox::frontend::onboarding::router(
-                database.clone(), http.clone(), Arc::clone(jobset)
-            );
-
-
-            let (webmention, task) = kittybox::webmentions::router(
-                kittybox::webmentions::queue::PostgresJobQueue::new(job_queue_uri).await.unwrap(),
-                database.clone(),
-                http.clone(),
-                cancellation_token.clone()
-            );
-
-            let router = axum::Router::new()
-                .route("/", homepage)
-                .fallback(fallback)
-                .route("/.kittybox/micropub", micropub)
-                .route("/.kittybox/onboarding", onboarding)
-                .nest("/.kittybox/media", init_media(auth_backend.clone(), blobstore_uri))
-                .merge(kittybox::indieauth::router(auth_backend.clone(), database.clone(), http.clone()))
-                .merge(webmention)
-                .route(
-                    "/.kittybox/health",
-                    axum::routing::get(health_check::<kittybox::database::FileStorage>)
-                        .layer(axum::Extension(database))
-                );
-
-            (router, task)
-        },
-        "redis" => unimplemented!("Redis backend is not supported."),
-        #[cfg(feature = "postgres")]
-        "postgres" => {
-            use kittybox::database::PostgresStorage;
-
-            let database = {
-                match PostgresStorage::new(backend_uri).await {
-                    Ok(db) => db,
-                    Err(err) => {
-                        error!("Error creating database: {:?}", err);
-                        std::process::exit(1);
-                    }
-                }
-            };
-
-            // Technically, if we don't construct the micropub router,
-            // we could use some wrapper that makes the database
-            // read-only.
-            //
-            // This would allow to exclude all code to write to the
-            // database and separate reader and writer processes of
-            // Kittybox to improve security.
-            let homepage: axum::routing::MethodRouter<_> = axum::routing::get(
-                kittybox::frontend::homepage::<PostgresStorage>
-            )
-                .layer(axum::Extension(database.clone()));
-            let fallback = axum::routing::get(
-                kittybox::frontend::catchall::<PostgresStorage>
-            )
-                .layer(axum::Extension(database.clone()));
-
-            let micropub = kittybox::micropub::router(
-                database.clone(),
-                http.clone(),
-                auth_backend.clone(),
-                Arc::clone(jobset)
-            );
-            let onboarding = kittybox::frontend::onboarding::router(
-                database.clone(), http.clone(), Arc::clone(jobset)
-            );
-
-            let (webmention, task) = kittybox::webmentions::router(
-                kittybox::webmentions::queue::PostgresJobQueue::new(job_queue_uri).await.unwrap(),
-                database.clone(),
-                http.clone(),
-                cancellation_token.clone()
-            );
-
-            let router = axum::Router::new()
-                .route("/", homepage)
-                .fallback(fallback)
-                .route("/.kittybox/micropub", micropub)
-                .route("/.kittybox/onboarding", onboarding)
-                .nest("/.kittybox/media", init_media(auth_backend.clone(), blobstore_uri))
-                .merge(kittybox::indieauth::router(auth_backend.clone(), database.clone(), http.clone()))
-                .merge(webmention)
-                .route(
-                    "/.kittybox/health",
-                    axum::routing::get(health_check::<kittybox::database::PostgresStorage>)
-                        .layer(axum::Extension(database))
-                );
-
-            (router, task)
-        },
-        other => unimplemented!("Unsupported backend: {other}")
-    }
+    (axum::http::StatusCode::OK, std::borrow::Cow::Borrowed("OK"))
 }
 
-async fn compose_kittybox(
-    backend_uri: &str,
-    blobstore_uri: &str,
-    authstore_uri: &str,
-    job_queue_uri: &str,
-    jobset: &Arc<tokio::sync::Mutex<tokio::task::JoinSet<()>>>,
-    cancellation_token: &tokio_util::sync::CancellationToken
-) -> (axum::Router, kittybox::webmentions::SupervisedTask) {
-    let http: reqwest::Client = {
-        #[allow(unused_mut)]
-        let mut builder = reqwest::Client::builder()
-            .user_agent(concat!(
-                env!("CARGO_PKG_NAME"),
-                "/",
-                env!("CARGO_PKG_VERSION")
-            ));
-        if let Ok(certs) = std::env::var("KITTYBOX_CUSTOM_PKI_ROOTS") {
-            // TODO: add a root certificate if there's an environment variable pointing at it
-            for path in certs.split(':') {
-                let metadata = match tokio::fs::metadata(path).await {
-                    Ok(metadata) => metadata,
-                    Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
-                        tracing::error!("TLS root certificate {} not found, skipping...", path);
-                        continue;
-                    }
-                    Err(err) => panic!("Error loading TLS certificates: {}", err)
-                };
-                if metadata.is_dir() {
-                    let mut dir = tokio::fs::read_dir(path).await.unwrap();
-                    while let Ok(Some(file)) = dir.next_entry().await {
-                        let pem = tokio::fs::read(file.path()).await.unwrap();
-                        builder = builder.add_root_certificate(
-                            reqwest::Certificate::from_pem(&pem).unwrap()
-                        );
-                    }
-                } else {
-                    let pem = tokio::fs::read(path).await.unwrap();
-                    builder = builder.add_root_certificate(
-                        reqwest::Certificate::from_pem(&pem).unwrap()
-                    );
-                }
-            }
-        }
-
-        builder.build().unwrap()
-    };
-
-    let (router, task) = match authstore_uri.split_once(':').unwrap().0 {
-        "file" => {
-            let auth_backend = {
-                let folder = authstore_uri
-                    .strip_prefix("file://")
-                    .unwrap();
-                kittybox::indieauth::backend::fs::FileBackend::new(folder)
-            };
-
-            compose_kittybox_with_auth(http, auth_backend, backend_uri, blobstore_uri, job_queue_uri, jobset, cancellation_token).await
-        }
-        other => unimplemented!("Unsupported backend: {other}")
-    };
-
-    // TODO: load from environment
-    let cookie_key = axum_extra::extract::cookie::Key::generate();
 
-    let router = router
+async fn compose_stateful_kittybox<St, A, S, M, Q>() -> axum::Router<St>
+where
+A: AuthBackend + 'static + FromRef<St>,
+S: Storage + 'static + FromRef<St>,
+M: MediaStore + 'static + FromRef<St>,
+Q: kittybox_util::queue::JobQueue<kittybox::webmentions::Webmention> + FromRef<St>,
+reqwest::Client: FromRef<St>,
+Arc<Mutex<JoinSet<()>>>: FromRef<St>,
+St: Clone + Send + Sync + 'static
+{
+    use axum::routing::get;
+    axum::Router::new()
+        .route("/", get(kittybox::frontend::homepage::<S>))
+        .fallback(get(kittybox::frontend::catchall::<S>))
+        .route("/.kittybox/micropub", kittybox::micropub::router::<A, S, St>())
+        .route("/.kittybox/onboarding", kittybox::frontend::onboarding::router::<St, S>())
+        .nest("/.kittybox/media", kittybox::media::router::<St, A, M>())
+        .merge(kittybox::indieauth::router::<St, A, S>())
+        .merge(kittybox::webmentions::router::<St, Q>())
+        .route("/.kittybox/health", get(health_check::<S>))
         .route(
             "/.kittybox/static/:path",
             axum::routing::get(kittybox::frontend::statics)
         )
-        .route("/.kittybox/coffee", teapot_route())
-        .nest("/.kittybox/micropub/client", kittybox::companion::router())
-        .nest("/.kittybox/login", kittybox::login::router(cookie_key))
+        .route("/.kittybox/coffee", get(teapot_route))
+        .nest("/.kittybox/micropub/client", kittybox::companion::router::<St>())
         .layer(tower_http::trace::TraceLayer::new_for_http())
         .layer(tower_http::catch_panic::CatchPanicLayer::new())
         .layer(tower_http::sensitive_headers::SetSensitiveHeadersLayer::new([
             axum::http::header::AUTHORIZATION,
             axum::http::header::COOKIE,
             axum::http::header::SET_COOKIE,
-        ]));
-
-    (router, task)
-}
-
-fn teapot_route() -> axum::routing::MethodRouter {
-    axum::routing::get(|| async {
-        use axum::http::{header, StatusCode};
-        (StatusCode::IM_A_TEAPOT, [(header::CONTENT_TYPE, "text/plain")], "Sorry, can't brew coffee yet!")
-    })
-}
-
-async fn health_check</*A, B, */D>(
-    //axum::Extension(auth): axum::Extension<A>,
-    //axum::Extension(blob): axum::Extension<B>,
-    axum::Extension(data): axum::Extension<D>,
-) -> impl axum::response::IntoResponse
-where
-    //A: kittybox::indieauth::backend::AuthBackend,
-    //B: kittybox::media::storage::MediaStore,
-    D: kittybox::database::Storage
-{
-    (axum::http::StatusCode::OK, std::borrow::Cow::Borrowed("OK"))
+        ]))
 }
 
 #[tokio::main]
@@ -306,40 +98,158 @@ async fn main() {
 
     tracing::info!("Starting the kittybox server...");
 
-    let backend_uri: String = env::var("BACKEND_URI")
+    let backend_uri: url::Url = env::var("BACKEND_URI")
+        .as_deref()
+        .map(|s| url::Url::parse(s).expect("BACKEND_URI malformed"))
         .unwrap_or_else(|_| {
             error!("BACKEND_URI is not set, cannot find a database");
             std::process::exit(1);
         });
-    let blobstore_uri: String = env::var("BLOBSTORE_URI")
+    let blobstore_uri: url::Url = env::var("BLOBSTORE_URI")
+        .as_deref()
+        .map(|s| url::Url::parse(s).expect("BLOBSTORE_URI malformed"))
         .unwrap_or_else(|_| {
             error!("BLOBSTORE_URI is not set, can't find media store");
             std::process::exit(1);
         });
 
-    let authstore_uri: String = env::var("AUTH_STORE_URI")
+    let authstore_uri: url::Url = env::var("AUTH_STORE_URI")
+        .as_deref()
+        .map(|s| url::Url::parse(s).expect("AUTH_STORE_URI malformed"))
         .unwrap_or_else(|_| {
             error!("AUTH_STORE_URI is not set, can't find authentication store");
             std::process::exit(1);
         });
 
-    let job_queue_uri: String = env::var("JOB_QUEUE_URI")
+    let job_queue_uri: url::Url = env::var("JOB_QUEUE_URI")
+        .as_deref()
+        .map(|s| url::Url::parse(s).expect("JOB_QUEUE_URI malformed"))
         .unwrap_or_else(|_| {
             error!("JOB_QUEUE_URI is not set, can't find job queue");
             std::process::exit(1);
         });
 
+    // TODO: load from environment
+    let cookie_key = axum_extra::extract::cookie::Key::generate();
+
     let cancellation_token = tokio_util::sync::CancellationToken::new();
-    let jobset = Arc::new(tokio::sync::Mutex::new(tokio::task::JoinSet::new()));
+    let jobset = Arc::new(tokio::sync::Mutex::new(tokio::task::JoinSet::<()>::new()));
+    let http: reqwest::Client = {
+        #[allow(unused_mut)]
+        let mut builder = reqwest::Client::builder()
+            .user_agent(concat!(
+                env!("CARGO_PKG_NAME"),
+                "/",
+                env!("CARGO_PKG_VERSION")
+            ));
+        if let Ok(certs) = std::env::var("KITTYBOX_CUSTOM_PKI_ROOTS") {
+            // TODO: add a root certificate if there's an environment variable pointing at it
+            for path in certs.split(':') {
+                let metadata = match tokio::fs::metadata(path).await {
+                    Ok(metadata) => metadata,
+                    Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
+                        tracing::error!("TLS root certificate {} not found, skipping...", path);
+                        continue;
+                    }
+                    Err(err) => panic!("Error loading TLS certificates: {}", err)
+                };
+                if metadata.is_dir() {
+                    let mut dir = tokio::fs::read_dir(path).await.unwrap();
+                    while let Ok(Some(file)) = dir.next_entry().await {
+                        let pem = tokio::fs::read(file.path()).await.unwrap();
+                        builder = builder.add_root_certificate(
+                            reqwest::Certificate::from_pem(&pem).unwrap()
+                        );
+                    }
+                } else {
+                    let pem = tokio::fs::read(path).await.unwrap();
+                    builder = builder.add_root_certificate(
+                        reqwest::Certificate::from_pem(&pem).unwrap()
+                    );
+                }
+            }
+        }
+
+        builder.build().unwrap()
+    };
+
+    let backend_type = backend_uri.scheme();
+    let blobstore_type = blobstore_uri.scheme();
+    let authstore_type = authstore_uri.scheme();
+    let job_queue_type = job_queue_uri.scheme();
+
+    macro_rules! compose_kittybox {
+        ($auth:ty, $store:ty, $media:ty, $queue:ty) => { {
+            type AuthBackend = $auth;
+            type Storage = $store;
+            type MediaStore = $media;
+            type JobQueue = $queue;
+
+            let state = kittybox::AppState {
+                auth_backend: match AuthBackend::new(&authstore_uri).await {
+                    Ok(auth) => auth,
+                    Err(err) => {
+                        error!("Error creating auth backend: {:?}", err);
+                        std::process::exit(1);
+                    }
+                },
+                storage: match Storage::new(&backend_uri).await {
+                    Ok(db) => db,
+                    Err(err) => {
+                        error!("Error creating database: {:?}", err);
+                        std::process::exit(1);
+                    }
+                },
+                media_store: match MediaStore::new(&blobstore_uri).await {
+                    Ok(media) => media,
+                    Err(err) => {
+                        error!("Error creating media store: {:?}", err);
+                        std::process::exit(1);
+                    }
+                },
+                job_queue: match JobQueue::new(&job_queue_uri).await {
+                    Ok(queue) => queue,
+                    Err(err) => {
+                        error!("Error creating webmention job queue: {:?}", err);
+                        std::process::exit(1);
+                    }
+                },
+                http,
+                background_jobs: jobset.clone(),
+                cookie_key
+            };
+
+            type St = kittybox::AppState<AuthBackend, Storage, MediaStore, JobQueue>;
+            let stateful_router = compose_stateful_kittybox::<St, AuthBackend, Storage, MediaStore, JobQueue>().await;
+            let task = kittybox::webmentions::supervised_webmentions_task::<St, Storage, JobQueue>(&state, cancellation_token.clone());
+            let router = stateful_router.with_state(state);
 
-    let (router, webmentions_task) = compose_kittybox(
-        backend_uri.as_str(),
-        blobstore_uri.as_str(),
-        authstore_uri.as_str(),
-        job_queue_uri.as_str(),
-        &jobset,
-        &cancellation_token
-    ).await;
+            (router, task)
+        } }
+    }
+
+    let (router, webmentions_task): (axum::Router<()>, kittybox::webmentions::SupervisedTask) = match (authstore_type, backend_type, blobstore_type, job_queue_type) {
+        ("file", "file", "file", "postgres") => {
+            compose_kittybox!(
+                kittybox::indieauth::backend::fs::FileBackend,
+                kittybox::database::FileStorage,
+                kittybox::media::storage::file::FileStore,
+                kittybox::webmentions::queue::PostgresJobQueue<Webmention>
+            )
+        },
+        ("file", "postgres", "file", "postgres") => {
+            compose_kittybox!(
+                kittybox::indieauth::backend::fs::FileBackend,
+                kittybox::database::PostgresStorage,
+                kittybox::media::storage::file::FileStore,
+                kittybox::webmentions::queue::PostgresJobQueue<Webmention>
+            )
+        },
+        (_, _, _, _) => {
+            // TODO: refine this error.
+            panic!("Invalid type for AUTH_STORE_URI, BACKEND_URI, BLOBSTORE_URI or JOB_QUEUE_URI");
+        }
+    };
 
     let mut servers: Vec<hyper::server::Server<hyper::server::conn::AddrIncoming, _>> = vec![];
 
@@ -494,5 +404,4 @@ async fn main() {
     while (jobset.join_next().await).is_some() {}
     tracing::info!("Shutdown complete, exiting.");
     std::process::exit(exitcode);
-
 }