11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4+ use vortex_buffer:: BufferMut ;
5+ use vortex_error:: VortexResult ;
6+ use vortex_error:: vortex_bail;
7+ use vortex_mask:: AllOr ;
8+
49use crate :: ArrayRef ;
510use crate :: IntoArray ;
11+ use crate :: ToCanonical ;
612use crate :: arrays:: ExtensionArray ;
713use crate :: arrays:: ExtensionVTable ;
14+ use crate :: arrays:: PrimitiveArray ;
815use crate :: builtins:: ArrayBuiltins ;
916use crate :: dtype:: DType ;
17+ use crate :: dtype:: PType ;
18+ use crate :: extension:: datetime:: AnyTemporal ;
19+ use crate :: extension:: datetime:: TemporalMetadata ;
20+ use crate :: extension:: datetime:: TimeUnit ;
1021use crate :: scalar_fn:: fns:: cast:: CastReduce ;
22+ use crate :: vtable:: ValidityHelper ;
1123
1224impl CastReduce for ExtensionVTable {
13- fn cast ( array : & ExtensionArray , dtype : & DType ) -> vortex_error :: VortexResult < Option < ArrayRef > > {
14- if !array . dtype ( ) . eq_ignore_nullability ( dtype) {
25+ fn cast ( array : & ExtensionArray , dtype : & DType ) -> VortexResult < Option < ArrayRef > > {
26+ let DType :: Extension ( ext_dtype ) = dtype else {
1527 return Ok ( None ) ;
28+ } ;
29+
30+ if array. ext_dtype ( ) . eq_ignore_nullability ( ext_dtype) {
31+ let new_storage = match array
32+ . storage ( )
33+ . cast ( ext_dtype. storage_dtype ( ) . clone ( ) )
34+ . and_then ( |a| a. to_canonical ( ) . map ( |c| c. into_array ( ) ) )
35+ {
36+ Ok ( arr) => arr,
37+ Err ( e) => {
38+ tracing:: warn!( "Failed to cast storage array: {e}" ) ;
39+ return Ok ( None ) ;
40+ }
41+ } ;
42+
43+ return Ok ( Some (
44+ ExtensionArray :: new ( ext_dtype. clone ( ) , new_storage) . into_array ( ) ,
45+ ) ) ;
1646 }
1747
18- let DType :: Extension ( ext_dtype) = dtype else {
19- unreachable ! ( "Already verified we have an extension dtype" ) ;
20- } ;
48+ if let Some ( new_storage) = cast_temporal_date_to_timestamp ( array, dtype) ? {
49+ return Ok ( Some (
50+ ExtensionArray :: new ( ext_dtype. clone ( ) , new_storage) . into_array ( ) ,
51+ ) ) ;
52+ }
53+
54+ Ok ( None )
55+ }
56+ }
57+
58+ fn cast_temporal_date_to_timestamp (
59+ array : & ExtensionArray ,
60+ target_dtype : & DType ,
61+ ) -> VortexResult < Option < ArrayRef > > {
62+ let DType :: Extension ( target_ext_dtype) = target_dtype else {
63+ return Ok ( None ) ;
64+ } ;
65+
66+ let Some ( source_temporal) = array. ext_dtype ( ) . metadata_opt :: < AnyTemporal > ( ) else {
67+ return Ok ( None ) ;
68+ } ;
69+ let Some ( target_temporal) = target_ext_dtype. metadata_opt :: < AnyTemporal > ( ) else {
70+ return Ok ( None ) ;
71+ } ;
72+
73+ let ( TemporalMetadata :: Date ( source_unit) , TemporalMetadata :: Timestamp ( target_unit, _) ) =
74+ ( source_temporal, target_temporal)
75+ else {
76+ return Ok ( None ) ;
77+ } ;
78+
79+ let source_i64 = array
80+ . storage ( )
81+ . cast ( DType :: Primitive ( PType :: I64 , array. dtype ( ) . nullability ( ) ) ) ?;
82+ let source_i64 = source_i64. to_primitive ( ) ;
2183
22- let new_storage = match array
23- . storage ( )
24- . cast ( ext_dtype. storage_dtype ( ) . clone ( ) )
25- . and_then ( |a| a. to_canonical ( ) . map ( |c| c. into_array ( ) ) )
26- {
27- Ok ( arr) => arr,
28- Err ( e) => {
29- tracing:: warn!( "Failed to cast storage array: {e}" ) ;
30- return Ok ( None ) ;
84+ let converted = cast_date_values_to_timestamp ( & source_i64, * source_unit, * target_unit) ?;
85+
86+ converted
87+ . to_array ( )
88+ . cast ( target_ext_dtype. storage_dtype ( ) . clone ( ) )
89+ . map ( Some )
90+ }
91+
92+ fn cast_date_values_to_timestamp (
93+ values : & PrimitiveArray ,
94+ source_unit : TimeUnit ,
95+ target_unit : TimeUnit ,
96+ ) -> VortexResult < PrimitiveArray > {
97+ let ( multiply, divide) = date_to_timestamp_scale ( source_unit, target_unit) ?;
98+
99+ let input = values. as_slice :: < i64 > ( ) ;
100+ let mut output = BufferMut :: with_capacity ( input. len ( ) ) ;
101+ match values. validity_mask ( ) ?. bit_buffer ( ) {
102+ AllOr :: All => {
103+ for & value in input {
104+ // SAFETY: output has sufficient capacity for all pushed values.
105+ unsafe { output. push_unchecked ( convert_temporal_value ( value, multiply, divide) ?) } ;
31106 }
32- } ;
107+ }
108+ AllOr :: None => {
109+ for _ in 0 ..input. len ( ) {
110+ // SAFETY: output has sufficient capacity for all pushed values.
111+ unsafe { output. push_unchecked ( 0i64 ) } ;
112+ }
113+ }
114+ AllOr :: Some ( bits) => {
115+ for ( & value, valid) in input. iter ( ) . zip ( bits. iter ( ) ) {
116+ if valid {
117+ // SAFETY: output has sufficient capacity for all pushed values.
118+ unsafe {
119+ output. push_unchecked ( convert_temporal_value ( value, multiply, divide) ?)
120+ } ;
121+ } else {
122+ // SAFETY: output has sufficient capacity for all pushed values.
123+ unsafe { output. push_unchecked ( 0i64 ) } ;
124+ }
125+ }
126+ }
127+ }
128+
129+ Ok ( PrimitiveArray :: new (
130+ output. freeze ( ) ,
131+ values. validity ( ) . clone ( ) ,
132+ ) )
133+ }
134+
135+ fn date_to_timestamp_scale (
136+ source_unit : TimeUnit ,
137+ target_unit : TimeUnit ,
138+ ) -> VortexResult < ( i64 , i64 ) > {
139+ let source_ns = to_nanoseconds ( source_unit) ?;
140+ let target_ns = to_nanoseconds ( target_unit) ?;
141+
142+ if source_ns >= target_ns {
143+ let multiply = source_ns / target_ns;
144+ return Ok ( ( multiply, 1 ) ) ;
145+ }
33146
34- Ok ( Some (
35- ExtensionArray :: new ( ext_dtype. clone ( ) , new_storage) . into_array ( ) ,
36- ) )
147+ let divide = target_ns / source_ns;
148+ Ok ( ( 1 , divide) )
149+ }
150+
151+ fn to_nanoseconds ( unit : TimeUnit ) -> VortexResult < i64 > {
152+ match unit {
153+ TimeUnit :: Nanoseconds => Ok ( 1 ) ,
154+ TimeUnit :: Microseconds => Ok ( 1_000 ) ,
155+ TimeUnit :: Milliseconds => Ok ( 1_000_000 ) ,
156+ TimeUnit :: Seconds => Ok ( 1_000_000_000 ) ,
157+ TimeUnit :: Days => Ok ( 86_400_000_000_000 ) ,
158+ }
159+ }
160+
161+ fn convert_temporal_value ( value : i64 , multiply : i64 , divide : i64 ) -> VortexResult < i64 > {
162+ let mut scaled = i128:: from ( value)
163+ . checked_mul ( i128:: from ( multiply) )
164+ . ok_or_else ( || {
165+ vortex_error:: vortex_err!(
166+ Compute : "Date value {value} overflows while scaling to timestamp"
167+ )
168+ } ) ?;
169+
170+ if divide != 1 {
171+ let divisor = i128:: from ( divide) ;
172+ if scaled % divisor != 0 {
173+ vortex_bail ! (
174+ Compute : "Date value {value} cannot be represented exactly in target timestamp unit"
175+ ) ;
176+ }
177+ scaled /= divisor;
178+ }
179+
180+ if scaled < i128:: from ( i64:: MIN ) || scaled > i128:: from ( i64:: MAX ) {
181+ vortex_bail ! ( Compute : "Date value {value} overflows target timestamp range" ) ;
37182 }
183+
184+ Ok ( scaled as i64 )
38185}
39186
40187#[ cfg( test) ]
@@ -45,11 +192,13 @@ mod tests {
45192 use vortex_buffer:: buffer;
46193
47194 use super :: * ;
195+ use crate :: Array ;
48196 use crate :: IntoArray ;
49197 use crate :: arrays:: PrimitiveArray ;
50198 use crate :: builtins:: ArrayBuiltins ;
51199 use crate :: compute:: conformance:: cast:: test_cast_conformance;
52200 use crate :: dtype:: Nullability ;
201+ use crate :: extension:: datetime:: Date ;
53202 use crate :: extension:: datetime:: TimeUnit ;
54203 use crate :: extension:: datetime:: Timestamp ;
55204
@@ -85,6 +234,57 @@ mod tests {
85234 assert_eq ! ( output. dtype( ) , & new_dtype) ;
86235 }
87236
237+ #[ test]
238+ fn cast_date_days_to_timestamp_nanoseconds ( ) {
239+ let source_dtype = Date :: new ( TimeUnit :: Days , Nullability :: NonNullable ) . erased ( ) ;
240+ let target_dtype = Timestamp :: new ( TimeUnit :: Nanoseconds , Nullability :: NonNullable ) . erased ( ) ;
241+
242+ let arr = ExtensionArray :: new ( source_dtype, buffer ! [ 0i32 , 1 , -1 ] . into_array ( ) ) ;
243+ let output = arr
244+ . to_array ( )
245+ . cast ( DType :: Extension ( target_dtype. clone ( ) ) )
246+ . unwrap ( )
247+ . to_extension ( ) ;
248+
249+ assert_eq ! ( output. dtype( ) , & DType :: Extension ( target_dtype) ) ;
250+
251+ let storage = output. storage ( ) . to_primitive ( ) ;
252+ assert_eq ! (
253+ storage. as_slice:: <i64 >( ) ,
254+ & [ 0 , 86_400_000_000_000 , -86_400_000_000_000 ]
255+ ) ;
256+ }
257+
258+ #[ test]
259+ fn cast_date_days_to_timestamp_seconds_nullable ( ) {
260+ let source_dtype = Date :: new ( TimeUnit :: Days , Nullability :: Nullable ) . erased ( ) ;
261+ let target_dtype = Timestamp :: new ( TimeUnit :: Seconds , Nullability :: Nullable ) . erased ( ) ;
262+
263+ let arr = ExtensionArray :: new (
264+ source_dtype,
265+ PrimitiveArray :: from_option_iter ( [ Some ( 0i32 ) , None , Some ( 2 ) ] ) . into_array ( ) ,
266+ ) ;
267+
268+ let output = arr
269+ . to_array ( )
270+ . cast ( DType :: Extension ( target_dtype. clone ( ) ) )
271+ . unwrap ( )
272+ . to_extension ( ) ;
273+
274+ assert_eq ! ( output. dtype( ) , & DType :: Extension ( target_dtype) ) ;
275+
276+ let storage = output. storage ( ) . to_primitive ( ) ;
277+ assert_eq ! (
278+ storage. scalar_at( 0 ) . unwrap( ) . as_primitive( ) . as_:: <i64 >( ) ,
279+ Some ( 0 )
280+ ) ;
281+ assert ! ( storage. scalar_at( 1 ) . unwrap( ) . is_null( ) ) ;
282+ assert_eq ! (
283+ storage. scalar_at( 2 ) . unwrap( ) . as_primitive( ) . as_:: <i64 >( ) ,
284+ Some ( 172_800 )
285+ ) ;
286+ }
287+
88288 #[ test]
89289 fn cast_different_ext_dtype ( ) {
90290 let original_dtype =
0 commit comments