From 89c07039da402fb2ad77e00edb4f09079ecbb41d Mon Sep 17 00:00:00 2001 From: Vaggelis Danias Date: Thu, 3 Oct 2024 18:40:19 +0300 Subject: [PATCH] feat(bigquery): Native math function annotations (#4201) * feat(bigquery): Custom type annotations * PR Feedback 1 --- sqlglot/dialects/bigquery.py | 5 +++ sqlglot/dialects/dialect.py | 4 +-- .../fixtures/optimizer/annotate_functions.sql | 36 ++++++++++++++++++- tests/test_optimizer.py | 4 +-- 4 files changed, 44 insertions(+), 5 deletions(-) diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 5f8933ef8..fa685542f 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -293,6 +293,11 @@ class BigQuery(Dialect): # All set operations require either a DISTINCT or ALL specifier SET_OP_DISTINCT_BY_DEFAULT = dict.fromkeys((exp.Except, exp.Intersect, exp.Union), None) + ANNOTATORS = { + **Dialect.ANNOTATORS, + exp.Sign: lambda self, e: self._annotate_by_args(e, "this"), + } + def normalize_identifier(self, expression: E) -> E: if ( isinstance(expression, exp.Identifier) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index f32f6945d..2e26cb46d 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -577,7 +577,6 @@ class Dialect(metaclass=_Dialect): exp.DataType.Type.DOUBLE: { exp.ApproxQuantile, exp.Avg, - exp.Div, exp.Exp, exp.Ln, exp.Log, @@ -689,9 +688,10 @@ class Dialect(metaclass=_Dialect): exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( e, exp.DataType.build("ARRAY") ), + exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), - exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), + exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.Literal: lambda self, e: self._annotate_literal(e), exp.Map: lambda self, e: self._annotate_map(e), exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), diff --git a/tests/fixtures/optimizer/annotate_functions.sql b/tests/fixtures/optimizer/annotate_functions.sql index 0ab1b5cb8..4fd3a6ec9 100644 --- a/tests/fixtures/optimizer/annotate_functions.sql +++ b/tests/fixtures/optimizer/annotate_functions.sql @@ -1,5 +1,26 @@ -------------------------------------- --- Spark2 / Spark3 / Databricks functions +-- Dialect +-------------------------------------- +ABS(1); +INT; + +ABS(1.5); +DOUBLE; + +GREATEST(1, 2, 3); +INT; + +GREATEST(1, 2.5, 3); +DOUBLE; + +LEAST(1, 2, 3); +INT; + +LEAST(1, 2.5, 3); +DOUBLE; + +-------------------------------------- +-- Spark2 / Spark3 / Databricks -------------------------------------- # dialect: spark2, spark, databricks @@ -69,3 +90,16 @@ STRING; # dialect: spark2, spark, databricks RPAD(tbl.str_col, 1, tbl.str_col); STRING; + + +-------------------------------------- +-- BigQuery +-------------------------------------- + +# dialect: bigquery +SIGN(1); +INT; + +# dialect: bigquery +SIGN(1.5); +DOUBLE; diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 1bbe86a43..d76a81caf 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -806,10 +806,10 @@ def test_annotate_funcs(self): load_sql_fixture_pairs("optimizer/annotate_functions.sql"), start=1 ): title = meta.get("title") or f"{i}, {sql}" - dialects = meta.get("dialect").split(", ") + dialect = meta.get("dialect") or "" sql = f"SELECT {sql} FROM tbl" - for dialect in dialects: + for dialect in dialect.split(", "): result = parse_and_optimize( annotate_functions, sql, dialect, schema=test_schema, dialect=dialect )