Skip to content

Commit 0636ce3

Browse files
authored
Add pre-generated prompts option for benchmark (#1091)
During benchmarking, we wanted to have pre-generated prompts that have been prepared for better benchmark result. Hence, It can be handy during benchmarking. In our test, we wanted to focus only token generation and sampling on SLM.
1 parent 4db5f2a commit 0636ce3

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

benchmark/python/benchmark_e2e.py

+21
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ def generate_prompt(model, tokenizer, prompt_length, use_graph_capture) -> str:
8282
generator.generate_next_token()
8383
return tokenizer.decode(generator.get_sequence(0))
8484

85+
# Use prompt length to get pre-defined prompt
86+
def get_prompt_by_length(prompt_length):
87+
json_path = "prompts.json"
88+
with open(json_path) as prompts_file:
89+
content = prompts_file.read()
90+
data = json.load(content)
91+
return data[f"{prompt_length}"]
92+
8593
def get_target_pip_package_version(target_pip_package_name_list):
8694
# get package name and version
8795
import pkg_resources
@@ -231,6 +239,18 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length
231239
# use random tokens instead of generating a prompt using the model and then tokenizing it
232240
tokens = np.random.randint(100, size=(batch_size, prompt_length))
233241
prompt = [tokenizer.decode(tokens[0])] * batch_size
242+
elif args.use_prompt_set:
243+
prompt = [get_prompt_by_length(prompt_length)] * batch_size
244+
tokens = tokenizer.encode_batch(prompt)
245+
246+
if len(tokens) > max_length:
247+
# Shorten the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
248+
tokens = tokens[:, :max_length]
249+
elif len(tokens) < max_length:
250+
# Lengthen the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
251+
tokens_first_col = tokens[:, 0].unsqueeze(0).T
252+
for _ in range(max_length - len(tokens)):
253+
tokens = np.hstack((tokens_first_col, tokens))
234254
else:
235255
prompt = [generate_prompt(model, tokenizer, prompt_length, args.use_graph_capture)] * batch_size
236256
tokens = tokenizer.encode_batch(prompt)
@@ -416,6 +436,7 @@ def str2strlist(value):
416436
parser.add_argument('-mn', '--model_name', type=str, default='model_name', help='Model name defined by users')
417437
parser.add_argument('-pr', '--precision', type=str, default='fp16', help='Model precision for metrics info')
418438
parser.add_argument('--use_random_tokens', action='store_true', help='Use random tokens instead of generating a prompt')
439+
parser.add_argument('--use_prompt_set', action='store_true', help='Use pre-generated prompt set instead of generating a prompt')
419440
args = parser.parse_args()
420441

421442
# check max_lengths

0 commit comments

Comments
 (0)