Skip to content

Commit 3483c12

Browse files
committed
Extract node_comments to new module, return nodes instead of strings
ghstack-source-id: adf9fc5 Pull Request resolved: #450
1 parent cca9f96 commit 3483c12

File tree

7 files changed

+213
-109
lines changed

7 files changed

+213
-109
lines changed

src/fixit/comments.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Generator, Optional, Sequence
7+
8+
from libcst import (
9+
BaseSuite,
10+
Comma,
11+
Comment,
12+
CSTNode,
13+
Decorator,
14+
EmptyLine,
15+
IndentedBlock,
16+
LeftSquareBracket,
17+
Module,
18+
RightSquareBracket,
19+
SimpleStatementSuite,
20+
TrailingWhitespace,
21+
)
22+
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider
23+
24+
25+
def node_comments(
26+
node: CSTNode, metadata: MetadataWrapper
27+
) -> Generator[Comment, None, None]:
28+
"""
29+
Yield all comments associated with the given node.
30+
31+
Includes comments from both leading comments and trailing inline comments.
32+
"""
33+
parent_nodes = metadata.resolve(ParentNodeProvider)
34+
positions = metadata.resolve(PositionProvider)
35+
target_line = positions[node].end.line
36+
37+
def gen(node: CSTNode) -> Generator[Comment, None, None]:
38+
while not isinstance(node, Module):
39+
# trailing_whitespace can either be a property of the node itself, or in
40+
# case of blocks, be part of the block's body element
41+
tw: Optional[TrailingWhitespace] = getattr(
42+
node, "trailing_whitespace", None
43+
)
44+
if tw is None:
45+
body: Optional[BaseSuite] = getattr(node, "body", None)
46+
if isinstance(body, SimpleStatementSuite):
47+
tw = body.trailing_whitespace
48+
elif isinstance(body, IndentedBlock):
49+
tw = body.header
50+
51+
if tw and tw.comment:
52+
yield tw.comment
53+
54+
comma: Optional[Comma] = getattr(node, "comma", None)
55+
if isinstance(comma, Comma):
56+
tw = getattr(comma.whitespace_after, "first_line", None)
57+
if tw and tw.comment:
58+
yield tw.comment
59+
60+
rb: Optional[RightSquareBracket] = getattr(node, "rbracket", None)
61+
if rb is not None:
62+
tw = getattr(rb.whitespace_before, "first_line", None)
63+
if tw and tw.comment:
64+
yield tw.comment
65+
66+
el: Optional[Sequence[EmptyLine]] = None
67+
lb: Optional[LeftSquareBracket] = getattr(node, "lbracket", None)
68+
if lb is not None:
69+
el = getattr(lb.whitespace_after, "empty_lines", None)
70+
if el is not None:
71+
for line in el:
72+
if line.comment:
73+
yield line.comment
74+
75+
el = getattr(node, "lines_after_decorators", None)
76+
if el is not None:
77+
for line in el:
78+
if line.comment:
79+
yield line.comment
80+
81+
ll: Optional[Sequence[EmptyLine]] = getattr(node, "leading_lines", None)
82+
if ll is not None:
83+
for line in ll:
84+
if line.comment:
85+
yield line.comment
86+
if not isinstance(node, Decorator):
87+
# stop looking once we've gone up far enough for leading_lines,
88+
# even if there are no comment lines here at all
89+
break
90+
91+
parent = parent_nodes.get(node)
92+
if parent is None:
93+
break
94+
node = parent
95+
96+
# comments at the start of the file are part of the module header rather than
97+
# part of the first statement's leading_lines, so we need to look there in case
98+
# the reported node is part of the first statement.
99+
if isinstance(node, Module):
100+
for line in node.header:
101+
if line.comment:
102+
yield line.comment
103+
else:
104+
parent = parent_nodes.get(node)
105+
if isinstance(parent, Module) and parent.body and parent.body[0] == node:
106+
for line in parent.header:
107+
if line.comment:
108+
yield line.comment
109+
110+
# wrap this in a pass-through generator so that we can easily filter the results
111+
# to only include comments that are located on or before the line containing
112+
# the original node that we're searching from
113+
yield from (c for c in gen(node) if positions[c].end.line <= target_line)

src/fixit/engine.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(self, path: Path, source: FileContent) -> None:
5454
self.source = source
5555
self.module: Module = parse_module(source)
5656
self.timings: Timings = defaultdict(lambda: 0)
57+
self.wrapper = MetadataWrapper(self.module)
5758

5859
def collect_violations(
5960
self,
@@ -79,10 +80,14 @@ def visit_hook(name: str) -> Iterator[None]:
7980
self.timings[name] += duration_us
8081

8182
metadata_cache: Mapping[ProviderT, object] = {}
83+
self.wrapper = MetadataWrapper(
84+
self.module, unsafe_skip_copy=True, cache=metadata_cache
85+
)
8286
needs_repo_manager: Set[ProviderT] = set()
8387

8488
for rule in rules:
8589
rule._visit_hook = visit_hook
90+
rule._metadata_wrapper = self.wrapper
8691
for provider in rule.get_inherited_dependencies():
8792
if provider.gen_cache is not None:
8893
# TODO: find a better way to declare this requirement in LibCST
@@ -95,12 +100,11 @@ def visit_hook(name: str) -> Iterator[None]:
95100
providers=needs_repo_manager,
96101
)
97102
repo_manager.resolve_cache()
98-
metadata_cache = repo_manager.get_cache_for_path(config.path.as_posix())
103+
self.wrapper._cache = repo_manager.get_cache_for_path(
104+
config.path.as_posix()
105+
)
99106

100-
wrapper = MetadataWrapper(
101-
self.module, unsafe_skip_copy=True, cache=metadata_cache
102-
)
103-
wrapper.visit_batched(rules)
107+
self.wrapper.visit_batched(rules)
104108
count = 0
105109
for rule in rules:
106110
for violation in rule._violations:

src/fixit/ftypes.py

+6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import platform
77
import re
88
from dataclasses import dataclass, field
9+
from enum import Enum
910
from pathlib import Path
1011
from typing import (
1112
Any,
@@ -65,6 +66,11 @@ class Valid:
6566
code: str
6667

6768

69+
class LintIgnoreStyle(Enum):
70+
fixme = "fixme"
71+
ignore = "ignore"
72+
73+
6874
LintIgnoreRegex = re.compile(
6975
r"""
7076
\#\s* # leading hash and whitespace

src/fixit/rule.py

+6-102
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,9 @@
77

88
import functools
99
from dataclasses import replace
10-
from typing import (
11-
ClassVar,
12-
Collection,
13-
Generator,
14-
List,
15-
Mapping,
16-
Optional,
17-
Sequence,
18-
Set,
19-
Union,
20-
)
10+
from typing import ClassVar, Collection, List, Mapping, Optional, Set, Union
2111

22-
from libcst import (
23-
BaseSuite,
24-
BatchableCSTVisitor,
25-
Comma,
26-
CSTNode,
27-
Decorator,
28-
EmptyLine,
29-
IndentedBlock,
30-
LeftSquareBracket,
31-
Module,
32-
RightSquareBracket,
33-
SimpleStatementSuite,
34-
TrailingWhitespace,
35-
)
12+
from libcst import BatchableCSTVisitor, CSTNode, MetadataWrapper, Module
3613
from libcst.metadata import (
3714
CodePosition,
3815
CodeRange,
@@ -41,6 +18,7 @@
4118
ProviderT,
4219
)
4320

21+
from .comments import node_comments
4422
from .ftypes import (
4523
Invalid,
4624
LintIgnoreRegex,
@@ -115,81 +93,7 @@ def __str__(self) -> str:
11593
return f"{self.__class__.__module__}:{self.__class__.__name__}"
11694

11795
_visit_hook: Optional[VisitHook] = None
118-
119-
def node_comments(self, node: CSTNode) -> Generator[str, None, None]:
120-
"""
121-
Yield all comments associated with the given node.
122-
123-
Includes comments from both leading comments and trailing inline comments.
124-
"""
125-
while not isinstance(node, Module):
126-
# trailing_whitespace can either be a property of the node itself, or in
127-
# case of blocks, be part of the block's body element
128-
tw: Optional[TrailingWhitespace] = getattr(
129-
node, "trailing_whitespace", None
130-
)
131-
if tw is None:
132-
body: Optional[BaseSuite] = getattr(node, "body", None)
133-
if isinstance(body, SimpleStatementSuite):
134-
tw = body.trailing_whitespace
135-
elif isinstance(body, IndentedBlock):
136-
tw = body.header
137-
138-
if tw and tw.comment:
139-
yield tw.comment.value
140-
141-
comma: Optional[Comma] = getattr(node, "comma", None)
142-
if isinstance(comma, Comma):
143-
tw = getattr(comma.whitespace_after, "first_line", None)
144-
if tw and tw.comment:
145-
yield tw.comment.value
146-
147-
rb: Optional[RightSquareBracket] = getattr(node, "rbracket", None)
148-
if rb is not None:
149-
tw = getattr(rb.whitespace_before, "first_line", None)
150-
if tw and tw.comment:
151-
yield tw.comment.value
152-
153-
el: Optional[Sequence[EmptyLine]] = None
154-
lb: Optional[LeftSquareBracket] = getattr(node, "lbracket", None)
155-
if lb is not None:
156-
el = getattr(lb.whitespace_after, "empty_lines", None)
157-
if el is not None:
158-
for line in el:
159-
if line.comment:
160-
yield line.comment.value
161-
162-
el = getattr(node, "lines_after_decorators", None)
163-
if el is not None:
164-
for line in el:
165-
if line.comment:
166-
yield line.comment.value
167-
168-
ll: Optional[Sequence[EmptyLine]] = getattr(node, "leading_lines", None)
169-
if ll is not None:
170-
for line in ll:
171-
if line.comment:
172-
yield line.comment.value
173-
if not isinstance(node, Decorator):
174-
# stop looking once we've gone up far enough for leading_lines,
175-
# even if there are no comment lines here at all
176-
break
177-
178-
node = self.get_metadata(ParentNodeProvider, node)
179-
180-
# comments at the start of the file are part of the module header rather than
181-
# part of the first statement's leading_lines, so we need to look there in case
182-
# the reported node is part of the first statement.
183-
if isinstance(node, Module):
184-
for line in node.header:
185-
if line.comment:
186-
yield line.comment.value
187-
else:
188-
parent = self.get_metadata(ParentNodeProvider, node)
189-
if isinstance(parent, Module) and parent.body and parent.body[0] == node:
190-
for line in parent.header:
191-
if line.comment:
192-
yield line.comment.value
96+
_metadata_wrapper: MetadataWrapper = MetadataWrapper(Module([]))
19397

19498
def ignore_lint(self, node: CSTNode) -> bool:
19599
"""
@@ -199,8 +103,8 @@ def ignore_lint(self, node: CSTNode) -> bool:
199103
current rule by name, or if the directives have no rule names listed.
200104
"""
201105
rule_names = (self.name, self.name.lower())
202-
for comment in self.node_comments(node):
203-
if match := LintIgnoreRegex.search(comment):
106+
for comment in node_comments(node, self._metadata_wrapper):
107+
if match := LintIgnoreRegex.search(comment.value):
204108
_style, names = match.groups()
205109

206110
# directive

src/fixit/tests/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from fixit.ftypes import Config, QualifiedRule
88

99
from fixit.testing import add_lint_rule_tests_to_module
10+
from .comments import CommentsTest
1011
from .config import ConfigTest
1112
from .engine import EngineTest
1213
from .ftypes import TypesTest

src/fixit/tests/comments.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from textwrap import dedent
7+
from unittest import TestCase
8+
9+
import libcst.matchers as m
10+
from libcst import MetadataWrapper, parse_module
11+
12+
from ..comments import node_comments
13+
14+
15+
class CommentsTest(TestCase):
16+
def test_node_comments(self) -> None:
17+
for idx, (code, test_cases) in enumerate(
18+
(
19+
(
20+
"""
21+
# module-level comment
22+
print("hello") # trailing comment
23+
""",
24+
(
25+
(m.Call(func=m.Name("something")), ()),
26+
(m.Call(), ["# module-level comment", "# trailing comment"]),
27+
),
28+
),
29+
(
30+
"""
31+
import sys
32+
33+
# leading comment
34+
print("hello") # trailing comment
35+
""",
36+
((m.Call(), ["# leading comment", "# trailing comment"]),),
37+
),
38+
(
39+
"""
40+
import sys
41+
42+
# leading comment
43+
@alpha # first decorator comment
44+
# between-decorator comment
45+
@beta # second decorator comment
46+
# after-decorator comment
47+
class Foo: # trailing comment
48+
pass
49+
""",
50+
(
51+
(
52+
m.ClassDef(),
53+
[
54+
"# leading comment",
55+
"# after-decorator comment",
56+
"# trailing comment",
57+
],
58+
),
59+
(
60+
m.Decorator(decorator=m.Name("alpha")),
61+
["# leading comment", "# first decorator comment"],
62+
),
63+
),
64+
),
65+
),
66+
start=1,
67+
):
68+
code = dedent(code)
69+
module = parse_module(code)
70+
wrapper = MetadataWrapper(module, unsafe_skip_copy=True)
71+
for idx2, (matcher, expected) in enumerate(test_cases):
72+
with self.subTest(f"node comments {idx}-{chr(ord('a')+idx2)}"):
73+
for node in m.findall(module, matcher):
74+
comments = [c.value for c in node_comments(node, wrapper)]
75+
self.assertEqual(sorted(expected), sorted(comments))

0 commit comments

Comments
 (0)