22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
44use vortex_buffer:: Buffer ;
5+ use vortex_buffer:: BufferMut ;
56use vortex_error:: VortexExpect ;
67use vortex_error:: VortexResult ;
78use vortex_error:: vortex_bail;
9+ use vortex_error:: vortex_err;
810use vortex_error:: vortex_panic;
11+ use vortex_mask:: AllOr ;
12+ use vortex_mask:: Mask ;
913
1014use crate :: ArrayRef ;
1115use crate :: ExecutionCtx ;
1216use crate :: IntoArray ;
1317use crate :: array:: ArrayView ;
1418use crate :: arrays:: Decimal ;
1519use crate :: arrays:: DecimalArray ;
20+ use crate :: arrays:: primitive:: PrimitiveArray ;
1621use crate :: dtype:: DType ;
1722use crate :: dtype:: DecimalType ;
1823use crate :: dtype:: NativeDecimalType ;
24+ use crate :: dtype:: NativePType ;
1925use crate :: match_each_decimal_value_type;
26+ use crate :: match_each_native_ptype;
2027use crate :: scalar_fn:: fns:: cast:: CastKernel ;
2128use crate :: scalar_fn:: fns:: cast:: CastReduce ;
2229
@@ -66,17 +73,40 @@ impl CastKernel for Decimal {
6673 dtype : & DType ,
6774 ctx : & mut ExecutionCtx ,
6875 ) -> VortexResult < Option < ArrayRef > > {
69- // Early return if not casting to decimal
70- let DType :: Decimal ( to_decimal_dtype, to_nullability) = dtype else {
71- return Ok ( None ) ;
72- } ;
7376 let DType :: Decimal ( from_decimal_dtype, _) = array. dtype ( ) else {
7477 vortex_panic ! (
7578 "DecimalArray must have decimal dtype, got {:?}" ,
7679 array. dtype( )
7780 ) ;
7881 } ;
7982
83+ if let DType :: Primitive ( to_ptype, to_nullability) = dtype {
84+ let validity = array. validity ( ) ?;
85+ let new_validity =
86+ validity
87+ . clone ( )
88+ . cast_nullability ( * to_nullability, array. len ( ) , ctx) ?;
89+ let mask = validity. execute_mask ( array. len ( ) , ctx) ?;
90+
91+ return Ok ( Some ( match_each_native_ptype ! ( * to_ptype, |T | {
92+ match_each_decimal_value_type!( array. values_type( ) , |F | {
93+ PrimitiveArray :: new(
94+ cast_decimal_buffer_to_primitive:: <F , T >(
95+ array. buffer:: <F >( ) ,
96+ from_decimal_dtype. scale( ) ,
97+ mask,
98+ ) ?,
99+ new_validity,
100+ )
101+ . into_array( )
102+ } )
103+ } ) ) ) ;
104+ }
105+
106+ let DType :: Decimal ( to_decimal_dtype, to_nullability) = dtype else {
107+ return Ok ( None ) ;
108+ } ;
109+
80110 // Scale changes are not yet supported
81111 if from_decimal_dtype. scale ( ) != to_decimal_dtype. scale ( ) {
82112 vortex_bail ! (
@@ -180,6 +210,57 @@ fn upcast_decimal_buffer<F: NativeDecimalType, T: NativeDecimalType>(from: Buffe
180210 . collect ( )
181211}
182212
213+ fn cast_decimal_buffer_to_primitive < F , T > (
214+ from : Buffer < F > ,
215+ scale : i8 ,
216+ mask : Mask ,
217+ ) -> VortexResult < Buffer < T > >
218+ where
219+ F : NativeDecimalType ,
220+ T : NativePType ,
221+ {
222+ let scale_factor = 10_f64 . powi ( i32:: from ( scale) ) ;
223+
224+ match mask. bit_buffer ( ) {
225+ AllOr :: All => {
226+ let mut buffer = BufferMut :: < T > :: with_capacity ( from. len ( ) ) ;
227+ for value in from {
228+ let value = cast_decimal_value_to_primitive :: < F , T > ( value, scale_factor) ?;
229+ buffer. push ( value) ;
230+ }
231+ Ok ( buffer. freeze ( ) )
232+ }
233+ AllOr :: None => Ok ( Buffer :: zeroed ( from. len ( ) ) ) ,
234+ AllOr :: Some ( validity) => {
235+ let mut buffer = BufferMut :: < T > :: with_capacity ( from. len ( ) ) ;
236+ for ( value, valid) in from. iter ( ) . zip ( validity. iter ( ) ) {
237+ if valid {
238+ let value = cast_decimal_value_to_primitive :: < F , T > ( * value, scale_factor) ?;
239+ buffer. push ( value) ;
240+ } else {
241+ buffer. push ( T :: default ( ) ) ;
242+ }
243+ }
244+ Ok ( buffer. freeze ( ) )
245+ }
246+ }
247+ }
248+
249+ fn cast_decimal_value_to_primitive < F , T > ( value : F , scale_factor : f64 ) -> VortexResult < T >
250+ where
251+ F : NativeDecimalType ,
252+ T : NativePType ,
253+ {
254+ let value = value
255+ . to_f64 ( )
256+ . ok_or_else ( || vortex_err ! ( Compute : "Failed to cast decimal value {value} to f64" ) ) ?
257+ / scale_factor;
258+
259+ T :: from ( value) . ok_or_else (
260+ || vortex_err ! ( Compute : "Failed to cast decimal value {value} to {:?}" , T :: PTYPE ) ,
261+ )
262+ }
263+
183264#[ cfg( test) ]
184265mod tests {
185266 use rstest:: rstest;
@@ -198,6 +279,7 @@ mod tests {
198279 use crate :: dtype:: DecimalDType ;
199280 use crate :: dtype:: DecimalType ;
200281 use crate :: dtype:: Nullability ;
282+ use crate :: dtype:: PType ;
201283 use crate :: validity:: Validity ;
202284
203285 #[ test]
@@ -331,6 +413,63 @@ mod tests {
331413 assert_eq ! ( casted. values_type( ) , DecimalType :: I128 ) ;
332414 }
333415
416+ #[ test]
417+ fn cast_decimal_to_f64_applies_scale ( ) {
418+ let array = DecimalArray :: new (
419+ buffer ! [ 12345i64 , -50 , 0 ] ,
420+ DecimalDType :: new ( 15 , 2 ) ,
421+ Validity :: NonNullable ,
422+ ) ;
423+ let dtype = DType :: Primitive ( PType :: F64 , Nullability :: NonNullable ) ;
424+
425+ #[ expect( deprecated) ]
426+ let casted = array
427+ . into_array ( )
428+ . cast ( dtype. clone ( ) )
429+ . unwrap ( )
430+ . to_primitive ( ) ;
431+
432+ assert_eq ! ( casted. as_ref( ) . dtype( ) , & dtype) ;
433+ assert ! ( matches!(
434+ casted. as_ref( ) . validity( ) ,
435+ Ok ( Validity :: NonNullable )
436+ ) ) ;
437+ let values = casted. as_slice :: < f64 > ( ) ;
438+ assert ! ( ( values[ 0 ] - 123.45 ) . abs( ) < 0.000000000001 ) ;
439+ assert_eq ! ( values[ 1 ] , -0.5 ) ;
440+ assert_eq ! ( values[ 2 ] , 0.0 ) ;
441+ }
442+
443+ #[ test]
444+ fn cast_nullable_decimal_to_nullable_f64_preserves_validity ( ) {
445+ let array = DecimalArray :: from_option_iter (
446+ [ Some ( 12345i64 ) , None , Some ( -50 ) ] ,
447+ DecimalDType :: new ( 15 , 2 ) ,
448+ ) ;
449+ let dtype = DType :: Primitive ( PType :: F64 , Nullability :: Nullable ) ;
450+
451+ #[ expect( deprecated) ]
452+ let casted = array
453+ . into_array ( )
454+ . cast ( dtype. clone ( ) )
455+ . unwrap ( )
456+ . to_primitive ( ) ;
457+
458+ assert_eq ! ( casted. as_ref( ) . dtype( ) , & dtype) ;
459+ let mask = casted
460+ . as_ref ( )
461+ . validity ( )
462+ . unwrap ( )
463+ . execute_mask ( casted. len ( ) , & mut LEGACY_SESSION . create_execution_ctx ( ) )
464+ . unwrap ( ) ;
465+ assert ! ( mask. value( 0 ) ) ;
466+ assert ! ( !mask. value( 1 ) ) ;
467+ assert ! ( mask. value( 2 ) ) ;
468+ let values = casted. as_slice :: < f64 > ( ) ;
469+ assert ! ( ( values[ 0 ] - 123.45 ) . abs( ) < 0.000000000001 ) ;
470+ assert_eq ! ( values[ 2 ] , -0.5 ) ;
471+ }
472+
334473 #[ test]
335474 fn cast_to_non_decimal_returns_err ( ) {
336475 let array = DecimalArray :: new (
0 commit comments