diff options
author | Vika <vika@fireburn.ru> | 2022-07-07 00:32:33 +0300 |
---|---|---|
committer | Vika <vika@fireburn.ru> | 2022-07-07 00:36:39 +0300 |
commit | 7f23ec84bc05c236c1bf40c2f0d72412af711516 (patch) | |
tree | f0ba64804fffce29a8f04e5b6c76f9863de81dd2 /kittybox-rs/src/indieauth.rs | |
parent | 5cfac54aa4fb3c207ea2cbbeccd4571fa204a09b (diff) | |
download | kittybox-7f23ec84bc05c236c1bf40c2f0d72412af711516.tar.zst |
treewide: rewrite using Axum
Axum has streaming bodies and allows to write simpler code. It also helps enforce stronger types and looks much more neat. This allows me to progress on the media endpoint and add streaming reads and writes to the MediaStore trait. Metrics are temporarily not implemented. Everything else was preserved, and the tests still pass, after adjusting for new calling conventions. TODO: create method routers for protocol endpoints
Diffstat (limited to 'kittybox-rs/src/indieauth.rs')
-rw-r--r-- | kittybox-rs/src/indieauth.rs | 381 |
1 files changed, 227 insertions, 154 deletions
diff --git a/kittybox-rs/src/indieauth.rs b/kittybox-rs/src/indieauth.rs index 57c0301..63de859 100644 --- a/kittybox-rs/src/indieauth.rs +++ b/kittybox-rs/src/indieauth.rs @@ -1,6 +1,5 @@ +use serde::{Deserialize, Serialize}; use url::Url; -use serde::{Serialize, Deserialize}; -use warp::{Filter, Rejection, reject::MissingHeader}; #[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] pub struct User { @@ -15,40 +14,46 @@ pub enum ErrorKind { NotAuthorized, TokenEndpointError, JsonParsing, - Other + InvalidHeader, + Other, } #[derive(Deserialize, Serialize, Debug, Clone)] pub struct TokenEndpointError { error: String, - error_description: String + error_description: String, } #[derive(Debug)] pub struct IndieAuthError { source: Option<Box<dyn std::error::Error + Send + Sync>>, kind: ErrorKind, - msg: String + msg: String, } impl std::error::Error for IndieAuthError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.source.as_ref().map(|e| e.as_ref() as &dyn std::error::Error) + self.source + .as_ref() + .map(|e| e.as_ref() as &dyn std::error::Error) } } impl std::fmt::Display for IndieAuthError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match match self.kind { - ErrorKind::TokenEndpointError => write!(f, "token endpoint returned an error: "), - ErrorKind::JsonParsing => write!(f, "error while parsing token endpoint response: "), - ErrorKind::NotAuthorized => write!(f, "token endpoint did not recognize the token: "), - ErrorKind::PermissionDenied => write!(f, "token endpoint rejected the token: "), - ErrorKind::Other => write!(f, "token endpoint communication error: "), - } { - Ok(_) => write!(f, "{}", self.msg), - Err(err) => Err(err) - } + write!( + f, + "{}: {}", + match self.kind { + ErrorKind::TokenEndpointError => "token endpoint returned an error: ", + ErrorKind::JsonParsing => "error while parsing token endpoint response: ", + ErrorKind::NotAuthorized => "token endpoint did not recognize the token: ", + ErrorKind::PermissionDenied => "token endpoint rejected the token: ", + ErrorKind::InvalidHeader => "authorization header parsing error: ", + ErrorKind::Other => "token endpoint communication error: ", + }, + self.msg + ) } } @@ -72,7 +77,42 @@ impl From<reqwest::Error> for IndieAuthError { } } -impl warp::reject::Reject for IndieAuthError {} +impl From<axum::extract::rejection::TypedHeaderRejection> for IndieAuthError { + fn from(err: axum::extract::rejection::TypedHeaderRejection) -> Self { + Self { + msg: format!("{:?}", err.reason()), + source: Some(Box::new(err)), + kind: ErrorKind::InvalidHeader, + } + } +} + +impl axum::response::IntoResponse for IndieAuthError { + fn into_response(self) -> axum::response::Response { + let status_code: StatusCode = match self.kind { + ErrorKind::PermissionDenied => StatusCode::FORBIDDEN, + ErrorKind::NotAuthorized => StatusCode::UNAUTHORIZED, + ErrorKind::TokenEndpointError => StatusCode::INTERNAL_SERVER_ERROR, + ErrorKind::JsonParsing => StatusCode::BAD_REQUEST, + ErrorKind::InvalidHeader => StatusCode::UNAUTHORIZED, + ErrorKind::Other => StatusCode::INTERNAL_SERVER_ERROR, + }; + + let body = serde_json::json!({ + "error": match self.kind { + ErrorKind::PermissionDenied => "forbidden", + ErrorKind::NotAuthorized => "unauthorized", + ErrorKind::TokenEndpointError => "token_endpoint_error", + ErrorKind::JsonParsing => "invalid_request", + ErrorKind::InvalidHeader => "unauthorized", + ErrorKind::Other => "unknown_error", + }, + "error_description": self.msg + }); + + (status_code, axum::response::Json(body)).into_response() + } +} impl User { pub fn check_scope(&self, scope: &str) -> bool { @@ -90,89 +130,112 @@ impl User { } } -pub fn require_token(token_endpoint: String, http: reqwest::Client) -> impl Filter<Extract = (User,), Error = Rejection> + Clone { - // It might be OK to panic here, because we're still inside the initialisation sequence for now. - // Proper error handling on the top of this should be used though. - let token_endpoint_uri = url::Url::parse(&token_endpoint) - .expect("Couldn't parse the token endpoint URI!"); - warp::any() - .map(move || token_endpoint_uri.clone()) - .and(warp::any().map(move || http.clone())) - .and(warp::header::<String>("Authorization").recover(|err: Rejection| async move { - if err.find::<MissingHeader>().is_some() { - Err(IndieAuthError { - source: None, - msg: "No Authorization header provided.".to_string(), - kind: ErrorKind::NotAuthorized - }.into()) - } else { - Err(err) - } - }).unify()) - .and_then(|token_endpoint, http: reqwest::Client, token| async move { - use hyper::StatusCode; - - match http - .get(token_endpoint) - .header("Authorization", token) - .header("Accept", "application/json") - .send() +use axum::{ + extract::{Extension, FromRequest, RequestParts, TypedHeader}, + headers::{ + authorization::{Bearer, Credentials}, + Authorization, + }, + http::StatusCode, +}; + +// this newtype is required due to axum::Extension retrieving items by type +// it's based on compiler magic matching extensions by their type's hashes +#[derive(Debug, Clone)] +pub struct TokenEndpoint(pub url::Url); + +#[async_trait::async_trait] +impl<B> FromRequest<B> for User +where + B: Send, +{ + type Rejection = IndieAuthError; + + #[cfg_attr(all(debug_assertions, not(test)), allow(unreachable_code, unused_variables))] + async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + // Return a fake user if we're running a debug build + // I don't wanna bother with authentication + #[cfg(all(debug_assertions, not(test)))] + return Ok(User::new( + "http://localhost:8080/", + "https://quill.p3k.io/", + "create update delete media" + )); + + let TypedHeader(Authorization(token)) = + TypedHeader::<Authorization<Bearer>>::from_request(req) .await - { - Ok(res) => match res.status() { - StatusCode::OK => match res.json::<serde_json::Value>().await { - Ok(json) => match serde_json::from_value::<User>(json.clone()) { - Ok(user) => Ok(user), - Err(err) => { - if let Some(false) = json["active"].as_bool() { - Err(IndieAuthError { - source: None, - kind: ErrorKind::NotAuthorized, - msg: "The token is not active for this user.".to_owned() - }.into()) - } else { - Err(IndieAuthError::from(err).into()) - } + .map_err(IndieAuthError::from)?; + + let Extension(TokenEndpoint(token_endpoint)): Extension<TokenEndpoint> = + Extension::from_request(req).await.unwrap(); + + let Extension(http): Extension<reqwest::Client> = + Extension::from_request(req).await.unwrap(); + + match http + .get(token_endpoint) + .header("Authorization", token.encode()) + .header("Accept", "application/json") + .send() + .await + { + Ok(res) => match res.status() { + StatusCode::OK => match res.json::<serde_json::Value>().await { + Ok(json) => match serde_json::from_value::<User>(json.clone()) { + Ok(user) => Ok(user), + Err(err) => { + if let Some(false) = json["active"].as_bool() { + Err(IndieAuthError { + source: None, + kind: ErrorKind::NotAuthorized, + msg: "The token is not active for this user.".to_owned(), + }) + } else { + Err(IndieAuthError::from(err)) } } - Err(err) => Err(IndieAuthError::from(err).into()) }, - StatusCode::BAD_REQUEST => { - match res.json::<TokenEndpointError>().await { - Ok(err) => { - if err.error == "unauthorized" { - Err(IndieAuthError { - source: None, - kind: ErrorKind::NotAuthorized, - msg: err.error_description - }.into()) - } else { - Err(IndieAuthError { - source: None, - kind: ErrorKind::TokenEndpointError, - msg: err.error_description - }.into()) - } - }, - Err(err) => Err(IndieAuthError::from(err).into()) + Err(err) => Err(IndieAuthError::from(err)), + }, + StatusCode::BAD_REQUEST => match res.json::<TokenEndpointError>().await { + Ok(err) => { + if err.error == "unauthorized" { + Err(IndieAuthError { + source: None, + kind: ErrorKind::NotAuthorized, + msg: err.error_description, + }) + } else { + Err(IndieAuthError { + source: None, + kind: ErrorKind::TokenEndpointError, + msg: err.error_description, + }) } - }, - _ => Err(IndieAuthError { - source: None, - msg: format!("Token endpoint returned {}", res.status()), - kind: ErrorKind::TokenEndpointError - }.into()) + } + Err(err) => Err(IndieAuthError::from(err)), }, - Err(err) => Err(warp::reject::custom(IndieAuthError::from(err))) - } - }) + _ => Err(IndieAuthError { + source: None, + msg: format!("Token endpoint returned {}", res.status()), + kind: ErrorKind::TokenEndpointError, + }), + }, + Err(err) => Err(IndieAuthError::from(err)), + } + } } #[cfg(test)] mod tests { - use super::{User, IndieAuthError, require_token}; + use super::User; + use axum::{ + extract::FromRequest, + http::{Method, Request} + }; use httpmock::prelude::*; - + #[test] fn user_scopes_are_checkable() { let user = User::new( @@ -189,76 +252,88 @@ mod tests { fn get_http_client() -> reqwest::Client { reqwest::Client::new() } - + + fn request<A: Into<Option<&'static str>>, T: TryInto<url::Url> + std::fmt::Debug>( + auth: A, + endpoint: T, + ) -> Request<()> + where + <T as std::convert::TryInto<url::Url>>::Error: std::fmt::Debug, + { + let request = Request::builder().method(Method::GET); + + match auth.into() { + Some(auth) => request.header("Authorization", auth), + None => request, + } + .extension(super::TokenEndpoint(endpoint.try_into().unwrap())) + .extension(get_http_client()) + .body(()) + .unwrap() + } + #[tokio::test] async fn test_require_token_with_token() { let server = MockServer::start_async().await; - server.mock_async(|when, then| { - when.path("/token") - .header("Authorization", "Bearer token"); - - then.status(200) - .header("Content-Type", "application/json") - .json_body(serde_json::to_value(User::new( - "https://fireburn.ru/", - "https://quill.p3k.io/", - "create update media", - )).unwrap()); - }).await; - - let filter = require_token(server.url("/token"), get_http_client()); - - let res: User = warp::test::request() - .path("/") - .header("Authorization", "Bearer token") - .filter(&filter) - .await - .unwrap(); - - assert_eq!(res.me.as_str(), "https://fireburn.ru/") + server + .mock_async(|when, then| { + when.path("/token").header("Authorization", "Bearer token"); + + then.status(200) + .header("Content-Type", "application/json") + .json_body( + serde_json::to_value(User::new( + "https://fireburn.ru/", + "https://quill.p3k.io/", + "create update media", + )) + .unwrap(), + ); + }) + .await; + + let request = request("Bearer token", server.url("/token").as_str()); + let mut parts = axum::extract::RequestParts::new(request); + let user = User::from_request(&mut parts).await.unwrap(); + + assert_eq!(user.me.as_str(), "https://fireburn.ru/") } #[tokio::test] async fn test_require_token_fake_token() { let server = MockServer::start_async().await; - server.mock_async(|when, then| { - when.path("/refuse_token"); - - then.status(200) - .json_body(serde_json::json!({"active": false})); - }).await; + server + .mock_async(|when, then| { + when.path("/refuse_token"); - let filter = require_token(server.url("/refuse_token"), get_http_client()); + then.status(200) + .json_body(serde_json::json!({"active": false})); + }) + .await; - let res = warp::test::request() - .path("/") - .header("Authorization", "Bearer token") - .filter(&filter) - .await - .unwrap_err(); + let request = request("Bearer token", server.url("/refuse_token").as_str()); + let mut parts = axum::extract::RequestParts::new(request); + let err = User::from_request(&mut parts).await.unwrap_err(); - let err: &IndieAuthError = res.find().unwrap(); - assert_eq!(err.kind, super::ErrorKind::NotAuthorized); + assert_eq!(err.kind, super::ErrorKind::NotAuthorized) } #[tokio::test] async fn test_require_token_no_token() { let server = MockServer::start_async().await; - let mock = server.mock_async(|when, then| { - when.path("/should_never_be_called"); + let mock = server + .mock_async(|when, then| { + when.path("/should_never_be_called"); - then.status(500); - }).await; - let filter = require_token(server.url("/should_never_be_called"), get_http_client()); + then.status(500); + }) + .await; - let res = warp::test::request() - .path("/") - .filter(&filter) - .await - .unwrap_err(); + let request = request(None, server.url("/should_never_be_called").as_str()); + let mut parts = axum::extract::RequestParts::new(request); + let err = User::from_request(&mut parts).await.unwrap_err(); - let err: &IndieAuthError = res.find().unwrap(); - assert_eq!(err.kind, super::ErrorKind::NotAuthorized); + assert_eq!(err.kind, super::ErrorKind::InvalidHeader); mock.assert_hits_async(0).await; } @@ -266,26 +341,24 @@ mod tests { #[tokio::test] async fn test_require_token_400_error_unauthorized() { let server = MockServer::start_async().await; - server.mock_async(|when, then| { - when.path("/refuse_token_with_400"); + server + .mock_async(|when, then| { + when.path("/refuse_token_with_400"); - then.status(400) - .json_body(serde_json::json!({ + then.status(400).json_body(serde_json::json!({ "error": "unauthorized", "error_description": "The token provided was malformed" })); - }).await; - - let filter = require_token(server.url("/refuse_token_with_400"), get_http_client()); + }) + .await; - let res = warp::test::request() - .path("/") - .header("Authorization", "Bearer token") - .filter(&filter) - .await - .unwrap_err(); + let request = request( + "Bearer token", + server.url("/refuse_token_with_400").as_str(), + ); + let mut parts = axum::extract::RequestParts::new(request); + let err = User::from_request(&mut parts).await.unwrap_err(); - let err: &IndieAuthError = res.find().unwrap(); assert_eq!(err.kind, super::ErrorKind::NotAuthorized); } } |