Skip to content

Commit a0b141a

Browse files
rdraskicTTkpaigwarsraizada-ttdjordje-tt
committed
Fix non-uniform seeding (#35906)
Co-authored-by: kpaigwar <kpaigwar@tenstorrent.com> Co-authored-by: Stuti Raizada <sraizada@tenstorrent.com> Co-authored-by: Djordje Ivanovic <divanovic@tenstorrent.com>
1 parent 6599616 commit a0b141a

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

models/tt_transformers/demo/simple_text_demo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,7 @@ def test_demo_text(
10011001
temperature=sampling_params["temperature"],
10021002
top_k=sampling_params["top_k"],
10031003
top_p=sampling_params["top_p"],
1004+
seed=sampling_params["seed"] if "seed" in sampling_params else None,
10041005
frequency_penalty=sampling_params["frequency_penalty"]
10051006
if "frequency_penalty" in sampling_params
10061007
else 0.0,
@@ -1110,6 +1111,7 @@ def test_demo_text(
11101111
enable_trace=enable_trace,
11111112
page_table=page_table,
11121113
kv_cache=tt_kv_cache,
1114+
reset_batch=(iteration == 0),
11131115
sampling_params=device_sampling_params,
11141116
prompt_tokens=input_tokens_prefill_pt,
11151117
output_tokens=out_tok,

models/tt_transformers/tt/generator.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,13 @@ def _apply_prefill_sampling_state(
7676
*,
7777
sampling_params: SamplingParams,
7878
prompt_tokens: torch.Tensor | None,
79+
empty_slots: list[int],
7980
):
80-
sampling_module = getattr(model_instance, "sampling_prefill", None)
81+
sampling_module = getattr(model_instance, "sampling", None)
8182
assert sampling_module is not None, "Sampling module not found in model for sampling on device."
8283
sampling_module.reset_sampling_params(sampling_params)
83-
sampling_module.reset_seed(sampling_params.seed)
84+
sampling_module.seed_manager.reset_seed(sampling_params.seed, empty_slots)
85+
sampling_module.seed_manager.get_new_values(empty_slots, replicate_seeds=True)
8486
if prompt_tokens is not None:
8587
sampling_module.reset_prompt_tokens(prompt_tokens)
8688
sampling_module.reset_output_state()
@@ -422,6 +424,7 @@ def prefill_forward_text(
422424
self.model[model_id],
423425
sampling_params=per_request_params,
424426
prompt_tokens=prefill_ids[:, :seq_len].repeat(32, 1),
427+
empty_slots=[user_id % 32],
425428
)
426429

427430
if enable_trace_current_prompt:
@@ -471,7 +474,7 @@ def prefill_forward_text(
471474
logits = self.model[model_id].process_logits_after_prefill_trace(logits, last_token_idx)
472475

473476
if sampling_enabled:
474-
tt_tokens, tt_log_probs = self.model[model_id].sampling_prefill.sample(
477+
tt_tokens, tt_log_probs = self.model[model_id].sampling.sample(
475478
logits,
476479
enable_trace=False,
477480
)
@@ -732,8 +735,8 @@ def decode_forward_text(
732735
sampling_module = getattr(self.model[i], "sampling", None)
733736
assert sampling_module is not None, "Sampling module not found in model for sampling on device."
734737
sampling_module.reset_sampling_params(formatted_params)
738+
sampling_module.seed_manager.get_new_values()
735739
if reset_batch:
736-
sampling_module.reset_seed(formatted_params.seed)
737740
sampling_module.reset_prompt_tokens(prompt_chunks[i])
738741
sampling_module.reset_output_state(output_chunks[i])
739742

models/tt_transformers/tt/model.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,6 @@ def __init__(
137137
sampling_splits = self.args.num_devices if list(self.mesh_device.shape) != [1, 1] else 2
138138
self._supports_on_device_sampling = self.args.vocab_size // sampling_splits <= 64 * 1024
139139
if self._supports_on_device_sampling:
140-
self.sampling_prefill = SamplingGenerator(
141-
args=args,
142-
mesh_device=mesh_device,
143-
tt_ccl=self.tt_ccl,
144-
)
145140
self.sampling = SamplingGenerator(
146141
args=args,
147142
mesh_device=mesh_device,

0 commit comments

Comments
 (0)