#![cfg(feature = "sqlparser")]
use sqlparser::ast::Expr;

#[derive(Debug, thiserror::Error)]
enum Error<'sql> {
    #[error("subquery detected: {0}")]
    SubqueryDetected(&'sql sqlparser::ast::Query),
    #[error("function call detected: {0}")]
    FunctionCallDetected(&'sql sqlparser::ast::Function),
    #[error("sql parser error: {0}")]
    Sql(#[from] sqlparser::parser::ParserError),
}

fn sanitize(expr: &Expr) -> Result<(), Error> {
    match expr {
        Expr::Identifier(_) => Ok(()),
        Expr::CompoundIdentifier(_) => Ok(()),
        Expr::JsonAccess {
            left,
            operator: _,
            right,
        } => sanitize(left).and(sanitize(right)),
        Expr::CompositeAccess { expr, key: _ } => sanitize(expr),
        Expr::IsFalse(subexpr) => sanitize(subexpr),
        Expr::IsNotFalse(subexpr) => sanitize(subexpr),
        Expr::IsTrue(subexpr) => sanitize(subexpr),
        Expr::IsNotTrue(subexpr) => sanitize(subexpr),
        Expr::IsNull(subexpr) => sanitize(subexpr),
        Expr::IsNotNull(subexpr) => sanitize(subexpr),
        Expr::IsUnknown(subexpr) => sanitize(subexpr),
        Expr::IsNotUnknown(subexpr) => sanitize(subexpr),
        Expr::IsDistinctFrom(left, right) => sanitize(left).and(sanitize(right)),
        Expr::IsNotDistinctFrom(left, right) => sanitize(left).and(sanitize(right)),
        Expr::InList {
            expr,
            list,
            negated: _,
        } => sanitize(expr).and(list.iter().try_for_each(sanitize)),
        Expr::InSubquery {
            expr: _,
            subquery,
            negated: _,
        } => Err(Error::SubqueryDetected(subquery.as_ref())),
        Expr::InUnnest {
            expr,
            array_expr,
            negated: _,
        } => sanitize(expr).and(sanitize(array_expr)),
        Expr::Between {
            expr,
            negated: _,
            low,
            high,
        } => sanitize(expr).and(sanitize(low)).and(sanitize(high)),
        Expr::BinaryOp { left, op: _, right } => sanitize(left).and(sanitize(right)),
        Expr::Like {
            negated: _,
            expr,
            pattern,
            escape_char: _,
        } => sanitize(expr).and(sanitize(pattern)),
        Expr::ILike {
            negated: _,
            expr,
            pattern,
            escape_char: _,
        } => sanitize(expr).and(sanitize(pattern)),
        Expr::SimilarTo {
            negated: _,
            expr,
            pattern,
            escape_char: _,
        } => sanitize(expr).and(sanitize(pattern)),
        Expr::RLike {
            negated: _,
            expr,
            pattern,
            regexp: _,
        } => sanitize(expr).and(sanitize(pattern)),
        Expr::AnyOp {
            left,
            compare_op: _,
            right,
        } => sanitize(left).and(sanitize(right)),
        Expr::AllOp {
            left,
            compare_op: _,
            right,
        } => sanitize(left).and(sanitize(right)),
        Expr::UnaryOp { op: _, expr } => sanitize(expr),
        Expr::Convert {
            expr,
            data_type: _,
            charset: _,
            target_before_value: _,
        } => sanitize(expr),
        Expr::Cast {
            expr,
            data_type: _,
            format: _,
        } => sanitize(expr),
        Expr::TryCast {
            expr,
            data_type: _,
            format: _,
        } => sanitize(expr),
        Expr::SafeCast {
            expr,
            data_type: _,
            format: _,
        } => sanitize(expr),
        Expr::AtTimeZone {
            timestamp,
            time_zone: _,
        } => sanitize(timestamp),
        Expr::Extract { field: _, expr } => sanitize(expr),
        Expr::Ceil { expr, field: _ } => sanitize(expr),
        Expr::Floor { expr, field: _ } => sanitize(expr),
        Expr::Position { expr, r#in } => sanitize(expr).and(sanitize(r#in)),
        Expr::Substring {
            expr,
            substring_from,
            substring_for,
            special: _,
        } => sanitize(expr)
            .and(substring_from.as_deref().map(sanitize).unwrap_or(Ok(())))
            .and(substring_for.as_deref().map(sanitize).unwrap_or(Ok(()))),
        Expr::Trim {
            expr,
            trim_where: _,
            trim_what,
            trim_characters,
        } => sanitize(expr)
            .and(trim_what.as_deref().map(sanitize).unwrap_or(Ok(())))
            .and(
                trim_characters
                    .as_ref()
                    .map(|v| v.iter())
                    .map(|mut iter| iter.try_for_each(sanitize))
                    .unwrap_or(Ok(())),
            ),
        Expr::Overlay {
            expr,
            overlay_what,
            overlay_from,
            overlay_for,
        } => sanitize(expr)
            .and(sanitize(overlay_what))
            .and(sanitize(overlay_from))
            .and(overlay_for.as_deref().map(sanitize).unwrap_or(Ok(()))),
        Expr::Collate { expr, collation: _ } => sanitize(expr),
        Expr::Nested(subexpr) => sanitize(subexpr),
        Expr::Value(_) => Ok(()),
        Expr::IntroducedString {
            introducer: _,
            value: _,
        } => Ok(()),
        Expr::TypedString {
            data_type: _,
            value: _,
        } => Ok(()),
        Expr::MapAccess { column, keys } => {
            sanitize(column).and(keys.iter().try_for_each(sanitize))
        }
        Expr::Function(func) => Err(Error::FunctionCallDetected(func)),
        Expr::AggregateExpressionWithFilter { expr, filter } => {
            sanitize(expr).and(sanitize(filter))
        }
        Expr::Case {
            operand,
            conditions,
            results,
            else_result,
        } => conditions
            .iter()
            .chain(results)
            .chain(operand.iter().map(std::borrow::Borrow::borrow))
            .chain(else_result.iter().map(std::borrow::Borrow::borrow))
            .try_for_each(sanitize),
        Expr::Exists {
            subquery,
            negated: _,
        } => Err(Error::SubqueryDetected(subquery)),
        Expr::Subquery(subquery) => Err(Error::SubqueryDetected(subquery.as_ref())),
        Expr::ArraySubquery(subquery) => Err(Error::SubqueryDetected(subquery.as_ref())),
        Expr::ListAgg(agg) => sanitize(&agg.expr),
        Expr::ArrayAgg(agg) => sanitize(&agg.expr),
        Expr::GroupingSets(sets) => sets
            .iter()
            .map(|i| i.iter())
            .try_for_each(|mut si| si.try_for_each(sanitize)),
        Expr::Cube(cube) => cube
            .iter()
            .map(|i| i.iter())
            .try_for_each(|mut si| si.try_for_each(sanitize)),
        Expr::Rollup(rollup) => rollup
            .iter()
            .map(|i| i.iter())
            .try_for_each(|mut si| si.try_for_each(sanitize)),
        Expr::Tuple(tuple) => tuple.iter().try_for_each(sanitize),
        Expr::Struct { values, fields: _ } => values.iter().try_for_each(sanitize),
        Expr::Named { expr, name: _ } => sanitize(expr),
        Expr::ArrayIndex { obj, indexes } => {
            sanitize(obj).and(indexes.iter().try_for_each(sanitize))
        }
        Expr::Array(array) => array.elem.iter().try_for_each(sanitize),
        Expr::Interval(interval) => sanitize(&interval.value),
        Expr::MatchAgainst { .. } => Ok(()),
        Expr::Wildcard => Ok(()),
        Expr::QualifiedWildcard(_) => Ok(()),
        Expr::OuterJoin(expr) => sanitize(expr),
    }
}

fn main() -> Result<(), sqlparser::parser::ParserError> {
    let query = std::env::args().skip(1).take(1).next().unwrap();
    static DIALECT: sqlparser::dialect::PostgreSqlDialect =
        sqlparser::dialect::PostgreSqlDialect {};

    let parser: sqlparser::parser::Parser<'static> = sqlparser::parser::Parser::new(&DIALECT);
    let expr = parser.try_with_sql(&query)?.parse_expr()?;
    match sanitize(&expr) {
        Ok(_) => eprintln!("{0:#?}\n\n{0}", expr),
        Err(err) => {
            eprintln!("{}", err);
        }
    }

    Ok(())
}