Skip to content

Commit 49eceb8

Browse files
committed
PR comments
1 parent a917a38 commit 49eceb8

3 files changed

Lines changed: 158 additions & 45 deletions

File tree

vortex-array/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ name = "expr_large_struct_pack"
129129
path = "benches/expr/large_struct_pack.rs"
130130
harness = false
131131

132+
[[bench]]
133+
name = "expr_case_when"
134+
path = "benches/expr/case_when_bench.rs"
135+
harness = false
136+
132137
[[bench]]
133138
name = "chunked_dict_builder"
134139
harness = false
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#![allow(clippy::unwrap_used)]
5+
#![allow(clippy::cast_possible_truncation)]
6+
7+
use divan::Bencher;
8+
use vortex_array::ArrayRef;
9+
use vortex_array::IntoArray;
10+
use vortex_array::arrays::StructArray;
11+
use vortex_array::expr::case_when;
12+
use vortex_array::expr::get_item;
13+
use vortex_array::expr::gt;
14+
use vortex_array::expr::lit;
15+
use vortex_array::expr::nested_case_when;
16+
use vortex_array::expr::root;
17+
use vortex_array::validity::Validity;
18+
use vortex_buffer::Buffer;
19+
use vortex_dtype::FieldNames;
20+
21+
fn main() {
22+
divan::main();
23+
}
24+
25+
fn make_struct_array(size: usize) -> ArrayRef {
26+
let data: Buffer<i32> = (0..size as i32).collect();
27+
let field = data.into_array();
28+
StructArray::try_new(
29+
FieldNames::from(["value"]),
30+
vec![field],
31+
size,
32+
Validity::NonNullable,
33+
)
34+
.unwrap()
35+
.into_array()
36+
}
37+
38+
/// Benchmark a simple binary CASE WHEN with varying array sizes.
39+
#[divan::bench(args = [1000, 10000, 100000])]
40+
fn case_when_simple(bencher: Bencher, size: usize) {
41+
let array = make_struct_array(size);
42+
43+
// CASE WHEN value > 500 THEN 100 ELSE 0 END
44+
let expr = case_when(
45+
gt(get_item("value", root()), lit(500i32)),
46+
lit(100i32),
47+
lit(0i32),
48+
);
49+
50+
bencher
51+
.with_inputs(|| (&expr, &array))
52+
.bench_refs(|(expr, array)| expr.evaluate(array).unwrap());
53+
}
54+
55+
/// Benchmark nested CASE WHEN with multiple conditions.
56+
#[divan::bench(args = [1000, 10000, 100000])]
57+
fn case_when_nested_3_conditions(bencher: Bencher, size: usize) {
58+
let array = make_struct_array(size);
59+
60+
// CASE WHEN value > 750 THEN 3 WHEN value > 500 THEN 2 WHEN value > 250 THEN 1 ELSE 0 END
61+
let expr = nested_case_when(
62+
vec![
63+
(gt(get_item("value", root()), lit(750i32)), lit(3i32)),
64+
(gt(get_item("value", root()), lit(500i32)), lit(2i32)),
65+
(gt(get_item("value", root()), lit(250i32)), lit(1i32)),
66+
],
67+
Some(lit(0i32)),
68+
);
69+
70+
bencher
71+
.with_inputs(|| (&expr, &array))
72+
.bench_refs(|(expr, array)| expr.evaluate(array).unwrap());
73+
}
74+
75+
/// Benchmark CASE WHEN where all conditions are true (short-circuit path).
76+
#[divan::bench(args = [1000, 10000, 100000])]
77+
fn case_when_all_true(bencher: Bencher, size: usize) {
78+
let array = make_struct_array(size);
79+
80+
// CASE WHEN value >= 0 THEN 100 ELSE 0 END (always true for our data)
81+
let expr = case_when(
82+
gt(get_item("value", root()), lit(-1i32)),
83+
lit(100i32),
84+
lit(0i32),
85+
);
86+
87+
bencher
88+
.with_inputs(|| (&expr, &array))
89+
.bench_refs(|(expr, array)| expr.evaluate(array).unwrap());
90+
}
91+
92+
/// Benchmark CASE WHEN where all conditions are false (short-circuit path).
93+
#[divan::bench(args = [1000, 10000, 100000])]
94+
fn case_when_all_false(bencher: Bencher, size: usize) {
95+
let array = make_struct_array(size);
96+
97+
// CASE WHEN value > 1000000 THEN 100 ELSE 0 END (always false for our data)
98+
let expr = case_when(
99+
gt(get_item("value", root()), lit(1_000_000i32)),
100+
lit(100i32),
101+
lit(0i32),
102+
);
103+
104+
bencher
105+
.with_inputs(|| (&expr, &array))
106+
.bench_refs(|(expr, array)| expr.evaluate(array).unwrap());
107+
}

vortex-array/src/expr/exprs/case_when.rs

Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@ use std::hash::Hash;
1717

1818
use prost::Message;
1919
use vortex_dtype::DType;
20-
use vortex_error::VortexExpect;
2120
use vortex_error::VortexResult;
2221
use vortex_error::vortex_bail;
22+
use vortex_error::vortex_panic;
2323
use vortex_proto::expr as pb;
2424
use vortex_scalar::Scalar;
2525

26+
use crate::ArrayRef;
2627
use crate::IntoArray;
2728
use crate::arrays::BoolArray;
2829
use crate::arrays::ConstantArray;
@@ -132,17 +133,33 @@ impl VTable for CaseWhen {
132133

133134
fn execute(
134135
&self,
135-
options: &Self::Options,
136-
mut args: ExecutionArgs,
136+
_options: &Self::Options,
137+
args: ExecutionArgs,
137138
) -> VortexResult<ExecutionResult> {
138139
let row_count = args.row_count;
139140

140-
// Extract inputs: condition, then_value, else_value (optional)
141-
let condition = args.inputs.remove(0);
142-
let then_value = args.inputs.remove(0);
141+
// Extract inputs based on arity: [condition, then_value] or [condition, then_value, else_value]
142+
let (condition, then_value, else_value) = match args.inputs.len() {
143+
2 => {
144+
let [condition, then_value]: [ArrayRef; 2] = args
145+
.inputs
146+
.try_into()
147+
.map_err(|_| vortex_error::vortex_err!("Expected 2 inputs"))?;
148+
(condition, then_value, None)
149+
}
150+
3 => {
151+
let [condition, then_value, else_value]: [ArrayRef; 3] = args
152+
.inputs
153+
.try_into()
154+
.map_err(|_| vortex_error::vortex_err!("Expected 3 inputs"))?;
155+
(condition, then_value, Some(else_value))
156+
}
157+
n => vortex_bail!("CaseWhen expects 2 or 3 inputs, got {}", n),
158+
};
143159

144160
// Execute condition to get a BoolArray
145161
let cond_bool = condition.execute::<BoolArray>(args.ctx)?;
162+
// SQL semantics: NULL condition is treated as FALSE (i.e., we take the ELSE branch)
146163
let mask = cond_bool.to_mask_fill_null_false();
147164

148165
// Short-circuit: all true -> just return THEN value
@@ -152,27 +169,24 @@ impl VTable for CaseWhen {
152169

153170
// Short-circuit: all false -> return ELSE value or NULL
154171
if mask.all_false() {
155-
return if options.has_else {
156-
let else_value = args.inputs.remove(0);
157-
else_value.execute::<ExecutionResult>(args.ctx)
158-
} else {
159-
// Create NULL constant of appropriate type
160-
let then_dtype = then_value.dtype().as_nullable();
161-
Ok(ExecutionResult::constant(
162-
Scalar::null(then_dtype),
163-
row_count,
164-
))
172+
return match else_value {
173+
Some(else_value) => else_value.execute::<ExecutionResult>(args.ctx),
174+
None => {
175+
// Create NULL constant of appropriate type
176+
let then_dtype = then_value.dtype().as_nullable();
177+
Ok(ExecutionResult::constant(
178+
Scalar::null(then_dtype),
179+
row_count,
180+
))
181+
}
165182
};
166183
}
167184

168-
// Get else value for zip
169-
let else_value = if options.has_else {
170-
args.inputs.pop().vortex_expect("Missing else input")
171-
} else {
172-
// Create NULL constant array for the else branch
185+
// Get else value for zip (create NULL constant if no else clause)
186+
let else_value = else_value.unwrap_or_else(|| {
173187
let then_dtype = then_value.dtype().as_nullable();
174188
ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
175-
};
189+
});
176190

177191
// Use zip to select: where mask is true, take then_value; else take else_value
178192
let result = zip(then_value.as_ref(), else_value.as_ref(), &mask)?;
@@ -259,29 +273,16 @@ pub fn nested_case_when(
259273
"nested_case_when requires at least one when/then pair"
260274
);
261275

262-
// Build from right to left (innermost first)
263-
// Using fold to avoid expect/unwrap
264-
let pairs: Vec<_> = when_then_pairs.into_iter().rev().collect();
265-
let first_pair = &pairs[0]; // Safe: assert guarantees non-empty
266-
let remaining = &pairs[1..];
267-
268-
// Build innermost expression
269-
let mut result = if let Some(ref else_expr) = else_value {
270-
case_when(
271-
first_pair.0.clone(),
272-
first_pair.1.clone(),
273-
else_expr.clone(),
274-
)
275-
} else {
276-
case_when_no_else(first_pair.0.clone(), first_pair.1.clone())
277-
};
278-
279-
// Wrap with remaining pairs
280-
for (condition, then_value) in remaining {
281-
result = case_when(condition.clone(), then_value.clone(), result);
282-
}
283-
284-
result
276+
// Build from right to left (innermost first) using rfold
277+
when_then_pairs
278+
.into_iter()
279+
.rfold(else_value, |acc, (condition, then_value)| {
280+
Some(match acc {
281+
Some(else_expr) => case_when(condition, then_value, else_expr),
282+
None => case_when_no_else(condition, then_value),
283+
})
284+
})
285+
.unwrap_or_else(|| vortex_panic!("rfold on non-empty iterator always produces Some"))
285286
}
286287

287288
#[cfg(test)]

0 commit comments

Comments
 (0)