Skip to content

Commit 41237c0

Browse files
Support more cast_to_variant arguments
1 parent b394b75 commit 41237c0

2 files changed

Lines changed: 83 additions & 38 deletions

File tree

src/cast_to_variant.rs

Lines changed: 82 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use datafusion::{
1313
use parquet_variant::Variant;
1414
use parquet_variant_compute::{VariantArray, VariantArrayBuilder};
1515

16-
use crate::shared::{try_field_as_binary, try_parse_binary_columnar, try_parse_binary_scalar};
16+
use crate::shared::{try_parse_binary_columnar, try_parse_binary_scalar};
1717

1818
#[derive(Debug, Hash, PartialEq, Eq)]
1919
pub struct CastToVariantUdf {
@@ -23,48 +23,26 @@ pub struct CastToVariantUdf {
2323
impl Default for CastToVariantUdf {
2424
fn default() -> Self {
2525
Self {
26-
signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
26+
signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
2727
}
2828
}
2929
}
3030

31-
impl ScalarUDFImpl for CastToVariantUdf {
32-
fn as_any(&self) -> &dyn std::any::Any {
33-
self
34-
}
35-
36-
fn name(&self) -> &str {
37-
"cast_to_variant"
38-
}
39-
40-
fn signature(&self) -> &Signature {
41-
&self.signature
42-
}
43-
44-
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
45-
Ok(DataType::Struct(Fields::from(vec![
46-
Field::new("metadata", DataType::BinaryView, false),
47-
Field::new("value", DataType::BinaryView, true),
48-
])))
49-
}
31+
fn build_variant_array<'m, 'v, T: Into<Variant<'m, 'v>>>(
32+
value_opt: Option<T>,
33+
) -> Result<ColumnarValue> {
34+
let variant_array = VariantArray::from_iter([value_opt.map(|v| v.into())]).into();
5035

51-
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
52-
match args.arg_fields.as_slice() {
53-
[metadata_field, variant_field] => {
54-
try_field_as_binary(metadata_field.as_ref())?;
55-
try_field_as_binary(variant_field.as_ref())?;
56-
}
57-
_ => {
58-
// right now, let's only support (BinaryViewArray, BinaryViewArray)
59-
// but I don't see why we couldn't call cast_to_variant(string_column) -> VariantArray...
60-
return exec_err!("unsupported, expected 2 arguments");
61-
}
62-
}
63-
64-
let [metadata_argument, variant_argument] = args.args.as_slice() else {
65-
return exec_err!("expected 2 arguments");
66-
};
36+
Ok(ColumnarValue::Scalar(ScalarValue::Struct(Arc::new(
37+
variant_array,
38+
))))
39+
}
6740

41+
impl CastToVariantUdf {
42+
fn from_metadata_value(
43+
metadata_argument: &ColumnarValue,
44+
variant_argument: &ColumnarValue,
45+
) -> Result<ColumnarValue> {
6846
let out = match (metadata_argument, variant_argument) {
6947
(ColumnarValue::Array(metadata_array), ColumnarValue::Array(value_array)) => {
7048
if metadata_array.len() != value_array.len() {
@@ -154,6 +132,73 @@ impl ScalarUDFImpl for CastToVariantUdf {
154132

155133
Ok(out)
156134
}
135+
136+
fn from_array(_array: &ArrayRef) -> Result<ColumnarValue> {
137+
todo!()
138+
}
139+
140+
fn from_scalar_value(scalar_value: &ScalarValue) -> Result<ColumnarValue> {
141+
match scalar_value {
142+
ScalarValue::Null => build_variant_array(Some(Variant::Null)),
143+
// String values
144+
ScalarValue::Utf8(string_opt)
145+
| ScalarValue::Utf8View(string_opt)
146+
| ScalarValue::LargeUtf8(string_opt) => {
147+
build_variant_array(string_opt.as_ref().map(|s| s.as_str()))
148+
}
149+
// Binary values
150+
ScalarValue::Binary(binary_opt)
151+
| ScalarValue::BinaryView(binary_opt)
152+
| ScalarValue::LargeBinary(binary_opt) => {
153+
build_variant_array(binary_opt.as_ref().map(|b| b.as_slice()))
154+
}
155+
// Boolean
156+
ScalarValue::Boolean(b) => build_variant_array(b.as_ref().map(|b| *b)),
157+
// Numbers
158+
ScalarValue::Int8(i) => build_variant_array(i.as_ref().map(|i| *i)),
159+
ScalarValue::Int16(i) => build_variant_array(i.as_ref().map(|i| *i)),
160+
ScalarValue::Int32(i) => build_variant_array(i.as_ref().map(|i| *i)),
161+
ScalarValue::Int64(i) => build_variant_array(i.as_ref().map(|i| *i)),
162+
ScalarValue::UInt8(i) => build_variant_array(i.as_ref().map(|i| *i)),
163+
ScalarValue::UInt16(i) => build_variant_array(i.as_ref().map(|i| *i)),
164+
ScalarValue::UInt32(i) => build_variant_array(i.as_ref().map(|i| *i)),
165+
ScalarValue::UInt64(i) => build_variant_array(i.as_ref().map(|i| *i)),
166+
167+
_ => todo!(),
168+
}
169+
}
170+
}
171+
172+
impl ScalarUDFImpl for CastToVariantUdf {
173+
fn as_any(&self) -> &dyn std::any::Any {
174+
self
175+
}
176+
177+
fn name(&self) -> &str {
178+
"cast_to_variant"
179+
}
180+
181+
fn signature(&self) -> &Signature {
182+
&self.signature
183+
}
184+
185+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
186+
Ok(DataType::Struct(Fields::from(vec![
187+
Field::new("metadata", DataType::BinaryView, false),
188+
Field::new("value", DataType::BinaryView, true),
189+
])))
190+
}
191+
192+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
193+
match args.args.as_slice() {
194+
[metadata_value, variant_value] => {
195+
Self::from_metadata_value(metadata_value, variant_value)
196+
}
197+
[ColumnarValue::Scalar(scalar_value)] => Self::from_scalar_value(scalar_value),
198+
[ColumnarValue::Array(array)] => Self::from_array(array),
199+
_ => exec_err!("unrecognized argument"),
200+
}
201+
}
157202
}
158203

159204
#[cfg(test)]

src/shared.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ pub fn try_field_as_variant_array(field: &Field) -> Result<()> {
2626
Ok(())
2727
}
2828

29-
pub fn try_field_as_binary(field: &Field) -> Result<()> {
29+
pub fn _try_field_as_binary(field: &Field) -> Result<()> {
3030
match field.data_type() {
3131
DataType::Binary | DataType::BinaryView | DataType::LargeBinary => {}
3232
unsupported => return exec_err!("expected binary field, got {unsupported} field"),

0 commit comments

Comments
 (0)