11from lsprotocol .types import Range , Position
22import typing as t
33from pathlib import Path
4+ from pydantic import Field
45
56from sqlmesh .core .audit import StandaloneAudit
67from sqlmesh .core .dialect import normalize_model_name
2324import inspect
2425
2526
26- class Reference (PydanticModel ):
27- """
28- A reference to a model or CTE.
29-
30- Attributes:
31- range: The range of the reference in the source file
32- uri: The uri of the referenced model
33- markdown_description: The markdown description of the referenced model
34- target_range: The range of the definition for go-to-definition (optional, used for CTEs)
35- """
27+ class LSPBaseReference (PydanticModel ):
28+ """Base class for all LSP reference types."""
3629
3730 range : Range
3831 uri : str
3932 markdown_description : t .Optional [str ] = None
40- target_range : t .Optional [Range ] = None
33+
34+
35+ class LSPModelReference (LSPBaseReference ):
36+ """A LSP reference to a model."""
37+
38+ type : t .Literal ["model" ] = "model"
39+
40+
41+ class LSPCteReference (LSPBaseReference ):
42+ """A LSP reference to a CTE."""
43+
44+ type : t .Literal ["cte" ] = "cte"
45+ target_range : Range
46+
47+
48+ class LSPMacroReference (LSPBaseReference ):
49+ """A LSP reference to a macro."""
50+
51+ type : t .Literal ["macro" ] = "macro"
52+ target_range : Range
53+
54+
55+ Reference = t .Annotated [
56+ t .Union [LSPModelReference , LSPCteReference , LSPMacroReference ], Field (discriminator = "type" )
57+ ]
4158
4259
4360def by_position (position : Position ) -> t .Callable [[Reference ], bool ]:
@@ -136,7 +153,7 @@ def get_model_definitions_for_a_path(
136153 return []
137154
138155 # Find all possible references
139- references = []
156+ references : t . List [ Reference ] = []
140157
141158 with open (file_path , "r" , encoding = "utf-8" ) as file :
142159 read_file = file .readlines ()
@@ -173,7 +190,7 @@ def get_model_definitions_for_a_path(
173190 table_range = to_lsp_range (table_range_sqlmesh )
174191
175192 references .append (
176- Reference (
193+ LSPCteReference (
177194 uri = document_uri .value , # Same file
178195 range = table_range ,
179196 target_range = target_range ,
@@ -227,7 +244,7 @@ def get_model_definitions_for_a_path(
227244 description = generate_markdown_description (referenced_model )
228245
229246 references .append (
230- Reference (
247+ LSPModelReference (
231248 uri = referenced_model_uri .value ,
232249 range = Range (
233250 start = to_lsp_position (start_pos_sqlmesh ),
@@ -286,7 +303,7 @@ def get_macro_definitions_for_a_path(
286303 return []
287304
288305 references = []
289- config_for_model , config_path = lsp_context .context .config_for_path (
306+ _ , config_path = lsp_context .context .config_for_path (
290307 file_path ,
291308 )
292309
@@ -372,7 +389,7 @@ def get_macro_reference(
372389 # Create a reference to the macro definition
373390 macro_uri = URI .from_path (path )
374391
375- return Reference (
392+ return LSPMacroReference (
376393 uri = macro_uri .value ,
377394 range = to_lsp_range (macro_range ),
378395 target_range = Range (
@@ -405,7 +422,7 @@ def get_built_in_macro_reference(macro_name: str, macro_range: Range) -> t.Optio
405422 # Calculate the end line number by counting the number of source lines
406423 end_line_number = line_number + len (source_lines ) - 1
407424
408- return Reference (
425+ return LSPMacroReference (
409426 uri = URI .from_path (Path (filename )).value ,
410427 range = macro_range ,
411428 target_range = Range (
@@ -416,9 +433,91 @@ def get_built_in_macro_reference(macro_name: str, macro_range: Range) -> t.Optio
416433 )
417434
418435
436+ def get_model_find_all_references (
437+ lint_context : LSPContext , document_uri : URI , position : Position
438+ ) -> t .List [LSPModelReference ]:
439+ """
440+ Get all references to a model across the entire project.
441+
442+ This function finds all usages of a model in other files by searching through
443+ all models in the project and checking their dependencies.
444+
445+ Args:
446+ lint_context: The LSP context
447+ document_uri: The URI of the document
448+ position: The position to check for model references
449+
450+ Returns:
451+ A list of references to the model across all files
452+ """
453+ # First, get the references in the current file to determine what model we're looking for
454+ current_file_references = get_model_definitions_for_a_path (lint_context , document_uri )
455+
456+ # Find the model reference at the cursor position
457+ target_model_uri : t .Optional [str ] = None
458+ for ref in current_file_references :
459+ if _position_within_range (position , ref .range ) and isinstance (ref , LSPModelReference ):
460+ # This is a model reference, get the target model URI
461+ target_model_uri = ref .uri
462+ break
463+
464+ if target_model_uri is None :
465+ return []
466+
467+ # Start with the model definition
468+ all_references : t .List [LSPModelReference ] = [
469+ LSPModelReference (
470+ uri = ref .uri ,
471+ range = Range (
472+ start = Position (line = 0 , character = 0 ),
473+ end = Position (line = 0 , character = 0 ),
474+ ),
475+ markdown_description = ref .markdown_description ,
476+ )
477+ ]
478+
479+ # Then add the original reference
480+ for ref in current_file_references :
481+ if ref .uri == target_model_uri and isinstance (ref , LSPModelReference ):
482+ all_references .append (
483+ LSPModelReference (
484+ uri = document_uri .value ,
485+ range = ref .range ,
486+ markdown_description = ref .markdown_description ,
487+ )
488+ )
489+
490+ # Search through the models in the project
491+ for path , target in lint_context .map .items ():
492+ if not isinstance (target , (ModelTarget , AuditTarget )):
493+ continue
494+
495+ file_uri = URI .from_path (path )
496+
497+ # Skip current file, already processed
498+ if file_uri .value == document_uri .value :
499+ continue
500+
501+ # Get model references for this file
502+ file_references = get_model_definitions_for_a_path (lint_context , file_uri )
503+
504+ # Add references that point to the target model file
505+ for ref in file_references :
506+ if ref .uri == target_model_uri and isinstance (ref , LSPModelReference ):
507+ all_references .append (
508+ LSPModelReference (
509+ uri = file_uri .value ,
510+ range = ref .range ,
511+ markdown_description = ref .markdown_description ,
512+ )
513+ )
514+
515+ return all_references
516+
517+
419518def get_cte_references (
420519 lint_context : LSPContext , document_uri : URI , position : Position
421- ) -> t .List [Reference ]:
520+ ) -> t .List [LSPCteReference ]:
422521 """
423522 Get all references to a CTE at a specific position in a document.
424523
@@ -432,12 +531,12 @@ def get_cte_references(
432531 Returns:
433532 A list of references to the CTE (including its definition and all usages)
434533 """
435- references = get_model_definitions_for_a_path (lint_context , document_uri )
436534
437- # Filter for CTE references (those with target_range set and same URI)
438- # TODO: Consider extending Reference class to explicitly indicate reference type instead
439- cte_references = [
440- ref for ref in references if ref .target_range is not None and ref .uri == document_uri .value
535+ # Filter to get the CTE references
536+ cte_references : t .List [LSPCteReference ] = [
537+ ref
538+ for ref in get_model_definitions_for_a_path (lint_context , document_uri )
539+ if isinstance (ref , LSPCteReference )
441540 ]
442541
443542 if not cte_references :
@@ -450,7 +549,7 @@ def get_cte_references(
450549 target_cte_definition_range = ref .target_range
451550 break
452551 # Check if cursor is on the CTE definition
453- elif ref . target_range and _position_within_range (position , ref .target_range ):
552+ elif _position_within_range (position , ref .target_range ):
454553 target_cte_definition_range = ref .target_range
455554 break
456555
@@ -459,9 +558,10 @@ def get_cte_references(
459558
460559 # Add the CTE definition
461560 matching_references = [
462- Reference (
561+ LSPCteReference (
463562 uri = document_uri .value ,
464563 range = target_cte_definition_range ,
564+ target_range = target_cte_definition_range ,
465565 markdown_description = "CTE definition" ,
466566 )
467567 ]
@@ -470,16 +570,45 @@ def get_cte_references(
470570 for ref in cte_references :
471571 if ref .target_range == target_cte_definition_range :
472572 matching_references .append (
473- Reference (
573+ LSPCteReference (
474574 uri = document_uri .value ,
475575 range = ref .range ,
576+ target_range = ref .target_range ,
476577 markdown_description = "CTE usage" ,
477578 )
478579 )
479580
480581 return matching_references
481582
482583
584+ def get_all_references (
585+ lint_context : LSPContext , document_uri : URI , position : Position
586+ ) -> t .Sequence [Reference ]:
587+ """
588+ Get all references of a symbol at a specific position in a document.
589+
590+ This function determines the type of reference (CTE, model for now) at the cursor
591+ position and returns all references to that symbol across the project.
592+
593+ Args:
594+ lint_context: The LSP context
595+ document_uri: The URI of the document
596+ position: The position to check for references
597+
598+ Returns:
599+ A list of references to the symbol at the given position
600+ """
601+ # First try CTE references (within same file)
602+ if cte_references := get_cte_references (lint_context , document_uri , position ):
603+ return cte_references
604+
605+ # Then try model references (across files)
606+ if model_references := get_model_find_all_references (lint_context , document_uri , position ):
607+ return model_references
608+
609+ return []
610+
611+
483612def _position_within_range (position : Position , range : Range ) -> bool :
484613 """Check if a position is within a given range."""
485614 return (
0 commit comments