@@ -23,7 +23,8 @@ use std::str::FromStr;
2323use std:: sync:: Arc ;
2424
2525use arrow_arith:: boolean:: { and, is_not_null, is_null, not, or} ;
26- use arrow_array:: { Array , ArrayRef , BooleanArray , RecordBatch } ;
26+ use arrow_array:: { Array , ArrayRef , BooleanArray , Datum as ArrowDatum , RecordBatch , Scalar } ;
27+ use arrow_cast:: cast:: cast;
2728use arrow_ord:: cmp:: { eq, gt, gt_eq, lt, lt_eq, neq} ;
2829use arrow_schema:: {
2930 ArrowError , DataType , FieldRef , Schema as ArrowSchema , SchemaRef as ArrowSchemaRef ,
@@ -907,6 +908,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
907908
908909 Ok ( Box :: new ( move |batch| {
909910 let left = project_column ( & batch, idx) ?;
911+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
910912 lt ( & left, literal. as_ref ( ) )
911913 } ) )
912914 } else {
@@ -926,6 +928,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
926928
927929 Ok ( Box :: new ( move |batch| {
928930 let left = project_column ( & batch, idx) ?;
931+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
929932 lt_eq ( & left, literal. as_ref ( ) )
930933 } ) )
931934 } else {
@@ -945,6 +948,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
945948
946949 Ok ( Box :: new ( move |batch| {
947950 let left = project_column ( & batch, idx) ?;
951+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
948952 gt ( & left, literal. as_ref ( ) )
949953 } ) )
950954 } else {
@@ -964,6 +968,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
964968
965969 Ok ( Box :: new ( move |batch| {
966970 let left = project_column ( & batch, idx) ?;
971+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
967972 gt_eq ( & left, literal. as_ref ( ) )
968973 } ) )
969974 } else {
@@ -983,6 +988,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
983988
984989 Ok ( Box :: new ( move |batch| {
985990 let left = project_column ( & batch, idx) ?;
991+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
986992 eq ( & left, literal. as_ref ( ) )
987993 } ) )
988994 } else {
@@ -1002,6 +1008,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10021008
10031009 Ok ( Box :: new ( move |batch| {
10041010 let left = project_column ( & batch, idx) ?;
1011+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
10051012 neq ( & left, literal. as_ref ( ) )
10061013 } ) )
10071014 } else {
@@ -1021,6 +1028,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10211028
10221029 Ok ( Box :: new ( move |batch| {
10231030 let left = project_column ( & batch, idx) ?;
1031+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
10241032 starts_with ( & left, literal. as_ref ( ) )
10251033 } ) )
10261034 } else {
@@ -1040,7 +1048,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10401048
10411049 Ok ( Box :: new ( move |batch| {
10421050 let left = project_column ( & batch, idx) ?;
1043-
1051+ let literal = cast_literal_if_required ( Arc :: clone ( & literal ) , left . data_type ( ) ) ? ;
10441052 // update here if arrow ever adds a native not_starts_with
10451053 not ( & starts_with ( & left, literal. as_ref ( ) ) ?)
10461054 } ) )
@@ -1065,8 +1073,10 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10651073 Ok ( Box :: new ( move |batch| {
10661074 // update this if arrow ever adds a native is_in kernel
10671075 let left = project_column ( & batch, idx) ?;
1076+
10681077 let mut acc = BooleanArray :: from ( vec ! [ false ; batch. num_rows( ) ] ) ;
10691078 for literal in & literals {
1079+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
10701080 acc = or ( & acc, & eq ( & left, literal. as_ref ( ) ) ?) ?
10711081 }
10721082
@@ -1095,6 +1105,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10951105 let left = project_column ( & batch, idx) ?;
10961106 let mut acc = BooleanArray :: from ( vec ! [ true ; batch. num_rows( ) ] ) ;
10971107 for literal in & literals {
1108+ let literal = cast_literal_if_required ( Arc :: clone ( & literal) , left. data_type ( ) ) ?;
10981109 acc = and ( & acc, & neq ( & left, literal. as_ref ( ) ) ?) ?
10991110 }
11001111
@@ -1150,21 +1161,53 @@ impl<R: FileRead> AsyncFileReader for ArrowFileReader<R> {
11501161 }
11511162}
11521163
1164+ /// The Arrow type of an array that the Parquet reader reads may not match the exact Arrow type
1165+ /// that Iceberg uses for literals - but they are effectively the same logical type,
1166+ /// i.e. LargeUtf8 and Utf8 or Utf8View and Utf8 or Utf8View and LargeUtf8.
1167+ ///
1168+ /// The Arrow compute kernels that we use must match the type exactly, so first cast the literal
1169+ /// into the type of the batch we read from Parquet before sending it to the compute kernel.
1170+ fn cast_literal_if_required (
1171+ literal : Arc < dyn ArrowDatum + Send + Sync > ,
1172+ column_type : & DataType ,
1173+ ) -> std:: result:: Result < Arc < dyn ArrowDatum + Send + Sync > , ArrowError > {
1174+ let literal_array = literal. get ( ) . 0 ;
1175+
1176+ // No cast required
1177+ if literal_array. data_type ( ) == column_type {
1178+ return Ok ( literal) ;
1179+ }
1180+
1181+ let literal_array = cast ( literal_array, column_type) ?;
1182+ Ok ( Arc :: new ( Scalar :: new ( literal_array) ) )
1183+ }
1184+
11531185#[ cfg( test) ]
11541186mod tests {
11551187 use std:: collections:: { HashMap , HashSet } ;
1188+ use std:: fs:: File ;
11561189 use std:: sync:: Arc ;
11571190
1191+ use arrow_array:: cast:: AsArray ;
1192+ use arrow_array:: { ArrayRef , LargeStringArray , RecordBatch , StringArray } ;
11581193 use arrow_schema:: { DataType , Field , Schema as ArrowSchema , TimeUnit } ;
1159- use parquet:: arrow:: ProjectionMask ;
1194+ use futures:: TryStreamExt ;
1195+ use parquet:: arrow:: { ArrowWriter , ProjectionMask } ;
1196+ use parquet:: basic:: Compression ;
1197+ use parquet:: file:: properties:: WriterProperties ;
11601198 use parquet:: schema:: parser:: parse_message_type;
11611199 use parquet:: schema:: types:: SchemaDescriptor ;
1200+ use tempfile:: TempDir ;
11621201
11631202 use crate :: arrow:: reader:: { CollectFieldIdVisitor , PARQUET_FIELD_ID_META_KEY } ;
1164- use crate :: arrow:: ArrowReader ;
1203+ use crate :: arrow:: { ArrowReader , ArrowReaderBuilder } ;
11651204 use crate :: expr:: visitors:: bound_predicate_visitor:: visit;
1166- use crate :: expr:: { Bind , Reference } ;
1167- use crate :: spec:: { NestedField , PrimitiveType , Schema , SchemaRef , Type } ;
1205+ use crate :: expr:: { Bind , Predicate , Reference } ;
1206+ use crate :: io:: FileIO ;
1207+ use crate :: scan:: { FileScanTask , FileScanTaskStream } ;
1208+ use crate :: spec:: {
1209+ DataContentType , DataFileFormat , Datum , NestedField , PrimitiveType , Schema , SchemaRef , Type ,
1210+ } ;
11681211 use crate :: ErrorKind ;
11691212
11701213 fn table_schema_simple ( ) -> SchemaRef {
@@ -1324,4 +1367,176 @@ message schema {
13241367 . expect ( "Some ProjectionMask" ) ;
13251368 assert_eq ! ( mask, ProjectionMask :: leaves( & parquet_schema, vec![ 0 ] ) ) ;
13261369 }
1370+
1371+ #[ tokio:: test]
1372+ async fn test_predicate_cast_literal ( ) {
1373+ let predicates = vec ! [
1374+ // a == 'foo'
1375+ (
1376+ Reference :: new( "a" ) . equal_to( Datum :: string( "foo" ) ) ,
1377+ vec![ Some ( "foo" . to_string( ) ) ] ,
1378+ ) ,
1379+ // a != 'foo'
1380+ (
1381+ Reference :: new( "a" ) . not_equal_to( Datum :: string( "foo" ) ) ,
1382+ vec![ Some ( "bar" . to_string( ) ) ] ,
1383+ ) ,
1384+ // STARTS_WITH(a, 'foo')
1385+ (
1386+ Reference :: new( "a" ) . starts_with( Datum :: string( "f" ) ) ,
1387+ vec![ Some ( "foo" . to_string( ) ) ] ,
1388+ ) ,
1389+ // NOT STARTS_WITH(a, 'foo')
1390+ (
1391+ Reference :: new( "a" ) . not_starts_with( Datum :: string( "f" ) ) ,
1392+ vec![ Some ( "bar" . to_string( ) ) ] ,
1393+ ) ,
1394+ // a < 'foo'
1395+ (
1396+ Reference :: new( "a" ) . less_than( Datum :: string( "foo" ) ) ,
1397+ vec![ Some ( "bar" . to_string( ) ) ] ,
1398+ ) ,
1399+ // a <= 'foo'
1400+ (
1401+ Reference :: new( "a" ) . less_than_or_equal_to( Datum :: string( "foo" ) ) ,
1402+ vec![ Some ( "foo" . to_string( ) ) , Some ( "bar" . to_string( ) ) ] ,
1403+ ) ,
1404+ // a > 'foo'
1405+ (
1406+ Reference :: new( "a" ) . greater_than( Datum :: string( "bar" ) ) ,
1407+ vec![ Some ( "foo" . to_string( ) ) ] ,
1408+ ) ,
1409+ // a >= 'foo'
1410+ (
1411+ Reference :: new( "a" ) . greater_than_or_equal_to( Datum :: string( "foo" ) ) ,
1412+ vec![ Some ( "foo" . to_string( ) ) ] ,
1413+ ) ,
1414+ // a IN ('foo', 'bar')
1415+ (
1416+ Reference :: new( "a" ) . is_in( [ Datum :: string( "foo" ) , Datum :: string( "baz" ) ] ) ,
1417+ vec![ Some ( "foo" . to_string( ) ) ] ,
1418+ ) ,
1419+ // a NOT IN ('foo', 'bar')
1420+ (
1421+ Reference :: new( "a" ) . is_not_in( [ Datum :: string( "foo" ) , Datum :: string( "baz" ) ] ) ,
1422+ vec![ Some ( "bar" . to_string( ) ) ] ,
1423+ ) ,
1424+ ] ;
1425+
1426+ // Table data: ["foo", "bar"]
1427+ let data_for_col_a = vec ! [ Some ( "foo" . to_string( ) ) , Some ( "bar" . to_string( ) ) ] ;
1428+
1429+ let ( file_io, schema, table_location, _temp_dir) =
1430+ setup_kleene_logic ( data_for_col_a, DataType :: LargeUtf8 ) ;
1431+ let reader = ArrowReaderBuilder :: new ( file_io) . build ( ) ;
1432+
1433+ for ( predicate, expected) in predicates {
1434+ println ! ( "testing predicate {predicate}" ) ;
1435+ let result_data = test_perform_read (
1436+ predicate. clone ( ) ,
1437+ schema. clone ( ) ,
1438+ table_location. clone ( ) ,
1439+ reader. clone ( ) ,
1440+ )
1441+ . await ;
1442+
1443+ assert_eq ! ( result_data, expected, "predicate={predicate}" ) ;
1444+ }
1445+ }
1446+
1447+ async fn test_perform_read (
1448+ predicate : Predicate ,
1449+ schema : SchemaRef ,
1450+ table_location : String ,
1451+ reader : ArrowReader ,
1452+ ) -> Vec < Option < String > > {
1453+ let tasks = Box :: pin ( futures:: stream:: iter (
1454+ vec ! [ Ok ( FileScanTask {
1455+ start: 0 ,
1456+ length: 0 ,
1457+ record_count: None ,
1458+ data_file_path: format!( "{}/1.parquet" , table_location) ,
1459+ data_file_content: DataContentType :: Data ,
1460+ data_file_format: DataFileFormat :: Parquet ,
1461+ schema: schema. clone( ) ,
1462+ project_field_ids: vec![ 1 ] ,
1463+ predicate: Some ( predicate. bind( schema, true ) . unwrap( ) ) ,
1464+ } ) ]
1465+ . into_iter ( ) ,
1466+ ) ) as FileScanTaskStream ;
1467+
1468+ let result = reader
1469+ . read ( tasks)
1470+ . await
1471+ . unwrap ( )
1472+ . try_collect :: < Vec < RecordBatch > > ( )
1473+ . await
1474+ . unwrap ( ) ;
1475+
1476+ let result_data = result[ 0 ] . columns ( ) [ 0 ]
1477+ . as_string_opt :: < i32 > ( )
1478+ . unwrap ( )
1479+ . iter ( )
1480+ . map ( |v| v. map ( ToOwned :: to_owned) )
1481+ . collect :: < Vec < _ > > ( ) ;
1482+
1483+ result_data
1484+ }
1485+
1486+ fn setup_kleene_logic (
1487+ data_for_col_a : Vec < Option < String > > ,
1488+ col_a_type : DataType ,
1489+ ) -> ( FileIO , SchemaRef , String , TempDir ) {
1490+ let schema = Arc :: new (
1491+ Schema :: builder ( )
1492+ . with_schema_id ( 1 )
1493+ . with_fields ( vec ! [ NestedField :: optional(
1494+ 1 ,
1495+ "a" ,
1496+ Type :: Primitive ( PrimitiveType :: String ) ,
1497+ )
1498+ . into( ) ] )
1499+ . build ( )
1500+ . unwrap ( ) ,
1501+ ) ;
1502+
1503+ let arrow_schema = Arc :: new ( ArrowSchema :: new ( vec ! [ Field :: new(
1504+ "a" ,
1505+ col_a_type. clone( ) ,
1506+ true ,
1507+ )
1508+ . with_metadata( HashMap :: from( [ (
1509+ PARQUET_FIELD_ID_META_KEY . to_string( ) ,
1510+ "1" . to_string( ) ,
1511+ ) ] ) ) ] ) ) ;
1512+
1513+ let tmp_dir = TempDir :: new ( ) . unwrap ( ) ;
1514+ let table_location = tmp_dir. path ( ) . to_str ( ) . unwrap ( ) . to_string ( ) ;
1515+
1516+ let file_io = FileIO :: from_path ( & table_location) . unwrap ( ) . build ( ) . unwrap ( ) ;
1517+
1518+ let col = match col_a_type {
1519+ DataType :: Utf8 => Arc :: new ( StringArray :: from ( data_for_col_a) ) as ArrayRef ,
1520+ DataType :: LargeUtf8 => Arc :: new ( LargeStringArray :: from ( data_for_col_a) ) as ArrayRef ,
1521+ _ => panic ! ( "unexpected col_a_type" ) ,
1522+ } ;
1523+
1524+ let to_write = RecordBatch :: try_new ( arrow_schema. clone ( ) , vec ! [ col] ) . unwrap ( ) ;
1525+
1526+ // Write the Parquet files
1527+ let props = WriterProperties :: builder ( )
1528+ . set_compression ( Compression :: SNAPPY )
1529+ . build ( ) ;
1530+
1531+ let file = File :: create ( format ! ( "{}/1.parquet" , & table_location) ) . unwrap ( ) ;
1532+ let mut writer =
1533+ ArrowWriter :: try_new ( file, to_write. schema ( ) , Some ( props. clone ( ) ) ) . unwrap ( ) ;
1534+
1535+ writer. write ( & to_write) . expect ( "Writing batch" ) ;
1536+
1537+ // writer must be closed to write footer
1538+ writer. close ( ) . unwrap ( ) ;
1539+
1540+ ( file_io, schema, table_location, tmp_dir)
1541+ }
13271542}
0 commit comments