Skip to content

Commit 29a2e9b

Browse files
committed
rm ops
1 parent 7242da9 commit 29a2e9b

File tree

1 file changed

+75
-49
lines changed
  • models/demos/gpt_oss/tt/experts_throughput

1 file changed

+75
-49
lines changed

models/demos/gpt_oss/tt/experts_throughput/decode.py

Lines changed: 75 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ def decode_forward(
145145
)
146146
total_tokens = tokens_per_device * num_dispatch_devices # Global tokens across dispatch axis
147147

148+
# Debug: bypass all_to_all dispatch/combine and moe token remap to isolate hangs.
149+
disable_moe_routing = True
150+
if disable_moe_routing:
151+
num_dispatch_devices = 1
152+
total_tokens = tokens_per_device
153+
148154
# ==========================================================================
149155
# STEP 1: PREPARE INPUTS FOR ALL_TO_ALL_DISPATCH
150156
# ==========================================================================
@@ -156,32 +162,37 @@ def decode_forward(
156162
# Shape is already [1, 1, tokens_per_device, H], just ensure it's correct
157163
hidden_rm = ttnn.reshape(hidden_rm, shape=(1, 1, tokens_per_device, config.hidden_size))
158164

159-
# Expert indices: [1, 1, tokens_per_device, K]
160-
topk_indices_rm = ttnn.to_layout(topk_expert_indices, ttnn.ROW_MAJOR_LAYOUT)
161-
ttnn.deallocate(topk_expert_indices)
162-
topk_indices_rm = ttnn.reshape(topk_indices_rm, shape=(1, 1, tokens_per_device, config.num_experts_per_tok))
163-
164-
# ==========================================================================
165-
# STEP 2: ALL_TO_ALL_DISPATCH - Route tokens to expert devices
166-
# ==========================================================================
167-
# Dispatch sends each token to the device(s) that own its assigned expert(s)
168-
#
169-
# Inputs (tokens on dim -2):
170-
# - hidden_rm: [1, 1, tokens_per_device, H] - token embeddings
171-
# - topk_indices_rm: [1, 1, tokens_per_device, K] - which experts each token routes to
172-
# - expert_mapping_tensors: [1, 1, E, D] - one-hot mapping of expert -> device
173-
#
174-
# With output_concat_dim=2, outputs have seq_len scaled:
175-
# - dispatch_output: [D, 1, total_tokens, H] - tokens scattered to expert devices
176-
# - dispatch_metadata: [D, 1, total_tokens, K] - expert indices (for combine routing)
177-
dispatch_output, dispatch_metadata = ttnn.all_to_all_dispatch(
178-
hidden_rm,
179-
topk_indices_rm,
180-
expert_mapping_tensors,
181-
**dispatch_config.as_dict(),
182-
)
183-
ttnn.deallocate(hidden_rm)
184-
ttnn.deallocate(topk_indices_rm)
165+
if disable_moe_routing:
166+
ttnn.deallocate(topk_expert_indices)
167+
dispatch_output = hidden_rm
168+
dispatch_metadata = None
169+
else:
170+
# Expert indices: [1, 1, tokens_per_device, K]
171+
topk_indices_rm = ttnn.to_layout(topk_expert_indices, ttnn.ROW_MAJOR_LAYOUT)
172+
ttnn.deallocate(topk_expert_indices)
173+
topk_indices_rm = ttnn.reshape(topk_indices_rm, shape=(1, 1, tokens_per_device, config.num_experts_per_tok))
174+
175+
# ==========================================================================
176+
# STEP 2: ALL_TO_ALL_DISPATCH - Route tokens to expert devices
177+
# ==========================================================================
178+
# Dispatch sends each token to the device(s) that own its assigned expert(s)
179+
#
180+
# Inputs (tokens on dim -2):
181+
# - hidden_rm: [1, 1, tokens_per_device, H] - token embeddings
182+
# - topk_indices_rm: [1, 1, tokens_per_device, K] - which experts each token routes to
183+
# - expert_mapping_tensors: [1, 1, E, D] - one-hot mapping of expert -> device
184+
#
185+
# With output_concat_dim=2, outputs have seq_len scaled:
186+
# - dispatch_output: [D, 1, total_tokens, H] - tokens scattered to expert devices
187+
# - dispatch_metadata: [D, 1, total_tokens, K] - expert indices (for combine routing)
188+
dispatch_output, dispatch_metadata = ttnn.all_to_all_dispatch(
189+
hidden_rm,
190+
topk_indices_rm,
191+
expert_mapping_tensors,
192+
**dispatch_config.as_dict(),
193+
)
194+
ttnn.deallocate(hidden_rm)
195+
ttnn.deallocate(topk_indices_rm)
185196

186197
# ==========================================================================
187198
# STEP 3: MOE_EXPERT_TOKEN_REMAP - Create sparsity pattern
@@ -194,21 +205,24 @@ def decode_forward(
194205
# remap_topk_mask: [1, dispatch_rows, 1, num_experts]
195206
# -> repeat to [1, dispatch_rows, tokens_per_device, num_experts]
196207
# -> reshape to [1, 1, total_tokens, num_experts] to match dispatch_metadata batch/seq dims
197-
remap_mask = ttnn.repeat(remap_topk_mask, ttnn.Shape((1, 1, tokens_per_device, 1)))
198-
remap_mask = ttnn.reshape(remap_mask, (1, 1, total_tokens, config.num_experts))
199-
# moe_expert_token_remap returns:
200-
# - mapping: [D, tokens, 1, experts_per_device] - local expert activation weights
201-
# - sparsity: [D, 1, tokens/reduction_size, experts_per_device] - which blocks are active
202-
#
203-
# The sparsity tensor tells sparse_matmul which expert blocks have tokens,
204-
# avoiding computation on empty slots.
205-
_, sparsity = ttnn.moe_expert_token_remap(
206-
remap_mask,
207-
expert_mapping_tensors,
208-
dispatch_metadata,
209-
reduction_size=config.sparsity_block_size,
210-
)
211-
ttnn.deallocate(remap_mask)
208+
if disable_moe_routing:
209+
sparsity = None
210+
else:
211+
remap_mask = ttnn.repeat(remap_topk_mask, ttnn.Shape((1, 1, tokens_per_device, 1)))
212+
remap_mask = ttnn.reshape(remap_mask, (1, 1, total_tokens, config.num_experts))
213+
# moe_expert_token_remap returns:
214+
# - mapping: [D, tokens, 1, experts_per_device] - local expert activation weights
215+
# - sparsity: [D, 1, tokens/reduction_size, experts_per_device] - which blocks are active
216+
#
217+
# The sparsity tensor tells sparse_matmul which expert blocks have tokens,
218+
# avoiding computation on empty slots.
219+
_, sparsity = ttnn.moe_expert_token_remap(
220+
remap_mask,
221+
expert_mapping_tensors,
222+
dispatch_metadata,
223+
reduction_size=config.sparsity_block_size,
224+
)
225+
ttnn.deallocate(remap_mask)
212226

213227
# ==========================================================================
214228
# STEP 4: PREPARE DISPATCH OUTPUT FOR EXPERT COMPUTATION
@@ -233,6 +247,15 @@ def decode_forward(
233247
shape=(1, num_sparse_blocks, config.sparsity_block_size, config.hidden_size),
234248
)
235249

250+
if disable_moe_routing:
251+
sparsity = ttnn.ones(
252+
(1, 1, num_sparse_blocks, config.num_experts_per_device),
253+
dtype=ttnn.bfloat16,
254+
layout=ttnn.ROW_MAJOR_LAYOUT,
255+
device=mesh_device,
256+
memory_config=dispatch_config.memory_config,
257+
)
258+
236259
memory_config = dispatch_config.memory_config
237260

238261
# ==========================================================================
@@ -343,14 +366,17 @@ def decode_forward(
343366
# With output_shard_dim=2, output has tokens sharded on dim -2:
344367
# Output shape: [num_experts_per_tok, 1, tokens_per_device, H]
345368
# (each token gets outputs from k experts stacked in first dimension)
346-
combine_output = ttnn.all_to_all_combine(
347-
expert_output,
348-
dispatch_metadata,
349-
expert_mapping_tensors,
350-
**combine_config.as_dict(),
351-
)
352-
ttnn.deallocate(expert_output)
353-
ttnn.deallocate(dispatch_metadata)
369+
if disable_moe_routing:
370+
combine_output = expert_output
371+
else:
372+
combine_output = ttnn.all_to_all_combine(
373+
expert_output,
374+
dispatch_metadata,
375+
expert_mapping_tensors,
376+
**combine_config.as_dict(),
377+
)
378+
ttnn.deallocate(expert_output)
379+
ttnn.deallocate(dispatch_metadata)
354380

355381
# ==========================================================================
356382
# STEP 8: APPLY ROUTING WEIGHTS AND REDUCE ACROSS EXPERTS

0 commit comments

Comments
 (0)