@@ -145,6 +145,12 @@ def decode_forward(
145145 )
146146 total_tokens = tokens_per_device * num_dispatch_devices # Global tokens across dispatch axis
147147
148+ # Debug: bypass all_to_all dispatch/combine and moe token remap to isolate hangs.
149+ disable_moe_routing = True
150+ if disable_moe_routing :
151+ num_dispatch_devices = 1
152+ total_tokens = tokens_per_device
153+
148154 # ==========================================================================
149155 # STEP 1: PREPARE INPUTS FOR ALL_TO_ALL_DISPATCH
150156 # ==========================================================================
@@ -156,32 +162,37 @@ def decode_forward(
156162 # Shape is already [1, 1, tokens_per_device, H], just ensure it's correct
157163 hidden_rm = ttnn .reshape (hidden_rm , shape = (1 , 1 , tokens_per_device , config .hidden_size ))
158164
159- # Expert indices: [1, 1, tokens_per_device, K]
160- topk_indices_rm = ttnn .to_layout (topk_expert_indices , ttnn .ROW_MAJOR_LAYOUT )
161- ttnn .deallocate (topk_expert_indices )
162- topk_indices_rm = ttnn .reshape (topk_indices_rm , shape = (1 , 1 , tokens_per_device , config .num_experts_per_tok ))
163-
164- # ==========================================================================
165- # STEP 2: ALL_TO_ALL_DISPATCH - Route tokens to expert devices
166- # ==========================================================================
167- # Dispatch sends each token to the device(s) that own its assigned expert(s)
168- #
169- # Inputs (tokens on dim -2):
170- # - hidden_rm: [1, 1, tokens_per_device, H] - token embeddings
171- # - topk_indices_rm: [1, 1, tokens_per_device, K] - which experts each token routes to
172- # - expert_mapping_tensors: [1, 1, E, D] - one-hot mapping of expert -> device
173- #
174- # With output_concat_dim=2, outputs have seq_len scaled:
175- # - dispatch_output: [D, 1, total_tokens, H] - tokens scattered to expert devices
176- # - dispatch_metadata: [D, 1, total_tokens, K] - expert indices (for combine routing)
177- dispatch_output , dispatch_metadata = ttnn .all_to_all_dispatch (
178- hidden_rm ,
179- topk_indices_rm ,
180- expert_mapping_tensors ,
181- ** dispatch_config .as_dict (),
182- )
183- ttnn .deallocate (hidden_rm )
184- ttnn .deallocate (topk_indices_rm )
165+ if disable_moe_routing :
166+ ttnn .deallocate (topk_expert_indices )
167+ dispatch_output = hidden_rm
168+ dispatch_metadata = None
169+ else :
170+ # Expert indices: [1, 1, tokens_per_device, K]
171+ topk_indices_rm = ttnn .to_layout (topk_expert_indices , ttnn .ROW_MAJOR_LAYOUT )
172+ ttnn .deallocate (topk_expert_indices )
173+ topk_indices_rm = ttnn .reshape (topk_indices_rm , shape = (1 , 1 , tokens_per_device , config .num_experts_per_tok ))
174+
175+ # ==========================================================================
176+ # STEP 2: ALL_TO_ALL_DISPATCH - Route tokens to expert devices
177+ # ==========================================================================
178+ # Dispatch sends each token to the device(s) that own its assigned expert(s)
179+ #
180+ # Inputs (tokens on dim -2):
181+ # - hidden_rm: [1, 1, tokens_per_device, H] - token embeddings
182+ # - topk_indices_rm: [1, 1, tokens_per_device, K] - which experts each token routes to
183+ # - expert_mapping_tensors: [1, 1, E, D] - one-hot mapping of expert -> device
184+ #
185+ # With output_concat_dim=2, outputs have seq_len scaled:
186+ # - dispatch_output: [D, 1, total_tokens, H] - tokens scattered to expert devices
187+ # - dispatch_metadata: [D, 1, total_tokens, K] - expert indices (for combine routing)
188+ dispatch_output , dispatch_metadata = ttnn .all_to_all_dispatch (
189+ hidden_rm ,
190+ topk_indices_rm ,
191+ expert_mapping_tensors ,
192+ ** dispatch_config .as_dict (),
193+ )
194+ ttnn .deallocate (hidden_rm )
195+ ttnn .deallocate (topk_indices_rm )
185196
186197 # ==========================================================================
187198 # STEP 3: MOE_EXPERT_TOKEN_REMAP - Create sparsity pattern
@@ -194,21 +205,24 @@ def decode_forward(
194205 # remap_topk_mask: [1, dispatch_rows, 1, num_experts]
195206 # -> repeat to [1, dispatch_rows, tokens_per_device, num_experts]
196207 # -> reshape to [1, 1, total_tokens, num_experts] to match dispatch_metadata batch/seq dims
197- remap_mask = ttnn .repeat (remap_topk_mask , ttnn .Shape ((1 , 1 , tokens_per_device , 1 )))
198- remap_mask = ttnn .reshape (remap_mask , (1 , 1 , total_tokens , config .num_experts ))
199- # moe_expert_token_remap returns:
200- # - mapping: [D, tokens, 1, experts_per_device] - local expert activation weights
201- # - sparsity: [D, 1, tokens/reduction_size, experts_per_device] - which blocks are active
202- #
203- # The sparsity tensor tells sparse_matmul which expert blocks have tokens,
204- # avoiding computation on empty slots.
205- _ , sparsity = ttnn .moe_expert_token_remap (
206- remap_mask ,
207- expert_mapping_tensors ,
208- dispatch_metadata ,
209- reduction_size = config .sparsity_block_size ,
210- )
211- ttnn .deallocate (remap_mask )
208+ if disable_moe_routing :
209+ sparsity = None
210+ else :
211+ remap_mask = ttnn .repeat (remap_topk_mask , ttnn .Shape ((1 , 1 , tokens_per_device , 1 )))
212+ remap_mask = ttnn .reshape (remap_mask , (1 , 1 , total_tokens , config .num_experts ))
213+ # moe_expert_token_remap returns:
214+ # - mapping: [D, tokens, 1, experts_per_device] - local expert activation weights
215+ # - sparsity: [D, 1, tokens/reduction_size, experts_per_device] - which blocks are active
216+ #
217+ # The sparsity tensor tells sparse_matmul which expert blocks have tokens,
218+ # avoiding computation on empty slots.
219+ _ , sparsity = ttnn .moe_expert_token_remap (
220+ remap_mask ,
221+ expert_mapping_tensors ,
222+ dispatch_metadata ,
223+ reduction_size = config .sparsity_block_size ,
224+ )
225+ ttnn .deallocate (remap_mask )
212226
213227 # ==========================================================================
214228 # STEP 4: PREPARE DISPATCH OUTPUT FOR EXPERT COMPUTATION
@@ -233,6 +247,15 @@ def decode_forward(
233247 shape = (1 , num_sparse_blocks , config .sparsity_block_size , config .hidden_size ),
234248 )
235249
250+ if disable_moe_routing :
251+ sparsity = ttnn .ones (
252+ (1 , 1 , num_sparse_blocks , config .num_experts_per_device ),
253+ dtype = ttnn .bfloat16 ,
254+ layout = ttnn .ROW_MAJOR_LAYOUT ,
255+ device = mesh_device ,
256+ memory_config = dispatch_config .memory_config ,
257+ )
258+
236259 memory_config = dispatch_config .memory_config
237260
238261 # ==========================================================================
@@ -343,14 +366,17 @@ def decode_forward(
343366 # With output_shard_dim=2, output has tokens sharded on dim -2:
344367 # Output shape: [num_experts_per_tok, 1, tokens_per_device, H]
345368 # (each token gets outputs from k experts stacked in first dimension)
346- combine_output = ttnn .all_to_all_combine (
347- expert_output ,
348- dispatch_metadata ,
349- expert_mapping_tensors ,
350- ** combine_config .as_dict (),
351- )
352- ttnn .deallocate (expert_output )
353- ttnn .deallocate (dispatch_metadata )
369+ if disable_moe_routing :
370+ combine_output = expert_output
371+ else :
372+ combine_output = ttnn .all_to_all_combine (
373+ expert_output ,
374+ dispatch_metadata ,
375+ expert_mapping_tensors ,
376+ ** combine_config .as_dict (),
377+ )
378+ ttnn .deallocate (expert_output )
379+ ttnn .deallocate (dispatch_metadata )
354380
355381 # ==========================================================================
356382 # STEP 8: APPLY ROUTING WEIGHTS AND REDUCE ACROSS EXPERTS
0 commit comments