Skip to content

Commit 904123c

Browse files
Fix predicates not matching the Arrow type of columns read from parquet
files
1 parent 16529b5 commit 904123c

4 files changed

Lines changed: 241 additions & 20 deletions

File tree

crates/iceberg/src/arrow/reader.rs

Lines changed: 221 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ use std::str::FromStr;
2323
use std::sync::Arc;
2424

2525
use arrow_arith::boolean::{and, is_not_null, is_null, not, or};
26-
use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch};
26+
use arrow_array::{Array, ArrayRef, BooleanArray, Datum as ArrowDatum, RecordBatch, Scalar};
27+
use arrow_cast::cast::cast;
2728
use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq};
2829
use arrow_schema::{
2930
ArrowError, DataType, FieldRef, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef,
@@ -907,6 +908,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
907908

908909
Ok(Box::new(move |batch| {
909910
let left = project_column(&batch, idx)?;
911+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
910912
lt(&left, literal.as_ref())
911913
}))
912914
} else {
@@ -926,6 +928,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
926928

927929
Ok(Box::new(move |batch| {
928930
let left = project_column(&batch, idx)?;
931+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
929932
lt_eq(&left, literal.as_ref())
930933
}))
931934
} else {
@@ -945,6 +948,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
945948

946949
Ok(Box::new(move |batch| {
947950
let left = project_column(&batch, idx)?;
951+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
948952
gt(&left, literal.as_ref())
949953
}))
950954
} else {
@@ -964,6 +968,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
964968

965969
Ok(Box::new(move |batch| {
966970
let left = project_column(&batch, idx)?;
971+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
967972
gt_eq(&left, literal.as_ref())
968973
}))
969974
} else {
@@ -983,6 +988,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
983988

984989
Ok(Box::new(move |batch| {
985990
let left = project_column(&batch, idx)?;
991+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
986992
eq(&left, literal.as_ref())
987993
}))
988994
} else {
@@ -1002,6 +1008,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10021008

10031009
Ok(Box::new(move |batch| {
10041010
let left = project_column(&batch, idx)?;
1011+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
10051012
neq(&left, literal.as_ref())
10061013
}))
10071014
} else {
@@ -1021,6 +1028,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10211028

10221029
Ok(Box::new(move |batch| {
10231030
let left = project_column(&batch, idx)?;
1031+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
10241032
starts_with(&left, literal.as_ref())
10251033
}))
10261034
} else {
@@ -1040,7 +1048,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10401048

10411049
Ok(Box::new(move |batch| {
10421050
let left = project_column(&batch, idx)?;
1043-
1051+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
10441052
// update here if arrow ever adds a native not_starts_with
10451053
not(&starts_with(&left, literal.as_ref())?)
10461054
}))
@@ -1065,8 +1073,10 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10651073
Ok(Box::new(move |batch| {
10661074
// update this if arrow ever adds a native is_in kernel
10671075
let left = project_column(&batch, idx)?;
1076+
10681077
let mut acc = BooleanArray::from(vec![false; batch.num_rows()]);
10691078
for literal in &literals {
1079+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
10701080
acc = or(&acc, &eq(&left, literal.as_ref())?)?
10711081
}
10721082

@@ -1095,6 +1105,7 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
10951105
let left = project_column(&batch, idx)?;
10961106
let mut acc = BooleanArray::from(vec![true; batch.num_rows()]);
10971107
for literal in &literals {
1108+
let literal = cast_literal_if_required(Arc::clone(&literal), left.data_type())?;
10981109
acc = and(&acc, &neq(&left, literal.as_ref())?)?
10991110
}
11001111

@@ -1150,21 +1161,53 @@ impl<R: FileRead> AsyncFileReader for ArrowFileReader<R> {
11501161
}
11511162
}
11521163

1164+
/// The Arrow type of an array that the Parquet reader reads may not match the exact Arrow type
1165+
/// that Iceberg uses for literals - but they are effectively the same logical type,
1166+
/// i.e. LargeUtf8 and Utf8 or Utf8View and Utf8 or Utf8View and LargeUtf8.
1167+
///
1168+
/// The Arrow compute kernels that we use must match the type exactly, so first cast the literal
1169+
/// into the type of the batch we read from Parquet before sending it to the compute kernel.
1170+
fn cast_literal_if_required(
1171+
literal: Arc<dyn ArrowDatum + Send + Sync>,
1172+
column_type: &DataType,
1173+
) -> std::result::Result<Arc<dyn ArrowDatum + Send + Sync>, ArrowError> {
1174+
let literal_array = literal.get().0;
1175+
1176+
// No cast required
1177+
if literal_array.data_type() == column_type {
1178+
return Ok(literal);
1179+
}
1180+
1181+
let literal_array = cast(literal_array, column_type)?;
1182+
Ok(Arc::new(Scalar::new(literal_array)))
1183+
}
1184+
11531185
#[cfg(test)]
11541186
mod tests {
11551187
use std::collections::{HashMap, HashSet};
1188+
use std::fs::File;
11561189
use std::sync::Arc;
11571190

1191+
use arrow_array::cast::AsArray;
1192+
use arrow_array::{ArrayRef, LargeStringArray, RecordBatch, StringArray};
11581193
use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit};
1159-
use parquet::arrow::ProjectionMask;
1194+
use futures::TryStreamExt;
1195+
use parquet::arrow::{ArrowWriter, ProjectionMask};
1196+
use parquet::basic::Compression;
1197+
use parquet::file::properties::WriterProperties;
11601198
use parquet::schema::parser::parse_message_type;
11611199
use parquet::schema::types::SchemaDescriptor;
1200+
use tempfile::TempDir;
11621201

11631202
use crate::arrow::reader::{CollectFieldIdVisitor, PARQUET_FIELD_ID_META_KEY};
1164-
use crate::arrow::ArrowReader;
1203+
use crate::arrow::{ArrowReader, ArrowReaderBuilder};
11651204
use crate::expr::visitors::bound_predicate_visitor::visit;
1166-
use crate::expr::{Bind, Reference};
1167-
use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type};
1205+
use crate::expr::{Bind, Predicate, Reference};
1206+
use crate::io::FileIO;
1207+
use crate::scan::{FileScanTask, FileScanTaskStream};
1208+
use crate::spec::{
1209+
DataContentType, DataFileFormat, Datum, NestedField, PrimitiveType, Schema, SchemaRef, Type,
1210+
};
11681211
use crate::ErrorKind;
11691212

11701213
fn table_schema_simple() -> SchemaRef {
@@ -1324,4 +1367,176 @@ message schema {
13241367
.expect("Some ProjectionMask");
13251368
assert_eq!(mask, ProjectionMask::leaves(&parquet_schema, vec![0]));
13261369
}
1370+
1371+
#[tokio::test]
1372+
async fn test_predicate_cast_literal() {
1373+
let predicates = vec![
1374+
// a == 'foo'
1375+
(
1376+
Reference::new("a").equal_to(Datum::string("foo")),
1377+
vec![Some("foo".to_string())],
1378+
),
1379+
// a != 'foo'
1380+
(
1381+
Reference::new("a").not_equal_to(Datum::string("foo")),
1382+
vec![Some("bar".to_string())],
1383+
),
1384+
// STARTS_WITH(a, 'foo')
1385+
(
1386+
Reference::new("a").starts_with(Datum::string("f")),
1387+
vec![Some("foo".to_string())],
1388+
),
1389+
// NOT STARTS_WITH(a, 'foo')
1390+
(
1391+
Reference::new("a").not_starts_with(Datum::string("f")),
1392+
vec![Some("bar".to_string())],
1393+
),
1394+
// a < 'foo'
1395+
(
1396+
Reference::new("a").less_than(Datum::string("foo")),
1397+
vec![Some("bar".to_string())],
1398+
),
1399+
// a <= 'foo'
1400+
(
1401+
Reference::new("a").less_than_or_equal_to(Datum::string("foo")),
1402+
vec![Some("foo".to_string()), Some("bar".to_string())],
1403+
),
1404+
// a > 'foo'
1405+
(
1406+
Reference::new("a").greater_than(Datum::string("bar")),
1407+
vec![Some("foo".to_string())],
1408+
),
1409+
// a >= 'foo'
1410+
(
1411+
Reference::new("a").greater_than_or_equal_to(Datum::string("foo")),
1412+
vec![Some("foo".to_string())],
1413+
),
1414+
// a IN ('foo', 'bar')
1415+
(
1416+
Reference::new("a").is_in([Datum::string("foo"), Datum::string("baz")]),
1417+
vec![Some("foo".to_string())],
1418+
),
1419+
// a NOT IN ('foo', 'bar')
1420+
(
1421+
Reference::new("a").is_not_in([Datum::string("foo"), Datum::string("baz")]),
1422+
vec![Some("bar".to_string())],
1423+
),
1424+
];
1425+
1426+
// Table data: ["foo", "bar"]
1427+
let data_for_col_a = vec![Some("foo".to_string()), Some("bar".to_string())];
1428+
1429+
let (file_io, schema, table_location, _temp_dir) =
1430+
setup_kleene_logic(data_for_col_a, DataType::LargeUtf8);
1431+
let reader = ArrowReaderBuilder::new(file_io).build();
1432+
1433+
for (predicate, expected) in predicates {
1434+
println!("testing predicate {predicate}");
1435+
let result_data = test_perform_read(
1436+
predicate.clone(),
1437+
schema.clone(),
1438+
table_location.clone(),
1439+
reader.clone(),
1440+
)
1441+
.await;
1442+
1443+
assert_eq!(result_data, expected, "predicate={predicate}");
1444+
}
1445+
}
1446+
1447+
async fn test_perform_read(
1448+
predicate: Predicate,
1449+
schema: SchemaRef,
1450+
table_location: String,
1451+
reader: ArrowReader,
1452+
) -> Vec<Option<String>> {
1453+
let tasks = Box::pin(futures::stream::iter(
1454+
vec![Ok(FileScanTask {
1455+
start: 0,
1456+
length: 0,
1457+
record_count: None,
1458+
data_file_path: format!("{}/1.parquet", table_location),
1459+
data_file_content: DataContentType::Data,
1460+
data_file_format: DataFileFormat::Parquet,
1461+
schema: schema.clone(),
1462+
project_field_ids: vec![1],
1463+
predicate: Some(predicate.bind(schema, true).unwrap()),
1464+
})]
1465+
.into_iter(),
1466+
)) as FileScanTaskStream;
1467+
1468+
let result = reader
1469+
.read(tasks)
1470+
.await
1471+
.unwrap()
1472+
.try_collect::<Vec<RecordBatch>>()
1473+
.await
1474+
.unwrap();
1475+
1476+
let result_data = result[0].columns()[0]
1477+
.as_string_opt::<i32>()
1478+
.unwrap()
1479+
.iter()
1480+
.map(|v| v.map(ToOwned::to_owned))
1481+
.collect::<Vec<_>>();
1482+
1483+
result_data
1484+
}
1485+
1486+
fn setup_kleene_logic(
1487+
data_for_col_a: Vec<Option<String>>,
1488+
col_a_type: DataType,
1489+
) -> (FileIO, SchemaRef, String, TempDir) {
1490+
let schema = Arc::new(
1491+
Schema::builder()
1492+
.with_schema_id(1)
1493+
.with_fields(vec![NestedField::optional(
1494+
1,
1495+
"a",
1496+
Type::Primitive(PrimitiveType::String),
1497+
)
1498+
.into()])
1499+
.build()
1500+
.unwrap(),
1501+
);
1502+
1503+
let arrow_schema = Arc::new(ArrowSchema::new(vec![Field::new(
1504+
"a",
1505+
col_a_type.clone(),
1506+
true,
1507+
)
1508+
.with_metadata(HashMap::from([(
1509+
PARQUET_FIELD_ID_META_KEY.to_string(),
1510+
"1".to_string(),
1511+
)]))]));
1512+
1513+
let tmp_dir = TempDir::new().unwrap();
1514+
let table_location = tmp_dir.path().to_str().unwrap().to_string();
1515+
1516+
let file_io = FileIO::from_path(&table_location).unwrap().build().unwrap();
1517+
1518+
let col = match col_a_type {
1519+
DataType::Utf8 => Arc::new(StringArray::from(data_for_col_a)) as ArrayRef,
1520+
DataType::LargeUtf8 => Arc::new(LargeStringArray::from(data_for_col_a)) as ArrayRef,
1521+
_ => panic!("unexpected col_a_type"),
1522+
};
1523+
1524+
let to_write = RecordBatch::try_new(arrow_schema.clone(), vec![col]).unwrap();
1525+
1526+
// Write the Parquet files
1527+
let props = WriterProperties::builder()
1528+
.set_compression(Compression::SNAPPY)
1529+
.build();
1530+
1531+
let file = File::create(format!("{}/1.parquet", &table_location)).unwrap();
1532+
let mut writer =
1533+
ArrowWriter::try_new(file, to_write.schema(), Some(props.clone())).unwrap();
1534+
1535+
writer.write(&to_write).expect("Writing batch");
1536+
1537+
// writer must be closed to write footer
1538+
writer.close().unwrap();
1539+
1540+
(file_io, schema, table_location, tmp_dir)
1541+
}
13271542
}

crates/iceberg/src/arrow/schema.rs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ use arrow_array::{
2828
Int64Array, PrimitiveArray, Scalar, StringArray, TimestampMicrosecondArray,
2929
};
3030
use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit};
31-
use bitvec::macros::internal::funty::Fundamental;
3231
use num_bigint::BigInt;
3332
use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
3433
use parquet::file::statistics::Statistics;
@@ -647,33 +646,33 @@ pub fn type_to_arrow_type(ty: &crate::spec::Type) -> crate::Result<DataType> {
647646
}
648647

649648
/// Convert Iceberg Datum to Arrow Datum.
650-
pub(crate) fn get_arrow_datum(datum: &Datum) -> Result<Box<dyn ArrowDatum + Send>> {
649+
pub(crate) fn get_arrow_datum(datum: &Datum) -> Result<Arc<dyn ArrowDatum + Send + Sync>> {
651650
match (datum.data_type(), datum.literal()) {
652651
(PrimitiveType::Boolean, PrimitiveLiteral::Boolean(value)) => {
653-
Ok(Box::new(BooleanArray::new_scalar(*value)))
652+
Ok(Arc::new(BooleanArray::new_scalar(*value)))
654653
}
655654
(PrimitiveType::Int, PrimitiveLiteral::Int(value)) => {
656-
Ok(Box::new(Int32Array::new_scalar(*value)))
655+
Ok(Arc::new(Int32Array::new_scalar(*value)))
657656
}
658657
(PrimitiveType::Long, PrimitiveLiteral::Long(value)) => {
659-
Ok(Box::new(Int64Array::new_scalar(*value)))
658+
Ok(Arc::new(Int64Array::new_scalar(*value)))
660659
}
661660
(PrimitiveType::Float, PrimitiveLiteral::Float(value)) => {
662-
Ok(Box::new(Float32Array::new_scalar(value.as_f32())))
661+
Ok(Arc::new(Float32Array::new_scalar(value.to_f32().unwrap())))
663662
}
664663
(PrimitiveType::Double, PrimitiveLiteral::Double(value)) => {
665-
Ok(Box::new(Float64Array::new_scalar(value.as_f64())))
664+
Ok(Arc::new(Float64Array::new_scalar(value.to_f64().unwrap())))
666665
}
667666
(PrimitiveType::String, PrimitiveLiteral::String(value)) => {
668-
Ok(Box::new(StringArray::new_scalar(value.as_str())))
667+
Ok(Arc::new(StringArray::new_scalar(value.as_str())))
669668
}
670669
(PrimitiveType::Date, PrimitiveLiteral::Int(value)) => {
671-
Ok(Box::new(Date32Array::new_scalar(*value)))
670+
Ok(Arc::new(Date32Array::new_scalar(*value)))
672671
}
673672
(PrimitiveType::Timestamp, PrimitiveLiteral::Long(value)) => {
674-
Ok(Box::new(TimestampMicrosecondArray::new_scalar(*value)))
673+
Ok(Arc::new(TimestampMicrosecondArray::new_scalar(*value)))
675674
}
676-
(PrimitiveType::Timestamptz, PrimitiveLiteral::Long(value)) => Ok(Box::new(Scalar::new(
675+
(PrimitiveType::Timestamptz, PrimitiveLiteral::Long(value)) => Ok(Arc::new(Scalar::new(
677676
PrimitiveArray::<TimestampMicrosecondType>::new(vec![*value; 1].into(), None)
678677
.with_timezone("UTC"),
679678
))),

crates/iceberg/src/inspect/metadata_table.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ impl<'a> MetadataTable<'a> {
4444
}
4545
}
4646

47+
#[allow(missing_docs)]
4748
#[cfg(test)]
4849
pub mod tests {
4950
use expect_test::Expect;

0 commit comments

Comments
 (0)