Skip to content

[gpt-oss] batched prefill and prefill tracing#37848

Merged
sraizada-tt merged 8 commits intomainfrom
sraizada/gpt-oss-batched-prefill
Feb 16, 2026
Merged

[gpt-oss] batched prefill and prefill tracing#37848
sraizada-tt merged 8 commits intomainfrom
sraizada/gpt-oss-batched-prefill

Conversation

@sraizada-tt
Copy link
Contributor

@sraizada-tt sraizada-tt commented Feb 13, 2026

Ubuntu added 5 commits February 13, 2026 10:49
Process 4 users simultaneously (one per mesh row) during prefill,
reducing iterations from 128 to 32 for ~3.7x TTFT improvement
(393ms -> 106ms).

Changes:
- model.py: Add batched_prefill mode to prepare_inputs_prefill() that
  shards tokens/page_table across rows via ShardTensor2dMesh(dims=(0,None)).
  Add process_output_prefill_batched() to extract per-row logits.
- text_demo.py: Replace sequential 128-user prefill loop with 32-iteration
  row-parallel loop. Each iteration prefills 4 users (one per mesh row).

Key insight: attention allreduce (axis=1) keeps rows independent while
MoE all_to_all (axis=0) naturally handles cross-row token routing.
No changes needed to model internals.

Tested on commit 52b9c66 (stable), 4x8 Galaxy:
- Baseline TTFT: 393ms, Batched TTFT: 106ms (3.7x improvement)
- Decode unchanged: ~221ms @ 4.52 tok/s/user
- Output quality verified: identical correct outputs across all rows
Process 4 users simultaneously (one per mesh row) during prefill,
reducing iterations from 128 to 32 for ~3.7x TTFT improvement
(393ms -> 106ms).

Changes:
- model.py: Add batched_prefill mode to prepare_inputs_prefill() that
  shards tokens/page_table across rows via ShardTensor2dMesh(dims=(0,None)).
  Add process_output_prefill_batched() to extract per-row logits.
- text_demo.py: Replace sequential 128-user prefill loop with 32-iteration
  row-parallel loop. Each iteration prefills 4 users (one per mesh row).

Key insight: attention allreduce (axis=1) keeps rows independent while
MoE all_to_all (axis=0) naturally handles cross-row token routing.
No changes needed to model internals.

Tested on commit 52b9c66 (stable), 4x8 Galaxy:
- Baseline TTFT: 393ms, Batched TTFT: 106ms (3.7x improvement)
- Decode unchanged: ~221ms @ 4.52 tok/s/user
- Output quality verified: identical correct outputs across all rows
Add users_per_row_per_iter parameter (default 1) to control how many
users each mesh row processes per prefill iteration. With batch>1:
- attention/prefill.py: flatten batch*seq for QKV matmul, reshape for
  SDPA, loop paged_fill_cache per user, slice RoPE to per-user seq_len
- model.py: pass batch_size through forward chain, process_output
  extracts N results per row with correct seq offset
- text_demo.py: concatenate per-row tokens along seq dim, generalize
  user_indices and loop counts for arbitrary users_per_row_per_iter

Tested: batch128 passes at both users_per_row_per_iter=1 (TTFT 91ms,
identical to baseline) and users_per_row_per_iter=2 (TTFT 219ms,
correct outputs, slower due to disabled get_last_token optimization).
…/tenstorrent/tt-metal into sraizada/gpt-oss-batched-prefill

# Conflicts:
#	models/demos/gpt_oss/demo/text_demo.py
#	models/demos/gpt_oss/tt/model.py
Copilot AI review requested due to automatic review settings February 13, 2026 14:51
@sraizada-tt sraizada-tt requested review from a team, handrewsTT and mtairum as code owners February 13, 2026 14:51
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the GPT-OSS TTNN demo/model path to support row-parallel batched prefill (one user per mesh row per iteration) and adds a prefill trace flow to reduce host dispatch overhead, enabling higher throughput for multi-user prefill on multi-row meshes.

Changes:

  • Thread a new batch_size parameter through model → decoder layer → attention, and update prefill attention to handle batched/flattened sequences.
  • Add row-sharded token preparation for batched prefill and a helper to extract per-row logits from multi-device outputs.
  • Rework the GPT-OSS text demo to run row-parallel batched prefill, with an optional traced execution path.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 10 comments.

Show a summary per file
File Description
models/demos/gpt_oss/tt/model.py Adds batched prefill input handling and per-row output extraction; threads batch_size through forward path.
models/demos/gpt_oss/tt/layer.py Passes batch_size into attention for prefill.
models/demos/gpt_oss/tt/attention/prefill.py Updates prefill attention to reshape/handle batch_size > 1 and KV cache fill loops.
models/demos/gpt_oss/tt/attention/init.py Extends Attention call signature to accept and forward batch_size.
models/demos/gpt_oss/demo/text_demo.py Implements row-parallel batched prefill and adds a prefill tracing execution path.
Comments suppressed due to low confidence (2)

models/demos/gpt_oss/demo/text_demo.py:508

  • This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
    This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
    This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
                    k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache)

models/demos/gpt_oss/demo/text_demo.py:509

  • This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
    This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
    This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
                    v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache)

activation_dtype = ttnn.bfloat16
_, batch_size, seq_len, hidden_size = hidden_states.shape
total_seq_len = hidden_states.shape[-2]
hidden_size = hidden_states.shape[-1]
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seq_len is derived via integer division (total_seq_len // batch_size) and then used to reshape tensors. If total_seq_len is not exactly divisible by batch_size, this will silently produce an incorrect per-user seq_len and likely break the subsequent reshape/attention behavior. Add an explicit check (and a clear error) that total_seq_len % batch_size == 0 before computing seq_len (and consider validating batch_size > 0).

Suggested change
hidden_size = hidden_states.shape[-1]
hidden_size = hidden_states.shape[-1]
# Validate batch_size and per-user sequence length shape
if batch_size <= 0:
raise ValueError(f"batch_size must be a positive integer, got {batch_size}.")
if total_seq_len % batch_size != 0:
raise ValueError(
f"total_seq_len ({total_seq_len}) must be divisible by batch_size ({batch_size}) "
"to compute a valid per-user seq_len."
)

Copilot uses AI. Check for mistakes.
Comment on lines +648 to +652
last_idx = last_token_idxs[user_flat_idx] if isinstance(last_token_idxs, list) else last_token_idxs
if users_per_row > 1:
# Tokens are concatenated: user u's last token is at offset u*seq_len_per_user + last_idx
global_idx = u * seq_len_per_user + last_idx
result = torch_output[..., global_idx, : self.vocab_size]
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

process_output_prefill_batched will crash for users_per_row > 1 if seq_len_per_user is not provided (it is used in global_idx = u * seq_len_per_user + last_idx). Since the docstring advertises multi-user-per-row support, add an explicit validation (e.g., raise ValueError) when users_per_row > 1 and seq_len_per_user is None (and consider validating the length of last_token_idxs when it’s a list).

Copilot uses AI. Check for mistakes.
)
host_inputs = (host_out[0], host_out[3], host_out[4])
copy_host_to_device(host_inputs, device_tensors=trace_dev_inputs, mesh_device=mesh_device)
ttnn.execute_trace(mesh_device, trace_id, cq_id=0, blocking=False)
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the traced batched prefill loop you execute the trace with blocking=False and then immediately read/convert the output. Without an explicit sync/block on the output tensor, this risks reading incomplete results depending on TTNN scheduling. Add a synchronization point (e.g., block on output CPU copy or ttnn.synchronize_device(mesh_device)) before processing the output.

Suggested change
ttnn.execute_trace(mesh_device, trace_id, cq_id=0, blocking=False)
ttnn.execute_trace(mesh_device, trace_id, cq_id=0, blocking=False)
# Ensure trace execution has completed before reading from tt_out_trace.
ttnn.synchronize_device(mesh_device)

Copilot uses AI. Check for mistakes.
Comment on lines +773 to +796
batch_tokens_list = []
batch_page_tables = []
batch_last_token_idxs = []

for uid in user_indices:
prefill_len = int(decoding_pos[uid])
padded_len = get_padded_prefill_len(prefill_len)
user_tokens = torch.cat(
[
input_tokens_prefill_pt[uid : uid + 1, :prefill_len],
torch.zeros(1, padded_len - prefill_len, dtype=torch.long),
],
dim=-1,
)
batch_tokens_list.append(user_tokens)
block_size = page_params["page_block_size"]
num_blocks_needed = (padded_len + block_size - 1) // block_size
batch_page_tables.append(page_table[uid : uid + 1, :num_blocks_needed])
batch_last_token_idxs.append(prefill_len - 1)

tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [total_users, padded_len]
page_table_stacked = torch.cat(batch_page_tables, dim=0) # [total_users, num_blocks]
padded_len = tokens_stacked.shape[1]

Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.cat(batch_page_tables, dim=0) will fail if users in the same iteration have different num_blocks_needed (since each slice page_table[uid:uid+1, :num_blocks_needed] can have a different second dimension). To make batched prefill robust when prompts differ, pad/standardize to a single max_num_blocks for the iteration (e.g., fill extra entries with -1) before concatenation.

Suggested change
batch_tokens_list = []
batch_page_tables = []
batch_last_token_idxs = []
for uid in user_indices:
prefill_len = int(decoding_pos[uid])
padded_len = get_padded_prefill_len(prefill_len)
user_tokens = torch.cat(
[
input_tokens_prefill_pt[uid : uid + 1, :prefill_len],
torch.zeros(1, padded_len - prefill_len, dtype=torch.long),
],
dim=-1,
)
batch_tokens_list.append(user_tokens)
block_size = page_params["page_block_size"]
num_blocks_needed = (padded_len + block_size - 1) // block_size
batch_page_tables.append(page_table[uid : uid + 1, :num_blocks_needed])
batch_last_token_idxs.append(prefill_len - 1)
tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [total_users, padded_len]
page_table_stacked = torch.cat(batch_page_tables, dim=0) # [total_users, num_blocks]
padded_len = tokens_stacked.shape[1]
# First pass: collect per-user metadata and track max sizes for padding
per_user_tokens = []
per_user_padded_lens = []
per_user_page_tables = []
per_user_num_blocks = []
batch_last_token_idxs = []
max_padded_len = 0
max_num_blocks = 0
block_size = page_params["page_block_size"]
for uid in user_indices:
prefill_len = int(decoding_pos[uid])
padded_len = get_padded_prefill_len(prefill_len)
max_padded_len = max(max_padded_len, padded_len)
tokens_slice = input_tokens_prefill_pt[uid : uid + 1, :prefill_len]
per_user_tokens.append(tokens_slice)
per_user_padded_lens.append(padded_len)
num_blocks_needed = (padded_len + block_size - 1) // block_size
max_num_blocks = max(max_num_blocks, num_blocks_needed)
page_table_slice = page_table[uid : uid + 1, :num_blocks_needed]
per_user_page_tables.append(page_table_slice)
per_user_num_blocks.append(num_blocks_needed)
batch_last_token_idxs.append(prefill_len - 1)
# Second pass: pad tokens and page tables to per-iteration maxima
batch_tokens_list = []
batch_page_tables = []
for tokens_slice, padded_len, page_table_slice, num_blocks_needed in zip(
per_user_tokens,
per_user_padded_lens,
per_user_page_tables,
per_user_num_blocks,
):
# Pad tokens to max_padded_len with zeros
token_pad_len = max_padded_len - padded_len
if token_pad_len > 0:
token_padding = torch.zeros(1, token_pad_len, dtype=torch.long, device=tokens_slice.device)
user_tokens = torch.cat([tokens_slice, token_padding], dim=-1)
else:
user_tokens = tokens_slice
batch_tokens_list.append(user_tokens)
# Pad page tables to max_num_blocks with -1 sentinel
page_pad_blocks = max_num_blocks - num_blocks_needed
if page_pad_blocks > 0:
pad_values = torch.full(
(1, page_pad_blocks),
-1,
dtype=page_table_slice.dtype,
device=page_table_slice.device,
)
user_page_table = torch.cat([page_table_slice, pad_values], dim=-1)
else:
user_page_table = page_table_slice
batch_page_tables.append(user_page_table)
tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [total_users, max_padded_len]
page_table_stacked = torch.cat(batch_page_tables, dim=0) # [total_users, max_num_blocks]
padded_len = max_padded_len

Copilot uses AI. Check for mistakes.
Comment on lines 684 to 685
k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache)
v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache)
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.

Suggested change
k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache)
v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache)
ttnn.mul(k_cache, 0, output_tensor=k_cache)
ttnn.mul(v_cache, 0, output_tensor=v_cache)

Copilot uses AI. Check for mistakes.
Comment on lines 684 to 685
k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache)
v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache)
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.

Suggested change
k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache)
v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache)
ttnn.mul(k_cache, 0, output_tensor=k_cache)
ttnn.mul(v_cache, 0, output_tensor=v_cache)

Copilot uses AI. Check for mistakes.
Comment on lines 848 to 849
k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache)
v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache)
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.

Suggested change
k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache)
v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache)
ttnn.mul(k_cache, 0, output_tensor=k_cache)
ttnn.mul(v_cache, 0, output_tensor=v_cache)

Copilot uses AI. Check for mistakes.
Comment on lines 848 to 849
k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache)
v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache)
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.

Suggested change
k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache)
v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache)
ttnn.mul(k_cache, 0, output_tensor=k_cache)
ttnn.mul(v_cache, 0, output_tensor=v_cache)

Copilot uses AI. Check for mistakes.
# - model.py:process_output_prefill_batched: extract multiple logits per row
assert users_per_row_prefill % users_per_row_per_iter == 0
num_prefill_iters = users_per_row_prefill // users_per_row_per_iter
users_per_iter = num_rows * users_per_row_per_iter # Total users per iteration
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable users_per_iter is not used.

Suggested change
users_per_iter = num_rows * users_per_row_per_iter # Total users per iteration

Copilot uses AI. Check for mistakes.
fixed_get_last_token = -1

if users_per_row_per_iter > 1:
fixed_get_last_token = -1 # Can't use get_last_token with batch>1
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement is unreachable.

Copilot uses AI. Check for mistakes.
- Remove unnecessary k_cache/v_cache reassignments (in-place via output_tensor)
- Remove unused users_per_iter variable
- Restructure fixed_get_last_token logic to avoid unreachable code when
  users_per_row_per_iter=1
- model.py: When batch_size>1, extract each user's 32-token last-token
  tile via ttnn.slice and concat to [1,1,B*32,H] before norm+lm_head.
  Removes the batch>1 get_last_token=-1 override in ttnn_prefill_forward.
- attention/prefill.py: Replace per-user paged_fill_cache loop with
  single-call reshape approach (flatten batch into seq dim, heads into
  last dim, flatten page table). Matches llama_70b_galaxy pattern.
- text_demo.py: Remove batch>1 get_last_token override, use
  seq_len_per_user=32 when get_last_token is active.

batch128 users_per_row_per_iter=2 TTFT: 218ms -> 99ms (get_last_token fix)
Compile time: 5.24s -> 3.16s (single-call paged_fill_cache)
batch128 users_per_row_per_iter=1: unchanged at 91ms
@sraizada-tt sraizada-tt added this pull request to the merge queue Feb 16, 2026
Merged via the queue into main with commit f8cbbd0 Feb 16, 2026
101 checks passed
@sraizada-tt sraizada-tt deleted the sraizada/gpt-oss-batched-prefill branch February 16, 2026 10:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants