Skip to content

Commit

Permalink
Feat: expose a flag to automatically exclude Keep diff nodes (#4168)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Sep 26, 2024
1 parent 93cef30 commit 0a5444d
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 31 deletions.
22 changes: 18 additions & 4 deletions sqlglot/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def diff(
source: exp.Expression,
target: exp.Expression,
matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None,
delta_only: bool = False,
**kwargs: t.Any,
) -> t.List[Edit]:
"""
Expand Down Expand Up @@ -89,6 +90,8 @@ def diff(
heuristics produce better results for subtrees that are known by a caller to be matching.
Note: expression references in this list must refer to the same node objects that are
referenced in source / target trees.
delta_only: excludes all `Keep` nodes from the diff.
kwargs: additional arguments to pass to the ChangeDistiller instance.
Returns:
the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the
Expand Down Expand Up @@ -116,7 +119,12 @@ def compute_node_mappings(
}
matchings_copy = [(node_mappings[id(s)], node_mappings[id(t)]) for s, t in matchings]

return ChangeDistiller(**kwargs).diff(source_copy, target_copy, matchings=matchings_copy)
return ChangeDistiller(**kwargs).diff(
source_copy,
target_copy,
matchings=matchings_copy,
delta_only=delta_only,
)


# The expression types for which Update edits are allowed.
Expand Down Expand Up @@ -149,6 +157,7 @@ def diff(
source: exp.Expression,
target: exp.Expression,
matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None,
delta_only: bool = False,
) -> t.List[Edit]:
matchings = matchings or []
pre_matched_nodes = {id(s): id(t) for s, t in matchings}
Expand All @@ -168,9 +177,13 @@ def diff(
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}

matching_set = self._compute_matching_set() | {(s, t) for s, t in pre_matched_nodes.items()}
return self._generate_edit_script(matching_set)
return self._generate_edit_script(matching_set, delta_only)

def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.List[Edit]:
def _generate_edit_script(
self,
matching_set: t.Set[t.Tuple[int, int]],
delta_only: bool,
) -> t.List[Edit]:
edit_script: t.List[Edit] = []
for removed_node_id in self._unmatched_source_nodes:
edit_script.append(Remove(self._source_index[removed_node_id]))
Expand All @@ -186,7 +199,8 @@ def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.Lis
edit_script.extend(
self._generate_move_edits(source_node, target_node, matching_set)
)
edit_script.append(Keep(source_node, target_node))
if not delta_only:
edit_script.append(Keep(source_node, target_node))
else:
edit_script.append(Update(source_node, target_node))

Expand Down
58 changes: 31 additions & 27 deletions tests/test_diff.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,40 @@
import unittest

from sqlglot import exp, parse_one
from sqlglot.diff import Insert, Keep, Move, Remove, Update, diff
from sqlglot.diff import Insert, Move, Remove, Update, diff
from sqlglot.expressions import Join, to_table


def diff_delta_only(source, target, matchings=None, delta_only=True, **kwargs):
return diff(source, target, matchings=matchings, delta_only=delta_only, **kwargs)


class TestDiff(unittest.TestCase):
def test_simple(self):
self._validate_delta_only(
diff(parse_one("SELECT a + b"), parse_one("SELECT a - b")),
diff_delta_only(parse_one("SELECT a + b"), parse_one("SELECT a - b")),
[
Remove(parse_one("a + b")), # the Add node
Insert(parse_one("a - b")), # the Sub node
],
)

self._validate_delta_only(
diff(parse_one("SELECT a, b, c"), parse_one("SELECT a, c")),
diff_delta_only(parse_one("SELECT a, b, c"), parse_one("SELECT a, c")),
[
Remove(parse_one("b")), # the Column node
],
)

self._validate_delta_only(
diff(parse_one("SELECT a, b"), parse_one("SELECT a, b, c")),
diff_delta_only(parse_one("SELECT a, b"), parse_one("SELECT a, b, c")),
[
Insert(parse_one("c")), # the Column node
],
)

self._validate_delta_only(
diff(
diff_delta_only(
parse_one("SELECT a FROM table_one"),
parse_one("SELECT a FROM table_two"),
),
Expand All @@ -44,7 +48,9 @@ def test_simple(self):

def test_lambda(self):
self._validate_delta_only(
diff(parse_one("SELECT a, b, c, x(a -> a)"), parse_one("SELECT a, b, c, x(b -> b)")),
diff_delta_only(
parse_one("SELECT a, b, c, x(a -> a)"), parse_one("SELECT a, b, c, x(b -> b)")
),
[
Update(
exp.Lambda(this=exp.to_identifier("a"), expressions=[exp.to_identifier("a")]),
Expand All @@ -55,14 +61,16 @@ def test_lambda(self):

def test_udf(self):
self._validate_delta_only(
diff(parse_one('SELECT a, b, "my.udf1"()'), parse_one('SELECT a, b, "my.udf2"()')),
diff_delta_only(
parse_one('SELECT a, b, "my.udf1"()'), parse_one('SELECT a, b, "my.udf2"()')
),
[
Insert(parse_one('"my.udf2"()')),
Remove(parse_one('"my.udf1"()')),
],
)
self._validate_delta_only(
diff(
diff_delta_only(
parse_one('SELECT a, b, "my.udf"(x, y, z)'),
parse_one('SELECT a, b, "my.udf"(x, y, w)'),
),
Expand All @@ -74,28 +82,28 @@ def test_udf(self):

def test_node_position_changed(self):
self._validate_delta_only(
diff(parse_one("SELECT a, b, c"), parse_one("SELECT c, a, b")),
diff_delta_only(parse_one("SELECT a, b, c"), parse_one("SELECT c, a, b")),
[
Move(parse_one("c")), # the Column node
],
)

self._validate_delta_only(
diff(parse_one("SELECT a + b"), parse_one("SELECT b + a")),
diff_delta_only(parse_one("SELECT a + b"), parse_one("SELECT b + a")),
[
Move(parse_one("a")), # the Column node
],
)

self._validate_delta_only(
diff(parse_one("SELECT aaaa AND bbbb"), parse_one("SELECT bbbb AND aaaa")),
diff_delta_only(parse_one("SELECT aaaa AND bbbb"), parse_one("SELECT bbbb AND aaaa")),
[
Move(parse_one("aaaa")), # the Column node
],
)

self._validate_delta_only(
diff(
diff_delta_only(
parse_one("SELECT aaaa OR bbbb OR cccc"),
parse_one("SELECT cccc OR bbbb OR aaaa"),
),
Expand All @@ -120,7 +128,7 @@ def test_cte(self):
"""

self._validate_delta_only(
diff(parse_one(expr_src), parse_one(expr_tgt)),
diff_delta_only(parse_one(expr_src), parse_one(expr_tgt)),
[
Remove(parse_one("LOWER(c) AS c")), # the Alias node
Remove(parse_one("LOWER(c)")), # the Lower node
Expand All @@ -133,8 +141,7 @@ def test_join(self):
expr_src = "SELECT a, b FROM t1 LEFT JOIN t2 ON t1.key = t2.key"
expr_tgt = "SELECT a, b FROM t1 RIGHT JOIN t2 ON t1.key = t2.key"

changes = diff(parse_one(expr_src), parse_one(expr_tgt))
changes = _delta_only(changes)
changes = diff_delta_only(parse_one(expr_src), parse_one(expr_tgt))

self.assertEqual(len(changes), 2)
self.assertTrue(isinstance(changes[0], Remove))
Expand All @@ -145,10 +152,10 @@ def test_window_functions(self):
expr_src = parse_one("SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b)")
expr_tgt = parse_one("SELECT RANK() OVER (PARTITION BY a ORDER BY b)")

self._validate_delta_only(diff(expr_src, expr_src), [])
self._validate_delta_only(diff_delta_only(expr_src, expr_src), [])

self._validate_delta_only(
diff(expr_src, expr_tgt),
diff_delta_only(expr_src, expr_tgt),
[
Remove(parse_one("ROW_NUMBER()")), # the Anonymous node
Insert(parse_one("RANK()")), # the Anonymous node
Expand All @@ -160,7 +167,7 @@ def test_pre_matchings(self):
expr_tgt = parse_one("SELECT 1, 2, 3, 4")

self._validate_delta_only(
diff(expr_src, expr_tgt),
diff_delta_only(expr_src, expr_tgt),
[
Remove(expr_src),
Insert(expr_tgt),
Expand All @@ -171,7 +178,7 @@ def test_pre_matchings(self):
)

self._validate_delta_only(
diff(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt)]),
diff_delta_only(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt)]),
[
Insert(exp.Literal.number(2)),
Insert(exp.Literal.number(3)),
Expand All @@ -180,23 +187,20 @@ def test_pre_matchings(self):
)

with self.assertRaises(ValueError):
diff(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt), (expr_src, expr_tgt)])
diff_delta_only(
expr_src, expr_tgt, matchings=[(expr_src, expr_tgt), (expr_src, expr_tgt)]
)

def test_identifier(self):
expr_src = parse_one("SELECT a FROM tbl")
expr_tgt = parse_one("SELECT a, tbl.b from tbl")

self._validate_delta_only(
diff(expr_src, expr_tgt),
diff_delta_only(expr_src, expr_tgt),
[
Insert(expression=exp.to_column("tbl.b")),
],
)

def _validate_delta_only(self, actual_diff, expected_delta):
actual_delta = _delta_only(actual_diff)
def _validate_delta_only(self, actual_delta, expected_delta):
self.assertEqual(set(actual_delta), set(expected_delta))


def _delta_only(changes):
return [d for d in changes if not isinstance(d, Keep)]

0 comments on commit 0a5444d

Please sign in to comment.