Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions arrow-array/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,20 @@ impl RecordBatch {
})
}

/// Forcibly overrides the schema of this [`RecordBatch`]
/// without additional schema checks however bringing all the schema compatibility responsibilities
/// to the caller site.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Forcibly overrides the schema of this [`RecordBatch`]
/// without additional schema checks however bringing all the schema compatibility responsibilities
/// to the caller site.
/// Overrides the schema of this [`RecordBatch`]
/// without additional schema checks. Note, however, that this pushes all the schema compatibility responsibilities
/// to the caller site. In particular, the caller guarantees that `schema` is a superset
/// of the current schema as determined by [`Schema::contains`].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice said!

///
/// If provided schema is not compatible with this [`RecordBatch`] columns the runtime behavior
/// is undefined
pub fn with_schema_force(self, schema: SchemaRef) -> Result<Self, ArrowError> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe with_schema_unchecked would be better?

Ok(Self {
schema,
columns: self.columns,
row_count: self.row_count,
})
}

/// Returns the [`Schema`] of the record batch.
pub fn schema(&self) -> SchemaRef {
self.schema.clone()
Expand Down Expand Up @@ -744,12 +758,14 @@ impl RecordBatchOptions {
row_count: None,
}
}
/// Sets the row_count of RecordBatchOptions and returns self

/// Sets the `row_count` of `RecordBatchOptions` and returns this [`RecordBatch`]
pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
self.row_count = row_count;
self
}
/// Sets the match_field_names of RecordBatchOptions and returns self

/// Sets the `match_field_names` of `RecordBatchOptions` and returns this [`RecordBatch`]
pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
self.match_field_names = match_field_names;
self
Expand Down Expand Up @@ -1637,4 +1653,57 @@ mod tests {
"bar"
);
}

#[test]
fn test_batch_with_force_schema() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please also add a check where the forced schema succeeds?

fn force_schema_and_get_err_from_batch(
record_batch: &RecordBatch,
schema_ref: SchemaRef,
idx: usize,
) -> Option<ArrowError> {
record_batch
.clone()
.with_schema_force(schema_ref)
.unwrap()
.project(&[idx])
.err()
}

let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));

let record_batch =
RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");

// Test empty schema for non-empty schema batch
let invalid_schema_empty = Schema::empty();
assert_eq!(
force_schema_and_get_err_from_batch(&record_batch, invalid_schema_empty.into(), 0)
.unwrap()
.to_string(),
"Schema error: project index 0 out of bounds, max field 0"
);

// Wrong number of columns
let invalid_schema_more_cols = Schema::new(vec![
Field::new("a", DataType::Utf8, false),
Field::new("a", DataType::Int32, false),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),

?. This triggers my OCD 😅

]);
assert!(force_schema_and_get_err_from_batch(
&record_batch,
invalid_schema_more_cols.clone().into(),
0
)
.is_none());
assert_eq!(
force_schema_and_get_err_from_batch(&record_batch, invalid_schema_more_cols.into(), 1)
.unwrap()
.to_string(),
"Schema error: project index 1 out of bounds, max field 1"
);

// Wrong datatype
let invalid_schema_wrong_datatype =
Schema::new(vec![Field::new("a", DataType::Int32, false)]);
assert_eq!(force_schema_and_get_err_from_batch(&record_batch, invalid_schema_wrong_datatype.into(), 0).unwrap().to_string(), "Invalid argument error: column types must match schema types, expected Int32 but found Utf8 at column index 0");
}
}
Loading