Skip to content

Commit d7ee92c

Browse files
Feat(lsp): Add find all and go to references support for Macros
1 parent 8efbe5f commit d7ee92c

File tree

7 files changed

+418
-7
lines changed

7 files changed

+418
-7
lines changed

examples/sushi/models/customers.sql

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ LEFT JOIN (
3333
WITH current_marketing AS (
3434
SELECT
3535
customer_id,
36-
status
36+
status,
37+
@ADD_ONE(1) AS another_column,
3738
FROM current_marketing_outer
3839
)
3940
SELECT * FROM current_marketing

sqlmesh/lsp/reference.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,13 +587,88 @@ def get_cte_references(
587587
return matching_references
588588

589589

590+
def get_macro_find_all_references(
591+
lint_context: LSPContext, document_uri: URI, position: Position
592+
) -> t.List[LSPMacroReference]:
593+
"""
594+
Get all references to a macro at a specific position in a document.
595+
596+
This function finds all usages of a macro across the entire project.
597+
598+
Args:
599+
lint_context: The LSP context
600+
document_uri: The URI of the document
601+
position: The position to check for macro references
602+
603+
Returns:
604+
A list of references to the macro across all files
605+
"""
606+
# First, get the macro references in the current file to identify the target macro
607+
current_file_references = [
608+
ref
609+
for ref in get_macro_definitions_for_a_path(lint_context, document_uri)
610+
if isinstance(ref, LSPMacroReference)
611+
]
612+
613+
# Find the macro reference at the cursor position
614+
target_macro_uri: t.Optional[str] = None
615+
target_macro_target_range: t.Optional[Range] = None
616+
617+
for ref in current_file_references:
618+
if _position_within_range(position, ref.range):
619+
target_macro_uri = ref.uri
620+
target_macro_target_range = ref.target_range
621+
break
622+
623+
if target_macro_uri is None:
624+
return []
625+
626+
# Start with the macro definition
627+
all_references: t.List[LSPMacroReference] = [
628+
LSPMacroReference(
629+
uri=target_macro_uri,
630+
range=target_macro_target_range,
631+
target_range=target_macro_target_range,
632+
markdown_description=None,
633+
)
634+
]
635+
636+
# Search through all SQL and audit files in the project
637+
for path, target in lint_context.map.items():
638+
if not isinstance(target, (ModelTarget, AuditTarget)):
639+
continue
640+
641+
file_uri = URI.from_path(path)
642+
643+
# Get macro references for this file
644+
file_macro_references = [
645+
ref
646+
for ref in get_macro_definitions_for_a_path(lint_context, file_uri)
647+
if isinstance(ref, LSPMacroReference)
648+
]
649+
650+
# Add references that point to the same macro definition
651+
for ref in file_macro_references:
652+
if ref.uri == target_macro_uri and ref.target_range == target_macro_target_range:
653+
all_references.append(
654+
LSPMacroReference(
655+
uri=file_uri.value,
656+
range=ref.range,
657+
target_range=ref.target_range,
658+
markdown_description=ref.markdown_description,
659+
)
660+
)
661+
662+
return all_references
663+
664+
590665
def get_all_references(
591666
lint_context: LSPContext, document_uri: URI, position: Position
592667
) -> t.Sequence[Reference]:
593668
"""
594669
Get all references of a symbol at a specific position in a document.
595670
596-
This function determines the type of reference (CTE, model for now) at the cursor
671+
This function determines the type of reference (CTE, model or macro) at the cursor
597672
position and returns all references to that symbol across the project.
598673
599674
Args:
@@ -612,6 +687,10 @@ def get_all_references(
612687
if model_references := get_model_find_all_references(lint_context, document_uri, position):
613688
return model_references
614689

690+
# Finally try macro references (across files)
691+
if macro_references := get_macro_find_all_references(lint_context, document_uri, position):
692+
return macro_references
693+
615694
return []
616695

617696

tests/integrations/github/cicd/test_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1684,7 +1684,7 @@ def test_overlapping_changes_models(
16841684
16851685
+++
16861686
1687-
@@ -25,7 +25,8 @@
1687+
@@ -29,7 +29,8 @@
16881688
16891689
SELECT DISTINCT
16901690
CAST(o.customer_id AS INT) AS customer_id,
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
from lsprotocol.types import Position
2+
from sqlmesh.core.context import Context
3+
from sqlmesh.lsp.context import LSPContext, ModelTarget
4+
from sqlmesh.lsp.reference import (
5+
get_macro_find_all_references,
6+
get_macro_definitions_for_a_path,
7+
)
8+
from sqlmesh.lsp.uri import URI
9+
10+
11+
def test_find_all_references_for_macro_add_one():
12+
"""Test finding all references to the @ADD_ONE macro."""
13+
context = Context(paths=["examples/sushi"])
14+
lsp_context = LSPContext(context)
15+
16+
# Find the top_waiters model that uses @ADD_ONE macro
17+
top_waiters_path = next(
18+
path
19+
for path, info in lsp_context.map.items()
20+
if isinstance(info, ModelTarget) and "sushi.top_waiters" in info.names
21+
)
22+
23+
top_waiters_uri = URI.from_path(top_waiters_path)
24+
macro_references = get_macro_definitions_for_a_path(lsp_context, top_waiters_uri)
25+
26+
# Find the @ADD_ONE reference
27+
add_one_ref = next((ref for ref in macro_references if ref.range.start.line == 12), None)
28+
assert add_one_ref is not None, "Should find @ADD_ONE reference in top_waiters"
29+
30+
# Click on the @ADD_ONE macro at line 13, character 5 (the @ symbol)
31+
position = Position(line=12, character=5)
32+
33+
all_references = get_macro_find_all_references(lsp_context, top_waiters_uri, position)
34+
35+
# Should find at least 2 references: the definition and the usage in top_waiters
36+
assert len(all_references) >= 2, f"Expected at least 2 references, found {len(all_references)}"
37+
38+
# Verify the macro definition is included
39+
definition_refs = [ref for ref in all_references if "utils.py" in ref.uri]
40+
assert len(definition_refs) >= 1, "Should include the macro definition in utils.py"
41+
42+
# Verify the usage in top_waiters is included
43+
usage_refs = [ref for ref in all_references if "top_waiters" in ref.uri]
44+
assert len(usage_refs) >= 1, "Should include the usage in top_waiters.sql"
45+
46+
# breakpoint()
47+
expected_ranges = [
48+
# Macro definition in utils.py
49+
{
50+
"uri": "file:///Users/themistoklisvaltinos/Developer/sqlmesh/examples/sushi/macros/utils.py",
51+
"range": ((6, 0), (9, 22)),
52+
},
53+
# Usage in customers.sql
54+
{
55+
"uri": "file:///Users/themistoklisvaltinos/Developer/sqlmesh/examples/sushi/models/customers.sql",
56+
"range": ((36, 7), (36, 14)),
57+
},
58+
# Usage in top_waiters.sql
59+
{
60+
"uri": "file:///Users/themistoklisvaltinos/Developer/sqlmesh/examples/sushi/models/top_waiters.sql",
61+
"range": ((12, 5), (12, 12)),
62+
},
63+
]
64+
65+
for expected in expected_ranges:
66+
assert any(
67+
ref.uri == expected["uri"] and
68+
ref.range.start.line == expected["range"][0][0] and
69+
ref.range.start.character == expected["range"][0][1] and
70+
ref.range.end.line == expected["range"][1][0] and
71+
ref.range.end.character == expected["range"][1][1]
72+
for ref in all_references
73+
), f"Expected reference with uri {expected['uri']} and range {expected['range']} not found"
74+
75+
76+
def test_find_all_references_for_macro_multiply():
77+
"""Test finding all references to the @MULTIPLY macro."""
78+
context = Context(paths=["examples/sushi"])
79+
lsp_context = LSPContext(context)
80+
81+
# Find the top_waiters model that uses @MULTIPLY macro
82+
top_waiters_path = next(
83+
path
84+
for path, info in lsp_context.map.items()
85+
if isinstance(info, ModelTarget) and "sushi.top_waiters" in info.names
86+
)
87+
88+
top_waiters_uri = URI.from_path(top_waiters_path)
89+
macro_references = get_macro_definitions_for_a_path(lsp_context, top_waiters_uri)
90+
91+
# Find the @MULTIPLY reference
92+
multiply_ref = next((ref for ref in macro_references if ref.range.start.line == 13), None)
93+
assert multiply_ref is not None, "Should find @MULTIPLY reference in top_waiters"
94+
95+
# Click on the @MULTIPLY macro at line 14, character 5 (the @ symbol)
96+
position = Position(line=13, character=5)
97+
all_references = get_macro_find_all_references(lsp_context, top_waiters_uri, position)
98+
99+
# Should find at least 2 references: the definition and the usage
100+
assert len(all_references) >= 2, f"Expected at least 2 references, found {len(all_references)}"
101+
102+
# Verify both definition and usage are included
103+
assert any("utils.py" in ref.uri for ref in all_references), "Should include macro definition"
104+
assert any("top_waiters" in ref.uri for ref in all_references), "Should include usage"
105+
106+
107+
def test_find_all_references_for_sql_literal_macro():
108+
"""Test finding references to @SQL_LITERAL macro ."""
109+
context = Context(paths=["examples/sushi"])
110+
lsp_context = LSPContext(context)
111+
112+
# Find the top_waiters model that uses @SQL_LITERAL macro
113+
top_waiters_path = next(
114+
path
115+
for path, info in lsp_context.map.items()
116+
if isinstance(info, ModelTarget) and "sushi.top_waiters" in info.names
117+
)
118+
119+
top_waiters_uri = URI.from_path(top_waiters_path)
120+
macro_references = get_macro_definitions_for_a_path(lsp_context, top_waiters_uri)
121+
122+
# Find the @SQL_LITERAL reference
123+
sql_literal_ref = next((ref for ref in macro_references if ref.range.start.line == 14), None)
124+
assert sql_literal_ref is not None, "Should find @SQL_LITERAL reference in top_waiters"
125+
126+
# Click on the @SQL_LITERAL macro
127+
position = Position(line=14, character=5)
128+
all_references = get_macro_find_all_references(lsp_context, top_waiters_uri, position)
129+
130+
# For user-defined macros in utils.py, should find references
131+
assert len(all_references) >= 2, f"Expected at least 2 references, found {len(all_references)}"
132+
133+
134+
def test_find_references_from_outside_macro_position():
135+
"""Test that clicking outside a macro doesn't return macro references."""
136+
context = Context(paths=["examples/sushi"])
137+
lsp_context = LSPContext(context)
138+
139+
top_waiters_path = next(
140+
path
141+
for path, info in lsp_context.map.items()
142+
if isinstance(info, ModelTarget) and "sushi.top_waiters" in info.names
143+
)
144+
145+
top_waiters_uri = URI.from_path(top_waiters_path)
146+
147+
# Click on a position that is not on a macro
148+
position = Position(line=0, character=0) # First line, which is a comment
149+
all_references = get_macro_find_all_references(lsp_context, top_waiters_uri, position)
150+
151+
# Should return empty list when not on a macro
152+
assert len(all_references) == 0, "Should not find macro references when not on a macro"
153+
154+
155+
def test_multi_repo_macro_references():
156+
"""Test finding macro references across multiple repositories."""
157+
context = Context(paths=["examples/multi/repo_1", "examples/multi/repo_2"], gateway="memory")
158+
lsp_context = LSPContext(context)
159+
160+
# Find model 'd' which uses macros from repo_2
161+
d_path = next(
162+
path
163+
for path, info in lsp_context.map.items()
164+
if isinstance(info, ModelTarget) and "silver.d" in info.names
165+
)
166+
167+
d_uri = URI.from_path(d_path)
168+
macro_references = get_macro_definitions_for_a_path(lsp_context, d_uri)
169+
170+
if macro_references:
171+
# Click on the second macro reference which appears under the same name in repo_1 ('dup')
172+
first_ref = macro_references[1]
173+
position = Position(
174+
line=first_ref.range.start.line,
175+
character=first_ref.range.start.character + 1
176+
)
177+
all_references = get_macro_find_all_references(lsp_context, d_uri, position)
178+
179+
# Should find the definition and usage
180+
assert len(all_references) == 2, f"Expected 2 references, found {len(all_references)}"
181+
182+
# Verify references from repo_2
183+
assert any("repo_2" in ref.uri for ref in all_references), "Should find macro in repo_2"
184+
185+
# But not references in repo_1 since despite identical name they're different macros
186+
assert not any("repo_1" in ref.uri for ref in all_references), "Shouldn't find macro in repo_1"

tests/test_forking.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,14 @@ def test_parallel_load(assert_exp_eq, mocker):
4646
WITH "current_marketing" AS (
4747
SELECT
4848
"current_marketing_outer"."customer_id" AS "customer_id",
49-
"current_marketing_outer"."status" AS "status"
49+
"current_marketing_outer"."status" AS "status",
50+
2 AS "another_column"
5051
FROM "current_marketing_outer" AS "current_marketing_outer"
5152
)
5253
SELECT
5354
"current_marketing"."customer_id" AS "customer_id",
54-
"current_marketing"."status" AS "status"
55+
"current_marketing"."status" AS "status",
56+
"current_marketing"."another_column" AS "another_column"
5557
FROM "current_marketing" AS "current_marketing"
5658
) AS "m"
5759
ON "m"."customer_id" = "o"."customer_id"

tests/web/test_lineage.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,14 @@ def test_get_lineage(client: TestClient, web_sushi_context: Context) -> None:
6262
WITH "current_marketing" AS (
6363
SELECT
6464
"current_marketing_outer"."customer_id" AS "customer_id",
65-
"current_marketing_outer"."status" AS "status"
65+
"current_marketing_outer"."status" AS "status",
66+
2 AS "another_column"
6667
FROM "current_marketing_outer" AS "current_marketing_outer"
6768
)
6869
SELECT
6970
"current_marketing"."customer_id" AS "customer_id",
70-
"current_marketing"."status" AS "status"
71+
"current_marketing"."status" AS "status",
72+
"current_marketing"."another_column" AS "another_column"
7173
FROM "current_marketing" AS "current_marketing"
7274
) AS "m"
7375
ON "m"."customer_id" = "o"."customer_id"

0 commit comments

Comments
 (0)