Skip to content

Commit b6c14b7

Browse files
committed
Use strict cast in try_cast_to to error on overflow instead of silent NULL
1 parent 7b37778 commit b6c14b7

1 file changed

Lines changed: 102 additions & 28 deletions

File tree

datafusion-federation/src/schema_cast/record_convert.rs

Lines changed: 102 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use datafusion::arrow::{
22
array::{Array, RecordBatch, RecordBatchOptions},
3-
compute::cast,
3+
compute::{cast_with_options, CastOptions},
44
datatypes::{DataType, IntervalUnit, SchemaRef},
55
};
66
use std::sync::Arc;
@@ -84,6 +84,11 @@ pub fn try_cast_to(record_batch: RecordBatch, expected_schema: SchemaRef) -> Res
8484
});
8585
}
8686

87+
let cast_options = CastOptions {
88+
safe: false,
89+
..CastOptions::default()
90+
};
91+
8792
let cols = expected_schema
8893
.fields()
8994
.iter()
@@ -102,59 +107,55 @@ pub fn try_cast_to(record_batch: RecordBatch, expected_schema: SchemaRef) -> Res
102107

103108
match (record_batch_col.data_type(), expected_field.data_type()) {
104109
(DataType::Utf8, DataType::List(item_type)) => {
105-
cast_string_to_list::<i32>(&Arc::clone(record_batch_col), item_type)
106-
.map_err(make_err)
110+
cast_string_to_list::<i32>(record_batch_col, item_type).map_err(make_err)
107111
}
108112
(DataType::Utf8, DataType::LargeList(item_type)) => {
109-
cast_string_to_large_list::<i32>(&Arc::clone(record_batch_col), item_type)
110-
.map_err(make_err)
113+
cast_string_to_large_list::<i32>(record_batch_col, item_type).map_err(make_err)
111114
}
112115
(DataType::Utf8, DataType::FixedSizeList(item_type, value_length)) => {
113116
cast_string_to_fixed_size_list::<i32>(
114-
&Arc::clone(record_batch_col),
117+
record_batch_col,
115118
item_type,
116119
*value_length,
117120
)
118121
.map_err(make_err)
119122
}
120-
(DataType::Utf8, DataType::Struct(_)) => cast_string_to_struct::<i32>(
121-
&Arc::clone(record_batch_col),
122-
expected_field.clone(),
123-
)
124-
.map_err(make_err),
125-
(DataType::LargeUtf8, DataType::List(item_type)) => {
126-
cast_string_to_list::<i64>(&Arc::clone(record_batch_col), item_type)
123+
(DataType::Utf8, DataType::Struct(_)) => {
124+
cast_string_to_struct::<i32>(record_batch_col, expected_field.clone())
127125
.map_err(make_err)
128126
}
127+
(DataType::LargeUtf8, DataType::List(item_type)) => {
128+
cast_string_to_list::<i64>(record_batch_col, item_type).map_err(make_err)
129+
}
129130
(DataType::LargeUtf8, DataType::LargeList(item_type)) => {
130-
cast_string_to_large_list::<i64>(&Arc::clone(record_batch_col), item_type)
131-
.map_err(make_err)
131+
cast_string_to_large_list::<i64>(record_batch_col, item_type).map_err(make_err)
132132
}
133133
(DataType::LargeUtf8, DataType::FixedSizeList(item_type, value_length)) => {
134134
cast_string_to_fixed_size_list::<i64>(
135-
&Arc::clone(record_batch_col),
135+
record_batch_col,
136136
item_type,
137137
*value_length,
138138
)
139139
.map_err(make_err)
140140
}
141-
(DataType::LargeUtf8, DataType::Struct(_)) => cast_string_to_struct::<i64>(
142-
&Arc::clone(record_batch_col),
143-
expected_field.clone(),
144-
)
145-
.map_err(make_err),
141+
(DataType::LargeUtf8, DataType::Struct(_)) => {
142+
cast_string_to_struct::<i64>(record_batch_col, expected_field.clone())
143+
.map_err(make_err)
144+
}
146145
(
147146
DataType::Interval(IntervalUnit::MonthDayNano),
148147
DataType::Interval(IntervalUnit::YearMonth),
149-
) => cast_interval_monthdaynano_to_yearmonth(&Arc::clone(record_batch_col))
150-
.map_err(make_err),
148+
) => cast_interval_monthdaynano_to_yearmonth(record_batch_col).map_err(make_err),
151149
(
152150
DataType::Interval(IntervalUnit::MonthDayNano),
153151
DataType::Interval(IntervalUnit::DayTime),
154-
) => cast_interval_monthdaynano_to_daytime(&Arc::clone(record_batch_col))
155-
.map_err(make_err),
156-
_ => cast(&Arc::clone(record_batch_col), expected_field.data_type())
157-
.map_err(make_err),
152+
) => cast_interval_monthdaynano_to_daytime(record_batch_col).map_err(make_err),
153+
_ => cast_with_options(
154+
record_batch_col.as_ref(),
155+
expected_field.data_type(),
156+
&cast_options,
157+
)
158+
.map_err(make_err),
158159
}
159160
})
160161
.collect::<Result<Vec<Arc<dyn Array>>>>()
@@ -182,7 +183,7 @@ pub fn try_cast_to(record_batch: RecordBatch, expected_schema: SchemaRef) -> Res
182183
#[cfg(test)]
183184
mod test {
184185
use super::*;
185-
use datafusion::arrow::array::{LargeStringArray, RecordBatchOptions};
186+
use datafusion::arrow::array::{Decimal128Array, LargeStringArray, RecordBatchOptions};
186187
use datafusion::arrow::{
187188
array::{Int32Array, StringArray},
188189
datatypes::{DataType, Field, Schema, TimeUnit},
@@ -295,4 +296,77 @@ mod test {
295296
let expected = ["++", "++", "++"];
296297
assert_batches_eq!(expected, &[result]);
297298
}
299+
300+
/// Casting Decimal128(38,9) → Decimal128(38,27) must return an error when
301+
/// the upscale would overflow, instead of silently producing NULL.
302+
#[test]
303+
fn test_try_cast_to_decimal_overflow_returns_error() {
304+
// Value with 12 integer digits: 110_367_043_872.497010000
305+
// Internal i128 at scale 9 = 110367043872497010000
306+
let value_i128: i128 = 110_367_043_872_497_010_000;
307+
308+
let source_schema = Arc::new(Schema::new(vec![Field::new(
309+
"sum_charge",
310+
DataType::Decimal128(38, 9),
311+
true,
312+
)]));
313+
314+
let source_array = Decimal128Array::from(vec![Some(value_i128)])
315+
.with_precision_and_scale(38, 9)
316+
.expect("valid Decimal128(38,9)");
317+
318+
let batch =
319+
RecordBatch::try_new(source_schema, vec![Arc::new(source_array)]).expect("valid batch");
320+
321+
// Target schema with wider scale (38,27) — only allows 11 integer digits
322+
let target_schema = Arc::new(Schema::new(vec![Field::new(
323+
"sum_charge",
324+
DataType::Decimal128(38, 27),
325+
true,
326+
)]));
327+
328+
let err =
329+
try_cast_to(batch, target_schema).expect_err("Decimal overflow should return an error");
330+
assert!(
331+
matches!(err, Error::UnableToCastColumn { .. }),
332+
"Expected UnableToCastColumn, got: {err:?}"
333+
);
334+
let err_msg = err.to_string();
335+
assert!(
336+
err_msg.contains("is too large to store in a Decimal128"),
337+
"Expected overflow message, got: {err_msg}"
338+
);
339+
}
340+
341+
/// Casting Decimal128 with values that fit should succeed.
342+
#[test]
343+
fn test_try_cast_to_decimal_no_overflow_succeeds() {
344+
// Value with 11 integer digits: 99_999_999_999.000000000 (fits in 38-27=11 digits)
345+
let value_i128: i128 = 99_999_999_999_000_000_000;
346+
347+
let source_schema = Arc::new(Schema::new(vec![Field::new(
348+
"amount",
349+
DataType::Decimal128(38, 9),
350+
true,
351+
)]));
352+
353+
let source_array = Decimal128Array::from(vec![Some(value_i128)])
354+
.with_precision_and_scale(38, 9)
355+
.expect("valid Decimal128(38,9)");
356+
357+
let batch =
358+
RecordBatch::try_new(source_schema, vec![Arc::new(source_array)]).expect("valid batch");
359+
360+
let target_schema = Arc::new(Schema::new(vec![Field::new(
361+
"amount",
362+
DataType::Decimal128(38, 27),
363+
true,
364+
)]));
365+
366+
let result = try_cast_to(batch, target_schema);
367+
assert!(
368+
result.is_ok(),
369+
"Decimal cast should succeed when value fits: {result:?}"
370+
);
371+
}
298372
}

0 commit comments

Comments
 (0)