Skip to content

Commit 792193e

Browse files
committed
Fix decimal casts to primitive arrays
Signed-off-by: "Luke Kim" <80174+lukekim@users.noreply.github.com>
1 parent e8869db commit 792193e

1 file changed

Lines changed: 143 additions & 4 deletions

File tree

  • vortex-array/src/arrays/decimal/compute

vortex-array/src/arrays/decimal/compute/cast.rs

Lines changed: 143 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,28 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use vortex_buffer::Buffer;
5+
use vortex_buffer::BufferMut;
56
use vortex_error::VortexExpect;
67
use vortex_error::VortexResult;
78
use vortex_error::vortex_bail;
9+
use vortex_error::vortex_err;
810
use vortex_error::vortex_panic;
11+
use vortex_mask::AllOr;
12+
use vortex_mask::Mask;
913

1014
use crate::ArrayRef;
1115
use crate::ExecutionCtx;
1216
use crate::IntoArray;
1317
use crate::array::ArrayView;
1418
use crate::arrays::Decimal;
1519
use crate::arrays::DecimalArray;
20+
use crate::arrays::primitive::PrimitiveArray;
1621
use crate::dtype::DType;
1722
use crate::dtype::DecimalType;
1823
use crate::dtype::NativeDecimalType;
24+
use crate::dtype::NativePType;
1925
use crate::match_each_decimal_value_type;
26+
use crate::match_each_native_ptype;
2027
use crate::scalar_fn::fns::cast::CastKernel;
2128
use crate::scalar_fn::fns::cast::CastReduce;
2229

@@ -66,17 +73,40 @@ impl CastKernel for Decimal {
6673
dtype: &DType,
6774
ctx: &mut ExecutionCtx,
6875
) -> VortexResult<Option<ArrayRef>> {
69-
// Early return if not casting to decimal
70-
let DType::Decimal(to_decimal_dtype, to_nullability) = dtype else {
71-
return Ok(None);
72-
};
7376
let DType::Decimal(from_decimal_dtype, _) = array.dtype() else {
7477
vortex_panic!(
7578
"DecimalArray must have decimal dtype, got {:?}",
7679
array.dtype()
7780
);
7881
};
7982

83+
if let DType::Primitive(to_ptype, to_nullability) = dtype {
84+
let validity = array.validity()?;
85+
let new_validity =
86+
validity
87+
.clone()
88+
.cast_nullability(*to_nullability, array.len(), ctx)?;
89+
let mask = validity.execute_mask(array.len(), ctx)?;
90+
91+
return Ok(Some(match_each_native_ptype!(*to_ptype, |T| {
92+
match_each_decimal_value_type!(array.values_type(), |F| {
93+
PrimitiveArray::new(
94+
cast_decimal_buffer_to_primitive::<F, T>(
95+
array.buffer::<F>(),
96+
from_decimal_dtype.scale(),
97+
mask,
98+
)?,
99+
new_validity,
100+
)
101+
.into_array()
102+
})
103+
})));
104+
}
105+
106+
let DType::Decimal(to_decimal_dtype, to_nullability) = dtype else {
107+
return Ok(None);
108+
};
109+
80110
// Scale changes are not yet supported
81111
if from_decimal_dtype.scale() != to_decimal_dtype.scale() {
82112
vortex_bail!(
@@ -180,6 +210,57 @@ fn upcast_decimal_buffer<F: NativeDecimalType, T: NativeDecimalType>(from: Buffe
180210
.collect()
181211
}
182212

213+
fn cast_decimal_buffer_to_primitive<F, T>(
214+
from: Buffer<F>,
215+
scale: i8,
216+
mask: Mask,
217+
) -> VortexResult<Buffer<T>>
218+
where
219+
F: NativeDecimalType,
220+
T: NativePType,
221+
{
222+
let scale_factor = 10_f64.powi(i32::from(scale));
223+
224+
match mask.bit_buffer() {
225+
AllOr::All => {
226+
let mut buffer = BufferMut::<T>::with_capacity(from.len());
227+
for value in from {
228+
let value = cast_decimal_value_to_primitive::<F, T>(value, scale_factor)?;
229+
buffer.push(value);
230+
}
231+
Ok(buffer.freeze())
232+
}
233+
AllOr::None => Ok(Buffer::zeroed(from.len())),
234+
AllOr::Some(validity) => {
235+
let mut buffer = BufferMut::<T>::with_capacity(from.len());
236+
for (value, valid) in from.iter().zip(validity.iter()) {
237+
if valid {
238+
let value = cast_decimal_value_to_primitive::<F, T>(*value, scale_factor)?;
239+
buffer.push(value);
240+
} else {
241+
buffer.push(T::default());
242+
}
243+
}
244+
Ok(buffer.freeze())
245+
}
246+
}
247+
}
248+
249+
fn cast_decimal_value_to_primitive<F, T>(value: F, scale_factor: f64) -> VortexResult<T>
250+
where
251+
F: NativeDecimalType,
252+
T: NativePType,
253+
{
254+
let value = value
255+
.to_f64()
256+
.ok_or_else(|| vortex_err!(Compute: "Failed to cast decimal value {value} to f64"))?
257+
/ scale_factor;
258+
259+
T::from(value).ok_or_else(
260+
|| vortex_err!(Compute: "Failed to cast decimal value {value} to {:?}", T::PTYPE),
261+
)
262+
}
263+
183264
#[cfg(test)]
184265
mod tests {
185266
use rstest::rstest;
@@ -198,6 +279,7 @@ mod tests {
198279
use crate::dtype::DecimalDType;
199280
use crate::dtype::DecimalType;
200281
use crate::dtype::Nullability;
282+
use crate::dtype::PType;
201283
use crate::validity::Validity;
202284

203285
#[test]
@@ -331,6 +413,63 @@ mod tests {
331413
assert_eq!(casted.values_type(), DecimalType::I128);
332414
}
333415

416+
#[test]
417+
fn cast_decimal_to_f64_applies_scale() {
418+
let array = DecimalArray::new(
419+
buffer![12345i64, -50, 0],
420+
DecimalDType::new(15, 2),
421+
Validity::NonNullable,
422+
);
423+
let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
424+
425+
#[expect(deprecated)]
426+
let casted = array
427+
.into_array()
428+
.cast(dtype.clone())
429+
.unwrap()
430+
.to_primitive();
431+
432+
assert_eq!(casted.as_ref().dtype(), &dtype);
433+
assert!(matches!(
434+
casted.as_ref().validity(),
435+
Ok(Validity::NonNullable)
436+
));
437+
let values = casted.as_slice::<f64>();
438+
assert!((values[0] - 123.45).abs() < 0.000000000001);
439+
assert_eq!(values[1], -0.5);
440+
assert_eq!(values[2], 0.0);
441+
}
442+
443+
#[test]
444+
fn cast_nullable_decimal_to_nullable_f64_preserves_validity() {
445+
let array = DecimalArray::from_option_iter(
446+
[Some(12345i64), None, Some(-50)],
447+
DecimalDType::new(15, 2),
448+
);
449+
let dtype = DType::Primitive(PType::F64, Nullability::Nullable);
450+
451+
#[expect(deprecated)]
452+
let casted = array
453+
.into_array()
454+
.cast(dtype.clone())
455+
.unwrap()
456+
.to_primitive();
457+
458+
assert_eq!(casted.as_ref().dtype(), &dtype);
459+
let mask = casted
460+
.as_ref()
461+
.validity()
462+
.unwrap()
463+
.execute_mask(casted.len(), &mut LEGACY_SESSION.create_execution_ctx())
464+
.unwrap();
465+
assert!(mask.value(0));
466+
assert!(!mask.value(1));
467+
assert!(mask.value(2));
468+
let values = casted.as_slice::<f64>();
469+
assert!((values[0] - 123.45).abs() < 0.000000000001);
470+
assert_eq!(values[2], -0.5);
471+
}
472+
334473
#[test]
335474
fn cast_to_non_decimal_returns_err() {
336475
let array = DecimalArray::new(

0 commit comments

Comments
 (0)