Skip to content

Commit cd7c55d

Browse files
mingjerliclaude
andcommitted
feat: Add star expansion for cross-query column lineage
Implement two-phase parsing to resolve SELECT * to individual columns when upstream table schemas are known from previous queries. ## Key Changes ### Architecture - **Phase 1**: Topological ordering (existing) - **Phase 2**: Sequential parsing with upstream context (NEW) - Collect output columns from upstream tables - Pass as `external_table_columns` to RecursiveLineageBuilder - Expand SELECT * to actual columns during parsing ### Implementation **pipeline.py:** - `_collect_upstream_table_schemas()`: Collect columns from upstream tables - Enhanced `build()`: Pass external_table_columns to lineage builder - Improved cross-query edge logic: Distinguish SELECT * from COUNT(*) **lineage_builder.py:** - `_try_expand_star_from_external_table()`: Expand stars from external tables - Enhanced star expansion to handle both CTEs and external tables - Support for EXCEPT and REPLACE clauses in star expansion ### Features - ✅ SELECT * → expands to individual columns with precise lineage - ✅ SELECT * REPLACE → expands with transformations applied - ✅ SELECT * EXCEPT → expands excluding specified columns - ✅ COUNT(*) → preserved as * node for aggregation semantics - ✅ Multi-query star chains → full expansion through pipeline ### Testing - Added comprehensive test suite: test_star_expansion_cross_query.py - 13 new tests covering expansion, preservation, mixed usage, edge cases - Updated existing test: test_multi_query.py (SELECT * EXCEPT verification) - All 294 tests pass ### Backward Compatibility - Fully backward compatible - Automatic and transparent - No breaking API changes 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent af3a079 commit cd7c55d

4 files changed

Lines changed: 970 additions & 45 deletions

File tree

src/clgraph/lineage_builder.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,23 @@ def _extract_output_columns(self, unit: QueryUnit) -> List[Dict]:
199199

200200
# STAR EXPANSION: Try to expand star if we know the source columns
201201
# This only applies to the main query output (not CTEs/subqueries)
202-
if unit.unit_type == QueryUnitType.MAIN_QUERY and col_info.get("source_unit"):
203-
source_unit_id = str(col_info["source_unit"])
204-
expanded_cols = self._try_expand_star(unit, source_unit_id, col_info)
202+
if unit.unit_type == QueryUnitType.MAIN_QUERY:
203+
expanded_cols = None
204+
205+
# Case 1: Source is a CTE or subquery (internal query unit)
206+
if col_info.get("source_unit"):
207+
source_unit_id = str(col_info["source_unit"])
208+
expanded_cols = self._try_expand_star(unit, source_unit_id, col_info)
209+
210+
# Case 2: Source is an external table with known columns
211+
elif col_info.get("source_table"):
212+
source_table = col_info["source_table"]
213+
# Type guard: ensure source_table is a string
214+
if isinstance(source_table, str):
215+
expanded_cols = self._try_expand_star_from_external_table(
216+
unit, source_table, col_info
217+
)
218+
205219
if expanded_cols:
206220
# Replace star with expanded columns
207221
output_cols.extend(expanded_cols)
@@ -1040,6 +1054,58 @@ def _try_expand_star(
10401054

10411055
return expanded if expanded else None
10421056

1057+
def _try_expand_star_from_external_table(
1058+
self, unit: QueryUnit, source_table: str, star_col_info: Dict
1059+
) -> Optional[List[Dict]]:
1060+
"""
1061+
Try to expand a star from an external table using external_table_columns.
1062+
1063+
This handles the cross-query scenario where Query 2 does SELECT * FROM staging.orders,
1064+
and staging.orders was created by Query 1 with known columns.
1065+
1066+
Args:
1067+
unit: The current query unit (main query)
1068+
source_table: The external table name (e.g., "staging.orders")
1069+
star_col_info: The star column info dict
1070+
1071+
Returns:
1072+
List of expanded column dicts, or None if expansion not possible
1073+
"""
1074+
# Check if we have column information for this external table
1075+
if source_table not in self.external_table_columns:
1076+
return None
1077+
1078+
column_names = self.external_table_columns[source_table]
1079+
if not column_names:
1080+
return None
1081+
1082+
# Great! We can expand. Create individual column entries
1083+
expanded: List[Dict] = []
1084+
except_cols = star_col_info.get("except_columns", set())
1085+
replace_cols = star_col_info.get("replace_columns", {})
1086+
1087+
for col_name in column_names:
1088+
# Skip columns in EXCEPT clause
1089+
if col_name in except_cols:
1090+
continue
1091+
1092+
# Create a new column info for this expanded column
1093+
expanded_col = {
1094+
"index": len(expanded),
1095+
"ast_node": None, # No specific AST node for expanded columns
1096+
"is_star": False, # This is now an explicit column
1097+
"name": col_name,
1098+
"type": "direct_column", # Direct pass-through from source
1099+
"expression": replace_cols.get(col_name, col_name), # Use REPLACE if specified
1100+
"source_columns": [(source_table, col_name)],
1101+
# Mark this as star-expanded so we can trace it properly
1102+
"star_expanded": True,
1103+
"star_source_table": source_table,
1104+
}
1105+
expanded.append(expanded_col)
1106+
1107+
return expanded if expanded else None
1108+
10431109
# ========================================================================
10441110
# Validation Methods (Static Analysis)
10451111
# ========================================================================

src/clgraph/pipeline.py

Lines changed: 222 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,18 @@ def build(self, pipeline_or_graph) -> "Pipeline":
6969
sql_for_lineage = self._extract_select_from_query(query)
7070

7171
if sql_for_lineage:
72+
# Collect upstream table schemas from already-processed queries
73+
external_table_columns = self._collect_upstream_table_schemas(
74+
pipeline, query, table_graph
75+
)
76+
7277
# RecursiveLineageBuilder handles parsing internally
73-
lineage_builder = RecursiveLineageBuilder(sql_for_lineage)
78+
# Pass external_table_columns so it can resolve * to actual columns
79+
lineage_builder = RecursiveLineageBuilder(
80+
sql_for_lineage,
81+
external_table_columns=external_table_columns,
82+
dialect=pipeline.dialect,
83+
)
7484
query_lineage = lineage_builder.build()
7585

7686
# Store query lineage
@@ -98,6 +108,171 @@ def build(self, pipeline_or_graph) -> "Pipeline":
98108

99109
return pipeline
100110

111+
def _expand_star_nodes_in_pipeline(
112+
self, pipeline: "Pipeline", query: ParsedQuery, nodes: list[ColumnNode]
113+
) -> list[ColumnNode]:
114+
"""
115+
Expand * nodes in output layer when upstream columns are known.
116+
117+
For cross-query scenarios:
118+
- If query_1 does SELECT * EXCEPT (col1) FROM staging.table
119+
- And staging.table was created by query_0 with known columns
120+
- We should expand the * to the actual columns (minus excepted ones)
121+
122+
This gives users precise column-level lineage instead of just *.
123+
"""
124+
result = []
125+
126+
# Find all input layer * nodes to get source table info
127+
input_star_nodes = {
128+
node.table_name: node for node in nodes if node.is_star and node.layer == "input"
129+
}
130+
131+
for node in nodes:
132+
# Only expand output layer * nodes
133+
if not (node.is_star and node.layer == "output"):
134+
result.append(node)
135+
continue
136+
137+
# Get the source table from the corresponding input * node
138+
# The output * has EXCEPT/REPLACE info, but we need the source table from input *
139+
source_table_name = None
140+
except_columns = node.except_columns
141+
replace_columns = node.replace_columns
142+
143+
# Find which input table this output * is selecting from
144+
for input_table, input_star in input_star_nodes.items():
145+
# Check if the input * feeds into this output *
146+
# (in simple cases, there's only one input * per query)
147+
# Infer the fully qualified table name for the input table
148+
source_table_name = self._infer_table_name(input_star, query) or input_table
149+
break
150+
151+
if not source_table_name:
152+
# Can't expand - keep the * node
153+
result.append(node)
154+
continue
155+
156+
# Try to find upstream table columns
157+
upstream_columns = self._get_upstream_table_columns(pipeline, source_table_name)
158+
159+
if not upstream_columns:
160+
# Can't expand - keep the * node
161+
result.append(node)
162+
continue
163+
164+
# Expand the * to individual columns
165+
for upstream_col in upstream_columns:
166+
col_name = upstream_col.column_name
167+
168+
# Skip excepted columns
169+
if col_name in except_columns:
170+
continue
171+
172+
# Create expanded column node
173+
# Get the properly inferred destination table name
174+
dest_table_name = self._infer_table_name(node, query) or node.table_name
175+
176+
expanded_node = ColumnNode(
177+
column_name=col_name,
178+
table_name=dest_table_name,
179+
full_name=f"{dest_table_name}.{col_name}",
180+
unit_id=node.unit_id,
181+
layer=node.layer,
182+
query_id=node.query_id,
183+
node_type="direct_column",
184+
is_star=False,
185+
# Check if this column is being replaced
186+
expression=(
187+
replace_columns.get(col_name, col_name)
188+
if col_name in replace_columns
189+
else col_name
190+
),
191+
# Preserve metadata from upstream if available
192+
description=upstream_col.description,
193+
pii=upstream_col.pii,
194+
owner=upstream_col.owner,
195+
tags=upstream_col.tags.copy(),
196+
)
197+
result.append(expanded_node)
198+
199+
return result
200+
201+
def _collect_upstream_table_schemas(
202+
self,
203+
pipeline: "Pipeline",
204+
query: ParsedQuery,
205+
table_graph: TableDependencyGraph,
206+
) -> Dict[str, List[str]]:
207+
"""
208+
Collect column names from upstream tables that this query reads from.
209+
210+
This is used to pass to RecursiveLineageBuilder so it can resolve * properly.
211+
212+
Args:
213+
pipeline: Pipeline being built
214+
query: Current query being processed
215+
table_graph: Table dependency graph
216+
217+
Returns:
218+
Dict mapping table_name -> list of column names
219+
Example: {"staging.orders": ["order_id", "user_id", "amount", "status", "order_date"]}
220+
"""
221+
external_table_columns = {}
222+
223+
# For each source table this query reads from
224+
for source_table in query.source_tables:
225+
# Get the table node
226+
table_node = table_graph.tables.get(source_table)
227+
if not table_node:
228+
continue
229+
230+
# If this table was created by a previous query, get its output columns
231+
if table_node.created_by:
232+
creating_query_id = table_node.created_by
233+
234+
# Get output columns from the creating query
235+
output_cols = [
236+
col.column_name
237+
for col in pipeline.columns.values()
238+
if col.query_id == creating_query_id
239+
and col.table_name == source_table
240+
and col.layer == "output"
241+
and not col.is_star # Don't include * nodes
242+
]
243+
244+
if output_cols:
245+
external_table_columns[source_table] = output_cols
246+
247+
return external_table_columns
248+
249+
def _get_upstream_table_columns(
250+
self, pipeline: "Pipeline", table_name: str
251+
) -> list[ColumnNode]:
252+
"""
253+
Get columns from an upstream table that was created in the pipeline.
254+
255+
Returns the output columns from the query that created this table.
256+
"""
257+
# Find which query created this table
258+
table_node = pipeline.table_graph.tables.get(table_name)
259+
if not table_node or not table_node.created_by:
260+
return []
261+
262+
creating_query_id = table_node.created_by
263+
264+
# Get output columns from the creating query
265+
upstream_cols = [
266+
col
267+
for col in pipeline.columns.values()
268+
if col.query_id == creating_query_id
269+
and col.table_name == table_name
270+
and col.layer == "output"
271+
and not col.is_star # Don't use * nodes as source
272+
]
273+
274+
return upstream_cols
275+
101276
def _add_query_columns(
102277
self,
103278
pipeline: "Pipeline",
@@ -110,9 +285,17 @@ def _add_query_columns(
110285
Note: We add all nodes (both input and output layers) to maintain full lineage.
111286
Input layer nodes represent source columns, output layer nodes represent derived columns.
112287
Both are needed for complete lineage tracing.
288+
289+
Special handling for star expansion:
290+
- If output layer has a * node and we know the upstream columns, expand it
291+
- This is crucial for cross-query lineage to show exact columns
113292
"""
293+
# Check if we need to expand any * nodes in the output layer
294+
nodes_to_add = list(query_lineage.nodes.values())
295+
expanded_nodes = self._expand_star_nodes_in_pipeline(pipeline, query, nodes_to_add)
296+
114297
# Add columns with table context
115-
for node in query_lineage.nodes.values():
298+
for node in expanded_nodes:
116299
# Extract metadata from SQL comments if available
117300
description = None
118301
description_source = None
@@ -234,13 +417,34 @@ def _add_cross_query_edges(self, pipeline: "Pipeline"):
234417
output_star_column = col
235418
break
236419

420+
# Check if this query had SELECT * that was expanded to individual output columns
421+
# Key indicator: ALL upstream columns appear in output with same names
422+
# This distinguishes SELECT * (all columns) from SELECT user_id, COUNT(*) (partial)
423+
has_select_star_expanded = False
424+
if output_star_column is None and input_star_column is not None:
425+
# Count how many upstream columns appear in output
426+
upstream_col_names = {oc.column_name for oc in output_columns}
427+
output_col_names = {
428+
col.column_name
429+
for col in pipeline.columns.values()
430+
if col.query_id == reading_query_id
431+
and col.layer == "output"
432+
and not col.is_star
433+
}
434+
matching_cols = upstream_col_names & output_col_names
435+
436+
# If ALL upstream columns appear in output (or most of them, accounting for EXCEPT),
437+
# then this was likely SELECT * that got expanded
438+
has_select_star_expanded = len(matching_cols) >= len(upstream_col_names) * 0.8
439+
237440
# Use output * for EXCEPT/REPLACE info, but connect to input *
238441
star_column = input_star_column
239442
except_columns = output_star_column.except_columns if output_star_column else set()
240443

241-
# If there's a star column, connect all output columns to it
444+
# If there's a star column AND it hasn't been expanded, connect all output columns to it
242445
# BUT respect EXCEPT clause - skip columns that are excepted
243-
if star_column:
446+
# IMPORTANT: Skip this if SELECT * was expanded - we'll create direct edges instead
447+
if star_column and not has_select_star_expanded:
244448
for output_col in output_columns:
245449
# Skip columns in EXCEPT clause
246450
if output_col.column_name in except_columns:
@@ -259,15 +463,26 @@ def _add_cross_query_edges(self, pipeline: "Pipeline"):
259463
# ALWAYS match columns by name (not just when there's no star)
260464
# This handles cases where the query uses both * (for COUNT(*))
261465
# and specific columns (for SUM(amount), etc.)
466+
#
467+
# Also handles star expansion: when the reading query has SELECT * FROM table,
468+
# the star is expanded to individual OUTPUT columns at parse time.
469+
# We need to connect upstream columns to those expanded output columns.
262470
for output_col in output_columns:
263-
# Find corresponding input column in reading query
264-
# Search for this column in reading query's lineage
471+
# Find corresponding column in reading query by NAME
472+
# This could be:
473+
# 1. Input layer column with same table_name (explicit reference)
474+
# 2. Output layer column with matching name (star-expanded column)
265475
for col in pipeline.columns.values():
266476
if (
267477
col.query_id == reading_query_id
268-
and col.table_name == table_name
269478
and col.column_name == output_col.column_name
270479
):
480+
# Check if this is the right column:
481+
# - Input layer: table_name must match
482+
# - Output layer: any match (star-expanded columns)
483+
if col.layer == "input" and col.table_name != table_name:
484+
continue # Wrong input table
485+
271486
# Create cross-query edge
272487
edge = ColumnEdge(
273488
from_node=output_col,

0 commit comments

Comments
 (0)