Skip to content

Commit 92eb5a3

Browse files
getChanCopilot
andauthored
Add variant_contains UDF (#59)
* Add variant_contains UDF * chore: remove .idea from gitignore Agent-Logs-Url: https://github.com/getChan/datafusion-variant/sessions/4c154f78-d449-4fc1-b0da-6c643e55e364 Co-authored-by: getChan <33323415+getChan@users.noreply.github.com> * cargo fmt --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
1 parent a334066 commit 92eb5a3

7 files changed

Lines changed: 360 additions & 9 deletions

File tree

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
/target
22
/data/*
3-
profile.json.gz
3+
profile.json.gz

examples/cli.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ use arrow::datatypes::{DataType, Field, Schema};
44
use datafusion::logical_expr::ScalarUDF;
55
use datafusion::prelude::*;
66
use datafusion_variant::{
7-
CastToVariantUdf, IsVariantNullUdf, JsonToVariantUdf, VariantGetUdf, VariantListConstruct,
8-
VariantListInsert, VariantObjectConstruct, VariantObjectInsert, VariantObjectKeys,
9-
VariantPretty, VariantToJsonUdf,
7+
CastToVariantUdf, IsVariantNullUdf, JsonToVariantUdf, VariantContainsUdf, VariantGetUdf,
8+
VariantListConstruct, VariantListInsert, VariantObjectConstruct, VariantObjectInsert,
9+
VariantObjectKeys, VariantPretty, VariantToJsonUdf,
1010
};
1111
use flate2::read::GzDecoder;
1212
use rustyline::error::ReadlineError;
@@ -112,6 +112,7 @@ async fn main() -> Result<()> {
112112
ctx.register_udf(ScalarUDF::new_from_impl(JsonToVariantUdf::default()));
113113
ctx.register_udf(ScalarUDF::new_from_impl(CastToVariantUdf::default()));
114114
ctx.register_udf(ScalarUDF::new_from_impl(IsVariantNullUdf::default()));
115+
ctx.register_udf(ScalarUDF::new_from_impl(VariantContainsUdf::default()));
115116
ctx.register_udf(ScalarUDF::new_from_impl(VariantGetUdf::default()));
116117
ctx.register_udf(ScalarUDF::new_from_impl(VariantPretty::default()));
117118
ctx.register_udf(ScalarUDF::new_from_impl(VariantObjectConstruct::default()));

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mod cast_to_variant;
66
mod impl_variant_get;
77
mod is_variant_null;
88
mod json_to_variant;
9+
mod variant_contains;
910
mod variant_get;
1011
mod variant_list_construct;
1112
mod variant_list_delete;
@@ -21,6 +22,7 @@ mod variant_to_json;
2122
pub use cast_to_variant::*;
2223
pub use is_variant_null::*;
2324
pub use json_to_variant::*;
25+
pub use variant_contains::*;
2426
pub use variant_get::*;
2527
pub use variant_list_construct::*;
2628
pub use variant_list_delete::*;

src/shared.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ pub fn variant_get_array_values<T>(
168168
/// - **List** scalars treat each element as a single field name
169169
/// (e.g. `['a.b', 'c']` → path `[a.b, c]`), which is critical for keys that
170170
/// contain dots such as OTEL attribute keys like `http.response.status_code`.
171-
fn path_from_scalar(scalar: &ScalarValue) -> Result<VariantPath<'static>> {
171+
pub(crate) fn path_from_scalar(scalar: &ScalarValue) -> Result<VariantPath<'static>> {
172172
match scalar {
173173
ScalarValue::Utf8(Some(s))
174174
| ScalarValue::Utf8View(Some(s))

src/variant_contains.rs

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
use std::sync::Arc;
2+
3+
use arrow::array::{ArrayRef, BooleanArray};
4+
use arrow_schema::DataType;
5+
use datafusion::common::{exec_datafusion_err, exec_err};
6+
use datafusion::error::{DataFusionError, Result};
7+
use datafusion::logical_expr::{
8+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
9+
};
10+
use datafusion::scalar::ScalarValue;
11+
use parquet_variant::{Variant, VariantPath};
12+
use parquet_variant_compute::VariantArray;
13+
14+
use crate::shared::{path_from_scalar, try_field_as_variant_array, try_parse_string_columnar};
15+
16+
#[derive(Debug, Hash, PartialEq, Eq)]
17+
pub struct VariantContainsUdf {
18+
signature: Signature,
19+
}
20+
21+
impl Default for VariantContainsUdf {
22+
fn default() -> Self {
23+
Self {
24+
signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
25+
}
26+
}
27+
}
28+
29+
fn variant_contains(variant: Option<&Variant<'_, '_>>, path: &VariantPath<'_>) -> Option<bool> {
30+
variant.map(|value| value.get_path(path).is_some())
31+
}
32+
33+
impl ScalarUDFImpl for VariantContainsUdf {
34+
fn as_any(&self) -> &dyn std::any::Any {
35+
self
36+
}
37+
38+
fn name(&self) -> &str {
39+
"variant_contains"
40+
}
41+
42+
fn signature(&self) -> &Signature {
43+
&self.signature
44+
}
45+
46+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
47+
Ok(DataType::Boolean)
48+
}
49+
50+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
51+
let (variant_arg, path_arg) = match args.args.as_slice() {
52+
[variant_arg, path_arg] => (variant_arg, path_arg),
53+
_ => return exec_err!("expected 2 arguments"),
54+
};
55+
56+
let variant_field = args
57+
.arg_fields
58+
.first()
59+
.ok_or_else(|| exec_datafusion_err!("expected argument field"))?;
60+
61+
try_field_as_variant_array(variant_field.as_ref())?;
62+
63+
match (variant_arg, path_arg) {
64+
(ColumnarValue::Array(variant_array), ColumnarValue::Scalar(path_scalar)) => {
65+
if path_scalar.is_null() {
66+
return exec_err!("path argument must be non-null");
67+
}
68+
69+
let path = path_from_scalar(path_scalar)?;
70+
let variant_array = VariantArray::try_new(variant_array.as_ref())?;
71+
let values = variant_array
72+
.iter()
73+
.map(|variant| variant_contains(variant.as_ref(), &path))
74+
.collect::<Vec<_>>();
75+
76+
Ok(ColumnarValue::Array(
77+
Arc::new(BooleanArray::from(values)) as ArrayRef
78+
))
79+
}
80+
(ColumnarValue::Scalar(scalar_variant), ColumnarValue::Scalar(path_scalar)) => {
81+
let ScalarValue::Struct(variant_array) = scalar_variant else {
82+
return exec_err!("expected struct array");
83+
};
84+
85+
if path_scalar.is_null() {
86+
return exec_err!("path argument must be non-null");
87+
}
88+
89+
let path = path_from_scalar(path_scalar)?;
90+
let variant_array = VariantArray::try_new(variant_array.as_ref())?;
91+
let variant = variant_array.iter().next().flatten();
92+
let value = variant_contains(variant.as_ref(), &path);
93+
94+
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(value)))
95+
}
96+
(ColumnarValue::Array(variant_array), ColumnarValue::Array(paths)) => {
97+
if variant_array.len() != paths.len() {
98+
return exec_err!("expected variant array and paths to be of same length");
99+
}
100+
101+
let variant_array = VariantArray::try_new(variant_array.as_ref())?;
102+
let paths = try_parse_string_columnar(paths)?;
103+
104+
let values = variant_array
105+
.iter()
106+
.zip(paths)
107+
.map(|(maybe_variant, path_str)| {
108+
let path_str = path_str.ok_or_else(|| {
109+
exec_datafusion_err!("path argument must be non-null")
110+
})?;
111+
let path = VariantPath::try_from(path_str)
112+
.map_err(Into::<DataFusionError>::into)?;
113+
114+
Ok(variant_contains(maybe_variant.as_ref(), &path))
115+
})
116+
.collect::<Result<Vec<_>>>()?;
117+
118+
Ok(ColumnarValue::Array(
119+
Arc::new(BooleanArray::from(values)) as ArrayRef
120+
))
121+
}
122+
(ColumnarValue::Scalar(scalar_variant), ColumnarValue::Array(paths)) => {
123+
let ScalarValue::Struct(variant_array) = scalar_variant else {
124+
return exec_err!("expected struct array");
125+
};
126+
127+
let variant_array = VariantArray::try_new(variant_array.as_ref())?;
128+
let variant = variant_array.iter().next().flatten();
129+
let paths = try_parse_string_columnar(paths)?;
130+
131+
let values = paths
132+
.into_iter()
133+
.map(|path_str| {
134+
let path_str = path_str.ok_or_else(|| {
135+
exec_datafusion_err!("path argument must be non-null")
136+
})?;
137+
let path = VariantPath::try_from(path_str)
138+
.map_err(Into::<DataFusionError>::into)?;
139+
140+
Ok(variant_contains(variant.as_ref(), &path))
141+
})
142+
.collect::<Result<Vec<_>>>()?;
143+
144+
Ok(ColumnarValue::Array(
145+
Arc::new(BooleanArray::from(values)) as ArrayRef
146+
))
147+
}
148+
}
149+
}
150+
}
151+
152+
#[cfg(test)]
153+
mod tests {
154+
use std::sync::Arc;
155+
156+
use arrow::array::{Array, ArrayRef, BooleanArray, StringArray};
157+
use arrow_schema::{Field, Fields};
158+
use parquet_variant_compute::VariantType;
159+
160+
use crate::shared::{build_variant_array_from_json_array, variant_scalar_from_json};
161+
162+
use super::*;
163+
164+
fn arg_fields() -> Vec<Arc<Field>> {
165+
vec![
166+
Arc::new(
167+
Field::new("input", DataType::Struct(Fields::empty()), true)
168+
.with_extension_type(VariantType),
169+
),
170+
Arc::new(Field::new("path", DataType::Utf8, true)),
171+
]
172+
}
173+
174+
#[test]
175+
fn test_scalar_existing_path_returns_true() {
176+
let udf = VariantContainsUdf::default();
177+
let args = ScalarFunctionArgs {
178+
args: vec![
179+
ColumnarValue::Scalar(variant_scalar_from_json(serde_json::json!({
180+
"a": {"b": null}
181+
}))),
182+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("a.b".to_string()))),
183+
],
184+
return_field: Arc::new(Field::new("result", DataType::Boolean, true)),
185+
arg_fields: arg_fields(),
186+
number_rows: Default::default(),
187+
config_options: Default::default(),
188+
};
189+
190+
let result = udf.invoke_with_args(args).unwrap();
191+
let ColumnarValue::Scalar(ScalarValue::Boolean(Some(value))) = result else {
192+
panic!("expected boolean scalar")
193+
};
194+
195+
assert!(value);
196+
}
197+
198+
#[test]
199+
fn test_scalar_missing_path_returns_false() {
200+
let udf = VariantContainsUdf::default();
201+
let args = ScalarFunctionArgs {
202+
args: vec![
203+
ColumnarValue::Scalar(variant_scalar_from_json(serde_json::json!({
204+
"a": 1
205+
}))),
206+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("a.b".to_string()))),
207+
],
208+
return_field: Arc::new(Field::new("result", DataType::Boolean, true)),
209+
arg_fields: arg_fields(),
210+
number_rows: Default::default(),
211+
config_options: Default::default(),
212+
};
213+
214+
let result = udf.invoke_with_args(args).unwrap();
215+
let ColumnarValue::Scalar(ScalarValue::Boolean(Some(value))) = result else {
216+
panic!("expected boolean scalar")
217+
};
218+
219+
assert!(!value);
220+
}
221+
222+
#[test]
223+
fn test_array_paths_and_null_variant() {
224+
let udf = VariantContainsUdf::default();
225+
let input = build_variant_array_from_json_array(&[
226+
Some(serde_json::json!({"a": 1})),
227+
Some(serde_json::json!({"a": null})),
228+
None,
229+
]);
230+
let args = ScalarFunctionArgs {
231+
args: vec![
232+
ColumnarValue::Array(Arc::new(arrow::array::StructArray::from(input)) as ArrayRef),
233+
ColumnarValue::Array(Arc::new(StringArray::from(vec![
234+
Some("a"),
235+
Some("a"),
236+
Some("a"),
237+
])) as ArrayRef),
238+
],
239+
return_field: Arc::new(Field::new("result", DataType::Boolean, true)),
240+
arg_fields: arg_fields(),
241+
number_rows: Default::default(),
242+
config_options: Default::default(),
243+
};
244+
245+
let result = udf.invoke_with_args(args).unwrap();
246+
let ColumnarValue::Array(values) = result else {
247+
panic!("expected boolean array")
248+
};
249+
250+
let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
251+
assert_eq!(
252+
values.into_iter().collect::<Vec<_>>(),
253+
vec![Some(true), Some(true), None]
254+
);
255+
}
256+
257+
#[test]
258+
fn test_array_variant_scalar_path() {
259+
let udf = VariantContainsUdf::default();
260+
let input = build_variant_array_from_json_array(&[
261+
Some(serde_json::json!({"a": 1})),
262+
Some(serde_json::json!({"b": 1})),
263+
]);
264+
let args = ScalarFunctionArgs {
265+
args: vec![
266+
ColumnarValue::Array(Arc::new(arrow::array::StructArray::from(input)) as ArrayRef),
267+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))),
268+
],
269+
return_field: Arc::new(Field::new("result", DataType::Boolean, true)),
270+
arg_fields: arg_fields(),
271+
number_rows: Default::default(),
272+
config_options: Default::default(),
273+
};
274+
275+
let result = udf.invoke_with_args(args).unwrap();
276+
let ColumnarValue::Array(values) = result else {
277+
panic!("expected boolean array")
278+
};
279+
280+
let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
281+
assert_eq!(
282+
values.into_iter().collect::<Vec<_>>(),
283+
vec![Some(true), Some(false)]
284+
);
285+
}
286+
}

tests/sqllogictests.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use datafusion::{logical_expr::ScalarUDF, prelude::*};
22
use datafusion_sqllogictest::{DataFusion, TestContext};
33
use datafusion_variant::{
4-
CastToVariantUdf, IsVariantNullUdf, JsonToVariantUdf, VariantGetBoolUdf, VariantGetFieldUdf,
5-
VariantGetFloatUdf, VariantGetIntUdf, VariantGetJsonUdf, VariantGetStrUdf, VariantGetUdf,
6-
VariantListConstruct, VariantListDelete, VariantListInsert, VariantObjectConstruct,
7-
VariantObjectDelete, VariantObjectInsert, VariantObjectKeys, VariantPretty, VariantToJsonUdf,
4+
CastToVariantUdf, IsVariantNullUdf, JsonToVariantUdf, VariantContainsUdf, VariantGetBoolUdf,
5+
VariantGetFieldUdf, VariantGetFloatUdf, VariantGetIntUdf, VariantGetJsonUdf, VariantGetStrUdf,
6+
VariantGetUdf, VariantListConstruct, VariantListDelete, VariantListInsert,
7+
VariantObjectConstruct, VariantObjectDelete, VariantObjectInsert, VariantObjectKeys,
8+
VariantPretty, VariantToJsonUdf,
89
};
910
use indicatif::ProgressBar;
1011
use sqllogictest::strict_column_validator;
@@ -49,6 +50,7 @@ async fn run_sqllogictests() -> Result<(), Box<dyn std::error::Error>> {
4950
ctx.register_udf(ScalarUDF::new_from_impl(JsonToVariantUdf::default()));
5051
ctx.register_udf(ScalarUDF::new_from_impl(CastToVariantUdf::default()));
5152
ctx.register_udf(ScalarUDF::new_from_impl(IsVariantNullUdf::default()));
53+
ctx.register_udf(ScalarUDF::new_from_impl(VariantContainsUdf::default()));
5254
ctx.register_udf(ScalarUDF::new_from_impl(VariantGetUdf::default()));
5355
ctx.register_udf(ScalarUDF::new_from_impl(VariantGetStrUdf::default()));
5456
ctx.register_udf(ScalarUDF::new_from_impl(VariantGetFloatUdf::default()));

0 commit comments

Comments
 (0)