From 4ca0c24b1989fcd12c453d428af70f58456f7651 Mon Sep 17 00:00:00 2001 From: Vika Date: Thu, 1 Aug 2024 20:01:12 +0300 Subject: Migrate from axum::Extension to axum::extract::State This somehow allowed me to shrink the construction phase of Kittybox by a huge amount of code. --- src/indieauth/mod.rs | 45 ++++++++++++++++++++------------------------- 1 file changed, 20 insertions(+), 25 deletions(-) (limited to 'src/indieauth/mod.rs') diff --git a/src/indieauth/mod.rs b/src/indieauth/mod.rs index 26879bb..2550df0 100644 --- a/src/indieauth/mod.rs +++ b/src/indieauth/mod.rs @@ -1,12 +1,10 @@ use std::marker::PhantomData; +use microformats::types::Class; use tracing::error; use serde::Deserialize; use axum::{ - extract::{Query, Json, Host, Form}, - response::{Html, IntoResponse, Response}, - http::StatusCode, TypedHeader, headers::{Authorization, authorization::Bearer}, - Extension + extract::{Form, FromRef, Host, Json, Query, State}, headers::{authorization::Bearer, Authorization}, http::StatusCode, response::{Html, IntoResponse, Response}, Extension, TypedHeader }; #[cfg_attr(not(feature = "webauthn"), allow(unused_imports))] use axum_extra::extract::cookie::{CookieJar, Cookie}; @@ -73,18 +71,16 @@ impl axum::response::IntoResponse for IndieAuthResourceError { } #[async_trait::async_trait] -impl axum::extract::FromRequestParts for User { +impl , St: Clone + Send + Sync + 'static> axum::extract::FromRequestParts for User { type Rejection = IndieAuthResourceError; - async fn from_request_parts(req: &mut axum::http::request::Parts, state: &S) -> Result { + async fn from_request_parts(req: &mut axum::http::request::Parts, state: &St) -> Result { let TypedHeader(Authorization(token)) = TypedHeader::>::from_request_parts(req, state) .await .map_err(|_| IndieAuthResourceError::Unauthorized)?; - let axum::Extension(auth) = axum::Extension::::from_request_parts(req, state) - .await - .unwrap(); + let auth = A::from_ref(state); let Host(host) = Host::from_request_parts(req, state) .await @@ -146,9 +142,9 @@ pub async fn metadata( async fn authorization_endpoint_get( Host(host): Host, Query(request): Query, - Extension(db): Extension, - Extension(http): Extension, - Extension(auth): Extension + State(db): State, + State(http): State, + State(auth): State ) -> Response { let me = format!("https://{host}/").parse().unwrap(); let h_app = { @@ -802,8 +798,13 @@ async fn userinfo_endpoint_get( } } -#[must_use] -pub fn router(backend: A, db: D, http: reqwest::Client) -> axum::Router { +pub fn router() -> axum::Router +where + S: Storage + FromRef + 'static, + A: AuthBackend + FromRef, + reqwest::Client: FromRef, + St: Clone + Send + Sync + 'static +{ use axum::routing::{Router, get, post}; Router::new() @@ -814,14 +815,14 @@ pub fn router(backend: A, db: D, http: req get(metadata)) .route( "/auth", - get(authorization_endpoint_get::) - .post(authorization_endpoint_post::)) + get(authorization_endpoint_get::) + .post(authorization_endpoint_post::)) .route( "/auth/confirm", post(authorization_endpoint_confirm::)) .route( "/token", - post(token_endpoint_post::)) + post(token_endpoint_post::)) .route( "/token_status", post(introspection_endpoint_post::)) @@ -830,11 +831,11 @@ pub fn router(backend: A, db: D, http: req post(revocation_endpoint_post::)) .route( "/userinfo", - get(userinfo_endpoint_get::)) + get(userinfo_endpoint_get::)) .route("/webauthn/pre_register", get( - #[cfg(feature = "webauthn")] webauthn::webauthn_pre_register::, + #[cfg(feature = "webauthn")] webauthn::webauthn_pre_register::, #[cfg(not(feature = "webauthn"))] || std::future::ready(axum::http::StatusCode::NOT_FOUND) ) ) @@ -844,12 +845,6 @@ pub fn router(backend: A, db: D, http: req axum::http::Method::POST ]) .allow_origin(tower_http::cors::Any)) - .layer(Extension(backend)) - // I don't really like the fact that I have to use the whole database - // If I could, I would've designed a separate trait for getting profiles - // And made databases implement it, for example - .layer(Extension(db)) - .layer(Extension(http)) ) .route( "/.well-known/oauth-authorization-server", -- cgit 1.4.1