@@ -23,7 +23,8 @@ use std::str::FromStr;
2323use std:: sync:: Arc ;
2424
2525use arrow_arith:: boolean:: { and, is_not_null, is_null, not, or} ;
26- use arrow_array:: { Array , ArrayRef , BooleanArray , RecordBatch } ;
26+ use arrow_array:: { Array , ArrayRef , BooleanArray , Datum as ArrowDatum , RecordBatch , Scalar } ;
27+ use arrow_cast:: cast:: cast;
2728use arrow_ord:: cmp:: { eq, gt, gt_eq, lt, lt_eq, neq} ;
2829use arrow_schema:: {
2930 ArrowError , DataType , FieldRef , Schema as ArrowSchema , SchemaRef as ArrowSchemaRef ,
@@ -907,6 +908,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
907908
908909 Ok ( Box :: new ( move |batch| {
909910 let left = project_column ( & batch, idx) ?;
911+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
910912 lt ( & left, literal. as_ref ( ) )
911913 } ) )
912914 } else {
@@ -926,6 +928,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
926928
927929 Ok ( Box :: new ( move |batch| {
928930 let left = project_column ( & batch, idx) ?;
931+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
929932 lt_eq ( & left, literal. as_ref ( ) )
930933 } ) )
931934 } else {
@@ -945,6 +948,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
945948
946949 Ok ( Box :: new ( move |batch| {
947950 let left = project_column ( & batch, idx) ?;
951+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
948952 gt ( & left, literal. as_ref ( ) )
949953 } ) )
950954 } else {
@@ -964,6 +968,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
964968
965969 Ok ( Box :: new ( move |batch| {
966970 let left = project_column ( & batch, idx) ?;
971+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
967972 gt_eq ( & left, literal. as_ref ( ) )
968973 } ) )
969974 } else {
@@ -983,6 +988,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
983988
984989 Ok ( Box :: new ( move |batch| {
985990 let left = project_column ( & batch, idx) ?;
991+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
986992 eq ( & left, literal. as_ref ( ) )
987993 } ) )
988994 } else {
@@ -1002,6 +1008,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10021008
10031009 Ok ( Box :: new ( move |batch| {
10041010 let left = project_column ( & batch, idx) ?;
1011+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
10051012 neq ( & left, literal. as_ref ( ) )
10061013 } ) )
10071014 } else {
@@ -1021,6 +1028,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10211028
10221029 Ok ( Box :: new ( move |batch| {
10231030 let left = project_column ( & batch, idx) ?;
1031+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
10241032 starts_with ( & left, literal. as_ref ( ) )
10251033 } ) )
10261034 } else {
@@ -1040,7 +1048,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10401048
10411049 Ok ( Box :: new ( move |batch| {
10421050 let left = project_column ( & batch, idx) ?;
1043-
1051+ let literal = cast_literal_if_required ( Arc :: clone ( & literal ) , left . data_type ( ) ) ? ;
10441052 // update here if arrow ever adds a native not_starts_with
10451053 not ( & starts_with ( & left, literal. as_ref ( ) ) ?)
10461054 } ) )
@@ -1065,8 +1073,10 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10651073 Ok ( Box :: new ( move |batch| {
10661074 // update this if arrow ever adds a native is_in kernel
10671075 let left = project_column ( & batch, idx) ?;
1076+
10681077 let mut acc = BooleanArray :: from ( vec ! [ false ; batch. num_rows( ) ] ) ;
10691078 for literal in & literals {
1079+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
10701080 acc = or ( & acc, & eq ( & left, literal. as_ref ( ) ) ?) ?
10711081 }
10721082
@@ -1095,6 +1105,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10951105 let left = project_column ( & batch, idx) ?;
10961106 let mut acc = BooleanArray :: from ( vec ! [ true ; batch. num_rows( ) ] ) ;
10971107 for literal in & literals {
1108+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
10981109 acc = and ( & acc, & neq ( & left, literal. as_ref ( ) ) ?) ?
10991110 }
11001111
@@ -1150,11 +1161,34 @@ impl<R: FileRead> AsyncFileReader for ArrowFileReader<R> {
11501161 }
11511162}
11521163
1164+ /// The Arrow type of an array that the Parquet reader reads may not match the exact Arrow type
1165+ /// that Iceberg uses for literals - but they are effectively the same logical type,
1166+ /// i.e. LargeUtf8 and Utf8 or Utf8View and Utf8 or Utf8View and LargeUtf8.
1167+ ///
1168+ /// The Arrow compute kernels that we use must match the type exactly, so first cast the literal
1169+ /// into the type of the batch we read from Parquet before sending it to the compute kernel.
1170+ fn cast_literal_if_required (
1171+ literal : Arc < dyn ArrowDatum + Send + Sync > ,
1172+ column_type : & DataType ,
1173+ ) -> std:: result:: Result < Arc < dyn ArrowDatum + Send + Sync > , ArrowError > {
1174+ let literal_array = literal. get ( ) . 0 ;
1175+
1176+ // No cast required
1177+ if literal_array. data_type ( ) == column_type {
1178+ return Ok ( literal) ;
1179+ }
1180+
1181+ let literal_array = cast ( literal_array, column_type) ?;
1182+ Ok ( Arc :: new ( Scalar :: new ( literal_array) ) )
1183+ }
1184+
11531185#[ cfg( test) ]
11541186mod tests {
11551187 use std:: collections:: { HashMap , HashSet } ;
11561188 use std:: sync:: Arc ;
11571189
1190+ use arrow_array:: cast:: AsArray ;
1191+ use arrow_array:: { ArrayRef , LargeStringArray , RecordBatch , StringArray } ;
11581192 use arrow_schema:: { DataType , Field , Schema as ArrowSchema , TimeUnit } ;
11591193 use parquet:: arrow:: ProjectionMask ;
11601194 use parquet:: schema:: parser:: parse_message_type;
@@ -1324,4 +1358,177 @@ message schema {
13241358 . expect ( "Some ProjectionMask" ) ;
13251359 assert_eq ! ( mask, ProjectionMask :: leaves( & parquet_schema, vec![ 0 ] ) ) ;
13261360 }
1361+
1362+ #[ tokio:: test]
1363+ async fn test_predicate_cast_literal ( ) {
1364+ let predicates = vec ! [
1365+ // a == 'foo'
1366+ (
1367+ Reference :: new( "a" ) . equal_to( Datum :: string( "foo" ) ) ,
1368+ vec![ Some ( "foo" . to_string( ) ) ] ,
1369+ ) ,
1370+ // a != 'foo'
1371+ (
1372+ Reference :: new( "a" ) . not_equal_to( Datum :: string( "foo" ) ) ,
1373+ vec![ Some ( "bar" . to_string( ) ) ] ,
1374+ ) ,
1375+ // STARTS_WITH(a, 'foo')
1376+ (
1377+ Reference :: new( "a" ) . starts_with( Datum :: string( "f" ) ) ,
1378+ vec![ Some ( "foo" . to_string( ) ) ] ,
1379+ ) ,
1380+ // NOT STARTS_WITH(a, 'foo')
1381+ (
1382+ Reference :: new( "a" ) . not_starts_with( Datum :: string( "f" ) ) ,
1383+ vec![ Some ( "bar" . to_string( ) ) ] ,
1384+ ) ,
1385+ // a < 'foo'
1386+ (
1387+ Reference :: new( "a" ) . less_than( Datum :: string( "foo" ) ) ,
1388+ vec![ Some ( "bar" . to_string( ) ) ] ,
1389+ ) ,
1390+ // a <= 'foo'
1391+ (
1392+ Reference :: new( "a" ) . less_than_or_equal_to( Datum :: string( "foo" ) ) ,
1393+ vec![ Some ( "foo" . to_string( ) ) , Some ( "bar" . to_string( ) ) ] ,
1394+ ) ,
1395+ // a > 'foo'
1396+ (
1397+ Reference :: new( "a" ) . greater_than( Datum :: string( "bar" ) ) ,
1398+ vec![ Some ( "foo" . to_string( ) ) ] ,
1399+ ) ,
1400+ // a >= 'foo'
1401+ (
1402+ Reference :: new( "a" ) . greater_than_or_equal_to( Datum :: string( "foo" ) ) ,
1403+ vec![ Some ( "foo" . to_string( ) ) ] ,
1404+ ) ,
1405+ // a IN ('foo', 'bar')
1406+ (
1407+ Reference :: new( "a" ) . is_in( [ Datum :: string( "foo" ) , Datum :: string( "baz" ) ] ) ,
1408+ vec![ Some ( "foo" . to_string( ) ) ] ,
1409+ ) ,
1410+ // a NOT IN ('foo', 'bar')
1411+ (
1412+ Reference :: new( "a" ) . is_not_in( [ Datum :: string( "foo" ) , Datum :: string( "baz" ) ] ) ,
1413+ vec![ Some ( "bar" . to_string( ) ) ] ,
1414+ ) ,
1415+ ] ;
1416+
1417+ // Table data: ["foo", "bar"]
1418+ let data_for_col_a = vec ! [ Some ( "foo" . to_string( ) ) , Some ( "bar" . to_string( ) ) ] ;
1419+
1420+ let ( file_io, schema, table_location, _temp_dir) =
1421+ setup_kleene_logic ( data_for_col_a, DataType :: LargeUtf8 ) ;
1422+ let reader = ArrowReaderBuilder :: new ( file_io) . build ( ) ;
1423+
1424+ for ( predicate, expected) in predicates {
1425+ println ! ( "testing predicate {predicate}" ) ;
1426+ let result_data = test_perform_read (
1427+ predicate. clone ( ) ,
1428+ schema. clone ( ) ,
1429+ table_location. clone ( ) ,
1430+ reader. clone ( ) ,
1431+ )
1432+ . await ;
1433+
1434+ assert_eq ! ( result_data, expected, "predicate={predicate}" ) ;
1435+ }
1436+ }
1437+
1438+ async fn test_perform_read (
1439+ predicate : Predicate ,
1440+ schema : SchemaRef ,
1441+ table_location : String ,
1442+ reader : ArrowReader ,
1443+ ) -> Vec < Option < String > > {
1444+ let tasks = Box :: pin ( futures:: stream:: iter (
1445+ vec ! [ Ok ( FileScanTask {
1446+ start: 0 ,
1447+ length: 0 ,
1448+ record_count: None ,
1449+ data_file_path: format!( "{}/1.parquet" , table_location) ,
1450+ data_file_content: DataContentType :: Data ,
1451+ data_file_format: DataFileFormat :: Parquet ,
1452+ schema: schema. clone( ) ,
1453+ project_field_ids: vec![ 1 ] ,
1454+ predicate: Some ( predicate. bind( schema, true ) . unwrap( ) ) ,
1455+ deletes: vec![ ] ,
1456+ } ) ]
1457+ . into_iter ( ) ,
1458+ ) ) as FileScanTaskStream ;
1459+
1460+ let result = reader
1461+ . read ( tasks)
1462+ . await
1463+ . unwrap ( )
1464+ . try_collect :: < Vec < RecordBatch > > ( )
1465+ . await
1466+ . unwrap ( ) ;
1467+
1468+ let result_data = result[ 0 ] . columns ( ) [ 0 ]
1469+ . as_string_opt :: < i32 > ( )
1470+ . unwrap ( )
1471+ . iter ( )
1472+ . map ( |v| v. map ( ToOwned :: to_owned) )
1473+ . collect :: < Vec < _ > > ( ) ;
1474+
1475+ result_data
1476+ }
1477+
1478+ fn setup_kleene_logic (
1479+ data_for_col_a : Vec < Option < String > > ,
1480+ col_a_type : DataType ,
1481+ ) -> ( FileIO , SchemaRef , String , TempDir ) {
1482+ let schema = Arc :: new (
1483+ Schema :: builder ( )
1484+ . with_schema_id ( 1 )
1485+ . with_fields ( vec ! [ NestedField :: optional(
1486+ 1 ,
1487+ "a" ,
1488+ Type :: Primitive ( PrimitiveType :: String ) ,
1489+ )
1490+ . into( ) ] )
1491+ . build ( )
1492+ . unwrap ( ) ,
1493+ ) ;
1494+
1495+ let arrow_schema = Arc :: new ( ArrowSchema :: new ( vec ! [ Field :: new(
1496+ "a" ,
1497+ col_a_type. clone( ) ,
1498+ true ,
1499+ )
1500+ . with_metadata( HashMap :: from( [ (
1501+ PARQUET_FIELD_ID_META_KEY . to_string( ) ,
1502+ "1" . to_string( ) ,
1503+ ) ] ) ) ] ) ) ;
1504+
1505+ let tmp_dir = TempDir :: new ( ) . unwrap ( ) ;
1506+ let table_location = tmp_dir. path ( ) . to_str ( ) . unwrap ( ) . to_string ( ) ;
1507+
1508+ let file_io = FileIO :: from_path ( & table_location) . unwrap ( ) . build ( ) . unwrap ( ) ;
1509+
1510+ let col = match col_a_type {
1511+ DataType :: Utf8 => Arc :: new ( StringArray :: from ( data_for_col_a) ) as ArrayRef ,
1512+ DataType :: LargeUtf8 => Arc :: new ( LargeStringArray :: from ( data_for_col_a) ) as ArrayRef ,
1513+ _ => panic ! ( "unexpected col_a_type" ) ,
1514+ } ;
1515+
1516+ let to_write = RecordBatch :: try_new ( arrow_schema. clone ( ) , vec ! [ col] ) . unwrap ( ) ;
1517+
1518+ // Write the Parquet files
1519+ let props = WriterProperties :: builder ( )
1520+ . set_compression ( Compression :: SNAPPY )
1521+ . build ( ) ;
1522+
1523+ let file = File :: create ( format ! ( "{}/1.parquet" , & table_location) ) . unwrap ( ) ;
1524+ let mut writer =
1525+ ArrowWriter :: try_new ( file, to_write. schema ( ) , Some ( props. clone ( ) ) ) . unwrap ( ) ;
1526+
1527+ writer. write ( & to_write) . expect ( "Writing batch" ) ;
1528+
1529+ // writer must be closed to write footer
1530+ writer. close ( ) . unwrap ( ) ;
1531+
1532+ ( file_io, schema, table_location, tmp_dir)
1533+ }
13271534}
0 commit comments