Skip to content

Commit 1b24781

Browse files
[FEATURE] Support asymmetric quoted identifiers in dialect quoting (#11652)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f3bf7d9 commit 1b24781

File tree

3 files changed

+198
-26
lines changed

3 files changed

+198
-26
lines changed

great_expectations/execution_engine/sqlalchemy_dialect.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from enum import Enum
4-
from typing import Any, Final, List, Literal, Mapping, Union, overload
4+
from typing import Any, Final, List, Mapping, Union, overload
55

66
from great_expectations.compatibility.sqlalchemy import quoted_name
77
from great_expectations.compatibility.typing_extensions import override
@@ -63,30 +63,33 @@ def get_all_dialects(cls) -> List[GXSqlDialect]:
6363
return [dialect for dialect in cls if dialect != GXSqlDialect.OTHER]
6464

6565

66-
DIALECT_IDENTIFIER_QUOTE_STRINGS: Final[Mapping[GXSqlDialect, Literal['"', "`"]]] = {
67-
# TODO: add other dialects
68-
GXSqlDialect.DATABRICKS: "`",
69-
GXSqlDialect.MYSQL: "`",
70-
GXSqlDialect.POSTGRESQL: '"',
71-
GXSqlDialect.SNOWFLAKE: '"',
72-
GXSqlDialect.SQLITE: '"',
73-
GXSqlDialect.TRINO: "`",
66+
DIALECT_IDENTIFIER_QUOTE_STRINGS: Final[Mapping[GXSqlDialect, tuple[str, str]]] = {
67+
GXSqlDialect.DATABRICKS: ("`", "`"),
68+
GXSqlDialect.MSSQL: ("[", "]"),
69+
GXSqlDialect.MYSQL: ("`", "`"),
70+
GXSqlDialect.POSTGRESQL: ('"', '"'),
71+
GXSqlDialect.SNOWFLAKE: ('"', '"'),
72+
GXSqlDialect.SQLITE: ('"', '"'),
73+
GXSqlDialect.TRINO: ("`", "`"),
7474
}
7575

7676

7777
def quote_str(unquoted_identifier: str, dialect: GXSqlDialect) -> str:
7878
"""Quote a string using the specified dialect's quote character."""
79-
quote_char = DIALECT_IDENTIFIER_QUOTE_STRINGS[dialect]
80-
if unquoted_identifier.startswith(quote_char) or unquoted_identifier.endswith(quote_char):
79+
open_q, close_q = DIALECT_IDENTIFIER_QUOTE_STRINGS[dialect]
80+
if unquoted_identifier.startswith(open_q) or unquoted_identifier.endswith(close_q):
8181
raise ValueError( # noqa: TRY003 # FIXME CoP
82-
f"Identifier {unquoted_identifier} already uses quote character {quote_char}"
82+
f"Identifier {unquoted_identifier} already uses quote characters {open_q}{close_q}"
8383
)
84-
return f"{quote_char}{unquoted_identifier}{quote_char}"
84+
return f"{open_q}{unquoted_identifier}{close_q}"
8585

8686

8787
def _strip_quotes(s: str, dialect: GXSqlDialect) -> str:
88-
quote_str = DIALECT_IDENTIFIER_QUOTE_STRINGS[dialect]
89-
if s.startswith(quote_str) and s.endswith(quote_str):
88+
open_q, close_q = DIALECT_IDENTIFIER_QUOTE_STRINGS[dialect]
89+
if s.startswith(open_q) and s.endswith(close_q):
90+
return s[len(open_q) : -len(close_q)]
91+
# MSSQL also accepts double quotes when QUOTED_IDENTIFIER is ON
92+
if dialect == GXSqlDialect.MSSQL and s.startswith('"') and s.endswith('"'):
9093
return s[1:-1]
9194
return s
9295

tests/execution_engine/test_sqlalchemy_dialect.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import pytest
22

3-
from great_expectations.execution_engine.sqlalchemy_dialect import GXSqlDialect
3+
from great_expectations.execution_engine.sqlalchemy_dialect import (
4+
GXSqlDialect,
5+
_strip_quotes,
6+
quote_str,
7+
wrap_identifier,
8+
)
49

510

611
@pytest.mark.unit
@@ -31,3 +36,93 @@ def test_get_all_dialect_names_no_other_dialects():
3136
@pytest.mark.unit
3237
def test_get_all_dialects_no_other_dialects():
3338
assert GXSqlDialect.OTHER not in GXSqlDialect.get_all_dialects()
39+
40+
41+
@pytest.mark.unit
42+
@pytest.mark.parametrize(
43+
"dialect,expected",
44+
[
45+
(GXSqlDialect.DATABRICKS, "`col`"),
46+
(GXSqlDialect.MSSQL, "[col]"),
47+
(GXSqlDialect.MYSQL, "`col`"),
48+
(GXSqlDialect.POSTGRESQL, '"col"'),
49+
(GXSqlDialect.SNOWFLAKE, '"col"'),
50+
(GXSqlDialect.SQLITE, '"col"'),
51+
(GXSqlDialect.TRINO, "`col`"),
52+
],
53+
)
54+
def test_quote_str(dialect, expected):
55+
assert quote_str("col", dialect) == expected
56+
57+
58+
@pytest.mark.unit
59+
@pytest.mark.parametrize(
60+
"dialect,quoted_input",
61+
[
62+
(GXSqlDialect.DATABRICKS, "`col`"),
63+
(GXSqlDialect.MSSQL, "[col]"),
64+
(GXSqlDialect.MYSQL, "`col`"),
65+
(GXSqlDialect.POSTGRESQL, '"col"'),
66+
(GXSqlDialect.SNOWFLAKE, '"col"'),
67+
(GXSqlDialect.SQLITE, '"col"'),
68+
(GXSqlDialect.TRINO, "`col`"),
69+
],
70+
)
71+
def test_quote_str_already_quoted_raises(dialect, quoted_input):
72+
with pytest.raises(ValueError, match="already uses quote characters"):
73+
quote_str(quoted_input, dialect)
74+
75+
76+
@pytest.mark.unit
77+
@pytest.mark.parametrize(
78+
"dialect,quoted_input",
79+
[
80+
(GXSqlDialect.DATABRICKS, "`col`"),
81+
(GXSqlDialect.MSSQL, "[col]"),
82+
(GXSqlDialect.MSSQL, '"col"'),
83+
(GXSqlDialect.MYSQL, "`col`"),
84+
(GXSqlDialect.POSTGRESQL, '"col"'),
85+
(GXSqlDialect.SNOWFLAKE, '"col"'),
86+
(GXSqlDialect.SQLITE, '"col"'),
87+
(GXSqlDialect.TRINO, "`col`"),
88+
],
89+
)
90+
def test_strip_quotes(dialect, quoted_input):
91+
assert _strip_quotes(quoted_input, dialect) == "col"
92+
93+
94+
@pytest.mark.unit
95+
@pytest.mark.parametrize(
96+
"dialect",
97+
[
98+
GXSqlDialect.DATABRICKS,
99+
GXSqlDialect.MSSQL,
100+
GXSqlDialect.MYSQL,
101+
GXSqlDialect.POSTGRESQL,
102+
GXSqlDialect.SNOWFLAKE,
103+
GXSqlDialect.SQLITE,
104+
GXSqlDialect.TRINO,
105+
],
106+
)
107+
def test_strip_quotes_unquoted_noop(dialect):
108+
assert _strip_quotes("col", dialect) == "col"
109+
110+
111+
@pytest.mark.unit
112+
@pytest.mark.parametrize(
113+
"dialect,quoted_input",
114+
[
115+
(GXSqlDialect.DATABRICKS, "`col`"),
116+
(GXSqlDialect.MSSQL, "[col]"),
117+
(GXSqlDialect.MSSQL, '"col"'),
118+
(GXSqlDialect.MYSQL, "`col`"),
119+
(GXSqlDialect.POSTGRESQL, '"col"'),
120+
(GXSqlDialect.SNOWFLAKE, '"col"'),
121+
(GXSqlDialect.SQLITE, '"col"'),
122+
(GXSqlDialect.TRINO, "`col`"),
123+
],
124+
)
125+
def test_wrap_identifier_strips_quotes(dialect, quoted_input):
126+
result = wrap_identifier(quoted_input, dialect=dialect)
127+
assert str(result) == "col"
128+
assert result.quote is True

tests/integration/fluent/test_sql_datasources.py

Lines changed: 84 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
)
4646
from great_expectations.core.expectation_suite import ExpectationSuite
4747
from great_expectations.core.validation_definition import ValidationDefinition
48+
from great_expectations.datasource.fluent.sql_server_datasource import (
49+
SQLServerAuthConnectionDetails,
50+
)
4851
from great_expectations.execution_engine.sqlalchemy_dialect import (
4952
DIALECT_IDENTIFIER_QUOTE_STRINGS,
5053
GXSqlDialect,
@@ -64,6 +67,9 @@
6467
SQLDatasource,
6568
SqliteDatasource,
6669
)
70+
from great_expectations.datasource.fluent.sql_server_datasource import (
71+
SQLServerDatasource,
72+
)
6773
from great_expectations.execution_engine import SqlAlchemyExecutionEngine
6874

6975
TERMINAL_WIDTH: Final = shutil.get_terminal_size().columns
@@ -83,7 +89,7 @@
8389
# sqlite db files should be using fresh tmp_path on every test
8490
DO_NOT_DROP_TABLES: set[str] = {"sqlite"}
8591

86-
DatabaseType: TypeAlias = Literal["postgres", "sqlite", "trino"]
92+
DatabaseType: TypeAlias = Literal["postgres", "sqlite", "trino", "mssql"]
8793
TableNameCase: TypeAlias = Literal[
8894
"quoted_lower",
8995
"quoted_mixed",
@@ -122,6 +128,14 @@
122128
"quoted_mixed": f'"{TEST_TABLE_NAME.title()}"',
123129
"unquoted_mixed": TEST_TABLE_NAME.title(),
124130
},
131+
"mssql": {
132+
"unquoted_lower": TEST_TABLE_NAME.lower(),
133+
"quoted_lower": f"[{TEST_TABLE_NAME.lower()}]",
134+
"unquoted_upper": TEST_TABLE_NAME.upper(),
135+
"quoted_upper": f"[{TEST_TABLE_NAME.upper()}]",
136+
"quoted_mixed": f"[{TEST_TABLE_NAME.title()}]",
137+
"unquoted_mixed": TEST_TABLE_NAME.title(),
138+
},
125139
}
126140

127141
# column names
@@ -343,6 +357,24 @@ def __call__(
343357
) -> None: ...
344358

345359

360+
def _create_schema_ddl(schema: str, is_mssql: bool) -> str:
361+
if is_mssql:
362+
return (
363+
f"IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{schema}')"
364+
f" EXEC('CREATE SCHEMA {schema}')"
365+
)
366+
return f"CREATE SCHEMA IF NOT EXISTS {schema}"
367+
368+
369+
def _create_table_ddl(qualified_table_name: str, table_columns: str, is_mssql: bool) -> str:
370+
if is_mssql:
371+
return (
372+
f"IF OBJECT_ID(N'{qualified_table_name}', N'U') IS NULL"
373+
f" CREATE TABLE {qualified_table_name}{table_columns}"
374+
)
375+
return f"CREATE TABLE IF NOT EXISTS {qualified_table_name}{table_columns}"
376+
377+
346378
@pytest.fixture(
347379
scope="class",
348380
)
@@ -382,6 +414,7 @@ def _table_factory(
382414
)
383415
dialect = GXSqlDialect(sa_engine.dialect.name)
384416
created_tables: list[dict[Literal["table_name", "schema"], str | None]] = []
417+
is_mssql = dialect == GXSqlDialect.MSSQL
385418

386419
with gx_engine.get_connection() as conn:
387420
quoted_upper_col: str = quote_str(QUOTED_UPPER_COL, dialect=dialect)
@@ -390,18 +423,18 @@ def _table_factory(
390423
quoted_mixed_case: str = quote_str(QUOTED_MIXED_CASE, dialect=dialect)
391424

392425
if schema:
393-
conn.execute(TextClause(f"CREATE SCHEMA IF NOT EXISTS {schema}"))
426+
conn.execute(TextClause(_create_schema_ddl(schema, is_mssql)))
394427
for name in table_names:
395428
qualified_table_name = f"{schema}.{name}" if schema else name
396-
# TODO: use dialect specific quotes
397-
create_tables: str = (
398-
f"CREATE TABLE IF NOT EXISTS {qualified_table_name}"
429+
table_columns: str = (
399430
" (id INTEGER, name VARCHAR(255),"
400431
f" {quoted_upper_col} VARCHAR(255), {quoted_lower_col} VARCHAR(255),"
401432
f" {UNQUOTED_UPPER_COL} VARCHAR(255), {UNQUOTED_LOWER_COL} VARCHAR(255),"
402433
f" {quoted_mixed_case} VARCHAR(255), {quoted_w_dots} VARCHAR(255))"
403434
)
404-
conn.execute(TextClause(create_tables))
435+
conn.execute(
436+
TextClause(_create_table_ddl(qualified_table_name, table_columns, is_mssql))
437+
)
405438
if data:
406439
insert_data = (
407440
f"INSERT INTO {qualified_table_name}"
@@ -470,6 +503,24 @@ def sqlite_ds(context: EphemeralDataContext, tmp_path: pathlib.Path) -> SqliteDa
470503
return ds
471504

472505

506+
@pytest.fixture
507+
def mssql_ds(context: EphemeralDataContext) -> SQLServerDatasource:
508+
ds = context.data_sources.add_sql_server(
509+
"mssql",
510+
connection_string=SQLServerAuthConnectionDetails(
511+
host="127.0.0.1",
512+
port=1433,
513+
database="test_ci",
514+
schema="dbo",
515+
username="sa",
516+
password="ReallyStrongPwd1234%^&*",
517+
driver="ODBC Driver 18 for SQL Server",
518+
encrypt="Optional",
519+
),
520+
)
521+
return ds
522+
523+
473524
@pytest.fixture(
474525
params=[
475526
param(
@@ -551,6 +602,24 @@ def test_sqlite(
551602

552603
sqlite_ds.add_table_asset(asset_name, table_name=table_name)
553604

605+
@pytest.mark.mssql
606+
def test_mssql(
607+
self,
608+
mssql_ds: SQLServerDatasource,
609+
asset_name: TableNameCase,
610+
table_factory: TableFactory,
611+
):
612+
table_name = TABLE_NAME_MAPPING["mssql"].get(asset_name)
613+
if not table_name:
614+
pytest.skip(f"no '{asset_name}' table_name for mssql")
615+
# create table
616+
table_factory(gx_engine=mssql_ds.get_execution_engine(), table_names={table_name})
617+
618+
table_names: list[str] = inspect(mssql_ds.get_engine()).get_table_names()
619+
print(f"mssql tables:\n{pf(table_names)}))")
620+
621+
mssql_ds.add_table_asset(asset_name, table_name=table_name)
622+
554623
@pytest.mark.filterwarnings( # snowflake `add_table_asset` raises warning on passing a schema
555624
"once::great_expectations.datasource.fluent.GxDatasourceWarning"
556625
)
@@ -560,6 +629,7 @@ def test_sqlite(
560629
param("trino", None, marks=[pytest.mark.trino]),
561630
param("postgres", None, marks=[pytest.mark.postgresql]),
562631
param("sqlite", None, marks=[pytest.mark.sqlite]),
632+
param("mssql", None, marks=[pytest.mark.mssql]),
563633
],
564634
)
565635
def test_checkpoint_run(
@@ -625,10 +695,10 @@ def _is_quote_char_dialect_mismatch(
625695
dialect: GXSqlDialect,
626696
column_name: str | quoted_name,
627697
) -> bool:
628-
quote_char = column_name[0] if column_name[0] in ("'", '"', "`") else None
698+
quote_char = column_name[0] if column_name[0] in ("'", '"', "`", "[") else None
629699
if quote_char:
630-
dialect_quote_char = DIALECT_IDENTIFIER_QUOTE_STRINGS[dialect]
631-
if quote_char != dialect_quote_char:
700+
expected = DIALECT_IDENTIFIER_QUOTE_STRINGS[dialect][0]
701+
if quote_char != expected:
632702
return True
633703
return False
634704

@@ -639,8 +709,12 @@ def _raw_query_check_column_exists(
639709
gx_execution_engine: SqlAlchemyExecutionEngine,
640710
) -> bool:
641711
"""Use a simple 'SELECT {column_name_param} from {qualified_table_name};' query to check if the column exists.'""" # noqa: E501 # FIXME CoP
642-
with gx_execution_engine.get_connection() as connection:
712+
dialect_name = gx_execution_engine.engine.dialect.name
713+
if dialect_name == "mssql":
714+
query = f"""SELECT TOP 1 {column_name_param} FROM {qualified_table_name};"""
715+
else:
643716
query = f"""SELECT {column_name_param} FROM {qualified_table_name} LIMIT 1;"""
717+
with gx_execution_engine.get_connection() as connection:
644718
print(f"query:\n {query}")
645719
# an exception will be raised if the column does not exist
646720
try:

0 commit comments

Comments
 (0)