Skip to content

Commit cf2dc5d

Browse files
incorporate pr comments
1 parent c0df538 commit cf2dc5d

File tree

4 files changed

+100
-90
lines changed

4 files changed

+100
-90
lines changed

.github/workflows/relevant-warnings.yml

-6
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,6 @@ jobs:
3030
pip install mypy==1.15.0
3131
pip install ruff==0.3.3
3232
33-
- name: Install Build Dependencies
34-
run: |
35-
sudo apt-get update
36-
wget http://archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb
37-
sudo dpkg -i libssl1.1_1.1.1f-1ubuntu2_amd64.deb
38-
3933
- name: Install Overwatch CLI
4034
run: |
4135
curl -o overwatch-cli https://overwatch.codecov.io/linux/cli

src/seer/automation/codegen/models.py

-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ class PrFile(BaseModel):
156156
changes: int
157157
sha: str
158158

159-
160159
class FilterWarningsRequest(BaseComponentRequest):
161160
warnings: list[StaticAnalysisWarning]
162161
pr_files: list[PrFile]

src/seer/automation/codegen/relevant_warnings_component.py

+69-60
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import bisect
12
import logging
23
import re
34
import textwrap
@@ -12,7 +13,7 @@
1213

1314
from seer.automation.agent.client import GeminiProvider, LlmClient
1415
from seer.automation.agent.embeddings import GoogleProviderEmbeddings
15-
from seer.automation.codebase.models import StaticAnalysisWarning
16+
from seer.automation.codebase.models import Location, StaticAnalysisWarning
1617
from seer.automation.codegen.codegen_context import CodegenContext
1718
from seer.automation.codegen.models import (
1819
AssociateWarningsWithIssuesOutput,
@@ -66,8 +67,35 @@ def _left_truncated_paths(path: Path, max_num_paths: int = 2) -> list[str]:
6667
result.append(Path(*parts).as_posix())
6768
return result
6869

69-
def _get_pr_changed_lines(self, pr_file: PrFile) -> list[int]:
70-
"""Returns the 1-indexed changed line numbers in the updated pr file.
70+
def _build_filepath_mapping(self, pr_files: list[PrFile]) -> dict[str, PrFile]:
71+
"""Build mapping of possible filepaths to PR files, including truncated variations."""
72+
filepath_to_pr_file: dict[str, PrFile] = {}
73+
for pr_file in pr_files:
74+
pr_path = Path(pr_file.filename)
75+
filepath_to_pr_file[pr_path.as_posix()] = pr_file
76+
for truncated in self._left_truncated_paths(pr_path, max_num_paths=1):
77+
filepath_to_pr_file[truncated] = pr_file
78+
return filepath_to_pr_file
79+
80+
def _is_warning_in_diff(
81+
self,
82+
warning: StaticAnalysisWarning,
83+
filepath_to_pr_file: dict[str, PrFile],
84+
) -> bool:
85+
matching_pr_files = self._get_matching_pr_files(warning, filepath_to_pr_file)
86+
warning_location = Location.from_encoded(warning.encoded_location)
87+
for pr_file in matching_pr_files:
88+
hunk_ranges = self._get_sorted_hunk_ranges(pr_file)
89+
if self._do_ranges_overlap(
90+
(int(warning_location.start_line), int(warning_location.end_line)),
91+
hunk_ranges,
92+
):
93+
return True
94+
95+
return False
96+
97+
def _get_sorted_hunk_ranges(self, pr_file: PrFile) -> list[tuple[int, int]]:
98+
"""Returns sorted tuples of 1-indexed line numbers (start_inclusive, end_exclusive) in the updated pr file.
7199
72100
Determined by parsing git diff hunk headers of the form:
73101
@@ -n,m +p,q @@ where:
@@ -83,46 +111,43 @@ def hello():
83111
+ print("world") # Line 3 is added
84112
print("goodbye")
85113
86-
This would return [1,2,3,4] since these are all the lines in the updated file
114+
@@ -20,3 +21,4 @@
115+
print("end")
116+
+ print("new end") # Line 22 is added
117+
return
118+
119+
This would return [(1,5), (21,25)] representing the modified file's hunk ranges.
87120
88121
Args:
89-
pr_file: PrFile object containing the patch/diff
122+
pr_file: PrFile object containing the patch/diff (sorted by line number)
90123
91124
Returns:
92-
List of 1-indexed line numbers in the updated file
125+
List of sorted tuples containing 1-indexed line numbers (start_inclusive, end_exclusive) in the updated file
93126
"""
94127
patch_lines = pr_file.patch.split("\n")
95-
changed_lines: list[int] = []
128+
hunk_ranges: list[tuple[int, int]] = []
96129

97130
for line in patch_lines:
98131
if line.startswith("@@"):
99132
try:
100133
match = re.match(r"@@ -(\d+),(\d+) \+(\d+),(\d+) @@", line)
101134
if match:
102135
_, _, new_start, num_lines = map(int, match.groups())
103-
changed_lines.extend(range(new_start, new_start + num_lines))
136+
hunk_ranges.append((new_start, new_start + num_lines))
104137
except Exception:
105138
self.logger.warning(f"Could not parse hunk header: {line}")
106139
continue
107140

108-
return changed_lines
141+
return hunk_ranges
109142

110-
def _get_possible_pr_files(
111-
self, warning: StaticAnalysisWarning, pr_files: list[PrFile]
143+
def _get_matching_pr_files(
144+
self, warning: StaticAnalysisWarning, filepath_to_pr_file: dict[str, PrFile]
112145
) -> list[PrFile]:
113146
"""Find PR files that may match a warning's location.
114-
115147
This handles cases where the warning location and PR file paths may be specified differently:
116148
- With different numbers of parent directories
117149
- With or without a repo prefix
118150
- With relative vs absolute paths
119-
120-
Args:
121-
warning: The static analysis warning to check
122-
pr_files: List of PR files to match against
123-
124-
Returns:
125-
List of PR files that may match the warning's location
126151
"""
127152
filename = warning.encoded_location.split(":")[0]
128153
path = Path(filename)
@@ -135,62 +160,46 @@ def _get_possible_pr_files(
135160
f"Found `..` in the middle of path. Encoded location: {warning.encoded_location}"
136161
)
137162

138-
# Make possible variations of the warning's path
139163
warning_filepath_variations = {
140164
path.as_posix(),
141165
*self._left_truncated_paths(path, max_num_paths=2),
142166
}
143167

144-
# Make possible variations of the pr files' paths
145-
pr_file_by_filepath_variation: dict[str, PrFile] = {}
146-
for pr_file in pr_files:
147-
pr_path = Path(pr_file.filename)
148-
pr_file_by_filepath_variation[pr_path.as_posix()] = pr_file
149-
for truncated in self._left_truncated_paths(pr_path, max_num_paths=1):
150-
pr_file_by_filepath_variation[truncated] = pr_file
151-
152-
# Find all matching PR files
153-
matching_pr_files: list[PrFile] = []
154-
for filepath in warning_filepath_variations:
155-
if filepath in pr_file_by_filepath_variation:
156-
matching_pr_files.append(pr_file_by_filepath_variation[filepath])
157-
158-
if matching_pr_files:
159-
self.logger.debug(
160-
"Found matching PR files",
161-
extra={
162-
"warning_location": warning.encoded_location,
163-
"matching_files": [pf.filename for pf in matching_pr_files],
164-
},
165-
)
166-
167-
return matching_pr_files
168+
return [
169+
filepath_to_pr_file[filepath]
170+
for filepath in warning_filepath_variations & set(filepath_to_pr_file)
171+
]
168172

169-
def _is_warning_line_in_pr_files(
170-
self, warning: StaticAnalysisWarning, matching_pr_files: list[PrFile]
173+
def _do_ranges_overlap(
174+
self, warning_range: tuple[int, int], sorted_hunk_ranges: list[tuple[int, int]]
171175
) -> bool:
172-
# Encoded location format: "file:line:col"
173-
location_parts = warning.encoded_location.split(":")
174-
if len(location_parts) < 2:
175-
self.logger.warning(
176-
f"Invalid warning location format - missing line number: {warning.encoded_location}"
177-
)
176+
if not sorted_hunk_ranges or not warning_range:
178177
return False
179-
180-
warning_line = int(location_parts[1])
181-
return any(
182-
warning_line in self._get_pr_changed_lines(pr_file) for pr_file in matching_pr_files
178+
target_start, target_end = warning_range
179+
# Handle special case of single line warning by making end inclusive
180+
if target_start == target_end:
181+
target_end += 1
182+
index = bisect.bisect_left(sorted_hunk_ranges, (target_start,))
183+
return (index > 0 and sorted_hunk_ranges[index - 1][1] > target_start) or (
184+
index < len(sorted_hunk_ranges) and sorted_hunk_ranges[index][0] < target_end
183185
)
184186

185187
@observe(name="Codegen - Relevant Warnings - Filter Warnings Component")
186188
@ai_track(description="Codegen - Relevant Warnings - Filter Warnings Component")
187189
def invoke(self, request: FilterWarningsRequest) -> FilterWarningsOutput:
190+
filepath_to_pr_file = self._build_filepath_mapping(request.pr_files)
191+
188192
filtered_warnings: list[StaticAnalysisWarning] = []
189193
for warning in request.warnings:
190-
possible_pr_files = self._get_possible_pr_files(warning, request.pr_files)
191-
192-
if self._is_warning_line_in_pr_files(warning, possible_pr_files):
193-
filtered_warnings.append(warning)
194+
try:
195+
if self._is_warning_in_diff(warning, filepath_to_pr_file):
196+
filtered_warnings.append(warning)
197+
except Exception as e:
198+
self.logger.warning(
199+
f"Failed to evaluate warning, skipping: {warning.id} ({warning.encoded_location})",
200+
exc_info=e,
201+
)
202+
continue
194203

195204
return FilterWarningsOutput(warnings=filtered_warnings)
196205

tests/automation/codegen/test_relevant_warnings.py

+31-23
Original file line numberDiff line numberDiff line change
@@ -52,39 +52,47 @@ def test_bad_encoded_locations_cause_errors(self, component: FilterWarningsCompo
5252
ValueError,
5353
match=f"Found `..` in the middle of path. Encoded location: {warning.encoded_location}",
5454
):
55-
component._get_possible_pr_files(
55+
component._get_matching_pr_files(
5656
warning,
5757
[PrFile(filename="file1.py", patch="", status="modified", changes=1, sha="abc")],
5858
)
5959

60-
def test_get_changed_lines(self, component: FilterWarningsComponent):
61-
pr_file = PrFile(
62-
filename="hello.py",
63-
patch="""@@ -1,3 +1,4 @@
64-
def hello():
65-
print("hello")
66-
+ print("world")
67-
print("goodbye")""",
68-
status="modified",
69-
changes=10,
70-
sha="sha1",
71-
)
72-
assert component._get_pr_changed_lines(pr_file) == [1, 2, 3, 4]
73-
60+
def test_get_hunk_ranges(self, component: FilterWarningsComponent):
7461
pr_file = PrFile(
7562
filename="test.py",
76-
patch="""@@ -10,3 +10,6 @@
77-
class MyTest:
78-
def test_one(self):
79-
assert True
80-
+ def test_two(self):
81-
+ assert False
82-
+ def test_three(self):""",
63+
patch="""@@ -1,3 +1,4 @@
64+
def hello():
65+
print("hello")
66+
+ print("world") # Line 3 is added
67+
print("goodbye")
68+
69+
@@ -20,3 +21,4 @@
70+
print("end")
71+
+ print("new end") # Line 22 is added
72+
return""",
8373
status="modified",
8474
changes=15,
8575
sha="sha2",
8676
)
87-
assert component._get_pr_changed_lines(pr_file) == [10, 11, 12, 13, 14, 15]
77+
assert component._get_sorted_hunk_ranges(pr_file) == [(1, 5), (21, 25)]
78+
79+
def test_do_ranges_overlap(self, component: FilterWarningsComponent):
80+
# Test overlapping ranges
81+
assert component._do_ranges_overlap((1, 5), [(1, 5)]) # Exact match
82+
assert component._do_ranges_overlap((2, 4), [(1, 5)]) # Warning contained within hunk
83+
assert component._do_ranges_overlap((1, 3), [(2, 5)]) # Partial overlap at start
84+
assert component._do_ranges_overlap((4, 6), [(2, 5)]) # Partial overlap at end
85+
assert component._do_ranges_overlap((1, 6), [(2, 4)]) # Hunk contained within warning
86+
assert component._do_ranges_overlap((3, 6), [(1, 4), (5, 7)]) # Overlaps multiple hunks
87+
assert component._do_ranges_overlap((1, 1), [(1, 4), (5, 7)]) # Warning only has 1 line
88+
89+
# Test non-overlapping ranges
90+
assert not component._do_ranges_overlap((1, 2), [(3, 4)]) # Warning before hunk
91+
assert not component._do_ranges_overlap((5, 6), [(2, 4)]) # Warning after hunk
92+
assert not component._do_ranges_overlap((1, 2), []) # Empty hunks
93+
assert not component._do_ranges_overlap(
94+
(1, 2), [(10, 12), (20, 25)]
95+
) # No overlap with any hunks
8896

8997
class _TestInvokeTestCase(BaseModel):
9098
id: str

0 commit comments

Comments
 (0)