Skip to content

Commit

Permalink
PR Feedback 1
Browse files Browse the repository at this point in the history
  • Loading branch information
VaggelisD committed Sep 26, 2024
1 parent c5b8514 commit 1caba37
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 20 deletions.
19 changes: 12 additions & 7 deletions sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,13 @@ def temporary_storage_provider(expression: exp.Expression) -> exp.Expression:
return expression


def _annotate_by_same_args(
def _annotate_by_similar_args(
self: TypeAnnotator, expression: E, *args: str, target_type: exp.DataType | exp.DataType.Type
) -> E:
"""
Infers the type of the expression if all the param @args are of that type,
otherwise defaults to param @target_type
Infers the type of the expression according to the following rules:
- If all args are UNKNOWN OR all args are mixed types, the type is set to target_type
- If only one arg is known OR all args are of the same type, the type is set to that type
"""
self._annotate_args(expression)

Expand All @@ -129,10 +130,14 @@ def _annotate_by_same_args(
arg_expr = expression.args.get(arg)
expressions.extend(expr for expr in ensure_list(arg_expr) if expr)

last_datatype = expressions[0].type if expressions else None
last_datatype = None

for expr in expressions:
if not expr.is_type(last_datatype):
if expr.is_type(exp.DataType.Type.UNKNOWN):
continue
elif not last_datatype:
last_datatype = expr.type
elif not expr.is_type(last_datatype):
last_datatype = None
break

Expand All @@ -144,10 +149,10 @@ class Spark2(Hive):
ANNOTATORS = {
**Hive.ANNOTATORS,
exp.Substring: lambda self, e: self._annotate_by_args(e, "this"),
exp.Concat: lambda self, e: _annotate_by_same_args(
exp.Concat: lambda self, e: _annotate_by_similar_args(
self, e, "expressions", target_type=exp.DataType.Type.TEXT
),
exp.Pad: lambda self, e: _annotate_by_same_args(
exp.Pad: lambda self, e: _annotate_by_similar_args(
self, e, "this", "fill_pattern", target_type=exp.DataType.Type.TEXT
),
}
Expand Down
36 changes: 23 additions & 13 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,28 +1381,38 @@ def gen_expr(depth: int) -> exp.Expression:
def test_spark_annotators(self):
"""Test Spark annotators, mainly built-in string/binary functions"""

schema = {"tbl": {"bin_col": "BINARY", "str_col": "STRING"}}
spark_schema = {"tbl": {"bin_col": "BINARY", "str_col": "STRING"}}

from sqlglot.dialects import Dialect

def _assert_func_return_type(func: str, dialect: str, target_type: str):
ast = parse_one(f"SELECT {func} FROM tbl", read=dialect)
optimized = optimizer.optimize(ast, schema=schema, dialect=dialect)
annotators = Dialect.get_or_raise(dialect).ANNOTATORS
annotated = annotate_types(ast, annotators=annotators, schema=spark_schema)

self.assertEqual(
optimized.expressions[0].type.sql(dialect),
annotated.expressions[0].type.sql(dialect),
exp.DataType.build(target_type).sql(dialect),
)

str_col, bin_col = "tbl.str_col", "tbl.bin_col"

# In Spark hierarchy, SUBSTRING result type is dependent on input expr type
for dialect in ("spark2", "spark", "databricks"):
_assert_func_return_type("SUBSTRING(str_col, 0, 0)", dialect, "STRING")
_assert_func_return_type("SUBSTRING(bin_col, 0, 0)", dialect, "BINARY")
_assert_func_return_type(f"SUBSTRING({str_col}, 0, 0)", dialect, "STRING")
_assert_func_return_type(f"SUBSTRING({bin_col}, 0, 0)", dialect, "BINARY")

_assert_func_return_type(f"CONCAT({bin_col}, {bin_col})", dialect, "BINARY")
_assert_func_return_type(f"CONCAT({bin_col}, {str_col})", dialect, "STRING")
_assert_func_return_type(f"CONCAT({str_col}, {bin_col})", dialect, "STRING")
_assert_func_return_type(f"CONCAT({str_col}, {str_col})", dialect, "STRING")

_assert_func_return_type("CONCAT(bin_col, bin_col)", dialect, "BINARY")
_assert_func_return_type("CONCAT(bin_col, str_col)", dialect, "STRING")
_assert_func_return_type("CONCAT(str_col, bin_col)", dialect, "STRING")
_assert_func_return_type("CONCAT(str_col, str_col)", dialect, "STRING")
_assert_func_return_type(f"CONCAT({str_col}, foo)", dialect, "STRING")
_assert_func_return_type(f"CONCAT({bin_col}, bar)", dialect, "BINARY")
_assert_func_return_type("CONCAT(foo, bar)", dialect, "STRING")

for func in ("LPAD", "RPAD"):
_assert_func_return_type(f"{func}(bin_col, 1, bin_col)", dialect, "BINARY")
_assert_func_return_type(f"{func}(bin_col, 1, str_col)", dialect, "STRING")
_assert_func_return_type(f"{func}(str_col, 1, bin_col)", dialect, "STRING")
_assert_func_return_type(f"{func}(str_col, 1, str_col)", dialect, "STRING")
_assert_func_return_type(f"{func}({bin_col}, 1, {bin_col})", dialect, "BINARY")
_assert_func_return_type(f"{func}({bin_col}, 1, {str_col})", dialect, "STRING")
_assert_func_return_type(f"{func}({str_col}, 1, {bin_col})", dialect, "STRING")
_assert_func_return_type(f"{func}({str_col}, 1, {str_col})", dialect, "STRING")

0 comments on commit 1caba37

Please sign in to comment.