Skip to content

Commit a22cf08

Browse files
authored
implement variant_get_int (#51)
* implement variant_get_int * move common variant_get_typed helpers to shared.rs * variant_get could use the same helpers in tests
1 parent 27aef49 commit a22cf08

7 files changed

Lines changed: 585 additions & 224 deletions

File tree

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mod cast_to_variant;
66
mod is_variant_null;
77
mod json_to_variant;
88
mod variant_get;
9+
mod variant_get_int;
910
mod variant_get_str;
1011
mod variant_list_construct;
1112
mod variant_list_delete;
@@ -21,6 +22,7 @@ pub use cast_to_variant::*;
2122
pub use is_variant_null::*;
2223
pub use json_to_variant::*;
2324
pub use variant_get::*;
25+
pub use variant_get_int::*;
2426
pub use variant_get_str::*;
2527
pub use variant_list_construct::*;
2628
pub use variant_list_delete::*;

src/shared.rs

Lines changed: 174 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
use std::sync::Arc;
22

3-
use arrow::array::{Array, cast::AsArray};
3+
#[cfg(test)]
4+
use arrow::array::StructArray;
5+
use arrow::array::{Array, ArrayRef, cast::AsArray};
6+
#[cfg(test)]
7+
use arrow_schema::Fields;
48
use arrow_schema::extension::ExtensionType;
59
use arrow_schema::{DataType, Field};
610
use datafusion::common::exec_datafusion_err;
711
use datafusion::error::Result;
12+
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs};
813
use datafusion::{common::exec_err, scalar::ScalarValue};
14+
use parquet_variant::{Variant, VariantPath};
915
use parquet_variant_compute::{VariantArray, VariantType};
1016

1117
#[cfg(test)]
@@ -118,6 +124,129 @@ pub fn try_parse_string_columnar(array: &Arc<dyn Array>) -> Result<Vec<Option<&s
118124
Err(exec_datafusion_err!("expected string array"))
119125
}
120126

127+
pub fn variant_get_single_value<T>(
128+
variant_array: &VariantArray,
129+
index: usize,
130+
path: &str,
131+
extract: for<'m, 'v> fn(Variant<'m, 'v>) -> Result<Option<T>>,
132+
) -> Result<Option<T>> {
133+
let Some(variant) = variant_array.iter().nth(index).flatten() else {
134+
return Ok(None);
135+
};
136+
137+
let variant_path = VariantPath::from(path);
138+
let Some(value) = variant.get_path(&variant_path) else {
139+
return Ok(None);
140+
};
141+
142+
extract(value)
143+
}
144+
145+
pub fn variant_get_array_values<T>(
146+
variant_array: &VariantArray,
147+
path: &str,
148+
extract: for<'m, 'v> fn(Variant<'m, 'v>) -> Result<Option<T>>,
149+
) -> Result<Vec<Option<T>>> {
150+
let variant_path = VariantPath::from(path);
151+
152+
variant_array
153+
.iter()
154+
.map(|maybe_variant| {
155+
let Some(variant) = maybe_variant else {
156+
return Ok(None);
157+
};
158+
159+
let Some(value) = variant.get_path(&variant_path) else {
160+
return Ok(None);
161+
};
162+
163+
extract(value)
164+
})
165+
.collect()
166+
}
167+
168+
pub fn invoke_variant_get_typed<T>(
169+
args: ScalarFunctionArgs,
170+
scalar_from_option: fn(Option<T>) -> ScalarValue,
171+
array_from_values: fn(Vec<Option<T>>) -> ArrayRef,
172+
extract: for<'m, 'v> fn(Variant<'m, 'v>) -> Result<Option<T>>,
173+
) -> Result<ColumnarValue> {
174+
let (variant_arg, path_arg) = match args.args.as_slice() {
175+
[variant_arg, path_arg] => (variant_arg, path_arg),
176+
_ => return exec_err!("expected 2 arguments"),
177+
};
178+
179+
let variant_field = args
180+
.arg_fields
181+
.first()
182+
.ok_or_else(|| exec_datafusion_err!("expected argument field"))?;
183+
184+
try_field_as_variant_array(variant_field.as_ref())?;
185+
186+
let out = match (variant_arg, path_arg) {
187+
(ColumnarValue::Array(variant_array), ColumnarValue::Scalar(path_scalar)) => {
188+
let path = try_parse_string_scalar(path_scalar)?
189+
.map(|s| s.as_str())
190+
.unwrap_or_default();
191+
192+
let variant_array = VariantArray::try_new(variant_array.as_ref())?;
193+
let values = variant_get_array_values(&variant_array, path, extract)?;
194+
ColumnarValue::Array(array_from_values(values))
195+
}
196+
(ColumnarValue::Scalar(scalar_variant), ColumnarValue::Scalar(path_scalar)) => {
197+
let ScalarValue::Struct(variant_array) = scalar_variant else {
198+
return exec_err!("expected struct array");
199+
};
200+
201+
let path = try_parse_string_scalar(path_scalar)?
202+
.map(|s| s.as_str())
203+
.unwrap_or_default();
204+
205+
let variant_array = VariantArray::try_new(variant_array.as_ref())?;
206+
let value = variant_get_single_value(&variant_array, 0, path, extract)?;
207+
208+
ColumnarValue::Scalar(scalar_from_option(value))
209+
}
210+
(ColumnarValue::Array(variant_array), ColumnarValue::Array(paths)) => {
211+
if variant_array.len() != paths.len() {
212+
return exec_err!("expected variant array and paths to be of same length");
213+
}
214+
215+
let paths = try_parse_string_columnar(paths)?;
216+
let variant_array = VariantArray::try_new(variant_array.as_ref())?;
217+
218+
let values: Vec<Option<T>> = (0..variant_array.len())
219+
.map(|i| {
220+
let path = paths[i].unwrap_or_default();
221+
variant_get_single_value(&variant_array, i, path, extract)
222+
})
223+
.collect::<Result<_>>()?;
224+
225+
ColumnarValue::Array(array_from_values(values))
226+
}
227+
(ColumnarValue::Scalar(scalar_variant), ColumnarValue::Array(paths)) => {
228+
let ScalarValue::Struct(variant_array) = scalar_variant else {
229+
return exec_err!("expected struct array");
230+
};
231+
232+
let variant_array = VariantArray::try_new(variant_array.as_ref())?;
233+
let paths = try_parse_string_columnar(paths)?;
234+
235+
let values: Vec<Option<T>> = paths
236+
.iter()
237+
.map(|path| {
238+
let path = path.unwrap_or_default();
239+
variant_get_single_value(&variant_array, 0, path, extract)
240+
})
241+
.collect::<Result<_>>()?;
242+
243+
ColumnarValue::Array(array_from_values(values))
244+
}
245+
};
246+
247+
Ok(out)
248+
}
249+
121250
/// This is similar to anyhow's ensure! macro
122251
/// If the `pred` fails, it will return a DataFusionError
123252
pub fn ensure(pred: bool, err_msg: &str) -> Result<()> {
@@ -139,6 +268,50 @@ pub fn build_variant_array_from_json(value: &serde_json::Value) -> VariantArray
139268
builder.build()
140269
}
141270

271+
#[cfg(test)]
272+
pub fn variant_scalar_from_json(json: serde_json::Value) -> ScalarValue {
273+
let mut builder = VariantArrayBuilder::new(1);
274+
builder.append_json(json.to_string().as_str()).unwrap();
275+
ScalarValue::Struct(Arc::new(builder.build().into()))
276+
}
277+
278+
#[cfg(test)]
279+
pub fn variant_array_from_json_rows(json_rows: &[serde_json::Value]) -> ArrayRef {
280+
let mut builder = VariantArrayBuilder::new(json_rows.len());
281+
for value in json_rows {
282+
builder.append_json(value.to_string().as_str()).unwrap();
283+
}
284+
let variant_array: StructArray = builder.build().into();
285+
Arc::new(variant_array) as ArrayRef
286+
}
287+
288+
#[cfg(test)]
289+
pub fn standard_variant_get_arg_fields() -> Vec<Arc<Field>> {
290+
vec![
291+
Arc::new(
292+
Field::new("input", DataType::Struct(Fields::empty()), true)
293+
.with_extension_type(VariantType),
294+
),
295+
Arc::new(Field::new("path", DataType::Utf8, true)),
296+
]
297+
}
298+
299+
#[cfg(test)]
300+
pub fn build_variant_get_args(
301+
variant_input: ColumnarValue,
302+
path: ColumnarValue,
303+
return_data_type: DataType,
304+
arg_fields: Vec<Arc<Field>>,
305+
) -> ScalarFunctionArgs {
306+
ScalarFunctionArgs {
307+
args: vec![variant_input, path],
308+
return_field: Arc::new(Field::new("result", return_data_type, true)),
309+
arg_fields,
310+
number_rows: Default::default(),
311+
config_options: Default::default(),
312+
}
313+
}
314+
142315
#[cfg(test)]
143316
#[allow(unused)]
144317
pub fn build_variant_array_from_json_array(jsons: &[Option<serde_json::Value>]) -> VariantArray {

src/variant_get.rs

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -316,38 +316,17 @@ impl ScalarUDFImpl for VariantGetFieldUdf {
316316

317317
#[cfg(test)]
318318
mod tests {
319+
use super::*;
320+
use crate::shared::{
321+
standard_variant_get_arg_fields, variant_array_from_json_rows, variant_scalar_from_json,
322+
};
319323
use arrow::array::{Array, BinaryViewArray, Int64Array};
320-
use arrow_schema::{Field, Fields};
324+
use arrow_schema::Field;
321325
use datafusion::logical_expr::{ReturnFieldArgs, ScalarFunctionArgs};
322326
use parquet_variant::Variant;
323-
use parquet_variant_compute::{VariantArrayBuilder, VariantType};
324-
use parquet_variant_json::JsonToVariant;
325-
326-
use super::*;
327-
328-
fn variant_scalar_from_json(json: serde_json::Value) -> ScalarValue {
329-
let mut builder = VariantArrayBuilder::new(1);
330-
builder.append_json(json.to_string().as_str()).unwrap();
331-
ScalarValue::Struct(Arc::new(builder.build().into()))
332-
}
333-
334-
fn variant_array_from_json_rows(json_rows: &[serde_json::Value]) -> ArrayRef {
335-
let mut builder = VariantArrayBuilder::new(json_rows.len());
336-
for value in json_rows {
337-
builder.append_json(value.to_string().as_str()).unwrap();
338-
}
339-
let variant_array: StructArray = builder.build().into();
340-
Arc::new(variant_array) as ArrayRef
341-
}
342327

343328
fn standard_arg_fields(with_type_hint: bool) -> Vec<FieldRef> {
344-
let mut fields = vec![
345-
Arc::new(
346-
Field::new("input", DataType::Struct(Fields::empty()), true)
347-
.with_extension_type(VariantType),
348-
),
349-
Arc::new(Field::new("path", DataType::Utf8, true)),
350-
];
329+
let mut fields = standard_variant_get_arg_fields();
351330
if with_type_hint {
352331
fields.push(Arc::new(Field::new("type", DataType::Utf8, true)));
353332
}

0 commit comments

Comments
 (0)