From 0617663b249f9ca488e5de652108b17d67fbaf45 Mon Sep 17 00:00:00 2001 From: Vika Date: Sat, 29 Jul 2023 21:59:56 +0300 Subject: Moved the entire Kittybox tree into the root --- src/webmentions/check.rs | 113 ++++++++++++++++++ src/webmentions/mod.rs | 195 ++++++++++++++++++++++++++++++ src/webmentions/queue.rs | 303 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 611 insertions(+) create mode 100644 src/webmentions/check.rs create mode 100644 src/webmentions/mod.rs create mode 100644 src/webmentions/queue.rs (limited to 'src/webmentions') diff --git a/src/webmentions/check.rs b/src/webmentions/check.rs new file mode 100644 index 0000000..f7322f7 --- /dev/null +++ b/src/webmentions/check.rs @@ -0,0 +1,113 @@ +use std::{cell::RefCell, rc::Rc}; +use microformats::{types::PropertyValue, html5ever::{self, tendril::TendrilSink}}; +use kittybox_util::MentionType; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("microformats error: {0}")] + Microformats(#[from] microformats::Error), + // #[error("json error: {0}")] + // Json(#[from] serde_json::Error), + #[error("url parse error: {0}")] + UrlParse(#[from] url::ParseError), +} + +#[tracing::instrument] +pub fn check_mention(document: impl AsRef + std::fmt::Debug, base_url: &url::Url, link: &url::Url) -> Result, Error> { + tracing::debug!("Parsing MF2 markup..."); + // First, check the document for MF2 markup + let document = microformats::from_html(document.as_ref(), base_url.clone())?; + + // Get an iterator of all items + let items_iter = document.items.iter() + .map(AsRef::as_ref) + .map(RefCell::borrow); + + for item in items_iter { + tracing::debug!("Processing item: {:?}", item); + + let props = item.properties.borrow(); + for (prop, interaction_type) in [ + ("in-reply-to", MentionType::Reply), ("like-of", MentionType::Like), + ("bookmark-of", MentionType::Bookmark), ("repost-of", MentionType::Repost) + ] { + if let Some(propvals) = props.get(prop) { + tracing::debug!("Has a u-{} property", prop); + for val in propvals { + if let PropertyValue::Url(url) = val { + if url == link { + tracing::debug!("URL matches! Webmention is valid"); + return Ok(Some((interaction_type, serde_json::to_value(&*item).unwrap()))) + } + } + } + } + } + // Process `content` + tracing::debug!("Processing e-content..."); + if let Some(PropertyValue::Fragment(content)) = props.get("content") + .map(Vec::as_slice) + .unwrap_or_default() + .first() + { + tracing::debug!("Parsing HTML data..."); + let root = html5ever::parse_document(html5ever::rcdom::RcDom::default(), Default::default()) + .from_utf8() + .one(content.html.to_owned().as_bytes()) + .document; + + // This is a trick to unwrap recursion into a loop + // + // A list of unprocessed node is made. Then, in each + // iteration, the list is "taken" and replaced with an + // empty list, which is populated with nodes for the next + // iteration of the loop. + // + // Empty list means all nodes were processed. + let mut unprocessed_nodes: Vec> = root.children.borrow().iter().cloned().collect(); + while !unprocessed_nodes.is_empty() { + // "Take" the list out of its memory slot, replace it with an empty list + let nodes = std::mem::take(&mut unprocessed_nodes); + tracing::debug!("Processing list of {} nodes", nodes.len()); + 'nodes_loop: for node in nodes.into_iter() { + // Add children nodes to the list for the next iteration + unprocessed_nodes.extend(node.children.borrow().iter().cloned()); + + if let html5ever::rcdom::NodeData::Element { ref name, ref attrs, .. } = node.data { + // If it's not ``, skip it + if name.local != *"a" { continue; } + let mut is_mention: bool = false; + for attr in attrs.borrow().iter() { + if attr.name.local == *"rel" { + // Don't count `rel="nofollow"` links — a web crawler should ignore them + // and so for purposes of driving visitors they are useless + if attr.value + .as_ref() + .split([',', ' ']) + .any(|v| v == "nofollow") + { + // Skip the entire node. + continue 'nodes_loop; + } + } + // if it's not ``, skip it + if attr.name.local != *"href" { continue; } + // Be forgiving in parsing URLs, and resolve them against the base URL + if let Ok(url) = base_url.join(attr.value.as_ref()) { + if &url == link { + is_mention = true; + } + } + } + if is_mention { + return Ok(Some((MentionType::Mention, serde_json::to_value(&*item).unwrap()))); + } + } + } + } + + } + } + + Ok(None) +} diff --git a/src/webmentions/mod.rs b/src/webmentions/mod.rs new file mode 100644 index 0000000..95ea870 --- /dev/null +++ b/src/webmentions/mod.rs @@ -0,0 +1,195 @@ +use axum::{Form, response::{IntoResponse, Response}, Extension}; +use axum::http::StatusCode; +use tracing::error; + +use crate::database::{Storage, StorageError}; +use self::queue::JobQueue; +pub mod queue; + +#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] +pub struct Webmention { + source: String, + target: String, +} + +impl queue::JobItem for Webmention {} +impl queue::PostgresJobItem for Webmention { + const DATABASE_NAME: &'static str = "kittybox_webmention.incoming_webmention_queue"; + const NOTIFICATION_CHANNEL: &'static str = "incoming_webmention"; +} + +async fn accept_webmention>( + Extension(queue): Extension, + Form(webmention): Form, +) -> Response { + if let Err(err) = webmention.source.parse::() { + return (StatusCode::BAD_REQUEST, err.to_string()).into_response() + } + if let Err(err) = webmention.target.parse::() { + return (StatusCode::BAD_REQUEST, err.to_string()).into_response() + } + + match queue.put(&webmention).await { + Ok(id) => (StatusCode::ACCEPTED, [ + ("Location", format!("/.kittybox/webmention/{id}")) + ]).into_response(), + Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, [ + ("Content-Type", "text/plain") + ], err.to_string()).into_response() + } +} + +pub fn router, S: Storage + 'static>( + queue: Q, db: S, http: reqwest::Client, + cancellation_token: tokio_util::sync::CancellationToken +) -> (axum::Router, SupervisedTask) { + // Automatically spawn a background task to handle webmentions + let bgtask_handle = supervised_webmentions_task(queue.clone(), db, http, cancellation_token); + + let router = axum::Router::new() + .route("/.kittybox/webmention", + axum::routing::post(accept_webmention::) + ) + .layer(Extension(queue)); + + (router, bgtask_handle) +} + +#[derive(thiserror::Error, Debug)] +pub enum SupervisorError { + #[error("the task was explicitly cancelled")] + Cancelled +} + +pub type SupervisedTask = tokio::task::JoinHandle>; + +pub fn supervisor(mut f: F, cancellation_token: tokio_util::sync::CancellationToken) -> SupervisedTask +where + E: std::error::Error + std::fmt::Debug + Send + 'static, + A: std::future::Future> + Send + 'static, + F: FnMut() -> A + Send + 'static +{ + + let supervisor_future = async move { + loop { + // Don't spawn the task if we are already cancelled, but + // have somehow missed it (probably because the task + // crashed and we immediately received a cancellation + // request after noticing the crashed task) + if cancellation_token.is_cancelled() { + return Err(SupervisorError::Cancelled) + } + let task = tokio::task::spawn(f()); + tokio::select! { + _ = cancellation_token.cancelled() => { + tracing::info!("Shutdown of background task {:?} requested.", std::any::type_name::()); + return Err(SupervisorError::Cancelled) + } + task_result = task => match task_result { + Err(e) => tracing::error!("background task {:?} exited unexpectedly: {}", std::any::type_name::(), e), + Ok(Err(e)) => tracing::error!("background task {:?} returned error: {}", std::any::type_name::(), e), + Ok(Ok(_)) => unreachable!("task's Ok is Infallible") + } + } + tracing::debug!("Sleeping for a little while to back-off..."); + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + } + }; + #[cfg(not(tokio_unstable))] + return tokio::task::spawn(supervisor_future); + #[cfg(tokio_unstable)] + return tokio::task::Builder::new() + .name(format!("supervisor for background task {}", std::any::type_name::()).as_str()) + .spawn(supervisor_future) + .unwrap(); +} + +mod check; + +#[derive(thiserror::Error, Debug)] +enum Error { + #[error("queue error: {0}")] + Queue(#[from] Q), + #[error("storage error: {0}")] + Storage(StorageError) +} + +async fn process_webmentions_from_queue, S: Storage + 'static>(queue: Q, db: S, http: reqwest::Client) -> Result> { + use futures_util::StreamExt; + use self::queue::Job; + + let mut stream = queue.into_stream().await?; + while let Some(item) = stream.next().await.transpose()? { + let job = item.job(); + let (source, target) = ( + job.source.parse::().unwrap(), + job.target.parse::().unwrap() + ); + + let (code, text) = match http.get(source.clone()).send().await { + Ok(response) => { + let code = response.status(); + if ![StatusCode::OK, StatusCode::GONE].iter().any(|i| i == &code) { + error!("error processing webmention: webpage fetch returned {}", code); + continue; + } + match response.text().await { + Ok(text) => (code, text), + Err(err) => { + error!("error processing webmention: error fetching webpage text: {}", err); + continue + } + } + } + Err(err) => { + error!("error processing webmention: error requesting webpage: {}", err); + continue + } + }; + + if code == StatusCode::GONE { + todo!("removing webmentions is not implemented yet"); + // db.remove_webmention(target.as_str(), source.as_str()).await.map_err(Error::::Storage)?; + } else { + // Verify webmention + let (mention_type, mut mention) = match tokio::task::block_in_place({ + || check::check_mention(text, &source, &target) + }) { + Ok(Some(mention_type)) => mention_type, + Ok(None) => { + error!("webmention {} -> {} invalid, rejecting", source, target); + item.done().await?; + continue; + } + Err(err) => { + error!("error processing webmention: error checking webmention: {}", err); + continue; + } + }; + + { + mention["type"] = serde_json::json!(["h-cite"]); + + if !mention["properties"].as_object().unwrap().contains_key("uid") { + let url = mention["properties"]["url"][0].as_str().unwrap_or_else(|| target.as_str()).to_owned(); + let props = mention["properties"].as_object_mut().unwrap(); + props.insert("uid".to_owned(), serde_json::Value::Array( + vec![serde_json::Value::String(url)]) + ); + } + } + + db.add_or_update_webmention(target.as_str(), mention_type, mention).await.map_err(Error::::Storage)?; + } + } + unreachable!() +} + +fn supervised_webmentions_task, S: Storage + 'static>( + queue: Q, db: S, + http: reqwest::Client, + cancellation_token: tokio_util::sync::CancellationToken +) -> SupervisedTask { + supervisor::, _, _>(move || process_webmentions_from_queue(queue.clone(), db.clone(), http.clone()), cancellation_token) +} diff --git a/src/webmentions/queue.rs b/src/webmentions/queue.rs new file mode 100644 index 0000000..b811e71 --- /dev/null +++ b/src/webmentions/queue.rs @@ -0,0 +1,303 @@ +use std::{pin::Pin, str::FromStr}; + +use futures_util::{Stream, StreamExt}; +use sqlx::{postgres::PgListener, Executor}; +use uuid::Uuid; + +use super::Webmention; + +static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("./migrations/webmention"); + +pub use kittybox_util::queue::{JobQueue, JobItem, Job}; + +pub trait PostgresJobItem: JobItem + sqlx::FromRow<'static, sqlx::postgres::PgRow> { + const DATABASE_NAME: &'static str; + const NOTIFICATION_CHANNEL: &'static str; +} + +#[derive(sqlx::FromRow)] +struct PostgresJobRow { + id: Uuid, + #[sqlx(flatten)] + job: T +} + +#[derive(Debug)] +pub struct PostgresJob { + id: Uuid, + job: T, + // This will normally always be Some, except on drop + txn: Option>, + runtime_handle: tokio::runtime::Handle, +} + + +impl Drop for PostgresJob { + // This is an emulation of "async drop" — the struct retains a + // runtime handle, which it uses to block on a future that does + // the actual cleanup. + // + // Of course, this is not portable between runtimes, but I don't + // care about that, since Kittybox is designed to work within the + // Tokio ecosystem. + fn drop(&mut self) { + tracing::error!("Job {:?} failed, incrementing attempts...", &self); + if let Some(mut txn) = self.txn.take() { + let id = self.id; + self.runtime_handle.spawn(async move { + tracing::debug!("Constructing query to increment attempts for job {}...", id); + // UPDATE "T::DATABASE_NAME" WHERE id = $1 SET attempts = attempts + 1 + sqlx::query_builder::QueryBuilder::new("UPDATE ") + // This is safe from a SQL injection standpoint, since it is a constant. + .push(T::DATABASE_NAME) + .push(" SET attempts = attempts + 1") + .push(" WHERE id = ") + .push_bind(id) + .build() + .execute(&mut *txn) + .await + .unwrap(); + sqlx::query_builder::QueryBuilder::new("NOTIFY ") + .push(T::NOTIFICATION_CHANNEL) + .build() + .execute(&mut *txn) + .await + .unwrap(); + txn.commit().await.unwrap(); + }); + } + } +} + +#[cfg(test)] +impl PostgresJob { + async fn attempts(&mut self) -> Result { + sqlx::query_builder::QueryBuilder::new("SELECT attempts FROM ") + .push(T::DATABASE_NAME) + .push(" WHERE id = ") + .push_bind(self.id) + .build_query_as::<(i32,)>() + // It's safe to unwrap here, because we "take" the txn only on drop or commit, + // where it's passed by value, not by reference. + .fetch_one(self.txn.as_deref_mut().unwrap()) + .await + .map(|(i,)| i as usize) + } +} + +#[async_trait::async_trait] +impl Job> for PostgresJob { + fn job(&self) -> &Webmention { + &self.job + } + async fn done(mut self) -> Result<(), as JobQueue>::Error> { + tracing::debug!("Deleting {} from the job queue", self.id); + sqlx::query("DELETE FROM kittybox_webmention.incoming_webmention_queue WHERE id = $1") + .bind(self.id) + .execute(self.txn.as_deref_mut().unwrap()) + .await?; + + self.txn.take().unwrap().commit().await + } +} + +pub struct PostgresJobQueue { + db: sqlx::PgPool, + _phantom: std::marker::PhantomData +} +impl Clone for PostgresJobQueue { + fn clone(&self) -> Self { + Self { + db: self.db.clone(), + _phantom: std::marker::PhantomData + } + } +} + +impl PostgresJobQueue { + pub async fn new(uri: &str) -> Result { + let mut options = sqlx::postgres::PgConnectOptions::from_str(uri)? + .options([("search_path", "kittybox_webmention")]); + if let Ok(password_file) = std::env::var("PGPASS_FILE") { + let password = tokio::fs::read_to_string(password_file).await.unwrap(); + options = options.password(&password); + } else if let Ok(password) = std::env::var("PGPASS") { + options = options.password(&password) + } + Self::from_pool( + sqlx::postgres::PgPoolOptions::new() + .max_connections(50) + .connect_with(options) + .await? + ).await + + } + + pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result { + db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox_webmention")).await?; + MIGRATOR.run(&db).await?; + Ok(Self { db, _phantom: std::marker::PhantomData }) + } +} + +#[async_trait::async_trait] +impl JobQueue for PostgresJobQueue { + type Job = PostgresJob; + type Error = sqlx::Error; + + async fn get_one(&self) -> Result, Self::Error> { + let mut txn = self.db.begin().await?; + + match sqlx::query_as::<_, PostgresJobRow>( + "SELECT id, source, target FROM kittybox_webmention.incoming_webmention_queue WHERE attempts < 5 FOR UPDATE SKIP LOCKED LIMIT 1" + ) + .fetch_optional(&mut *txn) + .await? + { + Some(job_row) => { + return Ok(Some(Self::Job { + id: job_row.id, + job: job_row.job, + txn: Some(txn), + runtime_handle: tokio::runtime::Handle::current(), + })) + }, + None => Ok(None) + } + } + + async fn put(&self, item: &Webmention) -> Result { + sqlx::query_scalar::<_, Uuid>("INSERT INTO kittybox_webmention.incoming_webmention_queue (source, target) VALUES ($1, $2) RETURNING id") + .bind(item.source.as_str()) + .bind(item.target.as_str()) + .fetch_one(&self.db) + .await + } + + async fn into_stream(self) -> Result> + Send>>, Self::Error> { + let mut listener = PgListener::connect_with(&self.db).await?; + listener.listen("incoming_webmention").await?; + + let stream: Pin> + Send>> = futures_util::stream::try_unfold((), { + let listener = std::sync::Arc::new(tokio::sync::Mutex::new(listener)); + move |_| { + let queue = self.clone(); + let listener = listener.clone(); + async move { + loop { + match queue.get_one().await? { + Some(item) => return Ok(Some((item, ()))), + None => { + listener.lock().await.recv().await?; + continue + } + } + } + } + } + }).boxed(); + + Ok(stream) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::{Webmention, PostgresJobQueue, Job, JobQueue, MIGRATOR}; + use futures_util::StreamExt; + + #[sqlx::test(migrator = "MIGRATOR")] + #[tracing_test::traced_test] + async fn test_webmention_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> { + let test_webmention = Webmention { + source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(), + target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned() + }; + + let queue = PostgresJobQueue::::from_pool(pool).await?; + tracing::debug!("Putting webmention into queue"); + queue.put(&test_webmention).await?; + { + let mut job_description = queue.get_one().await?.unwrap(); + assert_eq!(job_description.job(), &test_webmention); + assert_eq!(job_description.attempts().await?, 0); + } + tracing::debug!("Creating a stream"); + let mut stream = queue.clone().into_stream().await?; + + { + let mut guard = stream.next().await.transpose()?.unwrap(); + assert_eq!(guard.job(), &test_webmention); + assert_eq!(guard.attempts().await?, 1); + if let Some(item) = queue.get_one().await? { + panic!("Unexpected item {:?} returned from job queue!", item) + }; + } + + { + let mut guard = stream.next().await.transpose()?.unwrap(); + assert_eq!(guard.job(), &test_webmention); + assert_eq!(guard.attempts().await?, 2); + guard.done().await?; + } + + match queue.get_one().await? { + Some(item) => panic!("Unexpected item {:?} returned from job queue!", item), + None => Ok(()) + } + } + + #[sqlx::test(migrator = "MIGRATOR")] + #[tracing_test::traced_test] + async fn test_no_hangups_in_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> { + let test_webmention = Webmention { + source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(), + target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned() + }; + + let queue = PostgresJobQueue::::from_pool(pool.clone()).await?; + tracing::debug!("Putting webmention into queue"); + queue.put(&test_webmention).await?; + tracing::debug!("Creating a stream"); + let mut stream = queue.clone().into_stream().await?; + + // Synchronisation barrier that will be useful later + let barrier = Arc::new(tokio::sync::Barrier::new(2)); + { + // Get one job guard from a queue + let mut guard = stream.next().await.transpose()?.unwrap(); + assert_eq!(guard.job(), &test_webmention); + assert_eq!(guard.attempts().await?, 0); + + tokio::task::spawn({ + let barrier = barrier.clone(); + async move { + // Wait for the signal to drop the guard! + barrier.wait().await; + + drop(guard) + } + }); + } + tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()).await.unwrap_err(); + + let future = tokio::task::spawn( + tokio::time::timeout( + std::time::Duration::from_secs(10), async move { + stream.next().await.unwrap().unwrap() + } + ) + ); + // Let the other task drop the guard it is holding + barrier.wait().await; + let mut guard = future.await + .expect("Timeout on fetching item") + .expect("Job queue error"); + assert_eq!(guard.job(), &test_webmention); + assert_eq!(guard.attempts().await?, 1); + + Ok(()) + } +} -- cgit 1.4.1