diff options
Diffstat (limited to 'src/indieauth.rs')
-rw-r--r-- | src/indieauth.rs | 110 |
1 files changed, 83 insertions, 27 deletions
diff --git a/src/indieauth.rs b/src/indieauth.rs index 46d2459..4fac5e1 100644 --- a/src/indieauth.rs +++ b/src/indieauth.rs @@ -1,14 +1,14 @@ +use async_trait::async_trait; use log::{error,info}; -use std::future::Future; -use std::pin::Pin; use url::Url; use tide::prelude::*; use tide::{Request, Response, Next, Result}; +use std::sync::Arc; use crate::database; use crate::ApplicationState; -#[derive(Deserialize, Serialize, Debug, PartialEq)] +#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] pub struct User { pub me: Url, pub client_id: Url, @@ -58,15 +58,55 @@ async fn get_token_data(token: String, token_endpoint: &http_types::Url, http_cl } } -// TODO: Figure out how to cache these authorization values - they can potentially take a lot of processing time -pub fn check_auth<'a, Backend>(mut req: Request<ApplicationState<Backend>>, next: Next<'a, ApplicationState<Backend>>) -> Pin<Box<dyn Future<Output = Result> + Send + 'a>> -where - Backend: database::Storage + Send + Sync + Clone +pub struct IndieAuthMiddleware { + cache: Arc<retainer::Cache<String, User>>, + monitor_task: Option<async_std::task::JoinHandle<()>> +} +impl IndieAuthMiddleware { + /// Create a new instance of IndieAuthMiddleware. + /// + /// Note that creating a new instance automatically launches a task + /// to garbage-collect stale cache entries. Please do not create + /// instances willy-nilly because of that. + pub fn new() -> Self { + let cache: Arc<retainer::Cache<String, User>> = Arc::new(retainer::Cache::new()); + let cache_clone = cache.clone(); + let task = async_std::task::spawn(async move { cache_clone.monitor(4, 0.1, std::time::Duration::from_secs(30)).await }); + Self { cache, monitor_task: Some(task) } + } +} +impl Drop for IndieAuthMiddleware { + fn drop(&mut self) { + // Cancel the task, or a VERY FUNNY thing might occur. + // If I understand this correctly, keeping a task active + // WILL keep an active reference to a value, so I'm pretty sure + // that something VERY FUNNY might occur whenever `cache` is dropped + // and its related task is not cancelled. So let's cancel it so + // [`cache`] can be dropped once and for all. + + // First, get the ownership of a task, sneakily switching it out with None + // (wow, this is sneaky, didn't know Safe Rust could even do that!!!) + // (it is safe tho cuz None is no nullptr and dereferencing it doesn't cause unsafety) + // (could cause a VERY FUNNY race condition to occur though + // if you tried to refer to the value in another thread!) + let task = std::mem::take(&mut self.monitor_task).expect("Dropped IndieAuthMiddleware TWICE? Impossible!"); + // Then cancel the task, using another task to request cancellation. + // Because apparently you can't run async code from Drop... + // This should drop the last reference for the [`cache`], + // allowing it to be dropped. + async_std::task::spawn(async move { task.cancel() }); + } +} +#[async_trait] +impl<B> tide::Middleware<ApplicationState<B>> for IndieAuthMiddleware where + B: database::Storage + Send + Sync + Clone { - Box::pin(async { + async fn handle(&self, mut req: Request<ApplicationState<B>>, next: Next<'_, ApplicationState<B>>) -> Result { let header = req.header("Authorization"); match header { None => { + // TODO: move that to the request handling functions + // or make a middleware that refuses to accept unauthenticated requests Ok(Response::builder(401).body(json!({ "error": "unauthorized", "error_description": "Please provide an access token." @@ -75,31 +115,47 @@ where Some(value) => { let endpoint = &req.state().token_endpoint; let http_client = &req.state().http_client; - match get_token_data(value.last().to_string(), endpoint, http_client).await { - (http_types::StatusCode::Ok, Some(user)) => { - req.set_ext(user); + let token = value.last().to_string(); + match self.cache.get(&token).await { + Some(user) => { + req.set_ext::<User>(user.clone()); Ok(next.run(req).await) }, - (http_types::StatusCode::InternalServerError, None) => { - Ok(Response::builder(500).body(json!({ - "error": "token_endpoint_fail", - "error_description": "Token endpoint made a boo-boo and refused to answer." - })).build()) - }, - (_, None) => { - Ok(Response::builder(401).body(json!({ - "error": "unauthorized", - "error_description": "The token endpoint refused to accept your token." - })).build()) - }, - (_, Some(_)) => { - // This shouldn't happen. - panic!("The token validation function has caught rabies and returns malformed responses. Aborting."); + None => match get_token_data(value.last().to_string(), endpoint, http_client).await { + (http_types::StatusCode::Ok, Some(user)) => { + // Note that this can run multiple requests before the value appears in the cache. + // This seems to be in line with some other implementations of a function cache + // (e.g. the [`cached`](https://lib.rs/crates/cached) crate and Python's `functools.lru_cache`) + // + // TODO: ensure the duration is no more than the token's remaining time until expiration + // (in case the expiration time is defined on the token - AFAIK currently non-standard in IndieAuth) + self.cache.insert(token, user.clone(), std::time::Duration::from_secs(600)).await; + req.set_ext(user); + Ok(next.run(req).await) + }, + // TODO: Refactor to return Err(IndieAuthError) so downstream middleware could catch it + // and present a prettier interface to the error (maybe even hiding data from the user) + (http_types::StatusCode::InternalServerError, None) => { + Ok(Response::builder(500).body(json!({ + "error": "token_endpoint_fail", + "error_description": "Token endpoint made a boo-boo and refused to answer." + })).build()) + }, + (_, None) => { + Ok(Response::builder(401).body(json!({ + "error": "unauthorized", + "error_description": "The token endpoint refused to accept your token." + })).build()) + }, + (_, Some(_)) => { + // This shouldn't happen. + panic!("The token validation function has caught rabies and returns malformed responses. Aborting."); + } } } } } - }) + } } #[cfg(test)] |