Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions models/demos/gpt_oss/demo/text_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def prepare_gpt_oss_generator_args(
1, # data_parallel
128, # batch_size
1, # repeat_batches
8 * 1024, # max_seq_len
128 * 1024, # max_seq_len
200, # max_generated_tokens
{"page_block_size": 64, "page_max_num_blocks_per_dp": 128 * 1024 // 64}, # page_params
{"temperature": 0, "top_p": 0.08}, # sampling_params (greedy decoding)
Expand Down Expand Up @@ -827,29 +827,47 @@ def test_gpt_oss_demo(
if f"batch_{batch_size}" in perf_targets["ci"]:
if f"prefill_{prefill_pad_length}" in perf_targets["ci"][f"batch_{batch_size}"]:
if model_device_key in perf_targets["ci"][f"batch_{batch_size}"][f"prefill_{prefill_pad_length}"]:
current_ttft_target = perf_targets["ci"][f"batch_{batch_size}"][f"prefill_{prefill_pad_length}"][
perf_config = perf_targets["ci"][f"batch_{batch_size}"][f"prefill_{prefill_pad_length}"][
model_device_key
]["TTFT"]
]

# Parse TTFT target with tolerance
current_ttft_target = perf_config["TTFT"]
if isinstance(current_ttft_target, list):
high_tol_percentage = current_ttft_target[1]
ttft_tolerance = current_ttft_target[1]
current_ttft_target = current_ttft_target[0]
else:
high_tol_percentage = 1.15
ci_targets = {
ttft_tolerance = 1.15 # Default 15% tolerance

# Parse decode_tok_s_u target with tolerance
decode_tsu_target = perf_config["decode_tok_s_u"]
if isinstance(decode_tsu_target, list):
decode_tolerance = decode_tsu_target[1]
decode_tsu_target = decode_tsu_target[0]
else:
decode_tolerance = 1.15 # Default 15% tolerance

# Verify prefill performance with prefill-specific tolerance
prefill_targets = {
"prefill_time_to_token": current_ttft_target / 1000, # convert to seconds
"decode_t/s/u": perf_targets["ci"][f"batch_{batch_size}"][f"prefill_{prefill_pad_length}"][
model_device_key
]["decode_tok_s_u"],
"decode_t/s": perf_targets["ci"][f"batch_{batch_size}"][f"prefill_{prefill_pad_length}"][
model_device_key
]["decode_tok_s_u"]
* global_batch_size, # calculate from per-user rate
}
verify_perf(
measurements,
ci_targets,
high_tol_percentage=high_tol_percentage,
expected_measurements={k: True for k in ci_targets.keys()},
prefill_targets,
high_tol_percentage=ttft_tolerance,
expected_measurements={k: True for k in prefill_targets.keys()},
)

# Verify decode performance with decode-specific tolerance
decode_targets = {
"decode_t/s/u": decode_tsu_target,
"decode_t/s": decode_tsu_target * global_batch_size, # calculate from per-user rate
}
verify_perf(
measurements,
decode_targets,
high_tol_percentage=decode_tolerance,
expected_measurements={k: True for k in decode_targets.keys()},
)
else:
logger.warning(
Expand Down
20 changes: 10 additions & 10 deletions models/demos/gpt_oss/perf_targets.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,37 +30,37 @@
"prefill_128": {
"T3K_gpt-oss-20b": {
"TTFT": [400, 1.2],
"decode_tok_s_u": 12,
"decode_tok_s_u": [12, 1.15],
"decode_tok_s": 2200
},
"T3K_gpt-oss-120b": {
"TTFT": [1700, 1.2],
"decode_tok_s_u": 8,
"decode_tok_s_u": [8, 1.15],
"decode_tok_s": 2200
},
"GLX_gpt-oss-20b": {
"TTFT": [500, 1.2],
"decode_tok_s_u": 10,
"decode_tok_s_u": [10, 1.15],
"decode_tok_s": 2200
},
"GLX_gpt-oss-120b": {
"TTFT": [1300, 1.2],
"decode_tok_s_u": 6.5,
"decode_tok_s_u": [6.5, 1.15],
"decode_tok_s": 2200
}
}
},
"batch_128": {
"prefill_128": {
"GLX_gpt-oss-20b": {
"TTFT": [250, 1.2],
"decode_tok_s_u": 4.5,
"decode_tok_s": 600
"TTFT": [250, 2.0],
"decode_tok_s_u": [5.0, 2.0],
"decode_tok_s": 650
},
"GLX_gpt-oss-120b": {
"TTFT": [650, 1.2],
"decode_tok_s_u": 3.0,
"decode_tok_s": 400
"TTFT": [400, 2.0],
"decode_tok_s_u": [3.0, 2.0],
"decode_tok_s": 380
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion models/demos/gpt_oss/tt/experts_throughput/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class AllToAllDispatchConfig:
num_links: int = 1
topology: ttnn.Topology = field(default_factory=lambda: ttnn.Topology.Ring)
subdevice_id: Optional[int] = None
output_concat_dim: Optional[int] = 1 # 1 for decode 2 for prefill
output_concat_dim: Optional[int] = 2 # 2 for tokens on seq_len dim (decode and prefill)

def as_dict(self):
"""Convert to kwargs dict for ttnn.all_to_all_dispatch."""
Expand All @@ -115,6 +115,7 @@ class AllToAllCombineConfig:
memory_config: ttnn.MemoryConfig = field(default_factory=lambda: ttnn.L1_MEMORY_CONFIG)
num_links: int = 1
topology: ttnn.Topology = field(default_factory=lambda: ttnn.Topology.Ring)
output_shard_dim: int = 2 # 1 for batch dim, 2 for seq_len dim (prefer 2 for decode)

def as_dict(self):
"""Convert to kwargs dict for ttnn.all_to_all_combine."""
Expand All @@ -123,6 +124,7 @@ def as_dict(self):
"memory_config": self.memory_config,
"num_links": self.num_links,
"topology": self.topology,
"output_shard_dim": self.output_shard_dim,
}


Expand Down
111 changes: 55 additions & 56 deletions models/demos/gpt_oss/tt/experts_throughput/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,60 +112,68 @@ def decode_forward(
Returns:
Output tensor [batch_size_per_device, 1, seq_len, hidden_size]
Copy link

Copilot AI Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring states the output tensor has shape [batch_size_per_device, 1, seq_len, hidden_size], but the actual implementation returns [1, 1, tokens_per_device, hidden_size] where tokens_per_device = batch_size_per_device * seq_len. The docstring should be updated to reflect the actual output shape, or the implementation should reshape the output back to match the documented shape if that's what callers expect.

Suggested change
Output tensor [batch_size_per_device, 1, seq_len, hidden_size]
Output tensor [1, 1, tokens_per_device, hidden_size], where
tokens_per_device = batch_size_per_device * seq_len.

Copilot uses AI. Check for mistakes.
"""
# Note: reshape returns views - don't deallocate originals
hidden_states = ttnn.reshape(hidden_states, (-1, 1, 1, config.hidden_size))
# ==========================================================================
# STEP 0: RESHAPE TO PUT TOKENS ON DIM -2 (seq_len dimension)
# ==========================================================================
# This optimization reduces reshapes by keeping tokens on seq_len dim throughout.
# Input typically comes as [B, 1, S, H] where B*S = total tokens per device.
# We reshape to [1, 1, tokens_per_device, H] so tokens are on dim -2.

# Get total tokens per device
input_shape = hidden_states.shape
tokens_per_device = input_shape[0] * input_shape[2] # B * S

# Reshape hidden states: put all tokens on dim -2
hidden_states = ttnn.reshape(hidden_states, (1, 1, tokens_per_device, config.hidden_size))

# typecast creates new tensors - safe to deallocate originals
topk_expert_indices_orig = topk_expert_indices
topk_expert_indices = ttnn.typecast(topk_expert_indices, dtype=ttnn.uint32)
ttnn.deallocate(topk_expert_indices_orig)

topk_expert_indices = ttnn.reshape(topk_expert_indices, (-1, 1, 1, config.num_experts_per_tok))
# Reshape indices: put all tokens on dim -2
topk_expert_indices = ttnn.reshape(topk_expert_indices, (1, 1, tokens_per_device, config.num_experts_per_tok))
topk_expert_indices_u32 = topk_expert_indices
topk_expert_indices = ttnn.typecast(topk_expert_indices, dtype=ttnn.uint16)
ttnn.deallocate(topk_expert_indices_u32)
topk_expert_weights = ttnn.reshape(topk_expert_weights, (-1, 1, 1, config.num_experts_per_tok))
topk_expert_weights = ttnn.reshape(topk_expert_weights, (1, 1, tokens_per_device, config.num_experts_per_tok))

seq_len = 1 # Decode mode always has seq_len=1
batch_size_per_device = hidden_states.shape[0]
num_dispatch_devices = (
mesh_device.shape[dispatch_config.cluster_axis]
if dispatch_config.cluster_axis is not None
else prod(mesh_device.shape)
)
batch_size = batch_size_per_device * num_dispatch_devices # Global batch across dispatch axis
total_tokens = tokens_per_device * num_dispatch_devices # Global tokens across dispatch axis

# ==========================================================================
# STEP 1: PREPARE INPUTS FOR ALL_TO_ALL_DISPATCH
# ==========================================================================
# all_to_all_dispatch requires ROW_MAJOR layout with shape [B, 1, S, H]
# Convert from TILE layout used by transformer layers
# With tokens on dim -2: [1, 1, tokens_per_device, H]
# to_layout creates new tensors - safe to deallocate originals
hidden_rm = ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT)
ttnn.deallocate(hidden_states)
hidden_rm = ttnn.reshape(hidden_rm, shape=(batch_size_per_device, 1, seq_len, config.hidden_size))
# Shape is already [1, 1, tokens_per_device, H], just ensure it's correct
hidden_rm = ttnn.reshape(hidden_rm, shape=(1, 1, tokens_per_device, config.hidden_size))

# Expert indices need to be in ROW_MAJOR with shape [B, 1, S, K]
# where K = num_experts_per_tok (top-k experts selected per token)
# Expert indices: [1, 1, tokens_per_device, K]
topk_indices_rm = ttnn.to_layout(topk_expert_indices, ttnn.ROW_MAJOR_LAYOUT)
ttnn.deallocate(topk_expert_indices)
topk_indices_rm = ttnn.reshape(
topk_indices_rm, shape=(batch_size_per_device, 1, seq_len, config.num_experts_per_tok)
)
topk_indices_rm = ttnn.reshape(topk_indices_rm, shape=(1, 1, tokens_per_device, config.num_experts_per_tok))

# ==========================================================================
# STEP 2: ALL_TO_ALL_DISPATCH - Route tokens to expert devices
# ==========================================================================
# Dispatch sends each token to the device(s) that own its assigned expert(s)
#
# Inputs:
# - hidden_rm: [B_per_device, 1, S, H] - token embeddings
# - topk_indices_rm: [B_per_device, 1, S, K] - which experts each token routes to
# Inputs (tokens on dim -2):
# - hidden_rm: [1, 1, tokens_per_device, H] - token embeddings
# - topk_indices_rm: [1, 1, tokens_per_device, K] - which experts each token routes to
# - expert_mapping_tensors: [1, 1, E, D] - one-hot mapping of expert -> device
#
# Outputs:
# - dispatch_output: [D, B_global, S, H] - tokens scattered to expert devices
# - dispatch_metadata: [D, B_global, S, K] - expert indices (for combine routing)
# With output_concat_dim=2, outputs have seq_len scaled:
# - dispatch_output: [D, 1, total_tokens, H] - tokens scattered to expert devices
# - dispatch_metadata: [D, 1, total_tokens, K] - expert indices (for combine routing)
dispatch_output, dispatch_metadata = ttnn.all_to_all_dispatch(
hidden_rm,
topk_indices_rm,
Expand All @@ -181,12 +189,16 @@ def decode_forward(
# Converts global expert indices to local (per-device) indices and creates
# a sparsity mask for efficient sparse matmul.
#
# The remap_topk_mask is broadcast across batch dimension
# The remap_topk_mask is broadcast across the token dimension (now on dim -2)
# repeat creates a new tensor - safe to deallocate, but remap_topk_mask is reused externally
remap_mask = ttnn.repeat(remap_topk_mask, ttnn.Shape((1, batch_size_per_device, 1, 1)))
# remap_topk_mask: [1, dispatch_rows, 1, num_experts]
# -> repeat to [1, dispatch_rows, tokens_per_device, num_experts]
# -> reshape to [1, 1, total_tokens, num_experts] to match dispatch_metadata batch/seq dims
remap_mask = ttnn.repeat(remap_topk_mask, ttnn.Shape((1, 1, tokens_per_device, 1)))
remap_mask = ttnn.reshape(remap_mask, (1, 1, total_tokens, config.num_experts))
# moe_expert_token_remap returns:
# - mapping: [D, B, S, experts_per_device] - local expert activation weights
# - sparsity: [D, 1, B*S/reduction_size, experts_per_device] - which blocks are active
# - mapping: [D, tokens, 1, experts_per_device] - local expert activation weights
# - sparsity: [D, 1, tokens/reduction_size, experts_per_device] - which blocks are active
#
# The sparsity tensor tells sparse_matmul which expert blocks have tokens,
# avoiding computation on empty slots.
Expand All @@ -202,21 +214,20 @@ def decode_forward(
# STEP 4: PREPARE DISPATCH OUTPUT FOR EXPERT COMPUTATION
# ==========================================================================
# Reshape dispatch output for sparse matmul:
# From: [D, B, S, H] (ROW_MAJOR from dispatch)
# To: [1, B*S/block_size, block_size, H] (TILE for matmul)
# From: [D, 1, total_tokens, H] (ROW_MAJOR from dispatch with tokens on dim -2)
# To: [1, total_tokens/block_size, block_size, H] (TILE for matmul)
#
# The sparse matmul operates on blocks of tokens, with sparsity indicating
# which (token_block, expert) pairs need computation.
# Note: reshape returns view, but to_layout creates new tensor
post_dispatch = ttnn.reshape(dispatch_output, shape=(1, 1, batch_size * seq_len, config.hidden_size))
post_dispatch = ttnn.reshape(dispatch_output, shape=(1, 1, total_tokens, config.hidden_size))
post_dispatch_rm = post_dispatch
post_dispatch = ttnn.to_layout(post_dispatch, ttnn.TILE_LAYOUT)
ttnn.deallocate(post_dispatch_rm) # This deallocates dispatch_output via the view

# Reshape to sparse block format for matmul
# Note: reshape returns a view - don't deallocate post_dispatch separately
num_tokens = batch_size * seq_len
num_sparse_blocks = num_tokens // config.sparsity_block_size
num_sparse_blocks = total_tokens // config.sparsity_block_size
expert_input = ttnn.reshape(
post_dispatch,
shape=(1, num_sparse_blocks, config.sparsity_block_size, config.hidden_size),
Expand Down Expand Up @@ -306,17 +317,18 @@ def decode_forward(
expert_output_sparse = ttnn.squeeze(expert_output_sparse, 0)

# Reshape from sparse matmul output to format expected by combine:
# From: [B*S/block, experts, block, H]
# To: [experts_per_device, B_global, S, H] (ROW_MAJOR)
# From: [total_tokens/block, experts, block, H]
# To: [experts_per_device, 1, total_tokens, H] (ROW_MAJOR, tokens on dim -2)
#
# Permute to get experts_per_device as first dimension (what combine expects)
# permute creates a new tensor - safe to deallocate original
expert_output = ttnn.permute(expert_output_sparse, (1, 0, 2, 3))
ttnn.deallocate(expert_output_sparse)
# Note: reshape returns a view, to_layout creates new tensor
# With tokens on dim -2: [experts_per_device, 1, total_tokens, H]
expert_output = ttnn.reshape(
expert_output,
shape=(config.num_experts_per_device, batch_size, seq_len, config.hidden_size),
shape=(config.num_experts_per_device, 1, total_tokens, config.hidden_size),
)
expert_output_tiled = expert_output
expert_output = ttnn.to_layout(expert_output, ttnn.ROW_MAJOR_LAYOUT)
Expand All @@ -328,7 +340,8 @@ def decode_forward(
# Combine routes each expert output back to the device that owns the original token.
# Uses dispatch_metadata to know which token each output corresponds to.
#
# Output shape: [num_experts_per_tok, B_per_device, S, H]
# With output_shard_dim=2, output has tokens sharded on dim -2:
# Output shape: [num_experts_per_tok, 1, tokens_per_device, H]
# (each token gets outputs from k experts stacked in first dimension)
combine_output = ttnn.all_to_all_combine(
expert_output,
Expand All @@ -342,33 +355,19 @@ def decode_forward(
# ==========================================================================
# STEP 8: APPLY ROUTING WEIGHTS AND REDUCE ACROSS EXPERTS
# ==========================================================================
# Reshape combine output for weighted sum:
# Shape: [K, 1, B_per_device * S, H] where K = num_experts_per_tok
# Note: reshape returns view, to_layout creates new tensor
post_combine = ttnn.reshape(
combine_output,
shape=(config.num_experts_per_tok, 1, batch_size_per_device * seq_len, config.hidden_size),
)
post_combine_rm = post_combine
post_combine = ttnn.to_layout(post_combine, ttnn.TILE_LAYOUT)
ttnn.deallocate(post_combine_rm) # Deallocates combine_output via its reshape view
# Combine output already has tokens on dim -2: [K, 1, tokens_per_device, H]
# No reshape needed! Just convert to TILE layout.
post_combine = ttnn.to_layout(combine_output, ttnn.TILE_LAYOUT)
ttnn.deallocate(combine_output)

# Prepare routing weights for broadcasting:
# From: [B, 1, S, K] (original topk weights)
# To: [K, 1, B*S, H] (matches post_combine for element-wise multiply)
#
# Steps:
# 1. Repeat along hidden_size dimension
# 2. Permute to [K, 1, B*S, H]
# topk_expert_weights is [1, 1, tokens_per_device, K] (tokens on dim -2)
# We want [K, 1, tokens_per_device, 1] so it can broadcast across hidden_size.
# to_layout creates new tensor - safe to deallocate original
topk_weights_rm = ttnn.to_layout(topk_expert_weights, ttnn.ROW_MAJOR_LAYOUT)
ttnn.deallocate(topk_expert_weights)
# repeat creates new tensor - safe to deallocate original
topk_weights_repeated = ttnn.repeat(topk_weights_rm, ttnn.Shape((1, 1, config.hidden_size, 1)))
ttnn.deallocate(topk_weights_rm)
# permute creates new tensor - safe to deallocate original
topk_weights_rm = ttnn.permute(topk_weights_repeated, (3, 1, 0, 2))
ttnn.deallocate(topk_weights_repeated)
# permute to [K, 1, tokens_per_device, 1]
topk_weights_rm = ttnn.permute(topk_weights_rm, (3, 1, 2, 0))
topk_weights_reshaped = ttnn.to_layout(topk_weights_rm, ttnn.TILE_LAYOUT)
ttnn.deallocate(topk_weights_rm)

Expand Down Expand Up @@ -397,5 +396,5 @@ def decode_forward(
)
ttnn.deallocate(output)

# Final shape: [1, 1, B_per_device * S, H]
# Final shape: [1, 1, tokens_per_device, H] (tokens on dim -2)
return output_all_reduced
3 changes: 1 addition & 2 deletions models/demos/gpt_oss/tt/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def __init__(
state_dict=experts_state_dict,
weight_dtype=ttnn.bfloat4_b,
dispatch_cluster_axis=0,
# decode_memory_config=ttnn.L1_MEMORY_CONFIG,
decode_memory_config=ttnn.DRAM_MEMORY_CONFIG, ## Change this back to L1 when test runs
decode_memory_config=ttnn.L1_MEMORY_CONFIG, # L1 for better decode throughput
tensor_cache_path=get_cache_file_name(tensor_cache_path, "experts"),
)
else:
Expand Down
Loading