Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 75 additions & 2 deletions core/src/sql/sql_provider_datafusion/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use bigdecimal::{num_bigint::BigInt, BigDecimal};
use datafusion::{
logical_expr::{Cast, Expr},
logical_expr::{Cast, Expr, Operator},
scalar::ScalarValue,
sql::unparser::dialect::{
DefaultDialect, Dialect, DuckDBDialect, MySqlDialect, PostgreSqlDialect, SqliteDialect,
Expand Down Expand Up @@ -64,7 +64,12 @@ pub fn to_sql_with_engine(expr: &Expr, engine: Option<Engine>) -> Result<String>
}
}

Ok(format!("{} {} {}", left, binary_expr.op, right))
match binary_expr.op {
Operator::And | Operator::Or => {
Ok(format!("({}) {} ({})", left, binary_expr.op, right))
}
_ => Ok(format!("{} {} {}", left, binary_expr.op, right)),
}
}
Expr::Column(name) => match engine {
Some(Engine::Spark | Engine::ODBC) => Ok(format!("{name}")),
Expand Down Expand Up @@ -536,6 +541,74 @@ mod tests {
Ok(())
}

fn binary(left: Expr, op: Operator, right: Expr) -> Expr {
Expr::BinaryExpr(datafusion::logical_expr::BinaryExpr {
left: Box::new(left),
op,
right: Box::new(right),
})
}

fn int(v: i32) -> Expr {
Expr::Literal(ScalarValue::Int32(Some(v)), None)
}

#[test]
fn test_and_or_binary_exprs_are_parenthesized() -> Result<()> {
// AND operands are wrapped so the tree shape is preserved in SQL
let and_expr = binary(col("a"), Operator::And, col("b"));
assert_eq!(to_sql(&and_expr)?, "(\"a\") AND (\"b\")");

// OR operands are wrapped for the same reason
let or_expr = binary(col("a"), Operator::Or, col("b"));
assert_eq!(to_sql(&or_expr)?, "(\"a\") OR (\"b\")");

Ok(())
}

#[test]
fn test_non_boolean_binary_exprs_not_parenthesized() -> Result<()> {
// Comparison and arithmetic operators must NOT add extra parens
let eq_expr = binary(col("k1"), Operator::Eq, int(1));
assert_eq!(to_sql(&eq_expr)?, "\"k1\" = 1");

let lt_expr = binary(col("v"), Operator::Lt, int(42));
assert_eq!(to_sql(&lt_expr)?, "\"v\" < 42");

let plus_expr = binary(col("x"), Operator::Plus, int(5));
assert_eq!(to_sql(&plus_expr)?, "\"x\" + 5");

Ok(())
}

#[test]
fn test_composite_key_or_chain_parenthesized() -> Result<()> {
// Simulates a two-row composite-key DELETE:
// (k1 = 1 AND k2 = 2) OR (k1 = 2 AND k2 = 4)
//
// Without parens this serialises as a flat chain that DuckDB's optimizer
// reconstructs as a left-recursive tree of depth N, causing a stack overflow
// for large N. With parens the structure is explicit.
let row1 = binary(
binary(col("k1"), Operator::Eq, int(1)),
Operator::And,
binary(col("k2"), Operator::Eq, int(2)),
);
let row2 = binary(
binary(col("k1"), Operator::Eq, int(2)),
Operator::And,
binary(col("k2"), Operator::Eq, int(4)),
);
let or_chain = binary(row1, Operator::Or, row2);

assert_eq!(
to_sql(&or_chain)?,
"((\"k1\" = 1) AND (\"k2\" = 2)) OR ((\"k1\" = 2) AND (\"k2\" = 4))"
);

Ok(())
}

#[test]
fn test_expr_timestamp_scalar_value_to_sql() -> Result<()> {
let expr = Expr::Literal(
Expand Down
105 changes: 105 additions & 0 deletions core/tests/duckdb/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,108 @@ mod multipart_table_reference {
assert_eq!(vals.values().to_vec(), vec![10, 20]);
}
}

/// Tests that `to_sql_with_engine` emits a parenthesized WHERE clause for composite-key
/// deletes, preventing a stack overflow in DuckDB's optimizer on the 512 KiB worker thread.
///
/// Without the fix, `to_sql_with_engine` emits a flat unparenthesized OR chain:
/// `"k1" = 1 AND "k2" = 2 OR "k1" = 2 AND "k2" = 4 OR … OR "k1" = N AND "k2" = 2N`
///
/// DuckDB's parser reconstructs this as a left-recursive tree of depth N.
/// ExpressionRewriter / DistributivityRule recurse N levels deep → stack overflow.
///
/// With the fix, AND/OR operands are wrapped in parentheses, so the depth stays O(log N).
mod composite_delete_stack_overflow {
use datafusion::logical_expr::{BinaryExpr, Expr, Operator};
use datafusion::prelude::col;
use datafusion::scalar::ScalarValue;
use datafusion_table_providers::sql::sql_provider_datafusion::expr::{
to_sql_with_engine, Engine,
};
use duckdb::Connection;

const N: usize = 100;

fn setup(conn: &Connection) {
conn.execute_batch("CREATE TABLE test (k1 INTEGER, k2 INTEGER, val INTEGER);")
.unwrap();
conn.execute_batch(&format!(
"INSERT INTO test SELECT i, i*2, i FROM generate_series(1, {N}) t(i);"
))
.unwrap();
}

fn binary(left: Expr, op: Operator, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr {
left: Box::new(left),
op,
right: Box::new(right),
})
}

fn int32(v: i32) -> Expr {
Expr::Literal(ScalarValue::Int32(Some(v)), None)
}

fn reduce_or(conds: &[Expr]) -> Expr {
if conds.len() == 1 {
return conds[0].clone();
}
let mid = conds.len() / 2;
binary(
reduce_or(&conds[..mid]),
Operator::Or,
reduce_or(&conds[mid..]),
)
}

/// Verifies that `to_sql_with_engine` emits a parenthesized WHERE clause that DuckDB
/// can execute on a 512 KiB worker thread stack.
///
/// Failure modes:
/// - Without the parenthesization fix: the assertion on SQL structure fails
/// immediately (clean test failure, no process crash).
/// - With correct SQL but a DuckDB regression: `execute_batch` panics inside the
/// spawned thread, which surfaces as a `join` error.
#[test]
fn composite_delete_via_to_sql_with_engine() {
let conds: Vec<Expr> = (1..=N)
.map(|i| {
binary(
binary(col("k1"), Operator::Eq, int32(i as i32)),
Operator::And,
binary(col("k2"), Operator::Eq, int32((i * 2) as i32)),
)
})
.collect();

let where_clause = to_sql_with_engine(&reduce_or(&conds), Some(Engine::DuckDB))
.expect("to_sql_with_engine failed");

// Without the fix: `"k1" = 1 AND "k2" = 2 OR …` — no parens, fails here.
// With the fix: `(("k1" = 1) AND ("k2" = 2)) OR …`
assert!(
where_clause.starts_with('('),
"WHERE clause must be parenthesized to avoid DuckDB stack overflow:\n{where_clause}"
);

let sql = format!("DELETE FROM test WHERE {where_clause}");

let handle = std::thread::Builder::new()
.stack_size(512 * 1024) // 512 KiB — matches production refresh-worker stack
.name("refresh-worker".to_string())
.spawn(move || {
let conn = Connection::open_in_memory().unwrap();
setup(&conn);
conn.execute_batch(&sql).unwrap();

let remaining: i64 = conn
.query_row("SELECT COUNT(*) FROM test", [], |r| r.get(0))
.unwrap();
assert_eq!(remaining, 0, "all {N} rows should be deleted");
})
.unwrap();

handle.join().unwrap();
}
}
Loading