132
132
# {resp}
133
133
delim = "\n \n " ,
134
134
),
135
+ "llama3-math" : dict ( # https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3
136
+ id = "llama3-math" ,
137
+ sys_prompt = (
138
+ "<|begin_of_text|>"
139
+ # + "<|start_header_id|>system<|end_header_id|>\n\n"
140
+ # + "You are a helpful assistant."
141
+ # + "<|eot_id|>"
142
+ ),
143
+ query_prompt = "<|start_header_id|>" + "user" + "<|end_header_id|>" + "\n \n " ,
144
+ # {query}
145
+ prompt_after_query = "<|eot_id|>" ,
146
+ resp_prompt = "<|start_header_id|>" + "assistant" + "<|end_header_id|>" + "\n \n " ,
147
+ prompt_before_resp = "" ,
148
+ # {resp}
149
+ delim = "<|eot_id|>" + "\n " ,
150
+ model_ids = [
151
+ "meta-llama--Meta-Llama-3-8B-Instruct" ,
152
+ "meta-llama--Meta-Llama-3-70B-Instruct" ,
153
+ "meta-llama--Meta-Llama-3.1-8B-Instruct" ,
154
+ "meta-llama--Meta-Llama-3.1-70B-Instruct" ,
155
+ ],
156
+ ),
135
157
}
136
158
137
159
@@ -184,7 +206,13 @@ def __init__(
184
206
def load_from_id_or_path (prompt_template : str = "alpaca" ) -> "PromptTemplate" :
185
207
"""Load prompt template from ID or file path."""
186
208
if prompt_template in PROMPT_TEMPLATE_ID2DICT : # ID
187
- return PromptTemplate (** PROMPT_TEMPLATE_ID2DICT [prompt_template ])
209
+ return PromptTemplate (
210
+ ** {
211
+ k : v
212
+ for k , v in PROMPT_TEMPLATE_ID2DICT [prompt_template ].items ()
213
+ if k != "model_ids"
214
+ }
215
+ )
188
216
elif isinstance (prompt_template , str ) and os .path .exists (prompt_template ):
189
217
# File path
190
218
stem = os .path .splitext (os .path .basename (prompt_template ))[0 ]
@@ -216,15 +244,15 @@ def make_full_prompt(self, query: str, eg_qas: list[tuple[str, str]] = []) -> st
216
244
@staticmethod
217
245
def get_prompt_template_from_prompt_type_and_model (
218
246
prompt_type : str ,
219
- model_name_or_path : str ,
247
+ model_dirname : str ,
220
248
) -> "PromptTemplate" :
221
249
"""Get the prompt template suitable for the model.
222
250
223
251
Parameters
224
252
----------
225
253
prompt_type : str
226
254
Prompt type, like "cot" or "tool".
227
- model_name_or_path : str
255
+ model_dirname : str
228
256
HF ID or path to the model.
229
257
230
258
Returns
@@ -234,29 +262,31 @@ def get_prompt_template_from_prompt_type_and_model(
234
262
"""
235
263
prompt_template = None
236
264
if prompt_type == "cot" :
237
- if model_name_or_path in BASE_MODEL_IDS + MATH_SHEPHERD_MODEL_IDS :
265
+ if model_dirname in BASE_MODEL_IDS + MATH_SHEPHERD_MODEL_IDS :
238
266
prompt_template = "qa"
239
- elif model_name_or_path .startswith ("dart-math" ):
267
+ elif model_dirname .startswith ("dart-math" ):
240
268
prompt_template = "alpaca"
241
- elif model_name_or_path in DEEPSEEK_INSTR_MODEL_IDS :
269
+ elif model_dirname in DEEPSEEK_INSTR_MODEL_IDS :
242
270
prompt_template = "deepseekmath"
243
- elif model_name_or_path .startswith ("Xwin-LM/Xwin-Math" ):
271
+ elif model_dirname .startswith ("Xwin-LM/Xwin-Math" ):
244
272
prompt_template = "xwinmath"
245
- elif model_name_or_path .startswith ("TIGER-Lab--MAmmoTH2" ):
273
+ elif model_dirname .startswith ("TIGER-Lab--MAmmoTH2" ):
246
274
prompt_template = "mammoth2-cot"
275
+ elif model_dirname in PROMPT_TEMPLATE_ID2DICT ["llama3-math" ]["model_ids" ]:
276
+ prompt_template = "llama3-math"
247
277
else : # default
248
278
prompt_template = "alpaca"
249
279
elif prompt_type == "tool" :
250
- if model_name_or_path in DEEPSEEK_INSTR_MODEL_IDS :
280
+ if model_dirname in DEEPSEEK_INSTR_MODEL_IDS :
251
281
prompt_template = "deepseekmath-tool"
252
282
253
283
if prompt_template is None :
254
284
raise ValueError (
255
- f"Unknown prompt type { prompt_type } for model { model_name_or_path } ."
285
+ f"Unknown prompt type { prompt_type } for model { model_dirname } ."
256
286
)
257
287
258
288
prompt_template = PromptTemplate .load_from_id_or_path (prompt_template )
259
- if "MMIQC" in model_name_or_path :
289
+ if "MMIQC" in model_dirname :
260
290
prompt_template .prompt_before_resp = (
261
291
'Please solve the following problem and put your answer at the end with "The answer is: ".'
262
292
+ " "
0 commit comments