Skip to content

Commit 7e1ca6b

Browse files
committed
Add code review evaluation pipeline
1 parent 7f1400f commit 7e1ca6b

File tree

4 files changed

+713
-0
lines changed

4 files changed

+713
-0
lines changed

bugbug/tools/code_review/scorer.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
from functools import cached_property
2+
from logging import getLogger
3+
4+
import weave
5+
6+
from bugbug.tools.comment_matching.agent import CommentMatchingTool
7+
from bugbug.tools.suggestion_filtering.agent import SuggestionFilteringTool
8+
9+
logger = getLogger(__name__)
10+
11+
12+
class BasicMetricsScorer(weave.Scorer):
13+
"""Score basic metrics: comment counts and error tracking."""
14+
15+
@weave.op()
16+
def score(
17+
self,
18+
output: dict,
19+
ground_truth_comments: list[dict],
20+
) -> dict:
21+
valid_comment_count = sum(
22+
comment["evaluation"] == "VALID" for comment in ground_truth_comments
23+
)
24+
invalid_comment_count = sum(
25+
comment["evaluation"] == "INVALID" for comment in ground_truth_comments
26+
)
27+
28+
return {
29+
"generated_comment_count": len(output["comments"]),
30+
"ground_truth_valid_count": valid_comment_count,
31+
"ground_truth_invalid_count": invalid_comment_count,
32+
"ground_truth_total_count": len(ground_truth_comments),
33+
"successful": not output["error"],
34+
}
35+
36+
def summarize(self, score_rows: list[dict]) -> dict:
37+
"""Aggregate scores across all examples."""
38+
total_generated = sum(r["generated_comment_count"] for r in score_rows)
39+
total_gt_valid = sum(r["ground_truth_valid_count"] for r in score_rows)
40+
total_gt_invalid = sum(r["ground_truth_invalid_count"] for r in score_rows)
41+
total_gt = sum(r["ground_truth_total_count"] for r in score_rows)
42+
error_count = sum(not r["successful"] for r in score_rows)
43+
44+
return {
45+
"total_generated_comments": total_generated,
46+
"total_ground_truth_valid": total_gt_valid,
47+
"total_ground_truth_invalid": total_gt_invalid,
48+
"total_ground_truth": total_gt,
49+
"avg_generated_per_diff": (
50+
total_generated / len(score_rows) if score_rows else 0
51+
),
52+
"error_rate": error_count / len(score_rows) if score_rows else 0,
53+
"num_examples": len(score_rows),
54+
}
55+
56+
57+
class LLMCommentMatchingScorer(weave.Scorer):
58+
"""Score comment matching using LLM-based semantic comparison.
59+
60+
This scorer uses an LLM to match generated comments against ground truth
61+
comments, calculating recall and precision metrics.
62+
"""
63+
64+
@cached_property
65+
def matching_tool(self):
66+
return CommentMatchingTool.create()
67+
68+
@cached_property
69+
def filtering_tool(self):
70+
return SuggestionFilteringTool.create()
71+
72+
@weave.op()
73+
def score(
74+
self,
75+
output: dict,
76+
ground_truth_comments: list[dict],
77+
diff_id: int,
78+
) -> dict:
79+
generated_comments = output["comments"]
80+
81+
retained_indices = set(
82+
self.filtering_tool.get_indices_of_retained_comments(generated_comments)
83+
)
84+
retained_comments = [
85+
c for i, c in enumerate(generated_comments) if i in retained_indices
86+
]
87+
excluded_comments = [
88+
c for i, c in enumerate(generated_comments) if i not in retained_indices
89+
]
90+
91+
old_comments = [
92+
{"id": i, "content": c["comment"], "file": c["file_path"]}
93+
for i, c in enumerate(ground_truth_comments)
94+
]
95+
96+
new_comments = [
97+
{"id": i, "content": c.comment, "file": c.file}
98+
for i, c in enumerate(generated_comments)
99+
]
100+
101+
matches = self.matching_tool.run(
102+
old_comments=old_comments, new_comments=new_comments
103+
)
104+
105+
seen_old: set[int] = set()
106+
seen_new: set[int] = set()
107+
matched_valid_retained = []
108+
matched_valid_excluded = []
109+
matched_invalid_retained = []
110+
matched_invalid_excluded = []
111+
112+
for match in matches:
113+
old_idx = match.old_comment_id
114+
new_idx = match.new_comment_id
115+
116+
if old_idx >= len(ground_truth_comments) or new_idx >= len(
117+
generated_comments
118+
):
119+
continue
120+
121+
# Validate file match
122+
gt_comment = ground_truth_comments[old_idx]
123+
gen_comment = generated_comments[new_idx]
124+
125+
if gt_comment["file_path"] != gen_comment.file:
126+
logger.debug(
127+
f"File mismatch for diff {diff_id}: "
128+
f"{gt_comment['file_path']} != {gen_comment.file}"
129+
)
130+
continue
131+
132+
seen_old.add(old_idx)
133+
seen_new.add(new_idx)
134+
135+
is_retained = new_idx in retained_indices
136+
match_comments = {
137+
"ground_truth_comment": gt_comment,
138+
"generated_comment": gen_comment,
139+
}
140+
141+
if gt_comment["evaluation"] == "VALID":
142+
if is_retained:
143+
matched_valid_retained.append(match_comments)
144+
else:
145+
matched_valid_excluded.append(match_comments)
146+
else:
147+
if is_retained:
148+
matched_invalid_retained.append(match_comments)
149+
else:
150+
matched_invalid_excluded.append(match_comments)
151+
152+
unmatched_gt_valid = []
153+
unmatched_gt_invalid = []
154+
155+
for i in range(len(ground_truth_comments)):
156+
if i in seen_old:
157+
continue
158+
159+
comment = ground_truth_comments[i]
160+
evaluation = ground_truth_comments[i]["evaluation"]
161+
if evaluation == "VALID":
162+
unmatched_gt_valid.append(comment)
163+
else:
164+
unmatched_gt_invalid.append(comment)
165+
166+
unmatched_gen_retained = []
167+
unmatched_gen_excluded = []
168+
169+
for i in range(len(generated_comments)):
170+
if i in seen_new:
171+
continue
172+
173+
comment = new_comments[i]
174+
if i in retained_indices:
175+
unmatched_gen_retained.append(comment)
176+
else:
177+
unmatched_gen_excluded.append(comment)
178+
179+
return {
180+
# Matched counts (derived from lists)
181+
"matched_valid_count": len(matched_valid_retained)
182+
+ len(matched_valid_excluded),
183+
"matched_invalid_count": len(matched_invalid_retained)
184+
+ len(matched_invalid_excluded),
185+
# Unmatched counts
186+
"unmatched_generated_count": len(unmatched_gen_retained)
187+
+ len(unmatched_gen_excluded),
188+
"unmatched_ground_truth_valid_count": len(unmatched_gt_valid),
189+
"unmatched_ground_truth_invalid_count": len(unmatched_gt_invalid),
190+
# Unmatched details
191+
"unmatched_ground_truth_valid": unmatched_gt_valid,
192+
"unmatched_ground_truth_invalid": unmatched_gt_invalid,
193+
"unmatched_gen_retained": unmatched_gen_retained,
194+
"unmatched_gen_excluded": unmatched_gen_excluded,
195+
# Filtering metrics
196+
"filtering_retained_count": len(retained_comments),
197+
"filtering_excluded_count": len(excluded_comments),
198+
"filtering_retention_rate": (
199+
len(retained_comments) / len(generated_comments)
200+
if generated_comments
201+
else 0
202+
),
203+
"filtering_retained_comments": retained_comments,
204+
"filtering_excluded_comments": excluded_comments,
205+
# Filtering x Matching breakdown (lists with details)
206+
"matched_valid_retained": matched_valid_retained,
207+
"matched_valid_excluded": matched_valid_excluded,
208+
"matched_invalid_retained": matched_invalid_retained,
209+
"matched_invalid_excluded": matched_invalid_excluded,
210+
}
211+
212+
def summarize(self, score_rows: list[dict]) -> dict:
213+
total_matched_valid = sum(r["matched_valid_count"] for r in score_rows)
214+
total_matched_invalid = sum(r["matched_invalid_count"] for r in score_rows)
215+
total_unmatched_gen = sum(r["unmatched_generated_count"] for r in score_rows)
216+
total_unmatched_gt_valid = sum(
217+
r["unmatched_ground_truth_valid_count"] for r in score_rows
218+
)
219+
total_unmatched_gt_invalid = sum(
220+
r["unmatched_ground_truth_invalid_count"] for r in score_rows
221+
)
222+
223+
total_gt_valid = total_matched_valid + total_unmatched_gt_valid
224+
total_gt_invalid = total_matched_invalid + total_unmatched_gt_invalid
225+
226+
# Filtering aggregates
227+
total_retained = sum(r["filtering_retained_count"] for r in score_rows)
228+
total_excluded = sum(r["filtering_excluded_count"] for r in score_rows)
229+
total_generated = total_retained + total_excluded
230+
231+
# Filtering x Matching aggregates (use len() since values are lists)
232+
total_matched_valid_retained = sum(
233+
len(r["matched_valid_retained"]) for r in score_rows
234+
)
235+
total_matched_valid_excluded = sum(
236+
len(r["matched_valid_excluded"]) for r in score_rows
237+
)
238+
total_matched_invalid_retained = sum(
239+
len(r["matched_invalid_retained"]) for r in score_rows
240+
)
241+
total_matched_invalid_excluded = sum(
242+
len(r["matched_invalid_excluded"]) for r in score_rows
243+
)
244+
245+
return {
246+
"total_matched_valid": total_matched_valid,
247+
"total_matched_invalid": total_matched_invalid,
248+
"total_unmatched_generated": total_unmatched_gen,
249+
"recall_valid": (
250+
total_matched_valid / total_gt_valid if total_gt_valid > 0 else 0
251+
),
252+
"recall_invalid": (
253+
total_matched_invalid / total_gt_invalid if total_gt_invalid > 0 else 0
254+
),
255+
"missed_valid_rate": (
256+
total_unmatched_gt_valid / total_gt_valid if total_gt_valid > 0 else 0
257+
),
258+
"missed_invalid_rate": (
259+
total_unmatched_gt_invalid / total_gt_invalid
260+
if total_gt_invalid > 0
261+
else 0
262+
),
263+
# Filtering summary metrics
264+
"total_filtering_retained": total_retained,
265+
"total_filtering_excluded": total_excluded,
266+
"overall_retention_rate": (
267+
total_retained / total_generated if total_generated > 0 else 0
268+
),
269+
# Filtering x Matching summary
270+
"total_matched_valid_retained": total_matched_valid_retained,
271+
"total_matched_valid_excluded": total_matched_valid_excluded,
272+
"total_matched_invalid_retained": total_matched_invalid_retained,
273+
"total_matched_invalid_excluded": total_matched_invalid_excluded,
274+
# Derived rates
275+
"false_exclusion_rate": (
276+
total_matched_valid_excluded / total_matched_valid
277+
if total_matched_valid > 0
278+
else 0
279+
),
280+
"true_exclusion_rate": (
281+
total_matched_invalid_excluded / total_matched_invalid
282+
if total_matched_invalid > 0
283+
else 0
284+
),
285+
}

0 commit comments

Comments
 (0)