Skip to content

Commit 098f04d

Browse files
Kipokshan18
andauthored
Refactoring of the metrics logic (#504)
Signed-off-by: Igor Gitman <igitman@nvidia.com> Co-authored-by: Shantanu Acharya <shantanua@nvidia.com> Co-authored-by: Shantanu Acharya <shan.sacharya@gmail.com>
1 parent 2693427 commit 098f04d

42 files changed

Lines changed: 1455 additions & 583 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

nemo_skills/evaluation/metrics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from nemo_skills.evaluation.metrics.base import default_formatting
1516
from nemo_skills.evaluation.metrics.compute_metrics import ComputeMetrics
16-
from nemo_skills.evaluation.metrics.utils import read_predictions
1717
from nemo_skills.evaluation.metrics.map_metrics import get_metrics
18+
from nemo_skills.evaluation.metrics.utils import read_predictions

nemo_skills/evaluation/metrics/answer_judgement_metrics.py

Lines changed: 57 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -12,47 +12,61 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from collections import Counter, defaultdict
16-
from typing import Union
1715
from nemo_skills.evaluation.metrics.base import BaseMetrics
1816
from nemo_skills.evaluation.metrics.utils import is_correct_judgement
1917

2018

2119
class AnswerJudgementMetrics(BaseMetrics):
22-
def __init__(self):
23-
self.reset()
20+
def _get_score_dict(self, prediction: dict) -> dict[str, bool | int | float]:
21+
gt_judgement = is_correct_judgement(prediction['expected_judgement'])
22+
pred_judgement = is_correct_judgement(prediction['judgement'])
2423

25-
def update_perf_dict(self, perf_dict, is_correct, is_fp, is_fn, invalid_count):
26-
perf_dict["total_correct"] += float(is_correct)
27-
perf_dict["fp_count"] += float(is_fp)
28-
perf_dict["fn_count"] += float(is_fn)
29-
perf_dict["invalid_count"] += float(invalid_count)
30-
31-
def get_judgement_by_type(self, predictions, judgement_type: str, gt_judgement: bool) -> Union[bool, None]:
32-
answers = [c for elem in predictions if (c:=is_correct_judgement(elem['judgement'])) is not None]
33-
if len(answers) == 0:
34-
return None
35-
if judgement_type == "majority":
36-
return Counter(answers).most_common(1)[0][0]
37-
elif judgement_type == "pass":
38-
for answer in answers:
39-
if answer == gt_judgement:
40-
return answer
41-
return answers[0]
42-
else:
43-
raise ValueError(f"Invalid judgement type: {judgement_type}")
44-
45-
def get_judgement_metrics(self, pred_judgement, gt_judgement):
46-
is_fp, is_fn = False, False
47-
is_invalid = pred_judgement is None
48-
is_correct = pred_judgement == gt_judgement
49-
if not is_correct:
50-
if pred_judgement == True:
51-
is_fp = True
52-
elif pred_judgement == False:
53-
is_fn = True
54-
return is_correct, is_fp, is_fn, is_invalid
55-
24+
return {'correct_judgements': gt_judgement == pred_judgement}
25+
26+
def _update_fp_fn(self, metrics_dict, pred_judgement, gt_judgement, divide_by=1):
27+
is_fp = pred_judgement is True and gt_judgement is False
28+
is_fn = pred_judgement is False and gt_judgement is True
29+
metrics_dict['false_positives'] += float(is_fp) / divide_by
30+
metrics_dict['false_negatives'] += float(is_fn) / divide_by
31+
32+
def _update_score_metrics_for_majority(
33+
self,
34+
eval_dict: dict,
35+
k: int,
36+
score_method: str,
37+
score_dicts: list[dict],
38+
majority_score: bool | float | int,
39+
majority_answer: str,
40+
predictions: list[dict],
41+
predicted_answers: list[str],
42+
):
43+
assert score_method == 'correct_judgements'
44+
# expected answer is always the same for all predictions, so just take the first one
45+
gt_judgement = is_correct_judgement(predictions[0]['expected_judgement'])
46+
self._update_fp_fn(eval_dict[f"majority@{k}"], majority_answer, gt_judgement)
47+
48+
def _update_score_metrics_for_pass(
49+
self,
50+
eval_dict: dict,
51+
k: int,
52+
score_method: str,
53+
score_dicts: list[dict],
54+
pass_score: bool | float | int,
55+
predictions: list[dict],
56+
predicted_answers: list[str] | None,
57+
):
58+
assert score_method == 'correct_judgements'
59+
# expected answer is always the same for all predictions, so just take the first one
60+
gt_judgement = is_correct_judgement(predictions[0]['expected_judgement'])
61+
pred_judgement = is_correct_judgement(predictions[0]['judgement'])
62+
# if pass is not correct, means all predictions are the same and wrong
63+
if not pass_score:
64+
self._update_fp_fn(eval_dict[f"pass@{k}"], pred_judgement, gt_judgement)
65+
66+
for pred in predictions[:k]:
67+
gt_judgement = is_correct_judgement(pred['expected_judgement'])
68+
pred_judgement = is_correct_judgement(pred['judgement'])
69+
self._update_fp_fn(eval_dict[f"pass@1[{k}]"], pred_judgement, gt_judgement, divide_by=k)
5670

5771
def update(self, predictions):
5872
"""Updating the evaluation results with the current element.
@@ -61,47 +75,13 @@ def update(self, predictions):
6175
predictions (list[dict]): aggregated predictions across all generations.
6276
The content of the file is benchmark specific.
6377
"""
64-
self.total += 1
65-
gt_judgement = is_correct_judgement(predictions[0]['expected_judgement'])
66-
if len(predictions) > 1:
67-
# Majority@k, Pass@k, Pass@1[k]
68-
for k in range(len(predictions), 0, -1):
69-
pred_subset = predictions[:k]
70-
majority_judgement = self.get_judgement_by_type(pred_subset, "majority", gt_judgement)
71-
majority_metrics = self.get_judgement_metrics(majority_judgement, gt_judgement)
72-
self.update_perf_dict(self.agg_mode_dict[f"majority@{k}"], *majority_metrics)
73-
74-
pass_judgement = self.get_judgement_by_type(pred_subset, "pass", gt_judgement)
75-
pass_metrics = self.get_judgement_metrics(pass_judgement, gt_judgement)
76-
self.update_perf_dict(self.agg_mode_dict[f"pass@{k}"], *pass_metrics)
77-
78-
pass1_k_metrics = [self.get_judgement_metrics(is_correct_judgement(prediction['judgement']), gt_judgement) for prediction in pred_subset]
79-
avg_pass1_k_metrics = [sum(metrics) / len(metrics) for metrics in zip(*pass1_k_metrics)]
80-
self.update_perf_dict(self.agg_mode_dict[f"pass@1[{k}]"], *avg_pass1_k_metrics)
81-
82-
# Greedy
83-
if len(predictions) == 1:
84-
per_sample_metrics = self.get_judgement_metrics(is_correct_judgement(predictions[0]['judgement']), gt_judgement)
85-
self.update_perf_dict(self.agg_mode_dict["greedy"], *per_sample_metrics)
86-
return
87-
78+
super().update(predictions)
79+
predicted_answers = [is_correct_judgement(pred['judgement']) for pred in predictions]
80+
self._compute_pass_at_k(predictions=predictions, predicted_answers=predicted_answers)
81+
self._compute_majority_at_k(predictions=predictions, predicted_answers=predicted_answers)
8882

8983
def get_metrics(self):
90-
metrics_dict = {}
91-
for agg_mode, agg_metric_dict in self.agg_mode_dict.items():
92-
metrics_dict[agg_mode] = {"num_entries": self.total}
93-
94-
metrics_dict[agg_mode]["correct_judgements"] = (agg_metric_dict["total_correct"] / self.total) * 100.0
95-
metrics_dict[agg_mode]["false_positives"] = (agg_metric_dict["fp_count"] / self.total) * 100.0
96-
metrics_dict[agg_mode]["false_negatives"] = (agg_metric_dict["fn_count"] / self.total) * 100.0
97-
metrics_dict[agg_mode]["invalid_judgements"] = (agg_metric_dict["invalid_count"] / self.total) * 100.0
98-
99-
return metrics_dict
100-
101-
def reset(self):
102-
self.total = 0
103-
self.agg_mode_dict = defaultdict(lambda: defaultdict(int))
104-
105-
def max_aggregations_to_print(self):
106-
# majority + pass + pass@1[k]
107-
return 1 + 1 + 1
84+
# renaming no_answer to invalid_judgements
85+
for agg_metric_dict in self.eval_dict.values():
86+
agg_metric_dict["invalid_judgements"] = agg_metric_dict.pop("no_answer")
87+
return super().get_metrics()

nemo_skills/evaluation/metrics/arena_metrics.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,10 @@ def update(self, predictions):
7676
"""
7777
# this shouldn't do any heavy calculation, but just read the metric from existing json entry
7878
# all the heavy lifting should be done in the evaluation script
79-
self.total += 1
79+
super().update(predictions)
8080
self.scores.append([])
81+
self.agg_mode = f"pass@{len(predictions)}"
8182
if len(predictions) > 1:
82-
self.agg_mode = f"pass@{len(predictions)}"
83-
8483
judge_scores = [self._get_judge_score(elem['judgement-gen-base']) for elem in predictions]
8584
# adding the best score out of all the generations
8685
possible_scores = ['A>>B', 'A>B', 'A=B', 'B>A', 'B>>A']
@@ -89,7 +88,7 @@ def update(self, predictions):
8988
if any([score == possible_score for score in judge_scores]):
9089
self.scores[-1].append(possible_score)
9190
best_id = judge_scores.index(possible_score)
92-
self.lengths += len(predictions[best_id]['generation'])
91+
self.lengths += predictions[best_id].get('num_generated_tokens', 0)
9392
break
9493
else:
9594
self.scores[-1].append(None) # in case judge didn't generate a valid score
@@ -101,15 +100,12 @@ def update(self, predictions):
101100
if any([score == possible_score for score in judge_scores]):
102101
self.scores[-1].append(possible_score)
103102
best_id = judge_scores.index(possible_score)
104-
self.lengths += len(predictions[best_id]['generation'])
103+
self.lengths += predictions[best_id].get('num_generated_tokens', 0)
105104
break
106105
else:
107106
self.scores[-1].append(None) # in case judge didn't generate a valid score
108107
else:
109-
# Single prediction
110-
self.agg_mode = "greedy"
111-
112-
self.lengths += len(predictions[0]['generation'])
108+
self.lengths += predictions[0].get('num_generated_tokens', 0)
113109
self.scores[-1] = [
114110
self._get_judge_score(predictions[0]['judgement-gen-base']),
115111
self._get_judge_score(predictions[0]['judgement-base-gen']),
@@ -120,12 +116,12 @@ def get_metrics(self):
120116

121117
metrics = {'num_entries': self.total}
122118
metrics.update(get_aggregate_score(self.scores))
123-
metrics['avg_response_length'] = self.lengths / self.total
119+
if self.lengths > 0:
120+
metrics['avg_response_tokens'] = int(self.lengths / self.total)
124121
return {self.agg_mode: metrics}
125122

126123
def reset(self):
124+
super().reset()
127125
self.scores = [] # list of lists
128126
self.lengths = 0
129-
self.total = 0
130-
# Set automatically
131-
self.agg_mode = "greedy"
127+
self.agg_mode = "pass@1"

0 commit comments

Comments
 (0)