about summary refs log tree commit diff
path: root/kittybox-rs/src/indieauth/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'kittybox-rs/src/indieauth/mod.rs')
-rw-r--r--kittybox-rs/src/indieauth/mod.rs175
1 files changed, 150 insertions, 25 deletions
diff --git a/kittybox-rs/src/indieauth/mod.rs b/kittybox-rs/src/indieauth/mod.rs
index 70b909a..b7d4597 100644
--- a/kittybox-rs/src/indieauth/mod.rs
+++ b/kittybox-rs/src/indieauth/mod.rs
@@ -4,6 +4,7 @@ use axum::{
     http::StatusCode, TypedHeader, headers::{Authorization, authorization::Bearer},
     Extension
 };
+use crate::database::Storage;
 use kittybox_indieauth::{
     Metadata, IntrospectionEndpointAuthMethod, RevocationEndpointAuthMethod,
     Scope, Scopes, PKCEMethod, Error, ErrorKind,
@@ -18,6 +19,8 @@ use backend::AuthBackend;
 
 const ACCESS_TOKEN_VALIDITY: u64 = 7 * 24 * 60 * 60; // 7 days
 const REFRESH_TOKEN_VALIDITY: u64 = ACCESS_TOKEN_VALIDITY / 7 * 60; // 60 days
+/// Internal scope for accessing the token introspection endpoint.
+const KITTYBOX_TOKEN_STATUS: &str = "kittybox:token_status";
 
 pub async fn metadata(
     Host(host): Host
@@ -71,16 +74,20 @@ async fn authorization_endpoint_get(
     }.to_string())
 }
 
-async fn authorization_endpoint_post<A: AuthBackend>(
+async fn authorization_endpoint_post<A: AuthBackend, D: Storage + 'static>(
     Host(host): Host,
     Form(auth): Form<RequestMaybeAuthorizationEndpoint>,
-    Extension(backend): Extension<A>
+    Extension(backend): Extension<A>,
+    Extension(db): Extension<D>
 ) -> Response {
     use RequestMaybeAuthorizationEndpoint::*;
     match auth {
         Authorization(auth) => {
+            // Cloning these two values, because we can't destructure
+            // the AuthorizationRequest - we need it for the code
             let state = auth.state.clone();
             let redirect_uri = auth.redirect_uri.clone();
+
             let code = match backend.create_code(auth).await {
                 Ok(code) => code,
                 Err(err) => {
@@ -89,7 +96,7 @@ async fn authorization_endpoint_post<A: AuthBackend>(
                 }
             };
 
-            let redirect_uri = {
+            let location = {
                 let mut uri = redirect_uri;
                 uri.set_query(Some(&serde_urlencoded::to_string(
                     AuthorizationResponse {
@@ -102,7 +109,7 @@ async fn authorization_endpoint_post<A: AuthBackend>(
             };
 
             (StatusCode::FOUND,
-             [("Location", redirect_uri.as_str())]
+             [("Location", location.as_str())]
             )
                 .into_response()
         },
@@ -142,15 +149,28 @@ async fn authorization_endpoint_post<A: AuthBackend>(
                         error_uri: "https://datatracker.ietf.org/doc/html/rfc7636#section-4.6".parse().ok()
                     }.into_response()
                 }
-                let profile = if request.scope
+                let me: url::Url = format!("https://{}/", host).parse().unwrap();
+                let profile = if request.scope.as_ref()
                     .map(|s| s.has(&Scope::Profile))
                     .unwrap_or_default()
                 {
-                    Some(todo!())
+                    match get_profile(
+                        db,
+                        me.as_str(),
+                        request.scope.as_ref()
+                            .map(|s| s.has(&Scope::Email))
+                            .unwrap_or_default()
+                    ).await {
+                        Ok(profile) => profile,
+                        Err(err) => {
+                            tracing::error!("Error retrieving profile from database: {}", err);
+
+                            return StatusCode::INTERNAL_SERVER_ERROR.into_response()
+                        }
+                    }
                 } else {
                     None
                 };
-                let me = format!("https://{}/", host).parse().unwrap();
 
                 GrantResponse::ProfileUrl { me, profile }.into_response()
             },
@@ -163,10 +183,11 @@ async fn authorization_endpoint_post<A: AuthBackend>(
     }
 }
 
-async fn token_endpoint_post<A: AuthBackend>(
+async fn token_endpoint_post<A: AuthBackend, D: Storage + 'static>(
     Host(host): Host,
     Form(grant): Form<GrantRequest>,
-    Extension(backend): Extension<A>
+    Extension(backend): Extension<A>,
+    Extension(db): Extension<D>
 ) -> Response {
     #[inline]
     fn prepare_access_token(me: url::Url, client_id: url::Url, scope: Scopes) -> TokenData {
@@ -252,7 +273,18 @@ async fn token_endpoint_post<A: AuthBackend>(
             }
 
             let profile = if scope.has(&Scope::Profile) {
-                Some(todo!())
+                match get_profile(
+                    db,
+                    me.as_str(),
+                    scope.has(&Scope::Email)
+                ).await {
+                    Ok(profile) => profile,
+                    Err(err) => {
+                        tracing::error!("Error retrieving profile from database: {}", err);
+
+                        return StatusCode::INTERNAL_SERVER_ERROR.into_response()
+                    }
+                }
             } else {
                 None
             };
@@ -324,7 +356,18 @@ async fn token_endpoint_post<A: AuthBackend>(
 
 
             let profile = if scope.has(&Scope::Profile) {
-                Some(todo!())
+                match get_profile(
+                    db,
+                    data.me.as_str(),
+                    scope.has(&Scope::Email)
+                ).await {
+                    Ok(profile) => profile,
+                    Err(err) => {
+                        tracing::error!("Error retrieving profile from database: {}", err);
+
+                        return StatusCode::INTERNAL_SERVER_ERROR.into_response()
+                    }
+                }
             } else {
                 None
             };
@@ -366,11 +409,26 @@ async fn token_endpoint_post<A: AuthBackend>(
 }
 
 async fn introspection_endpoint_post<A: AuthBackend>(
-    Host(host): Host,
     Form(token_request): Form<TokenIntrospectionRequest>,
     TypedHeader(Authorization(auth_token)): TypedHeader<Authorization<Bearer>>,
     Extension(backend): Extension<A>
 ) -> Response {
+    use serde_json::json;
+    // Check authentication first
+    match backend.get_token(auth_token.token()).await {
+        Ok(Some(token)) => if !token.scope.has(&Scope::custom(KITTYBOX_TOKEN_STATUS)) {
+            return (StatusCode::UNAUTHORIZED, Json(json!({
+                "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope
+            }))).into_response();
+        },
+        Ok(None) => return (StatusCode::UNAUTHORIZED, Json(json!({
+            "error": kittybox_indieauth::ResourceErrorKind::InvalidToken
+        }))).into_response(),
+        Err(err) => {
+            tracing::error!("Error retrieving token data for introspection: {}", err);
+            return StatusCode::INTERNAL_SERVER_ERROR.into_response()
+        }
+    }
     let response: TokenIntrospectionResponse = match backend.get_token(&token_request.token).await {
         Ok(maybe_data) => maybe_data.into(),
         Err(err) => {
@@ -383,7 +441,6 @@ async fn introspection_endpoint_post<A: AuthBackend>(
 }
 
 async fn revocation_endpoint_post<A: AuthBackend>(
-    Host(host): Host,
     Form(revocation): Form<TokenRevocationRequest>,
     Extension(backend): Extension<A>
 ) -> impl IntoResponse {
@@ -399,20 +456,84 @@ async fn revocation_endpoint_post<A: AuthBackend>(
     }
 }
 
-async fn userinfo_endpoint_get<A: AuthBackend>(
+async fn get_profile<D: Storage + 'static>(
+    db: D,
+    url: &str,
+    email: bool
+) -> crate::database::Result<Option<Profile>> {
+    Ok(db.get_post(url).await?.map(|mut mf2| {
+        // Ruthlessly manually destructure the MF2 document to save memory
+        let name = match mf2["properties"]["name"][0].take() {
+            serde_json::Value::String(s) => Some(s),
+            _ => None
+        };
+        let url = match mf2["properties"]["uid"][0].take() {
+            serde_json::Value::String(s) => s.parse().ok(),
+            _ => None
+        };
+        let photo = match mf2["properties"]["photo"][0].take() {
+            serde_json::Value::String(s) => s.parse().ok(),
+            _ => None
+        };
+        let email = if email {
+            match mf2["properties"]["email"][0].take() {
+                serde_json::Value::String(s) => Some(s),
+                _ => None
+            }
+        } else {
+            None
+        };
+
+        Profile { name, url, photo, email }
+    }))
+}
+
+async fn userinfo_endpoint_get<A: AuthBackend, D: Storage + 'static>(
     Host(host): Host,
     TypedHeader(Authorization(auth_token)): TypedHeader<Authorization<Bearer>>,
-    Extension(backend): Extension<A>
+    Extension(backend): Extension<A>,
+    Extension(db): Extension<D>
 ) -> Response {
-    Profile {
-        name: todo!(),
-        url: todo!(),
-        photo: todo!(),
-        email: Some(todo!())
-    }.into_response()
+    use serde_json::json;
+
+    match backend.get_token(auth_token.token()).await {
+        Ok(Some(token)) => {
+            if token.expired() {
+                return (StatusCode::UNAUTHORIZED, Json(json!({
+                    "error": kittybox_indieauth::ResourceErrorKind::InvalidToken
+                }))).into_response();
+            }
+            if !token.scope.has(&Scope::Profile) {
+                return (StatusCode::UNAUTHORIZED, Json(json!({
+                    "error": kittybox_indieauth::ResourceErrorKind::InsufficientScope
+                }))).into_response();
+            }
+
+            match get_profile(db, &format!("https://{}/", host), token.scope.has(&Scope::Email)).await {
+                Ok(Some(profile)) => profile.into_response(),
+                Ok(None) => Json(json!({
+                    // We do this because ResourceErrorKind is IndieAuth errors only
+                    "error": "invalid_request"
+                })).into_response(),
+                Err(err) => {
+                    tracing::error!("Error retrieving profile from database: {}", err);
+
+                    StatusCode::INTERNAL_SERVER_ERROR.into_response()
+                }
+            }
+        },
+        Ok(None) => Json(json!({
+            "error": kittybox_indieauth::ResourceErrorKind::InvalidToken
+        })).into_response(),
+        Err(err) => {
+            tracing::error!("Error reading token: {}", err);
+
+            StatusCode::INTERNAL_SERVER_ERROR.into_response()
+        }
+    }
 }
 
-pub fn router<A: AuthBackend>(backend: A) -> axum::Router {
+pub fn router<A: AuthBackend, D: Storage + 'static>(backend: A, db: D) -> axum::Router {
     use axum::routing::{Router, get, post};
 
     Router::new()
@@ -422,10 +543,10 @@ pub fn router<A: AuthBackend>(backend: A) -> axum::Router {
                 .route(
                     "/auth",
                     get(authorization_endpoint_get)
-                        .post(authorization_endpoint_post::<A>))
+                        .post(authorization_endpoint_post::<A, D>))
                 .route(
                     "/token",
-                    post(token_endpoint_post::<A>))
+                    post(token_endpoint_post::<A, D>))
                 .route(
                     "/token_status",
                     post(introspection_endpoint_post::<A>))
@@ -434,7 +555,7 @@ pub fn router<A: AuthBackend>(backend: A) -> axum::Router {
                     post(revocation_endpoint_post::<A>))
                 .route(
                     "/userinfo",
-                    get(userinfo_endpoint_get::<A>))
+                    get(userinfo_endpoint_get::<A, D>))
                 .layer(tower_http::cors::CorsLayer::new()
                        .allow_methods([
                            axum::http::Method::GET,
@@ -442,6 +563,10 @@ pub fn router<A: AuthBackend>(backend: A) -> axum::Router {
                        ])
                        .allow_origin(tower_http::cors::Any))
                 .layer(Extension(backend))
+            // I don't really like the fact that I have to use the whole database
+            // If I could, I would've designed a separate trait for getting profiles
+            // And made databases implement it, for example
+                .layer(Extension(db))
         )
         .route(
             "/.well-known/oauth-authorization-server",