Skip to content

Commit 7ac5659

Browse files
sgrebnovphillipleblanc
authored andcommitted
DuckDB: don't use ArrowVTab while creating the table (more types support) (#319)
1 parent 0888b06 commit 7ac5659

4 files changed

Lines changed: 90 additions & 17 deletions

File tree

core/src/duckdb.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ pub enum Error {
156156
#[snafu(display("Failed to register Arrow scan view for DuckDB ingestion: {source}"))]
157157
UnableToRegisterArrowScanView { source: duckdb::Error },
158158

159+
#[snafu(display("Failed to register Arrow scan view to build table creation statement: {source}"))]
160+
UnableToRegisterArrowScanViewForTableCreation { source: duckdb::Error },
161+
159162
#[snafu(display("Failed to drop Arrow scan view for DuckDB ingestion: {source}"))]
160163
UnableToDropArrowScanView { source: duckdb::Error },
161164
}

core/src/duckdb/creator.rs

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@ use crate::sql::arrow_sql_gen::statement::IndexBuilder;
22
use crate::sql::db_connection_pool::dbconnection::duckdbconn::DuckDbConnection;
33
use crate::sql::db_connection_pool::duckdbpool::DuckDbConnectionPool;
44
use crate::util::on_conflict::OnConflict;
5-
use arrow::{array::RecordBatch, datatypes::SchemaRef};
5+
use arrow::{
6+
array::{RecordBatch, RecordBatchIterator, RecordBatchReader},
7+
datatypes::SchemaRef,
8+
ffi_stream::FFI_ArrowArrayStream,
9+
};
610
use datafusion::common::utils::quote_identifier;
711
use datafusion::common::Constraints;
812
use datafusion::sql::TableReference;
9-
use duckdb::{vtab::arrow_recordbatch_to_query_params, ToSql, Transaction};
13+
use duckdb::Transaction;
1014
use itertools::Itertools;
1115
use snafu::prelude::*;
1216
use std::collections::HashSet;
@@ -392,19 +396,25 @@ impl TableManager {
392396
.transaction()
393397
.context(super::UnableToBeginTransactionSnafu)?;
394398
let table_name = self.table_name();
395-
let empty_record = RecordBatch::new_empty(Arc::clone(&self.table_definition.schema));
399+
let record_batch_reader =
400+
create_empty_record_batch_reader(Arc::clone(&self.table_definition.schema));
401+
let stream = FFI_ArrowArrayStream::new(Box::new(record_batch_reader));
396402

397-
let arrow_params = arrow_recordbatch_to_query_params(empty_record);
398-
let arrow_params_vec: Vec<&dyn ToSql> = arrow_params
399-
.iter()
400-
.map(|p| p as &dyn ToSql)
401-
.collect::<Vec<_>>();
402-
let arrow_params_ref: &[&dyn ToSql] = &arrow_params_vec;
403-
let sql =
404-
format!(r#"CREATE TABLE IF NOT EXISTS "{table_name}" AS SELECT * FROM arrow(?, ?)"#,);
403+
let current_ts = std::time::SystemTime::now()
404+
.duration_since(std::time::UNIX_EPOCH)
405+
.context(super::UnableToGetSystemTimeSnafu)?
406+
.as_millis();
407+
408+
let view_name = format!("__scan_{}_{current_ts}", table_name);
409+
tx.register_arrow_scan_view(&view_name, &stream)
410+
.context(super::UnableToRegisterArrowScanViewForTableCreationSnafu)?;
411+
412+
let sql = format!(
413+
r#"CREATE TABLE IF NOT EXISTS "{table_name}" AS SELECT * FROM "{view_name}""#,
414+
);
405415
tracing::debug!("{sql}");
406416

407-
tx.execute(&sql, arrow_params_ref)
417+
tx.execute(&sql, [])
408418
.context(super::UnableToCreateDuckDBTableSnafu)?;
409419

410420
let create_stmt = tx
@@ -670,6 +680,12 @@ impl TableManager {
670680
}
671681
}
672682

683+
fn create_empty_record_batch_reader(schema: SchemaRef) -> impl RecordBatchReader {
684+
let empty_batch = RecordBatch::new_empty(Arc::clone(&schema));
685+
let batches = vec![empty_batch];
686+
RecordBatchIterator::new(batches.into_iter().map(Ok), schema)
687+
}
688+
673689
#[derive(Debug, Clone)]
674690
pub(crate) struct ViewCreator {
675691
name: RelationName,

core/tests/arrow_record_batch_gen/mod.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,60 @@ pub(crate) fn get_arrow_dictionary_array_record_batch() -> (RecordBatch, SchemaR
766766
(record_batch, schema)
767767
}
768768

769+
pub(crate) fn get_arrow_map_record_batch() -> (RecordBatch, SchemaRef) {
770+
let keys = vec!["a", "b", "c", "d", "e", "f", "g", "h"];
771+
let values_data = UInt32Array::from(vec![
772+
Some(0u32),
773+
None,
774+
Some(20),
775+
Some(30),
776+
None,
777+
Some(50),
778+
Some(60),
779+
Some(70),
780+
]);
781+
// Construct a buffer for value offsets, for the nested array:
782+
// [[a, b, c], [d, e, f], [g, h]]
783+
let entry_offsets = [0, 3, 6, 8];
784+
let map_array =
785+
MapArray::new_from_strings(keys.clone().into_iter(), &values_data, &entry_offsets)
786+
.expect("Failed to create MapArray");
787+
let schema = Arc::new(Schema::new(vec![Field::new(
788+
"map_array",
789+
map_array.data_type().clone(),
790+
true,
791+
)]));
792+
let rb = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(map_array)])
793+
.expect("Failed to created arrow Map array record batch");
794+
(rb, schema)
795+
}
796+
797+
// Custom Test Case for Sqlite <-> Arrow Decimal Roundtrip
798+
// SQLite supports up to 16 precision for decimal numbers through REAL type, conforming to IEEE 754 Binary-64 format - https://www.sqlite.org/floatingpoint.html
799+
pub(crate) fn get_sqlite_arrow_decimal_record_batch() -> (RecordBatch, SchemaRef) {
800+
let decimal128_array =
801+
Decimal128Array::from(vec![i128::from(123), i128::from(222), i128::from(321)])
802+
.with_precision_and_scale(16, 10)
803+
.expect("Fail to create Decimal128 array");
804+
let decimal256_array =
805+
Decimal256Array::from(vec![i256::from(-123), i256::from(222), i256::from(0)])
806+
.with_precision_and_scale(16, 10)
807+
.expect("Fail to create Decimal256 array");
808+
809+
let schema = Arc::new(Schema::new(vec![
810+
Field::new("decimal128", DataType::Decimal128(16, 10), false),
811+
Field::new("decimal256", DataType::Decimal256(16, 10), false),
812+
]));
813+
814+
let record_batch = RecordBatch::try_new(
815+
Arc::clone(&schema),
816+
vec![Arc::new(decimal128_array), Arc::new(decimal256_array)],
817+
)
818+
.expect("Failed to created arrow decimal record batch");
819+
820+
(record_batch, schema)
821+
}
822+
769823
fn parse_json_to_batch(json_data: &str, schema: SchemaRef) -> RecordBatch {
770824
let reader = arrow_json::ReaderBuilder::new(schema)
771825
.build(std::io::Cursor::new(json_data))

core/tests/duckdb/mod.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,6 @@ async fn arrow_duckdb_round_trip(
8383
}
8484

8585
#[rstest]
86-
#[ignore]
87-
// Binder Error: Unsupported data type: FixedSizeBinary(2), please file an issue https://github.com/wangfenjin/duckdb-rs"
8886
#[case::binary(get_arrow_binary_record_batch(), "binary")]
8987
#[case::int(get_arrow_int_record_batch(), "int")]
9088
#[case::float(get_arrow_float_record_batch(), "float")]
@@ -93,11 +91,11 @@ async fn arrow_duckdb_round_trip(
9391
#[case::timestamp(get_arrow_timestamp_record_batch(), "timestamp")]
9492
#[case::date(get_arrow_date_record_batch(), "date")]
9593
#[case::struct_type(get_arrow_struct_record_batch(), "struct")]
96-
#[ignore] // Decimal256(76,10) is not yet supported for ArrowVTab
94+
#[ignore] // DuckDB does not support Decimal256 / duckdb_arrow_scan failed to register view
9795
#[case::decimal(get_arrow_decimal_record_batch(), "decimal")]
98-
#[ignore] // Interval(DayTime) is not yet supported for ArrowVTab
96+
#[ignore] // Interval(DayTime) is not supported: / "Conversion Error: Could not convert Interval to Microsecond"
9997
#[case::interval(get_arrow_interval_record_batch(), "interval")]
100-
#[ignore] // Duration(NanoSecond) is not yet supported for ArrowVTab
98+
#[ignore] // TimeUnit::Nanosecond is not correctly supported; written values are zeros
10199
#[case::duration(get_arrow_duration_record_batch(), "duration")]
102100
#[case::list(get_arrow_list_record_batch(), "list")]
103101
#[case::null(get_arrow_null_record_batch(), "null")]
@@ -107,6 +105,8 @@ async fn arrow_duckdb_round_trip(
107105
"list_of_fixed_size_lists"
108106
)]
109107
#[case::list_of_lists(get_arrow_list_of_lists_record_batch(), "list_of_lists")]
108+
#[case::map(get_arrow_map_record_batch(), "map")]
109+
#[case::dictionary(get_arrow_dictionary_array_record_batch(), "dictionary")]
110110
#[test_log::test(tokio::test)]
111111
async fn test_arrow_duckdb_roundtrip(
112112
#[case] arrow_result: (RecordBatch, SchemaRef),

0 commit comments

Comments
 (0)