From 2e9c292bb989ffff2c99aa2a6062962c913b3586 Mon Sep 17 00:00:00 2001 From: Vika Date: Tue, 9 Jul 2024 22:43:21 +0300 Subject: database: use Url to represent user authorities This makes the interface more consistent and resistant to misuse. --- src/database/file/mod.rs | 32 ++++++++++++++------------- src/database/memory.rs | 23 ++++++++------------ src/database/mod.rs | 34 ++++++++++++++--------------- src/database/postgres/mod.rs | 27 ++++++++++++----------- src/frontend/mod.rs | 52 ++++++++++++++++++++++++-------------------- src/frontend/onboarding.rs | 8 +++---- src/micropub/mod.rs | 6 ++--- src/micropub/util.rs | 2 +- 8 files changed, 93 insertions(+), 91 deletions(-) diff --git a/src/database/file/mod.rs b/src/database/file/mod.rs index 46660ab..f6715e1 100644 --- a/src/database/file/mod.rs +++ b/src/database/file/mod.rs @@ -210,7 +210,8 @@ impl FileStorage { async fn hydrate_author( feed: &mut serde_json::Value, - user: &'_ Option, + // Unused? + user: Option<&url::Url>, storage: &S, ) { let url = feed["properties"]["uid"][0] @@ -226,6 +227,7 @@ async fn hydrate_author( let author_list: Vec = stream::iter(author.iter()) .then(|i| async move { if let Some(i) = i.as_str() { + // BUG: Use `user` to sanitize? match storage.get_post(i).await { Ok(post) => match post { Some(post) => post, @@ -319,7 +321,7 @@ impl Storage for FileStorage { } #[tracing::instrument(skip(self))] - async fn put_post(&self, post: &'_ serde_json::Value, user: &'_ str) -> Result<()> { + async fn put_post(&self, post: &'_ serde_json::Value, user: &url::Url) -> Result<()> { let key = post["properties"]["uid"][0] .as_str() .expect("Tried to save a post without UID"); @@ -358,7 +360,7 @@ impl Storage for FileStorage { .unwrap_or_default() ) }; - if url != key && url_domain == user { + if url != key && url_domain == user.authority() { let link = url_to_path(&self.root_dir, url); debug!("Creating a symlink at {:?}", link); let orig = path.clone(); @@ -386,7 +388,7 @@ impl Storage for FileStorage { // Add the h-feed to the channel list let path = { let mut path = relative_path::RelativePathBuf::new(); - path.push(user); + path.push(user.authority()); path.push("channels"); path.to_path(&self.root_dir) @@ -487,9 +489,9 @@ impl Storage for FileStorage { } #[tracing::instrument(skip(self))] - async fn get_channels(&self, user: &'_ str) -> Result> { + async fn get_channels(&self, user: &url::Url) -> Result> { let mut path = relative_path::RelativePathBuf::new(); - path.push(user); + path.push(user.authority()); path.push("channels"); let path = path.to_path(&self.root_dir); @@ -521,13 +523,13 @@ impl Storage for FileStorage { url: &'_ str, cursor: Option<&'_ str>, limit: usize, - user: Option<&'_ str> + user: Option<&url::Url> ) -> Result)>> { Ok(self.read_feed_with_limit( url, - &cursor.map(|v| v.to_owned()), + cursor, limit, - &user.map(|v| v.to_owned()) + user ).await? .map(|feed| { tracing::debug!("Feed: {:#}", serde_json::Value::Array( @@ -555,9 +557,9 @@ impl Storage for FileStorage { async fn read_feed_with_limit( &self, url: &'_ str, - after: &'_ Option, + after: Option<&str>, limit: usize, - user: &'_ Option, + user: Option<&url::Url>, ) -> Result> { if let Some(mut feed) = self.get_post(url).await? { if feed["children"].is_array() { @@ -627,10 +629,10 @@ impl Storage for FileStorage { } #[tracing::instrument(skip(self))] - async fn get_setting, 'a>(&self, user: &'_ str) -> Result { + async fn get_setting, 'a>(&self, user: &url::Url) -> Result { debug!("User for getting settings: {}", user); let mut path = relative_path::RelativePathBuf::new(); - path.push(user); + path.push(user.authority()); path.push("settings"); let path = path.to_path(&self.root_dir); @@ -648,9 +650,9 @@ impl Storage for FileStorage { } #[tracing::instrument(skip(self))] - async fn set_setting + 'a, 'a>(&self, user: &'a str, value: S::Data) -> Result<()> { + async fn set_setting + 'a, 'a>(&self, user: &'a url::Url, value: S::Data) -> Result<()> { let mut path = relative_path::RelativePathBuf::new(); - path.push(user); + path.push(user.authority()); path.push("settings"); let path = path.to_path(&self.root_dir); diff --git a/src/database/memory.rs b/src/database/memory.rs index 564f451..56caeec 100644 --- a/src/database/memory.rs +++ b/src/database/memory.rs @@ -11,7 +11,7 @@ use crate::database::{ErrorKind, MicropubChannel, Result, settings, Storage, Sto #[derive(Clone, Debug)] pub struct MemoryStorage { pub mapping: Arc>>, - pub channels: Arc>>>, + pub channels: Arc>>>, } #[async_trait] @@ -45,7 +45,7 @@ impl Storage for MemoryStorage { } } - async fn put_post(&self, post: &'_ serde_json::Value, _user: &'_ str) -> Result<()> { + async fn put_post(&self, post: &'_ serde_json::Value, user: &url::Url) -> Result<()> { let mapping = &mut self.mapping.write().await; let key: &str = match post["properties"]["uid"][0].as_str() { Some(uid) => uid, @@ -80,12 +80,7 @@ impl Storage for MemoryStorage { self.channels .write() .await - .entry( - post["properties"]["author"][0] - .as_str() - .unwrap() - .to_string(), - ) + .entry(user.clone()) .or_insert_with(Vec::new) .push(key.to_string()) } @@ -165,7 +160,7 @@ impl Storage for MemoryStorage { Ok(()) } - async fn get_channels(&self, user: &'_ str) -> Result> { + async fn get_channels(&self, user: &url::Url) -> Result> { match self.channels.read().await.get(user) { Some(channels) => Ok(futures_util::future::join_all( channels @@ -197,9 +192,9 @@ impl Storage for MemoryStorage { async fn read_feed_with_limit( &self, url: &'_ str, - after: &'_ Option, + after: Option<&str>, limit: usize, - user: &'_ Option, + user: Option<&url::Url>, ) -> Result> { todo!() } @@ -210,7 +205,7 @@ impl Storage for MemoryStorage { url: &'_ str, cursor: Option<&'_ str>, limit: usize, - user: Option<&'_ str> + user: Option<&url::Url> ) -> Result)>> { todo!() } @@ -221,12 +216,12 @@ impl Storage for MemoryStorage { } #[allow(unused_variables)] - async fn get_setting, 'a>(&'_ self, user: &'_ str) -> Result { + async fn get_setting, 'a>(&'_ self, user: &url::Url) -> Result { todo!() } #[allow(unused_variables)] - async fn set_setting + 'a, 'a>(&self, user: &'a str, value: S::Data) -> Result<()> { + async fn set_setting + 'a, 'a>(&self, user: &'a url::Url, value: S::Data) -> Result<()> { todo!() } diff --git a/src/database/mod.rs b/src/database/mod.rs index a6a3b46..f48b4a9 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -227,7 +227,7 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { /// Save a post to the database as an MF2-JSON structure. /// /// Note that the `post` object MUST have `post["properties"]["uid"][0]` defined. - async fn put_post(&self, post: &'_ serde_json::Value, user: &'_ str) -> Result<()>; + async fn put_post(&self, post: &'_ serde_json::Value, user: &url::Url) -> Result<()>; /// Add post to feed. Some database implementations might have optimized ways to do this. #[tracing::instrument(skip(self))] @@ -258,7 +258,7 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { /// Get a list of channels available for the user represented by /// the `user` domain to write to. - async fn get_channels(&self, user: &'_ str) -> Result>; + async fn get_channels(&self, user: &url::Url) -> Result>; /// Fetch a feed at `url` and return an h-feed object containing /// `limit` posts after a post by url `after`, filtering the content @@ -279,9 +279,9 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { async fn read_feed_with_limit( &self, url: &'_ str, - after: &'_ Option, + after: Option<&str>, limit: usize, - user: &'_ Option, + user: Option<&url::Url>, ) -> Result>; /// Fetch a feed at `url` and return an h-feed object containing @@ -307,17 +307,17 @@ pub trait Storage: std::fmt::Debug + Clone + Send + Sync { url: &'_ str, cursor: Option<&'_ str>, limit: usize, - user: Option<&'_ str> + user: Option<&url::Url> ) -> Result)>>; /// Deletes a post from the database irreversibly. Must be idempotent. async fn delete_post(&self, url: &'_ str) -> Result<()>; /// Gets a setting from the setting store and passes the result. - async fn get_setting, 'a>(&'_ self, user: &'_ str) -> Result; + async fn get_setting, 'a>(&'_ self, user: &url::Url) -> Result; /// Commits a setting to the setting store. - async fn set_setting + 'a, 'a>(&self, user: &'a str, value: S::Data) -> Result<()>; + async fn set_setting + 'a, 'a>(&self, user: &'a url::Url, value: S::Data) -> Result<()>; /// Add (or update) a webmention on a certian post. /// @@ -359,7 +359,7 @@ mod tests { // Reading and writing backend - .put_post(&post, "fireburn.ru") + .put_post(&post, &"https://fireburn.ru/".parse().unwrap()) .await .unwrap(); if let Some(returned_post) = backend.get_post(&key).await.unwrap() { @@ -423,7 +423,7 @@ mod tests { // Reading and writing backend - .put_post(&post, "fireburn.ru") + .put_post(&post, &"https://fireburn.ru/".parse().unwrap()) .await .unwrap(); @@ -482,10 +482,10 @@ mod tests { "children": [] }); backend - .put_post(&feed, "fireburn.ru") + .put_post(&feed, &"https://fireburn.ru/".parse().unwrap()) .await .unwrap(); - let chans = backend.get_channels("fireburn.ru").await.unwrap(); + let chans = backend.get_channels(&"https://fireburn.ru/".parse().unwrap()).await.unwrap(); assert_eq!(chans.len(), 1); assert_eq!( chans[0], @@ -499,14 +499,14 @@ mod tests { async fn test_settings(backend: Backend) { backend .set_setting::( - "https://fireburn.ru/", + &"https://fireburn.ru/".parse().unwrap(), "Vika's Hideout".to_owned() ) .await .unwrap(); assert_eq!( backend - .get_setting::("https://fireburn.ru/") + .get_setting::(&"https://fireburn.ru/".parse().unwrap()) .await .unwrap() .as_ref(), @@ -594,13 +594,13 @@ mod tests { let key = feed["properties"]["uid"][0].as_str().unwrap(); backend - .put_post(&feed, "fireburn.ru") + .put_post(&feed, &"https://fireburn.ru/".parse().unwrap()) .await .unwrap(); for (i, post) in posts.iter().rev().enumerate() { backend - .put_post(post, "fireburn.ru") + .put_post(post, &"https://fireburn.ru/".parse().unwrap()) .await .unwrap(); backend.add_to_feed(key, post["properties"]["uid"][0].as_str().unwrap()).await.unwrap(); @@ -699,7 +699,7 @@ mod tests { async fn test_webmention_addition(db: Backend) { let post = gen_random_post("fireburn.ru"); - db.put_post(&post, "fireburn.ru").await.unwrap(); + db.put_post(&post, &"https://fireburn.ru/".parse().unwrap()).await.unwrap(); const TYPE: MentionType = MentionType::Reply; let target = post["properties"]["uid"][0].as_str().unwrap(); @@ -732,7 +732,7 @@ mod tests { post }; - db.put_post(&post, "fireburn.ru").await.unwrap(); + db.put_post(&post, &"https://fireburn.ru/".parse().unwrap()).await.unwrap(); for i in post["properties"]["url"].as_array().unwrap() { let (read_post, _) = db.read_feed_with_cursor(i.as_str().unwrap(), None, 20, None).await.unwrap().unwrap(); diff --git a/src/database/postgres/mod.rs b/src/database/postgres/mod.rs index 71c4d58..7813045 100644 --- a/src/database/postgres/mod.rs +++ b/src/database/postgres/mod.rs @@ -1,4 +1,3 @@ -#![allow(unused_variables)] use std::borrow::Cow; use std::str::FromStr; @@ -111,11 +110,11 @@ WHERE } #[tracing::instrument(skip(self))] - async fn put_post(&self, post: &'_ serde_json::Value, user: &'_ str) -> Result<()> { + async fn put_post(&self, post: &'_ serde_json::Value, user: &url::Url) -> Result<()> { tracing::debug!("New post: {}", post); sqlx::query("INSERT INTO kittybox.mf2_json (uid, mf2, owner) VALUES ($1 #>> '{properties,uid,0}', $1, $2)") .bind(post) - .bind(user) + .bind(user.authority()) .execute(&self.db) .await .map(|_| ()) @@ -247,14 +246,14 @@ WHERE } #[tracing::instrument(skip(self))] - async fn get_channels(&self, user: &'_ str) -> Result> { + async fn get_channels(&self, user: &url::Url) -> Result> { /*sqlx::query_as::<_, MicropubChannel>("SELECT name, uid FROM kittybox.channels WHERE owner = $1") .bind(user) .fetch_all(&self.db) .await .map_err(|err| err.into())*/ sqlx::query_as::<_, MicropubChannel>(r#"SELECT mf2 #>> '{properties,name,0}' as name, uid FROM kittybox.mf2_json WHERE '["h-feed"]'::jsonb @> mf2['type'] AND owner = $1"#) - .bind(user) + .bind(user.authority()) .fetch_all(&self.db) .await .map_err(|err| err.into()) @@ -264,10 +263,12 @@ WHERE async fn read_feed_with_limit( &self, url: &'_ str, - after: &'_ Option, + after: Option<&str>, limit: usize, - user: &'_ Option, + // BUG: this doesn't seem to be used?! + user: Option<&url::Url>, ) -> Result> { + unimplemented!("read_feed_with_limit is insecure and deprecated"); let mut feed = match sqlx::query_as::<_, (serde_json::Value,)>(" SELECT jsonb_set( mf2, @@ -331,7 +332,7 @@ ORDER BY mf2 #>> '{properties,published,0}' DESC url: &'_ str, cursor: Option<&'_ str>, limit: usize, - user: Option<&'_ str> + user: Option<&url::Url> ) -> Result)>> { let mut txn = self.db.begin().await?; sqlx::query("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ ONLY") @@ -384,7 +385,7 @@ LIMIT $2" ) .bind(url) .bind(limit as i64) - .bind(user) + .bind(user.map(url::Url::to_string)) .bind(cursor) .fetch_all(&mut *txn) .await @@ -405,9 +406,9 @@ LIMIT $2" } #[tracing::instrument(skip(self))] - async fn get_setting, 'a>(&'_ self, user: &'_ str) -> Result { + async fn get_setting, 'a>(&'_ self, user: &url::Url) -> Result { match sqlx::query_as::<_, (serde_json::Value,)>("SELECT kittybox.get_setting($1, $2)") - .bind(user) + .bind(user.authority()) .bind(S::ID) .fetch_one(&self.db) .await @@ -418,9 +419,9 @@ LIMIT $2" } #[tracing::instrument(skip(self))] - async fn set_setting + 'a, 'a>(&self, user: &'a str, value: S::Data) -> Result<()> { + async fn set_setting + 'a, 'a>(&self, user: &'a url::Url, value: S::Data) -> Result<()> { sqlx::query("SELECT kittybox.set_setting($1, $2, $3)") - .bind(user) + .bind(user.authority()) .bind(S::ID) .bind(serde_json::to_value(S::new(value)).unwrap()) .execute(&self.db) diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 7a43532..0292171 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -87,7 +87,7 @@ impl std::fmt::Display for FrontendError { #[tracing::instrument(skip(post), fields(post = %post))] pub fn filter_post( mut post: serde_json::Value, - user: Option<&str>, + user: Option<&url::Url>, ) -> Option { if post["properties"]["deleted"][0].is_string() { tracing::debug!("Deleted post; returning tombstone instead"); @@ -108,7 +108,9 @@ pub fn filter_post( serde_json::Value::String(ref author) => author.as_str(), mf2 => mf2["properties"]["uid"][0].as_str().unwrap() } - }).collect::>(); + }) + .map(|i| i.parse().unwrap()) + .collect::>(); let visibility = post["properties"]["visibility"][0] .as_str() .unwrap_or("public"); @@ -118,12 +120,12 @@ pub fn filter_post( .as_array() .unwrap_or(&empty_vec) .iter() - .map(|i| i.as_str().unwrap())); + .map(|i| i.as_str().unwrap().parse().unwrap())); audience }; tracing::debug!("post audience = {:?}", audience); - if (visibility == "private" && !audience.iter().any(|i| Some(*i) == user)) + if (visibility == "private" && !audience.iter().any(|i| Some(i) == user)) || (visibility == "protected" && user.is_none()) { return None; @@ -137,8 +139,8 @@ pub fn filter_post( .as_array() .unwrap_or(&empty_vec) .iter() - .map(|i| i.as_str().unwrap()); - if (location_visibility == "private" && !author.any(|i| Some(i) == user)) + .map(|i| i.as_str().unwrap().parse().unwrap()); + if (location_visibility == "private" && !author.any(|i| Some(&i) == user)) || (location_visibility == "protected" && user.is_none()) { post["properties"] @@ -184,10 +186,10 @@ async fn get_post_from_database( db: &S, url: &str, after: Option, - user: &Option, + user: Option<&url::Url>, ) -> std::result::Result<(serde_json::Value, Option), FrontendError> { match db - .read_feed_with_cursor(url, after.as_deref(), POSTS_PER_PAGE, user.as_deref()) + .read_feed_with_cursor(url, after.as_deref(), POSTS_PER_PAGE, user) .await { Ok(result) => match result { @@ -240,12 +242,13 @@ pub async fn homepage( Extension(db): Extension, ) -> impl IntoResponse { let user = None; // TODO authentication - let path = format!("https://{}/", host); + // This is stupid, but there is no other way. + let hcard_url: url::Url = format!("https://{}/", host).parse().unwrap(); let feed_path = format!("https://{}/feeds/main", host); match tokio::try_join!( - get_post_from_database(&db, &path, None, &user), - get_post_from_database(&db, &feed_path, query.after, &user) + get_post_from_database(&db, &hcard_url.as_str(), None, user.as_ref()), + get_post_from_database(&db, &feed_path, query.after, user.as_ref()) ) { Ok(((hcard, _), (hfeed, cursor))) => { // Here, we know those operations can't really fail @@ -254,13 +257,13 @@ pub async fn homepage( // // btw is it more efficient to fetch these in parallel? let (blogname, webring, channels) = tokio::join!( - db.get_setting::(&host) + db.get_setting::(&hcard_url) .map(Result::unwrap_or_default), - db.get_setting::(&host) + db.get_setting::(&hcard_url) .map(Result::unwrap_or_default), - db.get_channels(&host).map(|i| i.unwrap_or_default()) + db.get_channels(&hcard_url).map(|i| i.unwrap_or_default()) ); // Render the homepage ( @@ -273,7 +276,7 @@ pub async fn homepage( title: blogname.as_ref(), blog_name: blogname.as_ref(), feeds: channels, - user, + user: user.as_ref().map(url::Url::to_string), content: MainPage { feed: &hfeed, card: &hcard, @@ -298,10 +301,10 @@ pub async fn homepage( error!("Error while fetching h-card and/or h-feed: {}", err); // Return the error let (blogname, channels) = tokio::join!( - db.get_setting::(&host) + db.get_setting::(&hcard_url) .map(Result::unwrap_or_default), - db.get_channels(&host).map(|i| i.unwrap_or_default()) + db.get_channels(&hcard_url).map(|i| i.unwrap_or_default()) ); ( @@ -314,7 +317,7 @@ pub async fn homepage( title: blogname.as_ref(), blog_name: blogname.as_ref(), feeds: channels, - user, + user: user.as_ref().map(url::Url::to_string), content: ErrorPage { code: err.code(), msg: Some(err.msg().to_string()), @@ -335,13 +338,14 @@ pub async fn catchall( Query(query): Query, uri: Uri, ) -> impl IntoResponse { - let user = None; // TODO authentication - let path = url::Url::parse(&format!("https://{}/", host)) - .unwrap() + let user: Option = None; // TODO authentication + let host = url::Url::parse(&format!("https://{}/", host)).unwrap(); + let path = host + .clone() .join(uri.path()) .unwrap(); - match get_post_from_database(&db, path.as_str(), query.after, &user).await { + match get_post_from_database(&db, path.as_str(), query.after, user.as_ref()).await { Ok((post, cursor)) => { let (blogname, channels) = tokio::join!( db.get_setting::(&host) @@ -360,7 +364,7 @@ pub async fn catchall( title: blogname.as_ref(), blog_name: blogname.as_ref(), feeds: channels, - user, + user: user.as_ref().map(url::Url::to_string), content: match post.pointer("/type/0").and_then(|i| i.as_str()) { Some("h-entry") => Entry { post: &post }.to_string(), Some("h-feed") => Feed { feed: &post, cursor: cursor.as_deref() }.to_string(), @@ -390,7 +394,7 @@ pub async fn catchall( title: blogname.as_ref(), blog_name: blogname.as_ref(), feeds: channels, - user, + user: user.as_ref().map(url::Url::to_string), content: ErrorPage { code: err.code(), msg: Some(err.msg().to_owned()), diff --git a/src/frontend/onboarding.rs b/src/frontend/onboarding.rs index e44e866..faf8cdd 100644 --- a/src/frontend/onboarding.rs +++ b/src/frontend/onboarding.rs @@ -82,11 +82,11 @@ async fn onboard( .map(|port| format!(":{}", port)) .unwrap_or_default() ); - db.set_setting::(&user_domain, data.blog_name.to_owned()) + db.set_setting::(&user.me, data.blog_name.to_owned()) .await .map_err(FrontendError::from)?; - db.set_setting::(&user_domain, false) + db.set_setting::(&user.me, false) .await .map_err(FrontendError::from)?; @@ -95,7 +95,7 @@ async fn onboard( hcard["properties"]["uid"] = serde_json::json!([&user_uid]); crate::micropub::normalize_mf2(hcard, &user) }; - db.put_post(&hcard, user_domain.as_str()) + db.put_post(&hcard, &user.me) .await .map_err(FrontendError::from)?; @@ -113,7 +113,7 @@ async fn onboard( &user, ); - db.put_post(&feed, user_uid.as_str()) + db.put_post(&feed, &user.me) .await .map_err(FrontendError::from)?; } diff --git a/src/micropub/mod.rs b/src/micropub/mod.rs index 8f7ff90..74f53a0 100644 --- a/src/micropub/mod.rs +++ b/src/micropub/mod.rs @@ -292,7 +292,7 @@ pub(crate) async fn _post( ); // Save the post tracing::debug!("Saving post to database..."); - db.put_post(&mf2, &user_domain).await?; + db.put_post(&mf2, &user.me).await?; let mut channels = mf2["properties"]["channel"] .as_array() @@ -579,7 +579,7 @@ pub(crate) async fn query( ); match query.q { QueryType::Config => { - let channels: Vec = match db.get_channels(&user_domain).await { + let channels: Vec = match db.get_channels(&user.me).await { Ok(chans) => chans, Err(err) => { return MicropubError::new( @@ -636,7 +636,7 @@ pub(crate) async fn query( } } } - QueryType::Channel => match db.get_channels(&user_domain).await { + QueryType::Channel => match db.get_channels(&user.me).await { Ok(chans) => axum::response::Json(json!({ "channels": chans })).into_response(), Err(err) => MicropubError::new( ErrorType::InternalServerError, diff --git a/src/micropub/util.rs b/src/micropub/util.rs index b6a045d..0633ce9 100644 --- a/src/micropub/util.rs +++ b/src/micropub/util.rs @@ -212,7 +212,7 @@ pub(crate) async fn create_feed( }), user, ); - storage.put_post(&feed, user.me.as_str()).await?; + storage.put_post(&feed, &user.me).await?; storage.add_to_feed(channel, uid).await } -- cgit 1.4.1