Skip to content

Commit a93b7e9

Browse files
committed
Feat(table_diff): Add option for case insensitive schema comparisons
1 parent ff87b03 commit a93b7e9

File tree

5 files changed

+190
-6
lines changed

5 files changed

+190
-6
lines changed

sqlmesh/cli/main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,11 @@ def create_external_models(obj: Context, **kwargs: t.Any) -> None:
938938
multiple=True,
939939
help="Specify one or more models to data diff. Use wildcards to diff multiple models. Ex: '*' (all models with applied plan diffs), 'demo.model+' (this and downstream models), 'git:feature_branch' (models with direct modifications in this branch only)",
940940
)
941+
@click.option(
942+
"--schema-diff-ignore-case",
943+
is_flag=True,
944+
help="If set, when performing a schema diff the case of column names is ignored when matching between the two schemas. For example, 'col_a' in the source schema and 'COL_A' in the target schema will be treated as the same column.",
945+
)
941946
@click.pass_obj
942947
@error_handler
943948
@cli_analytics

sqlmesh/core/context.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,6 +1673,7 @@ def table_diff(
16731673
skip_grain_check: bool = False,
16741674
warn_grain_check: bool = False,
16751675
temp_schema: t.Optional[str] = None,
1676+
schema_diff_ignore_case: bool = False,
16761677
) -> t.List[TableDiff]:
16771678
"""Show a diff between two tables.
16781679
@@ -1796,6 +1797,7 @@ def table_diff(
17961797
show=show,
17971798
temp_schema=temp_schema,
17981799
skip_grain_check=skip_grain_check,
1800+
schema_diff_ignore_case=schema_diff_ignore_case,
17991801
),
18001802
tasks_num=tasks_num,
18011803
)
@@ -1821,6 +1823,7 @@ def table_diff(
18211823
on=on,
18221824
skip_columns=skip_columns,
18231825
where=where,
1826+
schema_diff_ignore_case=schema_diff_ignore_case,
18241827
)
18251828
]
18261829

@@ -1845,6 +1848,7 @@ def _model_diff(
18451848
show: bool = True,
18461849
temp_schema: t.Optional[str] = None,
18471850
skip_grain_check: bool = False,
1851+
schema_diff_ignore_case: bool = False,
18481852
) -> TableDiff:
18491853
self.console.start_table_diff_model_progress(model.name)
18501854

@@ -1860,6 +1864,7 @@ def _model_diff(
18601864
target=target,
18611865
source_alias=source_alias,
18621866
target_alias=target_alias,
1867+
schema_diff_ignore_case=schema_diff_ignore_case,
18631868
)
18641869

18651870
if show:
@@ -1883,6 +1888,7 @@ def _table_diff(
18831888
model: t.Optional[Model] = None,
18841889
skip_columns: t.Optional[t.List[str]] = None,
18851890
where: t.Optional[str | exp.Condition] = None,
1891+
schema_diff_ignore_case: bool = False,
18861892
) -> TableDiff:
18871893
if not on:
18881894
raise SQLMeshError(
@@ -1902,6 +1908,7 @@ def _table_diff(
19021908
decimals=decimals,
19031909
model_name=model.name if model else None,
19041910
model_dialect=model.dialect if model else None,
1911+
schema_diff_ignore_case=schema_diff_ignore_case,
19051912
)
19061913

19071914
@python_api_analytics

sqlmesh/core/table_diff.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,29 +36,78 @@ class SchemaDiff(PydanticModel, frozen=True):
3636
source_alias: t.Optional[str] = None
3737
target_alias: t.Optional[str] = None
3838
model_name: t.Optional[str] = None
39+
ignore_case: bool = False
40+
41+
@property
42+
def _normalized_source_schema(self) -> t.Dict[str, exp.DataType]:
43+
return (
44+
self._lowercase_schema_names(self.source_schema)
45+
if self.ignore_case
46+
else self.source_schema
47+
)
48+
49+
@property
50+
def _normalized_target_schema(self) -> t.Dict[str, exp.DataType]:
51+
return (
52+
self._lowercase_schema_names(self.target_schema)
53+
if self.ignore_case
54+
else self.target_schema
55+
)
56+
57+
def _lowercase_schema_names(
58+
self, schema: t.Dict[str, exp.DataType]
59+
) -> t.Dict[str, exp.DataType]:
60+
return {c.lower(): t for c, t in schema.items()}
61+
62+
def _original_column_name(
63+
self, maybe_lowercased_column_name: str, schema: t.Dict[str, exp.DataType]
64+
) -> str:
65+
if not self.ignore_case:
66+
return maybe_lowercased_column_name
67+
68+
return next(c for c in schema if c.lower() == maybe_lowercased_column_name)
3969

4070
@property
4171
def added(self) -> t.List[t.Tuple[str, exp.DataType]]:
4272
"""Added columns."""
43-
return [(c, t) for c, t in self.target_schema.items() if c not in self.source_schema]
73+
return [
74+
(self._original_column_name(c, self.target_schema), t)
75+
for c, t in self._normalized_target_schema.items()
76+
if c not in self._normalized_source_schema
77+
]
4478

4579
@property
4680
def removed(self) -> t.List[t.Tuple[str, exp.DataType]]:
4781
"""Removed columns."""
48-
return [(c, t) for c, t in self.source_schema.items() if c not in self.target_schema]
82+
return [
83+
(self._original_column_name(c, self.source_schema), t)
84+
for c, t in self._normalized_source_schema.items()
85+
if c not in self._normalized_target_schema
86+
]
4987

5088
@property
5189
def modified(self) -> t.Dict[str, t.Tuple[exp.DataType, exp.DataType]]:
5290
"""Columns with modified types."""
5391
modified = {}
54-
for column in self.source_schema.keys() & self.target_schema.keys():
55-
source_type = self.source_schema[column]
56-
target_type = self.target_schema[column]
92+
for column in self._normalized_source_schema.keys() & self._normalized_target_schema.keys():
93+
source_type = self._normalized_source_schema[column]
94+
target_type = self._normalized_target_schema[column]
5795

5896
if source_type != target_type:
5997
modified[column] = (source_type, target_type)
98+
99+
if self.ignore_case:
100+
modified = {
101+
self._original_column_name(c, self.source_schema): dt for c, dt in modified.items()
102+
}
103+
60104
return modified
61105

106+
@property
107+
def has_changes(self) -> bool:
108+
"""Does the schema contain any changes at all between source and target"""
109+
return bool(self.added or self.removed or self.modified)
110+
62111

63112
class RowDiff(PydanticModel, frozen=True):
64113
"""Summary statistics and a sample dataframe."""
@@ -183,6 +232,7 @@ def __init__(
183232
model_name: t.Optional[str] = None,
184233
model_dialect: t.Optional[str] = None,
185234
decimals: int = 3,
235+
schema_diff_ignore_case: bool = False,
186236
):
187237
if not isinstance(adapter, RowDiffMixin):
188238
raise ValueError(f"Engine {adapter} doesnt support RowDiff")
@@ -198,6 +248,7 @@ def __init__(
198248
self.model_name = model_name
199249
self.model_dialect = model_dialect
200250
self.decimals = decimals
251+
self.schema_diff_ignore_case = schema_diff_ignore_case
201252

202253
# Support environment aliases for diff output improvement in certain cases
203254
self.source_alias = source_alias
@@ -282,6 +333,7 @@ def schema_diff(self) -> SchemaDiff:
282333
source_alias=self.source_alias,
283334
target_alias=self.target_alias,
284335
model_name=self.model_name,
336+
ignore_case=self.schema_diff_ignore_case,
285337
)
286338

287339
def row_diff(

tests/cli/test_cli.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import string
23
from contextlib import contextmanager
34
from os import getcwd, path, remove
45
from pathlib import Path
@@ -1759,3 +1760,34 @@ def test_ignore_warnings(runner: CliRunner, tmp_path: Path) -> None:
17591760
)
17601761
assert result.exit_code == 0
17611762
assert audit_warning not in result.output
1763+
1764+
1765+
def test_table_diff_schema_diff_ignore_case(runner: CliRunner, tmp_path: Path):
1766+
from sqlmesh.core.engine_adapter import DuckDBEngineAdapter
1767+
1768+
create_example_project(tmp_path)
1769+
1770+
ctx = Context(paths=tmp_path)
1771+
assert isinstance(ctx.engine_adapter, DuckDBEngineAdapter)
1772+
1773+
ctx.engine_adapter.execute('create table t1 (id int, "naME" varchar)')
1774+
ctx.engine_adapter.execute('create table t2 (id int, "name" varchar)')
1775+
1776+
# default behavior (case sensitive)
1777+
result = runner.invoke(
1778+
cli,
1779+
["--paths", str(tmp_path), "table_diff", "t1:t2", "-o", "id"],
1780+
)
1781+
assert result.exit_code == 0
1782+
stripped_output = "".join((x for x in result.output if x in string.printable))
1783+
assert "Added Columns:\n name (TEXT)" in stripped_output
1784+
assert "Removed Columns:\n naME (TEXT)" in stripped_output
1785+
1786+
# ignore case
1787+
result = runner.invoke(
1788+
cli,
1789+
["--paths", str(tmp_path), "table_diff", "t1:t2", "-o", "id", "--schema-diff-ignore-case"],
1790+
)
1791+
assert result.exit_code == 0
1792+
stripped_output = "".join((x for x in result.output if x in string.printable))
1793+
assert "Schema Diff Between 'T1' and 'T2':\n Schemas match" in stripped_output

tests/core/test_table_diff.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sqlmesh.core.context import Context
1212
from sqlmesh.core.config import AutoCategorizationMode, CategorizerConfig, DuckDBConnectionConfig
1313
from sqlmesh.core.model import SqlModel, load_sql_based_model
14-
from sqlmesh.core.table_diff import TableDiff
14+
from sqlmesh.core.table_diff import TableDiff, SchemaDiff
1515
import numpy as np # noqa: TID253
1616
from sqlmesh.utils.errors import SQLMeshError
1717

@@ -944,3 +944,91 @@ def test_data_diff_multiple_models_lacking_grain(sushi_context_fixed_date, capsy
944944
assert row_diff1.t_sample.shape == (0, 2)
945945
assert row_diff1.joined_sample.shape == (2, 3)
946946
assert row_diff1.sample.shape == (2, 4)
947+
948+
949+
def test_schema_diff_ignore_case():
950+
# no changes
951+
table_a = {"COL_A": exp.DataType.build("varchar"), "cOl_b": exp.DataType.build("int")}
952+
table_b = {"col_a": exp.DataType.build("varchar"), "COL_b": exp.DataType.build("int")}
953+
954+
diff = SchemaDiff(
955+
source="table_a",
956+
source_schema=table_a,
957+
target="table_b",
958+
target_schema=table_b,
959+
ignore_case=True,
960+
)
961+
962+
assert not diff.has_changes
963+
964+
# added in target
965+
table_a = {"COL_A": exp.DataType.build("varchar"), "cOl_b": exp.DataType.build("int")}
966+
table_b = {
967+
"col_a": exp.DataType.build("varchar"),
968+
"COL_b": exp.DataType.build("int"),
969+
"cOL__C": exp.DataType.build("date"),
970+
}
971+
972+
diff = SchemaDiff(
973+
source="table_a",
974+
source_schema=table_a,
975+
target="table_b",
976+
target_schema=table_b,
977+
ignore_case=True,
978+
)
979+
980+
assert diff.has_changes
981+
assert len(diff.added) == 1
982+
assert diff.added[0] == (
983+
"cOL__C",
984+
exp.DataType.build("date"),
985+
) # notice: case preserved on output
986+
assert not diff.removed
987+
assert not diff.modified
988+
989+
# removed from source
990+
table_a = {
991+
"cOL_fo0": exp.DataType.build("float"),
992+
"COL_A": exp.DataType.build("varchar"),
993+
"cOl_b": exp.DataType.build("int"),
994+
}
995+
table_b = {"col_a": exp.DataType.build("varchar"), "COL_b": exp.DataType.build("int")}
996+
997+
diff = SchemaDiff(
998+
source="table_a",
999+
source_schema=table_a,
1000+
target="table_b",
1001+
target_schema=table_b,
1002+
ignore_case=True,
1003+
)
1004+
1005+
assert diff.has_changes
1006+
assert not diff.added
1007+
assert len(diff.removed) == 1
1008+
assert diff.removed[0] == (
1009+
"cOL_fo0",
1010+
exp.DataType.build("float"),
1011+
) # notice: case preserved on output
1012+
assert not diff.modified
1013+
1014+
# column type change
1015+
table_a = {"CoL_A": exp.DataType.build("varchar"), "cOl_b": exp.DataType.build("int")}
1016+
table_b = {"col_a": exp.DataType.build("date"), "COL_b": exp.DataType.build("int")}
1017+
1018+
diff = SchemaDiff(
1019+
source="table_a",
1020+
source_schema=table_a,
1021+
target="table_b",
1022+
target_schema=table_b,
1023+
ignore_case=True,
1024+
)
1025+
1026+
assert diff.has_changes
1027+
assert not diff.added
1028+
assert not diff.removed
1029+
assert diff.modified == {
1030+
"CoL_A": (
1031+
exp.DataType.build("varchar"),
1032+
exp.DataType.build("date"),
1033+
) # notice: source casing used on output
1034+
}

0 commit comments

Comments
 (0)