Skip to content

Commit 846d4de

Browse files
authored
fix(duckdb): cast query_arrow results to projected_schema (#652)
* fix(duckdb): cast query_arrow results to projected_schema DuckDB's query_arrow ignored the projected_schema parameter, returning batches with DuckDB's native types (e.g. Timestamp(µs)) even when the caller expected different types (e.g. Timestamp(ns)). This caused schema mismatches for downstream operators pushed below SchemaCastScanExec. Cast result batches to projected_schema in the output stream when types differ. Add shared cast_batch_to_schema utility in util/arrow.rs for reuse by other Arrow-native connectors (ADBC, ODBC). * Revert "fix(duckdb): use actual DuckDB schema for read provider (#650)" This reverts commit 040aa83.
1 parent 040aa83 commit 846d4de

4 files changed

Lines changed: 216 additions & 78 deletions

File tree

core/src/duckdb.rs

Lines changed: 11 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,15 @@ impl TableProviderFactory for DuckDBTableProviderFactory {
485485
.with_indexes(indexes.clone());
486486

487487
let pool = Arc::new(pool);
488-
make_initial_table(Arc::new(table_definition), &pool)?;
488+
make_initial_table(Arc::new(table_definition.clone()), &pool)?;
489+
490+
let write_settings = DuckDBWriteSettings::from_params(&options);
491+
492+
let table_writer_builder = DuckDBTableWriterBuilder::new()
493+
.with_table_definition(table_definition)
494+
.with_pool(pool)
495+
.set_on_conflict(on_conflict)
496+
.with_write_settings(write_settings);
489497

490498
let dyn_pool: Arc<DynDuckDbConnectionPool> = Arc::new(read_pool);
491499

@@ -502,28 +510,9 @@ impl TableProviderFactory for DuckDBTableProviderFactory {
502510
self.settings_registry
503511
.apply_settings(conn, &options, DuckDBSettingScope::Global)?;
504512

505-
// Read actual DuckDB schema after table creation (may differ from cmd.schema).
506-
let schema_conn = dyn_pool.connect().await?;
507-
let schema = get_schema(schema_conn, &TableReference::bare(name.clone()))
508-
.await
509-
.map_err(|e| DataFusionError::External(Box::new(e)))?;
510-
511-
let table_definition =
512-
TableDefinition::new(RelationName::new(name.clone()), Arc::clone(&schema))
513-
.with_constraints(cmd.constraints.clone())
514-
.with_indexes(indexes.clone());
515-
516-
let write_settings = DuckDBWriteSettings::from_params(&options);
517-
518-
let table_writer_builder = DuckDBTableWriterBuilder::new()
519-
.with_table_definition(table_definition)
520-
.with_pool(pool)
521-
.set_on_conflict(on_conflict)
522-
.with_write_settings(write_settings);
523-
524513
let read_provider = Arc::new(DuckDBTable::new_with_schema(
525514
&dyn_pool,
526-
schema,
515+
Arc::clone(&schema),
527516
TableReference::bare(name.clone()),
528517
None,
529518
Some(self.dialect.clone()),
@@ -809,7 +798,7 @@ pub(crate) mod tests {
809798
use crate::duckdb::write::DuckDBTableWriter;
810799

811800
use super::*;
812-
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
801+
use arrow::datatypes::{DataType, Field, Schema};
813802
use datafusion::common::{Constraints, ToDFSchema};
814803
use datafusion::logical_expr::CreateExternalTable;
815804
use datafusion::prelude::SessionContext;
@@ -1128,57 +1117,4 @@ pub(crate) mod tests {
11281117
assert_eq!(e.to_string(), "External error: Query execution failed.\nInvalid Input Error: Failed to cast value: Could not convert string 'invalid' to BOOL\nFor details, refer to the DuckDB manual: https://duckdb.org/docs/");
11291118
}
11301119
}
1131-
1132-
/// Verifies the read provider advertises actual DuckDB storage types,
1133-
/// not the requested cmd.schema types.
1134-
#[tokio::test]
1135-
async fn test_read_provider_schema_reflects_actual_duckdb_types() {
1136-
let table_name = TableReference::bare("test_timestamp_schema");
1137-
let schema = Schema::new(vec![
1138-
Field::new("id", DataType::Int32, false),
1139-
Field::new(
1140-
"created_at",
1141-
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
1142-
false,
1143-
),
1144-
]);
1145-
1146-
let mut options = HashMap::new();
1147-
options.insert("mode".to_string(), "memory".to_string());
1148-
1149-
let factory = DuckDBTableProviderFactory::new(duckdb::AccessMode::ReadWrite);
1150-
let ctx = SessionContext::new();
1151-
let cmd = CreateExternalTable {
1152-
schema: Arc::new(schema.to_dfschema().expect("to df schema")),
1153-
name: table_name,
1154-
location: "".to_string(),
1155-
file_type: "".to_string(),
1156-
table_partition_cols: vec![],
1157-
if_not_exists: false,
1158-
definition: None,
1159-
order_exprs: vec![],
1160-
unbounded: false,
1161-
options,
1162-
constraints: Constraints::default(),
1163-
column_defaults: HashMap::new(),
1164-
temporary: false,
1165-
or_replace: false,
1166-
};
1167-
1168-
let table_provider = factory
1169-
.create(&ctx.state(), &cmd)
1170-
.await
1171-
.expect("table provider created");
1172-
1173-
let read_schema = table_provider.schema();
1174-
let ts_field = read_schema
1175-
.field_with_name("created_at")
1176-
.expect("created_at field exists");
1177-
1178-
// DuckDB stores TIMESTAMPTZ as Microsecond regardless of requested precision.
1179-
match ts_field.data_type() {
1180-
DataType::Timestamp(TimeUnit::Microsecond, _) => {}
1181-
other => panic!("Expected Timestamp(Microsecond, _), got {other:?}"),
1182-
}
1183-
}
11841120
}

core/src/sql/db_connection_pool/dbconnection/duckdbconn.rs

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use snafu::{prelude::*, ResultExt};
2323
use tokio::sync::mpsc::Sender;
2424

2525
use crate::sql::db_connection_pool::runtime::run_sync_with_tokio;
26+
use crate::util::arrow::cast_batch_to_schema;
2627
use crate::util::schema::SchemaValidator;
2728
use crate::UnsupportedTypeAction;
2829

@@ -453,7 +454,7 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
453454
&self,
454455
sql: &str,
455456
params: &[DuckDBParameter],
456-
_projected_schema: Option<SchemaRef>,
457+
projected_schema: Option<SchemaRef>,
457458
) -> Result<SendableRecordBatchStream> {
458459
let (batch_tx, mut batch_rx) = tokio::sync::mpsc::channel::<RecordBatch>(4);
459460

@@ -501,9 +502,18 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
501502
Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
502503
});
503504

505+
let stream_schema = projected_schema.clone().unwrap_or(schema);
506+
504507
let output_stream = stream! {
505508
while let Some(batch) = batch_rx.recv().await {
506-
yield Ok(batch);
509+
if let Some(ref target_schema) = projected_schema {
510+
match cast_batch_to_schema(&batch, target_schema) {
511+
Ok(casted) => yield Ok(casted),
512+
Err(e) => yield Err(e),
513+
}
514+
} else {
515+
yield Ok(batch);
516+
}
507517
}
508518

509519
match join_handle.await {
@@ -522,7 +532,7 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
522532
};
523533

524534
Ok(Box::pin(RecordBatchStreamAdapter::new(
525-
schema,
535+
stream_schema,
526536
output_stream,
527537
)))
528538
};
@@ -945,4 +955,70 @@ mod tests {
945955

946956
Ok(())
947957
}
958+
959+
#[test]
960+
fn test_query_arrow_casts_to_projected_schema() {
961+
use arrow::datatypes::{Schema, TimeUnit};
962+
use futures::StreamExt;
963+
964+
use crate::sql::db_connection_pool::duckdbpool::DuckDbConnectionPool;
965+
use crate::sql::db_connection_pool::DbConnectionPool;
966+
967+
let rt = tokio::runtime::Runtime::new().expect("runtime");
968+
rt.block_on(async {
969+
let pool = DuckDbConnectionPool::new_memory().expect("pool created");
970+
971+
// Create table with TIMESTAMPTZ (DuckDB stores as Microsecond)
972+
let conn = pool.connect().await.expect("connection");
973+
let conn = conn.as_sync().expect("sync connection");
974+
conn.execute(
975+
"CREATE TABLE test_ts (id INTEGER, created_at TIMESTAMPTZ)",
976+
&[],
977+
)
978+
.expect("table created");
979+
conn.execute(
980+
"INSERT INTO test_ts VALUES (1, '2023-01-01T00:00:00Z')",
981+
&[],
982+
)
983+
.expect("data inserted");
984+
985+
// Request Nanosecond via projected_schema
986+
let projected_schema = Arc::new(Schema::new(vec![
987+
Field::new("id", DataType::Int32, true),
988+
Field::new(
989+
"created_at",
990+
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
991+
true,
992+
),
993+
]));
994+
995+
let stream = conn
996+
.query_arrow(
997+
"SELECT id, created_at FROM test_ts",
998+
&[],
999+
Some(projected_schema.clone()),
1000+
)
1001+
.expect("query_arrow should succeed");
1002+
1003+
// Verify stream schema matches projected_schema
1004+
assert_eq!(stream.schema(), projected_schema);
1005+
1006+
let mut batches = vec![];
1007+
let mut stream = stream;
1008+
while let Some(batch) = stream.next().await {
1009+
batches.push(batch.expect("batch should be Ok"));
1010+
}
1011+
1012+
assert_eq!(batches.len(), 1);
1013+
let batch = &batches[0];
1014+
assert_eq!(batch.schema(), projected_schema);
1015+
1016+
// Verify the timestamp column is Nanosecond
1017+
let ts_col = batch.column(1);
1018+
assert_eq!(
1019+
ts_col.data_type(),
1020+
&DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into()))
1021+
);
1022+
});
1023+
}
9481024
}

core/src/util/arrow.rs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
use std::sync::Arc;
2+
3+
use arrow::array::RecordBatch;
4+
use arrow::compute::cast;
5+
use datafusion::arrow::datatypes::SchemaRef;
6+
use datafusion::error::DataFusionError;
7+
8+
/// Cast a `RecordBatch` to match the target schema, casting columns whose types differ.
9+
/// Columns that already match are passed through unchanged.
10+
pub fn cast_batch_to_schema(
11+
batch: &RecordBatch,
12+
target_schema: &SchemaRef,
13+
) -> Result<RecordBatch, DataFusionError> {
14+
let columns: Vec<_> = batch
15+
.columns()
16+
.iter()
17+
.zip(target_schema.fields())
18+
.map(|(col, target_field)| {
19+
if col.data_type() == target_field.data_type() {
20+
Ok(Arc::clone(col))
21+
} else {
22+
cast(col, target_field.data_type()).map_err(|e| {
23+
DataFusionError::Execution(format!(
24+
"Failed to cast column '{}' from {:?} to {:?}: {e}",
25+
target_field.name(),
26+
col.data_type(),
27+
target_field.data_type(),
28+
))
29+
})
30+
}
31+
})
32+
.collect::<Result<_, _>>()?;
33+
34+
RecordBatch::try_new(Arc::clone(target_schema), columns).map_err(|e| {
35+
DataFusionError::Execution(format!("Failed to create RecordBatch after cast: {e}"))
36+
})
37+
}
38+
39+
#[cfg(test)]
40+
mod tests {
41+
use super::*;
42+
use arrow::array::{Int32Array, TimestampMicrosecondArray, TimestampNanosecondArray};
43+
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
44+
45+
#[test]
46+
fn test_cast_timestamp_us_to_ns() {
47+
let source_schema = Arc::new(Schema::new(vec![
48+
Field::new("id", DataType::Int32, false),
49+
Field::new(
50+
"ts",
51+
DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
52+
false,
53+
),
54+
]));
55+
56+
let target_schema = Arc::new(Schema::new(vec![
57+
Field::new("id", DataType::Int32, false),
58+
Field::new(
59+
"ts",
60+
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
61+
false,
62+
),
63+
]));
64+
65+
let batch = RecordBatch::try_new(
66+
source_schema,
67+
vec![
68+
Arc::new(Int32Array::from(vec![1, 2, 3])),
69+
Arc::new(
70+
TimestampMicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000])
71+
.with_timezone("UTC"),
72+
),
73+
],
74+
)
75+
.unwrap();
76+
77+
let result = cast_batch_to_schema(&batch, &target_schema).unwrap();
78+
assert_eq!(result.schema(), target_schema);
79+
assert_eq!(
80+
result.column(1).data_type(),
81+
&DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into()))
82+
);
83+
84+
// Values should be multiplied by 1000
85+
let ts_col = result
86+
.column(1)
87+
.as_any()
88+
.downcast_ref::<TimestampNanosecondArray>()
89+
.unwrap();
90+
assert_eq!(ts_col.value(0), 1_000_000_000);
91+
assert_eq!(ts_col.value(1), 2_000_000_000);
92+
assert_eq!(ts_col.value(2), 3_000_000_000);
93+
}
94+
95+
#[test]
96+
fn test_no_cast_when_types_match() {
97+
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
98+
99+
let batch = RecordBatch::try_new(
100+
Arc::clone(&schema),
101+
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
102+
)
103+
.unwrap();
104+
105+
let result = cast_batch_to_schema(&batch, &schema).unwrap();
106+
assert_eq!(result.schema(), schema);
107+
}
108+
109+
#[test]
110+
fn test_cast_incompatible_types_returns_error() {
111+
use arrow::array::StringArray;
112+
113+
let source_schema = Arc::new(Schema::new(vec![Field::new("val", DataType::Utf8, false)]));
114+
let target_schema = Arc::new(Schema::new(vec![Field::new("val", DataType::Int32, false)]));
115+
116+
let batch = RecordBatch::try_new(
117+
source_schema,
118+
vec![Arc::new(StringArray::from(vec!["not_a_number"]))],
119+
)
120+
.unwrap();
121+
122+
let result = cast_batch_to_schema(&batch, &target_schema);
123+
assert!(result.is_err());
124+
}
125+
}

core/src/util/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::collections::HashMap;
66

77
use crate::UnsupportedTypeAction;
88

9+
pub mod arrow;
910
pub mod column_reference;
1011
pub mod constraints;
1112
pub mod count_exec;

0 commit comments

Comments
 (0)