Skip to content

Commit f26a437

Browse files
address comments; remove LSPBaseReference class
1 parent 2c8eeba commit f26a437

File tree

4 files changed

+41
-26
lines changed

4 files changed

+41
-26
lines changed

sqlmesh/lsp/main.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,12 @@
4444
CustomMethod,
4545
)
4646
from sqlmesh.lsp.hints import get_hints
47-
from sqlmesh.lsp.reference import LSPModelReference, get_references, get_all_references
47+
from sqlmesh.lsp.reference import (
48+
LSPCteReference,
49+
LSPModelReference,
50+
get_references,
51+
get_all_references,
52+
)
4853
from sqlmesh.lsp.uri import URI
4954
from web.server.api.endpoints.lineage import column_lineage, model_lineage
5055
from web.server.api.endpoints.models import get_models
@@ -368,7 +373,7 @@ def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hov
368373
if not references:
369374
return None
370375
reference = references[0]
371-
if not reference.markdown_description:
376+
if isinstance(reference, LSPCteReference) or not reference.markdown_description:
372377
return None
373378
return types.Hover(
374379
contents=types.MarkupContent(
@@ -419,7 +424,7 @@ def goto_definition(
419424
location_links = []
420425
for reference in references:
421426
# Use target_range if available (CTEs, Macros), otherwise default to start of file
422-
if not isinstance(reference, LSPModelReference) and reference.target_range:
427+
if not isinstance(reference, LSPModelReference):
423428
target_range = reference.target_range
424429
target_selection_range = reference.target_range
425430
else:

sqlmesh/lsp/reference.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,32 @@
2424
import inspect
2525

2626

27-
class LSPBaseReference(PydanticModel):
28-
"""Base class for all LSP reference types."""
29-
30-
range: Range
31-
uri: str
32-
markdown_description: t.Optional[str] = None
33-
34-
35-
class LSPModelReference(LSPBaseReference):
27+
class LSPModelReference(PydanticModel):
3628
"""A LSP reference to a model."""
3729

3830
type: t.Literal["model"] = "model"
31+
uri: str
32+
range: Range
33+
markdown_description: t.Optional[str] = None
3934

4035

41-
class LSPCteReference(LSPBaseReference):
36+
class LSPCteReference(PydanticModel):
4237
"""A LSP reference to a CTE."""
4338

4439
type: t.Literal["cte"] = "cte"
40+
uri: str
41+
range: Range
4542
target_range: Range
4643

4744

48-
class LSPMacroReference(LSPBaseReference):
45+
class LSPMacroReference(PydanticModel):
4946
"""A LSP reference to a macro."""
5047

5148
type: t.Literal["macro"] = "macro"
49+
uri: str
50+
range: Range
5251
target_range: Range
52+
markdown_description: t.Optional[str] = None
5353

5454

5555
Reference = t.Annotated[
@@ -451,12 +451,16 @@ def get_model_find_all_references(
451451
A list of references to the model across all files
452452
"""
453453
# First, get the references in the current file to determine what model we're looking for
454-
current_file_references = get_model_definitions_for_a_path(lint_context, document_uri)
454+
current_file_references = [
455+
ref
456+
for ref in get_model_definitions_for_a_path(lint_context, document_uri)
457+
if isinstance(ref, LSPModelReference)
458+
]
455459

456460
# Find the model reference at the cursor position
457461
target_model_uri: t.Optional[str] = None
458462
for ref in current_file_references:
459-
if _position_within_range(position, ref.range) and isinstance(ref, LSPModelReference):
463+
if _position_within_range(position, ref.range):
460464
# This is a model reference, get the target model URI
461465
target_model_uri = ref.uri
462466
break
@@ -499,7 +503,11 @@ def get_model_find_all_references(
499503
continue
500504

501505
# Get model references for this file
502-
file_references = get_model_definitions_for_a_path(lint_context, file_uri)
506+
file_references = [
507+
ref
508+
for ref in get_model_definitions_for_a_path(lint_context, file_uri)
509+
if isinstance(ref, LSPModelReference)
510+
]
503511

504512
# Add references that point to the target model file
505513
for ref in file_references:
@@ -562,7 +570,6 @@ def get_cte_references(
562570
uri=document_uri.value,
563571
range=target_cte_definition_range,
564572
target_range=target_cte_definition_range,
565-
markdown_description="CTE definition",
566573
)
567574
]
568575

@@ -574,7 +581,6 @@ def get_cte_references(
574581
uri=document_uri.value,
575582
range=ref.range,
576583
target_range=ref.target_range,
577-
markdown_description="CTE usage",
578584
)
579585
)
580586

tests/lsp/test_reference.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from lsprotocol.types import Position
22
from sqlmesh.core.context import Context
33
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
4-
from sqlmesh.lsp.reference import get_model_definitions_for_a_path, by_position
4+
from sqlmesh.lsp.reference import LSPModelReference, get_model_definitions_for_a_path, by_position
55
from sqlmesh.lsp.uri import URI
66

77

@@ -47,9 +47,13 @@ def test_reference_with_alias() -> None:
4747
if isinstance(info, ModelTarget) and "sushi.waiter_revenue_by_day" in info.names
4848
)
4949

50-
references = get_model_definitions_for_a_path(
51-
lsp_context, URI.from_path(waiter_revenue_by_day_path)
52-
)
50+
references = [
51+
ref
52+
for ref in get_model_definitions_for_a_path(
53+
lsp_context, URI.from_path(waiter_revenue_by_day_path)
54+
)
55+
if isinstance(ref, LSPModelReference)
56+
]
5357
assert len(references) == 3
5458

5559
with open(waiter_revenue_by_day_path, "r") as file:

tests/lsp/test_reference_cte.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import re
22
from sqlmesh.core.context import Context
33
from sqlmesh.lsp.context import LSPContext, ModelTarget
4-
from sqlmesh.lsp.reference import get_references
4+
from sqlmesh.lsp.reference import LSPCteReference, get_references
55
from sqlmesh.lsp.uri import URI
66
from lsprotocol.types import Range, Position
77
import typing as t
@@ -28,7 +28,7 @@ def test_cte_parsing():
2828
references = get_references(lsp_context, URI.from_path(sushi_customers_path), position)
2929
assert len(references) == 1
3030
assert references[0].uri == URI.from_path(sushi_customers_path).value
31-
assert references[0].markdown_description is None
31+
assert isinstance(references[0], LSPCteReference)
3232
assert (
3333
references[0].range.start.line == ranges[1].start.line
3434
) # The reference location (where we clicked)
@@ -43,7 +43,7 @@ def test_cte_parsing():
4343
references = get_references(lsp_context, URI.from_path(sushi_customers_path), position)
4444
assert len(references) == 1
4545
assert references[0].uri == URI.from_path(sushi_customers_path).value
46-
assert references[0].markdown_description is None
46+
assert isinstance(references[0], LSPCteReference)
4747
assert (
4848
references[0].range.start.line == ranges[1].start.line
4949
) # The reference location (where we clicked)

0 commit comments

Comments
 (0)