about summary refs log tree commit diff
diff options
context:
space:
mode:
authorVika <vika@fireburn.ru>2024-03-25 03:33:08 +0300
committerVika <vika@fireburn.ru>2024-06-14 22:21:26 +0300
commit919bc2e9973bf57b2e2fe09ed0456fb0d07bdae9 (patch)
treea651aeb49955e9b36d83a91798dd78acc994fb1e
parent1e815637e3e15c7eb81b45b51b40253f3ec57ebb (diff)
downloadkittybox-919bc2e9973bf57b2e2fe09ed0456fb0d07bdae9.tar.zst
Prototype sanitizer for SQL
This might allow me to use SQL syntax in Kittybox's private search
interfaces, allowing for queries of incredible specificity while not
allowing to query private data or inject arbitrary SQL.
-rw-r--r--Cargo.lock12
-rw-r--r--Cargo.toml4
-rw-r--r--examples/sql.rs129
3 files changed, 145 insertions, 0 deletions
diff --git a/Cargo.lock b/Cargo.lock
index ef2030b..b0cf21b 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1646,6 +1646,7 @@ dependencies = [
  "serde_urlencoded",
  "serde_variant",
  "sha2",
+ "sqlparser",
  "sqlx",
  "tempfile",
  "thiserror",
@@ -3342,6 +3343,17 @@ dependencies = [
 ]
 
 [[package]]
+name = "sqlparser"
+version = "0.44.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "aaf9c7ff146298ffda83a200f8d5084f08dcee1edfc135fcc1d646a45d50ffd6"
+dependencies = [
+ "log",
+ "serde",
+ "serde_json",
+]
+
+[[package]]
 name = "sqlx"
 version = "0.7.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index 809e68b..f32ce6d 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -155,4 +155,8 @@ version = "1.0.35"
 [dependencies.sqlx]
 version = "^0.7"
 features = ["uuid", "chrono", "json", "postgres", "runtime-tokio"]
+optional = true
+[dependencies.sqlparser]
+version = "0.44.0"
+features = ["serde", "serde_json"]
 optional = true
\ No newline at end of file
diff --git a/examples/sql.rs b/examples/sql.rs
new file mode 100644
index 0000000..c28f8dc
--- /dev/null
+++ b/examples/sql.rs
@@ -0,0 +1,129 @@
+use sqlparser::ast::Expr;
+
+#[derive(Debug, thiserror::Error)]
+enum Error<'sql> {
+    #[error("subquery detected: {0}")]
+    SubqueryDetected(&'sql sqlparser::ast::Query),
+    #[error("function call detected: {0}")]
+    FunctionCallDetected(&'sql sqlparser::ast::Function),
+    #[error("sql parser error: {0}")]
+    Sql(#[from] sqlparser::parser::ParserError),
+}
+
+fn sanitize(expr: &Expr) -> Result<(), Error> {
+    match expr {
+        Expr::Identifier(_) => Ok(()),
+        Expr::CompoundIdentifier(_) => Ok(()),
+        Expr::JsonAccess { left, operator: _, right } => sanitize(left)
+            .and(sanitize(right)),
+        Expr::CompositeAccess { expr, key: _ } => sanitize(expr),
+        Expr::IsFalse(subexpr) => sanitize(subexpr),
+        Expr::IsNotFalse(subexpr) => sanitize(subexpr),
+        Expr::IsTrue(subexpr) => sanitize(subexpr),
+        Expr::IsNotTrue(subexpr) => sanitize(subexpr),
+        Expr::IsNull(subexpr) => sanitize(subexpr),
+        Expr::IsNotNull(subexpr) => sanitize(subexpr),
+        Expr::IsUnknown(subexpr) => sanitize(subexpr),
+        Expr::IsNotUnknown(subexpr) => sanitize(subexpr),
+        Expr::IsDistinctFrom(left, right) => sanitize(left)
+            .and(sanitize(right)),
+        Expr::IsNotDistinctFrom(left, right) => sanitize(left)
+            .and(sanitize(right)),
+        Expr::InList { expr, list, negated: _ } => sanitize(expr)
+            .and(list.iter().try_for_each(sanitize)),
+        Expr::InSubquery { expr: _, subquery, negated: _ } => Err(Error::SubqueryDetected(subquery.as_ref())),
+        Expr::InUnnest { expr, array_expr, negated: _ } => sanitize(expr).and(sanitize(array_expr)),
+        Expr::Between { expr, negated: _, low, high } => sanitize(expr)
+            .and(sanitize(low))
+            .and(sanitize(high)),
+        Expr::BinaryOp { left, op: _, right } => sanitize(left)
+            .and(sanitize(right)),
+        Expr::Like { negated: _, expr, pattern, escape_char: _ } => sanitize(expr)
+            .and(sanitize(pattern)),
+        Expr::ILike { negated: _, expr, pattern, escape_char: _ } => sanitize(expr).and(sanitize(pattern)),
+        Expr::SimilarTo { negated: _, expr, pattern, escape_char: _ } => sanitize(expr).and(sanitize(pattern)),
+        Expr::RLike { negated: _, expr, pattern, regexp: _ } => sanitize(expr).and(sanitize(pattern)),
+        Expr::AnyOp { left, compare_op: _, right } => sanitize(left).and(sanitize(right)),
+        Expr::AllOp { left, compare_op: _, right } => sanitize(left).and(sanitize(right)),
+        Expr::UnaryOp { op: _, expr } => sanitize(expr),
+        Expr::Convert { expr, data_type: _, charset: _, target_before_value: _ } => sanitize(expr),
+        Expr::Cast { expr, data_type: _, format: _ } => sanitize(expr),
+        Expr::TryCast { expr, data_type: _, format: _ } => sanitize(expr),
+        Expr::SafeCast { expr, data_type: _, format: _ } => sanitize(expr),
+        Expr::AtTimeZone { timestamp, time_zone: _ } => sanitize(timestamp),
+        Expr::Extract { field: _, expr } => sanitize(expr),
+        Expr::Ceil { expr, field: _ } => sanitize(expr),
+        Expr::Floor { expr, field: _ } => sanitize(expr),
+        Expr::Position { expr, r#in } => sanitize(expr).and(sanitize(r#in)),
+        Expr::Substring { expr, substring_from, substring_for, special: _ } => sanitize(expr)
+            .and(substring_from.as_deref().map(sanitize).unwrap_or(Ok(())))
+            .and(substring_for.as_deref().map(sanitize).unwrap_or(Ok(()))),
+        Expr::Trim { expr, trim_where: _, trim_what, trim_characters } => sanitize(expr)
+            .and(trim_what.as_deref().map(sanitize).unwrap_or(Ok(())))
+            .and(
+                trim_characters
+                    .as_ref()
+                    .map(|v| v.iter())
+                    .map(|mut iter| iter.try_for_each(sanitize))
+                    .unwrap_or(Ok(()))
+            ),
+        Expr::Overlay { expr, overlay_what, overlay_from, overlay_for } => sanitize(expr)
+            .and(sanitize(overlay_what))
+            .and(sanitize(overlay_from))
+            .and(overlay_for.as_deref().map(sanitize).unwrap_or(Ok(()))),
+        Expr::Collate { expr, collation: _ } => sanitize(expr),
+        Expr::Nested(subexpr) => sanitize(subexpr),
+        Expr::Value(_) => Ok(()),
+        Expr::IntroducedString { introducer: _, value: _ } => Ok(()),
+        Expr::TypedString { data_type: _, value: _ } => Ok(()),
+        Expr::MapAccess { column, keys } => sanitize(column).and(keys.iter().try_for_each(sanitize)),
+        Expr::Function(func) => Err(Error::FunctionCallDetected(func)),
+        Expr::AggregateExpressionWithFilter { expr, filter } => sanitize(expr).and(sanitize(filter)),
+        Expr::Case { operand, conditions, results, else_result } => conditions.iter()
+            .chain(results)
+            .chain(operand.iter().map(std::borrow::Borrow::borrow))
+            .chain(else_result.iter().map(std::borrow::Borrow::borrow))
+            .try_for_each(sanitize),
+        Expr::Exists { subquery, negated: _ } => Err(Error::SubqueryDetected(subquery)),
+        Expr::Subquery(subquery) => Err(Error::SubqueryDetected(subquery.as_ref())),
+        Expr::ArraySubquery(subquery) => Err(Error::SubqueryDetected(subquery.as_ref())),
+        Expr::ListAgg(agg) => sanitize(&agg.expr),
+        Expr::ArrayAgg(agg) => sanitize(&agg.expr),
+        Expr::GroupingSets(sets) => sets.iter()
+            .map(|i| i.iter())
+            .try_for_each(|mut si| si.try_for_each(sanitize)),
+        Expr::Cube(cube) => cube.iter()
+            .map(|i| i.iter())
+            .try_for_each(|mut si| si.try_for_each(sanitize)),
+        Expr::Rollup(rollup) => rollup.iter()
+            .map(|i| i.iter())
+            .try_for_each(|mut si| si.try_for_each(sanitize)),
+        Expr::Tuple(tuple) => tuple.iter().try_for_each(sanitize),
+        Expr::Struct { values, fields: _ } => values.iter().try_for_each(sanitize),
+        Expr::Named { expr, name: _ } => sanitize(expr),
+        Expr::ArrayIndex { obj, indexes } => sanitize(obj)
+            .and(indexes.iter().try_for_each(sanitize)),
+        Expr::Array(array) => array.elem.iter().try_for_each(sanitize),
+        Expr::Interval(interval) => sanitize(&interval.value),
+        Expr::MatchAgainst { .. } => Ok(()),
+        Expr::Wildcard => Ok(()),
+        Expr::QualifiedWildcard(_) => Ok(()),
+        Expr::OuterJoin(expr) => sanitize(expr),
+    }
+}
+
+fn main() -> Result<(), sqlparser::parser::ParserError> {
+    let query = std::env::args().skip(1).take(1).next().unwrap();
+    static DIALECT: sqlparser::dialect::PostgreSqlDialect = sqlparser::dialect::PostgreSqlDialect {};
+
+    let parser: sqlparser::parser::Parser<'static> = sqlparser::parser::Parser::new(&DIALECT);
+    let expr = parser.try_with_sql(&query)?.parse_expr()?;
+    match sanitize(&expr) {
+        Ok(_) => eprintln!("{0:#?}\n\n{0}", expr),
+        Err(err) => {
+            eprintln!("{}", err);
+        }
+    }
+    
+    Ok(())
+}