Skip to content

Commit de0e2e2

Browse files
Fix predicates not matching the Arrow type of columns read from parquet
files
1 parent 58308b2 commit de0e2e2

2 files changed

Lines changed: 219 additions & 12 deletions

File tree

crates/iceberg/src/arrow/reader.rs

Lines changed: 209 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ use std::str::FromStr;
2323
use std::sync::Arc;
2424

2525
use arrow_arith::boolean::{and, is_not_null, is_null, not, or};
26-
use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch};
26+
use arrow_array::{Array, ArrayRef, BooleanArray, Datum as ArrowDatum, RecordBatch, Scalar};
27+
use arrow_cast::cast::cast;
2728
use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq};
2829
use arrow_schema::{
2930
ArrowError, DataType, FieldRef, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef,
@@ -907,6 +908,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
907908

908909
Ok(Box::new(move |batch| {
909910
let left = project_column(&batch, idx)?;
911+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
910912
lt(&left, literal.as_ref())
911913
}))
912914
} else {
@@ -926,6 +928,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
926928

927929
Ok(Box::new(move |batch| {
928930
let left = project_column(&batch, idx)?;
931+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
929932
lt_eq(&left, literal.as_ref())
930933
}))
931934
} else {
@@ -945,6 +948,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
945948

946949
Ok(Box::new(move |batch| {
947950
let left = project_column(&batch, idx)?;
951+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
948952
gt(&left, literal.as_ref())
949953
}))
950954
} else {
@@ -964,6 +968,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
964968

965969
Ok(Box::new(move |batch| {
966970
let left = project_column(&batch, idx)?;
971+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
967972
gt_eq(&left, literal.as_ref())
968973
}))
969974
} else {
@@ -983,6 +988,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
983988

984989
Ok(Box::new(move |batch| {
985990
let left = project_column(&batch, idx)?;
991+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
986992
eq(&left, literal.as_ref())
987993
}))
988994
} else {
@@ -1002,6 +1008,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10021008

10031009
Ok(Box::new(move |batch| {
10041010
let left = project_column(&batch, idx)?;
1011+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
10051012
neq(&left, literal.as_ref())
10061013
}))
10071014
} else {
@@ -1021,6 +1028,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10211028

10221029
Ok(Box::new(move |batch| {
10231030
let left = project_column(&batch, idx)?;
1031+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
10241032
starts_with(&left, literal.as_ref())
10251033
}))
10261034
} else {
@@ -1040,7 +1048,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10401048

10411049
Ok(Box::new(move |batch| {
10421050
let left = project_column(&batch, idx)?;
1043-
1051+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
10441052
// update here if arrow ever adds a native not_starts_with
10451053
not(&starts_with(&left, literal.as_ref())?)
10461054
}))
@@ -1065,8 +1073,10 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10651073
Ok(Box::new(move |batch| {
10661074
// update this if arrow ever adds a native is_in kernel
10671075
let left = project_column(&batch, idx)?;
1076+
10681077
let mut acc = BooleanArray::from(vec![false; batch.num_rows()]);
10691078
for literal in &literals {
1079+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
10701080
acc = or(&acc, &eq(&left, literal.as_ref())?)?
10711081
}
10721082

@@ -1095,6 +1105,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10951105
let left = project_column(&batch, idx)?;
10961106
let mut acc = BooleanArray::from(vec![true; batch.num_rows()]);
10971107
for literal in &literals {
1108+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
10981109
acc = and(&acc, &neq(&left, literal.as_ref())?)?
10991110
}
11001111

@@ -1150,11 +1161,34 @@ impl<R: FileRead> AsyncFileReader for ArrowFileReader<R> {
11501161
}
11511162
}
11521163

1164+
/// The Arrow type of an array that the Parquet reader reads may not match the exact Arrow type
1165+
/// that Iceberg uses for literals - but they are effectively the same logical type,
1166+
/// i.e. LargeUtf8 and Utf8 or Utf8View and Utf8 or Utf8View and LargeUtf8.
1167+
///
1168+
/// The Arrow compute kernels that we use must match the type exactly, so first cast the literal
1169+
/// into the type of the batch we read from Parquet before sending it to the compute kernel.
1170+
fn cast_literal_if_required(
1171+
literal: Arc<dyn ArrowDatum + Send + Sync>,
1172+
column_type: &DataType,
1173+
) -> std::result::Result<Arc<dyn ArrowDatum + Send + Sync>, ArrowError> {
1174+
let literal_array = literal.get().0;
1175+
1176+
// No cast required
1177+
if literal_array.data_type() == column_type {
1178+
return Ok(literal);
1179+
}
1180+
1181+
let literal_array = cast(literal_array, column_type)?;
1182+
Ok(Arc::new(Scalar::new(literal_array)))
1183+
}
1184+
11531185
#[cfg(test)]
11541186
mod tests {
11551187
use std::collections::{HashMap, HashSet};
11561188
use std::sync::Arc;
11571189

1190+
use arrow_array::cast::AsArray;
1191+
use arrow_array::{ArrayRef, LargeStringArray, RecordBatch, StringArray};
11581192
use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit};
11591193
use parquet::arrow::ProjectionMask;
11601194
use parquet::schema::parser::parse_message_type;
@@ -1324,4 +1358,177 @@ message schema {
13241358
.expect("Some ProjectionMask");
13251359
assert_eq!(mask, ProjectionMask::leaves(&parquet_schema, vec![0]));
13261360
}
1361+
1362+
#[tokio::test]
1363+
async fn test_predicate_cast_literal() {
1364+
let predicates = vec![
1365+
// a == 'foo'
1366+
(
1367+
Reference::new("a").equal_to(Datum::string("foo")),
1368+
vec![Some("foo".to_string())],
1369+
),
1370+
// a != 'foo'
1371+
(
1372+
Reference::new("a").not_equal_to(Datum::string("foo")),
1373+
vec![Some("bar".to_string())],
1374+
),
1375+
// STARTS_WITH(a, 'foo')
1376+
(
1377+
Reference::new("a").starts_with(Datum::string("f")),
1378+
vec![Some("foo".to_string())],
1379+
),
1380+
// NOT STARTS_WITH(a, 'foo')
1381+
(
1382+
Reference::new("a").not_starts_with(Datum::string("f")),
1383+
vec![Some("bar".to_string())],
1384+
),
1385+
// a < 'foo'
1386+
(
1387+
Reference::new("a").less_than(Datum::string("foo")),
1388+
vec![Some("bar".to_string())],
1389+
),
1390+
// a <= 'foo'
1391+
(
1392+
Reference::new("a").less_than_or_equal_to(Datum::string("foo")),
1393+
vec![Some("foo".to_string()), Some("bar".to_string())],
1394+
),
1395+
// a > 'foo'
1396+
(
1397+
Reference::new("a").greater_than(Datum::string("bar")),
1398+
vec![Some("foo".to_string())],
1399+
),
1400+
// a >= 'foo'
1401+
(
1402+
Reference::new("a").greater_than_or_equal_to(Datum::string("foo")),
1403+
vec![Some("foo".to_string())],
1404+
),
1405+
// a IN ('foo', 'bar')
1406+
(
1407+
Reference::new("a").is_in([Datum::string("foo"), Datum::string("baz")]),
1408+
vec![Some("foo".to_string())],
1409+
),
1410+
// a NOT IN ('foo', 'bar')
1411+
(
1412+
Reference::new("a").is_not_in([Datum::string("foo"), Datum::string("baz")]),
1413+
vec![Some("bar".to_string())],
1414+
),
1415+
];
1416+
1417+
// Table data: ["foo", "bar"]
1418+
let data_for_col_a = vec![Some("foo".to_string()), Some("bar".to_string())];
1419+
1420+
let (file_io, schema, table_location, _temp_dir) =
1421+
setup_kleene_logic(data_for_col_a, DataType::LargeUtf8);
1422+
let reader = ArrowReaderBuilder::new(file_io).build();
1423+
1424+
for (predicate, expected) in predicates {
1425+
println!("testing predicate {predicate}");
1426+
let result_data = test_perform_read(
1427+
predicate.clone(),
1428+
schema.clone(),
1429+
table_location.clone(),
1430+
reader.clone(),
1431+
)
1432+
.await;
1433+
1434+
assert_eq!(result_data, expected, "predicate={predicate}");
1435+
}
1436+
}
1437+
1438+
async fn test_perform_read(
1439+
predicate: Predicate,
1440+
schema: SchemaRef,
1441+
table_location: String,
1442+
reader: ArrowReader,
1443+
) -> Vec<Option<String>> {
1444+
let tasks = Box::pin(futures::stream::iter(
1445+
vec![Ok(FileScanTask {
1446+
start: 0,
1447+
length: 0,
1448+
record_count: None,
1449+
data_file_path: format!("{}/1.parquet", table_location),
1450+
data_file_content: DataContentType::Data,
1451+
data_file_format: DataFileFormat::Parquet,
1452+
schema: schema.clone(),
1453+
project_field_ids: vec![1],
1454+
predicate: Some(predicate.bind(schema, true).unwrap()),
1455+
deletes: vec![],
1456+
})]
1457+
.into_iter(),
1458+
)) as FileScanTaskStream;
1459+
1460+
let result = reader
1461+
.read(tasks)
1462+
.await
1463+
.unwrap()
1464+
.try_collect::<Vec<RecordBatch>>()
1465+
.await
1466+
.unwrap();
1467+
1468+
let result_data = result[0].columns()[0]
1469+
.as_string_opt::<i32>()
1470+
.unwrap()
1471+
.iter()
1472+
.map(|v| v.map(ToOwned::to_owned))
1473+
.collect::<Vec<_>>();
1474+
1475+
result_data
1476+
}
1477+
1478+
fn setup_kleene_logic(
1479+
data_for_col_a: Vec<Option<String>>,
1480+
col_a_type: DataType,
1481+
) -> (FileIO, SchemaRef, String, TempDir) {
1482+
let schema = Arc::new(
1483+
Schema::builder()
1484+
.with_schema_id(1)
1485+
.with_fields(vec![NestedField::optional(
1486+
1,
1487+
"a",
1488+
Type::Primitive(PrimitiveType::String),
1489+
)
1490+
.into()])
1491+
.build()
1492+
.unwrap(),
1493+
);
1494+
1495+
let arrow_schema = Arc::new(ArrowSchema::new(vec![Field::new(
1496+
"a",
1497+
col_a_type.clone(),
1498+
true,
1499+
)
1500+
.with_metadata(HashMap::from([(
1501+
PARQUET_FIELD_ID_META_KEY.to_string(),
1502+
"1".to_string(),
1503+
)]))]));
1504+
1505+
let tmp_dir = TempDir::new().unwrap();
1506+
let table_location = tmp_dir.path().to_str().unwrap().to_string();
1507+
1508+
let file_io = FileIO::from_path(&table_location).unwrap().build().unwrap();
1509+
1510+
let col = match col_a_type {
1511+
DataType::Utf8 => Arc::new(StringArray::from(data_for_col_a)) as ArrayRef,
1512+
DataType::LargeUtf8 => Arc::new(LargeStringArray::from(data_for_col_a)) as ArrayRef,
1513+
_ => panic!("unexpected col_a_type"),
1514+
};
1515+
1516+
let to_write = RecordBatch::try_new(arrow_schema.clone(), vec![col]).unwrap();
1517+
1518+
// Write the Parquet files
1519+
let props = WriterProperties::builder()
1520+
.set_compression(Compression::SNAPPY)
1521+
.build();
1522+
1523+
let file = File::create(format!("{}/1.parquet", &table_location)).unwrap();
1524+
let mut writer =
1525+
ArrowWriter::try_new(file, to_write.schema(), Some(props.clone())).unwrap();
1526+
1527+
writer.write(&to_write).expect("Writing batch");
1528+
1529+
// writer must be closed to write footer
1530+
writer.close().unwrap();
1531+
1532+
(file_io, schema, table_location, tmp_dir)
1533+
}
13271534
}

crates/iceberg/src/arrow/schema.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -647,33 +647,33 @@ pub fn type_to_arrow_type(ty: &crate::spec::Type) -> crate::Result<DataType> {
647647
}
648648

649649
/// Convert Iceberg Datum to Arrow Datum.
650-
pub(crate) fn get_arrow_datum(datum: &Datum) -> Result<Box<dyn ArrowDatum + Send>> {
650+
pub(crate) fn get_arrow_datum(datum: &Datum) -> Result<Arc<dyn ArrowDatum + Send + Sync>> {
651651
match (datum.data_type(), datum.literal()) {
652652
(PrimitiveType::Boolean, PrimitiveLiteral::Boolean(value)) => {
653-
Ok(Box::new(BooleanArray::new_scalar(*value)))
653+
Ok(Arc::new(BooleanArray::new_scalar(*value)))
654654
}
655655
(PrimitiveType::Int, PrimitiveLiteral::Int(value)) => {
656-
Ok(Box::new(Int32Array::new_scalar(*value)))
656+
Ok(Arc::new(Int32Array::new_scalar(*value)))
657657
}
658658
(PrimitiveType::Long, PrimitiveLiteral::Long(value)) => {
659-
Ok(Box::new(Int64Array::new_scalar(*value)))
659+
Ok(Arc::new(Int64Array::new_scalar(*value)))
660660
}
661661
(PrimitiveType::Float, PrimitiveLiteral::Float(value)) => {
662-
Ok(Box::new(Float32Array::new_scalar(value.as_f32())))
662+
Ok(Arc::new(Float32Array::new_scalar(value.to_f32().unwrap())))
663663
}
664664
(PrimitiveType::Double, PrimitiveLiteral::Double(value)) => {
665-
Ok(Box::new(Float64Array::new_scalar(value.as_f64())))
665+
Ok(Arc::new(Float64Array::new_scalar(value.to_f64().unwrap())))
666666
}
667667
(PrimitiveType::String, PrimitiveLiteral::String(value)) => {
668-
Ok(Box::new(StringArray::new_scalar(value.as_str())))
668+
Ok(Arc::new(StringArray::new_scalar(value.as_str())))
669669
}
670670
(PrimitiveType::Date, PrimitiveLiteral::Int(value)) => {
671-
Ok(Box::new(Date32Array::new_scalar(*value)))
671+
Ok(Arc::new(Date32Array::new_scalar(*value)))
672672
}
673673
(PrimitiveType::Timestamp, PrimitiveLiteral::Long(value)) => {
674-
Ok(Box::new(TimestampMicrosecondArray::new_scalar(*value)))
674+
Ok(Arc::new(TimestampMicrosecondArray::new_scalar(*value)))
675675
}
676-
(PrimitiveType::Timestamptz, PrimitiveLiteral::Long(value)) => Ok(Box::new(Scalar::new(
676+
(PrimitiveType::Timestamptz, PrimitiveLiteral::Long(value)) => Ok(Arc::new(Scalar::new(
677677
PrimitiveArray::<TimestampMicrosecondType>::new(vec![*value; 1].into(), None)
678678
.with_timezone("UTC"),
679679
))),

0 commit comments

Comments
 (0)