Skip to content

Commit 9f0b9d1

Browse files
committed
fix(duckdb): cast query_arrow results to projected_schema
DuckDB's query_arrow ignored the projected_schema parameter, returning batches with DuckDB's native types (e.g. Timestamp(µs)) even when the caller expected different types (e.g. Timestamp(ns)). This caused schema mismatches for downstream operators pushed below SchemaCastScanExec. Cast result batches to projected_schema in the output stream when types differ. Add shared cast_batch_to_schema utility in util/arrow.rs for reuse by other Arrow-native connectors (ADBC, ODBC).
1 parent 040aa83 commit 9f0b9d1

3 files changed

Lines changed: 205 additions & 3 deletions

File tree

core/src/sql/db_connection_pool/dbconnection/duckdbconn.rs

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use snafu::{prelude::*, ResultExt};
2323
use tokio::sync::mpsc::Sender;
2424

2525
use crate::sql::db_connection_pool::runtime::run_sync_with_tokio;
26+
use crate::util::arrow::cast_batch_to_schema;
2627
use crate::util::schema::SchemaValidator;
2728
use crate::UnsupportedTypeAction;
2829

@@ -453,7 +454,7 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
453454
&self,
454455
sql: &str,
455456
params: &[DuckDBParameter],
456-
_projected_schema: Option<SchemaRef>,
457+
projected_schema: Option<SchemaRef>,
457458
) -> Result<SendableRecordBatchStream> {
458459
let (batch_tx, mut batch_rx) = tokio::sync::mpsc::channel::<RecordBatch>(4);
459460

@@ -501,9 +502,18 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
501502
Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
502503
});
503504

505+
let stream_schema = projected_schema.clone().unwrap_or(schema);
506+
504507
let output_stream = stream! {
505508
while let Some(batch) = batch_rx.recv().await {
506-
yield Ok(batch);
509+
if let Some(ref target_schema) = projected_schema {
510+
match cast_batch_to_schema(&batch, target_schema) {
511+
Ok(casted) => yield Ok(casted),
512+
Err(e) => yield Err(e),
513+
}
514+
} else {
515+
yield Ok(batch);
516+
}
507517
}
508518

509519
match join_handle.await {
@@ -522,7 +532,7 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
522532
};
523533

524534
Ok(Box::pin(RecordBatchStreamAdapter::new(
525-
schema,
535+
stream_schema,
526536
output_stream,
527537
)))
528538
};
@@ -945,4 +955,70 @@ mod tests {
945955

946956
Ok(())
947957
}
958+
959+
#[test]
960+
fn test_query_arrow_casts_to_projected_schema() {
961+
use arrow::datatypes::{Schema, TimeUnit};
962+
use futures::StreamExt;
963+
964+
use crate::sql::db_connection_pool::duckdbpool::DuckDbConnectionPool;
965+
use crate::sql::db_connection_pool::DbConnectionPool;
966+
967+
let rt = tokio::runtime::Runtime::new().expect("runtime");
968+
rt.block_on(async {
969+
let pool = DuckDbConnectionPool::new_memory().expect("pool created");
970+
971+
// Create table with TIMESTAMPTZ (DuckDB stores as Microsecond)
972+
let conn = pool.connect().await.expect("connection");
973+
let conn = conn.as_sync().expect("sync connection");
974+
conn.execute(
975+
"CREATE TABLE test_ts (id INTEGER, created_at TIMESTAMPTZ)",
976+
&[],
977+
)
978+
.expect("table created");
979+
conn.execute(
980+
"INSERT INTO test_ts VALUES (1, '2023-01-01T00:00:00Z')",
981+
&[],
982+
)
983+
.expect("data inserted");
984+
985+
// Request Nanosecond via projected_schema
986+
let projected_schema = Arc::new(Schema::new(vec![
987+
Field::new("id", DataType::Int32, true),
988+
Field::new(
989+
"created_at",
990+
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
991+
true,
992+
),
993+
]));
994+
995+
let stream = conn
996+
.query_arrow(
997+
"SELECT id, created_at FROM test_ts",
998+
&[],
999+
Some(projected_schema.clone()),
1000+
)
1001+
.expect("query_arrow should succeed");
1002+
1003+
// Verify stream schema matches projected_schema
1004+
assert_eq!(stream.schema(), projected_schema);
1005+
1006+
let mut batches = vec![];
1007+
let mut stream = stream;
1008+
while let Some(batch) = stream.next().await {
1009+
batches.push(batch.expect("batch should be Ok"));
1010+
}
1011+
1012+
assert_eq!(batches.len(), 1);
1013+
let batch = &batches[0];
1014+
assert_eq!(batch.schema(), projected_schema);
1015+
1016+
// Verify the timestamp column is Nanosecond
1017+
let ts_col = batch.column(1);
1018+
assert_eq!(
1019+
ts_col.data_type(),
1020+
&DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into()))
1021+
);
1022+
});
1023+
}
9481024
}

core/src/util/arrow.rs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
use std::sync::Arc;
2+
3+
use arrow::array::RecordBatch;
4+
use arrow::compute::cast;
5+
use datafusion::arrow::datatypes::SchemaRef;
6+
use datafusion::error::DataFusionError;
7+
8+
/// Cast a `RecordBatch` to match the target schema, casting columns whose types differ.
9+
/// Columns that already match are passed through unchanged.
10+
pub fn cast_batch_to_schema(
11+
batch: &RecordBatch,
12+
target_schema: &SchemaRef,
13+
) -> Result<RecordBatch, DataFusionError> {
14+
let columns: Vec<_> = batch
15+
.columns()
16+
.iter()
17+
.zip(target_schema.fields())
18+
.map(|(col, target_field)| {
19+
if col.data_type() == target_field.data_type() {
20+
Ok(Arc::clone(col))
21+
} else {
22+
cast(col, target_field.data_type()).map_err(|e| {
23+
DataFusionError::Execution(format!(
24+
"Failed to cast column '{}' from {:?} to {:?}: {e}",
25+
target_field.name(),
26+
col.data_type(),
27+
target_field.data_type(),
28+
))
29+
})
30+
}
31+
})
32+
.collect::<Result<_, _>>()?;
33+
34+
RecordBatch::try_new(Arc::clone(target_schema), columns).map_err(|e| {
35+
DataFusionError::Execution(format!("Failed to create RecordBatch after cast: {e}"))
36+
})
37+
}
38+
39+
#[cfg(test)]
40+
mod tests {
41+
use super::*;
42+
use arrow::array::{Int32Array, TimestampMicrosecondArray, TimestampNanosecondArray};
43+
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
44+
45+
#[test]
46+
fn test_cast_timestamp_us_to_ns() {
47+
let source_schema = Arc::new(Schema::new(vec![
48+
Field::new("id", DataType::Int32, false),
49+
Field::new(
50+
"ts",
51+
DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
52+
false,
53+
),
54+
]));
55+
56+
let target_schema = Arc::new(Schema::new(vec![
57+
Field::new("id", DataType::Int32, false),
58+
Field::new(
59+
"ts",
60+
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
61+
false,
62+
),
63+
]));
64+
65+
let batch = RecordBatch::try_new(
66+
source_schema,
67+
vec![
68+
Arc::new(Int32Array::from(vec![1, 2, 3])),
69+
Arc::new(
70+
TimestampMicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000])
71+
.with_timezone("UTC"),
72+
),
73+
],
74+
)
75+
.unwrap();
76+
77+
let result = cast_batch_to_schema(&batch, &target_schema).unwrap();
78+
assert_eq!(result.schema(), target_schema);
79+
assert_eq!(
80+
result.column(1).data_type(),
81+
&DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into()))
82+
);
83+
84+
// Values should be multiplied by 1000
85+
let ts_col = result
86+
.column(1)
87+
.as_any()
88+
.downcast_ref::<TimestampNanosecondArray>()
89+
.unwrap();
90+
assert_eq!(ts_col.value(0), 1_000_000_000);
91+
assert_eq!(ts_col.value(1), 2_000_000_000);
92+
assert_eq!(ts_col.value(2), 3_000_000_000);
93+
}
94+
95+
#[test]
96+
fn test_no_cast_when_types_match() {
97+
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
98+
99+
let batch = RecordBatch::try_new(
100+
Arc::clone(&schema),
101+
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
102+
)
103+
.unwrap();
104+
105+
let result = cast_batch_to_schema(&batch, &schema).unwrap();
106+
assert_eq!(result.schema(), schema);
107+
}
108+
109+
#[test]
110+
fn test_cast_incompatible_types_returns_error() {
111+
use arrow::array::StringArray;
112+
113+
let source_schema = Arc::new(Schema::new(vec![Field::new("val", DataType::Utf8, false)]));
114+
let target_schema = Arc::new(Schema::new(vec![Field::new("val", DataType::Int32, false)]));
115+
116+
let batch = RecordBatch::try_new(
117+
source_schema,
118+
vec![Arc::new(StringArray::from(vec!["not_a_number"]))],
119+
)
120+
.unwrap();
121+
122+
let result = cast_batch_to_schema(&batch, &target_schema);
123+
assert!(result.is_err());
124+
}
125+
}

core/src/util/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::collections::HashMap;
66

77
use crate::UnsupportedTypeAction;
88

9+
pub mod arrow;
910
pub mod column_reference;
1011
pub mod constraints;
1112
pub mod count_exec;

0 commit comments

Comments
 (0)