Skip to content

Commit 930d837

Browse files
authored
Longbench bugfix (EleutherAI#2895)
* add warning in for default until * fix stop tokens; add vcsum * bugfix:fix doc_to_target to string * fix lsht, trec * add task to readme * add debugging logs for multiple input/output
1 parent 82fe48e commit 930d837

39 files changed

+320
-268
lines changed

lm_eval/api/task.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def __post_init__(self) -> None:
113113
)
114114

115115
if "until" not in self.generation_kwargs:
116+
eval_logger.warning(
117+
f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}"
118+
)
116119
self.generation_kwargs["until"] = [self.fewshot_delimiter]
117120
else:
118121
if self.output_type == "generate_until":
@@ -124,7 +127,11 @@ def __post_init__(self) -> None:
124127
else [self.fewshot_delimiter]
125128
),
126129
"do_sample": False,
130+
"temperature": 0,
127131
}
132+
eval_logger.warning(
133+
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
134+
)
128135

129136
def __getitem__(self, item):
130137
return getattr(self, item)
@@ -928,11 +935,17 @@ def __init__(
928935
num_choice = len(test_choice)
929936

930937
if isinstance(test_text, int):
938+
eval_logger.debug(
939+
"doc_to_text returned an int. Assuming multiple inputs."
940+
)
931941
self.multiple_input = num_choice
932942
else:
933943
test_choice = None
934944

935945
if isinstance(test_target, list):
946+
eval_logger.debug(
947+
"doc_to_target returned a list. Assuming multiple targets."
948+
)
936949
self.multiple_target = len(test_target)
937950
else:
938951
if (isinstance(test_target, int)) and (test_choice is not None):

lm_eval/tasks/README.md

Lines changed: 158 additions & 156 deletions
Large diffs are not rendered by default.

lm_eval/tasks/longbench/2wikimqa.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
66
test_split: test
77
dataset_name: 2wikimqa
88
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
9-
doc_to_target: '{{answers}}'
9+
doc_to_target: '{{answers[0]}}'
1010
generation_kwargs:
1111
max_gen_toks: 32
1212
temperature: 1
1313
do_sample: True
14+
until: []
1415
metric_list:
1516
- metric: !function metrics.qa_f1_score
1617
aggregation: mean
1718
higher_is_better: True
1819
metadata:
19-
version: 1.0
20+
version: 2.0
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1+
12
tag:
23
- longbench_e
34
task: longbench_2wikimqa_e
45
dataset_path: THUDM/LongBench
56
test_split: test
67
dataset_name: 2wikimqa_e
78
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
8-
doc_to_target: '{{answers}}'
9+
doc_to_target: '{{answers[0]}}'
910
generation_kwargs:
1011
max_gen_toks: 32
1112
temperature: 1
1213
do_sample: True
14+
until: []
1315
metric_list:
1416
- metric: !function metrics.qa_f1_score
1517
aggregation: mean
1618
higher_is_better: True
1719
metadata:
18-
version: 1.0
20+
version: 2.0

lm_eval/tasks/longbench/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,4 @@ If other tasks on this dataset are already supported:
9595
* [x] Have you noted which, if any, published evaluation setups are matched by this variant?
9696

9797
### Changelog
98+
v2.: fix doc_to_target; add vcsum

lm_eval/tasks/longbench/_generate_config.py

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@
138138

139139
def parse_args():
140140
parser = argparse.ArgumentParser()
141-
parser.add_argument("--save_prefix_path", default="longbench")
141+
parser.add_argument("--save_prefix_path", default="")
142142
return parser.parse_args()
143143

144144

@@ -156,6 +156,7 @@ def parse_args():
156156
max_gen_toks: {{ generation_kwargs.max_gen_toks }}
157157
temperature: {{ generation_kwargs.temperature }}
158158
do_sample: {{ generation_kwargs.do_sample }}
159+
until: {{ generation_kwargs.until }}
159160
metric_list:
160161
- metric: {{ metric_list[0].metric }}
161162
aggregation: {{ metric_list[0].aggregation }}
@@ -171,10 +172,21 @@ def parse_args():
171172
template = env.from_string(template_str)
172173
for ds in DATASETS:
173174
df = ds[:-2] if ds.endswith("_e") else ds
175+
# from https://github.com/THUDM/LongBench/blob/2e00731f8d0bff23dc4325161044d0ed8af94c1e/LongBench/eval.py#L52C25-L52C29
176+
if df in ["trec", "triviaqa", "samsum", "lsht"] + [
177+
"trec_e",
178+
"triviaqa_e",
179+
"samsum_e",
180+
"lsht_e",
181+
]:
182+
until = ["\n"]
183+
else:
184+
until = []
174185
generation_kwargs = {
175186
"max_gen_toks": dataset2maxlen[df],
176187
"temperature": 1,
177188
"do_sample": True,
189+
"until": until,
178190
}
179191
raw_doc_to_text = (
180192
dataset2prompt[df]
@@ -199,10 +211,10 @@ def parse_args():
199211
"test_split": "test",
200212
"dataset_name": ds,
201213
"doc_to_text": raw_doc_to_text,
202-
"doc_to_target": "{{answers}}",
214+
"doc_to_target": "{{answers[0]}}",
203215
"generation_kwargs": generation_kwargs,
204216
"metric_list": metric_list,
205-
"metadata": {"version": "1.0"},
217+
"metadata": {"version": "2.0"},
206218
}
207219

208220
# Render template
@@ -211,35 +223,3 @@ def parse_args():
211223
# Save to file
212224
with open(args.save_prefix_path + f"{ds}.yaml", "w") as f:
213225
f.write(rendered_yaml)
214-
215-
# for ds in DATASETS:
216-
# df = ds[:-2] if ds.endswith("_e") else ds
217-
# generation_kwargs = {"max_gen_toks": dataset2maxlen[df], "temperature": 1, "do_sample": False}
218-
# # Escape newlines and curly braces
219-
# raw_doc_to_text = dataset2prompt[df].replace("\n", "\\n").replace("{", "{{").replace("}", "}}")
220-
# metric_list = [
221-
# {"metric": f"!function metrics.{dataset2metric[df]}", "aggregation": "mean", "higher_is_better": True}]
222-
# yaml_dict = {
223-
# "tag": ["longbench_e" if ds.endswith("_e") else "longbench"],
224-
# "task": f"longbench_{ds}",
225-
# "dataset_path": "THUDM/LongBench",
226-
# "test_split": "test",
227-
# "dataset_name": ds,
228-
# "doc_to_text": raw_doc_to_text,
229-
# "doc_to_target": "{{answers}}",
230-
# "generation_kwargs": generation_kwargs,
231-
# "metric_list": metric_list,
232-
# "metadata": {"version": "1.0"}
233-
# }
234-
# template = env.from_string(yaml_dict)
235-
#
236-
#
237-
# file_save_path = args.save_prefix_path + f"{ds}.yaml"
238-
# with open(file_save_path, "w", encoding="utf-8") as yaml_file:
239-
# yaml.dump(
240-
# yaml_dict,
241-
# yaml_file,
242-
# allow_unicode=True,
243-
# default_flow_style=False,
244-
# sort_keys=False
245-
# )

lm_eval/tasks/longbench/dureader.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
66
test_split: test
77
dataset_name: dureader
88
doc_to_text: '请基于给定的文章回答下述问题。\n\n文章:{{context}}\n\n请基于上述文章回答下面的问题。\n\n问题:{{input}}\n回答:'
9-
doc_to_target: '{{answers}}'
9+
doc_to_target: '{{answers[0]}}'
1010
generation_kwargs:
1111
max_gen_toks: 128
1212
temperature: 1
1313
do_sample: True
14+
until: []
1415
metric_list:
1516
- metric: !function metrics.rouge_zh_score
1617
aggregation: mean
1718
higher_is_better: True
1819
metadata:
19-
version: 1.0
20+
version: 2.0

lm_eval/tasks/longbench/gov_report.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
66
test_split: test
77
dataset_name: gov_report
88
doc_to_text: 'You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{{context}}\n\nNow, write a one-page summary of the report.\n\nSummary:'
9-
doc_to_target: '{{answers}}'
9+
doc_to_target: '{{answers[0]}}'
1010
generation_kwargs:
1111
max_gen_toks: 512
1212
temperature: 1
1313
do_sample: True
14+
until: []
1415
metric_list:
1516
- metric: !function metrics.rouge_score
1617
aggregation: mean
1718
higher_is_better: True
1819
metadata:
19-
version: 1.0
20+
version: 2.0

lm_eval/tasks/longbench/gov_report_e.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
66
test_split: test
77
dataset_name: gov_report_e
88
doc_to_text: 'You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{{context}}\n\nNow, write a one-page summary of the report.\n\nSummary:'
9-
doc_to_target: '{{answers}}'
9+
doc_to_target: '{{answers[0]}}'
1010
generation_kwargs:
1111
max_gen_toks: 512
1212
temperature: 1
1313
do_sample: True
14+
until: []
1415
metric_list:
1516
- metric: !function metrics.rouge_score
1617
aggregation: mean
1718
higher_is_better: True
1819
metadata:
19-
version: 1.0
20+
version: 2.0

lm_eval/tasks/longbench/hotpotqa.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
66
test_split: test
77
dataset_name: hotpotqa
88
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
9-
doc_to_target: '{{answers}}'
9+
doc_to_target: '{{answers[0]}}'
1010
generation_kwargs:
1111
max_gen_toks: 32
1212
temperature: 1
1313
do_sample: True
14+
until: []
1415
metric_list:
1516
- metric: !function metrics.qa_f1_score
1617
aggregation: mean
1718
higher_is_better: True
1819
metadata:
19-
version: 1.0
20+
version: 2.0

0 commit comments

Comments
 (0)