From 89e6763776013abe037133ba5acfc86c22161950 Mon Sep 17 00:00:00 2001 From: daihuynh Date: Fri, 4 Oct 2024 15:41:12 +0930 Subject: [PATCH 1/2] feat(tsql): SPLIT_PART function and conversion to PARSENAME in tsql --- sqlglot/dialects/tsql.py | 31 ++++++++++++++++++++- sqlglot/expressions.py | 5 ++++ tests/dialects/test_tsql.py | 54 +++++++++++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 1 deletion(-) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 65002427c..b675386e6 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -324,6 +324,34 @@ def _parse(args: t.List[exp.Expression]) -> exp.Expression: return _parse +# https://learn.microsoft.com/en-us/sql/t-sql/functions/parsename-transact-sql?view=sql-server-ver16 +def _build_parsename(args: t.List) -> t.Optional[exp.SplitPart]: + if len(args) != 2: + return None + arg_this: exp.Literal = seq_get(args, 0) or exp.Literal.string("") + arg_partnum: exp.Literal = seq_get(args, 1) or exp.Literal.number(1) + text = arg_this.this + part_num = int(arg_partnum.this) + length = 1 if isinstance(arg_this, exp.Null) else len(text.split(".")) + 1 # Reverse index + idx = 0 if isinstance(arg_this, exp.Null) else int(part_num) + return exp.SplitPart( + this=arg_this, delimiter=exp.Literal.string("."), part_num=exp.Literal.number(length - idx) + ) + + +def _parsename_sql(self: TSQL.Generator, expression: exp.SplitPart) -> str: + delimiter: exp.Literal = expression.args.get("delimiter") or exp.Literal.string(".") + if delimiter.this != ".": + return str(expression) + arg_this: exp.Literal = expression.args.get("this") or exp.Literal.string("") + arg_part_num: exp.Literal = expression.args.get("part_num") or exp.Literal.number(1) + text = arg_this.this + part_num = int(arg_part_num.this) + length = 1 if isinstance(arg_this, exp.Null) else len(text.split(".")) + 1 # Reverse index + idx = 0 if isinstance(arg_this, exp.Null) else part_num + return self.func("PARSENAME", arg_this, exp.Literal.number(length - idx)) + + def _build_json_query(args: t.List, dialect: Dialect) -> exp.JSONExtract: if len(args) == 1: # The default value for path is '$'. As a result, if you don't provide a @@ -542,7 +570,7 @@ class Parser(parser.Parser): "JSON_VALUE": parser.build_extract_json_with_path(exp.JSONExtractScalar), "LEN": _build_with_arg_as_text(exp.Length), "LEFT": _build_with_arg_as_text(exp.Left), - "RIGHT": _build_with_arg_as_text(exp.Right), + "PARSENAME": _build_parsename, "REPLICATE": exp.Repeat.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "SYSDATETIME": exp.CurrentTimestamp.from_arg_list, @@ -886,6 +914,7 @@ class Generator(generator.Generator): transforms.unnest_generate_date_array_using_recursive_cte, ] ), + exp.SplitPart: _parsename_sql, exp.Stddev: rename_func("STDEV"), exp.StrPosition: lambda self, e: self.func( "CHARINDEX", e.args.get("substr"), e.this, e.args.get("position") diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 59de679f3..9d5e6ef04 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -6054,6 +6054,11 @@ class Split(Func): arg_types = {"this": True, "expression": True, "limit": False} +# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split_part.html +class SplitPart(Func): + arg_types = {"this": True, "delimiter": True, "part_num": True} + + # Start may be omitted in the case of postgres # https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 class Substring(Func): diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index cec78894e..a9d76ef1a 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -2009,3 +2009,57 @@ def test_grant(self): self.validate_identity( "GRANT EXECUTE ON TestProc TO User2 AS TesterRole", check_command_warning=True ) + + def test_parsename(self): + # Test default case + self.validate_all( + "SELECT PARSENAME('1.2.3', 1)", + read={ + "spark": "SELECT SPLIT_PART('1.2.3', '.', 3)", + "databricks": "SELECT SPLIT_PART('1.2.3', '.', 3)", + }, + write={ + "spark": "SELECT SPLIT_PART('1.2.3', '.', 3)", + "databricks": "SELECT SPLIT_PART('1.2.3', '.', 3)", + "tsql": "SELECT PARSENAME('1.2.3', 1)", + }, + ) + # Test zero index + self.validate_all( + "SELECT PARSENAME('1.2.3', 0)", + read={ + "spark": "SELECT SPLIT_PART('1.2.3', '.', 4)", + "databricks": "SELECT SPLIT_PART('1.2.3', '.', 4)", + }, + write={ + "spark": "SELECT SPLIT_PART('1.2.3', '.', 4)", + "databricks": "SELECT SPLIT_PART('1.2.3', '.', 4)", + "tsql": "SELECT PARSENAME('1.2.3', 0)", + }, + ) + # Test null value + self.validate_all( + "SELECT PARSENAME(NULL, 1)", + read={ + "spark": "SELECT SPLIT_PART(NULL, '.', 1)", + "databricks": "SELECT SPLIT_PART(NULL, '.', 1)", + }, + write={ + "spark": "SELECT SPLIT_PART(NULL, '.', 1)", + "databricks": "SELECT SPLIT_PART(NULL, '.', 1)", + "tsql": "SELECT PARSENAME(NULL, 1)", + }, + ) + # Test non-dot delimiter + self.validate_all( + "SELECT SPLIT_PART('1.2.3', ',', 1)", + read={ + "spark": "SELECT SPLIT_PART('1.2.3', ',', 1)", + "databricks": "SELECT SPLIT_PART('1.2.3', ',', 1)", + }, + write={ + "spark": "SELECT SPLIT_PART('1.2.3', ',', 1)", + "databricks": "SELECT SPLIT_PART('1.2.3', ',', 1)", + "tsql": "SELECT SPLIT_PART('1.2.3', ',', 1)", + }, + ) From 39eb8ee522ca03b99515c0a7b64f66ca97bae40d Mon Sep 17 00:00:00 2001 From: daihuynh Date: Fri, 4 Oct 2024 16:12:57 +0930 Subject: [PATCH 2/2] Fix: restore RIGHT in FUNCTIONS --- sqlglot/dialects/tsql.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index b675386e6..a4ff42f6f 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -570,6 +570,7 @@ class Parser(parser.Parser): "JSON_VALUE": parser.build_extract_json_with_path(exp.JSONExtractScalar), "LEN": _build_with_arg_as_text(exp.Length), "LEFT": _build_with_arg_as_text(exp.Left), + "RIGHT": _build_with_arg_as_text(exp.Right), "PARSENAME": _build_parsename, "REPLICATE": exp.Repeat.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),