about summary refs log blame commit diff
path: root/src/indieauth.rs
blob: 7a2a07ec43d8ddd7d4dacd95a529ad80b0282ba7 (plain) (tree)
1
2
3
4
5
6
7
8
                             
                        
                       
                     
                        
                                            


                            
                                                          

                       
                  







                                                                


                                                                
                                     


         
                                        










                                             


                                                      


                                                                                    
                                                   
                     
                                 



                                                                                













                                                                             
                                
                                                                                   
                                              
                                                          








                                                                                         



                                                                    


                                                                                                                                                                                                


                                     














                                                                                              
                                                                      


                                                                            
                                                                   

              

                                                                     
 
                                            








                                                  

                                            



                                              

                                                 
                                                                                       





                                                                              
                            
                                                           


                                                          
                                               




























                                                                                                                                 


                     
     





                                    



                                    


                                             
 
use async_trait::async_trait;
#[allow(unused_imports)]
use log::{error, info};
use std::sync::Arc;
use tide::prelude::*;
#[allow(unused_imports)]
use tide::{Next, Request, Response, Result};
use url::Url;

use crate::database;
use crate::ApplicationState;

#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
pub struct User {
    pub me: Url,
    pub client_id: Url,
    scope: String,
}

impl User {
    pub fn check_scope(&self, scope: &str) -> bool {
        self.scopes().any(|i| i == scope)
    }
    pub fn scopes(&self) -> std::str::SplitAsciiWhitespace<'_> {
        self.scope.split_ascii_whitespace()
    }
    pub fn new(me: &str, client_id: &str, scope: &str) -> Self {
        Self {
            me: Url::parse(me).unwrap(),
            client_id: Url::parse(client_id).unwrap(),
            scope: scope.to_string(),
        }
    }
}

#[cfg(any(not(debug_assertions), test))]
async fn get_token_data(
    token: String,
    token_endpoint: &http_types::Url,
    http_client: &surf::Client,
) -> (http_types::StatusCode, Option<User>) {
    match http_client
        .get(token_endpoint)
        .header("Authorization", token)
        .header("Accept", "application/json")
        .send()
        .await
    {
        Ok(mut resp) => {
            if resp.status() == 200 {
                match resp.body_json::<User>().await {
                    Ok(user) => {
                        info!(
                            "Token endpoint request successful. Validated user: {}",
                            user.me
                        );
                        (resp.status(), Some(user))
                    }
                    Err(err) => {
                        error!(
                            "Token endpoint parsing error (HTTP status {}): {}",
                            resp.status(),
                            err
                        );
                        (http_types::StatusCode::InternalServerError, None)
                    }
                }
            } else {
                error!("Token endpoint returned non-200: {}", resp.status());
                (resp.status(), None)
            }
        }
        Err(err) => {
            error!("Token endpoint connection error: {}", err);
            (http_types::StatusCode::InternalServerError, None)
        }
    }
}

pub struct IndieAuthMiddleware {
    #[allow(dead_code)] // it's not really dead since it's only dead in debug scope
    cache: Arc<retainer::Cache<String, User>>,
    monitor_task: Option<async_std::task::JoinHandle<()>>,
}
impl IndieAuthMiddleware {
    /// Create a new instance of IndieAuthMiddleware.
    ///
    /// Note that creating a new instance automatically launches a task
    /// to garbage-collect stale cache entries. Please do not create
    /// instances willy-nilly because of that.
    pub fn new() -> Self {
        let cache: Arc<retainer::Cache<String, User>> = Arc::new(retainer::Cache::new());
        let cache_clone = cache.clone();
        let task = async_std::task::spawn(async move {
            cache_clone
                .monitor(4, 0.1, std::time::Duration::from_secs(30))
                .await
        });

        #[cfg(all(debug_assertions, not(test)))]
        error!("ATTENTION: You are running in debug mode. NO REQUESTS TO TOKEN ENDPOINT WILL BE MADE. YOU WILL BE PROCEEDING WITH DEBUG USER CREDENTIALS. DO NOT RUN LIKE THIS IN PRODUCTION.");

        Self {
            cache,
            monitor_task: Some(task),
        }
    }
}
impl Drop for IndieAuthMiddleware {
    fn drop(&mut self) {
        // Cancel the task, or a VERY FUNNY thing might occur.
        // If I understand this correctly, keeping a task active
        // WILL keep an active reference to a value, so I'm pretty sure
        // that something VERY FUNNY might occur whenever `cache` is dropped
        // and its related task is not cancelled. So let's cancel it so
        // [`cache`] can be dropped once and for all.

        // First, get the ownership of a task, sneakily switching it out with None
        // (wow, this is sneaky, didn't know Safe Rust could even do that!!!)
        // (it is safe tho cuz None is no nullptr and dereferencing it doesn't cause unsafety)
        // (could cause a VERY FUNNY race condition to occur though
        //  if you tried to refer to the value in another thread!)
        let task = std::mem::take(&mut self.monitor_task)
            .expect("Dropped IndieAuthMiddleware TWICE? Impossible!");
        // Then cancel the task, using another task to request cancellation.
        // Because apparently you can't run async code from Drop...
        // This should drop the last reference for the [`cache`],
        // allowing it to be dropped.
        async_std::task::spawn(async move { task.cancel().await });
    }
}
#[async_trait]
impl<B> tide::Middleware<ApplicationState<B>> for IndieAuthMiddleware
where
    B: database::Storage + Send + Sync + Clone,
{
    #[cfg(all(not(test), debug_assertions))]
    async fn handle(
        &self,
        mut req: Request<ApplicationState<B>>,
        next: Next<'_, ApplicationState<B>>,
    ) -> Result {
        req.set_ext(User::new(
            "http://localhost:8080/",
            "https://curl.haxx.se/",
            "create update delete undelete media",
        ));
        Ok(next.run(req).await)
    }
    #[cfg(any(not(debug_assertions), test))]
    async fn handle(
        &self,
        mut req: Request<ApplicationState<B>>,
        next: Next<'_, ApplicationState<B>>,
    ) -> Result {
        let header = req.header("Authorization");
        match header {
            None => {
                // TODO: move that to the request handling functions
                // or make a middleware that refuses to accept unauthenticated requests
                Ok(Response::builder(401)
                    .body(json!({
                        "error": "unauthorized",
                        "error_description": "Please provide an access token."
                    }))
                    .build())
            }
            Some(value) => {
                let endpoint = &req.state().token_endpoint;
                let http_client = &req.state().http_client;
                let token = value.last().to_string();
                match self.cache.get(&token).await {
                    Some(user) => {
                        req.set_ext::<User>(user.clone());
                        Ok(next.run(req).await)
                    },
                    None => match get_token_data(value.last().to_string(), endpoint, http_client).await {
                        (http_types::StatusCode::Ok, Some(user)) => {
                            // Note that this can run multiple requests before the value appears in the cache.
                            // This seems to be in line with some other implementations of a function cache
                            // (e.g. the [`cached`](https://lib.rs/crates/cached) crate and Python's `functools.lru_cache`)
                            //
                            // TODO: ensure the duration is no more than the token's remaining time until expiration
                            // (in case the expiration time is defined on the token - AFAIK currently non-standard in IndieAuth)
                            self.cache.insert(token, user.clone(), std::time::Duration::from_secs(600)).await;
                            req.set_ext(user);
                            Ok(next.run(req).await)
                        },
                        // TODO: Refactor to return Err(IndieAuthError) so downstream middleware could catch it
                        // and present a prettier interface to the error (maybe even hiding data from the user)
                        (http_types::StatusCode::InternalServerError, None) => {
                            Ok(Response::builder(500).body(json!({
                                "error": "token_endpoint_fail",
                                "error_description": "Token endpoint made a boo-boo and refused to answer."
                            })).build())
                        },
                        (_, None) => {
                            Ok(Response::builder(401).body(json!({
                                "error": "unauthorized",
                                "error_description": "The token endpoint refused to accept your token."
                            })).build())
                        },
                        (_, Some(_)) => {
                            // This shouldn't happen.
                            panic!("The token validation function has caught rabies and returns malformed responses. Aborting.");
                        }
                    }
                }
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn user_scopes_are_checkable() {
        let user = User::new(
            "https://fireburn.ru/",
            "https://quill.p3k.io/",
            "create update media",
        );

        assert!(user.check_scope("create"));
        assert!(!user.check_scope("delete"));
    }
}