about summary refs log blame commit diff
path: root/kittybox-rs/src/tokenauth.rs
blob: 244a045db3af815e19eb066ef6d51e7c683aeefc (plain) (tree)
1
2
3
4
5
6
7
8
9
10
                                    
             
 
                                                          

                       
                  
 




                                        
                  



                                               
                              




                                                             
                


                                                                    

                                                          



                                                                        











                                                                                           











                                                 
                                              






                                        


































                                                                               
 





                                                                


                                                                
                                     


         



















                                                                           


                                                 





                                                                                       
                                         


                                                                   
                      


























                                                                                             
                             
                         
                      














                                                                                         
                         
                                                               
                  







                                                                             


            

                             
                                
      
                                                       
 
                                    



                                    


                                             
 
             
                                             
     
 
                                              
                
                         




                                                                
                                                                   



                                     
                                              










                                                         
                   
                                                                                  


                                                                  


                                              





                                                                              
                   
 
                                                                                         
                                                                    
 
                                                             

                  
                                            
                                               
 


                                                     
                   
 
                                                                                         
                                                                    
 
                                                              

                  
                                                          









                                                                                     
                   
 
                              
                                                               

                                                                    
 
                                                              
 
use serde::{Deserialize, Serialize};
use url::Url;

#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
pub struct User {
    pub me: Url,
    pub client_id: Url,
    scope: String,
}

#[derive(Debug, Clone, PartialEq, Copy)]
pub enum ErrorKind {
    PermissionDenied,
    NotAuthorized,
    TokenEndpointError,
    JsonParsing,
    InvalidHeader,
    Other,
}

#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct TokenEndpointError {
    error: String,
    error_description: String,
}

#[derive(Debug)]
pub struct IndieAuthError {
    source: Option<Box<dyn std::error::Error + Send + Sync>>,
    kind: ErrorKind,
    msg: String,
}

impl std::error::Error for IndieAuthError {
    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
        self.source
            .as_ref()
            .map(|e| e.as_ref() as &dyn std::error::Error)
    }
}

impl std::fmt::Display for IndieAuthError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "{}: {}",
            match self.kind {
                ErrorKind::TokenEndpointError => "token endpoint returned an error: ",
                ErrorKind::JsonParsing => "error while parsing token endpoint response: ",
                ErrorKind::NotAuthorized => "token endpoint did not recognize the token: ",
                ErrorKind::PermissionDenied => "token endpoint rejected the token: ",
                ErrorKind::InvalidHeader => "authorization header parsing error: ",
                ErrorKind::Other => "token endpoint communication error: ",
            },
            self.msg
        )
    }
}

impl From<serde_json::Error> for IndieAuthError {
    fn from(err: serde_json::Error) -> Self {
        Self {
            msg: format!("{}", err),
            source: Some(Box::new(err)),
            kind: ErrorKind::JsonParsing,
        }
    }
}

impl From<reqwest::Error> for IndieAuthError {
    fn from(err: reqwest::Error) -> Self {
        Self {
            msg: format!("{}", err),
            source: Some(Box::new(err)),
            kind: ErrorKind::Other,
        }
    }
}

impl From<axum::extract::rejection::TypedHeaderRejection> for IndieAuthError {
    fn from(err: axum::extract::rejection::TypedHeaderRejection) -> Self {
        Self {
            msg: format!("{:?}", err.reason()),
            source: Some(Box::new(err)),
            kind: ErrorKind::InvalidHeader,
        }
    }
}

impl axum::response::IntoResponse for IndieAuthError {
    fn into_response(self) -> axum::response::Response {
        let status_code: StatusCode = match self.kind {
            ErrorKind::PermissionDenied => StatusCode::FORBIDDEN,
            ErrorKind::NotAuthorized => StatusCode::UNAUTHORIZED,
            ErrorKind::TokenEndpointError => StatusCode::INTERNAL_SERVER_ERROR,
            ErrorKind::JsonParsing => StatusCode::BAD_REQUEST,
            ErrorKind::InvalidHeader => StatusCode::UNAUTHORIZED,
            ErrorKind::Other => StatusCode::INTERNAL_SERVER_ERROR,
        };

        let body = serde_json::json!({
            "error": match self.kind {
                ErrorKind::PermissionDenied => "forbidden",
                ErrorKind::NotAuthorized => "unauthorized",
                ErrorKind::TokenEndpointError => "token_endpoint_error",
                ErrorKind::JsonParsing => "invalid_request",
                ErrorKind::InvalidHeader => "unauthorized",
                ErrorKind::Other => "unknown_error",
            },
            "error_description": self.msg
        });

        (status_code, axum::response::Json(body)).into_response()
    }
}

impl User {
    pub fn check_scope(&self, scope: &str) -> bool {
        self.scopes().any(|i| i == scope)
    }
    pub fn scopes(&self) -> std::str::SplitAsciiWhitespace<'_> {
        self.scope.split_ascii_whitespace()
    }
    pub fn new(me: &str, client_id: &str, scope: &str) -> Self {
        Self {
            me: Url::parse(me).unwrap(),
            client_id: Url::parse(client_id).unwrap(),
            scope: scope.to_string(),
        }
    }
}

use axum::{
    extract::{Extension, FromRequest, RequestParts, TypedHeader},
    headers::{
        authorization::{Bearer, Credentials},
        Authorization,
    },
    http::StatusCode,
};

// this newtype is required due to axum::Extension retrieving items by type
// it's based on compiler magic matching extensions by their type's hashes
#[derive(Debug, Clone)]
pub struct TokenEndpoint(pub url::Url);

#[async_trait::async_trait]
impl<B> FromRequest<B> for User
where
    B: Send,
{
    type Rejection = IndieAuthError;

    #[cfg_attr(
        all(debug_assertions, not(test)),
        allow(unreachable_code, unused_variables)
    )]
    async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
        // Return a fake user if we're running a debug build
        // I don't wanna bother with authentication
        #[cfg(all(debug_assertions, not(test)))]
        return Ok(User::new(
            "http://localhost:8080/",
            "https://quill.p3k.io/",
            "create update delete media",
        ));

        let TypedHeader(Authorization(token)) =
            TypedHeader::<Authorization<Bearer>>::from_request(req)
                .await
                .map_err(IndieAuthError::from)?;

        let Extension(TokenEndpoint(token_endpoint)): Extension<TokenEndpoint> =
            Extension::from_request(req).await.unwrap();

        let Extension(http): Extension<reqwest::Client> =
            Extension::from_request(req).await.unwrap();

        match http
            .get(token_endpoint)
            .header("Authorization", token.encode())
            .header("Accept", "application/json")
            .send()
            .await
        {
            Ok(res) => match res.status() {
                StatusCode::OK => match res.json::<serde_json::Value>().await {
                    Ok(json) => match serde_json::from_value::<User>(json.clone()) {
                        Ok(user) => Ok(user),
                        Err(err) => {
                            if let Some(false) = json["active"].as_bool() {
                                Err(IndieAuthError {
                                    source: None,
                                    kind: ErrorKind::NotAuthorized,
                                    msg: "The token is not active for this user.".to_owned(),
                                })
                            } else {
                                Err(IndieAuthError::from(err))
                            }
                        }
                    },
                    Err(err) => Err(IndieAuthError::from(err)),
                },
                StatusCode::BAD_REQUEST => match res.json::<TokenEndpointError>().await {
                    Ok(err) => {
                        if err.error == "unauthorized" {
                            Err(IndieAuthError {
                                source: None,
                                kind: ErrorKind::NotAuthorized,
                                msg: err.error_description,
                            })
                        } else {
                            Err(IndieAuthError {
                                source: None,
                                kind: ErrorKind::TokenEndpointError,
                                msg: err.error_description,
                            })
                        }
                    }
                    Err(err) => Err(IndieAuthError::from(err)),
                },
                _ => Err(IndieAuthError {
                    source: None,
                    msg: format!("Token endpoint returned {}", res.status()),
                    kind: ErrorKind::TokenEndpointError,
                }),
            },
            Err(err) => Err(IndieAuthError::from(err)),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::User;
    use axum::{
        extract::FromRequest,
        http::{Method, Request},
    };
    use wiremock::{MockServer, Mock, ResponseTemplate};
    use wiremock::matchers::{method, path, header};

    #[test]
    fn user_scopes_are_checkable() {
        let user = User::new(
            "https://fireburn.ru/",
            "https://quill.p3k.io/",
            "create update media",
        );

        assert!(user.check_scope("create"));
        assert!(!user.check_scope("delete"));
    }

    #[inline]
    fn get_http_client() -> reqwest::Client {
        reqwest::Client::new()
    }

    fn request<A: Into<Option<&'static str>>>(
        auth: A,
        endpoint: String,
    ) -> Request<()> {
        let request = Request::builder().method(Method::GET);

        match auth.into() {
            Some(auth) => request.header("Authorization", auth),
            None => request,
        }
        .extension(super::TokenEndpoint(endpoint.parse().unwrap()))
        .extension(get_http_client())
        .body(())
        .unwrap()
    }

    #[tokio::test]
    async fn test_require_token_with_token() {
        let server = MockServer::start().await;

        Mock::given(path("/token"))
            .and(header("Authorization", "Bearer token"))
            .respond_with(ResponseTemplate::new(200)
                          .set_body_json(User::new(
                              "https://fireburn.ru/",
                              "https://quill.p3k.io/",
                              "create update media",
                          ))
            )
            .mount(&server)
            .await;

        let request = request("Bearer token", format!("{}/token", &server.uri()));
        let mut parts = axum::extract::RequestParts::new(request);
        let user = User::from_request(&mut parts).await.unwrap();

        assert_eq!(user.me.as_str(), "https://fireburn.ru/")
    }

    #[tokio::test]
    async fn test_require_token_fake_token() {
        let server = MockServer::start().await;

        Mock::given(path("/refuse_token"))
            .respond_with(ResponseTemplate::new(200)
                          .set_body_json(serde_json::json!({"active": false}))
            )
            .mount(&server)
            .await;

        let request = request("Bearer token", format!("{}/refuse_token", &server.uri()));
        let mut parts = axum::extract::RequestParts::new(request);
        let err = User::from_request(&mut parts).await.unwrap_err();

        assert_eq!(err.kind, super::ErrorKind::NotAuthorized)
    }

    #[tokio::test]
    async fn test_require_token_no_token() {
        let server = MockServer::start().await;

        Mock::given(path("/should_never_be_called"))
            .respond_with(ResponseTemplate::new(500))
            .expect(0)
            .mount(&server)
            .await;

        let request = request(None, format!("{}/should_never_be_called", &server.uri()));
        let mut parts = axum::extract::RequestParts::new(request);
        let err = User::from_request(&mut parts).await.unwrap_err();

        assert_eq!(err.kind, super::ErrorKind::InvalidHeader);
    }

    #[tokio::test]
    async fn test_require_token_400_error_unauthorized() {
        let server = MockServer::start().await;

        Mock::given(path("/refuse_token_with_400"))
            .and(header("Authorization", "Bearer token"))
            .respond_with(ResponseTemplate::new(400)
                          .set_body_json(serde_json::json!({
                              "error": "unauthorized",
                              "error_description": "The token provided was malformed"
                          }))
            )
            .mount(&server)
            .await;

        let request = request(
            "Bearer token",
            format!("{}/refuse_token_with_400", &server.uri()),
        );
        let mut parts = axum::extract::RequestParts::new(request);
        let err = User::from_request(&mut parts).await.unwrap_err();

        assert_eq!(err.kind, super::ErrorKind::NotAuthorized);
    }
}