Skip to content

Commit e419f40

Browse files
committed
Add prompt for Llama-3-Instruct
1 parent 6d06e74 commit e419f40

File tree

3 files changed

+43
-13
lines changed

3 files changed

+43
-13
lines changed

dart_math/utils.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,28 @@
132132
# {resp}
133133
delim="\n\n",
134134
),
135+
"llama3-math": dict( # https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3
136+
id="llama3-math",
137+
sys_prompt=(
138+
"<|begin_of_text|>"
139+
# + "<|start_header_id|>system<|end_header_id|>\n\n"
140+
# + "You are a helpful assistant."
141+
# + "<|eot_id|>"
142+
),
143+
query_prompt="<|start_header_id|>" + "user" + "<|end_header_id|>" + "\n\n",
144+
# {query}
145+
prompt_after_query="<|eot_id|>",
146+
resp_prompt="<|start_header_id|>" + "assistant" + "<|end_header_id|>" + "\n\n",
147+
prompt_before_resp="",
148+
# {resp}
149+
delim="<|eot_id|>" + "\n",
150+
model_ids=[
151+
"meta-llama--Meta-Llama-3-8B-Instruct",
152+
"meta-llama--Meta-Llama-3-70B-Instruct",
153+
"meta-llama--Meta-Llama-3.1-8B-Instruct",
154+
"meta-llama--Meta-Llama-3.1-70B-Instruct",
155+
],
156+
),
135157
}
136158

137159

@@ -184,7 +206,13 @@ def __init__(
184206
def load_from_id_or_path(prompt_template: str = "alpaca") -> "PromptTemplate":
185207
"""Load prompt template from ID or file path."""
186208
if prompt_template in PROMPT_TEMPLATE_ID2DICT: # ID
187-
return PromptTemplate(**PROMPT_TEMPLATE_ID2DICT[prompt_template])
209+
return PromptTemplate(
210+
**{
211+
k: v
212+
for k, v in PROMPT_TEMPLATE_ID2DICT[prompt_template].items()
213+
if k != "model_ids"
214+
}
215+
)
188216
elif isinstance(prompt_template, str) and os.path.exists(prompt_template):
189217
# File path
190218
stem = os.path.splitext(os.path.basename(prompt_template))[0]
@@ -216,15 +244,15 @@ def make_full_prompt(self, query: str, eg_qas: list[tuple[str, str]] = []) -> st
216244
@staticmethod
217245
def get_prompt_template_from_prompt_type_and_model(
218246
prompt_type: str,
219-
model_name_or_path: str,
247+
model_dirname: str,
220248
) -> "PromptTemplate":
221249
"""Get the prompt template suitable for the model.
222250
223251
Parameters
224252
----------
225253
prompt_type : str
226254
Prompt type, like "cot" or "tool".
227-
model_name_or_path : str
255+
model_dirname : str
228256
HF ID or path to the model.
229257
230258
Returns
@@ -234,29 +262,31 @@ def get_prompt_template_from_prompt_type_and_model(
234262
"""
235263
prompt_template = None
236264
if prompt_type == "cot":
237-
if model_name_or_path in BASE_MODEL_IDS + MATH_SHEPHERD_MODEL_IDS:
265+
if model_dirname in BASE_MODEL_IDS + MATH_SHEPHERD_MODEL_IDS:
238266
prompt_template = "qa"
239-
elif model_name_or_path.startswith("dart-math"):
267+
elif model_dirname.startswith("dart-math"):
240268
prompt_template = "alpaca"
241-
elif model_name_or_path in DEEPSEEK_INSTR_MODEL_IDS:
269+
elif model_dirname in DEEPSEEK_INSTR_MODEL_IDS:
242270
prompt_template = "deepseekmath"
243-
elif model_name_or_path.startswith("Xwin-LM/Xwin-Math"):
271+
elif model_dirname.startswith("Xwin-LM/Xwin-Math"):
244272
prompt_template = "xwinmath"
245-
elif model_name_or_path.startswith("TIGER-Lab--MAmmoTH2"):
273+
elif model_dirname.startswith("TIGER-Lab--MAmmoTH2"):
246274
prompt_template = "mammoth2-cot"
275+
elif model_dirname in PROMPT_TEMPLATE_ID2DICT["llama3-math"]["model_ids"]:
276+
prompt_template = "llama3-math"
247277
else: # default
248278
prompt_template = "alpaca"
249279
elif prompt_type == "tool":
250-
if model_name_or_path in DEEPSEEK_INSTR_MODEL_IDS:
280+
if model_dirname in DEEPSEEK_INSTR_MODEL_IDS:
251281
prompt_template = "deepseekmath-tool"
252282

253283
if prompt_template is None:
254284
raise ValueError(
255-
f"Unknown prompt type {prompt_type} for model {model_name_or_path}."
285+
f"Unknown prompt type {prompt_type} for model {model_dirname}."
256286
)
257287

258288
prompt_template = PromptTemplate.load_from_id_or_path(prompt_template)
259-
if "MMIQC" in model_name_or_path:
289+
if "MMIQC" in model_dirname:
260290
prompt_template.prompt_before_resp = (
261291
'Please solve the following problem and put your answer at the end with "The answer is: ".'
262292
+ " "

pipeline/gen.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@
303303
"\n",
304304
"prompt_template = (\n",
305305
" PromptTemplate.get_prompt_template_from_prompt_type_and_model(\n",
306-
" prompt_type=args.prompt_template, model_name_or_path=model_dirname\n",
306+
" prompt_type=args.prompt_template, model_dirname=model_dirname\n",
307307
" )\n",
308308
" if args.prompt_template in [\"cot\", \"tool\"]\n",
309309
" else PromptTemplate.load_from_id_or_path(args.prompt_template)\n",

pipeline/gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@
208208

209209
prompt_template = (
210210
PromptTemplate.get_prompt_template_from_prompt_type_and_model(
211-
prompt_type=args.prompt_template, model_name_or_path=model_dirname
211+
prompt_type=args.prompt_template, model_dirname=model_dirname
212212
)
213213
if args.prompt_template in ["cot", "tool"]
214214
else PromptTemplate.load_from_id_or_path(args.prompt_template)

0 commit comments

Comments
 (0)