-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmetrics.py
More file actions
70 lines (49 loc) · 2.02 KB
/
metrics.py
File metadata and controls
70 lines (49 loc) · 2.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from typing import List
import re
from collections import Counter
import string
def normalize_answer(s: str) -> str:
"""Normalize answer string for comparison."""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def exact_match(prediction: str, ground_truth: str) -> float:
"""Compute exact match score."""
return float(normalize_answer(prediction) == normalize_answer(ground_truth))
def f1_score(prediction: str, ground_truth: str) -> float:
"""Compute F1 score between prediction and ground truth."""
pred_tokens = normalize_answer(prediction).split()
truth_tokens = normalize_answer(ground_truth).split()
if len(pred_tokens) == 0 or len(truth_tokens) == 0:
return float(pred_tokens == truth_tokens)
common = Counter(pred_tokens) & Counter(truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0.0
precision = num_same / len(pred_tokens)
recall = num_same / len(truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def compute_metrics(predictions: List[str], ground_truths: List[str]) -> dict:
"""Compute metrics for a batch of predictions."""
assert len(predictions) == len(
ground_truths
), "Predictions and ground truths must have same length"
em_scores = [
exact_match(pred, truth) for pred, truth in zip(predictions, ground_truths)
]
f1_scores = [
f1_score(pred, truth) for pred, truth in zip(predictions, ground_truths)
]
return {
"exact_match": sum(em_scores) / len(em_scores) if em_scores else 0.0,
"f1": sum(f1_scores) / len(f1_scores) if f1_scores else 0.0,
"count": len(predictions),
}