Skip to content

Commit f656bc3

Browse files
Feat(lsp): Add support for find references for Model usages
1 parent 53224fc commit f656bc3

File tree

6 files changed

+752
-34
lines changed

6 files changed

+752
-34
lines changed

sqlmesh/lsp/main.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
RenderModelResponse,
3838
)
3939
from sqlmesh.lsp.hints import get_hints
40-
from sqlmesh.lsp.reference import get_references, get_cte_references
40+
from sqlmesh.lsp.reference import LSPModelReference, get_references, get_all_references
4141
from sqlmesh.lsp.uri import URI
4242
from web.server.api.endpoints.lineage import column_lineage, model_lineage
4343
from web.server.api.endpoints.models import get_models
@@ -369,8 +369,8 @@ def goto_definition(
369369
references = get_references(self.lsp_context, uri, params.position)
370370
location_links = []
371371
for reference in references:
372-
# Use target_range if available (for CTEs), otherwise default to start of file
373-
if reference.target_range:
372+
# Use target_range if available (CTEs, Macros), otherwise default to start of file
373+
if not isinstance(reference, LSPModelReference) and reference.target_range:
374374
target_range = reference.target_range
375375
target_selection_range = reference.target_range
376376
else:
@@ -400,18 +400,18 @@ def goto_definition(
400400
def find_references(
401401
ls: LanguageServer, params: types.ReferenceParams
402402
) -> t.Optional[t.List[types.Location]]:
403-
"""Find all references of a symbol (currently supporting CTEs)"""
403+
"""Find all references of a symbol (supporting CTEs, models for now)"""
404404
try:
405405
uri = URI(params.text_document.uri)
406406
self._ensure_context_for_document(uri)
407407
document = ls.workspace.get_text_document(params.text_document.uri)
408408
if self.lsp_context is None:
409409
raise RuntimeError(f"No context found for document: {document.path}")
410410

411-
cte_references = get_cte_references(self.lsp_context, uri, params.position)
411+
all_references = get_all_references(self.lsp_context, uri, params.position)
412412

413413
# Convert references to Location objects
414-
locations = [types.Location(uri=ref.uri, range=ref.range) for ref in cte_references]
414+
locations = [types.Location(uri=ref.uri, range=ref.range) for ref in all_references]
415415

416416
return locations if locations else None
417417
except Exception as e:

sqlmesh/lsp/reference.py

Lines changed: 155 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from lsprotocol.types import Range, Position
22
import typing as t
33
from pathlib import Path
4+
from pydantic import Field
45

56
from sqlmesh.core.audit import StandaloneAudit
67
from sqlmesh.core.dialect import normalize_model_name
@@ -23,21 +24,37 @@
2324
import inspect
2425

2526

26-
class Reference(PydanticModel):
27-
"""
28-
A reference to a model or CTE.
29-
30-
Attributes:
31-
range: The range of the reference in the source file
32-
uri: The uri of the referenced model
33-
markdown_description: The markdown description of the referenced model
34-
target_range: The range of the definition for go-to-definition (optional, used for CTEs)
35-
"""
27+
class LSPBaseReference(PydanticModel):
28+
"""Base class for all LSP reference types."""
3629

3730
range: Range
3831
uri: str
3932
markdown_description: t.Optional[str] = None
40-
target_range: t.Optional[Range] = None
33+
34+
35+
class LSPModelReference(LSPBaseReference):
36+
"""A LSP reference to a model."""
37+
38+
type: t.Literal["model"] = "model"
39+
40+
41+
class LSPCteReference(LSPBaseReference):
42+
"""A LSP reference to a CTE."""
43+
44+
type: t.Literal["cte"] = "cte"
45+
target_range: Range
46+
47+
48+
class LSPMacroReference(LSPBaseReference):
49+
"""A LSP reference to a macro."""
50+
51+
type: t.Literal["macro"] = "macro"
52+
target_range: Range
53+
54+
55+
Reference = t.Annotated[
56+
t.Union[LSPModelReference, LSPCteReference, LSPMacroReference], Field(discriminator="type")
57+
]
4158

4259

4360
def by_position(position: Position) -> t.Callable[[Reference], bool]:
@@ -136,7 +153,7 @@ def get_model_definitions_for_a_path(
136153
return []
137154

138155
# Find all possible references
139-
references = []
156+
references: t.List[Reference] = []
140157

141158
with open(file_path, "r", encoding="utf-8") as file:
142159
read_file = file.readlines()
@@ -173,7 +190,7 @@ def get_model_definitions_for_a_path(
173190
table_range = to_lsp_range(table_range_sqlmesh)
174191

175192
references.append(
176-
Reference(
193+
LSPCteReference(
177194
uri=document_uri.value, # Same file
178195
range=table_range,
179196
target_range=target_range,
@@ -227,7 +244,7 @@ def get_model_definitions_for_a_path(
227244
description = generate_markdown_description(referenced_model)
228245

229246
references.append(
230-
Reference(
247+
LSPModelReference(
231248
uri=referenced_model_uri.value,
232249
range=Range(
233250
start=to_lsp_position(start_pos_sqlmesh),
@@ -286,7 +303,7 @@ def get_macro_definitions_for_a_path(
286303
return []
287304

288305
references = []
289-
config_for_model, config_path = lsp_context.context.config_for_path(
306+
_, config_path = lsp_context.context.config_for_path(
290307
file_path,
291308
)
292309

@@ -372,7 +389,7 @@ def get_macro_reference(
372389
# Create a reference to the macro definition
373390
macro_uri = URI.from_path(path)
374391

375-
return Reference(
392+
return LSPMacroReference(
376393
uri=macro_uri.value,
377394
range=to_lsp_range(macro_range),
378395
target_range=Range(
@@ -405,7 +422,7 @@ def get_built_in_macro_reference(macro_name: str, macro_range: Range) -> t.Optio
405422
# Calculate the end line number by counting the number of source lines
406423
end_line_number = line_number + len(source_lines) - 1
407424

408-
return Reference(
425+
return LSPMacroReference(
409426
uri=URI.from_path(Path(filename)).value,
410427
range=macro_range,
411428
target_range=Range(
@@ -416,9 +433,91 @@ def get_built_in_macro_reference(macro_name: str, macro_range: Range) -> t.Optio
416433
)
417434

418435

436+
def get_model_find_all_references(
437+
lint_context: LSPContext, document_uri: URI, position: Position
438+
) -> t.List[LSPModelReference]:
439+
"""
440+
Get all references to a model across the entire project.
441+
442+
This function finds all usages of a model in other files by searching through
443+
all models in the project and checking their dependencies.
444+
445+
Args:
446+
lint_context: The LSP context
447+
document_uri: The URI of the document
448+
position: The position to check for model references
449+
450+
Returns:
451+
A list of references to the model across all files
452+
"""
453+
# First, get the references in the current file to determine what model we're looking for
454+
current_file_references = get_model_definitions_for_a_path(lint_context, document_uri)
455+
456+
# Find the model reference at the cursor position
457+
target_model_uri: t.Optional[str] = None
458+
for ref in current_file_references:
459+
if _position_within_range(position, ref.range) and isinstance(ref, LSPModelReference):
460+
# This is a model reference, get the target model URI
461+
target_model_uri = ref.uri
462+
break
463+
464+
if target_model_uri is None:
465+
return []
466+
467+
# Start with the model definition
468+
all_references: t.List[LSPModelReference] = [
469+
LSPModelReference(
470+
uri=ref.uri,
471+
range=Range(
472+
start=Position(line=0, character=0),
473+
end=Position(line=0, character=0),
474+
),
475+
markdown_description=ref.markdown_description,
476+
)
477+
]
478+
479+
# Then add the original reference
480+
for ref in current_file_references:
481+
if ref.uri == target_model_uri and isinstance(ref, LSPModelReference):
482+
all_references.append(
483+
LSPModelReference(
484+
uri=document_uri.value,
485+
range=ref.range,
486+
markdown_description=ref.markdown_description,
487+
)
488+
)
489+
490+
# Search through the models in the project
491+
for path, target in lint_context.map.items():
492+
if not isinstance(target, (ModelTarget, AuditTarget)):
493+
continue
494+
495+
file_uri = URI.from_path(path)
496+
497+
# Skip current file, already processed
498+
if file_uri.value == document_uri.value:
499+
continue
500+
501+
# Get model references for this file
502+
file_references = get_model_definitions_for_a_path(lint_context, file_uri)
503+
504+
# Add references that point to the target model file
505+
for ref in file_references:
506+
if ref.uri == target_model_uri and isinstance(ref, LSPModelReference):
507+
all_references.append(
508+
LSPModelReference(
509+
uri=file_uri.value,
510+
range=ref.range,
511+
markdown_description=ref.markdown_description,
512+
)
513+
)
514+
515+
return all_references
516+
517+
419518
def get_cte_references(
420519
lint_context: LSPContext, document_uri: URI, position: Position
421-
) -> t.List[Reference]:
520+
) -> t.List[LSPCteReference]:
422521
"""
423522
Get all references to a CTE at a specific position in a document.
424523
@@ -432,12 +531,12 @@ def get_cte_references(
432531
Returns:
433532
A list of references to the CTE (including its definition and all usages)
434533
"""
435-
references = get_model_definitions_for_a_path(lint_context, document_uri)
436534

437-
# Filter for CTE references (those with target_range set and same URI)
438-
# TODO: Consider extending Reference class to explicitly indicate reference type instead
439-
cte_references = [
440-
ref for ref in references if ref.target_range is not None and ref.uri == document_uri.value
535+
# Filter to get the CTE references
536+
cte_references: t.List[LSPCteReference] = [
537+
ref
538+
for ref in get_model_definitions_for_a_path(lint_context, document_uri)
539+
if isinstance(ref, LSPCteReference)
441540
]
442541

443542
if not cte_references:
@@ -450,7 +549,7 @@ def get_cte_references(
450549
target_cte_definition_range = ref.target_range
451550
break
452551
# Check if cursor is on the CTE definition
453-
elif ref.target_range and _position_within_range(position, ref.target_range):
552+
elif _position_within_range(position, ref.target_range):
454553
target_cte_definition_range = ref.target_range
455554
break
456555

@@ -459,9 +558,10 @@ def get_cte_references(
459558

460559
# Add the CTE definition
461560
matching_references = [
462-
Reference(
561+
LSPCteReference(
463562
uri=document_uri.value,
464563
range=target_cte_definition_range,
564+
target_range=target_cte_definition_range,
465565
markdown_description="CTE definition",
466566
)
467567
]
@@ -470,16 +570,45 @@ def get_cte_references(
470570
for ref in cte_references:
471571
if ref.target_range == target_cte_definition_range:
472572
matching_references.append(
473-
Reference(
573+
LSPCteReference(
474574
uri=document_uri.value,
475575
range=ref.range,
576+
target_range=ref.target_range,
476577
markdown_description="CTE usage",
477578
)
478579
)
479580

480581
return matching_references
481582

482583

584+
def get_all_references(
585+
lint_context: LSPContext, document_uri: URI, position: Position
586+
) -> t.Sequence[Reference]:
587+
"""
588+
Get all references of a symbol at a specific position in a document.
589+
590+
This function determines the type of reference (CTE, model for now) at the cursor
591+
position and returns all references to that symbol across the project.
592+
593+
Args:
594+
lint_context: The LSP context
595+
document_uri: The URI of the document
596+
position: The position to check for references
597+
598+
Returns:
599+
A list of references to the symbol at the given position
600+
"""
601+
# First try CTE references (within same file)
602+
if cte_references := get_cte_references(lint_context, document_uri, position):
603+
return cte_references
604+
605+
# Then try model references (across files)
606+
if model_references := get_model_find_all_references(lint_context, document_uri, position):
607+
return model_references
608+
609+
return []
610+
611+
483612
def _position_within_range(position: Position, range: Range) -> bool:
484613
"""Check if a position is within a given range."""
485614
return (

tests/lsp/test_reference_macro.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from sqlmesh.core.context import Context
22
from sqlmesh.lsp.context import LSPContext, ModelTarget
3-
from sqlmesh.lsp.reference import get_macro_definitions_for_a_path
3+
from sqlmesh.lsp.reference import LSPMacroReference, get_macro_definitions_for_a_path
44
from sqlmesh.lsp.uri import URI
55

66

@@ -24,5 +24,6 @@ def test_macro_references() -> None:
2424

2525
# Check that all references point to the utils.py file
2626
for ref in macro_references:
27+
assert isinstance(ref, LSPMacroReference)
2728
assert ref.uri.endswith("sushi/macros/utils.py")
2829
assert ref.target_range is not None

tests/lsp/test_reference_macro_multi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from sqlmesh.core.context import Context
22
from sqlmesh.lsp.context import LSPContext, ModelTarget
3-
from sqlmesh.lsp.reference import get_macro_definitions_for_a_path
3+
from sqlmesh.lsp.reference import LSPMacroReference, get_macro_definitions_for_a_path
44
from sqlmesh.lsp.uri import URI
55

66

@@ -19,5 +19,6 @@ def test_macro_references_multirepo() -> None:
1919

2020
assert len(macro_references) == 2
2121
for ref in macro_references:
22+
assert isinstance(ref, LSPMacroReference)
2223
assert ref.uri.endswith("multi/repo_2/macros/__init__.py")
2324
assert ref.target_range is not None

0 commit comments

Comments
 (0)