Skip to content

Commit 8ff6aa2

Browse files
authored
feat: external model takes you to the model not just file (#4711)
1 parent 28155b3 commit 8ff6aa2

File tree

4 files changed

+157
-28
lines changed

4 files changed

+157
-28
lines changed

sqlmesh/lsp/helpers.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,35 @@
66
)
77

88

9-
def to_lsp_range(
10-
range: SQLMeshRange,
11-
) -> Range:
9+
def to_sqlmesh_position(position: Position) -> SQLMeshPosition:
1210
"""
13-
Converts a SQLMesh Range to an LSP Range.
11+
Converts an LSP Position to a SQLMesh Position.
1412
"""
15-
return Range(
16-
start=Position(line=range.start.line, character=range.start.character),
17-
end=Position(line=range.end.line, character=range.end.character),
18-
)
13+
return SQLMeshPosition(line=position.line, character=position.character)
1914

2015

21-
def to_lsp_position(
22-
position: SQLMeshPosition,
23-
) -> Position:
16+
def to_lsp_position(position: SQLMeshPosition) -> Position:
2417
"""
2518
Converts a SQLMesh Position to an LSP Position.
2619
"""
2720
return Position(line=position.line, character=position.character)
21+
22+
23+
def to_sqlmesh_range(range: Range) -> SQLMeshRange:
24+
"""
25+
Converts an LSP Range to a SQLMesh Range.
26+
"""
27+
return SQLMeshRange(
28+
start=to_sqlmesh_position(range.start),
29+
end=to_sqlmesh_position(range.end),
30+
)
31+
32+
33+
def to_lsp_range(range: SQLMeshRange) -> Range:
34+
"""
35+
Converts a SQLMesh Range to an LSP Range.
36+
"""
37+
return Range(
38+
start=to_lsp_position(range.start),
39+
end=to_lsp_position(range.end),
40+
)

sqlmesh/lsp/main.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from sqlmesh.lsp.reference import (
4949
LSPCteReference,
5050
LSPModelReference,
51+
LSPExternalModelReference,
5152
get_references,
5253
get_all_references,
5354
)
@@ -444,11 +445,19 @@ def goto_definition(
444445
references = get_references(self.lsp_context, uri, params.position)
445446
location_links = []
446447
for reference in references:
447-
# Use target_range if available (CTEs, Macros), otherwise default to start of file
448-
if not isinstance(reference, LSPModelReference):
449-
target_range = reference.target_range
450-
target_selection_range = reference.target_range
451-
else:
448+
# Use target_range if available (CTEs, Macros, and external models in YAML)
449+
if isinstance(reference, LSPModelReference):
450+
# Regular SQL models - default to start of file
451+
target_range = types.Range(
452+
start=types.Position(line=0, character=0),
453+
end=types.Position(line=0, character=0),
454+
)
455+
target_selection_range = types.Range(
456+
start=types.Position(line=0, character=0),
457+
end=types.Position(line=0, character=0),
458+
)
459+
elif isinstance(reference, LSPExternalModelReference):
460+
# External models may have target_range set for YAML files
452461
target_range = types.Range(
453462
start=types.Position(line=0, character=0),
454463
end=types.Position(line=0, character=0),
@@ -457,6 +466,13 @@ def goto_definition(
457466
start=types.Position(line=0, character=0),
458467
end=types.Position(line=0, character=0),
459468
)
469+
if reference.target_range is not None:
470+
target_range = reference.target_range
471+
target_selection_range = reference.target_range
472+
else:
473+
# CTEs and Macros always have target_range
474+
target_range = reference.target_range
475+
target_selection_range = reference.target_range
460476

461477
location_links.append(
462478
types.LocationLink(

sqlmesh/lsp/reference.py

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sqlmesh.core.linter.helpers import (
99
TokenPositionDetails,
1010
)
11-
from sqlmesh.core.model.definition import SqlModel
11+
from sqlmesh.core.model.definition import SqlModel, ExternalModel
1212
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
1313
from sqlglot import exp
1414
from sqlmesh.lsp.description import generate_markdown_description
@@ -22,17 +22,28 @@
2222
from sqlmesh.core.model import Model
2323
from sqlmesh import macro
2424
import inspect
25+
from ruamel.yaml import YAML
2526

2627

2728
class LSPModelReference(PydanticModel):
28-
"""A LSP reference to a model."""
29+
"""A LSP reference to a model, excluding external models."""
2930

3031
type: t.Literal["model"] = "model"
3132
uri: str
3233
range: Range
3334
markdown_description: t.Optional[str] = None
3435

3536

37+
class LSPExternalModelReference(PydanticModel):
38+
"""A LSP reference to an external model."""
39+
40+
type: t.Literal["external_model"] = "external_model"
41+
uri: str
42+
range: Range
43+
markdown_description: t.Optional[str] = None
44+
target_range: t.Optional[Range] = None
45+
46+
3647
class LSPCteReference(PydanticModel):
3748
"""A LSP reference to a CTE."""
3849

@@ -53,7 +64,8 @@ class LSPMacroReference(PydanticModel):
5364

5465

5566
Reference = t.Annotated[
56-
t.Union[LSPModelReference, LSPCteReference, LSPMacroReference], Field(discriminator="type")
67+
t.Union[LSPModelReference, LSPCteReference, LSPMacroReference, LSPExternalModelReference],
68+
Field(discriminator="type"),
5769
]
5870

5971

@@ -243,16 +255,38 @@ def get_model_definitions_for_a_path(
243255

244256
description = generate_markdown_description(referenced_model)
245257

246-
references.append(
247-
LSPModelReference(
248-
uri=referenced_model_uri.value,
249-
range=Range(
250-
start=to_lsp_position(start_pos_sqlmesh),
251-
end=to_lsp_position(end_pos_sqlmesh),
252-
),
253-
markdown_description=description,
258+
# For external models in YAML files, find the specific model block
259+
if isinstance(referenced_model, ExternalModel):
260+
yaml_target_range: t.Optional[Range] = None
261+
if (
262+
referenced_model_path.suffix in (".yaml", ".yml")
263+
and referenced_model_path.is_file()
264+
):
265+
yaml_target_range = _get_yaml_model_range(
266+
referenced_model_path, referenced_model.name
267+
)
268+
references.append(
269+
LSPExternalModelReference(
270+
uri=referenced_model_uri.value,
271+
range=Range(
272+
start=to_lsp_position(start_pos_sqlmesh),
273+
end=to_lsp_position(end_pos_sqlmesh),
274+
),
275+
markdown_description=description,
276+
target_range=yaml_target_range,
277+
)
278+
)
279+
else:
280+
references.append(
281+
LSPModelReference(
282+
uri=referenced_model_uri.value,
283+
range=Range(
284+
start=to_lsp_position(start_pos_sqlmesh),
285+
end=to_lsp_position(end_pos_sqlmesh),
286+
),
287+
markdown_description=description,
288+
)
254289
)
255-
)
256290

257291
return references
258292

@@ -699,3 +733,31 @@ def _position_within_range(position: Position, range: Range) -> bool:
699733
range.end.line > position.line
700734
or (range.end.line == position.line and range.end.character >= position.character)
701735
)
736+
737+
738+
def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]:
739+
"""
740+
Find the range of a specific model block in a YAML file.
741+
742+
Args:
743+
yaml_path: Path to the YAML file
744+
model_name: Name of the model to find
745+
746+
Returns:
747+
The Range of the model block in the YAML file, or None if not found
748+
"""
749+
yaml = YAML()
750+
with path.open("r", encoding="utf-8") as f:
751+
data = yaml.load(f)
752+
753+
if not isinstance(data, list):
754+
return None
755+
756+
for item in data:
757+
if isinstance(item, dict) and item.get("name") == model_name:
758+
# Get size of block by taking the earliest line/col in the items block and the last line/col of the block
759+
position_data = item.lc.data["name"] # type: ignore
760+
start = Position(line=position_data[2], character=position_data[3])
761+
end = Position(line=position_data[2], character=position_data[3] + len(item["name"]))
762+
return Range(start=start, end=end)
763+
return None
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from lsprotocol.types import Position
2+
from sqlmesh.core.context import Context
3+
from sqlmesh.core.linter.helpers import read_range_from_file
4+
from sqlmesh.lsp.context import LSPContext, ModelTarget
5+
from sqlmesh.lsp.helpers import to_sqlmesh_range
6+
from sqlmesh.lsp.reference import get_references, LSPExternalModelReference
7+
from sqlmesh.lsp.uri import URI
8+
9+
10+
def test_reference() -> None:
11+
context = Context(paths=["examples/sushi"])
12+
lsp_context = LSPContext(context)
13+
14+
# Find model URIs
15+
customers = next(
16+
path
17+
for path, info in lsp_context.map.items()
18+
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
19+
)
20+
21+
# Position of reference in file sushi.customers for sushi.raw_demographics
22+
position = Position(line=42, character=20)
23+
references = get_references(lsp_context, URI.from_path(customers), position)
24+
25+
assert len(references) == 1
26+
reference = references[0]
27+
assert isinstance(reference, LSPExternalModelReference)
28+
assert reference.uri.endswith("external_models.yaml")
29+
30+
source_range = read_range_from_file(customers, to_sqlmesh_range(reference.range))
31+
assert source_range == "raw.demographics"
32+
33+
if reference.target_range is None:
34+
raise AssertionError("Reference target range should not be None")
35+
target_range = read_range_from_file(
36+
URI(reference.uri).to_path(), to_sqlmesh_range(reference.target_range)
37+
)
38+
assert target_range == "raw.demographics"

0 commit comments

Comments
 (0)