use std::{path::PathBuf, collections::HashMap, borrow::Cow, time::{SystemTime, Duration}};
use super::{AuthBackend, Result, EnrolledCredential};
use async_trait::async_trait;
use kittybox_indieauth::{
AuthorizationRequest, TokenData
};
use serde::de::DeserializeOwned;
use tokio::{task::spawn_blocking, io::AsyncReadExt};
#[cfg(feature = "webauthn")]
use webauthn::prelude::{Passkey, PasskeyRegistration, PasskeyAuthentication};
const CODE_LENGTH: usize = 16;
const TOKEN_LENGTH: usize = 128;
const CODE_DURATION: std::time::Duration = std::time::Duration::from_secs(600);
#[derive(Clone, Debug)]
pub struct FileBackend {
path: PathBuf,
}
impl FileBackend {
/// Sanitize a filename, leaving only alphanumeric characters.
///
/// Doesn't allocate a new string unless non-alphanumeric
/// characters are encountered.
fn sanitize_for_path(filename: &'_ str) -> Cow<'_, str> {
if filename.chars().all(char::is_alphanumeric) {
Cow::Borrowed(filename)
} else {
let mut s = String::with_capacity(filename.len());
filename.chars()
.filter(|c| c.is_alphanumeric())
.for_each(|c| s.push(c));
Cow::Owned(s)
}
}
#[inline]
async fn serialize_to_file<T: 'static + serde::ser::Serialize + Send, B: Into<Option<&'static str>>>(
&self,
dir: &str,
basename: B,
length: usize,
data: T
) -> Result<String> {
let basename = basename.into();
let has_ext = basename.is_some();
let (filename, mut file) = kittybox_util::fs::mktemp(
self.path.join(dir), basename, length
)
.await
.map(|(name, file)| (name, file.try_into_std().unwrap()))?;
spawn_blocking(move || serde_json::to_writer(&mut file, &data))
.await
.unwrap_or_else(|e| panic!(
"Panic while serializing {}: {}",
std::any::type_name::<T>(),
e
))
.map(move |_| {
(if has_ext {
filename
.extension()
} else {
filename
.file_name()
})
.unwrap()
.to_str()
.unwrap()
.to_owned()
})
.map_err(|err| err.into())
}
#[inline]
async fn deserialize_from_file<'filename, 'this: 'filename, T, B>(
&'this self,
dir: &'filename str,
basename: B,
filename: &'filename str,
) -> Result<Option<(PathBuf, SystemTime, T)>>
where
T: serde::de::DeserializeOwned + Send,
B: Into<Option<&'static str>>
{
let basename = basename.into();
let path = self.path
.join(dir)
.join(format!(
"{}{}{}",
basename.unwrap_or(""),
if basename.is_none() { "" } else { "." },
FileBackend::sanitize_for_path(filename)
));
let data = match tokio::fs::File::open(&path).await {
Ok(mut file) => {
let mut buf = Vec::new();
file.read_to_end(&mut buf).await?;
match serde_json::from_slice::<'_, T>(buf.as_slice()) {
Ok(data) => data,
Err(err) => return Err(err.into())
}
},
Err(err) => if err.kind() == std::io::ErrorKind::NotFound {
return Ok(None)
} else {
return Err(err)
}
};
let ctime = tokio::fs::metadata(&path).await?.created()?;
Ok(Some((path, ctime, data)))
}
#[inline]
fn url_to_dir(url: &url::Url) -> String {
let host = url.host_str().unwrap();
let port = url.port()
.map(|port| Cow::Owned(format!(":{}", port)))
.unwrap_or(Cow::Borrowed(""));
format!("{}{}", host, port)
}
async fn list_files<'dir, 'this: 'dir, T: DeserializeOwned + Send>(
&'this self,
dir: &'dir str,
prefix: &'static str
) -> Result<HashMap<String, T>> {
let dir = self.path.join(dir);
let mut hashmap = HashMap::new();
let mut readdir = match tokio::fs::read_dir(dir).await {
Ok(readdir) => readdir,
Err(err) => if err.kind() == std::io::ErrorKind::NotFound {
// empty hashmap
return Ok(hashmap);
} else {
return Err(err);
}
};
while let Some(entry) = readdir.next_entry().await? {
// safe to unwrap; filenames are alphanumeric
let filename = entry.file_name()
.into_string()
.expect("token filenames should be alphanumeric!");
if let Some(token) = filename.strip_prefix(&format!("{}.", prefix)) {
match tokio::fs::File::open(entry.path()).await {
Ok(mut file) => {
let mut buf = Vec::new();
file.read_to_end(&mut buf).await?;
match serde_json::from_slice::<'_, T>(buf.as_slice()) {
Ok(data) => hashmap.insert(token.to_string(), data),
Err(err) => {
tracing::error!(
"Error decoding token data from file {}: {}",
entry.path().display(), err
);
continue;
}
};
},
Err(err) => if err.kind() == std::io::ErrorKind::NotFound {
continue
} else {
return Err(err)
}
}
}
}
Ok(hashmap)
}
}
#[async_trait]
impl AuthBackend for FileBackend {
async fn new(path: &'_ url::Url) -> Result<Self> {
Ok(Self {
path: std::path::PathBuf::from(path.path())
})
}
// Authorization code management.
async fn create_code(&self, data: AuthorizationRequest) -> Result<String> {
self.serialize_to_file("codes", None, CODE_LENGTH, data).await
}
async fn get_code(&self, code: &str) -> Result<Option<AuthorizationRequest>> {
match self.deserialize_from_file("codes", None, FileBackend::sanitize_for_path(code).as_ref()).await? {
Some((path, ctime, data)) => {
if let Err(err) = tokio::fs::remove_file(path).await {
tracing::error!("Failed to clean up authorization code: {}", err);
}
// Err on the safe side in case of clock drift
if ctime.elapsed().unwrap_or(Duration::ZERO) > CODE_DURATION {
Ok(None)
} else {
Ok(Some(data))
}
},
None => Ok(None)
}
}
// Token management.
async fn create_token(&self, data: TokenData) -> Result<String> {
let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me));
self.serialize_to_file(&dir, "access", TOKEN_LENGTH, data).await
}
async fn get_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> {
let dir = format!("{}/tokens", FileBackend::url_to_dir(website));
match self.deserialize_from_file::<TokenData, _>(
&dir, "access",
FileBackend::sanitize_for_path(token).as_ref()
).await? {
Some((path, _, token)) => {
if token.expired() {
if let Err(err) = tokio::fs::remove_file(path).await {
tracing::error!("Failed to remove expired token: {}", err);
}
Ok(None)
} else {
Ok(Some(token))
}
},
None => Ok(None)
}
}
async fn list_tokens(&self, website: &url::Url) -> Result<HashMap<String, TokenData>> {
let dir = format!("{}/tokens", FileBackend::url_to_dir(website));
self.list_files(&dir, "access").await
}
async fn revoke_token(&self, website: &url::Url, token: &str) -> Result<()> {
match tokio::fs::remove_file(
self.path
.join(FileBackend::url_to_dir(website))
.join("tokens")
.join(format!("access.{}", FileBackend::sanitize_for_path(token)))
).await {
Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
result => result
}
}
// Refresh token management.
async fn create_refresh_token(&self, data: TokenData) -> Result<String> {
let dir = format!("{}/tokens", FileBackend::url_to_dir(&data.me));
self.serialize_to_file(&dir, "refresh", TOKEN_LENGTH, data).await
}
async fn get_refresh_token(&self, website: &url::Url, token: &str) -> Result<Option<TokenData>> {
let dir = format!("{}/tokens", FileBackend::url_to_dir(website));
match self.deserialize_from_file::<TokenData, _>(
&dir, "refresh",
FileBackend::sanitize_for_path(token).as_ref()
).await? {
Some((path, _, token)) => {
if token.expired() {
if let Err(err) = tokio::fs::remove_file(path).await {
tracing::error!("Failed to remove expired token: {}", err);
}
Ok(None)
} else {
Ok(Some(token))
}
},
None => Ok(None)
}
}
async fn list_refresh_tokens(&self, website: &url::Url) -> Result<HashMap<String, TokenData>> {
let dir = format!("{}/tokens", FileBackend::url_to_dir(website));
self.list_files(&dir, "refresh").await
}
async fn revoke_refresh_token(&self, website: &url::Url, token: &str) -> Result<()> {
match tokio::fs::remove_file(
self.path
.join(FileBackend::url_to_dir(website))
.join("tokens")
.join(format!("refresh.{}", FileBackend::sanitize_for_path(token)))
).await {
Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
result => result
}
}
// Password management.
#[tracing::instrument(skip(password))]
async fn verify_password(&self, website: &url::Url, password: String) -> Result<bool> {
use argon2::{Argon2, password_hash::{PasswordHash, PasswordVerifier}};
let password_filename = self.path
.join(FileBackend::url_to_dir(website))
.join("password");
tracing::debug!("Reading password for {} from {}", website, password_filename.display());
match tokio::fs::read_to_string(password_filename).await {
Ok(password_hash) => {
let parsed_hash = {
let hash = password_hash.trim();
#[cfg(debug_assertions)] tracing::debug!("Password hash: {}", hash);
PasswordHash::new(hash)
.expect("Password hash should be valid!")
};
Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok())
},
Err(err) => if err.kind() == std::io::ErrorKind::NotFound {
Ok(false)
} else {
Err(err)
}
}
}
#[tracing::instrument(skip(password))]
async fn enroll_password(&self, website: &url::Url, password: String) -> Result<()> {
use argon2::{Argon2, password_hash::{rand_core::OsRng, PasswordHasher, SaltString}};
let password_filename = self.path
.join(FileBackend::url_to_dir(website))
.join("password");
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2.hash_password(password.as_bytes(), &salt)
.expect("Hashing a password should not error out")
.to_string();
tracing::debug!("Enrolling password for {} at {}", website, password_filename.display());
tokio::fs::write(password_filename, password_hash.as_bytes()).await
}
// WebAuthn credential management.
#[cfg(feature = "webauthn")]
async fn enroll_webauthn(&self, website: &url::Url, credential: Passkey) -> Result<()> {
todo!()
}
#[cfg(feature = "webauthn")]
async fn list_webauthn_pubkeys(&self, website: &url::Url) -> Result<Vec<Passkey>> {
// TODO stub!
Ok(vec![])
}
#[cfg(feature = "webauthn")]
async fn persist_registration_challenge(
&self,
website: &url::Url,
state: PasskeyRegistration
) -> Result<String> {
todo!()
}
#[cfg(feature = "webauthn")]
async fn retrieve_registration_challenge(
&self,
website: &url::Url,
challenge_id: &str
) -> Result<PasskeyRegistration> {
todo!()
}
#[cfg(feature = "webauthn")]
async fn persist_authentication_challenge(
&self,
website: &url::Url,
state: PasskeyAuthentication
) -> Result<String> {
todo!()
}
#[cfg(feature = "webauthn")]
async fn retrieve_authentication_challenge(
&self,
website: &url::Url,
challenge_id: &str
) -> Result<PasskeyAuthentication> {
todo!()
}
async fn list_user_credential_types(&self, website: &url::Url) -> Result<Vec<EnrolledCredential>> {
let mut creds = vec![];
match tokio::fs::metadata(self.path
.join(FileBackend::url_to_dir(website))
.join("password"))
.await
{
Ok(_) => creds.push(EnrolledCredential::Password),
Err(err) => if err.kind() != std::io::ErrorKind::NotFound {
return Err(err)
}
}
#[cfg(feature = "webauthn")]
if !self.list_webauthn_pubkeys(website).await?.is_empty() {
creds.push(EnrolledCredential::WebAuthn);
}
Ok(creds)
}
}