Skip to content

Commit 133f1d3

Browse files
authored
fix(lsp): overwriting the formatted file (#4547)
1 parent a69f1a4 commit 133f1d3

File tree

2 files changed

+73
-34
lines changed

2 files changed

+73
-34
lines changed

sqlmesh/core/context.py

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,33 +1129,15 @@ def format(
11291129

11301130
with open(target._path, "r+", encoding="utf-8") as file:
11311131
before = file.read()
1132-
expressions = parse(before, default_dialect=self.config_for_node(target).dialect)
1133-
if transpile and is_meta_expression(expressions[0]):
1134-
for prop in expressions[0].expressions:
1135-
if prop.name.lower() == "dialect":
1136-
prop.replace(
1137-
exp.Property(
1138-
this="dialect",
1139-
value=exp.Literal.string(transpile or target.dialect),
1140-
)
1141-
)
1142-
1143-
format_config = self.config_for_node(target).format
1144-
after = format_model_expressions(
1145-
expressions,
1146-
transpile or target.dialect,
1147-
rewrite_casts=(
1148-
rewrite_casts
1149-
if rewrite_casts is not None
1150-
else not format_config.no_rewrite_casts
1151-
),
1152-
**{**format_config.generator_options, **kwargs},
1153-
)
11541132

1155-
if append_newline is None:
1156-
append_newline = format_config.append_newline
1157-
if append_newline:
1158-
after += "\n"
1133+
after = self._format(
1134+
target,
1135+
before,
1136+
transpile=transpile,
1137+
rewrite_casts=rewrite_casts,
1138+
append_newline=append_newline,
1139+
**kwargs,
1140+
)
11591141

11601142
if not check:
11611143
file.seek(0)
@@ -1174,6 +1156,44 @@ def format(
11741156

11751157
return True
11761158

1159+
def _format(
1160+
self,
1161+
target: Model | Audit,
1162+
before: str,
1163+
*,
1164+
transpile: t.Optional[str] = None,
1165+
rewrite_casts: t.Optional[bool] = None,
1166+
append_newline: t.Optional[bool] = None,
1167+
**kwargs: t.Any,
1168+
) -> str:
1169+
expressions = parse(before, default_dialect=self.config_for_node(target).dialect)
1170+
if transpile and is_meta_expression(expressions[0]):
1171+
for prop in expressions[0].expressions:
1172+
if prop.name.lower() == "dialect":
1173+
prop.replace(
1174+
exp.Property(
1175+
this="dialect",
1176+
value=exp.Literal.string(transpile or target.dialect),
1177+
)
1178+
)
1179+
1180+
format_config = self.config_for_node(target).format
1181+
after = format_model_expressions(
1182+
expressions,
1183+
transpile or target.dialect,
1184+
rewrite_casts=(
1185+
rewrite_casts if rewrite_casts is not None else not format_config.no_rewrite_casts
1186+
),
1187+
**{**format_config.generator_options, **kwargs},
1188+
)
1189+
1190+
if append_newline is None:
1191+
append_newline = format_config.append_newline
1192+
if append_newline:
1193+
after += "\n"
1194+
1195+
return after
1196+
11771197
@python_api_analytics
11781198
def plan(
11791199
self,

sqlmesh/lsp/main.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22
"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration, refactored without globals."""
33

4+
from itertools import chain
45
import logging
56
import typing as t
67
from pathlib import Path
@@ -20,7 +21,11 @@
2021
ApiResponseGetModels,
2122
)
2223
from sqlmesh.lsp.completions import get_sql_completions
23-
from sqlmesh.lsp.context import LSPContext, ModelTarget, render_model as render_model_context
24+
from sqlmesh.lsp.context import (
25+
LSPContext,
26+
ModelTarget,
27+
render_model as render_model_context,
28+
)
2429
from sqlmesh.lsp.custom import (
2530
ALL_MODELS_FEATURE,
2631
RENDER_MODEL_FEATURE,
@@ -213,15 +218,29 @@ def formatting(
213218
uri = URI(params.text_document.uri)
214219
self._ensure_context_for_document(uri)
215220
document = ls.workspace.get_text_document(params.text_document.uri)
221+
before = document.source
216222
if self.lsp_context is None:
217223
raise RuntimeError(f"No context found for document: {document.path}")
218224

219-
# Perform formatting using the loaded context
220-
self.lsp_context.context.format(paths=(str(uri.to_path()),))
221-
with open(uri.to_path(), "r+", encoding="utf-8") as file:
222-
new_text = file.read()
223-
224-
# Return a single edit that replaces the entire file.
225+
target = next(
226+
(
227+
target
228+
for target in chain(
229+
self.lsp_context.context._models.values(),
230+
self.lsp_context.context._audits.values(),
231+
)
232+
if target._path is not None
233+
and target._path.suffix == ".sql"
234+
and (target._path.samefile(uri.to_path()))
235+
),
236+
None,
237+
)
238+
if target is None:
239+
return []
240+
after = self.lsp_context.context._format(
241+
target=target,
242+
before=before,
243+
)
225244
return [
226245
types.TextEdit(
227246
range=types.Range(
@@ -231,7 +250,7 @@ def formatting(
231250
character=len(document.lines[-1]) if document.lines else 0,
232251
),
233252
),
234-
new_text=new_text,
253+
new_text=after,
235254
)
236255
]
237256
except Exception as e:

0 commit comments

Comments
 (0)