Skip to content

Commit 2839e25

Browse files
committed
Add pass@1[avg-of-N] support to ArenaMetrics
Signed-off-by: Jakub Slowikowski <jslowikowski@nvidia.com>
1 parent 589294c commit 2839e25

2 files changed

Lines changed: 153 additions & 73 deletions

File tree

nemo_skills/evaluation/metrics/arena_metrics.py

Lines changed: 120 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,18 @@
1414

1515
import re
1616
from collections import defaultdict
17+
from statistics import mean
1718

1819
from nemo_skills.evaluation.metrics.base import BaseMetrics
1920

21+
# Score-label preference for picking the best of N predictions per judgement direction.
22+
# "Best" = most candidate-favorable. The two preference orders are mirrored because the
23+
# judge prompts swap the A/B slot assignments to mitigate position bias:
24+
# judgement-gen-base: A = candidate's answer, B = baseline's answer
25+
# judgement-base-gen: A = baseline's answer, B = candidate's answer
26+
_GEN_BASE_PREFERENCE = ("A>>B", "A>B", "A=B", "B>A", "B>>A")
27+
_BASE_GEN_PREFERENCE = ("B>>A", "B>A", "A=B", "A>B", "A>>B")
28+
2029

2130
class ArenaMetrics(BaseMetrics):
2231
def __init__(self):
@@ -40,87 +49,127 @@ def get_incorrect_sample(self, prediction: dict) -> dict:
4049
prediction["judgement-base-gen"] = "Rating: [[B>>A]]"
4150
return prediction
4251

52+
@staticmethod
53+
def _best_pair(prompt_pairs):
54+
"""Pick the most candidate-favorable label across all predictions, per direction."""
55+
gen_base_pool = [pair[0] for pair in prompt_pairs]
56+
base_gen_pool = [pair[1] for pair in prompt_pairs]
57+
return [
58+
next((s for s in _GEN_BASE_PREFERENCE if s in gen_base_pool), None),
59+
next((s for s in _BASE_GEN_PREFERENCE if s in base_gen_pool), None),
60+
]
61+
4362
def update(self, predictions):
44-
"""Updating the evaluation results with the current element.
63+
"""Store all per-prediction (gen-base, base-gen) score pairs for this prompt.
4564
46-
Args:
47-
predictions (list[dict]): aggregated predictions across all generations.
48-
The content of the file is benchmark specific.
65+
Aggregation is deferred to get_metrics() so that both pass@N (best-of-N) and
66+
pass@1[avg-of-N] can be derived from the same stored data.
4967
"""
50-
# this shouldn't do any heavy calculation, but just read the metric from existing json entry
51-
# all the heavy lifting should be done in the evaluation script
5268
super().update(predictions)
53-
self.scores.append([])
54-
self.agg_mode = f"pass@{len(predictions)}"
55-
56-
# Track category for per-category scoring (defaults to None for v1 compatibility)
57-
category = predictions[0].get("category")
58-
self.categories.append(category)
59-
60-
if len(predictions) > 1:
61-
judge_scores = [self._get_judge_score(elem["judgement-gen-base"]) for elem in predictions]
62-
# adding the best score out of all the generations
63-
possible_scores = ["A>>B", "A>B", "A=B", "B>A", "B>>A"]
64-
for possible_score in possible_scores:
65-
# picking the best available score
66-
if any([score == possible_score for score in judge_scores]):
67-
self.scores[-1].append(possible_score)
68-
best_id = judge_scores.index(possible_score)
69-
self.lengths += predictions[best_id].get("num_generated_tokens", 0)
70-
break
71-
else:
72-
self.scores[-1].append(None) # in case judge didn't generate a valid score
73-
74-
judge_scores = [self._get_judge_score(elem["judgement-base-gen"]) for elem in predictions]
75-
# second score is grading swapped answers, so we iterate from the end
76-
for possible_score in possible_scores[::-1]:
77-
# picking the best available score
78-
if any([score == possible_score for score in judge_scores]):
79-
self.scores[-1].append(possible_score)
80-
best_id = judge_scores.index(possible_score)
81-
self.lengths += predictions[best_id].get("num_generated_tokens", 0)
82-
break
83-
else:
84-
self.scores[-1].append(None) # in case judge didn't generate a valid score
85-
else:
86-
self.lengths += predictions[0].get("num_generated_tokens", 0)
87-
self.scores[-1] = [
88-
self._get_judge_score(predictions[0]["judgement-gen-base"]),
89-
self._get_judge_score(predictions[0]["judgement-base-gen"]),
69+
self.per_prompt_scores.append(
70+
[
71+
(
72+
self._get_judge_score(p["judgement-gen-base"]),
73+
self._get_judge_score(p["judgement-base-gen"]),
74+
)
75+
for p in predictions
9076
]
77+
)
78+
self.categories.append(predictions[0].get("category"))
9179

9280
def get_metrics(self):
93-
from nemo_skills.evaluation.evaluator.arena import get_aggregate_score
94-
95-
metrics_dict = {}
96-
97-
# Compute overall metrics
98-
overall_metrics = {"num_entries": self.total}
99-
overall_metrics.update(get_aggregate_score(self.scores))
100-
self.update_common_metrics(overall_metrics)
81+
n = self.max_k or 1
82+
emit_categories = len(set(self.categories)) > 1
83+
84+
# pass@N (best-of-N): pick the most candidate-favorable label per direction across
85+
# all N predictions per prompt, then run Elo on those 1-pair-per-prompt lists.
86+
best_of_n = [self._best_pair(pairs) for pairs in self.per_prompt_scores]
87+
metrics_dict = {f"pass@{n}": self._aggregate(best_of_n, emit_categories)}
88+
89+
# pass@1[avg-of-N]: N independent single-shot Elo bootstraps (one per repeat),
90+
# averaged. Skipped for N==1 since avg-of-1 is degenerate with pass@1.
91+
if n > 1:
92+
per_repeat_aggs = [
93+
self._aggregate(
94+
[list(pairs[r]) for pairs in self.per_prompt_scores],
95+
emit_categories,
96+
)
97+
for r in range(n)
98+
]
99+
metrics_dict[f"pass@1[avg-of-{n}]"] = self._average_aggregations(per_repeat_aggs)
101100

102-
# Group scores by category for per-category metrics
103-
category_scores = defaultdict(list)
104-
for score, category in zip(self.scores, self.categories, strict=True):
105-
category_scores[category].append(score)
101+
return metrics_dict
106102

107-
# If we have multiple categories, compute per-category metrics
108-
unique_categories = set(self.categories)
109-
if len(unique_categories) > 1:
110-
for category, scores in category_scores.items():
111-
cat_metrics = {"num_entries": len(scores)}
112-
cat_metrics.update(get_aggregate_score(scores))
113-
overall_metrics[f"category_{category}"] = cat_metrics
103+
def _aggregate(self, prompt_pairs, emit_categories):
104+
"""Run get_aggregate_score on a list of (gen-base, base-gen) pairs (one per prompt)."""
105+
from nemo_skills.evaluation.evaluator.arena import get_aggregate_score
114106

115-
metrics_dict[self.agg_mode] = overall_metrics
116-
# arena metrics have their own confidence estimation, so not doing std metrics here
117-
return metrics_dict
107+
agg = {"num_entries": self.total}
108+
agg.update(self._native_aggregate_score(get_aggregate_score(prompt_pairs)))
109+
self.update_common_metrics(agg)
110+
111+
if emit_categories:
112+
by_category = defaultdict(list)
113+
for pair, category in zip(prompt_pairs, self.categories, strict=True):
114+
by_category[category].append(pair)
115+
for category, pairs in by_category.items():
116+
cat_agg = {"num_entries": len(pairs)}
117+
cat_agg.update(self._native_aggregate_score(get_aggregate_score(pairs)))
118+
agg[f"category_{category}"] = cat_agg
119+
120+
return agg
121+
122+
@staticmethod
123+
def _native_aggregate_score(agg):
124+
"""Cast get_aggregate_score's numpy types to native Python — yaml.safe_dump can't serialize numpy."""
125+
return {
126+
"score": float(agg["score"]),
127+
"95_CI": tuple(float(x) for x in agg["95_CI"]),
128+
"invalid_scores": int(agg["invalid_scores"]),
129+
}
130+
131+
def _average_aggregations(self, per_repeat):
132+
"""Average a list of per-repeat aggregation dicts to produce pass@1[avg-of-N].
133+
134+
- 'score': mean across repeats.
135+
- 'invalid_scores': summed across repeats (total invalid-judgement count).
136+
- '95_CI': dropped (mean of CIs is not a meaningful CI).
137+
- num_entries / avg_tokens / gen_seconds: same across repeats; populated by
138+
update_common_metrics.
139+
- Per-category sub-dicts: averaged using the same rules.
140+
141+
Per-repeat scores are not surfaced here because downstream metric parsers
142+
(e.g. nemo-evaluator-launcher's `core_evals/nemo_skills/output.py`) wrap each
143+
leaf value in a `Score(value=float)` and reject lists. Consumers who need the
144+
per-repeat breakdown can recompute it from `output-rs*.jsonl`.
145+
"""
146+
# Cast to native Python float — get_aggregate_score returns numpy.float64,
147+
# which yaml.safe_dump can't serialize.
148+
avg = {"num_entries": per_repeat[0]["num_entries"]}
149+
avg["score"] = float(mean(m["score"] for m in per_repeat))
150+
avg["invalid_scores"] = sum(m["invalid_scores"] for m in per_repeat)
151+
self.update_common_metrics(avg)
152+
153+
for cat_key in [k for k in per_repeat[0] if k.startswith("category_")]:
154+
cat_avg = {"num_entries": per_repeat[0][cat_key]["num_entries"]}
155+
cat_avg["score"] = float(mean(m[cat_key]["score"] for m in per_repeat))
156+
cat_avg["invalid_scores"] = sum(m[cat_key]["invalid_scores"] for m in per_repeat)
157+
avg[cat_key] = cat_avg
158+
159+
return avg
160+
161+
def evaluations_to_print(self):
162+
# Override BaseMetrics' default — Arena doesn't compute majority@k, so dropping
163+
# that key avoids a missing-key request to the framework's printer (matches the
164+
# OmniMetrics convention).
165+
if self.max_k > 1:
166+
return [f"pass@{self.max_k}", f"pass@1[avg-of-{self.max_k}]"]
167+
return ["pass@1"]
118168

119169
def reset(self):
120170
super().reset()
121-
self.scores = [] # list of lists
122-
self.categories = [] # list of category strings
123-
self.lengths = 0
124-
# TODO: the class should support pass@k, but this forces it to report as pass@1.
125-
# There is some error here for k>1
126-
self.agg_mode = "pass@1"
171+
# Per-prompt list of (gen-base, base-gen) score pairs — N tuples per prompt where
172+
# N == self.max_k. Aggregation is deferred to get_metrics() so both pass@N
173+
# (best-of-N) and pass@1[avg-of-N] can be derived from the same data.
174+
self.per_prompt_scores = []
175+
self.categories = []

tests/test_arena_metrics.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ def test_arena_metrics_score_parsing():
136136
for gen_base, base_gen in test_cases:
137137
m.reset()
138138
m.update([_make_prediction(gen_base, base_gen, category="test")])
139-
assert m.scores[0] == [gen_base, base_gen]
139+
# per_prompt_scores stores the full per-prediction pair list (one tuple per
140+
# prediction). For N=1 it's a single-element list.
141+
assert m.per_prompt_scores[0] == [(gen_base, base_gen)]
140142

141143

142144
def test_arena_metrics_invalid_score_handling():
@@ -151,4 +153,33 @@ def test_arena_metrics_invalid_score_handling():
151153
}
152154
m.update([pred])
153155

154-
assert m.scores[0] == [None, None]
156+
assert m.per_prompt_scores[0] == [(None, None)]
157+
158+
159+
def test_arena_metrics_pass_at_k_with_repeats():
160+
"""num_repeats > 1 should emit pass@N (best-of-N) and pass@1[avg-of-N] (avg-of-N)."""
161+
m = ArenaMetrics()
162+
random.seed(42)
163+
scores_pool = [("A>B", "B>A"), ("B>A", "A>B"), ("A=B", "A=B"), ("A>>B", "B>>A"), ("B>>A", "A>>B")]
164+
165+
n_prompts, n_repeats = 50, 5
166+
for _ in range(n_prompts):
167+
preds = [_make_prediction(*random.choice(scores_pool), category="test") for _ in range(n_repeats)]
168+
m.update(preds)
169+
170+
metrics = m.get_metrics()
171+
172+
pass_at_n = f"pass@{n_repeats}"
173+
avg_of_n = f"pass@1[avg-of-{n_repeats}]"
174+
assert set(metrics.keys()) == {pass_at_n, avg_of_n}, (
175+
f"Expected exactly {{{pass_at_n}, {avg_of_n}}}, got {set(metrics.keys())}"
176+
)
177+
178+
assert metrics[pass_at_n]["num_entries"] == n_prompts
179+
assert metrics[avg_of_n]["num_entries"] == n_prompts
180+
181+
# best-of-N can never score worse than avg-of-N: picking the best out of N
182+
# generations is at least as candidate-favorable as the average single-shot.
183+
assert metrics[pass_at_n]["score"] >= metrics[avg_of_n]["score"], (
184+
f"best-of-N ({metrics[pass_at_n]['score']}) < avg-of-N ({metrics[avg_of_n]['score']})"
185+
)

0 commit comments

Comments
 (0)