about summary refs log tree commit diff
path: root/src/indieauth
diff options
context:
space:
mode:
Diffstat (limited to 'src/indieauth')
-rw-r--r--src/indieauth/mod.rs45
1 files changed, 20 insertions, 25 deletions
diff --git a/src/indieauth/mod.rs b/src/indieauth/mod.rs
index 26879bb..2550df0 100644
--- a/src/indieauth/mod.rs
+++ b/src/indieauth/mod.rs
@@ -1,12 +1,10 @@
 use std::marker::PhantomData;
 
+use microformats::types::Class;
 use tracing::error;
 use serde::Deserialize;
 use axum::{
-    extract::{Query, Json, Host, Form},
-    response::{Html, IntoResponse, Response},
-    http::StatusCode, TypedHeader, headers::{Authorization, authorization::Bearer},
-    Extension
+    extract::{Form, FromRef, Host, Json, Query, State}, headers::{authorization::Bearer, Authorization}, http::StatusCode, response::{Html, IntoResponse, Response}, Extension, TypedHeader
 };
 #[cfg_attr(not(feature = "webauthn"), allow(unused_imports))]
 use axum_extra::extract::cookie::{CookieJar, Cookie};
@@ -73,18 +71,16 @@ impl axum::response::IntoResponse for IndieAuthResourceError {
 }
 
 #[async_trait::async_trait]
-impl <S: Send + Sync, A: AuthBackend> axum::extract::FromRequestParts<S> for User<A> {
+impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::extract::FromRequestParts<St> for User<A> {
     type Rejection = IndieAuthResourceError;
 
-    async fn from_request_parts(req: &mut axum::http::request::Parts, state: &S) -> Result<Self, Self::Rejection> {
+    async fn from_request_parts(req: &mut axum::http::request::Parts, state: &St) -> Result<Self, Self::Rejection> {
         let TypedHeader(Authorization(token)) =
             TypedHeader::<Authorization<Bearer>>::from_request_parts(req, state)
             .await
             .map_err(|_| IndieAuthResourceError::Unauthorized)?;
 
-        let axum::Extension(auth) = axum::Extension::<A>::from_request_parts(req, state)
-            .await
-            .unwrap();
+        let auth = A::from_ref(state);
 
         let Host(host) = Host::from_request_parts(req, state)
             .await
@@ -146,9 +142,9 @@ pub async fn metadata(
 async fn authorization_endpoint_get<A: AuthBackend, D: Storage + 'static>(
     Host(host): Host,
     Query(request): Query<AuthorizationRequest>,
-    Extension(db): Extension<D>,
-    Extension(http): Extension<reqwest::Client>,
-    Extension(auth): Extension<A>
+    State(db): State<D>,
+    State(http): State<reqwest::Client>,
+    State(auth): State<A>
 ) -> Response {
     let me = format!("https://{host}/").parse().unwrap();
     let h_app = {
@@ -802,8 +798,13 @@ async fn userinfo_endpoint_get<A: AuthBackend, D: Storage + 'static>(
     }
 }
 
-#[must_use]
-pub fn router<A: AuthBackend, D: Storage + 'static>(backend: A, db: D, http: reqwest::Client) -> axum::Router {
+pub fn router<St, A, S>() -> axum::Router<St>
+where
+    S: Storage + FromRef<St> + 'static,
+    A: AuthBackend + FromRef<St>,
+    reqwest::Client: FromRef<St>,
+    St: Clone + Send + Sync + 'static
+{
     use axum::routing::{Router, get, post};
 
     Router::new()
@@ -814,14 +815,14 @@ pub fn router<A: AuthBackend, D: Storage + 'static>(backend: A, db: D, http: req
                        get(metadata))
                 .route(
                     "/auth",
-                    get(authorization_endpoint_get::<A, D>)
-                        .post(authorization_endpoint_post::<A, D>))
+                    get(authorization_endpoint_get::<A, S>)
+                        .post(authorization_endpoint_post::<A, S>))
                 .route(
                     "/auth/confirm",
                     post(authorization_endpoint_confirm::<A>))
                 .route(
                     "/token",
-                    post(token_endpoint_post::<A, D>))
+                    post(token_endpoint_post::<A, S>))
                 .route(
                     "/token_status",
                     post(introspection_endpoint_post::<A>))
@@ -830,11 +831,11 @@ pub fn router<A: AuthBackend, D: Storage + 'static>(backend: A, db: D, http: req
                     post(revocation_endpoint_post::<A>))
                 .route(
                     "/userinfo",
-                    get(userinfo_endpoint_get::<A, D>))
+                    get(userinfo_endpoint_get::<A, S>))
 
                 .route("/webauthn/pre_register",
                        get(
-                           #[cfg(feature = "webauthn")] webauthn::webauthn_pre_register::<A, D>,
+                           #[cfg(feature = "webauthn")] webauthn::webauthn_pre_register::<A, S>,
                            #[cfg(not(feature = "webauthn"))] || std::future::ready(axum::http::StatusCode::NOT_FOUND)
                        )
                 )
@@ -844,12 +845,6 @@ pub fn router<A: AuthBackend, D: Storage + 'static>(backend: A, db: D, http: req
                            axum::http::Method::POST
                        ])
                        .allow_origin(tower_http::cors::Any))
-                .layer(Extension(backend))
-            // I don't really like the fact that I have to use the whole database
-            // If I could, I would've designed a separate trait for getting profiles
-            // And made databases implement it, for example
-                .layer(Extension(db))
-                .layer(Extension(http))
         )
         .route(
             "/.well-known/oauth-authorization-server",