@@ -82,6 +82,14 @@ def generate_prompt(model, tokenizer, prompt_length, use_graph_capture) -> str:
82
82
generator .generate_next_token ()
83
83
return tokenizer .decode (generator .get_sequence (0 ))
84
84
85
+ # Use prompt length to get pre-defined prompt
86
+ def get_prompt_by_length (prompt_length ):
87
+ json_path = "prompts.json"
88
+ with open (json_path ) as prompts_file :
89
+ content = prompts_file .read ()
90
+ data = json .load (content )
91
+ return data [f"{ prompt_length } " ]
92
+
85
93
def get_target_pip_package_version (target_pip_package_name_list ):
86
94
# get package name and version
87
95
import pkg_resources
@@ -231,6 +239,18 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length
231
239
# use random tokens instead of generating a prompt using the model and then tokenizing it
232
240
tokens = np .random .randint (100 , size = (batch_size , prompt_length ))
233
241
prompt = [tokenizer .decode (tokens [0 ])] * batch_size
242
+ elif args .use_prompt_set :
243
+ prompt = [get_prompt_by_length (prompt_length )] * batch_size
244
+ tokens = tokenizer .encode_batch (prompt )
245
+
246
+ if len (tokens ) > max_length :
247
+ # Shorten the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
248
+ tokens = tokens [:, :max_length ]
249
+ elif len (tokens ) < max_length :
250
+ # Lengthen the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
251
+ tokens_first_col = tokens [:, 0 ].unsqueeze (0 ).T
252
+ for _ in range (max_length - len (tokens )):
253
+ tokens = np .hstack ((tokens_first_col , tokens ))
234
254
else :
235
255
prompt = [generate_prompt (model , tokenizer , prompt_length , args .use_graph_capture )] * batch_size
236
256
tokens = tokenizer .encode_batch (prompt )
@@ -416,6 +436,7 @@ def str2strlist(value):
416
436
parser .add_argument ('-mn' , '--model_name' , type = str , default = 'model_name' , help = 'Model name defined by users' )
417
437
parser .add_argument ('-pr' , '--precision' , type = str , default = 'fp16' , help = 'Model precision for metrics info' )
418
438
parser .add_argument ('--use_random_tokens' , action = 'store_true' , help = 'Use random tokens instead of generating a prompt' )
439
+ parser .add_argument ('--use_prompt_set' , action = 'store_true' , help = 'Use pre-generated prompt set instead of generating a prompt' )
419
440
args = parser .parse_args ()
420
441
421
442
# check max_lengths
0 commit comments