Skip to content

Commit aac9b48

Browse files
committed
no combine
1 parent 5f59815 commit aac9b48

File tree

1 file changed

+29
-8
lines changed
  • models/demos/gpt_oss/tt/experts_throughput

1 file changed

+29
-8
lines changed

models/demos/gpt_oss/tt/experts_throughput/decode.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -343,14 +343,20 @@ def decode_forward(
343343
# With output_shard_dim=2, output has tokens sharded on dim -2:
344344
# Output shape: [num_experts_per_tok, 1, tokens_per_device, H]
345345
# (each token gets outputs from k experts stacked in first dimension)
346-
combine_output = ttnn.all_to_all_combine(
347-
expert_output,
348-
dispatch_metadata,
349-
expert_mapping_tensors,
350-
**combine_config.as_dict(),
351-
)
352-
ttnn.deallocate(expert_output)
353-
ttnn.deallocate(dispatch_metadata)
346+
# Debug: bypass all_to_all_combine only.
347+
disable_combine = True
348+
if disable_combine:
349+
combine_output = expert_output
350+
ttnn.deallocate(dispatch_metadata)
351+
else:
352+
combine_output = ttnn.all_to_all_combine(
353+
expert_output,
354+
dispatch_metadata,
355+
expert_mapping_tensors,
356+
**combine_config.as_dict(),
357+
)
358+
ttnn.deallocate(expert_output)
359+
ttnn.deallocate(dispatch_metadata)
354360

355361
# ==========================================================================
356362
# STEP 8: APPLY ROUTING WEIGHTS AND REDUCE ACROSS EXPERTS
@@ -371,6 +377,21 @@ def decode_forward(
371377
topk_weights_reshaped = ttnn.to_layout(topk_weights_rm, ttnn.TILE_LAYOUT)
372378
ttnn.deallocate(topk_weights_rm)
373379

380+
target_k = min(topk_weights_reshaped.shape[0], post_combine.shape[0])
381+
target_tokens = min(topk_weights_reshaped.shape[2], post_combine.shape[2])
382+
if post_combine.shape[0] != target_k or post_combine.shape[2] != target_tokens:
383+
post_combine = ttnn.slice(
384+
post_combine,
385+
[0, 0, 0, 0],
386+
[target_k, 1, target_tokens, post_combine.shape[3]],
387+
)
388+
if topk_weights_reshaped.shape[0] != target_k or topk_weights_reshaped.shape[2] != target_tokens:
389+
topk_weights_reshaped = ttnn.slice(
390+
topk_weights_reshaped,
391+
[0, 0, 0, 0],
392+
[target_k, 1, target_tokens, 1],
393+
)
394+
374395
# Weighted sum: sum_k(expert_output_k * routing_weight_k)
375396
weighted_output = ttnn.mul(post_combine, topk_weights_reshaped, memory_config=memory_config)
376397
ttnn.deallocate(post_combine)

0 commit comments

Comments
 (0)