Skip to content

Commit 66e8817

Browse files
authored
Update examples and prompts (#1199)
- Add default system prompt templates for phi2, phi3, phi4, llama2, and llama3 models to improve the user experience and provide more accurate responses. - Improve chat templates for phi2, phi3, phi4, llama2 and llama3 models to enhance the user experience - Add info about system prompt
1 parent 0636ce3 commit 66e8817

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

examples/python/model-chat.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,14 @@ def main(args):
3434
if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1:
3535
raise ValueError("Chat template must have exactly one pair of curly braces with input word in it, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
3636
else:
37-
if model.type.startswith("phi"):
37+
if model.type.startswith("phi2") or model.type.startswith("phi3"):
3838
args.chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
39-
elif model.type.startswith("llama"):
40-
args.chat_template = '<|start_header_id|>user<|end_header_id|>{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
39+
elif model.type.startswith("phi4"):
40+
args.chat_template = '<|im_start|>user<|im_sep|>\n{input}<|im_end|>\n<|im_start|>assistant<|im_sep|>'
41+
elif model.type.startswith("llama3"):
42+
args.chat_template = '<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
43+
elif model.type.startswith("llama2"):
44+
args.chat_template = '<s>{input}'
4145
else:
4246
raise ValueError(f"Chat Template for model type {model.type} is not known. Please provide chat template using --chat_template")
4347

@@ -51,7 +55,17 @@ def main(args):
5155
if args.verbose: print("Generator created")
5256

5357
# Set system prompt
54-
system_prompt = args.system_prompt
58+
if model.type.startswith('phi2') or model.type.startswith('phi3'):
59+
system_prompt = f"<|system|>\n{args.system_prompt}<|end|>"
60+
elif model.type.startswith('phi4'):
61+
system_prompt = f"<|im_start|>system<|im_sep|>\n{args.system_prompt}<|im_end|>"
62+
elif model.type.startswith("llama3"):
63+
system_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{args.system_prompt}<|eot_id|>"
64+
elif model.type.startswith("llama2"):
65+
system_prompt = f"<s>[INST] <<SYS>>\n{args.system_prompt}\n<</SYS>>"
66+
else:
67+
system_prompt = args.system_prompt
68+
5569
system_tokens = tokenizer.encode(system_prompt)
5670
generator.append_tokens(system_tokens)
5771
system_prompt_length = len(system_tokens)
@@ -103,7 +117,7 @@ def main(args):
103117
run_time = time.time() - first_token_timestamp
104118
print(f"Prompt length: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps")
105119

106-
# Rewind the generator to the system prompt
120+
# Rewind the generator to the system prompt, this will erase all the memory of the model.
107121
if args.rewind:
108122
generator.rewind_to(system_prompt_length)
109123

examples/python/model-qa.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,34 @@ def main(args):
3030
if args.chat_template:
3131
if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1:
3232
raise ValueError("Chat template must have exactly one pair of curly braces with input word in it, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
33+
else:
34+
if model.type.startswith("phi2") or model.type.startswith("phi3"):
35+
args.chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
36+
elif model.type.startswith("phi4"):
37+
args.chat_template = '<|im_start|>user<|im_sep|>\n{input}<|im_end|>\n<|im_start|>assistant<|im_sep|>'
38+
elif model.type.startswith("llama3"):
39+
args.chat_template = '<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
40+
elif model.type.startswith("llama2"):
41+
args.chat_template = '<s>{input}'
42+
else:
43+
raise ValueError(f"Chat Template for model type {model.type} is not known. Please provide chat template using --chat_template")
3344

3445
params = og.GeneratorParams(model)
3546
params.set_search_options(**search_options)
3647
generator = og.Generator(model, params)
3748

3849
# Set system prompt
39-
system_prompt = args.system_prompt
50+
if model.type.startswith('phi2') or model.type.startswith('phi3'):
51+
system_prompt = f"<|system|>\n{args.system_prompt}<|end|>"
52+
elif model.type.startswith('phi4'):
53+
system_prompt = f"<|im_start|>system<|im_sep|>\n{args.system_prompt}<|im_end|>"
54+
elif model.type.startswith("llama3"):
55+
system_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{args.system_prompt}<|eot_id|>"
56+
elif model.type.startswith("llama2"):
57+
system_prompt = f"<s>[INST] <<SYS>>\n{args.system_prompt}\n<</SYS>>"
58+
else:
59+
system_prompt = args.system_prompt
60+
4061
system_tokens = tokenizer.encode(system_prompt)
4162
generator.append_tokens(system_tokens)
4263
system_prompt_length = len(system_tokens)
@@ -89,7 +110,7 @@ def main(args):
89110
run_time = time.time() - first_token_timestamp
90111
print(f"Prompt length: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps")
91112

92-
# Rewind the generator to the system prompt
113+
# Rewind the generator to the system prompt, this will erase all the memory of the model.
93114
if args.rewind:
94115
generator.rewind_to(system_prompt_length)
95116

@@ -108,6 +129,6 @@ def main(args):
108129
parser.add_argument('-g', '--timings', action='store_true', default=False, help='Print timing information for each generation step. Defaults to false')
109130
parser.add_argument('-c', '--chat_template', type=str, default='', help='Chat template to use for the prompt. User input will be injected into {input}')
110131
parser.add_argument('-s', '--system_prompt', type=str, default='You are a helpful AI assistant.', help='System prompt to use for the prompt.')
111-
parser.add_argument('-r', '--rewind', action='store_true', default=False, help='Rewind to the system prompt after each generation. Defaults to false')
132+
parser.add_argument('-r', '--rewind', action='store_true', default=True, help='Rewind to the system prompt after each generation. Defaults to true')
112133
args = parser.parse_args()
113134
main(args)

0 commit comments

Comments
 (0)