Skip to content

Commit 8efd26e

Browse files
blagininclaude
andauthored
chore: list contains over scalars (vortex-data#6242)
Signed-off-by: blaginin <github@blaginin.me> Co-authored-by: Claude <claude@anthropic.com>
1 parent d8ee003 commit 8efd26e

1 file changed

Lines changed: 53 additions & 0 deletions

File tree

vortex-array/src/expr/exprs/list_contains.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use vortex_dtype::DType;
88
use vortex_error::VortexResult;
99
use vortex_error::vortex_bail;
1010
use vortex_error::vortex_err;
11+
use vortex_scalar::Scalar;
1112
use vortex_session::VortexSession;
1213

1314
use crate::ArrayRef;
@@ -105,6 +106,13 @@ impl VTable for ListContains {
105106
.try_into()
106107
.map_err(|_| vortex_err!("Wrong number of arguments for ListContains expression"))?;
107108

109+
if let Some(list_scalar) = list_array.as_constant()
110+
&& let Some(value_scalar) = value_array.as_constant()
111+
{
112+
let result = compute_contains_scalar(&list_scalar, &value_scalar)?;
113+
return Ok(ExecutionResult::constant(result, args.row_count));
114+
}
115+
108116
compute_list_contains(list_array.as_ref(), value_array.as_ref())?.execute(args.ctx)
109117
}
110118

@@ -158,6 +166,23 @@ impl VTable for ListContains {
158166
}
159167
}
160168

169+
fn compute_contains_scalar(list: &Scalar, needle: &Scalar) -> VortexResult<Scalar> {
170+
let nullability = list.dtype().nullability() | needle.dtype().nullability();
171+
172+
// Handle null list or null needle
173+
if list.is_null() || needle.is_null() {
174+
return Ok(Scalar::null(DType::Bool(nullability)));
175+
}
176+
177+
let list_scalar = list.as_list();
178+
let elements = list_scalar
179+
.elements()
180+
.ok_or_else(|| vortex_err!("Expected non-null list"))?;
181+
182+
let contains = elements.iter().any(|elem| elem == needle);
183+
Ok(Scalar::bool(contains, nullability))
184+
}
185+
161186
/// Creates an expression that checks if a value is contained in a list.
162187
///
163188
/// Returns a boolean array indicating whether the value appears in each list.
@@ -379,4 +404,32 @@ mod tests {
379404
let expr2 = list_contains(root(), lit(42));
380405
assert_eq!(expr2.to_string(), "contains($, 42i32)");
381406
}
407+
408+
#[test]
409+
pub fn test_constant_scalars() {
410+
let arr = test_array();
411+
412+
// Both list and needle are constants - should use scalar optimization
413+
let list_scalar = Scalar::list(
414+
Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
415+
vec![1.into(), 2.into(), 3.into()],
416+
Nullability::NonNullable,
417+
);
418+
419+
// Test contains true
420+
let expr = list_contains(lit(list_scalar.clone()), lit(2i32));
421+
let result = arr.apply(&expr).unwrap();
422+
assert_eq!(
423+
result.scalar_at(0).unwrap(),
424+
Scalar::bool(true, Nullability::NonNullable)
425+
);
426+
427+
// Test contains false
428+
let expr = list_contains(lit(list_scalar), lit(42i32));
429+
let result = arr.apply(&expr).unwrap();
430+
assert_eq!(
431+
result.scalar_at(0).unwrap(),
432+
Scalar::bool(false, Nullability::NonNullable)
433+
);
434+
}
382435
}

0 commit comments

Comments
 (0)