Skip to content

Commit 2c2ad24

Browse files
committed
[flight] Coerce flight data into target schema if needed
1 parent 58531df commit 2c2ad24

4 files changed

Lines changed: 121 additions & 34 deletions

File tree

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ description = "Extend the capabilities of DataFusion to support additional data
1010
[dependencies]
1111
arrow = "52.2.0"
1212
arrow-array = { version = "52.2.0", optional = true }
13+
arrow-cast = { version = "52.2.0", optional = true }
1314
arrow-flight = { version = "52.2.0", optional = true, features = ["flight-sql-experimental", "tls"] }
1415
arrow-schema = { version = "52.2.0", optional = true, features = ["serde"] }
1516
arrow-json = "52.2.0"
@@ -83,6 +84,7 @@ sqlite = ["dep:rusqlite", "dep:tokio-rusqlite"]
8384
duckdb = ["dep:duckdb", "dep:r2d2", "dep:uuid"]
8485
flight = [
8586
"dep:arrow-array",
87+
"dep:arrow-cast",
8688
"dep:arrow-flight",
8789
"dep:arrow-schema",
8890
"dep:base64",

src/flight/exec.rs

Lines changed: 115 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use crate::flight::{FlightMetadata, FlightProperties};
2727
use arrow_array::RecordBatch;
2828
use arrow_flight::error::FlightError;
2929
use arrow_flight::{FlightClient, FlightEndpoint, Ticket};
30-
use arrow_schema::SchemaRef;
30+
use arrow_schema::{ArrowError, SchemaRef};
3131
use datafusion::arrow::datatypes::ToByteSlice;
3232
use datafusion::common::Result;
3333
use datafusion::common::{project_schema, DataFusionError};
@@ -206,35 +206,40 @@ async fn try_fetch_stream(
206206
.map_err(|e| FlightError::ExternalError(Box::new(e)))?;
207207
let mut client = FlightClient::new(channel);
208208
client.metadata_mut().clone_from(grpc_headers.as_ref());
209-
let stream = client.do_get(ticket).await?;
209+
let stream = client
210+
.do_get(ticket)
211+
.await?
212+
.map_err(|e| DataFusionError::External(Box::new(e)));
210213
Ok(Box::pin(RecordBatchStreamAdapter::new(
211214
schema.clone(),
212-
stream.map(move |rb| {
213-
let schema = schema.clone();
214-
rb.map(move |rb| {
215-
if schema.fields.is_empty() || rb.schema() == schema {
216-
rb
217-
} else if schema.contains(rb.schema_ref()) {
218-
rb.with_schema(schema.clone()).unwrap()
219-
} else {
220-
let columns = schema
221-
.fields
222-
.iter()
223-
.map(|field| {
224-
rb.column_by_name(field.name())
225-
.expect("missing fields in record batch")
226-
.clone()
227-
})
228-
.collect();
229-
RecordBatch::try_new(schema.clone(), columns)
230-
.expect("cannot impose desired schema on record batch")
231-
}
232-
})
233-
.map_err(|e| DataFusionError::External(Box::new(e)))
234-
}),
215+
stream.map(move |item| item.and_then(|rb| enforce_schema(rb, &schema).map_err(Into::into))),
235216
)))
236217
}
237218

219+
fn enforce_schema(rb: RecordBatch, target_schema: &SchemaRef) -> arrow::error::Result<RecordBatch> {
220+
if target_schema.fields.is_empty() || rb.schema() == *target_schema {
221+
Ok(rb)
222+
} else if target_schema.contains(rb.schema_ref()) {
223+
rb.with_schema(target_schema.clone())
224+
} else {
225+
let columns = target_schema
226+
.fields
227+
.iter()
228+
.map(|field| {
229+
rb.column_by_name(field.name())
230+
.ok_or(ArrowError::SchemaError(format!(
231+
"Required field `{}` is missing from the flight response",
232+
field.name()
233+
)))
234+
.and_then(|original_array| {
235+
arrow_cast::cast(original_array.as_ref(), field.data_type())
236+
})
237+
})
238+
.collect::<Result<_, _>>()?;
239+
RecordBatch::try_new(target_schema.clone(), columns)
240+
}
241+
}
242+
238243
impl DisplayAs for FlightExec {
239244
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
240245
match t {
@@ -297,9 +302,12 @@ impl ExecutionPlan for FlightExec {
297302

298303
#[cfg(test)]
299304
mod tests {
300-
use crate::flight::exec::{FlightConfig, FlightPartition, FlightTicket};
305+
use crate::flight::exec::{enforce_schema, FlightConfig, FlightPartition, FlightTicket};
301306
use crate::flight::FlightProperties;
302-
use arrow_schema::{DataType, Field, Schema};
307+
use arrow_array::{
308+
BooleanArray, Float32Array, Int32Array, RecordBatch, StringArray, StructArray,
309+
};
310+
use arrow_schema::{DataType, Field, Fields, Schema};
303311
use std::collections::HashMap;
304312
use std::sync::Arc;
305313

@@ -334,4 +342,84 @@ mod tests {
334342
let restored = serde_json::from_slice(json.as_slice()).expect("cannot decode json config");
335343
assert_eq!(config, restored);
336344
}
345+
346+
#[test]
347+
fn test_schema_enforcement() -> arrow::error::Result<()> {
348+
let data = StructArray::new(
349+
Fields::from(vec![
350+
Arc::new(Field::new("f_int", DataType::Int32, true)),
351+
Arc::new(Field::new("f_bool", DataType::Boolean, false)),
352+
]),
353+
vec![
354+
Arc::new(Int32Array::from(vec![10, 20])),
355+
Arc::new(BooleanArray::from(vec![true, false])),
356+
],
357+
None,
358+
);
359+
let input_rb = RecordBatch::from(data);
360+
361+
let empty_schema = Arc::new(Schema::empty());
362+
let same_rb = enforce_schema(input_rb.clone(), &empty_schema)?;
363+
assert_eq!(input_rb, same_rb);
364+
365+
let coerced_rb = enforce_schema(
366+
input_rb.clone(),
367+
&Arc::new(Schema::new(vec![
368+
// compatible yet different types with flipped nullability
369+
Arc::new(Field::new("f_int", DataType::Float32, false)),
370+
Arc::new(Field::new("f_bool", DataType::Utf8, true)),
371+
])),
372+
)?;
373+
assert_ne!(input_rb, coerced_rb);
374+
assert_eq!(coerced_rb.num_columns(), 2);
375+
assert_eq!(coerced_rb.num_rows(), 2);
376+
assert_eq!(
377+
coerced_rb.column(0).as_ref(),
378+
&Float32Array::from(vec![10.0, 20.0])
379+
);
380+
assert_eq!(
381+
coerced_rb.column(1).as_ref(),
382+
&StringArray::from(vec!["true", "false"])
383+
);
384+
385+
let projection_rb = enforce_schema(
386+
input_rb.clone(),
387+
&Arc::new(Schema::new(vec![
388+
// keep only the first column and make it non-nullable int16
389+
Arc::new(Field::new("f_int", DataType::Int16, false)),
390+
])),
391+
)?;
392+
assert_eq!(projection_rb.num_columns(), 1);
393+
assert_eq!(projection_rb.num_rows(), 2);
394+
assert_eq!(projection_rb.schema().fields().len(), 1);
395+
assert_eq!(projection_rb.schema().fields()[0].name(), "f_int");
396+
397+
let incompatible_schema_attempt = enforce_schema(
398+
input_rb.clone(),
399+
&Arc::new(Schema::new(vec![
400+
Arc::new(Field::new("f_int", DataType::Float32, true)),
401+
Arc::new(Field::new("f_bool", DataType::Date32, false)),
402+
])),
403+
);
404+
assert!(incompatible_schema_attempt.is_err());
405+
assert_eq!(
406+
incompatible_schema_attempt.unwrap_err().to_string(),
407+
"Cast error: Casting from Boolean to Date32 not supported"
408+
);
409+
410+
let broader_schema_attempt = enforce_schema(
411+
input_rb.clone(),
412+
&Arc::new(Schema::new(vec![
413+
Arc::new(Field::new("f_int", DataType::Int32, true)),
414+
Arc::new(Field::new("f_bool", DataType::Boolean, false)),
415+
Arc::new(Field::new("f_extra", DataType::Utf8, true)),
416+
])),
417+
);
418+
assert!(broader_schema_attempt.is_err());
419+
assert_eq!(
420+
broader_schema_attempt.unwrap_err().to_string(),
421+
"Schema error: Required field `f_extra` is missing from the flight response"
422+
);
423+
Ok(())
424+
}
337425
}

src/flight/sql.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ impl FlightDriver for FlightSqlDriver {
6565
key.strip_prefix(HEADER_PREFIX)
6666
.map(|header_name| (header_name, value))
6767
});
68-
for header in headers {
69-
client.set_header(header.0, header.1)
68+
for (name, value) in headers {
69+
client.set_header(name, value)
7070
}
7171
if let Some(username) = options.get(USERNAME) {
7272
let default_password = "".to_string();

tests/flight/mod.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,7 @@ async fn test_flight_sql_data_source() -> datafusion::common::Result<()> {
161161
Arc::new(Float32Array::from(vec![0.0, 0.1, 0.2, 0.3])),
162162
Arc::new(Int8Array::from(vec![10, 20, 30, 40])),
163163
],
164-
)
165-
.unwrap();
164+
)?;
166165
let rows_per_partition = partition_data.num_rows();
167166

168167
let query = "SELECT * FROM some_table";
@@ -174,9 +173,7 @@ async fn test_flight_sql_data_source() -> datafusion::common::Result<()> {
174173
endpoint_archetype,
175174
];
176175
let num_partitions = endpoints.len();
177-
let flight_info = FlightInfo::default()
178-
.try_with_schema(partition_data.schema().as_ref())
179-
.unwrap();
176+
let flight_info = FlightInfo::default().try_with_schema(partition_data.schema().as_ref())?;
180177
let flight_info = endpoints
181178
.into_iter()
182179
.fold(flight_info, |fi, e| fi.with_endpoint(e));

0 commit comments

Comments
 (0)