Skip to content

Commit 1ea7739

Browse files
authored
Move more generator args to use dataclass (#233)
* prompt * chat_mode, num_samples * move more args * more gen args * update * args * undo some changes * typos
1 parent 55aa360 commit 1ea7739

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

generate.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
@dataclass
3232
class GeneratorArgs:
3333
prompt: str = "torchchat is pronounced torch-chat and is so cool because"
34+
encoded_prompt: Optional[torch.Tensor] = None
3435
chat_mode: bool = False
3536
gui_mode: bool = False
3637
num_samples: int = 1
@@ -45,6 +46,7 @@ class GeneratorArgs:
4546
def from_args(cls, args): # -> GeneratorArgs:
4647
return cls(
4748
prompt=args.prompt,
49+
encoded_prompt=None,
4850
chat_mode=args.chat,
4951
gui_mode=args.gui,
5052
num_samples=args.num_samples,
@@ -305,7 +307,7 @@ def generate(
305307
return seq, generate_stats
306308

307309

308-
def encode_tokens(tokenizer, string, bos=True, device="cuda"):
310+
def encode_tokens(tokenizer, string, bos=True, device="cpu"):
309311
tokens = tokenizer.encode(string)
310312
if bos:
311313
tokens = [tokenizer.bos_id()] + tokens
@@ -317,13 +319,9 @@ def _main(
317319
speculative_builder_args: BuilderArgs,
318320
tokenizer_args: TokenizerArgs,
319321
generator_args: GeneratorArgs,
320-
max_new_tokens: int = 100,
321-
top_k: int = 200,
322-
temperature: float = 0.8,
323322
compile: bool = True,
324323
compile_prefill: bool = False,
325324
profile: Optional[Path] = None,
326-
speculate_k: int = 5,
327325
quantize=None,
328326
) -> None:
329327
"""Generates text samples based on a pre-trained Transformer model and tokenizer."""
@@ -436,6 +434,7 @@ def callback(x):
436434
t0 = time.perf_counter()
437435
import contextlib
438436

437+
generator_args.encoded_prompt = encoded
439438
if (i != generator_args.num_samples - 1 or not profile) or (use_tp and rank != 0):
440439
prof = contextlib.nullcontext()
441440
else:
@@ -445,13 +444,13 @@ def callback(x):
445444
y, metrics = generate(
446445
model,
447446
encoded,
448-
max_new_tokens,
447+
generator_args.max_new_tokens,
449448
draft_model=draft_model,
450-
speculate_k=speculate_k,
449+
speculate_k=generator_args.speculate_k,
451450
chat_mode=generator_args.chat_mode,
452451
callback=callback,
453-
temperature=temperature,
454-
top_k=top_k,
452+
temperature=generator_args.temperature,
453+
top_k=generator_args.top_k,
455454
)
456455
aggregate_metrics["accept_counts"].append(metrics["accept_counts"])
457456
if i == -1:
@@ -502,13 +501,9 @@ def main(args):
502501
speculative_builder_args,
503502
tokenizer_args,
504503
generator_args,
505-
args.max_new_tokens,
506-
args.top_k,
507-
args.temperature,
508504
args.compile,
509505
args.compile_prefill,
510506
args.profile,
511-
args.speculate_k,
512507
args.quantize,
513508
)
514509

0 commit comments

Comments
 (0)