diff options
Diffstat (limited to 'kittybox-rs/src/tokenauth.rs')
-rw-r--r-- | kittybox-rs/src/tokenauth.rs | 367 |
1 files changed, 367 insertions, 0 deletions
diff --git a/kittybox-rs/src/tokenauth.rs b/kittybox-rs/src/tokenauth.rs new file mode 100644 index 0000000..103f514 --- /dev/null +++ b/kittybox-rs/src/tokenauth.rs @@ -0,0 +1,367 @@ +use serde::{Deserialize, Serialize}; +use url::Url; + +#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] +pub struct User { + pub me: Url, + pub client_id: Url, + scope: String, +} + +#[derive(Debug, Clone, PartialEq, Copy)] +pub enum ErrorKind { + PermissionDenied, + NotAuthorized, + TokenEndpointError, + JsonParsing, + InvalidHeader, + Other, +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct TokenEndpointError { + error: String, + error_description: String, +} + +#[derive(Debug)] +pub struct IndieAuthError { + source: Option<Box<dyn std::error::Error + Send + Sync>>, + kind: ErrorKind, + msg: String, +} + +impl std::error::Error for IndieAuthError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.source + .as_ref() + .map(|e| e.as_ref() as &dyn std::error::Error) + } +} + +impl std::fmt::Display for IndieAuthError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}: {}", + match self.kind { + ErrorKind::TokenEndpointError => "token endpoint returned an error: ", + ErrorKind::JsonParsing => "error while parsing token endpoint response: ", + ErrorKind::NotAuthorized => "token endpoint did not recognize the token: ", + ErrorKind::PermissionDenied => "token endpoint rejected the token: ", + ErrorKind::InvalidHeader => "authorization header parsing error: ", + ErrorKind::Other => "token endpoint communication error: ", + }, + self.msg + ) + } +} + +impl From<serde_json::Error> for IndieAuthError { + fn from(err: serde_json::Error) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::JsonParsing, + } + } +} + +impl From<reqwest::Error> for IndieAuthError { + fn from(err: reqwest::Error) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::Other, + } + } +} + +impl From<axum::extract::rejection::TypedHeaderRejection> for IndieAuthError { + fn from(err: axum::extract::rejection::TypedHeaderRejection) -> Self { + Self { + msg: format!("{:?}", err.reason()), + source: Some(Box::new(err)), + kind: ErrorKind::InvalidHeader, + } + } +} + +impl axum::response::IntoResponse for IndieAuthError { + fn into_response(self) -> axum::response::Response { + let status_code: StatusCode = match self.kind { + ErrorKind::PermissionDenied => StatusCode::FORBIDDEN, + ErrorKind::NotAuthorized => StatusCode::UNAUTHORIZED, + ErrorKind::TokenEndpointError => StatusCode::INTERNAL_SERVER_ERROR, + ErrorKind::JsonParsing => StatusCode::BAD_REQUEST, + ErrorKind::InvalidHeader => StatusCode::UNAUTHORIZED, + ErrorKind::Other => StatusCode::INTERNAL_SERVER_ERROR, + }; + + let body = serde_json::json!({ + "error": match self.kind { + ErrorKind::PermissionDenied => "forbidden", + ErrorKind::NotAuthorized => "unauthorized", + ErrorKind::TokenEndpointError => "token_endpoint_error", + ErrorKind::JsonParsing => "invalid_request", + ErrorKind::InvalidHeader => "unauthorized", + ErrorKind::Other => "unknown_error", + }, + "error_description": self.msg + }); + + (status_code, axum::response::Json(body)).into_response() + } +} + +impl User { + pub fn check_scope(&self, scope: &str) -> bool { + self.scopes().any(|i| i == scope) + } + pub fn scopes(&self) -> std::str::SplitAsciiWhitespace<'_> { + self.scope.split_ascii_whitespace() + } + pub fn new(me: &str, client_id: &str, scope: &str) -> Self { + Self { + me: Url::parse(me).unwrap(), + client_id: Url::parse(client_id).unwrap(), + scope: scope.to_string(), + } + } +} + +use axum::{ + extract::{Extension, FromRequest, RequestParts, TypedHeader}, + headers::{ + authorization::{Bearer, Credentials}, + Authorization, + }, + http::StatusCode, +}; + +// this newtype is required due to axum::Extension retrieving items by type +// it's based on compiler magic matching extensions by their type's hashes +#[derive(Debug, Clone)] +pub struct TokenEndpoint(pub url::Url); + +#[async_trait::async_trait] +impl<B> FromRequest<B> for User +where + B: Send, +{ + type Rejection = IndieAuthError; + + #[cfg_attr( + all(debug_assertions, not(test)), + allow(unreachable_code, unused_variables) + )] + async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + // Return a fake user if we're running a debug build + // I don't wanna bother with authentication + #[cfg(all(debug_assertions, not(test)))] + return Ok(User::new( + "http://localhost:8080/", + "https://quill.p3k.io/", + "create update delete media", + )); + + let TypedHeader(Authorization(token)) = + TypedHeader::<Authorization<Bearer>>::from_request(req) + .await + .map_err(IndieAuthError::from)?; + + let Extension(TokenEndpoint(token_endpoint)): Extension<TokenEndpoint> = + Extension::from_request(req).await.unwrap(); + + let Extension(http): Extension<reqwest::Client> = + Extension::from_request(req).await.unwrap(); + + match http + .get(token_endpoint) + .header("Authorization", token.encode()) + .header("Accept", "application/json") + .send() + .await + { + Ok(res) => match res.status() { + StatusCode::OK => match res.json::<serde_json::Value>().await { + Ok(json) => match serde_json::from_value::<User>(json.clone()) { + Ok(user) => Ok(user), + Err(err) => { + if let Some(false) = json["active"].as_bool() { + Err(IndieAuthError { + source: None, + kind: ErrorKind::NotAuthorized, + msg: "The token is not active for this user.".to_owned(), + }) + } else { + Err(IndieAuthError::from(err)) + } + } + }, + Err(err) => Err(IndieAuthError::from(err)), + }, + StatusCode::BAD_REQUEST => match res.json::<TokenEndpointError>().await { + Ok(err) => { + if err.error == "unauthorized" { + Err(IndieAuthError { + source: None, + kind: ErrorKind::NotAuthorized, + msg: err.error_description, + }) + } else { + Err(IndieAuthError { + source: None, + kind: ErrorKind::TokenEndpointError, + msg: err.error_description, + }) + } + } + Err(err) => Err(IndieAuthError::from(err)), + }, + _ => Err(IndieAuthError { + source: None, + msg: format!("Token endpoint returned {}", res.status()), + kind: ErrorKind::TokenEndpointError, + }), + }, + Err(err) => Err(IndieAuthError::from(err)), + } + } +} + +#[cfg(test)] +mod tests { + use super::User; + use axum::{ + extract::FromRequest, + http::{Method, Request}, + }; + use httpmock::prelude::*; + + #[test] + fn user_scopes_are_checkable() { + let user = User::new( + "https://fireburn.ru/", + "https://quill.p3k.io/", + "create update media", + ); + + assert!(user.check_scope("create")); + assert!(!user.check_scope("delete")); + } + + #[inline] + fn get_http_client() -> reqwest::Client { + reqwest::Client::new() + } + + fn request<A: Into<Option<&'static str>>, T: TryInto<url::Url> + std::fmt::Debug>( + auth: A, + endpoint: T, + ) -> Request<()> + where + <T as std::convert::TryInto<url::Url>>::Error: std::fmt::Debug, + { + let request = Request::builder().method(Method::GET); + + match auth.into() { + Some(auth) => request.header("Authorization", auth), + None => request, + } + .extension(super::TokenEndpoint(endpoint.try_into().unwrap())) + .extension(get_http_client()) + .body(()) + .unwrap() + } + + #[tokio::test] + async fn test_require_token_with_token() { + let server = MockServer::start_async().await; + server + .mock_async(|when, then| { + when.path("/token").header("Authorization", "Bearer token"); + + then.status(200) + .header("Content-Type", "application/json") + .json_body( + serde_json::to_value(User::new( + "https://fireburn.ru/", + "https://quill.p3k.io/", + "create update media", + )) + .unwrap(), + ); + }) + .await; + + let request = request("Bearer token", server.url("/token").as_str()); + let mut parts = axum::extract::RequestParts::new(request); + let user = User::from_request(&mut parts).await.unwrap(); + + assert_eq!(user.me.as_str(), "https://fireburn.ru/") + } + + #[tokio::test] + async fn test_require_token_fake_token() { + let server = MockServer::start_async().await; + server + .mock_async(|when, then| { + when.path("/refuse_token"); + + then.status(200) + .json_body(serde_json::json!({"active": false})); + }) + .await; + + let request = request("Bearer token", server.url("/refuse_token").as_str()); + let mut parts = axum::extract::RequestParts::new(request); + let err = User::from_request(&mut parts).await.unwrap_err(); + + assert_eq!(err.kind, super::ErrorKind::NotAuthorized) + } + + #[tokio::test] + async fn test_require_token_no_token() { + let server = MockServer::start_async().await; + let mock = server + .mock_async(|when, then| { + when.path("/should_never_be_called"); + + then.status(500); + }) + .await; + + let request = request(None, server.url("/should_never_be_called").as_str()); + let mut parts = axum::extract::RequestParts::new(request); + let err = User::from_request(&mut parts).await.unwrap_err(); + + assert_eq!(err.kind, super::ErrorKind::InvalidHeader); + + mock.assert_hits_async(0).await; + } + + #[tokio::test] + async fn test_require_token_400_error_unauthorized() { + let server = MockServer::start_async().await; + server + .mock_async(|when, then| { + when.path("/refuse_token_with_400"); + + then.status(400).json_body(serde_json::json!({ + "error": "unauthorized", + "error_description": "The token provided was malformed" + })); + }) + .await; + + let request = request( + "Bearer token", + server.url("/refuse_token_with_400").as_str(), + ); + let mut parts = axum::extract::RequestParts::new(request); + let err = User::from_request(&mut parts).await.unwrap_err(); + + assert_eq!(err.kind, super::ErrorKind::NotAuthorized); + } +} |