use url::Url; use serde::{Serialize, Deserialize}; use warp::{Filter, Rejection, reject::MissingHeader}; #[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] pub struct User { pub me: Url, pub client_id: Url, scope: String, } #[derive(Debug, Clone, PartialEq, Copy)] pub enum ErrorKind { PermissionDenied, NotAuthorized, TokenEndpointError, JsonParsing, Other } #[derive(Deserialize, Serialize, Debug, Clone)] pub struct TokenEndpointError { error: String, error_description: String } #[derive(Debug)] pub struct IndieAuthError { source: Option>, kind: ErrorKind, msg: String } impl std::error::Error for IndieAuthError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { self.source.as_ref().map(|e| e.as_ref() as &dyn std::error::Error) } } impl std::fmt::Display for IndieAuthError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match match self.kind { ErrorKind::TokenEndpointError => write!(f, "token endpoint returned an error: "), ErrorKind::JsonParsing => write!(f, "error while parsing token endpoint response: "), ErrorKind::NotAuthorized => write!(f, "token endpoint did not recognize the token: "), ErrorKind::PermissionDenied => write!(f, "token endpoint rejected the token: "), ErrorKind::Other => write!(f, "token endpoint communication error: "), } { Ok(_) => write!(f, "{}", self.msg), Err(err) => Err(err) } } } impl From for IndieAuthError { fn from(err: serde_json::Error) -> Self { Self { msg: format!("{}", err), source: Some(Box::new(err)), kind: ErrorKind::JsonParsing, } } } impl From for IndieAuthError { fn from(err: hyper::Error) -> Self { Self { msg: format!("{}", err), source: Some(Box::new(err)), kind: ErrorKind::Other, } } } impl warp::reject::Reject for IndieAuthError {} impl User { pub fn check_scope(&self, scope: &str) -> bool { self.scopes().any(|i| i == scope) } pub fn scopes(&self) -> std::str::SplitAsciiWhitespace<'_> { self.scope.split_ascii_whitespace() } pub fn new(me: &str, client_id: &str, scope: &str) -> Self { Self { me: Url::parse(me).unwrap(), client_id: Url::parse(client_id).unwrap(), scope: scope.to_string(), } } } pub fn require_token(token_endpoint: String, http: hyper::Client) -> impl Filter + Clone where T: hyper::client::connect::Connect + Clone + Send + Sync + 'static { // It might be OK to panic here, because we're still inside the initialisation sequence for now. // Proper error handling on the top of this should be used though. let token_endpoint_uri = hyper::Uri::try_from(&token_endpoint) .expect("Couldn't parse the token endpoint URI!"); warp::any() .map(move || token_endpoint_uri.clone()) .and(warp::any().map(move || http.clone())) .and(warp::header::("Authorization").recover(|err: Rejection| async move { if err.find::().is_some() { Err(IndieAuthError { source: None, msg: "No Authorization header provided.".to_string(), kind: ErrorKind::NotAuthorized }.into()) } else { Err(err) } }).unify()) .and_then(|token_endpoint, http: hyper::Client, token| async move { let request = hyper::Request::builder() .method(hyper::Method::GET) .uri(token_endpoint) .header("Authorization", token) .header("Accept", "application/json") .body(hyper::Body::from("")) // TODO is it acceptable to panic here? .unwrap(); use hyper::StatusCode; match http.request(request).await { Ok(mut res) => match res.status() { StatusCode::OK => { use hyper::body::HttpBody; use bytes::BufMut; let mut buf: Vec = Vec::default(); while let Some(chunk) = res.body_mut().data().await { if let Err(err) = chunk { return Err(IndieAuthError::from(err).into()); } buf.put(chunk.unwrap()); } match serde_json::from_slice(&buf) { Ok(user) => Ok(user), Err(err) => { if let Ok(json) = serde_json::from_slice::(&buf) { if Some(false) == json["active"].as_bool() { Err(IndieAuthError { source: None, kind: ErrorKind::NotAuthorized, msg: "The token endpoint deemed the token as not \"active\".".to_string() }.into()) } else { Err(IndieAuthError::from(err).into()) } } else { Err(IndieAuthError::from(err).into()) } } } }, StatusCode::BAD_REQUEST => { use hyper::body::HttpBody; use bytes::BufMut; let mut buf: Vec = Vec::default(); while let Some(chunk) = res.body_mut().data().await { if let Err(err) = chunk { return Err(IndieAuthError::from(err).into()); } buf.put(chunk.unwrap()); } match serde_json::from_slice::(&buf) { Ok(err) => { if err.error == "unauthorized" { Err(IndieAuthError { source: None, kind: ErrorKind::NotAuthorized, msg: err.error_description }.into()) } else { Err(IndieAuthError { source: None, kind: ErrorKind::TokenEndpointError, msg: err.error_description }.into()) } }, Err(err) => Err(IndieAuthError::from(err).into()) } }, _ => Err(IndieAuthError { source: None, msg: format!("Token endpoint returned {}", res.status()).to_string(), kind: ErrorKind::TokenEndpointError }.into()) }, Err(err) => Err(warp::reject::custom(IndieAuthError::from(err))) } }) } #[cfg(test)] mod tests { use super::{User, IndieAuthError, require_token}; use httpmock::prelude::*; #[test] fn user_scopes_are_checkable() { let user = User::new( "https://fireburn.ru/", "https://quill.p3k.io/", "create update media", ); assert!(user.check_scope("create")); assert!(!user.check_scope("delete")); } fn get_http_client() -> hyper::Client { let builder = hyper::Client::builder(); let https = hyper_rustls::HttpsConnectorBuilder::new() .with_webpki_roots() .https_or_http() .enable_http1() .enable_http2() .build(); builder.build(https) } #[tokio::test] async fn test_require_token_with_token() { let server = MockServer::start_async().await; server.mock_async(|when, then| { when.path("/token") .header("Authorization", "Bearer token"); then.status(200) .header("Content-Type", "application/json") .json_body(serde_json::to_value(User::new( "https://fireburn.ru/", "https://quill.p3k.io/", "create update media", )).unwrap()); }).await; let filter = require_token(server.url("/token"), get_http_client()); let res: User = warp::test::request() .path("/") .header("Authorization", "Bearer token") .filter(&filter) .await .unwrap(); assert_eq!(res.me.as_str(), "https://fireburn.ru/") } #[tokio::test] async fn test_require_token_fake_token() { let server = MockServer::start_async().await; server.mock_async(|when, then| { when.path("/refuse_token"); then.status(200) .json_body(serde_json::json!({"active": false})); }).await; let filter = require_token(server.url("/refuse_token"), get_http_client()); let res = warp::test::request() .path("/") .header("Authorization", "Bearer token") .filter(&filter) .await .unwrap_err(); let err: &IndieAuthError = res.find().unwrap(); assert_eq!(err.kind, super::ErrorKind::NotAuthorized); } #[tokio::test] async fn test_require_token_no_token() { let server = MockServer::start_async().await; let mock = server.mock_async(|when, then| { when.path("/should_never_be_called"); then.status(500); }).await; let filter = require_token(server.url("/should_never_be_called"), get_http_client()); let res = warp::test::request() .path("/") .filter(&filter) .await .unwrap_err(); let err: &IndieAuthError = res.find().unwrap(); assert_eq!(err.kind, super::ErrorKind::NotAuthorized); mock.assert_hits_async(0).await; } #[tokio::test] async fn test_require_token_400_error_unauthorized() { let server = MockServer::start_async().await; server.mock_async(|when, then| { when.path("/refuse_token_with_400"); then.status(400) .json_body(serde_json::json!({ "error": "unauthorized", "error_description": "The token provided was malformed" })); }).await; let filter = require_token(server.url("/refuse_token_with_400"), get_http_client()); let res = warp::test::request() .path("/") .header("Authorization", "Bearer token") .filter(&filter) .await .unwrap_err(); let err: &IndieAuthError = res.find().unwrap(); assert_eq!(err.kind, super::ErrorKind::NotAuthorized); } }