Skip to content

Commit 7f3907e

Browse files
authored
Replace with_schema_unchecked with new_unchecked (#7405)
* Make with_schema_unchecked unsafe * Fix tests * Revert with_schema_unchecked * Add new_unchecked * Add into_parts * Review feedback
1 parent 6f3a8f0 commit 7f3907e

File tree

1 file changed

+39
-96
lines changed

1 file changed

+39
-96
lines changed

Diff for: arrow-array/src/record_batch.rs

+39-96
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,11 @@ impl RecordBatch {
211211
/// Creates a `RecordBatch` from a schema and columns.
212212
///
213213
/// Expects the following:
214-
/// * the vec of columns to not be empty
215-
/// * the schema and column data types to have equal lengths
216-
/// and match
217-
/// * each array in columns to have the same length
214+
///
215+
/// * `!columns.is_empty()`
216+
/// * `schema.fields.len() == columns.len()`
217+
/// * `schema.fields[i].data_type() == columns[i].data_type()`
218+
/// * `columns[i].len() == columns[j].len()`
218219
///
219220
/// If the conditions are not met, an error is returned.
220221
///
@@ -240,6 +241,33 @@ impl RecordBatch {
240241
Self::try_new_impl(schema, columns, &options)
241242
}
242243

244+
/// Creates a `RecordBatch` from a schema and columns, without validation.
245+
///
246+
/// See [`Self::try_new`] for the checked version.
247+
///
248+
/// # Safety
249+
///
250+
/// Expects the following:
251+
///
252+
/// * `schema.fields.len() == columns.len()`
253+
/// * `schema.fields[i].data_type() == columns[i].data_type()`
254+
/// * `columns[i].len() == row_count`
255+
///
256+
/// Note: if the schema does not match the underlying data exactly, it can lead to undefined
257+
/// behavior, for example, via conversion to a `StructArray`, which in turn could lead
258+
/// to incorrect access.
259+
pub unsafe fn new_unchecked(
260+
schema: SchemaRef,
261+
columns: Vec<Arc<dyn Array>>,
262+
row_count: usize,
263+
) -> Self {
264+
Self {
265+
schema,
266+
columns,
267+
row_count,
268+
}
269+
}
270+
243271
/// Creates a `RecordBatch` from a schema and columns, with additional options,
244272
/// such as whether to strictly validate field names.
245273
///
@@ -340,6 +368,11 @@ impl RecordBatch {
340368
})
341369
}
342370

371+
/// Return the schema, columns and row count of this [`RecordBatch`]
372+
pub fn into_parts(self) -> (SchemaRef, Vec<ArrayRef>, usize) {
373+
(self.schema, self.columns, self.row_count)
374+
}
375+
343376
/// Override the schema of this [`RecordBatch`]
344377
///
345378
/// Returns an error if `schema` is not a superset of the current schema
@@ -359,18 +392,6 @@ impl RecordBatch {
359392
})
360393
}
361394

362-
/// Overrides the schema of this [`RecordBatch`]
363-
/// without additional schema checks. Note, however, that this pushes all the schema compatibility responsibilities
364-
/// to the caller site. In particular, the caller guarantees that `schema` is a superset
365-
/// of the current schema as determined by [`Schema::contains`].
366-
pub fn with_schema_unchecked(self, schema: SchemaRef) -> Result<Self, ArrowError> {
367-
Ok(Self {
368-
schema,
369-
columns: self.columns,
370-
row_count: self.row_count,
371-
})
372-
}
373-
374395
/// Returns the [`Schema`] of the record batch.
375396
pub fn schema(&self) -> SchemaRef {
376397
self.schema.clone()
@@ -756,14 +777,12 @@ impl RecordBatchOptions {
756777
row_count: None,
757778
}
758779
}
759-
760-
/// Sets the `row_count` of `RecordBatchOptions` and returns this [`RecordBatch`]
780+
/// Sets the row_count of RecordBatchOptions and returns self
761781
pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
762782
self.row_count = row_count;
763783
self
764784
}
765-
766-
/// Sets the `match_field_names` of `RecordBatchOptions` and returns this [`RecordBatch`]
785+
/// Sets the match_field_names of RecordBatchOptions and returns self
767786
pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
768787
self.match_field_names = match_field_names;
769788
self
@@ -1651,80 +1670,4 @@ mod tests {
16511670
"bar"
16521671
);
16531672
}
1654-
1655-
#[test]
1656-
fn test_batch_with_unchecked_schema() {
1657-
fn apply_schema_unchecked(
1658-
record_batch: &RecordBatch,
1659-
schema_ref: SchemaRef,
1660-
idx: usize,
1661-
) -> Option<ArrowError> {
1662-
record_batch
1663-
.clone()
1664-
.with_schema_unchecked(schema_ref)
1665-
.unwrap()
1666-
.project(&[idx])
1667-
.err()
1668-
}
1669-
1670-
let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1671-
1672-
let record_batch =
1673-
RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");
1674-
1675-
// Test empty schema for non-empty schema batch
1676-
let invalid_schema_empty = Schema::empty();
1677-
assert_eq!(
1678-
apply_schema_unchecked(&record_batch, invalid_schema_empty.into(), 0)
1679-
.unwrap()
1680-
.to_string(),
1681-
"Schema error: project index 0 out of bounds, max field 0"
1682-
);
1683-
1684-
// Wrong number of columns
1685-
let invalid_schema_more_cols = Schema::new(vec![
1686-
Field::new("a", DataType::Utf8, false),
1687-
Field::new("b", DataType::Int32, false),
1688-
]);
1689-
1690-
assert!(
1691-
apply_schema_unchecked(&record_batch, invalid_schema_more_cols.clone().into(), 0)
1692-
.is_none()
1693-
);
1694-
1695-
assert_eq!(
1696-
apply_schema_unchecked(&record_batch, invalid_schema_more_cols.into(), 1)
1697-
.unwrap()
1698-
.to_string(),
1699-
"Schema error: project index 1 out of bounds, max field 1"
1700-
);
1701-
1702-
// Wrong datatype
1703-
let invalid_schema_wrong_datatype =
1704-
Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1705-
assert_eq!(apply_schema_unchecked(&record_batch, invalid_schema_wrong_datatype.into(), 0).unwrap().to_string(), "Invalid argument error: column types must match schema types, expected Int32 but found Utf8 at column index 0");
1706-
1707-
// Wrong column name. A instead C
1708-
let invalid_schema_wrong_col_name =
1709-
Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
1710-
1711-
assert!(record_batch
1712-
.clone()
1713-
.with_schema_unchecked(invalid_schema_wrong_col_name.into())
1714-
.unwrap()
1715-
.column_by_name("c")
1716-
.is_none());
1717-
1718-
// Valid schema
1719-
let valid_schema = Schema::new(vec![Field::new("c", DataType::Utf8, false)]);
1720-
1721-
assert_eq!(
1722-
record_batch
1723-
.clone()
1724-
.with_schema_unchecked(valid_schema.into())
1725-
.unwrap()
1726-
.column_by_name("c"),
1727-
record_batch.column_by_name("c")
1728-
);
1729-
}
17301673
}

0 commit comments

Comments
 (0)