@@ -27,7 +27,7 @@ use crate::flight::{FlightMetadata, FlightProperties};
2727use arrow_array:: RecordBatch ;
2828use arrow_flight:: error:: FlightError ;
2929use arrow_flight:: { FlightClient , FlightEndpoint , Ticket } ;
30- use arrow_schema:: SchemaRef ;
30+ use arrow_schema:: { ArrowError , SchemaRef } ;
3131use datafusion:: arrow:: datatypes:: ToByteSlice ;
3232use datafusion:: common:: Result ;
3333use datafusion:: common:: { project_schema, DataFusionError } ;
@@ -206,35 +206,40 @@ async fn try_fetch_stream(
206206 . map_err ( |e| FlightError :: ExternalError ( Box :: new ( e) ) ) ?;
207207 let mut client = FlightClient :: new ( channel) ;
208208 client. metadata_mut ( ) . clone_from ( grpc_headers. as_ref ( ) ) ;
209- let stream = client. do_get ( ticket) . await ?;
209+ let stream = client
210+ . do_get ( ticket)
211+ . await ?
212+ . map_err ( |e| DataFusionError :: External ( Box :: new ( e) ) ) ;
210213 Ok ( Box :: pin ( RecordBatchStreamAdapter :: new (
211214 schema. clone ( ) ,
212- stream. map ( move |rb| {
213- let schema = schema. clone ( ) ;
214- rb. map ( move |rb| {
215- if schema. fields . is_empty ( ) || rb. schema ( ) == schema {
216- rb
217- } else if schema. contains ( rb. schema_ref ( ) ) {
218- rb. with_schema ( schema. clone ( ) ) . unwrap ( )
219- } else {
220- let columns = schema
221- . fields
222- . iter ( )
223- . map ( |field| {
224- rb. column_by_name ( field. name ( ) )
225- . expect ( "missing fields in record batch" )
226- . clone ( )
227- } )
228- . collect ( ) ;
229- RecordBatch :: try_new ( schema. clone ( ) , columns)
230- . expect ( "cannot impose desired schema on record batch" )
231- }
232- } )
233- . map_err ( |e| DataFusionError :: External ( Box :: new ( e) ) )
234- } ) ,
215+ stream. map ( move |item| item. and_then ( |rb| enforce_schema ( rb, & schema) . map_err ( Into :: into) ) ) ,
235216 ) ) )
236217}
237218
219+ fn enforce_schema ( rb : RecordBatch , target_schema : & SchemaRef ) -> arrow:: error:: Result < RecordBatch > {
220+ if target_schema. fields . is_empty ( ) || rb. schema ( ) == * target_schema {
221+ Ok ( rb)
222+ } else if target_schema. contains ( rb. schema_ref ( ) ) {
223+ rb. with_schema ( target_schema. clone ( ) )
224+ } else {
225+ let columns = target_schema
226+ . fields
227+ . iter ( )
228+ . map ( |field| {
229+ rb. column_by_name ( field. name ( ) )
230+ . ok_or ( ArrowError :: SchemaError ( format ! (
231+ "Required field `{}` is missing from the flight response" ,
232+ field. name( )
233+ ) ) )
234+ . and_then ( |original_array| {
235+ arrow_cast:: cast ( original_array. as_ref ( ) , field. data_type ( ) )
236+ } )
237+ } )
238+ . collect :: < Result < _ , _ > > ( ) ?;
239+ RecordBatch :: try_new ( target_schema. clone ( ) , columns)
240+ }
241+ }
242+
238243impl DisplayAs for FlightExec {
239244 fn fmt_as ( & self , t : DisplayFormatType , f : & mut Formatter ) -> std:: fmt:: Result {
240245 match t {
@@ -297,9 +302,12 @@ impl ExecutionPlan for FlightExec {
297302
298303#[ cfg( test) ]
299304mod tests {
300- use crate :: flight:: exec:: { FlightConfig , FlightPartition , FlightTicket } ;
305+ use crate :: flight:: exec:: { enforce_schema , FlightConfig , FlightPartition , FlightTicket } ;
301306 use crate :: flight:: FlightProperties ;
302- use arrow_schema:: { DataType , Field , Schema } ;
307+ use arrow_array:: {
308+ BooleanArray , Float32Array , Int32Array , RecordBatch , StringArray , StructArray ,
309+ } ;
310+ use arrow_schema:: { DataType , Field , Fields , Schema } ;
303311 use std:: collections:: HashMap ;
304312 use std:: sync:: Arc ;
305313
@@ -334,4 +342,84 @@ mod tests {
334342 let restored = serde_json:: from_slice ( json. as_slice ( ) ) . expect ( "cannot decode json config" ) ;
335343 assert_eq ! ( config, restored) ;
336344 }
345+
346+ #[ test]
347+ fn test_schema_enforcement ( ) -> arrow:: error:: Result < ( ) > {
348+ let data = StructArray :: new (
349+ Fields :: from ( vec ! [
350+ Arc :: new( Field :: new( "f_int" , DataType :: Int32 , true ) ) ,
351+ Arc :: new( Field :: new( "f_bool" , DataType :: Boolean , false ) ) ,
352+ ] ) ,
353+ vec ! [
354+ Arc :: new( Int32Array :: from( vec![ 10 , 20 ] ) ) ,
355+ Arc :: new( BooleanArray :: from( vec![ true , false ] ) ) ,
356+ ] ,
357+ None ,
358+ ) ;
359+ let input_rb = RecordBatch :: from ( data) ;
360+
361+ let empty_schema = Arc :: new ( Schema :: empty ( ) ) ;
362+ let same_rb = enforce_schema ( input_rb. clone ( ) , & empty_schema) ?;
363+ assert_eq ! ( input_rb, same_rb) ;
364+
365+ let coerced_rb = enforce_schema (
366+ input_rb. clone ( ) ,
367+ & Arc :: new ( Schema :: new ( vec ! [
368+ // compatible yet different types with flipped nullability
369+ Arc :: new( Field :: new( "f_int" , DataType :: Float32 , false ) ) ,
370+ Arc :: new( Field :: new( "f_bool" , DataType :: Utf8 , true ) ) ,
371+ ] ) ) ,
372+ ) ?;
373+ assert_ne ! ( input_rb, coerced_rb) ;
374+ assert_eq ! ( coerced_rb. num_columns( ) , 2 ) ;
375+ assert_eq ! ( coerced_rb. num_rows( ) , 2 ) ;
376+ assert_eq ! (
377+ coerced_rb. column( 0 ) . as_ref( ) ,
378+ & Float32Array :: from( vec![ 10.0 , 20.0 ] )
379+ ) ;
380+ assert_eq ! (
381+ coerced_rb. column( 1 ) . as_ref( ) ,
382+ & StringArray :: from( vec![ "true" , "false" ] )
383+ ) ;
384+
385+ let projection_rb = enforce_schema (
386+ input_rb. clone ( ) ,
387+ & Arc :: new ( Schema :: new ( vec ! [
388+ // keep only the first column and make it non-nullable int16
389+ Arc :: new( Field :: new( "f_int" , DataType :: Int16 , false ) ) ,
390+ ] ) ) ,
391+ ) ?;
392+ assert_eq ! ( projection_rb. num_columns( ) , 1 ) ;
393+ assert_eq ! ( projection_rb. num_rows( ) , 2 ) ;
394+ assert_eq ! ( projection_rb. schema( ) . fields( ) . len( ) , 1 ) ;
395+ assert_eq ! ( projection_rb. schema( ) . fields( ) [ 0 ] . name( ) , "f_int" ) ;
396+
397+ let incompatible_schema_attempt = enforce_schema (
398+ input_rb. clone ( ) ,
399+ & Arc :: new ( Schema :: new ( vec ! [
400+ Arc :: new( Field :: new( "f_int" , DataType :: Float32 , true ) ) ,
401+ Arc :: new( Field :: new( "f_bool" , DataType :: Date32 , false ) ) ,
402+ ] ) ) ,
403+ ) ;
404+ assert ! ( incompatible_schema_attempt. is_err( ) ) ;
405+ assert_eq ! (
406+ incompatible_schema_attempt. unwrap_err( ) . to_string( ) ,
407+ "Cast error: Casting from Boolean to Date32 not supported"
408+ ) ;
409+
410+ let broader_schema_attempt = enforce_schema (
411+ input_rb. clone ( ) ,
412+ & Arc :: new ( Schema :: new ( vec ! [
413+ Arc :: new( Field :: new( "f_int" , DataType :: Int32 , true ) ) ,
414+ Arc :: new( Field :: new( "f_bool" , DataType :: Boolean , false ) ) ,
415+ Arc :: new( Field :: new( "f_extra" , DataType :: Utf8 , true ) ) ,
416+ ] ) ) ,
417+ ) ;
418+ assert ! ( broader_schema_attempt. is_err( ) ) ;
419+ assert_eq ! (
420+ broader_schema_attempt. unwrap_err( ) . to_string( ) ,
421+ "Schema error: Required field `f_extra` is missing from the flight response"
422+ ) ;
423+ Ok ( ( ) )
424+ }
337425}
0 commit comments