about summary refs log tree commit diff
path: root/kittybox-rs/src/indieauth.rs
diff options
context:
space:
mode:
Diffstat (limited to 'kittybox-rs/src/indieauth.rs')
-rw-r--r--kittybox-rs/src/indieauth.rs381
1 files changed, 227 insertions, 154 deletions
diff --git a/kittybox-rs/src/indieauth.rs b/kittybox-rs/src/indieauth.rs
index 57c0301..63de859 100644
--- a/kittybox-rs/src/indieauth.rs
+++ b/kittybox-rs/src/indieauth.rs
@@ -1,6 +1,5 @@
+use serde::{Deserialize, Serialize};
 use url::Url;
-use serde::{Serialize, Deserialize};
-use warp::{Filter, Rejection, reject::MissingHeader};
 
 #[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
 pub struct User {
@@ -15,40 +14,46 @@ pub enum ErrorKind {
     NotAuthorized,
     TokenEndpointError,
     JsonParsing,
-    Other
+    InvalidHeader,
+    Other,
 }
 
 #[derive(Deserialize, Serialize, Debug, Clone)]
 pub struct TokenEndpointError {
     error: String,
-    error_description: String
+    error_description: String,
 }
 
 #[derive(Debug)]
 pub struct IndieAuthError {
     source: Option<Box<dyn std::error::Error + Send + Sync>>,
     kind: ErrorKind,
-    msg: String
+    msg: String,
 }
 
 impl std::error::Error for IndieAuthError {
     fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
-        self.source.as_ref().map(|e| e.as_ref() as &dyn std::error::Error)
+        self.source
+            .as_ref()
+            .map(|e| e.as_ref() as &dyn std::error::Error)
     }
 }
 
 impl std::fmt::Display for IndieAuthError {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        match match self.kind {
-            ErrorKind::TokenEndpointError => write!(f, "token endpoint returned an error: "),
-            ErrorKind::JsonParsing => write!(f, "error while parsing token endpoint response: "),
-            ErrorKind::NotAuthorized => write!(f, "token endpoint did not recognize the token: "),
-            ErrorKind::PermissionDenied => write!(f, "token endpoint rejected the token: "),
-            ErrorKind::Other => write!(f, "token endpoint communication error: "),
-        } {
-            Ok(_) => write!(f, "{}", self.msg),
-            Err(err) => Err(err)
-        }
+        write!(
+            f,
+            "{}: {}",
+            match self.kind {
+                ErrorKind::TokenEndpointError => "token endpoint returned an error: ",
+                ErrorKind::JsonParsing => "error while parsing token endpoint response: ",
+                ErrorKind::NotAuthorized => "token endpoint did not recognize the token: ",
+                ErrorKind::PermissionDenied => "token endpoint rejected the token: ",
+                ErrorKind::InvalidHeader => "authorization header parsing error: ",
+                ErrorKind::Other => "token endpoint communication error: ",
+            },
+            self.msg
+        )
     }
 }
 
@@ -72,7 +77,42 @@ impl From<reqwest::Error> for IndieAuthError {
     }
 }
 
-impl warp::reject::Reject for IndieAuthError {}
+impl From<axum::extract::rejection::TypedHeaderRejection> for IndieAuthError {
+    fn from(err: axum::extract::rejection::TypedHeaderRejection) -> Self {
+        Self {
+            msg: format!("{:?}", err.reason()),
+            source: Some(Box::new(err)),
+            kind: ErrorKind::InvalidHeader,
+        }
+    }
+}
+
+impl axum::response::IntoResponse for IndieAuthError {
+    fn into_response(self) -> axum::response::Response {
+        let status_code: StatusCode = match self.kind {
+            ErrorKind::PermissionDenied => StatusCode::FORBIDDEN,
+            ErrorKind::NotAuthorized => StatusCode::UNAUTHORIZED,
+            ErrorKind::TokenEndpointError => StatusCode::INTERNAL_SERVER_ERROR,
+            ErrorKind::JsonParsing => StatusCode::BAD_REQUEST,
+            ErrorKind::InvalidHeader => StatusCode::UNAUTHORIZED,
+            ErrorKind::Other => StatusCode::INTERNAL_SERVER_ERROR,
+        };
+
+        let body = serde_json::json!({
+            "error": match self.kind {
+                ErrorKind::PermissionDenied => "forbidden",
+                ErrorKind::NotAuthorized => "unauthorized",
+                ErrorKind::TokenEndpointError => "token_endpoint_error",
+                ErrorKind::JsonParsing => "invalid_request",
+                ErrorKind::InvalidHeader => "unauthorized",
+                ErrorKind::Other => "unknown_error",
+            },
+            "error_description": self.msg
+        });
+
+        (status_code, axum::response::Json(body)).into_response()
+    }
+}
 
 impl User {
     pub fn check_scope(&self, scope: &str) -> bool {
@@ -90,89 +130,112 @@ impl User {
     }
 }
 
-pub fn require_token(token_endpoint: String, http: reqwest::Client) -> impl Filter<Extract = (User,), Error = Rejection> + Clone {
-    // It might be OK to panic here, because we're still inside the initialisation sequence for now.
-    // Proper error handling on the top of this should be used though.
-    let token_endpoint_uri = url::Url::parse(&token_endpoint)
-        .expect("Couldn't parse the token endpoint URI!");
-    warp::any()
-        .map(move || token_endpoint_uri.clone())
-        .and(warp::any().map(move || http.clone()))
-        .and(warp::header::<String>("Authorization").recover(|err: Rejection| async move {
-            if err.find::<MissingHeader>().is_some() {
-                Err(IndieAuthError {
-                    source: None,
-                    msg: "No Authorization header provided.".to_string(),
-                    kind: ErrorKind::NotAuthorized
-                }.into())
-            } else {
-                Err(err)
-            }
-        }).unify())
-        .and_then(|token_endpoint, http: reqwest::Client, token| async move {
-            use hyper::StatusCode;
-
-            match http
-                .get(token_endpoint)
-                .header("Authorization", token)
-                .header("Accept", "application/json")
-                .send()
+use axum::{
+    extract::{Extension, FromRequest, RequestParts, TypedHeader},
+    headers::{
+        authorization::{Bearer, Credentials},
+        Authorization,
+    },
+    http::StatusCode,
+};
+
+// this newtype is required due to axum::Extension retrieving items by type
+// it's based on compiler magic matching extensions by their type's hashes
+#[derive(Debug, Clone)]
+pub struct TokenEndpoint(pub url::Url);
+
+#[async_trait::async_trait]
+impl<B> FromRequest<B> for User
+where
+    B: Send,
+{
+    type Rejection = IndieAuthError;
+
+    #[cfg_attr(all(debug_assertions, not(test)), allow(unreachable_code, unused_variables))]
+    async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
+        // Return a fake user if we're running a debug build
+        // I don't wanna bother with authentication
+        #[cfg(all(debug_assertions, not(test)))]
+        return Ok(User::new(
+            "http://localhost:8080/",
+            "https://quill.p3k.io/",
+            "create update delete media"
+        ));
+
+        let TypedHeader(Authorization(token)) =
+            TypedHeader::<Authorization<Bearer>>::from_request(req)
                 .await
-            {
-                Ok(res) => match res.status() {
-                    StatusCode::OK => match res.json::<serde_json::Value>().await {
-                        Ok(json) => match serde_json::from_value::<User>(json.clone()) {
-                            Ok(user) => Ok(user),
-                            Err(err) => {
-                                if let Some(false) = json["active"].as_bool() {
-                                    Err(IndieAuthError {
-                                        source: None,
-                                        kind: ErrorKind::NotAuthorized,
-                                        msg: "The token is not active for this user.".to_owned()
-                                    }.into())
-                                } else {
-                                    Err(IndieAuthError::from(err).into())
-                                }
+                .map_err(IndieAuthError::from)?;
+
+        let Extension(TokenEndpoint(token_endpoint)): Extension<TokenEndpoint> =
+            Extension::from_request(req).await.unwrap();
+
+        let Extension(http): Extension<reqwest::Client> =
+            Extension::from_request(req).await.unwrap();
+
+        match http
+            .get(token_endpoint)
+            .header("Authorization", token.encode())
+            .header("Accept", "application/json")
+            .send()
+            .await
+        {
+            Ok(res) => match res.status() {
+                StatusCode::OK => match res.json::<serde_json::Value>().await {
+                    Ok(json) => match serde_json::from_value::<User>(json.clone()) {
+                        Ok(user) => Ok(user),
+                        Err(err) => {
+                            if let Some(false) = json["active"].as_bool() {
+                                Err(IndieAuthError {
+                                    source: None,
+                                    kind: ErrorKind::NotAuthorized,
+                                    msg: "The token is not active for this user.".to_owned(),
+                                })
+                            } else {
+                                Err(IndieAuthError::from(err))
                             }
                         }
-                        Err(err) => Err(IndieAuthError::from(err).into())
                     },
-                    StatusCode::BAD_REQUEST => {
-                        match res.json::<TokenEndpointError>().await {
-                            Ok(err) => {
-                                if err.error == "unauthorized" {
-                                    Err(IndieAuthError {
-                                        source: None,
-                                        kind: ErrorKind::NotAuthorized,
-                                        msg: err.error_description
-                                    }.into())
-                                } else {
-                                    Err(IndieAuthError {
-                                        source: None,
-                                        kind: ErrorKind::TokenEndpointError,
-                                        msg: err.error_description
-                                    }.into())
-                                }
-                            },
-                            Err(err) => Err(IndieAuthError::from(err).into())
+                    Err(err) => Err(IndieAuthError::from(err)),
+                },
+                StatusCode::BAD_REQUEST => match res.json::<TokenEndpointError>().await {
+                    Ok(err) => {
+                        if err.error == "unauthorized" {
+                            Err(IndieAuthError {
+                                source: None,
+                                kind: ErrorKind::NotAuthorized,
+                                msg: err.error_description,
+                            })
+                        } else {
+                            Err(IndieAuthError {
+                                source: None,
+                                kind: ErrorKind::TokenEndpointError,
+                                msg: err.error_description,
+                            })
                         }
-                    },
-                    _ => Err(IndieAuthError {
-                        source: None,
-                        msg: format!("Token endpoint returned {}", res.status()),
-                        kind: ErrorKind::TokenEndpointError
-                    }.into())
+                    }
+                    Err(err) => Err(IndieAuthError::from(err)),
                 },
-                Err(err) => Err(warp::reject::custom(IndieAuthError::from(err)))
-            }
-        })
+                _ => Err(IndieAuthError {
+                    source: None,
+                    msg: format!("Token endpoint returned {}", res.status()),
+                    kind: ErrorKind::TokenEndpointError,
+                }),
+            },
+            Err(err) => Err(IndieAuthError::from(err)),
+        }
+    }
 }
 
 #[cfg(test)]
 mod tests {
-    use super::{User, IndieAuthError, require_token};
+    use super::User;
+    use axum::{
+        extract::FromRequest,
+        http::{Method, Request}
+    };
     use httpmock::prelude::*;
-    
+
     #[test]
     fn user_scopes_are_checkable() {
         let user = User::new(
@@ -189,76 +252,88 @@ mod tests {
     fn get_http_client() -> reqwest::Client {
         reqwest::Client::new()
     }
-    
+
+    fn request<A: Into<Option<&'static str>>, T: TryInto<url::Url> + std::fmt::Debug>(
+        auth: A,
+        endpoint: T,
+    ) -> Request<()>
+    where
+        <T as std::convert::TryInto<url::Url>>::Error: std::fmt::Debug,
+    {
+        let request = Request::builder().method(Method::GET);
+
+        match auth.into() {
+            Some(auth) => request.header("Authorization", auth),
+            None => request,
+        }
+        .extension(super::TokenEndpoint(endpoint.try_into().unwrap()))
+        .extension(get_http_client())
+        .body(())
+        .unwrap()
+    }
+
     #[tokio::test]
     async fn test_require_token_with_token() {
         let server = MockServer::start_async().await;
-        server.mock_async(|when, then| {
-            when.path("/token")
-                .header("Authorization", "Bearer token");
-
-            then.status(200)
-                .header("Content-Type", "application/json")
-                .json_body(serde_json::to_value(User::new(
-                    "https://fireburn.ru/",
-                    "https://quill.p3k.io/",
-                    "create update media",
-                )).unwrap());
-        }).await;
-        
-        let filter = require_token(server.url("/token"), get_http_client());
-
-        let res: User = warp::test::request()
-            .path("/")
-            .header("Authorization", "Bearer token")
-            .filter(&filter)
-            .await
-            .unwrap();
-
-        assert_eq!(res.me.as_str(), "https://fireburn.ru/")
+        server
+            .mock_async(|when, then| {
+                when.path("/token").header("Authorization", "Bearer token");
+
+                then.status(200)
+                    .header("Content-Type", "application/json")
+                    .json_body(
+                        serde_json::to_value(User::new(
+                            "https://fireburn.ru/",
+                            "https://quill.p3k.io/",
+                            "create update media",
+                        ))
+                        .unwrap(),
+                    );
+            })
+            .await;
+
+        let request = request("Bearer token", server.url("/token").as_str());
+        let mut parts = axum::extract::RequestParts::new(request);
+        let user = User::from_request(&mut parts).await.unwrap();
+
+        assert_eq!(user.me.as_str(), "https://fireburn.ru/")
     }
 
     #[tokio::test]
     async fn test_require_token_fake_token() {
         let server = MockServer::start_async().await;
-        server.mock_async(|when, then| {
-            when.path("/refuse_token");
-
-            then.status(200)
-                .json_body(serde_json::json!({"active": false}));
-        }).await;
+        server
+            .mock_async(|when, then| {
+                when.path("/refuse_token");
 
-        let filter = require_token(server.url("/refuse_token"), get_http_client());
+                then.status(200)
+                    .json_body(serde_json::json!({"active": false}));
+            })
+            .await;
 
-        let res = warp::test::request()
-            .path("/")
-            .header("Authorization", "Bearer token")
-            .filter(&filter)
-            .await
-            .unwrap_err();
+        let request = request("Bearer token", server.url("/refuse_token").as_str());
+        let mut parts = axum::extract::RequestParts::new(request);
+        let err = User::from_request(&mut parts).await.unwrap_err();
 
-        let err: &IndieAuthError = res.find().unwrap();
-        assert_eq!(err.kind, super::ErrorKind::NotAuthorized);
+        assert_eq!(err.kind, super::ErrorKind::NotAuthorized)
     }
 
     #[tokio::test]
     async fn test_require_token_no_token() {
         let server = MockServer::start_async().await;
-        let mock = server.mock_async(|when, then| {
-            when.path("/should_never_be_called");
+        let mock = server
+            .mock_async(|when, then| {
+                when.path("/should_never_be_called");
 
-            then.status(500);
-        }).await;
-        let filter = require_token(server.url("/should_never_be_called"), get_http_client());
+                then.status(500);
+            })
+            .await;
 
-        let res = warp::test::request()
-            .path("/")
-            .filter(&filter)
-            .await
-            .unwrap_err();
+        let request = request(None, server.url("/should_never_be_called").as_str());
+        let mut parts = axum::extract::RequestParts::new(request);
+        let err = User::from_request(&mut parts).await.unwrap_err();
 
-        let err: &IndieAuthError = res.find().unwrap();
-        assert_eq!(err.kind, super::ErrorKind::NotAuthorized);
+        assert_eq!(err.kind, super::ErrorKind::InvalidHeader);
 
         mock.assert_hits_async(0).await;
     }
@@ -266,26 +341,24 @@ mod tests {
     #[tokio::test]
     async fn test_require_token_400_error_unauthorized() {
         let server = MockServer::start_async().await;
-        server.mock_async(|when, then| {
-            when.path("/refuse_token_with_400");
+        server
+            .mock_async(|when, then| {
+                when.path("/refuse_token_with_400");
 
-            then.status(400)
-                .json_body(serde_json::json!({
+                then.status(400).json_body(serde_json::json!({
                     "error": "unauthorized",
                     "error_description": "The token provided was malformed"
                 }));
-        }).await;
-
-        let filter = require_token(server.url("/refuse_token_with_400"), get_http_client());
+            })
+            .await;
 
-        let res = warp::test::request()
-            .path("/")
-            .header("Authorization", "Bearer token")
-            .filter(&filter)
-            .await
-            .unwrap_err();
+        let request = request(
+            "Bearer token",
+            server.url("/refuse_token_with_400").as_str(),
+        );
+        let mut parts = axum::extract::RequestParts::new(request);
+        let err = User::from_request(&mut parts).await.unwrap_err();
 
-        let err: &IndieAuthError = res.find().unwrap();
         assert_eq!(err.kind, super::ErrorKind::NotAuthorized);
     }
 }