@@ -343,14 +343,20 @@ def decode_forward(
343343 # With output_shard_dim=2, output has tokens sharded on dim -2:
344344 # Output shape: [num_experts_per_tok, 1, tokens_per_device, H]
345345 # (each token gets outputs from k experts stacked in first dimension)
346- combine_output = ttnn .all_to_all_combine (
347- expert_output ,
348- dispatch_metadata ,
349- expert_mapping_tensors ,
350- ** combine_config .as_dict (),
351- )
352- ttnn .deallocate (expert_output )
353- ttnn .deallocate (dispatch_metadata )
346+ # Debug: bypass all_to_all_combine only.
347+ disable_combine = True
348+ if disable_combine :
349+ combine_output = expert_output
350+ ttnn .deallocate (dispatch_metadata )
351+ else :
352+ combine_output = ttnn .all_to_all_combine (
353+ expert_output ,
354+ dispatch_metadata ,
355+ expert_mapping_tensors ,
356+ ** combine_config .as_dict (),
357+ )
358+ ttnn .deallocate (expert_output )
359+ ttnn .deallocate (dispatch_metadata )
354360
355361 # ==========================================================================
356362 # STEP 8: APPLY ROUTING WEIGHTS AND REDUCE ACROSS EXPERTS
@@ -371,6 +377,21 @@ def decode_forward(
371377 topk_weights_reshaped = ttnn .to_layout (topk_weights_rm , ttnn .TILE_LAYOUT )
372378 ttnn .deallocate (topk_weights_rm )
373379
380+ target_k = min (topk_weights_reshaped .shape [0 ], post_combine .shape [0 ])
381+ target_tokens = min (topk_weights_reshaped .shape [2 ], post_combine .shape [2 ])
382+ if post_combine .shape [0 ] != target_k or post_combine .shape [2 ] != target_tokens :
383+ post_combine = ttnn .slice (
384+ post_combine ,
385+ [0 , 0 , 0 , 0 ],
386+ [target_k , 1 , target_tokens , post_combine .shape [3 ]],
387+ )
388+ if topk_weights_reshaped .shape [0 ] != target_k or topk_weights_reshaped .shape [2 ] != target_tokens :
389+ topk_weights_reshaped = ttnn .slice (
390+ topk_weights_reshaped ,
391+ [0 , 0 , 0 , 0 ],
392+ [target_k , 1 , target_tokens , 1 ],
393+ )
394+
374395 # Weighted sum: sum_k(expert_output_k * routing_weight_k)
375396 weighted_output = ttnn .mul (post_combine , topk_weights_reshaped , memory_config = memory_config )
376397 ttnn .deallocate (post_combine )
0 commit comments