use async_trait::async_trait;
use futures::stream;
use futures_util::FutureExt;
use futures_util::StreamExt;
use lazy_static::lazy_static;
use log::error;
use mobc::Pool;
use mobc_redis::redis;
use mobc_redis::redis::AsyncCommands;
use mobc_redis::RedisConnectionManager;
use serde_json::json;
use crate::database::{ErrorKind, MicropubChannel, Result, Storage, StorageError};
use crate::indieauth::User;
struct RedisScripts {
edit_post: redis::Script,
}
impl From<mobc_redis::redis::RedisError> for StorageError {
fn from(err: mobc_redis::redis::RedisError) -> Self {
Self {
msg: format!("{}", err),
source: Some(Box::new(err)),
kind: ErrorKind::Backend,
}
}
}
impl From<mobc::Error<mobc_redis::redis::RedisError>> for StorageError {
fn from(err: mobc::Error<mobc_redis::redis::RedisError>) -> Self {
Self {
msg: format!("{}", err),
source: Some(Box::new(err)),
kind: ErrorKind::Backend,
}
}
}
lazy_static! {
static ref SCRIPTS: RedisScripts = RedisScripts {
edit_post: redis::Script::new(include_str!("./edit_post.lua"))
};
}
#[derive(Clone)]
pub struct RedisStorage {
// note to future Vika:
// mobc::Pool is actually a fancy name for an Arc
// around a shared connection pool with a manager
// which makes it safe to implement [`Clone`] and
// not worry about new pools being suddenly made
//
// stop worrying and start coding, you dum-dum
redis: mobc::Pool<RedisConnectionManager>,
}
fn filter_post(mut post: serde_json::Value, user: &'_ Option<String>) -> Option<serde_json::Value> {
if post["properties"]["deleted"][0].is_string() {
return Some(json!({
"type": post["type"],
"properties": {
"deleted": post["properties"]["deleted"]
}
}));
}
let empty_vec: Vec<serde_json::Value> = vec![];
let author = post["properties"]["author"]
.as_array()
.unwrap_or(&empty_vec)
.iter()
.map(|i| i.as_str().unwrap().to_string());
let visibility = post["properties"]["visibility"][0]
.as_str()
.unwrap_or("public");
let mut audience = author.chain(
post["properties"]["audience"]
.as_array()
.unwrap_or(&empty_vec)
.iter()
.map(|i| i.as_str().unwrap().to_string()),
);
if (visibility == "private" && !audience.any(|i| Some(i) == *user))
|| (visibility == "protected" && user.is_none())
{
return None;
}
if post["properties"]["location"].is_array() {
let location_visibility = post["properties"]["location-visibility"][0]
.as_str()
.unwrap_or("private");
let mut author = post["properties"]["author"]
.as_array()
.unwrap_or(&empty_vec)
.iter()
.map(|i| i.as_str().unwrap().to_string());
if location_visibility == "private" && !author.any(|i| Some(i) == *user) {
post["properties"]
.as_object_mut()
.unwrap()
.remove("location");
}
}
Some(post)
}
#[async_trait]
impl Storage for RedisStorage {
async fn get_setting<'a>(&self, setting: &'a str, user: &'a str) -> Result<String> {
let mut conn = self.redis.get().await?;
Ok(conn
.hget::<String, &str, String>(format!("settings_{}", user), setting)
.await?)
}
async fn set_setting<'a>(&self, setting: &'a str, user: &'a str, value: &'a str) -> Result<()> {
let mut conn = self.redis.get().await?;
Ok(conn
.hset::<String, &str, &str, ()>(format!("settings_{}", user), setting, value)
.await?)
}
async fn delete_post<'a>(&self, url: &'a str) -> Result<()> {
let mut conn = self.redis.get().await?;
Ok(conn.hdel::<&str, &str, ()>("posts", url).await?)
}
async fn post_exists(&self, url: &str) -> Result<bool> {
let mut conn = self.redis.get().await?;
Ok(conn.hexists::<&str, &str, bool>(&"posts", url).await?)
}
async fn get_post(&self, url: &str) -> Result<Option<serde_json::Value>> {
let mut conn = self.redis.get().await?;
match conn
.hget::<&str, &str, Option<String>>(&"posts", url)
.await?
{
Some(val) => {
let parsed = serde_json::from_str::<serde_json::Value>(&val)?;
if let Some(new_url) = parsed["see_other"].as_str() {
match conn
.hget::<&str, &str, Option<String>>(&"posts", new_url)
.await?
{
Some(val) => Ok(Some(serde_json::from_str::<serde_json::Value>(&val)?)),
None => Ok(None),
}
} else {
Ok(Some(parsed))
}
}
None => Ok(None),
}
}
async fn get_channels(&self, user: &User) -> Result<Vec<MicropubChannel>> {
let mut conn = self.redis.get().await?;
let channels = conn
.smembers::<String, Vec<String>>("channels_".to_string() + user.me.as_str())
.await?;
// TODO: use streams here instead of this weird thing... how did I even write this?!
Ok(futures_util::future::join_all(
channels
.iter()
.map(|channel| {
self.get_post(channel).map(|result| result.unwrap()).map(
|post: Option<serde_json::Value>| {
post.map(|post| MicropubChannel {
uid: post["properties"]["uid"][0].as_str().unwrap().to_string(),
name: post["properties"]["name"][0].as_str().unwrap().to_string(),
})
},
)
})
.collect::<Vec<_>>(),
)
.await
.into_iter()
.flatten()
.collect::<Vec<_>>())
}
async fn put_post<'a>(&self, post: &'a serde_json::Value, user: &'a str) -> Result<()> {
let mut conn = self.redis.get().await?;
let key: &str;
match post["properties"]["uid"][0].as_str() {
Some(uid) => key = uid,
None => {
return Err(StorageError::new(
ErrorKind::BadRequest,
"post doesn't have a UID",
))
}
}
conn.hset::<&str, &str, String, ()>(&"posts", key, post.to_string())
.await?;
if post["properties"]["url"].is_array() {
for url in post["properties"]["url"]
.as_array()
.unwrap()
.iter()
.map(|i| i.as_str().unwrap().to_string())
{
if url != key && url.starts_with(user) {
conn.hset::<&str, &str, String, ()>(
&"posts",
&url,
json!({ "see_other": key }).to_string(),
)
.await?;
}
}
}
if post["type"]
.as_array()
.unwrap()
.iter()
.any(|i| i == "h-feed")
{
// This is a feed. Add it to the channels array if it's not already there.
conn.sadd::<String, &str, ()>(
"channels_".to_string() + post["properties"]["author"][0].as_str().unwrap(),
key,
)
.await?
}
Ok(())
}
async fn read_feed_with_limit<'a>(
&self,
url: &'a str,
after: &'a Option<String>,
limit: usize,
user: &'a Option<String>,
) -> Result<Option<serde_json::Value>> {
let mut conn = self.redis.get().await?;
let mut feed;
match conn
.hget::<&str, &str, Option<String>>(&"posts", url)
.await?
{
Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?,
None => return Ok(None),
}
if feed["see_other"].is_string() {
match conn
.hget::<&str, &str, Option<String>>(&"posts", feed["see_other"].as_str().unwrap())
.await?
{
Some(post) => feed = serde_json::from_str::<serde_json::Value>(&post)?,
None => return Ok(None),
}
}
if let Some(post) = filter_post(feed, user) {
feed = post
} else {
return Err(StorageError::new(
ErrorKind::PermissionDenied,
"specified user cannot access this post",
));
}
if feed["children"].is_array() {
let children = feed["children"].as_array().unwrap();
let posts_iter: Box<dyn std::iter::Iterator<Item = String> + Send>;
// TODO: refactor this to apply the skip on the &mut iterator
if let Some(after) = after {
posts_iter = Box::new(
children
.iter()
.map(|i| i.as_str().unwrap().to_string())
.skip_while(move |i| i != after)
.skip(1),
);
} else {
posts_iter = Box::new(children.iter().map(|i| i.as_str().unwrap().to_string()));
}
let posts = stream::iter(posts_iter)
.map(|url| async move {
match self.redis.get().await {
Ok(mut conn) => {
match conn.hget::<&str, &str, Option<String>>("posts", &url).await {
Ok(post) => match post {
Some(post) => {
match serde_json::from_str::<serde_json::Value>(&post) {
Ok(post) => Some(post),
Err(err) => {
let err = StorageError::from(err);
error!("{}", err);
panic!("{}", err)
}
}
}
// Happens because of a broken link (result of an improper deletion?)
None => None,
},
Err(err) => {
let err = StorageError::from(err);
error!("{}", err);
panic!("{}", err)
}
}
}
// TODO: Instead of causing a panic, investigate how can you fail the whole stream
// Somehow fuse it maybe?
Err(err) => {
let err = StorageError::from(err);
error!("{}", err);
panic!("{}", err)
}
}
})
// TODO: determine the optimal value for this buffer
// It will probably depend on how often can you encounter a private post on the page
// It shouldn't be too large, or we'll start fetching too many posts from the database
// It MUST NOT be larger than the typical page size
// It MUST NOT be a significant amount of the connection pool size
.buffered(std::cmp::min(3, limit))
// Hack to unwrap the Option and sieve out broken links
// Broken links return None, and Stream::filter_map skips all Nones.
.filter_map(|post: Option<serde_json::Value>| async move { post })
.filter_map(|post| async move { filter_post(post, user) })
.take(limit);
// TODO: Instead of catching panics, find a way to make the whole stream fail with Result<Vec<serde_json::Value>>
match std::panic::AssertUnwindSafe(posts.collect::<Vec<serde_json::Value>>())
.catch_unwind()
.await
{
Ok(posts) => feed["children"] = json!(posts),
Err(_) => {
return Err(StorageError::new(
ErrorKind::Other,
"Unknown error encountered while assembling feed, see logs for more info",
))
}
}
}
return Ok(Some(feed));
}
async fn update_post<'a>(&self, mut url: &'a str, update: serde_json::Value) -> Result<()> {
let mut conn = self.redis.get().await?;
if !conn
.hexists::<&str, &str, bool>("posts", url)
.await
.unwrap()
{
return Err(StorageError::new(
ErrorKind::NotFound,
"can't edit a non-existent post",
));
}
let post: serde_json::Value =
serde_json::from_str(&conn.hget::<&str, &str, String>("posts", url).await?)?;
if let Some(new_url) = post["see_other"].as_str() {
url = new_url
}
Ok(SCRIPTS
.edit_post
.key("posts")
.arg(url)
.arg(update.to_string())
.invoke_async::<_, ()>(&mut conn as &mut redis::aio::Connection)
.await?)
}
}
impl RedisStorage {
/// Create a new RedisDatabase that will connect to Redis at `redis_uri` to store data.
pub async fn new(redis_uri: String) -> Result<Self> {
match redis::Client::open(redis_uri) {
Ok(client) => Ok(Self {
redis: Pool::builder()
.max_open(20)
.build(RedisConnectionManager::new(client)),
}),
Err(e) => Err(e.into()),
}
}
}
#[cfg(test)]
pub mod tests {
use mobc_redis::redis;
use std::process;
use std::time::Duration;
pub struct RedisInstance {
// We just need to hold on to it so it won't get dropped and remove the socket
_tempdir: tempdir::TempDir,
uri: String,
child: std::process::Child,
}
impl Drop for RedisInstance {
fn drop(&mut self) {
self.child.kill().expect("Failed to kill the child!");
}
}
impl RedisInstance {
pub fn uri(&self) -> &str {
&self.uri
}
}
pub async fn get_redis_instance() -> RedisInstance {
let tempdir = tempdir::TempDir::new("redis").expect("failed to create tempdir");
let socket = tempdir.path().join("redis.sock");
let redis_child = process::Command::new("redis-server")
.current_dir(&tempdir)
.arg("--port")
.arg("0")
.arg("--unixsocket")
.arg(&socket)
.stdout(process::Stdio::null())
.stderr(process::Stdio::null())
.spawn()
.expect("Failed to spawn Redis");
println!("redis+unix:///{}", socket.to_str().unwrap());
let uri = format!("redis+unix:///{}", socket.to_str().unwrap());
// There should be a slight delay, we need to wait for Redis to spin up
let client = redis::Client::open(uri.clone()).unwrap();
let millisecond = Duration::from_millis(1);
let mut retries: usize = 0;
const MAX_RETRIES: usize = 60 * 1000/*ms*/;
while let Err(err) = client.get_connection() {
if err.is_connection_refusal() {
async_std::task::sleep(millisecond).await;
retries += 1;
if retries > MAX_RETRIES {
panic!("Timeout waiting for Redis, last error: {}", err);
}
} else {
panic!("Could not connect: {}", err);
}
}
return RedisInstance {
uri,
child: redis_child,
_tempdir: tempdir,
};
}
}