use axum::{extract::{FromRef, State}, response::{IntoResponse, Response}, routing::post, Form};
use axum::http::StatusCode;
use tracing::error;
use crate::database::{Storage, StorageError};
use self::queue::JobQueue;
pub mod queue;
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))]
pub struct Webmention {
source: String,
target: String,
}
impl queue::JobItem for Webmention {}
impl queue::PostgresJobItem for Webmention {
const DATABASE_NAME: &'static str = "kittybox_webmention.incoming_webmention_queue";
const NOTIFICATION_CHANNEL: &'static str = "incoming_webmention";
}
async fn accept_webmention<Q: JobQueue<Webmention>>(
State(queue): State<Q>,
Form(webmention): Form<Webmention>,
) -> Response {
if let Err(err) = webmention.source.parse::<url::Url>() {
return (StatusCode::BAD_REQUEST, err.to_string()).into_response()
}
if let Err(err) = webmention.target.parse::<url::Url>() {
return (StatusCode::BAD_REQUEST, err.to_string()).into_response()
}
match queue.put(&webmention).await {
Ok(_id) => StatusCode::ACCEPTED.into_response(),
Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, [
("Content-Type", "text/plain")
], err.to_string()).into_response()
}
}
pub fn router<St: Clone + Send + Sync + 'static, Q: JobQueue<Webmention> + FromRef<St>>() -> axum::Router<St> {
axum::Router::new()
.route("/.kittybox/webmention", post(accept_webmention::<Q>))
}
#[derive(thiserror::Error, Debug)]
pub enum SupervisorError {
#[error("the task was explicitly cancelled")]
Cancelled
}
pub type SupervisedTask = tokio::task::JoinHandle<Result<std::convert::Infallible, SupervisorError>>;
pub fn supervisor<E, A, F>(mut f: F, cancellation_token: tokio_util::sync::CancellationToken) -> SupervisedTask
where
E: std::error::Error + std::fmt::Debug + Send + 'static,
A: std::future::Future<Output = Result<std::convert::Infallible, E>> + Send + 'static,
F: FnMut() -> A + Send + 'static
{
let supervisor_future = async move {
loop {
// Don't spawn the task if we are already cancelled, but
// have somehow missed it (probably because the task
// crashed and we immediately received a cancellation
// request after noticing the crashed task)
if cancellation_token.is_cancelled() {
return Err(SupervisorError::Cancelled)
}
let task = tokio::task::spawn(f());
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::info!("Shutdown of background task {:?} requested.", std::any::type_name::<A>());
return Err(SupervisorError::Cancelled)
}
task_result = task => match task_result {
Err(e) => tracing::error!("background task {:?} exited unexpectedly: {}", std::any::type_name::<A>(), e),
Ok(Err(e)) => tracing::error!("background task {:?} returned error: {}", std::any::type_name::<A>(), e),
Ok(Ok(_)) => unreachable!("task's Ok is Infallible")
}
}
tracing::debug!("Sleeping for a little while to back-off...");
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
}
};
#[cfg(not(tokio_unstable))]
return tokio::task::spawn(supervisor_future);
#[cfg(tokio_unstable)]
return tokio::task::Builder::new()
.name(format!("supervisor for background task {}", std::any::type_name::<A>()).as_str())
.spawn(supervisor_future)
.unwrap();
}
pub mod check;
#[derive(thiserror::Error, Debug)]
enum Error<Q: std::error::Error + std::fmt::Debug + Send + 'static> {
#[error("queue error: {0}")]
Queue(#[from] Q),
#[error("storage error: {0}")]
Storage(StorageError)
}
async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 'static>(queue: Q, db: S, http: reqwest_middleware::ClientWithMiddleware) -> Result<std::convert::Infallible, Error<Q::Error>> {
use futures_util::StreamExt;
use self::queue::Job;
let mut stream = queue.into_stream().await?;
while let Some(item) = stream.next().await.transpose()? {
let job = item.job();
let (source, target) = (
job.source.parse::<url::Url>().unwrap(),
job.target.parse::<url::Url>().unwrap()
);
let (code, text) = match http.get(source.clone()).send().await {
Ok(response) => {
let code = response.status();
if ![StatusCode::OK, StatusCode::GONE].iter().any(|i| i == &code) {
error!("error processing webmention: webpage fetch returned {}", code);
continue;
}
match response.text().await {
Ok(text) => (code, text),
Err(err) => {
error!("error processing webmention: error fetching webpage text: {}", err);
continue
}
}
}
Err(err) => {
error!("error processing webmention: error requesting webpage: {}", err);
continue
}
};
if code == StatusCode::GONE {
todo!("removing webmentions is not implemented yet");
// db.remove_webmention(target.as_str(), source.as_str()).await.map_err(Error::<Q::Error>::Storage)?;
} else {
// Verify webmention
let (mention_type, mut mention) = match tokio::task::block_in_place({
|| check::check_mention(text, &source, &target)
}) {
Ok(Some(mention_type)) => mention_type,
Ok(None) => {
error!("webmention {} -> {} invalid, rejecting", source, target);
item.done().await?;
continue;
}
Err(err) => {
error!("error processing webmention: error checking webmention: {}", err);
continue;
}
};
{
mention["type"] = serde_json::json!(["h-cite"]);
if !mention["properties"].as_object().unwrap().contains_key("uid") {
let url = mention["properties"]["url"][0].as_str().unwrap_or_else(|| target.as_str()).to_owned();
let props = mention["properties"].as_object_mut().unwrap();
props.insert("uid".to_owned(), serde_json::Value::Array(
vec![serde_json::Value::String(url)])
);
}
}
db.add_or_update_webmention(target.as_str(), mention_type, mention).await.map_err(Error::<Q::Error>::Storage)?;
}
}
unreachable!()
}
pub fn supervised_webmentions_task<St: Send + Sync + 'static, S: Storage + FromRef<St> + 'static, Q: JobQueue<Webmention> + FromRef<St> + 'static>(
state: &St,
cancellation_token: tokio_util::sync::CancellationToken
) -> SupervisedTask
where reqwest_middleware::ClientWithMiddleware: FromRef<St>
{
let queue = Q::from_ref(state);
let storage = S::from_ref(state);
let http = reqwest_middleware::ClientWithMiddleware::from_ref(state);
supervisor::<Error<Q::Error>, _, _>(move || process_webmentions_from_queue(
queue.clone(), storage.clone(), http.clone()
), cancellation_token)
}