diff options
Diffstat (limited to 'src/indieauth/mod.rs')
-rw-r--r-- | src/indieauth/mod.rs | 45 |
1 files changed, 20 insertions, 25 deletions
diff --git a/src/indieauth/mod.rs b/src/indieauth/mod.rs index 26879bb..2550df0 100644 --- a/src/indieauth/mod.rs +++ b/src/indieauth/mod.rs @@ -1,12 +1,10 @@ use std::marker::PhantomData; +use microformats::types::Class; use tracing::error; use serde::Deserialize; use axum::{ - extract::{Query, Json, Host, Form}, - response::{Html, IntoResponse, Response}, - http::StatusCode, TypedHeader, headers::{Authorization, authorization::Bearer}, - Extension + extract::{Form, FromRef, Host, Json, Query, State}, headers::{authorization::Bearer, Authorization}, http::StatusCode, response::{Html, IntoResponse, Response}, Extension, TypedHeader }; #[cfg_attr(not(feature = "webauthn"), allow(unused_imports))] use axum_extra::extract::cookie::{CookieJar, Cookie}; @@ -73,18 +71,16 @@ impl axum::response::IntoResponse for IndieAuthResourceError { } #[async_trait::async_trait] -impl <S: Send + Sync, A: AuthBackend> axum::extract::FromRequestParts<S> for User<A> { +impl <A: AuthBackend + FromRef<St>, St: Clone + Send + Sync + 'static> axum::extract::FromRequestParts<St> for User<A> { type Rejection = IndieAuthResourceError; - async fn from_request_parts(req: &mut axum::http::request::Parts, state: &S) -> Result<Self, Self::Rejection> { + async fn from_request_parts(req: &mut axum::http::request::Parts, state: &St) -> Result<Self, Self::Rejection> { let TypedHeader(Authorization(token)) = TypedHeader::<Authorization<Bearer>>::from_request_parts(req, state) .await .map_err(|_| IndieAuthResourceError::Unauthorized)?; - let axum::Extension(auth) = axum::Extension::<A>::from_request_parts(req, state) - .await - .unwrap(); + let auth = A::from_ref(state); let Host(host) = Host::from_request_parts(req, state) .await @@ -146,9 +142,9 @@ pub async fn metadata( async fn authorization_endpoint_get<A: AuthBackend, D: Storage + 'static>( Host(host): Host, Query(request): Query<AuthorizationRequest>, - Extension(db): Extension<D>, - Extension(http): Extension<reqwest::Client>, - Extension(auth): Extension<A> + State(db): State<D>, + State(http): State<reqwest::Client>, + State(auth): State<A> ) -> Response { let me = format!("https://{host}/").parse().unwrap(); let h_app = { @@ -802,8 +798,13 @@ async fn userinfo_endpoint_get<A: AuthBackend, D: Storage + 'static>( } } -#[must_use] -pub fn router<A: AuthBackend, D: Storage + 'static>(backend: A, db: D, http: reqwest::Client) -> axum::Router { +pub fn router<St, A, S>() -> axum::Router<St> +where + S: Storage + FromRef<St> + 'static, + A: AuthBackend + FromRef<St>, + reqwest::Client: FromRef<St>, + St: Clone + Send + Sync + 'static +{ use axum::routing::{Router, get, post}; Router::new() @@ -814,14 +815,14 @@ pub fn router<A: AuthBackend, D: Storage + 'static>(backend: A, db: D, http: req get(metadata)) .route( "/auth", - get(authorization_endpoint_get::<A, D>) - .post(authorization_endpoint_post::<A, D>)) + get(authorization_endpoint_get::<A, S>) + .post(authorization_endpoint_post::<A, S>)) .route( "/auth/confirm", post(authorization_endpoint_confirm::<A>)) .route( "/token", - post(token_endpoint_post::<A, D>)) + post(token_endpoint_post::<A, S>)) .route( "/token_status", post(introspection_endpoint_post::<A>)) @@ -830,11 +831,11 @@ pub fn router<A: AuthBackend, D: Storage + 'static>(backend: A, db: D, http: req post(revocation_endpoint_post::<A>)) .route( "/userinfo", - get(userinfo_endpoint_get::<A, D>)) + get(userinfo_endpoint_get::<A, S>)) .route("/webauthn/pre_register", get( - #[cfg(feature = "webauthn")] webauthn::webauthn_pre_register::<A, D>, + #[cfg(feature = "webauthn")] webauthn::webauthn_pre_register::<A, S>, #[cfg(not(feature = "webauthn"))] || std::future::ready(axum::http::StatusCode::NOT_FOUND) ) ) @@ -844,12 +845,6 @@ pub fn router<A: AuthBackend, D: Storage + 'static>(backend: A, db: D, http: req axum::http::Method::POST ]) .allow_origin(tower_http::cors::Any)) - .layer(Extension(backend)) - // I don't really like the fact that I have to use the whole database - // If I could, I would've designed a separate trait for getting profiles - // And made databases implement it, for example - .layer(Extension(db)) - .layer(Extension(http)) ) .route( "/.well-known/oauth-authorization-server", |