-
Notifications
You must be signed in to change notification settings - Fork 74
Expand file tree
/
Copy pathmetric.py
More file actions
294 lines (253 loc) · 9.79 KB
/
metric.py
File metadata and controls
294 lines (253 loc) · 9.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import random
import re
import string
from abc import ABC, abstractmethod
from collections import Counter
from pathlib import Path
from typing import Callable, List, Optional, Union
import jieba
import numpy as np
from common.config_utils import config_utils as config_instance
from common.uc_eval.utils.data_class import MultiTurnDialogRecord, RequestRecord
stopwords_path = Path(
config_instance.get_nested_config(
"dataset.stopwords", "/mnt/private/dataset/uc-eval/cn_stopwords.txt"
)
)
STOPWORDS: List[str] = [
line.strip() for line in stopwords_path.open("r", encoding="utf-8").readlines()
]
def normalize_text(text: str) -> str:
# Remove punctuation (CJK full-width & ASCII)
pattern = r"[\u3000-\u303F\uFF00-\uFFEF" + re.escape(string.punctuation) + "]"
text = re.sub(pattern, "", text)
# Segment with jieba (precise mode) and lowercase
words = jieba.lcut(text)
words = [word.strip().lower() for word in words]
# Drop stop-words
filtered_words = [word for word in words if word not in STOPWORDS and word != ""]
text = " ".join(map(str, filtered_words))
return text
class MetricClass(ABC):
def __init__(self, record_list: List[RequestRecord | MultiTurnDialogRecord]):
self.record_list = record_list
self.ACCURACY_METIRC_FUNCTION_MAP: dict[str, Callable] = {
"accuracy": self.get_accuracy,
"bootstrap-accuracy": self.get_bootstrap_accuracy_std,
"f1-score": self.get_f1_score,
}
def calculate_metric(self, metric_names: List[str]):
"""
Get evaluation metrics
"""
for record in self.record_list:
expected_output, real_output = self.get_expected_and_real_output(record)
if self.match(expected_output, real_output):
record.is_match = True
metric_dict = {}
for metric in metric_names:
metric_function = self.ACCURACY_METIRC_FUNCTION_MAP[metric]
metric_dict[metric] = metric_function(self.record_list)
return metric_dict
def get_expected_and_real_output(
self, record: Union[RequestRecord, MultiTurnDialogRecord]
):
expected_output = record.expected_output
real_output = self.del_chain_of_thought(record.output_data)
if isinstance(expected_output, tuple):
expected_output = list(expected_output)
elif not isinstance(expected_output, list):
expected_output = [expected_output]
return expected_output, real_output
def get_normalize_text(self, record: Union[RequestRecord, MultiTurnDialogRecord]):
"""
Perform standardization of output data
"""
expected_output, real_output = self.get_expected_and_real_output(record)
expected_output = [normalize_text(output) for output in expected_output]
real_output = normalize_text(real_output)
return expected_output, real_output
def del_chain_of_thought(
self,
output_data: str,
think_start_tokens: str = "<think>",
think_end_tokens: str = "</think>",
):
"""
Delete the chain of thought from the output data
"""
if (
think_start_tokens not in output_data
and think_end_tokens not in output_data
):
return output_data
elif think_start_tokens not in output_data and think_end_tokens in output_data:
return output_data.split(think_end_tokens)[-1].strip()
start_escaped = re.escape(think_start_tokens)
end_escaped = re.escape(think_end_tokens)
reason_data = re.compile(rf"{start_escaped}(.*?){end_escaped}", re.DOTALL)
return reason_data.sub("", output_data).strip()
@abstractmethod
def match(
self,
expected_output: Union[str, List[str], tuple[str]],
real_output: str,
**kwargs,
):
pass
def get_accuracy(
self, record_list: List[RequestRecord | MultiTurnDialogRecord]
) -> float:
record_total = len(record_list)
match_num = sum(record.is_match for record in record_list)
return match_num / record_total if record_total != 0 else float("nan")
def get_bootstrap_accuracy_std(
self,
record_list: List[RequestRecord | MultiTurnDialogRecord],
num_samples: int = 1000,
):
"""
Compute standard deviation of accuracy using the Bootstrap method.
"""
if not record_list:
return float("nan")
vals = [record.is_match for record in record_list]
return np.std(
[np.mean(random.sample(vals, len(vals) // 2)) for _ in range(num_samples)]
).item()
def get_f1_score(
self,
record_list: List[RequestRecord | MultiTurnDialogRecord],
):
f1_score = []
for record in record_list:
expected_output, real_output = self.get_normalize_text(record)
f1_score.append(self._f1_score(expected_output, real_output))
return np.mean(f1_score).item()
def _f1_score(self, expected_output: List[str], real_output: str) -> float:
max_f1_score = 0
for output in expected_output:
common = Counter(output.split()) & Counter(real_output.split())
num_same = sum(common.values())
if num_same != 0:
precision = 1.0 * num_same / len(output.split())
recall = 1.0 * num_same / len(real_output.split())
f1 = (2 * precision * recall) / (precision + recall)
max_f1_score = max(max_f1_score, f1)
return max_f1_score
class Match(MetricClass):
def __init__(self, record_list: List[RequestRecord | MultiTurnDialogRecord]):
super().__init__(record_list)
def match(
self,
expected_output: List[str],
real_output: str,
separator: Callable[[str], bool] = None,
options: Optional[list[str]] = None,
) -> bool:
"""
Exact match: expected and picked must be identical
:param expected_output: the answer from dataset
:param real_output: actual output generated by model
:param separator: separator function to prevent partial matches
:param options: optional list of matching options; for multiple-choice questions, options must be present
"""
if options is None:
options = expected_output
picked = None
for option in options:
if not real_output.startswith(option):
continue
if (
separator is not None
and len(real_output) > len(options)
and not separator(real_output[len(option)])
):
continue
picked = option
break
match = picked in expected_output
return match
class Includes(MetricClass):
def __init__(self, record_list: List[RequestRecord | MultiTurnDialogRecord]):
super().__init__(record_list)
def match(
self,
expected_output: List[str],
real_output: str,
) -> bool:
"""
Match succeeds if any part expected_output is found in real_output
:param expected_output: the answer from dataset
:param real_output: actual output generated by model
"""
for output in expected_output:
if real_output.rfind(output) != -1:
return True
return False
class FuzzyMatch(MetricClass):
def __init__(self, record_list: List[RequestRecord | MultiTurnDialogRecord]):
super().__init__(record_list)
def match(
self,
expected_output: List[str],
real_output: str,
strategy: str = "substring",
threshold: float = 0.8,
) -> bool:
"""
Fuzzy matching
:param expected_output: the answer from dataset
:param real_output: actual output generated by model
:param strategy: matching strategy, currently supports substring and jaccard
:param threshold: similarity threshold for jaccard strategy
"""
return any(
self._single_match(expected, real_output, strategy, threshold)
for expected in expected_output
)
def _single_match(
self,
expected: str,
real: str,
strategy: str = "substring",
threshold: float = 0.8,
) -> bool:
if strategy == "substring":
return expected in real or real in expected
else:
set_exp, set_real = set(expected.split()), set(real.split())
if not set_exp and not set_real:
return True
if not set_exp or not set_real:
return False
inter = len(set_exp & set_real)
union = len(set_exp | set_real)
return (inter / union) >= threshold
class MatchPatterns(MetricClass):
def __init__(self, record_list: List[RequestRecord | MultiTurnDialogRecord]):
super().__init__(record_list)
def match(
self,
expected_output: List[str],
real_output: str,
) -> bool:
"""
Use the provided regular expression list to extract output from output, and then judge whether it matches
:param expected_output: the answer from dataset
:param real_output: actual output generated by model
"""
pred = self.get_answer_from_match_patterns(real_output)
if pred and pred in expected_output:
return True
return False
def get_answer_from_match_patterns(self, real_output: str):
"""
Get the answer through comparing match_patterns and output
"""
from common.uc_eval.utils.prompt_config import match_patterns
for pattern in match_patterns:
match = re.search(pattern, real_output)
if match:
return match.group(1)
return None