Skip to content

Commit 55aa360

Browse files
authored
Use generator args to group all arguments to generator (#231)
* prompt * chat_mode, num_samples
1 parent f8236e4 commit 55aa360

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

generate.py

+24-24
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,22 @@
3131
@dataclass
3232
class GeneratorArgs:
3333
prompt: str = "torchchat is pronounced torch-chat and is so cool because"
34-
chat: bool = (False,)
35-
gui: bool = (False,)
36-
num_samples: int = (1,)
37-
max_new_tokens: int = (200,)
38-
top_k: int = (200,)
39-
temperature: int = (0,) # deterministic argmax
40-
compile: bool = (False,)
41-
compile_prefill: bool = (False,)
42-
speculate_k: int = (5,)
34+
chat_mode: bool = False
35+
gui_mode: bool = False
36+
num_samples: int = 1
37+
max_new_tokens: int = 200
38+
top_k: int = 200
39+
temperature: int = 0 # deterministic argmax
40+
compile: bool = False
41+
compile_prefill: bool = False
42+
speculate_k: int = 5
4343

4444
@classmethod
4545
def from_args(cls, args): # -> GeneratorArgs:
4646
return cls(
4747
prompt=args.prompt,
48-
chat=args.chat,
49-
gui=args.gui,
48+
chat_mode=args.chat,
49+
gui_mode=args.gui,
5050
num_samples=args.num_samples,
5151
max_new_tokens=args.max_new_tokens,
5252
top_k=args.top_k,
@@ -316,9 +316,7 @@ def _main(
316316
builder_args: BuilderArgs,
317317
speculative_builder_args: BuilderArgs,
318318
tokenizer_args: TokenizerArgs,
319-
prompt: str = "Hello, my name is",
320-
chat_mode: bool = False,
321-
num_samples: int = 5,
319+
generator_args: GeneratorArgs,
322320
max_new_tokens: int = 100,
323321
top_k: int = 200,
324322
temperature: float = 0.8,
@@ -365,7 +363,9 @@ def _main(
365363
else:
366364
draft_model = None
367365

368-
encoded = encode_tokens(tokenizer, prompt, bos=True, device=builder_args.device)
366+
encoded = encode_tokens(
367+
tokenizer, generator_args.prompt, bos=True, device=builder_args.device
368+
)
369369
print(encoded)
370370
prompt_length = encoded.size(0)
371371

@@ -404,17 +404,17 @@ def _main(
404404
}
405405
start = -1 if compile else 0
406406

407-
for i in range(start, num_samples):
407+
for i in range(start, generator_args.num_samples):
408408
device_sync(device=builder_args.device)
409-
if i >= 0 and chat_mode:
409+
if i >= 0 and generator_args.chat_mode:
410410
prompt = input("What is your prompt? ")
411411
if is_chat:
412412
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
413413
encoded = encode_tokens(
414414
tokenizer, prompt, bos=True, device=builder_args.device
415415
)
416416

417-
if chat_mode and i >= 0:
417+
if generator_args.chat_mode and i >= 0:
418418
buffer = []
419419
period_id = tokenizer.encode(".")[0]
420420
done_generating = False
@@ -436,7 +436,7 @@ def callback(x):
436436
t0 = time.perf_counter()
437437
import contextlib
438438

439-
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
439+
if (i != generator_args.num_samples - 1 or not profile) or (use_tp and rank != 0):
440440
prof = contextlib.nullcontext()
441441
else:
442442
torch.profiler._utils._init_for_cuda_graphs()
@@ -448,7 +448,7 @@ def callback(x):
448448
max_new_tokens,
449449
draft_model=draft_model,
450450
speculate_k=speculate_k,
451-
chat_mode=chat_mode,
451+
chat_mode=generator_args.chat_mode,
452452
callback=callback,
453453
temperature=temperature,
454454
top_k=top_k,
@@ -465,7 +465,7 @@ def callback(x):
465465
device_sync(device=builder_args.device)
466466
t = time.perf_counter() - t0
467467

468-
if not chat_mode:
468+
if not generator_args.chat_mode:
469469
print(tokenizer.decode(y.tolist()))
470470
else:
471471
print()
@@ -495,13 +495,13 @@ def main(args):
495495
builder_args = BuilderArgs.from_args(args)
496496
speculative_builder_args = BuilderArgs.from_speculative_args(args)
497497
tokenizer_args = TokenizerArgs.from_args(args)
498+
generator_args = GeneratorArgs.from_args(args)
499+
498500
_main(
499501
builder_args,
500502
speculative_builder_args,
501503
tokenizer_args,
502-
args.prompt,
503-
args.chat,
504-
args.num_samples,
504+
generator_args,
505505
args.max_new_tokens,
506506
args.top_k,
507507
args.temperature,

0 commit comments

Comments
 (0)