use serde::{Deserialize, Serialize}; use url::Url; #[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] pub struct User { pub me: Url, pub client_id: Url, scope: String, } #[derive(Debug, Clone, PartialEq, Copy)] pub enum ErrorKind { PermissionDenied, NotAuthorized, TokenEndpointError, JsonParsing, InvalidHeader, Other, } #[derive(Deserialize, Serialize, Debug, Clone)] pub struct TokenEndpointError { error: String, error_description: String, } #[derive(Debug)] pub struct IndieAuthError { source: Option<Box<dyn std::error::Error + Send + Sync>>, kind: ErrorKind, msg: String, } impl std::error::Error for IndieAuthError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { self.source .as_ref() .map(|e| e.as_ref() as &dyn std::error::Error) } } impl std::fmt::Display for IndieAuthError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "{}: {}", match self.kind { ErrorKind::TokenEndpointError => "token endpoint returned an error: ", ErrorKind::JsonParsing => "error while parsing token endpoint response: ", ErrorKind::NotAuthorized => "token endpoint did not recognize the token: ", ErrorKind::PermissionDenied => "token endpoint rejected the token: ", ErrorKind::InvalidHeader => "authorization header parsing error: ", ErrorKind::Other => "token endpoint communication error: ", }, self.msg ) } } impl From<serde_json::Error> for IndieAuthError { fn from(err: serde_json::Error) -> Self { Self { msg: format!("{}", err), source: Some(Box::new(err)), kind: ErrorKind::JsonParsing, } } } impl From<reqwest::Error> for IndieAuthError { fn from(err: reqwest::Error) -> Self { Self { msg: format!("{}", err), source: Some(Box::new(err)), kind: ErrorKind::Other, } } } impl From<axum::extract::rejection::TypedHeaderRejection> for IndieAuthError { fn from(err: axum::extract::rejection::TypedHeaderRejection) -> Self { Self { msg: format!("{:?}", err.reason()), source: Some(Box::new(err)), kind: ErrorKind::InvalidHeader, } } } impl axum::response::IntoResponse for IndieAuthError { fn into_response(self) -> axum::response::Response { let status_code: StatusCode = match self.kind { ErrorKind::PermissionDenied => StatusCode::FORBIDDEN, ErrorKind::NotAuthorized => StatusCode::UNAUTHORIZED, ErrorKind::TokenEndpointError => StatusCode::INTERNAL_SERVER_ERROR, ErrorKind::JsonParsing => StatusCode::BAD_REQUEST, ErrorKind::InvalidHeader => StatusCode::UNAUTHORIZED, ErrorKind::Other => StatusCode::INTERNAL_SERVER_ERROR, }; let body = serde_json::json!({ "error": match self.kind { ErrorKind::PermissionDenied => "forbidden", ErrorKind::NotAuthorized => "unauthorized", ErrorKind::TokenEndpointError => "token_endpoint_error", ErrorKind::JsonParsing => "invalid_request", ErrorKind::InvalidHeader => "unauthorized", ErrorKind::Other => "unknown_error", }, "error_description": self.msg }); (status_code, axum::response::Json(body)).into_response() } } impl User { pub fn check_scope(&self, scope: &str) -> bool { self.scopes().any(|i| i == scope) } pub fn scopes(&self) -> std::str::SplitAsciiWhitespace<'_> { self.scope.split_ascii_whitespace() } pub fn new(me: &str, client_id: &str, scope: &str) -> Self { Self { me: Url::parse(me).unwrap(), client_id: Url::parse(client_id).unwrap(), scope: scope.to_string(), } } } use axum::{ extract::{Extension, FromRequest, RequestParts, TypedHeader}, headers::{ authorization::{Bearer, Credentials}, Authorization, }, http::StatusCode, }; // this newtype is required due to axum::Extension retrieving items by type // it's based on compiler magic matching extensions by their type's hashes #[derive(Debug, Clone)] pub struct TokenEndpoint(pub url::Url); #[async_trait::async_trait] impl<B> FromRequest<B> for User where B: Send, { type Rejection = IndieAuthError; #[cfg_attr( all(debug_assertions, not(test)), allow(unreachable_code, unused_variables) )] async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { // Return a fake user if we're running a debug build // I don't wanna bother with authentication #[cfg(all(debug_assertions, not(test)))] return Ok(User::new( "http://localhost:8080/", "https://quill.p3k.io/", "create update delete media", )); let TypedHeader(Authorization(token)) = TypedHeader::<Authorization<Bearer>>::from_request(req) .await .map_err(IndieAuthError::from)?; let Extension(TokenEndpoint(token_endpoint)): Extension<TokenEndpoint> = Extension::from_request(req).await.unwrap(); let Extension(http): Extension<reqwest::Client> = Extension::from_request(req).await.unwrap(); match http .get(token_endpoint) .header("Authorization", token.encode()) .header("Accept", "application/json") .send() .await { Ok(res) => match res.status() { StatusCode::OK => match res.json::<serde_json::Value>().await { Ok(json) => match serde_json::from_value::<User>(json.clone()) { Ok(user) => Ok(user), Err(err) => { if let Some(false) = json["active"].as_bool() { Err(IndieAuthError { source: None, kind: ErrorKind::NotAuthorized, msg: "The token is not active for this user.".to_owned(), }) } else { Err(IndieAuthError::from(err)) } } }, Err(err) => Err(IndieAuthError::from(err)), }, StatusCode::BAD_REQUEST => match res.json::<TokenEndpointError>().await { Ok(err) => { if err.error == "unauthorized" { Err(IndieAuthError { source: None, kind: ErrorKind::NotAuthorized, msg: err.error_description, }) } else { Err(IndieAuthError { source: None, kind: ErrorKind::TokenEndpointError, msg: err.error_description, }) } } Err(err) => Err(IndieAuthError::from(err)), }, _ => Err(IndieAuthError { source: None, msg: format!("Token endpoint returned {}", res.status()), kind: ErrorKind::TokenEndpointError, }), }, Err(err) => Err(IndieAuthError::from(err)), } } } #[cfg(test)] mod tests { use super::User; use axum::{ extract::FromRequest, http::{Method, Request}, }; use httpmock::prelude::*; #[test] fn user_scopes_are_checkable() { let user = User::new( "https://fireburn.ru/", "https://quill.p3k.io/", "create update media", ); assert!(user.check_scope("create")); assert!(!user.check_scope("delete")); } #[inline] fn get_http_client() -> reqwest::Client { reqwest::Client::new() } fn request<A: Into<Option<&'static str>>, T: TryInto<url::Url> + std::fmt::Debug>( auth: A, endpoint: T, ) -> Request<()> where <T as std::convert::TryInto<url::Url>>::Error: std::fmt::Debug, { let request = Request::builder().method(Method::GET); match auth.into() { Some(auth) => request.header("Authorization", auth), None => request, } .extension(super::TokenEndpoint(endpoint.try_into().unwrap())) .extension(get_http_client()) .body(()) .unwrap() } #[tokio::test] async fn test_require_token_with_token() { let server = MockServer::start_async().await; server .mock_async(|when, then| { when.path("/token").header("Authorization", "Bearer token"); then.status(200) .header("Content-Type", "application/json") .json_body( serde_json::to_value(User::new( "https://fireburn.ru/", "https://quill.p3k.io/", "create update media", )) .unwrap(), ); }) .await; let request = request("Bearer token", server.url("/token").as_str()); let mut parts = axum::extract::RequestParts::new(request); let user = User::from_request(&mut parts).await.unwrap(); assert_eq!(user.me.as_str(), "https://fireburn.ru/") } #[tokio::test] async fn test_require_token_fake_token() { let server = MockServer::start_async().await; server .mock_async(|when, then| { when.path("/refuse_token"); then.status(200) .json_body(serde_json::json!({"active": false})); }) .await; let request = request("Bearer token", server.url("/refuse_token").as_str()); let mut parts = axum::extract::RequestParts::new(request); let err = User::from_request(&mut parts).await.unwrap_err(); assert_eq!(err.kind, super::ErrorKind::NotAuthorized) } #[tokio::test] async fn test_require_token_no_token() { let server = MockServer::start_async().await; let mock = server .mock_async(|when, then| { when.path("/should_never_be_called"); then.status(500); }) .await; let request = request(None, server.url("/should_never_be_called").as_str()); let mut parts = axum::extract::RequestParts::new(request); let err = User::from_request(&mut parts).await.unwrap_err(); assert_eq!(err.kind, super::ErrorKind::InvalidHeader); mock.assert_hits_async(0).await; } #[tokio::test] async fn test_require_token_400_error_unauthorized() { let server = MockServer::start_async().await; server .mock_async(|when, then| { when.path("/refuse_token_with_400"); then.status(400).json_body(serde_json::json!({ "error": "unauthorized", "error_description": "The token provided was malformed" })); }) .await; let request = request( "Bearer token", server.url("/refuse_token_with_400").as_str(), ); let mut parts = axum::extract::RequestParts::new(request); let err = User::from_request(&mut parts).await.unwrap_err(); assert_eq!(err.kind, super::ErrorKind::NotAuthorized); } }