@@ -8,6 +8,7 @@ use vortex_dtype::DType;
88use vortex_error:: VortexResult ;
99use vortex_error:: vortex_bail;
1010use vortex_error:: vortex_err;
11+ use vortex_scalar:: Scalar ;
1112use vortex_session:: VortexSession ;
1213
1314use crate :: ArrayRef ;
@@ -105,6 +106,13 @@ impl VTable for ListContains {
105106 . try_into ( )
106107 . map_err ( |_| vortex_err ! ( "Wrong number of arguments for ListContains expression" ) ) ?;
107108
109+ if let Some ( list_scalar) = list_array. as_constant ( )
110+ && let Some ( value_scalar) = value_array. as_constant ( )
111+ {
112+ let result = compute_contains_scalar ( & list_scalar, & value_scalar) ?;
113+ return Ok ( ExecutionResult :: constant ( result, args. row_count ) ) ;
114+ }
115+
108116 compute_list_contains ( list_array. as_ref ( ) , value_array. as_ref ( ) ) ?. execute ( args. ctx )
109117 }
110118
@@ -158,6 +166,23 @@ impl VTable for ListContains {
158166 }
159167}
160168
169+ fn compute_contains_scalar ( list : & Scalar , needle : & Scalar ) -> VortexResult < Scalar > {
170+ let nullability = list. dtype ( ) . nullability ( ) | needle. dtype ( ) . nullability ( ) ;
171+
172+ // Handle null list or null needle
173+ if list. is_null ( ) || needle. is_null ( ) {
174+ return Ok ( Scalar :: null ( DType :: Bool ( nullability) ) ) ;
175+ }
176+
177+ let list_scalar = list. as_list ( ) ;
178+ let elements = list_scalar
179+ . elements ( )
180+ . ok_or_else ( || vortex_err ! ( "Expected non-null list" ) ) ?;
181+
182+ let contains = elements. iter ( ) . any ( |elem| elem == needle) ;
183+ Ok ( Scalar :: bool ( contains, nullability) )
184+ }
185+
161186/// Creates an expression that checks if a value is contained in a list.
162187///
163188/// Returns a boolean array indicating whether the value appears in each list.
@@ -379,4 +404,32 @@ mod tests {
379404 let expr2 = list_contains ( root ( ) , lit ( 42 ) ) ;
380405 assert_eq ! ( expr2. to_string( ) , "contains($, 42i32)" ) ;
381406 }
407+
408+ #[ test]
409+ pub fn test_constant_scalars ( ) {
410+ let arr = test_array ( ) ;
411+
412+ // Both list and needle are constants - should use scalar optimization
413+ let list_scalar = Scalar :: list (
414+ Arc :: new ( DType :: Primitive ( I32 , Nullability :: NonNullable ) ) ,
415+ vec ! [ 1 . into( ) , 2 . into( ) , 3 . into( ) ] ,
416+ Nullability :: NonNullable ,
417+ ) ;
418+
419+ // Test contains true
420+ let expr = list_contains ( lit ( list_scalar. clone ( ) ) , lit ( 2i32 ) ) ;
421+ let result = arr. apply ( & expr) . unwrap ( ) ;
422+ assert_eq ! (
423+ result. scalar_at( 0 ) . unwrap( ) ,
424+ Scalar :: bool ( true , Nullability :: NonNullable )
425+ ) ;
426+
427+ // Test contains false
428+ let expr = list_contains ( lit ( list_scalar) , lit ( 42i32 ) ) ;
429+ let result = arr. apply ( & expr) . unwrap ( ) ;
430+ assert_eq ! (
431+ result. scalar_at( 0 ) . unwrap( ) ,
432+ Scalar :: bool ( false , Nullability :: NonNullable )
433+ ) ;
434+ }
382435}
0 commit comments