use serde::{Deserialize, Serialize};
use url::Url;
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
pub struct User {
pub me: Url,
pub client_id: Url,
scope: String,
}
#[derive(Debug, Clone, PartialEq, Copy)]
pub enum ErrorKind {
PermissionDenied,
NotAuthorized,
TokenEndpointError,
JsonParsing,
InvalidHeader,
Other,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct TokenEndpointError {
error: String,
error_description: String,
}
#[derive(Debug)]
pub struct IndieAuthError {
source: Option<Box<dyn std::error::Error + Send + Sync>>,
kind: ErrorKind,
msg: String,
}
impl std::error::Error for IndieAuthError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.source
.as_ref()
.map(|e| e.as_ref() as &dyn std::error::Error)
}
}
impl std::fmt::Display for IndieAuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}: {}",
match self.kind {
ErrorKind::TokenEndpointError => "token endpoint returned an error: ",
ErrorKind::JsonParsing => "error while parsing token endpoint response: ",
ErrorKind::NotAuthorized => "token endpoint did not recognize the token: ",
ErrorKind::PermissionDenied => "token endpoint rejected the token: ",
ErrorKind::InvalidHeader => "authorization header parsing error: ",
ErrorKind::Other => "token endpoint communication error: ",
},
self.msg
)
}
}
impl From<serde_json::Error> for IndieAuthError {
fn from(err: serde_json::Error) -> Self {
Self {
msg: format!("{}", err),
source: Some(Box::new(err)),
kind: ErrorKind::JsonParsing,
}
}
}
impl From<reqwest::Error> for IndieAuthError {
fn from(err: reqwest::Error) -> Self {
Self {
msg: format!("{}", err),
source: Some(Box::new(err)),
kind: ErrorKind::Other,
}
}
}
impl From<axum::extract::rejection::TypedHeaderRejection> for IndieAuthError {
fn from(err: axum::extract::rejection::TypedHeaderRejection) -> Self {
Self {
msg: format!("{:?}", err.reason()),
source: Some(Box::new(err)),
kind: ErrorKind::InvalidHeader,
}
}
}
impl axum::response::IntoResponse for IndieAuthError {
fn into_response(self) -> axum::response::Response {
let status_code: StatusCode = match self.kind {
ErrorKind::PermissionDenied => StatusCode::FORBIDDEN,
ErrorKind::NotAuthorized => StatusCode::UNAUTHORIZED,
ErrorKind::TokenEndpointError => StatusCode::INTERNAL_SERVER_ERROR,
ErrorKind::JsonParsing => StatusCode::BAD_REQUEST,
ErrorKind::InvalidHeader => StatusCode::UNAUTHORIZED,
ErrorKind::Other => StatusCode::INTERNAL_SERVER_ERROR,
};
let body = serde_json::json!({
"error": match self.kind {
ErrorKind::PermissionDenied => "forbidden",
ErrorKind::NotAuthorized => "unauthorized",
ErrorKind::TokenEndpointError => "token_endpoint_error",
ErrorKind::JsonParsing => "invalid_request",
ErrorKind::InvalidHeader => "unauthorized",
ErrorKind::Other => "unknown_error",
},
"error_description": self.msg
});
(status_code, axum::response::Json(body)).into_response()
}
}
impl User {
pub fn check_scope(&self, scope: &str) -> bool {
self.scopes().any(|i| i == scope)
}
pub fn scopes(&self) -> std::str::SplitAsciiWhitespace<'_> {
self.scope.split_ascii_whitespace()
}
pub fn new(me: &str, client_id: &str, scope: &str) -> Self {
Self {
me: Url::parse(me).unwrap(),
client_id: Url::parse(client_id).unwrap(),
scope: scope.to_string(),
}
}
}
use axum::{
extract::{Extension, FromRequest, RequestParts, TypedHeader},
headers::{
authorization::{Bearer, Credentials},
Authorization,
},
http::StatusCode,
};
// this newtype is required due to axum::Extension retrieving items by type
// it's based on compiler magic matching extensions by their type's hashes
#[derive(Debug, Clone)]
pub struct TokenEndpoint(pub url::Url);
#[async_trait::async_trait]
impl<B> FromRequest<B> for User
where
B: Send,
{
type Rejection = IndieAuthError;
#[cfg_attr(
all(debug_assertions, not(test)),
allow(unreachable_code, unused_variables)
)]
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
// Return a fake user if we're running a debug build
// I don't wanna bother with authentication
#[cfg(all(debug_assertions, not(test)))]
return Ok(User::new(
"http://localhost:8080/",
"https://quill.p3k.io/",
"create update delete media",
));
let TypedHeader(Authorization(token)) =
TypedHeader::<Authorization<Bearer>>::from_request(req)
.await
.map_err(IndieAuthError::from)?;
let Extension(TokenEndpoint(token_endpoint)): Extension<TokenEndpoint> =
Extension::from_request(req).await.unwrap();
let Extension(http): Extension<reqwest_middleware::ClientWithMiddleware> =
Extension::from_request(req).await.unwrap();
match http
.get(token_endpoint)
.header("Authorization", token.encode())
.header("Accept", "application/json")
.send()
.await
{
Ok(res) => match res.status() {
StatusCode::OK => match res.json::<serde_json::Value>().await {
Ok(json) => match serde_json::from_value::<User>(json.clone()) {
Ok(user) => Ok(user),
Err(err) => {
if let Some(false) = json["active"].as_bool() {
Err(IndieAuthError {
source: None,
kind: ErrorKind::NotAuthorized,
msg: "The token is not active for this user.".to_owned(),
})
} else {
Err(IndieAuthError::from(err))
}
}
},
Err(err) => Err(IndieAuthError::from(err)),
},
StatusCode::BAD_REQUEST => match res.json::<TokenEndpointError>().await {
Ok(err) => {
if err.error == "unauthorized" {
Err(IndieAuthError {
source: None,
kind: ErrorKind::NotAuthorized,
msg: err.error_description,
})
} else {
Err(IndieAuthError {
source: None,
kind: ErrorKind::TokenEndpointError,
msg: err.error_description,
})
}
}
Err(err) => Err(IndieAuthError::from(err)),
},
_ => Err(IndieAuthError {
source: None,
msg: format!("Token endpoint returned {}", res.status()),
kind: ErrorKind::TokenEndpointError,
}),
},
Err(err) => Err(IndieAuthError::from(err)),
}
}
}
#[cfg(test)]
mod tests {
use super::User;
use axum::{
extract::FromRequest,
http::{Method, Request},
};
use wiremock::{MockServer, Mock, ResponseTemplate};
use wiremock::matchers::{method, path, header};
#[test]
fn user_scopes_are_checkable() {
let user = User::new(
"https://fireburn.ru/",
"https://quill.p3k.io/",
"create update media",
);
assert!(user.check_scope("create"));
assert!(!user.check_scope("delete"));
}
#[inline]
fn get_http_client() -> reqwest_middleware::ClientWithMiddleware {
reqwest_middleware::ClientWithMiddleware::new()
}
fn request<A: Into<Option<&'static str>>>(
auth: A,
endpoint: String,
) -> Request<()> {
let request = Request::builder().method(Method::GET);
match auth.into() {
Some(auth) => request.header("Authorization", auth),
None => request,
}
.extension(super::TokenEndpoint(endpoint.parse().unwrap()))
.extension(get_http_client())
.body(())
.unwrap()
}
#[tokio::test]
async fn test_require_token_with_token() {
let server = MockServer::start().await;
Mock::given(path("/token"))
.and(header("Authorization", "Bearer token"))
.respond_with(ResponseTemplate::new(200)
.set_body_json(User::new(
"https://fireburn.ru/",
"https://quill.p3k.io/",
"create update media",
))
)
.mount(&server)
.await;
let request = request("Bearer token", format!("{}/token", &server.uri()));
let mut parts = axum::extract::RequestParts::new(request);
let user = User::from_request(&mut parts).await.unwrap();
assert_eq!(user.me.as_str(), "https://fireburn.ru/")
}
#[tokio::test]
async fn test_require_token_fake_token() {
let server = MockServer::start().await;
Mock::given(path("/refuse_token"))
.respond_with(ResponseTemplate::new(200)
.set_body_json(serde_json::json!({"active": false}))
)
.mount(&server)
.await;
let request = request("Bearer token", format!("{}/refuse_token", &server.uri()));
let mut parts = axum::extract::RequestParts::new(request);
let err = User::from_request(&mut parts).await.unwrap_err();
assert_eq!(err.kind, super::ErrorKind::NotAuthorized)
}
#[tokio::test]
async fn test_require_token_no_token() {
let server = MockServer::start().await;
Mock::given(path("/should_never_be_called"))
.respond_with(ResponseTemplate::new(500))
.expect(0)
.mount(&server)
.await;
let request = request(None, format!("{}/should_never_be_called", &server.uri()));
let mut parts = axum::extract::RequestParts::new(request);
let err = User::from_request(&mut parts).await.unwrap_err();
assert_eq!(err.kind, super::ErrorKind::InvalidHeader);
}
#[tokio::test]
async fn test_require_token_400_error_unauthorized() {
let server = MockServer::start().await;
Mock::given(path("/refuse_token_with_400"))
.and(header("Authorization", "Bearer token"))
.respond_with(ResponseTemplate::new(400)
.set_body_json(serde_json::json!({
"error": "unauthorized",
"error_description": "The token provided was malformed"
}))
)
.mount(&server)
.await;
let request = request(
"Bearer token",
format!("{}/refuse_token_with_400", &server.uri()),
);
let mut parts = axum::extract::RequestParts::new(request);
let err = User::from_request(&mut parts).await.unwrap_err();
assert_eq!(err.kind, super::ErrorKind::NotAuthorized);
}
}