about summary refs log blame commit diff
path: root/src/indieauth/backend/fs.rs
blob: f74fbbcd0019bf50a7d9149f6f09a0134b227e81 (plain) (tree)
1
2
3
4
5
6
7
8
9

                                                                                          



                                                    
                            




                                                                               
                       



                        











































                                                                                                         
 























































                                                                       
                          




























































                                                                                 
                                  
                                                      






                                                                                            



                                                       











































































































                                                                                                               
                                          





                                                                                           
                                                                                                 
















                                                                                                
                                          




                                                                                            
 




                                                                            
                                                                                                 


                                                                           
                                


                                                                                            
                                



                                                                                       
                                






                                            
                                






                                             
                                






                                              
                                






                                               
                                      
                                                                                                       



                                                                                      
                  
                                                              



                                                                       
                                    





                                                                   
use std::{path::PathBuf, collections::HashMap, borrow::Cow, time::{SystemTime, Duration}};

use super::{AuthBackend, Result, EnrolledCredential};
use kittybox_indieauth::{
    AuthorizationRequest, TokenData
};
use serde::de::DeserializeOwned;
use tokio::{task::spawn_blocking, io::AsyncReadExt};
#[cfg(feature = "webauthn")]
use webauthn::prelude::{Passkey, PasskeyRegistration, PasskeyAuthentication};

const CODE_LENGTH: usize = 16;
const TOKEN_LENGTH: usize = 128;
const CODE_DURATION: std::time::Duration = std::time::Duration::from_secs(600);

#[derive(Clone, Debug)]
pub struct FileBackend {
    path: PathBuf,
}

impl FileBackend {
    /// Sanitize a filename, leaving only alphanumeric characters.
    ///
    /// Doesn't allocate a new string unless non-alphanumeric
    /// characters are encountered.
    fn sanitize_for_path(filename: &'_ str) -> Cow<'_, str> {
        if filename.chars().all(char::is_alphanumeric) {
            Cow::Borrowed(filename)
        } else {
            let mut s = String::with_capacity(filename.len());

            filename.chars()
                .filter(|c| c.is_alphanumeric())
                .for_each(|c| s.push(c));

            Cow::Owned(s)
        }
    }

    #[inline]
    async fn serialize_to_file<T: 'static + serde::ser::Serialize + Send, B: Into<Option<&'static str>>>(
        &self,
        dir: &str,
        basename: B,
        length: usize,
        data: T
    ) -> Result<String> {
        let basename = basename.into();
        let has_ext = basename.is_some();
        let (filename, mut file) = kittybox_util::fs::mktemp(
            self.path.join(dir), basename, length
        )
            .await
            .map(|(name, file)| (name, file.try_into_std().unwrap()))?;

        spawn_blocking(move || serde_json::to_writer(&mut file, &data))
            .await
            .unwrap_or_else(|e| panic!(
                "Panic while serializing {}: {}",
                std::any::type_name::<T>(),
                e
            ))
            .map(move |_| {
                (if has_ext {
                    filename
                        .extension()

                } else {
                    filename
                        .file_name()
                })
                    .unwrap()
                    .to_str()
                    .unwrap()
                    .to_owned()
            })
            .map_err(|err| err.into())
    }

    #[inline]
    async fn deserialize_from_file<'filename, 'this: 'filename, T, B>(
        &'this self,
        dir: &'filename str,
        basename: B,
        filename: &'filename str,
    ) -> Result<Option<(PathBuf, SystemTime, T)>>
    where
        T: serde::de::DeserializeOwned + Send,
        B: Into<Option<&'static str>>
    {
        let basename = basename.into();
        let path = self.path
            .join(dir)
            .join(format!(
                "{}{}{}",
                basename.unwrap_or(""),
                if basename.is_none() { "" } else { "." },
                FileBackend::sanitize_for_path(filename)
            ));

        let data = match tokio::fs::File::open(&path).await {
            Ok(mut file) => {
                let mut buf = Vec::new();

                file.read_to_end(&mut buf).await?;

                match serde_json::from_slice::<'_, T>(buf.as_slice()) {
                    Ok(data) => data,
                    Err(err) => return Err(err.into())
                }
            },
            Err(err) => if err.kind() == std::io::ErrorKind::NotFound {
                return Ok(None)
            } else {
                return Err(err)
            }
        };

        let ctime = tokio::fs::metadata(&path).await?.created()?;

        Ok(Some((path, ctime, data)))
    }

    #[inline]
    #[tracing::instrument]
    fn url_to_dir(url: &url::Url) -> String {
        let host = url.host_str().unwrap();
        let port = url.port()
            .map(|port| Cow::Owned(format!(":{}", port)))
            .unwrap_or(Cow::Borrowed(""));

        format!("{}{}", host, port)
    }

    async fn list_files<'dir, 'this: 'dir, T: DeserializeOwned + Send>(
        &'this self,
        dir: &'dir str,
        prefix: &'static str
    ) -> Result<HashMap<String, T>> {
        let dir = self.path.join(dir);

        let mut hashmap = HashMap::new();
        let mut readdir = match tokio::fs::read_dir(dir).await {
            Ok(readdir) => readdir,
            Err(err) => if err.kind() == std::io::ErrorKind::NotFound {
                // empty hashmap
                return Ok(hashmap);
            } else {
                return Err(err);
            }
        };
        while let Some(entry) = readdir.next_entry().await? {
            // safe to unwrap; filenames are alphanumeric
            let filename = entry.file_name()
                .into_string()
                .expect("token filenames should be alphanumeric!");
            if let Some(token) = filename.strip_prefix(&format!("{}.", prefix)) {
                match tokio::fs::File::open(entry.path()).await {
                    Ok(mut file) => {
                        let mut buf = Vec::new();

                        file.read_to_end(&mut buf).await?;

                        match serde_json::from_slice::<'_, T>(buf.as_slice()) {
                            Ok(data) => hashmap.insert(token.to_string(), data),
                            Err(err) => {
                                tracing::error!(
                                    "Error decoding token data from file {}: {}",
                                    entry.path().display(), err
                                );
                                continue;
                            }
                        };
                    },
                    Err(err) => if err.kind() == std::io::ErrorKind::NotFound {
                        continue
                    } else {
                        return Err(err)
                    }
                }
            }
        }

        Ok(hashmap)
    }
}

impl AuthBackend for FileBackend {
    async fn new(path: &'_ url::Url) -> Result<Self> {
        let orig_path = path;
        let mut path = orig_path.clone();
        if path.host_str() == Some(".") {
            let base = url::Url::from_directory_path(std::env::current_dir()?).unwrap();

            path = base.join(&format!(".{}", path.path())).unwrap();
        }
        tracing::debug!("Initializing File auth backend: {} -> {}", orig_path, path.path());
        Ok(Self {
            path: std::path::PathBuf::from(path.path())
        })
    }

    // Authorization code management.
    async fn create_code(&self, data: AuthorizationRequest) -> Result<String> {
        self.serialize_to_file("codes", None, CODE_LENGTH, data).await
    }

    async fn get_code(&self, code: &str) -> Result<Option<AuthorizationRequest>> {
        match self.deserialize_from_file("codes", None, FileBackend::sanitize_for_path(code).as_ref()).await? {
            Some((path, ctime, data)) => {
                if let Err(err) = tokio::fs::remove_file(path).await {
                    tracing::error!("Failed to clean up authorization code: {}", err);
                }
                // Err on the safe side in case of clock drift
                if ctime.elapsed().unwrap_or(Duration::ZERO) > CODE_DURATION {
                    Ok(None)
                } else {
                    Ok(Some(data))
                }
            },
            None => Ok(None)
        }
    }

    // Token management.
    async fn create_token(&self, data: TokenData) -> Result<String> {
        let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me));
        self.serialize_to_file(&dir, "access", TOKEN_LENGTH, data).await
    }

    async fn get_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> {
        let dir = format!("{}/tokens", FileBackend::url_to_dir(website));
        match self.deserialize_from_file::<TokenData, _>(
            &dir, "access",
            FileBackend::sanitize_for_path(token).as_ref()
        ).await? {
            Some((path, _, token)) => {
                if token.expired() {
                    if let Err(err) = tokio::fs::remove_file(path).await {
                        tracing::error!("Failed to remove expired token: {}", err);
                    }
                    Ok(None)
                } else {
                    Ok(Some(token))
                }
            },
            None => Ok(None)
        }
    }

    async fn list_tokens(&self, website: &url::Url) -> Result<HashMap<String, TokenData>> {
        let dir = format!("{}/tokens", FileBackend::url_to_dir(website));
        self.list_files(&dir, "access").await
    }

    async fn revoke_token(&self, website: &url::Url, token: &str) -> Result<()> {
        match tokio::fs::remove_file(
            self.path
                .join(FileBackend::url_to_dir(website))
                .join("tokens")
                .join(format!("access.{}", FileBackend::sanitize_for_path(token)))
        ).await {
            Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
            result => result
        }
    }

    // Refresh token management.
    async fn create_refresh_token(&self, data: TokenData) -> Result<String> {
        let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me));
        self.serialize_to_file(&dir, "refresh", TOKEN_LENGTH, data).await
    }

    async fn get_refresh_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> {
        let dir = format!("{}/tokens", FileBackend::url_to_dir(website));
        match self.deserialize_from_file::<TokenData, _>(
            &dir, "refresh",
            FileBackend::sanitize_for_path(token).as_ref()
        ).await? {
            Some((path, _, token)) => {
                if token.expired() {
                    if let Err(err) = tokio::fs::remove_file(path).await {
                        tracing::error!("Failed to remove expired token: {}", err);
                    }
                    Ok(None)
                } else {
                    Ok(Some(token))
                }
            },
            None => Ok(None)
        }
    }

    async fn list_refresh_tokens(&self, website: &url::Url) -> Result<HashMap<String, TokenData>> {
        let dir = format!("{}/tokens", FileBackend::url_to_dir(website));
        self.list_files(&dir, "refresh").await
    }

    async fn revoke_refresh_token(&self, website: &url::Url, token: &str) -> Result<()> {
        match tokio::fs::remove_file(
            self.path
                .join(FileBackend::url_to_dir(website))
                .join("tokens")
                .join(format!("refresh.{}", FileBackend::sanitize_for_path(token)))
        ).await {
            Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
            result => result
        }
    }

    // Password management.
    #[tracing::instrument(skip(password))]
    async fn verify_password(&self, website: &url::Url, password: String) -> Result<bool> {
        use argon2::{Argon2, password_hash::{PasswordHash, PasswordVerifier}};

        let password_filename = self.path
            .join(FileBackend::url_to_dir(website))
            .join("password");

        tracing::debug!("Reading password for {} from {}", website, password_filename.display());

        match tokio::fs::read_to_string(password_filename).await {
            Ok(password_hash) => {
                let parsed_hash = {
                    let hash = password_hash.trim();
                    #[cfg(debug_assertions)] tracing::debug!("Password hash: {}", hash);
                    PasswordHash::new(hash)
                        .expect("Password hash should be valid!")
                };
                Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok())
            },
            Err(err) => if err.kind() == std::io::ErrorKind::NotFound {
                Ok(false)
            } else {
                Err(err)
            }
        }
    }

    #[tracing::instrument(skip(password))]
    async fn enroll_password(&self, website: &url::Url, password: String) -> Result<()> {
        use argon2::{Argon2, password_hash::{rand_core::OsRng, PasswordHasher, SaltString}};

        let password_filename = self.path
            .join(FileBackend::url_to_dir(website))
            .join("password");

        let salt = SaltString::generate(&mut OsRng);
        let argon2 = Argon2::default();
        let password_hash = argon2.hash_password(password.as_bytes(), &salt)
            .expect("Hashing a password should not error out")
            .to_string();

        tracing::debug!("Enrolling password for {} at {}", website, password_filename.display());
        tokio::fs::write(password_filename, password_hash.as_bytes()).await
    }

    // WebAuthn credential management.
    #[cfg(feature = "webauthn")]
    async fn enroll_webauthn(&self, website: &url::Url, credential: Passkey) -> Result<()> {
        todo!()
    }

    #[cfg(feature = "webauthn")]
    async fn list_webauthn_pubkeys(&self, website: &url::Url) -> Result<Vec<Passkey>> {
        // TODO stub!
        Ok(vec![])
    }

    #[cfg(feature = "webauthn")]
    async fn persist_registration_challenge(
        &self,
        website: &url::Url,
        state: PasskeyRegistration
    ) -> Result<String> {
        todo!()
    }

    #[cfg(feature = "webauthn")]
    async fn retrieve_registration_challenge(
        &self,
        website: &url::Url,
        challenge_id: &str
    ) -> Result<PasskeyRegistration> {
        todo!()
    }

    #[cfg(feature = "webauthn")]
    async fn persist_authentication_challenge(
        &self,
        website: &url::Url,
        state: PasskeyAuthentication
    ) -> Result<String> {
        todo!()
    }

    #[cfg(feature = "webauthn")]
    async fn retrieve_authentication_challenge(
        &self,
        website: &url::Url,
        challenge_id: &str
    ) -> Result<PasskeyAuthentication> {
        todo!()
    }

    #[tracing::instrument(skip(self))]
    async fn list_user_credential_types(&self, website: &url::Url) -> Result<Vec<EnrolledCredential>> {
        let mut creds = vec![];
        let password_file = self.path
            .join(FileBackend::url_to_dir(website))
            .join("password");
        tracing::debug!("Password file for {}: {}", website, password_file.display());
        match tokio::fs::metadata(password_file)
            .await
        {
            Ok(_) => creds.push(EnrolledCredential::Password),
            Err(err) => if err.kind() != std::io::ErrorKind::NotFound {
                return Err(err)
            }
        }

        #[cfg(feature = "webauthn")]
        if !self.list_webauthn_pubkeys(website).await?.is_empty() {
            creds.push(EnrolledCredential::WebAuthn);
        }

        Ok(creds)
    }
}