@@ -37,7 +37,6 @@ use crate::compute::zip;
3737use crate :: expr:: Arity ;
3838use crate :: expr:: ChildName ;
3939use crate :: expr:: ExecutionArgs ;
40- use crate :: expr:: ExecutionResult ;
4140use crate :: expr:: ExprId ;
4241use crate :: expr:: VTable ;
4342use crate :: expr:: VTableExt ;
@@ -79,12 +78,10 @@ impl VTable for CaseWhen {
7978 }
8079
8180 fn serialize ( & self , options : & Self :: Options ) -> VortexResult < Option < Vec < u8 > > > {
81+ let num_children =
82+ options. num_when_then_pairs * 2 + if options. has_else { 1 } else { 0 } ;
8283 Ok ( Some (
83- pb:: CaseWhenOpts {
84- num_when_then_pairs : options. num_when_then_pairs ,
85- has_else : options. has_else ,
86- }
87- . encode_to_vec ( ) ,
84+ pb:: CaseWhenOpts { num_children } . encode_to_vec ( ) ,
8885 ) )
8986 }
9087
@@ -95,8 +92,8 @@ impl VTable for CaseWhen {
9592 ) -> VortexResult < Self :: Options > {
9693 let opts = pb:: CaseWhenOpts :: decode ( metadata) ?;
9794 Ok ( CaseWhenOptions {
98- num_when_then_pairs : opts. num_when_then_pairs ,
99- has_else : opts. has_else ,
95+ num_when_then_pairs : opts. num_children / 2 ,
96+ has_else : opts. num_children % 2 == 1 ,
10097 } )
10198 }
10299
@@ -156,6 +153,18 @@ impl VTable for CaseWhen {
156153 // The return dtype is based on the first THEN expression (index 1)
157154 let then_dtype = & arg_dtypes[ 1 ] ;
158155
156+ // All THEN (and ELSE) value dtypes must match
157+ debug_assert ! (
158+ ( 0 ..options. num_when_then_pairs as usize ) . all( |i| {
159+ let idx = i * 2 + 1 ;
160+ & arg_dtypes[ idx] == then_dtype
161+ } ) ,
162+ "All THEN expression dtypes must match, got {:?}" ,
163+ ( 0 ..options. num_when_then_pairs as usize )
164+ . map( |i| & arg_dtypes[ i * 2 + 1 ] )
165+ . collect:: <Vec <_>>( )
166+ ) ;
167+
159168 // If there's no ELSE, the result is always nullable (unmatched rows are NULL)
160169 if !options. has_else {
161170 Ok ( then_dtype. as_nullable ( ) )
@@ -168,7 +177,7 @@ impl VTable for CaseWhen {
168177 & self ,
169178 options : & Self :: Options ,
170179 args : ExecutionArgs ,
171- ) -> VortexResult < ExecutionResult > {
180+ ) -> VortexResult < ArrayRef > {
172181 let row_count = args. row_count ;
173182 let num_pairs = options. num_when_then_pairs as usize ;
174183
@@ -222,7 +231,7 @@ impl VTable for CaseWhen {
222231 result = zip ( then_value. as_ref ( ) , result. as_ref ( ) , & mask) ?;
223232 }
224233
225- result . execute :: < ExecutionResult > ( args . ctx )
234+ Ok ( result )
226235 }
227236
228237 fn is_null_sensitive ( & self , _options : & Self :: Options ) -> bool {
@@ -236,7 +245,7 @@ impl VTable for CaseWhen {
236245}
237246
238247/// Efficient implementation for binary CASE WHEN (single when/then pair)
239- fn execute_binary_case_when ( _has_else : bool , args : ExecutionArgs ) -> VortexResult < ExecutionResult > {
248+ fn execute_binary_case_when ( _has_else : bool , args : ExecutionArgs ) -> VortexResult < ArrayRef > {
240249 let row_count = args. row_count ;
241250
242251 // Extract inputs based on arity: [condition, then_value] or [condition, then_value, else_value]
@@ -265,20 +274,17 @@ fn execute_binary_case_when(_has_else: bool, args: ExecutionArgs) -> VortexResul
265274
266275 // Short-circuit: all true -> just return THEN value
267276 if mask. all_true ( ) {
268- return then_value . execute :: < ExecutionResult > ( args . ctx ) ;
277+ return Ok ( then_value ) ;
269278 }
270279
271280 // Short-circuit: all false -> return ELSE value or NULL
272281 if mask. all_false ( ) {
273282 return match else_value {
274- Some ( else_value) => else_value . execute :: < ExecutionResult > ( args . ctx ) ,
283+ Some ( else_value) => Ok ( else_value ) ,
275284 None => {
276285 // Create NULL constant of appropriate type
277286 let then_dtype = then_value. dtype ( ) . as_nullable ( ) ;
278- Ok ( ExecutionResult :: constant (
279- Scalar :: null ( then_dtype) ,
280- row_count,
281- ) )
287+ Ok ( ConstantArray :: new ( Scalar :: null ( then_dtype) , row_count) . into_array ( ) )
282288 }
283289 } ;
284290 }
@@ -290,9 +296,7 @@ fn execute_binary_case_when(_has_else: bool, args: ExecutionArgs) -> VortexResul
290296 } ) ;
291297
292298 // Use zip to select: where mask is true, take then_value; else take else_value
293- let result = zip ( then_value. as_ref ( ) , else_value. as_ref ( ) , & mask) ?;
294-
295- result. execute :: < ExecutionResult > ( args. ctx )
299+ zip ( then_value. as_ref ( ) , else_value. as_ref ( ) , & mask)
296300}
297301
298302/// Creates an N-ary CASE WHEN expression from a flat list of children.
0 commit comments