Skip to content

Commit 15ae26b

Browse files
authored
Merge branch 'spiceai-52' into viktor/duckdb-deletes-stack-overflow
2 parents 8df86df + 466acff commit 15ae26b

5 files changed

Lines changed: 85 additions & 9 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ arrow-json = "57.1.0"
4141
arrow-odbc = { version = "23.1" }
4242
datafusion = { version = "52", default-features = false }
4343
datafusion-expr = { version = "52" }
44-
datafusion-federation = { git = "https://github.com/spiceai/datafusion-federation.git", rev = "a4fce79433e5c2fe779427c0bfce6a599193e300" }
44+
datafusion-federation = { git = "https://github.com/spiceai/datafusion-federation.git", rev = "6b6bfb0d30da8e5c2eb851094e366d98fa839575" }
4545
datafusion-ffi = { version = "52" }
4646
datafusion-proto = { version = "52" }
4747
datafusion-physical-expr = { version = "52" }

core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ duckdb = [
121121
"dep:arrow-schema",
122122
"dep:byte-unit",
123123
"dep:datafusion-physical-expr",
124+
"federation",
124125
]
125126
duckdb-federation = ["duckdb", "federation"]
126127
federation = ["dep:datafusion-federation"]

core/src/sql/db_connection_pool/dbconnection/duckdbconn.rs

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
453453
&self,
454454
sql: &str,
455455
params: &[DuckDBParameter],
456-
_projected_schema: Option<SchemaRef>,
456+
projected_schema: Option<SchemaRef>,
457457
) -> Result<SendableRecordBatchStream> {
458458
let (batch_tx, mut batch_rx) = tokio::sync::mpsc::channel::<RecordBatch>(4);
459459

@@ -501,9 +501,20 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
501501
Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
502502
});
503503

504+
let stream_schema = projected_schema.clone().unwrap_or(schema);
505+
504506
let output_stream = stream! {
505507
while let Some(batch) = batch_rx.recv().await {
506-
yield Ok(batch);
508+
if let Some(ref target_schema) = projected_schema {
509+
match datafusion_federation::schema_cast::record_convert::try_cast_to(batch, Arc::clone(target_schema)) {
510+
Ok(casted) => yield Ok(casted),
511+
Err(e) => yield Err(DataFusionError::Execution(format!(
512+
"Failed to cast DuckDB result batch to projected schema: {e}"
513+
))),
514+
}
515+
} else {
516+
yield Ok(batch);
517+
}
507518
}
508519

509520
match join_handle.await {
@@ -522,7 +533,7 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
522533
};
523534

524535
Ok(Box::pin(RecordBatchStreamAdapter::new(
525-
schema,
536+
stream_schema,
526537
output_stream,
527538
)))
528539
};
@@ -945,4 +956,70 @@ mod tests {
945956

946957
Ok(())
947958
}
959+
960+
#[test]
961+
fn test_query_arrow_casts_to_projected_schema() {
962+
use arrow::datatypes::{Schema, TimeUnit};
963+
use futures::StreamExt;
964+
965+
use crate::sql::db_connection_pool::duckdbpool::DuckDbConnectionPool;
966+
use crate::sql::db_connection_pool::DbConnectionPool;
967+
968+
let rt = tokio::runtime::Runtime::new().expect("runtime");
969+
rt.block_on(async {
970+
let pool = DuckDbConnectionPool::new_memory().expect("pool created");
971+
972+
// Create table with TIMESTAMPTZ (DuckDB stores as Microsecond)
973+
let conn = pool.connect().await.expect("connection");
974+
let conn = conn.as_sync().expect("sync connection");
975+
conn.execute(
976+
"CREATE TABLE test_ts (id INTEGER, created_at TIMESTAMPTZ)",
977+
&[],
978+
)
979+
.expect("table created");
980+
conn.execute(
981+
"INSERT INTO test_ts VALUES (1, '2023-01-01T00:00:00Z')",
982+
&[],
983+
)
984+
.expect("data inserted");
985+
986+
// Request Nanosecond via projected_schema
987+
let projected_schema = Arc::new(Schema::new(vec![
988+
Field::new("id", DataType::Int32, true),
989+
Field::new(
990+
"created_at",
991+
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
992+
true,
993+
),
994+
]));
995+
996+
let stream = conn
997+
.query_arrow(
998+
"SELECT id, created_at FROM test_ts",
999+
&[],
1000+
Some(projected_schema.clone()),
1001+
)
1002+
.expect("query_arrow should succeed");
1003+
1004+
// Verify stream schema matches projected_schema
1005+
assert_eq!(stream.schema(), projected_schema);
1006+
1007+
let mut batches = vec![];
1008+
let mut stream = stream;
1009+
while let Some(batch) = stream.next().await {
1010+
batches.push(batch.expect("batch should be Ok"));
1011+
}
1012+
1013+
assert_eq!(batches.len(), 1);
1014+
let batch = &batches[0];
1015+
assert_eq!(batch.schema(), projected_schema);
1016+
1017+
// Verify the timestamp column is Nanosecond
1018+
let ts_col = batch.column(1);
1019+
assert_eq!(
1020+
ts_col.data_type(),
1021+
&DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into()))
1022+
);
1023+
});
1024+
}
9481025
}

core/tests/duckdb/mod.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,14 @@ use datafusion::execution::context::SessionContext;
88
use datafusion::logical_expr::dml::InsertOp;
99
use datafusion::logical_expr::CreateExternalTable;
1010
use datafusion::physical_plan::collect;
11-
use datafusion_federation::schema_cast::record_convert::try_cast_to;
1211
use datafusion_table_providers::duckdb::DuckDBTableProviderFactory;
1312
use rstest::rstest;
1413
use std::collections::HashMap;
1514
use std::sync::Arc;
1615

1716
async fn arrow_duckdb_round_trip(
1817
arrow_record: RecordBatch,
19-
source_schema: SchemaRef,
18+
_source_schema: SchemaRef,
2019
table_name: &str,
2120
) {
2221
let factory = DuckDBTableProviderFactory::new(duckdb::AccessMode::ReadWrite);
@@ -68,7 +67,6 @@ async fn arrow_duckdb_round_trip(
6867
.expect("DataFrame should be created from query");
6968

7069
let record_batch = df.collect().await.expect("RecordBatch should be collected");
71-
let casted_record = try_cast_to(record_batch[0].clone(), source_schema).unwrap();
7270

7371
tracing::debug!("Original Arrow Record Batch: {:?}", arrow_record.columns());
7472
tracing::debug!(
@@ -80,7 +78,7 @@ async fn arrow_duckdb_round_trip(
8078
assert_eq!(record_batch.len(), 1);
8179
assert_eq!(record_batch[0].num_rows(), arrow_record.num_rows());
8280
assert_eq!(record_batch[0].num_columns(), arrow_record.num_columns());
83-
assert_eq!(casted_record, arrow_record);
81+
assert_eq!(record_batch[0], arrow_record);
8482
}
8583

8684
#[rstest]

0 commit comments

Comments
 (0)