about summary refs log tree commit diff
path: root/src/database/postgres/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/postgres/mod.rs')
-rw-r--r--src/database/postgres/mod.rs21
1 files changed, 15 insertions, 6 deletions
diff --git a/src/database/postgres/mod.rs b/src/database/postgres/mod.rs
index 27ec51c..b705eed 100644
--- a/src/database/postgres/mod.rs
+++ b/src/database/postgres/mod.rs
@@ -189,11 +189,13 @@ WHERE
         txn.commit().await.map_err(Into::into)
     }
 
-    #[tracing::instrument(skip(self))]
-    async fn update_post(&self, url: &str, update: MicropubUpdate) -> Result<()> {
+    #[tracing::instrument(skip(self), fields(f = std::any::type_name::<F>()))]
+    async fn update_with<F: FnOnce(&mut serde_json::Value) + Send>(
+        &self, url: &str, f: F
+    ) -> Result<(serde_json::Value, serde_json::Value)> {
         tracing::debug!("Updating post {}", url);
         let mut txn = self.db.begin().await?;
-        let (uid, mut post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE")
+        let (uid, old_post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE")
             .bind(url)
             .fetch_optional(&mut *txn)
             .await?
@@ -202,15 +204,22 @@ WHERE
                 "The specified post wasn't found in the database."
             ))?;
 
-        update.apply(&mut post);
+        let new_post = {
+            let mut post = old_post.clone();
+            tokio::task::block_in_place(|| f(&mut post));
+
+            post
+        };
 
         sqlx::query("UPDATE kittybox.mf2_json SET mf2 = $2 WHERE uid = $1")
             .bind(uid)
-            .bind(post)
+            .bind(&new_post)
             .execute(&mut *txn)
             .await?;
 
-        txn.commit().await.map_err(Into::into)
+        txn.commit().await?;
+
+        Ok((old_post, new_post))
     }
 
     #[tracing::instrument(skip(self))]