Skip to content

Commit 8fad18d

Browse files
fix cast
1 parent 30e0581 commit 8fad18d

3 files changed

Lines changed: 23 additions & 7 deletions

File tree

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,6 @@ pedantic = { level = "deny", priority = -1 }
3535
[[bench]]
3636
name = "main"
3737
harness = false
38+
39+
[patch.crates-io]
40+
datafusion = { git = "https://github.com/pydantic/datafusion.git", branch = "pydantic-main" }

src/rewrite.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ fn optimise_json_get_cast(cast: &Cast) -> Option<Transformed<Expr>> {
3838
if scalar_func.func.name() != "json_get" {
3939
return None;
4040
}
41-
let func = match &cast.data_type {
41+
let func = match cast.field.data_type() {
4242
DataType::Boolean => crate::json_get_bool::json_get_bool_udf(),
4343
DataType::Float64 | DataType::Float32 | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
4444
crate::json_get_float::json_get_float_udf()

tests/main.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,27 @@ async fn test_json_get_array_with_path() {
163163

164164
#[tokio::test]
165165
async fn test_json_get_equals() {
166-
let e = run_query(r"select name, json_get(json_data, 'foo')='abc' from test")
166+
// union comparison now works thanks to the union coercions upport in datafusion
167+
// (previously failed with "Cannot infer common argument type for comparison operation Union")
168+
// see https://github.com/apache/datafusion/issues/10180
169+
let batches = run_query(r"select name, json_get(json_data, 'foo')='abc' from test")
167170
.await
168-
.unwrap_err();
171+
.unwrap();
169172

170-
// see https://github.com/apache/datafusion/issues/10180
171-
assert!(e
172-
.to_string()
173-
.starts_with("Error during planning: Cannot infer common argument type for comparison operation Union"));
173+
let expected = [
174+
"+------------------+----------------------------------------------------+",
175+
r#"| name | json_get(test.json_data,Utf8("foo")) = Utf8("abc") |"#,
176+
"+------------------+----------------------------------------------------+",
177+
"| object_foo | true |",
178+
"| object_foo_array | |",
179+
"| object_foo_obj | |",
180+
"| object_foo_null | |",
181+
"| object_bar | |",
182+
"| list_foo | |",
183+
"| invalid_json | |",
184+
"+------------------+----------------------------------------------------+",
185+
];
186+
assert_batches_eq!(expected, &batches);
174187
}
175188

176189
#[tokio::test]

0 commit comments

Comments
 (0)