1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from collections import Counter , defaultdict
16- from typing import Union
1715from nemo_skills .evaluation .metrics .base import BaseMetrics
1816from nemo_skills .evaluation .metrics .utils import is_correct_judgement
1917
2018
2119class AnswerJudgementMetrics (BaseMetrics ):
22- def __init__ (self ):
23- self .reset ()
20+ def _get_score_dict (self , prediction : dict ) -> dict [str , bool | int | float ]:
21+ gt_judgement = is_correct_judgement (prediction ['expected_judgement' ])
22+ pred_judgement = is_correct_judgement (prediction ['judgement' ])
2423
25- def update_perf_dict (self , perf_dict , is_correct , is_fp , is_fn , invalid_count ):
26- perf_dict ["total_correct" ] += float (is_correct )
27- perf_dict ["fp_count" ] += float (is_fp )
28- perf_dict ["fn_count" ] += float (is_fn )
29- perf_dict ["invalid_count" ] += float (invalid_count )
30-
31- def get_judgement_by_type (self , predictions , judgement_type : str , gt_judgement : bool ) -> Union [bool , None ]:
32- answers = [c for elem in predictions if (c := is_correct_judgement (elem ['judgement' ])) is not None ]
33- if len (answers ) == 0 :
34- return None
35- if judgement_type == "majority" :
36- return Counter (answers ).most_common (1 )[0 ][0 ]
37- elif judgement_type == "pass" :
38- for answer in answers :
39- if answer == gt_judgement :
40- return answer
41- return answers [0 ]
42- else :
43- raise ValueError (f"Invalid judgement type: { judgement_type } " )
44-
45- def get_judgement_metrics (self , pred_judgement , gt_judgement ):
46- is_fp , is_fn = False , False
47- is_invalid = pred_judgement is None
48- is_correct = pred_judgement == gt_judgement
49- if not is_correct :
50- if pred_judgement == True :
51- is_fp = True
52- elif pred_judgement == False :
53- is_fn = True
54- return is_correct , is_fp , is_fn , is_invalid
55-
24+ return {'correct_judgements' : gt_judgement == pred_judgement }
25+
26+ def _update_fp_fn (self , metrics_dict , pred_judgement , gt_judgement , divide_by = 1 ):
27+ is_fp = pred_judgement is True and gt_judgement is False
28+ is_fn = pred_judgement is False and gt_judgement is True
29+ metrics_dict ['false_positives' ] += float (is_fp ) / divide_by
30+ metrics_dict ['false_negatives' ] += float (is_fn ) / divide_by
31+
32+ def _update_score_metrics_for_majority (
33+ self ,
34+ eval_dict : dict ,
35+ k : int ,
36+ score_method : str ,
37+ score_dicts : list [dict ],
38+ majority_score : bool | float | int ,
39+ majority_answer : str ,
40+ predictions : list [dict ],
41+ predicted_answers : list [str ],
42+ ):
43+ assert score_method == 'correct_judgements'
44+ # expected answer is always the same for all predictions, so just take the first one
45+ gt_judgement = is_correct_judgement (predictions [0 ]['expected_judgement' ])
46+ self ._update_fp_fn (eval_dict [f"majority@{ k } " ], majority_answer , gt_judgement )
47+
48+ def _update_score_metrics_for_pass (
49+ self ,
50+ eval_dict : dict ,
51+ k : int ,
52+ score_method : str ,
53+ score_dicts : list [dict ],
54+ pass_score : bool | float | int ,
55+ predictions : list [dict ],
56+ predicted_answers : list [str ] | None ,
57+ ):
58+ assert score_method == 'correct_judgements'
59+ # expected answer is always the same for all predictions, so just take the first one
60+ gt_judgement = is_correct_judgement (predictions [0 ]['expected_judgement' ])
61+ pred_judgement = is_correct_judgement (predictions [0 ]['judgement' ])
62+ # if pass is not correct, means all predictions are the same and wrong
63+ if not pass_score :
64+ self ._update_fp_fn (eval_dict [f"pass@{ k } " ], pred_judgement , gt_judgement )
65+
66+ for pred in predictions [:k ]:
67+ gt_judgement = is_correct_judgement (pred ['expected_judgement' ])
68+ pred_judgement = is_correct_judgement (pred ['judgement' ])
69+ self ._update_fp_fn (eval_dict [f"pass@1[{ k } ]" ], pred_judgement , gt_judgement , divide_by = k )
5670
5771 def update (self , predictions ):
5872 """Updating the evaluation results with the current element.
@@ -61,47 +75,13 @@ def update(self, predictions):
6175 predictions (list[dict]): aggregated predictions across all generations.
6276 The content of the file is benchmark specific.
6377 """
64- self .total += 1
65- gt_judgement = is_correct_judgement (predictions [0 ]['expected_judgement' ])
66- if len (predictions ) > 1 :
67- # Majority@k, Pass@k, Pass@1[k]
68- for k in range (len (predictions ), 0 , - 1 ):
69- pred_subset = predictions [:k ]
70- majority_judgement = self .get_judgement_by_type (pred_subset , "majority" , gt_judgement )
71- majority_metrics = self .get_judgement_metrics (majority_judgement , gt_judgement )
72- self .update_perf_dict (self .agg_mode_dict [f"majority@{ k } " ], * majority_metrics )
73-
74- pass_judgement = self .get_judgement_by_type (pred_subset , "pass" , gt_judgement )
75- pass_metrics = self .get_judgement_metrics (pass_judgement , gt_judgement )
76- self .update_perf_dict (self .agg_mode_dict [f"pass@{ k } " ], * pass_metrics )
77-
78- pass1_k_metrics = [self .get_judgement_metrics (is_correct_judgement (prediction ['judgement' ]), gt_judgement ) for prediction in pred_subset ]
79- avg_pass1_k_metrics = [sum (metrics ) / len (metrics ) for metrics in zip (* pass1_k_metrics )]
80- self .update_perf_dict (self .agg_mode_dict [f"pass@1[{ k } ]" ], * avg_pass1_k_metrics )
81-
82- # Greedy
83- if len (predictions ) == 1 :
84- per_sample_metrics = self .get_judgement_metrics (is_correct_judgement (predictions [0 ]['judgement' ]), gt_judgement )
85- self .update_perf_dict (self .agg_mode_dict ["greedy" ], * per_sample_metrics )
86- return
87-
78+ super ().update (predictions )
79+ predicted_answers = [is_correct_judgement (pred ['judgement' ]) for pred in predictions ]
80+ self ._compute_pass_at_k (predictions = predictions , predicted_answers = predicted_answers )
81+ self ._compute_majority_at_k (predictions = predictions , predicted_answers = predicted_answers )
8882
8983 def get_metrics (self ):
90- metrics_dict = {}
91- for agg_mode , agg_metric_dict in self .agg_mode_dict .items ():
92- metrics_dict [agg_mode ] = {"num_entries" : self .total }
93-
94- metrics_dict [agg_mode ]["correct_judgements" ] = (agg_metric_dict ["total_correct" ] / self .total ) * 100.0
95- metrics_dict [agg_mode ]["false_positives" ] = (agg_metric_dict ["fp_count" ] / self .total ) * 100.0
96- metrics_dict [agg_mode ]["false_negatives" ] = (agg_metric_dict ["fn_count" ] / self .total ) * 100.0
97- metrics_dict [agg_mode ]["invalid_judgements" ] = (agg_metric_dict ["invalid_count" ] / self .total ) * 100.0
98-
99- return metrics_dict
100-
101- def reset (self ):
102- self .total = 0
103- self .agg_mode_dict = defaultdict (lambda : defaultdict (int ))
104-
105- def max_aggregations_to_print (self ):
106- # majority + pass + pass@1[k]
107- return 1 + 1 + 1
84+ # renaming no_answer to invalid_judgements
85+ for agg_metric_dict in self .eval_dict .values ():
86+ agg_metric_dict ["invalid_judgements" ] = agg_metric_dict .pop ("no_answer" )
87+ return super ().get_metrics ()
0 commit comments