Skip to content

Commit 812cd50

Browse files
committed
Better eval
1 parent 3379892 commit 812cd50

2 files changed

Lines changed: 94 additions & 8 deletions

File tree

squeez/training/evaluate.py

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,74 @@ def compute_partial_overlap(predicted: list[str], reference: list[str]) -> float
104104
return round(matched_chars / total_chars, 4) if total_chars > 0 else 0.0
105105

106106

107+
def _line_overlap_score(pred_line: str, ref_line: str) -> float:
108+
"""Compute a symmetric overlap score for two lines."""
109+
if not pred_line or not ref_line:
110+
return 0.0
111+
if pred_line == ref_line:
112+
return 1.0
113+
if pred_line in ref_line or ref_line in pred_line:
114+
return min(len(pred_line), len(ref_line)) / max(len(pred_line), len(ref_line))
115+
116+
pred_bigrams = (
117+
{pred_line[i : i + 2] for i in range(len(pred_line) - 1)}
118+
if len(pred_line) > 1
119+
else {pred_line}
120+
)
121+
ref_bigrams = (
122+
{ref_line[i : i + 2] for i in range(len(ref_line) - 1)} if len(ref_line) > 1 else {ref_line}
123+
)
124+
if not pred_bigrams or not ref_bigrams:
125+
return 0.0
126+
127+
overlap = len(pred_bigrams & ref_bigrams)
128+
precision = overlap / len(pred_bigrams)
129+
recall = overlap / len(ref_bigrams)
130+
if precision + recall == 0:
131+
return 0.0
132+
return 2 * precision * recall / (precision + recall)
133+
134+
135+
def compute_fuzzy_span_metrics(
136+
predicted: list[str],
137+
reference: list[str],
138+
threshold: float = 0.5,
139+
) -> dict[str, float]:
140+
"""Compute one-to-one fuzzy line overlap metrics at a fixed threshold."""
141+
if not reference and not predicted:
142+
return {"precision": 1.0, "recall": 1.0, "f1": 1.0}
143+
if not reference or not predicted:
144+
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
145+
146+
candidate_pairs: list[tuple[float, int, int]] = []
147+
for pred_idx, pred_line in enumerate(predicted):
148+
for ref_idx, ref_line in enumerate(reference):
149+
score = _line_overlap_score(pred_line, ref_line)
150+
if score >= threshold:
151+
candidate_pairs.append((score, pred_idx, ref_idx))
152+
153+
matched_pred: set[int] = set()
154+
matched_ref: set[int] = set()
155+
tp = 0
156+
157+
for score, pred_idx, ref_idx in sorted(candidate_pairs, reverse=True):
158+
del score
159+
if pred_idx in matched_pred or ref_idx in matched_ref:
160+
continue
161+
matched_pred.add(pred_idx)
162+
matched_ref.add(ref_idx)
163+
tp += 1
164+
165+
precision = tp / len(predicted) if predicted else 0.0
166+
recall = tp / len(reference) if reference else 0.0
167+
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
168+
return {
169+
"precision": round(precision, 4),
170+
"recall": round(recall, 4),
171+
"f1": round(f1, 4),
172+
}
173+
174+
107175
def compute_empty_accuracy(predicted: list[str], reference: list[str]) -> dict[str, float | str]:
108176
"""Check if model correctly predicts empty vs non-empty.
109177
@@ -226,6 +294,9 @@ def evaluate_model(
226294
"span_recall": [],
227295
"span_f1": [],
228296
"exact_match": [],
297+
"fuzzy_span_precision": [],
298+
"fuzzy_span_recall": [],
299+
"fuzzy_span_f1": [],
229300
"partial_overlap": [],
230301
"empty_accuracy": [],
231302
"rouge_l": [],
@@ -257,30 +328,24 @@ def evaluate_sample(sample: dict) -> dict:
257328

258329
# Span metrics
259330
span = compute_span_metrics(pred_lines, ref_lines)
260-
all_metrics["span_precision"].append(span["precision"])
261-
all_metrics["span_recall"].append(span["recall"])
262-
all_metrics["span_f1"].append(span["f1"])
263-
all_metrics["exact_match"].append(span["exact_match"])
331+
fuzzy = compute_fuzzy_span_metrics(pred_lines, ref_lines, threshold=0.5)
264332

265333
# Partial overlap
266334
partial = compute_partial_overlap(pred_lines, ref_lines)
267-
all_metrics["partial_overlap"].append(partial)
268335

269336
# Empty accuracy
270337
empty = compute_empty_accuracy(pred_lines, ref_lines)
271-
all_metrics["empty_accuracy"].append(empty["correct"])
272-
empty_confusion[empty["category"]] += 1
273338

274339
# ROUGE-L on concatenated text
275340
pred_text = "\n".join(pred_lines)
276341
ref_text = "\n".join(ref_lines)
277342
rouge = compute_rouge_l(pred_text, ref_text)
278-
all_metrics["rouge_l"].append(rouge)
279343

280344
# Compression
281345
compression = compute_compression_ratio(tool_output, pred_text)
282346
return {
283347
"span": span,
348+
"fuzzy": fuzzy,
284349
"partial": partial,
285350
"empty": empty,
286351
"rouge": rouge,
@@ -289,6 +354,7 @@ def evaluate_sample(sample: dict) -> dict:
289354

290355
def record_result(result: dict) -> None:
291356
span = result["span"]
357+
fuzzy = result["fuzzy"]
292358
partial = result["partial"]
293359
empty = result["empty"]
294360
rouge = result["rouge"]
@@ -298,6 +364,9 @@ def record_result(result: dict) -> None:
298364
all_metrics["span_recall"].append(span["recall"])
299365
all_metrics["span_f1"].append(span["f1"])
300366
all_metrics["exact_match"].append(span["exact_match"])
367+
all_metrics["fuzzy_span_precision"].append(fuzzy["precision"])
368+
all_metrics["fuzzy_span_recall"].append(fuzzy["recall"])
369+
all_metrics["fuzzy_span_f1"].append(fuzzy["f1"])
301370
all_metrics["partial_overlap"].append(partial)
302371
all_metrics["empty_accuracy"].append(empty["correct"])
303372
empty_confusion[empty["category"]] += 1

tests/test_extractor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,23 @@ def test_span_metrics(self):
191191
metrics = compute_span_metrics([], [])
192192
assert metrics["exact_match"] == 1.0
193193

194+
def test_fuzzy_span_metrics(self):
195+
from squeez.training.evaluate import compute_fuzzy_span_metrics
196+
197+
metrics = compute_fuzzy_span_metrics(
198+
["ERROR: foo failed at line 12"],
199+
["foo failed at line 12"],
200+
threshold=0.5,
201+
)
202+
assert metrics["precision"] == 1.0
203+
assert metrics["recall"] == 1.0
204+
assert metrics["f1"] == 1.0
205+
206+
metrics = compute_fuzzy_span_metrics(
207+
["completely different"], ["foo failed"], threshold=0.5
208+
)
209+
assert metrics["f1"] == 0.0
210+
194211
def test_empty_accuracy(self):
195212
from squeez.training.evaluate import compute_empty_accuracy
196213

0 commit comments

Comments
 (0)