Skip to content

Commit 8eef76e

Browse files
fix: Raise proper error for mismatching parquet schema instead of panicking (#17321)
1 parent f73937a commit 8eef76e

File tree

10 files changed

+233
-137
lines changed

10 files changed

+233
-137
lines changed

crates/polars-arrow/src/datatypes/schema.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::sync::Arc;
22

3+
use polars_error::{polars_bail, PolarsResult};
34
#[cfg(feature = "serde")]
45
use serde::{Deserialize, Serialize};
56

@@ -62,6 +63,24 @@ impl ArrowSchema {
6263
metadata: self.metadata,
6364
}
6465
}
66+
67+
pub fn try_project(&self, indices: &[usize]) -> PolarsResult<Self> {
68+
let fields = indices.iter().map(|&i| {
69+
let Some(out) = self.fields.get(i) else {
70+
polars_bail!(
71+
SchemaFieldNotFound: "projection index {} is out of bounds for schema of length {}",
72+
i, self.fields.len()
73+
);
74+
};
75+
76+
Ok(out.clone())
77+
}).collect::<PolarsResult<Vec<_>>>()?;
78+
79+
Ok(ArrowSchema {
80+
fields,
81+
metadata: self.metadata.clone(),
82+
})
83+
}
6584
}
6685

6786
impl From<Vec<Field>> for ArrowSchema {

crates/polars-core/src/schema.rs

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,9 @@ pub trait IndexOfSchema: Debug {
446446
/// Get a vector of all column names.
447447
fn get_names(&self) -> Vec<&str>;
448448

449+
/// Get a vector of (name, dtype) pairs
450+
fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, DataType)>;
451+
449452
fn try_index_of(&self, name: &str) -> PolarsResult<usize> {
450453
self.index_of(name).ok_or_else(|| {
451454
polars_err!(
@@ -464,6 +467,13 @@ impl IndexOfSchema for Schema {
464467
fn get_names(&self) -> Vec<&str> {
465468
self.iter_names().map(|name| name.as_str()).collect()
466469
}
470+
471+
fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, DataType)> {
472+
self.inner
473+
.iter()
474+
.map(|(name, dtype)| (name.as_str(), dtype.clone()))
475+
.collect()
476+
}
467477
}
468478

469479
impl IndexOfSchema for ArrowSchema {
@@ -474,6 +484,45 @@ impl IndexOfSchema for ArrowSchema {
474484
fn get_names(&self) -> Vec<&str> {
475485
self.fields.iter().map(|f| f.name.as_str()).collect()
476486
}
487+
488+
fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, DataType)> {
489+
self.fields
490+
.iter()
491+
.map(|x| (x.name.as_str(), DataType::from_arrow(&x.data_type, true)))
492+
.collect()
493+
}
494+
}
495+
496+
pub trait SchemaNamesAndDtypes {
497+
const IS_ARROW: bool;
498+
type DataType: Debug + PartialEq;
499+
500+
/// Get a vector of (name, dtype) pairs
501+
fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, Self::DataType)>;
502+
}
503+
504+
impl SchemaNamesAndDtypes for Schema {
505+
const IS_ARROW: bool = false;
506+
type DataType = DataType;
507+
508+
fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, Self::DataType)> {
509+
self.inner
510+
.iter()
511+
.map(|(name, dtype)| (name.as_str(), dtype.clone()))
512+
.collect()
513+
}
514+
}
515+
516+
impl SchemaNamesAndDtypes for ArrowSchema {
517+
const IS_ARROW: bool = true;
518+
type DataType = ArrowDataType;
519+
520+
fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, Self::DataType)> {
521+
self.fields
522+
.iter()
523+
.map(|x| (x.name.as_str(), x.data_type.clone()))
524+
.collect()
525+
}
477526
}
478527

479528
impl From<&ArrowSchema> for Schema {
@@ -498,3 +547,51 @@ impl From<&ArrowSchemaRef> for Schema {
498547
Self::from(value.as_ref())
499548
}
500549
}
550+
551+
pub fn ensure_matching_schema<S: SchemaNamesAndDtypes>(lhs: &S, rhs: &S) -> PolarsResult<()> {
552+
let lhs = lhs.get_names_and_dtypes();
553+
let rhs = rhs.get_names_and_dtypes();
554+
555+
if lhs.len() != rhs.len() {
556+
polars_bail!(
557+
SchemaMismatch:
558+
"schemas contained differing number of columns: {} != {}",
559+
lhs.len(), rhs.len(),
560+
);
561+
}
562+
563+
for (i, ((l_name, l_dtype), (r_name, r_dtype))) in lhs.iter().zip(&rhs).enumerate() {
564+
if l_name != r_name {
565+
polars_bail!(
566+
SchemaMismatch:
567+
"schema names differ at index {}: {} != {}",
568+
i, l_name, r_name
569+
)
570+
}
571+
if l_dtype != r_dtype
572+
&& (!S::IS_ARROW
573+
|| unsafe {
574+
// For timezone normalization. Easier than writing out the entire PartialEq.
575+
DataType::from_arrow(
576+
std::mem::transmute::<&<S as SchemaNamesAndDtypes>::DataType, &ArrowDataType>(
577+
l_dtype,
578+
),
579+
true,
580+
) != DataType::from_arrow(
581+
std::mem::transmute::<&<S as SchemaNamesAndDtypes>::DataType, &ArrowDataType>(
582+
r_dtype,
583+
),
584+
true,
585+
)
586+
})
587+
{
588+
polars_bail!(
589+
SchemaMismatch:
590+
"schema dtypes differ at index {} for column {}: {:?} != {:?}",
591+
i, l_name, l_dtype, r_dtype
592+
)
593+
}
594+
}
595+
596+
Ok(())
597+
}

crates/polars-io/src/parquet/read/reader.rs

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,22 +80,38 @@ impl<R: MmapBytesReader> ParquetReader<R> {
8080
self
8181
}
8282

83-
/// Set the [`Schema`] if already known. This must be exactly the same as
84-
/// the schema in the file itself.
85-
pub fn with_schema(mut self, schema: Option<ArrowSchemaRef>) -> Self {
86-
self.schema = schema;
87-
self
83+
/// Ensure the schema of the file matches the given schema. Calling this
84+
/// after setting the projection will ensure only the projected indices
85+
/// are checked.
86+
pub fn check_schema(mut self, schema: &ArrowSchema) -> PolarsResult<Self> {
87+
let self_schema = self.schema()?;
88+
let self_schema = self_schema.as_ref();
89+
90+
if let Some(ref projection) = self.projection {
91+
let projection = projection.as_slice();
92+
93+
ensure_matching_schema(
94+
&schema.try_project(projection)?,
95+
&self_schema.try_project(projection)?,
96+
)?;
97+
} else {
98+
ensure_matching_schema(schema, self_schema)?;
99+
}
100+
101+
Ok(self)
88102
}
89103

90104
/// [`Schema`] of the file.
91105
pub fn schema(&mut self) -> PolarsResult<ArrowSchemaRef> {
92-
match &self.schema {
93-
Some(schema) => Ok(schema.clone()),
106+
self.schema = Some(match &self.schema {
107+
Some(schema) => schema.clone(),
94108
None => {
95109
let metadata = self.get_metadata()?;
96-
Ok(Arc::new(read::infer_schema(metadata)?))
110+
Arc::new(read::infer_schema(metadata)?)
97111
},
98-
}
112+
});
113+
114+
Ok(self.schema.clone().unwrap())
99115
}
100116

101117
/// Use statistics in the parquet to determine if pages
@@ -226,7 +242,6 @@ impl ParquetAsyncReader {
226242
pub async fn from_uri(
227243
uri: &str,
228244
cloud_options: Option<&CloudOptions>,
229-
schema: Option<ArrowSchemaRef>,
230245
metadata: Option<FileMetaDataRef>,
231246
) -> PolarsResult<ParquetAsyncReader> {
232247
Ok(ParquetAsyncReader {
@@ -238,20 +253,40 @@ impl ParquetAsyncReader {
238253
predicate: None,
239254
use_statistics: true,
240255
hive_partition_columns: None,
241-
schema,
256+
schema: None,
242257
parallel: Default::default(),
243258
})
244259
}
245260

261+
pub async fn check_schema(mut self, schema: &ArrowSchema) -> PolarsResult<Self> {
262+
let self_schema = self.schema().await?;
263+
let self_schema = self_schema.as_ref();
264+
265+
if let Some(ref projection) = self.projection {
266+
let projection = projection.as_slice();
267+
268+
ensure_matching_schema(
269+
&schema.try_project(projection)?,
270+
&self_schema.try_project(projection)?,
271+
)?;
272+
} else {
273+
ensure_matching_schema(schema, self_schema)?;
274+
}
275+
276+
Ok(self)
277+
}
278+
246279
pub async fn schema(&mut self) -> PolarsResult<ArrowSchemaRef> {
247-
Ok(match self.schema.as_ref() {
280+
self.schema = Some(match self.schema.as_ref() {
248281
Some(schema) => Arc::clone(schema),
249282
None => {
250283
let metadata = self.reader.get_metadata().await?;
251284
let arrow_schema = polars_parquet::arrow::read::infer_schema(metadata)?;
252285
Arc::new(arrow_schema)
253286
},
254-
})
287+
});
288+
289+
Ok(self.schema.clone().unwrap())
255290
}
256291

257292
pub async fn num_rows(&mut self) -> PolarsResult<usize> {

crates/polars-io/src/utils.rs

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -270,53 +270,6 @@ pub fn materialize_projection(
270270
}
271271
}
272272

273-
pub fn check_projected_schema_impl(
274-
a: &Schema,
275-
b: &Schema,
276-
projected_names: Option<&[String]>,
277-
msg: &str,
278-
) -> PolarsResult<()> {
279-
if !projected_names
280-
.map(|projected_names| {
281-
projected_names
282-
.iter()
283-
.all(|name| a.get(name) == b.get(name))
284-
})
285-
.unwrap_or_else(|| a == b)
286-
{
287-
polars_bail!(ComputeError: "{msg}\n\n\
288-
Expected: {:?}\n\n\
289-
Got: {:?}", a, b)
290-
}
291-
Ok(())
292-
}
293-
294-
/// Checks if the projected columns are equal
295-
pub fn check_projected_arrow_schema(
296-
a: &ArrowSchema,
297-
b: &ArrowSchema,
298-
projected_names: Option<&[String]>,
299-
msg: &str,
300-
) -> PolarsResult<()> {
301-
if a != b {
302-
let a = Schema::from(a);
303-
let b = Schema::from(b);
304-
check_projected_schema_impl(&a, &b, projected_names, msg)
305-
} else {
306-
Ok(())
307-
}
308-
}
309-
310-
/// Checks if the projected columns are equal
311-
pub fn check_projected_schema(
312-
a: &Schema,
313-
b: &Schema,
314-
projected_names: Option<&[String]>,
315-
msg: &str,
316-
) -> PolarsResult<()> {
317-
check_projected_schema_impl(a, b, projected_names, msg)
318-
}
319-
320273
/// Split DataFrame into chunks in preparation for writing. The chunks have a
321274
/// maximum number of rows per chunk to ensure reasonable memory efficiency when
322275
/// reading the resulting file, and a minimum size per chunk to ensure

0 commit comments

Comments
 (0)