about summary refs log tree commit diff
path: root/src/frontend/mod.rs
diff options
context:
space:
mode:
authorVika <vika@fireburn.ru>2024-07-09 22:43:21 +0300
committerVika <vika@fireburn.ru>2024-07-09 22:44:01 +0300
commit2e9c292bb989ffff2c99aa2a6062962c913b3586 (patch)
tree9c148d9e8fcbd7756ab8d27ae110075beea8e615 /src/frontend/mod.rs
parent644e19aa08b2629d4b69281e14d702f0b9673687 (diff)
downloadkittybox-2e9c292bb989ffff2c99aa2a6062962c913b3586.tar.zst
database: use Url to represent user authorities
This makes the interface more consistent and resistant to misuse.
Diffstat (limited to 'src/frontend/mod.rs')
-rw-r--r--src/frontend/mod.rs52
1 files changed, 28 insertions, 24 deletions
diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs
index 7a43532..0292171 100644
--- a/src/frontend/mod.rs
+++ b/src/frontend/mod.rs
@@ -87,7 +87,7 @@ impl std::fmt::Display for FrontendError {
 #[tracing::instrument(skip(post), fields(post = %post))]
 pub fn filter_post(
     mut post: serde_json::Value,
-    user: Option<&str>,
+    user: Option<&url::Url>,
 ) -> Option<serde_json::Value> {
     if post["properties"]["deleted"][0].is_string() {
         tracing::debug!("Deleted post; returning tombstone instead");
@@ -108,7 +108,9 @@ pub fn filter_post(
                 serde_json::Value::String(ref author) => author.as_str(),
                 mf2 => mf2["properties"]["uid"][0].as_str().unwrap()
             }
-        }).collect::<Vec<&str>>();
+        })
+        .map(|i| i.parse().unwrap())
+        .collect::<Vec<url::Url>>();
     let visibility = post["properties"]["visibility"][0]
         .as_str()
         .unwrap_or("public");
@@ -118,12 +120,12 @@ pub fn filter_post(
             .as_array()
             .unwrap_or(&empty_vec)
             .iter()
-            .map(|i| i.as_str().unwrap()));
+            .map(|i| i.as_str().unwrap().parse().unwrap()));
 
         audience
     };
     tracing::debug!("post audience = {:?}", audience);
-    if (visibility == "private" && !audience.iter().any(|i| Some(*i) == user))
+    if (visibility == "private" && !audience.iter().any(|i| Some(i) == user))
         || (visibility == "protected" && user.is_none())
     {
         return None;
@@ -137,8 +139,8 @@ pub fn filter_post(
             .as_array()
             .unwrap_or(&empty_vec)
             .iter()
-            .map(|i| i.as_str().unwrap());
-        if (location_visibility == "private" && !author.any(|i| Some(i) == user))
+            .map(|i| i.as_str().unwrap().parse().unwrap());
+        if (location_visibility == "private" && !author.any(|i| Some(&i) == user))
             || (location_visibility == "protected" && user.is_none())
         {
             post["properties"]
@@ -184,10 +186,10 @@ async fn get_post_from_database<S: Storage>(
     db: &S,
     url: &str,
     after: Option<String>,
-    user: &Option<String>,
+    user: Option<&url::Url>,
 ) -> std::result::Result<(serde_json::Value, Option<String>), FrontendError> {
     match db
-        .read_feed_with_cursor(url, after.as_deref(), POSTS_PER_PAGE, user.as_deref())
+        .read_feed_with_cursor(url, after.as_deref(), POSTS_PER_PAGE, user)
         .await
     {
         Ok(result) => match result {
@@ -240,12 +242,13 @@ pub async fn homepage<D: Storage>(
     Extension(db): Extension<D>,
 ) -> impl IntoResponse {
     let user = None; // TODO authentication
-    let path = format!("https://{}/", host);
+    // This is stupid, but there is no other way.
+    let hcard_url: url::Url = format!("https://{}/", host).parse().unwrap();
     let feed_path = format!("https://{}/feeds/main", host);
 
     match tokio::try_join!(
-        get_post_from_database(&db, &path, None, &user),
-        get_post_from_database(&db, &feed_path, query.after, &user)
+        get_post_from_database(&db, &hcard_url.as_str(), None, user.as_ref()),
+        get_post_from_database(&db, &feed_path, query.after, user.as_ref())
     ) {
         Ok(((hcard, _), (hfeed, cursor))) => {
             // Here, we know those operations can't really fail
@@ -254,13 +257,13 @@ pub async fn homepage<D: Storage>(
             //
             // btw is it more efficient to fetch these in parallel?
             let (blogname, webring, channels) = tokio::join!(
-                db.get_setting::<crate::database::settings::SiteName>(&host)
+                db.get_setting::<crate::database::settings::SiteName>(&hcard_url)
                 .map(Result::unwrap_or_default),
 
-                db.get_setting::<crate::database::settings::Webring>(&host)
+                db.get_setting::<crate::database::settings::Webring>(&hcard_url)
                 .map(Result::unwrap_or_default),
 
-                db.get_channels(&host).map(|i| i.unwrap_or_default())
+                db.get_channels(&hcard_url).map(|i| i.unwrap_or_default())
             );
             // Render the homepage
             (
@@ -273,7 +276,7 @@ pub async fn homepage<D: Storage>(
                     title: blogname.as_ref(),
                     blog_name: blogname.as_ref(),
                     feeds: channels,
-                    user,
+                    user: user.as_ref().map(url::Url::to_string),
                     content: MainPage {
                         feed: &hfeed,
                         card: &hcard,
@@ -298,10 +301,10 @@ pub async fn homepage<D: Storage>(
                 error!("Error while fetching h-card and/or h-feed: {}", err);
                 // Return the error
                 let (blogname, channels) = tokio::join!(
-                    db.get_setting::<crate::database::settings::SiteName>(&host)
+                    db.get_setting::<crate::database::settings::SiteName>(&hcard_url)
                     .map(Result::unwrap_or_default),
 
-                    db.get_channels(&host).map(|i| i.unwrap_or_default())
+                    db.get_channels(&hcard_url).map(|i| i.unwrap_or_default())
                 );
 
                 (
@@ -314,7 +317,7 @@ pub async fn homepage<D: Storage>(
                         title: blogname.as_ref(),
                         blog_name: blogname.as_ref(),
                         feeds: channels,
-                        user,
+                        user: user.as_ref().map(url::Url::to_string),
                         content: ErrorPage {
                             code: err.code(),
                             msg: Some(err.msg().to_string()),
@@ -335,13 +338,14 @@ pub async fn catchall<D: Storage>(
     Query(query): Query<QueryParams>,
     uri: Uri,
 ) -> impl IntoResponse {
-    let user = None; // TODO authentication
-    let path = url::Url::parse(&format!("https://{}/", host))
-        .unwrap()
+    let user: Option<url::Url> = None; // TODO authentication
+    let host = url::Url::parse(&format!("https://{}/", host)).unwrap();
+    let path = host
+        .clone()
         .join(uri.path())
         .unwrap();
 
-    match get_post_from_database(&db, path.as_str(), query.after, &user).await {
+    match get_post_from_database(&db, path.as_str(), query.after, user.as_ref()).await {
         Ok((post, cursor)) => {
             let (blogname, channels) = tokio::join!(
                 db.get_setting::<crate::database::settings::SiteName>(&host)
@@ -360,7 +364,7 @@ pub async fn catchall<D: Storage>(
                     title: blogname.as_ref(),
                     blog_name: blogname.as_ref(),
                     feeds: channels,
-                    user,
+                    user: user.as_ref().map(url::Url::to_string),
                     content: match post.pointer("/type/0").and_then(|i| i.as_str()) {
                         Some("h-entry") => Entry { post: &post }.to_string(),
                         Some("h-feed") => Feed { feed: &post, cursor: cursor.as_deref() }.to_string(),
@@ -390,7 +394,7 @@ pub async fn catchall<D: Storage>(
                     title: blogname.as_ref(),
                     blog_name: blogname.as_ref(),
                     feeds: channels,
-                    user,
+                    user: user.as_ref().map(url::Url::to_string),
                     content: ErrorPage {
                         code: err.code(),
                         msg: Some(err.msg().to_owned()),