Skip to content

Commit 957fad8

Browse files
committed
feat: implement variant_length UDF
Closes #41 Returns the number of elements in a Variant Array or Object at the given path. Returns NULL for non-array/object variants or NULL inputs. - variant_length(VariantArray, path) -> UInt64 (Nullable) - Supports scalar and columnar variant inputs with scalar path - 5 unit tests covering array, object, nested path, primitive->NULL, columnar
1 parent 32353a2 commit 957fad8

2 files changed

Lines changed: 251 additions & 0 deletions

File tree

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod is_variant_null;
88
mod json_to_variant;
99
mod variant_contains;
1010
mod variant_get;
11+
mod variant_length;
1112
mod variant_list_construct;
1213
mod variant_list_delete;
1314
mod variant_list_insert;
@@ -24,6 +25,7 @@ pub use is_variant_null::*;
2425
pub use json_to_variant::*;
2526
pub use variant_contains::*;
2627
pub use variant_get::*;
28+
pub use variant_length::*;
2729
pub use variant_list_construct::*;
2830
pub use variant_list_delete::*;
2931
pub use variant_list_insert::*;

src/variant_length.rs

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
use std::sync::Arc;
2+
3+
use arrow::array::{ArrayRef, UInt64Array};
4+
use arrow_schema::DataType;
5+
use datafusion::common::exec_datafusion_err;
6+
use datafusion::error::Result;
7+
use datafusion::logical_expr::{
8+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
9+
};
10+
use datafusion::scalar::ScalarValue;
11+
use parquet_variant::Variant;
12+
use parquet_variant_compute::VariantArray;
13+
14+
use crate::shared::{path_from_scalar, try_field_as_variant_array};
15+
16+
#[derive(Debug, Hash, PartialEq, Eq)]
17+
pub struct VariantLengthUdf {
18+
signature: Signature,
19+
}
20+
21+
impl Default for VariantLengthUdf {
22+
fn default() -> Self {
23+
Self {
24+
signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
25+
}
26+
}
27+
}
28+
29+
/// Returns the number of elements in an Array or Object variant at the given path.
30+
/// Returns NULL if the variant at the path is not an Array or Object.
31+
fn variant_length(variant: Option<Variant<'_, '_>>) -> Option<u64> {
32+
let variant = variant?;
33+
match variant {
34+
Variant::List(list) => Some(list.len() as u64),
35+
Variant::Object(obj) => Some(obj.len() as u64),
36+
_ => None,
37+
}
38+
}
39+
40+
impl ScalarUDFImpl for VariantLengthUdf {
41+
fn as_any(&self) -> &dyn std::any::Any {
42+
self
43+
}
44+
45+
fn name(&self) -> &str {
46+
"variant_length"
47+
}
48+
49+
fn signature(&self) -> &Signature {
50+
&self.signature
51+
}
52+
53+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
54+
Ok(DataType::UInt64)
55+
}
56+
57+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
58+
let (variant_arg, path_arg) = match args.args.as_slice() {
59+
[variant_arg, path_arg] => (variant_arg, path_arg),
60+
_ => return datafusion::common::exec_err!("expected 2 arguments"),
61+
};
62+
63+
let variant_field = args
64+
.arg_fields
65+
.first()
66+
.ok_or_else(|| exec_datafusion_err!("expected argument field"))?;
67+
68+
try_field_as_variant_array(variant_field.as_ref())?;
69+
70+
match (variant_arg, path_arg) {
71+
(ColumnarValue::Array(variant_array), ColumnarValue::Scalar(path_scalar)) => {
72+
let path = path_from_scalar(path_scalar)?;
73+
let variant_array = VariantArray::try_new(variant_array.as_ref())?;
74+
let values = variant_array
75+
.iter()
76+
.map(|variant| {
77+
let at_path = variant.as_ref().and_then(|v| v.get_path(&path));
78+
variant_length(at_path)
79+
})
80+
.collect::<Vec<_>>();
81+
82+
Ok(ColumnarValue::Array(
83+
Arc::new(UInt64Array::from(values)) as ArrayRef,
84+
))
85+
}
86+
(ColumnarValue::Scalar(scalar_variant), ColumnarValue::Scalar(path_scalar)) => {
87+
let ScalarValue::Struct(variant_array) = scalar_variant else {
88+
return datafusion::common::exec_err!("expected struct array");
89+
};
90+
91+
let path = path_from_scalar(path_scalar)?;
92+
let variant_array = VariantArray::try_new(variant_array.as_ref())?;
93+
let variant = variant_array.iter().next().flatten();
94+
let at_path = variant.as_ref().and_then(|v| v.get_path(&path));
95+
let value = variant_length(at_path);
96+
97+
Ok(ColumnarValue::Scalar(ScalarValue::UInt64(value)))
98+
}
99+
_ => datafusion::common::exec_err!(
100+
"unsupported argument combination: variant_length requires a scalar path"
101+
),
102+
}
103+
}
104+
}
105+
106+
#[cfg(test)]
107+
mod tests {
108+
use std::sync::Arc;
109+
110+
use arrow::array::{Array, ArrayRef, UInt64Array};
111+
use arrow_schema::{DataType, Field, Fields};
112+
use parquet_variant_compute::VariantType;
113+
114+
use crate::shared::{build_variant_array_from_json_array, variant_scalar_from_json};
115+
116+
use super::*;
117+
118+
fn arg_fields() -> Vec<Arc<Field>> {
119+
vec![
120+
Arc::new(
121+
Field::new("input", DataType::Struct(Fields::empty()), true)
122+
.with_extension_type(VariantType),
123+
),
124+
Arc::new(Field::new("path", DataType::Utf8, true)),
125+
]
126+
}
127+
128+
#[test]
129+
fn test_scalar_array_length() {
130+
let udf = VariantLengthUdf::default();
131+
let args = ScalarFunctionArgs {
132+
args: vec![
133+
ColumnarValue::Scalar(variant_scalar_from_json(serde_json::json!([1, 2, 3]))),
134+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("".to_string()))),
135+
],
136+
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
137+
arg_fields: arg_fields(),
138+
number_rows: Default::default(),
139+
config_options: Default::default(),
140+
};
141+
142+
let result = udf.invoke_with_args(args).unwrap();
143+
let ColumnarValue::Scalar(ScalarValue::UInt64(Some(value))) = result else {
144+
panic!("expected u64 scalar")
145+
};
146+
assert_eq!(value, 3);
147+
}
148+
149+
#[test]
150+
fn test_scalar_object_length() {
151+
let udf = VariantLengthUdf::default();
152+
let args = ScalarFunctionArgs {
153+
args: vec![
154+
ColumnarValue::Scalar(variant_scalar_from_json(
155+
serde_json::json!({"a": 1, "b": 2}),
156+
)),
157+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("".to_string()))),
158+
],
159+
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
160+
arg_fields: arg_fields(),
161+
number_rows: Default::default(),
162+
config_options: Default::default(),
163+
};
164+
165+
let result = udf.invoke_with_args(args).unwrap();
166+
let ColumnarValue::Scalar(ScalarValue::UInt64(Some(value))) = result else {
167+
panic!("expected u64 scalar")
168+
};
169+
assert_eq!(value, 2);
170+
}
171+
172+
#[test]
173+
fn test_scalar_nested_path() {
174+
let udf = VariantLengthUdf::default();
175+
let args = ScalarFunctionArgs {
176+
args: vec![
177+
ColumnarValue::Scalar(variant_scalar_from_json(
178+
serde_json::json!({"a": [1, 2, 3, 4]}),
179+
)),
180+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))),
181+
],
182+
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
183+
arg_fields: arg_fields(),
184+
number_rows: Default::default(),
185+
config_options: Default::default(),
186+
};
187+
188+
let result = udf.invoke_with_args(args).unwrap();
189+
let ColumnarValue::Scalar(ScalarValue::UInt64(Some(value))) = result else {
190+
panic!("expected u64 scalar")
191+
};
192+
assert_eq!(value, 4);
193+
}
194+
195+
#[test]
196+
fn test_scalar_primitive_returns_null() {
197+
let udf = VariantLengthUdf::default();
198+
let args = ScalarFunctionArgs {
199+
args: vec![
200+
ColumnarValue::Scalar(variant_scalar_from_json(serde_json::json!(42))),
201+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("".to_string()))),
202+
],
203+
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
204+
arg_fields: arg_fields(),
205+
number_rows: Default::default(),
206+
config_options: Default::default(),
207+
};
208+
209+
let result = udf.invoke_with_args(args).unwrap();
210+
let ColumnarValue::Scalar(ScalarValue::UInt64(value)) = result else {
211+
panic!("expected u64 scalar")
212+
};
213+
assert_eq!(value, None);
214+
}
215+
216+
#[test]
217+
fn test_array_variants() {
218+
let udf = VariantLengthUdf::default();
219+
let input = build_variant_array_from_json_array(&[
220+
Some(serde_json::json!([1, 2])),
221+
Some(serde_json::json!({"a": 1, "b": 2, "c": 3})),
222+
Some(serde_json::json!(42)),
223+
None,
224+
]);
225+
let args = ScalarFunctionArgs {
226+
args: vec![
227+
ColumnarValue::Array(
228+
Arc::new(arrow::array::StructArray::from(input)) as ArrayRef
229+
),
230+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("".to_string()))),
231+
],
232+
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
233+
arg_fields: arg_fields(),
234+
number_rows: Default::default(),
235+
config_options: Default::default(),
236+
};
237+
238+
let result = udf.invoke_with_args(args).unwrap();
239+
let ColumnarValue::Array(values) = result else {
240+
panic!("expected array")
241+
};
242+
243+
let values = values.as_any().downcast_ref::<UInt64Array>().unwrap();
244+
assert_eq!(
245+
values.into_iter().collect::<Vec<_>>(),
246+
vec![Some(2), Some(3), None, None]
247+
);
248+
}
249+
}

0 commit comments

Comments
 (0)