Skip to content

Commit 4ad1e9b

Browse files
authored
support rewriting of ::numeric cast (#33)
1 parent 4f84c26 commit 4ad1e9b

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

src/rewrite.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ fn optimise_json_get_cast(cast: &Cast) -> Option<Transformed<Expr>> {
3434
}
3535
let func = match &cast.data_type {
3636
DataType::Boolean => crate::json_get_bool::json_get_bool_udf(),
37-
DataType::Float64 | DataType::Float32 => crate::json_get_float::json_get_float_udf(),
37+
DataType::Float64 | DataType::Float32 | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
38+
crate::json_get_float::json_get_float_udf()
39+
}
3840
DataType::Int64 | DataType::Int32 => crate::json_get_int::json_get_int_udf(),
3941
DataType::Utf8 => crate::json_get_str::json_get_str_udf(),
4042
_ => return None,

tests/main.rs

+21-1
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,20 @@ async fn test_json_get_cast_float() {
277277
assert_eq!(display_val(batches).await, (DataType::Float64, "420.0".to_string()));
278278
}
279279

280+
#[tokio::test]
281+
async fn test_json_get_cast_numeric() {
282+
let sql = r#"select json_get('{"foo": 4.2e2}', 'foo')::numeric"#;
283+
let batches = run_query(sql).await.unwrap();
284+
assert_eq!(display_val(batches).await, (DataType::Float64, "420.0".to_string()));
285+
}
286+
287+
#[tokio::test]
288+
async fn test_json_get_cast_numeric_equals() {
289+
let sql = r#"select json_get('{"foo": 420}', 'foo')::numeric = 420"#;
290+
let batches = run_query(sql).await.unwrap();
291+
assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string()));
292+
}
293+
280294
#[tokio::test]
281295
async fn test_json_get_bool() {
282296
let batches = run_query("select json_get_bool('[true]', 0)").await.unwrap();
@@ -1101,7 +1115,7 @@ async fn test_arrow_scalar_union_is_null() {
11011115
}
11021116

11031117
#[tokio::test]
1104-
async fn test_arrow_cast() {
1118+
async fn test_long_arrow_cast() {
11051119
let batches = run_query("select (json_data->>'foo')::int from other").await.unwrap();
11061120

11071121
let expected = [
@@ -1116,3 +1130,9 @@ async fn test_arrow_cast() {
11161130
];
11171131
assert_batches_eq!(expected, &batches);
11181132
}
1133+
1134+
async fn test_arrow_cast_numeric() {
1135+
let sql = r#"select ('{"foo": 420}'->'foo')::numeric = 420"#;
1136+
let batches = run_query(sql).await.unwrap();
1137+
assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string()));
1138+
}

0 commit comments

Comments
 (0)