diff options
author | Vika <vika@fireburn.ru> | 2021-05-05 22:40:01 +0300 |
---|---|---|
committer | Vika <vika@fireburn.ru> | 2021-05-05 22:40:01 +0300 |
commit | cbbfca9af1f0aa9da87709f99353fd76fd6617a8 (patch) | |
tree | 232f4bfc8682255195dc1a2278e2830db21dd7bb | |
parent | dd9d3ff3e9505926e72df7df679cefe960be23bd (diff) | |
download | kittybox-cbbfca9af1f0aa9da87709f99353fd76fd6617a8.tar.zst |
Added rudimentary caching to IndieAuth middleware
-rw-r--r-- | Cargo.lock | 24 | ||||
-rw-r--r-- | Cargo.toml | 15 | ||||
-rw-r--r-- | src/indieauth.rs | 110 | ||||
-rw-r--r-- | src/lib.rs | 4 |
4 files changed, 116 insertions, 37 deletions
diff --git a/Cargo.lock b/Cargo.lock index 87d7ce7..9b165db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -296,6 +296,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e91831deabf0d6d7ec49552e489aed63b7456a7a3c46cff62adad428110b0af0" [[package]] +name = "async-timer" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5fa6ed76cb2aa820707b4eb9ec46f42da9ce70b0eafab5e5e34942b38a44d5" +dependencies = [ + "libc", + "wasm-bindgen", + "winapi", +] + +[[package]] name = "async-trait" version = "0.1.50" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1223,6 +1234,7 @@ dependencies = [ "mobc-redis", "mockito", "newbase60", + "retainer", "serde", "serde_json", "serde_urlencoded", @@ -1846,6 +1858,18 @@ dependencies = [ ] [[package]] +name = "retainer" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59039dbf4a344af919780e9acdf7f9ce95deffb0152a72eca94b89d6a2bf66c0" +dependencies = [ + "async-lock", + "async-timer", + "log", + "rand 0.8.3", +] + +[[package]] name = "route-recognizer" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" diff --git a/Cargo.toml b/Cargo.toml index 6ce2cbe..f3b7433 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,8 +12,6 @@ mockito = "0.30.0" # HTTP mocking for Rust. tempdir = "0.3.7" # A library for managing a temporary directory and deleting all contents when it's dropped. [dependencies] -# Redis driver for Rust. -#redis = { version = "0.20.0", features = ["aio", "async-std-comp"] } # Redis support for the mobc connection pool mobc-redis = { version = "0.7.0", features = ["async-std-comp"], default-features = false } # A generic serialization/deserialization framework @@ -24,18 +22,19 @@ chrono = { version = "0.4.19", features = ["serde"] } url = { version = "2.2.1", features = ["serde"] } # Async version of the Rust standard library async-std = { version = "1.9.0", features = ["attributes"] } -lazy_static = "1.4.0" # A macro for declaring lazily evaluated statics in Rust. async-trait = "0.1.50" # Type erasure for async trait methods +easy-scraper = "0.2.0" # HTML scraping library focused on ease of use env_logger = "0.8.3" # A logging implementation for `log` which is configured via an environment variable. futures = "0.3.14" # An implementation of futures and streams futures-util = "0.3.14" # Common utilities and extension traits for the futures-rs library. http-types = "2.11.0" # Common types for HTTP operations. +lazy_static = "1.4.0" # A macro for declaring lazily evaluated statics in Rust. log = "0.4.14" # A lightweight logging facade for Rust +markdown = "0.3.0" # Native Rust library for parsing Markdown and (outputting HTML) +mobc = "0.7.2" # A generic connection pool with async/await support +newbase60 = "0.1.3" # A library that implements Tantek Çelik's New Base 60 +retainer = "0.2.2" # Minimal async cache in Rust with support for key expirations serde_json = "1.0.64" # A JSON serialization file format +serde_urlencoded = "0.7.0" # `x-www-form-urlencoded` meets Serde surf = "2.2.0" # Surf the web - HTTP client framework tide = "0.16.0" # A minimal and pragmatic Rust web application framework built for rapid development -newbase60 = "0.1.3" # A library that implements Tantek Çelik's New Base 60 -markdown = "0.3.0" # Native Rust library for parsing Markdown and (outputting HTML) -easy-scraper = "0.2.0" # HTML scraping library focused on ease of use -serde_urlencoded = "0.7.0" # `x-www-form-urlencoded` meets Serde -mobc = "0.7.2" # A generic connection pool with async/await support diff --git a/src/indieauth.rs b/src/indieauth.rs index 46d2459..4fac5e1 100644 --- a/src/indieauth.rs +++ b/src/indieauth.rs @@ -1,14 +1,14 @@ +use async_trait::async_trait; use log::{error,info}; -use std::future::Future; -use std::pin::Pin; use url::Url; use tide::prelude::*; use tide::{Request, Response, Next, Result}; +use std::sync::Arc; use crate::database; use crate::ApplicationState; -#[derive(Deserialize, Serialize, Debug, PartialEq)] +#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] pub struct User { pub me: Url, pub client_id: Url, @@ -58,15 +58,55 @@ async fn get_token_data(token: String, token_endpoint: &http_types::Url, http_cl } } -// TODO: Figure out how to cache these authorization values - they can potentially take a lot of processing time -pub fn check_auth<'a, Backend>(mut req: Request<ApplicationState<Backend>>, next: Next<'a, ApplicationState<Backend>>) -> Pin<Box<dyn Future<Output = Result> + Send + 'a>> -where - Backend: database::Storage + Send + Sync + Clone +pub struct IndieAuthMiddleware { + cache: Arc<retainer::Cache<String, User>>, + monitor_task: Option<async_std::task::JoinHandle<()>> +} +impl IndieAuthMiddleware { + /// Create a new instance of IndieAuthMiddleware. + /// + /// Note that creating a new instance automatically launches a task + /// to garbage-collect stale cache entries. Please do not create + /// instances willy-nilly because of that. + pub fn new() -> Self { + let cache: Arc<retainer::Cache<String, User>> = Arc::new(retainer::Cache::new()); + let cache_clone = cache.clone(); + let task = async_std::task::spawn(async move { cache_clone.monitor(4, 0.1, std::time::Duration::from_secs(30)).await }); + Self { cache, monitor_task: Some(task) } + } +} +impl Drop for IndieAuthMiddleware { + fn drop(&mut self) { + // Cancel the task, or a VERY FUNNY thing might occur. + // If I understand this correctly, keeping a task active + // WILL keep an active reference to a value, so I'm pretty sure + // that something VERY FUNNY might occur whenever `cache` is dropped + // and its related task is not cancelled. So let's cancel it so + // [`cache`] can be dropped once and for all. + + // First, get the ownership of a task, sneakily switching it out with None + // (wow, this is sneaky, didn't know Safe Rust could even do that!!!) + // (it is safe tho cuz None is no nullptr and dereferencing it doesn't cause unsafety) + // (could cause a VERY FUNNY race condition to occur though + // if you tried to refer to the value in another thread!) + let task = std::mem::take(&mut self.monitor_task).expect("Dropped IndieAuthMiddleware TWICE? Impossible!"); + // Then cancel the task, using another task to request cancellation. + // Because apparently you can't run async code from Drop... + // This should drop the last reference for the [`cache`], + // allowing it to be dropped. + async_std::task::spawn(async move { task.cancel() }); + } +} +#[async_trait] +impl<B> tide::Middleware<ApplicationState<B>> for IndieAuthMiddleware where + B: database::Storage + Send + Sync + Clone { - Box::pin(async { + async fn handle(&self, mut req: Request<ApplicationState<B>>, next: Next<'_, ApplicationState<B>>) -> Result { let header = req.header("Authorization"); match header { None => { + // TODO: move that to the request handling functions + // or make a middleware that refuses to accept unauthenticated requests Ok(Response::builder(401).body(json!({ "error": "unauthorized", "error_description": "Please provide an access token." @@ -75,31 +115,47 @@ where Some(value) => { let endpoint = &req.state().token_endpoint; let http_client = &req.state().http_client; - match get_token_data(value.last().to_string(), endpoint, http_client).await { - (http_types::StatusCode::Ok, Some(user)) => { - req.set_ext(user); + let token = value.last().to_string(); + match self.cache.get(&token).await { + Some(user) => { + req.set_ext::<User>(user.clone()); Ok(next.run(req).await) }, - (http_types::StatusCode::InternalServerError, None) => { - Ok(Response::builder(500).body(json!({ - "error": "token_endpoint_fail", - "error_description": "Token endpoint made a boo-boo and refused to answer." - })).build()) - }, - (_, None) => { - Ok(Response::builder(401).body(json!({ - "error": "unauthorized", - "error_description": "The token endpoint refused to accept your token." - })).build()) - }, - (_, Some(_)) => { - // This shouldn't happen. - panic!("The token validation function has caught rabies and returns malformed responses. Aborting."); + None => match get_token_data(value.last().to_string(), endpoint, http_client).await { + (http_types::StatusCode::Ok, Some(user)) => { + // Note that this can run multiple requests before the value appears in the cache. + // This seems to be in line with some other implementations of a function cache + // (e.g. the [`cached`](https://lib.rs/crates/cached) crate and Python's `functools.lru_cache`) + // + // TODO: ensure the duration is no more than the token's remaining time until expiration + // (in case the expiration time is defined on the token - AFAIK currently non-standard in IndieAuth) + self.cache.insert(token, user.clone(), std::time::Duration::from_secs(600)).await; + req.set_ext(user); + Ok(next.run(req).await) + }, + // TODO: Refactor to return Err(IndieAuthError) so downstream middleware could catch it + // and present a prettier interface to the error (maybe even hiding data from the user) + (http_types::StatusCode::InternalServerError, None) => { + Ok(Response::builder(500).body(json!({ + "error": "token_endpoint_fail", + "error_description": "Token endpoint made a boo-boo and refused to answer." + })).build()) + }, + (_, None) => { + Ok(Response::builder(401).body(json!({ + "error": "unauthorized", + "error_description": "The token endpoint refused to accept your token." + })).build()) + }, + (_, Some(_)) => { + // This shouldn't happen. + panic!("The token validation function has caught rabies and returns malformed responses. Aborting."); + } } } } } - }) + } } #[cfg(test)] diff --git a/src/lib.rs b/src/lib.rs index 86376d4..c422fea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,7 @@ mod database; mod indieauth; mod micropub; -use crate::indieauth::check_auth; +use crate::indieauth::IndieAuthMiddleware; use crate::micropub::{get_handler,post_handler}; #[derive(Clone)] @@ -39,7 +39,7 @@ fn equip_app<Storage>(mut app: App<Storage>) -> App<Storage> where Storage: database::Storage + Send + Sync + Clone { - app.at("/micropub").with(check_auth).get(get_handler).post(post_handler); + app.at("/micropub").with(IndieAuthMiddleware::new()).get(get_handler).post(post_handler); // The Micropub client. It'll start small, but could grow into something full-featured! app.at("/micropub/client").get(|_: Request<_>| async move { Ok(Response::builder(200).body(MICROPUB_CLIENT).content_type("text/html").build()) |