Skip to content

Commit 0a97ebb

Browse files
authored
refactor: update CASE WHEN options to use a single num_children field (#22)
1 parent da5f802 commit 0a97ebb

4 files changed

Lines changed: 38 additions & 29 deletions

File tree

vortex-array/benches/expr/case_when_bench.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ fn make_struct_array(size: usize) -> ArrayRef {
4444
}
4545

4646
/// Benchmark a simple binary CASE WHEN with varying array sizes.
47-
#[divan::bench(args = [10000, 100000, 1000000])]
47+
#[divan::bench(args = [1000, 10000, 100000])]
4848
fn case_when_simple(bencher: Bencher, size: usize) {
4949
let array = make_struct_array(size);
5050

@@ -94,7 +94,7 @@ fn case_when_nary_3_conditions(bencher: Bencher, size: usize) {
9494
}
9595

9696
/// Benchmark CASE WHEN where all conditions are true (short-circuit path).
97-
#[divan::bench(args = [10000, 100000, 1000000])]
97+
#[divan::bench(args = [1000, 10000, 100000])]
9898
fn case_when_all_true(bencher: Bencher, size: usize) {
9999
let array = make_struct_array(size);
100100

@@ -117,7 +117,7 @@ fn case_when_all_true(bencher: Bencher, size: usize) {
117117
}
118118

119119
/// Benchmark CASE WHEN where all conditions are false (short-circuit path).
120-
#[divan::bench(args = [10000, 100000, 1000000])]
120+
#[divan::bench(args = [1000, 10000, 100000])]
121121
fn case_when_all_false(bencher: Bencher, size: usize) {
122122
let array = make_struct_array(size);
123123

@@ -181,7 +181,7 @@ fn case_when_nary_10_conditions(bencher: Bencher, size: usize) {
181181
}
182182

183183
/// Benchmark n-ary CASE WHEN with 100 conditions.
184-
#[divan::bench(args = [10000, 100000, 1000000])]
184+
#[divan::bench(args = [1000, 10000, 100000])]
185185
fn case_when_nary_100_conditions(bencher: Bencher, size: usize) {
186186
use vortex_array::expr::Expression;
187187

vortex-array/src/expr/exprs/case_when.rs

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ use crate::compute::zip;
3737
use crate::expr::Arity;
3838
use crate::expr::ChildName;
3939
use crate::expr::ExecutionArgs;
40-
use crate::expr::ExecutionResult;
4140
use crate::expr::ExprId;
4241
use crate::expr::VTable;
4342
use crate::expr::VTableExt;
@@ -79,12 +78,10 @@ impl VTable for CaseWhen {
7978
}
8079

8180
fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
81+
let num_children =
82+
options.num_when_then_pairs * 2 + if options.has_else { 1 } else { 0 };
8283
Ok(Some(
83-
pb::CaseWhenOpts {
84-
num_when_then_pairs: options.num_when_then_pairs,
85-
has_else: options.has_else,
86-
}
87-
.encode_to_vec(),
84+
pb::CaseWhenOpts { num_children }.encode_to_vec(),
8885
))
8986
}
9087

@@ -95,8 +92,8 @@ impl VTable for CaseWhen {
9592
) -> VortexResult<Self::Options> {
9693
let opts = pb::CaseWhenOpts::decode(metadata)?;
9794
Ok(CaseWhenOptions {
98-
num_when_then_pairs: opts.num_when_then_pairs,
99-
has_else: opts.has_else,
95+
num_when_then_pairs: opts.num_children / 2,
96+
has_else: opts.num_children % 2 == 1,
10097
})
10198
}
10299

@@ -156,6 +153,18 @@ impl VTable for CaseWhen {
156153
// The return dtype is based on the first THEN expression (index 1)
157154
let then_dtype = &arg_dtypes[1];
158155

156+
// All THEN (and ELSE) value dtypes must match
157+
debug_assert!(
158+
(0..options.num_when_then_pairs as usize).all(|i| {
159+
let idx = i * 2 + 1;
160+
&arg_dtypes[idx] == then_dtype
161+
}),
162+
"All THEN expression dtypes must match, got {:?}",
163+
(0..options.num_when_then_pairs as usize)
164+
.map(|i| &arg_dtypes[i * 2 + 1])
165+
.collect::<Vec<_>>()
166+
);
167+
159168
// If there's no ELSE, the result is always nullable (unmatched rows are NULL)
160169
if !options.has_else {
161170
Ok(then_dtype.as_nullable())
@@ -168,7 +177,7 @@ impl VTable for CaseWhen {
168177
&self,
169178
options: &Self::Options,
170179
args: ExecutionArgs,
171-
) -> VortexResult<ExecutionResult> {
180+
) -> VortexResult<ArrayRef> {
172181
let row_count = args.row_count;
173182
let num_pairs = options.num_when_then_pairs as usize;
174183

@@ -222,7 +231,7 @@ impl VTable for CaseWhen {
222231
result = zip(then_value.as_ref(), result.as_ref(), &mask)?;
223232
}
224233

225-
result.execute::<ExecutionResult>(args.ctx)
234+
Ok(result)
226235
}
227236

228237
fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
@@ -236,7 +245,7 @@ impl VTable for CaseWhen {
236245
}
237246

238247
/// Efficient implementation for binary CASE WHEN (single when/then pair)
239-
fn execute_binary_case_when(_has_else: bool, args: ExecutionArgs) -> VortexResult<ExecutionResult> {
248+
fn execute_binary_case_when(_has_else: bool, args: ExecutionArgs) -> VortexResult<ArrayRef> {
240249
let row_count = args.row_count;
241250

242251
// Extract inputs based on arity: [condition, then_value] or [condition, then_value, else_value]
@@ -265,20 +274,17 @@ fn execute_binary_case_when(_has_else: bool, args: ExecutionArgs) -> VortexResul
265274

266275
// Short-circuit: all true -> just return THEN value
267276
if mask.all_true() {
268-
return then_value.execute::<ExecutionResult>(args.ctx);
277+
return Ok(then_value);
269278
}
270279

271280
// Short-circuit: all false -> return ELSE value or NULL
272281
if mask.all_false() {
273282
return match else_value {
274-
Some(else_value) => else_value.execute::<ExecutionResult>(args.ctx),
283+
Some(else_value) => Ok(else_value),
275284
None => {
276285
// Create NULL constant of appropriate type
277286
let then_dtype = then_value.dtype().as_nullable();
278-
Ok(ExecutionResult::constant(
279-
Scalar::null(then_dtype),
280-
row_count,
281-
))
287+
Ok(ConstantArray::new(Scalar::null(then_dtype), row_count).into_array())
282288
}
283289
};
284290
}
@@ -290,9 +296,7 @@ fn execute_binary_case_when(_has_else: bool, args: ExecutionArgs) -> VortexResul
290296
});
291297

292298
// Use zip to select: where mask is true, take then_value; else take else_value
293-
let result = zip(then_value.as_ref(), else_value.as_ref(), &mask)?;
294-
295-
result.execute::<ExecutionResult>(args.ctx)
299+
zip(then_value.as_ref(), else_value.as_ref(), &mask)
296300
}
297301

298302
/// Creates an N-ary CASE WHEN expression from a flat list of children.

vortex-proto/proto/expr.proto

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,10 @@ message SelectOpts {
8282
}
8383

8484
// Options for `vortex.case_when`
85+
// Encodes num_when_then_pairs and has_else into a single u32 (num_children).
86+
// num_children = num_when_then_pairs * 2 + (has_else ? 1 : 0)
87+
// has_else = num_children % 2 == 1
88+
// num_when_then_pairs = num_children / 2
8589
message CaseWhenOpts {
86-
uint32 num_when_then_pairs = 1;
87-
bool has_else = 2;
90+
uint32 num_children = 1;
8891
}

vortex-proto/src/generated/vortex.expr.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,12 @@ pub mod select_opts {
146146
}
147147
}
148148
/// Options for `vortex.case_when`
149+
/// Encodes num_when_then_pairs and has_else into a single u32 (num_children).
150+
/// num_children = num_when_then_pairs * 2 + (has_else ? 1 : 0)
151+
/// has_else = num_children % 2 == 1
152+
/// num_when_then_pairs = num_children / 2
149153
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
150154
pub struct CaseWhenOpts {
151155
#[prost(uint32, tag = "1")]
152-
pub num_when_then_pairs: u32,
153-
#[prost(bool, tag = "2")]
154-
pub has_else: bool,
156+
pub num_children: u32,
155157
}

0 commit comments

Comments
 (0)