31
31
@dataclass
32
32
class GeneratorArgs :
33
33
prompt : str = "torchchat is pronounced torch-chat and is so cool because"
34
+ encoded_prompt : Optional [torch .Tensor ] = None
34
35
chat_mode : bool = False
35
36
gui_mode : bool = False
36
37
num_samples : int = 1
@@ -45,6 +46,7 @@ class GeneratorArgs:
45
46
def from_args (cls , args ): # -> GeneratorArgs:
46
47
return cls (
47
48
prompt = args .prompt ,
49
+ encoded_prompt = None ,
48
50
chat_mode = args .chat ,
49
51
gui_mode = args .gui ,
50
52
num_samples = args .num_samples ,
@@ -305,7 +307,7 @@ def generate(
305
307
return seq , generate_stats
306
308
307
309
308
- def encode_tokens (tokenizer , string , bos = True , device = "cuda " ):
310
+ def encode_tokens (tokenizer , string , bos = True , device = "cpu " ):
309
311
tokens = tokenizer .encode (string )
310
312
if bos :
311
313
tokens = [tokenizer .bos_id ()] + tokens
@@ -317,13 +319,9 @@ def _main(
317
319
speculative_builder_args : BuilderArgs ,
318
320
tokenizer_args : TokenizerArgs ,
319
321
generator_args : GeneratorArgs ,
320
- max_new_tokens : int = 100 ,
321
- top_k : int = 200 ,
322
- temperature : float = 0.8 ,
323
322
compile : bool = True ,
324
323
compile_prefill : bool = False ,
325
324
profile : Optional [Path ] = None ,
326
- speculate_k : int = 5 ,
327
325
quantize = None ,
328
326
) -> None :
329
327
"""Generates text samples based on a pre-trained Transformer model and tokenizer."""
@@ -436,6 +434,7 @@ def callback(x):
436
434
t0 = time .perf_counter ()
437
435
import contextlib
438
436
437
+ generator_args .encoded_prompt = encoded
439
438
if (i != generator_args .num_samples - 1 or not profile ) or (use_tp and rank != 0 ):
440
439
prof = contextlib .nullcontext ()
441
440
else :
@@ -445,13 +444,13 @@ def callback(x):
445
444
y , metrics = generate (
446
445
model ,
447
446
encoded ,
448
- max_new_tokens ,
447
+ generator_args . max_new_tokens ,
449
448
draft_model = draft_model ,
450
- speculate_k = speculate_k ,
449
+ speculate_k = generator_args . speculate_k ,
451
450
chat_mode = generator_args .chat_mode ,
452
451
callback = callback ,
453
- temperature = temperature ,
454
- top_k = top_k ,
452
+ temperature = generator_args . temperature ,
453
+ top_k = generator_args . top_k ,
455
454
)
456
455
aggregate_metrics ["accept_counts" ].append (metrics ["accept_counts" ])
457
456
if i == - 1 :
@@ -502,13 +501,9 @@ def main(args):
502
501
speculative_builder_args ,
503
502
tokenizer_args ,
504
503
generator_args ,
505
- args .max_new_tokens ,
506
- args .top_k ,
507
- args .temperature ,
508
504
args .compile ,
509
505
args .compile_prefill ,
510
506
args .profile ,
511
- args .speculate_k ,
512
507
args .quantize ,
513
508
)
514
509
0 commit comments