Skip to content

Commit 7e5b081

Browse files
cast: support vortex.date -> vortex.timestamp extension casts (#28)
1 parent 3e77695 commit 7e5b081

1 file changed

Lines changed: 218 additions & 18 deletions

File tree

  • vortex-array/src/arrays/extension/compute

vortex-array/src/arrays/extension/compute/cast.rs

Lines changed: 218 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,187 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use vortex_buffer::BufferMut;
5+
use vortex_error::VortexResult;
6+
use vortex_error::vortex_bail;
7+
use vortex_mask::AllOr;
8+
49
use crate::ArrayRef;
510
use crate::IntoArray;
11+
use crate::ToCanonical;
612
use crate::arrays::ExtensionArray;
713
use crate::arrays::ExtensionVTable;
14+
use crate::arrays::PrimitiveArray;
815
use crate::builtins::ArrayBuiltins;
916
use crate::dtype::DType;
17+
use crate::dtype::PType;
18+
use crate::extension::datetime::AnyTemporal;
19+
use crate::extension::datetime::TemporalMetadata;
20+
use crate::extension::datetime::TimeUnit;
1021
use crate::scalar_fn::fns::cast::CastReduce;
22+
use crate::vtable::ValidityHelper;
1123

1224
impl CastReduce for ExtensionVTable {
13-
fn cast(array: &ExtensionArray, dtype: &DType) -> vortex_error::VortexResult<Option<ArrayRef>> {
14-
if !array.dtype().eq_ignore_nullability(dtype) {
25+
fn cast(array: &ExtensionArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
26+
let DType::Extension(ext_dtype) = dtype else {
1527
return Ok(None);
28+
};
29+
30+
if array.ext_dtype().eq_ignore_nullability(ext_dtype) {
31+
let new_storage = match array
32+
.storage()
33+
.cast(ext_dtype.storage_dtype().clone())
34+
.and_then(|a| a.to_canonical().map(|c| c.into_array()))
35+
{
36+
Ok(arr) => arr,
37+
Err(e) => {
38+
tracing::warn!("Failed to cast storage array: {e}");
39+
return Ok(None);
40+
}
41+
};
42+
43+
return Ok(Some(
44+
ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
45+
));
1646
}
1747

18-
let DType::Extension(ext_dtype) = dtype else {
19-
unreachable!("Already verified we have an extension dtype");
20-
};
48+
if let Some(new_storage) = cast_temporal_date_to_timestamp(array, dtype)? {
49+
return Ok(Some(
50+
ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
51+
));
52+
}
53+
54+
Ok(None)
55+
}
56+
}
57+
58+
fn cast_temporal_date_to_timestamp(
59+
array: &ExtensionArray,
60+
target_dtype: &DType,
61+
) -> VortexResult<Option<ArrayRef>> {
62+
let DType::Extension(target_ext_dtype) = target_dtype else {
63+
return Ok(None);
64+
};
65+
66+
let Some(source_temporal) = array.ext_dtype().metadata_opt::<AnyTemporal>() else {
67+
return Ok(None);
68+
};
69+
let Some(target_temporal) = target_ext_dtype.metadata_opt::<AnyTemporal>() else {
70+
return Ok(None);
71+
};
72+
73+
let (TemporalMetadata::Date(source_unit), TemporalMetadata::Timestamp(target_unit, _)) =
74+
(source_temporal, target_temporal)
75+
else {
76+
return Ok(None);
77+
};
78+
79+
let source_i64 = array
80+
.storage()
81+
.cast(DType::Primitive(PType::I64, array.dtype().nullability()))?;
82+
let source_i64 = source_i64.to_primitive();
2183

22-
let new_storage = match array
23-
.storage()
24-
.cast(ext_dtype.storage_dtype().clone())
25-
.and_then(|a| a.to_canonical().map(|c| c.into_array()))
26-
{
27-
Ok(arr) => arr,
28-
Err(e) => {
29-
tracing::warn!("Failed to cast storage array: {e}");
30-
return Ok(None);
84+
let converted = cast_date_values_to_timestamp(&source_i64, *source_unit, *target_unit)?;
85+
86+
converted
87+
.to_array()
88+
.cast(target_ext_dtype.storage_dtype().clone())
89+
.map(Some)
90+
}
91+
92+
fn cast_date_values_to_timestamp(
93+
values: &PrimitiveArray,
94+
source_unit: TimeUnit,
95+
target_unit: TimeUnit,
96+
) -> VortexResult<PrimitiveArray> {
97+
let (multiply, divide) = date_to_timestamp_scale(source_unit, target_unit)?;
98+
99+
let input = values.as_slice::<i64>();
100+
let mut output = BufferMut::with_capacity(input.len());
101+
match values.validity_mask()?.bit_buffer() {
102+
AllOr::All => {
103+
for &value in input {
104+
// SAFETY: output has sufficient capacity for all pushed values.
105+
unsafe { output.push_unchecked(convert_temporal_value(value, multiply, divide)?) };
31106
}
32-
};
107+
}
108+
AllOr::None => {
109+
for _ in 0..input.len() {
110+
// SAFETY: output has sufficient capacity for all pushed values.
111+
unsafe { output.push_unchecked(0i64) };
112+
}
113+
}
114+
AllOr::Some(bits) => {
115+
for (&value, valid) in input.iter().zip(bits.iter()) {
116+
if valid {
117+
// SAFETY: output has sufficient capacity for all pushed values.
118+
unsafe {
119+
output.push_unchecked(convert_temporal_value(value, multiply, divide)?)
120+
};
121+
} else {
122+
// SAFETY: output has sufficient capacity for all pushed values.
123+
unsafe { output.push_unchecked(0i64) };
124+
}
125+
}
126+
}
127+
}
128+
129+
Ok(PrimitiveArray::new(
130+
output.freeze(),
131+
values.validity().clone(),
132+
))
133+
}
134+
135+
fn date_to_timestamp_scale(
136+
source_unit: TimeUnit,
137+
target_unit: TimeUnit,
138+
) -> VortexResult<(i64, i64)> {
139+
let source_ns = to_nanoseconds(source_unit)?;
140+
let target_ns = to_nanoseconds(target_unit)?;
141+
142+
if source_ns >= target_ns {
143+
let multiply = source_ns / target_ns;
144+
return Ok((multiply, 1));
145+
}
33146

34-
Ok(Some(
35-
ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
36-
))
147+
let divide = target_ns / source_ns;
148+
Ok((1, divide))
149+
}
150+
151+
fn to_nanoseconds(unit: TimeUnit) -> VortexResult<i64> {
152+
match unit {
153+
TimeUnit::Nanoseconds => Ok(1),
154+
TimeUnit::Microseconds => Ok(1_000),
155+
TimeUnit::Milliseconds => Ok(1_000_000),
156+
TimeUnit::Seconds => Ok(1_000_000_000),
157+
TimeUnit::Days => Ok(86_400_000_000_000),
158+
}
159+
}
160+
161+
fn convert_temporal_value(value: i64, multiply: i64, divide: i64) -> VortexResult<i64> {
162+
let mut scaled = i128::from(value)
163+
.checked_mul(i128::from(multiply))
164+
.ok_or_else(|| {
165+
vortex_error::vortex_err!(
166+
Compute: "Date value {value} overflows while scaling to timestamp"
167+
)
168+
})?;
169+
170+
if divide != 1 {
171+
let divisor = i128::from(divide);
172+
if scaled % divisor != 0 {
173+
vortex_bail!(
174+
Compute: "Date value {value} cannot be represented exactly in target timestamp unit"
175+
);
176+
}
177+
scaled /= divisor;
178+
}
179+
180+
if scaled < i128::from(i64::MIN) || scaled > i128::from(i64::MAX) {
181+
vortex_bail!(Compute: "Date value {value} overflows target timestamp range");
37182
}
183+
184+
Ok(scaled as i64)
38185
}
39186

40187
#[cfg(test)]
@@ -45,11 +192,13 @@ mod tests {
45192
use vortex_buffer::buffer;
46193

47194
use super::*;
195+
use crate::Array;
48196
use crate::IntoArray;
49197
use crate::arrays::PrimitiveArray;
50198
use crate::builtins::ArrayBuiltins;
51199
use crate::compute::conformance::cast::test_cast_conformance;
52200
use crate::dtype::Nullability;
201+
use crate::extension::datetime::Date;
53202
use crate::extension::datetime::TimeUnit;
54203
use crate::extension::datetime::Timestamp;
55204

@@ -85,6 +234,57 @@ mod tests {
85234
assert_eq!(output.dtype(), &new_dtype);
86235
}
87236

237+
#[test]
238+
fn cast_date_days_to_timestamp_nanoseconds() {
239+
let source_dtype = Date::new(TimeUnit::Days, Nullability::NonNullable).erased();
240+
let target_dtype = Timestamp::new(TimeUnit::Nanoseconds, Nullability::NonNullable).erased();
241+
242+
let arr = ExtensionArray::new(source_dtype, buffer![0i32, 1, -1].into_array());
243+
let output = arr
244+
.to_array()
245+
.cast(DType::Extension(target_dtype.clone()))
246+
.unwrap()
247+
.to_extension();
248+
249+
assert_eq!(output.dtype(), &DType::Extension(target_dtype));
250+
251+
let storage = output.storage().to_primitive();
252+
assert_eq!(
253+
storage.as_slice::<i64>(),
254+
&[0, 86_400_000_000_000, -86_400_000_000_000]
255+
);
256+
}
257+
258+
#[test]
259+
fn cast_date_days_to_timestamp_seconds_nullable() {
260+
let source_dtype = Date::new(TimeUnit::Days, Nullability::Nullable).erased();
261+
let target_dtype = Timestamp::new(TimeUnit::Seconds, Nullability::Nullable).erased();
262+
263+
let arr = ExtensionArray::new(
264+
source_dtype,
265+
PrimitiveArray::from_option_iter([Some(0i32), None, Some(2)]).into_array(),
266+
);
267+
268+
let output = arr
269+
.to_array()
270+
.cast(DType::Extension(target_dtype.clone()))
271+
.unwrap()
272+
.to_extension();
273+
274+
assert_eq!(output.dtype(), &DType::Extension(target_dtype));
275+
276+
let storage = output.storage().to_primitive();
277+
assert_eq!(
278+
storage.scalar_at(0).unwrap().as_primitive().as_::<i64>(),
279+
Some(0)
280+
);
281+
assert!(storage.scalar_at(1).unwrap().is_null());
282+
assert_eq!(
283+
storage.scalar_at(2).unwrap().as_primitive().as_::<i64>(),
284+
Some(172_800)
285+
);
286+
}
287+
88288
#[test]
89289
fn cast_different_ext_dtype() {
90290
let original_dtype =

0 commit comments

Comments
 (0)