Skip to content

MoE: Selective reduce combine#37432

Open
amorrisonTT wants to merge 21 commits intomainfrom
amorrison/moe-selective-reduce-combine-mux-rebase
Open

MoE: Selective reduce combine#37432
amorrisonTT wants to merge 21 commits intomainfrom
amorrison/moe-selective-reduce-combine-mux-rebase

Conversation

@amorrisonTT
Copy link
Contributor

@amorrisonTT amorrisonTT commented Feb 9, 2026

Ticket

#33832
#33274 (partial)

Problem description

Part of the MoE inference pipeline. Takes dense expert contribution output from compute, sparsify it, and send back to originating devices.

What's changed

New optimized a2a combine op that takes dense input. Uses pre-computed metadata, inputs sharded in L1, and fabric mux over arbitrary cores and all link.

Average case perf ~108 us. But scales poorly (linearly with experts selected per row) for worse cases. Og (with one link at least) was ~3k us

Checklist

  • All post-commit tests
  • Blackhole Post commit
  • New/Existing tests provide coverage for changes

Model tests

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (1/2)

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (2/2)

@amorrisonTT
Copy link
Contributor Author

/codeowners ping

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new experimental TTNN CCL op to support the MoE inference pipeline step that sparsifies dense expert contributions and returns tokens to their originating devices, using a fabric mux-based combine path.

Changes:

  • Introduces ttnn.experimental.selective_reduce_combine (C++ op + device op + dataflow reader/writer kernels) and binds it via nanobind.
  • Extends shared CCL kernel utilities for mux teardown and adds a bidirectional multicast atomic-inc helper for 1D ring.
  • Adds Galaxy MoE nightly tests and adjusts Galaxy e2e pipeline configuration to run MoE tests separately.

Reviewed changes

Copilot reviewed 20 out of 20 changed files in this pull request and generated 16 comments.

Show a summary per file
File Description
ttnn/cpp/ttnn/operations/experimental/experimental_nanobind.cpp Minor cleanup in experimental nanobind module registration.
ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/selective_reduce_combine_nanobind.hpp Declares nanobind binding entrypoint for the new op.
ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/selective_reduce_combine_nanobind.cpp Adds Python binding + docstring for selective_reduce_combine.
ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/selective_reduce_combine.hpp Declares/registers the TTNN operation.
ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/selective_reduce_combine.cpp Implements host-side invoke forwarding into the prim/device op.
ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/device/selective_reduce_combine_device_operation.hpp Defines the device operation interface and prim API.
ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/device/selective_reduce_combine_device_operation.cpp Implements validation + output spec/tensor creation + prim launch.
ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/device/selective_reduce_combine_program_factory.cpp Program factory building CBs, mux workers, and setting runtime args.
ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/device/kernels/dataflow/reader.cpp Kernel to read token counts/maps and compute per-core work splits.
ttnn/cpp/ttnn/operations/experimental/ccl/moe/selective_reduce_combine/device/kernels/dataflow/writer.cpp Kernel to send token segments locally or over fabric mux and teardown.
ttnn/cpp/ttnn/operations/experimental/ccl/ccl_experimental_nanobind.cpp Registers the new op under the experimental CCL nanobind module.
ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt Adds new sources/kernels to the experimental CCL build target.
ttnn/CMakeLists.txt Adds new nanobind source to TTNN build.
ttnn/cpp/ttnn/operations/ccl/common/kernels/moe_utils.hpp Extends mux helper arg parsing/teardown; adds bidirectional atomic-inc helper.
ttnn/cpp/ttnn/operations/ccl/common/kernels/minimal_ccl_common.hpp Extends perform_payload_send template to optionally skip flush.
tt_metal/hw/inc/api/debug/dprint_pages.h Adds print_u32_pages helper for debugging.
tt_metal/fabric/hw/inc/linear/addrgen_api.h Adds a ShardedAddrGen using-declaration and updates a helper signature.
tests/pipeline_reorg/galaxy_e2e_tests.yaml Splits Galaxy CCL vs MoE tests; adds MoE-specific environment and timeout.
tests/nightly/tg/ccl/moe/test_selective_combine_6U.py Adds correctness + perf/trace tests for selective reduce combine on Galaxy.
tests/nightly/t3000/ccl/test_all_to_all_combine.py Exposes cluster-dimension helpers reused by the new MoE test.



Generally expect num_data_parallel_dim=num_token_parallel_dim=4
This can be though of as a logical grid by a physical grid is not required as long as the ordering is
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring typo: “though of” should be “thought of”.

Suggested change
This can be though of as a logical grid by a physical grid is not required as long as the ordering is
This can be thought of as a logical grid by a physical grid is not required as long as the ordering is

Copilot uses AI. Check for mistakes.

logger.info(f"Capturing Warmup iterations")
trace_id_warmup = ttnn.begin_trace_capture(mesh_device, cq_id=0)
tt_out = op_func(max(1, num_iters // 4))
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assignment to 'tt_out' is unnecessary as it is redefined before this value is used.

Suggested change
tt_out = op_func(max(1, num_iters // 4))
op_func(max(1, num_iters // 4))

Copilot uses AI. Check for mistakes.
@tenstorrent-github-bot
Copy link

CodeOwners Group Analysis

This PR requires approval from one member of each of the following groups:

Summary: 6 pending groups, 0 approved groups

Group Information:




  • tt_metal/fabric/ (Group) - Members: Abhishek Agarwal, Allan Liu, Austin Ho, Ridvan Song, Sean Nijjar, Umair Bilal Cheema, Yu Gao | Pending approval

    📁 Files owned by this group (1 files)

  • tt_metal/hw/inc/ (Group) - Members: Almeet Bhullar, Arik Yaacob, Ata Tuzuner, John Bauman, Kevin Stevens, Nathan Sidwell, Rui Zhang, Vuk Vukomanovic | Pending approval

    📁 Files owned by this group (1 files)

Note: At least one approval from each group is sufficient.

@amorrisonTT amorrisonTT force-pushed the amorrison/moe-selective-reduce-combine-mux-rebase branch from fa4e393 to b80fa57 Compare February 12, 2026 18:55

template <typename ShardingInfoType>
uint32_t get_page_size(const experimental::ShardedAddrGen<ShardingInfoType>& d) {
uint32_t get_page_size(const _ttnn_operations_experimental_ShardedAddrGen<ShardingInfoType>& d) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see a new using added, but its a weird using.
Whats going on here? :)

}
}

inline void print_u32_pages(uint32_t l1_addr, uint32_t elts_per_page, uint32_t npages, uint32_t start = 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jbaumanTT , @akerteszTT can you look at this addition to hw/inc/api please?

Comment on lines +197 to +198
// DPRINT << "OPENING MUX CORE: " << (uint32_t)args.fabric_mux_x << ", " << (uint32_t)args.fabric_mux_y
// << "\n";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

Comment on lines +281 to +282
// DPRINT << "CLOSING MUX CORE: " << (uint32_t)args.fabric_mux_x << ", " << (uint32_t)args.fabric_mux_y
// << "\n";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

class SenderType = WorkerToFabricEdmSender>
FORCE_INLINE void fabric_multicast_bidirectional_atomic_inc_ring_1d(
std::array<SenderType, 4>& fabric_connections,
volatile PACKET_HEADER_TYPE* packet_header_pos,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little out of the loop, what is PACKET_HEADER_TYPE? why is it all caps? is it a macro?

namespace ttnn {
namespace operations::experimental::ccl::moe {

struct ExecuteSelectiveReduceCombine {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove the struct with invoke and register_operation
see what we did here
#36303

Comment on lines +23 to +38
const uint32_t hidden_size;
const uint32_t batch_size;
const uint32_t seq_size;
const uint32_t select_experts_k;
const uint32_t experts;
const uint32_t num_links;

const std::optional<uint32_t> axis;
tt::tt_fabric::Topology topology;

const uint32_t num_token_parallel_cores;
const uint32_t num_data_parallel_cores;
const CoreRangeSet worker_core_range_set;
const CoreRangeSet mux_core_range_set;
const ttnn::MemoryConfig output_memory_config;
const std::optional<GlobalSemaphore> optional_cross_device_semaphore;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these params should not really be const.
if we want things const the whole struct is passed like const

Comment on lines +72 to +78
struct tensor_args_t {
const ttnn::Tensor dense_input_tensor;
const ttnn::Tensor dense_metadata_tensor;
const ttnn::Tensor dense_token_maps_tensor;
const ttnn::Tensor dense_token_counts_tensor;
const std::optional<ttnn::Tensor> optional_output_tensor;
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same note for const-ness

Comment on lines +84 to +115
struct UnifiedSelectReduce {
// Shared variables are the variables that are shared between the create and override_runtime_arguments methods
struct shared_variables_t {
tt::tt_metal::KernelHandle reader_kernel_id;
tt::tt_metal::KernelHandle writer_kernel_id;
std::vector<CoreCoord> cores;
const GlobalSemaphore init_semaphore;
const GlobalSemaphore cross_device_semaphore;
};
using cached_mesh_workload_t = ttnn::device_operation::AdaptedCachedMeshWorkload<shared_variables_t>;

static cached_mesh_workload_t create_mesh_workload(
const operation_attributes_t& operation_attributes,
const ttnn::MeshCoordinateRangeSet& tensor_coords,
const tensor_args_t& tensor_args,
tensor_return_value_t& tensor_return_value);

static ttnn::device_operation::CachedProgram<shared_variables_t> create_at(
const operation_attributes_t& operation_attributes,
const ttnn::MeshCoordinate& mesh_coordinate,
const std::vector<ttnn::MeshCoordinate>& all_mesh_coordinates,
const tensor_args_t& tensor_args,
tensor_return_value_t& tensor_return_value,
const GlobalSemaphore& init_semaphore,
const GlobalSemaphore& cross_device_semaphore);

static void override_runtime_arguments(
cached_mesh_workload_t& cached_workload,
const operation_attributes_t& operation_attributes,
const tensor_args_t& tensor_args,
tensor_return_value_t& tensor_return_value);
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please move factory to its own .hpp/.cpp

// Mandatory methods

// Select the program factory based on the operation attributes and tensor args
static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a PR from @dgomezTT that removes the need for this when there is a single program

static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);

// Empty as there doesn't seem to be any complicated hashing requirement
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a PR from @dgomezTT that removes the need for this when it simply calls cache_miss inside

const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<ttnn::Tensor>& optional_output_tensor,
const std::optional<GlobalSemaphore>& optional_cross_device_semaphore) {
auto input_memory_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that should ideally happen in the prim function i think

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then you can simply do using of that prim function in ttnn namespace and thats it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants