Skip to content

Commit 7c23fd8

Browse files
authored
Merge pull request #1 from mfakaehler/add_rprec_metric
Add R-Precision as metric
2 parents b4b2256 + 3500e7f commit 7c23fd8

1 file changed

Lines changed: 31 additions & 1 deletion

File tree

shared-task-eval-script/llms4subjects-evaluation.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,27 @@ def recall(true_labels: list, pred_labels: list, k: int):
9292
intersection = true_set & pred_set
9393
return round(len(intersection) / len(true_set), 4)
9494

95+
def rprec(true_labels: list, pred_labels: list, k: int):
96+
"""
97+
Calculates R-Precision@k as in
98+
Manning, C. D., Raghavan, P., & Schütze, H. (2012).
99+
Introduction to Information Retrieval. In Introduction to Information Retrieval.
100+
Cambridge University Press. https://doi.org/10.1017/CBO9780511809071
101+
102+
Args:
103+
true_labels (list): The list of true labels
104+
pred_labels (list): The list of predicted labels
105+
k (int): The value of K representing the top k values to consider
106+
107+
Returns:
108+
float: R-Precision@k
109+
"""
110+
true_set = set(true_labels)
111+
pred_set = set(pred_labels[:k])
112+
breakevenpoint = min(len(true_set), len(pred_set))
113+
intersection = true_set & pred_set
114+
return round(len(intersection) / breakevenpoint, 4)
115+
95116
def f1(precision_k: float, recall_k: float):
96117
"""
97118
Calculates the f1@k for the given precision@k and recall@k.
@@ -141,19 +162,22 @@ def evaluate_combined_record_type_language(true_dict: dict, predicted_dict: dict
141162
#Calculating the recall and precision at k
142163
recall_k = recall(true_labels, pred_labels, k)
143164
precision_k = precision(true_labels, pred_labels, k)
165+
rprec_k = rprec(true_labels, pred_labels, k)
144166

145167
total_recall += recall_k
146168
total_precision += precision_k
169+
total_rprec += rprec_k
147170

148171
#Averaging recall and precision and calculating the f1 score
149172
avg_recall = total_recall / count if count else 0.0
150173
avg_precision = total_precision / count if count else 0.0
174+
avg_rprec = total_rprec / count if count else 0.0
151175
avg_f1 = f1(avg_recall, avg_precision)
152176

153177
#Saving the metrics score in the dictionary
154178
if record_type not in combined_metrics:
155179
combined_metrics[record_type] = {}
156-
combined_metrics[record_type][language] = {f'precision_{k}': avg_precision, f'recall_{k}': avg_recall, f'f1_{k}': avg_f1}
180+
combined_metrics[record_type][language] = {f'precision_{k}': avg_precision, f'recall_{k}': avg_recall, f'rprec_{k}': avg_rprec, f'f1_{k}': avg_f1}
157181

158182
return combined_metrics
159183

@@ -193,15 +217,18 @@ def evaluate_record_type_level(true_dict: dict, predicted_dict: dict, k: int):
193217
#Calculating the recall and precision at k
194218
recall_k = recall(true_labels, pred_labels, k)
195219
precision_k = precision(true_labels, pred_labels, k)
220+
rprec_k = rprec(true_labels, pred_labels, k)
196221

197222
metrics_score[record_type][f'recall_{k}'] += recall_k
198223
metrics_score[record_type][f'precision_{k}'] += precision_k
224+
metrics_score[record_type][f'rprec_{k}'] += rprec_k
199225

200226
#Averaging recall and precision and calculating the f1 score
201227
for record_type, metrics in metrics_score.items():
202228
total_files = metrics['total_files']
203229
metrics[f'recall_{k}'] = metrics[f'recall_{k}'] / total_files if total_files else 0.0
204230
metrics[f'precision_{k}'] = metrics[f'precision_{k}'] / total_files if total_files else 0.0
231+
metrics[f'rprec_{k}'] = metrics[f'rprec_{k}'] / total_files if total_files else 0.0
205232
metrics[f'f1_{k}'] = f1(metrics[f'recall_{k}'], metrics[f'precision_{k}'])
206233

207234
#Deleting the total files key and value
@@ -245,15 +272,18 @@ def evaluate_language_level(true_dict: dict, predicted_dict: dict, k: int):
245272
#Calculating the recall and precision at k
246273
recall_k = recall(true_labels, pred_labels, k)
247274
precision_k = precision(true_labels, pred_labels, k)
275+
rprec_k = rprec(true_labels, pred_labels, k)
248276

249277
metrics_score[language][f'recall_{k}'] += recall_k
250278
metrics_score[language][f'precision_{k}'] += precision_k
279+
metrics_score[language][f'rprec_{k}'] += rprec_k
251280

252281
#Averaging recall and precision and calculating the f1 score
253282
for language, metrics in metrics_score.items():
254283
total_files = metrics['total_files']
255284
metrics[f'recall_{k}'] = metrics[f'recall_{k}'] / total_files if total_files else 0.0
256285
metrics[f'precision_{k}'] = metrics[f'precision_{k}'] / total_files if total_files else 0.0
286+
metrics[f'rprec_{k}'] = metrics[f'rprec_{k}'] / total_files if total_files else 0.0
257287
metrics[f'f1_{k}'] = f1(metrics[f'recall_{k}'], metrics[f'precision_{k}'])
258288

259289
#Deleting the total files key and value

0 commit comments

Comments
 (0)