@@ -153,13 +153,13 @@ def gpt_oss_experts_mlp_reference(
153153 Returns:
154154 Expert output tensor [num_experts_per_device, B, S, H]
155155 """
156- num_tokens = batch_size * seq_len
156+ total_tokens = batch_size * seq_len
157157 num_experts = w1 .shape [0 ]
158158 hidden_size = config .hidden_size
159159 intermediate_size = config .intermediate_size
160160
161161 # Reshape input: [1, 1, B*S, H] -> [B*S, H]
162- x = post_dispatch .reshape (num_tokens , hidden_size )
162+ x = post_dispatch .reshape (total_tokens , hidden_size )
163163
164164 # Expand for all experts: [B*S, H] -> [num_experts, B*S, H]
165165 x_expanded = x .unsqueeze (0 ).expand (num_experts , - 1 , - 1 )
@@ -196,8 +196,7 @@ def gpt_oss_experts_mlp_ttnn(
196196 config : ThroughputExpertConfig ,
197197 program_config : ThroughputProgramConfig ,
198198 memory_config : ttnn .MemoryConfig ,
199- batch_size : int ,
200- seq_len : int ,
199+ total_tokens : int ,
201200 mesh_device = None ,
202201 save_intermediate : bool = False ,
203202) -> ttnn .Tensor :
@@ -227,8 +226,7 @@ def gpt_oss_experts_mlp_ttnn(
227226 config = config ,
228227 program_config = program_config ,
229228 memory_config = memory_config ,
230- batch_size = batch_size ,
231- seq_len = seq_len ,
229+ total_tokens = total_tokens ,
232230 mesh_device = mesh_device ,
233231 save_intermediate = save_intermediate ,
234232 )
@@ -482,13 +480,13 @@ def _run_experts_mlp_test(
482480
483481 # Create input tensor (post_dispatch output)
484482 # Shape: [1, 1, B*S, H]
485- num_tokens = batch_size * seq_len
486- post_dispatch_torch = torch .randn (1 , 1 , num_tokens , hidden_size , dtype = torch .bfloat16 )
483+ total_tokens = batch_size * seq_len
484+ post_dispatch_torch = torch .randn (1 , 1 , total_tokens , hidden_size , dtype = torch .bfloat16 )
487485
488486 # Create sparsity tensor - for reference we'll compute dense
489487 # In practice sparsity indicates which (token_block, expert) pairs are active
490488 # For this test, we'll assume all tokens are active for all experts (dense case)
491- num_sparse_blocks = num_tokens // throughput_config .sparsity_block_size
489+ num_sparse_blocks = total_tokens // throughput_config .sparsity_block_size
492490 num_experts_per_device = throughput_config .num_experts_per_device
493491
494492 # Create full sparsity tensor (all ones = all active)
@@ -536,8 +534,7 @@ def _run_experts_mlp_test(
536534 config = throughput_config ,
537535 program_config = program_config ,
538536 memory_config = memory_config ,
539- batch_size = batch_size ,
540- seq_len = seq_len ,
537+ total_tokens = total_tokens ,
541538 mesh_device = mesh_device ,
542539 )
543540
@@ -572,8 +569,7 @@ def op_fn():
572569 config = throughput_config ,
573570 program_config = program_config ,
574571 memory_config = memory_config ,
575- batch_size = batch_size ,
576- seq_len = seq_len ,
572+ total_tokens = total_tokens ,
577573 mesh_device = mesh_device ,
578574 )
579575
@@ -819,11 +815,11 @@ def test_gpt_oss_experts_mlp_single_device(
819815 w2_ref = state_dict ["down_proj" ]
820816
821817 # Create input tensor
822- num_tokens = batch_size_per_device * seq_len
823- post_dispatch_torch = torch .randn (1 , 1 , num_tokens , hidden_size , dtype = torch .bfloat16 )
818+ total_tokens = batch_size_per_device * seq_len
819+ post_dispatch_torch = torch .randn (1 , 1 , total_tokens , hidden_size , dtype = torch .bfloat16 )
824820
825821 # Create sparsity tensor
826- num_sparse_blocks = num_tokens // throughput_config .sparsity_block_size
822+ num_sparse_blocks = total_tokens // throughput_config .sparsity_block_size
827823 num_experts_per_device = throughput_config .num_experts_per_device
828824 sparsity_torch = torch .ones (num_sparse_blocks , 1 , 1 , num_experts_per_device , dtype = torch .bfloat16 )
829825
0 commit comments