77from loguru import logger
88from typing import List
99from collections import defaultdict
10+ from dataclasses import fields , replace
1011
1112from llama_models .llama3 .api .datatypes import (
1213 InterleavedTextMedia ,
@@ -64,6 +65,18 @@ def __init__(self, model, model_args, mesh_device, tokenizer=None, formatter=Non
6465 self .trace_id_prefill = defaultdict (lambda : None )
6566 self .trace_inputs_prefill = defaultdict (lambda : None )
6667 self .trace_output_prefill = defaultdict (lambda : None )
68+ # Create persistent buffer for accumulated logits (used for on-device sampling)
69+ self .tt_logits_accumulated = [
70+ ttnn .from_torch (
71+ torch .zeros (1 , 1 , 1 , self .model .args .padded_vocab_size // self .model_args .cluster_shape [0 ]),
72+ mesh_mapper = ttnn .ReplicateTensorToMesh (self .mesh_device ),
73+ dtype = ttnn .bfloat8_b ,
74+ device = self .mesh_device ,
75+ layout = ttnn .TILE_LAYOUT ,
76+ )
77+ for _ in range (self .model_args .max_batch_size )
78+ ]
79+ self .tt_logits_accumulated_batched = [] # Temporary list for batched prefill
6780 self .prev_page_table = None
6881 self .prefill_traces_warmup = False
6982 self .trace_ids_decode = defaultdict (lambda : None ) # {return_logits: {device_id: trace_id}}
@@ -146,7 +159,7 @@ def prefill_forward_text(
146159 kv_cache ,
147160 prompt_lens ,
148161 enable_trace ,
149- sampling_params ,
162+ None ,
150163 empty_slots ,
151164 tt_out_logits_all_users ,
152165 )
@@ -176,9 +189,7 @@ def prefill_forward_text(
176189 if (
177190 batch >= 16
178191 and len (set (prefill_seq_lens )) == 1
179- and prefill_seq_lens [0 ] < 4 * 1024
180- and tt_out_logits_all_users is None
181- and not return_logits
192+ and prefill_seq_lens [0 ] == 128
182193 ):
183194 use_batched_prefill = True
184195
@@ -192,7 +203,6 @@ def prefill_forward_text(
192203 do_device_sampling = (not return_logits ) and (not save_logits_to_host )
193204
194205 # Accumulate sharded logits (same format as decode, before all-gather) for on-device sampling.
195- tt_logits_accumulated = [] if do_device_sampling else None
196206
197207 all_users = [0 ] if use_batched_prefill else empty_slots
198208
@@ -255,6 +265,10 @@ def prefill_forward_text(
255265 prefill_kwargs ["tt_out_logits_saved" ] = tt_out_logits_saved
256266
257267 if enable_trace :
268+ # For batched prefill, reset to empty list since we use extend()
269+ # For non-batched prefill with device sampling, use persistent buffer from __init__
270+ if use_batched_prefill and do_device_sampling :
271+ self .tt_logits_accumulated_batched = []
258272 tt_tok = self ._easy_trace_prefill (** prefill_kwargs , prefill_seq_len = prefill_seq_len )
259273 else :
260274 tt_tok = self .prefill_forward_single_user_text (** prefill_kwargs )
@@ -278,49 +292,64 @@ def prefill_forward_text(
278292 tt_logits_list = self .model .process_output_prefill_logits (tt_tok , last_token_idx = last_token_idx )
279293 if use_batched_prefill :
280294 # Batched prefill: logits list has 32 entries ordered by slot position
281- tt_logits_accumulated .extend (tt_logits_list )
295+ self . tt_logits_accumulated_batched .extend (tt_logits_list )
282296 else :
283- # Single user: logits list has 1 entry
284- tt_logits_accumulated .append (ttnn .clone (tt_logits_list [0 ]))
285-
297+ # Single user: logits list has 1 entry, copy into persistent buffer
298+ ttnn .copy (input_a = tt_logits_list [0 ], input_b = self .tt_logits_accumulated [user_id ])
286299 # On-device sampling for prefill
287- if do_device_sampling and tt_logits_accumulated :
300+ if do_device_sampling :
288301 padded_batch = 32
289302
290- # lm_head output is a list [logits_tensor], extract the tensor
291- logits_tensors = [logits [0 ] if isinstance (logits , list ) else logits for logits in tt_logits_accumulated ]
292-
293- if use_batched_prefill :
294- # Batched prefill: logits already have 32 entries (one per slot), ordered by slot.
295- tt_logits_batch = ttnn .concat (logits_tensors , dim = 2 )
296- else :
297- # Non-batched prefill: we have `batch` logits, need to pad to 32.
298- # Logits are in batch order (same as tokens and sampling_params).
299- if len (logits_tensors ) > 1 :
300- tt_logits_batch = ttnn .concat (logits_tensors , dim = 2 )
301- else :
302- tt_logits_batch = logits_tensors [0 ]
303-
304- # Pad to 32 users for sampling
305- num_users = len (logits_tensors )
306- if num_users < padded_batch :
307- padding_needed = padded_batch - num_users
308- padding_tensors = [logits_tensors [- 1 ]] * padding_needed
309- tt_logits_batch = ttnn .concat ([tt_logits_batch ] + padding_tensors , dim = 2 )
303+ # Use batched list for batched prefill, persistent buffer for non-batched
304+ logits_source = self .tt_logits_accumulated_batched if use_batched_prefill else self .tt_logits_accumulated
310305
306+ # Concatenate along slot dimension -> [1, 1, 1[32], vocab_shard]
307+ tt_logits_batch = ttnn .concat (logits_source , dim = 2 )
311308 # Sample using the sampling module
312309 # Logits are in sharded format (before all-gather), same as decode
313310 # sampling_params are already padded to 32 by format_sampling_params
314311 self .model .switch_mode ("decode" )
315312
316313 # Setting sampling module up after switch to decode mode
317314 sampling_params = format_sampling_params (sampling_params , self .model_args .max_batch_size )
315+
316+ # Reorder sampling params so values sit in their slot positions (except seed).
317+ def _scatter_params_to_slots (params , slots ):
318+ max_batch = self .model_args .max_batch_size
319+
320+ def _scatter_list (values ):
321+ if not isinstance (values , list ):
322+ return values
323+ values = list (values )
324+ # Broadcast single-entry lists to match user count
325+ if len (values ) == 1 and len (slots ) > 1 :
326+ values = values * len (slots )
327+ user_vals = values [: len (slots )]
328+ filler = values [len (slots )] if len (values ) > len (slots ) else values [- 1 ]
329+ scattered = [filler for _ in range (max_batch )]
330+ for val , slot_idx in zip (user_vals , slots ):
331+ scattered [slot_idx ] = val
332+ return scattered
333+
334+ updates = {}
335+ for f in fields (SamplingParams ):
336+ if f .name == "seed" :
337+ # Seeds stay in original order; no reordering to slot indices.
338+ updates [f .name ] = getattr (params , f .name )
339+ continue
340+ updates [f .name ] = _scatter_list (getattr (params , f .name ))
341+ return replace (params , ** updates )
342+
343+ sampling_params = _scatter_params_to_slots (sampling_params , empty_slots )
344+ # print("sampling_params_scattered", sampling_params, "empty_slots", empty_slots)
318345 sampling_module = self .model .sampling
346+
319347 sampling_module .reset_sampling_params (sampling_params )
320348 # if prompt_tokens is not None: # Guard for warmup
321349 sampling_module .reset_prompt_tokens (prefill_ids )
322350 sampling_module .reset_output_state ()
323- sampling_module .reset_seed (sampling_params .seed )
351+ sampling_module .seed_manager .reset_seed (sampling_params .seed , empty_slots )
352+ sampling_module .seed_manager .get_new_values (empty_slots )
324353 tt_sampled , tt_log_probs = sampling_module .sample (
325354 tt_logits_batch ,
326355 tt_out_tok = None ,
@@ -333,14 +362,9 @@ def prefill_forward_text(
333362
334363 sampled_tokens = ttnn .to_torch (ttnn .get_device_tensors (tt_sampled )[0 ])
335364
336- if use_batched_prefill :
337- # Batched prefill: sampled_tokens has 32 entries ordered by slot.
338- sampled_tensor = sampled_tokens [0 , 0 , 0 , :] # Shape: [32]
339- output_toks = sampled_tensor [empty_slots ].reshape (batch , 1 , 1 )
340- else :
341- # Non-batched prefill: first `batch` entries are our results in batch order.
342- for i in range (batch ):
343- output_toks [i ] = sampled_tokens [0 , 0 , 0 , i ].item ()
365+ # sampled_tokens has 32 entries ordered by slot.
366+ sampled_tensor = sampled_tokens [0 , 0 , 0 , :] # Shape: [32]
367+ output_toks = sampled_tensor [empty_slots ].reshape (batch , 1 , 1 )
344368
345369 if return_logits :
346370 # TODO: the current solution runs the argmax even if we are returning logits
@@ -523,6 +547,7 @@ def decode_forward_text(
523547 "is_cur_pos_sharded" : is_cur_pos_sharded ,
524548 "is_page_table_sharded" : is_page_table_sharded ,
525549 }
550+ self .model .sampling .seed_manager .get_new_values ()
526551 if reset_inputs and sampling_params is not None :
527552 # If we have new inputs, we need to set up the sampling module again
528553 sampling_params = format_sampling_params (sampling_params , self .model_args .max_batch_size )
@@ -532,7 +557,6 @@ def decode_forward_text(
532557 if reset_batch :
533558 sampling_module .reset_prompt_tokens (prompt_tokens )
534559 sampling_module .reset_output_state (output_tokens )
535- sampling_module .reset_seed (sampling_params .seed )
536560
537561 if tt_out_logits_saved is not None :
538562 decode_kwargs ["tt_out_logits_saved" ] = tt_out_logits_saved
@@ -834,18 +858,16 @@ def warmup_model_prefill(self, kv_cache, enable_trace, sampling_params) -> None:
834858 # page_table gets padded properly in prefill_forward_text
835859 # be sure to pad correctly for non traced sequences in future warmup calls
836860 page_table = torch .zeros (1 , 1 , dtype = torch .int32 )
837- # in case of multiple sampling parameters, we need to warmup for each one
838- for s in sampling_params :
839- self .warmup_prefill_traces (
840- tokens = None ,
841- page_table = page_table ,
842- kv_cache = kv_cache ,
843- prompt_lens = None ,
844- enable_trace = enable_trace ,
845- sampling_params = s ,
846- empty_slots = None ,
847- tt_out_logits_all_users = None ,
848- )
861+ self .warmup_prefill_traces (
862+ tokens = None ,
863+ page_table = page_table ,
864+ kv_cache = kv_cache ,
865+ prompt_lens = None ,
866+ enable_trace = enable_trace ,
867+ sampling_params = None ,
868+ empty_slots = None ,
869+ tt_out_logits_all_users = None ,
870+ )
849871
850872 ## Destructor (used to delete ttnn trace if exists)
851873
0 commit comments