Skip to content

Commit 3f2445d

Browse files
smahdavi4Kipok
andauthored
Improve Answer-Judge metrics (#507)
Co-authored-by: Igor Gitman <igitman@nvidia.com>
1 parent 6dfb7d3 commit 3f2445d

5 files changed

Lines changed: 197 additions & 58 deletions

File tree

nemo_skills/evaluation/metrics/answer_judgement_metrics.py

Lines changed: 108 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,49 @@
1414

1515
from nemo_skills.evaluation.metrics.base import BaseMetrics
1616
from nemo_skills.evaluation.metrics.utils import is_correct_judgement
17+
from functools import partial
18+
from collections import defaultdict
1719

20+
is_correct_judgement_or_none = partial(is_correct_judgement, return_none=True)
1821

1922
class AnswerJudgementMetrics(BaseMetrics):
23+
def __init__(self):
24+
super().__init__()
25+
# Store individual TP/FP/FN/TN values as N x K matrix (N datapoints, K samples each)
26+
self.total_positives = 0
27+
self.individual_metrics = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
28+
29+
def reset(self):
30+
super().reset()
31+
self.individual_metrics = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
32+
2033
def _get_score_dict(self, prediction: dict) -> dict[str, bool | int | float]:
21-
gt_judgement = is_correct_judgement(prediction['expected_judgement'])
22-
pred_judgement = is_correct_judgement(prediction['judgement'])
34+
gt_judgement = is_correct_judgement_or_none(prediction['expected_judgement'])
35+
pred_judgement = is_correct_judgement_or_none(prediction['judgement'])
2336

2437
return {'correct_judgements': gt_judgement == pred_judgement}
2538

39+
def _store_individual_metrics(self, agg_key, pred_judgement, gt_judgement, sample_idx=0):
40+
"""Store individual TP/FP/FN/TN values in N x K matrix structure."""
41+
is_fp = pred_judgement is True and gt_judgement is False
42+
is_fn = pred_judgement is False and gt_judgement is True
43+
is_tp = pred_judgement is True and gt_judgement is True
44+
is_tn = pred_judgement is False and gt_judgement is False
45+
46+
# Store in N x K matrix: [datapoint_idx][sample_idx]
47+
# This is hacky, but the only way to access the datapoint_idx
48+
datapoint_idx = self.total - 1
49+
self.individual_metrics[agg_key][datapoint_idx][sample_idx] = {
50+
'tp': float(is_tp),
51+
'fp': float(is_fp),
52+
'fn': float(is_fn),
53+
'tn': float(is_tn)
54+
}
55+
2656
def _update_fp_fn(self, metrics_dict, pred_judgement, gt_judgement, divide_by=1):
2757
is_fp = pred_judgement is True and gt_judgement is False
2858
is_fn = pred_judgement is False and gt_judgement is True
59+
2960
metrics_dict['false_positives'] += float(is_fp) / divide_by
3061
metrics_dict['false_negatives'] += float(is_fn) / divide_by
3162

@@ -42,8 +73,9 @@ def _update_score_metrics_for_majority(
4273
):
4374
assert score_method == 'correct_judgements'
4475
# expected answer is always the same for all predictions, so just take the first one
45-
gt_judgement = is_correct_judgement(predictions[0]['expected_judgement'])
76+
gt_judgement = is_correct_judgement_or_none(predictions[0]['expected_judgement'])
4677
self._update_fp_fn(eval_dict[f"majority@{k}"], majority_answer, gt_judgement)
78+
self._store_individual_metrics(f"majority@{k}", majority_answer, gt_judgement)
4779

4880
def _update_score_metrics_for_pass(
4981
self,
@@ -57,16 +89,22 @@ def _update_score_metrics_for_pass(
5789
):
5890
assert score_method == 'correct_judgements'
5991
# expected answer is always the same for all predictions, so just take the first one
60-
gt_judgement = is_correct_judgement(predictions[0]['expected_judgement'])
61-
pred_judgement = is_correct_judgement(predictions[0]['judgement'])
62-
# if pass is not correct, means all predictions are the same and wrong
63-
if not pass_score:
64-
self._update_fp_fn(eval_dict[f"pass@{k}"], pred_judgement, gt_judgement)
92+
gt_judgement = is_correct_judgement_or_none(predictions[0]['expected_judgement'])
93+
pred_judgements = [is_correct_judgement_or_none(pred['judgement']) for pred in predictions[:k]]
94+
if gt_judgement in pred_judgements:
95+
pred_judgement = gt_judgement
96+
else:
97+
not_none_pred_judgements = [pred_judgement for pred_judgement in pred_judgements if pred_judgement is not None]
98+
pred_judgement = not_none_pred_judgements[0] if not_none_pred_judgements else None
6599

66-
for pred in predictions[:k]:
67-
gt_judgement = is_correct_judgement(pred['expected_judgement'])
68-
pred_judgement = is_correct_judgement(pred['judgement'])
100+
self._update_fp_fn(eval_dict[f"pass@{k}"], pred_judgement, gt_judgement)
101+
self._store_individual_metrics(f"pass@{k}", pred_judgement, gt_judgement)
102+
103+
for sample_idx, pred in enumerate(predictions[:k]):
104+
gt_judgement = is_correct_judgement_or_none(pred['expected_judgement'])
105+
pred_judgement = is_correct_judgement_or_none(pred['judgement'])
69106
self._update_fp_fn(eval_dict[f"pass@1[{k}]"], pred_judgement, gt_judgement, divide_by=k)
107+
self._store_individual_metrics(f"pass@1[{k}]", pred_judgement, gt_judgement, sample_idx)
70108

71109
def update(self, predictions):
72110
"""Updating the evaluation results with the current element.
@@ -76,12 +114,69 @@ def update(self, predictions):
76114
The content of the file is benchmark specific.
77115
"""
78116
super().update(predictions)
79-
predicted_answers = [is_correct_judgement(pred['judgement']) for pred in predictions]
117+
self.total_positives += float(is_correct_judgement_or_none(predictions[0]['expected_judgement']) is True)
118+
predicted_answers = [is_correct_judgement_or_none(pred['judgement']) for pred in predictions]
80119
self._compute_pass_at_k(predictions=predictions, predicted_answers=predicted_answers)
81120
self._compute_majority_at_k(predictions=predictions, predicted_answers=predicted_answers)
82121

122+
def _compute_precision_recall_f1(self, datapoint_metrics):
123+
"""Compute unbiased precision, recall, F1 by averaging over K samples."""
124+
# Find the maximum number of samples K across all datapoints
125+
max_k = max(len(sample_metrics) for sample_metrics in datapoint_metrics.values())
126+
127+
# Compute metrics for each of the K samples, then average across K
128+
sample_precision_values = []
129+
sample_recall_values = []
130+
sample_f1_values = []
131+
132+
for sample_idx in range(max_k):
133+
# Aggregate TP, FP, FN across all N datapoints for sample k
134+
total_tp, total_fp, total_fn = 0, 0, 0
135+
136+
for sample_metrics in datapoint_metrics.values():
137+
metrics = sample_metrics[sample_idx]
138+
total_tp += metrics['tp']
139+
total_fp += metrics['fp']
140+
total_fn += metrics['fn']
141+
142+
# Compute precision for sample k
143+
if total_tp + total_fp > 0:
144+
sample_precision = total_tp / (total_tp + total_fp)
145+
else:
146+
sample_precision = 1.0
147+
sample_precision_values.append(sample_precision)
148+
149+
# Compute recall for sample k
150+
if self.total_positives > 0:
151+
sample_recall = total_tp / self.total_positives
152+
else:
153+
sample_recall = 1.0
154+
sample_recall_values.append(sample_recall)
155+
156+
# Compute F1 for sample k
157+
if sample_precision + sample_recall > 0:
158+
sample_f1 = 2 * (sample_precision * sample_recall) / (sample_precision + sample_recall)
159+
else:
160+
sample_f1 = 0.0
161+
sample_f1_values.append(sample_f1)
162+
163+
# Average across all K samples
164+
return {
165+
'precision': 100 * sum(sample_precision_values) / max_k,
166+
'recall': 100 * sum(sample_recall_values) / max_k,
167+
'f1': 100 * sum(sample_f1_values) / max_k,
168+
}
169+
83170
def get_metrics(self):
84171
# renaming no_answer to invalid_judgements
85172
for agg_metric_dict in self.eval_dict.values():
86173
agg_metric_dict["invalid_judgements"] = agg_metric_dict.pop("no_answer")
87-
return super().get_metrics()
174+
175+
metrics_dict = super().get_metrics()
176+
177+
# Compute unbiased precision, recall, F1 by averaging over K samples
178+
for agg_key, datapoint_metrics in self.individual_metrics.items():
179+
if agg_key in metrics_dict:
180+
metrics_dict[agg_key].update(self._compute_precision_recall_f1(datapoint_metrics))
181+
182+
return metrics_dict

nemo_skills/evaluation/metrics/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def read_predictions(predictions, line_idx, file_handles):
3636
def is_correct_judgement(judgement, return_none=False) -> Union[bool, None]:
3737
if 'Judgement:' in judgement:
3838
verdict = judgement.split('Judgement:')[-1].strip()
39-
if verdict.lower() == 'yes':
39+
if verdict.lower().startswith('yes'):
4040
return True
41-
elif verdict.lower() == 'no':
41+
elif verdict.lower().startswith('no'):
4242
return False
4343

4444
if return_none:

tests/data/eval_outputs/eval-results/metrics.json-test

Lines changed: 70 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,88 +5,121 @@
55
"avg_tokens": 189,
66
"correct_judgements": 41.666666666666664,
77
"false_positives": 8.333333333333334,
8-
"false_negatives": 50.0,
9-
"invalid_judgements": 0.0
8+
"false_negatives": 33.333333333333336,
9+
"invalid_judgements": 16.666666666666668,
10+
"precision": 66.66666666666666,
11+
"recall": 28.57142857142857,
12+
"f1": 40.0
1013
},
1114
"pass@1[1]": {
1215
"num_entries": 12,
1316
"avg_tokens": 189,
1417
"correct_judgements": 41.666666666666664,
1518
"false_positives": 8.333333333333334,
16-
"false_negatives": 50.0,
17-
"invalid_judgements": 0.0
19+
"false_negatives": 33.333333333333336,
20+
"invalid_judgements": 16.666666666666668,
21+
"precision": 66.66666666666666,
22+
"recall": 28.57142857142857,
23+
"f1": 40.0
1824
},
1925
"pass@2": {
2026
"num_entries": 12,
2127
"avg_tokens": 189,
22-
"correct_judgements": 58.333333333333336,
23-
"false_positives": 0.0,
24-
"false_negatives": 41.666666666666664,
25-
"invalid_judgements": 0.0
28+
"correct_judgements": 50.0,
29+
"false_positives": 8.333333333333334,
30+
"false_negatives": 33.333333333333336,
31+
"invalid_judgements": 8.333333333333334,
32+
"precision": 75.0,
33+
"recall": 42.857142857142854,
34+
"f1": 54.54545454545454
2635
},
2736
"pass@1[2]": {
2837
"num_entries": 12,
2938
"avg_tokens": 189,
30-
"correct_judgements": 41.666666666666664,
31-
"false_positives": 8.333333333333334,
32-
"false_negatives": 50.0,
33-
"invalid_judgements": 0.0
39+
"correct_judgements": 37.5,
40+
"false_positives": 16.666666666666668,
41+
"false_negatives": 25.0,
42+
"invalid_judgements": 20.833333333333332,
43+
"precision": 58.33333333333333,
44+
"recall": 35.71428571428571,
45+
"f1": 43.07692307692308
3446
},
3547
"pass@3": {
3648
"num_entries": 12,
3749
"avg_tokens": 189,
38-
"correct_judgements": 66.66666666666667,
39-
"false_positives": 0.0,
40-
"false_negatives": 33.333333333333336,
41-
"invalid_judgements": 0.0
50+
"correct_judgements": 58.333333333333336,
51+
"false_positives": 16.666666666666668,
52+
"false_negatives": 25.0,
53+
"invalid_judgements": 0.0,
54+
"precision": 66.66666666666666,
55+
"recall": 57.14285714285714,
56+
"f1": 61.53846153846153
4257
},
4358
"pass@1[3]": {
4459
"num_entries": 12,
4560
"avg_tokens": 189,
46-
"correct_judgements": 44.444444444444436,
47-
"false_positives": 5.5555555555555545,
48-
"false_negatives": 49.99999999999998,
49-
"invalid_judgements": 0.0
61+
"correct_judgements": 41.666666666666664,
62+
"false_positives": 19.444444444444443,
63+
"false_negatives": 22.222222222222218,
64+
"invalid_judgements": 16.666666666666664,
65+
"precision": 57.93650793650793,
66+
"recall": 42.857142857142854,
67+
"f1": 47.765567765567766
5068
},
5169
"pass@4": {
5270
"num_entries": 12,
5371
"avg_tokens": 189,
5472
"correct_judgements": 66.66666666666667,
55-
"false_positives": 0.0,
56-
"false_negatives": 33.333333333333336,
57-
"invalid_judgements": 0.0
73+
"false_positives": 8.333333333333334,
74+
"false_negatives": 25.0,
75+
"invalid_judgements": 0.0,
76+
"precision": 80.0,
77+
"recall": 57.14285714285714,
78+
"f1": 66.66666666666666
5879
},
5980
"pass@1[4]": {
6081
"num_entries": 12,
6182
"avg_tokens": 189,
6283
"correct_judgements": 43.75,
63-
"false_positives": 4.166666666666667,
64-
"false_negatives": 52.083333333333336,
65-
"invalid_judgements": 0.0
84+
"false_positives": 18.75,
85+
"false_negatives": 22.916666666666668,
86+
"invalid_judgements": 14.583333333333334,
87+
"precision": 58.45238095238094,
88+
"recall": 42.857142857142854,
89+
"f1": 48.324175824175825
6690
},
6791
"majority@2": {
6892
"num_entries": 12,
6993
"avg_tokens": 189,
70-
"correct_judgements": 41.666666666666664,
71-
"false_positives": 0.0,
72-
"false_negatives": 58.333333333333336,
73-
"invalid_judgements": 0.0
94+
"correct_judgements": 50.0,
95+
"false_positives": 8.333333333333334,
96+
"false_negatives": 33.333333333333336,
97+
"invalid_judgements": 8.333333333333334,
98+
"precision": 75.0,
99+
"recall": 42.857142857142854,
100+
"f1": 54.54545454545454
74101
},
75102
"majority@3": {
76103
"num_entries": 12,
77104
"avg_tokens": 189,
78105
"correct_judgements": 41.666666666666664,
79-
"false_positives": 0.0,
80-
"false_negatives": 58.333333333333336,
81-
"invalid_judgements": 0.0
106+
"false_positives": 25.0,
107+
"false_negatives": 33.333333333333336,
108+
"invalid_judgements": 0.0,
109+
"precision": 50.0,
110+
"recall": 42.857142857142854,
111+
"f1": 46.15384615384615
82112
},
83113
"majority@4": {
84114
"num_entries": 12,
85115
"avg_tokens": 189,
86-
"correct_judgements": 41.666666666666664,
87-
"false_positives": 0.0,
88-
"false_negatives": 58.333333333333336,
89-
"invalid_judgements": 0.0
116+
"correct_judgements": 58.333333333333336,
117+
"false_positives": 16.666666666666668,
118+
"false_negatives": 25.0,
119+
"invalid_judgements": 0.0,
120+
"precision": 66.66666666666666,
121+
"recall": 57.14285714285714,
122+
"f1": 61.53846153846153
90123
}
91124
},
92125
"arena-hard": {

tests/data/eval_outputs/summarize_results_output.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
Please see metrics.json for MT-bench per-category breakdown
2-
----------------------------------------------------- answer-judge -----------------------------------------------------
3-
evaluation_mode | num_entries | avg_tokens | correct_judgements | false_positives | false_negatives | invalid_judgements
4-
pass@1[4] | 12 | 189 | 43.75% | 4.17% | 52.08% | 0.00%
5-
majority@4 | 12 | 189 | 41.67% | 0.00% | 58.33% | 0.00%
6-
pass@4 | 12 | 189 | 66.67% | 0.00% | 33.33% | 0.00%
2+
-------------------------------------------------------------------- answer-judge --------------------------------------------------------------------
3+
evaluation_mode | num_entries | avg_tokens | correct_judgements | false_positives | false_negatives | invalid_judgements | precision | recall | f1
4+
pass@1[4] | 12 | 189 | 43.75% | 18.75% | 22.92% | 14.58% | 58.45% | 42.86% | 48.32%
5+
majority@4 | 12 | 189 | 58.33% | 16.67% | 25.00% | 0.00% | 66.67% | 57.14% | 61.54%
6+
pass@4 | 12 | 189 | 66.67% | 8.33% | 25.00% | 0.00% | 80.00% | 57.14% | 66.67%
77

88

99
----------------------------------------- arena-hard -----------------------------------------

tests/test_metrics.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,15 @@ def test_metrics(tmp_path):
6868
metrics = json.load(f)
6969
with open(metrics_ref_path, "r") as f:
7070
metrics_ref = json.load(f)
71-
assert metrics == metrics_ref, "metrics.json does not match metrics.json-test"
71+
72+
def check_metrics_equal(metrics1, metrics2, path=""):
73+
if isinstance(metrics1, dict) and isinstance(metrics2, dict):
74+
assert set(metrics1.keys()) == set(metrics2.keys()), f"Keys mismatch at {path}"
75+
for k in metrics1:
76+
check_metrics_equal(metrics1[k], metrics2[k], f"{path}.{k}")
77+
elif isinstance(metrics1, (int, float)) and isinstance(metrics2, (int, float)):
78+
assert abs(metrics1 - metrics2) < 1e-6, f"Value mismatch at {path}: {metrics1} != {metrics2}"
79+
else:
80+
assert metrics1 == metrics2, f"Type mismatch at {path}: {type(metrics1)} != {type(metrics2)}"
81+
82+
check_metrics_equal(metrics, metrics_ref)

0 commit comments

Comments
 (0)