Skip to content

Commit f8cbbd0

Browse files
sraizada-ttUbuntu
andauthored
[gpt-oss] batched prefill and prefill tracing (#37848)
galaxy demo: https://github.com/tenstorrent/tt-metal/actions/runs/22053977976 --------- Co-authored-by: Ubuntu <ubuntu@UF-EV-B5-GWH01.maas>
1 parent 4ffd828 commit f8cbbd0

File tree

5 files changed

+445
-54
lines changed

5 files changed

+445
-54
lines changed

models/demos/gpt_oss/demo/text_demo.py

Lines changed: 297 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from models.tt_transformers.demo.simple_text_demo import create_tt_page_table, load_inputs
3737
from models.tt_transformers.tt.common import (
3838
PagedAttentionConfig,
39+
copy_host_to_device,
3940
get_padded_prefill_len,
4041
preprocess_inputs_prefill,
4142
sample_host,
@@ -242,7 +243,7 @@ def prepare_gpt_oss_generator_args(
242243
{"page_block_size": 64, "page_max_num_blocks_per_dp": 128 * 1024 // 64}, # page_params
243244
{"temperature": 0, "top_p": 0.08}, # sampling_params (greedy decoding)
244245
True, # enable_decode_trace
245-
False, # enable_prefill_trace
246+
True, # enable_prefill_trace
246247
True, # users_row_sharded
247248
False, # long_context_mode
248249
),
@@ -572,33 +573,302 @@ def test_gpt_oss_demo(
572573
logger.info(f"Prefill finished for {num_real_users} real users")
573574
logger.info(f"First generated token (user 0): '{tokenizer.decode(prefilled_token[0])}'")
574575
else:
575-
# Standard batch prefill (matching tt_transformers)
576-
logger.info("Starting prefill warmup...")
577-
profiler.start(f"compile_prefill", iteration=batch_idx)
578-
generator.prefill_forward_text(
579-
input_tokens_prefill_pt[:1],
580-
page_table=page_table,
581-
kv_cache=tt_kv_cache,
582-
prompt_lens=decoding_pos,
583-
enable_trace=enable_prefill_trace,
584-
warmup_prefill=False,
585-
)
586-
profiler.end(f"compile_prefill", iteration=batch_idx)
587-
logger.info("Finished prefill warmup")
576+
# Row-parallel batched prefill: process 4 users at once (one per mesh row)
577+
# This gives ~4x speedup over sequential per-user prefill
578+
num_rows = mesh_device.shape[0]
579+
users_per_row_prefill = global_batch_size // num_rows
580+
users_per_row_per_iter = 1 # Users each mesh row processes per prefill iteration
581+
# Increasing above 1 requires model changes:
582+
# - attention/prefill.py: relax batch_size!=1 check, loop paged_fill_cache
583+
# - model.py:process_output_prefill_batched: extract multiple logits per row
584+
assert users_per_row_prefill % users_per_row_per_iter == 0
585+
num_prefill_iters = users_per_row_prefill // users_per_row_per_iter
586+
model_id = 0 # data_parallel=1, single model
587+
588+
prefilled_token = torch.zeros(global_batch_size, dtype=torch.long)
589+
590+
if enable_prefill_trace:
591+
# === TRACED BATCHED PREFILL ===
592+
# Trace captures device program once, then replays with input buffer updates.
593+
# Eliminates per-iteration host dispatch overhead.
594+
595+
# Uniform padded_len for all users (required for tracing: fixed tensor shapes)
596+
max_padded_len = max(get_padded_prefill_len(int(decoding_pos[uid])) for uid in range(global_batch_size))
597+
block_size = page_params["page_block_size"]
598+
max_num_blocks = (max_padded_len + block_size - 1) // block_size
599+
600+
# Compute fixed get_last_token for trace (all users must be in same 32-token tile)
601+
all_last_idxs = [int(decoding_pos[uid]) - 1 for uid in range(global_batch_size)]
602+
fixed_get_last_token = (min(all_last_idxs) // 32) * 32
603+
max_tile_start = (max(all_last_idxs) // 32) * 32
604+
if fixed_get_last_token != max_tile_start:
605+
logger.warning(
606+
f"Users span multiple 32-token tiles ({fixed_get_last_token} vs {max_tile_start}), "
607+
f"using get_last_token=-1 (slower)"
608+
)
609+
fixed_get_last_token = -1
610+
611+
def _prepare_batch_host(user_indices):
612+
"""Prepare host-side tokens + page_table for a batch of users."""
613+
tokens_list, pt_list, last_idxs = [], [], []
614+
for uid in user_indices:
615+
plen = int(decoding_pos[uid])
616+
toks = torch.cat(
617+
[
618+
input_tokens_prefill_pt[uid : uid + 1, :plen],
619+
torch.zeros(1, max_padded_len - plen, dtype=torch.long),
620+
],
621+
dim=-1,
622+
)
623+
tokens_list.append(toks)
624+
pt_list.append(page_table[uid : uid + 1, :max_num_blocks])
625+
last_idxs.append(plen - 1)
626+
return (torch.cat(tokens_list, dim=0), torch.cat(pt_list, dim=0), last_idxs)
627+
628+
# --- Warmup (compilation) ---
629+
logger.info("Starting traced row-parallel prefill warmup (compilation)...")
630+
warmup_indices = [
631+
row * users_per_row_prefill + u for row in range(num_rows) for u in range(users_per_row_per_iter)
632+
]
633+
tokens_w, pt_w, last_w = _prepare_batch_host(warmup_indices)
634+
tokens_w = tokens_w.reshape(num_rows, -1) # [num_rows, N*S] for batch>1 concat
635+
636+
host_out = model[model_id].prepare_inputs_prefill(
637+
tokens_w, page_table=pt_w, trace_enabled=True, batched_prefill=True
638+
)
639+
rot_global = host_out[1] # device-resident, fixed across iterations
640+
rot_local = host_out[2] # None
641+
host_inputs = (host_out[0], host_out[3], host_out[4]) # tokens, pt, cpt
642+
643+
profiler.start(f"compile_prefill", iteration=batch_idx)
644+
dev_inputs = copy_host_to_device(host_inputs, mesh_device=mesh_device)
645+
transformed = model[model_id].transform_and_embed_prefill_inputs_device(*dev_inputs)
646+
tt_logits = model[model_id].ttnn_prefill_forward(
647+
transformed[0],
648+
rot_mats_global=rot_global,
649+
rot_mats_local=rot_local,
650+
user_id=0,
651+
page_table=transformed[1],
652+
get_last_token=fixed_get_last_token,
653+
kv_cache=tt_kv_cache[model_id],
654+
batch_size=users_per_row_per_iter,
655+
)
656+
657+
if fixed_get_last_token == -1:
658+
warmup_results = model[model_id].process_output_prefill_batched(
659+
tt_logits,
660+
last_w,
661+
users_per_row=users_per_row_per_iter,
662+
seq_len_per_user=max_padded_len,
663+
)
664+
else:
665+
warmup_results = model[model_id].process_output_prefill_batched(
666+
tt_logits,
667+
[idx % 32 for idx in last_w],
668+
users_per_row=users_per_row_per_iter,
669+
seq_len_per_user=32,
670+
)
671+
for row, uid in enumerate(warmup_indices):
672+
prefilled_token[uid] = torch.argmax(warmup_results[row].view(-1)).item()
673+
profiler.end(f"compile_prefill", iteration=batch_idx)
674+
logger.info("Finished traced row-parallel prefill warmup")
675+
676+
# Clear KV caches (warmup wrote to them)
677+
for i in range(len(model)):
678+
for layer_obj in model[i].layers:
679+
k_cache, v_cache = layer_obj.self_attn.layer_past
680+
ttnn.mul(k_cache, 0, output_tensor=k_cache)
681+
ttnn.mul(v_cache, 0, output_tensor=v_cache)
682+
683+
# --- Trace capture ---
684+
logger.info("Capturing prefill trace...")
685+
iter0_indices = [
686+
row * users_per_row_prefill + u for row in range(num_rows) for u in range(users_per_row_per_iter)
687+
]
688+
tokens_0, pt_0, last_0 = _prepare_batch_host(iter0_indices)
689+
tokens_0 = tokens_0.reshape(num_rows, -1)
690+
host_out = model[model_id].prepare_inputs_prefill(
691+
tokens_0, page_table=pt_0, trace_enabled=True, batched_prefill=True
692+
)
693+
host_inputs = (host_out[0], host_out[3], host_out[4])
694+
695+
trace_dev_inputs = copy_host_to_device(host_inputs, mesh_device=mesh_device)
696+
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
697+
# Embed tokens on-device inside the trace (without deallocating input buffer,
698+
# since we need to update it between trace executions)
699+
tokens_embd = ttnn.embedding(
700+
trace_dev_inputs[0],
701+
model[model_id].embedding_weight,
702+
layout=ttnn.TILE_LAYOUT,
703+
dtype=ttnn.bfloat8_b,
704+
)
705+
if len(tokens_embd.shape) == 3:
706+
tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd)
707+
tt_out_trace = model[model_id].ttnn_prefill_forward(
708+
tokens_embd,
709+
rot_mats_global=rot_global,
710+
rot_mats_local=rot_local,
711+
user_id=0,
712+
page_table=trace_dev_inputs[1],
713+
get_last_token=fixed_get_last_token,
714+
kv_cache=tt_kv_cache[model_id],
715+
batch_size=users_per_row_per_iter,
716+
)
717+
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
718+
logger.info("Prefill trace captured")
719+
720+
# --- Execute trace for all iterations ---
721+
logger.info(
722+
f"Starting traced row-parallel prefill ({num_prefill_iters} iters, "
723+
f"{users_per_row_per_iter} user/row/iter, {global_batch_size} users)..."
724+
)
725+
profiler.start(f"inference_prefill", iteration=batch_idx)
726+
for iter_idx in range(num_prefill_iters):
727+
user_indices = [
728+
row * users_per_row_prefill + iter_idx * users_per_row_per_iter + u
729+
for row in range(num_rows)
730+
for u in range(users_per_row_per_iter)
731+
]
732+
tokens_i, pt_i, last_i = _prepare_batch_host(user_indices)
733+
tokens_i = tokens_i.reshape(num_rows, -1)
734+
host_out = model[model_id].prepare_inputs_prefill(
735+
tokens_i, page_table=pt_i, trace_enabled=True, batched_prefill=True
736+
)
737+
host_inputs = (host_out[0], host_out[3], host_out[4])
738+
copy_host_to_device(host_inputs, device_tensors=trace_dev_inputs, mesh_device=mesh_device)
739+
ttnn.execute_trace(mesh_device, trace_id, cq_id=0, blocking=False)
740+
741+
if fixed_get_last_token == -1:
742+
row_results = model[model_id].process_output_prefill_batched(
743+
tt_out_trace,
744+
last_i,
745+
users_per_row=users_per_row_per_iter,
746+
seq_len_per_user=max_padded_len,
747+
)
748+
else:
749+
row_results = model[model_id].process_output_prefill_batched(
750+
tt_out_trace,
751+
[idx % 32 for idx in last_i],
752+
users_per_row=users_per_row_per_iter,
753+
seq_len_per_user=32,
754+
)
755+
for row, uid in enumerate(user_indices):
756+
prefilled_token[uid] = torch.argmax(row_results[row].view(-1)).item()
757+
if iter_idx % 8 == 0:
758+
logger.info(f" Traced prefill batch {iter_idx+1}/{num_prefill_iters}")
759+
profiler.end(f"inference_prefill", iteration=batch_idx)
760+
761+
ttnn.release_trace(mesh_device, trace_id)
762+
logger.info(f"Traced row-parallel prefill finished ({num_prefill_iters} iterations)")
763+
764+
else:
765+
# === NON-TRACED BATCHED PREFILL ===
766+
767+
# Helper to run one batched prefill iteration
768+
def _run_batched_prefill_iter(iter_idx, user_indices):
769+
batch_tokens_list = []
770+
batch_page_tables = []
771+
batch_last_token_idxs = []
772+
773+
for uid in user_indices:
774+
prefill_len = int(decoding_pos[uid])
775+
padded_len = get_padded_prefill_len(prefill_len)
776+
user_tokens = torch.cat(
777+
[
778+
input_tokens_prefill_pt[uid : uid + 1, :prefill_len],
779+
torch.zeros(1, padded_len - prefill_len, dtype=torch.long),
780+
],
781+
dim=-1,
782+
)
783+
batch_tokens_list.append(user_tokens)
784+
block_size = page_params["page_block_size"]
785+
num_blocks_needed = (padded_len + block_size - 1) // block_size
786+
batch_page_tables.append(page_table[uid : uid + 1, :num_blocks_needed])
787+
batch_last_token_idxs.append(prefill_len - 1)
788+
789+
tokens_stacked = torch.cat(batch_tokens_list, dim=0) # [total_users, padded_len]
790+
page_table_stacked = torch.cat(batch_page_tables, dim=0) # [total_users, num_blocks]
791+
padded_len = tokens_stacked.shape[1]
792+
793+
# Reshape tokens for batch>1: concatenate per-row users along seq dim
794+
tokens_for_model = tokens_stacked.reshape(num_rows, -1) # [num_rows, N*padded_len]
795+
796+
(tokens_embd, rot_mats_global, rot_mats_local, page_table_tt, _) = model[
797+
model_id
798+
].prepare_inputs_prefill(
799+
tokens_for_model,
800+
page_table=page_table_stacked,
801+
batched_prefill=True,
802+
)
803+
804+
# Use get_last_token if all users' last tokens fall in the same 32-token tile
805+
min_tile = (min(batch_last_token_idxs) // 32) * 32
806+
max_tile = (max(batch_last_token_idxs) // 32) * 32
807+
get_last_token_val = min_tile if min_tile == max_tile else -1
808+
tt_logits = model[model_id].ttnn_prefill_forward(
809+
tokens_embd,
810+
rot_mats_global=rot_mats_global,
811+
rot_mats_local=rot_mats_local,
812+
user_id=0, # Must be 0: each device sees page_table[0] after row-sharding
813+
page_table=page_table_tt,
814+
get_last_token=get_last_token_val,
815+
kv_cache=tt_kv_cache[model_id],
816+
batch_size=users_per_row_per_iter,
817+
)
818+
819+
if get_last_token_val == -1:
820+
adjusted_last_idxs = batch_last_token_idxs
821+
seq_len_for_output = padded_len
822+
else:
823+
adjusted_last_idxs = [idx % 32 for idx in batch_last_token_idxs]
824+
seq_len_for_output = 32
825+
row_results = model[model_id].process_output_prefill_batched(
826+
tt_logits,
827+
adjusted_last_idxs,
828+
users_per_row=users_per_row_per_iter,
829+
seq_len_per_user=seq_len_for_output,
830+
)
831+
return row_results
832+
833+
# Warmup: compile with first batch
834+
logger.info("Starting row-parallel prefill warmup...")
835+
profiler.start(f"compile_prefill", iteration=batch_idx)
836+
warmup_user_indices = [
837+
row * users_per_row_prefill + u for row in range(num_rows) for u in range(users_per_row_per_iter)
838+
]
839+
warmup_results = _run_batched_prefill_iter(0, warmup_user_indices)
840+
for row, uid in enumerate(warmup_user_indices):
841+
prefilled_token[uid] = torch.argmax(warmup_results[row].view(-1)).item()
842+
profiler.end(f"compile_prefill", iteration=batch_idx)
843+
logger.info("Finished row-parallel prefill warmup")
844+
845+
# Clear KV caches before real prefill (warmup wrote to them)
846+
for i in range(len(model)):
847+
for layer_obj in model[i].layers:
848+
k_cache, v_cache = layer_obj.self_attn.layer_past
849+
ttnn.mul(k_cache, 0, output_tensor=k_cache)
850+
ttnn.mul(v_cache, 0, output_tensor=v_cache)
851+
852+
# Real prefill
853+
logger.info(
854+
f"Starting row-parallel batched prefill ({num_prefill_iters} iters, "
855+
f"{users_per_row_per_iter} user/row/iter, {global_batch_size} users)..."
856+
)
857+
profiler.start(f"inference_prefill", iteration=batch_idx)
858+
for iter_idx in range(num_prefill_iters):
859+
user_indices = [
860+
row * users_per_row_prefill + iter_idx * users_per_row_per_iter + u
861+
for row in range(num_rows)
862+
for u in range(users_per_row_per_iter)
863+
]
864+
row_results = _run_batched_prefill_iter(iter_idx, user_indices)
865+
for row, uid in enumerate(user_indices):
866+
prefilled_token[uid] = torch.argmax(row_results[row].view(-1)).item()
867+
if iter_idx % 8 == 0:
868+
logger.info(f" Prefilled batch {iter_idx+1}/{num_prefill_iters}")
869+
profiler.end(f"inference_prefill", iteration=batch_idx)
870+
logger.info(f"Row-parallel batched prefill finished ({num_prefill_iters} iterations)")
588871

589-
logger.info(f"Starting prefill...")
590-
profiler.start(f"inference_prefill", iteration=batch_idx)
591-
logits = generator.prefill_forward_text(
592-
input_tokens_prefill_pt,
593-
page_table=page_table,
594-
kv_cache=tt_kv_cache,
595-
prompt_lens=decoding_pos,
596-
enable_trace=enable_prefill_trace,
597-
warmup_prefill=False, # we can warmup prefill ourselves above if we want to
598-
)
599-
prefilled_token = torch.argmax(logits, dim=-1)
600-
profiler.end(f"inference_prefill", iteration=batch_idx)
601-
logger.info(f"Prefill finished")
602872
logger.info(f"First generated token: '{tokenizer.decode(prefilled_token[0])}'")
603873

604874
# Initialize generation state like tt_transformers

models/demos/gpt_oss/tt/attention/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,15 @@ def __init__(
108108
self.scaling = config.scaling
109109

110110
def __call__(
111-
self, hidden_states, rope_mats, position_idx=None, page_table=None, kv_cache=None, is_decode=True, user_id=0
111+
self,
112+
hidden_states,
113+
rope_mats,
114+
position_idx=None,
115+
page_table=None,
116+
kv_cache=None,
117+
is_decode=True,
118+
user_id=0,
119+
batch_size=1,
112120
):
113121
"""
114122
Forward pass - automatically dispatches to decode or prefill.
@@ -169,4 +177,5 @@ def __call__(
169177
position_idx=position_idx,
170178
page_table=page_table,
171179
ccl_manager=self.ccl_manager,
180+
batch_size=batch_size,
172181
)

0 commit comments

Comments
 (0)