1414
1515from nemo_skills .evaluation .metrics .base import BaseMetrics
1616from nemo_skills .evaluation .metrics .utils import is_correct_judgement
17+ from functools import partial
18+ from collections import defaultdict
1719
20+ is_correct_judgement_or_none = partial (is_correct_judgement , return_none = True )
1821
1922class AnswerJudgementMetrics (BaseMetrics ):
23+ def __init__ (self ):
24+ super ().__init__ ()
25+ # Store individual TP/FP/FN/TN values as N x K matrix (N datapoints, K samples each)
26+ self .total_positives = 0
27+ self .individual_metrics = defaultdict (lambda : defaultdict (lambda : defaultdict (dict )))
28+
29+ def reset (self ):
30+ super ().reset ()
31+ self .individual_metrics = defaultdict (lambda : defaultdict (lambda : defaultdict (dict )))
32+
2033 def _get_score_dict (self , prediction : dict ) -> dict [str , bool | int | float ]:
21- gt_judgement = is_correct_judgement (prediction ['expected_judgement' ])
22- pred_judgement = is_correct_judgement (prediction ['judgement' ])
34+ gt_judgement = is_correct_judgement_or_none (prediction ['expected_judgement' ])
35+ pred_judgement = is_correct_judgement_or_none (prediction ['judgement' ])
2336
2437 return {'correct_judgements' : gt_judgement == pred_judgement }
2538
39+ def _store_individual_metrics (self , agg_key , pred_judgement , gt_judgement , sample_idx = 0 ):
40+ """Store individual TP/FP/FN/TN values in N x K matrix structure."""
41+ is_fp = pred_judgement is True and gt_judgement is False
42+ is_fn = pred_judgement is False and gt_judgement is True
43+ is_tp = pred_judgement is True and gt_judgement is True
44+ is_tn = pred_judgement is False and gt_judgement is False
45+
46+ # Store in N x K matrix: [datapoint_idx][sample_idx]
47+ # This is hacky, but the only way to access the datapoint_idx
48+ datapoint_idx = self .total - 1
49+ self .individual_metrics [agg_key ][datapoint_idx ][sample_idx ] = {
50+ 'tp' : float (is_tp ),
51+ 'fp' : float (is_fp ),
52+ 'fn' : float (is_fn ),
53+ 'tn' : float (is_tn )
54+ }
55+
2656 def _update_fp_fn (self , metrics_dict , pred_judgement , gt_judgement , divide_by = 1 ):
2757 is_fp = pred_judgement is True and gt_judgement is False
2858 is_fn = pred_judgement is False and gt_judgement is True
59+
2960 metrics_dict ['false_positives' ] += float (is_fp ) / divide_by
3061 metrics_dict ['false_negatives' ] += float (is_fn ) / divide_by
3162
@@ -42,8 +73,9 @@ def _update_score_metrics_for_majority(
4273 ):
4374 assert score_method == 'correct_judgements'
4475 # expected answer is always the same for all predictions, so just take the first one
45- gt_judgement = is_correct_judgement (predictions [0 ]['expected_judgement' ])
76+ gt_judgement = is_correct_judgement_or_none (predictions [0 ]['expected_judgement' ])
4677 self ._update_fp_fn (eval_dict [f"majority@{ k } " ], majority_answer , gt_judgement )
78+ self ._store_individual_metrics (f"majority@{ k } " , majority_answer , gt_judgement )
4779
4880 def _update_score_metrics_for_pass (
4981 self ,
@@ -57,16 +89,22 @@ def _update_score_metrics_for_pass(
5789 ):
5890 assert score_method == 'correct_judgements'
5991 # expected answer is always the same for all predictions, so just take the first one
60- gt_judgement = is_correct_judgement (predictions [0 ]['expected_judgement' ])
61- pred_judgement = is_correct_judgement (predictions [0 ]['judgement' ])
62- # if pass is not correct, means all predictions are the same and wrong
63- if not pass_score :
64- self ._update_fp_fn (eval_dict [f"pass@{ k } " ], pred_judgement , gt_judgement )
92+ gt_judgement = is_correct_judgement_or_none (predictions [0 ]['expected_judgement' ])
93+ pred_judgements = [is_correct_judgement_or_none (pred ['judgement' ]) for pred in predictions [:k ]]
94+ if gt_judgement in pred_judgements :
95+ pred_judgement = gt_judgement
96+ else :
97+ not_none_pred_judgements = [pred_judgement for pred_judgement in pred_judgements if pred_judgement is not None ]
98+ pred_judgement = not_none_pred_judgements [0 ] if not_none_pred_judgements else None
6599
66- for pred in predictions [:k ]:
67- gt_judgement = is_correct_judgement (pred ['expected_judgement' ])
68- pred_judgement = is_correct_judgement (pred ['judgement' ])
100+ self ._update_fp_fn (eval_dict [f"pass@{ k } " ], pred_judgement , gt_judgement )
101+ self ._store_individual_metrics (f"pass@{ k } " , pred_judgement , gt_judgement )
102+
103+ for sample_idx , pred in enumerate (predictions [:k ]):
104+ gt_judgement = is_correct_judgement_or_none (pred ['expected_judgement' ])
105+ pred_judgement = is_correct_judgement_or_none (pred ['judgement' ])
69106 self ._update_fp_fn (eval_dict [f"pass@1[{ k } ]" ], pred_judgement , gt_judgement , divide_by = k )
107+ self ._store_individual_metrics (f"pass@1[{ k } ]" , pred_judgement , gt_judgement , sample_idx )
70108
71109 def update (self , predictions ):
72110 """Updating the evaluation results with the current element.
@@ -76,12 +114,69 @@ def update(self, predictions):
76114 The content of the file is benchmark specific.
77115 """
78116 super ().update (predictions )
79- predicted_answers = [is_correct_judgement (pred ['judgement' ]) for pred in predictions ]
117+ self .total_positives += float (is_correct_judgement_or_none (predictions [0 ]['expected_judgement' ]) is True )
118+ predicted_answers = [is_correct_judgement_or_none (pred ['judgement' ]) for pred in predictions ]
80119 self ._compute_pass_at_k (predictions = predictions , predicted_answers = predicted_answers )
81120 self ._compute_majority_at_k (predictions = predictions , predicted_answers = predicted_answers )
82121
122+ def _compute_precision_recall_f1 (self , datapoint_metrics ):
123+ """Compute unbiased precision, recall, F1 by averaging over K samples."""
124+ # Find the maximum number of samples K across all datapoints
125+ max_k = max (len (sample_metrics ) for sample_metrics in datapoint_metrics .values ())
126+
127+ # Compute metrics for each of the K samples, then average across K
128+ sample_precision_values = []
129+ sample_recall_values = []
130+ sample_f1_values = []
131+
132+ for sample_idx in range (max_k ):
133+ # Aggregate TP, FP, FN across all N datapoints for sample k
134+ total_tp , total_fp , total_fn = 0 , 0 , 0
135+
136+ for sample_metrics in datapoint_metrics .values ():
137+ metrics = sample_metrics [sample_idx ]
138+ total_tp += metrics ['tp' ]
139+ total_fp += metrics ['fp' ]
140+ total_fn += metrics ['fn' ]
141+
142+ # Compute precision for sample k
143+ if total_tp + total_fp > 0 :
144+ sample_precision = total_tp / (total_tp + total_fp )
145+ else :
146+ sample_precision = 1.0
147+ sample_precision_values .append (sample_precision )
148+
149+ # Compute recall for sample k
150+ if self .total_positives > 0 :
151+ sample_recall = total_tp / self .total_positives
152+ else :
153+ sample_recall = 1.0
154+ sample_recall_values .append (sample_recall )
155+
156+ # Compute F1 for sample k
157+ if sample_precision + sample_recall > 0 :
158+ sample_f1 = 2 * (sample_precision * sample_recall ) / (sample_precision + sample_recall )
159+ else :
160+ sample_f1 = 0.0
161+ sample_f1_values .append (sample_f1 )
162+
163+ # Average across all K samples
164+ return {
165+ 'precision' : 100 * sum (sample_precision_values ) / max_k ,
166+ 'recall' : 100 * sum (sample_recall_values ) / max_k ,
167+ 'f1' : 100 * sum (sample_f1_values ) / max_k ,
168+ }
169+
83170 def get_metrics (self ):
84171 # renaming no_answer to invalid_judgements
85172 for agg_metric_dict in self .eval_dict .values ():
86173 agg_metric_dict ["invalid_judgements" ] = agg_metric_dict .pop ("no_answer" )
87- return super ().get_metrics ()
174+
175+ metrics_dict = super ().get_metrics ()
176+
177+ # Compute unbiased precision, recall, F1 by averaging over K samples
178+ for agg_key , datapoint_metrics in self .individual_metrics .items ():
179+ if agg_key in metrics_dict :
180+ metrics_dict [agg_key ].update (self ._compute_precision_recall_f1 (datapoint_metrics ))
181+
182+ return metrics_dict
0 commit comments