Skip to content

Commit 7311652

Browse files
authored
Add write support for Decimal32 and Decimal64 (#519)
1 parent 51a39a3 commit 7311652

3 files changed

Lines changed: 83 additions & 25 deletions

File tree

core/src/sql/arrow_sql_gen/arrow.rs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
use datafusion::arrow::{
22
array::{
33
types::Int8Type, ArrayBuilder, BinaryBuilder, BooleanBuilder, Date32Builder, Date64Builder,
4-
Decimal128Builder, Decimal256Builder, FixedSizeBinaryBuilder, FixedSizeListBuilder,
5-
Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
6-
IntervalMonthDayNanoBuilder, LargeBinaryBuilder, LargeStringBuilder, ListBuilder,
7-
NullBuilder, StringBuilder, StringDictionaryBuilder, StructBuilder,
8-
Time64NanosecondBuilder, TimestampMicrosecondBuilder, TimestampMillisecondBuilder,
9-
TimestampNanosecondBuilder, TimestampSecondBuilder, UInt16Builder, UInt32Builder,
10-
UInt64Builder, UInt8Builder,
4+
Decimal128Builder, Decimal256Builder, Decimal32Builder, Decimal64Builder,
5+
FixedSizeBinaryBuilder, FixedSizeListBuilder, Float32Builder, Float64Builder, Int16Builder,
6+
Int32Builder, Int64Builder, Int8Builder, IntervalMonthDayNanoBuilder, LargeBinaryBuilder,
7+
LargeStringBuilder, ListBuilder, NullBuilder, StringBuilder, StringDictionaryBuilder,
8+
StructBuilder, Time64NanosecondBuilder, TimestampMicrosecondBuilder,
9+
TimestampMillisecondBuilder, TimestampNanosecondBuilder, TimestampSecondBuilder,
10+
UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder,
1111
},
1212
datatypes::{DataType, TimeUnit, UInt16Type},
1313
};
@@ -40,6 +40,16 @@ pub fn map_data_type_to_array_builder(data_type: &DataType) -> Box<dyn ArrayBuil
4040
DataType::Binary => Box::new(BinaryBuilder::new()),
4141
DataType::LargeBinary => Box::new(LargeBinaryBuilder::new()),
4242
DataType::Interval(_) => Box::new(IntervalMonthDayNanoBuilder::new()),
43+
DataType::Decimal32(precision, scale) => Box::new(
44+
Decimal32Builder::new()
45+
.with_precision_and_scale(*precision, *scale)
46+
.unwrap_or_default(),
47+
),
48+
DataType::Decimal64(precision, scale) => Box::new(
49+
Decimal64Builder::new()
50+
.with_precision_and_scale(*precision, *scale)
51+
.unwrap_or_default(),
52+
),
4353
DataType::Decimal128(precision, scale) => Box::new(
4454
Decimal128Builder::new()
4555
.with_precision_and_scale(*precision, *scale)

core/src/sql/arrow_sql_gen/statement.rs

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,20 @@ macro_rules! push_value {
175175
}};
176176
}
177177

178+
macro_rules! push_big_decimal_value {
179+
($row_values:expr, $column:expr, $row:expr, $scale:expr, $array_type:ident) => {{
180+
let array = $column.as_any().downcast_ref::<array::$array_type>();
181+
if let Some(valid_array) = array {
182+
if valid_array.is_null($row) {
183+
$row_values.push(Keyword::Null.into());
184+
continue;
185+
}
186+
$row_values
187+
.push(BigDecimal::new(valid_array.value($row).into(), i64::from(*$scale)).into());
188+
}
189+
}};
190+
}
191+
178192
macro_rules! push_list_values {
179193
($data_type:expr, $list_array:expr, $row_values:expr, $array_type:ty, $vec_type:ty, $sql_type:expr) => {{
180194
let mut list_values: Vec<$vec_type> = Vec::new();
@@ -260,18 +274,14 @@ impl<'a> InsertBuilder<'a> {
260274
DataType::LargeUtf8 => push_value!(row_values, column, row, LargeStringArray),
261275
DataType::Utf8View => push_value!(row_values, column, row, StringViewArray),
262276
DataType::Boolean => push_value!(row_values, column, row, BooleanArray),
277+
DataType::Decimal32(_, scale) => {
278+
push_big_decimal_value!(row_values, column, row, scale, Decimal32Array)
279+
}
280+
DataType::Decimal64(_, scale) => {
281+
push_big_decimal_value!(row_values, column, row, scale, Decimal64Array)
282+
}
263283
DataType::Decimal128(_, scale) => {
264-
let array = column.as_any().downcast_ref::<array::Decimal128Array>();
265-
if let Some(valid_array) = array {
266-
if valid_array.is_null(row) {
267-
row_values.push(Keyword::Null.into());
268-
continue;
269-
}
270-
row_values.push(
271-
BigDecimal::new(valid_array.value(row).into(), i64::from(*scale))
272-
.into(),
273-
);
274-
}
284+
push_big_decimal_value!(row_values, column, row, scale, Decimal128Array)
275285
}
276286
DataType::Decimal256(_, scale) => {
277287
let array = column.as_any().downcast_ref::<array::Decimal256Array>();
@@ -1321,9 +1331,10 @@ pub(crate) fn map_data_type_to_column_type(data_type: &DataType) -> ColumnType {
13211331
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => ColumnType::Text,
13221332
DataType::Boolean => ColumnType::Boolean,
13231333
#[allow(clippy::cast_sign_loss)] // This is safe because scale will never be negative
1324-
DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => {
1325-
ColumnType::Decimal(Some((u32::from(*p), *s as u32)))
1326-
}
1334+
DataType::Decimal32(p, s)
1335+
| DataType::Decimal64(p, s)
1336+
| DataType::Decimal128(p, s)
1337+
| DataType::Decimal256(p, s) => ColumnType::Decimal(Some((u32::from(*p), *s as u32))),
13271338
DataType::Timestamp(_unit, time_zone) => {
13281339
if time_zone.is_some() {
13291340
return ColumnType::TimestampWithTimeZone;
@@ -1468,17 +1479,22 @@ mod tests {
14681479
Field::new("id", DataType::Int32, false),
14691480
Field::new("name", DataType::Utf8, false),
14701481
Field::new("age", DataType::Int32, true),
1482+
Field::new("balance", DataType::Decimal64(10, 2), true),
14711483
]);
14721484
let id_array = array::Int32Array::from(vec![1, 2, 3]);
14731485
let name_array = array::StringArray::from(vec!["a", "b", "c"]);
14741486
let age_array = array::Int32Array::from(vec![10, 20, 30]);
1487+
let balance_array = array::Decimal64Array::from(vec![12345, -12345, 12300])
1488+
.with_precision_and_scale(10, 2)
1489+
.unwrap();
14751490

14761491
let batch1 = RecordBatch::try_new(
14771492
Arc::new(schema1.clone()),
14781493
vec![
14791494
Arc::new(id_array.clone()),
14801495
Arc::new(name_array.clone()),
14811496
Arc::new(age_array.clone()),
1497+
Arc::new(balance_array.clone()),
14821498
],
14831499
)
14841500
.expect("Unable to build record batch");
@@ -1487,6 +1503,7 @@ mod tests {
14871503
Field::new("id", DataType::Int32, false),
14881504
Field::new("name", DataType::Utf8, false),
14891505
Field::new("blah", DataType::Int32, true),
1506+
Field::new("balance", DataType::Decimal64(10, 2), true),
14901507
]);
14911508

14921509
let batch2 = RecordBatch::try_new(
@@ -1495,6 +1512,7 @@ mod tests {
14951512
Arc::new(id_array),
14961513
Arc::new(name_array),
14971514
Arc::new(age_array),
1515+
Arc::new(balance_array),
14981516
],
14991517
)
15001518
.expect("Unable to build record batch");
@@ -1503,7 +1521,16 @@ mod tests {
15031521
let sql = InsertBuilder::new(&TableReference::from("users"), &record_batches)
15041522
.build_postgres(None)
15051523
.expect("Failed to build insert statement");
1506-
assert_eq!(sql, "INSERT INTO \"users\" (\"id\", \"name\", \"age\") VALUES (1, 'a', 10), (2, 'b', 20), (3, 'c', 30), (1, 'a', 10), (2, 'b', 20), (3, 'c', 30)");
1524+
assert_eq!(
1525+
sql,
1526+
"INSERT INTO \"users\" (\"id\", \"name\", \"age\", \"balance\") VALUES \
1527+
(1, 'a', 10, 123.45), \
1528+
(2, 'b', 20, -123.45), \
1529+
(3, 'c', 30, 123.00), \
1530+
(1, 'a', 10, 123.45), \
1531+
(2, 'b', 20, -123.45), \
1532+
(3, 'c', 30, 123.00)"
1533+
);
15071534
}
15081535

15091536
#[test]

core/tests/arrow_record_batch_gen/mod.rs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,28 +395,42 @@ pub(crate) fn get_arrow_struct_record_batch() -> (RecordBatch, SchemaRef) {
395395
(record_batch, schema)
396396
}
397397

398-
// Decimal128/Decimal256
399398
pub(crate) fn get_arrow_decimal_record_batch() -> (RecordBatch, SchemaRef) {
399+
let decimal32_array =
400+
Decimal32Array::from(vec![i32::from(123), i32::from(222), i32::from(321)]);
401+
let decimal64_array =
402+
Decimal64Array::from(vec![i64::from(123), i64::from(222), i64::from(321)]);
400403
let decimal128_array =
401404
Decimal128Array::from(vec![i128::from(123), i128::from(222), i128::from(321)]);
402405
let decimal256_array =
403406
Decimal256Array::from(vec![i256::from(-123), i256::from(222), i256::from(0)]);
404407

405408
let schema = Arc::new(Schema::new(vec![
409+
Field::new("decimal32", DataType::Decimal32(9, 2), false),
410+
Field::new("decimal64", DataType::Decimal64(18, 6), false),
406411
Field::new("decimal128", DataType::Decimal128(38, 10), false),
407412
Field::new("decimal256", DataType::Decimal256(76, 10), false),
408413
]));
409414

410415
let record_batch = RecordBatch::try_new(
411416
Arc::clone(&schema),
412-
vec![Arc::new(decimal128_array), Arc::new(decimal256_array)],
417+
vec![
418+
Arc::new(decimal32_array),
419+
Arc::new(decimal64_array),
420+
Arc::new(decimal128_array),
421+
Arc::new(decimal256_array),
422+
],
413423
)
414424
.expect("Failed to created arrow decimal record batch");
415425

416426
(record_batch, schema)
417427
}
418428

419429
pub(crate) fn get_mysql_arrow_decimal_record() -> (RecordBatch, SchemaRef) {
430+
let decimal32_array =
431+
Decimal32Array::from(vec![i32::from(123), i32::from(222), i32::from(321)]);
432+
let decimal64_array =
433+
Decimal64Array::from(vec![i64::from(123), i64::from(222), i64::from(321)]);
420434
let decimal128_array =
421435
Decimal128Array::from(vec![i128::from(123), i128::from(222), i128::from(321)]);
422436
let decimal256_array =
@@ -425,13 +439,20 @@ pub(crate) fn get_mysql_arrow_decimal_record() -> (RecordBatch, SchemaRef) {
425439
.expect("Fail to create Decimal256(65, 10) array");
426440

427441
let schema = Arc::new(Schema::new(vec![
442+
Field::new("decimal32", DataType::Decimal32(9, 2), false),
443+
Field::new("decimal64", DataType::Decimal64(18, 6), false),
428444
Field::new("decimal128", DataType::Decimal128(38, 10), false),
429445
Field::new("decimal256", DataType::Decimal256(65, 10), false), // Maximum is 65.
430446
]));
431447

432448
let record_batch = RecordBatch::try_new(
433449
Arc::clone(&schema),
434-
vec![Arc::new(decimal128_array), Arc::new(decimal256_array)],
450+
vec![
451+
Arc::new(decimal32_array),
452+
Arc::new(decimal64_array),
453+
Arc::new(decimal128_array),
454+
Arc::new(decimal256_array),
455+
],
435456
)
436457
.expect("Failed to created arrow decimal record batch");
437458

0 commit comments

Comments
 (0)