|
| 1 | +use std::sync::Arc; |
| 2 | + |
1 | 3 | use datafusion::arrow::datatypes::DataType;
|
2 | 4 | use datafusion::common::config::ConfigOptions;
|
3 | 5 | use datafusion::common::tree_node::Transformed;
|
| 6 | +use datafusion::common::Column; |
4 | 7 | use datafusion::common::DFSchema;
|
5 | 8 | use datafusion::common::Result;
|
6 | 9 | use datafusion::logical_expr::expr::{Alias, Cast, Expr, ScalarFunction};
|
7 | 10 | use datafusion::logical_expr::expr_rewriter::FunctionRewrite;
|
8 | 11 | use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr};
|
9 | 12 | use datafusion::logical_expr::sqlparser::ast::BinaryOperator;
|
| 13 | +use datafusion::logical_expr::ScalarUDF; |
| 14 | +use datafusion::scalar::ScalarValue; |
10 | 15 |
|
11 | 16 | #[derive(Debug)]
|
12 | 17 | pub(crate) struct JsonFunctionRewriter;
|
@@ -93,27 +98,91 @@ fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> {
|
93 | 98 | }
|
94 | 99 | }
|
95 | 100 |
|
| 101 | +#[derive(Debug, Clone, Copy)] |
| 102 | +enum JsonOperator { |
| 103 | + Arrow, |
| 104 | + LongArrow, |
| 105 | + Question, |
| 106 | +} |
| 107 | + |
| 108 | +impl TryFrom<&BinaryOperator> for JsonOperator { |
| 109 | + type Error = (); |
| 110 | + |
| 111 | + fn try_from(op: &BinaryOperator) -> Result<Self, Self::Error> { |
| 112 | + match op { |
| 113 | + BinaryOperator::Arrow => Ok(JsonOperator::Arrow), |
| 114 | + BinaryOperator::LongArrow => Ok(JsonOperator::LongArrow), |
| 115 | + BinaryOperator::Question => Ok(JsonOperator::Question), |
| 116 | + _ => Err(()), |
| 117 | + } |
| 118 | + } |
| 119 | +} |
| 120 | + |
| 121 | +impl From<JsonOperator> for Arc<ScalarUDF> { |
| 122 | + fn from(op: JsonOperator) -> Arc<ScalarUDF> { |
| 123 | + match op { |
| 124 | + JsonOperator::Arrow => crate::udfs::json_get_udf(), |
| 125 | + JsonOperator::LongArrow => crate::udfs::json_as_text_udf(), |
| 126 | + JsonOperator::Question => crate::udfs::json_contains_udf(), |
| 127 | + } |
| 128 | + } |
| 129 | +} |
| 130 | + |
| 131 | +impl std::fmt::Display for JsonOperator { |
| 132 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 133 | + match self { |
| 134 | + JsonOperator::Arrow => write!(f, "->"), |
| 135 | + JsonOperator::LongArrow => write!(f, "->>"), |
| 136 | + JsonOperator::Question => write!(f, "?"), |
| 137 | + } |
| 138 | + } |
| 139 | +} |
| 140 | + |
| 141 | +/// Convert an Expr to a String representatiion for use in alias names. |
| 142 | +fn expr_to_sql_repr(expr: &Expr) -> String { |
| 143 | + match expr { |
| 144 | + Expr::Column(Column { name, relation }) => relation |
| 145 | + .as_ref() |
| 146 | + .map_or_else(|| name.clone(), |r| format!("{r}.{name}")), |
| 147 | + Expr::Alias(alias) => alias.name.clone(), |
| 148 | + Expr::Literal(scalar) => match scalar { |
| 149 | + ScalarValue::Utf8(Some(v)) | ScalarValue::Utf8View(Some(v)) | ScalarValue::LargeUtf8(Some(v)) => { |
| 150 | + format!("'{v}'") |
| 151 | + } |
| 152 | + ScalarValue::UInt8(Some(v)) => v.to_string(), |
| 153 | + ScalarValue::UInt16(Some(v)) => v.to_string(), |
| 154 | + ScalarValue::UInt32(Some(v)) => v.to_string(), |
| 155 | + ScalarValue::UInt64(Some(v)) => v.to_string(), |
| 156 | + ScalarValue::Int8(Some(v)) => v.to_string(), |
| 157 | + ScalarValue::Int16(Some(v)) => v.to_string(), |
| 158 | + ScalarValue::Int32(Some(v)) => v.to_string(), |
| 159 | + ScalarValue::Int64(Some(v)) => v.to_string(), |
| 160 | + _ => scalar.to_string(), |
| 161 | + }, |
| 162 | + Expr::Cast(cast) => expr_to_sql_repr(&cast.expr), |
| 163 | + _ => expr.to_string(), |
| 164 | + } |
| 165 | +} |
| 166 | + |
96 | 167 | /// Implement a custom SQL planner to replace postgres JSON operators with custom UDFs
|
97 | 168 | #[derive(Debug, Default)]
|
98 | 169 | pub struct JsonExprPlanner;
|
99 | 170 |
|
100 | 171 | impl ExprPlanner for JsonExprPlanner {
|
101 | 172 | fn plan_binary_op(&self, expr: RawBinaryExpr, _schema: &DFSchema) -> Result<PlannerResult<RawBinaryExpr>> {
|
102 |
| - let (func, op_display) = match &expr.op { |
103 |
| - BinaryOperator::Arrow => (crate::json_get::json_get_udf(), "->"), |
104 |
| - BinaryOperator::LongArrow => (crate::json_as_text::json_as_text_udf(), "->>"), |
105 |
| - BinaryOperator::Question => (crate::json_contains::json_contains_udf(), "?"), |
106 |
| - _ => return Ok(PlannerResult::Original(expr)), |
107 |
| - }; |
108 |
| - let alias_name = match &expr.left { |
109 |
| - Expr::Alias(alias) => format!("{} {} {}", alias.name, op_display, expr.right), |
110 |
| - left_expr => format!("{} {} {}", left_expr, op_display, expr.right), |
| 173 | + let Ok(op) = JsonOperator::try_from(&expr.op) else { |
| 174 | + return Ok(PlannerResult::Original(expr)); |
111 | 175 | };
|
112 | 176 |
|
| 177 | + let left_repr = expr_to_sql_repr(&expr.left); |
| 178 | + let right_repr = expr_to_sql_repr(&expr.right); |
| 179 | + |
| 180 | + let alias_name = format!("{left_repr} {op} {right_repr}"); |
| 181 | + |
113 | 182 | // we put the alias in so that default column titles are `foo -> bar` instead of `json_get(foo, bar)`
|
114 | 183 | Ok(PlannerResult::Planned(Expr::Alias(Alias::new(
|
115 | 184 | Expr::ScalarFunction(ScalarFunction {
|
116 |
| - func, |
| 185 | + func: op.into(), |
117 | 186 | args: vec![expr.left, expr.right],
|
118 | 187 | }),
|
119 | 188 | None::<&str>,
|
|
0 commit comments