Skip to content

Commit 0a5444d

Browse files
authored
Feat: expose a flag to automatically exclude Keep diff nodes (#4168)
1 parent 93cef30 commit 0a5444d

File tree

2 files changed

+49
-31
lines changed

2 files changed

+49
-31
lines changed

sqlglot/diff.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def diff(
6262
source: exp.Expression,
6363
target: exp.Expression,
6464
matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None,
65+
delta_only: bool = False,
6566
**kwargs: t.Any,
6667
) -> t.List[Edit]:
6768
"""
@@ -89,6 +90,8 @@ def diff(
8990
heuristics produce better results for subtrees that are known by a caller to be matching.
9091
Note: expression references in this list must refer to the same node objects that are
9192
referenced in source / target trees.
93+
delta_only: excludes all `Keep` nodes from the diff.
94+
kwargs: additional arguments to pass to the ChangeDistiller instance.
9295
9396
Returns:
9497
the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the
@@ -116,7 +119,12 @@ def compute_node_mappings(
116119
}
117120
matchings_copy = [(node_mappings[id(s)], node_mappings[id(t)]) for s, t in matchings]
118121

119-
return ChangeDistiller(**kwargs).diff(source_copy, target_copy, matchings=matchings_copy)
122+
return ChangeDistiller(**kwargs).diff(
123+
source_copy,
124+
target_copy,
125+
matchings=matchings_copy,
126+
delta_only=delta_only,
127+
)
120128

121129

122130
# The expression types for which Update edits are allowed.
@@ -149,6 +157,7 @@ def diff(
149157
source: exp.Expression,
150158
target: exp.Expression,
151159
matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None,
160+
delta_only: bool = False,
152161
) -> t.List[Edit]:
153162
matchings = matchings or []
154163
pre_matched_nodes = {id(s): id(t) for s, t in matchings}
@@ -168,9 +177,13 @@ def diff(
168177
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
169178

170179
matching_set = self._compute_matching_set() | {(s, t) for s, t in pre_matched_nodes.items()}
171-
return self._generate_edit_script(matching_set)
180+
return self._generate_edit_script(matching_set, delta_only)
172181

173-
def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.List[Edit]:
182+
def _generate_edit_script(
183+
self,
184+
matching_set: t.Set[t.Tuple[int, int]],
185+
delta_only: bool,
186+
) -> t.List[Edit]:
174187
edit_script: t.List[Edit] = []
175188
for removed_node_id in self._unmatched_source_nodes:
176189
edit_script.append(Remove(self._source_index[removed_node_id]))
@@ -186,7 +199,8 @@ def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.Lis
186199
edit_script.extend(
187200
self._generate_move_edits(source_node, target_node, matching_set)
188201
)
189-
edit_script.append(Keep(source_node, target_node))
202+
if not delta_only:
203+
edit_script.append(Keep(source_node, target_node))
190204
else:
191205
edit_script.append(Update(source_node, target_node))
192206

tests/test_diff.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,40 @@
11
import unittest
22

33
from sqlglot import exp, parse_one
4-
from sqlglot.diff import Insert, Keep, Move, Remove, Update, diff
4+
from sqlglot.diff import Insert, Move, Remove, Update, diff
55
from sqlglot.expressions import Join, to_table
66

77

8+
def diff_delta_only(source, target, matchings=None, delta_only=True, **kwargs):
9+
return diff(source, target, matchings=matchings, delta_only=delta_only, **kwargs)
10+
11+
812
class TestDiff(unittest.TestCase):
913
def test_simple(self):
1014
self._validate_delta_only(
11-
diff(parse_one("SELECT a + b"), parse_one("SELECT a - b")),
15+
diff_delta_only(parse_one("SELECT a + b"), parse_one("SELECT a - b")),
1216
[
1317
Remove(parse_one("a + b")), # the Add node
1418
Insert(parse_one("a - b")), # the Sub node
1519
],
1620
)
1721

1822
self._validate_delta_only(
19-
diff(parse_one("SELECT a, b, c"), parse_one("SELECT a, c")),
23+
diff_delta_only(parse_one("SELECT a, b, c"), parse_one("SELECT a, c")),
2024
[
2125
Remove(parse_one("b")), # the Column node
2226
],
2327
)
2428

2529
self._validate_delta_only(
26-
diff(parse_one("SELECT a, b"), parse_one("SELECT a, b, c")),
30+
diff_delta_only(parse_one("SELECT a, b"), parse_one("SELECT a, b, c")),
2731
[
2832
Insert(parse_one("c")), # the Column node
2933
],
3034
)
3135

3236
self._validate_delta_only(
33-
diff(
37+
diff_delta_only(
3438
parse_one("SELECT a FROM table_one"),
3539
parse_one("SELECT a FROM table_two"),
3640
),
@@ -44,7 +48,9 @@ def test_simple(self):
4448

4549
def test_lambda(self):
4650
self._validate_delta_only(
47-
diff(parse_one("SELECT a, b, c, x(a -> a)"), parse_one("SELECT a, b, c, x(b -> b)")),
51+
diff_delta_only(
52+
parse_one("SELECT a, b, c, x(a -> a)"), parse_one("SELECT a, b, c, x(b -> b)")
53+
),
4854
[
4955
Update(
5056
exp.Lambda(this=exp.to_identifier("a"), expressions=[exp.to_identifier("a")]),
@@ -55,14 +61,16 @@ def test_lambda(self):
5561

5662
def test_udf(self):
5763
self._validate_delta_only(
58-
diff(parse_one('SELECT a, b, "my.udf1"()'), parse_one('SELECT a, b, "my.udf2"()')),
64+
diff_delta_only(
65+
parse_one('SELECT a, b, "my.udf1"()'), parse_one('SELECT a, b, "my.udf2"()')
66+
),
5967
[
6068
Insert(parse_one('"my.udf2"()')),
6169
Remove(parse_one('"my.udf1"()')),
6270
],
6371
)
6472
self._validate_delta_only(
65-
diff(
73+
diff_delta_only(
6674
parse_one('SELECT a, b, "my.udf"(x, y, z)'),
6775
parse_one('SELECT a, b, "my.udf"(x, y, w)'),
6876
),
@@ -74,28 +82,28 @@ def test_udf(self):
7482

7583
def test_node_position_changed(self):
7684
self._validate_delta_only(
77-
diff(parse_one("SELECT a, b, c"), parse_one("SELECT c, a, b")),
85+
diff_delta_only(parse_one("SELECT a, b, c"), parse_one("SELECT c, a, b")),
7886
[
7987
Move(parse_one("c")), # the Column node
8088
],
8189
)
8290

8391
self._validate_delta_only(
84-
diff(parse_one("SELECT a + b"), parse_one("SELECT b + a")),
92+
diff_delta_only(parse_one("SELECT a + b"), parse_one("SELECT b + a")),
8593
[
8694
Move(parse_one("a")), # the Column node
8795
],
8896
)
8997

9098
self._validate_delta_only(
91-
diff(parse_one("SELECT aaaa AND bbbb"), parse_one("SELECT bbbb AND aaaa")),
99+
diff_delta_only(parse_one("SELECT aaaa AND bbbb"), parse_one("SELECT bbbb AND aaaa")),
92100
[
93101
Move(parse_one("aaaa")), # the Column node
94102
],
95103
)
96104

97105
self._validate_delta_only(
98-
diff(
106+
diff_delta_only(
99107
parse_one("SELECT aaaa OR bbbb OR cccc"),
100108
parse_one("SELECT cccc OR bbbb OR aaaa"),
101109
),
@@ -120,7 +128,7 @@ def test_cte(self):
120128
"""
121129

122130
self._validate_delta_only(
123-
diff(parse_one(expr_src), parse_one(expr_tgt)),
131+
diff_delta_only(parse_one(expr_src), parse_one(expr_tgt)),
124132
[
125133
Remove(parse_one("LOWER(c) AS c")), # the Alias node
126134
Remove(parse_one("LOWER(c)")), # the Lower node
@@ -133,8 +141,7 @@ def test_join(self):
133141
expr_src = "SELECT a, b FROM t1 LEFT JOIN t2 ON t1.key = t2.key"
134142
expr_tgt = "SELECT a, b FROM t1 RIGHT JOIN t2 ON t1.key = t2.key"
135143

136-
changes = diff(parse_one(expr_src), parse_one(expr_tgt))
137-
changes = _delta_only(changes)
144+
changes = diff_delta_only(parse_one(expr_src), parse_one(expr_tgt))
138145

139146
self.assertEqual(len(changes), 2)
140147
self.assertTrue(isinstance(changes[0], Remove))
@@ -145,10 +152,10 @@ def test_window_functions(self):
145152
expr_src = parse_one("SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b)")
146153
expr_tgt = parse_one("SELECT RANK() OVER (PARTITION BY a ORDER BY b)")
147154

148-
self._validate_delta_only(diff(expr_src, expr_src), [])
155+
self._validate_delta_only(diff_delta_only(expr_src, expr_src), [])
149156

150157
self._validate_delta_only(
151-
diff(expr_src, expr_tgt),
158+
diff_delta_only(expr_src, expr_tgt),
152159
[
153160
Remove(parse_one("ROW_NUMBER()")), # the Anonymous node
154161
Insert(parse_one("RANK()")), # the Anonymous node
@@ -160,7 +167,7 @@ def test_pre_matchings(self):
160167
expr_tgt = parse_one("SELECT 1, 2, 3, 4")
161168

162169
self._validate_delta_only(
163-
diff(expr_src, expr_tgt),
170+
diff_delta_only(expr_src, expr_tgt),
164171
[
165172
Remove(expr_src),
166173
Insert(expr_tgt),
@@ -171,7 +178,7 @@ def test_pre_matchings(self):
171178
)
172179

173180
self._validate_delta_only(
174-
diff(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt)]),
181+
diff_delta_only(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt)]),
175182
[
176183
Insert(exp.Literal.number(2)),
177184
Insert(exp.Literal.number(3)),
@@ -180,23 +187,20 @@ def test_pre_matchings(self):
180187
)
181188

182189
with self.assertRaises(ValueError):
183-
diff(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt), (expr_src, expr_tgt)])
190+
diff_delta_only(
191+
expr_src, expr_tgt, matchings=[(expr_src, expr_tgt), (expr_src, expr_tgt)]
192+
)
184193

185194
def test_identifier(self):
186195
expr_src = parse_one("SELECT a FROM tbl")
187196
expr_tgt = parse_one("SELECT a, tbl.b from tbl")
188197

189198
self._validate_delta_only(
190-
diff(expr_src, expr_tgt),
199+
diff_delta_only(expr_src, expr_tgt),
191200
[
192201
Insert(expression=exp.to_column("tbl.b")),
193202
],
194203
)
195204

196-
def _validate_delta_only(self, actual_diff, expected_delta):
197-
actual_delta = _delta_only(actual_diff)
205+
def _validate_delta_only(self, actual_delta, expected_delta):
198206
self.assertEqual(set(actual_delta), set(expected_delta))
199-
200-
201-
def _delta_only(changes):
202-
return [d for d in changes if not isinstance(d, Keep)]

0 commit comments

Comments
 (0)