|
1 | 1 | use datafusion::arrow::{ |
2 | 2 | array::{Array, RecordBatch, RecordBatchOptions}, |
3 | | - compute::cast, |
| 3 | + compute::{cast_with_options, CastOptions}, |
4 | 4 | datatypes::{DataType, IntervalUnit, SchemaRef}, |
5 | 5 | }; |
6 | 6 | use std::sync::Arc; |
@@ -153,8 +153,18 @@ pub fn try_cast_to(record_batch: RecordBatch, expected_schema: SchemaRef) -> Res |
153 | 153 | DataType::Interval(IntervalUnit::DayTime), |
154 | 154 | ) => cast_interval_monthdaynano_to_daytime(&Arc::clone(record_batch_col)) |
155 | 155 | .map_err(make_err), |
156 | | - _ => cast(&Arc::clone(record_batch_col), expected_field.data_type()) |
157 | | - .map_err(make_err), |
| 156 | + _ => { |
| 157 | + let options = CastOptions { |
| 158 | + safe: false, |
| 159 | + ..CastOptions::default() |
| 160 | + }; |
| 161 | + cast_with_options( |
| 162 | + &Arc::clone(record_batch_col), |
| 163 | + expected_field.data_type(), |
| 164 | + &options, |
| 165 | + ) |
| 166 | + .map_err(make_err) |
| 167 | + } |
158 | 168 | } |
159 | 169 | }) |
160 | 170 | .collect::<Result<Vec<Arc<dyn Array>>>>() |
@@ -295,4 +305,71 @@ mod test { |
295 | 305 | let expected = ["++", "++", "++"]; |
296 | 306 | assert_batches_eq!(expected, &[result]); |
297 | 307 | } |
| 308 | + |
| 309 | + /// Casting Decimal128(38,9) → Decimal128(38,27) must return an error when |
| 310 | + /// the upscale would overflow, instead of silently producing NULL. |
| 311 | + #[test] |
| 312 | + fn test_try_cast_to_decimal_overflow_returns_error() { |
| 313 | + // Value with 12 integer digits: 110_367_043_872.497010000 |
| 314 | + // Internal i128 at scale 9 = 110367043872497010000 |
| 315 | + let value_i128: i128 = 110_367_043_872_497_010_000; |
| 316 | + |
| 317 | + let source_schema = Arc::new(Schema::new(vec![Field::new( |
| 318 | + "sum_charge", |
| 319 | + DataType::Decimal128(38, 9), |
| 320 | + true, |
| 321 | + )])); |
| 322 | + |
| 323 | + let source_array = Decimal128Array::from(vec![Some(value_i128)]) |
| 324 | + .with_precision_and_scale(38, 9) |
| 325 | + .expect("valid Decimal128(38,9)"); |
| 326 | + |
| 327 | + let batch = |
| 328 | + RecordBatch::try_new(source_schema, vec![Arc::new(source_array)]).expect("valid batch"); |
| 329 | + |
| 330 | + // Target schema with wider scale (38,27) — only allows 11 integer digits |
| 331 | + let target_schema = Arc::new(Schema::new(vec![Field::new( |
| 332 | + "sum_charge", |
| 333 | + DataType::Decimal128(38, 27), |
| 334 | + true, |
| 335 | + )])); |
| 336 | + |
| 337 | + let result = try_cast_to(batch, target_schema); |
| 338 | + assert!( |
| 339 | + result.is_err(), |
| 340 | + "Decimal overflow should return an error, not silently produce NULL" |
| 341 | + ); |
| 342 | + } |
| 343 | + |
| 344 | + /// Casting Decimal128 with values that fit should succeed. |
| 345 | + #[test] |
| 346 | + fn test_try_cast_to_decimal_no_overflow_succeeds() { |
| 347 | + // Value with 11 integer digits: 99_999_999_999.000000000 (fits in 38-27=11 digits) |
| 348 | + let value_i128: i128 = 99_999_999_999_000_000_000; |
| 349 | + |
| 350 | + let source_schema = Arc::new(Schema::new(vec![Field::new( |
| 351 | + "amount", |
| 352 | + DataType::Decimal128(38, 9), |
| 353 | + true, |
| 354 | + )])); |
| 355 | + |
| 356 | + let source_array = Decimal128Array::from(vec![Some(value_i128)]) |
| 357 | + .with_precision_and_scale(38, 9) |
| 358 | + .expect("valid Decimal128(38,9)"); |
| 359 | + |
| 360 | + let batch = |
| 361 | + RecordBatch::try_new(source_schema, vec![Arc::new(source_array)]).expect("valid batch"); |
| 362 | + |
| 363 | + let target_schema = Arc::new(Schema::new(vec![Field::new( |
| 364 | + "amount", |
| 365 | + DataType::Decimal128(38, 27), |
| 366 | + true, |
| 367 | + )])); |
| 368 | + |
| 369 | + let result = try_cast_to(batch, target_schema); |
| 370 | + assert!( |
| 371 | + result.is_ok(), |
| 372 | + "Decimal cast should succeed when value fits: {result:?}" |
| 373 | + ); |
| 374 | + } |
298 | 375 | } |
0 commit comments