use base64::Engine;
use kittybox::{database::Storage, indieauth::backend::AuthBackend, media::storage::MediaStore, webmentions::Webmention, compose_kittybox};
use tokio::{sync::Mutex, task::JoinSet};
use std::{env, future::IntoFuture, sync::Arc};
use tracing::error;
#[tokio::main]
async fn main() {
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Registry};
let tracing_registry = Registry::default()
.with(EnvFilter::from_default_env())
.with(
#[cfg(debug_assertions)]
tracing_tree::HierarchicalLayer::new(2)
.with_bracketed_fields(true)
.with_indent_lines(true)
.with_verbose_exit(true),
#[cfg(not(debug_assertions))]
tracing_subscriber::fmt::layer().json()
.with_ansi(std::io::IsTerminal::is_terminal(&std::io::stdout().lock()))
);
// In debug builds, also log to JSON, but to file.
#[cfg(debug_assertions)]
let tracing_registry = tracing_registry.with(
tracing_subscriber::fmt::layer()
.json()
.with_writer({
let instant = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap();
move || std::fs::OpenOptions::new()
.append(true)
.create(true)
.open(
format!(
"{}.log.json",
instant
.as_secs_f64()
.to_string()
.replace('.', "_")
)
).unwrap()
})
);
tracing_registry.init();
tracing::info!("Starting the kittybox server...");
let backend_uri: url::Url = env::var("BACKEND_URI")
.as_deref()
.map(|s| url::Url::parse(s).expect("BACKEND_URI malformed"))
.unwrap_or_else(|_| {
error!("BACKEND_URI is not set, cannot find a database");
std::process::exit(1);
});
let blobstore_uri: url::Url = env::var("BLOBSTORE_URI")
.as_deref()
.map(|s| url::Url::parse(s).expect("BLOBSTORE_URI malformed"))
.unwrap_or_else(|_| {
error!("BLOBSTORE_URI is not set, can't find media store");
std::process::exit(1);
});
let authstore_uri: url::Url = env::var("AUTH_STORE_URI")
.as_deref()
.map(|s| url::Url::parse(s).expect("AUTH_STORE_URI malformed"))
.unwrap_or_else(|_| {
error!("AUTH_STORE_URI is not set, can't find authentication store");
std::process::exit(1);
});
let job_queue_uri: url::Url = env::var("JOB_QUEUE_URI")
.as_deref()
.map(|s| url::Url::parse(s).expect("JOB_QUEUE_URI malformed"))
.unwrap_or_else(|_| {
error!("JOB_QUEUE_URI is not set, can't find job queue");
std::process::exit(1);
});
// TODO: load from environment
let cookie_key = axum_extra::extract::cookie::Key::from(&env::var("COOKIE_KEY")
.as_deref()
.map(|s| base64::prelude::BASE64_STANDARD.decode(s.as_bytes())
.expect("Invalid cookie key: must be base64 encoded")
)
.unwrap()
);
let cancellation_token = tokio_util::sync::CancellationToken::new();
let jobset: Arc<Mutex<JoinSet<()>>> = Default::default();
let session_store: kittybox::SessionStore = Default::default();
let http: reqwest_middleware::ClientWithMiddleware = {
#[allow(unused_mut)]
let mut builder = reqwest::Client::builder()
.user_agent(concat!(
env!("CARGO_PKG_NAME"),
"/",
env!("CARGO_PKG_VERSION")
));
if let Ok(certs) = std::env::var("KITTYBOX_CUSTOM_PKI_ROOTS") {
// TODO: add a root certificate if there's an environment variable pointing at it
for path in certs.split(':') {
let metadata = match tokio::fs::metadata(path).await {
Ok(metadata) => metadata,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
tracing::error!("TLS root certificate {} not found, skipping...", path);
continue;
}
Err(err) => panic!("Error loading TLS certificates: {}", err)
};
if metadata.is_dir() {
let mut dir = tokio::fs::read_dir(path).await.unwrap();
while let Ok(Some(file)) = dir.next_entry().await {
let pem = tokio::fs::read(file.path()).await.unwrap();
builder = builder.add_root_certificate(
reqwest::Certificate::from_pem(&pem).unwrap()
);
}
} else {
let pem = tokio::fs::read(path).await.unwrap();
builder = builder.add_root_certificate(
reqwest::Certificate::from_pem(&pem).unwrap()
);
}
}
}
// This only works on debug builds. Don't get any funny thoughts.
#[cfg(debug_assertions)]
if std::env::var("KITTYBOX_DANGER_INSECURE_TLS")
.map(|y| y == "1")
.unwrap_or(false)
{
builder = builder.danger_accept_invalid_certs(true);
}
reqwest_middleware::ClientBuilder::new(builder.build().unwrap())
.with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache {
mode: http_cache_reqwest::CacheMode::Default,
manager: http_cache_reqwest::MokaManager::default(),
options: http_cache_reqwest::HttpCacheOptions::default(),
}))
.build()
};
let backend_type = backend_uri.scheme();
let blobstore_type = blobstore_uri.scheme();
let authstore_type = authstore_uri.scheme();
let job_queue_type = job_queue_uri.scheme();
macro_rules! compose_kittybox {
($auth:ty, $store:ty, $media:ty, $queue:ty) => { {
type AuthBackend = $auth;
type Storage = $store;
type MediaStore = $media;
type JobQueue = $queue;
let state = kittybox::AppState {
auth_backend: match AuthBackend::new(&authstore_uri).await {
Ok(auth) => auth,
Err(err) => {
error!("Error creating auth backend: {:?}", err);
std::process::exit(1);
}
},
storage: match Storage::new(&backend_uri).await {
Ok(db) => db,
Err(err) => {
error!("Error creating database: {:?}", err);
std::process::exit(1);
}
},
media_store: match MediaStore::new(&blobstore_uri).await {
Ok(media) => media,
Err(err) => {
error!("Error creating media store: {:?}", err);
std::process::exit(1);
}
},
job_queue: match JobQueue::new(&job_queue_uri).await {
Ok(queue) => queue,
Err(err) => {
error!("Error creating webmention job queue: {:?}", err);
std::process::exit(1);
}
},
http,
background_jobs: jobset.clone(),
cookie_key,
session_store,
};
type St = kittybox::AppState<AuthBackend, Storage, MediaStore, JobQueue>;
let stateful_router = compose_kittybox::<St, AuthBackend, Storage, MediaStore, JobQueue>().await;
let task = kittybox::webmentions::supervised_webmentions_task::<St, Storage, JobQueue>(&state, cancellation_token.clone());
let router = stateful_router.with_state(state);
(router, task)
} }
}
let (router, webmentions_task): (axum::Router<()>, kittybox::webmentions::SupervisedTask) = match (authstore_type, backend_type, blobstore_type, job_queue_type) {
("file", "file", "file", "postgres") => {
compose_kittybox!(
kittybox::indieauth::backend::fs::FileBackend,
kittybox::database::FileStorage,
kittybox::media::storage::file::FileStore,
kittybox::webmentions::queue::PostgresJobQueue<Webmention>
)
},
("file", "postgres", "file", "postgres") => {
compose_kittybox!(
kittybox::indieauth::backend::fs::FileBackend,
kittybox::database::PostgresStorage,
kittybox::media::storage::file::FileStore,
kittybox::webmentions::queue::PostgresJobQueue<Webmention>
)
},
(_, _, _, _) => {
// TODO: refine this error.
panic!("Invalid type for AUTH_STORE_URI, BACKEND_URI, BLOBSTORE_URI or JOB_QUEUE_URI");
}
};
let mut servers: Vec<axum::serve::Serve<_, _>> = vec![];
let build_hyper = |tcp: std::net::TcpListener| {
tracing::info!("Listening on {}", tcp.local_addr().unwrap());
// Set the socket to non-blocking so tokio can poll it
// properly -- this is the async magic!
tcp.set_nonblocking(true).unwrap();
//hyper::server::Server::from_tcp(tcp).unwrap()
// // Otherwise Chrome keeps connections open for too long
// .tcp_keepalive(Some(Duration::from_secs(30 * 60)))
// .serve(router.clone().into_make_service())
axum::serve(
tokio::net::TcpListener::from_std(tcp).unwrap(),
router.clone()
)
};
let mut listenfd = listenfd::ListenFd::from_env();
for i in 0..(listenfd.len()) {
match listenfd.take_tcp_listener(i) {
Ok(Some(tcp)) => servers.push(build_hyper(tcp)),
Ok(None) => {},
Err(err) => tracing::error!("Error binding to socket in fd {}: {}", i, err)
}
}
// TODO this requires the `hyperlocal` crate
//#[rustfmt::skip]
/*#[cfg(unix)] {
let build_hyper_unix = |unix: std::os::unix::net::UnixListener| {
{
use std::os::linux::net::SocketAddrExt;
let local_addr = unix.local_addr().unwrap();
if let Some(pathname) = local_addr.as_pathname() {
tracing::info!("Listening on unix:{}", pathname.display());
} else if let Some(name) = {
#[cfg(linux)]
local_addr.as_abstract_name();
#[cfg(not(linux))]
None::<&[u8]>
} {
tracing::info!("Listening on unix:@{}", String::from_utf8_lossy(name));
} else {
tracing::info!("Listening on unnamed unix socket");
}
}
unix.set_nonblocking(true).unwrap();
hyper::server::Server::builder(unix)
.serve(router.clone().into_make_service())
};
for i in 0..(listenfd.len()) {
match listenfd.take_unix_listener(i) {
Ok(Some(unix)) => servers.push(build_hyper_unix(unix)),
Ok(None) => {},
Err(err) => tracing::error!("Error binding to socket in fd {}: {}", i, err)
}
}
}*/
if servers.is_empty() {
servers.push(build_hyper({
let listen_addr = env::var("SERVE_AT")
.ok()
.unwrap_or_else(|| "[::]:8080".to_string())
.parse::<std::net::SocketAddr>()
.unwrap_or_else(|e| {
error!("Cannot parse SERVE_AT: {}", e);
std::process::exit(1);
});
std::net::TcpListener::bind(listen_addr).unwrap()
}))
}
// Drop the remaining copy of the router
// to get rid of an extra reference to `jobset`
drop(router);
// Polling streams mutates them
let mut servers_futures = Box::pin(servers.into_iter()
.map(
#[cfg(not(tokio_unstable))] |server| tokio::task::spawn(
server.with_graceful_shutdown(cancellation_token.clone().cancelled_owned())
.into_future()
),
#[cfg(tokio_unstable)] |server| {
tokio::task::Builder::new()
.name(format!("Kittybox HTTP acceptor: {:?}", server).as_str())
.spawn(
server.with_graceful_shutdown(
cancellation_token.clone().cancelled_owned()
).into_future()
)
.unwrap()
}
)
.collect::<futures_util::stream::FuturesUnordered<tokio::task::JoinHandle<Result<(), std::io::Error>>>>()
);
#[cfg(not(unix))]
let shutdown_signal = tokio::signal::ctrl_c();
#[cfg(unix)]
let shutdown_signal = {
use tokio::signal::unix::{signal, SignalKind};
async move {
let mut interrupt = signal(SignalKind::interrupt())
.expect("Failed to set up SIGINT handler");
let mut terminate = signal(SignalKind::terminate())
.expect("Failed to setup SIGTERM handler");
tokio::select! {
_ = terminate.recv() => {},
_ = interrupt.recv() => {},
}
}
};
use futures_util::stream::StreamExt;
let exitcode: i32 = tokio::select! {
// Poll the servers stream for errors.
// If any error out, shut down the entire operation
//
// We do this because there might not be a good way
// to recover from some errors without external help
Some(Err(e)) = servers_futures.next() => {
tracing::error!("Error in HTTP server: {}", e);
tracing::error!("Shutting down because of error.");
cancellation_token.cancel();
1
}
_ = cancellation_token.cancelled() => {
tracing::info!("Signal caught from watchdog.");
0
}
_ = shutdown_signal => {
tracing::info!("Shutdown requested by signal.");
cancellation_token.cancel();
0
}
};
tracing::info!("Waiting for unfinished background tasks...");
let _ = tokio::join!(
webmentions_task,
Box::pin(futures_util::future::join_all(
servers_futures.iter_mut().collect::<Vec<_>>()
)),
);
let mut jobset: tokio::task::JoinSet<()> = Arc::try_unwrap(jobset)
.expect("Dangling jobset references present")
.into_inner();
while (jobset.join_next().await).is_some() {}
tracing::info!("Shutdown complete, exiting.");
std::process::exit(exitcode);
}