Conversation
...imental/ccl/moe/selective_reduce_combine/device/selective_reduce_combine_program_factory.cpp
Outdated
Show resolved
Hide resolved
...erations/experimental/ccl/moe/selective_reduce_combine/selective_reduce_combine_nanobind.cpp
Outdated
Show resolved
Hide resolved
...erations/experimental/ccl/moe/selective_reduce_combine/selective_reduce_combine_nanobind.cpp
Outdated
Show resolved
Hide resolved
...mental/ccl/moe/selective_reduce_combine/device/selective_reduce_combine_device_operation.cpp
Outdated
Show resolved
Hide resolved
...mental/ccl/moe/selective_reduce_combine/device/selective_reduce_combine_device_operation.hpp
Outdated
Show resolved
Hide resolved
...mental/ccl/moe/selective_reduce_combine/device/selective_reduce_combine_device_operation.cpp
Outdated
Show resolved
Hide resolved
...mental/ccl/moe/selective_reduce_combine/device/selective_reduce_combine_device_operation.hpp
Outdated
Show resolved
Hide resolved
...mental/ccl/moe/selective_reduce_combine/device/selective_reduce_combine_device_operation.cpp
Outdated
Show resolved
Hide resolved
...p/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/selective_reduce_combine.cpp
Outdated
Show resolved
Hide resolved
...p/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/selective_reduce_combine.hpp
Outdated
Show resolved
Hide resolved
...p/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/selective_reduce_combine.cpp
Outdated
Show resolved
Hide resolved
...p/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/selective_reduce_combine.hpp
Outdated
Show resolved
Hide resolved
...imental/ccl/moe/selective_reduce_combine/device/selective_reduce_combine_program_factory.cpp
Show resolved
Hide resolved
...imental/ccl/moe/selective_reduce_combine/device/selective_reduce_combine_program_factory.cpp
Outdated
Show resolved
Hide resolved
|
/codeowners ping |
There was a problem hiding this comment.
Pull request overview
Adds a new experimental TTNN CCL op to support the MoE inference pipeline step that sparsifies dense expert contributions and returns tokens to their originating devices, using a fabric mux-based combine path.
Changes:
- Introduces
ttnn.experimental.selective_reduce_combine(C++ op + device op + dataflow reader/writer kernels) and binds it via nanobind. - Extends shared CCL kernel utilities for mux teardown and adds a bidirectional multicast atomic-inc helper for 1D ring.
- Adds Galaxy MoE nightly tests and adjusts Galaxy e2e pipeline configuration to run MoE tests separately.
Reviewed changes
Copilot reviewed 20 out of 20 changed files in this pull request and generated 16 comments.
Show a summary per file
| File | Description |
|---|---|
| ttnn/cpp/ttnn/operations/experimental/experimental_nanobind.cpp | Minor cleanup in experimental nanobind module registration. |
| ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/selective_reduce_combine_nanobind.hpp | Declares nanobind binding entrypoint for the new op. |
| ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/selective_reduce_combine_nanobind.cpp | Adds Python binding + docstring for selective_reduce_combine. |
| ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/selective_reduce_combine.hpp | Declares/registers the TTNN operation. |
| ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/selective_reduce_combine.cpp | Implements host-side invoke forwarding into the prim/device op. |
| ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/device/selective_reduce_combine_device_operation.hpp | Defines the device operation interface and prim API. |
| ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/device/selective_reduce_combine_device_operation.cpp | Implements validation + output spec/tensor creation + prim launch. |
| ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/device/selective_reduce_combine_program_factory.cpp | Program factory building CBs, mux workers, and setting runtime args. |
| ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/device/kernels/dataflow/reader.cpp | Kernel to read token counts/maps and compute per-core work splits. |
| ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/device/kernels/dataflow/writer.cpp | Kernel to send token segments locally or over fabric mux and teardown. |
| ttnn/cpp/ttnn/operations/experimental/ccl/ccl_experimental_nanobind.cpp | Registers the new op under the experimental CCL nanobind module. |
| ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt | Adds new sources/kernels to the experimental CCL build target. |
| ttnn/CMakeLists.txt | Adds new nanobind source to TTNN build. |
| ttnn/cpp/ttnn/operations/ccl/common/kernels/moe_utils.hpp | Extends mux helper arg parsing/teardown; adds bidirectional atomic-inc helper. |
| ttnn/cpp/ttnn/operations/ccl/common/kernels/minimal_ccl_common.hpp | Extends perform_payload_send template to optionally skip flush. |
| tt_metal/hw/inc/api/debug/dprint_pages.h | Adds print_u32_pages helper for debugging. |
| tt_metal/fabric/hw/inc/linear/addrgen_api.h | Adds a ShardedAddrGen using-declaration and updates a helper signature. |
| tests/pipeline_reorg/galaxy_e2e_tests.yaml | Splits Galaxy CCL vs MoE tests; adds MoE-specific environment and timeout. |
| tests/nightly/tg/ccl/moe/test_selective_combine_6U.py | Adds correctness + perf/trace tests for selective reduce combine on Galaxy. |
| tests/nightly/t3000/ccl/test_all_to_all_combine.py | Exposes cluster-dimension helpers reused by the new MoE test. |
...imental/ccl/moe/selective_reduce_combine/device/selective_reduce_combine_program_factory.cpp
Show resolved
Hide resolved
...imental/ccl/moe/selective_reduce_combine/device/selective_reduce_combine_program_factory.cpp
Outdated
Show resolved
Hide resolved
...erations/experimental/ccl/moe/selective_reduce_combine/selective_reduce_combine_nanobind.cpp
Outdated
Show resolved
Hide resolved
|
|
||
|
|
||
| Generally expect num_data_parallel_dim=num_token_parallel_dim=4 | ||
| This can be though of as a logical grid by a physical grid is not required as long as the ordering is |
There was a problem hiding this comment.
Docstring typo: “though of” should be “thought of”.
| This can be though of as a logical grid by a physical grid is not required as long as the ordering is | |
| This can be thought of as a logical grid by a physical grid is not required as long as the ordering is |
|
|
||
| logger.info(f"Capturing Warmup iterations") | ||
| trace_id_warmup = ttnn.begin_trace_capture(mesh_device, cq_id=0) | ||
| tt_out = op_func(max(1, num_iters // 4)) |
CodeOwners Group AnalysisThis PR requires approval from one member of each of the following groups: Summary: 6 pending groups, 0 approved groups Group Information:
Note: At least one approval from each group is sufficient. |
fa4e393 to
b80fa57
Compare
|
|
||
| template <typename ShardingInfoType> | ||
| uint32_t get_page_size(const experimental::ShardedAddrGen<ShardingInfoType>& d) { | ||
| uint32_t get_page_size(const _ttnn_operations_experimental_ShardedAddrGen<ShardingInfoType>& d) { |
There was a problem hiding this comment.
I see a new using added, but its a weird using.
Whats going on here? :)
| } | ||
| } | ||
|
|
||
| inline void print_u32_pages(uint32_t l1_addr, uint32_t elts_per_page, uint32_t npages, uint32_t start = 0) { |
There was a problem hiding this comment.
@jbaumanTT , @akerteszTT can you look at this addition to hw/inc/api please?
| // DPRINT << "OPENING MUX CORE: " << (uint32_t)args.fabric_mux_x << ", " << (uint32_t)args.fabric_mux_y | ||
| // << "\n"; |
| // DPRINT << "CLOSING MUX CORE: " << (uint32_t)args.fabric_mux_x << ", " << (uint32_t)args.fabric_mux_y | ||
| // << "\n"; |
| class SenderType = WorkerToFabricEdmSender> | ||
| FORCE_INLINE void fabric_multicast_bidirectional_atomic_inc_ring_1d( | ||
| std::array<SenderType, 4>& fabric_connections, | ||
| volatile PACKET_HEADER_TYPE* packet_header_pos, |
There was a problem hiding this comment.
I am a little out of the loop, what is PACKET_HEADER_TYPE? why is it all caps? is it a macro?
| namespace ttnn { | ||
| namespace operations::experimental::ccl::moe { | ||
|
|
||
| struct ExecuteSelectiveReduceCombine { |
There was a problem hiding this comment.
please remove the struct with invoke and register_operation
see what we did here
#36303
| const uint32_t hidden_size; | ||
| const uint32_t batch_size; | ||
| const uint32_t seq_size; | ||
| const uint32_t select_experts_k; | ||
| const uint32_t experts; | ||
| const uint32_t num_links; | ||
|
|
||
| const std::optional<uint32_t> axis; | ||
| tt::tt_fabric::Topology topology; | ||
|
|
||
| const uint32_t num_token_parallel_cores; | ||
| const uint32_t num_data_parallel_cores; | ||
| const CoreRangeSet worker_core_range_set; | ||
| const CoreRangeSet mux_core_range_set; | ||
| const ttnn::MemoryConfig output_memory_config; | ||
| const std::optional<GlobalSemaphore> optional_cross_device_semaphore; |
There was a problem hiding this comment.
these params should not really be const.
if we want things const the whole struct is passed like const
| struct tensor_args_t { | ||
| const ttnn::Tensor dense_input_tensor; | ||
| const ttnn::Tensor dense_metadata_tensor; | ||
| const ttnn::Tensor dense_token_maps_tensor; | ||
| const ttnn::Tensor dense_token_counts_tensor; | ||
| const std::optional<ttnn::Tensor> optional_output_tensor; | ||
| }; |
There was a problem hiding this comment.
same note for const-ness
| struct UnifiedSelectReduce { | ||
| // Shared variables are the variables that are shared between the create and override_runtime_arguments methods | ||
| struct shared_variables_t { | ||
| tt::tt_metal::KernelHandle reader_kernel_id; | ||
| tt::tt_metal::KernelHandle writer_kernel_id; | ||
| std::vector<CoreCoord> cores; | ||
| const GlobalSemaphore init_semaphore; | ||
| const GlobalSemaphore cross_device_semaphore; | ||
| }; | ||
| using cached_mesh_workload_t = ttnn::device_operation::AdaptedCachedMeshWorkload<shared_variables_t>; | ||
|
|
||
| static cached_mesh_workload_t create_mesh_workload( | ||
| const operation_attributes_t& operation_attributes, | ||
| const ttnn::MeshCoordinateRangeSet& tensor_coords, | ||
| const tensor_args_t& tensor_args, | ||
| tensor_return_value_t& tensor_return_value); | ||
|
|
||
| static ttnn::device_operation::CachedProgram<shared_variables_t> create_at( | ||
| const operation_attributes_t& operation_attributes, | ||
| const ttnn::MeshCoordinate& mesh_coordinate, | ||
| const std::vector<ttnn::MeshCoordinate>& all_mesh_coordinates, | ||
| const tensor_args_t& tensor_args, | ||
| tensor_return_value_t& tensor_return_value, | ||
| const GlobalSemaphore& init_semaphore, | ||
| const GlobalSemaphore& cross_device_semaphore); | ||
|
|
||
| static void override_runtime_arguments( | ||
| cached_mesh_workload_t& cached_workload, | ||
| const operation_attributes_t& operation_attributes, | ||
| const tensor_args_t& tensor_args, | ||
| tensor_return_value_t& tensor_return_value); | ||
| }; |
There was a problem hiding this comment.
please move factory to its own .hpp/.cpp
| // Mandatory methods | ||
|
|
||
| // Select the program factory based on the operation attributes and tensor args | ||
| static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); |
There was a problem hiding this comment.
there is a PR from @dgomezTT that removes the need for this when there is a single program
| static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); | ||
|
|
||
| // Empty as there doesn't seem to be any complicated hashing requirement | ||
| static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); |
There was a problem hiding this comment.
there is a PR from @dgomezTT that removes the need for this when it simply calls cache_miss inside
| const std::optional<ttnn::MemoryConfig>& memory_config, | ||
| const std::optional<ttnn::Tensor>& optional_output_tensor, | ||
| const std::optional<GlobalSemaphore>& optional_cross_device_semaphore) { | ||
| auto input_memory_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); |
There was a problem hiding this comment.
that should ideally happen in the prim function i think
There was a problem hiding this comment.
then you can simply do using of that prim function in ttnn namespace and thats it
Ticket
#33832
#33274 (partial)
Problem description
Part of the MoE inference pipeline. Takes dense expert contribution output from compute, sparsify it, and send back to originating devices.
What's changed
New optimized a2a combine op that takes dense input. Uses pre-computed metadata, inputs sharded in L1, and fabric mux over arbitrary cores and all link.
Average case perf ~108 us. But scales poorly (linearly with experts selected per row) for worse cases. Og (with one link at least) was ~3k us
Checklist
Model tests