Skip to content

Commit cf25ad4

Browse files
[FEAT]: allow is_in to take in Vec<Expr> instead of Expr (Eventual-Inc#3294)
closes Eventual-Inc#3140
1 parent 0709691 commit cf25ad4

File tree

11 files changed

+116
-40
lines changed

11 files changed

+116
-40
lines changed

daft/daft/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1087,7 +1087,7 @@ class PyExpr:
10871087
def is_null(self) -> PyExpr: ...
10881088
def not_null(self) -> PyExpr: ...
10891089
def fill_null(self, fill_value: PyExpr) -> PyExpr: ...
1090-
def is_in(self, other: PyExpr) -> PyExpr: ...
1090+
def is_in(self, other: list[PyExpr]) -> PyExpr: ...
10911091
def between(self, lower: PyExpr, upper: PyExpr) -> PyExpr: ...
10921092
def name(self) -> str: ...
10931093
def to_field(self, schema: PySchema) -> PyField: ...

daft/expressions/expressions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1137,7 +1137,7 @@ def is_in(self, other: Any) -> Expression:
11371137
series = item_to_series("items", other)
11381138
other = Expression._to_expression(series)
11391139

1140-
expr = self._expr.is_in(other._expr)
1140+
expr = self._expr.is_in([other._expr])
11411141
return Expression._from_pyexpr(expr)
11421142

11431143
def between(self, lower: Any, upper: Any) -> Expression:

src/daft-core/src/series/ops/concat.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use common_error::{DaftError, DaftResult};
2+
use daft_schema::dtype::DataType;
23

34
use crate::{
45
series::{IntoSeries, Series},
@@ -15,18 +16,25 @@ impl Series {
1516
)),
1617
[single_series] => Ok((*single_series).clone()),
1718
[first, rest @ ..] => {
19+
let mut series = vec![(*first).clone()];
20+
1821
let first_dtype = first.data_type();
1922
for s in rest {
20-
if first_dtype != s.data_type() {
23+
if s.data_type() == &DataType::Null {
24+
let s = Self::full_null("name", first_dtype, s.len());
25+
series.push(s);
26+
} else if first_dtype != s.data_type() {
2127
return Err(DaftError::TypeError(format!(
2228
"Series concat requires all data types to match. Found mismatched types. All types: {:?}",
2329
all_types
2430
)));
31+
} else {
32+
series.push((*s).clone());
2533
}
2634
}
2735

2836
with_match_daft_types!(first_dtype, |$T| {
29-
let downcasted = series.into_iter().map(|s| s.downcast::<<$T as DaftDataType>::ArrayType>()).collect::<DaftResult<Vec<_>>>()?;
37+
let downcasted = series.iter().map(|s| s.downcast::<<$T as DaftDataType>::ArrayType>()).collect::<DaftResult<Vec<_>>>()?;
3038
Ok(<$T as DaftDataType>::ArrayType::concat(downcasted.as_slice())?.into_series())
3139
})
3240
}

src/daft-dsl/src/expr/mod.rs

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use super::functions::FunctionExpr;
2525
use crate::{
2626
functions::{
2727
binary_op_display_without_formatter, function_display_without_formatter,
28-
function_semantic_id,
28+
function_semantic_id, is_in_display_without_formatter,
2929
python::PythonUDF,
3030
scalar_function_semantic_id,
3131
sketch::{HashableVecPercentiles, SketchExpr},
@@ -130,8 +130,8 @@ pub enum Expr {
130130
#[display("fill_null({_0}, {_1})")]
131131
FillNull(ExprRef, ExprRef),
132132

133-
#[display("{_0} in {_1}")]
134-
IsIn(ExprRef, ExprRef),
133+
#[display("{}", is_in_display_without_formatter(_0, _1)?)]
134+
IsIn(ExprRef, Vec<ExprRef>),
135135

136136
#[display("{_0} in [{_1},{_2}]")]
137137
Between(ExprRef, ExprRef, ExprRef),
@@ -603,7 +603,7 @@ impl Expr {
603603
Self::FillNull(self, fill_value).into()
604604
}
605605

606-
pub fn is_in(self: ExprRef, items: ExprRef) -> ExprRef {
606+
pub fn is_in(self: ExprRef, items: Vec<ExprRef>) -> ExprRef {
607607
Self::IsIn(self, items).into()
608608
}
609609

@@ -679,7 +679,10 @@ impl Expr {
679679
}
680680
Self::IsIn(expr, items) => {
681681
let child_id = expr.semantic_id(schema);
682-
let items_id = items.semantic_id(schema);
682+
let items_id = items.iter().fold(String::new(), |acc, item| {
683+
format!("{},{}", acc, item.semantic_id(schema))
684+
});
685+
683686
FieldID::new(format!("{child_id}.is_in({items_id})"))
684687
}
685688
Self::Between(expr, lower, upper) => {
@@ -741,7 +744,9 @@ impl Expr {
741744
Self::BinaryOp { left, right, .. } => {
742745
vec![left.clone(), right.clone()]
743746
}
744-
Self::IsIn(expr, items) => vec![expr.clone(), items.clone()],
747+
Self::IsIn(expr, items) => std::iter::once(expr.clone())
748+
.chain(items.iter().cloned())
749+
.collect::<Vec<_>>(),
745750
Self::Between(expr, lower, upper) => vec![expr.clone(), lower.clone(), upper.clone()],
746751
Self::IfElse {
747752
if_true,
@@ -788,10 +793,18 @@ impl Expr {
788793
left: children.first().expect("Should have 1 child").clone(),
789794
right: children.get(1).expect("Should have 2 child").clone(),
790795
},
791-
Self::IsIn(..) => Self::IsIn(
792-
children.first().expect("Should have 1 child").clone(),
793-
children.get(1).expect("Should have 2 child").clone(),
794-
),
796+
Self::IsIn(_, old_children) => {
797+
assert_eq!(
798+
children.len(),
799+
old_children.len() + 1,
800+
"Should have same number of children"
801+
);
802+
let mut children_iter = children.into_iter();
803+
let expr = children_iter.next().expect("Should have 1 child");
804+
let items = children_iter.collect();
805+
806+
Self::IsIn(expr, items)
807+
}
795808
Self::Between(..) => Self::Between(
796809
children.first().expect("Should have 1 child").clone(),
797810
children.get(1).expect("Should have 2 child").clone(),
@@ -865,10 +878,28 @@ impl Expr {
865878
}
866879
Self::IsIn(left, right) => {
867880
let left_field = left.to_field(schema)?;
868-
let right_field = right.to_field(schema)?;
881+
882+
let first_right_field = right
883+
.first()
884+
.expect("Should have at least 1 child")
885+
.to_field(schema)?;
886+
let all_same_type = right.iter().all(|expr| {
887+
let field = expr.to_field(schema).unwrap();
888+
// allow nulls to be compared with anything
889+
if field.dtype == DataType::Null || first_right_field.dtype == DataType::Null {
890+
return true;
891+
}
892+
field.dtype == first_right_field.dtype
893+
});
894+
if !all_same_type {
895+
return Err(DaftError::TypeError(format!(
896+
"Expected all arguments to be of the same type, but received {first_right_field} and others",
897+
)));
898+
}
899+
869900
let (result_type, _intermediate, _comp_type) =
870901
InferDataType::from(&left_field.dtype)
871-
.membership_op(&InferDataType::from(&right_field.dtype))?;
902+
.membership_op(&InferDataType::from(&first_right_field.dtype))?;
872903
Ok(Field::new(left_field.name.as_str(), result_type))
873904
}
874905
Self::Between(value, lower, upper) => {

src/daft-dsl/src/functions/mod.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,22 @@ pub fn function_display_without_formatter(
106106
Ok(f)
107107
}
108108

109+
pub fn is_in_display_without_formatter(
110+
expr: &ExprRef,
111+
inputs: &[ExprRef],
112+
) -> std::result::Result<String, std::fmt::Error> {
113+
let mut f = String::default();
114+
write!(&mut f, "{expr} IN (")?;
115+
for (i, input) in inputs.iter().enumerate() {
116+
if i != 0 {
117+
write!(&mut f, ", ")?;
118+
}
119+
write!(&mut f, "{input}")?;
120+
}
121+
write!(&mut f, ")")?;
122+
Ok(f)
123+
}
124+
109125
pub fn binary_op_display_without_formatter(
110126
op: &Operator,
111127
left: &ExprRef,

src/daft-dsl/src/python.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,10 @@ impl PyExpr {
452452
Ok(self.expr.clone().fill_null(fill_value.expr.clone()).into())
453453
}
454454

455-
pub fn is_in(&self, other: &Self) -> PyResult<Self> {
456-
Ok(self.expr.clone().is_in(other.expr.clone()).into())
455+
pub fn is_in(&self, other: Vec<Self>) -> PyResult<Self> {
456+
let other = other.into_iter().map(|e| e.into()).collect();
457+
458+
Ok(self.expr.clone().is_in(other).into())
457459
}
458460

459461
pub fn between(&self, lower: &Self, upper: &Self) -> PyResult<Self> {

src/daft-logical-plan/src/ops/project.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,12 +266,23 @@ fn replace_column_with_semantic_id(
266266
Expr::IsIn(child, items) => {
267267
let child =
268268
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema);
269-
let items =
270-
replace_column_with_semantic_id(items.clone(), subexprs_to_replace, schema);
271-
if !child.transformed && !items.transformed {
269+
270+
let transforms = items
271+
.iter()
272+
.map(|e| {
273+
replace_column_with_semantic_id(e.clone(), subexprs_to_replace, schema)
274+
})
275+
.collect::<Vec<_>>();
276+
if !child.transformed && transforms.iter().all(|e| !e.transformed) {
272277
Transformed::no(e)
273278
} else {
274-
Transformed::yes(Expr::IsIn(child.data, items.data).into())
279+
Transformed::yes(
280+
Expr::IsIn(
281+
child.data,
282+
transforms.iter().map(|t| t.data.clone()).collect(),
283+
)
284+
.into(),
285+
)
275286
}
276287
}
277288
Expr::Between(child, lower, upper) => {

src/daft-logical-plan/src/partitioning.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,10 @@ fn translate_clustering_spec_expr(
291291
}
292292
Expr::IsIn(child, items) => {
293293
let newchild = translate_clustering_spec_expr(child, old_colname_to_new_colname)?;
294-
let newitems = translate_clustering_spec_expr(items, old_colname_to_new_colname)?;
294+
let newitems = items
295+
.iter()
296+
.map(|e| translate_clustering_spec_expr(e, old_colname_to_new_colname))
297+
.collect::<Result<Vec<_>, _>>()?;
295298
Ok(newchild.is_in(newitems))
296299
}
297300
Expr::Between(child, lower, upper) => {

src/daft-sql/src/planner.rs

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,13 +1074,6 @@ impl SQLPlanner {
10741074
})
10751075
}
10761076

1077-
fn plan_lit(&self, expr: &sqlparser::ast::Expr) -> SQLPlannerResult<LiteralValue> {
1078-
if let sqlparser::ast::Expr::Value(v) = expr {
1079-
self.value_to_lit(v)
1080-
} else {
1081-
invalid_operation_err!("Only string, number, boolean and null literals are supported. Instead found: `{expr}`");
1082-
}
1083-
}
10841077
pub(crate) fn plan_expr(&self, expr: &sqlparser::ast::Expr) -> SQLPlannerResult<ExprRef> {
10851078
use sqlparser::ast::Expr as SQLExpr;
10861079
match expr {
@@ -1134,14 +1127,10 @@ impl SQLPlanner {
11341127
let expr = self.plan_expr(expr)?;
11351128
let list = list
11361129
.iter()
1137-
.map(|e| self.plan_lit(e))
1130+
.map(|e| self.plan_expr(e))
11381131
.collect::<SQLPlannerResult<Vec<_>>>()?;
1139-
// We should really have a better way to use `is_in` instead of all of this extra wrapping of the values
1140-
let series = literals_to_series(&list)?;
1141-
let series_lit = LiteralValue::Series(series);
1142-
let series_expr = Expr::Literal(series_lit);
1143-
let series_expr_arc = Arc::new(series_expr);
1144-
let expr = expr.is_in(series_expr_arc);
1132+
1133+
let expr = expr.is_in(list);
11451134
if *negated {
11461135
Ok(expr.not())
11471136
} else {

src/daft-table/src/lib.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,9 +527,16 @@ impl Table {
527527
let fill_value = self.eval_expression(fill_value)?;
528528
self.eval_expression(child)?.fill_null(&fill_value)
529529
}
530-
IsIn(child, items) => self
530+
IsIn(child, items) => {
531+
let items = items.iter().map(|i| self.eval_expression(i)).collect::<DaftResult<Vec<_>>>()?;
532+
533+
let items = items.iter().collect::<Vec<&Series>>();
534+
let s = Series::concat(items.as_slice())?;
535+
self
531536
.eval_expression(child)?
532-
.is_in(&self.eval_expression(items)?),
537+
.is_in(&s)
538+
}
539+
533540
Between(child, lower, upper) => self
534541
.eval_expression(child)?
535542
.between(&self.eval_expression(lower)?, &self.eval_expression(upper)?),

0 commit comments

Comments
 (0)