Skip to content

Commit 5e92be5

Browse files
committed
Debugging
1 parent 6632dbd commit 5e92be5

File tree

5 files changed

+319
-19
lines changed

5 files changed

+319
-19
lines changed

models/demos/llama3_70b_galaxy/tt/generator.py

Lines changed: 188 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from models.common.sampling.generator import format_sampling_params
2626
from models.tt_transformers.tt.generator import SamplingParams
27+
from models.demos.llama3_70b_galaxy.tt.llama_attention import should_use_ring_distributed_sdpa
2728

2829

2930
def get_padded_prefill_len(seq_len: int) -> int:
@@ -97,7 +98,7 @@ def __init__(self, model, model_args, mesh_device, tokenizer=None, formatter=Non
9798
# Split sampling: decode trace captures transformer only, sampling runs separately
9899
self.enable_split_sampling = True # Decode trace returns logits, sampling is separate
99100
self.model.enable_internal_trace = self.enable_split_sampling # NEVER trace sampling - causes buffer corruption
100-
self._disable_prefill_tracing = False # Whether to disable prefill traces
101+
self._disable_prefill_tracing = True # Whether to disable prefill traces
101102
self._disable_decode_tracing = False # Whether to disable decode traces
102103
self._trace_debug_seq = 0 # Monotonic seq for hang-debug logs (trace capture/replay ordering)
103104

@@ -223,7 +224,9 @@ def prefill_forward_text(
223224
all_users = [0] if use_batched_prefill else empty_slots
224225

225226
for id, user_id in enumerate(all_users):
226-
logger.info(f"Prefilling User {user_id + 1}, use_batched_prefill: {use_batched_prefill}")
227+
logger.info(
228+
f"Prefilling User {user_id}, use_batched_prefill: {use_batched_prefill}, prompt_lens: {prompt_lens[id]}, prefill_seq_len: {prefill_seq_lens[id]}, num_cached_tokens: {num_cached_tokens_list[id]}"
229+
)
227230
if use_batched_prefill:
228231
user_id = empty_slots
229232
last_token_idx = [(seq_len - 1) for seq_len in prompt_lens]
@@ -253,13 +256,36 @@ def prefill_forward_text(
253256
# Extract tokens skipping cached ones
254257
num_cached_tokens = num_cached_tokens_list[id]
255258
new_tokens_len = seq_len - num_cached_tokens
259+
tail_start = max(0, seq_len - 8)
260+
logger.info(
261+
"[PREFILL_INPUT_DEBUG] user_id={} seq_len={} last_token_idx={} tail_tokens={}",
262+
user_id,
263+
seq_len,
264+
last_token_idx,
265+
tokens[id, tail_start:seq_len].tolist(),
266+
)
256267
prefill_ids = torch.cat(
257268
[
258269
tokens[id : id + 1, num_cached_tokens:seq_len], # Skip cached tokens
259270
torch.zeros(1, prefill_seq_len - new_tokens_len).long(),
260271
],
261272
dim=-1,
262273
)
274+
# Extra debug: sanity-check padding and checksum for non-traced debugging
275+
try:
276+
pad_len = int(prefill_seq_len - new_tokens_len)
277+
checksum = int(prefill_ids.sum().item())
278+
except Exception:
279+
pad_len = "unavailable"
280+
checksum = "unavailable"
281+
logger.info(
282+
"[PREFILL_IDS_DEBUG] user_id={} new_tokens_len={} prefill_seq_len={} pad_len={} checksum={}",
283+
user_id,
284+
new_tokens_len,
285+
prefill_seq_len,
286+
pad_len,
287+
checksum,
288+
)
263289

264290
if page_table is not None:
265291
# For prefix caching, page_table includes both cached and new blocks
@@ -270,6 +296,23 @@ def prefill_forward_text(
270296
user_id,
271297
use_batched_prefill, # Use full seq_len including cached
272298
)
299+
# Debug: summarize the active user's row (first few entries) for the prefill page table
300+
try:
301+
row = page_table_user[user_id if not use_batched_prefill else 0, :].to(torch.int32)
302+
first = row[:16].tolist()
303+
neg = int((row < 0).sum().item())
304+
zero = int((row == 0).sum().item())
305+
logger.info(
306+
"[PREFILL_PAGETABLE_DEBUG] user_id={} prefill_len={} blocks={} first16={} neg={} zero={}",
307+
user_id if not use_batched_prefill else "batched",
308+
num_cached_tokens + prefill_seq_len,
309+
int(page_table_user.shape[1]),
310+
first,
311+
neg,
312+
zero,
313+
)
314+
except Exception as e:
315+
logger.info("[PREFILL_PAGETABLE_DEBUG] failed to summarize page_table_user: {}", e)
273316
# remove the first user from the page table
274317
page_table = page_table[1:, :]
275318

@@ -283,6 +326,20 @@ def prefill_forward_text(
283326
"num_cached_tokens": num_cached_tokens_list[id] if not use_batched_prefill else 0,
284327
}
285328

329+
# Remember prompt/prefill metadata for later decode debugging (non-traced focus)
330+
try:
331+
if not hasattr(self, "_debug_last_prefill_meta"):
332+
self._debug_last_prefill_meta = {}
333+
if not use_batched_prefill and isinstance(user_id, int):
334+
self._debug_last_prefill_meta[user_id] = {
335+
"seq_len": int(seq_len),
336+
"last_token_idx": int(last_token_idx),
337+
"prefill_seq_len": int(prefill_seq_len),
338+
"num_cached_tokens": int(num_cached_tokens),
339+
}
340+
except Exception:
341+
pass
342+
286343
# If PCC check enabled or return_logits is True (we save output logits)
287344
if tt_out_logits_all_users is not None or return_logits:
288345
tt_out_logits_saved = torch.zeros(1, self.model.args.padded_vocab_size)
@@ -477,7 +534,8 @@ def _easy_trace_prefill(
477534
Trace key is (prefill_seq_len, batch_size) only; page_table/chunk_page_table
478535
are padded to fixed shapes so one trace serves any num_cached_blocks.
479536
"""
480-
if isinstance(last_token_idx, (list, tuple)):
537+
is_last_token_list = isinstance(last_token_idx, (list, tuple))
538+
if is_last_token_list and batch_size == 1:
481539
last_token_idx = last_token_idx[user_id] if isinstance(user_id, int) else last_token_idx[0]
482540

483541
# Extract single user's page table row for batch_size=1
@@ -508,11 +566,34 @@ def _easy_trace_prefill(
508566
page_table = _pad_or_create_page_table(page_table, max_blocks_prefill)
509567
chunk_page_table = _pad_or_create_page_table(chunk_page_table, chunk_blocks)
510568

511-
trace_key = f"{prefill_seq_len}_{batch_size}"
569+
use_ring_sdpa = should_use_ring_distributed_sdpa(prefill_seq_len, batch_size, chunk_start_idx)
570+
use_start_pos = "sp1" if (chunk_start_idx is not None and chunk_start_idx > 0) else "sp0"
571+
trace_key = f"{prefill_seq_len}_{batch_size}_{'ring' if use_ring_sdpa else 'no_ring'}_{use_start_pos}"
572+
573+
def _table_preview(tensor):
574+
if tensor is None:
575+
return "None"
576+
flat = tensor.flatten()
577+
preview = flat[:8].tolist()
578+
return f"shape={tuple(tensor.shape)}, first={preview}"
579+
580+
logger.info(
581+
"[PREFILL_TRACE_DEBUG] key={} start_pos={} ring_sdpa={} page_table={} chunk_page_table={}",
582+
trace_key,
583+
chunk_start_idx,
584+
use_ring_sdpa,
585+
_table_preview(page_table),
586+
_table_preview(chunk_page_table),
587+
)
512588

513589
# For prefix caching, the model output has only prefill_seq_len positions (the chunk).
514590
# get_last_token must be the relative index within the chunk (0..prefill_seq_len-1).
515-
last_token_idx_for_trace = last_token_idx - num_cached_tokens
591+
if is_last_token_list and batch_size > 1:
592+
last_token_idx_for_trace = (
593+
[idx - num_cached_tokens for idx in last_token_idx] if use_prefix_caching else list(last_token_idx)
594+
)
595+
else:
596+
last_token_idx_for_trace = last_token_idx - num_cached_tokens
516597

517598
if self.trace_id_prefill[trace_key] is None:
518599
trace_id, tt_out_trace, *device_inputs = self._capture_trace_prefill(
@@ -528,6 +609,15 @@ def _easy_trace_prefill(
528609
self.trace_id_prefill[trace_key] = trace_id
529610
self.trace_inputs_prefill[trace_key] = device_inputs
530611
self.trace_output_prefill[trace_key] = tt_out_trace
612+
# Debug: print input tokens checksum before trace
613+
tokens_checksum = tokens.sum().item() if hasattr(tokens, "sum") else "N/A"
614+
logger.info(
615+
"[PREFILL_TRACE_INPUT] trace_key={} tokens_shape={} tokens_checksum={} user_id={}",
616+
trace_key,
617+
tuple(tokens.shape),
618+
tokens_checksum,
619+
user_id,
620+
)
531621
tt_out_trace = self._prefill_forward_trace_text(
532622
self.trace_id_prefill[trace_key],
533623
self.trace_inputs_prefill[trace_key],
@@ -539,8 +629,14 @@ def _easy_trace_prefill(
539629
batch_size=batch_size,
540630
start_pos=chunk_start_idx, # For position_ids generation
541631
)
632+
logger.info("[PREFILL_TRACE_OUTPUT] trace completed for key={}", trace_key)
542633
# Compute last_token_idx_relative for output processing
543-
last_token_idx_for_output = last_token_idx - num_cached_tokens if use_prefix_caching else last_token_idx
634+
if is_last_token_list and batch_size > 1:
635+
last_token_idx_for_output = (
636+
[idx - num_cached_tokens for idx in last_token_idx] if use_prefix_caching else list(last_token_idx)
637+
)
638+
else:
639+
last_token_idx_for_output = last_token_idx - num_cached_tokens if use_prefix_caching else last_token_idx
544640
toks = self.model.process_output_prefill(
545641
tt_out_trace,
546642
last_token_idx=last_token_idx_for_output,
@@ -564,6 +660,14 @@ def _capture_trace_prefill(
564660
Captures a trace for the prefill_forward method with prefix caching support.
565661
Uses full rot mats + chunk_start_idx device tensor; slice is inside the trace.
566662
"""
663+
logger.info(
664+
"[PREFILL_CAPTURE_DEBUG] tokens_shape={} last_token_idx={} user_id={} start_pos={} batch_size={}",
665+
tuple(tokens.shape),
666+
last_token_idx,
667+
user_id,
668+
start_pos,
669+
batch_size,
670+
)
567671
# Get host tensors (tokens, user_id, page_table, chunk_page_table, chunk_start_idx)
568672
host_inputs = self.model.prepare_prefill_inputs_host(
569673
tokens,
@@ -705,6 +809,54 @@ def decode_forward_text(
705809
if getattr(self, "_disable_decode_tracing", False):
706810
enable_trace = False
707811

812+
# Decode input debug (outside traced region)
813+
try:
814+
tok0 = tokens[:8, :].reshape(-1)[:8].tolist() if isinstance(tokens, torch.Tensor) else "unavailable"
815+
except Exception:
816+
tok0 = "unavailable"
817+
try:
818+
pos0 = start_pos[:8].tolist() if isinstance(start_pos, torch.Tensor) else start_pos
819+
except Exception:
820+
pos0 = "unavailable"
821+
logger.info(
822+
"[DECODE_INPUT_DEBUG] enable_trace={} tokens_shape={} tokens_first8={} start_pos_first8={}",
823+
enable_trace,
824+
tuple(tokens.shape) if hasattr(tokens, "shape") else "unavailable",
825+
tok0,
826+
pos0,
827+
)
828+
if page_table is not None and isinstance(page_table, torch.Tensor):
829+
try:
830+
row0 = page_table[0, :64].to(torch.int32)
831+
logger.info(
832+
"[DECODE_PAGETABLE_DEBUG] page_table_shape={} row0_first16={} row0_neg={} row0_zero={} row0_max={}",
833+
tuple(page_table.shape),
834+
row0[:16].tolist(),
835+
int((row0 < 0).sum().item()),
836+
int((row0 == 0).sum().item()),
837+
int(row0.max().item()) if row0.numel() > 0 else "unavailable",
838+
)
839+
except Exception as e:
840+
logger.info("[DECODE_PAGETABLE_DEBUG] failed to summarize page_table: {}", e)
841+
# Compare decode start_pos[0] vs last prefill metadata, if available
842+
try:
843+
if (
844+
hasattr(self, "_debug_last_prefill_meta")
845+
and 0 in self._debug_last_prefill_meta
846+
and isinstance(start_pos, torch.Tensor)
847+
):
848+
meta = self._debug_last_prefill_meta[0]
849+
logger.info(
850+
"[DECODE_PREFILL_XCHECK] start_pos0={} prefill_last_token_idx={} prefill_seq_len={} prompt_seq_len={} num_cached_tokens={}",
851+
int(start_pos[0].item()),
852+
meta.get("last_token_idx"),
853+
meta.get("prefill_seq_len"),
854+
meta.get("seq_len"),
855+
meta.get("num_cached_tokens"),
856+
)
857+
except Exception:
858+
pass
859+
708860
if sampling_params is None:
709861
return_logits = True
710862
reset_inputs = True
@@ -1012,6 +1164,22 @@ def _get_prefill_user_page_table(self, page_table, kv_cache, prefill_len, user_i
10121164
# Ensure page_table is not padded with extra blocks for paged_fill_cache to work properly
10131165
block_size = get_block_size(kv_cache)
10141166
num_blocks = num_blocks_in_seq(prefill_len, block_size)
1167+
# Debug: inspect incoming page_table row 0 (or first batch row) before slicing
1168+
try:
1169+
if isinstance(page_table, torch.Tensor):
1170+
row0 = page_table[0, : min(64, page_table.shape[1])].to(torch.int32)
1171+
logger.info(
1172+
"[PREFILL_PAGETABLE_SRC_DEBUG] prefill_len={} block_size={} num_blocks={} src_shape={} row0_first16={} row0_neg={} row0_zero={}",
1173+
int(prefill_len),
1174+
int(block_size),
1175+
int(num_blocks),
1176+
tuple(page_table.shape),
1177+
row0[:16].tolist(),
1178+
int((row0 < 0).sum().item()),
1179+
int((row0 == 0).sum().item()),
1180+
)
1181+
except Exception as e:
1182+
logger.info("[PREFILL_PAGETABLE_SRC_DEBUG] failed: {}", e)
10151183
page_table = page_table[:, :num_blocks]
10161184
if page_table.shape[1] < num_blocks:
10171185
# If page table is too short, pad it with -1
@@ -1024,6 +1192,20 @@ def _get_prefill_user_page_table(self, page_table, kv_cache, prefill_len, user_i
10241192
padded_page_table[user, :] = page_table[i, :]
10251193
else:
10261194
padded_page_table[user_id, :] = page_table[0, :]
1195+
# Debug: inspect resulting padded page_table for selected user
1196+
try:
1197+
if isinstance(padded_page_table, torch.Tensor) and (not use_batched_prefill) and isinstance(user_id, int):
1198+
row = padded_page_table[user_id, : min(64, padded_page_table.shape[1])].to(torch.int32)
1199+
logger.info(
1200+
"[PREFILL_PAGETABLE_DST_DEBUG] user_id={} dst_shape={} row_first16={} row_neg={} row_zero={}",
1201+
user_id,
1202+
tuple(padded_page_table.shape),
1203+
row[:16].tolist(),
1204+
int((row < 0).sum().item()),
1205+
int((row == 0).sum().item()),
1206+
)
1207+
except Exception:
1208+
pass
10271209
return padded_page_table
10281210

10291211
def warmup_model_prefill(self, kv_cache, enable_trace, sampling_params) -> None:

models/demos/llama3_70b_galaxy/tt/llama_attention.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,15 @@
44

55
import torch
66
import ttnn
7+
from loguru import logger
78
from models.common.lightweightmodule import LightweightModule
89
from models.common.rmsnorm import RMSNorm
910

1011

12+
def should_use_ring_distributed_sdpa(seq_len: int, batch_size: int, chunk_start_idx) -> bool:
13+
return seq_len > 1024 and batch_size == 1 and (chunk_start_idx is None or chunk_start_idx == 0)
14+
15+
1116
class TtLlamaAttention(LightweightModule):
1217
def __init__(
1318
self,
@@ -734,8 +739,10 @@ def forward_prefill(
734739

735740
user_id_for_mask = None # Will be set if page_table is provided
736741
if page_table:
737-
# If chunked prefill, use chunk_page_table if given, otherwise use page_table.
738-
fill_page_table = chunk_page_table if chunk_page_table is not None else page_table
742+
# Use chunk_page_table only for prefix-cached prefill (chunk_start_idx > 0).
743+
# For non-prefix prefill, ignore chunk_page_table (trace may pass a dummy) and use page_table.
744+
use_chunk_for_fill = chunk_start_idx is not None and chunk_start_idx > 0
745+
fill_page_table = chunk_page_table if (use_chunk_for_fill and chunk_page_table is not None) else page_table
739746

740747
# Each shard gets one row, which is locally at index 0
741748
ttnn.experimental.paged_fill_cache(keys_BKSD, k_fill, fill_page_table, batch_idx=0)
@@ -759,7 +766,16 @@ def forward_prefill(
759766

760767
# Run ring_distributed_sdpa for > 1k seqlen because we are seeing worse perf for <=1k seqlen as compared to regular SDPA
761768
# ring_distributed_sdpa needs seqlen//8 to be atleast one tile (32)
762-
ring_distributed_sdpa = seq_len > 1024 and batch_size == 1 and (chunk_start_idx is None or chunk_start_idx == 0)
769+
ring_distributed_sdpa = should_use_ring_distributed_sdpa(seq_len, batch_size, chunk_start_idx)
770+
use_chunked_sdpa = chunk_start_idx is not None and chunk_start_idx > 0
771+
logger.info(
772+
"[PREFILL_SDPA_DEBUG] seq_len={} batch_size={} chunk_start_idx={} ring_sdpa={} chunked_sdpa={}",
773+
seq_len,
774+
batch_size,
775+
chunk_start_idx,
776+
ring_distributed_sdpa,
777+
use_chunked_sdpa,
778+
)
763779

764780
if ring_distributed_sdpa:
765781
k_tensor = k_heads_1KSD_8b
@@ -781,7 +797,7 @@ def forward_prefill(
781797
else:
782798
# When using prefix caching (chunk_start_idx provided), use chunked SDPA with KV cache tensors.
783799
# Flexible path: chunk_start_idx_tensor so one trace works for any chunk_start at replay.
784-
if chunk_start_idx is not None and chunk_start_idx > 0:
800+
if use_chunked_sdpa:
785801
assert page_table is not None, "page_table must be provided for prefix caching"
786802
assert (
787803
chunk_start_idx_tensor is not None

models/demos/llama3_70b_galaxy/tt/llama_ccl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(
104104
self.rs_create_heads_buffers = self.get_decode_rs_create_heads_buffers()
105105
if mode == "prefill":
106106
# For some prefill seqlens we always allocate CCL buffers. Otherwise they will require barrier syncing
107-
self.support_seqlens = [4096, 2048, 1024, 128]
107+
self.support_seqlens = [1024, 128]
108108
if allocate_prefill_buffers:
109109
self.persistent_buffers = (
110110
self.get_ring_prefill_reduce_scatter_buffers()

0 commit comments

Comments
 (0)