about summary refs log tree commit diff
path: root/src/webmentions
diff options
context:
space:
mode:
Diffstat (limited to 'src/webmentions')
-rw-r--r--src/webmentions/check.rs52
-rw-r--r--src/webmentions/mod.rs130
-rw-r--r--src/webmentions/queue.rs60
3 files changed, 163 insertions, 79 deletions
diff --git a/src/webmentions/check.rs b/src/webmentions/check.rs
index 683cc6b..380f4db 100644
--- a/src/webmentions/check.rs
+++ b/src/webmentions/check.rs
@@ -1,7 +1,7 @@
-use std::rc::Rc;
-use microformats::types::PropertyValue;
 use html5ever::{self, tendril::TendrilSink};
 use kittybox_util::MentionType;
+use microformats::types::PropertyValue;
+use std::rc::Rc;
 
 // TODO: replace.
 mod rcdom;
@@ -17,7 +17,11 @@ pub enum Error {
 }
 
 #[tracing::instrument]
-pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url::Url, link: &url::Url) -> Result<Option<(MentionType, serde_json::Value)>, Error> {
+pub fn check_mention(
+    document: impl AsRef<str> + std::fmt::Debug,
+    base_url: &url::Url,
+    link: &url::Url,
+) -> Result<Option<(MentionType, serde_json::Value)>, Error> {
     tracing::debug!("Parsing MF2 markup...");
     // First, check the document for MF2 markup
     let document = microformats::from_html(document.as_ref(), base_url.clone())?;
@@ -29,8 +33,10 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url
         tracing::debug!("Processing item: {:?}", item);
 
         for (prop, interaction_type) in [
-            ("in-reply-to", MentionType::Reply), ("like-of", MentionType::Like),
-            ("bookmark-of", MentionType::Bookmark), ("repost-of", MentionType::Repost)
+            ("in-reply-to", MentionType::Reply),
+            ("like-of", MentionType::Like),
+            ("bookmark-of", MentionType::Bookmark),
+            ("repost-of", MentionType::Repost),
         ] {
             if let Some(propvals) = item.properties.get(prop) {
                 tracing::debug!("Has a u-{} property", prop);
@@ -38,7 +44,10 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url
                     if let PropertyValue::Url(url) = val {
                         if url == link {
                             tracing::debug!("URL matches! Webmention is valid");
-                            return Ok(Some((interaction_type, serde_json::to_value(item).unwrap())))
+                            return Ok(Some((
+                                interaction_type,
+                                serde_json::to_value(item).unwrap(),
+                            )));
                         }
                     }
                 }
@@ -46,7 +55,9 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url
         }
         // Process `content`
         tracing::debug!("Processing e-content...");
-        if let Some(PropertyValue::Fragment(content)) = item.properties.get("content")
+        if let Some(PropertyValue::Fragment(content)) = item
+            .properties
+            .get("content")
             .map(Vec::as_slice)
             .unwrap_or_default()
             .first()
@@ -65,7 +76,8 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url
             // iteration of the loop.
             //
             // Empty list means all nodes were processed.
-            let mut unprocessed_nodes: Vec<Rc<rcdom::Node>> = root.children.borrow().iter().cloned().collect();
+            let mut unprocessed_nodes: Vec<Rc<rcdom::Node>> =
+                root.children.borrow().iter().cloned().collect();
             while !unprocessed_nodes.is_empty() {
                 // "Take" the list out of its memory slot, replace it with an empty list
                 let nodes = std::mem::take(&mut unprocessed_nodes);
@@ -74,15 +86,23 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url
                     // Add children nodes to the list for the next iteration
                     unprocessed_nodes.extend(node.children.borrow().iter().cloned());
 
-                    if let rcdom::NodeData::Element { ref name, ref attrs, .. } = node.data {
+                    if let rcdom::NodeData::Element {
+                        ref name,
+                        ref attrs,
+                        ..
+                    } = node.data
+                    {
                         // If it's not `<a>`, skip it
-                        if name.local != *"a" { continue; }
+                        if name.local != *"a" {
+                            continue;
+                        }
                         let mut is_mention: bool = false;
                         for attr in attrs.borrow().iter() {
                             if attr.name.local == *"rel" {
                                 // Don't count `rel="nofollow"` links — a web crawler should ignore them
                                 // and so for purposes of driving visitors they are useless
-                                if attr.value
+                                if attr
+                                    .value
                                     .as_ref()
                                     .split([',', ' '])
                                     .any(|v| v == "nofollow")
@@ -92,7 +112,9 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url
                                 }
                             }
                             // if it's not `<a href="...">`, skip it
-                            if attr.name.local != *"href" { continue; }
+                            if attr.name.local != *"href" {
+                                continue;
+                            }
                             // Be forgiving in parsing URLs, and resolve them against the base URL
                             if let Ok(url) = base_url.join(attr.value.as_ref()) {
                                 if &url == link {
@@ -101,12 +123,14 @@ pub fn check_mention(document: impl AsRef<str> + std::fmt::Debug, base_url: &url
                             }
                         }
                         if is_mention {
-                            return Ok(Some((MentionType::Mention, serde_json::to_value(item).unwrap())));
+                            return Ok(Some((
+                                MentionType::Mention,
+                                serde_json::to_value(item).unwrap(),
+                            )));
                         }
                     }
                 }
             }
-
         }
     }
 
diff --git a/src/webmentions/mod.rs b/src/webmentions/mod.rs
index 91b274b..57f9a57 100644
--- a/src/webmentions/mod.rs
+++ b/src/webmentions/mod.rs
@@ -1,9 +1,14 @@
-use axum::{extract::{FromRef, State}, response::{IntoResponse, Response}, routing::post, Form};
 use axum::http::StatusCode;
+use axum::{
+    extract::{FromRef, State},
+    response::{IntoResponse, Response},
+    routing::post,
+    Form,
+};
 use tracing::error;
 
-use crate::database::{Storage, StorageError};
 use self::queue::JobQueue;
+use crate::database::{Storage, StorageError};
 pub mod queue;
 
 #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
@@ -24,40 +29,46 @@ async fn accept_webmention<Q: JobQueue<Webmention>>(
     Form(webmention): Form<Webmention>,
 ) -> Response {
     if let Err(err) = webmention.source.parse::<url::Url>() {
-        return (StatusCode::BAD_REQUEST, err.to_string()).into_response()
+        return (StatusCode::BAD_REQUEST, err.to_string()).into_response();
     }
     if let Err(err) = webmention.target.parse::<url::Url>() {
-        return (StatusCode::BAD_REQUEST, err.to_string()).into_response()
+        return (StatusCode::BAD_REQUEST, err.to_string()).into_response();
     }
 
     match queue.put(&webmention).await {
         Ok(_id) => StatusCode::ACCEPTED.into_response(),
-        Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, [
-            ("Content-Type", "text/plain")
-        ], err.to_string()).into_response()
+        Err(err) => (
+            StatusCode::INTERNAL_SERVER_ERROR,
+            [("Content-Type", "text/plain")],
+            err.to_string(),
+        )
+            .into_response(),
     }
 }
 
-pub fn router<St: Clone + Send + Sync + 'static, Q: JobQueue<Webmention> + FromRef<St>>() -> axum::Router<St> {
-    axum::Router::new()
-        .route("/.kittybox/webmention", post(accept_webmention::<Q>))
+pub fn router<St: Clone + Send + Sync + 'static, Q: JobQueue<Webmention> + FromRef<St>>(
+) -> axum::Router<St> {
+    axum::Router::new().route("/.kittybox/webmention", post(accept_webmention::<Q>))
 }
 
 #[derive(thiserror::Error, Debug)]
 pub enum SupervisorError {
     #[error("the task was explicitly cancelled")]
-    Cancelled
+    Cancelled,
 }
 
-pub type SupervisedTask = tokio::task::JoinHandle<Result<std::convert::Infallible, SupervisorError>>;
+pub type SupervisedTask =
+    tokio::task::JoinHandle<Result<std::convert::Infallible, SupervisorError>>;
 
-pub fn supervisor<E, A, F>(mut f: F, cancellation_token: tokio_util::sync::CancellationToken) -> SupervisedTask
+pub fn supervisor<E, A, F>(
+    mut f: F,
+    cancellation_token: tokio_util::sync::CancellationToken,
+) -> SupervisedTask
 where
     E: std::error::Error + std::fmt::Debug + Send + 'static,
     A: std::future::Future<Output = Result<std::convert::Infallible, E>> + Send + 'static,
-    F: FnMut() -> A + Send + 'static
+    F: FnMut() -> A + Send + 'static,
 {
-
     let supervisor_future = async move {
         loop {
             // Don't spawn the task if we are already cancelled, but
@@ -65,7 +76,7 @@ where
             // crashed and we immediately received a cancellation
             // request after noticing the crashed task)
             if cancellation_token.is_cancelled() {
-                return Err(SupervisorError::Cancelled)
+                return Err(SupervisorError::Cancelled);
             }
             let task = tokio::task::spawn(f());
             tokio::select! {
@@ -87,7 +98,13 @@ where
     return tokio::task::spawn(supervisor_future);
     #[cfg(tokio_unstable)]
     return tokio::task::Builder::new()
-        .name(format!("supervisor for background task {}", std::any::type_name::<A>()).as_str())
+        .name(
+            format!(
+                "supervisor for background task {}",
+                std::any::type_name::<A>()
+            )
+            .as_str(),
+        )
         .spawn(supervisor_future)
         .unwrap();
 }
@@ -99,39 +116,55 @@ enum Error<Q: std::error::Error + std::fmt::Debug + Send + 'static> {
     #[error("queue error: {0}")]
     Queue(#[from] Q),
     #[error("storage error: {0}")]
-    Storage(StorageError)
+    Storage(StorageError),
 }
 
-async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 'static>(queue: Q, db: S, http: reqwest_middleware::ClientWithMiddleware) -> Result<std::convert::Infallible, Error<Q::Error>> {
-    use futures_util::StreamExt;
+async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 'static>(
+    queue: Q,
+    db: S,
+    http: reqwest_middleware::ClientWithMiddleware,
+) -> Result<std::convert::Infallible, Error<Q::Error>> {
     use self::queue::Job;
+    use futures_util::StreamExt;
 
     let mut stream = queue.into_stream().await?;
     while let Some(item) = stream.next().await.transpose()? {
         let job = item.job();
         let (source, target) = (
             job.source.parse::<url::Url>().unwrap(),
-            job.target.parse::<url::Url>().unwrap()
+            job.target.parse::<url::Url>().unwrap(),
         );
 
         let (code, text) = match http.get(source.clone()).send().await {
             Ok(response) => {
                 let code = response.status();
-                if ![StatusCode::OK, StatusCode::GONE].iter().any(|i| i == &code) {
-                    error!("error processing webmention: webpage fetch returned {}", code);
+                if ![StatusCode::OK, StatusCode::GONE]
+                    .iter()
+                    .any(|i| i == &code)
+                {
+                    error!(
+                        "error processing webmention: webpage fetch returned {}",
+                        code
+                    );
                     continue;
                 }
                 match response.text().await {
                     Ok(text) => (code, text),
                     Err(err) => {
-                        error!("error processing webmention: error fetching webpage text: {}", err);
-                        continue
+                        error!(
+                            "error processing webmention: error fetching webpage text: {}",
+                            err
+                        );
+                        continue;
                     }
                 }
             }
             Err(err) => {
-                error!("error processing webmention: error requesting webpage: {}", err);
-                continue
+                error!(
+                    "error processing webmention: error requesting webpage: {}",
+                    err
+                );
+                continue;
             }
         };
 
@@ -150,7 +183,10 @@ async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 's
                     continue;
                 }
                 Err(err) => {
-                    error!("error processing webmention: error checking webmention: {}", err);
+                    error!(
+                        "error processing webmention: error checking webmention: {}",
+                        err
+                    );
                     continue;
                 }
             };
@@ -158,31 +194,47 @@ async fn process_webmentions_from_queue<Q: JobQueue<Webmention>, S: Storage + 's
             {
                 mention["type"] = serde_json::json!(["h-cite"]);
 
-                if !mention["properties"].as_object().unwrap().contains_key("uid") {
-                    let url = mention["properties"]["url"][0].as_str().unwrap_or_else(|| target.as_str()).to_owned();
+                if !mention["properties"]
+                    .as_object()
+                    .unwrap()
+                    .contains_key("uid")
+                {
+                    let url = mention["properties"]["url"][0]
+                        .as_str()
+                        .unwrap_or_else(|| target.as_str())
+                        .to_owned();
                     let props = mention["properties"].as_object_mut().unwrap();
-                    props.insert("uid".to_owned(), serde_json::Value::Array(
-                        vec![serde_json::Value::String(url)])
+                    props.insert(
+                        "uid".to_owned(),
+                        serde_json::Value::Array(vec![serde_json::Value::String(url)]),
                     );
                 }
             }
 
-            db.add_or_update_webmention(target.as_str(), mention_type, mention).await.map_err(Error::<Q::Error>::Storage)?;
+            db.add_or_update_webmention(target.as_str(), mention_type, mention)
+                .await
+                .map_err(Error::<Q::Error>::Storage)?;
         }
     }
     unreachable!()
 }
 
-pub fn supervised_webmentions_task<St: Send + Sync + 'static, S: Storage + FromRef<St> + 'static, Q: JobQueue<Webmention> + FromRef<St> + 'static>(
+pub fn supervised_webmentions_task<
+    St: Send + Sync + 'static,
+    S: Storage + FromRef<St> + 'static,
+    Q: JobQueue<Webmention> + FromRef<St> + 'static,
+>(
     state: &St,
-    cancellation_token: tokio_util::sync::CancellationToken
+    cancellation_token: tokio_util::sync::CancellationToken,
 ) -> SupervisedTask
-where reqwest_middleware::ClientWithMiddleware: FromRef<St>
+where
+    reqwest_middleware::ClientWithMiddleware: FromRef<St>,
 {
     let queue = Q::from_ref(state);
     let storage = S::from_ref(state);
     let http = reqwest_middleware::ClientWithMiddleware::from_ref(state);
-    supervisor::<Error<Q::Error>, _, _>(move || process_webmentions_from_queue(
-        queue.clone(), storage.clone(), http.clone()
-    ), cancellation_token)
+    supervisor::<Error<Q::Error>, _, _>(
+        move || process_webmentions_from_queue(queue.clone(), storage.clone(), http.clone()),
+        cancellation_token,
+    )
 }
diff --git a/src/webmentions/queue.rs b/src/webmentions/queue.rs
index 52bcdfa..a33de1a 100644
--- a/src/webmentions/queue.rs
+++ b/src/webmentions/queue.rs
@@ -6,7 +6,7 @@ use super::Webmention;
 
 static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("./migrations/webmention");
 
-pub use kittybox_util::queue::{JobQueue, JobItem, Job, JobStream};
+pub use kittybox_util::queue::{Job, JobItem, JobQueue, JobStream};
 
 pub trait PostgresJobItem: JobItem + sqlx::FromRow<'static, sqlx::postgres::PgRow> {
     const DATABASE_NAME: &'static str;
@@ -17,7 +17,7 @@ pub trait PostgresJobItem: JobItem + sqlx::FromRow<'static, sqlx::postgres::PgRo
 struct PostgresJobRow<T: PostgresJobItem> {
     id: Uuid,
     #[sqlx(flatten)]
-    job: T
+    job: T,
 }
 
 #[derive(Debug)]
@@ -29,7 +29,6 @@ pub struct PostgresJob<T: PostgresJobItem> {
     runtime_handle: tokio::runtime::Handle,
 }
 
-
 impl<T: PostgresJobItem> Drop for PostgresJob<T> {
     // This is an emulation of "async drop" — the struct retains a
     // runtime handle, which it uses to block on a future that does
@@ -87,7 +86,9 @@ impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJob<Webmention> {
     fn job(&self) -> &Webmention {
         &self.job
     }
-    async fn done(mut self) -> Result<(), <PostgresJobQueue<Webmention> as JobQueue<Webmention>>::Error> {
+    async fn done(
+        mut self,
+    ) -> Result<(), <PostgresJobQueue<Webmention> as JobQueue<Webmention>>::Error> {
         tracing::debug!("Deleting {} from the job queue", self.id);
         sqlx::query("DELETE FROM kittybox_webmention.incoming_webmention_queue WHERE id = $1")
             .bind(self.id)
@@ -100,13 +101,13 @@ impl Job<Webmention, PostgresJobQueue<Webmention>> for PostgresJob<Webmention> {
 
 pub struct PostgresJobQueue<T> {
     db: sqlx::PgPool,
-    _phantom: std::marker::PhantomData<T>
+    _phantom: std::marker::PhantomData<T>,
 }
 impl<T> Clone for PostgresJobQueue<T> {
     fn clone(&self) -> Self {
         Self {
             db: self.db.clone(),
-            _phantom: std::marker::PhantomData
+            _phantom: std::marker::PhantomData,
         }
     }
 }
@@ -120,15 +121,21 @@ impl PostgresJobQueue<Webmention> {
             sqlx::postgres::PgPoolOptions::new()
                 .max_connections(50)
                 .connect_with(options)
-                .await?
-        ).await
-
+                .await?,
+        )
+        .await
     }
 
     pub(crate) async fn from_pool(db: sqlx::PgPool) -> Result<Self, sqlx::Error> {
-        db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox_webmention")).await?;
+        db.execute(sqlx::query(
+            "CREATE SCHEMA IF NOT EXISTS kittybox_webmention",
+        ))
+        .await?;
         MIGRATOR.run(&db).await?;
-        Ok(Self { db, _phantom: std::marker::PhantomData })
+        Ok(Self {
+            db,
+            _phantom: std::marker::PhantomData,
+        })
     }
 }
 
@@ -180,13 +187,14 @@ impl JobQueue<Webmention> for PostgresJobQueue<Webmention> {
                             Some(item) => return Ok(Some((item, ()))),
                             None => {
                                 listener.lock().await.recv().await?;
-                                continue
+                                continue;
                             }
                         }
                     }
                 }
             }
-        }).boxed();
+        })
+        .boxed();
 
         Ok(stream)
     }
@@ -196,7 +204,7 @@ impl JobQueue<Webmention> for PostgresJobQueue<Webmention> {
 mod tests {
     use std::sync::Arc;
 
-    use super::{Webmention, PostgresJobQueue, Job, JobQueue, MIGRATOR};
+    use super::{Job, JobQueue, PostgresJobQueue, Webmention, MIGRATOR};
     use futures_util::StreamExt;
 
     #[sqlx::test(migrator = "MIGRATOR")]
@@ -204,7 +212,7 @@ mod tests {
     async fn test_webmention_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> {
         let test_webmention = Webmention {
             source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(),
-            target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned()
+            target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned(),
         };
 
         let queue = PostgresJobQueue::<Webmention>::from_pool(pool).await?;
@@ -236,7 +244,7 @@ mod tests {
 
         match queue.get_one().await? {
             Some(item) => panic!("Unexpected item {:?} returned from job queue!", item),
-            None => Ok(())
+            None => Ok(()),
         }
     }
 
@@ -245,7 +253,7 @@ mod tests {
     async fn test_no_hangups_in_queue(pool: sqlx::PgPool) -> Result<(), sqlx::Error> {
         let test_webmention = Webmention {
             source: "https://fireburn.ru/posts/lorem-ipsum".to_owned(),
-            target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned()
+            target: "https://aaronparecki.com/posts/dolor-sit-amet".to_owned(),
         };
 
         let queue = PostgresJobQueue::<Webmention>::from_pool(pool.clone()).await?;
@@ -272,18 +280,18 @@ mod tests {
                 }
             });
         }
-        tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()).await.unwrap_err();
+        tokio::time::timeout(std::time::Duration::from_secs(1), stream.next())
+            .await
+            .unwrap_err();
 
-        let future = tokio::task::spawn(
-            tokio::time::timeout(
-                std::time::Duration::from_secs(10), async move {
-                    stream.next().await.unwrap().unwrap()
-                }
-            )
-        );
+        let future = tokio::task::spawn(tokio::time::timeout(
+            std::time::Duration::from_secs(10),
+            async move { stream.next().await.unwrap().unwrap() },
+        ));
         // Let the other task drop the guard it is holding
         barrier.wait().await;
-        let mut guard = future.await
+        let mut guard = future
+            .await
             .expect("Timeout on fetching item")
             .expect("Job queue error");
         assert_eq!(guard.job(), &test_webmention);