diff options
Diffstat (limited to 'kittybox-rs/src')
-rw-r--r-- | kittybox-rs/src/frontend/onboarding.rs | 18 | ||||
-rw-r--r-- | kittybox-rs/src/main.rs | 54 | ||||
-rw-r--r-- | kittybox-rs/src/micropub/mod.rs | 27 |
3 files changed, 71 insertions, 28 deletions
diff --git a/kittybox-rs/src/frontend/onboarding.rs b/kittybox-rs/src/frontend/onboarding.rs index d5cde02..e44e866 100644 --- a/kittybox-rs/src/frontend/onboarding.rs +++ b/kittybox-rs/src/frontend/onboarding.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::database::{settings, Storage}; use axum::{ extract::{Extension, Host}, @@ -7,6 +9,7 @@ use axum::{ }; use kittybox_frontend_renderer::{ErrorPage, OnboardingPage, Template}; use serde::Deserialize; +use tokio::{task::JoinSet, sync::Mutex}; use tracing::{debug, error}; use super::FrontendError; @@ -51,6 +54,7 @@ async fn onboard<D: Storage + 'static>( user_uid: url::Url, data: OnboardingData, http: reqwest::Client, + jobset: Arc<Mutex<JoinSet<()>>>, ) -> Result<(), FrontendError> { // Create a user to pass to the backend // At this point the site belongs to nobody, so it is safe to do @@ -115,7 +119,7 @@ async fn onboard<D: Storage + 'static>( } let (uid, post) = crate::micropub::normalize_mf2(data.first_post, &user); tracing::debug!("Posting first post {}...", uid); - crate::micropub::_post(&user, uid, post, db, http) + crate::micropub::_post(&user, uid, post, db, http, jobset) .await .map_err(|e| FrontendError { msg: "Error while posting the first post".to_string(), @@ -130,6 +134,7 @@ pub async fn post<D: Storage + 'static>( Extension(db): Extension<D>, Host(host): Host, Extension(http): Extension<reqwest::Client>, + Extension(jobset): Extension<Arc<Mutex<JoinSet<()>>>>, Json(data): Json<OnboardingData>, ) -> axum::response::Response { let user_uid = format!("https://{}/", host.as_str()); @@ -137,7 +142,7 @@ pub async fn post<D: Storage + 'static>( if db.post_exists(&user_uid).await.unwrap() { IntoResponse::into_response((StatusCode::FOUND, [("Location", "/")])) } else { - match onboard(db, user_uid.parse().unwrap(), data, http).await { + match onboard(db, user_uid.parse().unwrap(), data, http, jobset).await { Ok(()) => IntoResponse::into_response((StatusCode::FOUND, [("Location", "/")])), Err(err) => { error!("Onboarding error: {}", err); @@ -163,9 +168,14 @@ pub async fn post<D: Storage + 'static>( } } -pub fn router<S: Storage + 'static>(database: S, http: reqwest::Client) -> axum::routing::MethodRouter { +pub fn router<S: Storage + 'static>( + database: S, + http: reqwest::Client, + jobset: Arc<Mutex<JoinSet<()>>>, +) -> axum::routing::MethodRouter { axum::routing::get(get) .post(post::<S>) .layer::<_, _, std::convert::Infallible>(axum::Extension(database)) - .layer(axum::Extension(http)) + .layer::<_, _, std::convert::Infallible>(axum::Extension(http)) + .layer(axum::Extension(jobset)) } diff --git a/kittybox-rs/src/main.rs b/kittybox-rs/src/main.rs index cc11f3a..d96a8fb 100644 --- a/kittybox-rs/src/main.rs +++ b/kittybox-rs/src/main.rs @@ -1,6 +1,6 @@ use kittybox::database::FileStorage; -use std::{env, time::Duration}; -use tracing::{debug, error, info}; +use std::{env, time::Duration, sync::Arc}; +use tracing::error; fn init_media<A: kittybox::indieauth::backend::AuthBackend>(auth_backend: A, blobstore_uri: &str) -> axum::Router { match blobstore_uri.split_once(':').unwrap().0 { @@ -20,7 +20,8 @@ async fn compose_kittybox_with_auth<A>( http: reqwest::Client, auth_backend: A, backend_uri: &str, - blobstore_uri: &str + blobstore_uri: &str, + jobset: &Arc<tokio::sync::Mutex<tokio::task::JoinSet<()>>> ) -> axum::Router where A: kittybox::indieauth::backend::AuthBackend { @@ -58,10 +59,11 @@ where A: kittybox::indieauth::backend::AuthBackend let micropub = kittybox::micropub::router( database.clone(), http.clone(), - auth_backend.clone() + auth_backend.clone(), + Arc::clone(jobset) ); let onboarding = kittybox::frontend::onboarding::router( - database.clone(), http.clone() + database.clone(), http.clone(), Arc::clone(jobset) ); axum::Router::new() @@ -111,13 +113,14 @@ where A: kittybox::indieauth::backend::AuthBackend let micropub = kittybox::micropub::router( database.clone(), http.clone(), - auth_backend.clone() + auth_backend.clone(), + Arc::clone(jobset) ); let onboarding = kittybox::frontend::onboarding::router( - database.clone(), http.clone() + database.clone(), http.clone(), Arc::clone(jobset) ); - axum::Router::new() + let router = axum::Router::new() .route("/", homepage) .fallback(fallback) .route("/.kittybox/micropub", micropub) @@ -128,7 +131,9 @@ where A: kittybox::indieauth::backend::AuthBackend "/.kittybox/health", axum::routing::get(health_check::<kittybox::database::PostgresStorage>) .layer(axum::Extension(database)) - ) + ); + + router }, other => unimplemented!("Unsupported backend: {other}") } @@ -138,6 +143,7 @@ async fn compose_kittybox( backend_uri: &str, blobstore_uri: &str, authstore_uri: &str, + jobset: &Arc<tokio::sync::Mutex<tokio::task::JoinSet<()>>>, cancellation_token: &tokio_util::sync::CancellationToken ) -> axum::Router { let http: reqwest::Client = { @@ -162,7 +168,7 @@ async fn compose_kittybox( kittybox::indieauth::backend::fs::FileBackend::new(folder) }; - compose_kittybox_with_auth(http, auth_backend, backend_uri, blobstore_uri).await + compose_kittybox_with_auth(http, auth_backend, backend_uri, blobstore_uri, jobset).await } other => unimplemented!("Unsupported backend: {other}") }; @@ -214,7 +220,7 @@ async fn main() { .init(); //let _ = tracing_log::LogTracer::init(); - info!("Starting the kittybox server..."); + tracing::info!("Starting the kittybox server..."); let backend_uri: String = env::var("BACKEND_URI") .unwrap_or_else(|_| { @@ -237,11 +243,13 @@ async fn main() { }); let cancellation_token = tokio_util::sync::CancellationToken::new(); + let jobset = Arc::new(tokio::sync::Mutex::new(tokio::task::JoinSet::new())); let router = compose_kittybox( backend_uri.as_str(), blobstore_uri.as_str(), authstore_uri.as_str(), + &jobset, &cancellation_token ).await; @@ -315,7 +323,9 @@ async fn main() { std::net::TcpListener::bind(listen_addr).unwrap() })) } - + // Drop the remaining copy of the router + // to get rid of an extra reference to `jobset` + drop(router); // Polling streams mutates them let mut servers_futures = Box::pin(servers.into_iter() .map( @@ -356,7 +366,7 @@ async fn main() { }; use futures_util::stream::StreamExt; - tokio::select! { + let exitcode: i32 = tokio::select! { // Poll the servers stream for errors. // If any error out, shut down the entire operation // @@ -370,17 +380,25 @@ async fn main() { servers_futures.iter_mut().collect::<Vec<_>>() )).await; - std::process::exit(1); + 1 } _ = shutdown_signal => { - info!("Shutdown requested by signal."); + tracing::info!("Shutdown requested by signal."); cancellation_token.cancel(); let _ = Box::pin(futures_util::future::join_all( servers_futures.iter_mut().collect::<Vec<_>>() )).await; - info!("Shutdown complete, exiting."); - std::process::exit(0); + 0 } - } + }; + + tracing::info!("Waiting for unfinished background tasks..."); + let mut jobset: tokio::task::JoinSet<()> = Arc::try_unwrap(jobset) + .expect("Dangling jobset references present") + .into_inner(); + while (jobset.join_next().await).is_some() {} + tracing::info!("Shutdown complete, exiting."); + std::process::exit(exitcode); + } diff --git a/kittybox-rs/src/micropub/mod.rs b/kittybox-rs/src/micropub/mod.rs index 04bf0a5..02eee6e 100644 --- a/kittybox-rs/src/micropub/mod.rs +++ b/kittybox-rs/src/micropub/mod.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use crate::database::{MicropubChannel, Storage, StorageError}; use crate::indieauth::backend::AuthBackend; @@ -11,6 +12,8 @@ use axum::TypedHeader; use axum::{http::StatusCode, Extension}; use serde::{Deserialize, Serialize}; use serde_json::json; +use tokio::sync::Mutex; +use tokio::task::JoinSet; use tracing::{debug, error, info, warn}; use kittybox_indieauth::{Scope, TokenData}; use kittybox_util::{MicropubError, ErrorType}; @@ -228,6 +231,7 @@ pub(crate) async fn _post<D: 'static + Storage>( mf2: serde_json::Value, db: D, http: reqwest::Client, + jobset: Arc<Mutex<JoinSet<()>>>, ) -> Result<Response, MicropubError> { // Here, we have the following guarantees: // - The MF2-JSON document is normalized (guaranteed by normalize_mf2) @@ -313,7 +317,12 @@ pub(crate) async fn _post<D: 'static + Storage>( let reply = IntoResponse::into_response((StatusCode::ACCEPTED, [("Location", uid.as_str())])); - tokio::task::spawn(background_processing(db, mf2, http)); + #[cfg(not(tokio_unstable))] + jobset.lock().await.spawn(background_processing(db, mf2, http)); + #[cfg(tokio_unstable)] + jobset.lock().await.build_task() + .name(format!("Kittybox background processing for post {}", uid.as_str()).as_str()) + .spawn(background_processing(db, mf2, http)); Ok(reply) } @@ -497,6 +506,7 @@ async fn dispatch_body( pub(crate) async fn post<D: Storage + 'static, A: AuthBackend>( Extension(db): Extension<D>, Extension(http): Extension<reqwest::Client>, + Extension(jobset): Extension<Arc<Mutex<JoinSet<()>>>>, TypedHeader(content_type): TypedHeader<ContentType>, user: User<A>, body: BodyStream, @@ -508,7 +518,7 @@ pub(crate) async fn post<D: Storage + 'static, A: AuthBackend>( }, Ok(PostBody::MF2(mf2)) => { let (uid, mf2) = normalize_mf2(mf2, &user); - match _post(&user, uid, mf2, db, http).await { + match _post(&user, uid, mf2, db, http, jobset).await { Ok(response) => response, Err(err) => err.into_response(), } @@ -631,7 +641,8 @@ pub(crate) async fn query<D: Storage, A: AuthBackend>( pub fn router<S, A>( storage: S, http: reqwest::Client, - auth: A + auth: A, + jobset: Arc<Mutex<JoinSet<()>>> ) -> axum::routing::MethodRouter where S: Storage + 'static, @@ -648,6 +659,7 @@ where .layer::<_, _, std::convert::Infallible>(axum::Extension(storage)) .layer::<_, _, std::convert::Infallible>(axum::Extension(http)) .layer::<_, _, std::convert::Infallible>(axum::Extension(auth)) + .layer::<_, _, std::convert::Infallible>(axum::Extension(jobset)) } #[cfg(test)] @@ -670,9 +682,12 @@ impl MicropubQuery { #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::{database::Storage, micropub::MicropubError}; use hyper::body::HttpBody; use serde_json::json; + use tokio::sync::Mutex; use super::FetchedPostContext; use kittybox_indieauth::{Scopes, Scope, TokenData}; @@ -733,7 +748,7 @@ mod tests { }; let (uid, mf2) = super::normalize_mf2(post, &user); - let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new()) + let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) .await .unwrap_err(); @@ -763,7 +778,7 @@ mod tests { }; let (uid, mf2) = super::normalize_mf2(post, &user); - let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new()) + let err = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) .await .unwrap_err(); @@ -791,7 +806,7 @@ mod tests { }; let (uid, mf2) = super::normalize_mf2(post, &user); - let res = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new()) + let res = super::_post(&user, uid, mf2, db.clone(), reqwest::Client::new(), Arc::new(Mutex::new(tokio::task::JoinSet::new()))) .await .unwrap(); |