about summary refs log blame commit diff
path: root/src/main.rs
blob: 9e541b99067a9b4bec722e13298974cb50e61391 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
                                    
                                          
 





                                                                                                                   
 



                                                                    
 


                                       
                        


                                                               














                                                                          
 

















                                                                              
                                     
                                                                    
                                                                  
              







                                                                                                  




                                                                                                         
                                  


                                                                                       

                          
                                                                     

                                                    
 



























                                                                              
                                     
                                                                    
                                                                  
              





                                                                                                  
                                            




                                                                                                         
                                  


                                                                                           
                  
                          
          


                                                               


                          
                        
                                                               
                                                            
                                                            
                                 































                                                                                                


                                
                                                                         





                                                                          
 
                                                                                                                                       

                                                               
 

                                                                  
                       




                                                                          
                                                                      
                                                             

                                                                             

                                           
                  
 
 




                                                                                                          
 










                                                                  
 

                                                                                                 
 
                                              
                                            
                                    


                                                   
                                                   
























                                                                                       
 
                                                      
 








                                                                       
 

                                                                                 
                                  
           

                                                                     

                                  
                                                                        
                                                                                
 
                                                      
                               
                               
                               
                
                           
            
                                                                                               
 
                                                                     
                                                              
                                           
 


                                                                   
      




                                                            
                                                                                       
         
































                                                                                           












                                                             

                                                   






                                                                                           
                                                                                              



                                                                        
                             













                                                                                                               
 






                                           
                                        







                                                               
 
             
         



                                                           
                                
                                                            
                                        
 
             
         

                                                                 





                                                          





                                                                      
 
use kittybox::database::FileStorage;
use std::{env, time::Duration, sync::Arc};
use tracing::error;

fn init_media<A: kittybox::indieauth::backend::AuthBackend>(auth_backend: A, blobstore_uri: &str) -> axum::Router {
    match blobstore_uri.split_once(':').unwrap().0 {
        "file" => {
            let folder = std::path::PathBuf::from(
                blobstore_uri.strip_prefix("file://").unwrap()
            );
            let blobstore = kittybox::media::storage::file::FileStore::new(folder);

            kittybox::media::router::<_, _>(blobstore, auth_backend)
        },
        other => unimplemented!("Unsupported backend: {other}")
    }
}

async fn compose_kittybox_with_auth<A>(
    http: reqwest::Client,
    auth_backend: A,
    backend_uri: &str,
    blobstore_uri: &str,
    job_queue_uri: &str,
    jobset: &Arc<tokio::sync::Mutex<tokio::task::JoinSet<()>>>,
    cancellation_token: &tokio_util::sync::CancellationToken
) -> (axum::Router, kittybox::webmentions::SupervisedTask)
where A: kittybox::indieauth::backend::AuthBackend
{
    match backend_uri.split_once(':').unwrap().0 {
        "file" => {
            let database = {
                let folder = backend_uri.strip_prefix("file://").unwrap();
                let path = std::path::PathBuf::from(folder);

                match kittybox::database::FileStorage::new(path).await {
                    Ok(db) => db,
                    Err(err) => {
                        error!("Error creating database: {:?}", err);
                        std::process::exit(1);
                    }
                }
            };

            // Technically, if we don't construct the micropub router,
            // we could use some wrapper that makes the database
            // read-only.
            //
            // This would allow to exclude all code to write to the
            // database and separate reader and writer processes of
            // Kittybox to improve security.
            let homepage: axum::routing::MethodRouter<_> = axum::routing::get(
                kittybox::frontend::homepage::<FileStorage>
            )
                .layer(axum::Extension(database.clone()));
            let fallback = axum::routing::get(
                kittybox::frontend::catchall::<FileStorage>
            )
                .layer(axum::Extension(database.clone()));

            let micropub = kittybox::micropub::router(
                database.clone(),
                http.clone(),
                auth_backend.clone(),
                Arc::clone(jobset)
            );
            let onboarding = kittybox::frontend::onboarding::router(
                database.clone(), http.clone(), Arc::clone(jobset)
            );


            let (webmention, task) = kittybox::webmentions::router(
                kittybox::webmentions::queue::PostgresJobQueue::new(job_queue_uri).await.unwrap(),
                database.clone(),
                http.clone(),
                cancellation_token.clone()
            );

            let router = axum::Router::new()
                .route("/", homepage)
                .fallback(fallback)
                .route("/.kittybox/micropub", micropub)
                .route("/.kittybox/onboarding", onboarding)
                .nest("/.kittybox/media", init_media(auth_backend.clone(), blobstore_uri))
                .merge(kittybox::indieauth::router(auth_backend.clone(), database.clone(), http.clone()))
                .merge(webmention)
                .route(
                    "/.kittybox/health",
                    axum::routing::get(health_check::<kittybox::database::FileStorage>)
                        .layer(axum::Extension(database))
                );

            (router, task)
        },
        "redis" => unimplemented!("Redis backend is not supported."),
        #[cfg(feature = "postgres")]
        "postgres" => {
            use kittybox::database::PostgresStorage;

            let database = {
                match PostgresStorage::new(backend_uri).await {
                    Ok(db) => db,
                    Err(err) => {
                        error!("Error creating database: {:?}", err);
                        std::process::exit(1);
                    }
                }
            };

            // Technically, if we don't construct the micropub router,
            // we could use some wrapper that makes the database
            // read-only.
            //
            // This would allow to exclude all code to write to the
            // database and separate reader and writer processes of
            // Kittybox to improve security.
            let homepage: axum::routing::MethodRouter<_> = axum::routing::get(
                kittybox::frontend::homepage::<PostgresStorage>
            )
                .layer(axum::Extension(database.clone()));
            let fallback = axum::routing::get(
                kittybox::frontend::catchall::<PostgresStorage>
            )
                .layer(axum::Extension(database.clone()));

            let micropub = kittybox::micropub::router(
                database.clone(),
                http.clone(),
                auth_backend.clone(),
                Arc::clone(jobset)
            );
            let onboarding = kittybox::frontend::onboarding::router(
                database.clone(), http.clone(), Arc::clone(jobset)
            );

            let (webmention, task) = kittybox::webmentions::router(
                kittybox::webmentions::queue::PostgresJobQueue::new(job_queue_uri).await.unwrap(),
                database.clone(),
                http.clone(),
                cancellation_token.clone()
            );

            let router = axum::Router::new()
                .route("/", homepage)
                .fallback(fallback)
                .route("/.kittybox/micropub", micropub)
                .route("/.kittybox/onboarding", onboarding)
                .nest("/.kittybox/media", init_media(auth_backend.clone(), blobstore_uri))
                .merge(kittybox::indieauth::router(auth_backend.clone(), database.clone(), http.clone()))
                .merge(webmention)
                .route(
                    "/.kittybox/health",
                    axum::routing::get(health_check::<kittybox::database::PostgresStorage>)
                        .layer(axum::Extension(database))
                );

            (router, task)
        },
        other => unimplemented!("Unsupported backend: {other}")
    }
}

async fn compose_kittybox(
    backend_uri: &str,
    blobstore_uri: &str,
    authstore_uri: &str,
    job_queue_uri: &str,
    jobset: &Arc<tokio::sync::Mutex<tokio::task::JoinSet<()>>>,
    cancellation_token: &tokio_util::sync::CancellationToken
) -> (axum::Router, kittybox::webmentions::SupervisedTask) {
    let http: reqwest::Client = {
        #[allow(unused_mut)]
        let mut builder = reqwest::Client::builder()
            .user_agent(concat!(
                env!("CARGO_PKG_NAME"),
                "/",
                env!("CARGO_PKG_VERSION")
            ));
        if let Ok(certs) = std::env::var("KITTYBOX_CUSTOM_PKI_ROOTS") {
            // TODO: add a root certificate if there's an environment variable pointing at it
            for path in certs.split(':') {
                let metadata = match tokio::fs::metadata(path).await {
                    Ok(metadata) => metadata,
                    Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
                        tracing::error!("TLS root certificate {} not found, skipping...", path);
                        continue;
                    }
                    Err(err) => panic!("Error loading TLS certificates: {}", err)
                };
                if metadata.is_dir() {
                    let mut dir = tokio::fs::read_dir(path).await.unwrap();
                    while let Ok(Some(file)) = dir.next_entry().await {
                        let pem = tokio::fs::read(file.path()).await.unwrap();
                        builder = builder.add_root_certificate(
                            reqwest::Certificate::from_pem(&pem).unwrap()
                        );
                    }
                } else {
                    let pem = tokio::fs::read(path).await.unwrap();
                    builder = builder.add_root_certificate(
                        reqwest::Certificate::from_pem(&pem).unwrap()
                    );
                }
            }
        }

        builder.build().unwrap()
    };

    let (router, task) = match authstore_uri.split_once(':').unwrap().0 {
        "file" => {
            let auth_backend = {
                let folder = authstore_uri
                    .strip_prefix("file://")
                    .unwrap();
                kittybox::indieauth::backend::fs::FileBackend::new(folder)
            };

            compose_kittybox_with_auth(http, auth_backend, backend_uri, blobstore_uri, job_queue_uri, jobset, cancellation_token).await
        }
        other => unimplemented!("Unsupported backend: {other}")
    };

    // TODO: load from environment
    let cookie_key = axum_extra::extract::cookie::Key::generate();

    let router = router
        .route(
            "/.kittybox/static/:path",
            axum::routing::get(kittybox::frontend::statics)
        )
        .route("/.kittybox/coffee", teapot_route())
        .nest("/.kittybox/micropub/client", kittybox::companion::router())
        .nest("/.kittybox/login", kittybox::login::router(cookie_key))
        .layer(tower_http::trace::TraceLayer::new_for_http())
        .layer(tower_http::catch_panic::CatchPanicLayer::new())
        .layer(tower_http::sensitive_headers::SetSensitiveHeadersLayer::new([
            axum::http::header::AUTHORIZATION,
            axum::http::header::COOKIE,
            axum::http::header::SET_COOKIE,
        ]));

    (router, task)
}

fn teapot_route() -> axum::routing::MethodRouter {
    axum::routing::get(|| async {
        use axum::http::{header, StatusCode};
        (StatusCode::IM_A_TEAPOT, [(header::CONTENT_TYPE, "text/plain")], "Sorry, can't brew coffee yet!")
    })
}

async fn health_check</*A, B, */D>(
    //axum::Extension(auth): axum::Extension<A>,
    //axum::Extension(blob): axum::Extension<B>,
    axum::Extension(data): axum::Extension<D>,
) -> impl axum::response::IntoResponse
where
    //A: kittybox::indieauth::backend::AuthBackend,
    //B: kittybox::media::storage::MediaStore,
    D: kittybox::database::Storage
{
    (axum::http::StatusCode::OK, std::borrow::Cow::Borrowed("OK"))
}

#[tokio::main]
async fn main() {
    use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Registry};

    let tracing_registry = Registry::default()
        .with(EnvFilter::from_default_env())
        .with(
            #[cfg(debug_assertions)]
            tracing_tree::HierarchicalLayer::new(2)
                .with_bracketed_fields(true)
                .with_indent_lines(true)
                .with_verbose_exit(true),
            #[cfg(not(debug_assertions))]
            tracing_subscriber::fmt::layer().json()
                .with_ansi(std::io::IsTerminal::is_terminal(&std::io::stdout().lock()))
        );
    // In debug builds, also log to JSON, but to file.
    #[cfg(debug_assertions)]
    let tracing_registry = tracing_registry.with(
        tracing_subscriber::fmt::layer()
            .json()
            .with_writer({
                let instant = std::time::SystemTime::now()
                        .duration_since(std::time::UNIX_EPOCH)
                        .unwrap();
                move || std::fs::OpenOptions::new()
                    .append(true)
                    .create(true)
                    .open(
                        format!(
                            "{}.log.json",
                            instant
                            .as_secs_f64()
                            .to_string()
                            .replace('.', "_")
                        )
                    ).unwrap()
            })
    );
    tracing_registry.init();

    tracing::info!("Starting the kittybox server...");

    let backend_uri: String = env::var("BACKEND_URI")
        .unwrap_or_else(|_| {
            error!("BACKEND_URI is not set, cannot find a database");
            std::process::exit(1);
        });
    let blobstore_uri: String = env::var("BLOBSTORE_URI")
        .unwrap_or_else(|_| {
            error!("BLOBSTORE_URI is not set, can't find media store");
            std::process::exit(1);
        });

    let authstore_uri: String = env::var("AUTH_STORE_URI")
        .unwrap_or_else(|_| {
            error!("AUTH_STORE_URI is not set, can't find authentication store");
            std::process::exit(1);
        });

    let job_queue_uri: String = env::var("JOB_QUEUE_URI")
        .unwrap_or_else(|_| {
            error!("JOB_QUEUE_URI is not set, can't find job queue");
            std::process::exit(1);
        });

    let cancellation_token = tokio_util::sync::CancellationToken::new();
    let jobset = Arc::new(tokio::sync::Mutex::new(tokio::task::JoinSet::new()));

    let (router, webmentions_task) = compose_kittybox(
        backend_uri.as_str(),
        blobstore_uri.as_str(),
        authstore_uri.as_str(),
        job_queue_uri.as_str(),
        &jobset,
        &cancellation_token
    ).await;

    let mut servers: Vec<hyper::server::Server<hyper::server::conn::AddrIncoming, _>> = vec![];

    let build_hyper = |tcp: std::net::TcpListener| {
        tracing::info!("Listening on {}", tcp.local_addr().unwrap());
        // Set the socket to non-blocking so tokio can poll it
        // properly -- this is the async magic!
        tcp.set_nonblocking(true).unwrap();

        hyper::server::Server::from_tcp(tcp).unwrap()
            // Otherwise Chrome keeps connections open for too long
            .tcp_keepalive(Some(Duration::from_secs(30 * 60)))
            .serve(router.clone().into_make_service())
    };

    let mut listenfd = listenfd::ListenFd::from_env();
    for i in 0..(listenfd.len()) {
        match listenfd.take_tcp_listener(i) {
            Ok(Some(tcp)) => servers.push(build_hyper(tcp)),
            Ok(None) => {},
            Err(err) => tracing::error!("Error binding to socket in fd {}: {}", i, err)
        }
    }
    // TODO this requires the `hyperlocal` crate
    //#[rustfmt::skip]
    /*#[cfg(unix)] {
        let build_hyper_unix = |unix: std::os::unix::net::UnixListener| {
            {
                use std::os::linux::net::SocketAddrExt;

                let local_addr = unix.local_addr().unwrap();
                if let Some(pathname) = local_addr.as_pathname() {
                    tracing::info!("Listening on unix:{}", pathname.display());
                } else if let Some(name) = {
                    #[cfg(linux)]
                    local_addr.as_abstract_name();
                    #[cfg(not(linux))]
                    None::<&[u8]>
                } {
                    tracing::info!("Listening on unix:@{}", String::from_utf8_lossy(name));
                } else {
                    tracing::info!("Listening on unnamed unix socket");
                }
            }
            unix.set_nonblocking(true).unwrap();

            hyper::server::Server::builder(unix)
                .serve(router.clone().into_make_service())
        };
        for i in 0..(listenfd.len()) {
            match listenfd.take_unix_listener(i) {
                Ok(Some(unix)) => servers.push(build_hyper_unix(unix)),
                Ok(None) => {},
                Err(err) => tracing::error!("Error binding to socket in fd {}: {}", i, err)
            }
        }
    }*/
    if servers.is_empty() {
        servers.push(build_hyper({
            let listen_addr = env::var("SERVE_AT")
                .ok()
                .unwrap_or_else(|| "[::]:8080".to_string())
                .parse::<std::net::SocketAddr>()
                .unwrap_or_else(|e| {
                    error!("Cannot parse SERVE_AT: {}", e);
                    std::process::exit(1);
                });

            std::net::TcpListener::bind(listen_addr).unwrap()
        }))
    }
    // Drop the remaining copy of the router
    // to get rid of an extra reference to `jobset`
    drop(router);
    // Polling streams mutates them
    let mut servers_futures = Box::pin(servers.into_iter()
        .map(
            #[cfg(not(tokio_unstable))] |server| tokio::task::spawn(
                server.with_graceful_shutdown(cancellation_token.clone().cancelled_owned())
            ),
            #[cfg(tokio_unstable)] |server| {
                tokio::task::Builder::new()
                    .name(format!("Kittybox HTTP acceptor: {}", server.local_addr()).as_str())
                    .spawn(
                        server.with_graceful_shutdown(
                            cancellation_token.clone().cancelled_owned()
                        )
                    )
                    .unwrap()
            }
        )
        .collect::<futures_util::stream::FuturesUnordered<tokio::task::JoinHandle<Result<(), hyper::Error>>>>()
    );

    #[cfg(not(unix))]
    let shutdown_signal = tokio::signal::ctrl_c();
    #[cfg(unix)]
    let shutdown_signal = {
        use tokio::signal::unix::{signal, SignalKind};

        async move {
            let mut interrupt = signal(SignalKind::interrupt())
                .expect("Failed to set up SIGINT handler");
            let mut terminate = signal(SignalKind::terminate())
                .expect("Failed to setup SIGTERM handler");

            tokio::select! {
                _ = terminate.recv() => {},
                _ = interrupt.recv() => {},
            }
        }
    };
    use futures_util::stream::StreamExt;

    let exitcode: i32 = tokio::select! {
        // Poll the servers stream for errors.
        // If any error out, shut down the entire operation
        //
        // We do this because there might not be a good way
        // to recover from some errors without external help
        Some(Err(e)) = servers_futures.next() => {
            tracing::error!("Error in HTTP server: {}", e);
            tracing::error!("Shutting down because of error.");
            cancellation_token.cancel();

            1
        }
        _ = cancellation_token.cancelled() => {
            tracing::info!("Signal caught from watchdog.");

            0
        }
        _ = shutdown_signal => {
            tracing::info!("Shutdown requested by signal.");
            cancellation_token.cancel();

            0
        }
    };

    tracing::info!("Waiting for unfinished background tasks...");

    let _ = tokio::join!(
        webmentions_task,
        Box::pin(futures_util::future::join_all(
            servers_futures.iter_mut().collect::<Vec<_>>()
        )),
    );
    let mut jobset: tokio::task::JoinSet<()> = Arc::try_unwrap(jobset)
        .expect("Dangling jobset references present")
        .into_inner();
    while (jobset.join_next().await).is_some() {}
    tracing::info!("Shutdown complete, exiting.");
    std::process::exit(exitcode);

}