Skip to content

Commit 6de9ea4

Browse files
committed
fix
1 parent 7e9a700 commit 6de9ea4

File tree

2 files changed

+18
-25
lines changed

2 files changed

+18
-25
lines changed

models/demos/gpt_oss/tests/fused_op_unit_tests/test_gpt_oss_experts_mlp.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,13 @@ def gpt_oss_experts_mlp_reference(
153153
Returns:
154154
Expert output tensor [num_experts_per_device, B, S, H]
155155
"""
156-
num_tokens = batch_size * seq_len
156+
total_tokens = batch_size * seq_len
157157
num_experts = w1.shape[0]
158158
hidden_size = config.hidden_size
159159
intermediate_size = config.intermediate_size
160160

161161
# Reshape input: [1, 1, B*S, H] -> [B*S, H]
162-
x = post_dispatch.reshape(num_tokens, hidden_size)
162+
x = post_dispatch.reshape(total_tokens, hidden_size)
163163

164164
# Expand for all experts: [B*S, H] -> [num_experts, B*S, H]
165165
x_expanded = x.unsqueeze(0).expand(num_experts, -1, -1)
@@ -196,8 +196,7 @@ def gpt_oss_experts_mlp_ttnn(
196196
config: ThroughputExpertConfig,
197197
program_config: ThroughputProgramConfig,
198198
memory_config: ttnn.MemoryConfig,
199-
batch_size: int,
200-
seq_len: int,
199+
total_tokens: int,
201200
mesh_device=None,
202201
save_intermediate: bool = False,
203202
) -> ttnn.Tensor:
@@ -227,8 +226,7 @@ def gpt_oss_experts_mlp_ttnn(
227226
config=config,
228227
program_config=program_config,
229228
memory_config=memory_config,
230-
batch_size=batch_size,
231-
seq_len=seq_len,
229+
total_tokens=total_tokens,
232230
mesh_device=mesh_device,
233231
save_intermediate=save_intermediate,
234232
)
@@ -482,13 +480,13 @@ def _run_experts_mlp_test(
482480

483481
# Create input tensor (post_dispatch output)
484482
# Shape: [1, 1, B*S, H]
485-
num_tokens = batch_size * seq_len
486-
post_dispatch_torch = torch.randn(1, 1, num_tokens, hidden_size, dtype=torch.bfloat16)
483+
total_tokens = batch_size * seq_len
484+
post_dispatch_torch = torch.randn(1, 1, total_tokens, hidden_size, dtype=torch.bfloat16)
487485

488486
# Create sparsity tensor - for reference we'll compute dense
489487
# In practice sparsity indicates which (token_block, expert) pairs are active
490488
# For this test, we'll assume all tokens are active for all experts (dense case)
491-
num_sparse_blocks = num_tokens // throughput_config.sparsity_block_size
489+
num_sparse_blocks = total_tokens // throughput_config.sparsity_block_size
492490
num_experts_per_device = throughput_config.num_experts_per_device
493491

494492
# Create full sparsity tensor (all ones = all active)
@@ -536,8 +534,7 @@ def _run_experts_mlp_test(
536534
config=throughput_config,
537535
program_config=program_config,
538536
memory_config=memory_config,
539-
batch_size=batch_size,
540-
seq_len=seq_len,
537+
total_tokens=total_tokens,
541538
mesh_device=mesh_device,
542539
)
543540

@@ -572,8 +569,7 @@ def op_fn():
572569
config=throughput_config,
573570
program_config=program_config,
574571
memory_config=memory_config,
575-
batch_size=batch_size,
576-
seq_len=seq_len,
572+
total_tokens=total_tokens,
577573
mesh_device=mesh_device,
578574
)
579575

@@ -819,11 +815,11 @@ def test_gpt_oss_experts_mlp_single_device(
819815
w2_ref = state_dict["down_proj"]
820816

821817
# Create input tensor
822-
num_tokens = batch_size_per_device * seq_len
823-
post_dispatch_torch = torch.randn(1, 1, num_tokens, hidden_size, dtype=torch.bfloat16)
818+
total_tokens = batch_size_per_device * seq_len
819+
post_dispatch_torch = torch.randn(1, 1, total_tokens, hidden_size, dtype=torch.bfloat16)
824820

825821
# Create sparsity tensor
826-
num_sparse_blocks = num_tokens // throughput_config.sparsity_block_size
822+
num_sparse_blocks = total_tokens // throughput_config.sparsity_block_size
827823
num_experts_per_device = throughput_config.num_experts_per_device
828824
sparsity_torch = torch.ones(num_sparse_blocks, 1, 1, num_experts_per_device, dtype=torch.bfloat16)
829825

models/demos/gpt_oss/tt/experts_throughput/decode.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ def expert_mlp_forward(
108108
config: ThroughputExpertConfig,
109109
program_config: ThroughputProgramConfig,
110110
memory_config: ttnn.MemoryConfig,
111-
batch_size: int,
112-
seq_len: int,
111+
total_tokens: int,
113112
mesh_device=None,
114113
save_intermediate: bool = False,
115114
) -> ttnn.Tensor:
@@ -139,8 +138,7 @@ def expert_mlp_forward(
139138
"""
140139
# Reshape to sparse block format for matmul
141140
# Note: reshape returns a view - don't deallocate post_dispatch separately
142-
num_tokens = batch_size * seq_len
143-
num_sparse_blocks = num_tokens // config.sparsity_block_size
141+
num_sparse_blocks = total_tokens // config.sparsity_block_size
144142
reshaped_expert_input = ttnn.reshape(
145143
experts_input,
146144
shape=(1, num_sparse_blocks, config.sparsity_block_size, config.hidden_size),
@@ -174,7 +172,7 @@ def expert_mlp_forward(
174172

175173
# Up projection (w3): same shape as gate
176174
w3_out = ttnn.sparse_matmul(
177-
expert_input,
175+
reshaped_expert_input,
178176
weights.w3,
179177
sparsity=sparsity,
180178
memory_config=memory_config,
@@ -183,7 +181,7 @@ def expert_mlp_forward(
183181
is_input_b_sparse=True,
184182
output_tile=ttnn.Tile([config.sparsity_block_size, ttnn.TILE_SIZE]),
185183
)
186-
ttnn.deallocate(expert_input)
184+
ttnn.deallocate(reshaped_expert_input)
187185

188186
# Add up bias
189187
# w3_out shape: [1, num_sparse_blocks, 1, num_experts_per_device, block_size, intermediate]
@@ -306,7 +304,7 @@ def decode_forward(
306304
topk_expert_indices = ttnn.typecast(topk_expert_indices, dtype=ttnn.uint32)
307305
topk_expert_indices = ttnn.reshape(topk_expert_indices, (-1, 1, 1, config.num_experts_per_tok))
308306
topk_expert_indices = ttnn.typecast(topk_expert_indices, dtype=ttnn.uint16)
309-
307+
310308
topk_expert_weights = ttnn.reshape(topk_expert_weights, (-1, 1, 1, config.num_experts_per_tok))
311309

312310
num_dispatch_devices = (
@@ -410,8 +408,7 @@ def decode_forward(
410408
config=config,
411409
program_config=program_config,
412410
memory_config=dispatch_config.memory_config,
413-
batch_size=batch_size,
414-
seq_len=seq_len,
411+
total_tokens=total_tokens,
415412
mesh_device=mesh_device,
416413
save_intermediate=False,
417414
)

0 commit comments

Comments
 (0)