@@ -17,12 +17,13 @@ use std::hash::Hash;
1717
1818use prost:: Message ;
1919use vortex_dtype:: DType ;
20- use vortex_error:: VortexExpect ;
2120use vortex_error:: VortexResult ;
2221use vortex_error:: vortex_bail;
22+ use vortex_error:: vortex_panic;
2323use vortex_proto:: expr as pb;
2424use vortex_scalar:: Scalar ;
2525
26+ use crate :: ArrayRef ;
2627use crate :: IntoArray ;
2728use crate :: arrays:: BoolArray ;
2829use crate :: arrays:: ConstantArray ;
@@ -132,17 +133,33 @@ impl VTable for CaseWhen {
132133
133134 fn execute (
134135 & self ,
135- options : & Self :: Options ,
136- mut args : ExecutionArgs ,
136+ _options : & Self :: Options ,
137+ args : ExecutionArgs ,
137138 ) -> VortexResult < ExecutionResult > {
138139 let row_count = args. row_count ;
139140
140- // Extract inputs: condition, then_value, else_value (optional)
141- let condition = args. inputs . remove ( 0 ) ;
142- let then_value = args. inputs . remove ( 0 ) ;
141+ // Extract inputs based on arity: [condition, then_value] or [condition, then_value, else_value]
142+ let ( condition, then_value, else_value) = match args. inputs . len ( ) {
143+ 2 => {
144+ let [ condition, then_value] : [ ArrayRef ; 2 ] = args
145+ . inputs
146+ . try_into ( )
147+ . map_err ( |_| vortex_error:: vortex_err!( "Expected 2 inputs" ) ) ?;
148+ ( condition, then_value, None )
149+ }
150+ 3 => {
151+ let [ condition, then_value, else_value] : [ ArrayRef ; 3 ] = args
152+ . inputs
153+ . try_into ( )
154+ . map_err ( |_| vortex_error:: vortex_err!( "Expected 3 inputs" ) ) ?;
155+ ( condition, then_value, Some ( else_value) )
156+ }
157+ n => vortex_bail ! ( "CaseWhen expects 2 or 3 inputs, got {}" , n) ,
158+ } ;
143159
144160 // Execute condition to get a BoolArray
145161 let cond_bool = condition. execute :: < BoolArray > ( args. ctx ) ?;
162+ // SQL semantics: NULL condition is treated as FALSE (i.e., we take the ELSE branch)
146163 let mask = cond_bool. to_mask_fill_null_false ( ) ;
147164
148165 // Short-circuit: all true -> just return THEN value
@@ -152,27 +169,24 @@ impl VTable for CaseWhen {
152169
153170 // Short-circuit: all false -> return ELSE value or NULL
154171 if mask. all_false ( ) {
155- return if options . has_else {
156- let else_value = args . inputs . remove ( 0 ) ;
157- else_value . execute :: < ExecutionResult > ( args . ctx )
158- } else {
159- // Create NULL constant of appropriate type
160- let then_dtype = then_value . dtype ( ) . as_nullable ( ) ;
161- Ok ( ExecutionResult :: constant (
162- Scalar :: null ( then_dtype ) ,
163- row_count ,
164- ) )
172+ return match else_value {
173+ Some ( else_value) => else_value . execute :: < ExecutionResult > ( args . ctx ) ,
174+ None => {
175+ // Create NULL constant of appropriate type
176+ let then_dtype = then_value . dtype ( ) . as_nullable ( ) ;
177+ Ok ( ExecutionResult :: constant (
178+ Scalar :: null ( then_dtype ) ,
179+ row_count ,
180+ ) )
181+ }
165182 } ;
166183 }
167184
168- // Get else value for zip
169- let else_value = if options. has_else {
170- args. inputs . pop ( ) . vortex_expect ( "Missing else input" )
171- } else {
172- // Create NULL constant array for the else branch
185+ // Get else value for zip (create NULL constant if no else clause)
186+ let else_value = else_value. unwrap_or_else ( || {
173187 let then_dtype = then_value. dtype ( ) . as_nullable ( ) ;
174188 ConstantArray :: new ( Scalar :: null ( then_dtype) , row_count) . into_array ( )
175- } ;
189+ } ) ;
176190
177191 // Use zip to select: where mask is true, take then_value; else take else_value
178192 let result = zip ( then_value. as_ref ( ) , else_value. as_ref ( ) , & mask) ?;
@@ -259,29 +273,16 @@ pub fn nested_case_when(
259273 "nested_case_when requires at least one when/then pair"
260274 ) ;
261275
262- // Build from right to left (innermost first)
263- // Using fold to avoid expect/unwrap
264- let pairs: Vec < _ > = when_then_pairs. into_iter ( ) . rev ( ) . collect ( ) ;
265- let first_pair = & pairs[ 0 ] ; // Safe: assert guarantees non-empty
266- let remaining = & pairs[ 1 ..] ;
267-
268- // Build innermost expression
269- let mut result = if let Some ( ref else_expr) = else_value {
270- case_when (
271- first_pair. 0 . clone ( ) ,
272- first_pair. 1 . clone ( ) ,
273- else_expr. clone ( ) ,
274- )
275- } else {
276- case_when_no_else ( first_pair. 0 . clone ( ) , first_pair. 1 . clone ( ) )
277- } ;
278-
279- // Wrap with remaining pairs
280- for ( condition, then_value) in remaining {
281- result = case_when ( condition. clone ( ) , then_value. clone ( ) , result) ;
282- }
283-
284- result
276+ // Build from right to left (innermost first) using rfold
277+ when_then_pairs
278+ . into_iter ( )
279+ . rfold ( else_value, |acc, ( condition, then_value) | {
280+ Some ( match acc {
281+ Some ( else_expr) => case_when ( condition, then_value, else_expr) ,
282+ None => case_when_no_else ( condition, then_value) ,
283+ } )
284+ } )
285+ . unwrap_or_else ( || vortex_panic ! ( "rfold on non-empty iterator always produces Some" ) )
285286}
286287
287288#[ cfg( test) ]
0 commit comments