Skip to content

Commit fbd2286

Browse files
committed
feat: add hover information for columns
[ci skip]
1 parent fb6bafb commit fbd2286

File tree

6 files changed

+145
-25
lines changed

6 files changed

+145
-25
lines changed

examples/sushi/models/customers.sql

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ MODEL (
77
grain customer_id,
88
description 'Sushi customer data',
99
column_descriptions (
10-
customer_id = 'customer_id uniquely identifies customers'
10+
customer_id = 'customer_id uniquely identifies customers',
11+
status = 'status of the customer'
1112
)
1213
);
1314

package-lock.json

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlmesh/lsp/columns.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import typing as t
2+
from dataclasses import dataclass
3+
4+
from lsprotocol.types import Range
5+
6+
from sqlmesh.core.model.definition import SqlModel
7+
from sqlglot import exp
8+
9+
from sqlmesh.lsp.reference import range_from_token_position_details, TokenPositionDetails
10+
11+
12+
@dataclass
13+
class ColumnDescriptionMap:
14+
range: Range
15+
column_name: str
16+
data_type: t.Optional[str] = None
17+
description: t.Optional[str] = None
18+
19+
20+
def get_columns_and_ranges_for_model(model: SqlModel) -> t.Optional[t.List[ColumnDescriptionMap]]:
21+
"""
22+
Get the top level columns and their position in the file to be able to provide hover information.
23+
24+
If the column information is not available, return None.
25+
"""
26+
type_definitions = model.columns_to_types
27+
columns = model.column_descriptions
28+
query = model.query
29+
30+
if not isinstance(query, exp.Query):
31+
return None
32+
33+
path = model._path
34+
if not path.is_file():
35+
return None
36+
with open(path, "r") as f:
37+
lines = f.readlines()
38+
39+
# Get the top-level columns from the SELECT
40+
outs = []
41+
top_level_columns = query.expressions
42+
for projection in top_level_columns:
43+
if isinstance(projection, exp.Alias):
44+
column = projection.get('alias')
45+
elif isinstance(projection, exp.Column):
46+
column = projection
47+
else:
48+
continue
49+
50+
if not isinstance(column, exp.Column):
51+
continue
52+
53+
column_name = column.name
54+
data_type = type_definitions[column_name] if type_definitions is not None else None
55+
description = columns[column_name] if column_name in columns else None
56+
token_details = TokenPositionDetails.from_meta(column.this.meta)
57+
column_range = range_from_token_position_details(token_details, lines)
58+
column_description_map = ColumnDescriptionMap(
59+
range=column_range,
60+
column_name=column_name,
61+
data_type=str(data_type) if data_type else None,
62+
description=description,
63+
)
64+
outs.append(column_description_map)
65+
66+
return outs

sqlmesh/lsp/main.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@
1111
from sqlmesh._version import __version__
1212
from sqlmesh.core.context import Context
1313
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
14+
from sqlmesh.core.model import SqlModel
15+
from sqlmesh.lsp.columns import get_columns_and_ranges_for_model
1416
from sqlmesh.lsp.completions import get_sql_completions
1517
from sqlmesh.lsp.context import LSPContext, ModelTarget
1618
from sqlmesh.lsp.custom import ALL_MODELS_FEATURE, AllModelsRequest, AllModelsResponse
1719
from sqlmesh.lsp.reference import (
1820
get_references,
21+
is_position_in_range,
1922
)
2023

2124

@@ -189,17 +192,39 @@ def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hov
189192
references = get_references(
190193
self.lsp_context, params.text_document.uri, params.position
191194
)
192-
if not references:
195+
if references:
196+
reference = references[0]
197+
if reference.description:
198+
return types.Hover(
199+
contents=types.MarkupContent(
200+
kind=types.MarkupKind.Markdown, value=reference.description
201+
),
202+
range=reference.range,
203+
)
204+
205+
# Try columns if no description is found
206+
models = self.lsp_context.map[params.text_document.uri]
207+
if not models:
193208
return None
194-
reference = references[0]
195-
if not reference.description:
209+
if not isinstance(models, ModelTarget):
196210
return None
197-
return types.Hover(
198-
contents=types.MarkupContent(
199-
kind=types.MarkupKind.Markdown, value=reference.description
200-
),
201-
range=reference.range,
202-
)
211+
model = self.lsp_context.context.get_model(models.names[0])
212+
if not isinstance(model, SqlModel):
213+
return None
214+
columns = get_columns_and_ranges_for_model(model)
215+
if not columns:
216+
return None
217+
218+
for column in columns:
219+
if column.description and is_position_in_range(column.range, params.position):
220+
return types.Hover(
221+
contents=types.MarkupContent(
222+
kind=types.MarkupKind.Markdown,
223+
value=column.description,
224+
),
225+
range=column.range,
226+
)
227+
return None
203228

204229
except Exception as e:
205230
ls.show_message(f"Error getting hover information: {e}", types.MessageType.Error)

sqlmesh/lsp/reference.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ class Reference(PydanticModel):
2424
description: t.Optional[str] = None
2525

2626

27+
def is_position_in_range(range: Range, position: Position) -> bool:
28+
return (
29+
range.start.line < position.line
30+
or (range.start.line == position.line and range.start.character <= position.character)
31+
) and (
32+
range.end.line > position.line
33+
or (range.end.line == position.line and range.end.character >= position.character)
34+
)
35+
36+
2737
def by_position(position: Position) -> t.Callable[[Reference], bool]:
2838
"""
2939
Filter reference to only filter references that contain the given position.
@@ -35,17 +45,8 @@ def by_position(position: Position) -> t.Callable[[Reference], bool]:
3545
A function that returns True if the reference contains the position, False otherwise
3646
"""
3747

38-
def contains_position(r: Reference) -> bool:
39-
return (
40-
r.range.start.line < position.line
41-
or (
42-
r.range.start.line == position.line
43-
and r.range.start.character <= position.character
44-
)
45-
) and (
46-
r.range.end.line > position.line
47-
or (r.range.end.line == position.line and r.range.end.character >= position.character)
48-
)
48+
def contains_position(reference: Reference) -> bool:
49+
return is_position_in_range(reference.range, position)
4950

5051
return contains_position
5152

@@ -167,15 +168,15 @@ def get_model_definitions_for_a_path(
167168

168169
# Extract metadata for positioning
169170
table_meta = TokenPositionDetails.from_meta(table.this.meta)
170-
table_range = _range_from_token_position_details(table_meta, read_file)
171+
table_range = range_from_token_position_details(table_meta, read_file)
171172
start_pos = table_range.start
172173
end_pos = table_range.end
173174

174175
# If there's a catalog or database qualifier, adjust the start position
175176
catalog_or_db = table.args.get("catalog") or table.args.get("db")
176177
if catalog_or_db is not None:
177178
catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta)
178-
catalog_or_db_range = _range_from_token_position_details(catalog_or_db_meta, read_file)
179+
catalog_or_db_range = range_from_token_position_details(catalog_or_db_meta, read_file)
179180
start_pos = catalog_or_db_range.start
180181

181182
references.append(
@@ -215,7 +216,7 @@ def from_meta(meta: t.Dict[str, int]) -> "TokenPositionDetails":
215216
)
216217

217218

218-
def _range_from_token_position_details(
219+
def range_from_token_position_details(
219220
token_position_details: TokenPositionDetails, read_file: t.List[str]
220221
) -> Range:
221222
"""

tests/lsp/test_columns.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from sqlmesh import Context
2+
from sqlmesh.core.model.definition import SqlModel
3+
from sqlmesh.lsp.columns import get_columns_and_ranges_for_model
4+
from sqlmesh.lsp.context import LSPContext
5+
6+
7+
def test_get_columns_and_ranges_for_model():
8+
context = Context(paths=["examples/sushi"])
9+
lsp_context = LSPContext(context)
10+
11+
model = lsp_context.context.get_model("sushi.customers")
12+
if not isinstance(model, SqlModel):
13+
raise ValueError("Model is not a SqlModel")
14+
15+
columns = get_columns_and_ranges_for_model(model)
16+
assert columns is not None
17+
18+
assert len(columns) == 3
19+
assert columns[0].column_name == "customer_id"
20+
assert columns[0].description == "customer_id uniquely identifies customers"
21+
assert columns[0].data_type == "INT"
22+
assert columns[0].range is not None
23+
assert columns[0].range.start.line == 27
24+
assert columns[0].range.end.line == 27
25+
assert columns[1].column_name == "status"
26+
assert columns[1].description is None
27+
assert columns[2].column_name == "zip"
28+
assert columns[2].description is None

0 commit comments

Comments
 (0)