Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 11 additions & 75 deletions core/src/duckdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,15 @@ impl TableProviderFactory for DuckDBTableProviderFactory {
.with_indexes(indexes.clone());

let pool = Arc::new(pool);
make_initial_table(Arc::new(table_definition), &pool)?;
make_initial_table(Arc::new(table_definition.clone()), &pool)?;

let write_settings = DuckDBWriteSettings::from_params(&options);

let table_writer_builder = DuckDBTableWriterBuilder::new()
.with_table_definition(table_definition)
.with_pool(pool)
.set_on_conflict(on_conflict)
.with_write_settings(write_settings);

let dyn_pool: Arc<DynDuckDbConnectionPool> = Arc::new(read_pool);

Expand All @@ -502,28 +510,9 @@ impl TableProviderFactory for DuckDBTableProviderFactory {
self.settings_registry
.apply_settings(conn, &options, DuckDBSettingScope::Global)?;

// Read actual DuckDB schema after table creation (may differ from cmd.schema).
let schema_conn = dyn_pool.connect().await?;
let schema = get_schema(schema_conn, &TableReference::bare(name.clone()))
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;

let table_definition =
TableDefinition::new(RelationName::new(name.clone()), Arc::clone(&schema))
.with_constraints(cmd.constraints.clone())
.with_indexes(indexes.clone());

let write_settings = DuckDBWriteSettings::from_params(&options);

let table_writer_builder = DuckDBTableWriterBuilder::new()
.with_table_definition(table_definition)
.with_pool(pool)
.set_on_conflict(on_conflict)
.with_write_settings(write_settings);

let read_provider = Arc::new(DuckDBTable::new_with_schema(
&dyn_pool,
schema,
Arc::clone(&schema),
TableReference::bare(name.clone()),
None,
Some(self.dialect.clone()),
Expand Down Expand Up @@ -809,7 +798,7 @@ pub(crate) mod tests {
use crate::duckdb::write::DuckDBTableWriter;

use super::*;
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::{Constraints, ToDFSchema};
use datafusion::logical_expr::CreateExternalTable;
use datafusion::prelude::SessionContext;
Expand Down Expand Up @@ -1128,57 +1117,4 @@ pub(crate) mod tests {
assert_eq!(e.to_string(), "External error: Query execution failed.\nInvalid Input Error: Failed to cast value: Could not convert string 'invalid' to BOOL\nFor details, refer to the DuckDB manual: https://duckdb.org/docs/");
}
}

/// Verifies the read provider advertises actual DuckDB storage types,
/// not the requested cmd.schema types.
#[tokio::test]
async fn test_read_provider_schema_reflects_actual_duckdb_types() {
let table_name = TableReference::bare("test_timestamp_schema");
let schema = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new(
"created_at",
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
false,
),
]);

let mut options = HashMap::new();
options.insert("mode".to_string(), "memory".to_string());

let factory = DuckDBTableProviderFactory::new(duckdb::AccessMode::ReadWrite);
let ctx = SessionContext::new();
let cmd = CreateExternalTable {
schema: Arc::new(schema.to_dfschema().expect("to df schema")),
name: table_name,
location: "".to_string(),
file_type: "".to_string(),
table_partition_cols: vec![],
if_not_exists: false,
definition: None,
order_exprs: vec![],
unbounded: false,
options,
constraints: Constraints::default(),
column_defaults: HashMap::new(),
temporary: false,
or_replace: false,
};

let table_provider = factory
.create(&ctx.state(), &cmd)
.await
.expect("table provider created");

let read_schema = table_provider.schema();
let ts_field = read_schema
.field_with_name("created_at")
.expect("created_at field exists");

// DuckDB stores TIMESTAMPTZ as Microsecond regardless of requested precision.
match ts_field.data_type() {
DataType::Timestamp(TimeUnit::Microsecond, _) => {}
other => panic!("Expected Timestamp(Microsecond, _), got {other:?}"),
}
}
}
82 changes: 79 additions & 3 deletions core/src/sql/db_connection_pool/dbconnection/duckdbconn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use snafu::{prelude::*, ResultExt};
use tokio::sync::mpsc::Sender;

use crate::sql::db_connection_pool::runtime::run_sync_with_tokio;
use crate::util::arrow::cast_batch_to_schema;
use crate::util::schema::SchemaValidator;
use crate::UnsupportedTypeAction;

Expand Down Expand Up @@ -453,7 +454,7 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
&self,
sql: &str,
params: &[DuckDBParameter],
_projected_schema: Option<SchemaRef>,
projected_schema: Option<SchemaRef>,
) -> Result<SendableRecordBatchStream> {
let (batch_tx, mut batch_rx) = tokio::sync::mpsc::channel::<RecordBatch>(4);

Expand Down Expand Up @@ -501,9 +502,18 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
});

let stream_schema = projected_schema.clone().unwrap_or(schema);

let output_stream = stream! {
while let Some(batch) = batch_rx.recv().await {
yield Ok(batch);
if let Some(ref target_schema) = projected_schema {
match cast_batch_to_schema(&batch, target_schema) {
Ok(casted) => yield Ok(casted),
Err(e) => yield Err(e),
}
} else {
yield Ok(batch);
}
}

match join_handle.await {
Expand All @@ -522,7 +532,7 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
};

Ok(Box::pin(RecordBatchStreamAdapter::new(
schema,
stream_schema,
output_stream,
)))
};
Expand Down Expand Up @@ -945,4 +955,70 @@ mod tests {

Ok(())
}

#[test]
fn test_query_arrow_casts_to_projected_schema() {
use arrow::datatypes::{Schema, TimeUnit};
use futures::StreamExt;

use crate::sql::db_connection_pool::duckdbpool::DuckDbConnectionPool;
use crate::sql::db_connection_pool::DbConnectionPool;

let rt = tokio::runtime::Runtime::new().expect("runtime");
rt.block_on(async {
let pool = DuckDbConnectionPool::new_memory().expect("pool created");

// Create table with TIMESTAMPTZ (DuckDB stores as Microsecond)
let conn = pool.connect().await.expect("connection");
let conn = conn.as_sync().expect("sync connection");
conn.execute(
"CREATE TABLE test_ts (id INTEGER, created_at TIMESTAMPTZ)",
&[],
)
.expect("table created");
conn.execute(
"INSERT INTO test_ts VALUES (1, '2023-01-01T00:00:00Z')",
&[],
)
.expect("data inserted");

// Request Nanosecond via projected_schema
let projected_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, true),
Field::new(
"created_at",
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
true,
),
]));

let stream = conn
.query_arrow(
"SELECT id, created_at FROM test_ts",
&[],
Some(projected_schema.clone()),
)
.expect("query_arrow should succeed");

// Verify stream schema matches projected_schema
assert_eq!(stream.schema(), projected_schema);

let mut batches = vec![];
let mut stream = stream;
while let Some(batch) = stream.next().await {
batches.push(batch.expect("batch should be Ok"));
}

assert_eq!(batches.len(), 1);
let batch = &batches[0];
assert_eq!(batch.schema(), projected_schema);

// Verify the timestamp column is Nanosecond
let ts_col = batch.column(1);
assert_eq!(
ts_col.data_type(),
&DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into()))
);
});
}
}
125 changes: 125 additions & 0 deletions core/src/util/arrow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
use std::sync::Arc;

use arrow::array::RecordBatch;
use arrow::compute::cast;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::error::DataFusionError;

/// Cast a `RecordBatch` to match the target schema, casting columns whose types differ.
/// Columns that already match are passed through unchanged.
pub fn cast_batch_to_schema(
batch: &RecordBatch,
target_schema: &SchemaRef,
) -> Result<RecordBatch, DataFusionError> {
let columns: Vec<_> = batch
.columns()
.iter()
.zip(target_schema.fields())
.map(|(col, target_field)| {
if col.data_type() == target_field.data_type() {
Ok(Arc::clone(col))
} else {
cast(col, target_field.data_type()).map_err(|e| {
DataFusionError::Execution(format!(
"Failed to cast column '{}' from {:?} to {:?}: {e}",
target_field.name(),
col.data_type(),
target_field.data_type(),
))
})
}
})
.collect::<Result<_, _>>()?;

RecordBatch::try_new(Arc::clone(target_schema), columns).map_err(|e| {
DataFusionError::Execution(format!("Failed to create RecordBatch after cast: {e}"))
})
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Int32Array, TimestampMicrosecondArray, TimestampNanosecondArray};
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};

#[test]
fn test_cast_timestamp_us_to_ns() {
let source_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new(
"ts",
DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
false,
),
]));

let target_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new(
"ts",
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
false,
),
]));

let batch = RecordBatch::try_new(
source_schema,
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(
TimestampMicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000])
.with_timezone("UTC"),
),
],
)
.unwrap();

let result = cast_batch_to_schema(&batch, &target_schema).unwrap();
assert_eq!(result.schema(), target_schema);
assert_eq!(
result.column(1).data_type(),
&DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into()))
);

// Values should be multiplied by 1000
let ts_col = result
.column(1)
.as_any()
.downcast_ref::<TimestampNanosecondArray>()
.unwrap();
assert_eq!(ts_col.value(0), 1_000_000_000);
assert_eq!(ts_col.value(1), 2_000_000_000);
assert_eq!(ts_col.value(2), 3_000_000_000);
}

#[test]
fn test_no_cast_when_types_match() {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));

let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();

let result = cast_batch_to_schema(&batch, &schema).unwrap();
assert_eq!(result.schema(), schema);
}

#[test]
fn test_cast_incompatible_types_returns_error() {
use arrow::array::StringArray;

let source_schema = Arc::new(Schema::new(vec![Field::new("val", DataType::Utf8, false)]));
let target_schema = Arc::new(Schema::new(vec![Field::new("val", DataType::Int32, false)]));

let batch = RecordBatch::try_new(
source_schema,
vec![Arc::new(StringArray::from(vec!["not_a_number"]))],
)
.unwrap();

let result = cast_batch_to_schema(&batch, &target_schema);
assert!(result.is_err());
}
}
1 change: 1 addition & 0 deletions core/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::collections::HashMap;

use crate::UnsupportedTypeAction;

pub mod arrow;
pub mod column_reference;
pub mod constraints;
pub mod count_exec;
Expand Down
Loading