diff options
author | Vika <vika@fireburn.ru> | 2022-02-21 07:34:55 +0300 |
---|---|---|
committer | Vika <vika@fireburn.ru> | 2022-02-21 07:34:55 +0300 |
commit | 2e54681af0bc76ed22ce43f8126a029e600ece93 (patch) | |
tree | d00626394908e9e4ec1c6169757ccf0194ba10bd /src/indieauth.rs | |
parent | 9e4c4551a786830bf34d74c4ef111a8ed292fa9f (diff) | |
download | kittybox-2e54681af0bc76ed22ce43f8126a029e600ece93.tar.zst |
Add a module for IndieAuth bearer token auth
require_token() uses a token endpoint URI and an HTTP client to query the token endpoint and return a User object if the user was authorized, or rejecting with IndieAuthError if not. It is recommended to use recover() and catch the IndieAuthError at the application level to show a "not authorized" error message to the user. This function is more intended for API consumption, but is general enough to permit using in other scenarios. TODO: make a variant that returns Option<User> instead of rejecting
Diffstat (limited to 'src/indieauth.rs')
-rw-r--r-- | src/indieauth.rs | 446 |
1 files changed, 245 insertions, 201 deletions
diff --git a/src/indieauth.rs b/src/indieauth.rs index f8f862b..305452a 100644 --- a/src/indieauth.rs +++ b/src/indieauth.rs @@ -1,14 +1,6 @@ -use async_trait::async_trait; -#[allow(unused_imports)] -use log::{error, info}; -use std::sync::Arc; -use tide::prelude::*; -#[allow(unused_imports)] -use tide::{Next, Request, Response, Result}; use url::Url; - -use crate::database; -use crate::ApplicationState; +use serde::{Serialize, Deserialize}; +use warp::{Filter, Rejection}; #[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] pub struct User { @@ -17,6 +9,71 @@ pub struct User { scope: String, } +#[derive(Debug, Clone, PartialEq, Copy)] +pub enum ErrorKind { + PermissionDenied, + NotAuthorized, + TokenEndpointError, + JsonParsing, + Other +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct TokenEndpointError { + error: String, + error_description: String +} + +#[derive(Debug)] +pub struct IndieAuthError { + source: Option<Box<dyn std::error::Error + Send + Sync>>, + kind: ErrorKind, + msg: String +} + +impl std::error::Error for IndieAuthError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.source.as_ref().map(|e| e.as_ref() as &dyn std::error::Error) + } +} + +impl std::fmt::Display for IndieAuthError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match match self.kind { + ErrorKind::TokenEndpointError => write!(f, "token endpoint returned an error: "), + ErrorKind::JsonParsing => write!(f, "error while parsing token endpoint response: "), + ErrorKind::NotAuthorized => write!(f, "token endpoint did not recognize the token: "), + ErrorKind::PermissionDenied => write!(f, "token endpoint rejected the token: "), + ErrorKind::Other => write!(f, "token endpoint communication error: "), + } { + Ok(_) => write!(f, "{}", self.msg), + Err(err) => Err(err) + } + } +} + +impl From<serde_json::Error> for IndieAuthError { + fn from(err: serde_json::Error) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::JsonParsing, + } + } +} + +impl From<hyper::Error> for IndieAuthError { + fn from(err: hyper::Error) -> Self { + Self { + msg: format!("{}", err), + source: Some(Box::new(err)), + kind: ErrorKind::Other, + } + } +} + +impl warp::reject::Reject for IndieAuthError {} + impl User { pub fn check_scope(&self, scope: &str) -> bool { self.scopes().any(|i| i == scope) @@ -33,207 +90,106 @@ impl User { } } -#[cfg(any(not(debug_assertions), test))] -async fn get_token_data( - token: String, - token_endpoint: &http_types::Url, - http_client: &surf::Client, -) -> (http_types::StatusCode, Option<User>) { - match http_client - .get(token_endpoint) - .header("Authorization", token) - .header("Accept", "application/json") - .send() - .await - { - Ok(mut resp) => { - if resp.status() == 200 { - match resp.body_json::<User>().await { - Ok(user) => { - info!( - "Token endpoint request successful. Validated user: {}", - user.me - ); - (resp.status(), Some(user)) - } - Err(err) => { - error!( - "Token endpoint parsing error (HTTP status {}): {}", - resp.status(), - err - ); - (http_types::StatusCode::InternalServerError, None) - } - } - } else { - error!("Token endpoint returned non-200: {}", resp.status()); - (resp.status(), None) - } - } - Err(err) => { - error!("Token endpoint connection error: {}", err); - (http_types::StatusCode::InternalServerError, None) - } - } -} +// TODO: consider making this a generic +type HttpClient = hyper::Client<hyper_rustls::HttpsConnector<hyper::client::HttpConnector<hyper::client::connect::dns::GaiResolver>>, hyper::Body>; -pub struct IndieAuthMiddleware { - #[allow(dead_code)] // it's not really dead since it's only dead in debug scope - cache: Arc<retainer::Cache<String, User>>, - monitor_task: Option<async_std::task::JoinHandle<()>>, -} -impl IndieAuthMiddleware { - /// Create a new instance of IndieAuthMiddleware. - /// - /// Note that creating a new instance automatically launches a task - /// to garbage-collect stale cache entries. Please do not create - /// instances willy-nilly because of that. - pub fn new() -> Self { - let cache: Arc<retainer::Cache<String, User>> = Arc::new(retainer::Cache::new()); - let cache_clone = cache.clone(); - let task = async_std::task::spawn(async move { - cache_clone - .monitor(4, 0.1, std::time::Duration::from_secs(30)) - .await - }); - - #[cfg(all(debug_assertions, not(test)))] - error!("ATTENTION: You are running in debug mode. NO REQUESTS TO TOKEN ENDPOINT WILL BE MADE. YOU WILL BE PROCEEDING WITH DEBUG USER CREDENTIALS. DO NOT RUN LIKE THIS IN PRODUCTION."); +pub fn require_token(token_endpoint: String, http: HttpClient) -> impl Filter<Extract = (User,), Error = Rejection> { + // It might be OK to panic here, because we're still inside the initialisation sequence for now. + // Proper error handling on the top of this should be used though. + let token_endpoint_uri = hyper::Uri::try_from(&token_endpoint) + .expect("Couldn't parse the token endpoint URI!"); + warp::any() + .map(move || token_endpoint_uri.clone()) + .and(warp::any().map(move || http.clone())) + .and(warp::header::<String>("Authorization")) + .and_then(|token_endpoint, http: HttpClient, token| async move { + let request = hyper::Request::builder() + .method(hyper::Method::GET) + .uri(token_endpoint) + .header("Authorization", token) + .header("Accept", "application/json") + .body(hyper::Body::from("")) + // TODO is it acceptable to panic here? + .unwrap(); - Self { - cache, - monitor_task: Some(task), - } - } -} -impl Drop for IndieAuthMiddleware { - fn drop(&mut self) { - // Cancel the task, or a VERY FUNNY thing might occur. - // If I understand this correctly, keeping a task active - // WILL keep an active reference to a value, so I'm pretty sure - // that something VERY FUNNY might occur whenever `cache` is dropped - // and its related task is not cancelled. So let's cancel it so - // [`cache`] can be dropped once and for all. - - // First, get the ownership of a task, sneakily switching it out with None - // (wow, this is sneaky, didn't know Safe Rust could even do that!!!) - // (it is safe tho cuz None is no nullptr and dereferencing it doesn't cause unsafety) - // (could cause a VERY FUNNY race condition to occur though - // if you tried to refer to the value in another thread!) - let task = std::mem::take(&mut self.monitor_task) - .expect("Dropped IndieAuthMiddleware TWICE? Impossible!"); - // Then cancel the task, using another task to request cancellation. - // Because apparently you can't run async code from Drop... - // This should drop the last reference for the [`cache`], - // allowing it to be dropped. - async_std::task::spawn(async move { task.cancel().await }); - } -} -#[async_trait] -impl<B> tide::Middleware<ApplicationState<B>> for IndieAuthMiddleware -where - B: database::Storage + Send + Sync + Clone, -{ - #[cfg(all(not(test), debug_assertions))] - async fn handle( - &self, - mut req: Request<ApplicationState<B>>, - next: Next<'_, ApplicationState<B>>, - ) -> Result { - req.set_ext(User::new( - "https://localhost:8080/", - "https://curl.haxx.se/", - "create update delete undelete media", - )); - Ok(next.run(req).await) - } - #[cfg(any(not(debug_assertions), test))] - async fn handle( - &self, - mut req: Request<ApplicationState<B>>, - next: Next<'_, ApplicationState<B>>, - ) -> Result { - let header = req.header("Authorization"); - match header { - None => { - // TODO: move that to the request handling functions - // or make a middleware that refuses to accept unauthenticated requests - Ok(Response::builder(401) - .body(json!({ - "error": "unauthorized", - "error_description": "Please provide an access token." - })) - .build()) - } - Some(value) => { - match &req.state().internal_token { - Some(token) => { - if token - == &value - .last() - .to_string() - .split(' ') - .skip(1) - .collect::<String>() - { - req.set_ext::<User>(User::new( - "", // no user ID here - "https://kittybox.fireburn.ru/", - "update delete undelete media kittybox_internal:do_what_thou_wilt", - )); - return Ok(next.run(req).await); + use hyper::StatusCode; + + match http.request(request).await { + Ok(mut res) => match res.status() { + StatusCode::OK => { + use hyper::body::HttpBody; + use bytes::BufMut; + let mut buf: Vec<u8> = Vec::default(); + while let Some(chunk) = res.body_mut().data().await { + if let Err(err) = chunk { + return Err(IndieAuthError::from(err).into()); + } + buf.put(chunk.unwrap()); + } + match serde_json::from_slice(&buf) { + Ok(user) => Ok(user), + Err(err) => { + if let Ok(json) = serde_json::from_slice::<serde_json::Value>(&buf) { + if Some(false) == json["active"].as_bool() { + Err(IndieAuthError { + source: None, + kind: ErrorKind::NotAuthorized, + msg: "The token endpoint deemed the token as not \"active\".".to_string() + }.into()) + } else { + Err(IndieAuthError::from(err).into()) + } + } else { + Err(IndieAuthError::from(err).into()) + } + } } - } - None => {} - } - let endpoint = &req.state().token_endpoint; - let http_client = &req.state().http_client; - let token = value.last().to_string(); - match self.cache.get(&token).await { - Some(user) => { - req.set_ext::<User>(user.clone()); - Ok(next.run(req).await) }, - None => match get_token_data(value.last().to_string(), endpoint, http_client).await { - (http_types::StatusCode::Ok, Some(user)) => { - // Note that this can run multiple requests before the value appears in the cache. - // This seems to be in line with some other implementations of a function cache - // (e.g. the [`cached`](https://lib.rs/crates/cached) crate and Python's `functools.lru_cache`) - // - // TODO: ensure the duration is no more than the token's remaining time until expiration - // (in case the expiration time is defined on the token - AFAIK currently non-standard in IndieAuth) - self.cache.insert(token, user.clone(), std::time::Duration::from_secs(600)).await; - req.set_ext(user); - Ok(next.run(req).await) - }, - // TODO: Refactor to return Err(IndieAuthError) so downstream middleware could catch it - // and present a prettier interface to the error (maybe even hiding data from the user) - (http_types::StatusCode::InternalServerError, None) => { - Ok(Response::builder(500).body(json!({ - "error": "token_endpoint_fail", - "error_description": "Token endpoint made a boo-boo and refused to answer." - })).build()) - }, - (_, None) => { - Ok(Response::builder(401).body(json!({ - "error": "unauthorized", - "error_description": "The token endpoint refused to accept your token." - })).build()) - }, - (_, Some(_)) => { - // This shouldn't happen. - panic!("The token validation function has caught rabies and returns malformed responses. Aborting."); + StatusCode::BAD_REQUEST => { + use hyper::body::HttpBody; + use bytes::BufMut; + let mut buf: Vec<u8> = Vec::default(); + while let Some(chunk) = res.body_mut().data().await { + if let Err(err) = chunk { + return Err(IndieAuthError::from(err).into()); + } + buf.put(chunk.unwrap()); + } + match serde_json::from_slice::<TokenEndpointError>(&buf) { + Ok(err) => { + if err.error == "unauthorized" { + Err(IndieAuthError { + source: None, + kind: ErrorKind::NotAuthorized, + msg: err.error_description + }.into()) + } else { + Err(IndieAuthError { + source: None, + kind: ErrorKind::TokenEndpointError, + msg: err.error_description + }.into()) + } + }, + Err(err) => Err(IndieAuthError::from(err).into()) } - } - } + }, + _ => Err(IndieAuthError { + source: None, + msg: format!("Token endpoint returned {}", res.status()).to_string(), + kind: ErrorKind::TokenEndpointError + }.into()) + }, + Err(err) => Err(warp::reject::custom(IndieAuthError::from(err))) } - } - } + }) } #[cfg(test)] mod tests { - use super::*; + use super::{HttpClient, User, IndieAuthError, require_token}; + use httpmock::prelude::*; + #[test] fn user_scopes_are_checkable() { let user = User::new( @@ -245,4 +201,92 @@ mod tests { assert!(user.check_scope("create")); assert!(!user.check_scope("delete")); } + + fn get_http_client() -> HttpClient { + let builder = hyper::Client::builder(); + let https = hyper_rustls::HttpsConnectorBuilder::new() + .with_webpki_roots() + .https_or_http() + .enable_http1() + .enable_http2() + .build(); + builder.build(https) + } + + #[tokio::test] + async fn test_require_token_with_token() { + let server = MockServer::start_async().await; + server.mock_async(|when, then| { + when.path("/token") + .header("Authorization", "Bearer token"); + + then.status(200) + .header("Content-Type", "application/json") + .json_body(serde_json::to_value(User::new( + "https://fireburn.ru/", + "https://quill.p3k.io/", + "create update media", + )).unwrap()); + }).await; + + let filter = require_token(server.url("/token"), get_http_client()); + + let res: User = warp::test::request() + .path("/") + .header("Authorization", "Bearer token") + .filter(&filter) + .await + .unwrap(); + + assert_eq!(res.me.as_str(), "https://fireburn.ru/") + } + + #[tokio::test] + async fn test_require_token_fake_token() { + let server = MockServer::start_async().await; + server.mock_async(|when, then| { + when.path("/refuse_token"); + + then.status(200) + .json_body(serde_json::json!({"active": false})); + }).await; + + let filter = require_token(server.url("/refuse_token"), get_http_client()); + + let res = warp::test::request() + .path("/") + .header("Authorization", "Bearer token") + .filter(&filter) + .await + .unwrap_err(); + + let err: &IndieAuthError = res.find().unwrap(); + assert_eq!(err.kind, super::ErrorKind::NotAuthorized); + } + + #[tokio::test] + async fn test_require_token_400_error_unauthorized() { + let server = MockServer::start_async().await; + server.mock_async(|when, then| { + when.path("/refuse_token_with_400"); + + then.status(400) + .json_body(serde_json::json!({ + "error": "unauthorized", + "error_description": "The token provided was malformed" + })); + }).await; + + let filter = require_token(server.url("/refuse_token_with_400"), get_http_client()); + + let res = warp::test::request() + .path("/") + .header("Authorization", "Bearer token") + .filter(&filter) + .await + .unwrap_err(); + + let err: &IndieAuthError = res.find().unwrap(); + assert_eq!(err.kind, super::ErrorKind::NotAuthorized); + } } |