about summary refs log tree commit diff
path: root/src/indieauth.rs
diff options
context:
space:
mode:
authorVika <vika@fireburn.ru>2021-05-05 22:40:01 +0300
committerVika <vika@fireburn.ru>2021-05-05 22:40:01 +0300
commitcbbfca9af1f0aa9da87709f99353fd76fd6617a8 (patch)
tree232f4bfc8682255195dc1a2278e2830db21dd7bb /src/indieauth.rs
parentdd9d3ff3e9505926e72df7df679cefe960be23bd (diff)
downloadkittybox-cbbfca9af1f0aa9da87709f99353fd76fd6617a8.tar.zst
Added rudimentary caching to IndieAuth middleware
Diffstat (limited to 'src/indieauth.rs')
-rw-r--r--src/indieauth.rs110
1 files changed, 83 insertions, 27 deletions
diff --git a/src/indieauth.rs b/src/indieauth.rs
index 46d2459..4fac5e1 100644
--- a/src/indieauth.rs
+++ b/src/indieauth.rs
@@ -1,14 +1,14 @@
+use async_trait::async_trait;
 use log::{error,info};
-use std::future::Future;
-use std::pin::Pin;
 use url::Url;
 use tide::prelude::*;
 use tide::{Request, Response, Next, Result};
+use std::sync::Arc;
 
 use crate::database;
 use crate::ApplicationState;
 
-#[derive(Deserialize, Serialize, Debug, PartialEq)]
+#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
 pub struct User {
     pub me: Url,
     pub client_id: Url,
@@ -58,15 +58,55 @@ async fn get_token_data(token: String, token_endpoint: &http_types::Url, http_cl
     }
 }
 
-// TODO: Figure out how to cache these authorization values - they can potentially take a lot of processing time
-pub fn check_auth<'a, Backend>(mut req: Request<ApplicationState<Backend>>, next: Next<'a, ApplicationState<Backend>>) -> Pin<Box<dyn Future<Output = Result> + Send + 'a>>
-where
-    Backend: database::Storage + Send + Sync + Clone
+pub struct IndieAuthMiddleware {
+    cache: Arc<retainer::Cache<String, User>>,
+    monitor_task: Option<async_std::task::JoinHandle<()>>
+}
+impl IndieAuthMiddleware {
+    /// Create a new instance of IndieAuthMiddleware.
+    ///
+    /// Note that creating a new instance automatically launches a task
+    /// to garbage-collect stale cache entries. Please do not create
+    /// instances willy-nilly because of that.
+    pub fn new() -> Self {
+        let cache: Arc<retainer::Cache<String, User>> = Arc::new(retainer::Cache::new());
+        let cache_clone = cache.clone();
+        let task = async_std::task::spawn(async move { cache_clone.monitor(4, 0.1, std::time::Duration::from_secs(30)).await });
+        Self { cache, monitor_task: Some(task) }
+    }
+}
+impl Drop for IndieAuthMiddleware {
+    fn drop(&mut self) {
+        // Cancel the task, or a VERY FUNNY thing might occur.
+        // If I understand this correctly, keeping a task active
+        // WILL keep an active reference to a value, so I'm pretty sure
+        // that something VERY FUNNY might occur whenever `cache` is dropped
+        // and its related task is not cancelled. So let's cancel it so
+        // [`cache`] can be dropped once and for all.
+
+        // First, get the ownership of a task, sneakily switching it out with None
+        // (wow, this is sneaky, didn't know Safe Rust could even do that!!!)
+        // (it is safe tho cuz None is no nullptr and dereferencing it doesn't cause unsafety)
+        // (could cause a VERY FUNNY race condition to occur though
+        //  if you tried to refer to the value in another thread!)
+        let task = std::mem::take(&mut self.monitor_task).expect("Dropped IndieAuthMiddleware TWICE? Impossible!");
+        // Then cancel the task, using another task to request cancellation.
+        // Because apparently you can't run async code from Drop...
+        // This should drop the last reference for the [`cache`],
+        // allowing it to be dropped.
+        async_std::task::spawn(async move { task.cancel() });
+    }
+}
+#[async_trait]
+impl<B> tide::Middleware<ApplicationState<B>> for IndieAuthMiddleware where
+    B: database::Storage + Send + Sync + Clone
 {
-    Box::pin(async {
+    async fn handle(&self, mut req: Request<ApplicationState<B>>, next: Next<'_, ApplicationState<B>>) -> Result {
         let header = req.header("Authorization");
         match header {
             None => {
+                // TODO: move that to the request handling functions
+                // or make a middleware that refuses to accept unauthenticated requests
                 Ok(Response::builder(401).body(json!({
                     "error": "unauthorized",
                     "error_description": "Please provide an access token."
@@ -75,31 +115,47 @@ where
             Some(value) => {
                 let endpoint = &req.state().token_endpoint;
                 let http_client = &req.state().http_client;
-                match get_token_data(value.last().to_string(), endpoint, http_client).await {
-                    (http_types::StatusCode::Ok, Some(user)) => {
-                        req.set_ext(user);
+                let token = value.last().to_string();
+                match self.cache.get(&token).await {
+                    Some(user) => {
+                        req.set_ext::<User>(user.clone());
                         Ok(next.run(req).await)
                     },
-                    (http_types::StatusCode::InternalServerError, None) => {
-                        Ok(Response::builder(500).body(json!({
-                            "error": "token_endpoint_fail",
-                            "error_description": "Token endpoint made a boo-boo and refused to answer."
-                        })).build())
-                    },
-                    (_, None) => {
-                        Ok(Response::builder(401).body(json!({
-                            "error": "unauthorized",
-                            "error_description": "The token endpoint refused to accept your token."
-                        })).build())
-                    },
-                    (_, Some(_)) => {
-                        // This shouldn't happen.
-                        panic!("The token validation function has caught rabies and returns malformed responses. Aborting.");
+                    None => match get_token_data(value.last().to_string(), endpoint, http_client).await {
+                        (http_types::StatusCode::Ok, Some(user)) => {
+                            // Note that this can run multiple requests before the value appears in the cache.
+                            // This seems to be in line with some other implementations of a function cache
+                            // (e.g. the [`cached`](https://lib.rs/crates/cached) crate and Python's `functools.lru_cache`)
+                            //
+                            // TODO: ensure the duration is no more than the token's remaining time until expiration
+                            // (in case the expiration time is defined on the token - AFAIK currently non-standard in IndieAuth)
+                            self.cache.insert(token, user.clone(), std::time::Duration::from_secs(600)).await;
+                            req.set_ext(user);
+                            Ok(next.run(req).await)
+                        },
+                        // TODO: Refactor to return Err(IndieAuthError) so downstream middleware could catch it
+                        // and present a prettier interface to the error (maybe even hiding data from the user)
+                        (http_types::StatusCode::InternalServerError, None) => {
+                            Ok(Response::builder(500).body(json!({
+                                "error": "token_endpoint_fail",
+                                "error_description": "Token endpoint made a boo-boo and refused to answer."
+                            })).build())
+                        },
+                        (_, None) => {
+                            Ok(Response::builder(401).body(json!({
+                                "error": "unauthorized",
+                                "error_description": "The token endpoint refused to accept your token."
+                            })).build())
+                        },
+                        (_, Some(_)) => {
+                            // This shouldn't happen.
+                            panic!("The token validation function has caught rabies and returns malformed responses. Aborting.");
+                        }
                     }
                 }
             }
         }
-    })
+    }
 }
 
 #[cfg(test)]