about summary refs log tree commit diff
path: root/src/tokenauth.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/tokenauth.rs')
-rw-r--r--src/tokenauth.rs358
1 files changed, 358 insertions, 0 deletions
diff --git a/src/tokenauth.rs b/src/tokenauth.rs
new file mode 100644
index 0000000..244a045
--- /dev/null
+++ b/src/tokenauth.rs
@@ -0,0 +1,358 @@
+use serde::{Deserialize, Serialize};
+use url::Url;
+
+#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
+pub struct User {
+    pub me: Url,
+    pub client_id: Url,
+    scope: String,
+}
+
+#[derive(Debug, Clone, PartialEq, Copy)]
+pub enum ErrorKind {
+    PermissionDenied,
+    NotAuthorized,
+    TokenEndpointError,
+    JsonParsing,
+    InvalidHeader,
+    Other,
+}
+
+#[derive(Deserialize, Serialize, Debug, Clone)]
+pub struct TokenEndpointError {
+    error: String,
+    error_description: String,
+}
+
+#[derive(Debug)]
+pub struct IndieAuthError {
+    source: Option<Box<dyn std::error::Error + Send + Sync>>,
+    kind: ErrorKind,
+    msg: String,
+}
+
+impl std::error::Error for IndieAuthError {
+    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+        self.source
+            .as_ref()
+            .map(|e| e.as_ref() as &dyn std::error::Error)
+    }
+}
+
+impl std::fmt::Display for IndieAuthError {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(
+            f,
+            "{}: {}",
+            match self.kind {
+                ErrorKind::TokenEndpointError => "token endpoint returned an error: ",
+                ErrorKind::JsonParsing => "error while parsing token endpoint response: ",
+                ErrorKind::NotAuthorized => "token endpoint did not recognize the token: ",
+                ErrorKind::PermissionDenied => "token endpoint rejected the token: ",
+                ErrorKind::InvalidHeader => "authorization header parsing error: ",
+                ErrorKind::Other => "token endpoint communication error: ",
+            },
+            self.msg
+        )
+    }
+}
+
+impl From<serde_json::Error> for IndieAuthError {
+    fn from(err: serde_json::Error) -> Self {
+        Self {
+            msg: format!("{}", err),
+            source: Some(Box::new(err)),
+            kind: ErrorKind::JsonParsing,
+        }
+    }
+}
+
+impl From<reqwest::Error> for IndieAuthError {
+    fn from(err: reqwest::Error) -> Self {
+        Self {
+            msg: format!("{}", err),
+            source: Some(Box::new(err)),
+            kind: ErrorKind::Other,
+        }
+    }
+}
+
+impl From<axum::extract::rejection::TypedHeaderRejection> for IndieAuthError {
+    fn from(err: axum::extract::rejection::TypedHeaderRejection) -> Self {
+        Self {
+            msg: format!("{:?}", err.reason()),
+            source: Some(Box::new(err)),
+            kind: ErrorKind::InvalidHeader,
+        }
+    }
+}
+
+impl axum::response::IntoResponse for IndieAuthError {
+    fn into_response(self) -> axum::response::Response {
+        let status_code: StatusCode = match self.kind {
+            ErrorKind::PermissionDenied => StatusCode::FORBIDDEN,
+            ErrorKind::NotAuthorized => StatusCode::UNAUTHORIZED,
+            ErrorKind::TokenEndpointError => StatusCode::INTERNAL_SERVER_ERROR,
+            ErrorKind::JsonParsing => StatusCode::BAD_REQUEST,
+            ErrorKind::InvalidHeader => StatusCode::UNAUTHORIZED,
+            ErrorKind::Other => StatusCode::INTERNAL_SERVER_ERROR,
+        };
+
+        let body = serde_json::json!({
+            "error": match self.kind {
+                ErrorKind::PermissionDenied => "forbidden",
+                ErrorKind::NotAuthorized => "unauthorized",
+                ErrorKind::TokenEndpointError => "token_endpoint_error",
+                ErrorKind::JsonParsing => "invalid_request",
+                ErrorKind::InvalidHeader => "unauthorized",
+                ErrorKind::Other => "unknown_error",
+            },
+            "error_description": self.msg
+        });
+
+        (status_code, axum::response::Json(body)).into_response()
+    }
+}
+
+impl User {
+    pub fn check_scope(&self, scope: &str) -> bool {
+        self.scopes().any(|i| i == scope)
+    }
+    pub fn scopes(&self) -> std::str::SplitAsciiWhitespace<'_> {
+        self.scope.split_ascii_whitespace()
+    }
+    pub fn new(me: &str, client_id: &str, scope: &str) -> Self {
+        Self {
+            me: Url::parse(me).unwrap(),
+            client_id: Url::parse(client_id).unwrap(),
+            scope: scope.to_string(),
+        }
+    }
+}
+
+use axum::{
+    extract::{Extension, FromRequest, RequestParts, TypedHeader},
+    headers::{
+        authorization::{Bearer, Credentials},
+        Authorization,
+    },
+    http::StatusCode,
+};
+
+// this newtype is required due to axum::Extension retrieving items by type
+// it's based on compiler magic matching extensions by their type's hashes
+#[derive(Debug, Clone)]
+pub struct TokenEndpoint(pub url::Url);
+
+#[async_trait::async_trait]
+impl<B> FromRequest<B> for User
+where
+    B: Send,
+{
+    type Rejection = IndieAuthError;
+
+    #[cfg_attr(
+        all(debug_assertions, not(test)),
+        allow(unreachable_code, unused_variables)
+    )]
+    async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
+        // Return a fake user if we're running a debug build
+        // I don't wanna bother with authentication
+        #[cfg(all(debug_assertions, not(test)))]
+        return Ok(User::new(
+            "http://localhost:8080/",
+            "https://quill.p3k.io/",
+            "create update delete media",
+        ));
+
+        let TypedHeader(Authorization(token)) =
+            TypedHeader::<Authorization<Bearer>>::from_request(req)
+                .await
+                .map_err(IndieAuthError::from)?;
+
+        let Extension(TokenEndpoint(token_endpoint)): Extension<TokenEndpoint> =
+            Extension::from_request(req).await.unwrap();
+
+        let Extension(http): Extension<reqwest::Client> =
+            Extension::from_request(req).await.unwrap();
+
+        match http
+            .get(token_endpoint)
+            .header("Authorization", token.encode())
+            .header("Accept", "application/json")
+            .send()
+            .await
+        {
+            Ok(res) => match res.status() {
+                StatusCode::OK => match res.json::<serde_json::Value>().await {
+                    Ok(json) => match serde_json::from_value::<User>(json.clone()) {
+                        Ok(user) => Ok(user),
+                        Err(err) => {
+                            if let Some(false) = json["active"].as_bool() {
+                                Err(IndieAuthError {
+                                    source: None,
+                                    kind: ErrorKind::NotAuthorized,
+                                    msg: "The token is not active for this user.".to_owned(),
+                                })
+                            } else {
+                                Err(IndieAuthError::from(err))
+                            }
+                        }
+                    },
+                    Err(err) => Err(IndieAuthError::from(err)),
+                },
+                StatusCode::BAD_REQUEST => match res.json::<TokenEndpointError>().await {
+                    Ok(err) => {
+                        if err.error == "unauthorized" {
+                            Err(IndieAuthError {
+                                source: None,
+                                kind: ErrorKind::NotAuthorized,
+                                msg: err.error_description,
+                            })
+                        } else {
+                            Err(IndieAuthError {
+                                source: None,
+                                kind: ErrorKind::TokenEndpointError,
+                                msg: err.error_description,
+                            })
+                        }
+                    }
+                    Err(err) => Err(IndieAuthError::from(err)),
+                },
+                _ => Err(IndieAuthError {
+                    source: None,
+                    msg: format!("Token endpoint returned {}", res.status()),
+                    kind: ErrorKind::TokenEndpointError,
+                }),
+            },
+            Err(err) => Err(IndieAuthError::from(err)),
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::User;
+    use axum::{
+        extract::FromRequest,
+        http::{Method, Request},
+    };
+    use wiremock::{MockServer, Mock, ResponseTemplate};
+    use wiremock::matchers::{method, path, header};
+
+    #[test]
+    fn user_scopes_are_checkable() {
+        let user = User::new(
+            "https://fireburn.ru/",
+            "https://quill.p3k.io/",
+            "create update media",
+        );
+
+        assert!(user.check_scope("create"));
+        assert!(!user.check_scope("delete"));
+    }
+
+    #[inline]
+    fn get_http_client() -> reqwest::Client {
+        reqwest::Client::new()
+    }
+
+    fn request<A: Into<Option<&'static str>>>(
+        auth: A,
+        endpoint: String,
+    ) -> Request<()> {
+        let request = Request::builder().method(Method::GET);
+
+        match auth.into() {
+            Some(auth) => request.header("Authorization", auth),
+            None => request,
+        }
+        .extension(super::TokenEndpoint(endpoint.parse().unwrap()))
+        .extension(get_http_client())
+        .body(())
+        .unwrap()
+    }
+
+    #[tokio::test]
+    async fn test_require_token_with_token() {
+        let server = MockServer::start().await;
+
+        Mock::given(path("/token"))
+            .and(header("Authorization", "Bearer token"))
+            .respond_with(ResponseTemplate::new(200)
+                          .set_body_json(User::new(
+                              "https://fireburn.ru/",
+                              "https://quill.p3k.io/",
+                              "create update media",
+                          ))
+            )
+            .mount(&server)
+            .await;
+
+        let request = request("Bearer token", format!("{}/token", &server.uri()));
+        let mut parts = axum::extract::RequestParts::new(request);
+        let user = User::from_request(&mut parts).await.unwrap();
+
+        assert_eq!(user.me.as_str(), "https://fireburn.ru/")
+    }
+
+    #[tokio::test]
+    async fn test_require_token_fake_token() {
+        let server = MockServer::start().await;
+
+        Mock::given(path("/refuse_token"))
+            .respond_with(ResponseTemplate::new(200)
+                          .set_body_json(serde_json::json!({"active": false}))
+            )
+            .mount(&server)
+            .await;
+
+        let request = request("Bearer token", format!("{}/refuse_token", &server.uri()));
+        let mut parts = axum::extract::RequestParts::new(request);
+        let err = User::from_request(&mut parts).await.unwrap_err();
+
+        assert_eq!(err.kind, super::ErrorKind::NotAuthorized)
+    }
+
+    #[tokio::test]
+    async fn test_require_token_no_token() {
+        let server = MockServer::start().await;
+
+        Mock::given(path("/should_never_be_called"))
+            .respond_with(ResponseTemplate::new(500))
+            .expect(0)
+            .mount(&server)
+            .await;
+
+        let request = request(None, format!("{}/should_never_be_called", &server.uri()));
+        let mut parts = axum::extract::RequestParts::new(request);
+        let err = User::from_request(&mut parts).await.unwrap_err();
+
+        assert_eq!(err.kind, super::ErrorKind::InvalidHeader);
+    }
+
+    #[tokio::test]
+    async fn test_require_token_400_error_unauthorized() {
+        let server = MockServer::start().await;
+
+        Mock::given(path("/refuse_token_with_400"))
+            .and(header("Authorization", "Bearer token"))
+            .respond_with(ResponseTemplate::new(400)
+                          .set_body_json(serde_json::json!({
+                              "error": "unauthorized",
+                              "error_description": "The token provided was malformed"
+                          }))
+            )
+            .mount(&server)
+            .await;
+
+        let request = request(
+            "Bearer token",
+            format!("{}/refuse_token_with_400", &server.uri()),
+        );
+        let mut parts = axum::extract::RequestParts::new(request);
+        let err = User::from_request(&mut parts).await.unwrap_err();
+
+        assert_eq!(err.kind, super::ErrorKind::NotAuthorized);
+    }
+}