@@ -359,6 +359,18 @@ impl RecordBatch {
359359 } )
360360 }
361361
362+ /// Overrides the schema of this [`RecordBatch`]
363+ /// without additional schema checks. Note, however, that this pushes all the schema compatibility responsibilities
364+ /// to the caller site. In particular, the caller guarantees that `schema` is a superset
365+ /// of the current schema as determined by [`Schema::contains`].
366+ pub fn with_schema_unchecked ( self , schema : SchemaRef ) -> Result < Self , ArrowError > {
367+ Ok ( Self {
368+ schema,
369+ columns : self . columns ,
370+ row_count : self . row_count ,
371+ } )
372+ }
373+
362374 /// Returns the [`Schema`] of the record batch.
363375 pub fn schema ( & self ) -> SchemaRef {
364376 self . schema . clone ( )
@@ -744,12 +756,14 @@ impl RecordBatchOptions {
744756 row_count : None ,
745757 }
746758 }
747- /// Sets the row_count of RecordBatchOptions and returns self
759+
760+ /// Sets the `row_count` of `RecordBatchOptions` and returns this [`RecordBatch`]
748761 pub fn with_row_count ( mut self , row_count : Option < usize > ) -> Self {
749762 self . row_count = row_count;
750763 self
751764 }
752- /// Sets the match_field_names of RecordBatchOptions and returns self
765+
766+ /// Sets the `match_field_names` of `RecordBatchOptions` and returns this [`RecordBatch`]
753767 pub fn with_match_field_names ( mut self , match_field_names : bool ) -> Self {
754768 self . match_field_names = match_field_names;
755769 self
@@ -1637,4 +1651,80 @@ mod tests {
16371651 "bar"
16381652 ) ;
16391653 }
1654+
1655+ #[ test]
1656+ fn test_batch_with_unchecked_schema ( ) {
1657+ fn apply_schema_unchecked (
1658+ record_batch : & RecordBatch ,
1659+ schema_ref : SchemaRef ,
1660+ idx : usize ,
1661+ ) -> Option < ArrowError > {
1662+ record_batch
1663+ . clone ( )
1664+ . with_schema_unchecked ( schema_ref)
1665+ . unwrap ( )
1666+ . project ( & [ idx] )
1667+ . err ( )
1668+ }
1669+
1670+ let c: ArrayRef = Arc :: new ( StringArray :: from ( vec ! [ "d" , "e" , "f" ] ) ) ;
1671+
1672+ let record_batch =
1673+ RecordBatch :: try_from_iter ( vec ! [ ( "c" , c. clone( ) ) ] ) . expect ( "valid conversion" ) ;
1674+
1675+ // Test empty schema for non-empty schema batch
1676+ let invalid_schema_empty = Schema :: empty ( ) ;
1677+ assert_eq ! (
1678+ apply_schema_unchecked( & record_batch, invalid_schema_empty. into( ) , 0 )
1679+ . unwrap( )
1680+ . to_string( ) ,
1681+ "Schema error: project index 0 out of bounds, max field 0"
1682+ ) ;
1683+
1684+ // Wrong number of columns
1685+ let invalid_schema_more_cols = Schema :: new ( vec ! [
1686+ Field :: new( "a" , DataType :: Utf8 , false ) ,
1687+ Field :: new( "b" , DataType :: Int32 , false ) ,
1688+ ] ) ;
1689+
1690+ assert ! (
1691+ apply_schema_unchecked( & record_batch, invalid_schema_more_cols. clone( ) . into( ) , 0 )
1692+ . is_none( )
1693+ ) ;
1694+
1695+ assert_eq ! (
1696+ apply_schema_unchecked( & record_batch, invalid_schema_more_cols. into( ) , 1 )
1697+ . unwrap( )
1698+ . to_string( ) ,
1699+ "Schema error: project index 1 out of bounds, max field 1"
1700+ ) ;
1701+
1702+ // Wrong datatype
1703+ let invalid_schema_wrong_datatype =
1704+ Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int32 , false ) ] ) ;
1705+ assert_eq ! ( apply_schema_unchecked( & record_batch, invalid_schema_wrong_datatype. into( ) , 0 ) . unwrap( ) . to_string( ) , "Invalid argument error: column types must match schema types, expected Int32 but found Utf8 at column index 0" ) ;
1706+
1707+ // Wrong column name. A instead C
1708+ let invalid_schema_wrong_col_name =
1709+ Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Utf8 , false ) ] ) ;
1710+
1711+ assert ! ( record_batch
1712+ . clone( )
1713+ . with_schema_unchecked( invalid_schema_wrong_col_name. into( ) )
1714+ . unwrap( )
1715+ . column_by_name( "c" )
1716+ . is_none( ) ) ;
1717+
1718+ // Valid schema
1719+ let valid_schema = Schema :: new ( vec ! [ Field :: new( "c" , DataType :: Utf8 , false ) ] ) ;
1720+
1721+ assert_eq ! (
1722+ record_batch
1723+ . clone( )
1724+ . with_schema_unchecked( valid_schema. into( ) )
1725+ . unwrap( )
1726+ . column_by_name( "c" ) ,
1727+ record_batch. column_by_name( "c" )
1728+ ) ;
1729+ }
16401730}
0 commit comments