Skip to content

Commit 16d6709

Browse files
authored
fix: nested UNION inside bracketed set_expression (#726)
1 parent 89e5a61 commit 16d6709

File tree

2 files changed

+52
-15
lines changed

2 files changed

+52
-15
lines changed

sqllineage/core/parser/sqlfluff/extractors/select.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,28 +49,15 @@ def extract(
4949
subqueries.append(sq)
5050

5151
if is_set_expression(segment):
52-
for _, sub_segment in enumerate(
53-
segment.get_children("select_statement", "bracketed")
54-
):
55-
for seg in list_child_segments(sub_segment):
56-
for sq in self.list_subquery(seg):
57-
subqueries.append(sq)
52+
subqueries.extend(self._collect_subqueries_in_set_expression(segment))
5853

5954
self.extract_subquery(subqueries, holder)
6055

6156
for segment in segments:
6257
self._handle_select_statement_child_segments(segment, holder)
6358

6459
if is_set_expression(segment):
65-
for idx, sub_segment in enumerate(
66-
segment.get_children("select_statement", "bracketed")
67-
):
68-
if idx != 0:
69-
self.union_barriers.append(
70-
(len(self.columns), len(self.tables))
71-
)
72-
for seg in list_child_segments(sub_segment):
73-
self._handle_select_statement_child_segments(seg, holder)
60+
self._handle_set_expression(segment, holder)
7461

7562
self.end_of_query_cleanup(holder)
7663

@@ -126,6 +113,44 @@ def _handle_select_into(self, segment: BaseSegment, holder: SubQueryLineageHolde
126113
if table := self.find_table(identifier):
127114
holder.add_write(table)
128115

116+
def _handle_set_expression(
117+
self, segment: BaseSegment, holder: SubQueryLineageHolder
118+
) -> None:
119+
# Recursively handle set_expression and nested bracketed set_expressions
120+
for idx, child in enumerate(
121+
segment.get_children("select_statement", "bracketed")
122+
):
123+
if idx != 0:
124+
self.union_barriers.append((len(self.columns), len(self.tables)))
125+
if child.type == "select_statement":
126+
for seg in list_child_segments(child):
127+
self._handle_select_statement_child_segments(seg, holder)
128+
elif child.type == "bracketed":
129+
# If the bracketed child contains another set_expression, recurse; otherwise handle its contents
130+
inner_children = list_child_segments(child)
131+
if any(c.type == "set_expression" for c in inner_children):
132+
for c in inner_children:
133+
if c.type == "set_expression":
134+
self._handle_set_expression(c, holder)
135+
else:
136+
for seg in inner_children:
137+
self._handle_select_statement_child_segments(seg, holder)
138+
139+
def _collect_subqueries_in_set_expression(self, segment: BaseSegment):
140+
subqueries = []
141+
for child in segment.get_children("select_statement", "bracketed"):
142+
if child.type == "select_statement":
143+
for seg in list_child_segments(child):
144+
subqueries.extend(self.list_subquery(seg))
145+
elif child.type == "bracketed":
146+
inner_children = list_child_segments(child)
147+
for c in inner_children:
148+
if c.type == "set_expression":
149+
subqueries.extend(self._collect_subqueries_in_set_expression(c))
150+
else:
151+
subqueries.extend(self.list_subquery(c))
152+
return subqueries
153+
129154
def _handle_column(self, segment: BaseSegment) -> None:
130155
"""
131156
Column handler method

tests/sql/table/test_select.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,18 @@ def test_select_union_all_with_nested_subquery():
346346
assert_table_lineage_equal(sql, {"tab1", "tab2"})
347347

348348

349+
def test_select_nested_union_all():
350+
sql = """SELECT id
351+
FROM tab1
352+
UNION ALL
353+
(SELECT id
354+
FROM tab2
355+
UNION ALL
356+
SELECT id
357+
FROM tab3)"""
358+
assert_table_lineage_equal(sql, {"tab1", "tab2", "tab3"})
359+
360+
349361
def test_non_reserved_keyword_as_source():
350362
assert_table_lineage_equal(
351363
"SELECT col1, col2 FROM segment", {"segment"}, test_sqlparse=False

0 commit comments

Comments
 (0)