diff options
-rw-r--r-- | Cargo.lock | 12 | ||||
-rw-r--r-- | Cargo.toml | 4 | ||||
-rw-r--r-- | examples/sql.rs | 129 |
3 files changed, 145 insertions, 0 deletions
diff --git a/Cargo.lock b/Cargo.lock index ef2030b..b0cf21b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1646,6 +1646,7 @@ dependencies = [ "serde_urlencoded", "serde_variant", "sha2", + "sqlparser", "sqlx", "tempfile", "thiserror", @@ -3342,6 +3343,17 @@ dependencies = [ ] [[package]] +name = "sqlparser" +version = "0.44.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aaf9c7ff146298ffda83a200f8d5084f08dcee1edfc135fcc1d646a45d50ffd6" +dependencies = [ + "log", + "serde", + "serde_json", +] + +[[package]] name = "sqlx" version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" diff --git a/Cargo.toml b/Cargo.toml index 809e68b..f32ce6d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -155,4 +155,8 @@ version = "1.0.35" [dependencies.sqlx] version = "^0.7" features = ["uuid", "chrono", "json", "postgres", "runtime-tokio"] +optional = true +[dependencies.sqlparser] +version = "0.44.0" +features = ["serde", "serde_json"] optional = true \ No newline at end of file diff --git a/examples/sql.rs b/examples/sql.rs new file mode 100644 index 0000000..c28f8dc --- /dev/null +++ b/examples/sql.rs @@ -0,0 +1,129 @@ +use sqlparser::ast::Expr; + +#[derive(Debug, thiserror::Error)] +enum Error<'sql> { + #[error("subquery detected: {0}")] + SubqueryDetected(&'sql sqlparser::ast::Query), + #[error("function call detected: {0}")] + FunctionCallDetected(&'sql sqlparser::ast::Function), + #[error("sql parser error: {0}")] + Sql(#[from] sqlparser::parser::ParserError), +} + +fn sanitize(expr: &Expr) -> Result<(), Error> { + match expr { + Expr::Identifier(_) => Ok(()), + Expr::CompoundIdentifier(_) => Ok(()), + Expr::JsonAccess { left, operator: _, right } => sanitize(left) + .and(sanitize(right)), + Expr::CompositeAccess { expr, key: _ } => sanitize(expr), + Expr::IsFalse(subexpr) => sanitize(subexpr), + Expr::IsNotFalse(subexpr) => sanitize(subexpr), + Expr::IsTrue(subexpr) => sanitize(subexpr), + Expr::IsNotTrue(subexpr) => sanitize(subexpr), + Expr::IsNull(subexpr) => sanitize(subexpr), + Expr::IsNotNull(subexpr) => sanitize(subexpr), + Expr::IsUnknown(subexpr) => sanitize(subexpr), + Expr::IsNotUnknown(subexpr) => sanitize(subexpr), + Expr::IsDistinctFrom(left, right) => sanitize(left) + .and(sanitize(right)), + Expr::IsNotDistinctFrom(left, right) => sanitize(left) + .and(sanitize(right)), + Expr::InList { expr, list, negated: _ } => sanitize(expr) + .and(list.iter().try_for_each(sanitize)), + Expr::InSubquery { expr: _, subquery, negated: _ } => Err(Error::SubqueryDetected(subquery.as_ref())), + Expr::InUnnest { expr, array_expr, negated: _ } => sanitize(expr).and(sanitize(array_expr)), + Expr::Between { expr, negated: _, low, high } => sanitize(expr) + .and(sanitize(low)) + .and(sanitize(high)), + Expr::BinaryOp { left, op: _, right } => sanitize(left) + .and(sanitize(right)), + Expr::Like { negated: _, expr, pattern, escape_char: _ } => sanitize(expr) + .and(sanitize(pattern)), + Expr::ILike { negated: _, expr, pattern, escape_char: _ } => sanitize(expr).and(sanitize(pattern)), + Expr::SimilarTo { negated: _, expr, pattern, escape_char: _ } => sanitize(expr).and(sanitize(pattern)), + Expr::RLike { negated: _, expr, pattern, regexp: _ } => sanitize(expr).and(sanitize(pattern)), + Expr::AnyOp { left, compare_op: _, right } => sanitize(left).and(sanitize(right)), + Expr::AllOp { left, compare_op: _, right } => sanitize(left).and(sanitize(right)), + Expr::UnaryOp { op: _, expr } => sanitize(expr), + Expr::Convert { expr, data_type: _, charset: _, target_before_value: _ } => sanitize(expr), + Expr::Cast { expr, data_type: _, format: _ } => sanitize(expr), + Expr::TryCast { expr, data_type: _, format: _ } => sanitize(expr), + Expr::SafeCast { expr, data_type: _, format: _ } => sanitize(expr), + Expr::AtTimeZone { timestamp, time_zone: _ } => sanitize(timestamp), + Expr::Extract { field: _, expr } => sanitize(expr), + Expr::Ceil { expr, field: _ } => sanitize(expr), + Expr::Floor { expr, field: _ } => sanitize(expr), + Expr::Position { expr, r#in } => sanitize(expr).and(sanitize(r#in)), + Expr::Substring { expr, substring_from, substring_for, special: _ } => sanitize(expr) + .and(substring_from.as_deref().map(sanitize).unwrap_or(Ok(()))) + .and(substring_for.as_deref().map(sanitize).unwrap_or(Ok(()))), + Expr::Trim { expr, trim_where: _, trim_what, trim_characters } => sanitize(expr) + .and(trim_what.as_deref().map(sanitize).unwrap_or(Ok(()))) + .and( + trim_characters + .as_ref() + .map(|v| v.iter()) + .map(|mut iter| iter.try_for_each(sanitize)) + .unwrap_or(Ok(())) + ), + Expr::Overlay { expr, overlay_what, overlay_from, overlay_for } => sanitize(expr) + .and(sanitize(overlay_what)) + .and(sanitize(overlay_from)) + .and(overlay_for.as_deref().map(sanitize).unwrap_or(Ok(()))), + Expr::Collate { expr, collation: _ } => sanitize(expr), + Expr::Nested(subexpr) => sanitize(subexpr), + Expr::Value(_) => Ok(()), + Expr::IntroducedString { introducer: _, value: _ } => Ok(()), + Expr::TypedString { data_type: _, value: _ } => Ok(()), + Expr::MapAccess { column, keys } => sanitize(column).and(keys.iter().try_for_each(sanitize)), + Expr::Function(func) => Err(Error::FunctionCallDetected(func)), + Expr::AggregateExpressionWithFilter { expr, filter } => sanitize(expr).and(sanitize(filter)), + Expr::Case { operand, conditions, results, else_result } => conditions.iter() + .chain(results) + .chain(operand.iter().map(std::borrow::Borrow::borrow)) + .chain(else_result.iter().map(std::borrow::Borrow::borrow)) + .try_for_each(sanitize), + Expr::Exists { subquery, negated: _ } => Err(Error::SubqueryDetected(subquery)), + Expr::Subquery(subquery) => Err(Error::SubqueryDetected(subquery.as_ref())), + Expr::ArraySubquery(subquery) => Err(Error::SubqueryDetected(subquery.as_ref())), + Expr::ListAgg(agg) => sanitize(&agg.expr), + Expr::ArrayAgg(agg) => sanitize(&agg.expr), + Expr::GroupingSets(sets) => sets.iter() + .map(|i| i.iter()) + .try_for_each(|mut si| si.try_for_each(sanitize)), + Expr::Cube(cube) => cube.iter() + .map(|i| i.iter()) + .try_for_each(|mut si| si.try_for_each(sanitize)), + Expr::Rollup(rollup) => rollup.iter() + .map(|i| i.iter()) + .try_for_each(|mut si| si.try_for_each(sanitize)), + Expr::Tuple(tuple) => tuple.iter().try_for_each(sanitize), + Expr::Struct { values, fields: _ } => values.iter().try_for_each(sanitize), + Expr::Named { expr, name: _ } => sanitize(expr), + Expr::ArrayIndex { obj, indexes } => sanitize(obj) + .and(indexes.iter().try_for_each(sanitize)), + Expr::Array(array) => array.elem.iter().try_for_each(sanitize), + Expr::Interval(interval) => sanitize(&interval.value), + Expr::MatchAgainst { .. } => Ok(()), + Expr::Wildcard => Ok(()), + Expr::QualifiedWildcard(_) => Ok(()), + Expr::OuterJoin(expr) => sanitize(expr), + } +} + +fn main() -> Result<(), sqlparser::parser::ParserError> { + let query = std::env::args().skip(1).take(1).next().unwrap(); + static DIALECT: sqlparser::dialect::PostgreSqlDialect = sqlparser::dialect::PostgreSqlDialect {}; + + let parser: sqlparser::parser::Parser<'static> = sqlparser::parser::Parser::new(&DIALECT); + let expr = parser.try_with_sql(&query)?.parse_expr()?; + match sanitize(&expr) { + Ok(_) => eprintln!("{0:#?}\n\n{0}", expr), + Err(err) => { + eprintln!("{}", err); + } + } + + Ok(()) +} |