Skip to content

Commit 9575244

Browse files
committed
no dispatch
1 parent 5f59815 commit 9575244

File tree

1 file changed

+33
-9
lines changed
  • models/demos/gpt_oss/tt/experts_throughput

1 file changed

+33
-9
lines changed

models/demos/gpt_oss/tt/experts_throughput/decode.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ def decode_forward(
145145
)
146146
total_tokens = tokens_per_device * num_dispatch_devices # Global tokens across dispatch axis
147147

148+
# Debug: bypass all_to_all_dispatch only.
149+
disable_dispatch = True
150+
if disable_dispatch:
151+
num_dispatch_devices = 1
152+
total_tokens = tokens_per_device
153+
148154
# ==========================================================================
149155
# STEP 1: PREPARE INPUTS FOR ALL_TO_ALL_DISPATCH
150156
# ==========================================================================
@@ -174,14 +180,18 @@ def decode_forward(
174180
# With output_concat_dim=2, outputs have seq_len scaled:
175181
# - dispatch_output: [D, 1, total_tokens, H] - tokens scattered to expert devices
176182
# - dispatch_metadata: [D, 1, total_tokens, K] - expert indices (for combine routing)
177-
dispatch_output, dispatch_metadata = ttnn.all_to_all_dispatch(
178-
hidden_rm,
179-
topk_indices_rm,
180-
expert_mapping_tensors,
181-
**dispatch_config.as_dict(),
182-
)
183-
ttnn.deallocate(hidden_rm)
184-
ttnn.deallocate(topk_indices_rm)
183+
if disable_dispatch:
184+
dispatch_output = hidden_rm
185+
dispatch_metadata = topk_indices_rm
186+
else:
187+
dispatch_output, dispatch_metadata = ttnn.all_to_all_dispatch(
188+
hidden_rm,
189+
topk_indices_rm,
190+
expert_mapping_tensors,
191+
**dispatch_config.as_dict(),
192+
)
193+
ttnn.deallocate(hidden_rm)
194+
ttnn.deallocate(topk_indices_rm)
185195

186196
# ==========================================================================
187197
# STEP 3: MOE_EXPERT_TOKEN_REMAP - Create sparsity pattern
@@ -195,7 +205,14 @@ def decode_forward(
195205
# -> repeat to [1, dispatch_rows, tokens_per_device, num_experts]
196206
# -> reshape to [1, 1, total_tokens, num_experts] to match dispatch_metadata batch/seq dims
197207
remap_mask = ttnn.repeat(remap_topk_mask, ttnn.Shape((1, 1, tokens_per_device, 1)))
198-
remap_mask = ttnn.reshape(remap_mask, (1, 1, total_tokens, config.num_experts))
208+
if disable_dispatch:
209+
remap_mask = ttnn.slice(
210+
remap_mask,
211+
[0, 0, 0, 0],
212+
[1, 1, tokens_per_device, config.num_experts],
213+
)
214+
else:
215+
remap_mask = ttnn.reshape(remap_mask, (1, 1, total_tokens, config.num_experts))
199216
# moe_expert_token_remap returns:
200217
# - mapping: [D, tokens, 1, experts_per_device] - local expert activation weights
201218
# - sparsity: [D, 1, tokens/reduction_size, experts_per_device] - which blocks are active
@@ -371,6 +388,13 @@ def decode_forward(
371388
topk_weights_reshaped = ttnn.to_layout(topk_weights_rm, ttnn.TILE_LAYOUT)
372389
ttnn.deallocate(topk_weights_rm)
373390

391+
if topk_weights_reshaped.shape[2] != post_combine.shape[2]:
392+
topk_weights_reshaped = ttnn.slice(
393+
topk_weights_reshaped,
394+
[0, 0, 0, 0],
395+
[topk_weights_reshaped.shape[0], 1, post_combine.shape[2], 1],
396+
)
397+
374398
# Weighted sum: sum_k(expert_output_k * routing_weight_k)
375399
weighted_output = ttnn.mul(post_combine, topk_weights_reshaped, memory_config=memory_config)
376400
ttnn.deallocate(post_combine)

0 commit comments

Comments
 (0)