Skip to content

Commit 31bdf47

Browse files
authored
[FEATURE] infer primary keys during column_types metric fetch (#11554)
1 parent 6b72405 commit 31bdf47

File tree

9 files changed

+298
-88
lines changed

9 files changed

+298
-88
lines changed

great_expectations/expectations/metrics/util.py

Lines changed: 105 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -376,87 +376,132 @@ def __getitem__(self, key: Any) -> Any:
376376
return item
377377

378378

379-
def get_sqlalchemy_column_metadata( # noqa: C901 # FIXME CoP
379+
def get_sqlalchemy_column_metadata(
380380
execution_engine: SqlAlchemyExecutionEngine,
381381
table_selectable: sqlalchemy.Select,
382382
schema_name: Optional[str] = None,
383383
) -> Sequence[Mapping[str, Any]] | None:
384384
try:
385-
columns: Sequence[Dict[str, Any]]
386-
387385
engine = execution_engine.engine
388386
inspector = execution_engine.get_inspector()
389-
try:
390-
# if a custom query was passed
391-
if sqlalchemy.TextClause and isinstance(table_selectable, sqlalchemy.TextClause): # type: ignore[truthy-function]
392-
if hasattr(table_selectable, "selected_columns"):
393-
# New in version 1.4.
394-
columns = table_selectable.selected_columns.columns
395-
else:
396-
# Implicit subquery for columns().column was deprecated in SQLAlchemy 1.4
397-
# We must explicitly create a subquery
398-
columns = table_selectable.columns().subquery().columns
399-
elif sqlalchemy.quoted_name and isinstance(table_selectable, sqlalchemy.quoted_name): # type: ignore[truthy-function]
400-
columns = inspector.get_columns(
401-
table_name=table_selectable,
402-
schema=schema_name,
403-
)
404-
else:
405-
logger.warning("unexpected table_selectable type")
406-
columns = inspector.get_columns( # type: ignore[assignment]
407-
table_name=str(table_selectable),
387+
388+
# Determine selectable type once
389+
is_text_clause = sqlalchemy.TextClause and isinstance( # type: ignore[truthy-function]
390+
table_selectable, sqlalchemy.TextClause
391+
)
392+
is_quoted_name = sqlalchemy.quoted_name and isinstance( # type: ignore[truthy-function]
393+
table_selectable, sqlalchemy.quoted_name
394+
)
395+
table_name = str(table_selectable)
396+
397+
# Fetch primary key info (skip for custom queries/TextClause)
398+
primary_key_columns: set[str] = set()
399+
if not is_text_clause:
400+
try:
401+
pk_constraint = inspector.get_pk_constraint(
402+
table_name=table_name,
408403
schema=schema_name,
409404
)
410-
except (
411-
KeyError,
412-
AttributeError,
413-
sa.exc.NoSuchTableError,
414-
sa.exc.ProgrammingError,
415-
) as exc:
416-
logger.debug(f"{type(exc).__name__} while introspecting columns", exc_info=exc)
417-
logger.info(f"While introspecting columns {exc!r}; attempting reflection fallback")
418-
# we will get a KeyError for temporary tables, since
419-
# reflection will not find the temporary schema
420-
columns = column_reflection_fallback(
421-
selectable=table_selectable,
422-
dialect=engine.dialect,
423-
sqlalchemy_engine=engine,
424-
)
405+
primary_key_columns = set(pk_constraint.get("constrained_columns", []))
406+
except (
407+
sa.exc.NoSuchTableError,
408+
sa.exc.ProgrammingError,
409+
NotImplementedError,
410+
AttributeError,
411+
) as e:
412+
logger.debug(f"Could not fetch primary key info for {table_name}: {e!r}")
413+
414+
# Fetch column metadata
415+
columns = _get_columns_from_selectable(
416+
table_selectable,
417+
inspector,
418+
schema_name,
419+
is_text_clause,
420+
is_quoted_name,
421+
)
425422

426-
# Use fallback because for mssql and trino reflection mechanisms do not throw an error but return an empty list # noqa: E501 # FIXME CoP
427-
if len(columns) == 0:
423+
# Use fallback for mssql/trino or when primary introspection fails
424+
if not columns:
428425
columns = column_reflection_fallback(
429426
selectable=table_selectable,
430427
dialect=engine.dialect,
431428
sqlalchemy_engine=engine,
432429
)
433430

434-
dialect_name = execution_engine.dialect.name
435-
if dialect_name in [
436-
GXSqlDialect.DATABRICKS,
437-
GXSqlDialect.POSTGRESQL,
438-
GXSqlDialect.SNOWFLAKE,
439-
GXSqlDialect.TRINO,
440-
]:
441-
# WARNING: Do not alter columns in place, as they are cached on the inspector
442-
columns_copy = [column.copy() for column in columns]
443-
for column in columns_copy:
444-
if column.get("type"):
445-
# When using column_reflection_fallback, we might not be able to
446-
# extract the column type, and only have the column name
447-
compiled_type = column["type"].compile(dialect=execution_engine.dialect)
448-
# Make the type case-insensitive
449-
column["type"] = CaseInsensitiveString(str(compiled_type))
450-
451-
# Wrap all columns in CaseInsensitiveNameDict for all three dialects
452-
return [CaseInsensitiveNameDict(column) for column in columns_copy]
453-
454-
return columns
431+
# Build result: copy columns, add PK info, apply dialect-specific formatting
432+
return _build_column_metadata_result(columns, primary_key_columns, execution_engine)
455433
except AttributeError as e:
456434
logger.debug(f"Error while introspecting columns: {e!r}", exc_info=e)
457435
return None
458436

459437

438+
def _get_columns_from_selectable(
439+
table_selectable: sqlalchemy.Select,
440+
inspector: Any,
441+
schema_name: Optional[str],
442+
is_text_clause: bool,
443+
is_quoted_name: bool,
444+
) -> Sequence[Dict[str, Any]]:
445+
"""Extract column metadata from a selectable, using reflection fallback on failure."""
446+
try:
447+
if is_text_clause:
448+
# Custom SQL query - extract columns from the clause itself (SQLAlchemy 1.4+)
449+
if hasattr(table_selectable, "selected_columns"):
450+
return [
451+
{"name": col.name, "type": col.type}
452+
for col in table_selectable.selected_columns.values()
453+
]
454+
# Pre-1.4 SQLAlchemy is no longer supported; fall back to reflection
455+
logger.debug("TextClause without selected_columns; using reflection fallback")
456+
return []
457+
458+
if not is_quoted_name:
459+
logger.warning("unexpected table_selectable type")
460+
461+
return inspector.get_columns(
462+
table_name=table_selectable if is_quoted_name else str(table_selectable),
463+
schema=schema_name,
464+
)
465+
except (KeyError, AttributeError, sa.exc.NoSuchTableError, sa.exc.ProgrammingError) as exc:
466+
logger.debug(f"{type(exc).__name__} while introspecting columns", exc_info=exc)
467+
logger.info(f"While introspecting columns {exc!r}; attempting reflection fallback")
468+
return [] # Caller will use column_reflection_fallback
469+
470+
471+
def _build_column_metadata_result(
472+
columns: Sequence[Dict[str, Any]],
473+
primary_key_columns: set[str],
474+
execution_engine: SqlAlchemyExecutionEngine,
475+
) -> Sequence[Mapping[str, Any]]:
476+
"""Build final column metadata with PK info and dialect-specific formatting."""
477+
# Copy columns to avoid mutating cached inspector data
478+
pk_columns_lower = {pk.casefold() for pk in primary_key_columns}
479+
result = [
480+
{
481+
**col,
482+
"primary_key": col.get("name", "").casefold() in pk_columns_lower,
483+
}
484+
for col in (c.copy() for c in columns)
485+
]
486+
487+
# Apply case-insensitive formatting for specific dialects
488+
dialect_name = execution_engine.dialect.name
489+
case_insensitive_dialects = {
490+
GXSqlDialect.DATABRICKS,
491+
GXSqlDialect.POSTGRESQL,
492+
GXSqlDialect.SNOWFLAKE,
493+
GXSqlDialect.TRINO,
494+
}
495+
if dialect_name in case_insensitive_dialects:
496+
for col in result:
497+
if col.get("type"):
498+
compiled_type = col["type"].compile(dialect=execution_engine.dialect)
499+
col["type"] = CaseInsensitiveString(str(compiled_type))
500+
return [CaseInsensitiveNameDict(col) for col in result]
501+
502+
return result
503+
504+
460505
def column_reflection_fallback( # noqa: C901, PLR0912, PLR0915 # FIXME CoP
461506
selectable: sqlalchemy.Select,
462507
dialect: sqlalchemy.Dialect,

great_expectations/experimental/metric_repository/metric_retriever.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,16 @@ def _get_table_column_types(self, batch_request: BatchRequest) -> Metric:
316316
{
317317
"name": raw_column_type["name"],
318318
"type": str(raw_column_type["type"]),
319+
"primary_key": raw_column_type.get("primary_key", False),
319320
}
320321
)
321322
else:
322-
column_types_converted_to_str.append({"name": raw_column_type["name"]})
323+
column_types_converted_to_str.append(
324+
{
325+
"name": raw_column_type["name"],
326+
"primary_key": raw_column_type.get("primary_key", False),
327+
}
328+
)
323329

324330
return TableMetric[List[str]](
325331
batch_id=batch_id,

great_expectations/metrics/metric_results.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class Config:
7979

8080
name: str
8181
type: str
82+
primary_key: bool
8283

8384

8485
class TableColumnTypesResult(MetricResult[list[ColumnType]]): ...

tests/expectations/metrics/conftest.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Iterable, Optional, Union
1+
from typing import Generator, Iterable, Optional, Union
22

33
import pytest
44

55
from great_expectations.compatibility.sqlalchemy import (
66
sqlalchemy as sa,
77
)
88
from great_expectations.core.metric_domain_types import MetricDomainTypes
9+
from great_expectations.data_context.util import file_relative_path
910
from great_expectations.execution_engine import SqlAlchemyExecutionEngine
1011
from great_expectations.execution_engine.sqlalchemy_batch_data import SqlAlchemyBatchData
1112

@@ -72,3 +73,25 @@ def mock_sqlalchemy_execution_engine():
7273
execution_engine = MockSqlAlchemyExecutionEngine()
7374
execution_engine._batch_manager = MockBatchManager()
7475
return execution_engine
76+
77+
78+
@pytest.fixture
79+
def sql_data_connector_test_db_execution_engine() -> Generator[
80+
SqlAlchemyExecutionEngine, None, None
81+
]:
82+
"""Provide a sqlite ExecutionEngine pointing to the SQL data connector test database.
83+
84+
The engine and its underlying connections are explicitly closed after use to avoid
85+
leaking sqlite3.Connection objects (which surface as ResourceWarning in CI).
86+
"""
87+
db_file = file_relative_path(
88+
__file__,
89+
"../../test_sets/test_cases_for_sql_data_connector.db",
90+
)
91+
engine: sa.engine.Engine = sa.create_engine(f"sqlite:///{db_file}")
92+
execution_engine = SqlAlchemyExecutionEngine(engine=engine)
93+
94+
try:
95+
yield execution_engine
96+
finally:
97+
execution_engine.close()

tests/expectations/metrics/test_metrics_util.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,140 @@ def test__eq__(
632632
assert input_case_insensitive != other
633633

634634

635+
@pytest.mark.unit
636+
def test_get_sqlalchemy_column_metadata_includes_primary_key_field(
637+
sql_data_connector_test_db_execution_engine,
638+
):
639+
"""Test that get_sqlalchemy_column_metadata includes primary_key field for all columns."""
640+
from great_expectations.execution_engine.sqlalchemy_batch_data import SqlAlchemyBatchData
641+
from great_expectations.expectations.metrics.util import get_sqlalchemy_column_metadata
642+
643+
engine = sql_data_connector_test_db_execution_engine
644+
645+
# Test table with single primary key
646+
batch_data = SqlAlchemyBatchData(execution_engine=engine, table_name="table_with_single_pk")
647+
engine.load_batch_data("__test_single_pk", batch_data)
648+
649+
columns = get_sqlalchemy_column_metadata(
650+
execution_engine=engine,
651+
table_selectable=sqlalchemy.quoted_name("table_with_single_pk", quote=False),
652+
schema_name=None,
653+
)
654+
655+
assert columns is not None
656+
assert len(columns) == 3 # id, name, value
657+
658+
# All columns should have primary_key field
659+
assert all("primary_key" in col for col in columns)
660+
661+
# Only 'id' should be marked as primary key
662+
pk_columns = [col["name"] for col in columns if col["primary_key"]]
663+
assert pk_columns == ["id"]
664+
665+
# Other columns should not be primary keys
666+
non_pk_columns = [col["name"] for col in columns if not col["primary_key"]]
667+
assert set(non_pk_columns) == {"name", "value"}
668+
669+
670+
@pytest.mark.unit
671+
def test_get_sqlalchemy_column_metadata_composite_primary_key(
672+
sql_data_connector_test_db_execution_engine,
673+
):
674+
"""Test that composite primary keys are correctly identified."""
675+
from great_expectations.execution_engine.sqlalchemy_batch_data import SqlAlchemyBatchData
676+
from great_expectations.expectations.metrics.util import get_sqlalchemy_column_metadata
677+
678+
engine = sql_data_connector_test_db_execution_engine
679+
680+
batch_data = SqlAlchemyBatchData(execution_engine=engine, table_name="table_with_composite_pk")
681+
engine.load_batch_data("__test_composite_pk", batch_data)
682+
683+
columns = get_sqlalchemy_column_metadata(
684+
execution_engine=engine,
685+
table_selectable=sqlalchemy.quoted_name("table_with_composite_pk", quote=False),
686+
schema_name=None,
687+
)
688+
689+
assert columns is not None
690+
assert len(columns) == 4 # user_id, order_id, product, quantity
691+
692+
# All columns should have primary_key field
693+
assert all("primary_key" in col for col in columns)
694+
695+
# Both user_id and order_id should be marked as primary keys
696+
pk_columns = sorted([col["name"] for col in columns if col["primary_key"]])
697+
assert pk_columns == ["order_id", "user_id"]
698+
699+
# Other columns should not be primary keys
700+
non_pk_columns = sorted([col["name"] for col in columns if not col["primary_key"]])
701+
assert non_pk_columns == ["product", "quantity"]
702+
703+
704+
@pytest.mark.unit
705+
def test_get_sqlalchemy_column_metadata_no_primary_key(
706+
sql_data_connector_test_db_execution_engine,
707+
):
708+
"""Test that tables without primary keys don't break."""
709+
from great_expectations.execution_engine.sqlalchemy_batch_data import SqlAlchemyBatchData
710+
from great_expectations.expectations.metrics.util import get_sqlalchemy_column_metadata
711+
712+
engine = sql_data_connector_test_db_execution_engine
713+
714+
batch_data = SqlAlchemyBatchData(execution_engine=engine, table_name="table_without_pk")
715+
engine.load_batch_data("__test_no_pk", batch_data)
716+
717+
columns = get_sqlalchemy_column_metadata(
718+
execution_engine=engine,
719+
table_selectable=sqlalchemy.quoted_name("table_without_pk", quote=False),
720+
schema_name=None,
721+
)
722+
723+
assert columns is not None
724+
assert len(columns) == 2 # description, amount
725+
726+
# All columns should have primary_key field
727+
assert all("primary_key" in col for col in columns)
728+
729+
# No columns should be marked as primary keys
730+
pk_columns = [col["name"] for col in columns if col["primary_key"]]
731+
assert pk_columns == []
732+
733+
# All columns should have primary_key=False
734+
assert all(not col["primary_key"] for col in columns)
735+
736+
737+
@pytest.mark.unit
738+
def test_get_sqlalchemy_column_metadata_quoted_pk_column(
739+
sql_data_connector_test_db_execution_engine,
740+
):
741+
"""Test that quoted column names as primary keys work correctly."""
742+
from great_expectations.execution_engine.sqlalchemy_batch_data import SqlAlchemyBatchData
743+
from great_expectations.expectations.metrics.util import get_sqlalchemy_column_metadata
744+
745+
engine = sql_data_connector_test_db_execution_engine
746+
747+
batch_data = SqlAlchemyBatchData(execution_engine=engine, table_name="table_with_quoted_pk")
748+
engine.load_batch_data("__test_quoted_pk", batch_data)
749+
750+
columns = get_sqlalchemy_column_metadata(
751+
execution_engine=engine,
752+
table_selectable=sqlalchemy.quoted_name("table_with_quoted_pk", quote=False),
753+
schema_name=None,
754+
)
755+
756+
assert columns is not None
757+
assert len(columns) == 2 # UserId, UserName
758+
759+
# All columns should have primary_key field
760+
assert all("primary_key" in col for col in columns)
761+
762+
# UserId should be marked as primary key
763+
pk_columns = [col["name"] for col in columns if col["primary_key"]]
764+
assert len(pk_columns) == 1
765+
# Case-insensitive check
766+
assert pk_columns[0].lower() == "userid"
767+
768+
635769
@pytest.mark.unit
636770
@patch("great_expectations.expectations.metrics.util.sa")
637771
def test_get_dialect_like_pattern_expression_is_resilient_to_missing_dialects(mock_sqlalchemy):

0 commit comments

Comments
 (0)