Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(duckdb): Add more Postgres operators #4199

Merged
merged 3 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
from sqlglot.parser import binary_range_parser

DATETIME_DELTA = t.Union[
exp.DateAdd, exp.TimeAdd, exp.DatetimeAdd, exp.TsOrDsAdd, exp.DateSub, exp.DatetimeSub
Expand Down Expand Up @@ -290,6 +291,9 @@ class Tokenizer(tokens.Tokenizer):
**tokens.Tokenizer.KEYWORDS,
"//": TokenType.DIV,
"**": TokenType.DSTAR,
"^@": TokenType.CARET_AT,
"@>": TokenType.AT_GT,
"<@": TokenType.LT_AT,
"ATTACH": TokenType.COMMAND,
"BINARY": TokenType.VARBINARY,
"BITSTRING": TokenType.BIT,
Expand Down Expand Up @@ -328,6 +332,12 @@ class Parser(parser.Parser):
}
BITWISE.pop(TokenType.CARET)

RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps),
TokenType.CARET_AT: binary_range_parser(exp.StartsWith),
}

EXPONENT = {
**parser.Parser.EXPONENT,
TokenType.CARET: exp.Pow,
Expand Down Expand Up @@ -488,7 +498,6 @@ class Generator(generator.Generator):
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: inline_array_unless_query,
exp.ArrayContainsAll: rename_func("ARRAY_HAS_ALL"),
exp.ArrayFilter: rename_func("LIST_FILTER"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"),
Expand Down
6 changes: 0 additions & 6 deletions sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,6 @@ class Tokenizer(tokens.Tokenizer):

KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"~~*": TokenType.ILIKE,
"~*": TokenType.IRLIKE,
"~": TokenType.RLIKE,
"@@": TokenType.DAT,
"@>": TokenType.AT_GT,
Expand Down Expand Up @@ -385,12 +383,10 @@ class Parser(parser.Parser):

RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
TokenType.AT_GT: binary_range_parser(exp.ArrayContainsAll),
TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps),
TokenType.DAT: lambda self, this: self.expression(
exp.MatchAgainst, this=self._parse_bitwise(), expressions=[this]
),
TokenType.LT_AT: binary_range_parser(exp.ArrayContainsAll, reverse_args=True),
TokenType.OPERATOR: lambda self, this: self._parse_operator(this),
}

Expand Down Expand Up @@ -488,8 +484,6 @@ class Generator(generator.Generator):
**generator.Generator.TRANSFORMS,
exp.AnyValue: any_value_to_max_sql,
exp.ArrayConcat: lambda self, e: self.arrayconcat_sql(e, name="ARRAY_CAT"),
exp.ArrayContainsAll: lambda self, e: self.binary(e, "@>"),
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayFilter: filter_array_using_unnest,
exp.ArraySize: lambda self, e: self.func("ARRAY_LENGTH", e.this, e.expression or "1"),
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ class Generator(metaclass=_Generator):
**JSON_PATH_PART_TRANSFORMS,
exp.AllowedValuesProperty: lambda self,
e: f"ALLOWED_VALUES {self.expressions(e, flat=True)}",
exp.ArrayContainsAll: lambda self, e: self.binary(e, "@>"),
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}",
exp.BackupProperty: lambda self, e: f"BACKUP {self.sql(e, 'this')}",
exp.CaseSpecificColumnConstraint: lambda _,
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,13 +859,15 @@ class Parser(metaclass=_Parser):
}

RANGE_PARSERS = {
TokenType.AT_GT: binary_range_parser(exp.ArrayContainsAll),
TokenType.BETWEEN: lambda self, this: self._parse_between(this),
TokenType.GLOB: binary_range_parser(exp.Glob),
TokenType.ILIKE: binary_range_parser(exp.ILike),
TokenType.IN: lambda self, this: self._parse_in(this),
TokenType.IRLIKE: binary_range_parser(exp.RegexpILike),
TokenType.IS: lambda self, this: self._parse_is(this),
TokenType.LIKE: binary_range_parser(exp.Like),
TokenType.LT_AT: binary_range_parser(exp.ArrayContainsAll, reverse_args=True),
VaggelisD marked this conversation as resolved.
Show resolved Hide resolved
TokenType.OVERLAPS: binary_range_parser(exp.Overlaps),
TokenType.RLIKE: binary_range_parser(exp.RegexpLike),
TokenType.SIMILAR_TO: binary_range_parser(exp.SimilarTo),
Expand Down
3 changes: 3 additions & 0 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class TokenType(AutoName):
PIPE_SLASH = auto()
DPIPE_SLASH = auto()
CARET = auto()
CARET_AT = auto()
TILDA = auto()
ARROW = auto()
DARROW = auto()
Expand Down Expand Up @@ -651,6 +652,8 @@ class Tokenizer(metaclass=_Tokenizer):
"??": TokenType.DQMARK,
"~~~": TokenType.GLOB,
"~~": TokenType.LIKE,
"~~*": TokenType.ILIKE,
"~*": TokenType.IRLIKE,
"ALL": TokenType.ALL,
"ALWAYS": TokenType.ALWAYS,
"AND": TokenType.AND,
Expand Down
12 changes: 12 additions & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,18 @@ def test_duckdb(self):
self.validate_identity("a ** b", "POWER(a, b)")
self.validate_identity("a ~~~ b", "a GLOB b")
self.validate_identity("a ~~ b", "a LIKE b")
self.validate_identity("a @> b")
self.validate_identity("a <@ b", "b @> a")
self.validate_identity("a && b").assert_is(exp.ArrayOverlaps)
self.validate_identity("a ^@ b", "STARTS_WITH(a, b)")
self.validate_identity(
"a !~~ b",
"NOT a LIKE b",
)
self.validate_identity(
"a !~~* b",
"NOT a ILIKE b",
)

def test_array_index(self):
with self.assertLogs(helper_logger) as cm:
Expand Down
12 changes: 3 additions & 9 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,10 @@ def test_postgres(self):
self.validate_all(
"SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]",
read={
"duckdb": "SELECT ARRAY_HAS_ALL([1, 2, 3], [1, 2])",
"duckdb": "SELECT [1, 2, 3] @> [1, 2]",
},
write={
"duckdb": "SELECT ARRAY_HAS_ALL([1, 2, 3], [1, 2])",
"duckdb": "SELECT [1, 2, 3] @> [1, 2]",
"mysql": UnsupportedError,
"postgres": "SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]",
},
Expand Down Expand Up @@ -398,13 +398,6 @@ def test_postgres(self):
"postgres": "SELECT (data ->> 'en-US') AS acat FROM my_table",
},
)
self.validate_all(
"SELECT ARRAY[1, 2, 3] && ARRAY[1, 2]",
write={
"": "SELECT ARRAY_OVERLAPS(ARRAY(1, 2, 3), ARRAY(1, 2))",
"postgres": "SELECT ARRAY[1, 2, 3] && ARRAY[1, 2]",
},
)
self.validate_all(
"SELECT JSON_EXTRACT_PATH_TEXT(x, k1, k2, k3) FROM t",
read={
Expand Down Expand Up @@ -802,6 +795,7 @@ def test_postgres(self):
)
self.validate_identity("SELECT OVERLAY(a PLACING b FROM 1)")
self.validate_identity("SELECT OVERLAY(a PLACING b FROM 1 FOR 1)")
self.validate_identity("ARRAY[1, 2, 3] && ARRAY[1, 2]").assert_is(exp.ArrayOverlaps)

def test_ddl(self):
# Checks that user-defined types are parsed into DataType instead of Identifier
Expand Down
Loading