@@ -145,6 +145,12 @@ def decode_forward(
145145 )
146146 total_tokens = tokens_per_device * num_dispatch_devices # Global tokens across dispatch axis
147147
148+ # Debug: bypass all_to_all_dispatch only.
149+ disable_dispatch = True
150+ if disable_dispatch :
151+ num_dispatch_devices = 1
152+ total_tokens = tokens_per_device
153+
148154 # ==========================================================================
149155 # STEP 1: PREPARE INPUTS FOR ALL_TO_ALL_DISPATCH
150156 # ==========================================================================
@@ -174,14 +180,18 @@ def decode_forward(
174180 # With output_concat_dim=2, outputs have seq_len scaled:
175181 # - dispatch_output: [D, 1, total_tokens, H] - tokens scattered to expert devices
176182 # - dispatch_metadata: [D, 1, total_tokens, K] - expert indices (for combine routing)
177- dispatch_output , dispatch_metadata = ttnn .all_to_all_dispatch (
178- hidden_rm ,
179- topk_indices_rm ,
180- expert_mapping_tensors ,
181- ** dispatch_config .as_dict (),
182- )
183- ttnn .deallocate (hidden_rm )
184- ttnn .deallocate (topk_indices_rm )
183+ if disable_dispatch :
184+ dispatch_output = hidden_rm
185+ dispatch_metadata = topk_indices_rm
186+ else :
187+ dispatch_output , dispatch_metadata = ttnn .all_to_all_dispatch (
188+ hidden_rm ,
189+ topk_indices_rm ,
190+ expert_mapping_tensors ,
191+ ** dispatch_config .as_dict (),
192+ )
193+ ttnn .deallocate (hidden_rm )
194+ ttnn .deallocate (topk_indices_rm )
185195
186196 # ==========================================================================
187197 # STEP 3: MOE_EXPERT_TOKEN_REMAP - Create sparsity pattern
@@ -195,7 +205,14 @@ def decode_forward(
195205 # -> repeat to [1, dispatch_rows, tokens_per_device, num_experts]
196206 # -> reshape to [1, 1, total_tokens, num_experts] to match dispatch_metadata batch/seq dims
197207 remap_mask = ttnn .repeat (remap_topk_mask , ttnn .Shape ((1 , 1 , tokens_per_device , 1 )))
198- remap_mask = ttnn .reshape (remap_mask , (1 , 1 , total_tokens , config .num_experts ))
208+ if disable_dispatch :
209+ remap_mask = ttnn .slice (
210+ remap_mask ,
211+ [0 , 0 , 0 , 0 ],
212+ [1 , 1 , tokens_per_device , config .num_experts ],
213+ )
214+ else :
215+ remap_mask = ttnn .reshape (remap_mask , (1 , 1 , total_tokens , config .num_experts ))
199216 # moe_expert_token_remap returns:
200217 # - mapping: [D, tokens, 1, experts_per_device] - local expert activation weights
201218 # - sparsity: [D, 1, tokens/reduction_size, experts_per_device] - which blocks are active
@@ -371,6 +388,13 @@ def decode_forward(
371388 topk_weights_reshaped = ttnn .to_layout (topk_weights_rm , ttnn .TILE_LAYOUT )
372389 ttnn .deallocate (topk_weights_rm )
373390
391+ if topk_weights_reshaped .shape [2 ] != post_combine .shape [2 ]:
392+ topk_weights_reshaped = ttnn .slice (
393+ topk_weights_reshaped ,
394+ [0 , 0 , 0 , 0 ],
395+ [topk_weights_reshaped .shape [0 ], 1 , post_combine .shape [2 ], 1 ],
396+ )
397+
374398 # Weighted sum: sum_k(expert_output_k * routing_weight_k)
375399 weighted_output = ttnn .mul (post_combine , topk_weights_reshaped , memory_config = memory_config )
376400 ttnn .deallocate (post_combine )
0 commit comments