From 50e7dad8774210262c78caac5a19d08af8960c5c Mon Sep 17 00:00:00 2001 From: Scott Schenkein Date: Fri, 29 May 2026 22:53:40 -0400 Subject: [PATCH 1/2] fix: rebalance deep AND/OR chains to avoid protobuf recursion limit A left-deep chain of N associative boolean operands serializes to a proto nested N levels deep. With N > protobuf's default recursion limit (100), the message overflows when the serialized plan is re-parsed -- on the JVM via Operator.parseFrom (findShuffleScanIndices / explain) and in the Rust prost decoder -- failing an otherwise-supported query. Comet evaluates AND/OR vectorially (both sides always evaluated, no row-level short-circuit), so the chains are fully associative. Flatten each chain and rebuild it as a balanced O(log n) tree before serialization; this is semantically identical and only changes the proto's shape. Adds QueryPlanSerde.flattenAssociative + createBalancedBinaryExpr and routes CometAnd / CometOr through them. Closes #4526 Co-Authored-By: Claude Opus 4.7 --- .../apache/comet/serde/QueryPlanSerde.scala | 61 +++++++++++++++++++ .../org/apache/comet/serde/predicates.scala | 21 +++++-- .../apache/comet/CometExpressionSuite.scala | 21 +++++++ 3 files changed, 97 insertions(+), 6 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 0bdc02a790..82e4663e5a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -832,6 +832,67 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { } } + /** + * Serialize an associative boolean chain (`And` / `Or`) as a BALANCED `BinaryExpr` tree of + * depth `O(log n)` instead of the natural left-deep `O(n)`. A query with many ANDed/ORed + * predicates otherwise builds a proto nested deeper than protobuf's default recursion limit + * (100), which overflows when the serialized plan is re-parsed -- on the JVM + * (`OperatorOuterClass.Operator.parseFrom`, e.g. `findShuffleScanIndices` / explain) and in the + * Rust prost decoder. Comet evaluates `And`/`Or` vectorially (both sides always evaluated, no + * row-level short-circuit), so rebalancing the associative chain is semantically identical -- + * it only changes the proto's shape. + * + * `operands` are the flattened leaves of the chain (see [[flattenAssociative]]); `wrap` tags + * each combined `BinaryExpr` as `And` or `Or`. + */ + def createBalancedBinaryExpr( + expr: Expression, + operands: Seq[Expression], + inputs: Seq[Attribute], + binding: Boolean, + wrap: ( + ExprOuterClass.Expr.Builder, + ExprOuterClass.BinaryExpr) => ExprOuterClass.Expr.Builder) + : Option[ExprOuterClass.Expr] = { + val protos = operands.map(exprToProtoInternal(_, inputs, binding)) + if (protos.exists(_.isEmpty)) { + withFallbackReason(expr, operands: _*) + None + } else { + val leaves = protos.map(_.get).toIndexedSeq + def build(slice: IndexedSeq[ExprOuterClass.Expr]): ExprOuterClass.Expr = { + if (slice.length == 1) slice.head + else { + val mid = slice.length / 2 + val inner = ExprOuterClass.BinaryExpr + .newBuilder() + .setLeft(build(slice.slice(0, mid))) + .setRight(build(slice.slice(mid, slice.length))) + .build() + wrap(ExprOuterClass.Expr.newBuilder(), inner).build() + } + } + Some(build(leaves)) + } + } + + /** + * Flatten an associative binary chain into its leaf operands. `matches` identifies the same + * operator (e.g. `case _: And => true`) and `children` extracts its two operands. Used to + * rebalance deep `And`/`Or` chains before serialization (see [[createBalancedBinaryExpr]]). + */ + def flattenAssociative( + expr: Expression, + matches: Expression => Boolean, + children: Expression => (Expression, Expression)): Seq[Expression] = { + if (matches(expr)) { + val (l, r) = children(expr) + flattenAssociative(l, matches, children) ++ flattenAssociative(r, matches, children) + } else { + Seq(expr) + } + } + def scalarFunctionExprToProtoWithReturnType( funcName: String, returnType: DataType, diff --git a/spark/src/main/scala/org/apache/comet/serde/predicates.scala b/spark/src/main/scala/org/apache/comet/serde/predicates.scala index 7abe40823e..63b64fbcf2 100644 --- a/spark/src/main/scala/org/apache/comet/serde/predicates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/predicates.scala @@ -69,10 +69,16 @@ object CometAnd extends CometExpressionSerde[And] { expr: And, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - createBinaryExpr( + // Rebalance the (associative) AND chain so deep `a AND b AND ...` predicates produce a + // shallow proto instead of a left-deep one that overflows protobuf's recursion limit when + // the plan is re-parsed (see createBalancedBinaryExpr). + val operands = flattenAssociative( expr, - expr.left, - expr.right, + { case _: And => true; case _ => false }, + { case a: And => (a.left, a.right) }) + createBalancedBinaryExpr( + expr, + operands, inputs, binding, (builder, binaryExpr) => builder.setAnd(binaryExpr)) @@ -84,10 +90,13 @@ object CometOr extends CometExpressionSerde[Or] { expr: Or, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - createBinaryExpr( + val operands = flattenAssociative( expr, - expr.left, - expr.right, + { case _: Or => true; case _ => false }, + { case o: Or => (o.left, o.right) }) + createBalancedBinaryExpr( + expr, + operands, inputs, binding, (builder, binaryExpr) => builder.setOr(binaryExpr)) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 819b1ba051..6cbe5f71f6 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -3096,4 +3096,25 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("deep AND/OR predicate chains do not overflow the protobuf recursion limit") { + // A left-deep chain of N associative boolean operands serializes to a proto nested N + // levels deep. With N > protobuf's default recursion limit (100), the message overflows + // when the serialized plan is re-parsed (JVM Operator.parseFrom and the Rust prost + // decoder), failing an otherwise-supported query. Comet evaluates AND/OR vectorially with + // no short-circuit, so the chain is fully associative and safe to rebalance. + val n = 200 + withParquetTable((0 until 100).map(i => (i, i.toLong)), "tbl") { + // Project the chains as boolean columns rather than filtering: a top-level filter AND is + // split by Spark's splitConjunctivePredicates into many shallow pushed predicates, which + // would hide the deep-nesting. A projected expression survives intact. Distinct literals + // keep the optimizer from folding the chain; `>`/`<` (not `=`) keeps OptimizeIn from + // collapsing the OR chain into a single In. + val andChain = (1 to n).map(i => col("_1") > lit(-i)).reduce(_ && _) + checkSparkAnswerAndOperator(spark.table("tbl").select(andChain.as("a"))) + + val orChain = (1 to n).map(i => col("_1") < lit(i)).reduce(_ || _) + checkSparkAnswerAndOperator(spark.table("tbl").select(orChain.as("o"))) + } + } + } From 52aea508d9ad0ec219cef90e6435cae5fb15cf48 Mon Sep 17 00:00:00 2001 From: Scott Schenkein Date: Sat, 30 May 2026 17:52:36 -0400 Subject: [PATCH 2/2] refactor: flatten AND/OR chains iteratively; test null operands and OR WHERE Address review feedback on the deep-chain rebalancing PR: - flattenAssociative now uses an explicit work stack and an accumulating buffer instead of recursion. The chains that trigger this are left-deep and O(n) deep, so the prior recursive walk could itself overflow the JVM stack and the `++` accumulation was O(n^2). - The recursion-limit test now mixes a nullable column into the chains so the rebalanced tree is exercised under SQL three-valued logic, and adds a deep OR in a WHERE clause -- a common trigger that, unlike a top-level AND, Spark does not split and so stays deeply nested. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../apache/comet/serde/QueryPlanSerde.scala | 31 ++++++++++++++----- .../apache/comet/CometExpressionSuite.scala | 22 ++++++++++--- 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 82e4663e5a..12e710d44e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -21,6 +21,7 @@ package org.apache.comet.serde import java.util.concurrent.atomic.AtomicLong +import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import org.apache.spark.internal.Logging @@ -877,20 +878,34 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { } /** - * Flatten an associative binary chain into its leaf operands. `matches` identifies the same - * operator (e.g. `case _: And => true`) and `children` extracts its two operands. Used to - * rebalance deep `And`/`Or` chains before serialization (see [[createBalancedBinaryExpr]]). + * Flatten an associative binary chain into its leaf operands, in left-to-right order. `matches` + * identifies the same operator (e.g. `case _: And => true`) and `children` extracts its two + * operands. Used to rebalance deep `And`/`Or` chains before serialization (see + * [[createBalancedBinaryExpr]]). + * + * Implemented with an explicit work stack and an accumulating buffer rather than recursion: the + * chains that trigger this are left-deep and `O(n)` deep, so a recursive walk could itself + * overflow the JVM stack, and `++`-accumulating the results would be `O(n^2)`. */ def flattenAssociative( expr: Expression, matches: Expression => Boolean, children: Expression => (Expression, Expression)): Seq[Expression] = { - if (matches(expr)) { - val (l, r) = children(expr) - flattenAssociative(l, matches, children) ++ flattenAssociative(r, matches, children) - } else { - Seq(expr) + val operands = ArrayBuffer.empty[Expression] + var stack: List[Expression] = expr :: Nil + while (stack.nonEmpty) { + val current = stack.head + stack = stack.tail + if (matches(current)) { + val (l, r) = children(current) + // Push right before left so the left subtree is popped (and emitted) first, preserving + // the original left-to-right operand order. + stack = l :: r :: stack + } else { + operands += current + } } + operands.toSeq } def scalarFunctionExprToProtoWithReturnType( diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 6cbe5f71f6..8dc84dc9e7 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -24,7 +24,7 @@ import java.time.{Duration, Period} import scala.util.Random import org.apache.hadoop.fs.Path -import org.apache.spark.sql.{CometTestBase, DataFrame, Row} +import org.apache.spark.sql.{Column, CometTestBase, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, FromUnixTime, Literal, StructsToJson, TruncDate, TruncTimestamp} import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps import org.apache.spark.sql.comet.CometProjectExec @@ -3103,17 +3103,31 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { // decoder), failing an otherwise-supported query. Comet evaluates AND/OR vectorially with // no short-circuit, so the chain is fully associative and safe to rebalance. val n = 200 - withParquetTable((0 until 100).map(i => (i, i.toLong)), "tbl") { + // `_2` is nullable (every 7th row is null) so the rebalanced chain is exercised under SQL + // three-valued logic, not just true/false operands. + withParquetTable( + (0 until 100).map(i => (i, if (i % 7 == 0) None else Some(i.toLong))), + "tbl") { + // Build a chain that mixes the non-nullable `_1` with the nullable `_2` so null operands + // flow through the rebalanced tree. + def operand(i: Int): Column = + if (i % 2 == 0) col("_2") > lit(-i) else col("_1") > lit(-i) + // Project the chains as boolean columns rather than filtering: a top-level filter AND is // split by Spark's splitConjunctivePredicates into many shallow pushed predicates, which // would hide the deep-nesting. A projected expression survives intact. Distinct literals // keep the optimizer from folding the chain; `>`/`<` (not `=`) keeps OptimizeIn from // collapsing the OR chain into a single In. - val andChain = (1 to n).map(i => col("_1") > lit(-i)).reduce(_ && _) + val andChain = (1 to n).map(operand).reduce(_ && _) checkSparkAnswerAndOperator(spark.table("tbl").select(andChain.as("a"))) - val orChain = (1 to n).map(i => col("_1") < lit(i)).reduce(_ || _) + val orChain = (1 to n).map(i => col("_1") < lit(i) || col("_2") < lit(i)).reduce(_ || _) checkSparkAnswerAndOperator(spark.table("tbl").select(orChain.as("o"))) + + // A deep OR is a common real-world WHERE clause and, unlike a top-level AND, is NOT split + // by Spark -- it stays intact as a single deeply-nested predicate, so exercise that path + // directly. + checkSparkAnswerAndOperator(spark.table("tbl").where(orChain)) } }