31
31
@dataclass
32
32
class GeneratorArgs :
33
33
prompt : str = "torchchat is pronounced torch-chat and is so cool because"
34
- chat : bool = ( False ,)
35
- gui : bool = ( False ,)
36
- num_samples : int = ( 1 ,)
37
- max_new_tokens : int = ( 200 ,)
38
- top_k : int = ( 200 ,)
39
- temperature : int = ( 0 ,) # deterministic argmax
40
- compile : bool = ( False ,)
41
- compile_prefill : bool = ( False ,)
42
- speculate_k : int = ( 5 ,)
34
+ chat_mode : bool = False
35
+ gui_mode : bool = False
36
+ num_samples : int = 1
37
+ max_new_tokens : int = 200
38
+ top_k : int = 200
39
+ temperature : int = 0 # deterministic argmax
40
+ compile : bool = False
41
+ compile_prefill : bool = False
42
+ speculate_k : int = 5
43
43
44
44
@classmethod
45
45
def from_args (cls , args ): # -> GeneratorArgs:
46
46
return cls (
47
47
prompt = args .prompt ,
48
- chat = args .chat ,
49
- gui = args .gui ,
48
+ chat_mode = args .chat ,
49
+ gui_mode = args .gui ,
50
50
num_samples = args .num_samples ,
51
51
max_new_tokens = args .max_new_tokens ,
52
52
top_k = args .top_k ,
@@ -316,9 +316,7 @@ def _main(
316
316
builder_args : BuilderArgs ,
317
317
speculative_builder_args : BuilderArgs ,
318
318
tokenizer_args : TokenizerArgs ,
319
- prompt : str = "Hello, my name is" ,
320
- chat_mode : bool = False ,
321
- num_samples : int = 5 ,
319
+ generator_args : GeneratorArgs ,
322
320
max_new_tokens : int = 100 ,
323
321
top_k : int = 200 ,
324
322
temperature : float = 0.8 ,
@@ -365,7 +363,9 @@ def _main(
365
363
else :
366
364
draft_model = None
367
365
368
- encoded = encode_tokens (tokenizer , prompt , bos = True , device = builder_args .device )
366
+ encoded = encode_tokens (
367
+ tokenizer , generator_args .prompt , bos = True , device = builder_args .device
368
+ )
369
369
print (encoded )
370
370
prompt_length = encoded .size (0 )
371
371
@@ -404,17 +404,17 @@ def _main(
404
404
}
405
405
start = - 1 if compile else 0
406
406
407
- for i in range (start , num_samples ):
407
+ for i in range (start , generator_args . num_samples ):
408
408
device_sync (device = builder_args .device )
409
- if i >= 0 and chat_mode :
409
+ if i >= 0 and generator_args . chat_mode :
410
410
prompt = input ("What is your prompt? " )
411
411
if is_chat :
412
412
prompt = f"{ B_INST } { prompt .strip ()} { E_INST } "
413
413
encoded = encode_tokens (
414
414
tokenizer , prompt , bos = True , device = builder_args .device
415
415
)
416
416
417
- if chat_mode and i >= 0 :
417
+ if generator_args . chat_mode and i >= 0 :
418
418
buffer = []
419
419
period_id = tokenizer .encode ("." )[0 ]
420
420
done_generating = False
@@ -436,7 +436,7 @@ def callback(x):
436
436
t0 = time .perf_counter ()
437
437
import contextlib
438
438
439
- if (i != num_samples - 1 or not profile ) or (use_tp and rank != 0 ):
439
+ if (i != generator_args . num_samples - 1 or not profile ) or (use_tp and rank != 0 ):
440
440
prof = contextlib .nullcontext ()
441
441
else :
442
442
torch .profiler ._utils ._init_for_cuda_graphs ()
@@ -448,7 +448,7 @@ def callback(x):
448
448
max_new_tokens ,
449
449
draft_model = draft_model ,
450
450
speculate_k = speculate_k ,
451
- chat_mode = chat_mode ,
451
+ chat_mode = generator_args . chat_mode ,
452
452
callback = callback ,
453
453
temperature = temperature ,
454
454
top_k = top_k ,
@@ -465,7 +465,7 @@ def callback(x):
465
465
device_sync (device = builder_args .device )
466
466
t = time .perf_counter () - t0
467
467
468
- if not chat_mode :
468
+ if not generator_args . chat_mode :
469
469
print (tokenizer .decode (y .tolist ()))
470
470
else :
471
471
print ()
@@ -495,13 +495,13 @@ def main(args):
495
495
builder_args = BuilderArgs .from_args (args )
496
496
speculative_builder_args = BuilderArgs .from_speculative_args (args )
497
497
tokenizer_args = TokenizerArgs .from_args (args )
498
+ generator_args = GeneratorArgs .from_args (args )
499
+
498
500
_main (
499
501
builder_args ,
500
502
speculative_builder_args ,
501
503
tokenizer_args ,
502
- args .prompt ,
503
- args .chat ,
504
- args .num_samples ,
504
+ generator_args ,
505
505
args .max_new_tokens ,
506
506
args .top_k ,
507
507
args .temperature ,
0 commit comments