Skip to content

Commit 82377fc

Browse files
committed
Fix ignore_eos
1 parent 91fd960 commit 82377fc

File tree

5 files changed

+50
-15
lines changed

5 files changed

+50
-15
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,9 @@ To reproduce other inference settings, just refer to the paper and
328328
modify the `--model_name_or_path` and `--gen_save_path` arguments
329329
accordingly.
330330

331+
- We observed that Llama-3-8B(-Base) tends to decode EoS immediately
332+
sometimes. Try use `--ignore_eos` as a workaround.
333+
331334
For other general inference settings, please modify the command or
332335
directly modify the
333336
[script](https://github.com/hkust-nlp/dart-math/blob/main/pipeline/gen.py).

dart_math/eval.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,17 +126,21 @@ def eq(self, ref: str, ans: str) -> bool:
126126

127127
def extract_ans(self, resp_str: str) -> str:
128128
"""Extract answer segment from complete `resp`."""
129-
130-
resp = self.extract_explicit_ans(resp_str)
131-
if not self.strict_extract and resp is None: # use the last number
132-
pattern = r"-?\d*\.?\d+"
133-
resp = re.findall(pattern, resp_str.replace(",", ""))
134-
if len(resp) >= 1:
135-
resp = resp[-1]
136-
else:
137-
resp = ""
138-
139-
return resp
129+
ans = self.extract_explicit_ans(resp_str)
130+
if ans is not None:
131+
return ans
132+
elif not self.strict_extract:
133+
# Speculate with the last latex formula
134+
matches = re.findall(
135+
r"(?:\$|\\\(|\\\[)([^\$]+)(?:\$|\\\(|\\\[)", resp_str, re.DOTALL
136+
)
137+
if len(matches) > 0:
138+
return matches[-1]
139+
# Speculate with the last number
140+
matches = re.findall(r"-?\d*\.?\d+", resp_str.replace(",", ""))
141+
if len(matches) > 0:
142+
return matches[-1]
143+
return "" # Empty str if no answer is found
140144

141145
def extract_explicit_ans(self, resp_str: str) -> str:
142146
resp_str = self.clean_trailing(resp_str)

nbs/index.ipynb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,8 @@
350350
"\n",
351351
"To reproduce other inference settings, just refer to the paper and modify the `--model_name_or_path` and `--gen_save_path` arguments accordingly.\n",
352352
"\n",
353+
"- We observed that Llama-3-8B(-Base) tends to decode EoS immediately sometimes. Try use `--ignore_eos` as a workaround.\n",
354+
"\n",
353355
"For other general inference settings, please modify the command or directly modify the [script](https://github.com/hkust-nlp/dart-math/blob/main/pipeline/gen.py).\n",
354356
"\n",
355357
"- To test **base** models, please add the corresponding **ID** to `BASE_MODEL_IDS` from [dart_math.utils](https://github.com/hkust-nlp/dart-math/blob/main/dart_math/utils.py).\n",

pipeline/gen.ipynb

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,13 @@
113113
")\n",
114114
"\n",
115115
"parser.add_argument(\n",
116+
" \"--revision\",\n",
117+
" type=str,\n",
118+
" default=None,\n",
119+
" help=\"Model revision.\",\n",
120+
")\n",
121+
"\n",
122+
"parser.add_argument(\n",
116123
" \"--dtype\",\n",
117124
" type=str,\n",
118125
" default=\"bfloat16\",\n",
@@ -148,6 +155,12 @@
148155
" help=\"Maximum number of new tokens.\",\n",
149156
")\n",
150157
"parser.add_argument(\n",
158+
" \"--ignore_eos\",\n",
159+
" action=\"store_true\",\n",
160+
" default=False,\n",
161+
" help=\"Ignore EOS token in generation. Llama-3-8B(-Base) tends to decode EoS immediately. Try this if you encounter this issue.\",\n",
162+
")\n",
163+
"parser.add_argument(\n",
151164
" \"--n_shots\",\n",
152165
" type=int,\n",
153166
" default=-1,\n",
@@ -342,7 +355,7 @@
342355
" temperature=args.temperature,\n",
343356
" top_p=args.top_p,\n",
344357
" max_tokens=args.max_new_toks,\n",
345-
" ignore_eos=True, # Llama-3-8B(-Base) tends to decode EoS immediately\n",
358+
" ignore_eos=args.ignore_eos,\n",
346359
" skip_special_tokens=True,\n",
347360
" seed=args.inf_seed,\n",
348361
")"
@@ -421,7 +434,7 @@
421434
"source": [
422435
"llm = LLM(\n",
423436
" model=args.model_name_or_path,\n",
424-
" tokenizer=args.model_name_or_path,\n",
437+
" revision=args.revision,\n",
425438
" tensor_parallel_size=torch.cuda.device_count(),\n",
426439
" dtype=args.dtype,\n",
427440
" seed=args.inf_seed,\n",

pipeline/gen.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@
6060
help="HF-style model name or path.",
6161
)
6262

63+
parser.add_argument(
64+
"--revision",
65+
type=str,
66+
default=None,
67+
help="Model revision.",
68+
)
69+
6370
parser.add_argument(
6471
"--dtype",
6572
type=str,
@@ -95,6 +102,12 @@
95102
default=2048,
96103
help="Maximum number of new tokens.",
97104
)
105+
parser.add_argument(
106+
"--ignore_eos",
107+
action="store_true",
108+
default=False,
109+
help="Ignore EOS token in generation. Llama-3-8B(-Base) tends to decode EoS immediately. Try this if you encounter this issue.",
110+
)
98111
parser.add_argument(
99112
"--n_shots",
100113
type=int,
@@ -231,7 +244,7 @@
231244
temperature=args.temperature,
232245
top_p=args.top_p,
233246
max_tokens=args.max_new_toks,
234-
ignore_eos=True, # Llama-3-8B(-Base) tends to decode EoS immediately
247+
ignore_eos=args.ignore_eos,
235248
skip_special_tokens=True,
236249
seed=args.inf_seed,
237250
)
@@ -244,7 +257,7 @@
244257

245258
llm = LLM(
246259
model=args.model_name_or_path,
247-
tokenizer=args.model_name_or_path,
260+
revision=args.revision,
248261
tensor_parallel_size=torch.cuda.device_count(),
249262
dtype=args.dtype,
250263
seed=args.inf_seed,

0 commit comments

Comments
 (0)