-
Notifications
You must be signed in to change notification settings - Fork 346
[gpt-oss] batched prefill and prefill tracing #37848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
1b53b80
0c67979
ffab842
38db9f3
c1e1d65
2c88901
387359e
a4915da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -36,6 +36,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from models.tt_transformers.demo.simple_text_demo import create_tt_page_table, load_inputs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from models.tt_transformers.tt.common import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PagedAttentionConfig, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| copy_host_to_device, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_padded_prefill_len, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| preprocess_inputs_prefill, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sample_host, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -242,7 +243,7 @@ def prepare_gpt_oss_generator_args( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| {"page_block_size": 64, "page_max_num_blocks_per_dp": 128 * 1024 // 64}, # page_params | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| {"temperature": 0, "top_p": 0.08}, # sampling_params (greedy decoding) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| True, # enable_decode_trace | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| False, # enable_prefill_trace | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| True, # enable_prefill_trace | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| True, # users_row_sharded | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| False, # long_context_mode | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -572,33 +573,301 @@ def test_gpt_oss_demo( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info(f"Prefill finished for {num_real_users} real users") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info(f"First generated token (user 0): '{tokenizer.decode(prefilled_token[0])}'") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Standard batch prefill (matching tt_transformers) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info("Starting prefill warmup...") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| profiler.start(f"compile_prefill", iteration=batch_idx) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generator.prefill_forward_text( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_tokens_prefill_pt[:1], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| page_table=page_table, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kv_cache=tt_kv_cache, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_lens=decoding_pos, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| enable_trace=enable_prefill_trace, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| warmup_prefill=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| profiler.end(f"compile_prefill", iteration=batch_idx) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info("Finished prefill warmup") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Row-parallel batched prefill: process 4 users at once (one per mesh row) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # This gives ~4x speedup over sequential per-user prefill | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_rows = mesh_device.shape[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| users_per_row_prefill = global_batch_size // num_rows | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| users_per_row_per_iter = 1 # Users each mesh row processes per prefill iteration | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Increasing above 1 requires model changes: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # - attention/prefill.py: relax batch_size!=1 check, loop paged_fill_cache | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # - model.py:process_output_prefill_batched: extract multiple logits per row | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert users_per_row_prefill % users_per_row_per_iter == 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_prefill_iters = users_per_row_prefill // users_per_row_per_iter | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| users_per_iter = num_rows * users_per_row_per_iter # Total users per iteration | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| users_per_iter = num_rows * users_per_row_per_iter # Total users per iteration |
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
Outdated
Copilot
AI
Feb 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This statement is unreachable.
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
Outdated
Copilot
AI
Feb 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
| k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) | |
| v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) | |
| ttnn.mul(k_cache, 0, output_tensor=k_cache) | |
| ttnn.mul(v_cache, 0, output_tensor=v_cache) |
Outdated
Copilot
AI
Feb 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
| k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) | |
| v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) | |
| ttnn.mul(k_cache, 0, output_tensor=k_cache) | |
| ttnn.mul(v_cache, 0, output_tensor=v_cache) |
Copilot
AI
Feb 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the traced batched prefill loop you execute the trace with blocking=False and then immediately read/convert the output. Without an explicit sync/block on the output tensor, this risks reading incomplete results depending on TTNN scheduling. Add a synchronization point (e.g., block on output CPU copy or ttnn.synchronize_device(mesh_device)) before processing the output.
| ttnn.execute_trace(mesh_device, trace_id, cq_id=0, blocking=False) | |
| ttnn.execute_trace(mesh_device, trace_id, cq_id=0, blocking=False) | |
| # Ensure trace execution has completed before reading from tt_out_trace. | |
| ttnn.synchronize_device(mesh_device) |
Copilot
AI
Feb 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.cat(batch_page_tables, dim=0) will fail if users in the same iteration have different num_blocks_needed (since each slice page_table[uid:uid+1, :num_blocks_needed] can have a different second dimension). To make batched prefill robust when prompts differ, pad/standardize to a single max_num_blocks for the iteration (e.g., fill extra entries with -1) before concatenation.
| batch_tokens_list = [] | |
| batch_page_tables = [] | |
| batch_last_token_idxs = [] | |
| for uid in user_indices: | |
| prefill_len = int(decoding_pos[uid]) | |
| padded_len = get_padded_prefill_len(prefill_len) | |
| user_tokens = torch.cat( | |
| [ | |
| input_tokens_prefill_pt[uid : uid + 1, :prefill_len], | |
| torch.zeros(1, padded_len - prefill_len, dtype=torch.long), | |
| ], | |
| dim=-1, | |
| ) | |
| batch_tokens_list.append(user_tokens) | |
| block_size = page_params["page_block_size"] | |
| num_blocks_needed = (padded_len + block_size - 1) // block_size | |
| batch_page_tables.append(page_table[uid : uid + 1, :num_blocks_needed]) | |
| batch_last_token_idxs.append(prefill_len - 1) | |
| tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [total_users, padded_len] | |
| page_table_stacked = torch.cat(batch_page_tables, dim=0) # [total_users, num_blocks] | |
| padded_len = tokens_stacked.shape[1] | |
| # First pass: collect per-user metadata and track max sizes for padding | |
| per_user_tokens = [] | |
| per_user_padded_lens = [] | |
| per_user_page_tables = [] | |
| per_user_num_blocks = [] | |
| batch_last_token_idxs = [] | |
| max_padded_len = 0 | |
| max_num_blocks = 0 | |
| block_size = page_params["page_block_size"] | |
| for uid in user_indices: | |
| prefill_len = int(decoding_pos[uid]) | |
| padded_len = get_padded_prefill_len(prefill_len) | |
| max_padded_len = max(max_padded_len, padded_len) | |
| tokens_slice = input_tokens_prefill_pt[uid : uid + 1, :prefill_len] | |
| per_user_tokens.append(tokens_slice) | |
| per_user_padded_lens.append(padded_len) | |
| num_blocks_needed = (padded_len + block_size - 1) // block_size | |
| max_num_blocks = max(max_num_blocks, num_blocks_needed) | |
| page_table_slice = page_table[uid : uid + 1, :num_blocks_needed] | |
| per_user_page_tables.append(page_table_slice) | |
| per_user_num_blocks.append(num_blocks_needed) | |
| batch_last_token_idxs.append(prefill_len - 1) | |
| # Second pass: pad tokens and page tables to per-iteration maxima | |
| batch_tokens_list = [] | |
| batch_page_tables = [] | |
| for tokens_slice, padded_len, page_table_slice, num_blocks_needed in zip( | |
| per_user_tokens, | |
| per_user_padded_lens, | |
| per_user_page_tables, | |
| per_user_num_blocks, | |
| ): | |
| # Pad tokens to max_padded_len with zeros | |
| token_pad_len = max_padded_len - padded_len | |
| if token_pad_len > 0: | |
| token_padding = torch.zeros(1, token_pad_len, dtype=torch.long, device=tokens_slice.device) | |
| user_tokens = torch.cat([tokens_slice, token_padding], dim=-1) | |
| else: | |
| user_tokens = tokens_slice | |
| batch_tokens_list.append(user_tokens) | |
| # Pad page tables to max_num_blocks with -1 sentinel | |
| page_pad_blocks = max_num_blocks - num_blocks_needed | |
| if page_pad_blocks > 0: | |
| pad_values = torch.full( | |
| (1, page_pad_blocks), | |
| -1, | |
| dtype=page_table_slice.dtype, | |
| device=page_table_slice.device, | |
| ) | |
| user_page_table = torch.cat([page_table_slice, pad_values], dim=-1) | |
| else: | |
| user_page_table = page_table_slice | |
| batch_page_tables.append(user_page_table) | |
| tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [total_users, max_padded_len] | |
| page_table_stacked = torch.cat(batch_page_tables, dim=0) # [total_users, max_num_blocks] | |
| padded_len = max_padded_len |
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
Outdated
Copilot
AI
Feb 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'k_cache' is unnecessary as it is redefined before this value is used.
| k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) | |
| v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) | |
| ttnn.mul(k_cache, 0, output_tensor=k_cache) | |
| ttnn.mul(v_cache, 0, output_tensor=v_cache) |
Outdated
Copilot
AI
Feb 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
This assignment to 'v_cache' is unnecessary as it is redefined before this value is used.
| k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) | |
| v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) | |
| ttnn.mul(k_cache, 0, output_tensor=k_cache) | |
| ttnn.mul(v_cache, 0, output_tensor=v_cache) |
Uh oh!
There was an error while loading. Please reload this page.