use std::{path::PathBuf, collections::HashMap, borrow::Cow, time::{SystemTime, Duration}}; use super::{AuthBackend, Result, EnrolledCredential}; use async_trait::async_trait; use kittybox_indieauth::{ AuthorizationRequest, TokenData }; use serde::de::DeserializeOwned; use tokio::{task::spawn_blocking, io::AsyncReadExt}; #[cfg(feature = "webauthn")] use webauthn::prelude::{Passkey, PasskeyRegistration, PasskeyAuthentication}; const CODE_LENGTH: usize = 16; const TOKEN_LENGTH: usize = 128; const CODE_DURATION: std::time::Duration = std::time::Duration::from_secs(600); #[derive(Clone, Debug)] pub struct FileBackend { path: PathBuf, } impl FileBackend { /// Sanitize a filename, leaving only alphanumeric characters. /// /// Doesn't allocate a new string unless non-alphanumeric /// characters are encountered. fn sanitize_for_path(filename: &'_ str) -> Cow<'_, str> { if filename.chars().all(char::is_alphanumeric) { Cow::Borrowed(filename) } else { let mut s = String::with_capacity(filename.len()); filename.chars() .filter(|c| c.is_alphanumeric()) .for_each(|c| s.push(c)); Cow::Owned(s) } } #[inline] async fn serialize_to_file<T: 'static + serde::ser::Serialize + Send, B: Into<Option<&'static str>>>( &self, dir: &str, basename: B, length: usize, data: T ) -> Result<String> { let basename = basename.into(); let has_ext = basename.is_some(); let (filename, mut file) = kittybox_util::fs::mktemp( self.path.join(dir), basename, length ) .await .map(|(name, file)| (name, file.try_into_std().unwrap()))?; spawn_blocking(move || serde_json::to_writer(&mut file, &data)) .await .unwrap_or_else(|e| panic!( "Panic while serializing {}: {}", std::any::type_name::<T>(), e )) .map(move |_| { (if has_ext { filename .extension() } else { filename .file_name() }) .unwrap() .to_str() .unwrap() .to_owned() }) .map_err(|err| err.into()) } #[inline] async fn deserialize_from_file<'filename, 'this: 'filename, T, B>( &'this self, dir: &'filename str, basename: B, filename: &'filename str, ) -> Result<Option<(PathBuf, SystemTime, T)>> where T: serde::de::DeserializeOwned + Send, B: Into<Option<&'static str>> { let basename = basename.into(); let path = self.path .join(dir) .join(format!( "{}{}{}", basename.unwrap_or(""), if basename.is_none() { "" } else { "." }, FileBackend::sanitize_for_path(filename) )); let data = match tokio::fs::File::open(&path).await { Ok(mut file) => { let mut buf = Vec::new(); file.read_to_end(&mut buf).await?; match serde_json::from_slice::<'_, T>(buf.as_slice()) { Ok(data) => data, Err(err) => return Err(err.into()) } }, Err(err) => if err.kind() == std::io::ErrorKind::NotFound { return Ok(None) } else { return Err(err) } }; let ctime = tokio::fs::metadata(&path).await?.created()?; Ok(Some((path, ctime, data))) } #[inline] fn url_to_dir(url: &url::Url) -> String { let host = url.host_str().unwrap(); let port = url.port() .map(|port| Cow::Owned(format!(":{}", port))) .unwrap_or(Cow::Borrowed("")); format!("{}{}", host, port) } async fn list_files<'dir, 'this: 'dir, T: DeserializeOwned + Send>( &'this self, dir: &'dir str, prefix: &'static str ) -> Result<HashMap<String, T>> { let dir = self.path.join(dir); let mut hashmap = HashMap::new(); let mut readdir = match tokio::fs::read_dir(dir).await { Ok(readdir) => readdir, Err(err) => if err.kind() == std::io::ErrorKind::NotFound { // empty hashmap return Ok(hashmap); } else { return Err(err); } }; while let Some(entry) = readdir.next_entry().await? { // safe to unwrap; filenames are alphanumeric let filename = entry.file_name() .into_string() .expect("token filenames should be alphanumeric!"); if let Some(token) = filename.strip_prefix(&format!("{}.", prefix)) { match tokio::fs::File::open(entry.path()).await { Ok(mut file) => { let mut buf = Vec::new(); file.read_to_end(&mut buf).await?; match serde_json::from_slice::<'_, T>(buf.as_slice()) { Ok(data) => hashmap.insert(token.to_string(), data), Err(err) => { tracing::error!( "Error decoding token data from file {}: {}", entry.path().display(), err ); continue; } }; }, Err(err) => if err.kind() == std::io::ErrorKind::NotFound { continue } else { return Err(err) } } } } Ok(hashmap) } } #[async_trait] impl AuthBackend for FileBackend { async fn new(path: &'_ url::Url) -> Result<Self> { Ok(Self { path: std::path::PathBuf::from(path.path()) }) } // Authorization code management. async fn create_code(&self, data: AuthorizationRequest) -> Result<String> { self.serialize_to_file("codes", None, CODE_LENGTH, data).await } async fn get_code(&self, code: &str) -> Result<Option<AuthorizationRequest>> { match self.deserialize_from_file("codes", None, FileBackend::sanitize_for_path(code).as_ref()).await? { Some((path, ctime, data)) => { if let Err(err) = tokio::fs::remove_file(path).await { tracing::error!("Failed to clean up authorization code: {}", err); } // Err on the safe side in case of clock drift if ctime.elapsed().unwrap_or(Duration::ZERO) > CODE_DURATION { Ok(None) } else { Ok(Some(data)) } }, None => Ok(None) } } // Token management. async fn create_token(&self, data: TokenData) -> Result<String> { let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me)); self.serialize_to_file(&dir, "access", TOKEN_LENGTH, data).await } async fn get_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> { let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); match self.deserialize_from_file::<TokenData, _>( &dir, "access", FileBackend::sanitize_for_path(token).as_ref() ).await? { Some((path, _, token)) => { if token.expired() { if let Err(err) = tokio::fs::remove_file(path).await { tracing::error!("Failed to remove expired token: {}", err); } Ok(None) } else { Ok(Some(token)) } }, None => Ok(None) } } async fn list_tokens(&self, website: &url::Url) -> Result<HashMap<String, TokenData>> { let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); self.list_files(&dir, "access").await } async fn revoke_token(&self, website: &url::Url, token: &str) -> Result<()> { match tokio::fs::remove_file( self.path .join(FileBackend::url_to_dir(website)) .join("tokens") .join(format!("access.{}", FileBackend::sanitize_for_path(token))) ).await { Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()), result => result } } // Refresh token management. async fn create_refresh_token(&self, data: TokenData) -> Result<String> { let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me)); self.serialize_to_file(&dir, "refresh", TOKEN_LENGTH, data).await } async fn get_refresh_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> { let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); match self.deserialize_from_file::<TokenData, _>( &dir, "refresh", FileBackend::sanitize_for_path(token).as_ref() ).await? { Some((path, _, token)) => { if token.expired() { if let Err(err) = tokio::fs::remove_file(path).await { tracing::error!("Failed to remove expired token: {}", err); } Ok(None) } else { Ok(Some(token)) } }, None => Ok(None) } } async fn list_refresh_tokens(&self, website: &url::Url) -> Result<HashMap<String, TokenData>> { let dir = format!("{}/tokens", FileBackend::url_to_dir(website)); self.list_files(&dir, "refresh").await } async fn revoke_refresh_token(&self, website: &url::Url, token: &str) -> Result<()> { match tokio::fs::remove_file( self.path .join(FileBackend::url_to_dir(website)) .join("tokens") .join(format!("refresh.{}", FileBackend::sanitize_for_path(token))) ).await { Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()), result => result } } // Password management. #[tracing::instrument(skip(password))] async fn verify_password(&self, website: &url::Url, password: String) -> Result<bool> { use argon2::{Argon2, password_hash::{PasswordHash, PasswordVerifier}}; let password_filename = self.path .join(FileBackend::url_to_dir(website)) .join("password"); tracing::debug!("Reading password for {} from {}", website, password_filename.display()); match tokio::fs::read_to_string(password_filename).await { Ok(password_hash) => { let parsed_hash = { let hash = password_hash.trim(); #[cfg(debug_assertions)] tracing::debug!("Password hash: {}", hash); PasswordHash::new(hash) .expect("Password hash should be valid!") }; Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok()) }, Err(err) => if err.kind() == std::io::ErrorKind::NotFound { Ok(false) } else { Err(err) } } } #[tracing::instrument(skip(password))] async fn enroll_password(&self, website: &url::Url, password: String) -> Result<()> { use argon2::{Argon2, password_hash::{rand_core::OsRng, PasswordHasher, SaltString}}; let password_filename = self.path .join(FileBackend::url_to_dir(website)) .join("password"); let salt = SaltString::generate(&mut OsRng); let argon2 = Argon2::default(); let password_hash = argon2.hash_password(password.as_bytes(), &salt) .expect("Hashing a password should not error out") .to_string(); tracing::debug!("Enrolling password for {} at {}", website, password_filename.display()); tokio::fs::write(password_filename, password_hash.as_bytes()).await } // WebAuthn credential management. #[cfg(feature = "webauthn")] async fn enroll_webauthn(&self, website: &url::Url, credential: Passkey) -> Result<()> { todo!() } #[cfg(feature = "webauthn")] async fn list_webauthn_pubkeys(&self, website: &url::Url) -> Result<Vec<Passkey>> { // TODO stub! Ok(vec![]) } #[cfg(feature = "webauthn")] async fn persist_registration_challenge( &self, website: &url::Url, state: PasskeyRegistration ) -> Result<String> { todo!() } #[cfg(feature = "webauthn")] async fn retrieve_registration_challenge( &self, website: &url::Url, challenge_id: &str ) -> Result<PasskeyRegistration> { todo!() } #[cfg(feature = "webauthn")] async fn persist_authentication_challenge( &self, website: &url::Url, state: PasskeyAuthentication ) -> Result<String> { todo!() } #[cfg(feature = "webauthn")] async fn retrieve_authentication_challenge( &self, website: &url::Url, challenge_id: &str ) -> Result<PasskeyAuthentication> { todo!() } async fn list_user_credential_types(&self, website: &url::Url) -> Result<Vec<EnrolledCredential>> { let mut creds = vec![]; match tokio::fs::metadata(self.path .join(FileBackend::url_to_dir(website)) .join("password")) .await { Ok(_) => creds.push(EnrolledCredential::Password), Err(err) => if err.kind() != std::io::ErrorKind::NotFound { return Err(err) } } #[cfg(feature = "webauthn")] if !self.list_webauthn_pubkeys(website).await?.is_empty() { creds.push(EnrolledCredential::WebAuthn); } Ok(creds) } }