Skip to content

Commit

Permalink
feat(bigquery): Native math function annotations (#4201)
Browse files Browse the repository at this point in the history
* feat(bigquery): Custom type annotations

* PR Feedback 1
  • Loading branch information
VaggelisD authored Oct 3, 2024
1 parent 332c74b commit 89c0703
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 5 deletions.
5 changes: 5 additions & 0 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,11 @@ class BigQuery(Dialect):
# All set operations require either a DISTINCT or ALL specifier
SET_OP_DISTINCT_BY_DEFAULT = dict.fromkeys((exp.Except, exp.Intersect, exp.Union), None)

ANNOTATORS = {
**Dialect.ANNOTATORS,
exp.Sign: lambda self, e: self._annotate_by_args(e, "this"),
}

def normalize_identifier(self, expression: E) -> E:
if (
isinstance(expression, exp.Identifier)
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,6 @@ class Dialect(metaclass=_Dialect):
exp.DataType.Type.DOUBLE: {
exp.ApproxQuantile,
exp.Avg,
exp.Div,
exp.Exp,
exp.Ln,
exp.Log,
Expand Down Expand Up @@ -689,9 +688,10 @@ class Dialect(metaclass=_Dialect):
exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type(
e, exp.DataType.build("ARRAY<TIMESTAMP>")
),
exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Literal: lambda self, e: self._annotate_literal(e),
exp.Map: lambda self, e: self._annotate_map(e),
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
Expand Down
36 changes: 35 additions & 1 deletion tests/fixtures/optimizer/annotate_functions.sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,26 @@
--------------------------------------
-- Spark2 / Spark3 / Databricks functions
-- Dialect
--------------------------------------
ABS(1);
INT;

ABS(1.5);
DOUBLE;

GREATEST(1, 2, 3);
INT;

GREATEST(1, 2.5, 3);
DOUBLE;

LEAST(1, 2, 3);
INT;

LEAST(1, 2.5, 3);
DOUBLE;

--------------------------------------
-- Spark2 / Spark3 / Databricks
--------------------------------------

# dialect: spark2, spark, databricks
Expand Down Expand Up @@ -69,3 +90,16 @@ STRING;
# dialect: spark2, spark, databricks
RPAD(tbl.str_col, 1, tbl.str_col);
STRING;


--------------------------------------
-- BigQuery
--------------------------------------

# dialect: bigquery
SIGN(1);
INT;

# dialect: bigquery
SIGN(1.5);
DOUBLE;
4 changes: 2 additions & 2 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,10 +806,10 @@ def test_annotate_funcs(self):
load_sql_fixture_pairs("optimizer/annotate_functions.sql"), start=1
):
title = meta.get("title") or f"{i}, {sql}"
dialects = meta.get("dialect").split(", ")
dialect = meta.get("dialect") or ""
sql = f"SELECT {sql} FROM tbl"

for dialect in dialects:
for dialect in dialect.split(", "):
result = parse_and_optimize(
annotate_functions, sql, dialect, schema=test_schema, dialect=dialect
)
Expand Down

0 comments on commit 89c0703

Please sign in to comment.