Skip to content

Commit 29f6608

Browse files
authored
fix(DuckDB): replace cast_batch_to_schema with federation's try_cast_to (#653)
1 parent 846d4de commit 29f6608

5 files changed

Lines changed: 7 additions & 133 deletions

File tree

core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ duckdb = [
121121
"dep:arrow-schema",
122122
"dep:byte-unit",
123123
"dep:datafusion-physical-expr",
124+
"federation",
124125
]
125126
duckdb-federation = ["duckdb", "federation"]
126127
federation = ["dep:datafusion-federation"]

core/src/sql/db_connection_pool/dbconnection/duckdbconn.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ use snafu::{prelude::*, ResultExt};
2323
use tokio::sync::mpsc::Sender;
2424

2525
use crate::sql::db_connection_pool::runtime::run_sync_with_tokio;
26-
use crate::util::arrow::cast_batch_to_schema;
2726
use crate::util::schema::SchemaValidator;
2827
use crate::UnsupportedTypeAction;
2928

@@ -507,9 +506,11 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
507506
let output_stream = stream! {
508507
while let Some(batch) = batch_rx.recv().await {
509508
if let Some(ref target_schema) = projected_schema {
510-
match cast_batch_to_schema(&batch, target_schema) {
509+
match datafusion_federation::schema_cast::record_convert::try_cast_to(batch, Arc::clone(target_schema)) {
511510
Ok(casted) => yield Ok(casted),
512-
Err(e) => yield Err(e),
511+
Err(e) => yield Err(DataFusionError::Execution(format!(
512+
"Failed to cast DuckDB result batch to projected schema: {e}"
513+
))),
513514
}
514515
} else {
515516
yield Ok(batch);

core/src/util/arrow.rs

Lines changed: 0 additions & 125 deletions
This file was deleted.

core/src/util/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use std::collections::HashMap;
66

77
use crate::UnsupportedTypeAction;
88

9-
pub mod arrow;
109
pub mod column_reference;
1110
pub mod constraints;
1211
pub mod count_exec;

core/tests/duckdb/mod.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,14 @@ use datafusion::execution::context::SessionContext;
88
use datafusion::logical_expr::dml::InsertOp;
99
use datafusion::logical_expr::CreateExternalTable;
1010
use datafusion::physical_plan::collect;
11-
use datafusion_federation::schema_cast::record_convert::try_cast_to;
1211
use datafusion_table_providers::duckdb::DuckDBTableProviderFactory;
1312
use rstest::rstest;
1413
use std::collections::HashMap;
1514
use std::sync::Arc;
1615

1716
async fn arrow_duckdb_round_trip(
1817
arrow_record: RecordBatch,
19-
source_schema: SchemaRef,
18+
_source_schema: SchemaRef,
2019
table_name: &str,
2120
) {
2221
let factory = DuckDBTableProviderFactory::new(duckdb::AccessMode::ReadWrite);
@@ -68,7 +67,6 @@ async fn arrow_duckdb_round_trip(
6867
.expect("DataFrame should be created from query");
6968

7069
let record_batch = df.collect().await.expect("RecordBatch should be collected");
71-
let casted_record = try_cast_to(record_batch[0].clone(), source_schema).unwrap();
7270

7371
tracing::debug!("Original Arrow Record Batch: {:?}", arrow_record.columns());
7472
tracing::debug!(
@@ -80,7 +78,7 @@ async fn arrow_duckdb_round_trip(
8078
assert_eq!(record_batch.len(), 1);
8179
assert_eq!(record_batch[0].num_rows(), arrow_record.num_rows());
8280
assert_eq!(record_batch[0].num_columns(), arrow_record.num_columns());
83-
assert_eq!(casted_record, arrow_record);
81+
assert_eq!(record_batch[0], arrow_record);
8482
}
8583

8684
#[rstest]

0 commit comments

Comments
 (0)