Skip to content

Commit 09fff48

Browse files
committed
CASE WHEN expression with execution context and DataFusion equivalence tests
1 parent ba82921 commit 09fff48

3 files changed

Lines changed: 166 additions & 18 deletions

File tree

vortex-array/benches/expr/case_when_bench.rs

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,28 @@
44
#![allow(clippy::unwrap_used)]
55
#![allow(clippy::cast_possible_truncation)]
66

7+
use std::sync::LazyLock;
8+
79
use divan::Bencher;
810
use vortex_array::ArrayRef;
11+
use vortex_array::Canonical;
912
use vortex_array::IntoArray;
13+
use vortex_array::VortexSessionExecute;
1014
use vortex_array::arrays::StructArray;
1115
use vortex_array::expr::case_when;
1216
use vortex_array::expr::get_item;
1317
use vortex_array::expr::gt;
1418
use vortex_array::expr::lit;
1519
use vortex_array::expr::nested_case_when;
1620
use vortex_array::expr::root;
21+
use vortex_array::session::ArraySession;
1722
use vortex_array::validity::Validity;
1823
use vortex_buffer::Buffer;
1924
use vortex_dtype::FieldNames;
25+
use vortex_session::VortexSession;
26+
27+
static SESSION: LazyLock<VortexSession> =
28+
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
2029

2130
fn main() {
2231
divan::main();
@@ -49,7 +58,14 @@ fn case_when_simple(bencher: Bencher, size: usize) {
4958

5059
bencher
5160
.with_inputs(|| (&expr, &array))
52-
.bench_refs(|(expr, array)| expr.evaluate(array).unwrap());
61+
.bench_refs(|(expr, array)| {
62+
let mut ctx = SESSION.create_execution_ctx();
63+
array
64+
.apply(expr)
65+
.unwrap()
66+
.execute::<Canonical>(&mut ctx)
67+
.unwrap()
68+
});
5369
}
5470

5571
/// Benchmark nested CASE WHEN with multiple conditions.
@@ -69,7 +85,14 @@ fn case_when_nested_3_conditions(bencher: Bencher, size: usize) {
6985

7086
bencher
7187
.with_inputs(|| (&expr, &array))
72-
.bench_refs(|(expr, array)| expr.evaluate(array).unwrap());
88+
.bench_refs(|(expr, array)| {
89+
let mut ctx = SESSION.create_execution_ctx();
90+
array
91+
.apply(expr)
92+
.unwrap()
93+
.execute::<Canonical>(&mut ctx)
94+
.unwrap()
95+
});
7396
}
7497

7598
/// Benchmark CASE WHEN where all conditions are true (short-circuit path).
@@ -86,7 +109,14 @@ fn case_when_all_true(bencher: Bencher, size: usize) {
86109

87110
bencher
88111
.with_inputs(|| (&expr, &array))
89-
.bench_refs(|(expr, array)| expr.evaluate(array).unwrap());
112+
.bench_refs(|(expr, array)| {
113+
let mut ctx = SESSION.create_execution_ctx();
114+
array
115+
.apply(expr)
116+
.unwrap()
117+
.execute::<Canonical>(&mut ctx)
118+
.unwrap()
119+
});
90120
}
91121

92122
/// Benchmark CASE WHEN where all conditions are false (short-circuit path).
@@ -103,5 +133,12 @@ fn case_when_all_false(bencher: Bencher, size: usize) {
103133

104134
bencher
105135
.with_inputs(|| (&expr, &array))
106-
.bench_refs(|(expr, array)| expr.evaluate(array).unwrap());
136+
.bench_refs(|(expr, array)| {
137+
let mut ctx = SESSION.create_execution_ctx();
138+
array
139+
.apply(expr)
140+
.unwrap()
141+
.execute::<Canonical>(&mut ctx)
142+
.unwrap()
143+
});
107144
}

vortex-array/src/expr/exprs/case_when.rs

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -287,16 +287,21 @@ pub fn nested_case_when(
287287

288288
#[cfg(test)]
289289
mod tests {
290+
use std::sync::LazyLock;
291+
290292
use vortex_buffer::buffer;
291293
use vortex_dtype::DType;
292294
use vortex_dtype::Nullability;
293295
use vortex_dtype::PType;
294296
use vortex_error::VortexExpect as _;
295297
use vortex_scalar::Scalar;
298+
use vortex_session::VortexSession;
296299

297300
use super::*;
301+
use crate::Canonical;
298302
use crate::IntoArray;
299303
use crate::ToCanonical;
304+
use crate::VortexSessionExecute as _;
300305
use crate::arrays::BoolArray;
301306
use crate::arrays::PrimitiveArray;
302307
use crate::arrays::StructArray;
@@ -307,6 +312,21 @@ mod tests {
307312
use crate::expr::exprs::literal::lit;
308313
use crate::expr::exprs::root::root;
309314
use crate::expr::test_harness;
315+
use crate::session::ArraySession;
316+
317+
static SESSION: LazyLock<VortexSession> =
318+
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
319+
320+
/// Helper to evaluate an expression using the apply+execute pattern
321+
fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
322+
let mut ctx = SESSION.create_execution_ctx();
323+
array
324+
.apply(expr)
325+
.unwrap()
326+
.execute::<Canonical>(&mut ctx)
327+
.unwrap()
328+
.into_array()
329+
}
310330

311331
// ==================== Serialization Tests ====================
312332

@@ -455,7 +475,7 @@ mod tests {
455475
lit(0i32),
456476
);
457477

458-
let result = expr.evaluate(&test_array).unwrap().to_primitive();
478+
let result = evaluate_expr(&expr, &test_array).to_primitive();
459479
assert_eq!(result.as_slice::<i32>(), &[0, 0, 100, 100, 100]);
460480
}
461481

@@ -475,7 +495,7 @@ mod tests {
475495
Some(lit(0i32)),
476496
);
477497

478-
let result = expr.evaluate(&test_array).unwrap().to_primitive();
498+
let result = evaluate_expr(&expr, &test_array).to_primitive();
479499
assert_eq!(result.as_slice::<i32>(), &[10, 0, 30, 0, 0]);
480500
}
481501

@@ -495,7 +515,7 @@ mod tests {
495515
Some(lit(0i32)),
496516
);
497517

498-
let result = expr.evaluate(&test_array).unwrap().to_primitive();
518+
let result = evaluate_expr(&expr, &test_array).to_primitive();
499519
assert_eq!(result.as_slice::<i32>(), &[0, 0, 100, 100, 100]);
500520
}
501521

@@ -508,7 +528,7 @@ mod tests {
508528

509529
let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32));
510530

511-
let result = expr.evaluate(&test_array).unwrap();
531+
let result = evaluate_expr(&expr, &test_array);
512532
assert!(result.dtype().is_nullable());
513533

514534
assert_eq!(
@@ -546,7 +566,7 @@ mod tests {
546566
lit(0i32),
547567
);
548568

549-
let result = expr.evaluate(&test_array).unwrap().to_primitive();
569+
let result = evaluate_expr(&expr, &test_array).to_primitive();
550570
assert_eq!(result.as_slice::<i32>(), &[0, 0, 0, 0, 0]);
551571
}
552572

@@ -563,15 +583,15 @@ mod tests {
563583
lit(0i32),
564584
);
565585

566-
let result = expr.evaluate(&test_array).unwrap().to_primitive();
586+
let result = evaluate_expr(&expr, &test_array).to_primitive();
567587
assert_eq!(result.as_slice::<i32>(), &[100, 100, 100, 100, 100]);
568588
}
569589

570590
#[test]
571591
fn test_evaluate_with_literal_condition() {
572592
let test_array = buffer![1i32, 2, 3].into_array();
573593
let expr = case_when(lit(true), lit(100i32), lit(0i32));
574-
let result = expr.evaluate(&test_array).unwrap();
594+
let result = evaluate_expr(&expr, &test_array);
575595

576596
if let Some(constant) = result.as_constant() {
577597
assert_eq!(constant, Scalar::from(100i32));
@@ -594,9 +614,9 @@ mod tests {
594614
lit(false),
595615
);
596616

597-
let result = expr.evaluate(&test_array).unwrap().to_bool();
617+
let result = evaluate_expr(&expr, &test_array).to_bool();
598618
assert_eq!(
599-
result.bit_buffer().iter().collect::<Vec<_>>(),
619+
result.to_bit_buffer().iter().collect::<Vec<_>>(),
600620
vec![false, false, true, true, true]
601621
);
602622
}
@@ -612,7 +632,7 @@ mod tests {
612632

613633
let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
614634

615-
let result = expr.evaluate(&test_array).unwrap().to_primitive();
635+
let result = evaluate_expr(&expr, &test_array).to_primitive();
616636
assert_eq!(result.as_slice::<i32>(), &[100, 0, 0, 0, 100]);
617637
}
618638

@@ -635,7 +655,7 @@ mod tests {
635655
lit(0i32),
636656
);
637657

638-
let result = expr.evaluate(&test_array).unwrap();
658+
let result = evaluate_expr(&expr, &test_array);
639659
let prim = result.to_primitive();
640660
assert_eq!(prim.as_slice::<i32>(), &[0, 0, 30, 40, 50]);
641661
}
@@ -651,12 +671,11 @@ mod tests {
651671

652672
let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
653673

654-
let result = expr.evaluate(&test_array).unwrap().to_primitive();
674+
let result = evaluate_expr(&expr, &test_array).to_primitive();
655675
assert_eq!(result.as_slice::<i32>(), &[0, 0, 0]);
656676
}
657677

658-
// Note: Direct execute tests are covered through evaluate tests above,
659-
// since evaluate() calls execute() internally.
678+
// Note: Direct execute tests are covered through apply+execute tests above.
660679

661680
// Note: The binary CASE WHEN implementation using `zip` does NOT provide
662681
// short-circuit/lazy evaluation. All child expressions are evaluated first,

vortex-datafusion/src/convert/exprs.rs

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,4 +847,96 @@ mod tests {
847847

848848
assert!(!can_be_pushed_down_impl(&like_expr, &test_schema));
849849
}
850+
851+
/// Test that applying a CASE expression to an Arrow RecordBatch using DataFusion
852+
/// matches the result of applying the converted Vortex expression.
853+
#[test]
854+
fn test_case_when_datafusion_vortex_equivalence() {
855+
use datafusion::arrow::array::Int32Array;
856+
use datafusion::arrow::array::RecordBatch;
857+
use datafusion_physical_expr::expressions::CaseExpr;
858+
use vortex::VortexSessionDefault;
859+
use vortex::array::ArrayRef;
860+
use vortex::array::Canonical;
861+
use vortex::array::VortexSessionExecute as _;
862+
use vortex::array::arrow::FromArrowArray;
863+
use vortex::session::VortexSession;
864+
865+
// Create test data
866+
let values = Arc::new(Int32Array::from(vec![1, 5, 10, 15, 20]));
867+
let schema = Arc::new(Schema::new(vec![Field::new(
868+
"value",
869+
DataType::Int32,
870+
false,
871+
)]));
872+
let batch = RecordBatch::try_new(schema, vec![values]).unwrap();
873+
874+
// Build a DataFusion CASE expression:
875+
// CASE WHEN value > 10 THEN 100 WHEN value > 5 THEN 50 ELSE 0 END
876+
let col_value = Arc::new(df_expr::Column::new("value", 0)) as Arc<dyn PhysicalExpr>;
877+
let lit_10 =
878+
Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(10)))) as Arc<dyn PhysicalExpr>;
879+
let lit_5 =
880+
Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn PhysicalExpr>;
881+
let lit_100 =
882+
Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(100)))) as Arc<dyn PhysicalExpr>;
883+
let lit_50 =
884+
Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(50)))) as Arc<dyn PhysicalExpr>;
885+
let lit_0 =
886+
Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(0)))) as Arc<dyn PhysicalExpr>;
887+
888+
// WHEN value > 10 THEN 100
889+
let when1 = Arc::new(df_expr::BinaryExpr::new(
890+
col_value.clone(),
891+
DFOperator::Gt,
892+
lit_10,
893+
)) as Arc<dyn PhysicalExpr>;
894+
// WHEN value > 5 THEN 50
895+
let when2 = Arc::new(df_expr::BinaryExpr::new(col_value, DFOperator::Gt, lit_5))
896+
as Arc<dyn PhysicalExpr>;
897+
898+
let case_expr =
899+
CaseExpr::try_new(None, vec![(when1, lit_100), (when2, lit_50)], Some(lit_0)).unwrap();
900+
901+
// Apply DataFusion expression
902+
let df_result = case_expr.evaluate(&batch).unwrap();
903+
let df_array = df_result.into_array(batch.num_rows()).unwrap();
904+
905+
// Convert to Vortex expression
906+
let expr_convertor = DefaultExpressionConvertor::default();
907+
let vortex_expr = expr_convertor.try_convert_case_expr(&case_expr).unwrap();
908+
909+
// Convert batch to Vortex array
910+
let vortex_array: ArrayRef = ArrayRef::from_arrow(&batch, false).unwrap();
911+
912+
// Apply Vortex expression
913+
let session = VortexSession::default();
914+
let mut ctx = session.create_execution_ctx();
915+
let vortex_result = vortex_array
916+
.apply(&vortex_expr)
917+
.unwrap()
918+
.execute::<Canonical>(&mut ctx)
919+
.unwrap();
920+
921+
// Convert back to Arrow for comparison
922+
let vortex_as_arrow = vortex_result.into_primitive().as_slice::<i32>().to_vec();
923+
924+
// Convert DataFusion result to Vec for comparison
925+
let df_as_arrow: Vec<i32> = df_array
926+
.as_any()
927+
.downcast_ref::<Int32Array>()
928+
.unwrap()
929+
.values()
930+
.to_vec();
931+
932+
// Compare results
933+
// Expected: [0, 0, 50, 100, 100] for values [1, 5, 10, 15, 20]
934+
// value=1: not > 10, not > 5 -> ELSE 0
935+
// value=5: not > 10, not > 5 -> ELSE 0
936+
// value=10: not > 10, > 5 -> 50
937+
// value=15: > 10 -> 100
938+
// value=20: > 10 -> 100
939+
assert_eq!(df_as_arrow, vec![0, 0, 50, 100, 100]);
940+
assert_eq!(vortex_as_arrow, df_as_arrow);
941+
}
850942
}

0 commit comments

Comments
 (0)