Skip to content

Commit ed6cd16

Browse files
committed
Refactoring: Introduce abstraction GroupedDiagnostics, removing duplication
1 parent b3649d7 commit ed6cd16

3 files changed

Lines changed: 71 additions & 120 deletions

File tree

src/serena/tools/symbol_tools.py

Lines changed: 9 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -18,44 +18,8 @@
1818
ToolMarkerSymbolicRead,
1919
)
2020
from serena.tools.tools_base import ToolMarkerOptional
21-
from solidlsp import ls_types
21+
from serena.util.ls_diagnostics import GroupedDiagnostics
2222
from solidlsp.ls_types import SymbolKind
23-
from solidlsp.lsp_protocol_handler.lsp_types import DiagnosticSeverity
24-
25-
FILE_LEVEL_DIAGNOSTIC_BUCKET = "<file>"
26-
27-
28-
def _diagnostic_severity_name(severity: int | None) -> str:
29-
if severity is None:
30-
return "Unknown"
31-
try:
32-
return DiagnosticSeverity(severity).name
33-
except ValueError:
34-
return f"Severity_{severity}"
35-
36-
37-
def _diagnostic_output_dict(diagnostic: ls_types.Diagnostic) -> dict[str, Any]:
38-
result: dict[str, Any] = {
39-
"message": diagnostic["message"],
40-
"range": diagnostic["range"],
41-
}
42-
if "code" in diagnostic:
43-
result["code"] = diagnostic["code"]
44-
if "source" in diagnostic:
45-
result["source"] = diagnostic["source"]
46-
return result
47-
48-
49-
def _add_grouped_diagnostic(
50-
grouped_result: dict[str, dict[str, dict[str, list[dict[str, Any]]]]],
51-
relative_path: str,
52-
severity_name: str,
53-
name_path: str,
54-
diagnostic: ls_types.Diagnostic,
55-
) -> None:
56-
grouped_result.setdefault(relative_path, {}).setdefault(severity_name, {}).setdefault(name_path, []).append(
57-
_diagnostic_output_dict(diagnostic)
58-
)
5923

6024

6125
def _offset_to_line_and_column(text: str, offset: int) -> tuple[int, int]:
@@ -663,7 +627,7 @@ class GetDiagnosticsForFileTool(Tool, ToolMarkerSymbolicRead):
663627
Gets diagnostics for a file, optionally restricted to a line range, grouped by file, severity, and containing symbol.
664628
"""
665629

666-
_ENABLE_DIAGNOSTICS: bool = True
630+
FILE_LEVEL_DIAGNOSTIC_BUCKET = "<file>"
667631

668632
def apply(
669633
self,
@@ -693,26 +657,20 @@ def apply(
693657
min_severity=min_severity,
694658
)
695659

696-
grouped_result: dict[str, dict[str, dict[str, list[dict[str, Any]]]]] = {}
660+
grouped_diagnostics = GroupedDiagnostics()
697661
for diagnostic in diagnostics:
698662
diag_range = diagnostic["range"]["start"]
699-
name_path = FILE_LEVEL_DIAGNOSTIC_BUCKET
663+
name_path = self.FILE_LEVEL_DIAGNOSTIC_BUCKET
700664
owner_symbol = symbol_retriever.find_diagnostic_owner_symbol(
701665
relative_file_path=relative_path,
702666
line=diag_range["line"],
703667
column=diag_range["character"],
704668
)
705669
if owner_symbol is not None:
706670
name_path = owner_symbol.get_name_path()
707-
_add_grouped_diagnostic(
708-
grouped_result,
709-
relative_path=relative_path,
710-
severity_name=_diagnostic_severity_name(diagnostic.get("severity")),
711-
name_path=name_path,
712-
diagnostic=diagnostic,
713-
)
671+
grouped_diagnostics.add(relative_path, name_path, diagnostic)
714672

715-
result = self._to_json(grouped_result)
673+
result = self._to_json(grouped_diagnostics.get_dict())
716674
return self._limit_length(result, max_answer_chars)
717675

718676

@@ -721,8 +679,6 @@ class GetDiagnosticsForSymbolTool(Tool, ToolMarkerSymbolicRead, ToolMarkerOption
721679
Gets diagnostics for a symbol and, optionally, for symbols that reference it.
722680
"""
723681

724-
_ENABLE_DIAGNOSTICS: bool = True
725-
726682
def apply(
727683
self,
728684
name_path: str,
@@ -752,22 +708,16 @@ def apply(
752708
min_severity=min_severity,
753709
)
754710

755-
grouped_result: dict[str, dict[str, dict[str, list[dict[str, Any]]]]] = {}
711+
grouped_diagnostics = GroupedDiagnostics()
756712
for symbol, diagnostics in diagnostics_by_symbol.items():
757713
relative_path = symbol.relative_path
758714
if relative_path is None:
759715
continue
760716
symbol_name_path = symbol.get_name_path()
761717
for diagnostic in diagnostics:
762-
_add_grouped_diagnostic(
763-
grouped_result,
764-
relative_path=relative_path,
765-
severity_name=_diagnostic_severity_name(diagnostic.get("severity")),
766-
name_path=symbol_name_path,
767-
diagnostic=diagnostic,
768-
)
718+
grouped_diagnostics.add(relative_path, symbol_name_path, diagnostic)
769719

770-
result = self._to_json(grouped_result)
720+
result = self._to_json(grouped_diagnostics.get_dict())
771721
return self._limit_length(result, max_answer_chars)
772722

773723

src/serena/tools/tools_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def format_result(
448448

449449
assert self._symbol_retriever is not None
450450
diagnostics_diff = DiagnosticsDiff(self._before_edit_diagnostics_snapshot, self._edited_files, self._symbol_retriever)
451-
grouped_diagnostics = diagnostics_diff.get_grouped_diagnostics()
451+
grouped_diagnostics = diagnostics_diff.get_grouped_diagnostics().get_dict()
452452

453453
if not grouped_diagnostics:
454454
return base_result

src/serena/util/ls_diagnostics.py

Lines changed: 61 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,70 @@ def __init__(self, edited_file_paths: Iterable[EditedFilePath], symbol_retriever
8888
self.warning_identities_by_before_path = warning_identities_by_before_path
8989

9090

91+
class GroupedDiagnostics:
92+
def __init__(self) -> None:
93+
self._grouped_diagnostics: dict[str, dict[str, dict[str, list[dict[str, Any]]]]] = {}
94+
95+
def add(self, relative_path: str, name_path: str, diagnostic: ls_types.Diagnostic) -> None:
96+
severity_name = self._diagnostic_severity_name(diagnostic.get("severity"))
97+
self._grouped_diagnostics.setdefault(relative_path, {}).setdefault(severity_name, {}).setdefault(name_path, []).append(
98+
self._diagnostic_output_dict(diagnostic)
99+
)
100+
101+
def get_dict(self) -> dict[str, dict[str, dict[str, list[dict[str, Any]]]]]:
102+
"""
103+
Returns a nested dictionary of the form:
104+
{
105+
relative_file_path: {
106+
severity_name: {
107+
name_path: [
108+
diagnostic_dict,
109+
...
110+
],
111+
...
112+
},
113+
...
114+
},
115+
...
116+
}
117+
where:
118+
- relative_file_path is the relative path of the file containing the diagnostic
119+
- severity_name is the name of the diagnostic severity (e.g. "Warning", "Error")
120+
- name_path is the name path of the symbol that owns the diagnostic (or "<file>" if no owner symbol was found)
121+
- diagnostic_dict is a dictionary containing the diagnostic's message, range, and optionally code and source
122+
"""
123+
return self._grouped_diagnostics
124+
125+
@staticmethod
126+
def _diagnostic_severity_name(severity: int | None) -> str:
127+
if severity is None:
128+
return "Unknown"
129+
try:
130+
return DiagnosticSeverity(severity).name
131+
except ValueError:
132+
return f"Severity_{severity}"
133+
134+
@staticmethod
135+
def _diagnostic_output_dict(diagnostic: ls_types.Diagnostic) -> dict[str, Any]:
136+
result: dict[str, Any] = {
137+
"message": diagnostic["message"],
138+
"range": diagnostic["range"],
139+
}
140+
if "code" in diagnostic:
141+
result["code"] = diagnostic["code"]
142+
if "source" in diagnostic:
143+
result["source"] = diagnostic["source"]
144+
return result
145+
146+
91147
class DiagnosticsDiff:
92148
def __init__(
93149
self,
94150
before_snapshot: PublishedDiagnosticsSnapshot,
95151
edited_files: Iterable[EditedFilePath],
96152
symbol_retriever: "LanguageServerSymbolRetriever",
97153
):
98-
grouped_result: dict[str, dict[str, dict[str, list[dict[str, Any]]]]] = {}
154+
grouped_diagnostics = GroupedDiagnostics()
99155

100156
for edited_file_path in edited_files:
101157
try:
@@ -139,64 +195,9 @@ def __init__(
139195
column=diagnostic_start["character"],
140196
)
141197
name_path = owner_symbol.get_name_path() if owner_symbol is not None else "<file>"
142-
self._add_grouped_diagnostic(grouped_result, edited_file_path.after_relative_path, name_path, diagnostic)
198+
grouped_diagnostics.add(edited_file_path.after_relative_path, name_path, diagnostic)
143199

144-
self._grouped_result = grouped_result
200+
self._grouped_diagnostics = grouped_diagnostics
145201

146-
def get_grouped_diagnostics(self) -> dict[str, dict[str, dict[str, list[dict[str, Any]]]]]:
147-
"""
148-
Returns a nested dictionary of the form:
149-
{
150-
relative_file_path: {
151-
severity_name: {
152-
name_path: [
153-
diagnostic_dict,
154-
...
155-
],
156-
...
157-
},
158-
...
159-
},
160-
...
161-
}
162-
where:
163-
- relative_file_path is the relative path of the file containing the diagnostic
164-
- severity_name is the name of the diagnostic severity (e.g. "Warning", "Error")
165-
- name_path is the name path of the symbol that owns the diagnostic (or "<file>" if no owner symbol was found)
166-
- diagnostic_dict is a dictionary containing the diagnostic's message, range, and optionally code and source
167-
"""
168-
return self._grouped_result
169-
170-
@classmethod
171-
def _add_grouped_diagnostic(
172-
cls,
173-
grouped_result: dict[str, dict[str, dict[str, list[dict[str, Any]]]]],
174-
relative_path: str,
175-
name_path: str,
176-
diagnostic: ls_types.Diagnostic,
177-
) -> None:
178-
severity_name = cls._diagnostic_severity_name(diagnostic.get("severity"))
179-
grouped_result.setdefault(relative_path, {}).setdefault(severity_name, {}).setdefault(name_path, []).append(
180-
cls._diagnostic_output_dict(diagnostic)
181-
)
182-
183-
@staticmethod
184-
def _diagnostic_severity_name(severity: int | None) -> str:
185-
if severity is None:
186-
return "Unknown"
187-
try:
188-
return DiagnosticSeverity(severity).name
189-
except ValueError:
190-
return f"Severity_{severity}"
191-
192-
@staticmethod
193-
def _diagnostic_output_dict(diagnostic: ls_types.Diagnostic) -> dict[str, Any]:
194-
result: dict[str, Any] = {
195-
"message": diagnostic["message"],
196-
"range": diagnostic["range"],
197-
}
198-
if "code" in diagnostic:
199-
result["code"] = diagnostic["code"]
200-
if "source" in diagnostic:
201-
result["source"] = diagnostic["source"]
202-
return result
202+
def get_grouped_diagnostics(self) -> GroupedDiagnostics:
203+
return self._grouped_diagnostics

0 commit comments

Comments
 (0)