about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorVika <vika@fireburn.ru>2024-08-01 19:48:37 +0300
committerVika <vika@fireburn.ru>2024-08-01 20:40:00 +0300
commit7e8e688e2e58f9c944b941e768ab7b034a348a1f (patch)
tree1068469c6b9b97bac407038276fd8971b2101e48 /src
parent57a9c3c7e520714928904fc7e2ff3d62ac2b2467 (diff)
downloadkittybox-7e8e688e2e58f9c944b941e768ab7b034a348a1f.tar.zst
treewide: create a common method for state initialization
Now the database objects can be uniformly created from a URI. They can
also optionally do sanity checks and one-time initialization.
Diffstat (limited to 'src')
-rw-r--r--src/database/file/mod.rs14
-rw-r--r--src/database/memory.rs21
-rw-r--r--src/database/mod.rs10
-rw-r--r--src/database/postgres/mod.rs30
-rw-r--r--src/indieauth/backend.rs2
-rw-r--r--src/indieauth/backend/fs.rs12
-rw-r--r--src/media/storage/file.rs11
-rw-r--r--src/media/storage/mod.rs2
-rw-r--r--src/micropub/mod.rs8
-rw-r--r--src/webmentions/queue.rs8
10 files changed, 53 insertions, 65 deletions
diff --git a/src/database/file/mod.rs b/src/database/file/mod.rs
index f6715e1..ba8201f 100644
--- a/src/database/file/mod.rs
+++ b/src/database/file/mod.rs
@@ -197,15 +197,7 @@ fn modify_post(post: &serde_json::Value, update: MicropubUpdate) -> Result<serde
 /// A backend using a folder with JSON files as a backing store.
 /// Uses symbolic links to represent a many-to-one mapping of URLs to a post.
 pub struct FileStorage {
-    root_dir: PathBuf,
-}
-
-impl FileStorage {
-    /// Create a new storage wrapping a folder specified by root_dir.
-    pub async fn new(root_dir: PathBuf) -> Result<Self> {
-        // TODO check if the dir is writable
-        Ok(Self { root_dir })
-    }
+    pub(super) root_dir: PathBuf,
 }
 
 async fn hydrate_author<S: Storage>(
@@ -255,6 +247,10 @@ async fn hydrate_author<S: Storage>(
 
 #[async_trait]
 impl Storage for FileStorage {
+    async fn new(url: &'_ url::Url) -> Result<Self> {
+        // TODO: sanity check
+        Ok(Self { root_dir: PathBuf::from(url.path()) })
+    }
     #[tracing::instrument(skip(self))]
     async fn categories(&self, url: &str) -> Result<Vec<String>> {
         // This requires an expensive scan through the entire
diff --git a/src/database/memory.rs b/src/database/memory.rs
index 56caeec..be37fed 100644
--- a/src/database/memory.rs
+++ b/src/database/memory.rs
@@ -8,7 +8,7 @@ use tokio::sync::RwLock;
 
 use crate::database::{ErrorKind, MicropubChannel, Result, settings, Storage, StorageError};
 
-#[derive(Clone, Debug)]
+#[derive(Clone, Debug, Default)]
 pub struct MemoryStorage {
     pub mapping: Arc<RwLock<HashMap<String, serde_json::Value>>>,
     pub channels: Arc<RwLock<HashMap<url::Url, Vec<String>>>>,
@@ -16,6 +16,10 @@ pub struct MemoryStorage {
 
 #[async_trait]
 impl Storage for MemoryStorage {
+    async fn new(_url: &url::Url) -> Result<Self> {
+        Ok(Self::default())
+    }
+
     async fn categories(&self, _url: &str) -> Result<Vec<String>> {
         unimplemented!()
     }
@@ -231,18 +235,3 @@ impl Storage for MemoryStorage {
     }
 
 }
-
-impl Default for MemoryStorage {
-    fn default() -> Self {
-        Self::new()
-    }
-}
-
-impl MemoryStorage {
-    pub fn new() -> Self {
-        Self {
-            mapping: Arc::new(RwLock::new(HashMap::new())),
-            channels: Arc::new(RwLock::new(HashMap::new())),
-        }
-    }
-}
diff --git a/src/database/mod.rs b/src/database/mod.rs
index f48b4a9..c256867 100644
--- a/src/database/mod.rs
+++ b/src/database/mod.rs
@@ -215,6 +215,8 @@ pub type Result<T> = std::result::Result<T, StorageError>;
 /// or lock the database so that write conflicts or reading half-written data should not occur.
 #[async_trait]
 pub trait Storage: std::fmt::Debug + Clone + Send + Sync {
+    /// Initialize Self from a URL, possibly performing initialization.
+    async fn new(url: &'_ url::Url) -> Result<Self>;
     /// Return the list of categories used in blog posts of a specified blog.
     async fn categories(&self, url: &str) -> Result<Vec<String>>;
 
@@ -759,11 +761,9 @@ mod tests {
             #[tracing_test::traced_test]
             async fn $func_name() {
                 let tempdir = tempfile::tempdir().expect("Failed to create tempdir");
-                let backend = super::super::FileStorage::new(
-                    tempdir.path().to_path_buf()
-                )
-                    .await
-                    .unwrap();
+                let backend = super::super::FileStorage {
+                    root_dir: tempdir.path().to_path_buf()
+                };
                 super::$func_name(backend).await
             }
         };
diff --git a/src/database/postgres/mod.rs b/src/database/postgres/mod.rs
index 7813045..0ebaffb 100644
--- a/src/database/postgres/mod.rs
+++ b/src/database/postgres/mod.rs
@@ -2,7 +2,7 @@ use std::borrow::Cow;
 use std::str::FromStr;
 
 use kittybox_util::{MicropubChannel, MentionType};
-use sqlx::{PgPool, Executor};
+use sqlx::{ConnectOptions, Executor, PgPool};
 use crate::micropub::{MicropubUpdate, MicropubPropertyDeletion};
 
 use super::settings::Setting;
@@ -36,6 +36,17 @@ pub struct PostgresStorage {
 }
 
 impl PostgresStorage {
+    /// Construct a [`PostgresStorage`] from a [`sqlx::PgPool`],
+    /// running appropriate migrations.
+    pub async fn from_pool(db: sqlx::PgPool) -> Result<Self> {
+        db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox")).await?;
+        MIGRATOR.run(&db).await?;
+        Ok(Self { db })
+    }
+}
+
+#[async_trait::async_trait]
+impl Storage for PostgresStorage {
     /// Construct a new [`PostgresStorage`] from an URI string and run
     /// migrations on the database.
     ///
@@ -43,9 +54,9 @@ impl PostgresStorage {
     /// password from the file at the specified path. If, instead,
     /// the `PGPASS` environment variable is present, read the
     /// password from it.
-    pub async fn new(uri: &str) -> Result<Self> {
-        tracing::debug!("Postgres URL: {uri}");
-        let mut options = sqlx::postgres::PgConnectOptions::from_str(uri)?
+    async fn new(url: &'_ url::Url) -> Result<Self> {
+        tracing::debug!("Postgres URL: {url}");
+        let mut options = sqlx::postgres::PgConnectOptions::from_url(url)?
             .options([("search_path", "kittybox")]);
         if let Ok(password_file) = std::env::var("PGPASS_FILE") {
             let password = tokio::fs::read_to_string(password_file).await.unwrap();
@@ -62,17 +73,6 @@ impl PostgresStorage {
 
     }
 
-    /// Construct a [`PostgresStorage`] from a [`sqlx::PgPool`],
-    /// running appropriate migrations.
-    pub async fn from_pool(db: sqlx::PgPool) -> Result<Self> {
-        db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox")).await?;
-        MIGRATOR.run(&db).await?;
-        Ok(Self { db })
-    }
-}
-
-#[async_trait::async_trait]
-impl Storage for PostgresStorage {
     #[tracing::instrument(skip(self))]
     async fn categories(&self, url: &str) -> Result<Vec<String>> {
         sqlx::query_scalar::<_, String>("
diff --git a/src/indieauth/backend.rs b/src/indieauth/backend.rs
index 534bcfb..5814dc2 100644
--- a/src/indieauth/backend.rs
+++ b/src/indieauth/backend.rs
@@ -11,6 +11,8 @@ pub use fs::FileBackend;
 
 #[async_trait::async_trait]
 pub trait AuthBackend: Clone + Send + Sync + 'static {
+    /// Initialize self from URL, possibly performing initialization.
+    async fn new(url: &'_ url::Url) -> Result<Self>;
     // Authorization code management.
     /// Create a one-time OAuth2 authorization code for the passed
     /// authorization request, and save it for later retrieval.
diff --git a/src/indieauth/backend/fs.rs b/src/indieauth/backend/fs.rs
index 80c3703..5e97ae5 100644
--- a/src/indieauth/backend/fs.rs
+++ b/src/indieauth/backend/fs.rs
@@ -20,12 +20,6 @@ pub struct FileBackend {
 }
 
 impl FileBackend {
-    pub fn new<T: Into<PathBuf>>(path: T) -> Self {
-        Self {
-            path: path.into()
-        }
-    }
-
     /// Sanitize a filename, leaving only alphanumeric characters.
     ///
     /// Doesn't allocate a new string unless non-alphanumeric
@@ -193,6 +187,12 @@ impl FileBackend {
 
 #[async_trait]
 impl AuthBackend for FileBackend {
+    async fn new(path: &'_ url::Url) -> Result<Self> {
+        Ok(Self {
+            path: std::path::PathBuf::from(path.path())
+        })
+    }
+
     // Authorization code management.
     async fn create_code(&self, data: AuthorizationRequest) -> Result<String> {
         self.serialize_to_file("codes", None, CODE_LENGTH, data).await
diff --git a/src/media/storage/file.rs b/src/media/storage/file.rs
index a910eca..7250a6b 100644
--- a/src/media/storage/file.rs
+++ b/src/media/storage/file.rs
@@ -31,10 +31,6 @@ impl From<tokio::io::Error> for MediaStoreError {
 }
 
 impl FileStore {
-    pub fn new<T: Into<PathBuf>>(base: T) -> Self {
-        Self { base: base.into() }
-    }
-
     async fn mktemp(&self) -> Result<(PathBuf, BufWriter<tokio::fs::File>)> {
         kittybox_util::fs::mktemp(&self.base, "temp", 16)
             .await
@@ -45,6 +41,9 @@ impl FileStore {
 
 #[async_trait]
 impl MediaStore for FileStore {
+    async fn new(url: &'_ url::Url) -> Result<Self> {
+        Ok(Self { base: url.path().into() })
+    }
 
     #[tracing::instrument(skip(self, content))]
     async fn write_streaming<T>(
@@ -261,7 +260,7 @@ mod tests {
     #[tracing_test::traced_test]
     async fn test_ranges() {
         let tempdir = tempfile::tempdir().expect("Failed to create tempdir");
-        let store = FileStore::new(tempdir.path());
+        let store = FileStore { base: tempdir.path().to_path_buf() };
 
         let file: &[u8] = include_bytes!("./file.rs");
         let stream = tokio_stream::iter(file.chunks(100).map(|i| Ok(bytes::Bytes::copy_from_slice(i))));
@@ -372,7 +371,7 @@ mod tests {
     #[tracing_test::traced_test]
     async fn test_streaming_read_write() {
         let tempdir = tempfile::tempdir().expect("Failed to create tempdir");
-        let store = FileStore::new(tempdir.path());
+        let store = FileStore { base: tempdir.path().to_path_buf() };
 
         let file: &[u8] = include_bytes!("./file.rs");
         let stream = tokio_stream::iter(file.chunks(100).map(|i| Ok(bytes::Bytes::copy_from_slice(i))));
diff --git a/src/media/storage/mod.rs b/src/media/storage/mod.rs
index 020999c..38410e6 100644
--- a/src/media/storage/mod.rs
+++ b/src/media/storage/mod.rs
@@ -86,6 +86,8 @@ pub type Result<T> = std::result::Result<T, MediaStoreError>;
 
 #[async_trait]
 pub trait MediaStore: 'static + Send + Sync + Clone {
+    // Initialize self from a URL, possibly performing asynchronous initialization.
+    async fn new(url: &'_ url::Url) -> Result<Self>;
     async fn write_streaming<T>(
         &self,
         domain: &str,
diff --git a/src/micropub/mod.rs b/src/micropub/mod.rs
index 74f53a0..624c239 100644
--- a/src/micropub/mod.rs
+++ b/src/micropub/mod.rs
@@ -760,7 +760,7 @@ mod tests {
 
     #[tokio::test]
     async fn test_post_reject_scope() {
-        let db = crate::database::MemoryStorage::new();
+        let db = crate::database::MemoryStorage::default();
 
         let post = json!({
             "type": ["h-entry"],
@@ -788,7 +788,7 @@ mod tests {
 
     #[tokio::test]
     async fn test_post_reject_different_user() {
-        let db = crate::database::MemoryStorage::new();
+        let db = crate::database::MemoryStorage::default();
 
         let post = json!({
             "type": ["h-entry"],
@@ -818,7 +818,7 @@ mod tests {
 
     #[tokio::test]
     async fn test_post_mf2() {
-        let db = crate::database::MemoryStorage::new();
+        let db = crate::database::MemoryStorage::default();
 
         let post = json!({
             "type": ["h-entry"],
@@ -850,7 +850,7 @@ mod tests {
     #[tokio::test]
     async fn test_query_foreign_url() {
         let mut res = super::query(
-            axum::Extension(crate::database::MemoryStorage::new()),
+            axum::Extension(crate::database::MemoryStorage::default()),
             Some(axum::extract::Query(super::MicropubQuery::source(
                 "https://aaronparecki.com/feeds/main",
             ))),
diff --git a/src/webmentions/queue.rs b/src/webmentions/queue.rs
index b811e71..af1387f 100644
--- a/src/webmentions/queue.rs
+++ b/src/webmentions/queue.rs
@@ -1,7 +1,7 @@
-use std::{pin::Pin, str::FromStr};
+use std::pin::Pin;
 
 use futures_util::{Stream, StreamExt};
-use sqlx::{postgres::PgListener, Executor};
+use sqlx::{postgres::PgListener, ConnectOptions, Executor};
 use uuid::Uuid;
 
 use super::Webmention;
@@ -115,8 +115,8 @@ impl<T> Clone for PostgresJobQueue<T> {
 }
 
 impl PostgresJobQueue<Webmention> {
-    pub async fn new(uri: &str) -> Result<Self, sqlx::Error> {
-        let mut options = sqlx::postgres::PgConnectOptions::from_str(uri)?
+    pub async fn new(uri: &url::Url) -> Result<Self, sqlx::Error> {
+        let mut options = sqlx::postgres::PgConnectOptions::from_url(uri)?
             .options([("search_path", "kittybox_webmention")]);
         if let Ok(password_file) = std::env::var("PGPASS_FILE") {
             let password = tokio::fs::read_to_string(password_file).await.unwrap();