about summary refs log tree commit diff
diff options
context:
space:
mode:
authorVika <vika@fireburn.ru>2021-05-05 22:40:01 +0300
committerVika <vika@fireburn.ru>2021-05-05 22:40:01 +0300
commitcbbfca9af1f0aa9da87709f99353fd76fd6617a8 (patch)
tree232f4bfc8682255195dc1a2278e2830db21dd7bb
parentdd9d3ff3e9505926e72df7df679cefe960be23bd (diff)
downloadkittybox-cbbfca9af1f0aa9da87709f99353fd76fd6617a8.tar.zst
Added rudimentary caching to IndieAuth middleware
-rw-r--r--Cargo.lock24
-rw-r--r--Cargo.toml15
-rw-r--r--src/indieauth.rs110
-rw-r--r--src/lib.rs4
4 files changed, 116 insertions, 37 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 87d7ce7..9b165db 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -296,6 +296,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "e91831deabf0d6d7ec49552e489aed63b7456a7a3c46cff62adad428110b0af0"
 
 [[package]]
+name = "async-timer"
+version = "0.7.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ba5fa6ed76cb2aa820707b4eb9ec46f42da9ce70b0eafab5e5e34942b38a44d5"
+dependencies = [
+ "libc",
+ "wasm-bindgen",
+ "winapi",
+]
+
+[[package]]
 name = "async-trait"
 version = "0.1.50"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1223,6 +1234,7 @@ dependencies = [
  "mobc-redis",
  "mockito",
  "newbase60",
+ "retainer",
  "serde",
  "serde_json",
  "serde_urlencoded",
@@ -1846,6 +1858,18 @@ dependencies = [
 ]
 
 [[package]]
+name = "retainer"
+version = "0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "59039dbf4a344af919780e9acdf7f9ce95deffb0152a72eca94b89d6a2bf66c0"
+dependencies = [
+ "async-lock",
+ "async-timer",
+ "log",
+ "rand 0.8.3",
+]
+
+[[package]]
 name = "route-recognizer"
 version = "0.2.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index 6ce2cbe..f3b7433 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -12,8 +12,6 @@ mockito = "0.30.0"          # HTTP mocking for Rust.
 tempdir = "0.3.7"           # A library for managing a temporary directory and deleting all contents when it's dropped.
 
 [dependencies]
-# Redis driver for Rust.
-#redis = { version = "0.20.0", features = ["aio", "async-std-comp"] }
 # Redis support for the mobc connection pool
 mobc-redis = { version = "0.7.0", features = ["async-std-comp"], default-features = false }
 # A generic serialization/deserialization framework
@@ -24,18 +22,19 @@ chrono = { version = "0.4.19", features = ["serde"] }
 url = { version = "2.2.1", features = ["serde"] }
 # Async version of the Rust standard library
 async-std = { version = "1.9.0", features = ["attributes"] }
-lazy_static = "1.4.0"       # A macro for declaring lazily evaluated statics in Rust.
 async-trait = "0.1.50"      # Type erasure for async trait methods
+easy-scraper = "0.2.0"      # HTML scraping library focused on ease of use
 env_logger = "0.8.3"        # A logging implementation for `log` which is configured via an environment variable.
 futures = "0.3.14"          # An implementation of futures and streams
 futures-util = "0.3.14"     # Common utilities and extension traits for the futures-rs library.
 http-types = "2.11.0"       # Common types for HTTP operations.
+lazy_static = "1.4.0"       # A macro for declaring lazily evaluated statics in Rust.
 log = "0.4.14"              # A lightweight logging facade for Rust
+markdown = "0.3.0"          # Native Rust library for parsing Markdown and (outputting HTML)
+mobc = "0.7.2"              # A generic connection pool with async/await support
+newbase60 = "0.1.3"         # A library that implements Tantek Çelik's New Base 60
+retainer = "0.2.2"          # Minimal async cache in Rust with support for key expirations
 serde_json = "1.0.64"       # A JSON serialization file format
+serde_urlencoded = "0.7.0"  # `x-www-form-urlencoded` meets Serde
 surf = "2.2.0"              # Surf the web - HTTP client framework
 tide = "0.16.0"             # A minimal and pragmatic Rust web application framework built for rapid development
-newbase60 = "0.1.3"         # A library that implements Tantek Çelik's New Base 60
-markdown = "0.3.0"          # Native Rust library for parsing Markdown and (outputting HTML)
-easy-scraper = "0.2.0"      # HTML scraping library focused on ease of use
-serde_urlencoded = "0.7.0"  # `x-www-form-urlencoded` meets Serde
-mobc = "0.7.2"              # A generic connection pool with async/await support
diff --git a/src/indieauth.rs b/src/indieauth.rs
index 46d2459..4fac5e1 100644
--- a/src/indieauth.rs
+++ b/src/indieauth.rs
@@ -1,14 +1,14 @@
+use async_trait::async_trait;
 use log::{error,info};
-use std::future::Future;
-use std::pin::Pin;
 use url::Url;
 use tide::prelude::*;
 use tide::{Request, Response, Next, Result};
+use std::sync::Arc;
 
 use crate::database;
 use crate::ApplicationState;
 
-#[derive(Deserialize, Serialize, Debug, PartialEq)]
+#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
 pub struct User {
     pub me: Url,
     pub client_id: Url,
@@ -58,15 +58,55 @@ async fn get_token_data(token: String, token_endpoint: &http_types::Url, http_cl
     }
 }
 
-// TODO: Figure out how to cache these authorization values - they can potentially take a lot of processing time
-pub fn check_auth<'a, Backend>(mut req: Request<ApplicationState<Backend>>, next: Next<'a, ApplicationState<Backend>>) -> Pin<Box<dyn Future<Output = Result> + Send + 'a>>
-where
-    Backend: database::Storage + Send + Sync + Clone
+pub struct IndieAuthMiddleware {
+    cache: Arc<retainer::Cache<String, User>>,
+    monitor_task: Option<async_std::task::JoinHandle<()>>
+}
+impl IndieAuthMiddleware {
+    /// Create a new instance of IndieAuthMiddleware.
+    ///
+    /// Note that creating a new instance automatically launches a task
+    /// to garbage-collect stale cache entries. Please do not create
+    /// instances willy-nilly because of that.
+    pub fn new() -> Self {
+        let cache: Arc<retainer::Cache<String, User>> = Arc::new(retainer::Cache::new());
+        let cache_clone = cache.clone();
+        let task = async_std::task::spawn(async move { cache_clone.monitor(4, 0.1, std::time::Duration::from_secs(30)).await });
+        Self { cache, monitor_task: Some(task) }
+    }
+}
+impl Drop for IndieAuthMiddleware {
+    fn drop(&mut self) {
+        // Cancel the task, or a VERY FUNNY thing might occur.
+        // If I understand this correctly, keeping a task active
+        // WILL keep an active reference to a value, so I'm pretty sure
+        // that something VERY FUNNY might occur whenever `cache` is dropped
+        // and its related task is not cancelled. So let's cancel it so
+        // [`cache`] can be dropped once and for all.
+
+        // First, get the ownership of a task, sneakily switching it out with None
+        // (wow, this is sneaky, didn't know Safe Rust could even do that!!!)
+        // (it is safe tho cuz None is no nullptr and dereferencing it doesn't cause unsafety)
+        // (could cause a VERY FUNNY race condition to occur though
+        //  if you tried to refer to the value in another thread!)
+        let task = std::mem::take(&mut self.monitor_task).expect("Dropped IndieAuthMiddleware TWICE? Impossible!");
+        // Then cancel the task, using another task to request cancellation.
+        // Because apparently you can't run async code from Drop...
+        // This should drop the last reference for the [`cache`],
+        // allowing it to be dropped.
+        async_std::task::spawn(async move { task.cancel() });
+    }
+}
+#[async_trait]
+impl<B> tide::Middleware<ApplicationState<B>> for IndieAuthMiddleware where
+    B: database::Storage + Send + Sync + Clone
 {
-    Box::pin(async {
+    async fn handle(&self, mut req: Request<ApplicationState<B>>, next: Next<'_, ApplicationState<B>>) -> Result {
         let header = req.header("Authorization");
         match header {
             None => {
+                // TODO: move that to the request handling functions
+                // or make a middleware that refuses to accept unauthenticated requests
                 Ok(Response::builder(401).body(json!({
                     "error": "unauthorized",
                     "error_description": "Please provide an access token."
@@ -75,31 +115,47 @@ where
             Some(value) => {
                 let endpoint = &req.state().token_endpoint;
                 let http_client = &req.state().http_client;
-                match get_token_data(value.last().to_string(), endpoint, http_client).await {
-                    (http_types::StatusCode::Ok, Some(user)) => {
-                        req.set_ext(user);
+                let token = value.last().to_string();
+                match self.cache.get(&token).await {
+                    Some(user) => {
+                        req.set_ext::<User>(user.clone());
                         Ok(next.run(req).await)
                     },
-                    (http_types::StatusCode::InternalServerError, None) => {
-                        Ok(Response::builder(500).body(json!({
-                            "error": "token_endpoint_fail",
-                            "error_description": "Token endpoint made a boo-boo and refused to answer."
-                        })).build())
-                    },
-                    (_, None) => {
-                        Ok(Response::builder(401).body(json!({
-                            "error": "unauthorized",
-                            "error_description": "The token endpoint refused to accept your token."
-                        })).build())
-                    },
-                    (_, Some(_)) => {
-                        // This shouldn't happen.
-                        panic!("The token validation function has caught rabies and returns malformed responses. Aborting.");
+                    None => match get_token_data(value.last().to_string(), endpoint, http_client).await {
+                        (http_types::StatusCode::Ok, Some(user)) => {
+                            // Note that this can run multiple requests before the value appears in the cache.
+                            // This seems to be in line with some other implementations of a function cache
+                            // (e.g. the [`cached`](https://lib.rs/crates/cached) crate and Python's `functools.lru_cache`)
+                            //
+                            // TODO: ensure the duration is no more than the token's remaining time until expiration
+                            // (in case the expiration time is defined on the token - AFAIK currently non-standard in IndieAuth)
+                            self.cache.insert(token, user.clone(), std::time::Duration::from_secs(600)).await;
+                            req.set_ext(user);
+                            Ok(next.run(req).await)
+                        },
+                        // TODO: Refactor to return Err(IndieAuthError) so downstream middleware could catch it
+                        // and present a prettier interface to the error (maybe even hiding data from the user)
+                        (http_types::StatusCode::InternalServerError, None) => {
+                            Ok(Response::builder(500).body(json!({
+                                "error": "token_endpoint_fail",
+                                "error_description": "Token endpoint made a boo-boo and refused to answer."
+                            })).build())
+                        },
+                        (_, None) => {
+                            Ok(Response::builder(401).body(json!({
+                                "error": "unauthorized",
+                                "error_description": "The token endpoint refused to accept your token."
+                            })).build())
+                        },
+                        (_, Some(_)) => {
+                            // This shouldn't happen.
+                            panic!("The token validation function has caught rabies and returns malformed responses. Aborting.");
+                        }
                     }
                 }
             }
         }
-    })
+    }
 }
 
 #[cfg(test)]
diff --git a/src/lib.rs b/src/lib.rs
index 86376d4..c422fea 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -8,7 +8,7 @@ mod database;
 mod indieauth;
 mod micropub;
 
-use crate::indieauth::check_auth;
+use crate::indieauth::IndieAuthMiddleware;
 use crate::micropub::{get_handler,post_handler};
 
 #[derive(Clone)]
@@ -39,7 +39,7 @@ fn equip_app<Storage>(mut app: App<Storage>) -> App<Storage>
 where
     Storage: database::Storage + Send + Sync + Clone
 {
-    app.at("/micropub").with(check_auth).get(get_handler).post(post_handler);
+    app.at("/micropub").with(IndieAuthMiddleware::new()).get(get_handler).post(post_handler);
     // The Micropub client. It'll start small, but could grow into something full-featured!
     app.at("/micropub/client").get(|_: Request<_>| async move {
         Ok(Response::builder(200).body(MICROPUB_CLIENT).content_type("text/html").build())