about summary refs log tree commit diff
path: root/src/indieauth/backend
diff options
context:
space:
mode:
Diffstat (limited to 'src/indieauth/backend')
-rw-r--r--src/indieauth/backend/fs.rs282
1 files changed, 172 insertions, 110 deletions
diff --git a/src/indieauth/backend/fs.rs b/src/indieauth/backend/fs.rs
index f74fbbc..26466fe 100644
--- a/src/indieauth/backend/fs.rs
+++ b/src/indieauth/backend/fs.rs
@@ -1,13 +1,16 @@
-use std::{path::PathBuf, collections::HashMap, borrow::Cow, time::{SystemTime, Duration}};
-
-use super::{AuthBackend, Result, EnrolledCredential};
-use kittybox_indieauth::{
-    AuthorizationRequest, TokenData
+use std::{
+    borrow::Cow,
+    collections::HashMap,
+    path::PathBuf,
+    time::{Duration, SystemTime},
 };
+
+use super::{AuthBackend, EnrolledCredential, Result};
+use kittybox_indieauth::{AuthorizationRequest, TokenData};
 use serde::de::DeserializeOwned;
-use tokio::{task::spawn_blocking, io::AsyncReadExt};
+use tokio::{io::AsyncReadExt, task::spawn_blocking};
 #[cfg(feature = "webauthn")]
-use webauthn::prelude::{Passkey, PasskeyRegistration, PasskeyAuthentication};
+use webauthn::prelude::{Passkey, PasskeyAuthentication, PasskeyRegistration};
 
 const CODE_LENGTH: usize = 16;
 const TOKEN_LENGTH: usize = 128;
@@ -29,7 +32,8 @@ impl FileBackend {
         } else {
             let mut s = String::with_capacity(filename.len());
 
-            filename.chars()
+            filename
+                .chars()
                 .filter(|c| c.is_alphanumeric())
                 .for_each(|c| s.push(c));
 
@@ -38,41 +42,41 @@ impl FileBackend {
     }
 
     #[inline]
-    async fn serialize_to_file<T: 'static + serde::ser::Serialize + Send, B: Into<Option<&'static str>>>(
+    async fn serialize_to_file<
+        T: 'static + serde::ser::Serialize + Send,
+        B: Into<Option<&'static str>>,
+    >(
         &self,
         dir: &str,
         basename: B,
         length: usize,
-        data: T
+        data: T,
     ) -> Result<String> {
         let basename = basename.into();
         let has_ext = basename.is_some();
-        let (filename, mut file) = kittybox_util::fs::mktemp(
-            self.path.join(dir), basename, length
-        )
+        let (filename, mut file) = kittybox_util::fs::mktemp(self.path.join(dir), basename, length)
             .await
             .map(|(name, file)| (name, file.try_into_std().unwrap()))?;
 
         spawn_blocking(move || serde_json::to_writer(&mut file, &data))
             .await
-            .unwrap_or_else(|e| panic!(
-                "Panic while serializing {}: {}",
-                std::any::type_name::<T>(),
-                e
-            ))
+            .unwrap_or_else(|e| {
+                panic!(
+                    "Panic while serializing {}: {}",
+                    std::any::type_name::<T>(),
+                    e
+                )
+            })
             .map(move |_| {
                 (if has_ext {
-                    filename
-                        .extension()
-
+                    filename.extension()
                 } else {
-                    filename
-                        .file_name()
+                    filename.file_name()
                 })
-                    .unwrap()
-                    .to_str()
-                    .unwrap()
-                    .to_owned()
+                .unwrap()
+                .to_str()
+                .unwrap()
+                .to_owned()
             })
             .map_err(|err| err.into())
     }
@@ -86,17 +90,15 @@ impl FileBackend {
     ) -> Result<Option<(PathBuf, SystemTime, T)>>
     where
         T: serde::de::DeserializeOwned + Send,
-        B: Into<Option<&'static str>>
+        B: Into<Option<&'static str>>,
     {
         let basename = basename.into();
-        let path = self.path
-            .join(dir)
-            .join(format!(
-                "{}{}{}",
-                basename.unwrap_or(""),
-                if basename.is_none() { "" } else { "." },
-                FileBackend::sanitize_for_path(filename)
-            ));
+        let path = self.path.join(dir).join(format!(
+            "{}{}{}",
+            basename.unwrap_or(""),
+            if basename.is_none() { "" } else { "." },
+            FileBackend::sanitize_for_path(filename)
+        ));
 
         let data = match tokio::fs::File::open(&path).await {
             Ok(mut file) => {
@@ -106,13 +108,15 @@ impl FileBackend {
 
                 match serde_json::from_slice::<'_, T>(buf.as_slice()) {
                     Ok(data) => data,
-                    Err(err) => return Err(err.into())
+                    Err(err) => return Err(err.into()),
+                }
+            }
+            Err(err) => {
+                if err.kind() == std::io::ErrorKind::NotFound {
+                    return Ok(None);
+                } else {
+                    return Err(err);
                 }
-            },
-            Err(err) => if err.kind() == std::io::ErrorKind::NotFound {
-                return Ok(None)
-            } else {
-                return Err(err)
             }
         };
 
@@ -125,7 +129,8 @@ impl FileBackend {
     #[tracing::instrument]
     fn url_to_dir(url: &url::Url) -> String {
         let host = url.host_str().unwrap();
-        let port = url.port()
+        let port = url
+            .port()
             .map(|port| Cow::Owned(format!(":{}", port)))
             .unwrap_or(Cow::Borrowed(""));
 
@@ -135,23 +140,26 @@ impl FileBackend {
     async fn list_files<'dir, 'this: 'dir, T: DeserializeOwned + Send>(
         &'this self,
         dir: &'dir str,
-        prefix: &'static str
+        prefix: &'static str,
     ) -> Result<HashMap<String, T>> {
         let dir = self.path.join(dir);
 
         let mut hashmap = HashMap::new();
         let mut readdir = match tokio::fs::read_dir(dir).await {
             Ok(readdir) => readdir,
-            Err(err) => if err.kind() == std::io::ErrorKind::NotFound {
-                // empty hashmap
-                return Ok(hashmap);
-            } else {
-                return Err(err);
+            Err(err) => {
+                if err.kind() == std::io::ErrorKind::NotFound {
+                    // empty hashmap
+                    return Ok(hashmap);
+                } else {
+                    return Err(err);
+                }
             }
         };
         while let Some(entry) = readdir.next_entry().await? {
             // safe to unwrap; filenames are alphanumeric
-            let filename = entry.file_name()
+            let filename = entry
+                .file_name()
                 .into_string()
                 .expect("token filenames should be alphanumeric!");
             if let Some(token) = filename.strip_prefix(&format!("{}.", prefix)) {
@@ -166,16 +174,19 @@ impl FileBackend {
                             Err(err) => {
                                 tracing::error!(
                                     "Error decoding token data from file {}: {}",
-                                    entry.path().display(), err
+                                    entry.path().display(),
+                                    err
                                 );
                                 continue;
                             }
                         };
-                    },
-                    Err(err) => if err.kind() == std::io::ErrorKind::NotFound {
-                        continue
-                    } else {
-                        return Err(err)
+                    }
+                    Err(err) => {
+                        if err.kind() == std::io::ErrorKind::NotFound {
+                            continue;
+                        } else {
+                            return Err(err);
+                        }
                     }
                 }
             }
@@ -194,19 +205,27 @@ impl AuthBackend for FileBackend {
 
             path = base.join(&format!(".{}", path.path())).unwrap();
         }
-        tracing::debug!("Initializing File auth backend: {} -> {}", orig_path, path.path());
+        tracing::debug!(
+            "Initializing File auth backend: {} -> {}",
+            orig_path,
+            path.path()
+        );
         Ok(Self {
-            path: std::path::PathBuf::from(path.path())
+            path: std::path::PathBuf::from(path.path()),
         })
     }
 
     // Authorization code management.
     async fn create_code(&self, data: AuthorizationRequest) -> Result<String> {
-        self.serialize_to_file("codes", None, CODE_LENGTH, data).await
+        self.serialize_to_file("codes", None, CODE_LENGTH, data)
+            .await
     }
 
     async fn get_code(&self, code: &str) -> Result<Option<AuthorizationRequest>> {
-        match self.deserialize_from_file("codes", None, FileBackend::sanitize_for_path(code).as_ref()).await? {
+        match self
+            .deserialize_from_file("codes", None, FileBackend::sanitize_for_path(code).as_ref())
+            .await?
+        {
             Some((path, ctime, data)) => {
                 if let Err(err) = tokio::fs::remove_file(path).await {
                     tracing::error!("Failed to clean up authorization code: {}", err);
@@ -217,23 +236,28 @@ impl AuthBackend for FileBackend {
                 } else {
                     Ok(Some(data))
                 }
-            },
-            None => Ok(None)
+            }
+            None => Ok(None),
         }
     }
 
     // Token management.
     async fn create_token(&self, data: TokenData) -> Result<String> {
         let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me));
-        self.serialize_to_file(&dir, "access", TOKEN_LENGTH, data).await
+        self.serialize_to_file(&dir, "access", TOKEN_LENGTH, data)
+            .await
     }
 
     async fn get_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> {
         let dir = format!("{}/tokens", FileBackend::url_to_dir(website));
-        match self.deserialize_from_file::<TokenData, _>(
-            &dir, "access",
-            FileBackend::sanitize_for_path(token).as_ref()
-        ).await? {
+        match self
+            .deserialize_from_file::<TokenData, _>(
+                &dir,
+                "access",
+                FileBackend::sanitize_for_path(token).as_ref(),
+            )
+            .await?
+        {
             Some((path, _, token)) => {
                 if token.expired() {
                     if let Err(err) = tokio::fs::remove_file(path).await {
@@ -243,8 +267,8 @@ impl AuthBackend for FileBackend {
                 } else {
                     Ok(Some(token))
                 }
-            },
-            None => Ok(None)
+            }
+            None => Ok(None),
         }
     }
 
@@ -258,25 +282,36 @@ impl AuthBackend for FileBackend {
             self.path
                 .join(FileBackend::url_to_dir(website))
                 .join("tokens")
-                .join(format!("access.{}", FileBackend::sanitize_for_path(token)))
-        ).await {
+                .join(format!("access.{}", FileBackend::sanitize_for_path(token))),
+        )
+        .await
+        {
             Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
-            result => result
+            result => result,
         }
     }
 
     // Refresh token management.
     async fn create_refresh_token(&self, data: TokenData) -> Result<String> {
         let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me));
-        self.serialize_to_file(&dir, "refresh", TOKEN_LENGTH, data).await
+        self.serialize_to_file(&dir, "refresh", TOKEN_LENGTH, data)
+            .await
     }
 
-    async fn get_refresh_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> {
+    async fn get_refresh_token(
+        &self,
+        website: &url::Url,
+        token: &str,
+    ) -> Result<Option<TokenData>> {
         let dir = format!("{}/tokens", FileBackend::url_to_dir(website));
-        match self.deserialize_from_file::<TokenData, _>(
-            &dir, "refresh",
-            FileBackend::sanitize_for_path(token).as_ref()
-        ).await? {
+        match self
+            .deserialize_from_file::<TokenData, _>(
+                &dir,
+                "refresh",
+                FileBackend::sanitize_for_path(token).as_ref(),
+            )
+            .await?
+        {
             Some((path, _, token)) => {
                 if token.expired() {
                     if let Err(err) = tokio::fs::remove_file(path).await {
@@ -286,8 +321,8 @@ impl AuthBackend for FileBackend {
                 } else {
                     Ok(Some(token))
                 }
-            },
-            None => Ok(None)
+            }
+            None => Ok(None),
         }
     }
 
@@ -301,57 +336,80 @@ impl AuthBackend for FileBackend {
             self.path
                 .join(FileBackend::url_to_dir(website))
                 .join("tokens")
-                .join(format!("refresh.{}", FileBackend::sanitize_for_path(token)))
-        ).await {
+                .join(format!("refresh.{}", FileBackend::sanitize_for_path(token))),
+        )
+        .await
+        {
             Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
-            result => result
+            result => result,
         }
     }
 
     // Password management.
     #[tracing::instrument(skip(password))]
     async fn verify_password(&self, website: &url::Url, password: String) -> Result<bool> {
-        use argon2::{Argon2, password_hash::{PasswordHash, PasswordVerifier}};
+        use argon2::{
+            password_hash::{PasswordHash, PasswordVerifier},
+            Argon2,
+        };
 
-        let password_filename = self.path
+        let password_filename = self
+            .path
             .join(FileBackend::url_to_dir(website))
             .join("password");
 
-        tracing::debug!("Reading password for {} from {}", website, password_filename.display());
+        tracing::debug!(
+            "Reading password for {} from {}",
+            website,
+            password_filename.display()
+        );
 
         match tokio::fs::read_to_string(password_filename).await {
             Ok(password_hash) => {
                 let parsed_hash = {
                     let hash = password_hash.trim();
-                    #[cfg(debug_assertions)] tracing::debug!("Password hash: {}", hash);
-                    PasswordHash::new(hash)
-                        .expect("Password hash should be valid!")
+                    #[cfg(debug_assertions)]
+                    tracing::debug!("Password hash: {}", hash);
+                    PasswordHash::new(hash).expect("Password hash should be valid!")
                 };
-                Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok())
-            },
-            Err(err) => if err.kind() == std::io::ErrorKind::NotFound {
-                Ok(false)
-            } else {
-                Err(err)
+                Ok(Argon2::default()
+                    .verify_password(password.as_bytes(), &parsed_hash)
+                    .is_ok())
+            }
+            Err(err) => {
+                if err.kind() == std::io::ErrorKind::NotFound {
+                    Ok(false)
+                } else {
+                    Err(err)
+                }
             }
         }
     }
 
     #[tracing::instrument(skip(password))]
     async fn enroll_password(&self, website: &url::Url, password: String) -> Result<()> {
-        use argon2::{Argon2, password_hash::{rand_core::OsRng, PasswordHasher, SaltString}};
+        use argon2::{
+            password_hash::{rand_core::OsRng, PasswordHasher, SaltString},
+            Argon2,
+        };
 
-        let password_filename = self.path
+        let password_filename = self
+            .path
             .join(FileBackend::url_to_dir(website))
             .join("password");
 
         let salt = SaltString::generate(&mut OsRng);
         let argon2 = Argon2::default();
-        let password_hash = argon2.hash_password(password.as_bytes(), &salt)
+        let password_hash = argon2
+            .hash_password(password.as_bytes(), &salt)
             .expect("Hashing a password should not error out")
             .to_string();
 
-        tracing::debug!("Enrolling password for {} at {}", website, password_filename.display());
+        tracing::debug!(
+            "Enrolling password for {} at {}",
+            website,
+            password_filename.display()
+        );
         tokio::fs::write(password_filename, password_hash.as_bytes()).await
     }
 
@@ -371,7 +429,7 @@ impl AuthBackend for FileBackend {
     async fn persist_registration_challenge(
         &self,
         website: &url::Url,
-        state: PasskeyRegistration
+        state: PasskeyRegistration,
     ) -> Result<String> {
         todo!()
     }
@@ -380,7 +438,7 @@ impl AuthBackend for FileBackend {
     async fn retrieve_registration_challenge(
         &self,
         website: &url::Url,
-        challenge_id: &str
+        challenge_id: &str,
     ) -> Result<PasskeyRegistration> {
         todo!()
     }
@@ -389,7 +447,7 @@ impl AuthBackend for FileBackend {
     async fn persist_authentication_challenge(
         &self,
         website: &url::Url,
-        state: PasskeyAuthentication
+        state: PasskeyAuthentication,
     ) -> Result<String> {
         todo!()
     }
@@ -398,24 +456,28 @@ impl AuthBackend for FileBackend {
     async fn retrieve_authentication_challenge(
         &self,
         website: &url::Url,
-        challenge_id: &str
+        challenge_id: &str,
     ) -> Result<PasskeyAuthentication> {
         todo!()
     }
 
     #[tracing::instrument(skip(self))]
-    async fn list_user_credential_types(&self, website: &url::Url) -> Result<Vec<EnrolledCredential>> {
+    async fn list_user_credential_types(
+        &self,
+        website: &url::Url,
+    ) -> Result<Vec<EnrolledCredential>> {
         let mut creds = vec![];
-        let password_file = self.path
+        let password_file = self
+            .path
             .join(FileBackend::url_to_dir(website))
             .join("password");
         tracing::debug!("Password file for {}: {}", website, password_file.display());
-        match tokio::fs::metadata(password_file)
-            .await
-        {
+        match tokio::fs::metadata(password_file).await {
             Ok(_) => creds.push(EnrolledCredential::Password),
-            Err(err) => if err.kind() != std::io::ErrorKind::NotFound {
-                return Err(err)
+            Err(err) => {
+                if err.kind() != std::io::ErrorKind::NotFound {
+                    return Err(err);
+                }
             }
         }