about summary refs log blame commit diff
path: root/src/indieauth.rs
blob: 57c03019de7977467439d3509a5e493fec94b49e (plain) (tree)
1
2
3
4
5
6
7
8
9
             
                                    
                                                     
 
                                                          

                       
                  
 



















































                                                                                                  
                                              








                                               





                                                                


                                                                
                                     


         
                                                                                                                                  
                                                                                                    
                                                             


                                                          









                                                                                          
                                                                             
                                  








                                                                                        
                                                 




                                                                                                


                                                                         
                         
                                                                         
                      
                                                
                                                                      














                                                                             
                         

                                             
                                                                                 


                                                                                
             
          


            
                                                     
                             
                                    



                                    


                                             
 
             
                                             




















































                                                                                   




















                                                                                             






















                                                                                            
 
use url::Url;
use serde::{Serialize, Deserialize};
use warp::{Filter, Rejection, reject::MissingHeader};

#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
pub struct User {
    pub me: Url,
    pub client_id: Url,
    scope: String,
}

#[derive(Debug, Clone, PartialEq, Copy)]
pub enum ErrorKind {
    PermissionDenied,
    NotAuthorized,
    TokenEndpointError,
    JsonParsing,
    Other
}

#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct TokenEndpointError {
    error: String,
    error_description: String
}

#[derive(Debug)]
pub struct IndieAuthError {
    source: Option<Box<dyn std::error::Error + Send + Sync>>,
    kind: ErrorKind,
    msg: String
}

impl std::error::Error for IndieAuthError {
    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
        self.source.as_ref().map(|e| e.as_ref() as &dyn std::error::Error)
    }
}

impl std::fmt::Display for IndieAuthError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match match self.kind {
            ErrorKind::TokenEndpointError => write!(f, "token endpoint returned an error: "),
            ErrorKind::JsonParsing => write!(f, "error while parsing token endpoint response: "),
            ErrorKind::NotAuthorized => write!(f, "token endpoint did not recognize the token: "),
            ErrorKind::PermissionDenied => write!(f, "token endpoint rejected the token: "),
            ErrorKind::Other => write!(f, "token endpoint communication error: "),
        } {
            Ok(_) => write!(f, "{}", self.msg),
            Err(err) => Err(err)
        }
    }
}

impl From<serde_json::Error> for IndieAuthError {
    fn from(err: serde_json::Error) -> Self {
        Self {
            msg: format!("{}", err),
            source: Some(Box::new(err)),
            kind: ErrorKind::JsonParsing,
        }
    }
}

impl From<reqwest::Error> for IndieAuthError {
    fn from(err: reqwest::Error) -> Self {
        Self {
            msg: format!("{}", err),
            source: Some(Box::new(err)),
            kind: ErrorKind::Other,
        }
    }
}

impl warp::reject::Reject for IndieAuthError {}

impl User {
    pub fn check_scope(&self, scope: &str) -> bool {
        self.scopes().any(|i| i == scope)
    }
    pub fn scopes(&self) -> std::str::SplitAsciiWhitespace<'_> {
        self.scope.split_ascii_whitespace()
    }
    pub fn new(me: &str, client_id: &str, scope: &str) -> Self {
        Self {
            me: Url::parse(me).unwrap(),
            client_id: Url::parse(client_id).unwrap(),
            scope: scope.to_string(),
        }
    }
}

pub fn require_token(token_endpoint: String, http: reqwest::Client) -> impl Filter<Extract = (User,), Error = Rejection> + Clone {
    // It might be OK to panic here, because we're still inside the initialisation sequence for now.
    // Proper error handling on the top of this should be used though.
    let token_endpoint_uri = url::Url::parse(&token_endpoint)
        .expect("Couldn't parse the token endpoint URI!");
    warp::any()
        .map(move || token_endpoint_uri.clone())
        .and(warp::any().map(move || http.clone()))
        .and(warp::header::<String>("Authorization").recover(|err: Rejection| async move {
            if err.find::<MissingHeader>().is_some() {
                Err(IndieAuthError {
                    source: None,
                    msg: "No Authorization header provided.".to_string(),
                    kind: ErrorKind::NotAuthorized
                }.into())
            } else {
                Err(err)
            }
        }).unify())
        .and_then(|token_endpoint, http: reqwest::Client, token| async move {
            use hyper::StatusCode;

            match http
                .get(token_endpoint)
                .header("Authorization", token)
                .header("Accept", "application/json")
                .send()
                .await
            {
                Ok(res) => match res.status() {
                    StatusCode::OK => match res.json::<serde_json::Value>().await {
                        Ok(json) => match serde_json::from_value::<User>(json.clone()) {
                            Ok(user) => Ok(user),
                            Err(err) => {
                                if let Some(false) = json["active"].as_bool() {
                                    Err(IndieAuthError {
                                        source: None,
                                        kind: ErrorKind::NotAuthorized,
                                        msg: "The token is not active for this user.".to_owned()
                                    }.into())
                                } else {
                                    Err(IndieAuthError::from(err).into())
                                }
                            }
                        }
                        Err(err) => Err(IndieAuthError::from(err).into())
                    },
                    StatusCode::BAD_REQUEST => {
                        match res.json::<TokenEndpointError>().await {
                            Ok(err) => {
                                if err.error == "unauthorized" {
                                    Err(IndieAuthError {
                                        source: None,
                                        kind: ErrorKind::NotAuthorized,
                                        msg: err.error_description
                                    }.into())
                                } else {
                                    Err(IndieAuthError {
                                        source: None,
                                        kind: ErrorKind::TokenEndpointError,
                                        msg: err.error_description
                                    }.into())
                                }
                            },
                            Err(err) => Err(IndieAuthError::from(err).into())
                        }
                    },
                    _ => Err(IndieAuthError {
                        source: None,
                        msg: format!("Token endpoint returned {}", res.status()),
                        kind: ErrorKind::TokenEndpointError
                    }.into())
                },
                Err(err) => Err(warp::reject::custom(IndieAuthError::from(err)))
            }
        })
}

#[cfg(test)]
mod tests {
    use super::{User, IndieAuthError, require_token};
    use httpmock::prelude::*;
    
    #[test]
    fn user_scopes_are_checkable() {
        let user = User::new(
            "https://fireburn.ru/",
            "https://quill.p3k.io/",
            "create update media",
        );

        assert!(user.check_scope("create"));
        assert!(!user.check_scope("delete"));
    }

    #[inline]
    fn get_http_client() -> reqwest::Client {
        reqwest::Client::new()
    }
    
    #[tokio::test]
    async fn test_require_token_with_token() {
        let server = MockServer::start_async().await;
        server.mock_async(|when, then| {
            when.path("/token")
                .header("Authorization", "Bearer token");

            then.status(200)
                .header("Content-Type", "application/json")
                .json_body(serde_json::to_value(User::new(
                    "https://fireburn.ru/",
                    "https://quill.p3k.io/",
                    "create update media",
                )).unwrap());
        }).await;
        
        let filter = require_token(server.url("/token"), get_http_client());

        let res: User = warp::test::request()
            .path("/")
            .header("Authorization", "Bearer token")
            .filter(&filter)
            .await
            .unwrap();

        assert_eq!(res.me.as_str(), "https://fireburn.ru/")
    }

    #[tokio::test]
    async fn test_require_token_fake_token() {
        let server = MockServer::start_async().await;
        server.mock_async(|when, then| {
            when.path("/refuse_token");

            then.status(200)
                .json_body(serde_json::json!({"active": false}));
        }).await;

        let filter = require_token(server.url("/refuse_token"), get_http_client());

        let res = warp::test::request()
            .path("/")
            .header("Authorization", "Bearer token")
            .filter(&filter)
            .await
            .unwrap_err();

        let err: &IndieAuthError = res.find().unwrap();
        assert_eq!(err.kind, super::ErrorKind::NotAuthorized);
    }

    #[tokio::test]
    async fn test_require_token_no_token() {
        let server = MockServer::start_async().await;
        let mock = server.mock_async(|when, then| {
            when.path("/should_never_be_called");

            then.status(500);
        }).await;
        let filter = require_token(server.url("/should_never_be_called"), get_http_client());

        let res = warp::test::request()
            .path("/")
            .filter(&filter)
            .await
            .unwrap_err();

        let err: &IndieAuthError = res.find().unwrap();
        assert_eq!(err.kind, super::ErrorKind::NotAuthorized);

        mock.assert_hits_async(0).await;
    }

    #[tokio::test]
    async fn test_require_token_400_error_unauthorized() {
        let server = MockServer::start_async().await;
        server.mock_async(|when, then| {
            when.path("/refuse_token_with_400");

            then.status(400)
                .json_body(serde_json::json!({
                    "error": "unauthorized",
                    "error_description": "The token provided was malformed"
                }));
        }).await;

        let filter = require_token(server.url("/refuse_token_with_400"), get_http_client());

        let res = warp::test::request()
            .path("/")
            .header("Authorization", "Bearer token")
            .filter(&filter)
            .await
            .unwrap_err();

        let err: &IndieAuthError = res.find().unwrap();
        assert_eq!(err.kind, super::ErrorKind::NotAuthorized);
    }
}