Skip to content

Commit 8cafbfe

Browse files
authored
Extension Scalar Cleanup (vortex-data#6699)
## Summary Tracking Issue: vortex-data#6618 Some cleanup work before I implement the validation, construction, and display logic for extension scalars. Also some renames / moving code around. Part of this was us realizing that we don't need the `ExtScalarValueRef` since we can just get the `ExtVTable` from the dtype! ## Testing N/A --------- Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent ba98e29 commit 8cafbfe

23 files changed

Lines changed: 349 additions & 894 deletions

vortex-array/public-api.lock

Lines changed: 40 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -11390,80 +11390,6 @@ pub const vortex_array::patches::PATCH_CHUNK_SIZE: usize
1139011390

1139111391
pub mod vortex_array::scalar
1139211392

11393-
pub mod vortex_array::scalar::extension
11394-
11395-
pub struct vortex_array::scalar::extension::ExtScalarValue<V: vortex_array::dtype::extension::ExtVTable>(_)
11396-
11397-
impl<V: vortex_array::dtype::extension::ExtVTable> vortex_array::scalar::extension::ExtScalarValue<V>
11398-
11399-
pub fn vortex_array::scalar::extension::ExtScalarValue<V>::erased(self) -> vortex_array::scalar::extension::ExtScalarValueRef
11400-
11401-
pub fn vortex_array::scalar::extension::ExtScalarValue<V>::id(&self) -> vortex_array::dtype::extension::ExtId
11402-
11403-
pub fn vortex_array::scalar::extension::ExtScalarValue<V>::storage_value(&self) -> &vortex_array::scalar::ScalarValue
11404-
11405-
pub fn vortex_array::scalar::extension::ExtScalarValue<V>::try_new(ext_dtype: &vortex_array::dtype::extension::ExtDType<V>, storage: vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<Self>
11406-
11407-
pub fn vortex_array::scalar::extension::ExtScalarValue<V>::vtable(&self) -> &V
11408-
11409-
impl<V: core::clone::Clone + vortex_array::dtype::extension::ExtVTable> core::clone::Clone for vortex_array::scalar::extension::ExtScalarValue<V>
11410-
11411-
pub fn vortex_array::scalar::extension::ExtScalarValue<V>::clone(&self) -> vortex_array::scalar::extension::ExtScalarValue<V>
11412-
11413-
impl<V: core::cmp::Eq + vortex_array::dtype::extension::ExtVTable> core::cmp::Eq for vortex_array::scalar::extension::ExtScalarValue<V>
11414-
11415-
impl<V: core::cmp::PartialEq + vortex_array::dtype::extension::ExtVTable> core::cmp::PartialEq for vortex_array::scalar::extension::ExtScalarValue<V>
11416-
11417-
pub fn vortex_array::scalar::extension::ExtScalarValue<V>::eq(&self, other: &vortex_array::scalar::extension::ExtScalarValue<V>) -> bool
11418-
11419-
impl<V: core::fmt::Debug + vortex_array::dtype::extension::ExtVTable> core::fmt::Debug for vortex_array::scalar::extension::ExtScalarValue<V>
11420-
11421-
pub fn vortex_array::scalar::extension::ExtScalarValue<V>::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
11422-
11423-
impl<V: core::hash::Hash + vortex_array::dtype::extension::ExtVTable> core::hash::Hash for vortex_array::scalar::extension::ExtScalarValue<V>
11424-
11425-
pub fn vortex_array::scalar::extension::ExtScalarValue<V>::hash<__H: core::hash::Hasher>(&self, state: &mut __H)
11426-
11427-
impl<V: vortex_array::dtype::extension::ExtVTable> core::marker::StructuralPartialEq for vortex_array::scalar::extension::ExtScalarValue<V>
11428-
11429-
pub struct vortex_array::scalar::extension::ExtScalarValueRef(_)
11430-
11431-
impl vortex_array::scalar::extension::ExtScalarValueRef
11432-
11433-
pub fn vortex_array::scalar::extension::ExtScalarValueRef::downcast<V: vortex_array::dtype::extension::ExtVTable>(self) -> vortex_array::scalar::extension::ExtScalarValue<V>
11434-
11435-
pub fn vortex_array::scalar::extension::ExtScalarValueRef::id(&self) -> vortex_array::dtype::extension::ExtId
11436-
11437-
pub fn vortex_array::scalar::extension::ExtScalarValueRef::storage_value(&self) -> &vortex_array::scalar::ScalarValue
11438-
11439-
pub fn vortex_array::scalar::extension::ExtScalarValueRef::try_downcast<V: vortex_array::dtype::extension::ExtVTable>(self) -> core::result::Result<vortex_array::scalar::extension::ExtScalarValue<V>, vortex_array::scalar::extension::ExtScalarValueRef>
11440-
11441-
impl core::clone::Clone for vortex_array::scalar::extension::ExtScalarValueRef
11442-
11443-
pub fn vortex_array::scalar::extension::ExtScalarValueRef::clone(&self) -> vortex_array::scalar::extension::ExtScalarValueRef
11444-
11445-
impl core::cmp::Eq for vortex_array::scalar::extension::ExtScalarValueRef
11446-
11447-
impl core::cmp::PartialEq for vortex_array::scalar::extension::ExtScalarValueRef
11448-
11449-
pub fn vortex_array::scalar::extension::ExtScalarValueRef::eq(&self, other: &Self) -> bool
11450-
11451-
impl core::cmp::PartialOrd for vortex_array::scalar::extension::ExtScalarValueRef
11452-
11453-
pub fn vortex_array::scalar::extension::ExtScalarValueRef::partial_cmp(&self, other: &Self) -> core::option::Option<core::cmp::Ordering>
11454-
11455-
impl core::fmt::Debug for vortex_array::scalar::extension::ExtScalarValueRef
11456-
11457-
pub fn vortex_array::scalar::extension::ExtScalarValueRef::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
11458-
11459-
impl core::fmt::Display for vortex_array::scalar::extension::ExtScalarValueRef
11460-
11461-
pub fn vortex_array::scalar::extension::ExtScalarValueRef::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
11462-
11463-
impl core::hash::Hash for vortex_array::scalar::extension::ExtScalarValueRef
11464-
11465-
pub fn vortex_array::scalar::extension::ExtScalarValueRef::hash<H: core::hash::Hasher>(&self, state: &mut H)
11466-
1146711393
pub enum vortex_array::scalar::DecimalValue
1146811394

1146911395
pub vortex_array::scalar::DecimalValue::I128(i128)
@@ -11906,12 +11832,6 @@ pub fn vortex_array::scalar::ScalarValue::into_utf8(self) -> vortex_buffer::stri
1190611832

1190711833
impl vortex_array::scalar::ScalarValue
1190811834

11909-
pub fn vortex_array::scalar::ScalarValue::default_value(dtype: &vortex_array::dtype::DType) -> core::option::Option<Self>
11910-
11911-
pub fn vortex_array::scalar::ScalarValue::zero_value(dtype: &vortex_array::dtype::DType) -> Self
11912-
11913-
impl vortex_array::scalar::ScalarValue
11914-
1191511835
pub fn vortex_array::scalar::ScalarValue::from_proto(value: &vortex_proto::scalar::ScalarValue, dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<core::option::Option<Self>>
1191611836

1191711837
pub fn vortex_array::scalar::ScalarValue::from_proto_bytes(bytes: &[u8], dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<core::option::Option<Self>>
@@ -12344,8 +12264,6 @@ pub fn vortex_array::scalar::ExtScalar<'a>::ext_dtype(&self) -> &'a vortex_array
1234412264

1234512265
pub fn vortex_array::scalar::ExtScalar<'a>::to_storage_scalar(&self) -> vortex_array::scalar::Scalar
1234612266

12347-
pub fn vortex_array::scalar::ExtScalar<'a>::try_new(dtype: &'a vortex_array::dtype::DType, value: core::option::Option<&'a vortex_array::scalar::ScalarValue>) -> vortex_error::VortexResult<Self>
12348-
1234912267
impl core::cmp::Eq for vortex_array::scalar::ExtScalar<'_>
1235012268

1235112269
impl core::cmp::PartialEq for vortex_array::scalar::ExtScalar<'_>
@@ -12520,6 +12438,40 @@ pub struct vortex_array::scalar::Scalar
1252012438

1252112439
impl vortex_array::scalar::Scalar
1252212440

12441+
pub fn vortex_array::scalar::Scalar::approx_nbytes(&self) -> usize
12442+
12443+
pub fn vortex_array::scalar::Scalar::default_value(dtype: &vortex_array::dtype::DType) -> Self
12444+
12445+
pub fn vortex_array::scalar::Scalar::dtype(&self) -> &vortex_array::dtype::DType
12446+
12447+
pub fn vortex_array::scalar::Scalar::eq_ignore_nullability(&self, other: &Self) -> bool
12448+
12449+
pub fn vortex_array::scalar::Scalar::into_parts(self) -> (vortex_array::dtype::DType, core::option::Option<vortex_array::scalar::ScalarValue>)
12450+
12451+
pub fn vortex_array::scalar::Scalar::into_value(self) -> core::option::Option<vortex_array::scalar::ScalarValue>
12452+
12453+
pub fn vortex_array::scalar::Scalar::is_null(&self) -> bool
12454+
12455+
pub fn vortex_array::scalar::Scalar::is_valid(&self) -> bool
12456+
12457+
pub fn vortex_array::scalar::Scalar::is_zero(&self) -> core::option::Option<bool>
12458+
12459+
pub unsafe fn vortex_array::scalar::Scalar::new_unchecked(dtype: vortex_array::dtype::DType, value: core::option::Option<vortex_array::scalar::ScalarValue>) -> Self
12460+
12461+
pub fn vortex_array::scalar::Scalar::null(dtype: vortex_array::dtype::DType) -> Self
12462+
12463+
pub fn vortex_array::scalar::Scalar::null_native<T: vortex_array::dtype::NativeDType>() -> Self
12464+
12465+
pub fn vortex_array::scalar::Scalar::primitive_reinterpret_cast(&self, ptype: vortex_array::dtype::PType) -> vortex_error::VortexResult<Self>
12466+
12467+
pub fn vortex_array::scalar::Scalar::try_new(dtype: vortex_array::dtype::DType, value: core::option::Option<vortex_array::scalar::ScalarValue>) -> vortex_error::VortexResult<Self>
12468+
12469+
pub fn vortex_array::scalar::Scalar::value(&self) -> core::option::Option<&vortex_array::scalar::ScalarValue>
12470+
12471+
pub fn vortex_array::scalar::Scalar::zero_value(dtype: &vortex_array::dtype::DType) -> Self
12472+
12473+
impl vortex_array::scalar::Scalar
12474+
1252312475
pub fn vortex_array::scalar::Scalar::as_binary(&self) -> vortex_array::scalar::BinaryScalar<'_>
1252412476

1252512477
pub fn vortex_array::scalar::Scalar::as_binary_opt(&self) -> core::option::Option<vortex_array::scalar::BinaryScalar<'_>>
@@ -12560,9 +12512,9 @@ pub fn vortex_array::scalar::Scalar::bool(value: bool, nullability: vortex_array
1256012512

1256112513
pub fn vortex_array::scalar::Scalar::decimal(value: vortex_array::scalar::DecimalValue, decimal_type: vortex_array::dtype::DecimalDType, nullability: vortex_array::dtype::Nullability) -> Self
1256212514

12563-
pub fn vortex_array::scalar::Scalar::extension<V: vortex_array::dtype::extension::ExtVTable + core::default::Default>(options: <V as vortex_array::dtype::extension::ExtVTable>::Metadata, value: vortex_array::scalar::Scalar) -> Self
12515+
pub fn vortex_array::scalar::Scalar::extension<V: vortex_array::dtype::extension::ExtVTable + core::default::Default>(options: <V as vortex_array::dtype::extension::ExtVTable>::Metadata, storage_scalar: vortex_array::scalar::Scalar) -> Self
1256412516

12565-
pub fn vortex_array::scalar::Scalar::extension_ref(ext_dtype: vortex_array::dtype::extension::ExtDTypeRef, value: vortex_array::scalar::Scalar) -> Self
12517+
pub fn vortex_array::scalar::Scalar::extension_ref(ext_dtype: vortex_array::dtype::extension::ExtDTypeRef, storage_scalar: vortex_array::scalar::Scalar) -> Self
1256612518

1256712519
pub fn vortex_array::scalar::Scalar::fixed_size_list(element_dtype: impl core::convert::Into<alloc::sync::Arc<vortex_array::dtype::DType>>, children: alloc::vec::Vec<vortex_array::scalar::Scalar>, nullability: vortex_array::dtype::Nullability) -> Self
1256812520

@@ -12586,42 +12538,6 @@ pub fn vortex_array::scalar::Scalar::into_nullable(self) -> vortex_array::scalar
1258612538

1258712539
impl vortex_array::scalar::Scalar
1258812540

12589-
pub fn vortex_array::scalar::Scalar::default_value(dtype: &vortex_array::dtype::DType) -> Self
12590-
12591-
pub fn vortex_array::scalar::Scalar::dtype(&self) -> &vortex_array::dtype::DType
12592-
12593-
pub fn vortex_array::scalar::Scalar::eq_ignore_nullability(&self, other: &Self) -> bool
12594-
12595-
pub fn vortex_array::scalar::Scalar::into_parts(self) -> (vortex_array::dtype::DType, core::option::Option<vortex_array::scalar::ScalarValue>)
12596-
12597-
pub fn vortex_array::scalar::Scalar::into_value(self) -> core::option::Option<vortex_array::scalar::ScalarValue>
12598-
12599-
pub fn vortex_array::scalar::Scalar::is_compatible(dtype: &vortex_array::dtype::DType, value: core::option::Option<&vortex_array::scalar::ScalarValue>) -> bool
12600-
12601-
pub fn vortex_array::scalar::Scalar::is_null(&self) -> bool
12602-
12603-
pub fn vortex_array::scalar::Scalar::is_valid(&self) -> bool
12604-
12605-
pub fn vortex_array::scalar::Scalar::is_zero(&self) -> core::option::Option<bool>
12606-
12607-
pub fn vortex_array::scalar::Scalar::nbytes(&self) -> usize
12608-
12609-
pub unsafe fn vortex_array::scalar::Scalar::new_unchecked(dtype: vortex_array::dtype::DType, value: core::option::Option<vortex_array::scalar::ScalarValue>) -> Self
12610-
12611-
pub fn vortex_array::scalar::Scalar::null(dtype: vortex_array::dtype::DType) -> Self
12612-
12613-
pub fn vortex_array::scalar::Scalar::null_native<T: vortex_array::dtype::NativeDType>() -> Self
12614-
12615-
pub fn vortex_array::scalar::Scalar::primitive_reinterpret_cast(&self, ptype: vortex_array::dtype::PType) -> vortex_error::VortexResult<Self>
12616-
12617-
pub fn vortex_array::scalar::Scalar::try_new(dtype: vortex_array::dtype::DType, value: core::option::Option<vortex_array::scalar::ScalarValue>) -> vortex_error::VortexResult<Self>
12618-
12619-
pub fn vortex_array::scalar::Scalar::value(&self) -> core::option::Option<&vortex_array::scalar::ScalarValue>
12620-
12621-
pub fn vortex_array::scalar::Scalar::zero_value(dtype: &vortex_array::dtype::DType) -> Self
12622-
12623-
impl vortex_array::scalar::Scalar
12624-
1262512541
pub fn vortex_array::scalar::Scalar::from_proto(value: &vortex_proto::scalar::Scalar, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self>
1262612542

1262712543
pub fn vortex_array::scalar::Scalar::from_proto_value(value: &vortex_proto::scalar::ScalarValue, dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self>
@@ -12630,6 +12546,10 @@ impl vortex_array::scalar::Scalar
1263012546

1263112547
pub fn vortex_array::scalar::Scalar::struct_(dtype: vortex_array::dtype::DType, children: alloc::vec::Vec<vortex_array::scalar::Scalar>) -> Self
1263212548

12549+
impl vortex_array::scalar::Scalar
12550+
12551+
pub fn vortex_array::scalar::Scalar::validate(dtype: &vortex_array::dtype::DType, value: core::option::Option<&vortex_array::scalar::ScalarValue>) -> vortex_error::VortexResult<()>
12552+
1263312553
impl core::clone::Clone for vortex_array::scalar::Scalar
1263412554

1263512555
pub fn vortex_array::scalar::Scalar::clone(&self) -> vortex_array::scalar::Scalar

vortex-array/src/extension/tests/divisible_int.rs

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,12 @@
55
66
use std::fmt;
77

8-
use vortex_error::VortexExpect;
98
use vortex_error::VortexResult;
109
use vortex_error::vortex_bail;
1110
use vortex_error::vortex_ensure;
1211

1312
use crate::dtype::DType;
14-
use crate::dtype::Nullability;
1513
use crate::dtype::PType;
16-
use crate::dtype::extension::ExtDType;
1714
use crate::dtype::extension::ExtId;
1815
use crate::dtype::extension::ExtVTable;
1916
use crate::scalar::ScalarValue;
@@ -32,14 +29,6 @@ impl fmt::Display for Divisor {
3229
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
3330
pub struct DivisibleInt;
3431

35-
impl DivisibleInt {
36-
/// Creates a new divisible integer extension dtype.
37-
pub fn new(divisor: u64, nullability: Nullability) -> ExtDType<Self> {
38-
ExtDType::try_new(Divisor(divisor), DType::Primitive(PType::U64, nullability))
39-
.vortex_expect("valid divisible int dtype")
40-
}
41-
}
42-
4332
impl ExtVTable for DivisibleInt {
4433
type Metadata = Divisor;
4534
type NativeValue<'a> = u64;
@@ -98,45 +87,6 @@ mod tests {
9887
use crate::dtype::Nullability;
9988
use crate::dtype::PType;
10089
use crate::dtype::extension::ExtVTable;
101-
use crate::scalar::PValue;
102-
use crate::scalar::ScalarValue;
103-
use crate::scalar::extension::ExtScalarValue;
104-
105-
#[test]
106-
fn accepts_divisible_values() -> VortexResult<()> {
107-
let div7 = DivisibleInt::new(7, Nullability::NonNullable);
108-
109-
for multiple in [0, 7, 14, 21, 7000] {
110-
let sv = ExtScalarValue::<DivisibleInt>::try_new(
111-
&div7,
112-
ScalarValue::Primitive(PValue::U64(multiple)),
113-
)?;
114-
assert_eq!(
115-
sv.storage_value(),
116-
&ScalarValue::Primitive(PValue::U64(multiple))
117-
);
118-
}
119-
120-
Ok(())
121-
}
122-
123-
#[test]
124-
fn rejects_non_divisible_values() -> VortexResult<()> {
125-
let div7 = DivisibleInt::new(7, Nullability::NonNullable);
126-
127-
for bad in [1, 2, 6, 8, 13, 15] {
128-
assert!(
129-
ExtScalarValue::<DivisibleInt>::try_new(
130-
&div7,
131-
ScalarValue::Primitive(PValue::U64(bad)),
132-
)
133-
.is_err(),
134-
"{bad} should not be accepted as divisible by 7"
135-
);
136-
}
137-
138-
Ok(())
139-
}
14090

14191
#[test]
14292
fn metadata_roundtrip() -> VortexResult<()> {

vortex-array/src/scalar/cast.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ impl Scalar {
4141
return Scalar::try_new(target_dtype.clone(), self.value().cloned());
4242
}
4343

44-
// TODO(connor): This isn't really correct but this will get fixed soon.
44+
// TODO(connor): This isn't really correct for extension types.
4545
// If the target is an extension type, then we want to cast to its storage type.
4646
if let Some(ext_dtype) = target_dtype.as_extension_opt() {
4747
let cast_storage_scalar_value = self.cast(ext_dtype.storage_dtype())?.into_value();

vortex-array/src/scalar/constructor.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -171,21 +171,22 @@ impl Scalar {
171171
}
172172

173173
/// Creates a new extension scalar wrapping the given storage value.
174-
pub fn extension<V: ExtVTable + Default>(options: V::Metadata, value: Scalar) -> Self {
175-
let ext_dtype = ExtDType::<V>::try_new(options, value.dtype().clone())
174+
pub fn extension<V: ExtVTable + Default>(options: V::Metadata, storage_scalar: Scalar) -> Self {
175+
let ext_dtype = ExtDType::<V>::try_new(options, storage_scalar.dtype().clone())
176176
.vortex_expect("Failed to create extension dtype");
177-
Self::try_new(DType::Extension(ext_dtype.erased()), value.into_value())
178-
.vortex_expect("unable to construct an extension `Scalar`")
177+
178+
Self::extension_ref(ext_dtype.erased(), storage_scalar)
179179
}
180180

181181
/// Creates a new extension scalar wrapping the given storage value.
182182
///
183183
/// # Panics
184184
///
185185
/// Panics if the storage dtype of `ext_dtype` does not match `value`'s dtype.
186-
pub fn extension_ref(ext_dtype: ExtDTypeRef, value: Scalar) -> Self {
187-
assert_eq!(ext_dtype.storage_dtype(), value.dtype());
188-
Self::try_new(DType::Extension(ext_dtype), value.into_value())
186+
pub fn extension_ref(ext_dtype: ExtDTypeRef, storage_scalar: Scalar) -> Self {
187+
assert_eq!(ext_dtype.storage_dtype(), storage_scalar.dtype());
188+
189+
Self::try_new(DType::Extension(ext_dtype), storage_scalar.into_value())
189190
.vortex_expect("unable to construct an extension `Scalar`")
190191
}
191192
}

vortex-array/src/scalar/display.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@ impl Display for Scalar {
2020
DType::Binary(_) => write!(f, "{}", self.as_binary()),
2121
DType::Struct(..) => write!(f, "{}", self.as_struct()),
2222
DType::List(..) | DType::FixedSizeList(..) => write!(f, "{}", self.as_list()),
23-
DType::Extension(_) => {
24-
// TODO(connor): This might need to change soon...
25-
write!(f, "{}", self.as_extension())
26-
}
23+
DType::Extension(_) => write!(f, "{}", self.as_extension()),
2724
}
2825
}
2926
}

vortex-array/src/scalar/downcast.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,13 @@ impl Scalar {
145145

146146
/// Returns a view of the scalar as an extension scalar if it has an extension type.
147147
pub fn as_extension_opt(&self) -> Option<ExtScalar<'_>> {
148-
ExtScalar::try_new(self.dtype(), self.value()).ok()
148+
if !self.dtype().is_extension() {
149+
return None;
150+
}
151+
152+
// SAFETY: Because we are a valid Scalar, we have already validated that the value is valid
153+
// for this extension type.
154+
Some(ExtScalar::new_unchecked(self.dtype(), self.value()))
149155
}
150156
}
151157

0 commit comments

Comments
 (0)