@@ -1590,6 +1590,24 @@ pub fn from_cast(
15901590 schema : & DFSchemaRef ,
15911591) -> Result < Expression > {
15921592 let Cast { expr, data_type } = cast;
1593+ // since substrait Null must be typed, so if we see a cast(null, dt), we make it a typed null
1594+ if let Expr :: Literal ( lit) = expr. as_ref ( ) {
1595+ // only the untyped(a null scalar value) null literal need this special handling
1596+ // since all other kind of nulls are already typed and can be handled by substrait
1597+ // e.g. null::<Int32Type> or null::<Utf8Type>
1598+ if matches ! ( lit, ScalarValue :: Null ) {
1599+ let lit = Literal {
1600+ nullable : true ,
1601+ type_variation_reference : DEFAULT_TYPE_VARIATION_REF ,
1602+ literal_type : Some ( LiteralType :: Null ( to_substrait_type (
1603+ data_type, true ,
1604+ ) ?) ) ,
1605+ } ;
1606+ return Ok ( Expression {
1607+ rex_type : Some ( RexType :: Literal ( lit) ) ,
1608+ } ) ;
1609+ }
1610+ }
15931611 Ok ( Expression {
15941612 rex_type : Some ( RexType :: Cast ( Box :: new (
15951613 substrait:: proto:: expression:: Cast {
@@ -2575,6 +2593,7 @@ mod test {
25752593 use datafusion:: common:: scalar:: ScalarStructBuilder ;
25762594 use datafusion:: common:: DFSchema ;
25772595 use datafusion:: execution:: { SessionState , SessionStateBuilder } ;
2596+ use datafusion:: logical_expr:: ExprSchemable ;
25782597 use datafusion:: prelude:: SessionContext ;
25792598 use std:: sync:: LazyLock ;
25802599
@@ -2912,4 +2931,70 @@ mod test {
29122931
29132932 assert ! ( matches!( err, Err ( DataFusionError :: SchemaError ( _, _) ) ) ) ;
29142933 }
2934+
2935+ #[ tokio:: test]
2936+ async fn fold_cast_null ( ) {
2937+ let state = SessionStateBuilder :: default ( ) . build ( ) ;
2938+ let empty_schema = DFSchemaRef :: new ( DFSchema :: empty ( ) ) ;
2939+ let field = Field :: new ( "out" , DataType :: Int32 , false ) ;
2940+
2941+ let expr = Expr :: Literal ( ScalarValue :: Null )
2942+ . cast_to ( & DataType :: Int32 , & empty_schema)
2943+ . unwrap ( ) ;
2944+
2945+ let typed_null =
2946+ to_substrait_extended_expr ( & [ ( & expr, & field) ] , & empty_schema, & state)
2947+ . unwrap ( ) ;
2948+
2949+ if let ExprType :: Expression ( expr) =
2950+ typed_null. referred_expr [ 0 ] . expr_type . as_ref ( ) . unwrap ( )
2951+ {
2952+ let lit = Literal {
2953+ nullable : true ,
2954+ type_variation_reference : DEFAULT_TYPE_VARIATION_REF ,
2955+ literal_type : Some ( LiteralType :: Null (
2956+ to_substrait_type ( & DataType :: Int32 , true ) . unwrap ( ) ,
2957+ ) ) ,
2958+ } ;
2959+ let expected = Expression {
2960+ rex_type : Some ( RexType :: Literal ( lit) ) ,
2961+ } ;
2962+ assert_eq ! ( * expr, expected) ;
2963+ } else {
2964+ panic ! ( "Expected expression type" ) ;
2965+ }
2966+
2967+ // a typed null should not be folded
2968+ let expr = Expr :: Literal ( ScalarValue :: Int64 ( None ) )
2969+ . cast_to ( & DataType :: Int32 , & empty_schema)
2970+ . unwrap ( ) ;
2971+
2972+ let typed_null =
2973+ to_substrait_extended_expr ( & [ ( & expr, & field) ] , & empty_schema, & state)
2974+ . unwrap ( ) ;
2975+
2976+ if let ExprType :: Expression ( expr) =
2977+ typed_null. referred_expr [ 0 ] . expr_type . as_ref ( ) . unwrap ( )
2978+ {
2979+ let cast_expr = substrait:: proto:: expression:: Cast {
2980+ r#type : Some ( to_substrait_type ( & DataType :: Int32 , true ) . unwrap ( ) ) ,
2981+ input : Some ( Box :: new ( Expression {
2982+ rex_type : Some ( RexType :: Literal ( Literal {
2983+ nullable : true ,
2984+ type_variation_reference : DEFAULT_TYPE_VARIATION_REF ,
2985+ literal_type : Some ( LiteralType :: Null (
2986+ to_substrait_type ( & DataType :: Int64 , true ) . unwrap ( ) ,
2987+ ) ) ,
2988+ } ) ) ,
2989+ } ) ) ,
2990+ failure_behavior : FailureBehavior :: ThrowException as i32 ,
2991+ } ;
2992+ let expected = Expression {
2993+ rex_type : Some ( RexType :: Cast ( Box :: new ( cast_expr) ) ) ,
2994+ } ;
2995+ assert_eq ! ( * expr, expected) ;
2996+ } else {
2997+ panic ! ( "Expected expression type" ) ;
2998+ }
2999+ }
29153000}
0 commit comments