Skip to content

Commit 6f3a8f0

Browse files
authored
feat: Adding with_schema_unchecked method for RecordBatch (#7402)
* feat: Adding `with_schema_force` method for `RecordBatch`
1 parent 959499b commit 6f3a8f0

File tree

1 file changed

+92
-2
lines changed

1 file changed

+92
-2
lines changed

Diff for: arrow-array/src/record_batch.rs

+92-2
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,18 @@ impl RecordBatch {
359359
})
360360
}
361361

362+
/// Overrides the schema of this [`RecordBatch`]
363+
/// without additional schema checks. Note, however, that this pushes all the schema compatibility responsibilities
364+
/// to the caller site. In particular, the caller guarantees that `schema` is a superset
365+
/// of the current schema as determined by [`Schema::contains`].
366+
pub fn with_schema_unchecked(self, schema: SchemaRef) -> Result<Self, ArrowError> {
367+
Ok(Self {
368+
schema,
369+
columns: self.columns,
370+
row_count: self.row_count,
371+
})
372+
}
373+
362374
/// Returns the [`Schema`] of the record batch.
363375
pub fn schema(&self) -> SchemaRef {
364376
self.schema.clone()
@@ -744,12 +756,14 @@ impl RecordBatchOptions {
744756
row_count: None,
745757
}
746758
}
747-
/// Sets the row_count of RecordBatchOptions and returns self
759+
760+
/// Sets the `row_count` of `RecordBatchOptions` and returns this [`RecordBatch`]
748761
pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
749762
self.row_count = row_count;
750763
self
751764
}
752-
/// Sets the match_field_names of RecordBatchOptions and returns self
765+
766+
/// Sets the `match_field_names` of `RecordBatchOptions` and returns this [`RecordBatch`]
753767
pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
754768
self.match_field_names = match_field_names;
755769
self
@@ -1637,4 +1651,80 @@ mod tests {
16371651
"bar"
16381652
);
16391653
}
1654+
1655+
#[test]
1656+
fn test_batch_with_unchecked_schema() {
1657+
fn apply_schema_unchecked(
1658+
record_batch: &RecordBatch,
1659+
schema_ref: SchemaRef,
1660+
idx: usize,
1661+
) -> Option<ArrowError> {
1662+
record_batch
1663+
.clone()
1664+
.with_schema_unchecked(schema_ref)
1665+
.unwrap()
1666+
.project(&[idx])
1667+
.err()
1668+
}
1669+
1670+
let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1671+
1672+
let record_batch =
1673+
RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");
1674+
1675+
// Test empty schema for non-empty schema batch
1676+
let invalid_schema_empty = Schema::empty();
1677+
assert_eq!(
1678+
apply_schema_unchecked(&record_batch, invalid_schema_empty.into(), 0)
1679+
.unwrap()
1680+
.to_string(),
1681+
"Schema error: project index 0 out of bounds, max field 0"
1682+
);
1683+
1684+
// Wrong number of columns
1685+
let invalid_schema_more_cols = Schema::new(vec![
1686+
Field::new("a", DataType::Utf8, false),
1687+
Field::new("b", DataType::Int32, false),
1688+
]);
1689+
1690+
assert!(
1691+
apply_schema_unchecked(&record_batch, invalid_schema_more_cols.clone().into(), 0)
1692+
.is_none()
1693+
);
1694+
1695+
assert_eq!(
1696+
apply_schema_unchecked(&record_batch, invalid_schema_more_cols.into(), 1)
1697+
.unwrap()
1698+
.to_string(),
1699+
"Schema error: project index 1 out of bounds, max field 1"
1700+
);
1701+
1702+
// Wrong datatype
1703+
let invalid_schema_wrong_datatype =
1704+
Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1705+
assert_eq!(apply_schema_unchecked(&record_batch, invalid_schema_wrong_datatype.into(), 0).unwrap().to_string(), "Invalid argument error: column types must match schema types, expected Int32 but found Utf8 at column index 0");
1706+
1707+
// Wrong column name. A instead C
1708+
let invalid_schema_wrong_col_name =
1709+
Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
1710+
1711+
assert!(record_batch
1712+
.clone()
1713+
.with_schema_unchecked(invalid_schema_wrong_col_name.into())
1714+
.unwrap()
1715+
.column_by_name("c")
1716+
.is_none());
1717+
1718+
// Valid schema
1719+
let valid_schema = Schema::new(vec![Field::new("c", DataType::Utf8, false)]);
1720+
1721+
assert_eq!(
1722+
record_batch
1723+
.clone()
1724+
.with_schema_unchecked(valid_schema.into())
1725+
.unwrap()
1726+
.column_by_name("c"),
1727+
record_batch.column_by_name("c")
1728+
);
1729+
}
16401730
}

0 commit comments

Comments
 (0)