Skip to content

Commit a4725d8

Browse files
committed
2 parents d678021 + ebddd5e commit a4725d8

4 files changed

Lines changed: 332 additions & 53 deletions

File tree

src/impl_variant_get.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,24 @@ macro_rules! impl_variant_get_typed {
1212
#[derive(Debug, Hash, PartialEq, Eq)]
1313
pub struct $struct_name {
1414
signature: Signature,
15+
path_mode: crate::variant_get::PathMode,
1516
}
1617

1718
impl Default for $struct_name {
1819
fn default() -> Self {
1920
Self {
2021
signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
22+
path_mode: crate::variant_get::PathMode::DotNotation,
23+
}
24+
}
25+
}
26+
27+
impl $struct_name {
28+
/// Create a new instance with the specified path mode.
29+
pub fn with_path_mode(path_mode: crate::variant_get::PathMode) -> Self {
30+
Self {
31+
signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
32+
path_mode,
2133
}
2234
}
2335
}
@@ -40,7 +52,7 @@ macro_rules! impl_variant_get_typed {
4052
}
4153

4254
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
43-
invoke_variant_get_typed(args, $scalar_from, $array_from, $extract)
55+
invoke_variant_get_typed(args, $scalar_from, $array_from, $extract, self.path_mode)
4456
}
4557
}
4658
};

src/shared.rs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ use datafusion::common::exec_datafusion_err;
1111
use datafusion::error::{DataFusionError, Result};
1212
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs};
1313
use datafusion::{common::exec_err, scalar::ScalarValue};
14-
use parquet_variant::{Variant, VariantPath};
14+
use parquet_variant::Variant;
1515
use parquet_variant_compute::{VariantArray, VariantType};
1616

17+
use crate::variant_get::PathMode;
18+
1719
#[cfg(test)]
1820
use parquet_variant_compute::VariantArrayBuilder;
1921

@@ -128,13 +130,14 @@ pub fn variant_get_single_value<T>(
128130
variant_array: &VariantArray,
129131
index: usize,
130132
path: &str,
133+
path_mode: PathMode,
131134
extract: for<'m, 'v> fn(Variant<'m, 'v>) -> Result<Option<T>>,
132135
) -> Result<Option<T>> {
133136
let Some(variant) = variant_array.iter().nth(index).flatten() else {
134137
return Ok(None);
135138
};
136139

137-
let variant_path = VariantPath::try_from(path)?;
140+
let variant_path = path_mode.try_build_path(path)?;
138141
let Some(value) = variant.get_path(&variant_path) else {
139142
return Ok(None);
140143
};
@@ -145,9 +148,10 @@ pub fn variant_get_single_value<T>(
145148
pub fn variant_get_array_values<T>(
146149
variant_array: &VariantArray,
147150
path: &str,
151+
path_mode: PathMode,
148152
extract: for<'m, 'v> fn(Variant<'m, 'v>) -> Result<Option<T>>,
149153
) -> Result<Vec<Option<T>>> {
150-
let variant_path = VariantPath::try_from(path)?;
154+
let variant_path = path_mode.try_build_path(path)?;
151155

152156
variant_array
153157
.iter()
@@ -170,6 +174,7 @@ pub fn invoke_variant_get_typed<T>(
170174
scalar_from_option: fn(Option<T>) -> ScalarValue,
171175
array_from_values: fn(Vec<Option<T>>) -> ArrayRef,
172176
extract: for<'m, 'v> fn(Variant<'m, 'v>) -> Result<Option<T>>,
177+
path_mode: PathMode,
173178
) -> Result<ColumnarValue> {
174179
let (variant_arg, path_arg) = match args.args.as_slice() {
175180
[variant_arg, path_arg] => (variant_arg, path_arg),
@@ -190,7 +195,7 @@ pub fn invoke_variant_get_typed<T>(
190195
.unwrap_or_default();
191196

192197
let variant_array = VariantArray::try_new(variant_array.as_ref())?;
193-
let values = variant_get_array_values(&variant_array, path, extract)?;
198+
let values = variant_get_array_values(&variant_array, path, path_mode, extract)?;
194199
ColumnarValue::Array(array_from_values(values))
195200
}
196201
(ColumnarValue::Scalar(scalar_variant), ColumnarValue::Scalar(path_scalar)) => {
@@ -203,7 +208,7 @@ pub fn invoke_variant_get_typed<T>(
203208
.unwrap_or_default();
204209

205210
let variant_array = VariantArray::try_new(variant_array.as_ref())?;
206-
let value = variant_get_single_value(&variant_array, 0, path, extract)?;
211+
let value = variant_get_single_value(&variant_array, 0, path, path_mode, extract)?;
207212

208213
ColumnarValue::Scalar(scalar_from_option(value))
209214
}
@@ -218,7 +223,7 @@ pub fn invoke_variant_get_typed<T>(
218223
let values: Vec<Option<T>> = (0..variant_array.len())
219224
.map(|i| {
220225
let path = paths[i].unwrap_or_default();
221-
variant_get_single_value(&variant_array, i, path, extract)
226+
variant_get_single_value(&variant_array, i, path, path_mode, extract)
222227
})
223228
.collect::<Result<_>>()?;
224229

@@ -236,7 +241,7 @@ pub fn invoke_variant_get_typed<T>(
236241
.iter()
237242
.map(|path| {
238243
let path = path.unwrap_or_default();
239-
variant_get_single_value(&variant_array, 0, path, extract)
244+
variant_get_single_value(&variant_array, 0, path, path_mode, extract)
240245
})
241246
.collect::<Result<_>>()?;
242247

src/variant_get.rs

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ fn build_get_options<'a>(path: VariantPath<'a>, as_type: &Option<FieldRef>) -> G
7474

7575
/// Determines how a string path is converted to a [`VariantPath`].
7676
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
77-
enum PathMode {
77+
pub enum PathMode {
7878
/// Splits the path on `.` for dot-notation traversal (e.g., `"a.b.c"` → `["a", "b", "c"]`).
7979
DotNotation,
8080
/// Treats the entire path string as a single field name (e.g., `"a.b.c"` → `["a.b.c"]`).
@@ -84,7 +84,7 @@ enum PathMode {
8484
}
8585

8686
impl PathMode {
87-
fn try_build_path<'a>(&self, path: &'a str) -> Result<VariantPath<'a>> {
87+
pub fn try_build_path<'a>(&self, path: &'a str) -> Result<VariantPath<'a>> {
8888
match self {
8989
PathMode::DotNotation => VariantPath::try_from(path).map_err(Into::into),
9090
PathMode::SingleField => Ok(VariantPath::new(vec![VariantPathElement::field(path)])),
@@ -1556,4 +1556,134 @@ mod tests {
15561556
assert!(bool_arr.is_null(2));
15571557
assert!(bool_arr.is_null(3));
15581558
}
1559+
1560+
#[test]
1561+
fn test_get_str_with_single_field_mode_dotted_key() {
1562+
// VariantGetStrUdf with SingleField mode should treat dotted keys as a single field
1563+
let variant_input = variant_scalar_from_json(serde_json::json!({
1564+
"http.response.status_code": 200,
1565+
"service.name": "my-service"
1566+
}));
1567+
1568+
let udf = VariantGetStrUdf::with_path_mode(PathMode::SingleField);
1569+
let args = build_variant_get_args(
1570+
ColumnarValue::Scalar(variant_input),
1571+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(
1572+
"http.response.status_code".to_string(),
1573+
))),
1574+
DataType::Utf8View,
1575+
standard_variant_get_arg_fields(),
1576+
);
1577+
1578+
let result = udf.invoke_with_args(args).unwrap();
1579+
1580+
let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) = result else {
1581+
panic!("expected Utf8View scalar, got {result:?}");
1582+
};
1583+
// Integer 200 should be JSON-serialized to "200"
1584+
assert_eq!(s, "200");
1585+
}
1586+
1587+
#[test]
1588+
fn test_get_str_with_single_field_mode_string_value() {
1589+
let variant_input = variant_scalar_from_json(serde_json::json!({
1590+
"service.name": "my-service"
1591+
}));
1592+
1593+
let udf = VariantGetStrUdf::with_path_mode(PathMode::SingleField);
1594+
let args = build_variant_get_args(
1595+
ColumnarValue::Scalar(variant_input),
1596+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("service.name".to_string()))),
1597+
DataType::Utf8View,
1598+
standard_variant_get_arg_fields(),
1599+
);
1600+
1601+
let result = udf.invoke_with_args(args).unwrap();
1602+
1603+
let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) = result else {
1604+
panic!("expected Utf8View scalar, got {result:?}");
1605+
};
1606+
// String values returned as-is (no JSON quotes)
1607+
assert_eq!(s, "my-service");
1608+
}
1609+
1610+
#[test]
1611+
fn test_get_str_with_single_field_mode_array() {
1612+
// Test array input with SingleField mode
1613+
let json_rows = vec![
1614+
serde_json::json!({"http.status": 200, "http.method": "GET"}),
1615+
serde_json::json!({"http.status": 404, "http.method": "POST"}),
1616+
serde_json::json!({"http.status": 500}),
1617+
];
1618+
1619+
let variant_array = variant_array_from_json_rows(&json_rows);
1620+
1621+
let udf = VariantGetStrUdf::with_path_mode(PathMode::SingleField);
1622+
let args = build_variant_get_args(
1623+
ColumnarValue::Array(variant_array),
1624+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("http.status".to_string()))),
1625+
DataType::Utf8View,
1626+
standard_variant_get_arg_fields(),
1627+
);
1628+
1629+
let result = udf.invoke_with_args(args).unwrap();
1630+
1631+
let ColumnarValue::Array(arr) = result else {
1632+
panic!("expected array output");
1633+
};
1634+
let str_arr = arr.as_any().downcast_ref::<StringViewArray>().unwrap();
1635+
assert_eq!(str_arr.len(), 3);
1636+
assert_eq!(str_arr.value(0), "200");
1637+
assert_eq!(str_arr.value(1), "404");
1638+
assert_eq!(str_arr.value(2), "500");
1639+
}
1640+
1641+
#[test]
1642+
fn test_get_int_with_single_field_mode() {
1643+
let variant_input = variant_scalar_from_json(serde_json::json!({
1644+
"http.status": 200
1645+
}));
1646+
1647+
let udf = VariantGetIntUdf::with_path_mode(PathMode::SingleField);
1648+
let args = build_variant_get_args(
1649+
ColumnarValue::Scalar(variant_input),
1650+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("http.status".to_string()))),
1651+
DataType::Int64,
1652+
standard_variant_get_arg_fields(),
1653+
);
1654+
1655+
let result = udf.invoke_with_args(args).unwrap();
1656+
1657+
let ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) = result else {
1658+
panic!("expected Int64 scalar, got {result:?}");
1659+
};
1660+
assert_eq!(v, 200);
1661+
}
1662+
1663+
#[test]
1664+
fn test_get_str_dot_notation_splits_dotted_key() {
1665+
// With default DotNotation mode, dotted keys are split — should return NULL
1666+
// for a key like "http.response.status_code" that's stored as a single field
1667+
let variant_input = variant_scalar_from_json(serde_json::json!({
1668+
"http.response.status_code": 200
1669+
}));
1670+
1671+
let udf = VariantGetStrUdf::default(); // DotNotation
1672+
let args = build_variant_get_args(
1673+
ColumnarValue::Scalar(variant_input),
1674+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(
1675+
"http.response.status_code".to_string(),
1676+
))),
1677+
DataType::Utf8View,
1678+
standard_variant_get_arg_fields(),
1679+
);
1680+
1681+
let result = udf.invoke_with_args(args).unwrap();
1682+
1683+
// DotNotation splits on dots, tries to traverse http -> response -> status_code
1684+
// which doesn't exist, so returns NULL
1685+
let ColumnarValue::Scalar(ScalarValue::Utf8View(None)) = result else {
1686+
panic!("expected NULL (dot notation splits the key), got {result:?}");
1687+
};
1688+
}
15591689
}

0 commit comments

Comments
 (0)