Skip to content

Commit a0f5fb0

Browse files
[BUGFIX] Preserve quoting when serializing quoted table names (#11357)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c8c8d82 commit a0f5fb0

File tree

2 files changed

+93
-14
lines changed

2 files changed

+93
-14
lines changed

great_expectations/datasource/fluent/sql_datasource.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@
8282
)
8383

8484
if TYPE_CHECKING:
85-
from sqlalchemy.sql import quoted_name # noqa: TID251 # type-checking only
86-
8785
# We re-import sqlalchemy here to make type-checking and our compatability layer
8886
# play nice with one another
8987
from great_expectations.compatibility import sqlalchemy
@@ -622,7 +620,7 @@ def get_batch(self, batch_request: BatchRequest) -> Batch:
622620
else:
623621
sql_partitioner = None
624622

625-
batch_spec_kwargs: dict[str, str | dict | None]
623+
batch_spec_kwargs: Dict[str, str | dict | None]
626624
requests = self._fully_specified_batch_requests(batch_request)
627625
unsorted_metadata_dicts = [self._get_batch_metadata_from_batch_request(r) for r in requests]
628626

@@ -921,7 +919,7 @@ def _validate_batch_request(self, batch_request: BatchRequest) -> None:
921919
f"but actually has form:\n{pf(batch_request.dict())}\n"
922920
)
923921

924-
def _create_batch_spec_kwargs(self) -> dict[str, Any]:
922+
def _create_batch_spec_kwargs(self) -> Dict[str, Any]:
925923
"""Creates batch_spec_kwargs used to instantiate a SqlAlchemyDatasourceBatchSpec or RuntimeQueryBatchSpec
926924
927925
This is called by get_batch to generate the batch.
@@ -974,7 +972,7 @@ def as_selectable(self) -> sqlalchemy.Selectable:
974972
return sa.select(sa.text(self.query.lstrip()[6:])).subquery()
975973

976974
@override
977-
def _create_batch_spec_kwargs(self) -> dict[str, Any]:
975+
def _create_batch_spec_kwargs(self) -> Dict[str, Any]:
978976
return {
979977
"data_asset_name": self.name,
980978
"query": self.query,
@@ -1005,6 +1003,8 @@ class TableAsset(_SQLAsset):
10051003
)
10061004
schema_name: Optional[str] = None
10071005

1006+
_quote_character: Optional[str] = None
1007+
10081008
@property
10091009
def qualified_name(self) -> str:
10101010
return f"{self.schema_name}.{self.table_name}" if self.schema_name else self.table_name
@@ -1019,9 +1019,7 @@ def _default_table_name(cls, table_name: str, values: dict, **kwargs) -> str:
10191019
return validated_table_name
10201020

10211021
@pydantic.validator("table_name")
1022-
def _resolve_quoted_name(cls, table_name: str) -> str | quoted_name:
1023-
table_name_is_quoted: bool = cls._is_bracketed_by_quotes(table_name)
1024-
1022+
def _resolve_quoted_name(cls, table_name: str, values: Dict[str, Any]) -> str:
10251023
# We reimport sqlalchemy from our compatability layer because we make
10261024
# quoted_name a top level import there.
10271025
from great_expectations.compatibility import sqlalchemy
@@ -1030,19 +1028,34 @@ def _resolve_quoted_name(cls, table_name: str) -> str | quoted_name:
10301028
if isinstance(table_name, sqlalchemy.quoted_name):
10311029
return table_name
10321030

1033-
if table_name_is_quoted:
1031+
quote: bool = cls._is_bracketed_by_quotes(table_name)
1032+
1033+
if quote:
10341034
# https://docs.sqlalchemy.org/en/20/core/sqlelement.html#sqlalchemy.sql.expression.quoted_name.quote
10351035
# Remove the quotes and add them back using the sqlalchemy.quoted_name function
10361036
# TODO: We need to handle nested quotes
1037-
table_name = table_name.strip("'").strip('"')
1037+
values["_quote_character"] = table_name[0]
1038+
quote = True
1039+
table_name = table_name.strip("".join(DEFAULT_QUOTE_CHARACTERS))
10381040

10391041
return sqlalchemy.quoted_name(
10401042
value=table_name,
1041-
quote=table_name_is_quoted,
1043+
quote=quote,
10421044
)
10431045

10441046
return table_name
10451047

1048+
@override
1049+
def dict(self, **kwargs) -> Dict[str, Any]:
1050+
original_dict = super().dict(**kwargs)
1051+
1052+
# we need to ensure we retain the quotes when serializing quoted names
1053+
qc = self._quote_character
1054+
if qc is not None:
1055+
original_dict["table_name"] = f"{qc}{self.table_name}{qc}"
1056+
1057+
return original_dict
1058+
10461059
@override
10471060
def test_connection(self) -> None:
10481061
"""Test the connection for the TableAsset.
@@ -1081,7 +1094,7 @@ def as_selectable(self) -> sqlalchemy.Selectable:
10811094
return sa.table(self.table_name, schema=self.schema_name)
10821095

10831096
@override
1084-
def _create_batch_spec_kwargs(self) -> dict[str, Any]:
1097+
def _create_batch_spec_kwargs(self) -> Dict[str, Any]:
10851098
return {
10861099
"type": "table",
10871100
"data_asset_name": self.name,
@@ -1091,7 +1104,7 @@ def _create_batch_spec_kwargs(self) -> dict[str, Any]:
10911104
}
10921105

10931106
@override
1094-
def _create_batch_spec(self, batch_spec_kwargs: dict) -> SqlAlchemyDatasourceBatchSpec:
1107+
def _create_batch_spec(self, batch_spec_kwargs: Dict) -> SqlAlchemyDatasourceBatchSpec:
10951108
return SqlAlchemyDatasourceBatchSpec(**batch_spec_kwargs)
10961109

10971110
@staticmethod
@@ -1135,7 +1148,7 @@ def _warn_for_more_specific_datasource_type(connection_string: str) -> None:
11351148

11361149
connector: str = connection_string.split("://")[0].split("+")[0]
11371150

1138-
type_lookup_plus: dict[str, str] = {
1151+
type_lookup_plus: Dict[str, str] = {
11391152
n: DataSourceManager.type_lookup[n].__name__
11401153
for n in DataSourceManager.type_lookup.type_names()
11411154
}

tests/datasource/fluent/test_sql_datasources.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,23 @@ def test_unquoted_schema_names_are_added_as_lowercase(
379379
)
380380
assert table_asset.schema_name == schema_name.lower()
381381

382+
@pytest.mark.parametrize("table_name", ["my_table", "MY_TABLE", "My_Table"])
383+
def test_unquoted_table_names_are_unquoted(
384+
self,
385+
sql_datasource_table_asset_test_connection_noop: SQLDatasource,
386+
table_name: str,
387+
):
388+
my_datasource: SQLDatasource = sql_datasource_table_asset_test_connection_noop
389+
390+
table_asset = my_datasource.add_table_asset(
391+
name="my_table_asset",
392+
table_name=table_name,
393+
schema_name="my_schema",
394+
)
395+
assert isinstance(table_asset.table_name, sqlalchemy.quoted_name)
396+
assert table_asset.table_name == table_name
397+
assert not table_asset.table_name.quote
398+
382399
@pytest.mark.parametrize(
383400
"schema_name",
384401
[
@@ -404,6 +421,55 @@ def test_quoted_schema_names_are_not_modified(
404421
)
405422
assert table_asset.schema_name == schema_name
406423

424+
@pytest.mark.parametrize(
425+
"table_name",
426+
[
427+
'"my_table"',
428+
'"MY_TABLE"',
429+
'"My_Table"',
430+
"'my_table'",
431+
"'MY_TABLE'",
432+
"'My_Table'",
433+
],
434+
)
435+
def test_quoted_table_names_are_quoted(
436+
self,
437+
sql_datasource_table_asset_test_connection_noop: SQLDatasource,
438+
table_name: str,
439+
):
440+
my_datasource: SQLDatasource = sql_datasource_table_asset_test_connection_noop
441+
442+
table_asset = my_datasource.add_table_asset(
443+
name="my_table_asset",
444+
table_name=table_name,
445+
schema_name="my_schema",
446+
)
447+
assert isinstance(table_asset.table_name, sqlalchemy.quoted_name)
448+
assert table_asset.table_name == table_name.lstrip("\"'").rstrip("\"'")
449+
assert table_asset.table_name.quote
450+
451+
@pytest.mark.parametrize(
452+
"table_name",
453+
[
454+
'"my_table"',
455+
'"MY_TABLE"',
456+
'"My_Table"',
457+
"'my_table'",
458+
"'MY_TABLE'",
459+
"'My_Table'",
460+
"my_table",
461+
"MY_TABLE",
462+
"My_Table",
463+
],
464+
)
465+
def test_table_name_serialization_preserves_quotes(
466+
self,
467+
table_name: str,
468+
):
469+
table_asset = TableAsset(name="my_table_asset", table_name=table_name)
470+
serialized = table_asset.dict()
471+
assert serialized["table_name"] == table_name
472+
407473

408474
if __name__ == "__main__":
409475
pytest.main([__file__, "-vv"])

0 commit comments

Comments
 (0)