diff options
Diffstat (limited to 'kittybox-rs/src/indieauth.rs')
-rw-r--r-- | kittybox-rs/src/indieauth.rs | 381 |
1 files changed, 227 insertions, 154 deletions
diff --git a/kittybox-rs/src/indieauth.rs b/kittybox-rs/src/indieauth.rs index 57c0301..63de859 100644 --- a/kittybox-rs/src/indieauth.rs +++ b/kittybox-rs/src/indieauth.rs @@ -1,6 +1,5 @@ +use serde::{Deserialize, Serialize}; use url::Url; -use serde::{Serialize, Deserialize}; -use warp::{Filter, Rejection, reject::MissingHeader}; #[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] pub struct User { @@ -15,40 +14,46 @@ pub enum ErrorKind { NotAuthorized, TokenEndpointError, JsonParsing, - Other + InvalidHeader, + Other, } #[derive(Deserialize, Serialize, Debug, Clone)] pub struct TokenEndpointError { error: String, - error_description: String + error_description: String, } #[derive(Debug)] pub struct IndieAuthError { source: Option<Box<dyn std::error::Error + Send + Sync>>, kind: ErrorKind, - msg: String + msg: String, } impl std::error::Error for IndieAuthError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.source.as_ref().map(|e| e.as_ref() as &dyn std::error::Error) + self.source + .as_ref() + .map(|e| e.as_ref() as &dyn std::error::Error) } } impl std::fmt::Display for IndieAuthError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match match self.kind { - ErrorKind::TokenEndpointError => write!(f, "token endpoint returned an error: "), - ErrorKind::JsonParsing => write!(f, "error while parsing token endpoint response: "), - ErrorKind::NotAuthorized => write!(f, "token endpoint did not recognize the token: "), - ErrorKind::PermissionDenied => write!(f, "token endpoint rejected the token: "), - ErrorKind::Other => write!(f, "token endpoint communication error: "), - } { - Ok(_) => write!(f, "{}", self.msg), - Err(err) => Err(err) - } + write!( + f, + "{}: {}", + match self.kind { + ErrorKind::TokenEndpointError => "token endpoint returned an error: ", + ErrorKind::JsonParsing => "error while parsing token endpoint response: ", + ErrorKind::NotAuthorized => "token endpoint did not recognize the token: ", + ErrorKind::PermissionDenied => "token endpoint rejected the token: ", + ErrorKind::InvalidHeader => "authorization header parsing error: ", + ErrorKind::Other => "token endpoint communication error: ", + }, + self.msg + ) } } @@ -72,7 +77,42 @@ impl From<reqwest::Error> for IndieAuthError { } } -impl warp::reject::Reject for IndieAuthError {} +impl From<axum::extract::rejection::TypedHeaderRejection> for IndieAuthError { + fn from(err: axum::extract::rejection::TypedHeaderRejection) -> Self { + Self { + msg: format!("{:?}", err.reason()), + source: Some(Box::new(err)), + kind: ErrorKind::InvalidHeader, + } + } +} + +impl axum::response::IntoResponse for IndieAuthError { + fn into_response(self) -> axum::response::Response { + let status_code: StatusCode = match self.kind { + ErrorKind::PermissionDenied => StatusCode::FORBIDDEN, + ErrorKind::NotAuthorized => StatusCode::UNAUTHORIZED, + ErrorKind::TokenEndpointError => StatusCode::INTERNAL_SERVER_ERROR, + ErrorKind::JsonParsing => StatusCode::BAD_REQUEST, + ErrorKind::InvalidHeader => StatusCode::UNAUTHORIZED, + ErrorKind::Other => StatusCode::INTERNAL_SERVER_ERROR, + }; + + let body = serde_json::json!({ + "error": match self.kind { + ErrorKind::PermissionDenied => "forbidden", + ErrorKind::NotAuthorized => "unauthorized", + ErrorKind::TokenEndpointError => "token_endpoint_error", + ErrorKind::JsonParsing => "invalid_request", + ErrorKind::InvalidHeader => "unauthorized", + ErrorKind::Other => "unknown_error", + }, + "error_description": self.msg + }); + + (status_code, axum::response::Json(body)).into_response() + } +} impl User { pub fn check_scope(&self, scope: &str) -> bool { @@ -90,89 +130,112 @@ impl User { } } -pub fn require_token(token_endpoint: String, http: reqwest::Client) -> impl Filter<Extract = (User,), Error = Rejection> + Clone { - // It might be OK to panic here, because we're still inside the initialisation sequence for now. - // Proper error handling on the top of this should be used though. - let token_endpoint_uri = url::Url::parse(&token_endpoint) - .expect("Couldn't parse the token endpoint URI!"); - warp::any() - .map(move || token_endpoint_uri.clone()) - .and(warp::any().map(move || http.clone())) - .and(warp::header::<String>("Authorization").recover(|err: Rejection| async move { - if err.find::<MissingHeader>().is_some() { - Err(IndieAuthError { - source: None, - msg: "No Authorization header provided.".to_string(), - kind: ErrorKind::NotAuthorized - }.into()) - } else { - Err(err) - } - }).unify()) - .and_then(|token_endpoint, http: reqwest::Client, token| async move { - use hyper::StatusCode; - - match http - .get(token_endpoint) - .header("Authorization", token) - .header("Accept", "application/json") - .send() +use axum::{ + extract::{Extension, FromRequest, RequestParts, TypedHeader}, + headers::{ + authorization::{Bearer, Credentials}, + Authorization, + }, + http::StatusCode, +}; + +// this newtype is required due to axum::Extension retrieving items by type +// it's based on compiler magic matching extensions by their type's hashes +#[derive(Debug, Clone)] +pub struct TokenEndpoint(pub url::Url); + +#[async_trait::async_trait] +impl<B> FromRequest<B> for User +where + B: Send, +{ + type Rejection = IndieAuthError; + + #[cfg_attr(all(debug_assertions, not(test)), allow(unreachable_code, unused_variables))] + async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + // Return a fake user if we're running a debug build + // I don't wanna bother with authentication + #[cfg(all(debug_assertions, not(test)))] + return Ok(User::new( + "http://localhost:8080/", + "https://quill.p3k.io/", + "create update delete media" + )); + + let TypedHeader(Authorization(token)) = + TypedHeader::<Authorization<Bearer>>::from_request(req) .await - { - Ok(res) => match res.status() { - StatusCode::OK => match res.json::<serde_json::Value>().await { - Ok(json) => match serde_json::from_value::<User>(json.clone()) { - Ok(user) => Ok(user), - Err(err) => { - if let Some(false) = json["active"].as_bool() { - Err(IndieAuthError { - source: None, - kind: ErrorKind::NotAuthorized, - msg: "The token is not active for this user.".to_owned() - }.into()) - } else { - Err(IndieAuthError::from(err).into()) - } + .map_err(IndieAuthError::from)?; + + let Extension(TokenEndpoint(token_endpoint)): Extension<TokenEndpoint> = + Extension::from_request(req).await.unwrap(); + + let Extension(http): Extension<reqwest::Client> = + Extension::from_request(req).await.unwrap(); + + match http + .get(token_endpoint) + .header("Authorization", token.encode()) + .header("Accept", "application/json") + .send() + .await + { + Ok(res) => match res.status() { + StatusCode::OK => match res.json::<serde_json::Value>().await { + Ok(json) => match serde_json::from_value::<User>(json.clone()) { + Ok(user) => Ok(user), + Err(err) => { + if let Some(false) = json["active"].as_bool() { + Err(IndieAuthError { + source: None, + kind: ErrorKind::NotAuthorized, + msg: "The token is not active for this user.".to_owned(), + }) + } else { + Err(IndieAuthError::from(err)) } } - Err(err) => Err(IndieAuthError::from(err).into()) }, - StatusCode::BAD_REQUEST => { - match res.json::<TokenEndpointError>().await { - Ok(err) => { - if err.error == "unauthorized" { - Err(IndieAuthError { - source: None, - kind: ErrorKind::NotAuthorized, - msg: err.error_description - }.into()) - } else { - Err(IndieAuthError { - source: None, - kind: ErrorKind::TokenEndpointError, - msg: err.error_description - }.into()) - } - }, - Err(err) => Err(IndieAuthError::from(err).into()) + Err(err) => Err(IndieAuthError::from(err)), + }, + StatusCode::BAD_REQUEST => match res.json::<TokenEndpointError>().await { + Ok(err) => { + if err.error == "unauthorized" { + Err(IndieAuthError { + source: None, + kind: ErrorKind::NotAuthorized, + msg: err.error_description, + }) + } else { + Err(IndieAuthError { + source: None, + kind: ErrorKind::TokenEndpointError, + msg: err.error_description, + }) } - }, - _ => Err(IndieAuthError { - source: None, - msg: format!("Token endpoint returned {}", res.status()), - kind: ErrorKind::TokenEndpointError - }.into()) + } + Err(err) => Err(IndieAuthError::from(err)), }, - Err(err) => Err(warp::reject::custom(IndieAuthError::from(err))) - } - }) + _ => Err(IndieAuthError { + source: None, + msg: format!("Token endpoint returned {}", res.status()), + kind: ErrorKind::TokenEndpointError, + }), + }, + Err(err) => Err(IndieAuthError::from(err)), + } + } } #[cfg(test)] mod tests { - use super::{User, IndieAuthError, require_token}; + use super::User; + use axum::{ + extract::FromRequest, + http::{Method, Request} + }; use httpmock::prelude::*; - + #[test] fn user_scopes_are_checkable() { let user = User::new( @@ -189,76 +252,88 @@ mod tests { fn get_http_client() -> reqwest::Client { reqwest::Client::new() } - + + fn request<A: Into<Option<&'static str>>, T: TryInto<url::Url> + std::fmt::Debug>( + auth: A, + endpoint: T, + ) -> Request<()> + where + <T as std::convert::TryInto<url::Url>>::Error: std::fmt::Debug, + { + let request = Request::builder().method(Method::GET); + + match auth.into() { + Some(auth) => request.header("Authorization", auth), + None => request, + } + .extension(super::TokenEndpoint(endpoint.try_into().unwrap())) + .extension(get_http_client()) + .body(()) + .unwrap() + } + #[tokio::test] async fn test_require_token_with_token() { let server = MockServer::start_async().await; - server.mock_async(|when, then| { - when.path("/token") - .header("Authorization", "Bearer token"); - - then.status(200) - .header("Content-Type", "application/json") - .json_body(serde_json::to_value(User::new( - "https://fireburn.ru/", - "https://quill.p3k.io/", - "create update media", - )).unwrap()); - }).await; - - let filter = require_token(server.url("/token"), get_http_client()); - - let res: User = warp::test::request() - .path("/") - .header("Authorization", "Bearer token") - .filter(&filter) - .await - .unwrap(); - - assert_eq!(res.me.as_str(), "https://fireburn.ru/") + server + .mock_async(|when, then| { + when.path("/token").header("Authorization", "Bearer token"); + + then.status(200) + .header("Content-Type", "application/json") + .json_body( + serde_json::to_value(User::new( + "https://fireburn.ru/", + "https://quill.p3k.io/", + "create update media", + )) + .unwrap(), + ); + }) + .await; + + let request = request("Bearer token", server.url("/token").as_str()); + let mut parts = axum::extract::RequestParts::new(request); + let user = User::from_request(&mut parts).await.unwrap(); + + assert_eq!(user.me.as_str(), "https://fireburn.ru/") } #[tokio::test] async fn test_require_token_fake_token() { let server = MockServer::start_async().await; - server.mock_async(|when, then| { - when.path("/refuse_token"); - - then.status(200) - .json_body(serde_json::json!({"active": false})); - }).await; + server + .mock_async(|when, then| { + when.path("/refuse_token"); - let filter = require_token(server.url("/refuse_token"), get_http_client()); + then.status(200) + .json_body(serde_json::json!({"active": false})); + }) + .await; - let res = warp::test::request() - .path("/") - .header("Authorization", "Bearer token") - .filter(&filter) - .await - .unwrap_err(); + let request = request("Bearer token", server.url("/refuse_token").as_str()); + let mut parts = axum::extract::RequestParts::new(request); + let err = User::from_request(&mut parts).await.unwrap_err(); - let err: &IndieAuthError = res.find().unwrap(); - assert_eq!(err.kind, super::ErrorKind::NotAuthorized); + assert_eq!(err.kind, super::ErrorKind::NotAuthorized) } #[tokio::test] async fn test_require_token_no_token() { let server = MockServer::start_async().await; - let mock = server.mock_async(|when, then| { - when.path("/should_never_be_called"); + let mock = server + .mock_async(|when, then| { + when.path("/should_never_be_called"); - then.status(500); - }).await; - let filter = require_token(server.url("/should_never_be_called"), get_http_client()); + then.status(500); + }) + .await; - let res = warp::test::request() - .path("/") - .filter(&filter) - .await - .unwrap_err(); + let request = request(None, server.url("/should_never_be_called").as_str()); + let mut parts = axum::extract::RequestParts::new(request); + let err = User::from_request(&mut parts).await.unwrap_err(); - let err: &IndieAuthError = res.find().unwrap(); - assert_eq!(err.kind, super::ErrorKind::NotAuthorized); + assert_eq!(err.kind, super::ErrorKind::InvalidHeader); mock.assert_hits_async(0).await; } @@ -266,26 +341,24 @@ mod tests { #[tokio::test] async fn test_require_token_400_error_unauthorized() { let server = MockServer::start_async().await; - server.mock_async(|when, then| { - when.path("/refuse_token_with_400"); + server + .mock_async(|when, then| { + when.path("/refuse_token_with_400"); - then.status(400) - .json_body(serde_json::json!({ + then.status(400).json_body(serde_json::json!({ "error": "unauthorized", "error_description": "The token provided was malformed" })); - }).await; - - let filter = require_token(server.url("/refuse_token_with_400"), get_http_client()); + }) + .await; - let res = warp::test::request() - .path("/") - .header("Authorization", "Bearer token") - .filter(&filter) - .await - .unwrap_err(); + let request = request( + "Bearer token", + server.url("/refuse_token_with_400").as_str(), + ); + let mut parts = axum::extract::RequestParts::new(request); + let err = User::from_request(&mut parts).await.unwrap_err(); - let err: &IndieAuthError = res.find().unwrap(); assert_eq!(err.kind, super::ErrorKind::NotAuthorized); } } |