Skip to content

Commit 3f72c9e

Browse files
committed
feat: add go to definition to lsp
1 parent fd19aa0 commit 3f72c9e

File tree

6 files changed

+300
-17
lines changed

6 files changed

+300
-17
lines changed

sqlmesh/lsp/context.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from collections import defaultdict
2+
from pathlib import Path
3+
from sqlmesh.core.context import Context
4+
import typing as t
5+
6+
7+
class LSPContext:
8+
"""
9+
A context that is used for linting. It contains the context and a reverse map of file uri to model names .
10+
"""
11+
12+
def __init__(self, context: Context) -> None:
13+
self.context = context
14+
map: t.Dict[str, t.List[str]] = defaultdict(list)
15+
for model in context.models.values():
16+
if model._path is not None:
17+
path = Path(model._path).resolve()
18+
map[f"file://{path.as_posix()}"].append(model.name)
19+
self.map = map

sqlmesh/lsp/main.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python
22
"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration, refactored without globals."""
33

4-
from collections import defaultdict
54
import logging
65
import typing as t
76
from pathlib import Path
@@ -12,21 +11,8 @@
1211
from sqlmesh._version import __version__
1312
from sqlmesh.core.context import Context
1413
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
15-
16-
17-
class LSPContext:
18-
"""
19-
A context that is used for linting. It contains the context and a reverse map of file uri to model names .
20-
"""
21-
22-
def __init__(self, context: Context) -> None:
23-
self.context = context
24-
map: t.Dict[str, t.List[str]] = defaultdict(list)
25-
for model in context.models.values():
26-
if model._path is not None:
27-
path = Path(model._path).resolve()
28-
map[f"file://{path.as_posix()}"].append(model.name)
29-
self.map = map
14+
from sqlmesh.lsp.context import LSPContext
15+
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
3016

3117

3218
class SQLMeshLanguageServer:
@@ -144,6 +130,43 @@ def formatting(
144130
ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error)
145131
return []
146132

133+
@self.server.feature(types.TEXT_DOCUMENT_DEFINITION)
134+
def goto_definition(
135+
ls: LanguageServer, params: types.DefinitionParams
136+
) -> t.List[types.LocationLink]:
137+
"""Jump to an object's definition."""
138+
try:
139+
self._ensure_context_for_document(params.text_document.uri)
140+
document = ls.workspace.get_document(params.text_document.uri)
141+
if self.lsp_context is None:
142+
raise RuntimeError(f"No context found for document: {document.path}")
143+
144+
references = get_model_definitions_for_a_path(
145+
self.lsp_context, params.text_document.uri
146+
)
147+
if len(references) == 0:
148+
return []
149+
150+
return [
151+
types.LocationLink(
152+
target_uri=reference.uri,
153+
target_selection_range=types.Range(
154+
start=types.Position(line=0, character=0),
155+
end=types.Position(line=0, character=0),
156+
),
157+
target_range=types.Range(
158+
start=types.Position(line=0, character=0),
159+
end=types.Position(line=0, character=0),
160+
),
161+
origin_selection_range=reference.range,
162+
)
163+
for reference in references
164+
]
165+
166+
except Exception as e:
167+
ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error)
168+
return []
169+
147170
def _context_get_or_load(self, document_uri: str) -> LSPContext:
148171
if self.lsp_context is None:
149172
self._ensure_context_for_document(document_uri)

sqlmesh/lsp/reference.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
from pathlib import Path
2+
3+
from lsprotocol.types import Range, Position
4+
import typing as t
5+
6+
from sqlmesh.core.dialect import normalize_model_name
7+
from sqlmesh.core.model.definition import SqlModel
8+
from sqlmesh.lsp.context import LSPContext
9+
from sqlglot import exp
10+
11+
from sqlmesh.utils.pydantic import PydanticModel
12+
13+
14+
class Reference(PydanticModel):
15+
range: Range
16+
uri: str
17+
18+
19+
def get_model_definitions_for_a_path(
20+
lint_context: LSPContext, document_uri: str
21+
) -> t.List[Reference]:
22+
"""
23+
Get the model references for a given path.
24+
25+
Works for models and audits.
26+
Works for targeting sql and python models.
27+
28+
Steps:
29+
- Get the parsed query
30+
- Find all table objects using find_all exp.Table
31+
- Match the string against all model names
32+
- Need to normalize it before matching
33+
- Try get_model before normalization
34+
- Match to models that the model refers to
35+
"""
36+
# Ensure the path is a sql model
37+
if not document_uri.endswith(".sql"):
38+
return []
39+
40+
# Get the model
41+
models = lint_context.map[document_uri]
42+
if models is None:
43+
return []
44+
if len(models) == 0:
45+
return []
46+
model_name = models[0]
47+
model = lint_context.context.get_model(model_or_snapshot=model_name, raise_if_missing=False)
48+
if model is None:
49+
return []
50+
if not isinstance(model, SqlModel):
51+
return []
52+
53+
# Find all possible references
54+
tables = list(model.query.find_all(exp.Table))
55+
if len(tables) == 0:
56+
return []
57+
58+
references = []
59+
for table in tables:
60+
depends_on = model.depends_on
61+
62+
# Normalize the table reference
63+
reference_name = table.this.this if table.db is None else f"{table.db}.{table.this.this}"
64+
normalized_reference_name = normalize_model_name(
65+
reference_name, default_catalog=lint_context.context.default_catalog
66+
)
67+
if normalized_reference_name not in depends_on:
68+
continue
69+
70+
# Get the referenced model uri
71+
referenced_model = lint_context.context.get_model(
72+
model_or_snapshot=normalized_reference_name, raise_if_missing=False
73+
)
74+
if referenced_model is None:
75+
continue
76+
# Get the model uri
77+
referenced_model_path = referenced_model._path
78+
if referenced_model_path is None:
79+
continue
80+
# Fully qualify the path in case
81+
path = Path.resolve(Path(referenced_model_path))
82+
referenced_model_uri = f"file://{path}"
83+
read_file = open(path, "r").readlines()
84+
85+
# Extract metadata for positioning
86+
table_meta = TokenPositionDetails.from_meta(table.this.meta)
87+
table_range = _range_from_token_position_details(table_meta, read_file)
88+
start_pos = table_range.start
89+
end_pos = table_range.end
90+
91+
# If there's a database qualifier, adjust the start position
92+
db = table.args.get("db")
93+
if db is not None:
94+
db_meta = TokenPositionDetails.from_meta(db.meta)
95+
db_range = _range_from_token_position_details(db_meta, read_file)
96+
start_pos = db_range.start
97+
98+
# If there's a catalog qualifier, adjust the start position further
99+
catalog = table.args.get("catalog")
100+
if catalog is not None:
101+
catalog_meta = TokenPositionDetails.from_meta(catalog.meta)
102+
catalog_range = _range_from_token_position_details(catalog_meta, read_file)
103+
start_pos = catalog_range.start
104+
105+
references.append(
106+
Reference(uri=referenced_model_uri, range=Range(start=start_pos, end=end_pos))
107+
)
108+
109+
return references
110+
111+
112+
class TokenPositionDetails(PydanticModel):
113+
"""
114+
Details about a token's position in the source code.
115+
116+
Attributes:
117+
line (int): The line that the token ends on.
118+
col (int): The column that the token ends on.
119+
start (int): The start index of the token.
120+
end (int): The ending index of the token.
121+
"""
122+
123+
line: int
124+
col: int
125+
start: int
126+
end: int
127+
128+
@staticmethod
129+
def from_meta(meta: t.Dict[str, int]) -> "TokenPositionDetails":
130+
return TokenPositionDetails(
131+
line=meta["line"],
132+
col=meta["col"],
133+
start=meta["start"],
134+
end=meta["end"],
135+
)
136+
137+
138+
def _range_from_token_position_details(
139+
token_position_details: TokenPositionDetails, read_file: t.List[str]
140+
) -> Range:
141+
"""
142+
Convert a TokenPositionDetails object to a Range object.
143+
144+
:param token_position_details: Details about a token's position
145+
:param read_file: List of lines from the file
146+
:return: A Range object representing the token's position
147+
"""
148+
# Convert from 1-indexed to 0-indexed for line and column
149+
end_line_0 = token_position_details.line - 1
150+
end_col_0 = token_position_details.col
151+
152+
# Find the start line and column by counting backwards from the end position
153+
start_pos = token_position_details.start
154+
end_pos = token_position_details.end
155+
156+
# Initialize with the end position
157+
start_line_0 = end_line_0
158+
start_col_0 = end_col_0 - (end_pos - start_pos + 1)
159+
160+
# If start_col_0 is negative, we need to go back to previous lines
161+
while start_col_0 < 0 and start_line_0 > 0:
162+
start_line_0 -= 1
163+
start_col_0 += len(read_file[start_line_0])
164+
# Account for newline character
165+
if start_col_0 >= 0:
166+
break
167+
start_col_0 += 1 # For the newline character
168+
169+
# Ensure we don't have negative values
170+
start_col_0 = max(0, start_col_0)
171+
return Range(
172+
start=Position(line=start_line_0, character=start_col_0),
173+
end=Position(line=end_line_0, character=end_col_0),
174+
)

tests/lsp/test_context.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from sqlmesh.core.context import Context
2+
from sqlmesh.lsp.context import LSPContext
3+
4+
5+
def test_lsp_context():
6+
context = Context(paths=["examples/sushi"])
7+
lsp_context = LSPContext(context)
8+
9+
assert lsp_context is not None
10+
assert lsp_context.context is not None
11+
assert lsp_context.map is not None
12+
13+
# find one model in the map
14+
active_customers_key = next(
15+
key for key in lsp_context.map.keys() if key.endswith("models/active_customers.sql")
16+
)
17+
assert lsp_context.map[active_customers_key] == ["sushi.active_customers"]

tests/lsp/test_reference.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from sqlmesh.core.context import Context
2+
from sqlmesh.lsp.context import LSPContext
3+
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
4+
5+
6+
def test_reference() -> None:
7+
context = Context(paths=["examples/sushi"])
8+
lsp_context = LSPContext(context)
9+
10+
active_customers_uri = next(
11+
uri for uri, models in lsp_context.map.items() if "sushi.active_customers" in models
12+
)
13+
if active_customers_uri is None:
14+
raise ValueError("active_customers_uri is not found")
15+
sushi_customers_uri = next(
16+
uri for uri, models in lsp_context.map.items() if "sushi.customers" in models
17+
)
18+
if sushi_customers_uri is None:
19+
raise ValueError("sushi_customers_uri is not found")
20+
21+
references = get_model_definitions_for_a_path(lsp_context, active_customers_uri)
22+
23+
assert len(references) == 1
24+
assert references[0].uri == sushi_customers_uri
25+
26+
# Check that the reference in the correct range is sushi.customers
27+
path = active_customers_uri.removeprefix("file://")
28+
read_file = open(path, "r").readlines()
29+
# Get the string range in the read file
30+
reference_range = references[0].range
31+
start_line = reference_range.start.line
32+
end_line = reference_range.end.line
33+
start_character = reference_range.start.character
34+
end_character = reference_range.end.character
35+
# Get the string from the file
36+
37+
# If the reference spans multiple lines, handle it accordingly
38+
if start_line == end_line:
39+
# Reference is on a single line
40+
line_content = read_file[start_line]
41+
referenced_text = line_content[start_character:end_character]
42+
else:
43+
# Reference spans multiple lines
44+
referenced_text = read_file[start_line][
45+
start_character:
46+
] # First line from start_character to end
47+
for line_num in range(start_line + 1, end_line): # Middle lines (if any)
48+
referenced_text += read_file[line_num]
49+
referenced_text += read_file[end_line][:end_character] # Last line up to end_character
50+
assert referenced_text == "sushi.customers"

vscode/extension/src/lsp/lsp.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ export class LSPClient implements Disposable {
2727

2828
const sqlmesh = await sqlmesh_lsp_exec()
2929
if (isErr(sqlmesh)) {
30-
traceError(`Failed to get sqlmesh_lsp_exec, ${sqlmesh.error.type}`)
30+
traceError(`Failed to get sqlmesh_lsp_exec, ${JSON.stringify(sqlmesh.error)}`)
3131
return sqlmesh
3232
}
3333
const workspaceFolders = getWorkspaceFolders()

0 commit comments

Comments
 (0)