Skip to content

Commit a9b09e0

Browse files
rdraskicTTkpaigwarsraizada-ttdjordje-tt
authored
Fix non-uniform seeding (#35906)
Co-authored-by: kpaigwar <kpaigwar@tenstorrent.com> Co-authored-by: Stuti Raizada <sraizada@tenstorrent.com> Co-authored-by: Djordje Ivanovic <divanovic@tenstorrent.com>
1 parent 558a196 commit a9b09e0

File tree

10 files changed

+155
-106
lines changed

10 files changed

+155
-106
lines changed

models/common/sampling/generator.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import random
6+
import secrets
67
from dataclasses import dataclass, fields, replace
78
from typing import List, Optional
89

@@ -60,6 +61,7 @@ def __init__(
6061
self._penalties_active = False
6162

6263
self._trace_states: dict[_TraceKey, dict] = {}
64+
self.seed_manager = SeedManager(self.tt_sampling)
6365

6466
def _new_trace_state(self):
6567
return {"id": None, "input": None, "output": None, "kwargs": {}}
@@ -263,24 +265,6 @@ def sample(
263265
self.tt_penalties.update_output_tokens(tt_out)
264266
return tt_out
265267

266-
def reset_seed(self, seed):
267-
for i, s in enumerate(seed):
268-
if s is None:
269-
# set to random seed to have variability while using tensor manual_seed
270-
seed[i] = random.randint(0, 1000000)
271-
seed = torch.tensor(seed)
272-
user_ids = torch.arange(seed.shape[0])
273-
274-
user_ids_tt = ttnn.from_torch(
275-
user_ids, device=self.mesh_device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT
276-
)
277-
seeds_tt = ttnn.from_torch(seed, device=self.mesh_device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT)
278-
279-
# reset seed for each user_id
280-
ttnn.manual_seed(seeds=seeds_tt, user_ids=user_ids_tt, sub_core_grids=self.sub_core_grids)
281-
seeds_tt.deallocate()
282-
user_ids_tt.deallocate()
283-
284268

285269
def clamp(value, min_value, max_value):
286270
if value < min_value:
@@ -307,7 +291,7 @@ def format_sampling_params(sampling_params, max_batch_size):
307291
"presence_penalty": 0.0,
308292
"frequency_penalty": 0.0,
309293
"repetition_penalty": 1.0,
310-
"seed": None,
294+
"seed": random.randint(0, 1000000), # set to random seed to have variability while using tensor manual_seed
311295
}
312296
target_len = max_batch_size
313297
assert target_len == 32, "Sampling only support batch_size=32"
@@ -369,3 +353,26 @@ def format_sampling_params(sampling_params, max_batch_size):
369353
if sampling_params.top_k[i] < 1:
370354
sampling_params.top_k[i] = 32 # k<1 means no restriction so set it to max k (32)
371355
return sampling_params
356+
357+
358+
class SeedManager:
359+
def __init__(self, tt_sampling):
360+
self.seeds = [secrets.randbits(64) for _ in range(32)]
361+
self.rngs = [random.Random(seed) for seed in self.seeds]
362+
self.tt_sampling = tt_sampling
363+
364+
def reset_seed(self, seeds, user_ids):
365+
for i, user in enumerate(user_ids):
366+
self.rngs[user].seed(seeds[i])
367+
self.seeds[user] = seeds[i]
368+
369+
def get_new_values(self, empty_slots=range(32), replicate_seeds=False):
370+
# get new seeds for each user in empty_slots otherwise 0
371+
new_seeds = [rng.randint(0, 1000000) if i in empty_slots else 0 for i, rng in enumerate(self.rngs)]
372+
373+
if replicate_seeds:
374+
assert len(empty_slots) == 1, "Cannot replicate seeds if empty_slots is not length 1"
375+
new_seeds = 32 * [new_seeds[empty_slots[0]]]
376+
# send new seeds to sampling module
377+
new_seed_tt = ttnn.from_torch(torch.tensor(new_seeds), dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT)
378+
ttnn.copy_host_to_device_tensor(new_seed_tt, self.tt_sampling.seeds_tt_tensor)

models/common/sampling/tt_sampling.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,21 @@ def __init__(
146146
self.tt_log_probs = None
147147
self.log_probs_calculator = LogProbsCalculator(self.mesh_device, self.sub_core_grids, self.tt_ccl)
148148

149+
self.seeds_tt_tensor = ttnn.as_tensor(
150+
torch.tensor(list(torch.arange(32)), dtype=torch.uint32),
151+
dtype=ttnn.uint32,
152+
layout=ttnn.ROW_MAJOR_LAYOUT,
153+
device=self.mesh_device,
154+
memory_config=ttnn.DRAM_MEMORY_CONFIG,
155+
)
156+
self.user_ids_tt_tensor = ttnn.as_tensor(
157+
torch.tensor(list(torch.arange(32)), dtype=torch.uint32),
158+
dtype=ttnn.uint32,
159+
layout=ttnn.ROW_MAJOR_LAYOUT,
160+
device=self.mesh_device,
161+
memory_config=ttnn.DRAM_MEMORY_CONFIG,
162+
)
163+
149164
def _create_indices_tensors(self):
150165
"""Create the indices tensors needed for distributed top-k operations."""
151166
# Create indices tensor for device offsets
@@ -395,7 +410,11 @@ def forward(
395410
topk_global_indices_interleaved, use_multicore=True, sub_core_grids=self.sub_core_grids
396411
)
397412
ttnn.deallocate(topk_global_indices_interleaved)
398-
413+
ttnn.manual_seed(
414+
seeds=self.seeds_tt_tensor,
415+
user_ids=self.user_ids_tt_tensor,
416+
sub_core_grids=self.sub_core_grids,
417+
)
399418
# Perform the actual sampling with top-k, top-p, and temperature
400419
tt_out_tok = ttnn.sampling(
401420
topk_values_gathered_bf16_interleaved,

models/demos/llama3_70b_galaxy/tt/generator.py

Lines changed: 74 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from loguru import logger
88
from typing import List
99
from collections import defaultdict
10+
from dataclasses import fields, replace
1011

1112
from llama_models.llama3.api.datatypes import (
1213
InterleavedTextMedia,
@@ -64,6 +65,18 @@ def __init__(self, model, model_args, mesh_device, tokenizer=None, formatter=Non
6465
self.trace_id_prefill = defaultdict(lambda: None)
6566
self.trace_inputs_prefill = defaultdict(lambda: None)
6667
self.trace_output_prefill = defaultdict(lambda: None)
68+
# Create persistent buffer for accumulated logits (used for on-device sampling)
69+
self.tt_logits_accumulated = [
70+
ttnn.from_torch(
71+
torch.zeros(1, 1, 1, self.model.args.padded_vocab_size // self.model_args.cluster_shape[0]),
72+
mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
73+
dtype=ttnn.bfloat8_b,
74+
device=self.mesh_device,
75+
layout=ttnn.TILE_LAYOUT,
76+
)
77+
for _ in range(self.model_args.max_batch_size)
78+
]
79+
self.tt_logits_accumulated_batched = [] # Temporary list for batched prefill
6780
self.prev_page_table = None
6881
self.prefill_traces_warmup = False
6982
self.trace_ids_decode = defaultdict(lambda: None) # {return_logits: {device_id: trace_id}}
@@ -146,7 +159,7 @@ def prefill_forward_text(
146159
kv_cache,
147160
prompt_lens,
148161
enable_trace,
149-
sampling_params,
162+
None,
150163
empty_slots,
151164
tt_out_logits_all_users,
152165
)
@@ -176,9 +189,7 @@ def prefill_forward_text(
176189
if (
177190
batch >= 16
178191
and len(set(prefill_seq_lens)) == 1
179-
and prefill_seq_lens[0] < 4 * 1024
180-
and tt_out_logits_all_users is None
181-
and not return_logits
192+
and prefill_seq_lens[0] == 128
182193
):
183194
use_batched_prefill = True
184195

@@ -192,7 +203,6 @@ def prefill_forward_text(
192203
do_device_sampling = (not return_logits) and (not save_logits_to_host)
193204

194205
# Accumulate sharded logits (same format as decode, before all-gather) for on-device sampling.
195-
tt_logits_accumulated = [] if do_device_sampling else None
196206

197207
all_users = [0] if use_batched_prefill else empty_slots
198208

@@ -255,6 +265,10 @@ def prefill_forward_text(
255265
prefill_kwargs["tt_out_logits_saved"] = tt_out_logits_saved
256266

257267
if enable_trace:
268+
# For batched prefill, reset to empty list since we use extend()
269+
# For non-batched prefill with device sampling, use persistent buffer from __init__
270+
if use_batched_prefill and do_device_sampling:
271+
self.tt_logits_accumulated_batched = []
258272
tt_tok = self._easy_trace_prefill(**prefill_kwargs, prefill_seq_len=prefill_seq_len)
259273
else:
260274
tt_tok = self.prefill_forward_single_user_text(**prefill_kwargs)
@@ -278,49 +292,64 @@ def prefill_forward_text(
278292
tt_logits_list = self.model.process_output_prefill_logits(tt_tok, last_token_idx=last_token_idx)
279293
if use_batched_prefill:
280294
# Batched prefill: logits list has 32 entries ordered by slot position
281-
tt_logits_accumulated.extend(tt_logits_list)
295+
self.tt_logits_accumulated_batched.extend(tt_logits_list)
282296
else:
283-
# Single user: logits list has 1 entry
284-
tt_logits_accumulated.append(ttnn.clone(tt_logits_list[0]))
285-
297+
# Single user: logits list has 1 entry, copy into persistent buffer
298+
ttnn.copy(input_a=tt_logits_list[0], input_b=self.tt_logits_accumulated[user_id])
286299
# On-device sampling for prefill
287-
if do_device_sampling and tt_logits_accumulated:
300+
if do_device_sampling:
288301
padded_batch = 32
289302

290-
# lm_head output is a list [logits_tensor], extract the tensor
291-
logits_tensors = [logits[0] if isinstance(logits, list) else logits for logits in tt_logits_accumulated]
292-
293-
if use_batched_prefill:
294-
# Batched prefill: logits already have 32 entries (one per slot), ordered by slot.
295-
tt_logits_batch = ttnn.concat(logits_tensors, dim=2)
296-
else:
297-
# Non-batched prefill: we have `batch` logits, need to pad to 32.
298-
# Logits are in batch order (same as tokens and sampling_params).
299-
if len(logits_tensors) > 1:
300-
tt_logits_batch = ttnn.concat(logits_tensors, dim=2)
301-
else:
302-
tt_logits_batch = logits_tensors[0]
303-
304-
# Pad to 32 users for sampling
305-
num_users = len(logits_tensors)
306-
if num_users < padded_batch:
307-
padding_needed = padded_batch - num_users
308-
padding_tensors = [logits_tensors[-1]] * padding_needed
309-
tt_logits_batch = ttnn.concat([tt_logits_batch] + padding_tensors, dim=2)
303+
# Use batched list for batched prefill, persistent buffer for non-batched
304+
logits_source = self.tt_logits_accumulated_batched if use_batched_prefill else self.tt_logits_accumulated
310305

306+
# Concatenate along slot dimension -> [1, 1, 1[32], vocab_shard]
307+
tt_logits_batch = ttnn.concat(logits_source, dim=2)
311308
# Sample using the sampling module
312309
# Logits are in sharded format (before all-gather), same as decode
313310
# sampling_params are already padded to 32 by format_sampling_params
314311
self.model.switch_mode("decode")
315312

316313
# Setting sampling module up after switch to decode mode
317314
sampling_params = format_sampling_params(sampling_params, self.model_args.max_batch_size)
315+
316+
# Reorder sampling params so values sit in their slot positions (except seed).
317+
def _scatter_params_to_slots(params, slots):
318+
max_batch = self.model_args.max_batch_size
319+
320+
def _scatter_list(values):
321+
if not isinstance(values, list):
322+
return values
323+
values = list(values)
324+
# Broadcast single-entry lists to match user count
325+
if len(values) == 1 and len(slots) > 1:
326+
values = values * len(slots)
327+
user_vals = values[: len(slots)]
328+
filler = values[len(slots)] if len(values) > len(slots) else values[-1]
329+
scattered = [filler for _ in range(max_batch)]
330+
for val, slot_idx in zip(user_vals, slots):
331+
scattered[slot_idx] = val
332+
return scattered
333+
334+
updates = {}
335+
for f in fields(SamplingParams):
336+
if f.name == "seed":
337+
# Seeds stay in original order; no reordering to slot indices.
338+
updates[f.name] = getattr(params, f.name)
339+
continue
340+
updates[f.name] = _scatter_list(getattr(params, f.name))
341+
return replace(params, **updates)
342+
343+
sampling_params = _scatter_params_to_slots(sampling_params, empty_slots)
344+
# print("sampling_params_scattered", sampling_params, "empty_slots", empty_slots)
318345
sampling_module = self.model.sampling
346+
319347
sampling_module.reset_sampling_params(sampling_params)
320348
# if prompt_tokens is not None: # Guard for warmup
321349
sampling_module.reset_prompt_tokens(prefill_ids)
322350
sampling_module.reset_output_state()
323-
sampling_module.reset_seed(sampling_params.seed)
351+
sampling_module.seed_manager.reset_seed(sampling_params.seed, empty_slots)
352+
sampling_module.seed_manager.get_new_values(empty_slots)
324353
tt_sampled, tt_log_probs = sampling_module.sample(
325354
tt_logits_batch,
326355
tt_out_tok=None,
@@ -333,14 +362,9 @@ def prefill_forward_text(
333362

334363
sampled_tokens = ttnn.to_torch(ttnn.get_device_tensors(tt_sampled)[0])
335364

336-
if use_batched_prefill:
337-
# Batched prefill: sampled_tokens has 32 entries ordered by slot.
338-
sampled_tensor = sampled_tokens[0, 0, 0, :] # Shape: [32]
339-
output_toks = sampled_tensor[empty_slots].reshape(batch, 1, 1)
340-
else:
341-
# Non-batched prefill: first `batch` entries are our results in batch order.
342-
for i in range(batch):
343-
output_toks[i] = sampled_tokens[0, 0, 0, i].item()
365+
# sampled_tokens has 32 entries ordered by slot.
366+
sampled_tensor = sampled_tokens[0, 0, 0, :] # Shape: [32]
367+
output_toks = sampled_tensor[empty_slots].reshape(batch, 1, 1)
344368

345369
if return_logits:
346370
# TODO: the current solution runs the argmax even if we are returning logits
@@ -523,6 +547,7 @@ def decode_forward_text(
523547
"is_cur_pos_sharded": is_cur_pos_sharded,
524548
"is_page_table_sharded": is_page_table_sharded,
525549
}
550+
self.model.sampling.seed_manager.get_new_values()
526551
if reset_inputs and sampling_params is not None:
527552
# If we have new inputs, we need to set up the sampling module again
528553
sampling_params = format_sampling_params(sampling_params, self.model_args.max_batch_size)
@@ -532,7 +557,6 @@ def decode_forward_text(
532557
if reset_batch:
533558
sampling_module.reset_prompt_tokens(prompt_tokens)
534559
sampling_module.reset_output_state(output_tokens)
535-
sampling_module.reset_seed(sampling_params.seed)
536560

537561
if tt_out_logits_saved is not None:
538562
decode_kwargs["tt_out_logits_saved"] = tt_out_logits_saved
@@ -834,18 +858,16 @@ def warmup_model_prefill(self, kv_cache, enable_trace, sampling_params) -> None:
834858
# page_table gets padded properly in prefill_forward_text
835859
# be sure to pad correctly for non traced sequences in future warmup calls
836860
page_table = torch.zeros(1, 1, dtype=torch.int32)
837-
# in case of multiple sampling parameters, we need to warmup for each one
838-
for s in sampling_params:
839-
self.warmup_prefill_traces(
840-
tokens=None,
841-
page_table=page_table,
842-
kv_cache=kv_cache,
843-
prompt_lens=None,
844-
enable_trace=enable_trace,
845-
sampling_params=s,
846-
empty_slots=None,
847-
tt_out_logits_all_users=None,
848-
)
861+
self.warmup_prefill_traces(
862+
tokens=None,
863+
page_table=page_table,
864+
kv_cache=kv_cache,
865+
prompt_lens=None,
866+
enable_trace=enable_trace,
867+
sampling_params=None,
868+
empty_slots=None,
869+
tt_out_logits_all_users=None,
870+
)
849871

850872
## Destructor (used to delete ttnn trace if exists)
851873

models/demos/llama3_70b_galaxy/tt/llama_attention.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,9 @@ def forward_prefill(
753753
is_causal=True,
754754
scale=self.scale,
755755
compute_kernel_config=self.compute_kernel_config_hifi4,
756-
program_config=self.model_config["SDPA_PROGCFG"](seq_len),
756+
program_config=self.model_config["SDPA_PROGCFG"](
757+
seq_len // batch_size if seq_len // batch_size == 128 else seq_len
758+
),
757759
)
758760

759761
# deallocate keys and values
@@ -830,15 +832,16 @@ def forward_prefill(
830832
ttnn.deallocate(attn_output_11SH)
831833

832834
# Reduce-scatter
833-
output_11SH = self.tt_ccl.line_all_reduce(
835+
output_11SH_reduced = self.tt_ccl.line_all_reduce(
834836
output_11SH,
835837
cluster_axis=0,
836838
num_links=3,
837839
memory_config=ttnn.DRAM_MEMORY_CONFIG,
838-
buffer_key="WO",
840+
buffer_key="WO_AG" if seq_len <= 4096 else "WO",
839841
)
842+
output_11SH.deallocate()
840843

841-
return output_11SH
844+
return output_11SH_reduced
842845

843846
def forward(
844847
self,

0 commit comments

Comments
 (0)