From 1b53b80043afbb3cf500550280d72f1bdbf90dc4 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 13 Feb 2026 10:49:10 +0000 Subject: [PATCH 1/6] Row-parallel batched prefill for GPT-OSS 120B batch128 on 4x8 Process 4 users simultaneously (one per mesh row) during prefill, reducing iterations from 128 to 32 for ~3.7x TTFT improvement (393ms -> 106ms). Changes: - model.py: Add batched_prefill mode to prepare_inputs_prefill() that shards tokens/page_table across rows via ShardTensor2dMesh(dims=(0,None)). Add process_output_prefill_batched() to extract per-row logits. - text_demo.py: Replace sequential 128-user prefill loop with 32-iteration row-parallel loop. Each iteration prefills 4 users (one per mesh row). Key insight: attention allreduce (axis=1) keeps rows independent while MoE all_to_all (axis=0) naturally handles cross-row token routing. No changes needed to model internals. Tested on commit 52b9c66ec6 (stable), 4x8 Galaxy: - Baseline TTFT: 393ms, Batched TTFT: 106ms (3.7x improvement) - Decode unchanged: ~221ms @ 4.52 tok/s/user - Output quality verified: identical correct outputs across all rows --- models/demos/gpt_oss/demo/text_demo.py | 105 +++++++++++++++++++------ models/demos/gpt_oss/tt/model.py | 53 +++++++++++-- 2 files changed, 131 insertions(+), 27 deletions(-) diff --git a/models/demos/gpt_oss/demo/text_demo.py b/models/demos/gpt_oss/demo/text_demo.py index 1ca99ae4a5c4..f35398524cab 100644 --- a/models/demos/gpt_oss/demo/text_demo.py +++ b/models/demos/gpt_oss/demo/text_demo.py @@ -572,33 +572,94 @@ def test_gpt_oss_demo( logger.info(f"Prefill finished for {num_real_users} real users") logger.info(f"First generated token (user 0): '{tokenizer.decode(prefilled_token[0])}'") else: - # Standard batch prefill (matching tt_transformers) - logger.info("Starting prefill warmup...") + # Row-parallel batched prefill: process 4 users at once (one per mesh row) + # This gives ~4x speedup over sequential per-user prefill + num_rows = mesh_device.shape[0] + users_per_row_prefill = global_batch_size // num_rows + model_id = 0 # data_parallel=1, single model + + prefilled_token = torch.zeros(global_batch_size, dtype=torch.long) + + # Helper to run one batched prefill iteration + def _run_batched_prefill_iter(iter_idx, user_indices): + batch_tokens_list = [] + batch_page_tables = [] + batch_last_token_idxs = [] + + for uid in user_indices: + prefill_len = int(decoding_pos[uid]) + padded_len = get_padded_prefill_len(prefill_len) + user_tokens = torch.cat( + [ + input_tokens_prefill_pt[uid : uid + 1, :prefill_len], + torch.zeros(1, padded_len - prefill_len, dtype=torch.long), + ], + dim=-1, + ) + batch_tokens_list.append(user_tokens) + block_size = page_params["page_block_size"] + num_blocks_needed = (padded_len + block_size - 1) // block_size + batch_page_tables.append(page_table[uid : uid + 1, :num_blocks_needed]) + batch_last_token_idxs.append(prefill_len - 1) + + tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [num_rows, padded_len] + page_table_stacked = torch.cat(batch_page_tables, dim=0) # [num_rows, num_blocks] + + (tokens_embd, rot_mats_global, rot_mats_local, page_table_tt, _) = model[ + model_id + ].prepare_inputs_prefill( + tokens_stacked, + page_table=page_table_stacked, + batched_prefill=True, + ) + + get_last_token_val = (max(batch_last_token_idxs) // 32) * 32 + tt_logits = model[model_id].ttnn_prefill_forward( + tokens_embd, + rot_mats_global=rot_mats_global, + rot_mats_local=rot_mats_local, + user_id=0, # Must be 0: each device sees page_table[0] after row-sharding + page_table=page_table_tt, + get_last_token=get_last_token_val, + kv_cache=tt_kv_cache[model_id], + ) + + row_results = model[model_id].process_output_prefill_batched( + tt_logits, [idx % 32 for idx in batch_last_token_idxs] + ) + return row_results + + # Warmup: compile with first batch + logger.info("Starting row-parallel prefill warmup...") profiler.start(f"compile_prefill", iteration=batch_idx) - generator.prefill_forward_text( - input_tokens_prefill_pt[:1], - page_table=page_table, - kv_cache=tt_kv_cache, - prompt_lens=decoding_pos, - enable_trace=enable_prefill_trace, - warmup_prefill=False, - ) + warmup_user_indices = [row * users_per_row_prefill for row in range(num_rows)] + warmup_results = _run_batched_prefill_iter(0, warmup_user_indices) + for row, uid in enumerate(warmup_user_indices): + prefilled_token[uid] = torch.argmax(warmup_results[row].view(-1)).item() profiler.end(f"compile_prefill", iteration=batch_idx) - logger.info("Finished prefill warmup") + logger.info("Finished row-parallel prefill warmup") - logger.info(f"Starting prefill...") - profiler.start(f"inference_prefill", iteration=batch_idx) - logits = generator.prefill_forward_text( - input_tokens_prefill_pt, - page_table=page_table, - kv_cache=tt_kv_cache, - prompt_lens=decoding_pos, - enable_trace=enable_prefill_trace, - warmup_prefill=False, # we can warmup prefill ourselves above if we want to + # Clear KV caches before real prefill (warmup wrote to them) + for i in range(len(model)): + for layer_obj in model[i].layers: + k_cache, v_cache = layer_obj.self_attn.layer_past + k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) + v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) + + # Real prefill + logger.info( + f"Starting row-parallel batched prefill ({users_per_row_prefill} iters for {global_batch_size} users)..." ) - prefilled_token = torch.argmax(logits, dim=-1) + profiler.start(f"inference_prefill", iteration=batch_idx) + for iter_idx in range(users_per_row_prefill): + user_indices = [row * users_per_row_prefill + iter_idx for row in range(num_rows)] + row_results = _run_batched_prefill_iter(iter_idx, user_indices) + for row, uid in enumerate(user_indices): + prefilled_token[uid] = torch.argmax(row_results[row].view(-1)).item() + if iter_idx % 8 == 0: + logger.info(f" Prefilled batch {iter_idx+1}/{users_per_row_prefill}") profiler.end(f"inference_prefill", iteration=batch_idx) - logger.info(f"Prefill finished") + logger.info(f"Row-parallel batched prefill finished ({users_per_row_prefill} iterations)") logger.info(f"First generated token: '{tokenizer.decode(prefilled_token[0])}'") # Initialize generation state like tt_transformers diff --git a/models/demos/gpt_oss/tt/model.py b/models/demos/gpt_oss/tt/model.py index 47bb42836c4e..e2187186fd3a 100644 --- a/models/demos/gpt_oss/tt/model.py +++ b/models/demos/gpt_oss/tt/model.py @@ -476,15 +476,34 @@ def prepare_inputs_prefill( trace_enabled=False, last_token_idx=None, global_user_id=None, + batched_prefill=False, ): - """Prepare inputs for prefill mode""" - # Embed the tokens - if tokens.dim() == 2: - tokens = tokens.reshape(1, 1, 1, -1) + """Prepare inputs for prefill mode + Args: + batched_prefill: If True, tokens is [num_rows, seq_len] and will be + sharded across mesh rows. Each row processes a different user. + """ + # Embed the tokens device = None if trace_enabled else self.mesh_device - tokens = ttnn.from_torch(tokens, device=device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT) + if batched_prefill: + # Row-parallel batched prefill: tokens is [num_rows, seq_len] + # Shard across mesh rows so each row gets one user's tokens [1, seq_len] + num_rows = tokens.shape[0] + seq_len_per_user = tokens.shape[1] + tokens = tokens.reshape(num_rows, 1, 1, seq_len_per_user) + tokens = ttnn.from_torch( + tokens, + device=device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh(self.mesh_device, dims=(0, None), mesh_shape=self.mesh_device.shape), + ) + else: + if tokens.dim() == 2: + tokens = tokens.reshape(1, 1, 1, -1) + tokens = ttnn.from_torch(tokens, device=device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT) if not trace_enabled: tokens_embd = ttnn.embedding(tokens, self.embedding_weight, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat8_b) @@ -589,3 +608,27 @@ def process_output_prefill(self, tt_out, last_token_idx): torch_output = ttnn.to_torch(tt_output_tensor) result = torch_output[..., last_token_idx, : self.vocab_size] return result + + def process_output_prefill_batched(self, tt_out, last_token_idxs): + """Process row-parallel batched prefill output. + + Extracts logits from one device per row (first device of each row). + + Args: + tt_out: Multi-device output tensor + last_token_idxs: List of last_token_idx per row user + + Returns: + List of per-user logit tensors (one per row) + """ + num_cols = self.mesh_device.shape[1] + device_tensors = ttnn.get_device_tensors(tt_out) + results = [] + num_rows = self.mesh_device.shape[0] + for row in range(num_rows): + device_idx = row * num_cols # First device of each row + torch_output = ttnn.to_torch(device_tensors[device_idx]) + last_idx = last_token_idxs[row] if isinstance(last_token_idxs, list) else last_token_idxs + result = torch_output[..., last_idx, : self.vocab_size] + results.append(result) + return results From 0c67979f77f2e2ef887ba909ef00888ae45eeec7 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 13 Feb 2026 10:49:10 +0000 Subject: [PATCH 2/6] Row-parallel batched prefill for GPT-OSS 120B batch128 on 4x8 Process 4 users simultaneously (one per mesh row) during prefill, reducing iterations from 128 to 32 for ~3.7x TTFT improvement (393ms -> 106ms). Changes: - model.py: Add batched_prefill mode to prepare_inputs_prefill() that shards tokens/page_table across rows via ShardTensor2dMesh(dims=(0,None)). Add process_output_prefill_batched() to extract per-row logits. - text_demo.py: Replace sequential 128-user prefill loop with 32-iteration row-parallel loop. Each iteration prefills 4 users (one per mesh row). Key insight: attention allreduce (axis=1) keeps rows independent while MoE all_to_all (axis=0) naturally handles cross-row token routing. No changes needed to model internals. Tested on commit 52b9c66ec6 (stable), 4x8 Galaxy: - Baseline TTFT: 393ms, Batched TTFT: 106ms (3.7x improvement) - Decode unchanged: ~221ms @ 4.52 tok/s/user - Output quality verified: identical correct outputs across all rows --- models/demos/gpt_oss/demo/text_demo.py | 105 +++++++++++++++++++------ models/demos/gpt_oss/tt/model.py | 53 +++++++++++-- 2 files changed, 131 insertions(+), 27 deletions(-) diff --git a/models/demos/gpt_oss/demo/text_demo.py b/models/demos/gpt_oss/demo/text_demo.py index 1ca99ae4a5c4..f35398524cab 100644 --- a/models/demos/gpt_oss/demo/text_demo.py +++ b/models/demos/gpt_oss/demo/text_demo.py @@ -572,33 +572,94 @@ def test_gpt_oss_demo( logger.info(f"Prefill finished for {num_real_users} real users") logger.info(f"First generated token (user 0): '{tokenizer.decode(prefilled_token[0])}'") else: - # Standard batch prefill (matching tt_transformers) - logger.info("Starting prefill warmup...") + # Row-parallel batched prefill: process 4 users at once (one per mesh row) + # This gives ~4x speedup over sequential per-user prefill + num_rows = mesh_device.shape[0] + users_per_row_prefill = global_batch_size // num_rows + model_id = 0 # data_parallel=1, single model + + prefilled_token = torch.zeros(global_batch_size, dtype=torch.long) + + # Helper to run one batched prefill iteration + def _run_batched_prefill_iter(iter_idx, user_indices): + batch_tokens_list = [] + batch_page_tables = [] + batch_last_token_idxs = [] + + for uid in user_indices: + prefill_len = int(decoding_pos[uid]) + padded_len = get_padded_prefill_len(prefill_len) + user_tokens = torch.cat( + [ + input_tokens_prefill_pt[uid : uid + 1, :prefill_len], + torch.zeros(1, padded_len - prefill_len, dtype=torch.long), + ], + dim=-1, + ) + batch_tokens_list.append(user_tokens) + block_size = page_params["page_block_size"] + num_blocks_needed = (padded_len + block_size - 1) // block_size + batch_page_tables.append(page_table[uid : uid + 1, :num_blocks_needed]) + batch_last_token_idxs.append(prefill_len - 1) + + tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [num_rows, padded_len] + page_table_stacked = torch.cat(batch_page_tables, dim=0) # [num_rows, num_blocks] + + (tokens_embd, rot_mats_global, rot_mats_local, page_table_tt, _) = model[ + model_id + ].prepare_inputs_prefill( + tokens_stacked, + page_table=page_table_stacked, + batched_prefill=True, + ) + + get_last_token_val = (max(batch_last_token_idxs) // 32) * 32 + tt_logits = model[model_id].ttnn_prefill_forward( + tokens_embd, + rot_mats_global=rot_mats_global, + rot_mats_local=rot_mats_local, + user_id=0, # Must be 0: each device sees page_table[0] after row-sharding + page_table=page_table_tt, + get_last_token=get_last_token_val, + kv_cache=tt_kv_cache[model_id], + ) + + row_results = model[model_id].process_output_prefill_batched( + tt_logits, [idx % 32 for idx in batch_last_token_idxs] + ) + return row_results + + # Warmup: compile with first batch + logger.info("Starting row-parallel prefill warmup...") profiler.start(f"compile_prefill", iteration=batch_idx) - generator.prefill_forward_text( - input_tokens_prefill_pt[:1], - page_table=page_table, - kv_cache=tt_kv_cache, - prompt_lens=decoding_pos, - enable_trace=enable_prefill_trace, - warmup_prefill=False, - ) + warmup_user_indices = [row * users_per_row_prefill for row in range(num_rows)] + warmup_results = _run_batched_prefill_iter(0, warmup_user_indices) + for row, uid in enumerate(warmup_user_indices): + prefilled_token[uid] = torch.argmax(warmup_results[row].view(-1)).item() profiler.end(f"compile_prefill", iteration=batch_idx) - logger.info("Finished prefill warmup") + logger.info("Finished row-parallel prefill warmup") - logger.info(f"Starting prefill...") - profiler.start(f"inference_prefill", iteration=batch_idx) - logits = generator.prefill_forward_text( - input_tokens_prefill_pt, - page_table=page_table, - kv_cache=tt_kv_cache, - prompt_lens=decoding_pos, - enable_trace=enable_prefill_trace, - warmup_prefill=False, # we can warmup prefill ourselves above if we want to + # Clear KV caches before real prefill (warmup wrote to them) + for i in range(len(model)): + for layer_obj in model[i].layers: + k_cache, v_cache = layer_obj.self_attn.layer_past + k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) + v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) + + # Real prefill + logger.info( + f"Starting row-parallel batched prefill ({users_per_row_prefill} iters for {global_batch_size} users)..." ) - prefilled_token = torch.argmax(logits, dim=-1) + profiler.start(f"inference_prefill", iteration=batch_idx) + for iter_idx in range(users_per_row_prefill): + user_indices = [row * users_per_row_prefill + iter_idx for row in range(num_rows)] + row_results = _run_batched_prefill_iter(iter_idx, user_indices) + for row, uid in enumerate(user_indices): + prefilled_token[uid] = torch.argmax(row_results[row].view(-1)).item() + if iter_idx % 8 == 0: + logger.info(f" Prefilled batch {iter_idx+1}/{users_per_row_prefill}") profiler.end(f"inference_prefill", iteration=batch_idx) - logger.info(f"Prefill finished") + logger.info(f"Row-parallel batched prefill finished ({users_per_row_prefill} iterations)") logger.info(f"First generated token: '{tokenizer.decode(prefilled_token[0])}'") # Initialize generation state like tt_transformers diff --git a/models/demos/gpt_oss/tt/model.py b/models/demos/gpt_oss/tt/model.py index 47bb42836c4e..e2187186fd3a 100644 --- a/models/demos/gpt_oss/tt/model.py +++ b/models/demos/gpt_oss/tt/model.py @@ -476,15 +476,34 @@ def prepare_inputs_prefill( trace_enabled=False, last_token_idx=None, global_user_id=None, + batched_prefill=False, ): - """Prepare inputs for prefill mode""" - # Embed the tokens - if tokens.dim() == 2: - tokens = tokens.reshape(1, 1, 1, -1) + """Prepare inputs for prefill mode + Args: + batched_prefill: If True, tokens is [num_rows, seq_len] and will be + sharded across mesh rows. Each row processes a different user. + """ + # Embed the tokens device = None if trace_enabled else self.mesh_device - tokens = ttnn.from_torch(tokens, device=device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT) + if batched_prefill: + # Row-parallel batched prefill: tokens is [num_rows, seq_len] + # Shard across mesh rows so each row gets one user's tokens [1, seq_len] + num_rows = tokens.shape[0] + seq_len_per_user = tokens.shape[1] + tokens = tokens.reshape(num_rows, 1, 1, seq_len_per_user) + tokens = ttnn.from_torch( + tokens, + device=device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh(self.mesh_device, dims=(0, None), mesh_shape=self.mesh_device.shape), + ) + else: + if tokens.dim() == 2: + tokens = tokens.reshape(1, 1, 1, -1) + tokens = ttnn.from_torch(tokens, device=device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT) if not trace_enabled: tokens_embd = ttnn.embedding(tokens, self.embedding_weight, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat8_b) @@ -589,3 +608,27 @@ def process_output_prefill(self, tt_out, last_token_idx): torch_output = ttnn.to_torch(tt_output_tensor) result = torch_output[..., last_token_idx, : self.vocab_size] return result + + def process_output_prefill_batched(self, tt_out, last_token_idxs): + """Process row-parallel batched prefill output. + + Extracts logits from one device per row (first device of each row). + + Args: + tt_out: Multi-device output tensor + last_token_idxs: List of last_token_idx per row user + + Returns: + List of per-user logit tensors (one per row) + """ + num_cols = self.mesh_device.shape[1] + device_tensors = ttnn.get_device_tensors(tt_out) + results = [] + num_rows = self.mesh_device.shape[0] + for row in range(num_rows): + device_idx = row * num_cols # First device of each row + torch_output = ttnn.to_torch(device_tensors[device_idx]) + last_idx = last_token_idxs[row] if isinstance(last_token_idxs, list) else last_token_idxs + result = torch_output[..., last_idx, : self.vocab_size] + results.append(result) + return results From ffab84235dedfe35b11ad6ece8e79c6c1b579407 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 13 Feb 2026 12:29:22 +0000 Subject: [PATCH 3/6] Enable prefill trace for row-parallel batched prefill (batch128) --- models/demos/gpt_oss/demo/text_demo.py | 287 +++++++++++++++++++------ 1 file changed, 218 insertions(+), 69 deletions(-) diff --git a/models/demos/gpt_oss/demo/text_demo.py b/models/demos/gpt_oss/demo/text_demo.py index f35398524cab..a88797d09613 100644 --- a/models/demos/gpt_oss/demo/text_demo.py +++ b/models/demos/gpt_oss/demo/text_demo.py @@ -36,6 +36,7 @@ from models.tt_transformers.demo.simple_text_demo import create_tt_page_table, load_inputs from models.tt_transformers.tt.common import ( PagedAttentionConfig, + copy_host_to_device, get_padded_prefill_len, preprocess_inputs_prefill, sample_host, @@ -242,7 +243,7 @@ def prepare_gpt_oss_generator_args( {"page_block_size": 64, "page_max_num_blocks_per_dp": 128 * 1024 // 64}, # page_params {"temperature": 0, "top_p": 0.08}, # sampling_params (greedy decoding) True, # enable_decode_trace - False, # enable_prefill_trace + True, # enable_prefill_trace True, # users_row_sharded False, # long_context_mode ), @@ -580,86 +581,234 @@ def test_gpt_oss_demo( prefilled_token = torch.zeros(global_batch_size, dtype=torch.long) - # Helper to run one batched prefill iteration - def _run_batched_prefill_iter(iter_idx, user_indices): - batch_tokens_list = [] - batch_page_tables = [] - batch_last_token_idxs = [] - - for uid in user_indices: - prefill_len = int(decoding_pos[uid]) - padded_len = get_padded_prefill_len(prefill_len) - user_tokens = torch.cat( - [ - input_tokens_prefill_pt[uid : uid + 1, :prefill_len], - torch.zeros(1, padded_len - prefill_len, dtype=torch.long), - ], - dim=-1, + if enable_prefill_trace: + # === TRACED BATCHED PREFILL === + # Trace captures device program once, then replays with input buffer updates. + # Eliminates per-iteration host dispatch overhead. + + # Uniform padded_len for all users (required for tracing: fixed tensor shapes) + max_padded_len = max(get_padded_prefill_len(int(decoding_pos[uid])) for uid in range(global_batch_size)) + block_size = page_params["page_block_size"] + max_num_blocks = (max_padded_len + block_size - 1) // block_size + + # Compute fixed get_last_token for trace (all users must be in same 32-token tile) + all_last_idxs = [int(decoding_pos[uid]) - 1 for uid in range(global_batch_size)] + fixed_get_last_token = (min(all_last_idxs) // 32) * 32 + max_tile_start = (max(all_last_idxs) // 32) * 32 + if fixed_get_last_token != max_tile_start: + logger.warning( + f"Users span multiple 32-token tiles ({fixed_get_last_token} vs {max_tile_start}), " + f"using get_last_token=-1 (slower)" ) - batch_tokens_list.append(user_tokens) - block_size = page_params["page_block_size"] - num_blocks_needed = (padded_len + block_size - 1) // block_size - batch_page_tables.append(page_table[uid : uid + 1, :num_blocks_needed]) - batch_last_token_idxs.append(prefill_len - 1) - - tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [num_rows, padded_len] - page_table_stacked = torch.cat(batch_page_tables, dim=0) # [num_rows, num_blocks] - - (tokens_embd, rot_mats_global, rot_mats_local, page_table_tt, _) = model[ - model_id - ].prepare_inputs_prefill( - tokens_stacked, - page_table=page_table_stacked, - batched_prefill=True, + fixed_get_last_token = -1 + + def _prepare_batch_host(user_indices): + """Prepare host-side tokens + page_table for a batch of users.""" + tokens_list, pt_list, last_idxs = [], [], [] + for uid in user_indices: + plen = int(decoding_pos[uid]) + toks = torch.cat( + [ + input_tokens_prefill_pt[uid : uid + 1, :plen], + torch.zeros(1, max_padded_len - plen, dtype=torch.long), + ], + dim=-1, + ) + tokens_list.append(toks) + pt_list.append(page_table[uid : uid + 1, :max_num_blocks]) + last_idxs.append(plen - 1) + return (torch.cat(tokens_list, dim=0), torch.cat(pt_list, dim=0), last_idxs) + + # --- Warmup (compilation) --- + logger.info("Starting traced row-parallel prefill warmup (compilation)...") + warmup_indices = [row * users_per_row_prefill for row in range(num_rows)] + tokens_w, pt_w, last_w = _prepare_batch_host(warmup_indices) + + host_out = model[model_id].prepare_inputs_prefill( + tokens_w, page_table=pt_w, trace_enabled=True, batched_prefill=True ) + rot_global = host_out[1] # device-resident, fixed across iterations + rot_local = host_out[2] # None + host_inputs = (host_out[0], host_out[3], host_out[4]) # tokens, pt, cpt - get_last_token_val = (max(batch_last_token_idxs) // 32) * 32 + profiler.start(f"compile_prefill", iteration=batch_idx) + dev_inputs = copy_host_to_device(host_inputs, mesh_device=mesh_device) + transformed = model[model_id].transform_and_embed_prefill_inputs_device(*dev_inputs) tt_logits = model[model_id].ttnn_prefill_forward( + transformed[0], + rot_mats_global=rot_global, + rot_mats_local=rot_local, + user_id=0, + page_table=transformed[1], + get_last_token=fixed_get_last_token, + kv_cache=tt_kv_cache[model_id], + ) + + if fixed_get_last_token == -1: + warmup_results = model[model_id].process_output_prefill_batched(tt_logits, last_w) + else: + warmup_results = model[model_id].process_output_prefill_batched( + tt_logits, [idx % 32 for idx in last_w] + ) + for row, uid in enumerate(warmup_indices): + prefilled_token[uid] = torch.argmax(warmup_results[row].view(-1)).item() + profiler.end(f"compile_prefill", iteration=batch_idx) + logger.info("Finished traced row-parallel prefill warmup") + + # Clear KV caches (warmup wrote to them) + for i in range(len(model)): + for layer_obj in model[i].layers: + k_cache, v_cache = layer_obj.self_attn.layer_past + k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) + v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) + + # --- Trace capture --- + logger.info("Capturing prefill trace...") + iter0_indices = [row * users_per_row_prefill for row in range(num_rows)] + tokens_0, pt_0, last_0 = _prepare_batch_host(iter0_indices) + host_out = model[model_id].prepare_inputs_prefill( + tokens_0, page_table=pt_0, trace_enabled=True, batched_prefill=True + ) + host_inputs = (host_out[0], host_out[3], host_out[4]) + + trace_dev_inputs = copy_host_to_device(host_inputs, mesh_device=mesh_device) + trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) + # Embed tokens on-device inside the trace (without deallocating input buffer, + # since we need to update it between trace executions) + tokens_embd = ttnn.embedding( + trace_dev_inputs[0], + model[model_id].embedding_weight, + layout=ttnn.TILE_LAYOUT, + dtype=ttnn.bfloat8_b, + ) + if len(tokens_embd.shape) == 3: + tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) + tt_out_trace = model[model_id].ttnn_prefill_forward( tokens_embd, - rot_mats_global=rot_mats_global, - rot_mats_local=rot_mats_local, - user_id=0, # Must be 0: each device sees page_table[0] after row-sharding - page_table=page_table_tt, - get_last_token=get_last_token_val, + rot_mats_global=rot_global, + rot_mats_local=rot_local, + user_id=0, + page_table=trace_dev_inputs[1], + get_last_token=fixed_get_last_token, kv_cache=tt_kv_cache[model_id], ) + ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) + logger.info("Prefill trace captured") - row_results = model[model_id].process_output_prefill_batched( - tt_logits, [idx % 32 for idx in batch_last_token_idxs] + # --- Execute trace for all iterations --- + logger.info( + f"Starting traced row-parallel prefill ({users_per_row_prefill} iters for {global_batch_size} users)..." ) - return row_results + profiler.start(f"inference_prefill", iteration=batch_idx) + for iter_idx in range(users_per_row_prefill): + user_indices = [row * users_per_row_prefill + iter_idx for row in range(num_rows)] + tokens_i, pt_i, last_i = _prepare_batch_host(user_indices) + host_out = model[model_id].prepare_inputs_prefill( + tokens_i, page_table=pt_i, trace_enabled=True, batched_prefill=True + ) + host_inputs = (host_out[0], host_out[3], host_out[4]) + copy_host_to_device(host_inputs, device_tensors=trace_dev_inputs, mesh_device=mesh_device) + ttnn.execute_trace(mesh_device, trace_id, cq_id=0, blocking=False) - # Warmup: compile with first batch - logger.info("Starting row-parallel prefill warmup...") - profiler.start(f"compile_prefill", iteration=batch_idx) - warmup_user_indices = [row * users_per_row_prefill for row in range(num_rows)] - warmup_results = _run_batched_prefill_iter(0, warmup_user_indices) - for row, uid in enumerate(warmup_user_indices): - prefilled_token[uid] = torch.argmax(warmup_results[row].view(-1)).item() - profiler.end(f"compile_prefill", iteration=batch_idx) - logger.info("Finished row-parallel prefill warmup") + if fixed_get_last_token == -1: + row_results = model[model_id].process_output_prefill_batched(tt_out_trace, last_i) + else: + row_results = model[model_id].process_output_prefill_batched( + tt_out_trace, [idx % 32 for idx in last_i] + ) + for row, uid in enumerate(user_indices): + prefilled_token[uid] = torch.argmax(row_results[row].view(-1)).item() + if iter_idx % 8 == 0: + logger.info(f" Traced prefill batch {iter_idx+1}/{users_per_row_prefill}") + profiler.end(f"inference_prefill", iteration=batch_idx) + + ttnn.release_trace(mesh_device, trace_id) + logger.info(f"Traced row-parallel prefill finished ({users_per_row_prefill} iterations)") - # Clear KV caches before real prefill (warmup wrote to them) - for i in range(len(model)): - for layer_obj in model[i].layers: - k_cache, v_cache = layer_obj.self_attn.layer_past - k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) - v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) + else: + # === NON-TRACED BATCHED PREFILL === + + # Helper to run one batched prefill iteration + def _run_batched_prefill_iter(iter_idx, user_indices): + batch_tokens_list = [] + batch_page_tables = [] + batch_last_token_idxs = [] + + for uid in user_indices: + prefill_len = int(decoding_pos[uid]) + padded_len = get_padded_prefill_len(prefill_len) + user_tokens = torch.cat( + [ + input_tokens_prefill_pt[uid : uid + 1, :prefill_len], + torch.zeros(1, padded_len - prefill_len, dtype=torch.long), + ], + dim=-1, + ) + batch_tokens_list.append(user_tokens) + block_size = page_params["page_block_size"] + num_blocks_needed = (padded_len + block_size - 1) // block_size + batch_page_tables.append(page_table[uid : uid + 1, :num_blocks_needed]) + batch_last_token_idxs.append(prefill_len - 1) + + tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [num_rows, padded_len] + page_table_stacked = torch.cat(batch_page_tables, dim=0) # [num_rows, num_blocks] + + (tokens_embd, rot_mats_global, rot_mats_local, page_table_tt, _) = model[ + model_id + ].prepare_inputs_prefill( + tokens_stacked, + page_table=page_table_stacked, + batched_prefill=True, + ) + + get_last_token_val = (max(batch_last_token_idxs) // 32) * 32 + tt_logits = model[model_id].ttnn_prefill_forward( + tokens_embd, + rot_mats_global=rot_mats_global, + rot_mats_local=rot_mats_local, + user_id=0, # Must be 0: each device sees page_table[0] after row-sharding + page_table=page_table_tt, + get_last_token=get_last_token_val, + kv_cache=tt_kv_cache[model_id], + ) + + row_results = model[model_id].process_output_prefill_batched( + tt_logits, [idx % 32 for idx in batch_last_token_idxs] + ) + return row_results + + # Warmup: compile with first batch + logger.info("Starting row-parallel prefill warmup...") + profiler.start(f"compile_prefill", iteration=batch_idx) + warmup_user_indices = [row * users_per_row_prefill for row in range(num_rows)] + warmup_results = _run_batched_prefill_iter(0, warmup_user_indices) + for row, uid in enumerate(warmup_user_indices): + prefilled_token[uid] = torch.argmax(warmup_results[row].view(-1)).item() + profiler.end(f"compile_prefill", iteration=batch_idx) + logger.info("Finished row-parallel prefill warmup") + + # Clear KV caches before real prefill (warmup wrote to them) + for i in range(len(model)): + for layer_obj in model[i].layers: + k_cache, v_cache = layer_obj.self_attn.layer_past + k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) + v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) + + # Real prefill + logger.info( + f"Starting row-parallel batched prefill ({users_per_row_prefill} iters for {global_batch_size} users)..." + ) + profiler.start(f"inference_prefill", iteration=batch_idx) + for iter_idx in range(users_per_row_prefill): + user_indices = [row * users_per_row_prefill + iter_idx for row in range(num_rows)] + row_results = _run_batched_prefill_iter(iter_idx, user_indices) + for row, uid in enumerate(user_indices): + prefilled_token[uid] = torch.argmax(row_results[row].view(-1)).item() + if iter_idx % 8 == 0: + logger.info(f" Prefilled batch {iter_idx+1}/{users_per_row_prefill}") + profiler.end(f"inference_prefill", iteration=batch_idx) + logger.info(f"Row-parallel batched prefill finished ({users_per_row_prefill} iterations)") - # Real prefill - logger.info( - f"Starting row-parallel batched prefill ({users_per_row_prefill} iters for {global_batch_size} users)..." - ) - profiler.start(f"inference_prefill", iteration=batch_idx) - for iter_idx in range(users_per_row_prefill): - user_indices = [row * users_per_row_prefill + iter_idx for row in range(num_rows)] - row_results = _run_batched_prefill_iter(iter_idx, user_indices) - for row, uid in enumerate(user_indices): - prefilled_token[uid] = torch.argmax(row_results[row].view(-1)).item() - if iter_idx % 8 == 0: - logger.info(f" Prefilled batch {iter_idx+1}/{users_per_row_prefill}") - profiler.end(f"inference_prefill", iteration=batch_idx) - logger.info(f"Row-parallel batched prefill finished ({users_per_row_prefill} iterations)") logger.info(f"First generated token: '{tokenizer.decode(prefilled_token[0])}'") # Initialize generation state like tt_transformers From 38db9f34e8b4b437f3734696609eded56848f964 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 13 Feb 2026 14:31:50 +0000 Subject: [PATCH 4/6] Parameterize users_per_row_per_iter for batched prefill batch>1 support Add users_per_row_per_iter parameter (default 1) to control how many users each mesh row processes per prefill iteration. With batch>1: - attention/prefill.py: flatten batch*seq for QKV matmul, reshape for SDPA, loop paged_fill_cache per user, slice RoPE to per-user seq_len - model.py: pass batch_size through forward chain, process_output extracts N results per row with correct seq offset - text_demo.py: concatenate per-row tokens along seq dim, generalize user_indices and loop counts for arbitrary users_per_row_per_iter Tested: batch128 passes at both users_per_row_per_iter=1 (TTFT 91ms, identical to baseline) and users_per_row_per_iter=2 (TTFT 219ms, correct outputs, slower due to disabled get_last_token optimization). --- models/demos/gpt_oss/demo/text_demo.py | 103 ++++++++++++++---- models/demos/gpt_oss/tt/attention/__init__.py | 11 +- models/demos/gpt_oss/tt/attention/prefill.py | 61 ++++++++--- models/demos/gpt_oss/tt/layer.py | 2 + models/demos/gpt_oss/tt/model.py | 38 +++++-- 5 files changed, 168 insertions(+), 47 deletions(-) diff --git a/models/demos/gpt_oss/demo/text_demo.py b/models/demos/gpt_oss/demo/text_demo.py index a88797d09613..9fa7b992d48c 100644 --- a/models/demos/gpt_oss/demo/text_demo.py +++ b/models/demos/gpt_oss/demo/text_demo.py @@ -577,6 +577,13 @@ def test_gpt_oss_demo( # This gives ~4x speedup over sequential per-user prefill num_rows = mesh_device.shape[0] users_per_row_prefill = global_batch_size // num_rows + users_per_row_per_iter = 1 # Users each mesh row processes per prefill iteration + # Increasing above 1 requires model changes: + # - attention/prefill.py: relax batch_size!=1 check, loop paged_fill_cache + # - model.py:process_output_prefill_batched: extract multiple logits per row + assert users_per_row_prefill % users_per_row_per_iter == 0 + num_prefill_iters = users_per_row_prefill // users_per_row_per_iter + users_per_iter = num_rows * users_per_row_per_iter # Total users per iteration model_id = 0 # data_parallel=1, single model prefilled_token = torch.zeros(global_batch_size, dtype=torch.long) @@ -602,6 +609,9 @@ def test_gpt_oss_demo( ) fixed_get_last_token = -1 + if users_per_row_per_iter > 1: + fixed_get_last_token = -1 # Can't use get_last_token with batch>1 + def _prepare_batch_host(user_indices): """Prepare host-side tokens + page_table for a batch of users.""" tokens_list, pt_list, last_idxs = [], [], [] @@ -621,8 +631,11 @@ def _prepare_batch_host(user_indices): # --- Warmup (compilation) --- logger.info("Starting traced row-parallel prefill warmup (compilation)...") - warmup_indices = [row * users_per_row_prefill for row in range(num_rows)] + warmup_indices = [ + row * users_per_row_prefill + u for row in range(num_rows) for u in range(users_per_row_per_iter) + ] tokens_w, pt_w, last_w = _prepare_batch_host(warmup_indices) + tokens_w = tokens_w.reshape(num_rows, -1) # [num_rows, N*S] for batch>1 concat host_out = model[model_id].prepare_inputs_prefill( tokens_w, page_table=pt_w, trace_enabled=True, batched_prefill=True @@ -642,13 +655,22 @@ def _prepare_batch_host(user_indices): page_table=transformed[1], get_last_token=fixed_get_last_token, kv_cache=tt_kv_cache[model_id], + batch_size=users_per_row_per_iter, ) if fixed_get_last_token == -1: - warmup_results = model[model_id].process_output_prefill_batched(tt_logits, last_w) + warmup_results = model[model_id].process_output_prefill_batched( + tt_logits, + last_w, + users_per_row=users_per_row_per_iter, + seq_len_per_user=max_padded_len, + ) else: warmup_results = model[model_id].process_output_prefill_batched( - tt_logits, [idx % 32 for idx in last_w] + tt_logits, + [idx % 32 for idx in last_w], + users_per_row=users_per_row_per_iter, + seq_len_per_user=max_padded_len, ) for row, uid in enumerate(warmup_indices): prefilled_token[uid] = torch.argmax(warmup_results[row].view(-1)).item() @@ -664,8 +686,11 @@ def _prepare_batch_host(user_indices): # --- Trace capture --- logger.info("Capturing prefill trace...") - iter0_indices = [row * users_per_row_prefill for row in range(num_rows)] + iter0_indices = [ + row * users_per_row_prefill + u for row in range(num_rows) for u in range(users_per_row_per_iter) + ] tokens_0, pt_0, last_0 = _prepare_batch_host(iter0_indices) + tokens_0 = tokens_0.reshape(num_rows, -1) host_out = model[model_id].prepare_inputs_prefill( tokens_0, page_table=pt_0, trace_enabled=True, batched_prefill=True ) @@ -691,18 +716,25 @@ def _prepare_batch_host(user_indices): page_table=trace_dev_inputs[1], get_last_token=fixed_get_last_token, kv_cache=tt_kv_cache[model_id], + batch_size=users_per_row_per_iter, ) ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) logger.info("Prefill trace captured") # --- Execute trace for all iterations --- logger.info( - f"Starting traced row-parallel prefill ({users_per_row_prefill} iters for {global_batch_size} users)..." + f"Starting traced row-parallel prefill ({num_prefill_iters} iters, " + f"{users_per_row_per_iter} user/row/iter, {global_batch_size} users)..." ) profiler.start(f"inference_prefill", iteration=batch_idx) - for iter_idx in range(users_per_row_prefill): - user_indices = [row * users_per_row_prefill + iter_idx for row in range(num_rows)] + for iter_idx in range(num_prefill_iters): + user_indices = [ + row * users_per_row_prefill + iter_idx * users_per_row_per_iter + u + for row in range(num_rows) + for u in range(users_per_row_per_iter) + ] tokens_i, pt_i, last_i = _prepare_batch_host(user_indices) + tokens_i = tokens_i.reshape(num_rows, -1) host_out = model[model_id].prepare_inputs_prefill( tokens_i, page_table=pt_i, trace_enabled=True, batched_prefill=True ) @@ -711,19 +743,27 @@ def _prepare_batch_host(user_indices): ttnn.execute_trace(mesh_device, trace_id, cq_id=0, blocking=False) if fixed_get_last_token == -1: - row_results = model[model_id].process_output_prefill_batched(tt_out_trace, last_i) + row_results = model[model_id].process_output_prefill_batched( + tt_out_trace, + last_i, + users_per_row=users_per_row_per_iter, + seq_len_per_user=max_padded_len, + ) else: row_results = model[model_id].process_output_prefill_batched( - tt_out_trace, [idx % 32 for idx in last_i] + tt_out_trace, + [idx % 32 for idx in last_i], + users_per_row=users_per_row_per_iter, + seq_len_per_user=max_padded_len, ) for row, uid in enumerate(user_indices): prefilled_token[uid] = torch.argmax(row_results[row].view(-1)).item() if iter_idx % 8 == 0: - logger.info(f" Traced prefill batch {iter_idx+1}/{users_per_row_prefill}") + logger.info(f" Traced prefill batch {iter_idx+1}/{num_prefill_iters}") profiler.end(f"inference_prefill", iteration=batch_idx) ttnn.release_trace(mesh_device, trace_id) - logger.info(f"Traced row-parallel prefill finished ({users_per_row_prefill} iterations)") + logger.info(f"Traced row-parallel prefill finished ({num_prefill_iters} iterations)") else: # === NON-TRACED BATCHED PREFILL === @@ -750,18 +790,22 @@ def _run_batched_prefill_iter(iter_idx, user_indices): batch_page_tables.append(page_table[uid : uid + 1, :num_blocks_needed]) batch_last_token_idxs.append(prefill_len - 1) - tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [num_rows, padded_len] - page_table_stacked = torch.cat(batch_page_tables, dim=0) # [num_rows, num_blocks] + tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [total_users, padded_len] + page_table_stacked = torch.cat(batch_page_tables, dim=0) # [total_users, num_blocks] + padded_len = tokens_stacked.shape[1] + + # Reshape tokens for batch>1: concatenate per-row users along seq dim + tokens_for_model = tokens_stacked.reshape(num_rows, -1) # [num_rows, N*padded_len] (tokens_embd, rot_mats_global, rot_mats_local, page_table_tt, _) = model[ model_id ].prepare_inputs_prefill( - tokens_stacked, + tokens_for_model, page_table=page_table_stacked, batched_prefill=True, ) - get_last_token_val = (max(batch_last_token_idxs) // 32) * 32 + get_last_token_val = (max(batch_last_token_idxs) // 32) * 32 if users_per_row_per_iter == 1 else -1 tt_logits = model[model_id].ttnn_prefill_forward( tokens_embd, rot_mats_global=rot_mats_global, @@ -770,17 +814,27 @@ def _run_batched_prefill_iter(iter_idx, user_indices): page_table=page_table_tt, get_last_token=get_last_token_val, kv_cache=tt_kv_cache[model_id], + batch_size=users_per_row_per_iter, ) + if get_last_token_val == -1: + adjusted_last_idxs = batch_last_token_idxs + else: + adjusted_last_idxs = [idx % 32 for idx in batch_last_token_idxs] row_results = model[model_id].process_output_prefill_batched( - tt_logits, [idx % 32 for idx in batch_last_token_idxs] + tt_logits, + adjusted_last_idxs, + users_per_row=users_per_row_per_iter, + seq_len_per_user=padded_len, ) return row_results # Warmup: compile with first batch logger.info("Starting row-parallel prefill warmup...") profiler.start(f"compile_prefill", iteration=batch_idx) - warmup_user_indices = [row * users_per_row_prefill for row in range(num_rows)] + warmup_user_indices = [ + row * users_per_row_prefill + u for row in range(num_rows) for u in range(users_per_row_per_iter) + ] warmup_results = _run_batched_prefill_iter(0, warmup_user_indices) for row, uid in enumerate(warmup_user_indices): prefilled_token[uid] = torch.argmax(warmup_results[row].view(-1)).item() @@ -796,18 +850,23 @@ def _run_batched_prefill_iter(iter_idx, user_indices): # Real prefill logger.info( - f"Starting row-parallel batched prefill ({users_per_row_prefill} iters for {global_batch_size} users)..." + f"Starting row-parallel batched prefill ({num_prefill_iters} iters, " + f"{users_per_row_per_iter} user/row/iter, {global_batch_size} users)..." ) profiler.start(f"inference_prefill", iteration=batch_idx) - for iter_idx in range(users_per_row_prefill): - user_indices = [row * users_per_row_prefill + iter_idx for row in range(num_rows)] + for iter_idx in range(num_prefill_iters): + user_indices = [ + row * users_per_row_prefill + iter_idx * users_per_row_per_iter + u + for row in range(num_rows) + for u in range(users_per_row_per_iter) + ] row_results = _run_batched_prefill_iter(iter_idx, user_indices) for row, uid in enumerate(user_indices): prefilled_token[uid] = torch.argmax(row_results[row].view(-1)).item() if iter_idx % 8 == 0: - logger.info(f" Prefilled batch {iter_idx+1}/{users_per_row_prefill}") + logger.info(f" Prefilled batch {iter_idx+1}/{num_prefill_iters}") profiler.end(f"inference_prefill", iteration=batch_idx) - logger.info(f"Row-parallel batched prefill finished ({users_per_row_prefill} iterations)") + logger.info(f"Row-parallel batched prefill finished ({num_prefill_iters} iterations)") logger.info(f"First generated token: '{tokenizer.decode(prefilled_token[0])}'") diff --git a/models/demos/gpt_oss/tt/attention/__init__.py b/models/demos/gpt_oss/tt/attention/__init__.py index 52eab7bacc3c..c7963051a511 100644 --- a/models/demos/gpt_oss/tt/attention/__init__.py +++ b/models/demos/gpt_oss/tt/attention/__init__.py @@ -108,7 +108,15 @@ def __init__( self.scaling = config.scaling def __call__( - self, hidden_states, rope_mats, position_idx=None, page_table=None, kv_cache=None, is_decode=True, user_id=0 + self, + hidden_states, + rope_mats, + position_idx=None, + page_table=None, + kv_cache=None, + is_decode=True, + user_id=0, + batch_size=1, ): """ Forward pass - automatically dispatches to decode or prefill. @@ -169,4 +177,5 @@ def __call__( position_idx=position_idx, page_table=page_table, ccl_manager=self.ccl_manager, + batch_size=batch_size, ) diff --git a/models/demos/gpt_oss/tt/attention/prefill.py b/models/demos/gpt_oss/tt/attention/prefill.py index 56f8376c6faf..06c21c453f17 100644 --- a/models/demos/gpt_oss/tt/attention/prefill.py +++ b/models/demos/gpt_oss/tt/attention/prefill.py @@ -29,6 +29,7 @@ def prefill_forward( page_table, ccl_manager, user_id=0, + batch_size=1, ): """ Prefill forward pass - optimized for sequence processing (seq_len>1). @@ -51,18 +52,21 @@ def prefill_forward( Attention output [batch, seq_len, hidden_size] """ activation_dtype = ttnn.bfloat16 - _, batch_size, seq_len, hidden_size = hidden_states.shape + total_seq_len = hidden_states.shape[-2] + hidden_size = hidden_states.shape[-1] + seq_len = total_seq_len // batch_size # Per-user sequence length # Validate prefill mode if seq_len <= 1: raise ValueError(f"Prefill mode requires seq_len>1, got {seq_len}. Use decode mode for single tokens.") - if batch_size != 1: - raise NotImplementedError(f"Currently only batch_size=1 supported, got {batch_size}") - # QKV projection xqkv_fused = apply_qkv_projection(hidden_states, weights) + # Reshape for batch: [1, 1, B*S, QKV] -> [B, 1, S, QKV] + if batch_size > 1: + xqkv_fused = ttnn.reshape(xqkv_fused, [batch_size, 1, seq_len, -1]) + # Split into Q, K, V heads num_local_heads = mesh_config.shard_size(config.num_heads) num_local_kv_heads = mesh_config.shard_size(config.num_kv_heads) @@ -70,11 +74,15 @@ def prefill_forward( tt_q, tt_k, tt_v = split_qkv_heads_prefill(xqkv_fused, num_local_heads, num_local_kv_heads) xqkv_fused.deallocate(True) - # Apply RoPE + # Apply RoPE (use per-user seq_len positions) + if batch_size > 1: + rope_mats_sliced = [rope_mats[0][:, :, :seq_len, :], rope_mats[1][:, :, :seq_len, :]] + else: + rope_mats_sliced = rope_mats tt_q_orig = tt_q tt_k_orig = tt_k - tt_q = apply_rope(tt_q, rope_mats, transformation_mat, is_decode_mode=False) - tt_k = apply_rope(tt_k, rope_mats, transformation_mat, is_decode_mode=False) + tt_q = apply_rope(tt_q, rope_mats_sliced, transformation_mat, is_decode_mode=False) + tt_k = apply_rope(tt_k, rope_mats_sliced, transformation_mat, is_decode_mode=False) tt_q_orig.deallocate(True) tt_k_orig.deallocate(True) @@ -89,16 +97,34 @@ def prefill_forward( if page_table is not None: block_size = k_cache.shape[2] - page_len = page_table.shape[1] * block_size - tt_k_sliced = tt_k[:, :, :page_len, :] if page_len < tt_k.shape[2] else tt_k - tt_v_sliced = tt_v[:, :, :page_len, :] if page_len < tt_v.shape[2] else tt_v - ttnn.experimental.paged_fill_cache(k_cache, tt_k_sliced, page_table, batch_idx=user_id) - ttnn.experimental.paged_fill_cache(v_cache, tt_v_sliced, page_table, batch_idx=user_id) + page_len = page_table.shape[-1] * block_size + if batch_size > 1: + for b in range(batch_size): + k_b = ttnn.slice(tt_k, (b, 0, 0, 0), (b + 1, tt_k.shape[1], min(page_len, seq_len), tt_k.shape[3])) + v_b = ttnn.slice(tt_v, (b, 0, 0, 0), (b + 1, tt_v.shape[1], min(page_len, seq_len), tt_v.shape[3])) + ttnn.experimental.paged_fill_cache(k_cache, k_b, page_table, batch_idx=b) + ttnn.experimental.paged_fill_cache(v_cache, v_b, page_table, batch_idx=b) + k_b.deallocate(True) + v_b.deallocate(True) + else: + tt_k_sliced = tt_k[:, :, :page_len, :] if page_len < tt_k.shape[2] else tt_k + tt_v_sliced = tt_v[:, :, :page_len, :] if page_len < tt_v.shape[2] else tt_v + ttnn.experimental.paged_fill_cache(k_cache, tt_k_sliced, page_table, batch_idx=user_id) + ttnn.experimental.paged_fill_cache(v_cache, tt_v_sliced, page_table, batch_idx=user_id) else: # Non-paged attention - ttnn.fill_cache(k_cache, tt_k, batch_idx=user_id) - ttnn.fill_cache(v_cache, tt_v, batch_idx=user_id) + if batch_size > 1: + for b in range(batch_size): + k_b = ttnn.slice(tt_k, (b, 0, 0, 0), (b + 1, tt_k.shape[1], tt_k.shape[2], tt_k.shape[3])) + v_b = ttnn.slice(tt_v, (b, 0, 0, 0), (b + 1, tt_v.shape[1], tt_v.shape[2], tt_v.shape[3])) + ttnn.fill_cache(k_cache, k_b, batch_idx=b) + ttnn.fill_cache(v_cache, v_b, batch_idx=b) + k_b.deallocate(True) + v_b.deallocate(True) + else: + ttnn.fill_cache(k_cache, tt_k, batch_idx=user_id) + ttnn.fill_cache(v_cache, tt_v, batch_idx=user_id) # Scaled dot-product attention tt_sdpa_out = ttnn.transformer.scaled_dot_product_attention( @@ -120,11 +146,14 @@ def prefill_forward( tt_sdpa_out = concat_heads(tt_sdpa_out, is_decode_mode=False) tt_sdpa_out_pre_concat.deallocate(True) + # Flatten back for output projection: [B, 1, S, H] -> [1, 1, B*S, H] + if batch_size > 1: + tt_sdpa_out = ttnn.reshape(tt_sdpa_out, [1, 1, total_seq_len, -1]) + tt_out = apply_output_projection(tt_sdpa_out, weights, activation_dtype) # Note: apply_output_projection already deallocates its input tensor internally - # tt_out = ttnn.reshape(tt_out, (batch_size, seq_len, hidden_size)) # Tensor parallel allreduce - tt_out = apply_allreduce(tt_out, mesh_config, ccl_manager, batch_size, seq_len, hidden_size) + tt_out = apply_allreduce(tt_out, mesh_config, ccl_manager, 1, total_seq_len, hidden_size) return tt_out diff --git a/models/demos/gpt_oss/tt/layer.py b/models/demos/gpt_oss/tt/layer.py index 31aada54584d..c3b91c5b8ef1 100644 --- a/models/demos/gpt_oss/tt/layer.py +++ b/models/demos/gpt_oss/tt/layer.py @@ -96,6 +96,7 @@ def __call__( kv_cache=None, is_decode=True, user_id=0, + batch_size=1, ): # hidden_states: [1, 1, tokens/num_rows, hidden_size/num_columns] # residual: [1, 1, tokens/num_rows, hidden_size/num_columns] @@ -112,6 +113,7 @@ def __call__( kv_cache=kv_cache, is_decode=is_decode, user_id=user_id, + batch_size=batch_size, ) hidden_states_post_norm.deallocate(True) diff --git a/models/demos/gpt_oss/tt/model.py b/models/demos/gpt_oss/tt/model.py index e2187186fd3a..2024dd84ab81 100644 --- a/models/demos/gpt_oss/tt/model.py +++ b/models/demos/gpt_oss/tt/model.py @@ -234,7 +234,16 @@ def switch_mode(self, mode: Mode): return None def _forward_layers_and_head( - self, hidden_states, rope_mats, current_pos, page_table, kv_cache, get_last_token=-1, is_decode=True, user_id=0 + self, + hidden_states, + rope_mats, + current_pos, + page_table, + kv_cache, + get_last_token=-1, + is_decode=True, + user_id=0, + batch_size=1, ): """ Shared forward pass through decoder layers and final projection. @@ -264,6 +273,7 @@ def _forward_layers_and_head( kv_cache=layer_kv_cache, is_decode=is_decode, user_id=user_id, + batch_size=batch_size, ) logits = hidden_states @@ -335,6 +345,7 @@ def ttnn_prefill_forward( chunk_start_idx=None, get_last_token=-1, kv_cache=None, + batch_size=1, ): """Prefill forward pass - processes full sequences""" # Use provided rotation matrices or slice from rope_setup (matches tt-transformers) @@ -355,9 +366,10 @@ def ttnn_prefill_forward( current_pos=None, # No current_pos for prefill page_table=page_table, kv_cache=kv_cache, - get_last_token=get_last_token, + get_last_token=get_last_token if batch_size == 1 else -1, # Disable get_last_token for batch>1 is_decode=False, user_id=user_id, + batch_size=batch_size, ) return logits @@ -609,17 +621,20 @@ def process_output_prefill(self, tt_out, last_token_idx): result = torch_output[..., last_token_idx, : self.vocab_size] return result - def process_output_prefill_batched(self, tt_out, last_token_idxs): + def process_output_prefill_batched(self, tt_out, last_token_idxs, users_per_row=1, seq_len_per_user=None): """Process row-parallel batched prefill output. Extracts logits from one device per row (first device of each row). + Supports multiple users per row when users_per_row > 1. Args: tt_out: Multi-device output tensor - last_token_idxs: List of last_token_idx per row user + last_token_idxs: List of last_token_idx per user (length = num_rows * users_per_row) + users_per_row: Number of users per mesh row per iteration + seq_len_per_user: Per-user sequence length (required when users_per_row > 1) Returns: - List of per-user logit tensors (one per row) + List of per-user logit tensors (one per user) """ num_cols = self.mesh_device.shape[1] device_tensors = ttnn.get_device_tensors(tt_out) @@ -628,7 +643,14 @@ def process_output_prefill_batched(self, tt_out, last_token_idxs): for row in range(num_rows): device_idx = row * num_cols # First device of each row torch_output = ttnn.to_torch(device_tensors[device_idx]) - last_idx = last_token_idxs[row] if isinstance(last_token_idxs, list) else last_token_idxs - result = torch_output[..., last_idx, : self.vocab_size] - results.append(result) + for u in range(users_per_row): + user_flat_idx = row * users_per_row + u + last_idx = last_token_idxs[user_flat_idx] if isinstance(last_token_idxs, list) else last_token_idxs + if users_per_row > 1: + # Tokens are concatenated: user u's last token is at offset u*seq_len_per_user + last_idx + global_idx = u * seq_len_per_user + last_idx + result = torch_output[..., global_idx, : self.vocab_size] + else: + result = torch_output[..., last_idx, : self.vocab_size] + results.append(result) return results From 2c88901115177d41207c4cdcbc3b825d5dc33a83 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 16 Feb 2026 06:15:30 +0000 Subject: [PATCH 5/6] Address PR review comments in text_demo.py - Remove unnecessary k_cache/v_cache reassignments (in-place via output_tensor) - Remove unused users_per_iter variable - Restructure fixed_get_last_token logic to avoid unreachable code when users_per_row_per_iter=1 --- models/demos/gpt_oss/demo/text_demo.py | 27 +++++++++++++------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/models/demos/gpt_oss/demo/text_demo.py b/models/demos/gpt_oss/demo/text_demo.py index 9fa7b992d48c..dd8eb2983f2d 100644 --- a/models/demos/gpt_oss/demo/text_demo.py +++ b/models/demos/gpt_oss/demo/text_demo.py @@ -583,7 +583,6 @@ def test_gpt_oss_demo( # - model.py:process_output_prefill_batched: extract multiple logits per row assert users_per_row_prefill % users_per_row_per_iter == 0 num_prefill_iters = users_per_row_prefill // users_per_row_per_iter - users_per_iter = num_rows * users_per_row_per_iter # Total users per iteration model_id = 0 # data_parallel=1, single model prefilled_token = torch.zeros(global_batch_size, dtype=torch.long) @@ -600,17 +599,17 @@ def test_gpt_oss_demo( # Compute fixed get_last_token for trace (all users must be in same 32-token tile) all_last_idxs = [int(decoding_pos[uid]) - 1 for uid in range(global_batch_size)] - fixed_get_last_token = (min(all_last_idxs) // 32) * 32 - max_tile_start = (max(all_last_idxs) // 32) * 32 - if fixed_get_last_token != max_tile_start: - logger.warning( - f"Users span multiple 32-token tiles ({fixed_get_last_token} vs {max_tile_start}), " - f"using get_last_token=-1 (slower)" - ) - fixed_get_last_token = -1 - if users_per_row_per_iter > 1: fixed_get_last_token = -1 # Can't use get_last_token with batch>1 + else: + fixed_get_last_token = (min(all_last_idxs) // 32) * 32 + max_tile_start = (max(all_last_idxs) // 32) * 32 + if fixed_get_last_token != max_tile_start: + logger.warning( + f"Users span multiple 32-token tiles ({fixed_get_last_token} vs {max_tile_start}), " + f"using get_last_token=-1 (slower)" + ) + fixed_get_last_token = -1 def _prepare_batch_host(user_indices): """Prepare host-side tokens + page_table for a batch of users.""" @@ -681,8 +680,8 @@ def _prepare_batch_host(user_indices): for i in range(len(model)): for layer_obj in model[i].layers: k_cache, v_cache = layer_obj.self_attn.layer_past - k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) - v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) + ttnn.mul(k_cache, 0, output_tensor=k_cache) + ttnn.mul(v_cache, 0, output_tensor=v_cache) # --- Trace capture --- logger.info("Capturing prefill trace...") @@ -845,8 +844,8 @@ def _run_batched_prefill_iter(iter_idx, user_indices): for i in range(len(model)): for layer_obj in model[i].layers: k_cache, v_cache = layer_obj.self_attn.layer_past - k_cache = ttnn.mul(k_cache, 0, output_tensor=k_cache) - v_cache = ttnn.mul(v_cache, 0, output_tensor=v_cache) + ttnn.mul(k_cache, 0, output_tensor=k_cache) + ttnn.mul(v_cache, 0, output_tensor=v_cache) # Real prefill logger.info( From 387359eeaead71936878e2f909c089da29eb76b9 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 16 Feb 2026 06:55:54 +0000 Subject: [PATCH 6/6] Enable get_last_token for batch>1 prefill, single-call paged_fill_cache - model.py: When batch_size>1, extract each user's 32-token last-token tile via ttnn.slice and concat to [1,1,B*32,H] before norm+lm_head. Removes the batch>1 get_last_token=-1 override in ttnn_prefill_forward. - attention/prefill.py: Replace per-user paged_fill_cache loop with single-call reshape approach (flatten batch into seq dim, heads into last dim, flatten page table). Matches llama_70b_galaxy pattern. - text_demo.py: Remove batch>1 get_last_token override, use seq_len_per_user=32 when get_last_token is active. batch128 users_per_row_per_iter=2 TTFT: 218ms -> 99ms (get_last_token fix) Compile time: 5.24s -> 3.16s (single-call paged_fill_cache) batch128 users_per_row_per_iter=1: unchanged at 91ms --- models/demos/gpt_oss/demo/text_demo.py | 32 +++++++++++--------- models/demos/gpt_oss/tt/attention/prefill.py | 16 +++++----- models/demos/gpt_oss/tt/model.py | 24 ++++++++++++--- 3 files changed, 45 insertions(+), 27 deletions(-) diff --git a/models/demos/gpt_oss/demo/text_demo.py b/models/demos/gpt_oss/demo/text_demo.py index dd8eb2983f2d..a09c307cad3c 100644 --- a/models/demos/gpt_oss/demo/text_demo.py +++ b/models/demos/gpt_oss/demo/text_demo.py @@ -599,17 +599,14 @@ def test_gpt_oss_demo( # Compute fixed get_last_token for trace (all users must be in same 32-token tile) all_last_idxs = [int(decoding_pos[uid]) - 1 for uid in range(global_batch_size)] - if users_per_row_per_iter > 1: - fixed_get_last_token = -1 # Can't use get_last_token with batch>1 - else: - fixed_get_last_token = (min(all_last_idxs) // 32) * 32 - max_tile_start = (max(all_last_idxs) // 32) * 32 - if fixed_get_last_token != max_tile_start: - logger.warning( - f"Users span multiple 32-token tiles ({fixed_get_last_token} vs {max_tile_start}), " - f"using get_last_token=-1 (slower)" - ) - fixed_get_last_token = -1 + fixed_get_last_token = (min(all_last_idxs) // 32) * 32 + max_tile_start = (max(all_last_idxs) // 32) * 32 + if fixed_get_last_token != max_tile_start: + logger.warning( + f"Users span multiple 32-token tiles ({fixed_get_last_token} vs {max_tile_start}), " + f"using get_last_token=-1 (slower)" + ) + fixed_get_last_token = -1 def _prepare_batch_host(user_indices): """Prepare host-side tokens + page_table for a batch of users.""" @@ -669,7 +666,7 @@ def _prepare_batch_host(user_indices): tt_logits, [idx % 32 for idx in last_w], users_per_row=users_per_row_per_iter, - seq_len_per_user=max_padded_len, + seq_len_per_user=32, ) for row, uid in enumerate(warmup_indices): prefilled_token[uid] = torch.argmax(warmup_results[row].view(-1)).item() @@ -753,7 +750,7 @@ def _prepare_batch_host(user_indices): tt_out_trace, [idx % 32 for idx in last_i], users_per_row=users_per_row_per_iter, - seq_len_per_user=max_padded_len, + seq_len_per_user=32, ) for row, uid in enumerate(user_indices): prefilled_token[uid] = torch.argmax(row_results[row].view(-1)).item() @@ -804,7 +801,10 @@ def _run_batched_prefill_iter(iter_idx, user_indices): batched_prefill=True, ) - get_last_token_val = (max(batch_last_token_idxs) // 32) * 32 if users_per_row_per_iter == 1 else -1 + # Use get_last_token if all users' last tokens fall in the same 32-token tile + min_tile = (min(batch_last_token_idxs) // 32) * 32 + max_tile = (max(batch_last_token_idxs) // 32) * 32 + get_last_token_val = min_tile if min_tile == max_tile else -1 tt_logits = model[model_id].ttnn_prefill_forward( tokens_embd, rot_mats_global=rot_mats_global, @@ -818,13 +818,15 @@ def _run_batched_prefill_iter(iter_idx, user_indices): if get_last_token_val == -1: adjusted_last_idxs = batch_last_token_idxs + seq_len_for_output = padded_len else: adjusted_last_idxs = [idx % 32 for idx in batch_last_token_idxs] + seq_len_for_output = 32 row_results = model[model_id].process_output_prefill_batched( tt_logits, adjusted_last_idxs, users_per_row=users_per_row_per_iter, - seq_len_per_user=padded_len, + seq_len_per_user=seq_len_for_output, ) return row_results diff --git a/models/demos/gpt_oss/tt/attention/prefill.py b/models/demos/gpt_oss/tt/attention/prefill.py index 06c21c453f17..2f4283b7151c 100644 --- a/models/demos/gpt_oss/tt/attention/prefill.py +++ b/models/demos/gpt_oss/tt/attention/prefill.py @@ -99,13 +99,15 @@ def prefill_forward( block_size = k_cache.shape[2] page_len = page_table.shape[-1] * block_size if batch_size > 1: - for b in range(batch_size): - k_b = ttnn.slice(tt_k, (b, 0, 0, 0), (b + 1, tt_k.shape[1], min(page_len, seq_len), tt_k.shape[3])) - v_b = ttnn.slice(tt_v, (b, 0, 0, 0), (b + 1, tt_v.shape[1], min(page_len, seq_len), tt_v.shape[3])) - ttnn.experimental.paged_fill_cache(k_cache, k_b, page_table, batch_idx=b) - ttnn.experimental.paged_fill_cache(v_cache, v_b, page_table, batch_idx=b) - k_b.deallocate(True) - v_b.deallocate(True) + # Flatten batch into seq dim, heads into last dim — single fill call, no per-user loop. + # Paged cache just maps sequence positions to physical pages. + k_fill = ttnn.reshape(tt_k, [1, 1, total_seq_len, -1]) + v_fill = ttnn.reshape(tt_v, [1, 1, total_seq_len, -1]) + page_table_flat = ttnn.reshape(page_table, [1, -1]) + ttnn.experimental.paged_fill_cache(k_cache, k_fill, page_table_flat, batch_idx=0) + ttnn.experimental.paged_fill_cache(v_cache, v_fill, page_table_flat, batch_idx=0) + k_fill.deallocate(True) + v_fill.deallocate(True) else: tt_k_sliced = tt_k[:, :, :page_len, :] if page_len < tt_k.shape[2] else tt_k tt_v_sliced = tt_v[:, :, :page_len, :] if page_len < tt_v.shape[2] else tt_v diff --git a/models/demos/gpt_oss/tt/model.py b/models/demos/gpt_oss/tt/model.py index 2024dd84ab81..44f36af797f5 100644 --- a/models/demos/gpt_oss/tt/model.py +++ b/models/demos/gpt_oss/tt/model.py @@ -278,12 +278,26 @@ def _forward_layers_and_head( logits = hidden_states if get_last_token != -1: - # The logits come from the shared method, slice them if len(logits.shape) == 3: logits = ttnn.unsqueeze(logits, dim=1) - logits_sliced = ttnn.slice(logits, (0, 0, get_last_token, 0), (1, 1, get_last_token + 32, logits.shape[-1])) - logits.deallocate(True) - logits = logits_sliced + if batch_size > 1: + # Batch>1: tokens are concatenated [1,1,B*S,H]. Extract each user's 32-token tile. + per_user_seq = logits.shape[2] // batch_size + tiles = [] + for b in range(batch_size): + start = b * per_user_seq + get_last_token + tile = ttnn.slice(logits, (0, 0, start, 0), (1, 1, start + 32, logits.shape[-1])) + tiles.append(tile) + logits.deallocate(True) + logits = ttnn.concat(tiles, dim=2) # [1, 1, B*32, H] + for t in tiles: + t.deallocate(True) + else: + logits_sliced = ttnn.slice( + logits, (0, 0, get_last_token, 0), (1, 1, get_last_token + 32, logits.shape[-1]) + ) + logits.deallocate(True) + logits = logits_sliced hidden_states = logits # Final norm and lm_head @@ -366,7 +380,7 @@ def ttnn_prefill_forward( current_pos=None, # No current_pos for prefill page_table=page_table, kv_cache=kv_cache, - get_last_token=get_last_token if batch_size == 1 else -1, # Disable get_last_token for batch>1 + get_last_token=get_last_token, is_decode=False, user_id=user_id, batch_size=batch_size,