Skip to content

Commit 0ca37e1

Browse files
authored
feat: add support for posexplode and posexplode_outer (#4270)
1 parent 64b5ac3 commit 0ca37e1

9 files changed

Lines changed: 449 additions & 38 deletions

File tree

docs/source/user-guide/latest/operators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ not supported by Comet will fall back to regular Spark execution.
3030
| ExpandExec | Yes | |
3131
| FileSourceScanExec | Yes | Supports Parquet files. See the [Comet Compatibility Guide] for more information. |
3232
| FilterExec | Yes | |
33-
| GenerateExec | Yes | Supports `explode` generator only. |
33+
| GenerateExec | Yes | Supports `explode` and `posexplode` generators (arrays only, `_outer` variants are incompatible). |
3434
| GlobalLimitExec | Yes | |
3535
| HashAggregateExec | Yes | |
3636
| InsertIntoHadoopFsRelationCommand | No | Experimental support for native Parquet writes. Disabled by default. |

docs/source/user-guide/latest/understanding-comet-plans.md

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -158,21 +158,21 @@ by role. Names match what is shown in the plan output.
158158
These run natively in DataFusion. When several appear consecutively in a plan,
159159
they execute as a single fused native block.
160160

161-
| Node | Spark equivalent |
162-
| ---------------------------- | ---------------------------------------------- |
163-
| `CometProject` | `ProjectExec` |
164-
| `CometFilter` | `FilterExec` |
165-
| `CometSort` | `SortExec` |
166-
| `CometLocalLimit` | `LocalLimitExec` |
167-
| `CometGlobalLimit` | `GlobalLimitExec` |
168-
| `CometExpand` | `ExpandExec` |
169-
| `CometExplode` | `GenerateExec` (for `explode` only) |
170-
| `CometHashAggregate` | `HashAggregateExec`, `ObjectHashAggregateExec` |
171-
| `CometHashJoin` | `ShuffledHashJoinExec` |
172-
| `CometBroadcastHashJoin` | `BroadcastHashJoinExec` |
173-
| `CometSortMergeJoin` | `SortMergeJoinExec` |
174-
| `CometWindow` | `WindowExec` |
175-
| `CometTakeOrderedAndProject` | `TakeOrderedAndProjectExec` |
161+
| Node | Spark equivalent |
162+
| ---------------------------- | ----------------------------------------------- |
163+
| `CometProject` | `ProjectExec` |
164+
| `CometFilter` | `FilterExec` |
165+
| `CometSort` | `SortExec` |
166+
| `CometLocalLimit` | `LocalLimitExec` |
167+
| `CometGlobalLimit` | `GlobalLimitExec` |
168+
| `CometExpand` | `ExpandExec` |
169+
| `CometExplode` | `GenerateExec` (for `explode` and `posexplode`) |
170+
| `CometHashAggregate` | `HashAggregateExec`, `ObjectHashAggregateExec` |
171+
| `CometHashJoin` | `ShuffledHashJoinExec` |
172+
| `CometBroadcastHashJoin` | `BroadcastHashJoinExec` |
173+
| `CometSortMergeJoin` | `SortMergeJoinExec` |
174+
| `CometWindow` | `WindowExec` |
175+
| `CometTakeOrderedAndProject` | `TakeOrderedAndProjectExec` |
176176

177177
### JVM-Side Operators
178178

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::any::Any;
19+
use std::fmt::{Display, Formatter};
20+
use std::hash::{Hash, Hasher};
21+
use std::sync::Arc;
22+
23+
use arrow::array::{Array, ArrayRef, Int32Array, ListArray, RecordBatch};
24+
use arrow::datatypes::{DataType, Field, FieldRef, Schema};
25+
use datafusion::common::{exec_err, Result as DataFusionResult};
26+
use datafusion::physical_expr::PhysicalExpr;
27+
use datafusion::physical_plan::ColumnarValue;
28+
29+
/// A `PhysicalExpr` that takes a `List<T>` input and produces a `List<Int32>` where each row's
30+
/// values are `[0, 1, ..., len - 1]`. Offsets and the null bitmap are inherited from the input,
31+
/// so when the resulting list is unnested in parallel with the original list it produces the
32+
/// `pos` column expected by Spark's `posexplode`.
33+
#[derive(Debug, Clone)]
34+
pub struct ListPositionsExpr {
35+
child: Arc<dyn PhysicalExpr>,
36+
field: FieldRef,
37+
}
38+
39+
impl ListPositionsExpr {
40+
pub fn new(child: Arc<dyn PhysicalExpr>) -> Self {
41+
let field = Arc::new(Field::new(
42+
"item",
43+
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
44+
true,
45+
));
46+
Self { child, field }
47+
}
48+
}
49+
50+
impl Display for ListPositionsExpr {
51+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
52+
write!(f, "list_positions({})", self.child)
53+
}
54+
}
55+
56+
impl PartialEq for ListPositionsExpr {
57+
fn eq(&self, other: &Self) -> bool {
58+
self.child.eq(&other.child)
59+
}
60+
}
61+
62+
impl Eq for ListPositionsExpr {}
63+
64+
impl Hash for ListPositionsExpr {
65+
fn hash<H: Hasher>(&self, state: &mut H) {
66+
self.child.hash(state);
67+
}
68+
}
69+
70+
impl PhysicalExpr for ListPositionsExpr {
71+
fn as_any(&self) -> &dyn Any {
72+
self
73+
}
74+
75+
fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
76+
Display::fmt(self, f)
77+
}
78+
79+
fn data_type(&self, _input_schema: &Schema) -> DataFusionResult<DataType> {
80+
Ok(self.field.data_type().clone())
81+
}
82+
83+
fn nullable(&self, _input_schema: &Schema) -> DataFusionResult<bool> {
84+
Ok(true)
85+
}
86+
87+
fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
88+
let value = self.child.evaluate(batch)?;
89+
let array = value.into_array(batch.num_rows())?;
90+
91+
let list = match array.as_any().downcast_ref::<ListArray>() {
92+
Some(list) => list,
93+
None => {
94+
return exec_err!(
95+
"ListPositionsExpr expected List input, got {}",
96+
array.data_type()
97+
);
98+
}
99+
};
100+
101+
let offsets = list.offsets();
102+
let total_len = *offsets.last().unwrap() as usize;
103+
104+
let mut values: Vec<i32> = Vec::with_capacity(total_len);
105+
for window in offsets.windows(2) {
106+
let start = window[0];
107+
let end = window[1];
108+
for i in 0..(end - start) {
109+
values.push(i);
110+
}
111+
}
112+
113+
let element_field = Arc::new(Field::new("item", DataType::Int32, true));
114+
let result = ListArray::new(
115+
element_field,
116+
offsets.clone(),
117+
Arc::new(Int32Array::from(values)),
118+
list.nulls().cloned(),
119+
);
120+
121+
Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
122+
}
123+
124+
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
125+
vec![&self.child]
126+
}
127+
128+
fn with_new_children(
129+
self: Arc<Self>,
130+
children: Vec<Arc<dyn PhysicalExpr>>,
131+
) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
132+
if children.len() != 1 {
133+
return exec_err!(
134+
"ListPositionsExpr expects exactly 1 child, got {}",
135+
children.len()
136+
);
137+
}
138+
Ok(Arc::new(ListPositionsExpr::new(Arc::clone(&children[0]))))
139+
}
140+
}

native/core/src/execution/expressions/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
pub mod arithmetic;
2121
pub mod bitwise;
2222
pub mod comparison;
23+
pub mod list_positions;
2324
pub mod logical;
2425
pub mod nullcheck;
2526
pub mod partition;

native/core/src/execution/planner.rs

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ pub mod operator_registry;
2424
use crate::execution::operators::init_csv_datasource_exec;
2525
use crate::execution::operators::IcebergScanExec;
2626
use crate::execution::{
27+
expressions::list_positions::ListPositionsExpr,
2728
expressions::subquery::Subquery,
2829
operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec, ShuffleScanExec},
2930
planner::expression_registry::ExpressionRegistry,
@@ -1656,12 +1657,8 @@ impl PhysicalPlanner {
16561657
.map(|expr| self.create_expr(expr, child.schema()))
16571658
.collect::<Result<Vec<_>, _>>()?;
16581659

1659-
// For UnnestExec, we need to add a projection to put the columns in the right order:
1660-
// 1. First add all projection columns
1661-
// 2. Then add the array column to be exploded
1662-
// Then UnnestExec will unnest the last column
1663-
1664-
// Use return_field() to get the proper column names from the expressions
1660+
// For posexplode, a parallel List<Int32> positions column is added before the
1661+
// array column so UnnestExec can unnest both in parallel.
16651662
let child_schema = child.schema();
16661663
let mut project_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> = projections
16671664
.iter()
@@ -1674,34 +1671,44 @@ impl PhysicalPlanner {
16741671
})
16751672
.collect();
16761673

1677-
// Add the array column as the last column
16781674
let array_field = child_expr
16791675
.return_field(&child_schema)
16801676
.expect("Failed to get field from array expression");
16811677
let array_col_name = array_field.name().to_string();
1678+
1679+
if explode.position {
1680+
let positions_expr: Arc<dyn PhysicalExpr> =
1681+
Arc::new(ListPositionsExpr::new(Arc::clone(&child_expr)));
1682+
project_exprs.push((positions_expr, "pos".to_string()));
1683+
}
16821684
project_exprs.push((Arc::clone(&child_expr), array_col_name.clone()));
16831685

1684-
// Create a projection to arrange columns as needed
16851686
let project_exec = Arc::new(ProjectionExec::try_new(
16861687
project_exprs,
16871688
Arc::clone(&child.native_plan),
16881689
)?);
16891690

1690-
// Get the input schema from the projection
16911691
let project_schema = project_exec.schema();
16921692

16931693
// Build the output schema for UnnestExec
1694-
// The output schema replaces the list column with its element type
16951694
let mut output_fields: Vec<Field> = Vec::new();
16961695

16971696
// Add all projection columns (non-array columns)
16981697
for i in 0..projections.len() {
16991698
output_fields.push(project_schema.field(i).clone());
17001699
}
17011700

1702-
// Add the unnested array element field
1701+
let array_input_index = if explode.position {
1702+
// With outer=true, UnnestExec preserves rows whose array is empty or NULL
1703+
// and emits a NULL position for them, so pos must be nullable in that case.
1704+
output_fields.push(Field::new("pos", DataType::Int32, explode.outer));
1705+
projections.len() + 1
1706+
} else {
1707+
projections.len()
1708+
};
1709+
17031710
// Extract the element type from the list/array type
1704-
let array_field = project_schema.field(projections.len());
1711+
let array_field = project_schema.field(array_input_index);
17051712
let element_type = match array_field.data_type() {
17061713
DataType::List(field) => field.data_type().clone(),
17071714
dt => {
@@ -1712,8 +1719,6 @@ impl PhysicalPlanner {
17121719
}
17131720
};
17141721

1715-
// The output column has the same name as the input array column
1716-
// but with the element type instead of the list type
17171722
output_fields.push(Field::new(
17181723
array_field.name(),
17191724
element_type,
@@ -1722,12 +1727,17 @@ impl PhysicalPlanner {
17221727

17231728
let output_schema = Arc::new(Schema::new(output_fields));
17241729

1725-
// Use UnnestExec to explode the last column (the array column)
1726-
// ListUnnest specifies which column to unnest and the depth (1 for single level)
1727-
let list_unnest = ListUnnest {
1728-
index_in_input_schema: projections.len(), // Index of the array column to unnest
1729-
depth: 1, // Unnest one level (explode single array)
1730-
};
1730+
let mut list_unnests = Vec::with_capacity(2);
1731+
if explode.position {
1732+
list_unnests.push(ListUnnest {
1733+
index_in_input_schema: projections.len(),
1734+
depth: 1,
1735+
});
1736+
}
1737+
list_unnests.push(ListUnnest {
1738+
index_in_input_schema: array_input_index,
1739+
depth: 1,
1740+
});
17311741

17321742
let unnest_options = UnnestOptions {
17331743
preserve_nulls: explode.outer,
@@ -1736,7 +1746,7 @@ impl PhysicalPlanner {
17361746

17371747
let unnest_exec = Arc::new(UnnestExec::new(
17381748
project_exec,
1739-
vec![list_unnest],
1749+
list_unnests,
17401750
vec![], // No struct columns to unnest
17411751
output_schema,
17421752
unnest_options,

native/proto/src/proto/operator.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,8 @@ message Explode {
361361
bool outer = 2;
362362
// Expressions for other columns to project alongside the exploded values
363363
repeated spark.spark_expression.Expr project_list = 3;
364+
// Whether to emit a position column alongside the exploded values (posexplode)
365+
bool position = 4;
364366
}
365367

366368
message HashJoin {

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1218,7 +1218,8 @@ object CometExplodeExec extends CometOperatorSerde[GenerateExec] {
12181218
if (op.generator.children.length != 1) {
12191219
return Unsupported(Some("generators with multiple inputs are not supported"))
12201220
}
1221-
if (op.generator.nodeName.toLowerCase(Locale.ROOT) != "explode") {
1221+
val nodeName = op.generator.nodeName.toLowerCase(Locale.ROOT)
1222+
if (nodeName != "explode" && nodeName != "posexplode") {
12221223
return Unsupported(Some(s"Unsupported generator: ${op.generator.nodeName}"))
12231224
}
12241225
if (op.outer) {
@@ -1262,10 +1263,13 @@ object CometExplodeExec extends CometOperatorSerde[GenerateExec] {
12621263
return None
12631264
}
12641265

1266+
val isPosExplode = op.generator.nodeName.toLowerCase(Locale.ROOT) == "posexplode"
1267+
12651268
val explodeBuilder = OperatorOuterClass.Explode
12661269
.newBuilder()
12671270
.setChild(childExprProto.get)
12681271
.setOuter(op.outer)
1272+
.setPosition(isPosExplode)
12691273
.addAllProjectList(projectExprs.map(_.get).asJava)
12701274

12711275
Some(builder.setExplode(explodeBuilder).build())

0 commit comments

Comments
 (0)