about summary refs log tree commit diff
path: root/kittybox-rs/src/indieauth.rs
diff options
context:
space:
mode:
Diffstat (limited to 'kittybox-rs/src/indieauth.rs')
-rw-r--r--kittybox-rs/src/indieauth.rs291
1 files changed, 291 insertions, 0 deletions
diff --git a/kittybox-rs/src/indieauth.rs b/kittybox-rs/src/indieauth.rs
new file mode 100644
index 0000000..57c0301
--- /dev/null
+++ b/kittybox-rs/src/indieauth.rs
@@ -0,0 +1,291 @@
+use url::Url;
+use serde::{Serialize, Deserialize};
+use warp::{Filter, Rejection, reject::MissingHeader};
+
+#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
+pub struct User {
+    pub me: Url,
+    pub client_id: Url,
+    scope: String,
+}
+
+#[derive(Debug, Clone, PartialEq, Copy)]
+pub enum ErrorKind {
+    PermissionDenied,
+    NotAuthorized,
+    TokenEndpointError,
+    JsonParsing,
+    Other
+}
+
+#[derive(Deserialize, Serialize, Debug, Clone)]
+pub struct TokenEndpointError {
+    error: String,
+    error_description: String
+}
+
+#[derive(Debug)]
+pub struct IndieAuthError {
+    source: Option<Box<dyn std::error::Error + Send + Sync>>,
+    kind: ErrorKind,
+    msg: String
+}
+
+impl std::error::Error for IndieAuthError {
+    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+        self.source.as_ref().map(|e| e.as_ref() as &dyn std::error::Error)
+    }
+}
+
+impl std::fmt::Display for IndieAuthError {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        match match self.kind {
+            ErrorKind::TokenEndpointError => write!(f, "token endpoint returned an error: "),
+            ErrorKind::JsonParsing => write!(f, "error while parsing token endpoint response: "),
+            ErrorKind::NotAuthorized => write!(f, "token endpoint did not recognize the token: "),
+            ErrorKind::PermissionDenied => write!(f, "token endpoint rejected the token: "),
+            ErrorKind::Other => write!(f, "token endpoint communication error: "),
+        } {
+            Ok(_) => write!(f, "{}", self.msg),
+            Err(err) => Err(err)
+        }
+    }
+}
+
+impl From<serde_json::Error> for IndieAuthError {
+    fn from(err: serde_json::Error) -> Self {
+        Self {
+            msg: format!("{}", err),
+            source: Some(Box::new(err)),
+            kind: ErrorKind::JsonParsing,
+        }
+    }
+}
+
+impl From<reqwest::Error> for IndieAuthError {
+    fn from(err: reqwest::Error) -> Self {
+        Self {
+            msg: format!("{}", err),
+            source: Some(Box::new(err)),
+            kind: ErrorKind::Other,
+        }
+    }
+}
+
+impl warp::reject::Reject for IndieAuthError {}
+
+impl User {
+    pub fn check_scope(&self, scope: &str) -> bool {
+        self.scopes().any(|i| i == scope)
+    }
+    pub fn scopes(&self) -> std::str::SplitAsciiWhitespace<'_> {
+        self.scope.split_ascii_whitespace()
+    }
+    pub fn new(me: &str, client_id: &str, scope: &str) -> Self {
+        Self {
+            me: Url::parse(me).unwrap(),
+            client_id: Url::parse(client_id).unwrap(),
+            scope: scope.to_string(),
+        }
+    }
+}
+
+pub fn require_token(token_endpoint: String, http: reqwest::Client) -> impl Filter<Extract = (User,), Error = Rejection> + Clone {
+    // It might be OK to panic here, because we're still inside the initialisation sequence for now.
+    // Proper error handling on the top of this should be used though.
+    let token_endpoint_uri = url::Url::parse(&token_endpoint)
+        .expect("Couldn't parse the token endpoint URI!");
+    warp::any()
+        .map(move || token_endpoint_uri.clone())
+        .and(warp::any().map(move || http.clone()))
+        .and(warp::header::<String>("Authorization").recover(|err: Rejection| async move {
+            if err.find::<MissingHeader>().is_some() {
+                Err(IndieAuthError {
+                    source: None,
+                    msg: "No Authorization header provided.".to_string(),
+                    kind: ErrorKind::NotAuthorized
+                }.into())
+            } else {
+                Err(err)
+            }
+        }).unify())
+        .and_then(|token_endpoint, http: reqwest::Client, token| async move {
+            use hyper::StatusCode;
+
+            match http
+                .get(token_endpoint)
+                .header("Authorization", token)
+                .header("Accept", "application/json")
+                .send()
+                .await
+            {
+                Ok(res) => match res.status() {
+                    StatusCode::OK => match res.json::<serde_json::Value>().await {
+                        Ok(json) => match serde_json::from_value::<User>(json.clone()) {
+                            Ok(user) => Ok(user),
+                            Err(err) => {
+                                if let Some(false) = json["active"].as_bool() {
+                                    Err(IndieAuthError {
+                                        source: None,
+                                        kind: ErrorKind::NotAuthorized,
+                                        msg: "The token is not active for this user.".to_owned()
+                                    }.into())
+                                } else {
+                                    Err(IndieAuthError::from(err).into())
+                                }
+                            }
+                        }
+                        Err(err) => Err(IndieAuthError::from(err).into())
+                    },
+                    StatusCode::BAD_REQUEST => {
+                        match res.json::<TokenEndpointError>().await {
+                            Ok(err) => {
+                                if err.error == "unauthorized" {
+                                    Err(IndieAuthError {
+                                        source: None,
+                                        kind: ErrorKind::NotAuthorized,
+                                        msg: err.error_description
+                                    }.into())
+                                } else {
+                                    Err(IndieAuthError {
+                                        source: None,
+                                        kind: ErrorKind::TokenEndpointError,
+                                        msg: err.error_description
+                                    }.into())
+                                }
+                            },
+                            Err(err) => Err(IndieAuthError::from(err).into())
+                        }
+                    },
+                    _ => Err(IndieAuthError {
+                        source: None,
+                        msg: format!("Token endpoint returned {}", res.status()),
+                        kind: ErrorKind::TokenEndpointError
+                    }.into())
+                },
+                Err(err) => Err(warp::reject::custom(IndieAuthError::from(err)))
+            }
+        })
+}
+
+#[cfg(test)]
+mod tests {
+    use super::{User, IndieAuthError, require_token};
+    use httpmock::prelude::*;
+    
+    #[test]
+    fn user_scopes_are_checkable() {
+        let user = User::new(
+            "https://fireburn.ru/",
+            "https://quill.p3k.io/",
+            "create update media",
+        );
+
+        assert!(user.check_scope("create"));
+        assert!(!user.check_scope("delete"));
+    }
+
+    #[inline]
+    fn get_http_client() -> reqwest::Client {
+        reqwest::Client::new()
+    }
+    
+    #[tokio::test]
+    async fn test_require_token_with_token() {
+        let server = MockServer::start_async().await;
+        server.mock_async(|when, then| {
+            when.path("/token")
+                .header("Authorization", "Bearer token");
+
+            then.status(200)
+                .header("Content-Type", "application/json")
+                .json_body(serde_json::to_value(User::new(
+                    "https://fireburn.ru/",
+                    "https://quill.p3k.io/",
+                    "create update media",
+                )).unwrap());
+        }).await;
+        
+        let filter = require_token(server.url("/token"), get_http_client());
+
+        let res: User = warp::test::request()
+            .path("/")
+            .header("Authorization", "Bearer token")
+            .filter(&filter)
+            .await
+            .unwrap();
+
+        assert_eq!(res.me.as_str(), "https://fireburn.ru/")
+    }
+
+    #[tokio::test]
+    async fn test_require_token_fake_token() {
+        let server = MockServer::start_async().await;
+        server.mock_async(|when, then| {
+            when.path("/refuse_token");
+
+            then.status(200)
+                .json_body(serde_json::json!({"active": false}));
+        }).await;
+
+        let filter = require_token(server.url("/refuse_token"), get_http_client());
+
+        let res = warp::test::request()
+            .path("/")
+            .header("Authorization", "Bearer token")
+            .filter(&filter)
+            .await
+            .unwrap_err();
+
+        let err: &IndieAuthError = res.find().unwrap();
+        assert_eq!(err.kind, super::ErrorKind::NotAuthorized);
+    }
+
+    #[tokio::test]
+    async fn test_require_token_no_token() {
+        let server = MockServer::start_async().await;
+        let mock = server.mock_async(|when, then| {
+            when.path("/should_never_be_called");
+
+            then.status(500);
+        }).await;
+        let filter = require_token(server.url("/should_never_be_called"), get_http_client());
+
+        let res = warp::test::request()
+            .path("/")
+            .filter(&filter)
+            .await
+            .unwrap_err();
+
+        let err: &IndieAuthError = res.find().unwrap();
+        assert_eq!(err.kind, super::ErrorKind::NotAuthorized);
+
+        mock.assert_hits_async(0).await;
+    }
+
+    #[tokio::test]
+    async fn test_require_token_400_error_unauthorized() {
+        let server = MockServer::start_async().await;
+        server.mock_async(|when, then| {
+            when.path("/refuse_token_with_400");
+
+            then.status(400)
+                .json_body(serde_json::json!({
+                    "error": "unauthorized",
+                    "error_description": "The token provided was malformed"
+                }));
+        }).await;
+
+        let filter = require_token(server.url("/refuse_token_with_400"), get_http_client());
+
+        let res = warp::test::request()
+            .path("/")
+            .header("Authorization", "Bearer token")
+            .filter(&filter)
+            .await
+            .unwrap_err();
+
+        let err: &IndieAuthError = res.find().unwrap();
+        assert_eq!(err.kind, super::ErrorKind::NotAuthorized);
+    }
+}