@@ -49,28 +49,15 @@ def extract(
4949 subqueries .append (sq )
5050
5151 if is_set_expression (segment ):
52- for _ , sub_segment in enumerate (
53- segment .get_children ("select_statement" , "bracketed" )
54- ):
55- for seg in list_child_segments (sub_segment ):
56- for sq in self .list_subquery (seg ):
57- subqueries .append (sq )
52+ subqueries .extend (self ._collect_subqueries_in_set_expression (segment ))
5853
5954 self .extract_subquery (subqueries , holder )
6055
6156 for segment in segments :
6257 self ._handle_select_statement_child_segments (segment , holder )
6358
6459 if is_set_expression (segment ):
65- for idx , sub_segment in enumerate (
66- segment .get_children ("select_statement" , "bracketed" )
67- ):
68- if idx != 0 :
69- self .union_barriers .append (
70- (len (self .columns ), len (self .tables ))
71- )
72- for seg in list_child_segments (sub_segment ):
73- self ._handle_select_statement_child_segments (seg , holder )
60+ self ._handle_set_expression (segment , holder )
7461
7562 self .end_of_query_cleanup (holder )
7663
@@ -126,6 +113,44 @@ def _handle_select_into(self, segment: BaseSegment, holder: SubQueryLineageHolde
126113 if table := self .find_table (identifier ):
127114 holder .add_write (table )
128115
116+ def _handle_set_expression (
117+ self , segment : BaseSegment , holder : SubQueryLineageHolder
118+ ) -> None :
119+ # Recursively handle set_expression and nested bracketed set_expressions
120+ for idx , child in enumerate (
121+ segment .get_children ("select_statement" , "bracketed" )
122+ ):
123+ if idx != 0 :
124+ self .union_barriers .append ((len (self .columns ), len (self .tables )))
125+ if child .type == "select_statement" :
126+ for seg in list_child_segments (child ):
127+ self ._handle_select_statement_child_segments (seg , holder )
128+ elif child .type == "bracketed" :
129+ # If the bracketed child contains another set_expression, recurse; otherwise handle its contents
130+ inner_children = list_child_segments (child )
131+ if any (c .type == "set_expression" for c in inner_children ):
132+ for c in inner_children :
133+ if c .type == "set_expression" :
134+ self ._handle_set_expression (c , holder )
135+ else :
136+ for seg in inner_children :
137+ self ._handle_select_statement_child_segments (seg , holder )
138+
139+ def _collect_subqueries_in_set_expression (self , segment : BaseSegment ):
140+ subqueries = []
141+ for child in segment .get_children ("select_statement" , "bracketed" ):
142+ if child .type == "select_statement" :
143+ for seg in list_child_segments (child ):
144+ subqueries .extend (self .list_subquery (seg ))
145+ elif child .type == "bracketed" :
146+ inner_children = list_child_segments (child )
147+ for c in inner_children :
148+ if c .type == "set_expression" :
149+ subqueries .extend (self ._collect_subqueries_in_set_expression (c ))
150+ else :
151+ subqueries .extend (self .list_subquery (c ))
152+ return subqueries
153+
129154 def _handle_column (self , segment : BaseSegment ) -> None :
130155 """
131156 Column handler method
0 commit comments