diff --git a/models/demos/gpt_oss/demo/text_demo.py b/models/demos/gpt_oss/demo/text_demo.py index 1ca99ae4a5c4..a09c307cad3c 100644 --- a/models/demos/gpt_oss/demo/text_demo.py +++ b/models/demos/gpt_oss/demo/text_demo.py @@ -36,6 +36,7 @@ from models.tt_transformers.demo.simple_text_demo import create_tt_page_table, load_inputs from models.tt_transformers.tt.common import ( PagedAttentionConfig, + copy_host_to_device, get_padded_prefill_len, preprocess_inputs_prefill, sample_host, @@ -242,7 +243,7 @@ def prepare_gpt_oss_generator_args( {"page_block_size": 64, "page_max_num_blocks_per_dp": 128 * 1024 // 64}, # page_params {"temperature": 0, "top_p": 0.08}, # sampling_params (greedy decoding) True, # enable_decode_trace - False, # enable_prefill_trace + True, # enable_prefill_trace True, # users_row_sharded False, # long_context_mode ), @@ -572,33 +573,302 @@ def test_gpt_oss_demo( logger.info(f"Prefill finished for {num_real_users} real users") logger.info(f"First generated token (user 0): '{tokenizer.decode(prefilled_token[0])}'") else: - # Standard batch prefill (matching tt_transformers) - logger.info("Starting prefill warmup...") - profiler.start(f"compile_prefill", iteration=batch_idx) - generator.prefill_forward_text( - input_tokens_prefill_pt[:1], - page_table=page_table, - kv_cache=tt_kv_cache, - prompt_lens=decoding_pos, - enable_trace=enable_prefill_trace, - warmup_prefill=False, - ) - profiler.end(f"compile_prefill", iteration=batch_idx) - logger.info("Finished prefill warmup") + # Row-parallel batched prefill: process 4 users at once (one per mesh row) + # This gives ~4x speedup over sequential per-user prefill + num_rows = mesh_device.shape[0] + users_per_row_prefill = global_batch_size // num_rows + users_per_row_per_iter = 1 # Users each mesh row processes per prefill iteration + # Increasing above 1 requires model changes: + # - attention/prefill.py: relax batch_size!=1 check, loop paged_fill_cache + # - model.py:process_output_prefill_batched: extract multiple logits per row + assert users_per_row_prefill % users_per_row_per_iter == 0 + num_prefill_iters = users_per_row_prefill // users_per_row_per_iter + model_id = 0 # data_parallel=1, single model + + prefilled_token = torch.zeros(global_batch_size, dtype=torch.long) + + if enable_prefill_trace: + # === TRACED BATCHED PREFILL === + # Trace captures device program once, then replays with input buffer updates. + # Eliminates per-iteration host dispatch overhead. + + # Uniform padded_len for all users (required for tracing: fixed tensor shapes) + max_padded_len = max(get_padded_prefill_len(int(decoding_pos[uid])) for uid in range(global_batch_size)) + block_size = page_params["page_block_size"] + max_num_blocks = (max_padded_len + block_size - 1) // block_size + + # Compute fixed get_last_token for trace (all users must be in same 32-token tile) + all_last_idxs = [int(decoding_pos[uid]) - 1 for uid in range(global_batch_size)] + fixed_get_last_token = (min(all_last_idxs) // 32) * 32 + max_tile_start = (max(all_last_idxs) // 32) * 32 + if fixed_get_last_token != max_tile_start: + logger.warning( + f"Users span multiple 32-token tiles ({fixed_get_last_token} vs {max_tile_start}), " + f"using get_last_token=-1 (slower)" + ) + fixed_get_last_token = -1 + + def _prepare_batch_host(user_indices): + """Prepare host-side tokens + page_table for a batch of users.""" + tokens_list, pt_list, last_idxs = [], [], [] + for uid in user_indices: + plen = int(decoding_pos[uid]) + toks = torch.cat( + [ + input_tokens_prefill_pt[uid : uid + 1, :plen], + torch.zeros(1, max_padded_len - plen, dtype=torch.long), + ], + dim=-1, + ) + tokens_list.append(toks) + pt_list.append(page_table[uid : uid + 1, :max_num_blocks]) + last_idxs.append(plen - 1) + return (torch.cat(tokens_list, dim=0), torch.cat(pt_list, dim=0), last_idxs) + + # --- Warmup (compilation) --- + logger.info("Starting traced row-parallel prefill warmup (compilation)...") + warmup_indices = [ + row * users_per_row_prefill + u for row in range(num_rows) for u in range(users_per_row_per_iter) + ] + tokens_w, pt_w, last_w = _prepare_batch_host(warmup_indices) + tokens_w = tokens_w.reshape(num_rows, -1) # [num_rows, N*S] for batch>1 concat + + host_out = model[model_id].prepare_inputs_prefill( + tokens_w, page_table=pt_w, trace_enabled=True, batched_prefill=True + ) + rot_global = host_out[1] # device-resident, fixed across iterations + rot_local = host_out[2] # None + host_inputs = (host_out[0], host_out[3], host_out[4]) # tokens, pt, cpt + + profiler.start(f"compile_prefill", iteration=batch_idx) + dev_inputs = copy_host_to_device(host_inputs, mesh_device=mesh_device) + transformed = model[model_id].transform_and_embed_prefill_inputs_device(*dev_inputs) + tt_logits = model[model_id].ttnn_prefill_forward( + transformed[0], + rot_mats_global=rot_global, + rot_mats_local=rot_local, + user_id=0, + page_table=transformed[1], + get_last_token=fixed_get_last_token, + kv_cache=tt_kv_cache[model_id], + batch_size=users_per_row_per_iter, + ) + + if fixed_get_last_token == -1: + warmup_results = model[model_id].process_output_prefill_batched( + tt_logits, + last_w, + users_per_row=users_per_row_per_iter, + seq_len_per_user=max_padded_len, + ) + else: + warmup_results = model[model_id].process_output_prefill_batched( + tt_logits, + [idx % 32 for idx in last_w], + users_per_row=users_per_row_per_iter, + seq_len_per_user=32, + ) + for row, uid in enumerate(warmup_indices): + prefilled_token[uid] = torch.argmax(warmup_results[row].view(-1)).item() + profiler.end(f"compile_prefill", iteration=batch_idx) + logger.info("Finished traced row-parallel prefill warmup") + + # Clear KV caches (warmup wrote to them) + for i in range(len(model)): + for layer_obj in model[i].layers: + k_cache, v_cache = layer_obj.self_attn.layer_past + ttnn.mul(k_cache, 0, output_tensor=k_cache) + ttnn.mul(v_cache, 0, output_tensor=v_cache) + + # --- Trace capture --- + logger.info("Capturing prefill trace...") + iter0_indices = [ + row * users_per_row_prefill + u for row in range(num_rows) for u in range(users_per_row_per_iter) + ] + tokens_0, pt_0, last_0 = _prepare_batch_host(iter0_indices) + tokens_0 = tokens_0.reshape(num_rows, -1) + host_out = model[model_id].prepare_inputs_prefill( + tokens_0, page_table=pt_0, trace_enabled=True, batched_prefill=True + ) + host_inputs = (host_out[0], host_out[3], host_out[4]) + + trace_dev_inputs = copy_host_to_device(host_inputs, mesh_device=mesh_device) + trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) + # Embed tokens on-device inside the trace (without deallocating input buffer, + # since we need to update it between trace executions) + tokens_embd = ttnn.embedding( + trace_dev_inputs[0], + model[model_id].embedding_weight, + layout=ttnn.TILE_LAYOUT, + dtype=ttnn.bfloat8_b, + ) + if len(tokens_embd.shape) == 3: + tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) + tt_out_trace = model[model_id].ttnn_prefill_forward( + tokens_embd, + rot_mats_global=rot_global, + rot_mats_local=rot_local, + user_id=0, + page_table=trace_dev_inputs[1], + get_last_token=fixed_get_last_token, + kv_cache=tt_kv_cache[model_id], + batch_size=users_per_row_per_iter, + ) + ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) + logger.info("Prefill trace captured") + + # --- Execute trace for all iterations --- + logger.info( + f"Starting traced row-parallel prefill ({num_prefill_iters} iters, " + f"{users_per_row_per_iter} user/row/iter, {global_batch_size} users)..." + ) + profiler.start(f"inference_prefill", iteration=batch_idx) + for iter_idx in range(num_prefill_iters): + user_indices = [ + row * users_per_row_prefill + iter_idx * users_per_row_per_iter + u + for row in range(num_rows) + for u in range(users_per_row_per_iter) + ] + tokens_i, pt_i, last_i = _prepare_batch_host(user_indices) + tokens_i = tokens_i.reshape(num_rows, -1) + host_out = model[model_id].prepare_inputs_prefill( + tokens_i, page_table=pt_i, trace_enabled=True, batched_prefill=True + ) + host_inputs = (host_out[0], host_out[3], host_out[4]) + copy_host_to_device(host_inputs, device_tensors=trace_dev_inputs, mesh_device=mesh_device) + ttnn.execute_trace(mesh_device, trace_id, cq_id=0, blocking=False) + + if fixed_get_last_token == -1: + row_results = model[model_id].process_output_prefill_batched( + tt_out_trace, + last_i, + users_per_row=users_per_row_per_iter, + seq_len_per_user=max_padded_len, + ) + else: + row_results = model[model_id].process_output_prefill_batched( + tt_out_trace, + [idx % 32 for idx in last_i], + users_per_row=users_per_row_per_iter, + seq_len_per_user=32, + ) + for row, uid in enumerate(user_indices): + prefilled_token[uid] = torch.argmax(row_results[row].view(-1)).item() + if iter_idx % 8 == 0: + logger.info(f" Traced prefill batch {iter_idx+1}/{num_prefill_iters}") + profiler.end(f"inference_prefill", iteration=batch_idx) + + ttnn.release_trace(mesh_device, trace_id) + logger.info(f"Traced row-parallel prefill finished ({num_prefill_iters} iterations)") + + else: + # === NON-TRACED BATCHED PREFILL === + + # Helper to run one batched prefill iteration + def _run_batched_prefill_iter(iter_idx, user_indices): + batch_tokens_list = [] + batch_page_tables = [] + batch_last_token_idxs = [] + + for uid in user_indices: + prefill_len = int(decoding_pos[uid]) + padded_len = get_padded_prefill_len(prefill_len) + user_tokens = torch.cat( + [ + input_tokens_prefill_pt[uid : uid + 1, :prefill_len], + torch.zeros(1, padded_len - prefill_len, dtype=torch.long), + ], + dim=-1, + ) + batch_tokens_list.append(user_tokens) + block_size = page_params["page_block_size"] + num_blocks_needed = (padded_len + block_size - 1) // block_size + batch_page_tables.append(page_table[uid : uid + 1, :num_blocks_needed]) + batch_last_token_idxs.append(prefill_len - 1) + + tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [total_users, padded_len] + page_table_stacked = torch.cat(batch_page_tables, dim=0) # [total_users, num_blocks] + padded_len = tokens_stacked.shape[1] + + # Reshape tokens for batch>1: concatenate per-row users along seq dim + tokens_for_model = tokens_stacked.reshape(num_rows, -1) # [num_rows, N*padded_len] + + (tokens_embd, rot_mats_global, rot_mats_local, page_table_tt, _) = model[ + model_id + ].prepare_inputs_prefill( + tokens_for_model, + page_table=page_table_stacked, + batched_prefill=True, + ) + + # Use get_last_token if all users' last tokens fall in the same 32-token tile + min_tile = (min(batch_last_token_idxs) // 32) * 32 + max_tile = (max(batch_last_token_idxs) // 32) * 32 + get_last_token_val = min_tile if min_tile == max_tile else -1 + tt_logits = model[model_id].ttnn_prefill_forward( + tokens_embd, + rot_mats_global=rot_mats_global, + rot_mats_local=rot_mats_local, + user_id=0, # Must be 0: each device sees page_table[0] after row-sharding + page_table=page_table_tt, + get_last_token=get_last_token_val, + kv_cache=tt_kv_cache[model_id], + batch_size=users_per_row_per_iter, + ) + + if get_last_token_val == -1: + adjusted_last_idxs = batch_last_token_idxs + seq_len_for_output = padded_len + else: + adjusted_last_idxs = [idx % 32 for idx in batch_last_token_idxs] + seq_len_for_output = 32 + row_results = model[model_id].process_output_prefill_batched( + tt_logits, + adjusted_last_idxs, + users_per_row=users_per_row_per_iter, + seq_len_per_user=seq_len_for_output, + ) + return row_results + + # Warmup: compile with first batch + logger.info("Starting row-parallel prefill warmup...") + profiler.start(f"compile_prefill", iteration=batch_idx) + warmup_user_indices = [ + row * users_per_row_prefill + u for row in range(num_rows) for u in range(users_per_row_per_iter) + ] + warmup_results = _run_batched_prefill_iter(0, warmup_user_indices) + for row, uid in enumerate(warmup_user_indices): + prefilled_token[uid] = torch.argmax(warmup_results[row].view(-1)).item() + profiler.end(f"compile_prefill", iteration=batch_idx) + logger.info("Finished row-parallel prefill warmup") + + # Clear KV caches before real prefill (warmup wrote to them) + for i in range(len(model)): + for layer_obj in model[i].layers: + k_cache, v_cache = layer_obj.self_attn.layer_past + ttnn.mul(k_cache, 0, output_tensor=k_cache) + ttnn.mul(v_cache, 0, output_tensor=v_cache) + + # Real prefill + logger.info( + f"Starting row-parallel batched prefill ({num_prefill_iters} iters, " + f"{users_per_row_per_iter} user/row/iter, {global_batch_size} users)..." + ) + profiler.start(f"inference_prefill", iteration=batch_idx) + for iter_idx in range(num_prefill_iters): + user_indices = [ + row * users_per_row_prefill + iter_idx * users_per_row_per_iter + u + for row in range(num_rows) + for u in range(users_per_row_per_iter) + ] + row_results = _run_batched_prefill_iter(iter_idx, user_indices) + for row, uid in enumerate(user_indices): + prefilled_token[uid] = torch.argmax(row_results[row].view(-1)).item() + if iter_idx % 8 == 0: + logger.info(f" Prefilled batch {iter_idx+1}/{num_prefill_iters}") + profiler.end(f"inference_prefill", iteration=batch_idx) + logger.info(f"Row-parallel batched prefill finished ({num_prefill_iters} iterations)") - logger.info(f"Starting prefill...") - profiler.start(f"inference_prefill", iteration=batch_idx) - logits = generator.prefill_forward_text( - input_tokens_prefill_pt, - page_table=page_table, - kv_cache=tt_kv_cache, - prompt_lens=decoding_pos, - enable_trace=enable_prefill_trace, - warmup_prefill=False, # we can warmup prefill ourselves above if we want to - ) - prefilled_token = torch.argmax(logits, dim=-1) - profiler.end(f"inference_prefill", iteration=batch_idx) - logger.info(f"Prefill finished") logger.info(f"First generated token: '{tokenizer.decode(prefilled_token[0])}'") # Initialize generation state like tt_transformers diff --git a/models/demos/gpt_oss/tt/attention/__init__.py b/models/demos/gpt_oss/tt/attention/__init__.py index 52eab7bacc3c..c7963051a511 100644 --- a/models/demos/gpt_oss/tt/attention/__init__.py +++ b/models/demos/gpt_oss/tt/attention/__init__.py @@ -108,7 +108,15 @@ def __init__( self.scaling = config.scaling def __call__( - self, hidden_states, rope_mats, position_idx=None, page_table=None, kv_cache=None, is_decode=True, user_id=0 + self, + hidden_states, + rope_mats, + position_idx=None, + page_table=None, + kv_cache=None, + is_decode=True, + user_id=0, + batch_size=1, ): """ Forward pass - automatically dispatches to decode or prefill. @@ -169,4 +177,5 @@ def __call__( position_idx=position_idx, page_table=page_table, ccl_manager=self.ccl_manager, + batch_size=batch_size, ) diff --git a/models/demos/gpt_oss/tt/attention/prefill.py b/models/demos/gpt_oss/tt/attention/prefill.py index 56f8376c6faf..2f4283b7151c 100644 --- a/models/demos/gpt_oss/tt/attention/prefill.py +++ b/models/demos/gpt_oss/tt/attention/prefill.py @@ -29,6 +29,7 @@ def prefill_forward( page_table, ccl_manager, user_id=0, + batch_size=1, ): """ Prefill forward pass - optimized for sequence processing (seq_len>1). @@ -51,18 +52,21 @@ def prefill_forward( Attention output [batch, seq_len, hidden_size] """ activation_dtype = ttnn.bfloat16 - _, batch_size, seq_len, hidden_size = hidden_states.shape + total_seq_len = hidden_states.shape[-2] + hidden_size = hidden_states.shape[-1] + seq_len = total_seq_len // batch_size # Per-user sequence length # Validate prefill mode if seq_len <= 1: raise ValueError(f"Prefill mode requires seq_len>1, got {seq_len}. Use decode mode for single tokens.") - if batch_size != 1: - raise NotImplementedError(f"Currently only batch_size=1 supported, got {batch_size}") - # QKV projection xqkv_fused = apply_qkv_projection(hidden_states, weights) + # Reshape for batch: [1, 1, B*S, QKV] -> [B, 1, S, QKV] + if batch_size > 1: + xqkv_fused = ttnn.reshape(xqkv_fused, [batch_size, 1, seq_len, -1]) + # Split into Q, K, V heads num_local_heads = mesh_config.shard_size(config.num_heads) num_local_kv_heads = mesh_config.shard_size(config.num_kv_heads) @@ -70,11 +74,15 @@ def prefill_forward( tt_q, tt_k, tt_v = split_qkv_heads_prefill(xqkv_fused, num_local_heads, num_local_kv_heads) xqkv_fused.deallocate(True) - # Apply RoPE + # Apply RoPE (use per-user seq_len positions) + if batch_size > 1: + rope_mats_sliced = [rope_mats[0][:, :, :seq_len, :], rope_mats[1][:, :, :seq_len, :]] + else: + rope_mats_sliced = rope_mats tt_q_orig = tt_q tt_k_orig = tt_k - tt_q = apply_rope(tt_q, rope_mats, transformation_mat, is_decode_mode=False) - tt_k = apply_rope(tt_k, rope_mats, transformation_mat, is_decode_mode=False) + tt_q = apply_rope(tt_q, rope_mats_sliced, transformation_mat, is_decode_mode=False) + tt_k = apply_rope(tt_k, rope_mats_sliced, transformation_mat, is_decode_mode=False) tt_q_orig.deallocate(True) tt_k_orig.deallocate(True) @@ -89,16 +97,36 @@ def prefill_forward( if page_table is not None: block_size = k_cache.shape[2] - page_len = page_table.shape[1] * block_size - tt_k_sliced = tt_k[:, :, :page_len, :] if page_len < tt_k.shape[2] else tt_k - tt_v_sliced = tt_v[:, :, :page_len, :] if page_len < tt_v.shape[2] else tt_v - ttnn.experimental.paged_fill_cache(k_cache, tt_k_sliced, page_table, batch_idx=user_id) - ttnn.experimental.paged_fill_cache(v_cache, tt_v_sliced, page_table, batch_idx=user_id) + page_len = page_table.shape[-1] * block_size + if batch_size > 1: + # Flatten batch into seq dim, heads into last dim — single fill call, no per-user loop. + # Paged cache just maps sequence positions to physical pages. + k_fill = ttnn.reshape(tt_k, [1, 1, total_seq_len, -1]) + v_fill = ttnn.reshape(tt_v, [1, 1, total_seq_len, -1]) + page_table_flat = ttnn.reshape(page_table, [1, -1]) + ttnn.experimental.paged_fill_cache(k_cache, k_fill, page_table_flat, batch_idx=0) + ttnn.experimental.paged_fill_cache(v_cache, v_fill, page_table_flat, batch_idx=0) + k_fill.deallocate(True) + v_fill.deallocate(True) + else: + tt_k_sliced = tt_k[:, :, :page_len, :] if page_len < tt_k.shape[2] else tt_k + tt_v_sliced = tt_v[:, :, :page_len, :] if page_len < tt_v.shape[2] else tt_v + ttnn.experimental.paged_fill_cache(k_cache, tt_k_sliced, page_table, batch_idx=user_id) + ttnn.experimental.paged_fill_cache(v_cache, tt_v_sliced, page_table, batch_idx=user_id) else: # Non-paged attention - ttnn.fill_cache(k_cache, tt_k, batch_idx=user_id) - ttnn.fill_cache(v_cache, tt_v, batch_idx=user_id) + if batch_size > 1: + for b in range(batch_size): + k_b = ttnn.slice(tt_k, (b, 0, 0, 0), (b + 1, tt_k.shape[1], tt_k.shape[2], tt_k.shape[3])) + v_b = ttnn.slice(tt_v, (b, 0, 0, 0), (b + 1, tt_v.shape[1], tt_v.shape[2], tt_v.shape[3])) + ttnn.fill_cache(k_cache, k_b, batch_idx=b) + ttnn.fill_cache(v_cache, v_b, batch_idx=b) + k_b.deallocate(True) + v_b.deallocate(True) + else: + ttnn.fill_cache(k_cache, tt_k, batch_idx=user_id) + ttnn.fill_cache(v_cache, tt_v, batch_idx=user_id) # Scaled dot-product attention tt_sdpa_out = ttnn.transformer.scaled_dot_product_attention( @@ -120,11 +148,14 @@ def prefill_forward( tt_sdpa_out = concat_heads(tt_sdpa_out, is_decode_mode=False) tt_sdpa_out_pre_concat.deallocate(True) + # Flatten back for output projection: [B, 1, S, H] -> [1, 1, B*S, H] + if batch_size > 1: + tt_sdpa_out = ttnn.reshape(tt_sdpa_out, [1, 1, total_seq_len, -1]) + tt_out = apply_output_projection(tt_sdpa_out, weights, activation_dtype) # Note: apply_output_projection already deallocates its input tensor internally - # tt_out = ttnn.reshape(tt_out, (batch_size, seq_len, hidden_size)) # Tensor parallel allreduce - tt_out = apply_allreduce(tt_out, mesh_config, ccl_manager, batch_size, seq_len, hidden_size) + tt_out = apply_allreduce(tt_out, mesh_config, ccl_manager, 1, total_seq_len, hidden_size) return tt_out diff --git a/models/demos/gpt_oss/tt/layer.py b/models/demos/gpt_oss/tt/layer.py index 31aada54584d..c3b91c5b8ef1 100644 --- a/models/demos/gpt_oss/tt/layer.py +++ b/models/demos/gpt_oss/tt/layer.py @@ -96,6 +96,7 @@ def __call__( kv_cache=None, is_decode=True, user_id=0, + batch_size=1, ): # hidden_states: [1, 1, tokens/num_rows, hidden_size/num_columns] # residual: [1, 1, tokens/num_rows, hidden_size/num_columns] @@ -112,6 +113,7 @@ def __call__( kv_cache=kv_cache, is_decode=is_decode, user_id=user_id, + batch_size=batch_size, ) hidden_states_post_norm.deallocate(True) diff --git a/models/demos/gpt_oss/tt/model.py b/models/demos/gpt_oss/tt/model.py index 47bb42836c4e..44f36af797f5 100644 --- a/models/demos/gpt_oss/tt/model.py +++ b/models/demos/gpt_oss/tt/model.py @@ -234,7 +234,16 @@ def switch_mode(self, mode: Mode): return None def _forward_layers_and_head( - self, hidden_states, rope_mats, current_pos, page_table, kv_cache, get_last_token=-1, is_decode=True, user_id=0 + self, + hidden_states, + rope_mats, + current_pos, + page_table, + kv_cache, + get_last_token=-1, + is_decode=True, + user_id=0, + batch_size=1, ): """ Shared forward pass through decoder layers and final projection. @@ -264,16 +273,31 @@ def _forward_layers_and_head( kv_cache=layer_kv_cache, is_decode=is_decode, user_id=user_id, + batch_size=batch_size, ) logits = hidden_states if get_last_token != -1: - # The logits come from the shared method, slice them if len(logits.shape) == 3: logits = ttnn.unsqueeze(logits, dim=1) - logits_sliced = ttnn.slice(logits, (0, 0, get_last_token, 0), (1, 1, get_last_token + 32, logits.shape[-1])) - logits.deallocate(True) - logits = logits_sliced + if batch_size > 1: + # Batch>1: tokens are concatenated [1,1,B*S,H]. Extract each user's 32-token tile. + per_user_seq = logits.shape[2] // batch_size + tiles = [] + for b in range(batch_size): + start = b * per_user_seq + get_last_token + tile = ttnn.slice(logits, (0, 0, start, 0), (1, 1, start + 32, logits.shape[-1])) + tiles.append(tile) + logits.deallocate(True) + logits = ttnn.concat(tiles, dim=2) # [1, 1, B*32, H] + for t in tiles: + t.deallocate(True) + else: + logits_sliced = ttnn.slice( + logits, (0, 0, get_last_token, 0), (1, 1, get_last_token + 32, logits.shape[-1]) + ) + logits.deallocate(True) + logits = logits_sliced hidden_states = logits # Final norm and lm_head @@ -335,6 +359,7 @@ def ttnn_prefill_forward( chunk_start_idx=None, get_last_token=-1, kv_cache=None, + batch_size=1, ): """Prefill forward pass - processes full sequences""" # Use provided rotation matrices or slice from rope_setup (matches tt-transformers) @@ -358,6 +383,7 @@ def ttnn_prefill_forward( get_last_token=get_last_token, is_decode=False, user_id=user_id, + batch_size=batch_size, ) return logits @@ -476,15 +502,34 @@ def prepare_inputs_prefill( trace_enabled=False, last_token_idx=None, global_user_id=None, + batched_prefill=False, ): - """Prepare inputs for prefill mode""" - # Embed the tokens - if tokens.dim() == 2: - tokens = tokens.reshape(1, 1, 1, -1) + """Prepare inputs for prefill mode + Args: + batched_prefill: If True, tokens is [num_rows, seq_len] and will be + sharded across mesh rows. Each row processes a different user. + """ + # Embed the tokens device = None if trace_enabled else self.mesh_device - tokens = ttnn.from_torch(tokens, device=device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT) + if batched_prefill: + # Row-parallel batched prefill: tokens is [num_rows, seq_len] + # Shard across mesh rows so each row gets one user's tokens [1, seq_len] + num_rows = tokens.shape[0] + seq_len_per_user = tokens.shape[1] + tokens = tokens.reshape(num_rows, 1, 1, seq_len_per_user) + tokens = ttnn.from_torch( + tokens, + device=device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh(self.mesh_device, dims=(0, None), mesh_shape=self.mesh_device.shape), + ) + else: + if tokens.dim() == 2: + tokens = tokens.reshape(1, 1, 1, -1) + tokens = ttnn.from_torch(tokens, device=device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT) if not trace_enabled: tokens_embd = ttnn.embedding(tokens, self.embedding_weight, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat8_b) @@ -589,3 +634,37 @@ def process_output_prefill(self, tt_out, last_token_idx): torch_output = ttnn.to_torch(tt_output_tensor) result = torch_output[..., last_token_idx, : self.vocab_size] return result + + def process_output_prefill_batched(self, tt_out, last_token_idxs, users_per_row=1, seq_len_per_user=None): + """Process row-parallel batched prefill output. + + Extracts logits from one device per row (first device of each row). + Supports multiple users per row when users_per_row > 1. + + Args: + tt_out: Multi-device output tensor + last_token_idxs: List of last_token_idx per user (length = num_rows * users_per_row) + users_per_row: Number of users per mesh row per iteration + seq_len_per_user: Per-user sequence length (required when users_per_row > 1) + + Returns: + List of per-user logit tensors (one per user) + """ + num_cols = self.mesh_device.shape[1] + device_tensors = ttnn.get_device_tensors(tt_out) + results = [] + num_rows = self.mesh_device.shape[0] + for row in range(num_rows): + device_idx = row * num_cols # First device of each row + torch_output = ttnn.to_torch(device_tensors[device_idx]) + for u in range(users_per_row): + user_flat_idx = row * users_per_row + u + last_idx = last_token_idxs[user_flat_idx] if isinstance(last_token_idxs, list) else last_token_idxs + if users_per_row > 1: + # Tokens are concatenated: user u's last token is at offset u*seq_len_per_user + last_idx + global_idx = u * seq_len_per_user + last_idx + result = torch_output[..., global_idx, : self.vocab_size] + else: + result = torch_output[..., last_idx, : self.vocab_size] + results.append(result) + return results