diff --git a/sqlglot/diff.py b/sqlglot/diff.py index 22c506a9b..5ebc682e3 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -62,6 +62,7 @@ def diff( source: exp.Expression, target: exp.Expression, matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None, + delta_only: bool = False, **kwargs: t.Any, ) -> t.List[Edit]: """ @@ -89,6 +90,8 @@ def diff( heuristics produce better results for subtrees that are known by a caller to be matching. Note: expression references in this list must refer to the same node objects that are referenced in source / target trees. + delta_only: excludes all `Keep` nodes from the diff. + kwargs: additional arguments to pass to the ChangeDistiller instance. Returns: the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the @@ -116,7 +119,12 @@ def compute_node_mappings( } matchings_copy = [(node_mappings[id(s)], node_mappings[id(t)]) for s, t in matchings] - return ChangeDistiller(**kwargs).diff(source_copy, target_copy, matchings=matchings_copy) + return ChangeDistiller(**kwargs).diff( + source_copy, + target_copy, + matchings=matchings_copy, + delta_only=delta_only, + ) # The expression types for which Update edits are allowed. @@ -149,6 +157,7 @@ def diff( source: exp.Expression, target: exp.Expression, matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None, + delta_only: bool = False, ) -> t.List[Edit]: matchings = matchings or [] pre_matched_nodes = {id(s): id(t) for s, t in matchings} @@ -168,9 +177,13 @@ def diff( self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {} matching_set = self._compute_matching_set() | {(s, t) for s, t in pre_matched_nodes.items()} - return self._generate_edit_script(matching_set) + return self._generate_edit_script(matching_set, delta_only) - def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.List[Edit]: + def _generate_edit_script( + self, + matching_set: t.Set[t.Tuple[int, int]], + delta_only: bool, + ) -> t.List[Edit]: edit_script: t.List[Edit] = [] for removed_node_id in self._unmatched_source_nodes: edit_script.append(Remove(self._source_index[removed_node_id])) @@ -186,7 +199,8 @@ def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.Lis edit_script.extend( self._generate_move_edits(source_node, target_node, matching_set) ) - edit_script.append(Keep(source_node, target_node)) + if not delta_only: + edit_script.append(Keep(source_node, target_node)) else: edit_script.append(Update(source_node, target_node)) diff --git a/tests/test_diff.py b/tests/test_diff.py index fa012a890..3befe00be 100644 --- a/tests/test_diff.py +++ b/tests/test_diff.py @@ -1,14 +1,18 @@ import unittest from sqlglot import exp, parse_one -from sqlglot.diff import Insert, Keep, Move, Remove, Update, diff +from sqlglot.diff import Insert, Move, Remove, Update, diff from sqlglot.expressions import Join, to_table +def diff_delta_only(source, target, matchings=None, delta_only=True, **kwargs): + return diff(source, target, matchings=matchings, delta_only=delta_only, **kwargs) + + class TestDiff(unittest.TestCase): def test_simple(self): self._validate_delta_only( - diff(parse_one("SELECT a + b"), parse_one("SELECT a - b")), + diff_delta_only(parse_one("SELECT a + b"), parse_one("SELECT a - b")), [ Remove(parse_one("a + b")), # the Add node Insert(parse_one("a - b")), # the Sub node @@ -16,21 +20,21 @@ def test_simple(self): ) self._validate_delta_only( - diff(parse_one("SELECT a, b, c"), parse_one("SELECT a, c")), + diff_delta_only(parse_one("SELECT a, b, c"), parse_one("SELECT a, c")), [ Remove(parse_one("b")), # the Column node ], ) self._validate_delta_only( - diff(parse_one("SELECT a, b"), parse_one("SELECT a, b, c")), + diff_delta_only(parse_one("SELECT a, b"), parse_one("SELECT a, b, c")), [ Insert(parse_one("c")), # the Column node ], ) self._validate_delta_only( - diff( + diff_delta_only( parse_one("SELECT a FROM table_one"), parse_one("SELECT a FROM table_two"), ), @@ -44,7 +48,9 @@ def test_simple(self): def test_lambda(self): self._validate_delta_only( - diff(parse_one("SELECT a, b, c, x(a -> a)"), parse_one("SELECT a, b, c, x(b -> b)")), + diff_delta_only( + parse_one("SELECT a, b, c, x(a -> a)"), parse_one("SELECT a, b, c, x(b -> b)") + ), [ Update( exp.Lambda(this=exp.to_identifier("a"), expressions=[exp.to_identifier("a")]), @@ -55,14 +61,16 @@ def test_lambda(self): def test_udf(self): self._validate_delta_only( - diff(parse_one('SELECT a, b, "my.udf1"()'), parse_one('SELECT a, b, "my.udf2"()')), + diff_delta_only( + parse_one('SELECT a, b, "my.udf1"()'), parse_one('SELECT a, b, "my.udf2"()') + ), [ Insert(parse_one('"my.udf2"()')), Remove(parse_one('"my.udf1"()')), ], ) self._validate_delta_only( - diff( + diff_delta_only( parse_one('SELECT a, b, "my.udf"(x, y, z)'), parse_one('SELECT a, b, "my.udf"(x, y, w)'), ), @@ -74,28 +82,28 @@ def test_udf(self): def test_node_position_changed(self): self._validate_delta_only( - diff(parse_one("SELECT a, b, c"), parse_one("SELECT c, a, b")), + diff_delta_only(parse_one("SELECT a, b, c"), parse_one("SELECT c, a, b")), [ Move(parse_one("c")), # the Column node ], ) self._validate_delta_only( - diff(parse_one("SELECT a + b"), parse_one("SELECT b + a")), + diff_delta_only(parse_one("SELECT a + b"), parse_one("SELECT b + a")), [ Move(parse_one("a")), # the Column node ], ) self._validate_delta_only( - diff(parse_one("SELECT aaaa AND bbbb"), parse_one("SELECT bbbb AND aaaa")), + diff_delta_only(parse_one("SELECT aaaa AND bbbb"), parse_one("SELECT bbbb AND aaaa")), [ Move(parse_one("aaaa")), # the Column node ], ) self._validate_delta_only( - diff( + diff_delta_only( parse_one("SELECT aaaa OR bbbb OR cccc"), parse_one("SELECT cccc OR bbbb OR aaaa"), ), @@ -120,7 +128,7 @@ def test_cte(self): """ self._validate_delta_only( - diff(parse_one(expr_src), parse_one(expr_tgt)), + diff_delta_only(parse_one(expr_src), parse_one(expr_tgt)), [ Remove(parse_one("LOWER(c) AS c")), # the Alias node Remove(parse_one("LOWER(c)")), # the Lower node @@ -133,8 +141,7 @@ def test_join(self): expr_src = "SELECT a, b FROM t1 LEFT JOIN t2 ON t1.key = t2.key" expr_tgt = "SELECT a, b FROM t1 RIGHT JOIN t2 ON t1.key = t2.key" - changes = diff(parse_one(expr_src), parse_one(expr_tgt)) - changes = _delta_only(changes) + changes = diff_delta_only(parse_one(expr_src), parse_one(expr_tgt)) self.assertEqual(len(changes), 2) self.assertTrue(isinstance(changes[0], Remove)) @@ -145,10 +152,10 @@ def test_window_functions(self): expr_src = parse_one("SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b)") expr_tgt = parse_one("SELECT RANK() OVER (PARTITION BY a ORDER BY b)") - self._validate_delta_only(diff(expr_src, expr_src), []) + self._validate_delta_only(diff_delta_only(expr_src, expr_src), []) self._validate_delta_only( - diff(expr_src, expr_tgt), + diff_delta_only(expr_src, expr_tgt), [ Remove(parse_one("ROW_NUMBER()")), # the Anonymous node Insert(parse_one("RANK()")), # the Anonymous node @@ -160,7 +167,7 @@ def test_pre_matchings(self): expr_tgt = parse_one("SELECT 1, 2, 3, 4") self._validate_delta_only( - diff(expr_src, expr_tgt), + diff_delta_only(expr_src, expr_tgt), [ Remove(expr_src), Insert(expr_tgt), @@ -171,7 +178,7 @@ def test_pre_matchings(self): ) self._validate_delta_only( - diff(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt)]), + diff_delta_only(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt)]), [ Insert(exp.Literal.number(2)), Insert(exp.Literal.number(3)), @@ -180,23 +187,20 @@ def test_pre_matchings(self): ) with self.assertRaises(ValueError): - diff(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt), (expr_src, expr_tgt)]) + diff_delta_only( + expr_src, expr_tgt, matchings=[(expr_src, expr_tgt), (expr_src, expr_tgt)] + ) def test_identifier(self): expr_src = parse_one("SELECT a FROM tbl") expr_tgt = parse_one("SELECT a, tbl.b from tbl") self._validate_delta_only( - diff(expr_src, expr_tgt), + diff_delta_only(expr_src, expr_tgt), [ Insert(expression=exp.to_column("tbl.b")), ], ) - def _validate_delta_only(self, actual_diff, expected_delta): - actual_delta = _delta_only(actual_diff) + def _validate_delta_only(self, actual_delta, expected_delta): self.assertEqual(set(actual_delta), set(expected_delta)) - - -def _delta_only(changes): - return [d for d in changes if not isinstance(d, Keep)]