Skip to content

Commit 967d0e6

Browse files
refactors; adapt with filtering logic
1 parent 5f3e229 commit 967d0e6

File tree

2 files changed

+111
-113
lines changed

2 files changed

+111
-113
lines changed

sqlmesh/lsp/reference.py

Lines changed: 80 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -450,75 +450,76 @@ def get_model_find_all_references(
450450
Returns:
451451
A list of references to the model across all files
452452
"""
453-
# First, get the references in the current file to determine what model we're looking for
454-
current_file_references = [
455-
ref
456-
for ref in get_model_definitions_for_a_path(lint_context, document_uri)
457-
if isinstance(ref, LSPModelReference)
458-
]
459-
460453
# Find the model reference at the cursor position
461-
target_model_uri: t.Optional[str] = None
462-
for ref in current_file_references:
463-
if _position_within_range(position, ref.range):
464-
# This is a model reference, get the target model URI
465-
target_model_uri = ref.uri
466-
break
454+
model_at_position = next(
455+
filter(
456+
lambda ref: isinstance(ref, LSPModelReference)
457+
and _position_within_range(position, ref.range),
458+
get_model_definitions_for_a_path(lint_context, document_uri),
459+
),
460+
None,
461+
)
467462

468-
if target_model_uri is None:
463+
if not model_at_position:
469464
return []
470465

466+
assert isinstance(model_at_position, LSPModelReference) # for mypy
467+
468+
target_model_uri = model_at_position.uri
469+
471470
# Start with the model definition
472471
all_references: t.List[LSPModelReference] = [
473472
LSPModelReference(
474-
uri=ref.uri,
473+
uri=model_at_position.uri,
475474
range=Range(
476475
start=Position(line=0, character=0),
477476
end=Position(line=0, character=0),
478477
),
479-
markdown_description=ref.markdown_description,
478+
markdown_description=model_at_position.markdown_description,
480479
)
481480
]
482481

483-
# Then add the original reference
484-
for ref in current_file_references:
485-
if ref.uri == target_model_uri and isinstance(ref, LSPModelReference):
486-
all_references.append(
487-
LSPModelReference(
488-
uri=document_uri.value,
489-
range=ref.range,
490-
markdown_description=ref.markdown_description,
491-
)
482+
# Then add references from the current file
483+
current_file_refs = filter(
484+
lambda ref: isinstance(ref, LSPModelReference) and ref.uri == target_model_uri,
485+
get_model_definitions_for_a_path(lint_context, document_uri),
486+
)
487+
488+
for ref in current_file_refs:
489+
assert isinstance(ref, LSPModelReference) # for mypy
490+
491+
all_references.append(
492+
LSPModelReference(
493+
uri=document_uri.value,
494+
range=ref.range,
495+
markdown_description=ref.markdown_description,
492496
)
497+
)
493498

494499
# Search through the models in the project
495-
for path, target in lint_context.map.items():
496-
if not isinstance(target, (ModelTarget, AuditTarget)):
497-
continue
498-
500+
for path, _ in lint_context.map.items():
499501
file_uri = URI.from_path(path)
500502

501503
# Skip current file, already processed
502504
if file_uri.value == document_uri.value:
503505
continue
504506

505-
# Get model references for this file
506-
file_references = [
507-
ref
508-
for ref in get_model_definitions_for_a_path(lint_context, file_uri)
509-
if isinstance(ref, LSPModelReference)
510-
]
511-
512-
# Add references that point to the target model file
513-
for ref in file_references:
514-
if ref.uri == target_model_uri and isinstance(ref, LSPModelReference):
515-
all_references.append(
516-
LSPModelReference(
517-
uri=file_uri.value,
518-
range=ref.range,
519-
markdown_description=ref.markdown_description,
520-
)
507+
# Get model references that point to the target model
508+
matching_refs = filter(
509+
lambda ref: isinstance(ref, LSPModelReference) and ref.uri == target_model_uri,
510+
get_model_definitions_for_a_path(lint_context, file_uri),
511+
)
512+
513+
for ref in matching_refs:
514+
assert isinstance(ref, LSPModelReference) # for mypy
515+
516+
all_references.append(
517+
LSPModelReference(
518+
uri=file_uri.value,
519+
range=ref.range,
520+
markdown_description=ref.markdown_description,
521521
)
522+
)
522523

523524
return all_references
524525

@@ -588,41 +589,39 @@ def get_cte_references(
588589

589590

590591
def get_macro_find_all_references(
591-
lint_context: LSPContext, document_uri: URI, position: Position
592+
lsp_context: LSPContext, document_uri: URI, position: Position
592593
) -> t.List[LSPMacroReference]:
593594
"""
594595
Get all references to a macro at a specific position in a document.
595596
596597
This function finds all usages of a macro across the entire project.
597598
598599
Args:
599-
lint_context: The LSP context
600+
lsp_context: The LSP context
600601
document_uri: The URI of the document
601602
position: The position to check for macro references
602603
603604
Returns:
604605
A list of references to the macro across all files
605606
"""
606-
# First, get the macro references in the current file to identify the target macro
607-
current_file_references = [
608-
ref
609-
for ref in get_macro_definitions_for_a_path(lint_context, document_uri)
610-
if isinstance(ref, LSPMacroReference)
611-
]
612-
613607
# Find the macro reference at the cursor position
614-
target_macro_uri: t.Optional[str] = None
615-
target_macro_target_range: t.Optional[Range] = None
616-
617-
for ref in current_file_references:
618-
if _position_within_range(position, ref.range):
619-
target_macro_uri = ref.uri
620-
target_macro_target_range = ref.target_range
621-
break
608+
macro_at_position = next(
609+
filter(
610+
lambda ref: isinstance(ref, LSPMacroReference)
611+
and _position_within_range(position, ref.range),
612+
get_macro_definitions_for_a_path(lsp_context, document_uri),
613+
),
614+
None,
615+
)
622616

623-
if target_macro_uri is None:
617+
if not macro_at_position:
624618
return []
625619

620+
assert isinstance(macro_at_position, LSPMacroReference) # for mypy
621+
622+
target_macro_uri = macro_at_position.uri
623+
target_macro_target_range = macro_at_position.target_range
624+
626625
# Start with the macro definition
627626
all_references: t.List[LSPMacroReference] = [
628627
LSPMacroReference(
@@ -634,30 +633,27 @@ def get_macro_find_all_references(
634633
]
635634

636635
# Search through all SQL and audit files in the project
637-
for path, target in lint_context.map.items():
638-
if not isinstance(target, (ModelTarget, AuditTarget)):
639-
continue
640-
636+
for path, _ in lsp_context.map.items():
641637
file_uri = URI.from_path(path)
642638

643-
# Get macro references for this file
644-
file_macro_references = [
645-
ref
646-
for ref in get_macro_definitions_for_a_path(lint_context, file_uri)
647-
if isinstance(ref, LSPMacroReference)
648-
]
649-
650-
# Add references that point to the same macro definition
651-
for ref in file_macro_references:
652-
if ref.uri == target_macro_uri and ref.target_range == target_macro_target_range:
653-
all_references.append(
654-
LSPMacroReference(
655-
uri=file_uri.value,
656-
range=ref.range,
657-
target_range=ref.target_range,
658-
markdown_description=ref.markdown_description,
659-
)
639+
# Get macro references that point to the same macro definition
640+
matching_refs = filter(
641+
lambda ref: isinstance(ref, LSPMacroReference)
642+
and ref.uri == target_macro_uri
643+
and ref.target_range == target_macro_target_range,
644+
get_macro_definitions_for_a_path(lsp_context, file_uri),
645+
)
646+
647+
for ref in matching_refs:
648+
assert isinstance(ref, LSPMacroReference) # for mypy
649+
all_references.append(
650+
LSPMacroReference(
651+
uri=file_uri.value,
652+
range=ref.range,
653+
target_range=ref.target_range,
654+
markdown_description=ref.markdown_description,
660655
)
656+
)
661657

662658
return all_references
663659

tests/lsp/test_reference_macro_find_all.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
get_macro_definitions_for_a_path,
77
)
88
from sqlmesh.lsp.uri import URI
9+
from sqlmesh.core.linter.helpers import (
10+
read_range_from_file,
11+
Range as SQLMeshRange,
12+
Position as SQLMeshPosition,
13+
)
914

1015

1116
def test_find_all_references_for_macro_add_one():
@@ -43,35 +48,32 @@ def test_find_all_references_for_macro_add_one():
4348
usage_refs = [ref for ref in all_references if "top_waiters" in ref.uri]
4449
assert len(usage_refs) >= 1, "Should include the usage in top_waiters.sql"
4550

46-
# breakpoint()
47-
expected_ranges = [
48-
# Macro definition in utils.py
49-
{
50-
"uri_file": "utils.py",
51-
"range": ((6, 0), (9, 22)),
52-
},
53-
# Usage in customers.sql
54-
{
55-
"uri_file": "customers.sql",
56-
"range": ((36, 7), (36, 14)),
57-
},
58-
# Usage in top_waiters.sql
59-
{
60-
"uri_file": "top_waiters.sql",
61-
"range": ((12, 5), (12, 12)),
62-
},
63-
]
64-
65-
for expected in expected_ranges:
66-
assert any(
67-
expected["uri_file"] in ref.uri
68-
and ref.range.start.line == expected["range"][0][0]
69-
and ref.range.start.character == expected["range"][0][1]
70-
and ref.range.end.line == expected["range"][1][0]
71-
and ref.range.end.character == expected["range"][1][1]
72-
for ref in all_references
73-
), (
74-
f"Expected reference with uri {expected['uri_file']} and range {expected['range']} not found"
51+
expected_files = {
52+
"utils.py": {"pattern": r"def add_one", "expected_content": "def add_one"},
53+
"customers.sql": {"pattern": r"@ADD_ONE\s*\(", "expected_content": "ADD_ONE"},
54+
"top_waiters.sql": {"pattern": r"@ADD_ONE\s*\(", "expected_content": "ADD_ONE"},
55+
}
56+
57+
for expected_file, expectations in expected_files.items():
58+
file_refs = [ref for ref in all_references if expected_file in ref.uri]
59+
assert len(file_refs) >= 1, f"Should find at least one reference in {expected_file}"
60+
61+
file_ref = file_refs[0]
62+
file_path = URI(file_ref.uri).to_path()
63+
64+
sqlmesh_range = SQLMeshRange(
65+
start=SQLMeshPosition(
66+
line=file_ref.range.start.line, character=file_ref.range.start.character
67+
),
68+
end=SQLMeshPosition(
69+
line=file_ref.range.end.line, character=file_ref.range.end.character
70+
),
71+
)
72+
73+
# Read the content at the reference location
74+
content = read_range_from_file(file_path, sqlmesh_range)
75+
assert content.startswith(expectations["expected_content"]), (
76+
f"Expected content to start with '{expectations['expected_content']}', got: {content}"
7577
)
7678

7779

0 commit comments

Comments
 (0)