Skip to content
322 changes: 295 additions & 27 deletions models/demos/gpt_oss/demo/text_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from models.tt_transformers.demo.simple_text_demo import create_tt_page_table, load_inputs
from models.tt_transformers.tt.common import (
PagedAttentionConfig,
copy_host_to_device,
get_padded_prefill_len,
preprocess_inputs_prefill,
sample_host,
Expand Down Expand Up @@ -242,7 +243,7 @@ def prepare_gpt_oss_generator_args(
{"page_block_size": 64, "page_max_num_blocks_per_dp": 128 * 1024 // 64}, # page_params
{"temperature": 0, "top_p": 0.08}, # sampling_params (greedy decoding)
True, # enable_decode_trace
False, # enable_prefill_trace
True, # enable_prefill_trace
True, # users_row_sharded
False, # long_context_mode
),
Expand Down Expand Up @@ -572,33 +573,300 @@ def test_gpt_oss_demo(
logger.info(f"Prefill finished for {num_real_users} real users")
logger.info(f"First generated token (user 0): '{tokenizer.decode(prefilled_token[0])}'")
else:
# Standard batch prefill (matching tt_transformers)
logger.info("Starting prefill warmup...")
profiler.start(f"compile_prefill", iteration=batch_idx)
generator.prefill_forward_text(
input_tokens_prefill_pt[:1],
page_table=page_table,
kv_cache=tt_kv_cache,
prompt_lens=decoding_pos,
enable_trace=enable_prefill_trace,
warmup_prefill=False,
)
profiler.end(f"compile_prefill", iteration=batch_idx)
logger.info("Finished prefill warmup")
# Row-parallel batched prefill: process 4 users at once (one per mesh row)
# This gives ~4x speedup over sequential per-user prefill
num_rows = mesh_device.shape[0]
users_per_row_prefill = global_batch_size // num_rows
users_per_row_per_iter = 1 # Users each mesh row processes per prefill iteration
# Increasing above 1 requires model changes:
# - attention/prefill.py: relax batch_size!=1 check, loop paged_fill_cache
# - model.py:process_output_prefill_batched: extract multiple logits per row
assert users_per_row_prefill % users_per_row_per_iter == 0
num_prefill_iters = users_per_row_prefill // users_per_row_per_iter
model_id = 0 # data_parallel=1, single model

prefilled_token = torch.zeros(global_batch_size, dtype=torch.long)

if enable_prefill_trace:
# === TRACED BATCHED PREFILL ===
# Trace captures device program once, then replays with input buffer updates.
# Eliminates per-iteration host dispatch overhead.

# Uniform padded_len for all users (required for tracing: fixed tensor shapes)
max_padded_len = max(get_padded_prefill_len(int(decoding_pos[uid])) for uid in range(global_batch_size))
block_size = page_params["page_block_size"]
max_num_blocks = (max_padded_len + block_size - 1) // block_size

# Compute fixed get_last_token for trace (all users must be in same 32-token tile)
all_last_idxs = [int(decoding_pos[uid]) - 1 for uid in range(global_batch_size)]
if users_per_row_per_iter > 1:
fixed_get_last_token = -1 # Can't use get_last_token with batch>1
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement is unreachable.

Copilot uses AI. Check for mistakes.
else:
fixed_get_last_token = (min(all_last_idxs) // 32) * 32
max_tile_start = (max(all_last_idxs) // 32) * 32
if fixed_get_last_token != max_tile_start:
logger.warning(
f"Users span multiple 32-token tiles ({fixed_get_last_token} vs {max_tile_start}), "
f"using get_last_token=-1 (slower)"
)
fixed_get_last_token = -1

def _prepare_batch_host(user_indices):
"""Prepare host-side tokens + page_table for a batch of users."""
tokens_list, pt_list, last_idxs = [], [], []
for uid in user_indices:
plen = int(decoding_pos[uid])
toks = torch.cat(
[
input_tokens_prefill_pt[uid : uid + 1, :plen],
torch.zeros(1, max_padded_len - plen, dtype=torch.long),
],
dim=-1,
)
tokens_list.append(toks)
pt_list.append(page_table[uid : uid + 1, :max_num_blocks])
last_idxs.append(plen - 1)
return (torch.cat(tokens_list, dim=0), torch.cat(pt_list, dim=0), last_idxs)

# --- Warmup (compilation) ---
logger.info("Starting traced row-parallel prefill warmup (compilation)...")
warmup_indices = [
row * users_per_row_prefill + u for row in range(num_rows) for u in range(users_per_row_per_iter)
]
tokens_w, pt_w, last_w = _prepare_batch_host(warmup_indices)
tokens_w = tokens_w.reshape(num_rows, -1) # [num_rows, N*S] for batch>1 concat

host_out = model[model_id].prepare_inputs_prefill(
tokens_w, page_table=pt_w, trace_enabled=True, batched_prefill=True
)
rot_global = host_out[1] # device-resident, fixed across iterations
rot_local = host_out[2] # None
host_inputs = (host_out[0], host_out[3], host_out[4]) # tokens, pt, cpt

profiler.start(f"compile_prefill", iteration=batch_idx)
dev_inputs = copy_host_to_device(host_inputs, mesh_device=mesh_device)
transformed = model[model_id].transform_and_embed_prefill_inputs_device(*dev_inputs)
tt_logits = model[model_id].ttnn_prefill_forward(
transformed[0],
rot_mats_global=rot_global,
rot_mats_local=rot_local,
user_id=0,
page_table=transformed[1],
get_last_token=fixed_get_last_token,
kv_cache=tt_kv_cache[model_id],
batch_size=users_per_row_per_iter,
)

if fixed_get_last_token == -1:
warmup_results = model[model_id].process_output_prefill_batched(
tt_logits,
last_w,
users_per_row=users_per_row_per_iter,
seq_len_per_user=max_padded_len,
)
else:
warmup_results = model[model_id].process_output_prefill_batched(
tt_logits,
[idx % 32 for idx in last_w],
users_per_row=users_per_row_per_iter,
seq_len_per_user=max_padded_len,
)
for row, uid in enumerate(warmup_indices):
prefilled_token[uid] = torch.argmax(warmup_results[row].view(-1)).item()
profiler.end(f"compile_prefill", iteration=batch_idx)
logger.info("Finished traced row-parallel prefill warmup")

# Clear KV caches (warmup wrote to them)
for i in range(len(model)):
for layer_obj in model[i].layers:
k_cache, v_cache = layer_obj.self_attn.layer_past
ttnn.mul(k_cache, 0, output_tensor=k_cache)
ttnn.mul(v_cache, 0, output_tensor=v_cache)

# --- Trace capture ---
logger.info("Capturing prefill trace...")
iter0_indices = [
row * users_per_row_prefill + u for row in range(num_rows) for u in range(users_per_row_per_iter)
]
tokens_0, pt_0, last_0 = _prepare_batch_host(iter0_indices)
tokens_0 = tokens_0.reshape(num_rows, -1)
host_out = model[model_id].prepare_inputs_prefill(
tokens_0, page_table=pt_0, trace_enabled=True, batched_prefill=True
)
host_inputs = (host_out[0], host_out[3], host_out[4])

trace_dev_inputs = copy_host_to_device(host_inputs, mesh_device=mesh_device)
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
# Embed tokens on-device inside the trace (without deallocating input buffer,
# since we need to update it between trace executions)
tokens_embd = ttnn.embedding(
trace_dev_inputs[0],
model[model_id].embedding_weight,
layout=ttnn.TILE_LAYOUT,
dtype=ttnn.bfloat8_b,
)
if len(tokens_embd.shape) == 3:
tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd)
tt_out_trace = model[model_id].ttnn_prefill_forward(
tokens_embd,
rot_mats_global=rot_global,
rot_mats_local=rot_local,
user_id=0,
page_table=trace_dev_inputs[1],
get_last_token=fixed_get_last_token,
kv_cache=tt_kv_cache[model_id],
batch_size=users_per_row_per_iter,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
logger.info("Prefill trace captured")

# --- Execute trace for all iterations ---
logger.info(
f"Starting traced row-parallel prefill ({num_prefill_iters} iters, "
f"{users_per_row_per_iter} user/row/iter, {global_batch_size} users)..."
)
profiler.start(f"inference_prefill", iteration=batch_idx)
for iter_idx in range(num_prefill_iters):
user_indices = [
row * users_per_row_prefill + iter_idx * users_per_row_per_iter + u
for row in range(num_rows)
for u in range(users_per_row_per_iter)
]
tokens_i, pt_i, last_i = _prepare_batch_host(user_indices)
tokens_i = tokens_i.reshape(num_rows, -1)
host_out = model[model_id].prepare_inputs_prefill(
tokens_i, page_table=pt_i, trace_enabled=True, batched_prefill=True
)
host_inputs = (host_out[0], host_out[3], host_out[4])
copy_host_to_device(host_inputs, device_tensors=trace_dev_inputs, mesh_device=mesh_device)
ttnn.execute_trace(mesh_device, trace_id, cq_id=0, blocking=False)
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the traced batched prefill loop you execute the trace with blocking=False and then immediately read/convert the output. Without an explicit sync/block on the output tensor, this risks reading incomplete results depending on TTNN scheduling. Add a synchronization point (e.g., block on output CPU copy or ttnn.synchronize_device(mesh_device)) before processing the output.

Suggested change
ttnn.execute_trace(mesh_device, trace_id, cq_id=0, blocking=False)
ttnn.execute_trace(mesh_device, trace_id, cq_id=0, blocking=False)
# Ensure trace execution has completed before reading from tt_out_trace.
ttnn.synchronize_device(mesh_device)

Copilot uses AI. Check for mistakes.

if fixed_get_last_token == -1:
row_results = model[model_id].process_output_prefill_batched(
tt_out_trace,
last_i,
users_per_row=users_per_row_per_iter,
seq_len_per_user=max_padded_len,
)
else:
row_results = model[model_id].process_output_prefill_batched(
tt_out_trace,
[idx % 32 for idx in last_i],
users_per_row=users_per_row_per_iter,
seq_len_per_user=max_padded_len,
)
for row, uid in enumerate(user_indices):
prefilled_token[uid] = torch.argmax(row_results[row].view(-1)).item()
if iter_idx % 8 == 0:
logger.info(f" Traced prefill batch {iter_idx+1}/{num_prefill_iters}")
profiler.end(f"inference_prefill", iteration=batch_idx)

ttnn.release_trace(mesh_device, trace_id)
logger.info(f"Traced row-parallel prefill finished ({num_prefill_iters} iterations)")

else:
# === NON-TRACED BATCHED PREFILL ===

# Helper to run one batched prefill iteration
def _run_batched_prefill_iter(iter_idx, user_indices):
batch_tokens_list = []
batch_page_tables = []
batch_last_token_idxs = []

for uid in user_indices:
prefill_len = int(decoding_pos[uid])
padded_len = get_padded_prefill_len(prefill_len)
user_tokens = torch.cat(
[
input_tokens_prefill_pt[uid : uid + 1, :prefill_len],
torch.zeros(1, padded_len - prefill_len, dtype=torch.long),
],
dim=-1,
)
batch_tokens_list.append(user_tokens)
block_size = page_params["page_block_size"]
num_blocks_needed = (padded_len + block_size - 1) // block_size
batch_page_tables.append(page_table[uid : uid + 1, :num_blocks_needed])
batch_last_token_idxs.append(prefill_len - 1)

tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [total_users, padded_len]
page_table_stacked = torch.cat(batch_page_tables, dim=0) # [total_users, num_blocks]
padded_len = tokens_stacked.shape[1]

Comment on lines +769 to +792
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.cat(batch_page_tables, dim=0) will fail if users in the same iteration have different num_blocks_needed (since each slice page_table[uid:uid+1, :num_blocks_needed] can have a different second dimension). To make batched prefill robust when prompts differ, pad/standardize to a single max_num_blocks for the iteration (e.g., fill extra entries with -1) before concatenation.

Suggested change
batch_tokens_list = []
batch_page_tables = []
batch_last_token_idxs = []
for uid in user_indices:
prefill_len = int(decoding_pos[uid])
padded_len = get_padded_prefill_len(prefill_len)
user_tokens = torch.cat(
[
input_tokens_prefill_pt[uid : uid + 1, :prefill_len],
torch.zeros(1, padded_len - prefill_len, dtype=torch.long),
],
dim=-1,
)
batch_tokens_list.append(user_tokens)
block_size = page_params["page_block_size"]
num_blocks_needed = (padded_len + block_size - 1) // block_size
batch_page_tables.append(page_table[uid : uid + 1, :num_blocks_needed])
batch_last_token_idxs.append(prefill_len - 1)
tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [total_users, padded_len]
page_table_stacked = torch.cat(batch_page_tables, dim=0) # [total_users, num_blocks]
padded_len = tokens_stacked.shape[1]
# First pass: collect per-user metadata and track max sizes for padding
per_user_tokens = []
per_user_padded_lens = []
per_user_page_tables = []
per_user_num_blocks = []
batch_last_token_idxs = []
max_padded_len = 0
max_num_blocks = 0
block_size = page_params["page_block_size"]
for uid in user_indices:
prefill_len = int(decoding_pos[uid])
padded_len = get_padded_prefill_len(prefill_len)
max_padded_len = max(max_padded_len, padded_len)
tokens_slice = input_tokens_prefill_pt[uid : uid + 1, :prefill_len]
per_user_tokens.append(tokens_slice)
per_user_padded_lens.append(padded_len)
num_blocks_needed = (padded_len + block_size - 1) // block_size
max_num_blocks = max(max_num_blocks, num_blocks_needed)
page_table_slice = page_table[uid : uid + 1, :num_blocks_needed]
per_user_page_tables.append(page_table_slice)
per_user_num_blocks.append(num_blocks_needed)
batch_last_token_idxs.append(prefill_len - 1)
# Second pass: pad tokens and page tables to per-iteration maxima
batch_tokens_list = []
batch_page_tables = []
for tokens_slice, padded_len, page_table_slice, num_blocks_needed in zip(
per_user_tokens,
per_user_padded_lens,
per_user_page_tables,
per_user_num_blocks,
):
# Pad tokens to max_padded_len with zeros
token_pad_len = max_padded_len - padded_len
if token_pad_len > 0:
token_padding = torch.zeros(1, token_pad_len, dtype=torch.long, device=tokens_slice.device)
user_tokens = torch.cat([tokens_slice, token_padding], dim=-1)
else:
user_tokens = tokens_slice
batch_tokens_list.append(user_tokens)
# Pad page tables to max_num_blocks with -1 sentinel
page_pad_blocks = max_num_blocks - num_blocks_needed
if page_pad_blocks > 0:
pad_values = torch.full(
(1, page_pad_blocks),
-1,
dtype=page_table_slice.dtype,
device=page_table_slice.device,
)
user_page_table = torch.cat([page_table_slice, pad_values], dim=-1)
else:
user_page_table = page_table_slice
batch_page_tables.append(user_page_table)
tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [total_users, max_padded_len]
page_table_stacked = torch.cat(batch_page_tables, dim=0) # [total_users, max_num_blocks]
padded_len = max_padded_len

Copilot uses AI. Check for mistakes.
# Reshape tokens for batch>1: concatenate per-row users along seq dim
tokens_for_model = tokens_stacked.reshape(num_rows, -1) # [num_rows, N*padded_len]

(tokens_embd, rot_mats_global, rot_mats_local, page_table_tt, _) = model[
model_id
].prepare_inputs_prefill(
tokens_for_model,
page_table=page_table_stacked,
batched_prefill=True,
)

get_last_token_val = (max(batch_last_token_idxs) // 32) * 32 if users_per_row_per_iter == 1 else -1
tt_logits = model[model_id].ttnn_prefill_forward(
tokens_embd,
rot_mats_global=rot_mats_global,
rot_mats_local=rot_mats_local,
user_id=0, # Must be 0: each device sees page_table[0] after row-sharding
page_table=page_table_tt,
get_last_token=get_last_token_val,
kv_cache=tt_kv_cache[model_id],
batch_size=users_per_row_per_iter,
)

if get_last_token_val == -1:
adjusted_last_idxs = batch_last_token_idxs
else:
adjusted_last_idxs = [idx % 32 for idx in batch_last_token_idxs]
row_results = model[model_id].process_output_prefill_batched(
tt_logits,
adjusted_last_idxs,
users_per_row=users_per_row_per_iter,
seq_len_per_user=padded_len,
)
return row_results

# Warmup: compile with first batch
logger.info("Starting row-parallel prefill warmup...")
profiler.start(f"compile_prefill", iteration=batch_idx)
warmup_user_indices = [
row * users_per_row_prefill + u for row in range(num_rows) for u in range(users_per_row_per_iter)
]
warmup_results = _run_batched_prefill_iter(0, warmup_user_indices)
for row, uid in enumerate(warmup_user_indices):
prefilled_token[uid] = torch.argmax(warmup_results[row].view(-1)).item()
profiler.end(f"compile_prefill", iteration=batch_idx)
logger.info("Finished row-parallel prefill warmup")

# Clear KV caches before real prefill (warmup wrote to them)
for i in range(len(model)):
for layer_obj in model[i].layers:
k_cache, v_cache = layer_obj.self_attn.layer_past
ttnn.mul(k_cache, 0, output_tensor=k_cache)
ttnn.mul(v_cache, 0, output_tensor=v_cache)

# Real prefill
logger.info(
f"Starting row-parallel batched prefill ({num_prefill_iters} iters, "
f"{users_per_row_per_iter} user/row/iter, {global_batch_size} users)..."
)
profiler.start(f"inference_prefill", iteration=batch_idx)
for iter_idx in range(num_prefill_iters):
user_indices = [
row * users_per_row_prefill + iter_idx * users_per_row_per_iter + u
for row in range(num_rows)
for u in range(users_per_row_per_iter)
]
row_results = _run_batched_prefill_iter(iter_idx, user_indices)
for row, uid in enumerate(user_indices):
prefilled_token[uid] = torch.argmax(row_results[row].view(-1)).item()
if iter_idx % 8 == 0:
logger.info(f" Prefilled batch {iter_idx+1}/{num_prefill_iters}")
profiler.end(f"inference_prefill", iteration=batch_idx)
logger.info(f"Row-parallel batched prefill finished ({num_prefill_iters} iterations)")

logger.info(f"Starting prefill...")
profiler.start(f"inference_prefill", iteration=batch_idx)
logits = generator.prefill_forward_text(
input_tokens_prefill_pt,
page_table=page_table,
kv_cache=tt_kv_cache,
prompt_lens=decoding_pos,
enable_trace=enable_prefill_trace,
warmup_prefill=False, # we can warmup prefill ourselves above if we want to
)
prefilled_token = torch.argmax(logits, dim=-1)
profiler.end(f"inference_prefill", iteration=batch_idx)
logger.info(f"Prefill finished")
logger.info(f"First generated token: '{tokenizer.decode(prefilled_token[0])}'")

# Initialize generation state like tt_transformers
Expand Down
11 changes: 10 additions & 1 deletion models/demos/gpt_oss/tt/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,15 @@ def __init__(
self.scaling = config.scaling

def __call__(
self, hidden_states, rope_mats, position_idx=None, page_table=None, kv_cache=None, is_decode=True, user_id=0
self,
hidden_states,
rope_mats,
position_idx=None,
page_table=None,
kv_cache=None,
is_decode=True,
user_id=0,
batch_size=1,
):
"""
Forward pass - automatically dispatches to decode or prefill.
Expand Down Expand Up @@ -169,4 +177,5 @@ def __call__(
position_idx=position_idx,
page_table=page_table,
ccl_manager=self.ccl_manager,
batch_size=batch_size,
)
Loading
Loading