[gpt-oss] batched prefill and prefill tracing#37848
Conversation
Process 4 users simultaneously (one per mesh row) during prefill, reducing iterations from 128 to 32 for ~3.7x TTFT improvement (393ms -> 106ms). Changes: - model.py: Add batched_prefill mode to prepare_inputs_prefill() that shards tokens/page_table across rows via ShardTensor2dMesh(dims=(0,None)). Add process_output_prefill_batched() to extract per-row logits. - text_demo.py: Replace sequential 128-user prefill loop with 32-iteration row-parallel loop. Each iteration prefills 4 users (one per mesh row). Key insight: attention allreduce (axis=1) keeps rows independent while MoE all_to_all (axis=0) naturally handles cross-row token routing. No changes needed to model internals. Tested on commit 52b9c66 (stable), 4x8 Galaxy: - Baseline TTFT: 393ms, Batched TTFT: 106ms (3.7x improvement) - Decode unchanged: ~221ms @ 4.52 tok/s/user - Output quality verified: identical correct outputs across all rows
Process 4 users simultaneously (one per mesh row) during prefill, reducing iterations from 128 to 32 for ~3.7x TTFT improvement (393ms -> 106ms). Changes: - model.py: Add batched_prefill mode to prepare_inputs_prefill() that shards tokens/page_table across rows via ShardTensor2dMesh(dims=(0,None)). Add process_output_prefill_batched() to extract per-row logits. - text_demo.py: Replace sequential 128-user prefill loop with 32-iteration row-parallel loop. Each iteration prefills 4 users (one per mesh row). Key insight: attention allreduce (axis=1) keeps rows independent while MoE all_to_all (axis=0) naturally handles cross-row token routing. No changes needed to model internals. Tested on commit 52b9c66 (stable), 4x8 Galaxy: - Baseline TTFT: 393ms, Batched TTFT: 106ms (3.7x improvement) - Decode unchanged: ~221ms @ 4.52 tok/s/user - Output quality verified: identical correct outputs across all rows
Add users_per_row_per_iter parameter (default 1) to control how many users each mesh row processes per prefill iteration. With batch>1: - attention/prefill.py: flatten batch*seq for QKV matmul, reshape for SDPA, loop paged_fill_cache per user, slice RoPE to per-user seq_len - model.py: pass batch_size through forward chain, process_output extracts N results per row with correct seq offset - text_demo.py: concatenate per-row tokens along seq dim, generalize user_indices and loop counts for arbitrary users_per_row_per_iter Tested: batch128 passes at both users_per_row_per_iter=1 (TTFT 91ms, identical to baseline) and users_per_row_per_iter=2 (TTFT 219ms, correct outputs, slower due to disabled get_last_token optimization).
…/tenstorrent/tt-metal into sraizada/gpt-oss-batched-prefill # Conflicts: # models/demos/gpt_oss/demo/text_demo.py # models/demos/gpt_oss/tt/model.py
There was a problem hiding this comment.
Pull request overview
This PR extends the GPT-OSS TTNN demo/model path to support row-parallel batched prefill (one user per mesh row per iteration) and adds a prefill trace flow to reduce host dispatch overhead, enabling higher throughput for multi-user prefill on multi-row meshes.
Changes:
- Thread a new
batch_sizeparameter through model → decoder layer → attention, and update prefill attention to handle batched/flattened sequences. - Add row-sharded token preparation for batched prefill and a helper to extract per-row logits from multi-device outputs.
- Rework the GPT-OSS text demo to run row-parallel batched prefill, with an optional traced execution path.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 10 comments.
Show a summary per file
| File | Description |
|---|---|
| models/demos/gpt_oss/tt/model.py | Adds batched prefill input handling and per-row output extraction; threads batch_size through forward path. |
| models/demos/gpt_oss/tt/layer.py | Passes batch_size into attention for prefill. |
| models/demos/gpt_oss/tt/attention/prefill.py | Updates prefill attention to reshape/handle batch_size > 1 and KV cache fill loops. |
| models/demos/gpt_oss/tt/attention/init.py | Extends Attention call signature to accept and forward batch_size. |
| models/demos/gpt_oss/demo/text_demo.py | Implements row-parallel batched prefill and adds a prefill tracing execution path. |
Comments suppressed due to low confidence (2)
models/demos/gpt_oss/demo/text_demo.py:508
- This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache)
models/demos/gpt_oss/demo/text_demo.py:509
- This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache)
| activation_dtype = ttnn.bfloat16 | ||
| _, batch_size, seq_len, hidden_size = hidden_states.shape | ||
| total_seq_len = hidden_states.shape[-2] | ||
| hidden_size = hidden_states.shape[-1] |
There was a problem hiding this comment.
seq_len is derived via integer division (total_seq_len // batch_size) and then used to reshape tensors. If total_seq_len is not exactly divisible by batch_size, this will silently produce an incorrect per-user seq_len and likely break the subsequent reshape/attention behavior. Add an explicit check (and a clear error) that total_seq_len % batch_size == 0 before computing seq_len (and consider validating batch_size > 0).
| hidden_size = hidden_states.shape[-1] | |
| hidden_size = hidden_states.shape[-1] | |
| # Validate batch_size and per-user sequence length shape | |
| if batch_size <= 0: | |
| raise ValueError(f"batch_size must be a positive integer, got {batch_size}.") | |
| if total_seq_len % batch_size != 0: | |
| raise ValueError( | |
| f"total_seq_len ({total_seq_len}) must be divisible by batch_size ({batch_size}) " | |
| "to compute a valid per-user seq_len." | |
| ) |
| last_idx = last_token_idxs[user_flat_idx] if isinstance(last_token_idxs, list) else last_token_idxs | ||
| if users_per_row > 1: | ||
| # Tokens are concatenated: user u's last token is at offset u*seq_len_per_user + last_idx | ||
| global_idx = u * seq_len_per_user + last_idx | ||
| result = torch_output[..., global_idx, : self.vocab_size] |
There was a problem hiding this comment.
process_output_prefill_batched will crash for users_per_row > 1 if seq_len_per_user is not provided (it is used in global_idx = u * seq_len_per_user + last_idx). Since the docstring advertises multi-user-per-row support, add an explicit validation (e.g., raise ValueError) when users_per_row > 1 and seq_len_per_user is None (and consider validating the length of last_token_idxs when it’s a list).
| ) | ||
| host_inputs = (host_out[0], host_out[3], host_out[4]) | ||
| copy_host_to_device(host_inputs, device_tensors=trace_dev_inputs, mesh_device=mesh_device) | ||
| ttnn.execute_trace(mesh_device, trace_id, cq_id=0, blocking=False) |
There was a problem hiding this comment.
In the traced batched prefill loop you execute the trace with blocking=False and then immediately read/convert the output. Without an explicit sync/block on the output tensor, this risks reading incomplete results depending on TTNN scheduling. Add a synchronization point (e.g., block on output CPU copy or ttnn.synchronize_device(mesh_device)) before processing the output.
| ttnn.execute_trace(mesh_device, trace_id, cq_id=0, blocking=False) | |
| ttnn.execute_trace(mesh_device, trace_id, cq_id=0, blocking=False) | |
| # Ensure trace execution has completed before reading from tt_out_trace. | |
| ttnn.synchronize_device(mesh_device) |
| batch_tokens_list = [] | ||
| batch_page_tables = [] | ||
| batch_last_token_idxs = [] | ||
|
|
||
| for uid in user_indices: | ||
| prefill_len = int(decoding_pos[uid]) | ||
| padded_len = get_padded_prefill_len(prefill_len) | ||
| user_tokens = torch.cat( | ||
| [ | ||
| input_tokens_prefill_pt[uid : uid + 1, :prefill_len], | ||
| torch.zeros(1, padded_len - prefill_len, dtype=torch.long), | ||
| ], | ||
| dim=-1, | ||
| ) | ||
| batch_tokens_list.append(user_tokens) | ||
| block_size = page_params["page_block_size"] | ||
| num_blocks_needed = (padded_len + block_size - 1) // block_size | ||
| batch_page_tables.append(page_table[uid : uid + 1, :num_blocks_needed]) | ||
| batch_last_token_idxs.append(prefill_len - 1) | ||
|
|
||
| tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [total_users, padded_len] | ||
| page_table_stacked = torch.cat(batch_page_tables, dim=0) # [total_users, num_blocks] | ||
| padded_len = tokens_stacked.shape[1] | ||
|
|
There was a problem hiding this comment.
torch.cat(batch_page_tables, dim=0) will fail if users in the same iteration have different num_blocks_needed (since each slice page_table[uid:uid+1, :num_blocks_needed] can have a different second dimension). To make batched prefill robust when prompts differ, pad/standardize to a single max_num_blocks for the iteration (e.g., fill extra entries with -1) before concatenation.
| batch_tokens_list = [] | |
| batch_page_tables = [] | |
| batch_last_token_idxs = [] | |
| for uid in user_indices: | |
| prefill_len = int(decoding_pos[uid]) | |
| padded_len = get_padded_prefill_len(prefill_len) | |
| user_tokens = torch.cat( | |
| [ | |
| input_tokens_prefill_pt[uid : uid + 1, :prefill_len], | |
| torch.zeros(1, padded_len - prefill_len, dtype=torch.long), | |
| ], | |
| dim=-1, | |
| ) | |
| batch_tokens_list.append(user_tokens) | |
| block_size = page_params["page_block_size"] | |
| num_blocks_needed = (padded_len + block_size - 1) // block_size | |
| batch_page_tables.append(page_table[uid : uid + 1, :num_blocks_needed]) | |
| batch_last_token_idxs.append(prefill_len - 1) | |
| tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [total_users, padded_len] | |
| page_table_stacked = torch.cat(batch_page_tables, dim=0) # [total_users, num_blocks] | |
| padded_len = tokens_stacked.shape[1] | |
| # First pass: collect per-user metadata and track max sizes for padding | |
| per_user_tokens = [] | |
| per_user_padded_lens = [] | |
| per_user_page_tables = [] | |
| per_user_num_blocks = [] | |
| batch_last_token_idxs = [] | |
| max_padded_len = 0 | |
| max_num_blocks = 0 | |
| block_size = page_params["page_block_size"] | |
| for uid in user_indices: | |
| prefill_len = int(decoding_pos[uid]) | |
| padded_len = get_padded_prefill_len(prefill_len) | |
| max_padded_len = max(max_padded_len, padded_len) | |
| tokens_slice = input_tokens_prefill_pt[uid : uid + 1, :prefill_len] | |
| per_user_tokens.append(tokens_slice) | |
| per_user_padded_lens.append(padded_len) | |
| num_blocks_needed = (padded_len + block_size - 1) // block_size | |
| max_num_blocks = max(max_num_blocks, num_blocks_needed) | |
| page_table_slice = page_table[uid : uid + 1, :num_blocks_needed] | |
| per_user_page_tables.append(page_table_slice) | |
| per_user_num_blocks.append(num_blocks_needed) | |
| batch_last_token_idxs.append(prefill_len - 1) | |
| # Second pass: pad tokens and page tables to per-iteration maxima | |
| batch_tokens_list = [] | |
| batch_page_tables = [] | |
| for tokens_slice, padded_len, page_table_slice, num_blocks_needed in zip( | |
| per_user_tokens, | |
| per_user_padded_lens, | |
| per_user_page_tables, | |
| per_user_num_blocks, | |
| ): | |
| # Pad tokens to max_padded_len with zeros | |
| token_pad_len = max_padded_len - padded_len | |
| if token_pad_len > 0: | |
| token_padding = torch.zeros(1, token_pad_len, dtype=torch.long, device=tokens_slice.device) | |
| user_tokens = torch.cat([tokens_slice, token_padding], dim=-1) | |
| else: | |
| user_tokens = tokens_slice | |
| batch_tokens_list.append(user_tokens) | |
| # Pad page tables to max_num_blocks with -1 sentinel | |
| page_pad_blocks = max_num_blocks - num_blocks_needed | |
| if page_pad_blocks > 0: | |
| pad_values = torch.full( | |
| (1, page_pad_blocks), | |
| -1, | |
| dtype=page_table_slice.dtype, | |
| device=page_table_slice.device, | |
| ) | |
| user_page_table = torch.cat([page_table_slice, pad_values], dim=-1) | |
| else: | |
| user_page_table = page_table_slice | |
| batch_page_tables.append(user_page_table) | |
| tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [total_users, max_padded_len] | |
| page_table_stacked = torch.cat(batch_page_tables, dim=0) # [total_users, max_num_blocks] | |
| padded_len = max_padded_len |
| k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) | ||
| v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) |
There was a problem hiding this comment.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
| k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) | |
| v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) | |
| ttnn.mul(k_cache, 0, output_tensor=k_cache) | |
| ttnn.mul(v_cache, 0, output_tensor=v_cache) |
| k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) | ||
| v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) |
There was a problem hiding this comment.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
| k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) | |
| v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) | |
| ttnn.mul(k_cache, 0, output_tensor=k_cache) | |
| ttnn.mul(v_cache, 0, output_tensor=v_cache) |
| k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) | ||
| v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) |
There was a problem hiding this comment.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
| k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) | |
| v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) | |
| ttnn.mul(k_cache, 0, output_tensor=k_cache) | |
| ttnn.mul(v_cache, 0, output_tensor=v_cache) |
| k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) | ||
| v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) |
There was a problem hiding this comment.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
| k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) | |
| v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) | |
| ttnn.mul(k_cache, 0, output_tensor=k_cache) | |
| ttnn.mul(v_cache, 0, output_tensor=v_cache) |
| # - model.py:process_output_prefill_batched: extract multiple logits per row | ||
| assert users_per_row_prefill % users_per_row_per_iter == 0 | ||
| num_prefill_iters = users_per_row_prefill // users_per_row_per_iter | ||
| users_per_iter = num_rows * users_per_row_per_iter # Total users per iteration |
There was a problem hiding this comment.
Variable users_per_iter is not used.
| users_per_iter = num_rows * users_per_row_per_iter # Total users per iteration |
| fixed_get_last_token = -1 | ||
|
|
||
| if users_per_row_per_iter > 1: | ||
| fixed_get_last_token = -1 # Can't use get_last_token with batch>1 |
There was a problem hiding this comment.
This statement is unreachable.
- Remove unnecessary k_cache/v_cache reassignments (in-place via output_tensor) - Remove unused users_per_iter variable - Restructure fixed_get_last_token logic to avoid unreachable code when users_per_row_per_iter=1
- model.py: When batch_size>1, extract each user's 32-token last-token tile via ttnn.slice and concat to [1,1,B*32,H] before norm+lm_head. Removes the batch>1 get_last_token=-1 override in ttnn_prefill_forward. - attention/prefill.py: Replace per-user paged_fill_cache loop with single-call reshape approach (flatten batch into seq dim, heads into last dim, flatten page table). Matches llama_70b_galaxy pattern. - text_demo.py: Remove batch>1 get_last_token override, use seq_len_per_user=32 when get_last_token is active. batch128 users_per_row_per_iter=2 TTFT: 218ms -> 99ms (get_last_token fix) Compile time: 5.24s -> 3.16s (single-call paged_fill_cache) batch128 users_per_row_per_iter=1: unchanged at 91ms
galaxy demo: https://github.com/tenstorrent/tt-metal/actions/runs/22053977976