diff options
Diffstat (limited to 'src/tokenauth.rs')
-rw-r--r-- | src/tokenauth.rs | 358 |
1 files changed, 0 insertions, 358 deletions
diff --git a/src/tokenauth.rs b/src/tokenauth.rs deleted file mode 100644 index 414454a..0000000 --- a/src/tokenauth.rs +++ /dev/null @@ -1,358 +0,0 @@ -use serde::{Deserialize, Serialize}; -use url::Url; - -#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] -pub struct User { - pub me: Url, - pub client_id: Url, - scope: String, -} - -#[derive(Debug, Clone, PartialEq, Copy)] -pub enum ErrorKind { - PermissionDenied, - NotAuthorized, - TokenEndpointError, - JsonParsing, - InvalidHeader, - Other, -} - -#[derive(Deserialize, Serialize, Debug, Clone)] -pub struct TokenEndpointError { - error: String, - error_description: String, -} - -#[derive(Debug)] -pub struct IndieAuthError { - source: Option<Box<dyn std::error::Error + Send + Sync>>, - kind: ErrorKind, - msg: String, -} - -impl std::error::Error for IndieAuthError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.source - .as_ref() - .map(|e| e.as_ref() as &dyn std::error::Error) - } -} - -impl std::fmt::Display for IndieAuthError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}: {}", - match self.kind { - ErrorKind::TokenEndpointError => "token endpoint returned an error: ", - ErrorKind::JsonParsing => "error while parsing token endpoint response: ", - ErrorKind::NotAuthorized => "token endpoint did not recognize the token: ", - ErrorKind::PermissionDenied => "token endpoint rejected the token: ", - ErrorKind::InvalidHeader => "authorization header parsing error: ", - ErrorKind::Other => "token endpoint communication error: ", - }, - self.msg - ) - } -} - -impl From<serde_json::Error> for IndieAuthError { - fn from(err: serde_json::Error) -> Self { - Self { - msg: format!("{}", err), - source: Some(Box::new(err)), - kind: ErrorKind::JsonParsing, - } - } -} - -impl From<reqwest::Error> for IndieAuthError { - fn from(err: reqwest::Error) -> Self { - Self { - msg: format!("{}", err), - source: Some(Box::new(err)), - kind: ErrorKind::Other, - } - } -} - -impl From<axum::extract::rejection::TypedHeaderRejection> for IndieAuthError { - fn from(err: axum::extract::rejection::TypedHeaderRejection) -> Self { - Self { - msg: format!("{:?}", err.reason()), - source: Some(Box::new(err)), - kind: ErrorKind::InvalidHeader, - } - } -} - -impl axum::response::IntoResponse for IndieAuthError { - fn into_response(self) -> axum::response::Response { - let status_code: StatusCode = match self.kind { - ErrorKind::PermissionDenied => StatusCode::FORBIDDEN, - ErrorKind::NotAuthorized => StatusCode::UNAUTHORIZED, - ErrorKind::TokenEndpointError => StatusCode::INTERNAL_SERVER_ERROR, - ErrorKind::JsonParsing => StatusCode::BAD_REQUEST, - ErrorKind::InvalidHeader => StatusCode::UNAUTHORIZED, - ErrorKind::Other => StatusCode::INTERNAL_SERVER_ERROR, - }; - - let body = serde_json::json!({ - "error": match self.kind { - ErrorKind::PermissionDenied => "forbidden", - ErrorKind::NotAuthorized => "unauthorized", - ErrorKind::TokenEndpointError => "token_endpoint_error", - ErrorKind::JsonParsing => "invalid_request", - ErrorKind::InvalidHeader => "unauthorized", - ErrorKind::Other => "unknown_error", - }, - "error_description": self.msg - }); - - (status_code, axum::response::Json(body)).into_response() - } -} - -impl User { - pub fn check_scope(&self, scope: &str) -> bool { - self.scopes().any(|i| i == scope) - } - pub fn scopes(&self) -> std::str::SplitAsciiWhitespace<'_> { - self.scope.split_ascii_whitespace() - } - pub fn new(me: &str, client_id: &str, scope: &str) -> Self { - Self { - me: Url::parse(me).unwrap(), - client_id: Url::parse(client_id).unwrap(), - scope: scope.to_string(), - } - } -} - -use axum::{ - extract::{Extension, FromRequest, RequestParts, TypedHeader}, - headers::{ - authorization::{Bearer, Credentials}, - Authorization, - }, - http::StatusCode, -}; - -// this newtype is required due to axum::Extension retrieving items by type -// it's based on compiler magic matching extensions by their type's hashes -#[derive(Debug, Clone)] -pub struct TokenEndpoint(pub url::Url); - -#[async_trait::async_trait] -impl<B> FromRequest<B> for User -where - B: Send, -{ - type Rejection = IndieAuthError; - - #[cfg_attr( - all(debug_assertions, not(test)), - allow(unreachable_code, unused_variables) - )] - async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { - // Return a fake user if we're running a debug build - // I don't wanna bother with authentication - #[cfg(all(debug_assertions, not(test)))] - return Ok(User::new( - "http://localhost:8080/", - "https://quill.p3k.io/", - "create update delete media", - )); - - let TypedHeader(Authorization(token)) = - TypedHeader::<Authorization<Bearer>>::from_request(req) - .await - .map_err(IndieAuthError::from)?; - - let Extension(TokenEndpoint(token_endpoint)): Extension<TokenEndpoint> = - Extension::from_request(req).await.unwrap(); - - let Extension(http): Extension<reqwest_middleware::ClientWithMiddleware> = - Extension::from_request(req).await.unwrap(); - - match http - .get(token_endpoint) - .header("Authorization", token.encode()) - .header("Accept", "application/json") - .send() - .await - { - Ok(res) => match res.status() { - StatusCode::OK => match res.json::<serde_json::Value>().await { - Ok(json) => match serde_json::from_value::<User>(json.clone()) { - Ok(user) => Ok(user), - Err(err) => { - if let Some(false) = json["active"].as_bool() { - Err(IndieAuthError { - source: None, - kind: ErrorKind::NotAuthorized, - msg: "The token is not active for this user.".to_owned(), - }) - } else { - Err(IndieAuthError::from(err)) - } - } - }, - Err(err) => Err(IndieAuthError::from(err)), - }, - StatusCode::BAD_REQUEST => match res.json::<TokenEndpointError>().await { - Ok(err) => { - if err.error == "unauthorized" { - Err(IndieAuthError { - source: None, - kind: ErrorKind::NotAuthorized, - msg: err.error_description, - }) - } else { - Err(IndieAuthError { - source: None, - kind: ErrorKind::TokenEndpointError, - msg: err.error_description, - }) - } - } - Err(err) => Err(IndieAuthError::from(err)), - }, - _ => Err(IndieAuthError { - source: None, - msg: format!("Token endpoint returned {}", res.status()), - kind: ErrorKind::TokenEndpointError, - }), - }, - Err(err) => Err(IndieAuthError::from(err)), - } - } -} - -#[cfg(test)] -mod tests { - use super::User; - use axum::{ - extract::FromRequest, - http::{Method, Request}, - }; - use wiremock::{MockServer, Mock, ResponseTemplate}; - use wiremock::matchers::{method, path, header}; - - #[test] - fn user_scopes_are_checkable() { - let user = User::new( - "https://fireburn.ru/", - "https://quill.p3k.io/", - "create update media", - ); - - assert!(user.check_scope("create")); - assert!(!user.check_scope("delete")); - } - - #[inline] - fn get_http_client() -> reqwest_middleware::ClientWithMiddleware { - reqwest_middleware::ClientWithMiddleware::new() - } - - fn request<A: Into<Option<&'static str>>>( - auth: A, - endpoint: String, - ) -> Request<()> { - let request = Request::builder().method(Method::GET); - - match auth.into() { - Some(auth) => request.header("Authorization", auth), - None => request, - } - .extension(super::TokenEndpoint(endpoint.parse().unwrap())) - .extension(get_http_client()) - .body(()) - .unwrap() - } - - #[tokio::test] - async fn test_require_token_with_token() { - let server = MockServer::start().await; - - Mock::given(path("/token")) - .and(header("Authorization", "Bearer token")) - .respond_with(ResponseTemplate::new(200) - .set_body_json(User::new( - "https://fireburn.ru/", - "https://quill.p3k.io/", - "create update media", - )) - ) - .mount(&server) - .await; - - let request = request("Bearer token", format!("{}/token", &server.uri())); - let mut parts = axum::extract::RequestParts::new(request); - let user = User::from_request(&mut parts).await.unwrap(); - - assert_eq!(user.me.as_str(), "https://fireburn.ru/") - } - - #[tokio::test] - async fn test_require_token_fake_token() { - let server = MockServer::start().await; - - Mock::given(path("/refuse_token")) - .respond_with(ResponseTemplate::new(200) - .set_body_json(serde_json::json!({"active": false})) - ) - .mount(&server) - .await; - - let request = request("Bearer token", format!("{}/refuse_token", &server.uri())); - let mut parts = axum::extract::RequestParts::new(request); - let err = User::from_request(&mut parts).await.unwrap_err(); - - assert_eq!(err.kind, super::ErrorKind::NotAuthorized) - } - - #[tokio::test] - async fn test_require_token_no_token() { - let server = MockServer::start().await; - - Mock::given(path("/should_never_be_called")) - .respond_with(ResponseTemplate::new(500)) - .expect(0) - .mount(&server) - .await; - - let request = request(None, format!("{}/should_never_be_called", &server.uri())); - let mut parts = axum::extract::RequestParts::new(request); - let err = User::from_request(&mut parts).await.unwrap_err(); - - assert_eq!(err.kind, super::ErrorKind::InvalidHeader); - } - - #[tokio::test] - async fn test_require_token_400_error_unauthorized() { - let server = MockServer::start().await; - - Mock::given(path("/refuse_token_with_400")) - .and(header("Authorization", "Bearer token")) - .respond_with(ResponseTemplate::new(400) - .set_body_json(serde_json::json!({ - "error": "unauthorized", - "error_description": "The token provided was malformed" - })) - ) - .mount(&server) - .await; - - let request = request( - "Bearer token", - format!("{}/refuse_token_with_400", &server.uri()), - ); - let mut parts = axum::extract::RequestParts::new(request); - let err = User::from_request(&mut parts).await.unwrap_err(); - - assert_eq!(err.kind, super::ErrorKind::NotAuthorized); - } -} |