From 7e8e688e2e58f9c944b941e768ab7b034a348a1f Mon Sep 17 00:00:00 2001 From: Vika Date: Thu, 1 Aug 2024 19:48:37 +0300 Subject: treewide: create a common method for state initialization Now the database objects can be uniformly created from a URI. They can also optionally do sanity checks and one-time initialization. --- src/database/file/mod.rs | 14 +++++--------- src/database/memory.rs | 21 +++++---------------- src/database/mod.rs | 10 +++++----- src/database/postgres/mod.rs | 30 +++++++++++++++--------------- src/indieauth/backend.rs | 2 ++ src/indieauth/backend/fs.rs | 12 ++++++------ src/media/storage/file.rs | 11 +++++------ src/media/storage/mod.rs | 2 ++ src/micropub/mod.rs | 8 ++++---- src/webmentions/queue.rs | 8 ++++---- 10 files changed, 53 insertions(+), 65 deletions(-) diff --git a/src/database/file/mod.rs b/src/database/file/mod.rs index f6715e1..ba8201f 100644 --- a/src/database/file/mod.rs +++ b/src/database/file/mod.rs @@ -197,15 +197,7 @@ fn modify_post(post: &serde_json::Value, update: MicropubUpdate) -> Result Result { - // TODO check if the dir is writable - Ok(Self { root_dir }) - } + pub(super) root_dir: PathBuf, } async fn hydrate_author( @@ -255,6 +247,10 @@ async fn hydrate_author( #[async_trait] impl Storage for FileStorage { + async fn new(url: &'_ url::Url) -> Result { + // TODO: sanity check + Ok(Self { root_dir: PathBuf::from(url.path()) }) + } #[tracing::instrument(skip(self))] async fn categories(&self, url: &str) -> Result> { // This requires an expensive scan through the entire diff --git a/src/database/memory.rs b/src/database/memory.rs index 56caeec..be37fed 100644 --- a/src/database/memory.rs +++ b/src/database/memory.rs @@ -8,7 +8,7 @@ use tokio::sync::RwLock; use crate::database::{ErrorKind, MicropubChannel, Result, settings, Storage, StorageError}; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct MemoryStorage { pub mapping: Arc>>, pub channels: Arc>>>, @@ -16,6 +16,10 @@ pub struct MemoryStorage { #[async_trait] impl Storage for MemoryStorage { + async fn new(_url: &url::Url) -> Result { + Ok(Self::default()) + } + async fn categories(&self, _url: &str) -> Result> { unimplemented!() } @@ -231,18 +235,3 @@ impl Storage for MemoryStorage { } } - -impl Default for MemoryStorage { - fn default() -> Self { - Self::new() - } -} - -impl MemoryStorage { - pub fn new() -> Self { - Self { - mapping: Arc::new(RwLock::new(HashMap::new())), - channels: Arc::new(RwLock::new(HashMap::new())), - } - } -} diff --git a/src/database/mod.rs b/src/database/mod.rs index f48b4a9..c256867 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -215,6 +215,8 @@ pub type Result = std::result::Result; /// or lock the database so that write conflicts or reading half-written data should not occur. #[async_trait] pub trait Storage: std::fmt::Debug + Clone + Send + Sync { + /// Initialize Self from a URL, possibly performing initialization. + async fn new(url: &'_ url::Url) -> Result; /// Return the list of categories used in blog posts of a specified blog. async fn categories(&self, url: &str) -> Result>; @@ -759,11 +761,9 @@ mod tests { #[tracing_test::traced_test] async fn $func_name() { let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); - let backend = super::super::FileStorage::new( - tempdir.path().to_path_buf() - ) - .await - .unwrap(); + let backend = super::super::FileStorage { + root_dir: tempdir.path().to_path_buf() + }; super::$func_name(backend).await } }; diff --git a/src/database/postgres/mod.rs b/src/database/postgres/mod.rs index 7813045..0ebaffb 100644 --- a/src/database/postgres/mod.rs +++ b/src/database/postgres/mod.rs @@ -2,7 +2,7 @@ use std::borrow::Cow; use std::str::FromStr; use kittybox_util::{MicropubChannel, MentionType}; -use sqlx::{PgPool, Executor}; +use sqlx::{ConnectOptions, Executor, PgPool}; use crate::micropub::{MicropubUpdate, MicropubPropertyDeletion}; use super::settings::Setting; @@ -36,6 +36,17 @@ pub struct PostgresStorage { } impl PostgresStorage { + /// Construct a [`PostgresStorage`] from a [`sqlx::PgPool`], + /// running appropriate migrations. + pub async fn from_pool(db: sqlx::PgPool) -> Result { + db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox")).await?; + MIGRATOR.run(&db).await?; + Ok(Self { db }) + } +} + +#[async_trait::async_trait] +impl Storage for PostgresStorage { /// Construct a new [`PostgresStorage`] from an URI string and run /// migrations on the database. /// @@ -43,9 +54,9 @@ impl PostgresStorage { /// password from the file at the specified path. If, instead, /// the `PGPASS` environment variable is present, read the /// password from it. - pub async fn new(uri: &str) -> Result { - tracing::debug!("Postgres URL: {uri}"); - let mut options = sqlx::postgres::PgConnectOptions::from_str(uri)? + async fn new(url: &'_ url::Url) -> Result { + tracing::debug!("Postgres URL: {url}"); + let mut options = sqlx::postgres::PgConnectOptions::from_url(url)? .options([("search_path", "kittybox")]); if let Ok(password_file) = std::env::var("PGPASS_FILE") { let password = tokio::fs::read_to_string(password_file).await.unwrap(); @@ -62,17 +73,6 @@ impl PostgresStorage { } - /// Construct a [`PostgresStorage`] from a [`sqlx::PgPool`], - /// running appropriate migrations. - pub async fn from_pool(db: sqlx::PgPool) -> Result { - db.execute(sqlx::query("CREATE SCHEMA IF NOT EXISTS kittybox")).await?; - MIGRATOR.run(&db).await?; - Ok(Self { db }) - } -} - -#[async_trait::async_trait] -impl Storage for PostgresStorage { #[tracing::instrument(skip(self))] async fn categories(&self, url: &str) -> Result> { sqlx::query_scalar::<_, String>(" diff --git a/src/indieauth/backend.rs b/src/indieauth/backend.rs index 534bcfb..5814dc2 100644 --- a/src/indieauth/backend.rs +++ b/src/indieauth/backend.rs @@ -11,6 +11,8 @@ pub use fs::FileBackend; #[async_trait::async_trait] pub trait AuthBackend: Clone + Send + Sync + 'static { + /// Initialize self from URL, possibly performing initialization. + async fn new(url: &'_ url::Url) -> Result; // Authorization code management. /// Create a one-time OAuth2 authorization code for the passed /// authorization request, and save it for later retrieval. diff --git a/src/indieauth/backend/fs.rs b/src/indieauth/backend/fs.rs index 80c3703..5e97ae5 100644 --- a/src/indieauth/backend/fs.rs +++ b/src/indieauth/backend/fs.rs @@ -20,12 +20,6 @@ pub struct FileBackend { } impl FileBackend { - pub fn new>(path: T) -> Self { - Self { - path: path.into() - } - } - /// Sanitize a filename, leaving only alphanumeric characters. /// /// Doesn't allocate a new string unless non-alphanumeric @@ -193,6 +187,12 @@ impl FileBackend { #[async_trait] impl AuthBackend for FileBackend { + async fn new(path: &'_ url::Url) -> Result { + Ok(Self { + path: std::path::PathBuf::from(path.path()) + }) + } + // Authorization code management. async fn create_code(&self, data: AuthorizationRequest) -> Result { self.serialize_to_file("codes", None, CODE_LENGTH, data).await diff --git a/src/media/storage/file.rs b/src/media/storage/file.rs index a910eca..7250a6b 100644 --- a/src/media/storage/file.rs +++ b/src/media/storage/file.rs @@ -31,10 +31,6 @@ impl From for MediaStoreError { } impl FileStore { - pub fn new>(base: T) -> Self { - Self { base: base.into() } - } - async fn mktemp(&self) -> Result<(PathBuf, BufWriter)> { kittybox_util::fs::mktemp(&self.base, "temp", 16) .await @@ -45,6 +41,9 @@ impl FileStore { #[async_trait] impl MediaStore for FileStore { + async fn new(url: &'_ url::Url) -> Result { + Ok(Self { base: url.path().into() }) + } #[tracing::instrument(skip(self, content))] async fn write_streaming( @@ -261,7 +260,7 @@ mod tests { #[tracing_test::traced_test] async fn test_ranges() { let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); - let store = FileStore::new(tempdir.path()); + let store = FileStore { base: tempdir.path().to_path_buf() }; let file: &[u8] = include_bytes!("./file.rs"); let stream = tokio_stream::iter(file.chunks(100).map(|i| Ok(bytes::Bytes::copy_from_slice(i)))); @@ -372,7 +371,7 @@ mod tests { #[tracing_test::traced_test] async fn test_streaming_read_write() { let tempdir = tempfile::tempdir().expect("Failed to create tempdir"); - let store = FileStore::new(tempdir.path()); + let store = FileStore { base: tempdir.path().to_path_buf() }; let file: &[u8] = include_bytes!("./file.rs"); let stream = tokio_stream::iter(file.chunks(100).map(|i| Ok(bytes::Bytes::copy_from_slice(i)))); diff --git a/src/media/storage/mod.rs b/src/media/storage/mod.rs index 020999c..38410e6 100644 --- a/src/media/storage/mod.rs +++ b/src/media/storage/mod.rs @@ -86,6 +86,8 @@ pub type Result = std::result::Result; #[async_trait] pub trait MediaStore: 'static + Send + Sync + Clone { + // Initialize self from a URL, possibly performing asynchronous initialization. + async fn new(url: &'_ url::Url) -> Result; async fn write_streaming( &self, domain: &str, diff --git a/src/micropub/mod.rs b/src/micropub/mod.rs index 74f53a0..624c239 100644 --- a/src/micropub/mod.rs +++ b/src/micropub/mod.rs @@ -760,7 +760,7 @@ mod tests { #[tokio::test] async fn test_post_reject_scope() { - let db = crate::database::MemoryStorage::new(); + let db = crate::database::MemoryStorage::default(); let post = json!({ "type": ["h-entry"], @@ -788,7 +788,7 @@ mod tests { #[tokio::test] async fn test_post_reject_different_user() { - let db = crate::database::MemoryStorage::new(); + let db = crate::database::MemoryStorage::default(); let post = json!({ "type": ["h-entry"], @@ -818,7 +818,7 @@ mod tests { #[tokio::test] async fn test_post_mf2() { - let db = crate::database::MemoryStorage::new(); + let db = crate::database::MemoryStorage::default(); let post = json!({ "type": ["h-entry"], @@ -850,7 +850,7 @@ mod tests { #[tokio::test] async fn test_query_foreign_url() { let mut res = super::query( - axum::Extension(crate::database::MemoryStorage::new()), + axum::Extension(crate::database::MemoryStorage::default()), Some(axum::extract::Query(super::MicropubQuery::source( "https://aaronparecki.com/feeds/main", ))), diff --git a/src/webmentions/queue.rs b/src/webmentions/queue.rs index b811e71..af1387f 100644 --- a/src/webmentions/queue.rs +++ b/src/webmentions/queue.rs @@ -1,7 +1,7 @@ -use std::{pin::Pin, str::FromStr}; +use std::pin::Pin; use futures_util::{Stream, StreamExt}; -use sqlx::{postgres::PgListener, Executor}; +use sqlx::{postgres::PgListener, ConnectOptions, Executor}; use uuid::Uuid; use super::Webmention; @@ -115,8 +115,8 @@ impl Clone for PostgresJobQueue { } impl PostgresJobQueue { - pub async fn new(uri: &str) -> Result { - let mut options = sqlx::postgres::PgConnectOptions::from_str(uri)? + pub async fn new(uri: &url::Url) -> Result { + let mut options = sqlx::postgres::PgConnectOptions::from_url(uri)? .options([("search_path", "kittybox_webmention")]); if let Ok(password_file) = std::env::var("PGPASS_FILE") { let password = tokio::fs::read_to_string(password_file).await.unwrap(); -- cgit 1.4.1