@@ -25,7 +25,7 @@ use super::functions::FunctionExpr;
2525use crate :: {
2626 functions:: {
2727 binary_op_display_without_formatter, function_display_without_formatter,
28- function_semantic_id,
28+ function_semantic_id, is_in_display_without_formatter ,
2929 python:: PythonUDF ,
3030 scalar_function_semantic_id,
3131 sketch:: { HashableVecPercentiles , SketchExpr } ,
@@ -130,8 +130,8 @@ pub enum Expr {
130130 #[ display( "fill_null({_0}, {_1})" ) ]
131131 FillNull ( ExprRef , ExprRef ) ,
132132
133- #[ display( "{_0} in {_1}" ) ]
134- IsIn ( ExprRef , ExprRef ) ,
133+ #[ display( "{}" , is_in_display_without_formatter ( _0 , _1 ) ? ) ]
134+ IsIn ( ExprRef , Vec < ExprRef > ) ,
135135
136136 #[ display( "{_0} in [{_1},{_2}]" ) ]
137137 Between ( ExprRef , ExprRef , ExprRef ) ,
@@ -603,7 +603,7 @@ impl Expr {
603603 Self :: FillNull ( self , fill_value) . into ( )
604604 }
605605
606- pub fn is_in ( self : ExprRef , items : ExprRef ) -> ExprRef {
606+ pub fn is_in ( self : ExprRef , items : Vec < ExprRef > ) -> ExprRef {
607607 Self :: IsIn ( self , items) . into ( )
608608 }
609609
@@ -679,7 +679,10 @@ impl Expr {
679679 }
680680 Self :: IsIn ( expr, items) => {
681681 let child_id = expr. semantic_id ( schema) ;
682- let items_id = items. semantic_id ( schema) ;
682+ let items_id = items. iter ( ) . fold ( String :: new ( ) , |acc, item| {
683+ format ! ( "{},{}" , acc, item. semantic_id( schema) )
684+ } ) ;
685+
683686 FieldID :: new ( format ! ( "{child_id}.is_in({items_id})" ) )
684687 }
685688 Self :: Between ( expr, lower, upper) => {
@@ -741,7 +744,9 @@ impl Expr {
741744 Self :: BinaryOp { left, right, .. } => {
742745 vec ! [ left. clone( ) , right. clone( ) ]
743746 }
744- Self :: IsIn ( expr, items) => vec ! [ expr. clone( ) , items. clone( ) ] ,
747+ Self :: IsIn ( expr, items) => std:: iter:: once ( expr. clone ( ) )
748+ . chain ( items. iter ( ) . cloned ( ) )
749+ . collect :: < Vec < _ > > ( ) ,
745750 Self :: Between ( expr, lower, upper) => vec ! [ expr. clone( ) , lower. clone( ) , upper. clone( ) ] ,
746751 Self :: IfElse {
747752 if_true,
@@ -788,10 +793,18 @@ impl Expr {
788793 left : children. first ( ) . expect ( "Should have 1 child" ) . clone ( ) ,
789794 right : children. get ( 1 ) . expect ( "Should have 2 child" ) . clone ( ) ,
790795 } ,
791- Self :: IsIn ( ..) => Self :: IsIn (
792- children. first ( ) . expect ( "Should have 1 child" ) . clone ( ) ,
793- children. get ( 1 ) . expect ( "Should have 2 child" ) . clone ( ) ,
794- ) ,
796+ Self :: IsIn ( _, old_children) => {
797+ assert_eq ! (
798+ children. len( ) ,
799+ old_children. len( ) + 1 ,
800+ "Should have same number of children"
801+ ) ;
802+ let mut children_iter = children. into_iter ( ) ;
803+ let expr = children_iter. next ( ) . expect ( "Should have 1 child" ) ;
804+ let items = children_iter. collect ( ) ;
805+
806+ Self :: IsIn ( expr, items)
807+ }
795808 Self :: Between ( ..) => Self :: Between (
796809 children. first ( ) . expect ( "Should have 1 child" ) . clone ( ) ,
797810 children. get ( 1 ) . expect ( "Should have 2 child" ) . clone ( ) ,
@@ -865,10 +878,28 @@ impl Expr {
865878 }
866879 Self :: IsIn ( left, right) => {
867880 let left_field = left. to_field ( schema) ?;
868- let right_field = right. to_field ( schema) ?;
881+
882+ let first_right_field = right
883+ . first ( )
884+ . expect ( "Should have at least 1 child" )
885+ . to_field ( schema) ?;
886+ let all_same_type = right. iter ( ) . all ( |expr| {
887+ let field = expr. to_field ( schema) . unwrap ( ) ;
888+ // allow nulls to be compared with anything
889+ if field. dtype == DataType :: Null || first_right_field. dtype == DataType :: Null {
890+ return true ;
891+ }
892+ field. dtype == first_right_field. dtype
893+ } ) ;
894+ if !all_same_type {
895+ return Err ( DaftError :: TypeError ( format ! (
896+ "Expected all arguments to be of the same type, but received {first_right_field} and others" ,
897+ ) ) ) ;
898+ }
899+
869900 let ( result_type, _intermediate, _comp_type) =
870901 InferDataType :: from ( & left_field. dtype )
871- . membership_op ( & InferDataType :: from ( & right_field . dtype ) ) ?;
902+ . membership_op ( & InferDataType :: from ( & first_right_field . dtype ) ) ?;
872903 Ok ( Field :: new ( left_field. name . as_str ( ) , result_type) )
873904 }
874905 Self :: Between ( value, lower, upper) => {
0 commit comments