Skip to content

Commit 387359e

Browse files
UbuntuUbuntu
authored andcommitted
Enable get_last_token for batch>1 prefill, single-call paged_fill_cache
- model.py: When batch_size>1, extract each user's 32-token last-token tile via ttnn.slice and concat to [1,1,B*32,H] before norm+lm_head. Removes the batch>1 get_last_token=-1 override in ttnn_prefill_forward. - attention/prefill.py: Replace per-user paged_fill_cache loop with single-call reshape approach (flatten batch into seq dim, heads into last dim, flatten page table). Matches llama_70b_galaxy pattern. - text_demo.py: Remove batch>1 get_last_token override, use seq_len_per_user=32 when get_last_token is active. batch128 users_per_row_per_iter=2 TTFT: 218ms -> 99ms (get_last_token fix) Compile time: 5.24s -> 3.16s (single-call paged_fill_cache) batch128 users_per_row_per_iter=1: unchanged at 91ms
1 parent 2c88901 commit 387359e

File tree

3 files changed

+45
-27
lines changed

3 files changed

+45
-27
lines changed

models/demos/gpt_oss/demo/text_demo.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -599,17 +599,14 @@ def test_gpt_oss_demo(
599599

600600
# Compute fixed get_last_token for trace (all users must be in same 32-token tile)
601601
all_last_idxs = [int(decoding_pos[uid]) - 1 for uid in range(global_batch_size)]
602-
if users_per_row_per_iter > 1:
603-
fixed_get_last_token = -1 # Can't use get_last_token with batch>1
604-
else:
605-
fixed_get_last_token = (min(all_last_idxs) // 32) * 32
606-
max_tile_start = (max(all_last_idxs) // 32) * 32
607-
if fixed_get_last_token != max_tile_start:
608-
logger.warning(
609-
f"Users span multiple 32-token tiles ({fixed_get_last_token} vs {max_tile_start}), "
610-
f"using get_last_token=-1 (slower)"
611-
)
612-
fixed_get_last_token = -1
602+
fixed_get_last_token = (min(all_last_idxs) // 32) * 32
603+
max_tile_start = (max(all_last_idxs) // 32) * 32
604+
if fixed_get_last_token != max_tile_start:
605+
logger.warning(
606+
f"Users span multiple 32-token tiles ({fixed_get_last_token} vs {max_tile_start}), "
607+
f"using get_last_token=-1 (slower)"
608+
)
609+
fixed_get_last_token = -1
613610

614611
def _prepare_batch_host(user_indices):
615612
"""Prepare host-side tokens + page_table for a batch of users."""
@@ -669,7 +666,7 @@ def _prepare_batch_host(user_indices):
669666
tt_logits,
670667
[idx % 32 for idx in last_w],
671668
users_per_row=users_per_row_per_iter,
672-
seq_len_per_user=max_padded_len,
669+
seq_len_per_user=32,
673670
)
674671
for row, uid in enumerate(warmup_indices):
675672
prefilled_token[uid] = torch.argmax(warmup_results[row].view(-1)).item()
@@ -753,7 +750,7 @@ def _prepare_batch_host(user_indices):
753750
tt_out_trace,
754751
[idx % 32 for idx in last_i],
755752
users_per_row=users_per_row_per_iter,
756-
seq_len_per_user=max_padded_len,
753+
seq_len_per_user=32,
757754
)
758755
for row, uid in enumerate(user_indices):
759756
prefilled_token[uid] = torch.argmax(row_results[row].view(-1)).item()
@@ -804,7 +801,10 @@ def _run_batched_prefill_iter(iter_idx, user_indices):
804801
batched_prefill=True,
805802
)
806803

807-
get_last_token_val = (max(batch_last_token_idxs) // 32) * 32 if users_per_row_per_iter == 1 else -1
804+
# Use get_last_token if all users' last tokens fall in the same 32-token tile
805+
min_tile = (min(batch_last_token_idxs) // 32) * 32
806+
max_tile = (max(batch_last_token_idxs) // 32) * 32
807+
get_last_token_val = min_tile if min_tile == max_tile else -1
808808
tt_logits = model[model_id].ttnn_prefill_forward(
809809
tokens_embd,
810810
rot_mats_global=rot_mats_global,
@@ -818,13 +818,15 @@ def _run_batched_prefill_iter(iter_idx, user_indices):
818818

819819
if get_last_token_val == -1:
820820
adjusted_last_idxs = batch_last_token_idxs
821+
seq_len_for_output = padded_len
821822
else:
822823
adjusted_last_idxs = [idx % 32 for idx in batch_last_token_idxs]
824+
seq_len_for_output = 32
823825
row_results = model[model_id].process_output_prefill_batched(
824826
tt_logits,
825827
adjusted_last_idxs,
826828
users_per_row=users_per_row_per_iter,
827-
seq_len_per_user=padded_len,
829+
seq_len_per_user=seq_len_for_output,
828830
)
829831
return row_results
830832

models/demos/gpt_oss/tt/attention/prefill.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,15 @@ def prefill_forward(
9999
block_size = k_cache.shape[2]
100100
page_len = page_table.shape[-1] * block_size
101101
if batch_size > 1:
102-
for b in range(batch_size):
103-
k_b = ttnn.slice(tt_k, (b, 0, 0, 0), (b + 1, tt_k.shape[1], min(page_len, seq_len), tt_k.shape[3]))
104-
v_b = ttnn.slice(tt_v, (b, 0, 0, 0), (b + 1, tt_v.shape[1], min(page_len, seq_len), tt_v.shape[3]))
105-
ttnn.experimental.paged_fill_cache(k_cache, k_b, page_table, batch_idx=b)
106-
ttnn.experimental.paged_fill_cache(v_cache, v_b, page_table, batch_idx=b)
107-
k_b.deallocate(True)
108-
v_b.deallocate(True)
102+
# Flatten batch into seq dim, heads into last dim — single fill call, no per-user loop.
103+
# Paged cache just maps sequence positions to physical pages.
104+
k_fill = ttnn.reshape(tt_k, [1, 1, total_seq_len, -1])
105+
v_fill = ttnn.reshape(tt_v, [1, 1, total_seq_len, -1])
106+
page_table_flat = ttnn.reshape(page_table, [1, -1])
107+
ttnn.experimental.paged_fill_cache(k_cache, k_fill, page_table_flat, batch_idx=0)
108+
ttnn.experimental.paged_fill_cache(v_cache, v_fill, page_table_flat, batch_idx=0)
109+
k_fill.deallocate(True)
110+
v_fill.deallocate(True)
109111
else:
110112
tt_k_sliced = tt_k[:, :, :page_len, :] if page_len < tt_k.shape[2] else tt_k
111113
tt_v_sliced = tt_v[:, :, :page_len, :] if page_len < tt_v.shape[2] else tt_v

models/demos/gpt_oss/tt/model.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -278,12 +278,26 @@ def _forward_layers_and_head(
278278
logits = hidden_states
279279

280280
if get_last_token != -1:
281-
# The logits come from the shared method, slice them
282281
if len(logits.shape) == 3:
283282
logits = ttnn.unsqueeze(logits, dim=1)
284-
logits_sliced = ttnn.slice(logits, (0, 0, get_last_token, 0), (1, 1, get_last_token + 32, logits.shape[-1]))
285-
logits.deallocate(True)
286-
logits = logits_sliced
283+
if batch_size > 1:
284+
# Batch>1: tokens are concatenated [1,1,B*S,H]. Extract each user's 32-token tile.
285+
per_user_seq = logits.shape[2] // batch_size
286+
tiles = []
287+
for b in range(batch_size):
288+
start = b * per_user_seq + get_last_token
289+
tile = ttnn.slice(logits, (0, 0, start, 0), (1, 1, start + 32, logits.shape[-1]))
290+
tiles.append(tile)
291+
logits.deallocate(True)
292+
logits = ttnn.concat(tiles, dim=2) # [1, 1, B*32, H]
293+
for t in tiles:
294+
t.deallocate(True)
295+
else:
296+
logits_sliced = ttnn.slice(
297+
logits, (0, 0, get_last_token, 0), (1, 1, get_last_token + 32, logits.shape[-1])
298+
)
299+
logits.deallocate(True)
300+
logits = logits_sliced
287301
hidden_states = logits
288302

289303
# Final norm and lm_head
@@ -366,7 +380,7 @@ def ttnn_prefill_forward(
366380
current_pos=None, # No current_pos for prefill
367381
page_table=page_table,
368382
kv_cache=kv_cache,
369-
get_last_token=get_last_token if batch_size == 1 else -1, # Disable get_last_token for batch>1
383+
get_last_token=get_last_token,
370384
is_decode=False,
371385
user_id=user_id,
372386
batch_size=batch_size,

0 commit comments

Comments
 (0)