|
6 | 6 | from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget |
7 | 7 | from sqlglot import exp |
8 | 8 | from sqlmesh.lsp.description import generate_markdown_description |
| 9 | +from sqlglot.optimizer.scope import build_scope |
9 | 10 | from sqlmesh.lsp.uri import URI |
10 | 11 | from sqlmesh.utils.pydantic import PydanticModel |
| 12 | +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers |
11 | 13 |
|
12 | 14 |
|
13 | 15 | class Reference(PydanticModel): |
14 | 16 | """ |
15 | | - A reference to a model. |
| 17 | + A reference to a model or CTE. |
16 | 18 |
|
17 | 19 | Attributes: |
18 | 20 | range: The range of the reference in the source file |
19 | 21 | uri: The uri of the referenced model |
20 | 22 | markdown_description: The markdown description of the referenced model |
| 23 | + target_range: The range of the definition for go-to-definition (optional, used for CTEs) |
21 | 24 | """ |
22 | 25 |
|
23 | 26 | range: Range |
24 | 27 | uri: str |
25 | 28 | markdown_description: t.Optional[str] = None |
| 29 | + target_range: t.Optional[Range] = None |
26 | 30 |
|
27 | 31 |
|
28 | 32 | def by_position(position: Position) -> t.Callable[[Reference], bool]: |
@@ -88,6 +92,7 @@ def get_model_definitions_for_a_path( |
88 | 92 | - Need to normalize it before matching |
89 | 93 | - Try get_model before normalization |
90 | 94 | - Match to models that the model refers to |
| 95 | + - Also find CTE references within the query |
91 | 96 | """ |
92 | 97 | path = document_uri.to_path() |
93 | 98 | if path.suffix != ".sql": |
@@ -126,66 +131,95 @@ def get_model_definitions_for_a_path( |
126 | 131 | # Find all possible references |
127 | 132 | references = [] |
128 | 133 |
|
129 | | - # Get SQL query and find all table references |
130 | | - tables = list(query.find_all(exp.Table)) |
131 | | - if len(tables) == 0: |
132 | | - return [] |
133 | | - |
134 | 134 | with open(file_path, "r", encoding="utf-8") as file: |
135 | 135 | read_file = file.readlines() |
136 | 136 |
|
137 | | - for table in tables: |
138 | | - # Normalize the table reference |
139 | | - unaliased = table.copy() |
140 | | - if unaliased.args.get("alias") is not None: |
141 | | - unaliased.set("alias", None) |
142 | | - reference_name = unaliased.sql(dialect=dialect) |
143 | | - try: |
144 | | - normalized_reference_name = normalize_model_name( |
145 | | - reference_name, |
146 | | - default_catalog=lint_context.context.default_catalog, |
147 | | - dialect=dialect, |
148 | | - ) |
149 | | - if normalized_reference_name not in depends_on: |
150 | | - continue |
151 | | - except Exception: |
152 | | - # Skip references that cannot be normalized |
153 | | - continue |
154 | | - |
155 | | - # Get the referenced model uri |
156 | | - referenced_model = lint_context.context.get_model( |
157 | | - model_or_snapshot=normalized_reference_name, raise_if_missing=False |
158 | | - ) |
159 | | - if referenced_model is None: |
160 | | - continue |
161 | | - referenced_model_path = referenced_model._path |
162 | | - # Check whether the path exists |
163 | | - if not referenced_model_path.is_file(): |
164 | | - continue |
165 | | - referenced_model_uri = URI.from_path(referenced_model_path) |
166 | | - |
167 | | - # Extract metadata for positioning |
168 | | - table_meta = TokenPositionDetails.from_meta(table.this.meta) |
169 | | - table_range = _range_from_token_position_details(table_meta, read_file) |
170 | | - start_pos = table_range.start |
171 | | - end_pos = table_range.end |
172 | | - |
173 | | - # If there's a catalog or database qualifier, adjust the start position |
174 | | - catalog_or_db = table.args.get("catalog") or table.args.get("db") |
175 | | - if catalog_or_db is not None: |
176 | | - catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta) |
177 | | - catalog_or_db_range = _range_from_token_position_details(catalog_or_db_meta, read_file) |
178 | | - start_pos = catalog_or_db_range.start |
179 | | - |
180 | | - description = generate_markdown_description(referenced_model) |
181 | | - |
182 | | - references.append( |
183 | | - Reference( |
184 | | - uri=referenced_model_uri.value, |
185 | | - range=Range(start=start_pos, end=end_pos), |
186 | | - markdown_description=description, |
187 | | - ) |
188 | | - ) |
| 137 | + # Build scope tree to properly handle nested CTEs |
| 138 | + query = normalize_identifiers(query.copy(), dialect=dialect) |
| 139 | + root_scope = build_scope(query) |
| 140 | + |
| 141 | + if root_scope: |
| 142 | + # Traverse all scopes to find CTE definitions and table references |
| 143 | + for scope in root_scope.traverse(): |
| 144 | + for table in scope.tables: |
| 145 | + table_name = table.name |
| 146 | + |
| 147 | + # Check if this table reference is a CTE in the current scope |
| 148 | + if cte_scope := scope.cte_sources.get(table_name): |
| 149 | + cte = cte_scope.expression.parent |
| 150 | + alias = cte.args["alias"] |
| 151 | + if isinstance(alias, exp.TableAlias): |
| 152 | + identifier = alias.this |
| 153 | + if isinstance(identifier, exp.Identifier): |
| 154 | + target_range = _range_from_token_position_details( |
| 155 | + TokenPositionDetails.from_meta(identifier.meta), read_file |
| 156 | + ) |
| 157 | + table_range = _range_from_token_position_details( |
| 158 | + TokenPositionDetails.from_meta(table.this.meta), read_file |
| 159 | + ) |
| 160 | + references.append( |
| 161 | + Reference( |
| 162 | + uri=document_uri.value, # Same file |
| 163 | + range=table_range, |
| 164 | + target_range=target_range, |
| 165 | + ) |
| 166 | + ) |
| 167 | + continue |
| 168 | + |
| 169 | + # For non-CTE tables, process as before (external model references) |
| 170 | + # Normalize the table reference |
| 171 | + unaliased = table.copy() |
| 172 | + if unaliased.args.get("alias") is not None: |
| 173 | + unaliased.set("alias", None) |
| 174 | + reference_name = unaliased.sql(dialect=dialect) |
| 175 | + try: |
| 176 | + normalized_reference_name = normalize_model_name( |
| 177 | + reference_name, |
| 178 | + default_catalog=lint_context.context.default_catalog, |
| 179 | + dialect=dialect, |
| 180 | + ) |
| 181 | + if normalized_reference_name not in depends_on: |
| 182 | + continue |
| 183 | + except Exception: |
| 184 | + # Skip references that cannot be normalized |
| 185 | + continue |
| 186 | + |
| 187 | + # Get the referenced model uri |
| 188 | + referenced_model = lint_context.context.get_model( |
| 189 | + model_or_snapshot=normalized_reference_name, raise_if_missing=False |
| 190 | + ) |
| 191 | + if referenced_model is None: |
| 192 | + continue |
| 193 | + referenced_model_path = referenced_model._path |
| 194 | + # Check whether the path exists |
| 195 | + if not referenced_model_path.is_file(): |
| 196 | + continue |
| 197 | + referenced_model_uri = URI.from_path(referenced_model_path) |
| 198 | + |
| 199 | + # Extract metadata for positioning |
| 200 | + table_meta = TokenPositionDetails.from_meta(table.this.meta) |
| 201 | + table_range = _range_from_token_position_details(table_meta, read_file) |
| 202 | + start_pos = table_range.start |
| 203 | + end_pos = table_range.end |
| 204 | + |
| 205 | + # If there's a catalog or database qualifier, adjust the start position |
| 206 | + catalog_or_db = table.args.get("catalog") or table.args.get("db") |
| 207 | + if catalog_or_db is not None: |
| 208 | + catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta) |
| 209 | + catalog_or_db_range = _range_from_token_position_details( |
| 210 | + catalog_or_db_meta, read_file |
| 211 | + ) |
| 212 | + start_pos = catalog_or_db_range.start |
| 213 | + |
| 214 | + description = generate_markdown_description(referenced_model) |
| 215 | + |
| 216 | + references.append( |
| 217 | + Reference( |
| 218 | + uri=referenced_model_uri.value, |
| 219 | + range=Range(start=start_pos, end=end_pos), |
| 220 | + markdown_description=description, |
| 221 | + ) |
| 222 | + ) |
189 | 223 |
|
190 | 224 | return references |
191 | 225 |
|
|
0 commit comments