Skip to content

Commit 0c21fec

Browse files
committed
PR Feedback 1
1 parent c5b8514 commit 0c21fec

File tree

2 files changed

+39
-22
lines changed

2 files changed

+39
-22
lines changed

sqlglot/dialects/spark2.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,13 @@ def temporary_storage_provider(expression: exp.Expression) -> exp.Expression:
115115
return expression
116116

117117

118-
def _annotate_by_same_args(
118+
def _annotate_by_similar_args(
119119
self: TypeAnnotator, expression: E, *args: str, target_type: exp.DataType | exp.DataType.Type
120120
) -> E:
121121
"""
122-
Infers the type of the expression if all the param @args are of that type,
123-
otherwise defaults to param @target_type
122+
Infers the type of the expression according to the following rules:
123+
- If all args are of the same type OR any arg is of target_type, the expr is inferred as such
124+
- If any arg is of UNKNOWN type and none of target_type, the expr is inferred as UNKNOWN
124125
"""
125126
self._annotate_args(expression)
126127

@@ -129,25 +130,31 @@ def _annotate_by_same_args(
129130
arg_expr = expression.args.get(arg)
130131
expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
131132

132-
last_datatype = expressions[0].type if expressions else None
133+
last_datatype = None
133134

135+
has_unknown = False
134136
for expr in expressions:
135-
if not expr.is_type(last_datatype):
136-
last_datatype = None
137+
if expr.is_type(exp.DataType.Type.UNKNOWN):
138+
has_unknown = True
139+
elif expr.is_type(target_type):
140+
has_unknown = False
141+
last_datatype = target_type
137142
break
143+
else:
144+
last_datatype = expr.type
138145

139-
self._set_type(expression, last_datatype or target_type)
146+
self._set_type(expression, exp.DataType.Type.UNKNOWN if has_unknown else last_datatype)
140147
return expression
141148

142149

143150
class Spark2(Hive):
144151
ANNOTATORS = {
145152
**Hive.ANNOTATORS,
146153
exp.Substring: lambda self, e: self._annotate_by_args(e, "this"),
147-
exp.Concat: lambda self, e: _annotate_by_same_args(
154+
exp.Concat: lambda self, e: _annotate_by_similar_args(
148155
self, e, "expressions", target_type=exp.DataType.Type.TEXT
149156
),
150-
exp.Pad: lambda self, e: _annotate_by_same_args(
157+
exp.Pad: lambda self, e: _annotate_by_similar_args(
151158
self, e, "this", "fill_pattern", target_type=exp.DataType.Type.TEXT
152159
),
153160
}

tests/test_optimizer.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,28 +1381,38 @@ def gen_expr(depth: int) -> exp.Expression:
13811381
def test_spark_annotators(self):
13821382
"""Test Spark annotators, mainly built-in string/binary functions"""
13831383

1384-
schema = {"tbl": {"bin_col": "BINARY", "str_col": "STRING"}}
1384+
spark_schema = {"tbl": {"bin_col": "BINARY", "str_col": "STRING"}}
1385+
1386+
from sqlglot.dialects import Dialect
13851387

13861388
def _assert_func_return_type(func: str, dialect: str, target_type: str):
13871389
ast = parse_one(f"SELECT {func} FROM tbl", read=dialect)
1388-
optimized = optimizer.optimize(ast, schema=schema, dialect=dialect)
1390+
annotators = Dialect.get_or_raise(dialect).ANNOTATORS
1391+
annotated = annotate_types(ast, annotators=annotators, schema=spark_schema)
1392+
13891393
self.assertEqual(
1390-
optimized.expressions[0].type.sql(dialect),
1394+
annotated.expressions[0].type.sql(dialect),
13911395
exp.DataType.build(target_type).sql(dialect),
13921396
)
13931397

1398+
str_col, bin_col = "tbl.str_col", "tbl.bin_col"
1399+
13941400
# In Spark hierarchy, SUBSTRING result type is dependent on input expr type
13951401
for dialect in ("spark2", "spark", "databricks"):
1396-
_assert_func_return_type("SUBSTRING(str_col, 0, 0)", dialect, "STRING")
1397-
_assert_func_return_type("SUBSTRING(bin_col, 0, 0)", dialect, "BINARY")
1402+
_assert_func_return_type(f"SUBSTRING({str_col}, 0, 0)", dialect, "STRING")
1403+
_assert_func_return_type(f"SUBSTRING({bin_col}, 0, 0)", dialect, "BINARY")
1404+
1405+
_assert_func_return_type(f"CONCAT({bin_col}, {bin_col})", dialect, "BINARY")
1406+
_assert_func_return_type(f"CONCAT({bin_col}, {str_col})", dialect, "STRING")
1407+
_assert_func_return_type(f"CONCAT({str_col}, {bin_col})", dialect, "STRING")
1408+
_assert_func_return_type(f"CONCAT({str_col}, {str_col})", dialect, "STRING")
13981409

1399-
_assert_func_return_type("CONCAT(bin_col, bin_col)", dialect, "BINARY")
1400-
_assert_func_return_type("CONCAT(bin_col, str_col)", dialect, "STRING")
1401-
_assert_func_return_type("CONCAT(str_col, bin_col)", dialect, "STRING")
1402-
_assert_func_return_type("CONCAT(str_col, str_col)", dialect, "STRING")
1410+
_assert_func_return_type(f"CONCAT({str_col}, foo)", dialect, "STRING")
1411+
_assert_func_return_type(f"CONCAT({bin_col}, bar)", dialect, "UNKNOWN")
1412+
_assert_func_return_type("CONCAT(foo, bar)", dialect, "UNKNOWN")
14031413

14041414
for func in ("LPAD", "RPAD"):
1405-
_assert_func_return_type(f"{func}(bin_col, 1, bin_col)", dialect, "BINARY")
1406-
_assert_func_return_type(f"{func}(bin_col, 1, str_col)", dialect, "STRING")
1407-
_assert_func_return_type(f"{func}(str_col, 1, bin_col)", dialect, "STRING")
1408-
_assert_func_return_type(f"{func}(str_col, 1, str_col)", dialect, "STRING")
1415+
_assert_func_return_type(f"{func}({bin_col}, 1, {bin_col})", dialect, "BINARY")
1416+
_assert_func_return_type(f"{func}({bin_col}, 1, {str_col})", dialect, "STRING")
1417+
_assert_func_return_type(f"{func}({str_col}, 1, {bin_col})", dialect, "STRING")
1418+
_assert_func_return_type(f"{func}({str_col}, 1, {str_col})", dialect, "STRING")

0 commit comments

Comments
 (0)