about summary refs log tree commit diff
path: root/src/webmentions
diff options
context:
space:
mode:
Diffstat (limited to 'src/webmentions')
-rw-r--r--src/webmentions/check.rs113
-rw-r--r--src/webmentions/mod.rs195
-rw-r--r--src/webmentions/queue.rs303
3 files changed, 611 insertions, 0 deletions
diff --git a/src/webmentions/check.rs b/src/webmentions/check.rs
new file mode 100644
index 0000000..f7322f7
--- /dev/null
+++ b/src/webmentions/check.rs
@@ -0,0 +1,113 @@
+use std::{cell::RefCell, rc::Rc};
+use microformats::{types::PropertyValue, html5ever::{self, tendril::TendrilSink}};
+use kittybox_util::MentionType;
+
+#[derive(thiserror::Error, Debug)]
+pub enum Error {
+    #[error("microformats error: {0}")]
+    Microformats(#[from] microformats::Error),
+    // #[error("json error: {0}")]
+    // Json(#[from] serde_json::Error),
+    #[error("url parse error: {0}")]
+    UrlParse(#[from] url::ParseError),
+}
+
+#[tracing::instrument]
+pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url::Url, link: &url::Url) -> Result<Option<(MentionType, serde_json::Value)>, Error> {
+    tracing::debug!("Parsing MF2 markup...");
+    // First, check the document for MF2 markup
+    let document = microformats::from_html(document.as_ref(), base_url.clone())?;
+
+    // Get an iterator of all items
+    let items_iter = document.items.iter()
+        .map(AsRef::as_ref)
+        .map(RefCell::borrow);
+
+    for item in items_iter {
+        tracing::debug!("Processing item: {:?}", item);
+
+        let props = item.properties.borrow();
+        for (prop, interaction_type) in [
+            ("in-reply-to", MentionType::Reply), ("like-of", MentionType::Like),
+            ("bookmark-of", MentionType::Bookmark), ("repost-of", MentionType::Repost)
+        ] {
+            if let Some(propvals) = props.get(prop) {
+                tracing::debug!("Has a u-{} property", prop);
+                for val in propvals {
+                    if let PropertyValue::Url(url) = val {
+                        if url == link {
+                            tracing::debug!("URL matches! Webmention is valid");
+                            return Ok(Some((interaction_type, serde_json::to_value(&*item).unwrap())))
+                        }
+                    }
+                }
+            }
+        }
+        // Process `content`
+        tracing::debug!("Processing e-content...");
+        if let Some(PropertyValue::Fragment(content)) = props.get("content")
+            .map(Vec::as_slice)
+            .unwrap_or_default()
+            .first()
+        {
+            tracing::debug!("Parsing HTML data...");
+            let root = html5ever::parse_document(html5ever::rcdom::RcDom::default(), Default::default())
+                .from_utf8()
+                .one(content.html.to_owned().as_bytes())
+                .document;
+
+            // This is a trick to unwrap recursion into a loop
+            //
+            // A list of unprocessed node is made. Then, in each
+            // iteration, the list is "taken" and replaced with an
+            // empty list, which is populated with nodes for the next
+            // iteration of the loop.
+            //
+            // Empty list means all nodes were processed.
+            let mut unprocessed_nodes: Vec<Rc<html5ever::rcdom::Node>> = root.children.borrow().iter().cloned().collect();
+            while !unprocessed_nodes.is_empty() {
+                // "Take" the list out of its memory slot, replace it with an empty list
+                let nodes = std::mem::take(&mut unprocessed_nodes);
+                tracing::debug!("Processing list of {} nodes", nodes.len());
+                'nodes_loop: for node in nodes.into_iter() {
+                    // Add children nodes to the list for the next iteration
+                    unprocessed_nodes.extend(node.children.borrow().iter().cloned());
+
+                    if let html5ever::rcdom::NodeData::Element { ref name, ref attrs, .. } = node.data {
+                        // If it's not `<a>`, skip it
+                        if name.local != *"a" { continue; }
+                        let mut is_mention: bool = false;
+                        for attr in attrs.borrow().iter() {
+                            if attr.name.local == *"rel" {
+                                // Don't count `rel="nofollow"` links — a web crawler should ignore them
+                                // and so for purposes of driving visitors they are useless
+                                if attr.value
+                                    .as_ref()
+                                    .split([',', ' '])
+                                    .any(|v| v == "nofollow")
+                                {
+                                    // Skip the entire node.
+                                    continue 'nodes_loop;
+                                }
+                            }
+                            // if it's not `<a href="...">`, skip it 
+                            if attr.name.local != *"href" { continue; }
+                            // Be forgiving in parsing URLs, and resolve them against the base URL
+                            if let Ok(url) = base_url.join(attr.value.as_ref()) {
+                                if &url == link {
+                                    is_mention = true;
+                                }
+                            }
+                        }
+                        if is_mention {
+                            return Ok(Some((MentionType::Mention, serde_json::to_value(&*item).unwrap())));
+                        }
+                    }
+                }
+            }
+            
+        }
+    }
+
+    Ok(None)
+}
diff --git a/src/webmentions/mod.rs b/src/webmentions/mod.rs
new file mode 100644
index 0000000..95ea870
--- /dev/null
+++ b/src/webmentions/mod.rs
@@ -0,0 +1,195 @@
+use axum::{Form, response::{IntoResponse, Response}, Extension};
+use axum::http::StatusCode;
+use tracing::error;
+
+use crate::database::{Storage, StorageError};
+use self::queue::JobQueue;
+pub mod queue;
+
+#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
+#[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))]
+pub struct Webmention {
+    source: String,
+    target: String,
+}
+
+impl queue::JobItem for Webmention {}
+impl queue::PostgresJobItem for Webmention {
+    const DATABASE_NAME: &'static str = "kittybox_webmention.incoming_webmention_queue";
+    const NOTIFICATION_CHANNEL: &'static str = "incoming_webmention";
+}
+
+async fn accept_webmention<Q: JobQueue<Webmention>>(
+    Extension(queue): Extension<Q>,
+    Form(webmention): Form<Webmention>,
+) -> Response {
+    if let Err(err) = webmention.source.parse::<url::Url>() {
+        return (StatusCode::BAD_REQUEST, err.to_string()).into_response()
+    }
+    if let Err(err) = webmention.target.parse::<url::Url>() {
+        return (StatusCode::BAD_REQUEST, err.to_string()).into_response()
+    }
+
+    match queue.put(&webmention).await {
+        Ok(id) => (StatusCode::ACCEPTED, [
+            ("Location", format!("/.kittybox/webmention/{id}"))
+        ]).into_response(),
+        Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, [
+            ("Content-Type", "text/plain")
+        ], err.to_string()).into_response()
+    }
+}
+
+pub fn router<Q: JobQueue<Webmention>, S: Storage + 'static>(
+    queue: Q, db: S, http: reqwest::Client,
+    cancellation_token: tokio_util::sync::CancellationToken
+) -> (axum::Router, SupervisedTask) {
+    // Automatically spawn a background task to handle webmentions
+    let bgtask_handle = supervised_webmentions_task(queue.clone(), db, http, cancellation_token);
+
+    let router = axum::Router::new()
+        .route("/.kittybox/webmention",
+            axum::routing::post(accept_webmention::<Q>)
+        )
+        .layer(Extension(queue));
+
+    (router, bgtask_handle)
+}
+
+#[derive(thiserror::Error, Debug)]
+pub enum SupervisorError {
+    #[error("the task was explicitly cancelled")]
+    Cancelled
+}
+
+pub type SupervisedTask = tokio::task::JoinHandle<Result<std::convert::Infallible, SupervisorError>>;
+
+pub fn supervisor<E, A, F>(mut f: F, cancellation_token: tokio_util::sync::CancellationToken) -> SupervisedTask
+where
+    E: std::error::Error + std::fmt::Debug + Send + 'static,
+    A: std::future::Future<Output = Result<std::convert::Infallible, E>> + Send + 'static,
+    F: FnMut() -> A + Send + 'static
+{
+
+    let supervisor_future = async move {
+        loop {
+            // Don't spawn the task if we are already cancelled, but
+            // have somehow missed it (probably because the task
+            // crashed and we immediately received a cancellation
+            // request after noticing the crashed task)
+            if cancellation_token.is_cancelled() {
+                return Err(SupervisorError::Cancelled)
+            }
+            let task = tokio::task::spawn(f());
+            tokio::select! {
+                _ = cancellation_token.cancelled() => {
+                    tracing::info!("Shutdown of background task {:?} requested.", std::any::type_name::<A>());
+                    return Err(SupervisorError::Cancelled)
+                }
+                task_result = task => match task_result {
+                    Err(e) => tracing::error!("background task {:?} exited unexpectedly: {}", std::any::type_name::<A>(), e),
+                    Ok(Err(e)) => tracing::error!("background task {:?} returned error: {}", std::any::type_name::<A>(), e),
+                    Ok(Ok(_)) => unreachable!("task's Ok is Infallible")
+                }
+            }
+            tracing::debug!("Sleeping for a little while to back-off...");
+            tokio::time::sleep(std::time::Duration::from_secs(5)).await;
+        }
+    };
+    #[cfg(not(tokio_unstable))]
+    return tokio::task::spawn(supervisor_future);
+    #[cfg(tokio_unstable)]
+    return tokio::task::Builder::new()
+        .name(format!("supervisor for background task {}", std::any::type_name::<A>()).as_str())
+        .spawn(supervisor_future)
+        .unwrap();
+}
+
+mod check;
+
+#[derive(thiserror::Error, Debug)]
+enum Error<Q: std::error::Error + std::fmt::Debug + Send + 'static> {
+    #[error("queue error: {0}")]
+    Queue(#[from] Q),
+    #[error("storage error: {0}")]
+    Storage(StorageError)
+}
+
+async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 'static>(queue: Q, db: S, http: reqwest::Client) -> Result<std::convert::Infallible, Error<Q::Error>> {
+    use futures_util::StreamExt;
+    use self::queue::Job;
+
+    let mut stream = queue.into_stream().await?;
+    while let Some(item) = stream.next().await.transpose()? {
+        let job = item.job();
+        let (source, target) = (
+            job.source.parse::<url::Url>().unwrap(),
+            job.target.parse::<url::Url>().unwrap()
+        );
+
+        let (code, text) = match http.get(source.clone()).send().await {
+            Ok(response) => {
+                let code = response.status();
+                if ![StatusCode::OK, StatusCode::GONE].iter().any(|i| i == &code) {
+                    error!("error processing webmention: webpage fetch returned {}", code);
+                    continue;
+                }
+                match response.text().await {
+                    Ok(text) => (code, text),
+                    Err(err) => {
+                        error!("error processing webmention: error fetching webpage text: {}", err);
+                        continue
+                    }
+                }
+            }
+            Err(err) => {
+                error!("error processing webmention: error requesting webpage: {}", err);
+                continue
+            }
+        };
+
+        if code == StatusCode::GONE {
+            todo!("removing webmentions is not implemented yet");
+            // db.remove_webmention(target.as_str(), source.as_str()).await.map_err(Error::<Q::Error>::Storage)?;
+        } else {
+            // Verify webmention
+            let (mention_type, mut mention) = match tokio::task::block_in_place({
+                || check::check_mention(text, &source, &target)
+            }) {
+                Ok(Some(mention_type)) => mention_type,
+                Ok(None) => {
+                    error!("webmention {} -> {} invalid, rejecting", source, target);
+                    item.done().await?;
+                    continue;
+                }
+                Err(err) => {
+                    error!("error processing webmention: error checking webmention: {}", err);
+                    continue;
+                }
+            };
+
+            {
+                mention["type"] = serde_json::json!(["h-cite"]);
+
+                if !mention["properties"].as_object().unwrap().contains_key("uid") {
+                    let url = mention["properties"]["url"][0].as_str().unwrap_or_else(|| target.as_str()).to_owned();
+                    let props = mention["properties"].as_object_mut().unwrap();
+                    props.insert("uid".to_owned(), serde_json::Value::Array(
+                        vec![serde_json::Value::String(url)])
+                    );
+                }
+            }
+
+            db.add_or_update_webmention(target.as_str(), mention_type, mention).await.map_err(Error::<Q::Error>::Storage)?;
+        }
+    }
+    unreachable!()
+}
+
+fn supervised_webmentions_task<Q: JobQueue<Webmention>, S: Storage + 'static>(
+    queue: Q, db: S,
+    http: reqwest::Client,
+    cancellation_token: tokio_util::sync::CancellationToken
+) -> SupervisedTask {
+    supervisor::<Error<Q::Error>, _, _>(move || process_webmentions_from_queue(queue.clone(), db.clone(), http.clone()), cancellation_token)
+}
diff --git a/src/webmentions/queue.rs b/src/webmentions/queue.rs
new file mode 100644
index 0000000..b811e71
--- /dev/null
+++ b/src/webmentions/queue.rs
@@ -0,0 +1,303 @@
+use std::{pin::Pin, str::FromStr};
+
+use futures_util::{Stream, StreamExt};
+use sqlx::{postgres::PgListener, Executor};
+use uuid::Uuid;
+
+use super::Webmention;
+
+static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("./migrations/webmention");
+
+pub use kittybox_util::queue::{JobQueue, JobItem, Job};
+
+pub trait PostgresJobItem: JobItem + sqlx::FromRow<'static, sqlx::postgres::PgRow> {
+    const DATABASE_NAME: &'static str;
+    const NOTIFICATION_CHANNEL: &'static str;
+}
+
+#[derive(sqlx::FromRow)]
+struct PostgresJobRow<T: PostgresJobItem> {
+    id: Uuid,
+    #[sqlx(flatten)]
+    job: T
+}
+
+#[derive(Debug)]
+pub struct PostgresJob<T: PostgresJobItem> {
+    id: Uuid,
+    job: T,
+    // This will normally always be Some, except on drop
+    txn: Option<sqlx::Transaction<'static, sqlx::Postgres>>,
+    runtime_handle: tokio::runtime::Handle,
+}
+
+
+impl<T: PostgresJobItem> Drop for PostgresJob<T> {
+    // This is an emulation of "async drop" — the struct retains a
+    // runtime handle, which it uses to block on a future that does
+    // the actual cleanup.
+    //
+    // Of course, this is not portable between runtimes, but I don't
+    // care about that, since Kittybox is designed to work within the
+    // Tokio ecosystem.
+    fn drop(&mut self) {
+        tracing::error!("Job {:?} failed, incrementing attempts...", &self);
+        if let Some(mut txn) = self.txn.take() {
+            let id = self.id;
+            self.runtime_handle.spawn(async move {
+                tracing::debug!("Constructing query to increment attempts for job {}...", id);
+                // UPDATE "T::DATABASE_NAME" WHERE id = $1 SET attempts = attempts + 1
+                sqlx::query_builder::QueryBuilder::new("UPDATE ")
+                    // This is safe from a SQL injection standpoint, since it is a constant.
+                    .push(T::DATABASE_NAME)
+                    .push(" SET attempts = attempts + 1")
+                    .push(" WHERE id = ")
+                    .push_bind(id)
+                    .build()
+                    .execute(&mut *txn)
+                    .await
+                    .unwrap();
+                sqlx::query_builder::QueryBuilder::new("NOTIFY ")
+                    .push(T::NOTIFICATION_CHANNEL)
+                    .build()
+                    .execute(&mut *txn)
+                    .await
+                    .unwrap();
+                txn.commit().await.unwrap();
+            });
+        }
+    }
+}
+
+#[cfg(test)]
+impl<T: PostgresJobItem> PostgresJob<T> {
+    async fn attempts(&mut self) -> Result<usize, sqlx::Error> {
+        sqlx::query_builder::QueryBuilder::new("SELECT attempts FROM ")
+            .push(T::DATABASE_NAME)
+            .push(" WHERE id = ")
+            .push_bind(self.id)
+            .build_query_as::<(i32,)>()
+            // It's safe to unwrap here, because we "take" the txn only on drop or commit,
+            // where it's passed by value, not by reference.
+            .fetch_one(self.txn.as_deref_mut().unwrap())
+            .await
+            .map(|(i,)| i as usize)
+    }
+}
+
+#[async_trait::async_trait]
+impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJob<Webmention> {
+    fn job(&self) -> &Webmention {
+        &self.job
+    }
+    async fn done(mut self) -> Result<(), <PostgresJobQueue<Webmention> as JobQueue<Webmention>>::Error> {
+        tracing::debug!("Deleting {} from the job queue", self.id);
+        sqlx::query("DELETE FROM kittybox_webmention.incoming_webmention_queue WHERE id = $1")
+            .bind(self.id)
+            .execute(self.txn.as_deref_mut().unwrap())
+            .await?;
+
+        self.txn.take().unwrap().commit().await
+    }
+}
+
+pub struct PostgresJobQueue<T> {
+    db: sqlx::PgPool,
+    _phantom: std::marker::PhantomData<T>
+}
+impl<T> Clone for PostgresJobQueue<T> {
+    fn clone(&self) -> Self {
+        Self {
+            db: self.db.clone(),
+            _phantom: std::marker::PhantomData
+        }
+    }
+}
+
+impl PostgresJobQueue<Webmention> {
+    pub async fn new(uri: &str) -> Result<Self, sqlx::Error> {
+        let mut options = sqlx::postgres::PgConnectOptions::from_str(uri)?
+            .options([("search_path", "kittybox_webmention")]);
+        if let Ok(password_file) = std::env::var("PGPASS_FILE") {
+            let password = tokio::fs::read_to_string(password_file).await.unwrap();
+            options = options.password(&password);
+        } else if let Ok(password) = std::env::var("PGPASS") {
+            options = options.password(&password)
+        }
+        Self::from_pool(
+            sqlx::postgres::PgPoolOptions::new()
+                .max_connections(50)
+                .connect_with(options)
+                .await?
+        ).await
+
+    }
+
+    pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result<Self, sqlx::Error> {
+        db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox_webmention")).await?;
+        MIGRATOR.run(&db).await?;
+        Ok(Self { db, _phantom: std::marker::PhantomData })
+    }
+}
+
+#[async_trait::async_trait]
+impl JobQueue<Webmention> for PostgresJobQueue<Webmention> {
+    type Job = PostgresJob<Webmention>;
+    type Error = sqlx::Error;
+
+    async fn get_one(&self) -> Result<Option<Self::Job>, Self::Error> {
+        let mut txn = self.db.begin().await?;
+
+        match sqlx::query_as::<_, PostgresJobRow<Webmention>>(
+            "SELECT id, source, target FROM kittybox_webmention.incoming_webmention_queue WHERE attempts < 5 FOR UPDATE SKIP LOCKED LIMIT 1"
+        )
+            .fetch_optional(&mut *txn)
+            .await?
+        {
+            Some(job_row) => {
+                return Ok(Some(Self::Job {
+                    id: job_row.id,
+                    job: job_row.job,
+                    txn: Some(txn),
+                    runtime_handle: tokio::runtime::Handle::current(),
+                }))
+            },
+            None => Ok(None)
+        }
+    }
+
+    async fn put(&self, item: &Webmention) -> Result<Uuid, Self::Error> {
+        sqlx::query_scalar::<_, Uuid>("INSERT INTO kittybox_webmention.incoming_webmention_queue (source, target) VALUES ($1, $2) RETURNING id")
+            .bind(item.source.as_str())
+            .bind(item.target.as_str())
+            .fetch_one(&self.db)
+            .await
+    }
+
+    async fn into_stream(self) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>>, Self::Error> {
+        let mut listener = PgListener::connect_with(&self.db).await?;
+        listener.listen("incoming_webmention").await?;
+
+        let stream: Pin<Box<dyn Stream<Item = Result<Self::Job, Self::Error>> + Send>> = futures_util::stream::try_unfold((), {
+            let listener = std::sync::Arc::new(tokio::sync::Mutex::new(listener));
+            move |_| {
+                let queue = self.clone();
+                let listener = listener.clone();
+                async move {
+                    loop {
+                        match queue.get_one().await? {
+                            Some(item) => return Ok(Some((item, ()))),
+                            None => {
+                                listener.lock().await.recv().await?;
+                                continue
+                            }
+                        }
+                    }
+                }
+            }
+        }).boxed();
+
+        Ok(stream)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::sync::Arc;
+
+    use super::{Webmention, PostgresJobQueue, Job, JobQueue, MIGRATOR};
+    use futures_util::StreamExt;
+
+    #[sqlx::test(migrator = "MIGRATOR")]
+    #[tracing_test::traced_test]
+    async fn test_webmention_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> {
+        let test_webmention = Webmention {
+            source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(),
+            target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned()
+        };
+
+        let queue = PostgresJobQueue::<Webmention>::from_pool(pool).await?;
+        tracing::debug!("Putting webmention into queue");
+        queue.put(&test_webmention).await?;
+        {
+            let mut job_description = queue.get_one().await?.unwrap();
+            assert_eq!(job_description.job(), &test_webmention);
+            assert_eq!(job_description.attempts().await?, 0);
+        }
+        tracing::debug!("Creating a stream");
+        let mut stream = queue.clone().into_stream().await?;
+
+        {
+            let mut guard = stream.next().await.transpose()?.unwrap();
+            assert_eq!(guard.job(), &test_webmention);
+            assert_eq!(guard.attempts().await?, 1);
+            if let Some(item) = queue.get_one().await? {
+                panic!("Unexpected item {:?} returned from job queue!", item)
+            };
+        }
+
+        {
+            let mut guard = stream.next().await.transpose()?.unwrap();
+            assert_eq!(guard.job(), &test_webmention);
+            assert_eq!(guard.attempts().await?, 2);
+            guard.done().await?;
+        }
+
+        match queue.get_one().await? {
+            Some(item) => panic!("Unexpected item {:?} returned from job queue!", item),
+            None => Ok(())
+        }
+    }
+
+    #[sqlx::test(migrator = "MIGRATOR")]
+    #[tracing_test::traced_test]
+    async fn test_no_hangups_in_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> {
+        let test_webmention = Webmention {
+            source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(),
+            target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned()
+        };
+
+        let queue = PostgresJobQueue::<Webmention>::from_pool(pool.clone()).await?;
+        tracing::debug!("Putting webmention into queue");
+        queue.put(&test_webmention).await?;
+        tracing::debug!("Creating a stream");
+        let mut stream = queue.clone().into_stream().await?;
+
+        // Synchronisation barrier that will be useful later
+        let barrier = Arc::new(tokio::sync::Barrier::new(2));
+        {
+            // Get one job guard from a queue
+            let mut guard = stream.next().await.transpose()?.unwrap();
+            assert_eq!(guard.job(), &test_webmention);
+            assert_eq!(guard.attempts().await?, 0);
+
+            tokio::task::spawn({
+                let barrier = barrier.clone();
+                async move {
+                    // Wait for the signal to drop the guard!
+                    barrier.wait().await;
+
+                    drop(guard)
+                }
+            });
+        }
+        tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()).await.unwrap_err();
+
+        let future = tokio::task::spawn(
+            tokio::time::timeout(
+                std::time::Duration::from_secs(10), async move {
+                    stream.next().await.unwrap().unwrap()
+                }
+            )
+        );
+        // Let the other task drop the guard it is holding
+        barrier.wait().await;
+        let mut guard = future.await
+            .expect("Timeout on fetching item")
+            .expect("Job queue error");
+        assert_eq!(guard.job(), &test_webmention);
+        assert_eq!(guard.attempts().await?, 1);
+
+        Ok(())
+    }
+}