Skip to content

Commit 89c0703

Browse files
authored
feat(bigquery): Native math function annotations (#4201)
* feat(bigquery): Custom type annotations * PR Feedback 1
1 parent 332c74b commit 89c0703

File tree

4 files changed

+44
-5
lines changed

4 files changed

+44
-5
lines changed

sqlglot/dialects/bigquery.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,11 @@ class BigQuery(Dialect):
293293
# All set operations require either a DISTINCT or ALL specifier
294294
SET_OP_DISTINCT_BY_DEFAULT = dict.fromkeys((exp.Except, exp.Intersect, exp.Union), None)
295295

296+
ANNOTATORS = {
297+
**Dialect.ANNOTATORS,
298+
exp.Sign: lambda self, e: self._annotate_by_args(e, "this"),
299+
}
300+
296301
def normalize_identifier(self, expression: E) -> E:
297302
if (
298303
isinstance(expression, exp.Identifier)

sqlglot/dialects/dialect.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,6 @@ class Dialect(metaclass=_Dialect):
577577
exp.DataType.Type.DOUBLE: {
578578
exp.ApproxQuantile,
579579
exp.Avg,
580-
exp.Div,
581580
exp.Exp,
582581
exp.Ln,
583582
exp.Log,
@@ -689,9 +688,10 @@ class Dialect(metaclass=_Dialect):
689688
exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type(
690689
e, exp.DataType.build("ARRAY<TIMESTAMP>")
691690
),
691+
exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
692692
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
693693
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
694-
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
694+
exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
695695
exp.Literal: lambda self, e: self._annotate_literal(e),
696696
exp.Map: lambda self, e: self._annotate_map(e),
697697
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),

tests/fixtures/optimizer/annotate_functions.sql

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,26 @@
11
--------------------------------------
2-
-- Spark2 / Spark3 / Databricks functions
2+
-- Dialect
3+
--------------------------------------
4+
ABS(1);
5+
INT;
6+
7+
ABS(1.5);
8+
DOUBLE;
9+
10+
GREATEST(1, 2, 3);
11+
INT;
12+
13+
GREATEST(1, 2.5, 3);
14+
DOUBLE;
15+
16+
LEAST(1, 2, 3);
17+
INT;
18+
19+
LEAST(1, 2.5, 3);
20+
DOUBLE;
21+
22+
--------------------------------------
23+
-- Spark2 / Spark3 / Databricks
324
--------------------------------------
425

526
# dialect: spark2, spark, databricks
@@ -69,3 +90,16 @@ STRING;
6990
# dialect: spark2, spark, databricks
7091
RPAD(tbl.str_col, 1, tbl.str_col);
7192
STRING;
93+
94+
95+
--------------------------------------
96+
-- BigQuery
97+
--------------------------------------
98+
99+
# dialect: bigquery
100+
SIGN(1);
101+
INT;
102+
103+
# dialect: bigquery
104+
SIGN(1.5);
105+
DOUBLE;

tests/test_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -806,10 +806,10 @@ def test_annotate_funcs(self):
806806
load_sql_fixture_pairs("optimizer/annotate_functions.sql"), start=1
807807
):
808808
title = meta.get("title") or f"{i}, {sql}"
809-
dialects = meta.get("dialect").split(", ")
809+
dialect = meta.get("dialect") or ""
810810
sql = f"SELECT {sql} FROM tbl"
811811

812-
for dialect in dialects:
812+
for dialect in dialect.split(", "):
813813
result = parse_and_optimize(
814814
annotate_functions, sql, dialect, schema=test_schema, dialect=dialect
815815
)

0 commit comments

Comments
 (0)