2424)
2525from models .common .sampling .generator import format_sampling_params
2626from models .tt_transformers .tt .generator import SamplingParams
27+ from models .demos .llama3_70b_galaxy .tt .llama_attention import should_use_ring_distributed_sdpa
2728
2829
2930def get_padded_prefill_len (seq_len : int ) -> int :
@@ -97,7 +98,7 @@ def __init__(self, model, model_args, mesh_device, tokenizer=None, formatter=Non
9798 # Split sampling: decode trace captures transformer only, sampling runs separately
9899 self .enable_split_sampling = True # Decode trace returns logits, sampling is separate
99100 self .model .enable_internal_trace = self .enable_split_sampling # NEVER trace sampling - causes buffer corruption
100- self ._disable_prefill_tracing = False # Whether to disable prefill traces
101+ self ._disable_prefill_tracing = True # Whether to disable prefill traces
101102 self ._disable_decode_tracing = False # Whether to disable decode traces
102103 self ._trace_debug_seq = 0 # Monotonic seq for hang-debug logs (trace capture/replay ordering)
103104
@@ -223,7 +224,9 @@ def prefill_forward_text(
223224 all_users = [0 ] if use_batched_prefill else empty_slots
224225
225226 for id , user_id in enumerate (all_users ):
226- logger .info (f"Prefilling User { user_id + 1 } , use_batched_prefill: { use_batched_prefill } " )
227+ logger .info (
228+ f"Prefilling User { user_id } , use_batched_prefill: { use_batched_prefill } , prompt_lens: { prompt_lens [id ]} , prefill_seq_len: { prefill_seq_lens [id ]} , num_cached_tokens: { num_cached_tokens_list [id ]} "
229+ )
227230 if use_batched_prefill :
228231 user_id = empty_slots
229232 last_token_idx = [(seq_len - 1 ) for seq_len in prompt_lens ]
@@ -253,13 +256,36 @@ def prefill_forward_text(
253256 # Extract tokens skipping cached ones
254257 num_cached_tokens = num_cached_tokens_list [id ]
255258 new_tokens_len = seq_len - num_cached_tokens
259+ tail_start = max (0 , seq_len - 8 )
260+ logger .info (
261+ "[PREFILL_INPUT_DEBUG] user_id={} seq_len={} last_token_idx={} tail_tokens={}" ,
262+ user_id ,
263+ seq_len ,
264+ last_token_idx ,
265+ tokens [id , tail_start :seq_len ].tolist (),
266+ )
256267 prefill_ids = torch .cat (
257268 [
258269 tokens [id : id + 1 , num_cached_tokens :seq_len ], # Skip cached tokens
259270 torch .zeros (1 , prefill_seq_len - new_tokens_len ).long (),
260271 ],
261272 dim = - 1 ,
262273 )
274+ # Extra debug: sanity-check padding and checksum for non-traced debugging
275+ try :
276+ pad_len = int (prefill_seq_len - new_tokens_len )
277+ checksum = int (prefill_ids .sum ().item ())
278+ except Exception :
279+ pad_len = "unavailable"
280+ checksum = "unavailable"
281+ logger .info (
282+ "[PREFILL_IDS_DEBUG] user_id={} new_tokens_len={} prefill_seq_len={} pad_len={} checksum={}" ,
283+ user_id ,
284+ new_tokens_len ,
285+ prefill_seq_len ,
286+ pad_len ,
287+ checksum ,
288+ )
263289
264290 if page_table is not None :
265291 # For prefix caching, page_table includes both cached and new blocks
@@ -270,6 +296,23 @@ def prefill_forward_text(
270296 user_id ,
271297 use_batched_prefill , # Use full seq_len including cached
272298 )
299+ # Debug: summarize the active user's row (first few entries) for the prefill page table
300+ try :
301+ row = page_table_user [user_id if not use_batched_prefill else 0 , :].to (torch .int32 )
302+ first = row [:16 ].tolist ()
303+ neg = int ((row < 0 ).sum ().item ())
304+ zero = int ((row == 0 ).sum ().item ())
305+ logger .info (
306+ "[PREFILL_PAGETABLE_DEBUG] user_id={} prefill_len={} blocks={} first16={} neg={} zero={}" ,
307+ user_id if not use_batched_prefill else "batched" ,
308+ num_cached_tokens + prefill_seq_len ,
309+ int (page_table_user .shape [1 ]),
310+ first ,
311+ neg ,
312+ zero ,
313+ )
314+ except Exception as e :
315+ logger .info ("[PREFILL_PAGETABLE_DEBUG] failed to summarize page_table_user: {}" , e )
273316 # remove the first user from the page table
274317 page_table = page_table [1 :, :]
275318
@@ -283,6 +326,20 @@ def prefill_forward_text(
283326 "num_cached_tokens" : num_cached_tokens_list [id ] if not use_batched_prefill else 0 ,
284327 }
285328
329+ # Remember prompt/prefill metadata for later decode debugging (non-traced focus)
330+ try :
331+ if not hasattr (self , "_debug_last_prefill_meta" ):
332+ self ._debug_last_prefill_meta = {}
333+ if not use_batched_prefill and isinstance (user_id , int ):
334+ self ._debug_last_prefill_meta [user_id ] = {
335+ "seq_len" : int (seq_len ),
336+ "last_token_idx" : int (last_token_idx ),
337+ "prefill_seq_len" : int (prefill_seq_len ),
338+ "num_cached_tokens" : int (num_cached_tokens ),
339+ }
340+ except Exception :
341+ pass
342+
286343 # If PCC check enabled or return_logits is True (we save output logits)
287344 if tt_out_logits_all_users is not None or return_logits :
288345 tt_out_logits_saved = torch .zeros (1 , self .model .args .padded_vocab_size )
@@ -477,7 +534,8 @@ def _easy_trace_prefill(
477534 Trace key is (prefill_seq_len, batch_size) only; page_table/chunk_page_table
478535 are padded to fixed shapes so one trace serves any num_cached_blocks.
479536 """
480- if isinstance (last_token_idx , (list , tuple )):
537+ is_last_token_list = isinstance (last_token_idx , (list , tuple ))
538+ if is_last_token_list and batch_size == 1 :
481539 last_token_idx = last_token_idx [user_id ] if isinstance (user_id , int ) else last_token_idx [0 ]
482540
483541 # Extract single user's page table row for batch_size=1
@@ -508,11 +566,34 @@ def _easy_trace_prefill(
508566 page_table = _pad_or_create_page_table (page_table , max_blocks_prefill )
509567 chunk_page_table = _pad_or_create_page_table (chunk_page_table , chunk_blocks )
510568
511- trace_key = f"{ prefill_seq_len } _{ batch_size } "
569+ use_ring_sdpa = should_use_ring_distributed_sdpa (prefill_seq_len , batch_size , chunk_start_idx )
570+ use_start_pos = "sp1" if (chunk_start_idx is not None and chunk_start_idx > 0 ) else "sp0"
571+ trace_key = f"{ prefill_seq_len } _{ batch_size } _{ 'ring' if use_ring_sdpa else 'no_ring' } _{ use_start_pos } "
572+
573+ def _table_preview (tensor ):
574+ if tensor is None :
575+ return "None"
576+ flat = tensor .flatten ()
577+ preview = flat [:8 ].tolist ()
578+ return f"shape={ tuple (tensor .shape )} , first={ preview } "
579+
580+ logger .info (
581+ "[PREFILL_TRACE_DEBUG] key={} start_pos={} ring_sdpa={} page_table={} chunk_page_table={}" ,
582+ trace_key ,
583+ chunk_start_idx ,
584+ use_ring_sdpa ,
585+ _table_preview (page_table ),
586+ _table_preview (chunk_page_table ),
587+ )
512588
513589 # For prefix caching, the model output has only prefill_seq_len positions (the chunk).
514590 # get_last_token must be the relative index within the chunk (0..prefill_seq_len-1).
515- last_token_idx_for_trace = last_token_idx - num_cached_tokens
591+ if is_last_token_list and batch_size > 1 :
592+ last_token_idx_for_trace = (
593+ [idx - num_cached_tokens for idx in last_token_idx ] if use_prefix_caching else list (last_token_idx )
594+ )
595+ else :
596+ last_token_idx_for_trace = last_token_idx - num_cached_tokens
516597
517598 if self .trace_id_prefill [trace_key ] is None :
518599 trace_id , tt_out_trace , * device_inputs = self ._capture_trace_prefill (
@@ -528,6 +609,15 @@ def _easy_trace_prefill(
528609 self .trace_id_prefill [trace_key ] = trace_id
529610 self .trace_inputs_prefill [trace_key ] = device_inputs
530611 self .trace_output_prefill [trace_key ] = tt_out_trace
612+ # Debug: print input tokens checksum before trace
613+ tokens_checksum = tokens .sum ().item () if hasattr (tokens , "sum" ) else "N/A"
614+ logger .info (
615+ "[PREFILL_TRACE_INPUT] trace_key={} tokens_shape={} tokens_checksum={} user_id={}" ,
616+ trace_key ,
617+ tuple (tokens .shape ),
618+ tokens_checksum ,
619+ user_id ,
620+ )
531621 tt_out_trace = self ._prefill_forward_trace_text (
532622 self .trace_id_prefill [trace_key ],
533623 self .trace_inputs_prefill [trace_key ],
@@ -539,8 +629,14 @@ def _easy_trace_prefill(
539629 batch_size = batch_size ,
540630 start_pos = chunk_start_idx , # For position_ids generation
541631 )
632+ logger .info ("[PREFILL_TRACE_OUTPUT] trace completed for key={}" , trace_key )
542633 # Compute last_token_idx_relative for output processing
543- last_token_idx_for_output = last_token_idx - num_cached_tokens if use_prefix_caching else last_token_idx
634+ if is_last_token_list and batch_size > 1 :
635+ last_token_idx_for_output = (
636+ [idx - num_cached_tokens for idx in last_token_idx ] if use_prefix_caching else list (last_token_idx )
637+ )
638+ else :
639+ last_token_idx_for_output = last_token_idx - num_cached_tokens if use_prefix_caching else last_token_idx
544640 toks = self .model .process_output_prefill (
545641 tt_out_trace ,
546642 last_token_idx = last_token_idx_for_output ,
@@ -564,6 +660,14 @@ def _capture_trace_prefill(
564660 Captures a trace for the prefill_forward method with prefix caching support.
565661 Uses full rot mats + chunk_start_idx device tensor; slice is inside the trace.
566662 """
663+ logger .info (
664+ "[PREFILL_CAPTURE_DEBUG] tokens_shape={} last_token_idx={} user_id={} start_pos={} batch_size={}" ,
665+ tuple (tokens .shape ),
666+ last_token_idx ,
667+ user_id ,
668+ start_pos ,
669+ batch_size ,
670+ )
567671 # Get host tensors (tokens, user_id, page_table, chunk_page_table, chunk_start_idx)
568672 host_inputs = self .model .prepare_prefill_inputs_host (
569673 tokens ,
@@ -705,6 +809,54 @@ def decode_forward_text(
705809 if getattr (self , "_disable_decode_tracing" , False ):
706810 enable_trace = False
707811
812+ # Decode input debug (outside traced region)
813+ try :
814+ tok0 = tokens [:8 , :].reshape (- 1 )[:8 ].tolist () if isinstance (tokens , torch .Tensor ) else "unavailable"
815+ except Exception :
816+ tok0 = "unavailable"
817+ try :
818+ pos0 = start_pos [:8 ].tolist () if isinstance (start_pos , torch .Tensor ) else start_pos
819+ except Exception :
820+ pos0 = "unavailable"
821+ logger .info (
822+ "[DECODE_INPUT_DEBUG] enable_trace={} tokens_shape={} tokens_first8={} start_pos_first8={}" ,
823+ enable_trace ,
824+ tuple (tokens .shape ) if hasattr (tokens , "shape" ) else "unavailable" ,
825+ tok0 ,
826+ pos0 ,
827+ )
828+ if page_table is not None and isinstance (page_table , torch .Tensor ):
829+ try :
830+ row0 = page_table [0 , :64 ].to (torch .int32 )
831+ logger .info (
832+ "[DECODE_PAGETABLE_DEBUG] page_table_shape={} row0_first16={} row0_neg={} row0_zero={} row0_max={}" ,
833+ tuple (page_table .shape ),
834+ row0 [:16 ].tolist (),
835+ int ((row0 < 0 ).sum ().item ()),
836+ int ((row0 == 0 ).sum ().item ()),
837+ int (row0 .max ().item ()) if row0 .numel () > 0 else "unavailable" ,
838+ )
839+ except Exception as e :
840+ logger .info ("[DECODE_PAGETABLE_DEBUG] failed to summarize page_table: {}" , e )
841+ # Compare decode start_pos[0] vs last prefill metadata, if available
842+ try :
843+ if (
844+ hasattr (self , "_debug_last_prefill_meta" )
845+ and 0 in self ._debug_last_prefill_meta
846+ and isinstance (start_pos , torch .Tensor )
847+ ):
848+ meta = self ._debug_last_prefill_meta [0 ]
849+ logger .info (
850+ "[DECODE_PREFILL_XCHECK] start_pos0={} prefill_last_token_idx={} prefill_seq_len={} prompt_seq_len={} num_cached_tokens={}" ,
851+ int (start_pos [0 ].item ()),
852+ meta .get ("last_token_idx" ),
853+ meta .get ("prefill_seq_len" ),
854+ meta .get ("seq_len" ),
855+ meta .get ("num_cached_tokens" ),
856+ )
857+ except Exception :
858+ pass
859+
708860 if sampling_params is None :
709861 return_logits = True
710862 reset_inputs = True
@@ -1012,6 +1164,22 @@ def _get_prefill_user_page_table(self, page_table, kv_cache, prefill_len, user_i
10121164 # Ensure page_table is not padded with extra blocks for paged_fill_cache to work properly
10131165 block_size = get_block_size (kv_cache )
10141166 num_blocks = num_blocks_in_seq (prefill_len , block_size )
1167+ # Debug: inspect incoming page_table row 0 (or first batch row) before slicing
1168+ try :
1169+ if isinstance (page_table , torch .Tensor ):
1170+ row0 = page_table [0 , : min (64 , page_table .shape [1 ])].to (torch .int32 )
1171+ logger .info (
1172+ "[PREFILL_PAGETABLE_SRC_DEBUG] prefill_len={} block_size={} num_blocks={} src_shape={} row0_first16={} row0_neg={} row0_zero={}" ,
1173+ int (prefill_len ),
1174+ int (block_size ),
1175+ int (num_blocks ),
1176+ tuple (page_table .shape ),
1177+ row0 [:16 ].tolist (),
1178+ int ((row0 < 0 ).sum ().item ()),
1179+ int ((row0 == 0 ).sum ().item ()),
1180+ )
1181+ except Exception as e :
1182+ logger .info ("[PREFILL_PAGETABLE_SRC_DEBUG] failed: {}" , e )
10151183 page_table = page_table [:, :num_blocks ]
10161184 if page_table .shape [1 ] < num_blocks :
10171185 # If page table is too short, pad it with -1
@@ -1024,6 +1192,20 @@ def _get_prefill_user_page_table(self, page_table, kv_cache, prefill_len, user_i
10241192 padded_page_table [user , :] = page_table [i , :]
10251193 else :
10261194 padded_page_table [user_id , :] = page_table [0 , :]
1195+ # Debug: inspect resulting padded page_table for selected user
1196+ try :
1197+ if isinstance (padded_page_table , torch .Tensor ) and (not use_batched_prefill ) and isinstance (user_id , int ):
1198+ row = padded_page_table [user_id , : min (64 , padded_page_table .shape [1 ])].to (torch .int32 )
1199+ logger .info (
1200+ "[PREFILL_PAGETABLE_DST_DEBUG] user_id={} dst_shape={} row_first16={} row_neg={} row_zero={}" ,
1201+ user_id ,
1202+ tuple (padded_page_table .shape ),
1203+ row [:16 ].tolist (),
1204+ int ((row < 0 ).sum ().item ()),
1205+ int ((row == 0 ).sum ().item ()),
1206+ )
1207+ except Exception :
1208+ pass
10271209 return padded_page_table
10281210
10291211 def warmup_model_prefill (self , kv_cache , enable_trace , sampling_params ) -> None :
0 commit comments