about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/database/file/mod.rs7
-rw-r--r--src/database/memory.rs7
-rw-r--r--src/database/mod.rs26
-rw-r--r--src/database/postgres/mod.rs21
4 files changed, 54 insertions, 7 deletions
diff --git a/src/database/file/mod.rs b/src/database/file/mod.rs
index cf7380f..6343f1f 100644
--- a/src/database/file/mod.rs
+++ b/src/database/file/mod.rs
@@ -484,6 +484,13 @@ impl Storage for FileStorage {
         Ok(())
     }
 
+    #[tracing::instrument(skip(self, f), fields(f = std::any::type_name::<F>()))]
+    async fn update_with<F: FnOnce(&mut serde_json::Value) + Send>(
+        &self, url: &str, f: F
+    ) -> Result<(serde_json::Value, serde_json::Value)> {
+        todo!("update_with is not yet implemented due to special requirements of the file backend")
+    }
+
     #[tracing::instrument(skip(self))]
     async fn get_channels(&self, user: &url::Url) -> Result<Vec<super::MicropubChannel>> {
         let mut path = relative_path::RelativePathBuf::new();
diff --git a/src/database/memory.rs b/src/database/memory.rs
index a4ffc7b..f799f2c 100644
--- a/src/database/memory.rs
+++ b/src/database/memory.rs
@@ -232,4 +232,11 @@ impl Storage for MemoryStorage {
         todo!()
     }
 
+    #[allow(unused_variables)]
+    async fn update_with<F: FnOnce(&mut serde_json::Value) + Send>(
+        &self, url: &str, f: F
+    ) -> Result<(serde_json::Value, serde_json::Value)> {
+        todo!()
+    }
+
 }
diff --git a/src/database/mod.rs b/src/database/mod.rs
index 058fc0c..0993715 100644
--- a/src/database/mod.rs
+++ b/src/database/mod.rs
@@ -251,7 +251,31 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync {
     /// each other's changes or simply corrupting something. Rejecting
     /// is allowed in case of concurrent updates if waiting for a lock
     /// cannot be done.
-    fn update_post(&self, url: &str, update: MicropubUpdate) -> impl Future<Output = Result<()>> + Send;
+    ///
+    /// Default implementation calls [`Storage::update_with`] and uses
+    /// [`update.apply`][MicropubUpdate::apply] to update the post.
+    fn update_post(&self, url: &str, update: MicropubUpdate) -> impl Future<Output = Result<()>> + Send {
+        let fut = self.update_with(url, |post| {
+            update.apply(post);
+        });
+
+        // The old interface didn't return anything, the new interface
+        // returns the old and new post. Adapt accordingly.
+        futures::TryFutureExt::map_ok(fut, |(_old, _new)| ())
+    }
+
+    /// Modify a post using an arbitrary closure.
+    ///
+    /// Note to implementors: the update operation MUST be atomic and
+    /// SHOULD lock the database to prevent two clients overwriting
+    /// each other's changes or simply corrupting something. Rejecting
+    /// is allowed in case of concurrent updates if waiting for a lock
+    /// cannot be done.
+    ///
+    /// Returns old post and the new post after editing.
+    fn update_with<F: FnOnce(&mut serde_json::Value) + Send>(
+        &self, url: &str, f: F
+    ) -> impl Future<Output = Result<(serde_json::Value, serde_json::Value)>> + Send;
 
     /// Get a list of channels available for the user represented by
     /// the `user` domain to write to.
diff --git a/src/database/postgres/mod.rs b/src/database/postgres/mod.rs
index 27ec51c..b705eed 100644
--- a/src/database/postgres/mod.rs
+++ b/src/database/postgres/mod.rs
@@ -189,11 +189,13 @@ WHERE
         txn.commit().await.map_err(Into::into)
     }
 
-    #[tracing::instrument(skip(self))]
-    async fn update_post(&self, url: &str, update: MicropubUpdate) -> Result<()> {
+    #[tracing::instrument(skip(self), fields(f = std::any::type_name::<F>()))]
+    async fn update_with<F: FnOnce(&mut serde_json::Value) + Send>(
+        &self, url: &str, f: F
+    ) -> Result<(serde_json::Value, serde_json::Value)> {
         tracing::debug!("Updating post {}", url);
         let mut txn = self.db.begin().await?;
-        let (uid, mut post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE")
+        let (uid, old_post) = sqlx::query_as::<_, (String, serde_json::Value)>("SELECT uid, mf2 FROM kittybox.mf2_json WHERE uid = $1 OR mf2['properties']['url'] ? $1 FOR UPDATE")
             .bind(url)
             .fetch_optional(&mut *txn)
             .await?
@@ -202,15 +204,22 @@ WHERE
                 "The specified post wasn't found in the database."
             ))?;
 
-        update.apply(&mut post);
+        let new_post = {
+            let mut post = old_post.clone();
+            tokio::task::block_in_place(|| f(&mut post));
+
+            post
+        };
 
         sqlx::query("UPDATE kittybox.mf2_json SET mf2 = $2 WHERE uid = $1")
             .bind(uid)
-            .bind(post)
+            .bind(&new_post)
             .execute(&mut *txn)
             .await?;
 
-        txn.commit().await.map_err(Into::into)
+        txn.commit().await?;
+
+        Ok((old_post, new_post))
     }
 
     #[tracing::instrument(skip(self))]