[gpt-oss] perf optimization: all to all ops with tokens on dim -2#36720
[gpt-oss] perf optimization: all to all ops with tokens on dim -2#36720handrewsTT merged 6 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR optimizes the GPT MoE expert throughput implementation by reducing memory access latency and minimizing tensor reshape operations during decode.
Changes:
- Switched decode memory configuration from DRAM to L1 for improved throughput
- Refactored decode forward pass to maintain tokens on seq_len dimension (dim -2) throughout the pipeline, reducing reshape operations
- Updated all_to_all dispatch/combine configurations to use output_concat_dim=2 and output_shard_dim=2 for consistency with the new token dimension strategy
- Reduced prefill chunk_size from 2048 to 512 as a workaround for diverging outputs (GitHub issue #36335)
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| models/demos/gpt_oss/tt/mlp.py | Changed decode_memory_config from DRAM to L1 for better decode throughput |
| models/demos/gpt_oss/tt/experts_throughput/prefill.py | Reduced chunk_size to 512 as temporary workaround for divergence issue |
| models/demos/gpt_oss/tt/experts_throughput/decode.py | Major refactor to keep tokens on seq_len dimension, reducing reshape operations and improving performance |
| models/demos/gpt_oss/tt/experts_throughput/config.py | Updated all_to_all configs to use dim 2 for both concat and shard operations, consistent with decode changes |
| while len(expert_output_sparse.shape) > 4: | ||
| expert_output_sparse = ttnn.squeeze(expert_output_sparse, 0) |
There was a problem hiding this comment.
This while loop to squeeze extra dimensions seems defensive but may indicate uncertainty about the output shape of the sparse_matmul. Consider documenting what conditions would cause expert_output_sparse to have more than 4 dimensions, or if this can be simplified to a fixed number of squeeze operations if the shape is always predictable.
| while len(expert_output_sparse.shape) > 4: | |
| expert_output_sparse = ttnn.squeeze(expert_output_sparse, 0) | |
| # sparse_matmul may introduce a leading singleton batch dimension, yielding: | |
| # [1, total_tokens/block, experts, block, H]. We only expect at most one such | |
| # dimension here; enforce that contract explicitly rather than squeezing in a loop. | |
| if len(expert_output_sparse.shape) == 5: | |
| expert_output_sparse = ttnn.squeeze(expert_output_sparse, 0) | |
| elif len(expert_output_sparse.shape) != 4: | |
| raise RuntimeError( | |
| f"Unexpected expert_output_sparse rank {len(expert_output_sparse.shape)}; " | |
| "expected 4D or 5D with a leading singleton batch dimension." | |
| ) |
| program_config: ThroughputProgramConfig, | ||
| mesh_device, | ||
| chunk_size: int = 2048, | ||
| chunk_size: int = 512, # TODO: increasing this causes diverging outputs for last mesh row (https://github.com/tenstorrent/tt-metal/issues/36335) |
There was a problem hiding this comment.
The chunk_size reduction from 2048 to 512 is a workaround for diverging outputs rather than an optimization. This TODO references issue #36335, suggesting this is a temporary fix. Consider adding a more prominent warning or tracking mechanism to ensure this gets reverted once the underlying issue is resolved, as the smaller chunk size may impact performance.
| @@ -112,60 +112,68 @@ | |||
| Returns: | |||
| Output tensor [batch_size_per_device, 1, seq_len, hidden_size] | |||
There was a problem hiding this comment.
The docstring states the output tensor has shape [batch_size_per_device, 1, seq_len, hidden_size], but the actual implementation returns [1, 1, tokens_per_device, hidden_size] where tokens_per_device = batch_size_per_device * seq_len. The docstring should be updated to reflect the actual output shape, or the implementation should reshape the output back to match the documented shape if that's what callers expect.
| Output tensor [batch_size_per_device, 1, seq_len, hidden_size] | |
| Output tensor [1, 1, tokens_per_device, hidden_size], where | |
| tokens_per_device = batch_size_per_device * seq_len. |
3e304a9 to
6f05076
Compare
…nstorrent#36720) [](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-unit-tests.yaml) [](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-demo-tests.yaml) --------- Co-authored-by: handrewsTT <handrews@tenstorrent.com>
…6720) [](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-unit-tests.yaml) [](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-demo-tests.yaml) --------- Co-authored-by: handrewsTT <handrews@tenstorrent.com>
…6720) [](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-unit-tests.yaml) [](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-demo-tests.yaml) --------- Co-authored-by: handrewsTT <handrews@tenstorrent.com>
Uh oh!
There was an error while loading. Please reload this page.